mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 20:09:20 +08:00
Support OAuth Integration for Plugin Tools (#22550)
Co-authored-by: zxhlyh <jasonapring2015@outlook.com> Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
1
.github/workflows/build-push.yml
vendored
1
.github/workflows/build-push.yml
vendored
@@ -6,6 +6,7 @@ on:
|
|||||||
- "main"
|
- "main"
|
||||||
- "deploy/dev"
|
- "deploy/dev"
|
||||||
- "deploy/enterprise"
|
- "deploy/enterprise"
|
||||||
|
- "build/**"
|
||||||
tags:
|
tags:
|
||||||
- "*"
|
- "*"
|
||||||
|
|
||||||
|
|||||||
@@ -5,17 +5,17 @@
|
|||||||
SECRET_KEY=
|
SECRET_KEY=
|
||||||
|
|
||||||
# Console API base URL
|
# Console API base URL
|
||||||
CONSOLE_API_URL=http://127.0.0.1:5001
|
CONSOLE_API_URL=http://localhost:5001
|
||||||
CONSOLE_WEB_URL=http://127.0.0.1:3000
|
CONSOLE_WEB_URL=http://localhost:3000
|
||||||
|
|
||||||
# Service API base URL
|
# Service API base URL
|
||||||
SERVICE_API_URL=http://127.0.0.1:5001
|
SERVICE_API_URL=http://localhost:5001
|
||||||
|
|
||||||
# Web APP base URL
|
# Web APP base URL
|
||||||
APP_WEB_URL=http://127.0.0.1:3000
|
APP_WEB_URL=http://localhost:3000
|
||||||
|
|
||||||
# Files URL
|
# Files URL
|
||||||
FILES_URL=http://127.0.0.1:5001
|
FILES_URL=http://localhost:5001
|
||||||
|
|
||||||
# INTERNAL_FILES_URL is used for plugin daemon communication within Docker network.
|
# INTERNAL_FILES_URL is used for plugin daemon communication within Docker network.
|
||||||
# Set this to the internal Docker service URL for proper plugin file access.
|
# Set this to the internal Docker service URL for proper plugin file access.
|
||||||
@@ -138,8 +138,8 @@ SUPABASE_API_KEY=your-access-key
|
|||||||
SUPABASE_URL=your-server-url
|
SUPABASE_URL=your-server-url
|
||||||
|
|
||||||
# CORS configuration
|
# CORS configuration
|
||||||
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
||||||
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
||||||
|
|
||||||
# Vector database configuration
|
# Vector database configuration
|
||||||
# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase, opengauss, tablestore, matrixone
|
# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase, opengauss, tablestore, matrixone
|
||||||
|
|||||||
@@ -2,19 +2,22 @@ import base64
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import secrets
|
import secrets
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
|
from pydantic import TypeAdapter
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from constants.languages import languages
|
from constants.languages import languages
|
||||||
|
from core.plugin.entities.plugin import ToolProviderID
|
||||||
from core.rag.datasource.vdb.vector_factory import Vector
|
from core.rag.datasource.vdb.vector_factory import Vector
|
||||||
from core.rag.datasource.vdb.vector_type import VectorType
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
|
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
|
||||||
from events.app_event import app_was_created
|
from events.app_event import app_was_created
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
@@ -27,6 +30,7 @@ from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, D
|
|||||||
from models.dataset import Document as DatasetDocument
|
from models.dataset import Document as DatasetDocument
|
||||||
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
|
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
|
||||||
from models.provider import Provider, ProviderModel
|
from models.provider import Provider, ProviderModel
|
||||||
|
from models.tools import ToolOAuthSystemClient
|
||||||
from services.account_service import AccountService, RegisterService, TenantService
|
from services.account_service import AccountService, RegisterService, TenantService
|
||||||
from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs
|
from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs
|
||||||
from services.plugin.data_migration import PluginDataMigration
|
from services.plugin.data_migration import PluginDataMigration
|
||||||
@@ -1155,3 +1159,49 @@ def remove_orphaned_files_on_storage(force: bool):
|
|||||||
click.echo(click.style(f"Removed {removed_files} orphaned files without errors.", fg="green"))
|
click.echo(click.style(f"Removed {removed_files} orphaned files without errors.", fg="green"))
|
||||||
else:
|
else:
|
||||||
click.echo(click.style(f"Removed {removed_files} orphaned files, with {error_files} errors.", fg="yellow"))
|
click.echo(click.style(f"Removed {removed_files} orphaned files, with {error_files} errors.", fg="yellow"))
|
||||||
|
|
||||||
|
|
||||||
|
@click.command("setup-system-tool-oauth-client", help="Setup system tool oauth client.")
|
||||||
|
@click.option("--provider", prompt=True, help="Provider name")
|
||||||
|
@click.option("--client-params", prompt=True, help="Client Params")
|
||||||
|
def setup_system_tool_oauth_client(provider, client_params):
|
||||||
|
"""
|
||||||
|
Setup system tool oauth client
|
||||||
|
"""
|
||||||
|
provider_id = ToolProviderID(provider)
|
||||||
|
provider_name = provider_id.provider_name
|
||||||
|
plugin_id = provider_id.plugin_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
# json validate
|
||||||
|
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
|
||||||
|
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
|
||||||
|
click.echo(click.style("Client params validated successfully.", fg="green"))
|
||||||
|
|
||||||
|
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
|
||||||
|
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
||||||
|
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
|
||||||
|
click.echo(click.style("Client params encrypted successfully.", fg="green"))
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||||
|
return
|
||||||
|
|
||||||
|
deleted_count = (
|
||||||
|
db.session.query(ToolOAuthSystemClient)
|
||||||
|
.filter_by(
|
||||||
|
provider=provider_name,
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
)
|
||||||
|
.delete()
|
||||||
|
)
|
||||||
|
if deleted_count > 0:
|
||||||
|
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
|
||||||
|
|
||||||
|
oauth_client = ToolOAuthSystemClient(
|
||||||
|
provider=provider_name,
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
encrypted_oauth_params=oauth_client_params,
|
||||||
|
)
|
||||||
|
db.session.add(oauth_client)
|
||||||
|
db.session.commit()
|
||||||
|
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
|
||||||
HIDDEN_VALUE = "[__HIDDEN__]"
|
HIDDEN_VALUE = "[__HIDDEN__]"
|
||||||
|
UNKNOWN_VALUE = "[__UNKNOWN__]"
|
||||||
UUID_NIL = "00000000-0000-0000-0000-000000000000"
|
UUID_NIL = "00000000-0000-0000-0000-000000000000"
|
||||||
|
|
||||||
DEFAULT_FILE_NUMBER_LIMITS = 3
|
DEFAULT_FILE_NUMBER_LIMITS = 3
|
||||||
|
|||||||
@@ -1,23 +1,32 @@
|
|||||||
import io
|
import io
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from flask import redirect, send_file
|
from flask import make_response, redirect, request, send_file
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restful import Resource, reqparse
|
from flask_restful import (
|
||||||
from sqlalchemy.orm import Session
|
Resource,
|
||||||
|
reqparse,
|
||||||
|
)
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
from controllers.console.wraps import (
|
||||||
|
account_initialization_required,
|
||||||
|
enterprise_license_required,
|
||||||
|
setup_required,
|
||||||
|
)
|
||||||
from core.mcp.auth.auth_flow import auth, handle_callback
|
from core.mcp.auth.auth_flow import auth, handle_callback
|
||||||
from core.mcp.auth.auth_provider import OAuthClientProvider
|
from core.mcp.auth.auth_provider import OAuthClientProvider
|
||||||
from core.mcp.error import MCPAuthError, MCPError
|
from core.mcp.error import MCPAuthError, MCPError
|
||||||
from core.mcp.mcp_client import MCPClient
|
from core.mcp.mcp_client import MCPClient
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from extensions.ext_database import db
|
from core.plugin.entities.plugin import ToolProviderID
|
||||||
from libs.helper import alphanumeric, uuid_value
|
from core.plugin.impl.oauth import OAuthHandler
|
||||||
|
from core.tools.entities.tool_entities import CredentialType
|
||||||
|
from libs.helper import StrLen, alphanumeric, uuid_value
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
|
from services.plugin.oauth_service import OAuthProxyService
|
||||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||||
from services.tools.mcp_tools_mange_service import MCPToolManageService
|
from services.tools.mcp_tools_mange_service import MCPToolManageService
|
||||||
@@ -89,7 +98,7 @@ class ToolBuiltinProviderInfoApi(Resource):
|
|||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
tenant_id = user.current_tenant_id
|
||||||
|
|
||||||
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(user_id, tenant_id, provider))
|
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
|
||||||
|
|
||||||
|
|
||||||
class ToolBuiltinProviderDeleteApi(Resource):
|
class ToolBuiltinProviderDeleteApi(Resource):
|
||||||
@@ -98,17 +107,47 @@ class ToolBuiltinProviderDeleteApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider):
|
def post(self, provider):
|
||||||
user = current_user
|
user = current_user
|
||||||
|
|
||||||
if not user.is_admin_or_owner:
|
if not user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
|
tenant_id = user.current_tenant_id
|
||||||
|
req = reqparse.RequestParser()
|
||||||
|
req.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||||
|
args = req.parse_args()
|
||||||
|
|
||||||
|
return BuiltinToolManageService.delete_builtin_tool_provider(
|
||||||
|
tenant_id,
|
||||||
|
provider,
|
||||||
|
args["credential_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolBuiltinProviderAddApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self, provider):
|
||||||
|
user = current_user
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
tenant_id = user.current_tenant_id
|
||||||
|
|
||||||
return BuiltinToolManageService.delete_builtin_tool_provider(
|
parser = reqparse.RequestParser()
|
||||||
user_id,
|
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
tenant_id,
|
parser.add_argument("name", type=StrLen(30), required=False, nullable=False, location="json")
|
||||||
provider,
|
parser.add_argument("type", type=str, required=True, nullable=False, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args["type"] not in CredentialType.values():
|
||||||
|
raise ValueError(f"Invalid credential type: {args['type']}")
|
||||||
|
|
||||||
|
return BuiltinToolManageService.add_builtin_tool_provider(
|
||||||
|
user_id=user_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider=provider,
|
||||||
|
credentials=args["credentials"],
|
||||||
|
name=args["name"],
|
||||||
|
api_type=CredentialType.of(args["type"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -126,19 +165,20 @@ class ToolBuiltinProviderUpdateApi(Resource):
|
|||||||
tenant_id = user.current_tenant_id
|
tenant_id = user.current_tenant_id
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||||
|
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||||
|
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
result = BuiltinToolManageService.update_builtin_tool_provider(
|
||||||
result = BuiltinToolManageService.update_builtin_tool_provider(
|
user_id=user_id,
|
||||||
session=session,
|
tenant_id=tenant_id,
|
||||||
user_id=user_id,
|
provider=provider,
|
||||||
tenant_id=tenant_id,
|
credential_id=args["credential_id"],
|
||||||
provider_name=provider,
|
credentials=args.get("credentials", None),
|
||||||
credentials=args["credentials"],
|
name=args.get("name", ""),
|
||||||
)
|
)
|
||||||
session.commit()
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@@ -149,9 +189,11 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
|
|||||||
def get(self, provider):
|
def get(self, provider):
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
return BuiltinToolManageService.get_builtin_tool_provider_credentials(
|
return jsonable_encoder(
|
||||||
tenant_id=tenant_id,
|
BuiltinToolManageService.get_builtin_tool_provider_credentials(
|
||||||
provider_name=provider,
|
tenant_id=tenant_id,
|
||||||
|
provider_name=provider,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -344,12 +386,15 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider):
|
def get(self, provider, credential_type):
|
||||||
user = current_user
|
user = current_user
|
||||||
|
|
||||||
tenant_id = user.current_tenant_id
|
tenant_id = user.current_tenant_id
|
||||||
|
|
||||||
return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, tenant_id)
|
return jsonable_encoder(
|
||||||
|
BuiltinToolManageService.list_builtin_provider_credentials_schema(
|
||||||
|
provider, CredentialType.of(credential_type), tenant_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ToolApiProviderSchemaApi(Resource):
|
class ToolApiProviderSchemaApi(Resource):
|
||||||
@@ -586,15 +631,12 @@ class ToolApiListApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
user = current_user
|
user = current_user
|
||||||
|
|
||||||
user_id = user.id
|
|
||||||
tenant_id = user.current_tenant_id
|
tenant_id = user.current_tenant_id
|
||||||
|
|
||||||
return jsonable_encoder(
|
return jsonable_encoder(
|
||||||
[
|
[
|
||||||
provider.to_dict()
|
provider.to_dict()
|
||||||
for provider in ApiToolManageService.list_api_tools(
|
for provider in ApiToolManageService.list_api_tools(
|
||||||
user_id,
|
|
||||||
tenant_id,
|
tenant_id,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
@@ -631,6 +673,179 @@ class ToolLabelsApi(Resource):
|
|||||||
return jsonable_encoder(ToolLabelsService.list_tool_labels())
|
return jsonable_encoder(ToolLabelsService.list_tool_labels())
|
||||||
|
|
||||||
|
|
||||||
|
class ToolPluginOAuthApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self, provider):
|
||||||
|
tool_provider = ToolProviderID(provider)
|
||||||
|
plugin_id = tool_provider.plugin_id
|
||||||
|
provider_name = tool_provider.provider_name
|
||||||
|
|
||||||
|
# todo check permission
|
||||||
|
user = current_user
|
||||||
|
|
||||||
|
if not user.is_admin_or_owner:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
tenant_id = user.current_tenant_id
|
||||||
|
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider)
|
||||||
|
if oauth_client_params is None:
|
||||||
|
raise Forbidden("no oauth available client config found for this tool provider")
|
||||||
|
|
||||||
|
oauth_handler = OAuthHandler()
|
||||||
|
context_id = OAuthProxyService.create_proxy_context(
|
||||||
|
user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
|
||||||
|
)
|
||||||
|
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
|
||||||
|
authorization_url_response = oauth_handler.get_authorization_url(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
user_id=user.id,
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
provider=provider_name,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
system_credentials=oauth_client_params,
|
||||||
|
)
|
||||||
|
response = make_response(jsonable_encoder(authorization_url_response))
|
||||||
|
response.set_cookie(
|
||||||
|
"context_id",
|
||||||
|
context_id,
|
||||||
|
httponly=True,
|
||||||
|
samesite="Lax",
|
||||||
|
max_age=OAuthProxyService.__MAX_AGE__,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
class ToolOAuthCallback(Resource):
|
||||||
|
@setup_required
|
||||||
|
def get(self, provider):
|
||||||
|
context_id = request.cookies.get("context_id")
|
||||||
|
if not context_id:
|
||||||
|
raise Forbidden("context_id not found")
|
||||||
|
|
||||||
|
context = OAuthProxyService.use_proxy_context(context_id)
|
||||||
|
if context is None:
|
||||||
|
raise Forbidden("Invalid context_id")
|
||||||
|
|
||||||
|
tool_provider = ToolProviderID(provider)
|
||||||
|
plugin_id = tool_provider.plugin_id
|
||||||
|
provider_name = tool_provider.provider_name
|
||||||
|
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
|
||||||
|
|
||||||
|
oauth_handler = OAuthHandler()
|
||||||
|
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id, provider)
|
||||||
|
if oauth_client_params is None:
|
||||||
|
raise Forbidden("no oauth available client config found for this tool provider")
|
||||||
|
|
||||||
|
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
|
||||||
|
credentials = oauth_handler.get_credentials(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
provider=provider_name,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
system_credentials=oauth_client_params,
|
||||||
|
request=request,
|
||||||
|
).credentials
|
||||||
|
|
||||||
|
if not credentials:
|
||||||
|
raise Exception("the plugin credentials failed")
|
||||||
|
|
||||||
|
# add credentials to database
|
||||||
|
BuiltinToolManageService.add_builtin_tool_provider(
|
||||||
|
user_id=user_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider=provider,
|
||||||
|
credentials=dict(credentials),
|
||||||
|
api_type=CredentialType.OAUTH2,
|
||||||
|
)
|
||||||
|
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||||
|
|
||||||
|
|
||||||
|
class ToolBuiltinProviderSetDefaultApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self, provider):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
return BuiltinToolManageService.set_default_provider(
|
||||||
|
tenant_id=current_user.current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolOAuthCustomClient(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self, provider):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||||
|
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
user = current_user
|
||||||
|
|
||||||
|
if not user.is_admin_or_owner:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
return BuiltinToolManageService.save_custom_oauth_client_params(
|
||||||
|
tenant_id=user.current_tenant_id,
|
||||||
|
provider=provider,
|
||||||
|
client_params=args.get("client_params", {}),
|
||||||
|
enable_oauth_custom_client=args.get("enable_oauth_custom_client", True),
|
||||||
|
)
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self, provider):
|
||||||
|
return jsonable_encoder(
|
||||||
|
BuiltinToolManageService.get_custom_oauth_client_params(
|
||||||
|
tenant_id=current_user.current_tenant_id, provider=provider
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def delete(self, provider):
|
||||||
|
return jsonable_encoder(
|
||||||
|
BuiltinToolManageService.delete_custom_oauth_client_params(
|
||||||
|
tenant_id=current_user.current_tenant_id, provider=provider
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self, provider):
|
||||||
|
return jsonable_encoder(
|
||||||
|
BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema(
|
||||||
|
tenant_id=current_user.current_tenant_id, provider_name=provider
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolBuiltinProviderGetCredentialInfoApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self, provider):
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
return jsonable_encoder(
|
||||||
|
BuiltinToolManageService.get_builtin_tool_provider_credential_info(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider=provider,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ToolProviderMCPApi(Resource):
|
class ToolProviderMCPApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@@ -794,17 +1009,33 @@ class ToolMCPCallbackApi(Resource):
|
|||||||
# tool provider
|
# tool provider
|
||||||
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
|
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
|
||||||
|
|
||||||
|
# tool oauth
|
||||||
|
api.add_resource(ToolPluginOAuthApi, "/oauth/plugin/<path:provider>/tool/authorization-url")
|
||||||
|
api.add_resource(ToolOAuthCallback, "/oauth/plugin/<path:provider>/tool/callback")
|
||||||
|
api.add_resource(ToolOAuthCustomClient, "/workspaces/current/tool-provider/builtin/<path:provider>/oauth/custom-client")
|
||||||
|
|
||||||
# builtin tool provider
|
# builtin tool provider
|
||||||
api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/tools")
|
api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/tools")
|
||||||
api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/info")
|
api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/info")
|
||||||
|
api.add_resource(ToolBuiltinProviderAddApi, "/workspaces/current/tool-provider/builtin/<path:provider>/add")
|
||||||
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete")
|
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete")
|
||||||
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update")
|
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update")
|
||||||
|
api.add_resource(
|
||||||
|
ToolBuiltinProviderSetDefaultApi, "/workspaces/current/tool-provider/builtin/<path:provider>/default-credential"
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
ToolBuiltinProviderGetCredentialInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credential/info"
|
||||||
|
)
|
||||||
api.add_resource(
|
api.add_resource(
|
||||||
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credentials"
|
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credentials"
|
||||||
)
|
)
|
||||||
api.add_resource(
|
api.add_resource(
|
||||||
ToolBuiltinProviderCredentialsSchemaApi,
|
ToolBuiltinProviderCredentialsSchemaApi,
|
||||||
"/workspaces/current/tool-provider/builtin/<path:provider>/credentials_schema",
|
"/workspaces/current/tool-provider/builtin/<path:provider>/credential/schema/<path:credential_type>",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
ToolBuiltinProviderGetOauthClientSchemaApi,
|
||||||
|
"/workspaces/current/tool-provider/builtin/<path:provider>/oauth/client-schema",
|
||||||
)
|
)
|
||||||
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<path:provider>/icon")
|
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<path:provider>/icon")
|
||||||
|
|
||||||
|
|||||||
@@ -175,6 +175,7 @@ class PluginInvokeToolApi(Resource):
|
|||||||
provider=payload.provider,
|
provider=payload.provider,
|
||||||
tool_name=payload.tool,
|
tool_name=payload.tool,
|
||||||
tool_parameters=payload.tool_parameters,
|
tool_parameters=payload.tool_parameters,
|
||||||
|
credential_id=payload.credential_id,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ class AgentToolEntity(BaseModel):
|
|||||||
tool_name: str
|
tool_name: str
|
||||||
tool_parameters: dict[str, Any] = Field(default_factory=dict)
|
tool_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||||
plugin_unique_identifier: str | None = None
|
plugin_unique_identifier: str | None = None
|
||||||
|
credential_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class AgentPromptEntity(BaseModel):
|
class AgentPromptEntity(BaseModel):
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
from core.agent.entities import AgentInvokeMessage
|
from core.agent.entities import AgentInvokeMessage
|
||||||
from core.agent.plugin_entities import AgentStrategyParameter
|
from core.agent.plugin_entities import AgentStrategyParameter
|
||||||
|
from core.plugin.entities.request import InvokeCredentials
|
||||||
|
|
||||||
|
|
||||||
class BaseAgentStrategy(ABC):
|
class BaseAgentStrategy(ABC):
|
||||||
@@ -18,11 +19,12 @@ class BaseAgentStrategy(ABC):
|
|||||||
conversation_id: Optional[str] = None,
|
conversation_id: Optional[str] = None,
|
||||||
app_id: Optional[str] = None,
|
app_id: Optional[str] = None,
|
||||||
message_id: Optional[str] = None,
|
message_id: Optional[str] = None,
|
||||||
|
credentials: Optional[InvokeCredentials] = None,
|
||||||
) -> Generator[AgentInvokeMessage, None, None]:
|
) -> Generator[AgentInvokeMessage, None, None]:
|
||||||
"""
|
"""
|
||||||
Invoke the agent strategy.
|
Invoke the agent strategy.
|
||||||
"""
|
"""
|
||||||
yield from self._invoke(params, user_id, conversation_id, app_id, message_id)
|
yield from self._invoke(params, user_id, conversation_id, app_id, message_id, credentials)
|
||||||
|
|
||||||
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
|
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
|
||||||
"""
|
"""
|
||||||
@@ -38,5 +40,6 @@ class BaseAgentStrategy(ABC):
|
|||||||
conversation_id: Optional[str] = None,
|
conversation_id: Optional[str] = None,
|
||||||
app_id: Optional[str] = None,
|
app_id: Optional[str] = None,
|
||||||
message_id: Optional[str] = None,
|
message_id: Optional[str] = None,
|
||||||
|
credentials: Optional[InvokeCredentials] = None,
|
||||||
) -> Generator[AgentInvokeMessage, None, None]:
|
) -> Generator[AgentInvokeMessage, None, None]:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from typing import Any, Optional
|
|||||||
from core.agent.entities import AgentInvokeMessage
|
from core.agent.entities import AgentInvokeMessage
|
||||||
from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter
|
from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter
|
||||||
from core.agent.strategy.base import BaseAgentStrategy
|
from core.agent.strategy.base import BaseAgentStrategy
|
||||||
|
from core.plugin.entities.request import InvokeCredentials, PluginInvokeContext
|
||||||
from core.plugin.impl.agent import PluginAgentClient
|
from core.plugin.impl.agent import PluginAgentClient
|
||||||
from core.plugin.utils.converter import convert_parameters_to_plugin_format
|
from core.plugin.utils.converter import convert_parameters_to_plugin_format
|
||||||
|
|
||||||
@@ -40,6 +41,7 @@ class PluginAgentStrategy(BaseAgentStrategy):
|
|||||||
conversation_id: Optional[str] = None,
|
conversation_id: Optional[str] = None,
|
||||||
app_id: Optional[str] = None,
|
app_id: Optional[str] = None,
|
||||||
message_id: Optional[str] = None,
|
message_id: Optional[str] = None,
|
||||||
|
credentials: Optional[InvokeCredentials] = None,
|
||||||
) -> Generator[AgentInvokeMessage, None, None]:
|
) -> Generator[AgentInvokeMessage, None, None]:
|
||||||
"""
|
"""
|
||||||
Invoke the agent strategy.
|
Invoke the agent strategy.
|
||||||
@@ -58,4 +60,5 @@ class PluginAgentStrategy(BaseAgentStrategy):
|
|||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
|
context=PluginInvokeContext(credentials=credentials or InvokeCredentials()),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ class AgentConfigManager:
|
|||||||
"provider_id": tool["provider_id"],
|
"provider_id": tool["provider_id"],
|
||||||
"tool_name": tool["tool_name"],
|
"tool_name": tool["tool_name"],
|
||||||
"tool_parameters": tool.get("tool_parameters", {}),
|
"tool_parameters": tool.get("tool_parameters", {}),
|
||||||
|
"credential_id": tool.get("credential_id", None),
|
||||||
}
|
}
|
||||||
|
|
||||||
agent_tools.append(AgentToolEntity(**agent_tool_properties))
|
agent_tools.append(AgentToolEntity(**agent_tool_properties))
|
||||||
|
|||||||
84
api/core/helper/provider_cache.py
Normal file
84
api/core/helper/provider_cache.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
import json
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from json import JSONDecodeError
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderCredentialsCache(ABC):
|
||||||
|
"""Base class for provider credentials cache"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.cache_key = self._generate_cache_key(**kwargs)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _generate_cache_key(self, **kwargs) -> str:
|
||||||
|
"""Generate cache key based on subclass implementation"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get(self) -> Optional[dict]:
|
||||||
|
"""Get cached provider credentials"""
|
||||||
|
cached_credentials = redis_client.get(self.cache_key)
|
||||||
|
if cached_credentials:
|
||||||
|
try:
|
||||||
|
cached_credentials = cached_credentials.decode("utf-8")
|
||||||
|
return dict(json.loads(cached_credentials))
|
||||||
|
except JSONDecodeError:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set(self, config: dict[str, Any]) -> None:
|
||||||
|
"""Cache provider credentials"""
|
||||||
|
redis_client.setex(self.cache_key, 86400, json.dumps(config))
|
||||||
|
|
||||||
|
def delete(self) -> None:
|
||||||
|
"""Delete cached provider credentials"""
|
||||||
|
redis_client.delete(self.cache_key)
|
||||||
|
|
||||||
|
|
||||||
|
class SingletonProviderCredentialsCache(ProviderCredentialsCache):
|
||||||
|
"""Cache for tool single provider credentials"""
|
||||||
|
|
||||||
|
def __init__(self, tenant_id: str, provider_type: str, provider_identity: str):
|
||||||
|
super().__init__(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider_type=provider_type,
|
||||||
|
provider_identity=provider_identity,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate_cache_key(self, **kwargs) -> str:
|
||||||
|
tenant_id = kwargs["tenant_id"]
|
||||||
|
provider_type = kwargs["provider_type"]
|
||||||
|
identity_name = kwargs["provider_identity"]
|
||||||
|
identity_id = f"{provider_type}.{identity_name}"
|
||||||
|
return f"{provider_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
|
||||||
|
|
||||||
|
|
||||||
|
class ToolProviderCredentialsCache(ProviderCredentialsCache):
|
||||||
|
"""Cache for tool provider credentials"""
|
||||||
|
|
||||||
|
def __init__(self, tenant_id: str, provider: str, credential_id: str):
|
||||||
|
super().__init__(tenant_id=tenant_id, provider=provider, credential_id=credential_id)
|
||||||
|
|
||||||
|
def _generate_cache_key(self, **kwargs) -> str:
|
||||||
|
tenant_id = kwargs["tenant_id"]
|
||||||
|
provider = kwargs["provider"]
|
||||||
|
credential_id = kwargs["credential_id"]
|
||||||
|
return f"tool_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}"
|
||||||
|
|
||||||
|
|
||||||
|
class NoOpProviderCredentialCache:
|
||||||
|
"""No-op provider credential cache"""
|
||||||
|
|
||||||
|
def get(self) -> Optional[dict]:
|
||||||
|
"""Get cached provider credentials"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set(self, config: dict[str, Any]) -> None:
|
||||||
|
"""Cache provider credentials"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def delete(self) -> None:
|
||||||
|
"""Delete cached provider credentials"""
|
||||||
|
pass
|
||||||
@@ -1,51 +0,0 @@
|
|||||||
import json
|
|
||||||
from enum import Enum
|
|
||||||
from json import JSONDecodeError
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from extensions.ext_redis import redis_client
|
|
||||||
|
|
||||||
|
|
||||||
class ToolProviderCredentialsCacheType(Enum):
|
|
||||||
PROVIDER = "tool_provider"
|
|
||||||
ENDPOINT = "endpoint"
|
|
||||||
|
|
||||||
|
|
||||||
class ToolProviderCredentialsCache:
|
|
||||||
def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType):
|
|
||||||
self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
|
|
||||||
|
|
||||||
def get(self) -> Optional[dict]:
|
|
||||||
"""
|
|
||||||
Get cached model provider credentials.
|
|
||||||
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
cached_provider_credentials = redis_client.get(self.cache_key)
|
|
||||||
if cached_provider_credentials:
|
|
||||||
try:
|
|
||||||
cached_provider_credentials = cached_provider_credentials.decode("utf-8")
|
|
||||||
cached_provider_credentials = json.loads(cached_provider_credentials)
|
|
||||||
except JSONDecodeError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return dict(cached_provider_credentials)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def set(self, credentials: dict) -> None:
|
|
||||||
"""
|
|
||||||
Cache model provider credentials.
|
|
||||||
|
|
||||||
:param credentials: provider credentials
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
redis_client.setex(self.cache_key, 86400, json.dumps(credentials))
|
|
||||||
|
|
||||||
def delete(self) -> None:
|
|
||||||
"""
|
|
||||||
Delete cached model provider credentials.
|
|
||||||
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
redis_client.delete(self.cache_key)
|
|
||||||
@@ -1,16 +1,20 @@
|
|||||||
|
from core.helper.provider_cache import SingletonProviderCredentialsCache
|
||||||
from core.plugin.entities.request import RequestInvokeEncrypt
|
from core.plugin.entities.request import RequestInvokeEncrypt
|
||||||
from core.tools.utils.configuration import ProviderConfigEncrypter
|
from core.tools.utils.encryption import create_provider_encrypter
|
||||||
from models.account import Tenant
|
from models.account import Tenant
|
||||||
|
|
||||||
|
|
||||||
class PluginEncrypter:
|
class PluginEncrypter:
|
||||||
@classmethod
|
@classmethod
|
||||||
def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict:
|
def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict:
|
||||||
encrypter = ProviderConfigEncrypter(
|
encrypter, cache = create_provider_encrypter(
|
||||||
tenant_id=tenant.id,
|
tenant_id=tenant.id,
|
||||||
config=payload.config,
|
config=payload.config,
|
||||||
provider_type=payload.namespace,
|
cache=SingletonProviderCredentialsCache(
|
||||||
provider_identity=payload.identity,
|
tenant_id=tenant.id,
|
||||||
|
provider_type=payload.namespace,
|
||||||
|
provider_identity=payload.identity,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if payload.opt == "encrypt":
|
if payload.opt == "encrypt":
|
||||||
@@ -22,7 +26,7 @@ class PluginEncrypter:
|
|||||||
"data": encrypter.decrypt(payload.data),
|
"data": encrypter.decrypt(payload.data),
|
||||||
}
|
}
|
||||||
elif payload.opt == "clear":
|
elif payload.opt == "clear":
|
||||||
encrypter.delete_tool_credentials_cache()
|
cache.delete()
|
||||||
return {
|
return {
|
||||||
"data": {},
|
"data": {},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
||||||
@@ -23,6 +23,7 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation):
|
|||||||
provider: str,
|
provider: str,
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
tool_parameters: dict[str, Any],
|
tool_parameters: dict[str, Any],
|
||||||
|
credential_id: Optional[str] = None,
|
||||||
) -> Generator[ToolInvokeMessage, None, None]:
|
) -> Generator[ToolInvokeMessage, None, None]:
|
||||||
"""
|
"""
|
||||||
invoke tool
|
invoke tool
|
||||||
@@ -30,7 +31,7 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation):
|
|||||||
# get tool runtime
|
# get tool runtime
|
||||||
try:
|
try:
|
||||||
tool_runtime = ToolManager.get_tool_runtime_from_plugin(
|
tool_runtime = ToolManager.get_tool_runtime_from_plugin(
|
||||||
tool_type, tenant_id, provider, tool_name, tool_parameters
|
tool_type, tenant_id, provider, tool_name, tool_parameters, credential_id
|
||||||
)
|
)
|
||||||
response = ToolEngine.generic_invoke(
|
response = ToolEngine.generic_invoke(
|
||||||
tool_runtime, tool_parameters, user_id, DifyWorkflowCallbackHandler(), workflow_call_depth=1
|
tool_runtime, tool_parameters, user_id, DifyWorkflowCallbackHandler(), workflow_call_depth=1
|
||||||
|
|||||||
@@ -27,6 +27,20 @@ from core.workflow.nodes.question_classifier.entities import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InvokeCredentials(BaseModel):
|
||||||
|
tool_credentials: dict[str, str] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Map of tool provider to credential id, used to store the credential id for the tool provider.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginInvokeContext(BaseModel):
|
||||||
|
credentials: Optional[InvokeCredentials] = Field(
|
||||||
|
default_factory=InvokeCredentials,
|
||||||
|
description="Credentials context for the plugin invocation or backward invocation.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RequestInvokeTool(BaseModel):
|
class RequestInvokeTool(BaseModel):
|
||||||
"""
|
"""
|
||||||
Request to invoke a tool
|
Request to invoke a tool
|
||||||
@@ -36,6 +50,7 @@ class RequestInvokeTool(BaseModel):
|
|||||||
provider: str
|
provider: str
|
||||||
tool: str
|
tool: str
|
||||||
tool_parameters: dict
|
tool_parameters: dict
|
||||||
|
credential_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class BaseRequestInvokeModel(BaseModel):
|
class BaseRequestInvokeModel(BaseModel):
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from core.plugin.entities.plugin import GenericProviderID
|
|||||||
from core.plugin.entities.plugin_daemon import (
|
from core.plugin.entities.plugin_daemon import (
|
||||||
PluginAgentProviderEntity,
|
PluginAgentProviderEntity,
|
||||||
)
|
)
|
||||||
|
from core.plugin.entities.request import PluginInvokeContext
|
||||||
from core.plugin.impl.base import BasePluginClient
|
from core.plugin.impl.base import BasePluginClient
|
||||||
|
|
||||||
|
|
||||||
@@ -83,6 +84,7 @@ class PluginAgentClient(BasePluginClient):
|
|||||||
conversation_id: Optional[str] = None,
|
conversation_id: Optional[str] = None,
|
||||||
app_id: Optional[str] = None,
|
app_id: Optional[str] = None,
|
||||||
message_id: Optional[str] = None,
|
message_id: Optional[str] = None,
|
||||||
|
context: Optional[PluginInvokeContext] = None,
|
||||||
) -> Generator[AgentInvokeMessage, None, None]:
|
) -> Generator[AgentInvokeMessage, None, None]:
|
||||||
"""
|
"""
|
||||||
Invoke the agent with the given tenant, user, plugin, provider, name and parameters.
|
Invoke the agent with the given tenant, user, plugin, provider, name and parameters.
|
||||||
@@ -99,6 +101,7 @@ class PluginAgentClient(BasePluginClient):
|
|||||||
"conversation_id": conversation_id,
|
"conversation_id": conversation_id,
|
||||||
"app_id": app_id,
|
"app_id": app_id,
|
||||||
"message_id": message_id,
|
"message_id": message_id,
|
||||||
|
"context": context.model_dump() if context else {},
|
||||||
"data": {
|
"data": {
|
||||||
"agent_strategy_provider": agent_provider_id.provider_name,
|
"agent_strategy_provider": agent_provider_id.provider_name,
|
||||||
"agent_strategy": agent_strategy,
|
"agent_strategy": agent_strategy,
|
||||||
|
|||||||
@@ -15,27 +15,32 @@ class OAuthHandler(BasePluginClient):
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
plugin_id: str,
|
plugin_id: str,
|
||||||
provider: str,
|
provider: str,
|
||||||
|
redirect_uri: str,
|
||||||
system_credentials: Mapping[str, Any],
|
system_credentials: Mapping[str, Any],
|
||||||
) -> PluginOAuthAuthorizationUrlResponse:
|
) -> PluginOAuthAuthorizationUrlResponse:
|
||||||
response = self._request_with_plugin_daemon_response_stream(
|
try:
|
||||||
"POST",
|
response = self._request_with_plugin_daemon_response_stream(
|
||||||
f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url",
|
"POST",
|
||||||
PluginOAuthAuthorizationUrlResponse,
|
f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url",
|
||||||
data={
|
PluginOAuthAuthorizationUrlResponse,
|
||||||
"user_id": user_id,
|
data={
|
||||||
"data": {
|
"user_id": user_id,
|
||||||
"provider": provider,
|
"data": {
|
||||||
"system_credentials": system_credentials,
|
"provider": provider,
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
"system_credentials": system_credentials,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
headers={
|
||||||
headers={
|
"X-Plugin-ID": plugin_id,
|
||||||
"X-Plugin-ID": plugin_id,
|
"Content-Type": "application/json",
|
||||||
"Content-Type": "application/json",
|
},
|
||||||
},
|
)
|
||||||
)
|
for resp in response:
|
||||||
for resp in response:
|
return resp
|
||||||
return resp
|
raise ValueError("No response received from plugin daemon for authorization URL request.")
|
||||||
raise ValueError("No response received from plugin daemon for authorization URL request.")
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error getting authorization URL: {e}")
|
||||||
|
|
||||||
def get_credentials(
|
def get_credentials(
|
||||||
self,
|
self,
|
||||||
@@ -43,6 +48,7 @@ class OAuthHandler(BasePluginClient):
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
plugin_id: str,
|
plugin_id: str,
|
||||||
provider: str,
|
provider: str,
|
||||||
|
redirect_uri: str,
|
||||||
system_credentials: Mapping[str, Any],
|
system_credentials: Mapping[str, Any],
|
||||||
request: Request,
|
request: Request,
|
||||||
) -> PluginOAuthCredentialsResponse:
|
) -> PluginOAuthCredentialsResponse:
|
||||||
@@ -50,30 +56,33 @@ class OAuthHandler(BasePluginClient):
|
|||||||
Get credentials from the given request.
|
Get credentials from the given request.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# encode request to raw http request
|
try:
|
||||||
raw_request_bytes = self._convert_request_to_raw_data(request)
|
# encode request to raw http request
|
||||||
|
raw_request_bytes = self._convert_request_to_raw_data(request)
|
||||||
response = self._request_with_plugin_daemon_response_stream(
|
response = self._request_with_plugin_daemon_response_stream(
|
||||||
"POST",
|
"POST",
|
||||||
f"plugin/{tenant_id}/dispatch/oauth/get_credentials",
|
f"plugin/{tenant_id}/dispatch/oauth/get_credentials",
|
||||||
PluginOAuthCredentialsResponse,
|
PluginOAuthCredentialsResponse,
|
||||||
data={
|
data={
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"data": {
|
"data": {
|
||||||
"provider": provider,
|
"provider": provider,
|
||||||
"system_credentials": system_credentials,
|
"redirect_uri": redirect_uri,
|
||||||
# for json serialization
|
"system_credentials": system_credentials,
|
||||||
"raw_http_request": binascii.hexlify(raw_request_bytes).decode(),
|
# for json serialization
|
||||||
|
"raw_http_request": binascii.hexlify(raw_request_bytes).decode(),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
headers={
|
||||||
headers={
|
"X-Plugin-ID": plugin_id,
|
||||||
"X-Plugin-ID": plugin_id,
|
"Content-Type": "application/json",
|
||||||
"Content-Type": "application/json",
|
},
|
||||||
},
|
)
|
||||||
)
|
for resp in response:
|
||||||
for resp in response:
|
return resp
|
||||||
return resp
|
raise ValueError("No response received from plugin daemon for authorization URL request.")
|
||||||
raise ValueError("No response received from plugin daemon for authorization URL request.")
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error getting credentials: {e}")
|
||||||
|
|
||||||
def _convert_request_to_raw_data(self, request: Request) -> bytes:
|
def _convert_request_to_raw_data(self, request: Request) -> bytes:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from pydantic import BaseModel
|
|||||||
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
|
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
|
||||||
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
|
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
|
||||||
from core.plugin.impl.base import BasePluginClient
|
from core.plugin.impl.base import BasePluginClient
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
from core.tools.entities.tool_entities import CredentialType, ToolInvokeMessage, ToolParameter
|
||||||
|
|
||||||
|
|
||||||
class PluginToolManager(BasePluginClient):
|
class PluginToolManager(BasePluginClient):
|
||||||
@@ -78,6 +78,7 @@ class PluginToolManager(BasePluginClient):
|
|||||||
tool_provider: str,
|
tool_provider: str,
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
credentials: dict[str, Any],
|
credentials: dict[str, Any],
|
||||||
|
credential_type: CredentialType,
|
||||||
tool_parameters: dict[str, Any],
|
tool_parameters: dict[str, Any],
|
||||||
conversation_id: Optional[str] = None,
|
conversation_id: Optional[str] = None,
|
||||||
app_id: Optional[str] = None,
|
app_id: Optional[str] = None,
|
||||||
@@ -102,6 +103,7 @@ class PluginToolManager(BasePluginClient):
|
|||||||
"provider": tool_provider_id.provider_name,
|
"provider": tool_provider_id.provider_name,
|
||||||
"tool": tool_name,
|
"tool": tool_name,
|
||||||
"credentials": credentials,
|
"credentials": credentials,
|
||||||
|
"credential_type": credential_type,
|
||||||
"tool_parameters": tool_parameters,
|
"tool_parameters": tool_parameters,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from openai import BaseModel
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.tools.entities.tool_entities import ToolInvokeFrom
|
from core.tools.entities.tool_entities import CredentialType, ToolInvokeFrom
|
||||||
|
|
||||||
|
|
||||||
class ToolRuntime(BaseModel):
|
class ToolRuntime(BaseModel):
|
||||||
@@ -17,6 +17,7 @@ class ToolRuntime(BaseModel):
|
|||||||
invoke_from: Optional[InvokeFrom] = None
|
invoke_from: Optional[InvokeFrom] = None
|
||||||
tool_invoke_from: Optional[ToolInvokeFrom] = None
|
tool_invoke_from: Optional[ToolInvokeFrom] = None
|
||||||
credentials: dict[str, Any] = Field(default_factory=dict)
|
credentials: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
credential_type: CredentialType = Field(default=CredentialType.API_KEY)
|
||||||
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
|
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,13 @@ from core.helper.module_import_helper import load_single_subclass_from_source
|
|||||||
from core.tools.__base.tool_provider import ToolProviderController
|
from core.tools.__base.tool_provider import ToolProviderController
|
||||||
from core.tools.__base.tool_runtime import ToolRuntime
|
from core.tools.__base.tool_runtime import ToolRuntime
|
||||||
from core.tools.builtin_tool.tool import BuiltinTool
|
from core.tools.builtin_tool.tool import BuiltinTool
|
||||||
from core.tools.entities.tool_entities import ToolEntity, ToolProviderEntity, ToolProviderType
|
from core.tools.entities.tool_entities import (
|
||||||
|
CredentialType,
|
||||||
|
OAuthSchema,
|
||||||
|
ToolEntity,
|
||||||
|
ToolProviderEntity,
|
||||||
|
ToolProviderType,
|
||||||
|
)
|
||||||
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
|
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
|
||||||
from core.tools.errors import (
|
from core.tools.errors import (
|
||||||
ToolProviderNotFoundError,
|
ToolProviderNotFoundError,
|
||||||
@@ -39,10 +45,18 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||||||
credential_dict = provider_yaml.get("credentials_for_provider", {}).get(credential, {})
|
credential_dict = provider_yaml.get("credentials_for_provider", {}).get(credential, {})
|
||||||
credentials_schema.append(credential_dict)
|
credentials_schema.append(credential_dict)
|
||||||
|
|
||||||
|
oauth_schema = None
|
||||||
|
if provider_yaml.get("oauth_schema", None) is not None:
|
||||||
|
oauth_schema = OAuthSchema(
|
||||||
|
client_schema=provider_yaml.get("oauth_schema", {}).get("client_schema", []),
|
||||||
|
credentials_schema=provider_yaml.get("oauth_schema", {}).get("credentials_schema", []),
|
||||||
|
)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
entity=ToolProviderEntity(
|
entity=ToolProviderEntity(
|
||||||
identity=provider_yaml["identity"],
|
identity=provider_yaml["identity"],
|
||||||
credentials_schema=credentials_schema,
|
credentials_schema=credentials_schema,
|
||||||
|
oauth_schema=oauth_schema,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -97,10 +111,39 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||||||
|
|
||||||
:return: the credentials schema
|
:return: the credentials schema
|
||||||
"""
|
"""
|
||||||
if not self.entity.credentials_schema:
|
return self.get_credentials_schema_by_type(CredentialType.API_KEY.value)
|
||||||
return []
|
|
||||||
|
|
||||||
return self.entity.credentials_schema.copy()
|
def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]:
|
||||||
|
"""
|
||||||
|
returns the credentials schema of the provider
|
||||||
|
|
||||||
|
:param credential_type: the type of the credential
|
||||||
|
:return: the credentials schema of the provider
|
||||||
|
"""
|
||||||
|
if credential_type == CredentialType.OAUTH2.value:
|
||||||
|
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
|
||||||
|
if credential_type == CredentialType.API_KEY.value:
|
||||||
|
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
|
||||||
|
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||||
|
|
||||||
|
def get_oauth_client_schema(self) -> list[ProviderConfig]:
|
||||||
|
"""
|
||||||
|
returns the oauth client schema of the provider
|
||||||
|
|
||||||
|
:return: the oauth client schema
|
||||||
|
"""
|
||||||
|
return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else []
|
||||||
|
|
||||||
|
def get_supported_credential_types(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
returns the credential support type of the provider
|
||||||
|
"""
|
||||||
|
types = []
|
||||||
|
if self.entity.credentials_schema is not None and len(self.entity.credentials_schema) > 0:
|
||||||
|
types.append(CredentialType.API_KEY.value)
|
||||||
|
if self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) > 0:
|
||||||
|
types.append(CredentialType.OAUTH2.value)
|
||||||
|
return types
|
||||||
|
|
||||||
def get_tools(self) -> list[BuiltinTool]:
|
def get_tools(self) -> list[BuiltinTool]:
|
||||||
"""
|
"""
|
||||||
@@ -123,7 +166,11 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||||||
|
|
||||||
:return: whether the provider needs credentials
|
:return: whether the provider needs credentials
|
||||||
"""
|
"""
|
||||||
return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0
|
return (
|
||||||
|
self.entity.credentials_schema is not None
|
||||||
|
and len(self.entity.credentials_schema) != 0
|
||||||
|
or (self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) != 0)
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider_type(self) -> ToolProviderType:
|
def provider_type(self) -> ToolProviderType:
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from pydantic import BaseModel, Field, field_validator
|
|||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.tools.__base.tool import ToolParameter
|
from core.tools.__base.tool import ToolParameter
|
||||||
from core.tools.entities.common_entities import I18nObject
|
from core.tools.entities.common_entities import I18nObject
|
||||||
from core.tools.entities.tool_entities import ToolProviderType
|
from core.tools.entities.tool_entities import CredentialType, ToolProviderType
|
||||||
|
|
||||||
|
|
||||||
class ToolApiEntity(BaseModel):
|
class ToolApiEntity(BaseModel):
|
||||||
@@ -87,3 +87,22 @@ class ToolProviderApiEntity(BaseModel):
|
|||||||
def optional_field(self, key: str, value: Any) -> dict:
|
def optional_field(self, key: str, value: Any) -> dict:
|
||||||
"""Return dict with key-value if value is truthy, empty dict otherwise."""
|
"""Return dict with key-value if value is truthy, empty dict otherwise."""
|
||||||
return {key: value} if value else {}
|
return {key: value} if value else {}
|
||||||
|
|
||||||
|
|
||||||
|
class ToolProviderCredentialApiEntity(BaseModel):
|
||||||
|
id: str = Field(description="The unique id of the credential")
|
||||||
|
name: str = Field(description="The name of the credential")
|
||||||
|
provider: str = Field(description="The provider of the credential")
|
||||||
|
credential_type: CredentialType = Field(description="The type of the credential")
|
||||||
|
is_default: bool = Field(
|
||||||
|
default=False, description="Whether the credential is the default credential for the provider in the workspace"
|
||||||
|
)
|
||||||
|
credentials: dict = Field(description="The credentials of the provider")
|
||||||
|
|
||||||
|
|
||||||
|
class ToolProviderCredentialInfoApiEntity(BaseModel):
|
||||||
|
supported_credential_types: list[str] = Field(description="The supported credential types of the provider")
|
||||||
|
is_oauth_custom_client_enabled: bool = Field(
|
||||||
|
default=False, description="Whether the OAuth custom client is enabled for the provider"
|
||||||
|
)
|
||||||
|
credentials: list[ToolProviderCredentialApiEntity] = Field(description="The credentials of the provider")
|
||||||
|
|||||||
@@ -370,10 +370,18 @@ class ToolEntity(BaseModel):
|
|||||||
return v or []
|
return v or []
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthSchema(BaseModel):
|
||||||
|
client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client")
|
||||||
|
credentials_schema: list[ProviderConfig] = Field(
|
||||||
|
default_factory=list, description="The schema of the OAuth credentials"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ToolProviderEntity(BaseModel):
|
class ToolProviderEntity(BaseModel):
|
||||||
identity: ToolProviderIdentity
|
identity: ToolProviderIdentity
|
||||||
plugin_id: Optional[str] = None
|
plugin_id: Optional[str] = None
|
||||||
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
|
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
|
||||||
|
oauth_schema: Optional[OAuthSchema] = None
|
||||||
|
|
||||||
|
|
||||||
class ToolProviderEntityWithPlugin(ToolProviderEntity):
|
class ToolProviderEntityWithPlugin(ToolProviderEntity):
|
||||||
@@ -453,6 +461,7 @@ class ToolSelector(BaseModel):
|
|||||||
options: Optional[list[PluginParameterOption]] = None
|
options: Optional[list[PluginParameterOption]] = None
|
||||||
|
|
||||||
provider_id: str = Field(..., description="The id of the provider")
|
provider_id: str = Field(..., description="The id of the provider")
|
||||||
|
credential_id: Optional[str] = Field(default=None, description="The id of the credential")
|
||||||
tool_name: str = Field(..., description="The name of the tool")
|
tool_name: str = Field(..., description="The name of the tool")
|
||||||
tool_description: str = Field(..., description="The description of the tool")
|
tool_description: str = Field(..., description="The description of the tool")
|
||||||
tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form")
|
tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form")
|
||||||
@@ -460,3 +469,36 @@ class ToolSelector(BaseModel):
|
|||||||
|
|
||||||
def to_plugin_parameter(self) -> dict[str, Any]:
|
def to_plugin_parameter(self) -> dict[str, Any]:
|
||||||
return self.model_dump()
|
return self.model_dump()
|
||||||
|
|
||||||
|
|
||||||
|
class CredentialType(enum.StrEnum):
|
||||||
|
API_KEY = "api-key"
|
||||||
|
OAUTH2 = "oauth2"
|
||||||
|
|
||||||
|
def get_name(self):
|
||||||
|
if self == CredentialType.API_KEY:
|
||||||
|
return "API KEY"
|
||||||
|
elif self == CredentialType.OAUTH2:
|
||||||
|
return "AUTH"
|
||||||
|
else:
|
||||||
|
return self.value.replace("-", " ").upper()
|
||||||
|
|
||||||
|
def is_editable(self):
|
||||||
|
return self == CredentialType.API_KEY
|
||||||
|
|
||||||
|
def is_validate_allowed(self):
|
||||||
|
return self == CredentialType.API_KEY
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def values(cls):
|
||||||
|
return [item.value for item in cls]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def of(cls, credential_type: str) -> "CredentialType":
|
||||||
|
type_name = credential_type.lower()
|
||||||
|
if type_name == "api-key":
|
||||||
|
return cls.API_KEY
|
||||||
|
elif type_name == "oauth2":
|
||||||
|
return cls.OAUTH2
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ class PluginTool(Tool):
|
|||||||
tool_provider=self.entity.identity.provider,
|
tool_provider=self.entity.identity.provider,
|
||||||
tool_name=self.entity.identity.name,
|
tool_name=self.entity.identity.name,
|
||||||
credentials=self.runtime.credentials,
|
credentials=self.runtime.credentials,
|
||||||
|
credential_type=self.runtime.credential_type,
|
||||||
tool_parameters=tool_parameters,
|
tool_parameters=tool_parameters,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
|
|||||||
from yarl import URL
|
from yarl import URL
|
||||||
|
|
||||||
import contexts
|
import contexts
|
||||||
|
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||||
from core.plugin.entities.plugin import ToolProviderID
|
from core.plugin.entities.plugin import ToolProviderID
|
||||||
from core.plugin.impl.tool import PluginToolManager
|
from core.plugin.impl.tool import PluginToolManager
|
||||||
from core.tools.__base.tool_provider import ToolProviderController
|
from core.tools.__base.tool_provider import ToolProviderController
|
||||||
@@ -17,6 +18,7 @@ from core.tools.mcp_tool.provider import MCPToolProviderController
|
|||||||
from core.tools.mcp_tool.tool import MCPTool
|
from core.tools.mcp_tool.tool import MCPTool
|
||||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||||
from core.tools.plugin_tool.tool import PluginTool
|
from core.tools.plugin_tool.tool import PluginTool
|
||||||
|
from core.tools.utils.uuid_utils import is_valid_uuid
|
||||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from services.tools.mcp_tools_mange_service import MCPToolManageService
|
from services.tools.mcp_tools_mange_service import MCPToolManageService
|
||||||
@@ -24,7 +26,6 @@ from services.tools.mcp_tools_mange_service import MCPToolManageService
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from core.workflow.nodes.tool.entities import ToolEntity
|
from core.workflow.nodes.tool.entities import ToolEntity
|
||||||
|
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.agent.entities import AgentToolEntity
|
from core.agent.entities import AgentToolEntity
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
@@ -41,16 +42,17 @@ from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProvider
|
|||||||
from core.tools.entities.common_entities import I18nObject
|
from core.tools.entities.common_entities import I18nObject
|
||||||
from core.tools.entities.tool_entities import (
|
from core.tools.entities.tool_entities import (
|
||||||
ApiProviderAuthType,
|
ApiProviderAuthType,
|
||||||
|
CredentialType,
|
||||||
ToolInvokeFrom,
|
ToolInvokeFrom,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
ToolProviderType,
|
ToolProviderType,
|
||||||
)
|
)
|
||||||
from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError
|
from core.tools.errors import ToolProviderNotFoundError
|
||||||
from core.tools.tool_label_manager import ToolLabelManager
|
from core.tools.tool_label_manager import ToolLabelManager
|
||||||
from core.tools.utils.configuration import (
|
from core.tools.utils.configuration import (
|
||||||
ProviderConfigEncrypter,
|
|
||||||
ToolParameterConfigurationManager,
|
ToolParameterConfigurationManager,
|
||||||
)
|
)
|
||||||
|
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
|
||||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
||||||
@@ -68,8 +70,11 @@ class ToolManager:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
|
def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
get the hardcoded provider
|
get the hardcoded provider
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if len(cls._hardcoded_providers) == 0:
|
if len(cls._hardcoded_providers) == 0:
|
||||||
# init the builtin providers
|
# init the builtin providers
|
||||||
cls.load_hardcoded_providers_cache()
|
cls.load_hardcoded_providers_cache()
|
||||||
@@ -113,7 +118,12 @@ class ToolManager:
|
|||||||
contexts.plugin_tool_providers.set({})
|
contexts.plugin_tool_providers.set({})
|
||||||
contexts.plugin_tool_providers_lock.set(Lock())
|
contexts.plugin_tool_providers_lock.set(Lock())
|
||||||
|
|
||||||
|
plugin_tool_providers = contexts.plugin_tool_providers.get()
|
||||||
|
if provider in plugin_tool_providers:
|
||||||
|
return plugin_tool_providers[provider]
|
||||||
|
|
||||||
with contexts.plugin_tool_providers_lock.get():
|
with contexts.plugin_tool_providers_lock.get():
|
||||||
|
# double check
|
||||||
plugin_tool_providers = contexts.plugin_tool_providers.get()
|
plugin_tool_providers = contexts.plugin_tool_providers.get()
|
||||||
if provider in plugin_tool_providers:
|
if provider in plugin_tool_providers:
|
||||||
return plugin_tool_providers[provider]
|
return plugin_tool_providers[provider]
|
||||||
@@ -131,25 +141,7 @@ class ToolManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
plugin_tool_providers[provider] = controller
|
plugin_tool_providers[provider] = controller
|
||||||
|
return controller
|
||||||
return controller
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None:
|
|
||||||
"""
|
|
||||||
get the builtin tool
|
|
||||||
|
|
||||||
:param provider: the name of the provider
|
|
||||||
:param tool_name: the name of the tool
|
|
||||||
:param tenant_id: the id of the tenant
|
|
||||||
:return: the provider, the tool
|
|
||||||
"""
|
|
||||||
provider_controller = cls.get_builtin_provider(provider, tenant_id)
|
|
||||||
tool = provider_controller.get_tool(tool_name)
|
|
||||||
if tool is None:
|
|
||||||
raise ToolNotFoundError(f"tool {tool_name} not found")
|
|
||||||
|
|
||||||
return tool
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_tool_runtime(
|
def get_tool_runtime(
|
||||||
@@ -160,6 +152,7 @@ class ToolManager:
|
|||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||||
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
|
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
|
||||||
|
credential_id: Optional[str] = None,
|
||||||
) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]:
|
) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]:
|
||||||
"""
|
"""
|
||||||
get the tool runtime
|
get the tool runtime
|
||||||
@@ -170,6 +163,7 @@ class ToolManager:
|
|||||||
:param tenant_id: the tenant id
|
:param tenant_id: the tenant id
|
||||||
:param invoke_from: invoke from
|
:param invoke_from: invoke from
|
||||||
:param tool_invoke_from: the tool invoke from
|
:param tool_invoke_from: the tool invoke from
|
||||||
|
:param credential_id: the credential id
|
||||||
|
|
||||||
:return: the tool
|
:return: the tool
|
||||||
"""
|
"""
|
||||||
@@ -193,49 +187,70 @@ class ToolManager:
|
|||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
builtin_provider = None
|
||||||
if isinstance(provider_controller, PluginToolProviderController):
|
if isinstance(provider_controller, PluginToolProviderController):
|
||||||
provider_id_entity = ToolProviderID(provider_id)
|
provider_id_entity = ToolProviderID(provider_id)
|
||||||
# get credentials
|
# get specific credentials
|
||||||
builtin_provider: BuiltinToolProvider | None = (
|
if is_valid_uuid(credential_id):
|
||||||
db.session.query(BuiltinToolProvider)
|
try:
|
||||||
.filter(
|
builtin_provider = (
|
||||||
BuiltinToolProvider.tenant_id == tenant_id,
|
db.session.query(BuiltinToolProvider)
|
||||||
(BuiltinToolProvider.provider == str(provider_id_entity))
|
.filter(
|
||||||
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
|
BuiltinToolProvider.tenant_id == tenant_id,
|
||||||
)
|
BuiltinToolProvider.id == credential_id,
|
||||||
.first()
|
)
|
||||||
)
|
.first()
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
builtin_provider = None
|
||||||
|
logger.info(f"Error getting builtin provider {credential_id}:{e}", exc_info=True)
|
||||||
|
# if the provider has been deleted, raise an error
|
||||||
|
if builtin_provider is None:
|
||||||
|
raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}")
|
||||||
|
|
||||||
|
# fallback to the default provider
|
||||||
if builtin_provider is None:
|
if builtin_provider is None:
|
||||||
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
|
# use the default provider
|
||||||
|
builtin_provider = (
|
||||||
|
db.session.query(BuiltinToolProvider)
|
||||||
|
.filter(
|
||||||
|
BuiltinToolProvider.tenant_id == tenant_id,
|
||||||
|
(BuiltinToolProvider.provider == str(provider_id_entity))
|
||||||
|
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
|
||||||
|
)
|
||||||
|
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if builtin_provider is None:
|
||||||
|
raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
|
||||||
else:
|
else:
|
||||||
builtin_provider = (
|
builtin_provider = (
|
||||||
db.session.query(BuiltinToolProvider)
|
db.session.query(BuiltinToolProvider)
|
||||||
.filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
|
.filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
|
||||||
|
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
if builtin_provider is None:
|
if builtin_provider is None:
|
||||||
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
|
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
|
||||||
|
|
||||||
# decrypt the credentials
|
encrypter, _ = create_provider_encrypter(
|
||||||
credentials = builtin_provider.credentials
|
|
||||||
tool_configuration = ProviderConfigEncrypter(
|
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
config=[
|
||||||
provider_type=provider_controller.provider_type.value,
|
x.to_basic_provider_config()
|
||||||
provider_identity=provider_controller.entity.identity.name,
|
for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type)
|
||||||
|
],
|
||||||
|
cache=ToolProviderCredentialsCache(
|
||||||
|
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
decrypted_credentials = tool_configuration.decrypt(credentials)
|
|
||||||
|
|
||||||
return cast(
|
return cast(
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
builtin_tool.fork_tool_runtime(
|
builtin_tool.fork_tool_runtime(
|
||||||
runtime=ToolRuntime(
|
runtime=ToolRuntime(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
credentials=decrypted_credentials,
|
credentials=encrypter.decrypt(builtin_provider.credentials),
|
||||||
|
credential_type=CredentialType.of(builtin_provider.credential_type),
|
||||||
runtime_parameters={},
|
runtime_parameters={},
|
||||||
invoke_from=invoke_from,
|
invoke_from=invoke_from,
|
||||||
tool_invoke_from=tool_invoke_from,
|
tool_invoke_from=tool_invoke_from,
|
||||||
@@ -245,22 +260,16 @@ class ToolManager:
|
|||||||
|
|
||||||
elif provider_type == ToolProviderType.API:
|
elif provider_type == ToolProviderType.API:
|
||||||
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
|
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
|
||||||
|
encrypter, _ = create_tool_provider_encrypter(
|
||||||
# decrypt the credentials
|
|
||||||
tool_configuration = ProviderConfigEncrypter(
|
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()],
|
controller=api_provider,
|
||||||
provider_type=api_provider.provider_type.value,
|
|
||||||
provider_identity=api_provider.entity.identity.name,
|
|
||||||
)
|
)
|
||||||
decrypted_credentials = tool_configuration.decrypt(credentials)
|
|
||||||
|
|
||||||
return cast(
|
return cast(
|
||||||
ApiTool,
|
ApiTool,
|
||||||
api_provider.get_tool(tool_name).fork_tool_runtime(
|
api_provider.get_tool(tool_name).fork_tool_runtime(
|
||||||
runtime=ToolRuntime(
|
runtime=ToolRuntime(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
credentials=decrypted_credentials,
|
credentials=encrypter.decrypt(credentials),
|
||||||
invoke_from=invoke_from,
|
invoke_from=invoke_from,
|
||||||
tool_invoke_from=tool_invoke_from,
|
tool_invoke_from=tool_invoke_from,
|
||||||
)
|
)
|
||||||
@@ -320,6 +329,7 @@ class ToolManager:
|
|||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
invoke_from=invoke_from,
|
invoke_from=invoke_from,
|
||||||
tool_invoke_from=ToolInvokeFrom.AGENT,
|
tool_invoke_from=ToolInvokeFrom.AGENT,
|
||||||
|
credential_id=agent_tool.credential_id,
|
||||||
)
|
)
|
||||||
runtime_parameters = {}
|
runtime_parameters = {}
|
||||||
parameters = tool_entity.get_merged_runtime_parameters()
|
parameters = tool_entity.get_merged_runtime_parameters()
|
||||||
@@ -362,6 +372,7 @@ class ToolManager:
|
|||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
invoke_from=invoke_from,
|
invoke_from=invoke_from,
|
||||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
||||||
|
credential_id=workflow_tool.credential_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
parameters = tool_runtime.get_merged_runtime_parameters()
|
parameters = tool_runtime.get_merged_runtime_parameters()
|
||||||
@@ -391,6 +402,7 @@ class ToolManager:
|
|||||||
provider: str,
|
provider: str,
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
tool_parameters: dict[str, Any],
|
tool_parameters: dict[str, Any],
|
||||||
|
credential_id: Optional[str] = None,
|
||||||
) -> Tool:
|
) -> Tool:
|
||||||
"""
|
"""
|
||||||
get tool runtime from plugin
|
get tool runtime from plugin
|
||||||
@@ -402,6 +414,7 @@ class ToolManager:
|
|||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
invoke_from=InvokeFrom.SERVICE_API,
|
invoke_from=InvokeFrom.SERVICE_API,
|
||||||
tool_invoke_from=ToolInvokeFrom.PLUGIN,
|
tool_invoke_from=ToolInvokeFrom.PLUGIN,
|
||||||
|
credential_id=credential_id,
|
||||||
)
|
)
|
||||||
runtime_parameters = {}
|
runtime_parameters = {}
|
||||||
parameters = tool_entity.get_merged_runtime_parameters()
|
parameters = tool_entity.get_merged_runtime_parameters()
|
||||||
@@ -551,6 +564,22 @@ class ToolManager:
|
|||||||
|
|
||||||
return cls._builtin_tools_labels[tool_name]
|
return cls._builtin_tools_labels[tool_name]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_default_builtin_providers(cls, tenant_id: str) -> list[BuiltinToolProvider]:
|
||||||
|
"""
|
||||||
|
list all the builtin providers
|
||||||
|
"""
|
||||||
|
# according to multi credentials, select the one with is_default=True first, then created_at oldest
|
||||||
|
# for compatibility with old version
|
||||||
|
sql = """
|
||||||
|
SELECT DISTINCT ON (tenant_id, provider) id
|
||||||
|
FROM tool_builtin_providers
|
||||||
|
WHERE tenant_id = :tenant_id
|
||||||
|
ORDER BY tenant_id, provider, is_default DESC, created_at DESC
|
||||||
|
"""
|
||||||
|
ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()]
|
||||||
|
return db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.id.in_(ids)).all()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def list_providers_from_api(
|
def list_providers_from_api(
|
||||||
cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral
|
cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral
|
||||||
@@ -565,21 +594,13 @@ class ToolManager:
|
|||||||
|
|
||||||
with db.session.no_autoflush:
|
with db.session.no_autoflush:
|
||||||
if "builtin" in filters:
|
if "builtin" in filters:
|
||||||
# get builtin providers
|
|
||||||
builtin_providers = cls.list_builtin_providers(tenant_id)
|
builtin_providers = cls.list_builtin_providers(tenant_id)
|
||||||
|
|
||||||
# get db builtin providers
|
# key: provider name, value: provider
|
||||||
db_builtin_providers: list[BuiltinToolProvider] = (
|
db_builtin_providers = {
|
||||||
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
|
str(ToolProviderID(provider.provider)): provider
|
||||||
)
|
for provider in cls.list_default_builtin_providers(tenant_id)
|
||||||
|
}
|
||||||
# rewrite db_builtin_providers
|
|
||||||
for db_provider in db_builtin_providers:
|
|
||||||
tool_provider_id = str(ToolProviderID(db_provider.provider))
|
|
||||||
db_provider.provider = tool_provider_id
|
|
||||||
|
|
||||||
def find_db_builtin_provider(provider):
|
|
||||||
return next((x for x in db_builtin_providers if x.provider == provider), None)
|
|
||||||
|
|
||||||
# append builtin providers
|
# append builtin providers
|
||||||
for provider in builtin_providers:
|
for provider in builtin_providers:
|
||||||
@@ -591,10 +612,9 @@ class ToolManager:
|
|||||||
name_func=lambda x: x.identity.name,
|
name_func=lambda x: x.identity.name,
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
user_provider = ToolTransformService.builtin_provider_to_user_provider(
|
user_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||||
provider_controller=provider,
|
provider_controller=provider,
|
||||||
db_provider=find_db_builtin_provider(provider.entity.identity.name),
|
db_provider=db_builtin_providers.get(provider.entity.identity.name),
|
||||||
decrypt_credentials=False,
|
decrypt_credentials=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -604,7 +624,6 @@ class ToolManager:
|
|||||||
result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
|
result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
|
||||||
|
|
||||||
# get db api providers
|
# get db api providers
|
||||||
|
|
||||||
if "api" in filters:
|
if "api" in filters:
|
||||||
db_api_providers: list[ApiToolProvider] = (
|
db_api_providers: list[ApiToolProvider] = (
|
||||||
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
|
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
|
||||||
@@ -764,15 +783,12 @@ class ToolManager:
|
|||||||
auth_type,
|
auth_type,
|
||||||
)
|
)
|
||||||
# init tool configuration
|
# init tool configuration
|
||||||
tool_configuration = ProviderConfigEncrypter(
|
encrypter, _ = create_tool_provider_encrypter(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
|
controller=controller,
|
||||||
provider_type=controller.provider_type.value,
|
|
||||||
provider_identity=controller.entity.identity.name,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
decrypted_credentials = tool_configuration.decrypt(credentials)
|
masked_credentials = encrypter.mask_tool_credentials(encrypter.decrypt(credentials))
|
||||||
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
icon = json.loads(provider_obj.icon)
|
icon = json.loads(provider_obj.icon)
|
||||||
|
|||||||
@@ -1,12 +1,8 @@
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from core.entities.provider_entities import BasicProviderConfig
|
|
||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
|
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
|
||||||
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
|
|
||||||
from core.tools.__base.tool import Tool
|
from core.tools.__base.tool import Tool
|
||||||
from core.tools.entities.tool_entities import (
|
from core.tools.entities.tool_entities import (
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
@@ -14,110 +10,6 @@ from core.tools.entities.tool_entities import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ProviderConfigEncrypter(BaseModel):
|
|
||||||
tenant_id: str
|
|
||||||
config: list[BasicProviderConfig]
|
|
||||||
provider_type: str
|
|
||||||
provider_identity: str
|
|
||||||
|
|
||||||
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
|
|
||||||
"""
|
|
||||||
deep copy data
|
|
||||||
"""
|
|
||||||
return deepcopy(data)
|
|
||||||
|
|
||||||
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
|
|
||||||
"""
|
|
||||||
encrypt tool credentials with tenant id
|
|
||||||
|
|
||||||
return a deep copy of credentials with encrypted values
|
|
||||||
"""
|
|
||||||
data = self._deep_copy(data)
|
|
||||||
|
|
||||||
# get fields need to be decrypted
|
|
||||||
fields = dict[str, BasicProviderConfig]()
|
|
||||||
for credential in self.config:
|
|
||||||
fields[credential.name] = credential
|
|
||||||
|
|
||||||
for field_name, field in fields.items():
|
|
||||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
|
||||||
if field_name in data:
|
|
||||||
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
|
|
||||||
data[field_name] = encrypted
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
mask tool credentials
|
|
||||||
|
|
||||||
return a deep copy of credentials with masked values
|
|
||||||
"""
|
|
||||||
data = self._deep_copy(data)
|
|
||||||
|
|
||||||
# get fields need to be decrypted
|
|
||||||
fields = dict[str, BasicProviderConfig]()
|
|
||||||
for credential in self.config:
|
|
||||||
fields[credential.name] = credential
|
|
||||||
|
|
||||||
for field_name, field in fields.items():
|
|
||||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
|
||||||
if field_name in data:
|
|
||||||
if len(data[field_name]) > 6:
|
|
||||||
data[field_name] = (
|
|
||||||
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
data[field_name] = "*" * len(data[field_name])
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
def decrypt(self, data: dict[str, str], use_cache: bool = True) -> dict[str, str]:
|
|
||||||
"""
|
|
||||||
decrypt tool credentials with tenant id
|
|
||||||
|
|
||||||
return a deep copy of credentials with decrypted values
|
|
||||||
"""
|
|
||||||
if use_cache:
|
|
||||||
cache = ToolProviderCredentialsCache(
|
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
identity_id=f"{self.provider_type}.{self.provider_identity}",
|
|
||||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
|
|
||||||
)
|
|
||||||
cached_credentials = cache.get()
|
|
||||||
if cached_credentials:
|
|
||||||
return cached_credentials
|
|
||||||
data = self._deep_copy(data)
|
|
||||||
# get fields need to be decrypted
|
|
||||||
fields = dict[str, BasicProviderConfig]()
|
|
||||||
for credential in self.config:
|
|
||||||
fields[credential.name] = credential
|
|
||||||
|
|
||||||
for field_name, field in fields.items():
|
|
||||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
|
||||||
if field_name in data:
|
|
||||||
try:
|
|
||||||
# if the value is None or empty string, skip decrypt
|
|
||||||
if not data[field_name]:
|
|
||||||
continue
|
|
||||||
|
|
||||||
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
cache.set(data)
|
|
||||||
return data
|
|
||||||
|
|
||||||
def delete_tool_credentials_cache(self):
|
|
||||||
cache = ToolProviderCredentialsCache(
|
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
identity_id=f"{self.provider_type}.{self.provider_identity}",
|
|
||||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
|
|
||||||
)
|
|
||||||
cache.delete()
|
|
||||||
|
|
||||||
|
|
||||||
class ToolParameterConfigurationManager:
|
class ToolParameterConfigurationManager:
|
||||||
"""
|
"""
|
||||||
Tool parameter configuration manager
|
Tool parameter configuration manager
|
||||||
|
|||||||
142
api/core/tools/utils/encryption.py
Normal file
142
api/core/tools/utils/encryption.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
from copy import deepcopy
|
||||||
|
from typing import Any, Optional, Protocol
|
||||||
|
|
||||||
|
from core.entities.provider_entities import BasicProviderConfig
|
||||||
|
from core.helper import encrypter
|
||||||
|
from core.helper.provider_cache import SingletonProviderCredentialsCache
|
||||||
|
from core.tools.__base.tool_provider import ToolProviderController
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderConfigCache(Protocol):
|
||||||
|
"""
|
||||||
|
Interface for provider configuration cache operations
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get(self) -> Optional[dict]:
|
||||||
|
"""Get cached provider configuration"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def set(self, config: dict[str, Any]) -> None:
|
||||||
|
"""Cache provider configuration"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def delete(self) -> None:
|
||||||
|
"""Delete cached provider configuration"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderConfigEncrypter:
|
||||||
|
tenant_id: str
|
||||||
|
config: list[BasicProviderConfig]
|
||||||
|
provider_config_cache: ProviderConfigCache
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
config: list[BasicProviderConfig],
|
||||||
|
provider_config_cache: ProviderConfigCache,
|
||||||
|
):
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.config = config
|
||||||
|
self.provider_config_cache = provider_config_cache
|
||||||
|
|
||||||
|
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
deep copy data
|
||||||
|
"""
|
||||||
|
return deepcopy(data)
|
||||||
|
|
||||||
|
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
encrypt tool credentials with tenant id
|
||||||
|
|
||||||
|
return a deep copy of credentials with encrypted values
|
||||||
|
"""
|
||||||
|
data = self._deep_copy(data)
|
||||||
|
|
||||||
|
# get fields need to be decrypted
|
||||||
|
fields = dict[str, BasicProviderConfig]()
|
||||||
|
for credential in self.config:
|
||||||
|
fields[credential.name] = credential
|
||||||
|
|
||||||
|
for field_name, field in fields.items():
|
||||||
|
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||||
|
if field_name in data:
|
||||||
|
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
|
||||||
|
data[field_name] = encrypted
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
mask tool credentials
|
||||||
|
|
||||||
|
return a deep copy of credentials with masked values
|
||||||
|
"""
|
||||||
|
data = self._deep_copy(data)
|
||||||
|
|
||||||
|
# get fields need to be decrypted
|
||||||
|
fields = dict[str, BasicProviderConfig]()
|
||||||
|
for credential in self.config:
|
||||||
|
fields[credential.name] = credential
|
||||||
|
|
||||||
|
for field_name, field in fields.items():
|
||||||
|
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||||
|
if field_name in data:
|
||||||
|
if len(data[field_name]) > 6:
|
||||||
|
data[field_name] = (
|
||||||
|
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
data[field_name] = "*" * len(data[field_name])
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
decrypt tool credentials with tenant id
|
||||||
|
|
||||||
|
return a deep copy of credentials with decrypted values
|
||||||
|
"""
|
||||||
|
cached_credentials = self.provider_config_cache.get()
|
||||||
|
if cached_credentials:
|
||||||
|
return cached_credentials
|
||||||
|
|
||||||
|
data = self._deep_copy(data)
|
||||||
|
# get fields need to be decrypted
|
||||||
|
fields = dict[str, BasicProviderConfig]()
|
||||||
|
for credential in self.config:
|
||||||
|
fields[credential.name] = credential
|
||||||
|
|
||||||
|
for field_name, field in fields.items():
|
||||||
|
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||||
|
if field_name in data:
|
||||||
|
try:
|
||||||
|
# if the value is None or empty string, skip decrypt
|
||||||
|
if not data[field_name]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.provider_config_cache.set(data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
|
||||||
|
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
|
||||||
|
|
||||||
|
|
||||||
|
def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController):
|
||||||
|
cache = SingletonProviderCredentialsCache(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider_type=controller.provider_type.value,
|
||||||
|
provider_identity=controller.entity.identity.name,
|
||||||
|
)
|
||||||
|
encrypt = ProviderConfigEncrypter(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
|
||||||
|
provider_config_cache=cache,
|
||||||
|
)
|
||||||
|
return encrypt, cache
|
||||||
187
api/core/tools/utils/system_oauth_encryption.py
Normal file
187
api/core/tools/utils/system_oauth_encryption.py
Normal file
@@ -0,0 +1,187 @@
|
|||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from Crypto.Cipher import AES
|
||||||
|
from Crypto.Random import get_random_bytes
|
||||||
|
from Crypto.Util.Padding import pad, unpad
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthEncryptionError(Exception):
|
||||||
|
"""OAuth encryption/decryption specific error"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SystemOAuthEncrypter:
|
||||||
|
"""
|
||||||
|
A simple OAuth parameters encrypter using AES-CBC encryption.
|
||||||
|
|
||||||
|
This class provides methods to encrypt and decrypt OAuth parameters
|
||||||
|
using AES-CBC mode with a key derived from the application's SECRET_KEY.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, secret_key: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Initialize the OAuth encrypter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If SECRET_KEY is not configured or empty
|
||||||
|
"""
|
||||||
|
secret_key = secret_key or dify_config.SECRET_KEY or ""
|
||||||
|
|
||||||
|
# Generate a fixed 256-bit key using SHA-256
|
||||||
|
self.key = hashlib.sha256(secret_key.encode()).digest()
|
||||||
|
|
||||||
|
def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str:
|
||||||
|
"""
|
||||||
|
Encrypt OAuth parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Base64-encoded encrypted string
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
OAuthEncryptionError: If encryption fails
|
||||||
|
ValueError: If oauth_params is invalid
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Generate random IV (16 bytes)
|
||||||
|
iv = get_random_bytes(16)
|
||||||
|
|
||||||
|
# Create AES cipher (CBC mode)
|
||||||
|
cipher = AES.new(self.key, AES.MODE_CBC, iv)
|
||||||
|
|
||||||
|
# Encrypt data
|
||||||
|
padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size)
|
||||||
|
encrypted_data = cipher.encrypt(padded_data)
|
||||||
|
|
||||||
|
# Combine IV and encrypted data
|
||||||
|
combined = iv + encrypted_data
|
||||||
|
|
||||||
|
# Return base64 encoded string
|
||||||
|
return base64.b64encode(combined).decode()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e
|
||||||
|
|
||||||
|
def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]:
|
||||||
|
"""
|
||||||
|
Decrypt OAuth parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encrypted_data: Base64-encoded encrypted string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decrypted OAuth parameters dictionary
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
OAuthEncryptionError: If decryption fails
|
||||||
|
ValueError: If encrypted_data is invalid
|
||||||
|
"""
|
||||||
|
if not isinstance(encrypted_data, str):
|
||||||
|
raise ValueError("encrypted_data must be a string")
|
||||||
|
|
||||||
|
if not encrypted_data:
|
||||||
|
raise ValueError("encrypted_data cannot be empty")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Base64 decode
|
||||||
|
combined = base64.b64decode(encrypted_data)
|
||||||
|
|
||||||
|
# Check minimum length (IV + at least one AES block)
|
||||||
|
if len(combined) < 32: # 16 bytes IV + 16 bytes minimum encrypted data
|
||||||
|
raise ValueError("Invalid encrypted data format")
|
||||||
|
|
||||||
|
# Separate IV and encrypted data
|
||||||
|
iv = combined[:16]
|
||||||
|
encrypted_data_bytes = combined[16:]
|
||||||
|
|
||||||
|
# Create AES cipher
|
||||||
|
cipher = AES.new(self.key, AES.MODE_CBC, iv)
|
||||||
|
|
||||||
|
# Decrypt data
|
||||||
|
decrypted_data = cipher.decrypt(encrypted_data_bytes)
|
||||||
|
unpadded_data = unpad(decrypted_data, AES.block_size)
|
||||||
|
|
||||||
|
# Parse JSON
|
||||||
|
oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
|
||||||
|
|
||||||
|
if not isinstance(oauth_params, dict):
|
||||||
|
raise ValueError("Decrypted data is not a valid dictionary")
|
||||||
|
|
||||||
|
return oauth_params
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
|
# Factory function for creating encrypter instances
|
||||||
|
def create_system_oauth_encrypter(secret_key: Optional[str] = None) -> SystemOAuthEncrypter:
|
||||||
|
"""
|
||||||
|
Create an OAuth encrypter instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SystemOAuthEncrypter instance
|
||||||
|
"""
|
||||||
|
return SystemOAuthEncrypter(secret_key=secret_key)
|
||||||
|
|
||||||
|
|
||||||
|
# Global encrypter instance (for backward compatibility)
|
||||||
|
_oauth_encrypter: Optional[SystemOAuthEncrypter] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_system_oauth_encrypter() -> SystemOAuthEncrypter:
|
||||||
|
"""
|
||||||
|
Get the global OAuth encrypter instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SystemOAuthEncrypter instance
|
||||||
|
"""
|
||||||
|
global _oauth_encrypter
|
||||||
|
if _oauth_encrypter is None:
|
||||||
|
_oauth_encrypter = SystemOAuthEncrypter()
|
||||||
|
return _oauth_encrypter
|
||||||
|
|
||||||
|
|
||||||
|
# Convenience functions for backward compatibility
|
||||||
|
def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str:
|
||||||
|
"""
|
||||||
|
Encrypt OAuth parameters using the global encrypter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
oauth_params: OAuth parameters dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Base64-encoded encrypted string
|
||||||
|
"""
|
||||||
|
return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params)
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]:
|
||||||
|
"""
|
||||||
|
Decrypt OAuth parameters using the global encrypter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encrypted_data: Base64-encoded encrypted string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decrypted OAuth parameters dictionary
|
||||||
|
"""
|
||||||
|
return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data)
|
||||||
@@ -1,7 +1,9 @@
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
def is_valid_uuid(uuid_str: str) -> bool:
|
def is_valid_uuid(uuid_str: str | None) -> bool:
|
||||||
|
if uuid_str is None or len(uuid_str) == 0:
|
||||||
|
return False
|
||||||
try:
|
try:
|
||||||
uuid.UUID(uuid_str)
|
uuid.UUID(uuid_str)
|
||||||
return True
|
return True
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from collections.abc import Generator, Mapping, Sequence
|
|||||||
from typing import Any, Optional, cast
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
|
from pydantic import ValidationError
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@@ -13,10 +14,16 @@ from core.agent.strategy.plugin import PluginAgentStrategy
|
|||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_manager import ModelInstance, ModelManager
|
from core.model_manager import ModelInstance, ModelManager
|
||||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||||
|
from core.plugin.entities.request import InvokeCredentials
|
||||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||||
from core.plugin.impl.plugin import PluginInstaller
|
from core.plugin.impl.plugin import PluginInstaller
|
||||||
from core.provider_manager import ProviderManager
|
from core.provider_manager import ProviderManager
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
|
from core.tools.entities.tool_entities import (
|
||||||
|
ToolIdentity,
|
||||||
|
ToolInvokeMessage,
|
||||||
|
ToolParameter,
|
||||||
|
ToolProviderType,
|
||||||
|
)
|
||||||
from core.tools.tool_manager import ToolManager
|
from core.tools.tool_manager import ToolManager
|
||||||
from core.variables.segments import StringSegment
|
from core.variables.segments import StringSegment
|
||||||
from core.workflow.entities.node_entities import NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
@@ -84,6 +91,7 @@ class AgentNode(ToolNode):
|
|||||||
for_log=True,
|
for_log=True,
|
||||||
strategy=strategy,
|
strategy=strategy,
|
||||||
)
|
)
|
||||||
|
credentials = self._generate_credentials(parameters=parameters)
|
||||||
|
|
||||||
# get conversation id
|
# get conversation id
|
||||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||||
@@ -94,6 +102,7 @@ class AgentNode(ToolNode):
|
|||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
app_id=self.app_id,
|
app_id=self.app_id,
|
||||||
conversation_id=conversation_id.text if conversation_id else None,
|
conversation_id=conversation_id.text if conversation_id else None,
|
||||||
|
credentials=credentials,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield RunCompletedEvent(
|
yield RunCompletedEvent(
|
||||||
@@ -246,6 +255,7 @@ class AgentNode(ToolNode):
|
|||||||
tool_name=tool.get("tool_name", ""),
|
tool_name=tool.get("tool_name", ""),
|
||||||
tool_parameters=parameters,
|
tool_parameters=parameters,
|
||||||
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
|
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
|
||||||
|
credential_id=tool.get("credential_id", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
extra = tool.get("extra", {})
|
extra = tool.get("extra", {})
|
||||||
@@ -276,6 +286,7 @@ class AgentNode(ToolNode):
|
|||||||
{
|
{
|
||||||
**tool_runtime.entity.model_dump(mode="json"),
|
**tool_runtime.entity.model_dump(mode="json"),
|
||||||
"runtime_parameters": runtime_parameters,
|
"runtime_parameters": runtime_parameters,
|
||||||
|
"credential_id": tool.get("credential_id", None),
|
||||||
"provider_type": provider_type.value,
|
"provider_type": provider_type.value,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -305,6 +316,27 @@ class AgentNode(ToolNode):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def _generate_credentials(
|
||||||
|
self,
|
||||||
|
parameters: dict[str, Any],
|
||||||
|
) -> InvokeCredentials:
|
||||||
|
"""
|
||||||
|
Generate credentials based on the given agent parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
credentials = InvokeCredentials()
|
||||||
|
|
||||||
|
# generate credentials for tools selector
|
||||||
|
credentials.tool_credentials = {}
|
||||||
|
for tool in parameters.get("tools", []):
|
||||||
|
if tool.get("credential_id"):
|
||||||
|
try:
|
||||||
|
identity = ToolIdentity.model_validate(tool.get("identity", {}))
|
||||||
|
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
|
||||||
|
except ValidationError:
|
||||||
|
continue
|
||||||
|
return credentials
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _extract_variable_selector_to_variable_mapping(
|
def _extract_variable_selector_to_variable_mapping(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ class ToolEntity(BaseModel):
|
|||||||
tool_name: str
|
tool_name: str
|
||||||
tool_label: str # redundancy
|
tool_label: str # redundancy
|
||||||
tool_configurations: dict[str, Any]
|
tool_configurations: dict[str, Any]
|
||||||
|
credential_id: str | None = None
|
||||||
plugin_unique_identifier: str | None = None # redundancy
|
plugin_unique_identifier: str | None = None # redundancy
|
||||||
|
|
||||||
@field_validator("tool_configurations", mode="before")
|
@field_validator("tool_configurations", mode="before")
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ def handle(sender, **kwargs):
|
|||||||
provider_id=tool_entity.provider_id,
|
provider_id=tool_entity.provider_id,
|
||||||
tool_name=tool_entity.tool_name,
|
tool_name=tool_entity.tool_name,
|
||||||
tenant_id=app.tenant_id,
|
tenant_id=app.tenant_id,
|
||||||
|
credential_id=tool_entity.credential_id,
|
||||||
)
|
)
|
||||||
manager = ToolParameterConfigurationManager(
|
manager = ToolParameterConfigurationManager(
|
||||||
tenant_id=app.tenant_id,
|
tenant_id=app.tenant_id,
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ def init_app(app: DifyApp):
|
|||||||
reset_email,
|
reset_email,
|
||||||
reset_encrypt_key_pair,
|
reset_encrypt_key_pair,
|
||||||
reset_password,
|
reset_password,
|
||||||
|
setup_system_tool_oauth_client,
|
||||||
upgrade_db,
|
upgrade_db,
|
||||||
vdb_migrate,
|
vdb_migrate,
|
||||||
)
|
)
|
||||||
@@ -40,6 +41,7 @@ def init_app(app: DifyApp):
|
|||||||
clear_free_plan_tenant_expired_logs,
|
clear_free_plan_tenant_expired_logs,
|
||||||
clear_orphaned_file_records,
|
clear_orphaned_file_records,
|
||||||
remove_orphaned_files_on_storage,
|
remove_orphaned_files_on_storage,
|
||||||
|
setup_system_tool_oauth_client,
|
||||||
]
|
]
|
||||||
for cmd in cmds_to_register:
|
for cmd in cmds_to_register:
|
||||||
app.cli.add_command(cmd)
|
app.cli.add_command(cmd)
|
||||||
|
|||||||
41
api/migrations/versions/2025_05_15_1635-16081485540c_.py
Normal file
41
api/migrations/versions/2025_05_15_1635-16081485540c_.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
"""empty message
|
||||||
|
|
||||||
|
Revision ID: 16081485540c
|
||||||
|
Revises: d28f2004b072
|
||||||
|
Create Date: 2025-05-15 16:35:39.113777
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '16081485540c'
|
||||||
|
down_revision = '2adcbe1f5dfb'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('tenant_plugin_auto_upgrade_strategies',
|
||||||
|
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||||
|
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('strategy_setting', sa.String(length=16), server_default='fix_only', nullable=False),
|
||||||
|
sa.Column('upgrade_time_of_day', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('upgrade_mode', sa.String(length=16), server_default='exclude', nullable=False),
|
||||||
|
sa.Column('exclude_plugins', sa.ARRAY(sa.String(length=255)), nullable=False),
|
||||||
|
sa.Column('include_plugins', sa.ARRAY(sa.String(length=255)), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='tenant_plugin_auto_upgrade_strategy_pkey'),
|
||||||
|
sa.UniqueConstraint('tenant_id', name='unique_tenant_plugin_auto_upgrade_strategy')
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_table('tenant_plugin_auto_upgrade_strategies')
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -12,7 +12,7 @@ import sqlalchemy as sa
|
|||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision = '4474872b0ee6'
|
revision = '4474872b0ee6'
|
||||||
down_revision = '2adcbe1f5dfb'
|
down_revision = '16081485540c'
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,62 @@
|
|||||||
|
"""tool oauth
|
||||||
|
|
||||||
|
Revision ID: 71f5020c6470
|
||||||
|
Revises: 4474872b0ee6
|
||||||
|
Create Date: 2025-06-24 17:05:43.118647
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '71f5020c6470'
|
||||||
|
down_revision = '1c9ba48be8e4'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('tool_oauth_system_clients',
|
||||||
|
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||||
|
sa.Column('plugin_id', sa.String(length=512), nullable=False),
|
||||||
|
sa.Column('provider', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('encrypted_oauth_params', sa.Text(), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='tool_oauth_system_client_pkey'),
|
||||||
|
sa.UniqueConstraint('plugin_id', 'provider', name='tool_oauth_system_client_plugin_id_provider_idx')
|
||||||
|
)
|
||||||
|
op.create_table('tool_oauth_tenant_clients',
|
||||||
|
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||||
|
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('plugin_id', sa.String(length=512), nullable=False),
|
||||||
|
sa.Column('provider', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False),
|
||||||
|
sa.Column('encrypted_oauth_params', sa.Text(), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='tool_oauth_tenant_client_pkey'),
|
||||||
|
sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_tool_oauth_tenant_client')
|
||||||
|
)
|
||||||
|
|
||||||
|
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('name', sa.String(length=256), server_default=sa.text("'API KEY 1'::character varying"), nullable=False))
|
||||||
|
batch_op.add_column(sa.Column('is_default', sa.Boolean(), server_default=sa.text('false'), nullable=False))
|
||||||
|
batch_op.add_column(sa.Column('credential_type', sa.String(length=32), server_default=sa.text("'api-key'::character varying"), nullable=False))
|
||||||
|
batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique')
|
||||||
|
batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider', 'name'])
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
|
||||||
|
batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique')
|
||||||
|
batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider'])
|
||||||
|
batch_op.drop_column('credential_type')
|
||||||
|
batch_op.drop_column('is_default')
|
||||||
|
batch_op.drop_column('name')
|
||||||
|
|
||||||
|
op.drop_table('tool_oauth_tenant_clients')
|
||||||
|
op.drop_table('tool_oauth_system_clients')
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -21,6 +21,43 @@ from .model import Account, App, Tenant
|
|||||||
from .types import StringUUID
|
from .types import StringUUID
|
||||||
|
|
||||||
|
|
||||||
|
# system level tool oauth client params (client_id, client_secret, etc.)
|
||||||
|
class ToolOAuthSystemClient(Base):
|
||||||
|
__tablename__ = "tool_oauth_system_clients"
|
||||||
|
__table_args__ = (
|
||||||
|
db.PrimaryKeyConstraint("id", name="tool_oauth_system_client_pkey"),
|
||||||
|
db.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"),
|
||||||
|
)
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||||
|
plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False)
|
||||||
|
provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||||
|
# oauth params of the tool provider
|
||||||
|
encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False)
|
||||||
|
|
||||||
|
|
||||||
|
# tenant level tool oauth client params (client_id, client_secret, etc.)
|
||||||
|
class ToolOAuthTenantClient(Base):
|
||||||
|
__tablename__ = "tool_oauth_tenant_clients"
|
||||||
|
__table_args__ = (
|
||||||
|
db.PrimaryKeyConstraint("id", name="tool_oauth_tenant_client_pkey"),
|
||||||
|
db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"),
|
||||||
|
)
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||||
|
# tenant id
|
||||||
|
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
|
plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False)
|
||||||
|
provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||||
|
enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
|
||||||
|
# oauth params of the tool provider
|
||||||
|
encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def oauth_params(self) -> dict:
|
||||||
|
return cast(dict, json.loads(self.encrypted_oauth_params or "{}"))
|
||||||
|
|
||||||
|
|
||||||
class BuiltinToolProvider(Base):
|
class BuiltinToolProvider(Base):
|
||||||
"""
|
"""
|
||||||
This table stores the tool provider information for built-in tools for each tenant.
|
This table stores the tool provider information for built-in tools for each tenant.
|
||||||
@@ -29,12 +66,14 @@ class BuiltinToolProvider(Base):
|
|||||||
__tablename__ = "tool_builtin_providers"
|
__tablename__ = "tool_builtin_providers"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"),
|
db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"),
|
||||||
# one tenant can only have one tool provider with the same name
|
db.UniqueConstraint("tenant_id", "provider", "name", name="unique_builtin_tool_provider"),
|
||||||
db.UniqueConstraint("tenant_id", "provider", name="unique_builtin_tool_provider"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# id of the tool provider
|
# id of the tool provider
|
||||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||||
|
name: Mapped[str] = mapped_column(
|
||||||
|
db.String(256), nullable=False, server_default=db.text("'API KEY 1'::character varying")
|
||||||
|
)
|
||||||
# id of the tenant
|
# id of the tenant
|
||||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
|
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
|
||||||
# who created this tool provider
|
# who created this tool provider
|
||||||
@@ -49,6 +88,11 @@ class BuiltinToolProvider(Base):
|
|||||||
updated_at: Mapped[datetime] = mapped_column(
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
||||||
)
|
)
|
||||||
|
is_default: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||||
|
# credential type, e.g., "api-key", "oauth2"
|
||||||
|
credential_type: Mapped[str] = mapped_column(
|
||||||
|
db.String(32), nullable=False, server_default=db.text("'api-key'::character varying")
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def credentials(self) -> dict:
|
def credentials(self) -> dict:
|
||||||
@@ -68,7 +112,7 @@ class ApiToolProvider(Base):
|
|||||||
|
|
||||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||||
# name of the api provider
|
# name of the api provider
|
||||||
name = db.Column(db.String(255), nullable=False)
|
name = db.Column(db.String(255), nullable=False, server_default=db.text("'API KEY 1'::character varying"))
|
||||||
# icon
|
# icon
|
||||||
icon = db.Column(db.String(255), nullable=False)
|
icon = db.Column(db.String(255), nullable=False)
|
||||||
# original schema
|
# original schema
|
||||||
@@ -281,18 +325,19 @@ class MCPToolProvider(Base):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def decrypted_credentials(self) -> dict:
|
def decrypted_credentials(self) -> dict:
|
||||||
|
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||||
from core.tools.mcp_tool.provider import MCPToolProviderController
|
from core.tools.mcp_tool.provider import MCPToolProviderController
|
||||||
from core.tools.utils.configuration import ProviderConfigEncrypter
|
from core.tools.utils.encryption import create_provider_encrypter
|
||||||
|
|
||||||
provider_controller = MCPToolProviderController._from_db(self)
|
provider_controller = MCPToolProviderController._from_db(self)
|
||||||
|
|
||||||
tool_configuration = ProviderConfigEncrypter(
|
encrypter, _ = create_provider_encrypter(
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
config=list(provider_controller.get_credentials_schema()),
|
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||||
provider_type=provider_controller.provider_type.value,
|
cache=NoOpProviderCredentialCache(),
|
||||||
provider_identity=provider_controller.provider_id,
|
|
||||||
)
|
)
|
||||||
return tool_configuration.decrypt(self.credentials, use_cache=False)
|
|
||||||
|
return encrypter.decrypt(self.credentials) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class ToolModelInvoke(Base):
|
class ToolModelInvoke(Base):
|
||||||
|
|||||||
@@ -575,13 +575,26 @@ class AppDslService:
|
|||||||
raise ValueError("Missing draft workflow configuration, please check.")
|
raise ValueError("Missing draft workflow configuration, please check.")
|
||||||
|
|
||||||
workflow_dict = workflow.to_dict(include_secret=include_secret)
|
workflow_dict = workflow.to_dict(include_secret=include_secret)
|
||||||
|
# TODO: refactor: we need a better way to filter workspace related data from nodes
|
||||||
for node in workflow_dict.get("graph", {}).get("nodes", []):
|
for node in workflow_dict.get("graph", {}).get("nodes", []):
|
||||||
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
|
node_data = node.get("data", {})
|
||||||
dataset_ids = node["data"].get("dataset_ids", [])
|
if not node_data:
|
||||||
node["data"]["dataset_ids"] = [
|
continue
|
||||||
|
data_type = node_data.get("type", "")
|
||||||
|
if data_type == NodeType.KNOWLEDGE_RETRIEVAL.value:
|
||||||
|
dataset_ids = node_data.get("dataset_ids", [])
|
||||||
|
node_data["dataset_ids"] = [
|
||||||
cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=app_model.tenant_id)
|
cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=app_model.tenant_id)
|
||||||
for dataset_id in dataset_ids
|
for dataset_id in dataset_ids
|
||||||
]
|
]
|
||||||
|
# filter credential id from tool node
|
||||||
|
if not include_secret and data_type == NodeType.TOOL.value:
|
||||||
|
node_data.pop("credential_id", None)
|
||||||
|
# filter credential id from agent node
|
||||||
|
if not include_secret and data_type == NodeType.AGENT.value:
|
||||||
|
for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []):
|
||||||
|
tool.pop("credential_id", None)
|
||||||
|
|
||||||
export_data["workflow"] = workflow_dict
|
export_data["workflow"] = workflow_dict
|
||||||
dependencies = cls._extract_dependencies_from_workflow(workflow)
|
dependencies = cls._extract_dependencies_from_workflow(workflow)
|
||||||
export_data["dependencies"] = [
|
export_data["dependencies"] = [
|
||||||
@@ -602,7 +615,15 @@ class AppDslService:
|
|||||||
if not app_model_config:
|
if not app_model_config:
|
||||||
raise ValueError("Missing app configuration, please check.")
|
raise ValueError("Missing app configuration, please check.")
|
||||||
|
|
||||||
export_data["model_config"] = app_model_config.to_dict()
|
model_config = app_model_config.to_dict()
|
||||||
|
|
||||||
|
# TODO: refactor: we need a better way to filter workspace related data from model config
|
||||||
|
# filter credential id from model config
|
||||||
|
for tool in model_config.get("agent_mode", {}).get("tools", []):
|
||||||
|
tool.pop("credential_id", None)
|
||||||
|
|
||||||
|
export_data["model_config"] = model_config
|
||||||
|
|
||||||
dependencies = cls._extract_dependencies_from_model_config(app_model_config.to_dict())
|
dependencies = cls._extract_dependencies_from_model_config(app_model_config.to_dict())
|
||||||
export_data["dependencies"] = [
|
export_data["dependencies"] = [
|
||||||
jsonable_encoder(d.model_dump())
|
jsonable_encoder(d.model_dump())
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
|
|||||||
from core.plugin.entities.parameters import PluginParameterOption
|
from core.plugin.entities.parameters import PluginParameterOption
|
||||||
from core.plugin.impl.dynamic_select import DynamicSelectClient
|
from core.plugin.impl.dynamic_select import DynamicSelectClient
|
||||||
from core.tools.tool_manager import ToolManager
|
from core.tools.tool_manager import ToolManager
|
||||||
from core.tools.utils.configuration import ProviderConfigEncrypter
|
from core.tools.utils.encryption import create_tool_provider_encrypter
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.tools import BuiltinToolProvider
|
from models.tools import BuiltinToolProvider
|
||||||
|
|
||||||
@@ -38,11 +38,9 @@ class PluginParameterService:
|
|||||||
case "tool":
|
case "tool":
|
||||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||||
# init tool configuration
|
# init tool configuration
|
||||||
tool_configuration = ProviderConfigEncrypter(
|
encrypter, _ = create_tool_provider_encrypter(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
controller=provider_controller,
|
||||||
provider_type=provider_controller.provider_type.value,
|
|
||||||
provider_identity=provider_controller.entity.identity.name,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# check if credentials are required
|
# check if credentials are required
|
||||||
@@ -63,7 +61,7 @@ class PluginParameterService:
|
|||||||
if db_record is None:
|
if db_record is None:
|
||||||
raise ValueError(f"Builtin provider {provider} not found when fetching credentials")
|
raise ValueError(f"Builtin provider {provider} not found when fetching credentials")
|
||||||
|
|
||||||
credentials = tool_configuration.decrypt(db_record.credentials)
|
credentials = encrypter.decrypt(db_record.credentials)
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Invalid provider type: {provider_type}")
|
raise ValueError(f"Invalid provider type: {provider_type}")
|
||||||
|
|
||||||
|
|||||||
@@ -196,6 +196,17 @@ class PluginService:
|
|||||||
manager = PluginInstaller()
|
manager = PluginInstaller()
|
||||||
return manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier)
|
return manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_plugin_verified(tenant_id: str, plugin_unique_identifier: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the plugin is verified
|
||||||
|
"""
|
||||||
|
manager = PluginInstaller()
|
||||||
|
try:
|
||||||
|
return manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier).verified
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def fetch_install_tasks(tenant_id: str, page: int, page_size: int) -> Sequence[PluginInstallTask]:
|
def fetch_install_tasks(tenant_id: str, page: int, page_size: int) -> Sequence[PluginInstallTask]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from core.tools.entities.tool_entities import (
|
|||||||
)
|
)
|
||||||
from core.tools.tool_label_manager import ToolLabelManager
|
from core.tools.tool_label_manager import ToolLabelManager
|
||||||
from core.tools.tool_manager import ToolManager
|
from core.tools.tool_manager import ToolManager
|
||||||
from core.tools.utils.configuration import ProviderConfigEncrypter
|
from core.tools.utils.encryption import create_tool_provider_encrypter
|
||||||
from core.tools.utils.parser import ApiBasedToolSchemaParser
|
from core.tools.utils.parser import ApiBasedToolSchemaParser
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.tools import ApiToolProvider
|
from models.tools import ApiToolProvider
|
||||||
@@ -164,15 +164,11 @@ class ApiToolManageService:
|
|||||||
provider_controller.load_bundled_tools(tool_bundles)
|
provider_controller.load_bundled_tools(tool_bundles)
|
||||||
|
|
||||||
# encrypt credentials
|
# encrypt credentials
|
||||||
tool_configuration = ProviderConfigEncrypter(
|
encrypter, _ = create_tool_provider_encrypter(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
config=list(provider_controller.get_credentials_schema()),
|
controller=provider_controller,
|
||||||
provider_type=provider_controller.provider_type.value,
|
|
||||||
provider_identity=provider_controller.entity.identity.name,
|
|
||||||
)
|
)
|
||||||
|
db_provider.credentials_str = json.dumps(encrypter.encrypt(credentials))
|
||||||
encrypted_credentials = tool_configuration.encrypt(credentials)
|
|
||||||
db_provider.credentials_str = json.dumps(encrypted_credentials)
|
|
||||||
|
|
||||||
db.session.add(db_provider)
|
db.session.add(db_provider)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
@@ -297,28 +293,26 @@ class ApiToolManageService:
|
|||||||
provider_controller.load_bundled_tools(tool_bundles)
|
provider_controller.load_bundled_tools(tool_bundles)
|
||||||
|
|
||||||
# get original credentials if exists
|
# get original credentials if exists
|
||||||
tool_configuration = ProviderConfigEncrypter(
|
encrypter, cache = create_tool_provider_encrypter(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
config=list(provider_controller.get_credentials_schema()),
|
controller=provider_controller,
|
||||||
provider_type=provider_controller.provider_type.value,
|
|
||||||
provider_identity=provider_controller.entity.identity.name,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
original_credentials = tool_configuration.decrypt(provider.credentials)
|
original_credentials = encrypter.decrypt(provider.credentials)
|
||||||
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
|
masked_credentials = encrypter.mask_tool_credentials(original_credentials)
|
||||||
# check if the credential has changed, save the original credential
|
# check if the credential has changed, save the original credential
|
||||||
for name, value in credentials.items():
|
for name, value in credentials.items():
|
||||||
if name in masked_credentials and value == masked_credentials[name]:
|
if name in masked_credentials and value == masked_credentials[name]:
|
||||||
credentials[name] = original_credentials[name]
|
credentials[name] = original_credentials[name]
|
||||||
|
|
||||||
credentials = tool_configuration.encrypt(credentials)
|
credentials = encrypter.encrypt(credentials)
|
||||||
provider.credentials_str = json.dumps(credentials)
|
provider.credentials_str = json.dumps(credentials)
|
||||||
|
|
||||||
db.session.add(provider)
|
db.session.add(provider)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
# delete cache
|
# delete cache
|
||||||
tool_configuration.delete_tool_credentials_cache()
|
cache.delete()
|
||||||
|
|
||||||
# update labels
|
# update labels
|
||||||
ToolLabelManager.update_tool_labels(provider_controller, labels)
|
ToolLabelManager.update_tool_labels(provider_controller, labels)
|
||||||
@@ -416,15 +410,13 @@ class ApiToolManageService:
|
|||||||
|
|
||||||
# decrypt credentials
|
# decrypt credentials
|
||||||
if db_provider.id:
|
if db_provider.id:
|
||||||
tool_configuration = ProviderConfigEncrypter(
|
encrypter, _ = create_tool_provider_encrypter(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
config=list(provider_controller.get_credentials_schema()),
|
controller=provider_controller,
|
||||||
provider_type=provider_controller.provider_type.value,
|
|
||||||
provider_identity=provider_controller.entity.identity.name,
|
|
||||||
)
|
)
|
||||||
decrypted_credentials = tool_configuration.decrypt(credentials)
|
decrypted_credentials = encrypter.decrypt(credentials)
|
||||||
# check if the credential has changed, save the original credential
|
# check if the credential has changed, save the original credential
|
||||||
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
|
masked_credentials = encrypter.mask_tool_credentials(decrypted_credentials)
|
||||||
for name, value in credentials.items():
|
for name, value in credentials.items():
|
||||||
if name in masked_credentials and value == masked_credentials[name]:
|
if name in masked_credentials and value == masked_credentials[name]:
|
||||||
credentials[name] = decrypted_credentials[name]
|
credentials[name] = decrypted_credentials[name]
|
||||||
@@ -446,7 +438,7 @@ class ApiToolManageService:
|
|||||||
return {"result": result or "empty response"}
|
return {"result": result or "empty response"}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def list_api_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
|
def list_api_tools(tenant_id: str) -> list[ToolProviderApiEntity]:
|
||||||
"""
|
"""
|
||||||
list api tools
|
list api tools
|
||||||
"""
|
"""
|
||||||
@@ -474,7 +466,7 @@ class ApiToolManageService:
|
|||||||
for tool in tools or []:
|
for tool in tools or []:
|
||||||
user_provider.tools.append(
|
user_provider.tools.append(
|
||||||
ToolTransformService.convert_tool_entity_to_api_entity(
|
ToolTransformService.convert_tool_entity_to_api_entity(
|
||||||
tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels
|
tenant_id=tenant_id, tool=tool, labels=labels
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,28 +1,84 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
|
from collections.abc import Mapping
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||||
from core.helper.position_helper import is_filtered
|
from core.helper.position_helper import is_filtered
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
|
||||||
from core.plugin.entities.plugin import ToolProviderID
|
from core.plugin.entities.plugin import ToolProviderID
|
||||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||||
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
||||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
from core.tools.entities.api_entities import (
|
||||||
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
|
ToolApiEntity,
|
||||||
|
ToolProviderApiEntity,
|
||||||
|
ToolProviderCredentialApiEntity,
|
||||||
|
ToolProviderCredentialInfoApiEntity,
|
||||||
|
)
|
||||||
|
from core.tools.entities.tool_entities import CredentialType
|
||||||
|
from core.tools.errors import ToolProviderNotFoundError
|
||||||
|
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||||
from core.tools.tool_label_manager import ToolLabelManager
|
from core.tools.tool_label_manager import ToolLabelManager
|
||||||
from core.tools.tool_manager import ToolManager
|
from core.tools.tool_manager import ToolManager
|
||||||
from core.tools.utils.configuration import ProviderConfigEncrypter
|
from core.tools.utils.encryption import create_provider_encrypter
|
||||||
|
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.tools import BuiltinToolProvider
|
from extensions.ext_redis import redis_client
|
||||||
|
from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient
|
||||||
|
from services.plugin.plugin_service import PluginService
|
||||||
from services.tools.tools_transform_service import ToolTransformService
|
from services.tools.tools_transform_service import ToolTransformService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BuiltinToolManageService:
|
class BuiltinToolManageService:
|
||||||
|
__MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def delete_custom_oauth_client_params(tenant_id: str, provider: str):
|
||||||
|
"""
|
||||||
|
delete custom oauth client params
|
||||||
|
"""
|
||||||
|
tool_provider = ToolProviderID(provider)
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
session.query(ToolOAuthTenantClient).filter_by(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider=tool_provider.provider_name,
|
||||||
|
plugin_id=tool_provider.plugin_id,
|
||||||
|
).delete()
|
||||||
|
session.commit()
|
||||||
|
return {"result": "success"}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_builtin_tool_provider_oauth_client_schema(tenant_id: str, provider_name: str):
|
||||||
|
"""
|
||||||
|
get builtin tool provider oauth client schema
|
||||||
|
"""
|
||||||
|
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
||||||
|
verified = not isinstance(provider, PluginToolProviderController) or PluginService.is_plugin_verified(
|
||||||
|
tenant_id, provider.plugin_unique_identifier
|
||||||
|
)
|
||||||
|
|
||||||
|
is_oauth_custom_client_enabled = BuiltinToolManageService.is_oauth_custom_client_enabled(
|
||||||
|
tenant_id, provider_name
|
||||||
|
)
|
||||||
|
is_system_oauth_params_exists = verified and BuiltinToolManageService.is_oauth_system_client_exists(
|
||||||
|
provider_name
|
||||||
|
)
|
||||||
|
result = {
|
||||||
|
"schema": provider.get_oauth_client_schema(),
|
||||||
|
"is_oauth_custom_client_enabled": is_oauth_custom_client_enabled,
|
||||||
|
"is_system_oauth_params_exists": is_system_oauth_params_exists,
|
||||||
|
"client_params": BuiltinToolManageService.get_custom_oauth_client_params(tenant_id, provider_name),
|
||||||
|
"redirect_uri": f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_name}/tool/callback",
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]:
|
def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]:
|
||||||
"""
|
"""
|
||||||
@@ -36,27 +92,11 @@ class BuiltinToolManageService:
|
|||||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||||
tools = provider_controller.get_tools()
|
tools = provider_controller.get_tools()
|
||||||
|
|
||||||
tool_provider_configurations = ProviderConfigEncrypter(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
|
||||||
provider_type=provider_controller.provider_type.value,
|
|
||||||
provider_identity=provider_controller.entity.identity.name,
|
|
||||||
)
|
|
||||||
# check if user has added the provider
|
|
||||||
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
|
|
||||||
|
|
||||||
credentials = {}
|
|
||||||
if builtin_provider is not None:
|
|
||||||
# get credentials
|
|
||||||
credentials = builtin_provider.credentials
|
|
||||||
credentials = tool_provider_configurations.decrypt(credentials)
|
|
||||||
|
|
||||||
result: list[ToolApiEntity] = []
|
result: list[ToolApiEntity] = []
|
||||||
for tool in tools or []:
|
for tool in tools or []:
|
||||||
result.append(
|
result.append(
|
||||||
ToolTransformService.convert_tool_entity_to_api_entity(
|
ToolTransformService.convert_tool_entity_to_api_entity(
|
||||||
tool=tool,
|
tool=tool,
|
||||||
credentials=credentials,
|
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
labels=ToolLabelManager.get_tool_labels(provider_controller),
|
labels=ToolLabelManager.get_tool_labels(provider_controller),
|
||||||
)
|
)
|
||||||
@@ -65,25 +105,15 @@ class BuiltinToolManageService:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_builtin_tool_provider_info(user_id: str, tenant_id: str, provider: str):
|
def get_builtin_tool_provider_info(tenant_id: str, provider: str):
|
||||||
"""
|
"""
|
||||||
get builtin tool provider info
|
get builtin tool provider info
|
||||||
"""
|
"""
|
||||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||||
tool_provider_configurations = ProviderConfigEncrypter(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
|
||||||
provider_type=provider_controller.provider_type.value,
|
|
||||||
provider_identity=provider_controller.entity.identity.name,
|
|
||||||
)
|
|
||||||
# check if user has added the provider
|
# check if user has added the provider
|
||||||
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
|
builtin_provider = BuiltinToolManageService.get_builtin_provider(provider, tenant_id)
|
||||||
|
if builtin_provider is None:
|
||||||
credentials = {}
|
raise ValueError(f"you have not added provider {provider}")
|
||||||
if builtin_provider is not None:
|
|
||||||
# get credentials
|
|
||||||
credentials = builtin_provider.credentials
|
|
||||||
credentials = tool_provider_configurations.decrypt(credentials)
|
|
||||||
|
|
||||||
entity = ToolTransformService.builtin_provider_to_user_provider(
|
entity = ToolTransformService.builtin_provider_to_user_provider(
|
||||||
provider_controller=provider_controller,
|
provider_controller=provider_controller,
|
||||||
@@ -92,128 +122,407 @@ class BuiltinToolManageService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
entity.original_credentials = {}
|
entity.original_credentials = {}
|
||||||
|
|
||||||
return entity
|
return entity
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def list_builtin_provider_credentials_schema(provider_name: str, tenant_id: str):
|
def list_builtin_provider_credentials_schema(provider_name: str, credential_type: CredentialType, tenant_id: str):
|
||||||
"""
|
"""
|
||||||
list builtin provider credentials schema
|
list builtin provider credentials schema
|
||||||
|
|
||||||
|
:param credential_type: credential type
|
||||||
:param provider_name: the name of the provider
|
:param provider_name: the name of the provider
|
||||||
:param tenant_id: the id of the tenant
|
:param tenant_id: the id of the tenant
|
||||||
:return: the list of tool providers
|
:return: the list of tool providers
|
||||||
"""
|
"""
|
||||||
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
||||||
return jsonable_encoder(provider.get_credentials_schema())
|
return provider.get_credentials_schema_by_type(credential_type)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_builtin_tool_provider(
|
def update_builtin_tool_provider(
|
||||||
session: Session, user_id: str, tenant_id: str, provider_name: str, credentials: dict
|
user_id: str,
|
||||||
|
tenant_id: str,
|
||||||
|
provider: str,
|
||||||
|
credential_id: str,
|
||||||
|
credentials: dict | None = None,
|
||||||
|
name: str | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
update builtin tool provider
|
update builtin tool provider
|
||||||
"""
|
"""
|
||||||
# get if the provider exists
|
with Session(db.engine) as session:
|
||||||
provider = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
|
# get if the provider exists
|
||||||
|
db_provider = (
|
||||||
try:
|
session.query(BuiltinToolProvider)
|
||||||
# get provider
|
.filter(
|
||||||
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
BuiltinToolProvider.tenant_id == tenant_id,
|
||||||
if not provider_controller.need_credentials:
|
BuiltinToolProvider.id == credential_id,
|
||||||
raise ValueError(f"provider {provider_name} does not need credentials")
|
)
|
||||||
tool_configuration = ProviderConfigEncrypter(
|
.first()
|
||||||
tenant_id=tenant_id,
|
|
||||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
|
||||||
provider_type=provider_controller.provider_type.value,
|
|
||||||
provider_identity=provider_controller.entity.identity.name,
|
|
||||||
)
|
)
|
||||||
|
if db_provider is None:
|
||||||
|
raise ValueError(f"you have not added provider {provider}")
|
||||||
|
|
||||||
# get original credentials if exists
|
try:
|
||||||
if provider is not None:
|
if CredentialType.of(db_provider.credential_type).is_editable() and credentials:
|
||||||
original_credentials = tool_configuration.decrypt(provider.credentials)
|
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||||
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
|
if not provider_controller.need_credentials:
|
||||||
# check if the credential has changed, save the original credential
|
raise ValueError(f"provider {provider} does not need credentials")
|
||||||
for name, value in credentials.items():
|
|
||||||
if name in masked_credentials and value == masked_credentials[name]:
|
|
||||||
credentials[name] = original_credentials[name]
|
|
||||||
# validate credentials
|
|
||||||
provider_controller.validate_credentials(user_id, credentials)
|
|
||||||
# encrypt credentials
|
|
||||||
credentials = tool_configuration.encrypt(credentials)
|
|
||||||
except (
|
|
||||||
PluginDaemonClientSideError,
|
|
||||||
ToolProviderNotFoundError,
|
|
||||||
ToolNotFoundError,
|
|
||||||
ToolProviderCredentialValidationError,
|
|
||||||
) as e:
|
|
||||||
raise ValueError(str(e))
|
|
||||||
|
|
||||||
if provider is None:
|
encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
|
||||||
# create provider
|
tenant_id, db_provider, provider, provider_controller
|
||||||
provider = BuiltinToolProvider(
|
)
|
||||||
tenant_id=tenant_id,
|
|
||||||
user_id=user_id,
|
|
||||||
provider=provider_name,
|
|
||||||
encrypted_credentials=json.dumps(credentials),
|
|
||||||
)
|
|
||||||
|
|
||||||
db.session.add(provider)
|
original_credentials = encrypter.decrypt(db_provider.credentials)
|
||||||
else:
|
new_credentials: dict = {
|
||||||
provider.encrypted_credentials = json.dumps(credentials)
|
key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE)
|
||||||
|
for key, value in credentials.items()
|
||||||
|
}
|
||||||
|
|
||||||
# delete cache
|
if CredentialType.of(db_provider.credential_type).is_validate_allowed():
|
||||||
tool_configuration.delete_tool_credentials_cache()
|
provider_controller.validate_credentials(user_id, new_credentials)
|
||||||
|
|
||||||
db.session.commit()
|
# encrypt credentials
|
||||||
|
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(new_credentials))
|
||||||
|
|
||||||
|
cache.delete()
|
||||||
|
|
||||||
|
# update name if provided
|
||||||
|
if name and name != db_provider.name:
|
||||||
|
# check if the name is already used
|
||||||
|
if (
|
||||||
|
session.query(BuiltinToolProvider)
|
||||||
|
.filter_by(tenant_id=tenant_id, provider=provider, name=name)
|
||||||
|
.count()
|
||||||
|
> 0
|
||||||
|
):
|
||||||
|
raise ValueError(f"the credential name '{name}' is already used")
|
||||||
|
|
||||||
|
db_provider.name = name
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
except Exception as e:
|
||||||
|
session.rollback()
|
||||||
|
raise ValueError(str(e))
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str):
|
def add_builtin_tool_provider(
|
||||||
|
user_id: str,
|
||||||
|
api_type: CredentialType,
|
||||||
|
tenant_id: str,
|
||||||
|
provider: str,
|
||||||
|
credentials: dict,
|
||||||
|
name: str | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
add builtin tool provider
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
|
||||||
|
with redis_client.lock(lock, timeout=20):
|
||||||
|
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||||
|
if not provider_controller.need_credentials:
|
||||||
|
raise ValueError(f"provider {provider} does not need credentials")
|
||||||
|
|
||||||
|
provider_count = (
|
||||||
|
session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count()
|
||||||
|
)
|
||||||
|
|
||||||
|
# check if the provider count is reached the limit
|
||||||
|
if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__:
|
||||||
|
raise ValueError(f"you have reached the maximum number of providers for {provider}")
|
||||||
|
|
||||||
|
# validate credentials if allowed
|
||||||
|
if CredentialType.of(api_type).is_validate_allowed():
|
||||||
|
provider_controller.validate_credentials(user_id, credentials)
|
||||||
|
|
||||||
|
# generate name if not provided
|
||||||
|
if name is None or name == "":
|
||||||
|
name = BuiltinToolManageService.generate_builtin_tool_provider_name(
|
||||||
|
session=session, tenant_id=tenant_id, provider=provider, credential_type=api_type
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# check if the name is already used
|
||||||
|
if (
|
||||||
|
session.query(BuiltinToolProvider)
|
||||||
|
.filter_by(tenant_id=tenant_id, provider=provider, name=name)
|
||||||
|
.count()
|
||||||
|
> 0
|
||||||
|
):
|
||||||
|
raise ValueError(f"the credential name '{name}' is already used")
|
||||||
|
|
||||||
|
# create encrypter
|
||||||
|
encrypter, _ = create_provider_encrypter(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
config=[
|
||||||
|
x.to_basic_provider_config()
|
||||||
|
for x in provider_controller.get_credentials_schema_by_type(api_type)
|
||||||
|
],
|
||||||
|
cache=NoOpProviderCredentialCache(),
|
||||||
|
)
|
||||||
|
|
||||||
|
db_provider = BuiltinToolProvider(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
provider=provider,
|
||||||
|
encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
|
||||||
|
credential_type=api_type.value,
|
||||||
|
name=name,
|
||||||
|
)
|
||||||
|
|
||||||
|
session.add(db_provider)
|
||||||
|
session.commit()
|
||||||
|
except Exception as e:
|
||||||
|
session.rollback()
|
||||||
|
raise ValueError(str(e))
|
||||||
|
return {"result": "success"}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_tool_encrypter(
|
||||||
|
tenant_id: str,
|
||||||
|
db_provider: BuiltinToolProvider,
|
||||||
|
provider: str,
|
||||||
|
provider_controller: BuiltinToolProviderController,
|
||||||
|
):
|
||||||
|
encrypter, cache = create_provider_encrypter(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
config=[
|
||||||
|
x.to_basic_provider_config()
|
||||||
|
for x in provider_controller.get_credentials_schema_by_type(db_provider.credential_type)
|
||||||
|
],
|
||||||
|
cache=ToolProviderCredentialsCache(tenant_id=tenant_id, provider=provider, credential_id=db_provider.id),
|
||||||
|
)
|
||||||
|
return encrypter, cache
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_builtin_tool_provider_name(
|
||||||
|
session: Session, tenant_id: str, provider: str, credential_type: CredentialType
|
||||||
|
) -> str:
|
||||||
|
try:
|
||||||
|
db_providers = (
|
||||||
|
session.query(BuiltinToolProvider)
|
||||||
|
.filter_by(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider=provider,
|
||||||
|
credential_type=credential_type.value,
|
||||||
|
)
|
||||||
|
.order_by(BuiltinToolProvider.created_at.desc())
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the default name pattern
|
||||||
|
default_pattern = f"{credential_type.get_name()}"
|
||||||
|
|
||||||
|
# Find all names that match the default pattern: "{default_pattern} {number}"
|
||||||
|
pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$"
|
||||||
|
numbers = []
|
||||||
|
|
||||||
|
for db_provider in db_providers:
|
||||||
|
if db_provider.name:
|
||||||
|
match = re.match(pattern, db_provider.name.strip())
|
||||||
|
if match:
|
||||||
|
numbers.append(int(match.group(1)))
|
||||||
|
|
||||||
|
# If no default pattern names found, start with 1
|
||||||
|
if not numbers:
|
||||||
|
return f"{default_pattern} 1"
|
||||||
|
|
||||||
|
# Find the next number
|
||||||
|
max_number = max(numbers)
|
||||||
|
return f"{default_pattern} {max_number + 1}"
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error generating next provider name for {provider}: {str(e)}")
|
||||||
|
# fallback
|
||||||
|
return f"{credential_type.get_name()} 1"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_builtin_tool_provider_credentials(
|
||||||
|
tenant_id: str, provider_name: str
|
||||||
|
) -> list[ToolProviderCredentialApiEntity]:
|
||||||
"""
|
"""
|
||||||
get builtin tool provider credentials
|
get builtin tool provider credentials
|
||||||
"""
|
"""
|
||||||
provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
|
with db.session.no_autoflush:
|
||||||
|
providers = (
|
||||||
|
db.session.query(BuiltinToolProvider)
|
||||||
|
.filter_by(tenant_id=tenant_id, provider=provider_name)
|
||||||
|
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
if provider_obj is None:
|
if len(providers) == 0:
|
||||||
return {}
|
return []
|
||||||
|
|
||||||
provider_controller = ToolManager.get_builtin_provider(provider_obj.provider, tenant_id)
|
default_provider = providers[0]
|
||||||
tool_configuration = ProviderConfigEncrypter(
|
default_provider.is_default = True
|
||||||
tenant_id=tenant_id,
|
provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id)
|
||||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
|
||||||
provider_type=provider_controller.provider_type.value,
|
credentials: list[ToolProviderCredentialApiEntity] = []
|
||||||
provider_identity=provider_controller.entity.identity.name,
|
encrypters = {}
|
||||||
)
|
for provider in providers:
|
||||||
credentials = tool_configuration.decrypt(provider_obj.credentials)
|
credential_type = provider.credential_type
|
||||||
credentials = tool_configuration.mask_tool_credentials(credentials)
|
if credential_type not in encrypters:
|
||||||
return credentials
|
encrypters[credential_type] = BuiltinToolManageService.create_tool_encrypter(
|
||||||
|
tenant_id, provider, provider.provider, provider_controller
|
||||||
|
)[0]
|
||||||
|
encrypter = encrypters[credential_type]
|
||||||
|
decrypt_credential = encrypter.mask_tool_credentials(encrypter.decrypt(provider.credentials))
|
||||||
|
credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
|
||||||
|
provider=provider,
|
||||||
|
credentials=decrypt_credential,
|
||||||
|
)
|
||||||
|
credentials.append(credential_entity)
|
||||||
|
return credentials
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str):
|
def get_builtin_tool_provider_credential_info(tenant_id: str, provider: str) -> ToolProviderCredentialInfoApiEntity:
|
||||||
|
"""
|
||||||
|
get builtin tool provider credential info
|
||||||
|
"""
|
||||||
|
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||||
|
supported_credential_types = provider_controller.get_supported_credential_types()
|
||||||
|
credentials = BuiltinToolManageService.get_builtin_tool_provider_credentials(tenant_id, provider)
|
||||||
|
credential_info = ToolProviderCredentialInfoApiEntity(
|
||||||
|
supported_credential_types=supported_credential_types,
|
||||||
|
is_oauth_custom_client_enabled=BuiltinToolManageService.is_oauth_custom_client_enabled(tenant_id, provider),
|
||||||
|
credentials=credentials,
|
||||||
|
)
|
||||||
|
|
||||||
|
return credential_info
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def delete_builtin_tool_provider(tenant_id: str, provider: str, credential_id: str):
|
||||||
"""
|
"""
|
||||||
delete tool provider
|
delete tool provider
|
||||||
"""
|
"""
|
||||||
provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
|
with Session(db.engine) as session:
|
||||||
|
db_provider = (
|
||||||
|
session.query(BuiltinToolProvider)
|
||||||
|
.filter(
|
||||||
|
BuiltinToolProvider.tenant_id == tenant_id,
|
||||||
|
BuiltinToolProvider.id == credential_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
if provider_obj is None:
|
if db_provider is None:
|
||||||
raise ValueError(f"you have not added provider {provider_name}")
|
raise ValueError(f"you have not added provider {provider}")
|
||||||
|
|
||||||
db.session.delete(provider_obj)
|
session.delete(db_provider)
|
||||||
db.session.commit()
|
session.commit()
|
||||||
|
|
||||||
# delete cache
|
# delete cache
|
||||||
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||||
tool_configuration = ProviderConfigEncrypter(
|
_, cache = BuiltinToolManageService.create_tool_encrypter(
|
||||||
tenant_id=tenant_id,
|
tenant_id, db_provider, provider, provider_controller
|
||||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
)
|
||||||
provider_type=provider_controller.provider_type.value,
|
cache.delete()
|
||||||
provider_identity=provider_controller.entity.identity.name,
|
|
||||||
)
|
|
||||||
tool_configuration.delete_tool_credentials_cache()
|
|
||||||
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_default_provider(tenant_id: str, user_id: str, provider: str, id: str):
|
||||||
|
"""
|
||||||
|
set default provider
|
||||||
|
"""
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
# get provider
|
||||||
|
target_provider = session.query(BuiltinToolProvider).filter_by(id=id).first()
|
||||||
|
if target_provider is None:
|
||||||
|
raise ValueError("provider not found")
|
||||||
|
|
||||||
|
# clear default provider
|
||||||
|
session.query(BuiltinToolProvider).filter_by(
|
||||||
|
tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True
|
||||||
|
).update({"is_default": False})
|
||||||
|
|
||||||
|
# set new default provider
|
||||||
|
target_provider.is_default = True
|
||||||
|
session.commit()
|
||||||
|
return {"result": "success"}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_oauth_system_client_exists(provider_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
check if oauth system client exists
|
||||||
|
"""
|
||||||
|
tool_provider = ToolProviderID(provider_name)
|
||||||
|
with Session(db.engine).no_autoflush as session:
|
||||||
|
system_client: ToolOAuthSystemClient | None = (
|
||||||
|
session.query(ToolOAuthSystemClient)
|
||||||
|
.filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
return system_client is not None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_oauth_custom_client_enabled(tenant_id: str, provider: str) -> bool:
|
||||||
|
"""
|
||||||
|
check if oauth custom client is enabled
|
||||||
|
"""
|
||||||
|
tool_provider = ToolProviderID(provider)
|
||||||
|
with Session(db.engine).no_autoflush as session:
|
||||||
|
user_client: ToolOAuthTenantClient | None = (
|
||||||
|
session.query(ToolOAuthTenantClient)
|
||||||
|
.filter_by(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider=tool_provider.provider_name,
|
||||||
|
plugin_id=tool_provider.plugin_id,
|
||||||
|
enabled=True,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
return user_client is not None and user_client.enabled
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_oauth_client(tenant_id: str, provider: str) -> Mapping[str, Any] | None:
|
||||||
|
"""
|
||||||
|
get builtin tool provider
|
||||||
|
"""
|
||||||
|
tool_provider = ToolProviderID(provider)
|
||||||
|
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||||
|
encrypter, _ = create_provider_encrypter(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
||||||
|
cache=NoOpProviderCredentialCache(),
|
||||||
|
)
|
||||||
|
with Session(db.engine).no_autoflush as session:
|
||||||
|
user_client: ToolOAuthTenantClient | None = (
|
||||||
|
session.query(ToolOAuthTenantClient)
|
||||||
|
.filter_by(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider=tool_provider.provider_name,
|
||||||
|
plugin_id=tool_provider.plugin_id,
|
||||||
|
enabled=True,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
oauth_params: Mapping[str, Any] | None = None
|
||||||
|
if user_client:
|
||||||
|
oauth_params = encrypter.decrypt(user_client.oauth_params)
|
||||||
|
return oauth_params
|
||||||
|
|
||||||
|
# only verified provider can use custom oauth client
|
||||||
|
is_verified = not isinstance(provider, PluginToolProviderController) or PluginService.is_plugin_verified(
|
||||||
|
tenant_id, provider.plugin_unique_identifier
|
||||||
|
)
|
||||||
|
if not is_verified:
|
||||||
|
return oauth_params
|
||||||
|
|
||||||
|
system_client: ToolOAuthSystemClient | None = (
|
||||||
|
session.query(ToolOAuthSystemClient)
|
||||||
|
.filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if system_client:
|
||||||
|
try:
|
||||||
|
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error decrypting system oauth params: {e}")
|
||||||
|
|
||||||
|
return oauth_params
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_builtin_tool_provider_icon(provider: str):
|
def get_builtin_tool_provider_icon(provider: str):
|
||||||
"""
|
"""
|
||||||
@@ -234,9 +543,7 @@ class BuiltinToolManageService:
|
|||||||
|
|
||||||
with db.session.no_autoflush:
|
with db.session.no_autoflush:
|
||||||
# get all user added providers
|
# get all user added providers
|
||||||
db_providers: list[BuiltinToolProvider] = (
|
db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id)
|
||||||
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or []
|
|
||||||
)
|
|
||||||
|
|
||||||
# rewrite db_providers
|
# rewrite db_providers
|
||||||
for db_provider in db_providers:
|
for db_provider in db_providers:
|
||||||
@@ -275,7 +582,6 @@ class BuiltinToolManageService:
|
|||||||
ToolTransformService.convert_tool_entity_to_api_entity(
|
ToolTransformService.convert_tool_entity_to_api_entity(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
tool=tool,
|
tool=tool,
|
||||||
credentials=user_builtin_provider.original_credentials,
|
|
||||||
labels=ToolLabelManager.get_tool_labels(provider_controller),
|
labels=ToolLabelManager.get_tool_labels(provider_controller),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -287,43 +593,153 @@ class BuiltinToolManageService:
|
|||||||
return BuiltinToolProviderSort.sort(result)
|
return BuiltinToolProviderSort.sort(result)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
|
def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
|
||||||
try:
|
"""
|
||||||
full_provider_name = provider_name
|
This method is used to fetch the builtin provider from the database
|
||||||
provider_id_entity = ToolProviderID(provider_name)
|
1.if the default provider exists, return the default provider
|
||||||
provider_name = provider_id_entity.provider_name
|
2.if the default provider does not exist, return the oldest provider
|
||||||
if provider_id_entity.organization != "langgenius":
|
"""
|
||||||
provider_obj = (
|
with Session(db.engine) as session:
|
||||||
db.session.query(BuiltinToolProvider)
|
try:
|
||||||
.filter(
|
full_provider_name = provider_name
|
||||||
BuiltinToolProvider.tenant_id == tenant_id,
|
provider_id_entity = ToolProviderID(provider_name)
|
||||||
BuiltinToolProvider.provider == full_provider_name,
|
provider_name = provider_id_entity.provider_name
|
||||||
|
|
||||||
|
if provider_id_entity.organization != "langgenius":
|
||||||
|
provider = (
|
||||||
|
session.query(BuiltinToolProvider)
|
||||||
|
.filter(
|
||||||
|
BuiltinToolProvider.tenant_id == tenant_id,
|
||||||
|
BuiltinToolProvider.provider == full_provider_name,
|
||||||
|
)
|
||||||
|
.order_by(
|
||||||
|
BuiltinToolProvider.is_default.desc(), # default=True first
|
||||||
|
BuiltinToolProvider.created_at.asc(), # oldest first
|
||||||
|
)
|
||||||
|
.first()
|
||||||
)
|
)
|
||||||
.first()
|
else:
|
||||||
)
|
provider = (
|
||||||
else:
|
session.query(BuiltinToolProvider)
|
||||||
provider_obj = (
|
.filter(
|
||||||
db.session.query(BuiltinToolProvider)
|
BuiltinToolProvider.tenant_id == tenant_id,
|
||||||
.filter(
|
(BuiltinToolProvider.provider == provider_name)
|
||||||
BuiltinToolProvider.tenant_id == tenant_id,
|
| (BuiltinToolProvider.provider == full_provider_name),
|
||||||
(BuiltinToolProvider.provider == provider_name)
|
)
|
||||||
| (BuiltinToolProvider.provider == full_provider_name),
|
.order_by(
|
||||||
|
BuiltinToolProvider.is_default.desc(), # default=True first
|
||||||
|
BuiltinToolProvider.created_at.asc(), # oldest first
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if provider is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
provider.provider = ToolProviderID(provider.provider).to_string()
|
||||||
|
return provider
|
||||||
|
except Exception:
|
||||||
|
# it's an old provider without organization
|
||||||
|
return (
|
||||||
|
session.query(BuiltinToolProvider)
|
||||||
|
.filter(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name)
|
||||||
|
.order_by(
|
||||||
|
BuiltinToolProvider.is_default.desc(), # default=True first
|
||||||
|
BuiltinToolProvider.created_at.asc(), # oldest first
|
||||||
)
|
)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
if provider_obj is None:
|
@staticmethod
|
||||||
return None
|
def save_custom_oauth_client_params(
|
||||||
|
tenant_id: str,
|
||||||
|
provider: str,
|
||||||
|
client_params: Optional[dict] = None,
|
||||||
|
enable_oauth_custom_client: Optional[bool] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
setup oauth custom client
|
||||||
|
"""
|
||||||
|
if client_params is None and enable_oauth_custom_client is None:
|
||||||
|
return {"result": "success"}
|
||||||
|
|
||||||
provider_obj.provider = ToolProviderID(provider_obj.provider).to_string()
|
tool_provider = ToolProviderID(provider)
|
||||||
return provider_obj
|
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||||
except Exception:
|
if not provider_controller:
|
||||||
# it's an old provider without organization
|
raise ToolProviderNotFoundError(f"Provider {provider} not found")
|
||||||
return (
|
|
||||||
db.session.query(BuiltinToolProvider)
|
if not isinstance(provider_controller, (BuiltinToolProviderController, PluginToolProviderController)):
|
||||||
.filter(
|
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
|
||||||
BuiltinToolProvider.tenant_id == tenant_id,
|
|
||||||
(BuiltinToolProvider.provider == provider_name),
|
with Session(db.engine) as session:
|
||||||
|
custom_client_params = (
|
||||||
|
session.query(ToolOAuthTenantClient)
|
||||||
|
.filter_by(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
plugin_id=tool_provider.plugin_id,
|
||||||
|
provider=tool_provider.provider_name,
|
||||||
)
|
)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# if the record does not exist, create a basic record
|
||||||
|
if custom_client_params is None:
|
||||||
|
custom_client_params = ToolOAuthTenantClient(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
plugin_id=tool_provider.plugin_id,
|
||||||
|
provider=tool_provider.provider_name,
|
||||||
|
)
|
||||||
|
session.add(custom_client_params)
|
||||||
|
|
||||||
|
if client_params is not None:
|
||||||
|
encrypter, _ = create_provider_encrypter(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
||||||
|
cache=NoOpProviderCredentialCache(),
|
||||||
|
)
|
||||||
|
original_params = encrypter.decrypt(custom_client_params.oauth_params)
|
||||||
|
new_params: dict = {
|
||||||
|
key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
|
||||||
|
for key, value in client_params.items()
|
||||||
|
}
|
||||||
|
custom_client_params.encrypted_oauth_params = json.dumps(encrypter.encrypt(new_params))
|
||||||
|
|
||||||
|
if enable_oauth_custom_client is not None:
|
||||||
|
custom_client_params.enabled = enable_oauth_custom_client
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
return {"result": "success"}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_custom_oauth_client_params(tenant_id: str, provider: str):
|
||||||
|
"""
|
||||||
|
get custom oauth client params
|
||||||
|
"""
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
tool_provider = ToolProviderID(provider)
|
||||||
|
custom_oauth_client_params: ToolOAuthTenantClient | None = (
|
||||||
|
session.query(ToolOAuthTenantClient)
|
||||||
|
.filter_by(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
plugin_id=tool_provider.plugin_id,
|
||||||
|
provider=tool_provider.provider_name,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if custom_oauth_client_params is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||||
|
if not provider_controller:
|
||||||
|
raise ToolProviderNotFoundError(f"Provider {provider} not found")
|
||||||
|
|
||||||
|
if not isinstance(provider_controller, BuiltinToolProviderController):
|
||||||
|
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
|
||||||
|
|
||||||
|
encrypter, _ = create_provider_encrypter(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
||||||
|
cache=NoOpProviderCredentialCache(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return encrypter.mask_tool_credentials(encrypter.decrypt(custom_oauth_client_params.oauth_params))
|
||||||
|
|||||||
@@ -7,13 +7,14 @@ from sqlalchemy import or_
|
|||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
|
||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
|
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||||
from core.mcp.error import MCPAuthError, MCPError
|
from core.mcp.error import MCPAuthError, MCPError
|
||||||
from core.mcp.mcp_client import MCPClient
|
from core.mcp.mcp_client import MCPClient
|
||||||
from core.tools.entities.api_entities import ToolProviderApiEntity
|
from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||||
from core.tools.entities.common_entities import I18nObject
|
from core.tools.entities.common_entities import I18nObject
|
||||||
from core.tools.entities.tool_entities import ToolProviderType
|
from core.tools.entities.tool_entities import ToolProviderType
|
||||||
from core.tools.mcp_tool.provider import MCPToolProviderController
|
from core.tools.mcp_tool.provider import MCPToolProviderController
|
||||||
from core.tools.utils.configuration import ProviderConfigEncrypter
|
from core.tools.utils.encryption import ProviderConfigEncrypter
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.tools import MCPToolProvider
|
from models.tools import MCPToolProvider
|
||||||
from services.tools.tools_transform_service import ToolTransformService
|
from services.tools.tools_transform_service import ToolTransformService
|
||||||
@@ -69,6 +70,7 @@ class MCPToolManageService:
|
|||||||
MCPToolProvider.server_url_hash == server_url_hash,
|
MCPToolProvider.server_url_hash == server_url_hash,
|
||||||
MCPToolProvider.server_identifier == server_identifier,
|
MCPToolProvider.server_identifier == server_identifier,
|
||||||
),
|
),
|
||||||
|
MCPToolProvider.tenant_id == tenant_id,
|
||||||
)
|
)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
@@ -197,8 +199,7 @@ class MCPToolManageService:
|
|||||||
tool_configuration = ProviderConfigEncrypter(
|
tool_configuration = ProviderConfigEncrypter(
|
||||||
tenant_id=mcp_provider.tenant_id,
|
tenant_id=mcp_provider.tenant_id,
|
||||||
config=list(provider_controller.get_credentials_schema()),
|
config=list(provider_controller.get_credentials_schema()),
|
||||||
provider_type=provider_controller.provider_type.value,
|
provider_config_cache=NoOpProviderCredentialCache(),
|
||||||
provider_identity=provider_controller.provider_id,
|
|
||||||
)
|
)
|
||||||
credentials = tool_configuration.encrypt(credentials)
|
credentials = tool_configuration.encrypt(credentials)
|
||||||
mcp_provider.updated_at = datetime.now()
|
mcp_provider.updated_at = datetime.now()
|
||||||
|
|||||||
@@ -5,21 +5,23 @@ from typing import Any, Optional, Union, cast
|
|||||||
from yarl import URL
|
from yarl import URL
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||||
from core.mcp.types import Tool as MCPTool
|
from core.mcp.types import Tool as MCPTool
|
||||||
from core.tools.__base.tool import Tool
|
from core.tools.__base.tool import Tool
|
||||||
from core.tools.__base.tool_runtime import ToolRuntime
|
from core.tools.__base.tool_runtime import ToolRuntime
|
||||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||||
from core.tools.custom_tool.provider import ApiToolProviderController
|
from core.tools.custom_tool.provider import ApiToolProviderController
|
||||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity
|
||||||
from core.tools.entities.common_entities import I18nObject
|
from core.tools.entities.common_entities import I18nObject
|
||||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||||
from core.tools.entities.tool_entities import (
|
from core.tools.entities.tool_entities import (
|
||||||
ApiProviderAuthType,
|
ApiProviderAuthType,
|
||||||
|
CredentialType,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
ToolProviderType,
|
ToolProviderType,
|
||||||
)
|
)
|
||||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||||
from core.tools.utils.configuration import ProviderConfigEncrypter
|
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
|
||||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||||
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
||||||
@@ -119,7 +121,12 @@ class ToolTransformService:
|
|||||||
result.plugin_unique_identifier = provider_controller.plugin_unique_identifier
|
result.plugin_unique_identifier = provider_controller.plugin_unique_identifier
|
||||||
|
|
||||||
# get credentials schema
|
# get credentials schema
|
||||||
schema = {x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema()}
|
schema = {
|
||||||
|
x.to_basic_provider_config().name: x
|
||||||
|
for x in provider_controller.get_credentials_schema_by_type(
|
||||||
|
CredentialType.of(db_provider.credential_type) if db_provider else CredentialType.API_KEY
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
for name, value in schema.items():
|
for name, value in schema.items():
|
||||||
if result.masked_credentials:
|
if result.masked_credentials:
|
||||||
@@ -136,15 +143,23 @@ class ToolTransformService:
|
|||||||
credentials = db_provider.credentials
|
credentials = db_provider.credentials
|
||||||
|
|
||||||
# init tool configuration
|
# init tool configuration
|
||||||
tool_configuration = ProviderConfigEncrypter(
|
encrypter, _ = create_provider_encrypter(
|
||||||
tenant_id=db_provider.tenant_id,
|
tenant_id=db_provider.tenant_id,
|
||||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
config=[
|
||||||
provider_type=provider_controller.provider_type.value,
|
x.to_basic_provider_config()
|
||||||
provider_identity=provider_controller.entity.identity.name,
|
for x in provider_controller.get_credentials_schema_by_type(
|
||||||
|
CredentialType.of(db_provider.credential_type)
|
||||||
|
)
|
||||||
|
],
|
||||||
|
cache=ToolProviderCredentialsCache(
|
||||||
|
tenant_id=db_provider.tenant_id,
|
||||||
|
provider=db_provider.provider,
|
||||||
|
credential_id=db_provider.id,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
# decrypt the credentials and mask the credentials
|
# decrypt the credentials and mask the credentials
|
||||||
decrypted_credentials = tool_configuration.decrypt(data=credentials)
|
decrypted_credentials = encrypter.decrypt(data=credentials)
|
||||||
masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials)
|
masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials)
|
||||||
|
|
||||||
result.masked_credentials = masked_credentials
|
result.masked_credentials = masked_credentials
|
||||||
result.original_credentials = decrypted_credentials
|
result.original_credentials = decrypted_credentials
|
||||||
@@ -287,16 +302,14 @@ class ToolTransformService:
|
|||||||
|
|
||||||
if decrypt_credentials:
|
if decrypt_credentials:
|
||||||
# init tool configuration
|
# init tool configuration
|
||||||
tool_configuration = ProviderConfigEncrypter(
|
encrypter, _ = create_tool_provider_encrypter(
|
||||||
tenant_id=db_provider.tenant_id,
|
tenant_id=db_provider.tenant_id,
|
||||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
controller=provider_controller,
|
||||||
provider_type=provider_controller.provider_type.value,
|
|
||||||
provider_identity=provider_controller.entity.identity.name,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# decrypt the credentials and mask the credentials
|
# decrypt the credentials and mask the credentials
|
||||||
decrypted_credentials = tool_configuration.decrypt(data=credentials)
|
decrypted_credentials = encrypter.decrypt(data=credentials)
|
||||||
masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials)
|
masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials)
|
||||||
|
|
||||||
result.masked_credentials = masked_credentials
|
result.masked_credentials = masked_credentials
|
||||||
|
|
||||||
@@ -306,7 +319,6 @@ class ToolTransformService:
|
|||||||
def convert_tool_entity_to_api_entity(
|
def convert_tool_entity_to_api_entity(
|
||||||
tool: Union[ApiToolBundle, WorkflowTool, Tool],
|
tool: Union[ApiToolBundle, WorkflowTool, Tool],
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
credentials: dict | None = None,
|
|
||||||
labels: list[str] | None = None,
|
labels: list[str] | None = None,
|
||||||
) -> ToolApiEntity:
|
) -> ToolApiEntity:
|
||||||
"""
|
"""
|
||||||
@@ -316,7 +328,7 @@ class ToolTransformService:
|
|||||||
# fork tool runtime
|
# fork tool runtime
|
||||||
tool = tool.fork_tool_runtime(
|
tool = tool.fork_tool_runtime(
|
||||||
runtime=ToolRuntime(
|
runtime=ToolRuntime(
|
||||||
credentials=credentials or {},
|
credentials={},
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -357,6 +369,19 @@ class ToolTransformService:
|
|||||||
labels=labels or [],
|
labels=labels or [],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_builtin_provider_to_credential_entity(
|
||||||
|
provider: BuiltinToolProvider, credentials: dict
|
||||||
|
) -> ToolProviderCredentialApiEntity:
|
||||||
|
return ToolProviderCredentialApiEntity(
|
||||||
|
id=provider.id,
|
||||||
|
name=provider.name,
|
||||||
|
provider=provider.provider,
|
||||||
|
credential_type=CredentialType.of(provider.credential_type),
|
||||||
|
is_default=provider.is_default,
|
||||||
|
credentials=credentials,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def convert_mcp_schema_to_parameter(schema: dict) -> list["ToolParameter"]:
|
def convert_mcp_schema_to_parameter(schema: dict) -> list["ToolParameter"]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -0,0 +1,619 @@
|
|||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from Crypto.Cipher import AES
|
||||||
|
from Crypto.Random import get_random_bytes
|
||||||
|
from Crypto.Util.Padding import pad
|
||||||
|
|
||||||
|
from core.tools.utils.system_oauth_encryption import (
|
||||||
|
OAuthEncryptionError,
|
||||||
|
SystemOAuthEncrypter,
|
||||||
|
create_system_oauth_encrypter,
|
||||||
|
decrypt_system_oauth_params,
|
||||||
|
encrypt_system_oauth_params,
|
||||||
|
get_system_oauth_encrypter,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSystemOAuthEncrypter:
|
||||||
|
"""Test cases for SystemOAuthEncrypter class"""
|
||||||
|
|
||||||
|
def test_init_with_secret_key(self):
|
||||||
|
"""Test initialization with provided secret key"""
|
||||||
|
secret_key = "test_secret_key"
|
||||||
|
encrypter = SystemOAuthEncrypter(secret_key=secret_key)
|
||||||
|
expected_key = hashlib.sha256(secret_key.encode()).digest()
|
||||||
|
assert encrypter.key == expected_key
|
||||||
|
|
||||||
|
def test_init_with_none_secret_key(self):
|
||||||
|
"""Test initialization with None secret key falls back to config"""
|
||||||
|
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
|
||||||
|
mock_config.SECRET_KEY = "config_secret"
|
||||||
|
encrypter = SystemOAuthEncrypter(secret_key=None)
|
||||||
|
expected_key = hashlib.sha256(b"config_secret").digest()
|
||||||
|
assert encrypter.key == expected_key
|
||||||
|
|
||||||
|
def test_init_with_empty_secret_key(self):
|
||||||
|
"""Test initialization with empty secret key"""
|
||||||
|
encrypter = SystemOAuthEncrypter(secret_key="")
|
||||||
|
expected_key = hashlib.sha256(b"").digest()
|
||||||
|
assert encrypter.key == expected_key
|
||||||
|
|
||||||
|
def test_init_without_secret_key_uses_config(self):
|
||||||
|
"""Test initialization without secret key uses config"""
|
||||||
|
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
|
||||||
|
mock_config.SECRET_KEY = "default_secret"
|
||||||
|
encrypter = SystemOAuthEncrypter()
|
||||||
|
expected_key = hashlib.sha256(b"default_secret").digest()
|
||||||
|
assert encrypter.key == expected_key
|
||||||
|
|
||||||
|
def test_encrypt_oauth_params_basic(self):
|
||||||
|
"""Test basic OAuth parameters encryption"""
|
||||||
|
encrypter = SystemOAuthEncrypter("test_secret")
|
||||||
|
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||||
|
|
||||||
|
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||||
|
|
||||||
|
assert isinstance(encrypted, str)
|
||||||
|
assert len(encrypted) > 0
|
||||||
|
# Should be valid base64
|
||||||
|
try:
|
||||||
|
base64.b64decode(encrypted)
|
||||||
|
except Exception:
|
||||||
|
pytest.fail("Encrypted result is not valid base64")
|
||||||
|
|
||||||
|
def test_encrypt_oauth_params_empty_dict(self):
|
||||||
|
"""Test encryption with empty dictionary"""
|
||||||
|
encrypter = SystemOAuthEncrypter("test_secret")
|
||||||
|
oauth_params = {}
|
||||||
|
|
||||||
|
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||||
|
assert isinstance(encrypted, str)
|
||||||
|
assert len(encrypted) > 0
|
||||||
|
|
||||||
|
def test_encrypt_oauth_params_complex_data(self):
|
||||||
|
"""Test encryption with complex data structures"""
|
||||||
|
encrypter = SystemOAuthEncrypter("test_secret")
|
||||||
|
oauth_params = {
|
||||||
|
"client_id": "test_id",
|
||||||
|
"client_secret": "test_secret",
|
||||||
|
"scopes": ["read", "write", "admin"],
|
||||||
|
"metadata": {"issuer": "test_issuer", "expires_in": 3600, "is_active": True},
|
||||||
|
"numeric_value": 42,
|
||||||
|
"boolean_value": False,
|
||||||
|
"null_value": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||||
|
assert isinstance(encrypted, str)
|
||||||
|
assert len(encrypted) > 0
|
||||||
|
|
||||||
|
def test_encrypt_oauth_params_unicode_data(self):
|
||||||
|
"""Test encryption with unicode data"""
|
||||||
|
encrypter = SystemOAuthEncrypter("test_secret")
|
||||||
|
oauth_params = {"client_id": "test_id", "client_secret": "test_secret", "description": "This is a test case 🚀"}
|
||||||
|
|
||||||
|
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||||
|
assert isinstance(encrypted, str)
|
||||||
|
assert len(encrypted) > 0
|
||||||
|
|
||||||
|
def test_encrypt_oauth_params_large_data(self):
|
||||||
|
"""Test encryption with large data"""
|
||||||
|
encrypter = SystemOAuthEncrypter("test_secret")
|
||||||
|
oauth_params = {
|
||||||
|
"client_id": "test_id",
|
||||||
|
"large_data": "x" * 10000, # 10KB of data
|
||||||
|
}
|
||||||
|
|
||||||
|
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||||
|
assert isinstance(encrypted, str)
|
||||||
|
assert len(encrypted) > 0
|
||||||
|
|
||||||
|
def test_encrypt_oauth_params_invalid_input(self):
|
||||||
|
"""Test encryption with invalid input types"""
|
||||||
|
encrypter = SystemOAuthEncrypter("test_secret")
|
||||||
|
|
||||||
|
with pytest.raises(Exception): # noqa: B017
|
||||||
|
encrypter.encrypt_oauth_params(None) # type: ignore
|
||||||
|
|
||||||
|
with pytest.raises(Exception): # noqa: B017
|
||||||
|
encrypter.encrypt_oauth_params("not_a_dict") # type: ignore
|
||||||
|
|
||||||
|
def test_decrypt_oauth_params_basic(self):
|
||||||
|
"""Test basic OAuth parameters decryption"""
|
||||||
|
encrypter = SystemOAuthEncrypter("test_secret")
|
||||||
|
original_params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||||
|
|
||||||
|
encrypted = encrypter.encrypt_oauth_params(original_params)
|
||||||
|
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||||
|
|
||||||
|
assert decrypted == original_params
|
||||||
|
|
||||||
|
def test_decrypt_oauth_params_empty_dict(self):
|
||||||
|
"""Test decryption of empty dictionary"""
|
||||||
|
encrypter = SystemOAuthEncrypter("test_secret")
|
||||||
|
original_params = {}
|
||||||
|
|
||||||
|
encrypted = encrypter.encrypt_oauth_params(original_params)
|
||||||
|
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||||
|
|
||||||
|
assert decrypted == original_params
|
||||||
|
|
||||||
|
def test_decrypt_oauth_params_complex_data(self):
|
||||||
|
"""Test decryption with complex data structures"""
|
||||||
|
encrypter = SystemOAuthEncrypter("test_secret")
|
||||||
|
original_params = {
|
||||||
|
"client_id": "test_id",
|
||||||
|
"client_secret": "test_secret",
|
||||||
|
"scopes": ["read", "write", "admin"],
|
||||||
|
"metadata": {"issuer": "test_issuer", "expires_in": 3600, "is_active": True},
|
||||||
|
"numeric_value": 42,
|
||||||
|
"boolean_value": False,
|
||||||
|
"null_value": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
encrypted = encrypter.encrypt_oauth_params(original_params)
|
||||||
|
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||||
|
|
||||||
|
assert decrypted == original_params
|
||||||
|
|
||||||
|
def test_decrypt_oauth_params_unicode_data(self):
|
||||||
|
"""Test decryption with unicode data"""
|
||||||
|
encrypter = SystemOAuthEncrypter("test_secret")
|
||||||
|
original_params = {
|
||||||
|
"client_id": "test_id",
|
||||||
|
"client_secret": "test_secret",
|
||||||
|
"description": "This is a test case 🚀",
|
||||||
|
}
|
||||||
|
|
||||||
|
encrypted = encrypter.encrypt_oauth_params(original_params)
|
||||||
|
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||||
|
|
||||||
|
assert decrypted == original_params
|
||||||
|
|
||||||
|
def test_decrypt_oauth_params_large_data(self):
|
||||||
|
"""Test decryption with large data"""
|
||||||
|
encrypter = SystemOAuthEncrypter("test_secret")
|
||||||
|
original_params = {
|
||||||
|
"client_id": "test_id",
|
||||||
|
"large_data": "x" * 10000, # 10KB of data
|
||||||
|
}
|
||||||
|
|
||||||
|
encrypted = encrypter.encrypt_oauth_params(original_params)
|
||||||
|
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||||
|
|
||||||
|
assert decrypted == original_params
|
||||||
|
|
||||||
|
def test_decrypt_oauth_params_invalid_base64(self):
|
||||||
|
"""Test decryption with invalid base64 data"""
|
||||||
|
encrypter = SystemOAuthEncrypter("test_secret")
|
||||||
|
|
||||||
|
with pytest.raises(OAuthEncryptionError):
|
||||||
|
encrypter.decrypt_oauth_params("invalid_base64!")
|
||||||
|
|
||||||
|
def test_decrypt_oauth_params_empty_string(self):
|
||||||
|
"""Test decryption with empty string"""
|
||||||
|
encrypter = SystemOAuthEncrypter("test_secret")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
encrypter.decrypt_oauth_params("")
|
||||||
|
|
||||||
|
assert "encrypted_data cannot be empty" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_decrypt_oauth_params_non_string_input(self):
|
||||||
|
"""Test decryption with non-string input"""
|
||||||
|
encrypter = SystemOAuthEncrypter("test_secret")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
encrypter.decrypt_oauth_params(123) # type: ignore
|
||||||
|
|
||||||
|
assert "encrypted_data must be a string" in str(exc_info.value)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
encrypter.decrypt_oauth_params(None) # type: ignore
|
||||||
|
|
||||||
|
assert "encrypted_data must be a string" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_decrypt_oauth_params_too_short_data(self):
|
||||||
|
"""Test decryption with too short encrypted data"""
|
||||||
|
encrypter = SystemOAuthEncrypter("test_secret")
|
||||||
|
|
||||||
|
# Create data that's too short (less than 32 bytes)
|
||||||
|
short_data = base64.b64encode(b"short").decode()
|
||||||
|
|
||||||
|
with pytest.raises(OAuthEncryptionError) as exc_info:
|
||||||
|
encrypter.decrypt_oauth_params(short_data)
|
||||||
|
|
||||||
|
assert "Invalid encrypted data format" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_decrypt_oauth_params_corrupted_data(self):
|
||||||
|
"""Test decryption with corrupted data"""
|
||||||
|
encrypter = SystemOAuthEncrypter("test_secret")
|
||||||
|
|
||||||
|
# Create corrupted data (valid base64 but invalid encrypted content)
|
||||||
|
corrupted_data = base64.b64encode(b"x" * 48).decode() # 48 bytes of garbage
|
||||||
|
|
||||||
|
with pytest.raises(OAuthEncryptionError):
|
||||||
|
encrypter.decrypt_oauth_params(corrupted_data)
|
||||||
|
|
||||||
|
def test_decrypt_oauth_params_wrong_key(self):
|
||||||
|
"""Test decryption with wrong key"""
|
||||||
|
encrypter1 = SystemOAuthEncrypter("secret1")
|
||||||
|
encrypter2 = SystemOAuthEncrypter("secret2")
|
||||||
|
|
||||||
|
original_params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||||
|
encrypted = encrypter1.encrypt_oauth_params(original_params)
|
||||||
|
|
||||||
|
with pytest.raises(OAuthEncryptionError):
|
||||||
|
encrypter2.decrypt_oauth_params(encrypted)
|
||||||
|
|
||||||
|
def test_encryption_decryption_consistency(self):
|
||||||
|
"""Test that encryption and decryption are consistent"""
|
||||||
|
encrypter = SystemOAuthEncrypter("test_secret")
|
||||||
|
|
||||||
|
test_cases = [
|
||||||
|
{},
|
||||||
|
{"simple": "value"},
|
||||||
|
{"client_id": "id", "client_secret": "secret"},
|
||||||
|
{"complex": {"nested": {"deep": "value"}}},
|
||||||
|
{"unicode": "test 🚀"},
|
||||||
|
{"numbers": 42, "boolean": True, "null": None},
|
||||||
|
{"array": [1, 2, 3, "four", {"five": 5}]},
|
||||||
|
]
|
||||||
|
|
||||||
|
for original_params in test_cases:
|
||||||
|
encrypted = encrypter.encrypt_oauth_params(original_params)
|
||||||
|
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||||
|
assert decrypted == original_params, f"Failed for case: {original_params}"
|
||||||
|
|
||||||
|
def test_encryption_randomness(self):
|
||||||
|
"""Test that encryption produces different results for same input"""
|
||||||
|
encrypter = SystemOAuthEncrypter("test_secret")
|
||||||
|
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||||
|
|
||||||
|
encrypted1 = encrypter.encrypt_oauth_params(oauth_params)
|
||||||
|
encrypted2 = encrypter.encrypt_oauth_params(oauth_params)
|
||||||
|
|
||||||
|
# Should be different due to random IV
|
||||||
|
assert encrypted1 != encrypted2
|
||||||
|
|
||||||
|
# But should decrypt to same result
|
||||||
|
decrypted1 = encrypter.decrypt_oauth_params(encrypted1)
|
||||||
|
decrypted2 = encrypter.decrypt_oauth_params(encrypted2)
|
||||||
|
assert decrypted1 == decrypted2 == oauth_params
|
||||||
|
|
||||||
|
def test_different_secret_keys_produce_different_results(self):
|
||||||
|
"""Test that different secret keys produce different encrypted results"""
|
||||||
|
encrypter1 = SystemOAuthEncrypter("secret1")
|
||||||
|
encrypter2 = SystemOAuthEncrypter("secret2")
|
||||||
|
|
||||||
|
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||||
|
|
||||||
|
encrypted1 = encrypter1.encrypt_oauth_params(oauth_params)
|
||||||
|
encrypted2 = encrypter2.encrypt_oauth_params(oauth_params)
|
||||||
|
|
||||||
|
# Should produce different encrypted results
|
||||||
|
assert encrypted1 != encrypted2
|
||||||
|
|
||||||
|
# But each should decrypt correctly with its own key
|
||||||
|
decrypted1 = encrypter1.decrypt_oauth_params(encrypted1)
|
||||||
|
decrypted2 = encrypter2.decrypt_oauth_params(encrypted2)
|
||||||
|
assert decrypted1 == decrypted2 == oauth_params
|
||||||
|
|
||||||
|
@patch("core.tools.utils.system_oauth_encryption.get_random_bytes")
|
||||||
|
def test_encrypt_oauth_params_crypto_error(self, mock_get_random_bytes):
|
||||||
|
"""Test encryption when crypto operation fails"""
|
||||||
|
mock_get_random_bytes.side_effect = Exception("Crypto error")
|
||||||
|
|
||||||
|
encrypter = SystemOAuthEncrypter("test_secret")
|
||||||
|
oauth_params = {"client_id": "test_id"}
|
||||||
|
|
||||||
|
with pytest.raises(OAuthEncryptionError) as exc_info:
|
||||||
|
encrypter.encrypt_oauth_params(oauth_params)
|
||||||
|
|
||||||
|
assert "Encryption failed" in str(exc_info.value)
|
||||||
|
|
||||||
|
@patch("core.tools.utils.system_oauth_encryption.TypeAdapter")
|
||||||
|
def test_encrypt_oauth_params_serialization_error(self, mock_type_adapter):
|
||||||
|
"""Test encryption when JSON serialization fails"""
|
||||||
|
mock_type_adapter.return_value.dump_json.side_effect = Exception("Serialization error")
|
||||||
|
|
||||||
|
encrypter = SystemOAuthEncrypter("test_secret")
|
||||||
|
oauth_params = {"client_id": "test_id"}
|
||||||
|
|
||||||
|
with pytest.raises(OAuthEncryptionError) as exc_info:
|
||||||
|
encrypter.encrypt_oauth_params(oauth_params)
|
||||||
|
|
||||||
|
assert "Encryption failed" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_decrypt_oauth_params_invalid_json(self):
|
||||||
|
"""Test decryption with invalid JSON data"""
|
||||||
|
encrypter = SystemOAuthEncrypter("test_secret")
|
||||||
|
|
||||||
|
# Create valid encrypted data but with invalid JSON content
|
||||||
|
iv = get_random_bytes(16)
|
||||||
|
cipher = AES.new(encrypter.key, AES.MODE_CBC, iv)
|
||||||
|
invalid_json = b"invalid json content"
|
||||||
|
padded_data = pad(invalid_json, AES.block_size)
|
||||||
|
encrypted_data = cipher.encrypt(padded_data)
|
||||||
|
combined = iv + encrypted_data
|
||||||
|
encoded = base64.b64encode(combined).decode()
|
||||||
|
|
||||||
|
with pytest.raises(OAuthEncryptionError):
|
||||||
|
encrypter.decrypt_oauth_params(encoded)
|
||||||
|
|
||||||
|
def test_key_derivation_consistency(self):
|
||||||
|
"""Test that key derivation is consistent"""
|
||||||
|
secret_key = "test_secret"
|
||||||
|
encrypter1 = SystemOAuthEncrypter(secret_key)
|
||||||
|
encrypter2 = SystemOAuthEncrypter(secret_key)
|
||||||
|
|
||||||
|
assert encrypter1.key == encrypter2.key
|
||||||
|
|
||||||
|
# Keys should be 32 bytes (256 bits)
|
||||||
|
assert len(encrypter1.key) == 32
|
||||||
|
|
||||||
|
|
||||||
|
class TestFactoryFunctions:
|
||||||
|
"""Test cases for factory functions"""
|
||||||
|
|
||||||
|
def test_create_system_oauth_encrypter_with_secret(self):
|
||||||
|
"""Test factory function with secret key"""
|
||||||
|
secret_key = "test_secret"
|
||||||
|
encrypter = create_system_oauth_encrypter(secret_key)
|
||||||
|
|
||||||
|
assert isinstance(encrypter, SystemOAuthEncrypter)
|
||||||
|
expected_key = hashlib.sha256(secret_key.encode()).digest()
|
||||||
|
assert encrypter.key == expected_key
|
||||||
|
|
||||||
|
def test_create_system_oauth_encrypter_without_secret(self):
|
||||||
|
"""Test factory function without secret key"""
|
||||||
|
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
|
||||||
|
mock_config.SECRET_KEY = "config_secret"
|
||||||
|
encrypter = create_system_oauth_encrypter()
|
||||||
|
|
||||||
|
assert isinstance(encrypter, SystemOAuthEncrypter)
|
||||||
|
expected_key = hashlib.sha256(b"config_secret").digest()
|
||||||
|
assert encrypter.key == expected_key
|
||||||
|
|
||||||
|
def test_create_system_oauth_encrypter_with_none_secret(self):
|
||||||
|
"""Test factory function with None secret key"""
|
||||||
|
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
|
||||||
|
mock_config.SECRET_KEY = "config_secret"
|
||||||
|
encrypter = create_system_oauth_encrypter(None)
|
||||||
|
|
||||||
|
assert isinstance(encrypter, SystemOAuthEncrypter)
|
||||||
|
expected_key = hashlib.sha256(b"config_secret").digest()
|
||||||
|
assert encrypter.key == expected_key
|
||||||
|
|
||||||
|
|
||||||
|
class TestGlobalEncrypterInstance:
|
||||||
|
"""Test cases for global encrypter instance"""
|
||||||
|
|
||||||
|
def test_get_system_oauth_encrypter_singleton(self):
|
||||||
|
"""Test that get_system_oauth_encrypter returns singleton instance"""
|
||||||
|
# Clear the global instance first
|
||||||
|
import core.tools.utils.system_oauth_encryption
|
||||||
|
|
||||||
|
core.tools.utils.system_oauth_encryption._oauth_encrypter = None
|
||||||
|
|
||||||
|
encrypter1 = get_system_oauth_encrypter()
|
||||||
|
encrypter2 = get_system_oauth_encrypter()
|
||||||
|
|
||||||
|
assert encrypter1 is encrypter2
|
||||||
|
assert isinstance(encrypter1, SystemOAuthEncrypter)
|
||||||
|
|
||||||
|
def test_get_system_oauth_encrypter_uses_config(self):
|
||||||
|
"""Test that global encrypter uses config"""
|
||||||
|
# Clear the global instance first
|
||||||
|
import core.tools.utils.system_oauth_encryption
|
||||||
|
|
||||||
|
core.tools.utils.system_oauth_encryption._oauth_encrypter = None
|
||||||
|
|
||||||
|
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
|
||||||
|
mock_config.SECRET_KEY = "global_secret"
|
||||||
|
encrypter = get_system_oauth_encrypter()
|
||||||
|
|
||||||
|
expected_key = hashlib.sha256(b"global_secret").digest()
|
||||||
|
assert encrypter.key == expected_key
|
||||||
|
|
||||||
|
|
||||||
|
class TestConvenienceFunctions:
|
||||||
|
"""Test cases for convenience functions"""
|
||||||
|
|
||||||
|
def test_encrypt_system_oauth_params(self):
|
||||||
|
"""Test encrypt_system_oauth_params convenience function"""
|
||||||
|
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||||
|
|
||||||
|
encrypted = encrypt_system_oauth_params(oauth_params)
|
||||||
|
|
||||||
|
assert isinstance(encrypted, str)
|
||||||
|
assert len(encrypted) > 0
|
||||||
|
|
||||||
|
def test_decrypt_system_oauth_params(self):
|
||||||
|
"""Test decrypt_system_oauth_params convenience function"""
|
||||||
|
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||||
|
|
||||||
|
encrypted = encrypt_system_oauth_params(oauth_params)
|
||||||
|
decrypted = decrypt_system_oauth_params(encrypted)
|
||||||
|
|
||||||
|
assert decrypted == oauth_params
|
||||||
|
|
||||||
|
def test_convenience_functions_consistency(self):
|
||||||
|
"""Test that convenience functions work consistently"""
|
||||||
|
test_cases = [
|
||||||
|
{},
|
||||||
|
{"simple": "value"},
|
||||||
|
{"client_id": "id", "client_secret": "secret"},
|
||||||
|
{"complex": {"nested": {"deep": "value"}}},
|
||||||
|
{"unicode": "test 🚀"},
|
||||||
|
{"numbers": 42, "boolean": True, "null": None},
|
||||||
|
]
|
||||||
|
|
||||||
|
for original_params in test_cases:
|
||||||
|
encrypted = encrypt_system_oauth_params(original_params)
|
||||||
|
decrypted = decrypt_system_oauth_params(encrypted)
|
||||||
|
assert decrypted == original_params, f"Failed for case: {original_params}"
|
||||||
|
|
||||||
|
def test_convenience_functions_with_errors(self):
|
||||||
|
"""Test convenience functions with error conditions"""
|
||||||
|
# Test encryption with invalid input
|
||||||
|
with pytest.raises(Exception): # noqa: B017
|
||||||
|
encrypt_system_oauth_params(None) # type: ignore
|
||||||
|
|
||||||
|
# Test decryption with invalid input
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
decrypt_system_oauth_params("")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
decrypt_system_oauth_params(None) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class TestErrorHandling:
|
||||||
|
"""Test cases for error handling"""
|
||||||
|
|
||||||
|
def test_oauth_encryption_error_inheritance(self):
|
||||||
|
"""Test that OAuthEncryptionError is a proper exception"""
|
||||||
|
error = OAuthEncryptionError("Test error")
|
||||||
|
assert isinstance(error, Exception)
|
||||||
|
assert str(error) == "Test error"
|
||||||
|
|
||||||
|
def test_oauth_encryption_error_with_cause(self):
|
||||||
|
"""Test OAuthEncryptionError with cause"""
|
||||||
|
original_error = ValueError("Original error")
|
||||||
|
error = OAuthEncryptionError("Wrapper error")
|
||||||
|
error.__cause__ = original_error
|
||||||
|
|
||||||
|
assert isinstance(error, Exception)
|
||||||
|
assert str(error) == "Wrapper error"
|
||||||
|
assert error.__cause__ is original_error
|
||||||
|
|
||||||
|
def test_error_messages_are_informative(self):
|
||||||
|
"""Test that error messages are informative"""
|
||||||
|
encrypter = SystemOAuthEncrypter("test_secret")
|
||||||
|
|
||||||
|
# Test empty string error
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
encrypter.decrypt_oauth_params("")
|
||||||
|
assert "encrypted_data cannot be empty" in str(exc_info.value)
|
||||||
|
|
||||||
|
# Test non-string error
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
encrypter.decrypt_oauth_params(123) # type: ignore
|
||||||
|
assert "encrypted_data must be a string" in str(exc_info.value)
|
||||||
|
|
||||||
|
# Test invalid format error
|
||||||
|
short_data = base64.b64encode(b"short").decode()
|
||||||
|
with pytest.raises(OAuthEncryptionError) as exc_info:
|
||||||
|
encrypter.decrypt_oauth_params(short_data)
|
||||||
|
assert "Invalid encrypted data format" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
class TestEdgeCases:
|
||||||
|
"""Test cases for edge cases and boundary conditions"""
|
||||||
|
|
||||||
|
def test_very_long_secret_key(self):
|
||||||
|
"""Test with very long secret key"""
|
||||||
|
long_secret = "x" * 10000
|
||||||
|
encrypter = SystemOAuthEncrypter(long_secret)
|
||||||
|
|
||||||
|
# Key should still be 32 bytes due to SHA-256
|
||||||
|
assert len(encrypter.key) == 32
|
||||||
|
|
||||||
|
# Should still work normally
|
||||||
|
oauth_params = {"client_id": "test_id"}
|
||||||
|
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||||
|
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||||
|
assert decrypted == oauth_params
|
||||||
|
|
||||||
|
def test_special_characters_in_secret_key(self):
|
||||||
|
"""Test with special characters in secret key"""
|
||||||
|
special_secret = "!@#$%^&*()_+-=[]{}|;':\",./<>?`~test🚀"
|
||||||
|
encrypter = SystemOAuthEncrypter(special_secret)
|
||||||
|
|
||||||
|
oauth_params = {"client_id": "test_id"}
|
||||||
|
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||||
|
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||||
|
assert decrypted == oauth_params
|
||||||
|
|
||||||
|
def test_empty_values_in_oauth_params(self):
|
||||||
|
"""Test with empty values in oauth params"""
|
||||||
|
oauth_params = {
|
||||||
|
"client_id": "",
|
||||||
|
"client_secret": "",
|
||||||
|
"empty_dict": {},
|
||||||
|
"empty_list": [],
|
||||||
|
"empty_string": "",
|
||||||
|
"zero": 0,
|
||||||
|
"false": False,
|
||||||
|
"none": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
encrypter = SystemOAuthEncrypter("test_secret")
|
||||||
|
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||||
|
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||||
|
assert decrypted == oauth_params
|
||||||
|
|
||||||
|
def test_deeply_nested_oauth_params(self):
|
||||||
|
"""Test with deeply nested oauth params"""
|
||||||
|
oauth_params = {"level1": {"level2": {"level3": {"level4": {"level5": {"deep_value": "found"}}}}}}
|
||||||
|
|
||||||
|
encrypter = SystemOAuthEncrypter("test_secret")
|
||||||
|
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||||
|
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||||
|
assert decrypted == oauth_params
|
||||||
|
|
||||||
|
def test_oauth_params_with_all_json_types(self):
|
||||||
|
"""Test with all JSON-supported data types"""
|
||||||
|
oauth_params = {
|
||||||
|
"string": "test_string",
|
||||||
|
"integer": 42,
|
||||||
|
"float": 3.14159,
|
||||||
|
"boolean_true": True,
|
||||||
|
"boolean_false": False,
|
||||||
|
"null_value": None,
|
||||||
|
"empty_string": "",
|
||||||
|
"array": [1, "two", 3.0, True, False, None],
|
||||||
|
"object": {"nested_string": "nested_value", "nested_number": 123, "nested_bool": True},
|
||||||
|
}
|
||||||
|
|
||||||
|
encrypter = SystemOAuthEncrypter("test_secret")
|
||||||
|
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||||
|
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||||
|
assert decrypted == oauth_params
|
||||||
|
|
||||||
|
|
||||||
|
class TestPerformance:
|
||||||
|
"""Test cases for performance considerations"""
|
||||||
|
|
||||||
|
def test_large_oauth_params(self):
|
||||||
|
"""Test with large oauth params"""
|
||||||
|
large_value = "x" * 100000 # 100KB
|
||||||
|
oauth_params = {"client_id": "test_id", "large_data": large_value}
|
||||||
|
|
||||||
|
encrypter = SystemOAuthEncrypter("test_secret")
|
||||||
|
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||||
|
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||||
|
assert decrypted == oauth_params
|
||||||
|
|
||||||
|
def test_many_fields_oauth_params(self):
|
||||||
|
"""Test with many fields in oauth params"""
|
||||||
|
oauth_params = {f"field_{i}": f"value_{i}" for i in range(1000)}
|
||||||
|
|
||||||
|
encrypter = SystemOAuthEncrypter("test_secret")
|
||||||
|
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||||
|
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||||
|
assert decrypted == oauth_params
|
||||||
|
|
||||||
|
def test_repeated_encryption_decryption(self):
|
||||||
|
"""Test repeated encryption and decryption operations"""
|
||||||
|
encrypter = SystemOAuthEncrypter("test_secret")
|
||||||
|
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||||
|
|
||||||
|
# Test multiple rounds of encryption/decryption
|
||||||
|
for i in range(100):
|
||||||
|
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||||
|
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||||
|
assert decrypted == oauth_params
|
||||||
@@ -18,7 +18,6 @@ import AppIcon from '@/app/components/base/app-icon'
|
|||||||
import Button from '@/app/components/base/button'
|
import Button from '@/app/components/base/button'
|
||||||
import Indicator from '@/app/components/header/indicator'
|
import Indicator from '@/app/components/header/indicator'
|
||||||
import Switch from '@/app/components/base/switch'
|
import Switch from '@/app/components/base/switch'
|
||||||
import Toast from '@/app/components/base/toast'
|
|
||||||
import ConfigContext from '@/context/debug-configuration'
|
import ConfigContext from '@/context/debug-configuration'
|
||||||
import type { AgentTool } from '@/types/app'
|
import type { AgentTool } from '@/types/app'
|
||||||
import { type Collection, CollectionType } from '@/app/components/tools/types'
|
import { type Collection, CollectionType } from '@/app/components/tools/types'
|
||||||
@@ -26,8 +25,6 @@ import { MAX_TOOLS_NUM } from '@/config'
|
|||||||
import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback'
|
import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback'
|
||||||
import Tooltip from '@/app/components/base/tooltip'
|
import Tooltip from '@/app/components/base/tooltip'
|
||||||
import { DefaultToolIcon } from '@/app/components/base/icons/src/public/other'
|
import { DefaultToolIcon } from '@/app/components/base/icons/src/public/other'
|
||||||
import ConfigCredential from '@/app/components/tools/setting/build-in/config-credentials'
|
|
||||||
import { updateBuiltInToolCredential } from '@/service/tools'
|
|
||||||
import cn from '@/utils/classnames'
|
import cn from '@/utils/classnames'
|
||||||
import ToolPicker from '@/app/components/workflow/block-selector/tool-picker'
|
import ToolPicker from '@/app/components/workflow/block-selector/tool-picker'
|
||||||
import type { ToolDefaultValue, ToolValue } from '@/app/components/workflow/block-selector/types'
|
import type { ToolDefaultValue, ToolValue } from '@/app/components/workflow/block-selector/types'
|
||||||
@@ -57,13 +54,7 @@ const AgentTools: FC = () => {
|
|||||||
|
|
||||||
const formattingChangedDispatcher = useFormattingChangedDispatcher()
|
const formattingChangedDispatcher = useFormattingChangedDispatcher()
|
||||||
const [currentTool, setCurrentTool] = useState<AgentToolWithMoreInfo>(null)
|
const [currentTool, setCurrentTool] = useState<AgentToolWithMoreInfo>(null)
|
||||||
const currentCollection = useMemo(() => {
|
|
||||||
if (!currentTool) return null
|
|
||||||
const collection = collectionList.find(collection => canFindTool(collection.id, currentTool?.provider_id) && collection.type === currentTool?.provider_type)
|
|
||||||
return collection
|
|
||||||
}, [currentTool, collectionList])
|
|
||||||
const [isShowSettingTool, setIsShowSettingTool] = useState(false)
|
const [isShowSettingTool, setIsShowSettingTool] = useState(false)
|
||||||
const [isShowSettingAuth, setShowSettingAuth] = useState(false)
|
|
||||||
const tools = (modelConfig?.agentConfig?.tools as AgentTool[] || []).map((item) => {
|
const tools = (modelConfig?.agentConfig?.tools as AgentTool[] || []).map((item) => {
|
||||||
const collection = collectionList.find(
|
const collection = collectionList.find(
|
||||||
collection =>
|
collection =>
|
||||||
@@ -100,17 +91,6 @@ const AgentTools: FC = () => {
|
|||||||
formattingChangedDispatcher()
|
formattingChangedDispatcher()
|
||||||
}
|
}
|
||||||
|
|
||||||
const handleToolAuthSetting = (value: AgentToolWithMoreInfo) => {
|
|
||||||
const newModelConfig = produce(modelConfig, (draft) => {
|
|
||||||
const tool = (draft.agentConfig.tools).find((item: any) => item.provider_id === value?.collection?.id && item.tool_name === value?.tool_name)
|
|
||||||
if (tool)
|
|
||||||
(tool as AgentTool).notAuthor = false
|
|
||||||
})
|
|
||||||
setModelConfig(newModelConfig)
|
|
||||||
setIsShowSettingTool(false)
|
|
||||||
formattingChangedDispatcher()
|
|
||||||
}
|
|
||||||
|
|
||||||
const [isDeleting, setIsDeleting] = useState<number>(-1)
|
const [isDeleting, setIsDeleting] = useState<number>(-1)
|
||||||
const getToolValue = (tool: ToolDefaultValue) => {
|
const getToolValue = (tool: ToolDefaultValue) => {
|
||||||
return {
|
return {
|
||||||
@@ -144,6 +124,20 @@ const AgentTools: FC = () => {
|
|||||||
return item.provider_name
|
return item.provider_name
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const handleAuthorizationItemClick = useCallback((credentialId: string) => {
|
||||||
|
const newModelConfig = produce(modelConfig, (draft) => {
|
||||||
|
const tool = (draft.agentConfig.tools).find((item: any) => item.provider_id === currentTool?.provider_id)
|
||||||
|
if (tool)
|
||||||
|
(tool as AgentTool).credential_id = credentialId
|
||||||
|
})
|
||||||
|
setCurrentTool({
|
||||||
|
...currentTool,
|
||||||
|
credential_id: credentialId,
|
||||||
|
} as any)
|
||||||
|
setModelConfig(newModelConfig)
|
||||||
|
formattingChangedDispatcher()
|
||||||
|
}, [currentTool, modelConfig, setModelConfig, formattingChangedDispatcher])
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<Panel
|
<Panel
|
||||||
@@ -299,7 +293,7 @@ const AgentTools: FC = () => {
|
|||||||
{item.notAuthor && (
|
{item.notAuthor && (
|
||||||
<Button variant='secondary' size='small' onClick={() => {
|
<Button variant='secondary' size='small' onClick={() => {
|
||||||
setCurrentTool(item)
|
setCurrentTool(item)
|
||||||
setShowSettingAuth(true)
|
setIsShowSettingTool(true)
|
||||||
}}>
|
}}>
|
||||||
{t('tools.notAuthorized')}
|
{t('tools.notAuthorized')}
|
||||||
<Indicator className='ml-2' color='orange' />
|
<Indicator className='ml-2' color='orange' />
|
||||||
@@ -319,21 +313,8 @@ const AgentTools: FC = () => {
|
|||||||
isModel={currentTool?.collection?.type === CollectionType.model}
|
isModel={currentTool?.collection?.type === CollectionType.model}
|
||||||
onSave={handleToolSettingChange}
|
onSave={handleToolSettingChange}
|
||||||
onHide={() => setIsShowSettingTool(false)}
|
onHide={() => setIsShowSettingTool(false)}
|
||||||
/>
|
credentialId={currentTool?.credential_id}
|
||||||
)}
|
onAuthorizationItemClick={handleAuthorizationItemClick}
|
||||||
{isShowSettingAuth && (
|
|
||||||
<ConfigCredential
|
|
||||||
collection={currentCollection as any}
|
|
||||||
onCancel={() => setShowSettingAuth(false)}
|
|
||||||
onSaved={async (value) => {
|
|
||||||
await updateBuiltInToolCredential((currentCollection as any).name, value)
|
|
||||||
Toast.notify({
|
|
||||||
type: 'success',
|
|
||||||
message: t('common.api.actionSuccess'),
|
|
||||||
})
|
|
||||||
handleToolAuthSetting(currentTool)
|
|
||||||
setShowSettingAuth(false)
|
|
||||||
}}
|
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
</>
|
</>
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ import Icon from '@/app/components/plugins/card/base/card-icon'
|
|||||||
import OrgInfo from '@/app/components/plugins/card/base/org-info'
|
import OrgInfo from '@/app/components/plugins/card/base/org-info'
|
||||||
import Description from '@/app/components/plugins/card/base/description'
|
import Description from '@/app/components/plugins/card/base/description'
|
||||||
import TabSlider from '@/app/components/base/tab-slider-plain'
|
import TabSlider from '@/app/components/base/tab-slider-plain'
|
||||||
|
|
||||||
import Button from '@/app/components/base/button'
|
import Button from '@/app/components/base/button'
|
||||||
import Form from '@/app/components/header/account-setting/model-provider-page/model-modal/Form'
|
import Form from '@/app/components/header/account-setting/model-provider-page/model-modal/Form'
|
||||||
import { addDefaultValue, toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema'
|
import { addDefaultValue, toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema'
|
||||||
@@ -25,6 +24,10 @@ import I18n from '@/context/i18n'
|
|||||||
import { getLanguage } from '@/i18n/language'
|
import { getLanguage } from '@/i18n/language'
|
||||||
import cn from '@/utils/classnames'
|
import cn from '@/utils/classnames'
|
||||||
import type { ToolWithProvider } from '@/app/components/workflow/types'
|
import type { ToolWithProvider } from '@/app/components/workflow/types'
|
||||||
|
import {
|
||||||
|
AuthCategory,
|
||||||
|
PluginAuthInAgent,
|
||||||
|
} from '@/app/components/plugins/plugin-auth'
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
showBackButton?: boolean
|
showBackButton?: boolean
|
||||||
@@ -36,6 +39,8 @@ type Props = {
|
|||||||
readonly?: boolean
|
readonly?: boolean
|
||||||
onHide: () => void
|
onHide: () => void
|
||||||
onSave?: (value: Record<string, any>) => void
|
onSave?: (value: Record<string, any>) => void
|
||||||
|
credentialId?: string
|
||||||
|
onAuthorizationItemClick?: (id: string) => void
|
||||||
}
|
}
|
||||||
|
|
||||||
const SettingBuiltInTool: FC<Props> = ({
|
const SettingBuiltInTool: FC<Props> = ({
|
||||||
@@ -48,6 +53,8 @@ const SettingBuiltInTool: FC<Props> = ({
|
|||||||
readonly,
|
readonly,
|
||||||
onHide,
|
onHide,
|
||||||
onSave,
|
onSave,
|
||||||
|
credentialId,
|
||||||
|
onAuthorizationItemClick,
|
||||||
}) => {
|
}) => {
|
||||||
const { locale } = useContext(I18n)
|
const { locale } = useContext(I18n)
|
||||||
const language = getLanguage(locale)
|
const language = getLanguage(locale)
|
||||||
@@ -197,8 +204,20 @@ const SettingBuiltInTool: FC<Props> = ({
|
|||||||
</div>
|
</div>
|
||||||
<div className='system-md-semibold mt-1 text-text-primary'>{currTool?.label[language]}</div>
|
<div className='system-md-semibold mt-1 text-text-primary'>{currTool?.label[language]}</div>
|
||||||
{!!currTool?.description[language] && (
|
{!!currTool?.description[language] && (
|
||||||
<Description className='mt-3' text={currTool.description[language]} descriptionLineRows={2}></Description>
|
<Description className='mb-2 mt-3 h-auto' text={currTool.description[language]} descriptionLineRows={2}></Description>
|
||||||
)}
|
)}
|
||||||
|
{
|
||||||
|
collection.allow_delete && collection.type === CollectionType.builtIn && (
|
||||||
|
<PluginAuthInAgent
|
||||||
|
pluginPayload={{
|
||||||
|
provider: collection.name,
|
||||||
|
category: AuthCategory.tool,
|
||||||
|
}}
|
||||||
|
credentialId={credentialId}
|
||||||
|
onAuthorizationItemClick={onAuthorizationItemClick}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
</div>
|
</div>
|
||||||
{/* form */}
|
{/* form */}
|
||||||
<div className='h-full'>
|
<div className='h-full'>
|
||||||
|
|||||||
@@ -117,7 +117,7 @@ const Question: FC<QuestionProps> = ({
|
|||||||
</div>
|
</div>
|
||||||
<div
|
<div
|
||||||
ref={contentRef}
|
ref={contentRef}
|
||||||
className='w-full rounded-2xl bg-background-gradient-bg-fill-chat-bubble-bg-3 px-4 py-3 text-sm text-text-primary'
|
className='bg-background-gradient-bg-fill-chat-bubble-bg-3 w-full rounded-2xl px-4 py-3 text-sm text-text-primary'
|
||||||
style={theme?.chatBubbleColorStyle ? CssTransform(theme.chatBubbleColorStyle) : {}}
|
style={theme?.chatBubbleColorStyle ? CssTransform(theme.chatBubbleColorStyle) : {}}
|
||||||
>
|
>
|
||||||
{
|
{
|
||||||
|
|||||||
177
web/app/components/base/form/components/base/base-field.tsx
Normal file
177
web/app/components/base/form/components/base/base-field.tsx
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
import {
|
||||||
|
isValidElement,
|
||||||
|
memo,
|
||||||
|
useMemo,
|
||||||
|
} from 'react'
|
||||||
|
import type { AnyFieldApi } from '@tanstack/react-form'
|
||||||
|
import { useStore } from '@tanstack/react-form'
|
||||||
|
import cn from '@/utils/classnames'
|
||||||
|
import Input from '@/app/components/base/input'
|
||||||
|
import PureSelect from '@/app/components/base/select/pure'
|
||||||
|
import type { FormSchema } from '@/app/components/base/form/types'
|
||||||
|
import { FormTypeEnum } from '@/app/components/base/form/types'
|
||||||
|
import { useRenderI18nObject } from '@/hooks/use-i18n'
|
||||||
|
|
||||||
|
export type BaseFieldProps = {
|
||||||
|
fieldClassName?: string
|
||||||
|
labelClassName?: string
|
||||||
|
inputContainerClassName?: string
|
||||||
|
inputClassName?: string
|
||||||
|
formSchema: FormSchema
|
||||||
|
field: AnyFieldApi
|
||||||
|
disabled?: boolean
|
||||||
|
}
|
||||||
|
const BaseField = ({
|
||||||
|
fieldClassName,
|
||||||
|
labelClassName,
|
||||||
|
inputContainerClassName,
|
||||||
|
inputClassName,
|
||||||
|
formSchema,
|
||||||
|
field,
|
||||||
|
disabled,
|
||||||
|
}: BaseFieldProps) => {
|
||||||
|
const renderI18nObject = useRenderI18nObject()
|
||||||
|
const {
|
||||||
|
label,
|
||||||
|
required,
|
||||||
|
placeholder,
|
||||||
|
options,
|
||||||
|
labelClassName: formLabelClassName,
|
||||||
|
show_on = [],
|
||||||
|
} = formSchema
|
||||||
|
|
||||||
|
const memorizedLabel = useMemo(() => {
|
||||||
|
if (isValidElement(label))
|
||||||
|
return label
|
||||||
|
|
||||||
|
if (typeof label === 'string')
|
||||||
|
return label
|
||||||
|
|
||||||
|
if (typeof label === 'object' && label !== null)
|
||||||
|
return renderI18nObject(label as Record<string, string>)
|
||||||
|
}, [label, renderI18nObject])
|
||||||
|
const memorizedPlaceholder = useMemo(() => {
|
||||||
|
if (typeof placeholder === 'string')
|
||||||
|
return placeholder
|
||||||
|
|
||||||
|
if (typeof placeholder === 'object' && placeholder !== null)
|
||||||
|
return renderI18nObject(placeholder as Record<string, string>)
|
||||||
|
}, [placeholder, renderI18nObject])
|
||||||
|
const memorizedOptions = useMemo(() => {
|
||||||
|
return options?.map((option) => {
|
||||||
|
return {
|
||||||
|
label: typeof option.label === 'string' ? option.label : renderI18nObject(option.label),
|
||||||
|
value: option.value,
|
||||||
|
}
|
||||||
|
}) || []
|
||||||
|
}, [options, renderI18nObject])
|
||||||
|
const value = useStore(field.form.store, s => s.values[field.name])
|
||||||
|
const values = useStore(field.form.store, (s) => {
|
||||||
|
return show_on.reduce((acc, condition) => {
|
||||||
|
acc[condition.variable] = s.values[condition.variable]
|
||||||
|
return acc
|
||||||
|
}, {} as Record<string, any>)
|
||||||
|
})
|
||||||
|
const show = useMemo(() => {
|
||||||
|
return show_on.every((condition) => {
|
||||||
|
const conditionValue = values[condition.variable]
|
||||||
|
return conditionValue === condition.value
|
||||||
|
})
|
||||||
|
}, [values, show_on])
|
||||||
|
|
||||||
|
if (!show)
|
||||||
|
return null
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className={cn(fieldClassName)}>
|
||||||
|
<div className={cn(labelClassName, formLabelClassName)}>
|
||||||
|
{memorizedLabel}
|
||||||
|
{
|
||||||
|
required && !isValidElement(label) && (
|
||||||
|
<span className='ml-1 text-text-destructive-secondary'>*</span>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
</div>
|
||||||
|
<div className={cn(inputContainerClassName)}>
|
||||||
|
{
|
||||||
|
formSchema.type === FormTypeEnum.textInput && (
|
||||||
|
<Input
|
||||||
|
id={field.name}
|
||||||
|
name={field.name}
|
||||||
|
className={cn(inputClassName)}
|
||||||
|
value={value || ''}
|
||||||
|
onChange={e => field.handleChange(e.target.value)}
|
||||||
|
onBlur={field.handleBlur}
|
||||||
|
disabled={disabled}
|
||||||
|
placeholder={memorizedPlaceholder}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
formSchema.type === FormTypeEnum.secretInput && (
|
||||||
|
<Input
|
||||||
|
id={field.name}
|
||||||
|
name={field.name}
|
||||||
|
type='password'
|
||||||
|
className={cn(inputClassName)}
|
||||||
|
value={value || ''}
|
||||||
|
onChange={e => field.handleChange(e.target.value)}
|
||||||
|
onBlur={field.handleBlur}
|
||||||
|
disabled={disabled}
|
||||||
|
placeholder={memorizedPlaceholder}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
formSchema.type === FormTypeEnum.textNumber && (
|
||||||
|
<Input
|
||||||
|
id={field.name}
|
||||||
|
name={field.name}
|
||||||
|
type='number'
|
||||||
|
className={cn(inputClassName)}
|
||||||
|
value={value || ''}
|
||||||
|
onChange={e => field.handleChange(e.target.value)}
|
||||||
|
onBlur={field.handleBlur}
|
||||||
|
disabled={disabled}
|
||||||
|
placeholder={memorizedPlaceholder}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
formSchema.type === FormTypeEnum.select && (
|
||||||
|
<PureSelect
|
||||||
|
value={value}
|
||||||
|
onChange={v => field.handleChange(v)}
|
||||||
|
disabled={disabled}
|
||||||
|
placeholder={memorizedPlaceholder}
|
||||||
|
options={memorizedOptions}
|
||||||
|
triggerPopupSameWidth
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
formSchema.type === FormTypeEnum.radio && (
|
||||||
|
<div className='flex items-center space-x-2'>
|
||||||
|
{
|
||||||
|
memorizedOptions.map(option => (
|
||||||
|
<div
|
||||||
|
key={option.value}
|
||||||
|
className={cn(
|
||||||
|
'system-sm-regular hover:bg-components-option-card-option-hover-bg hover:border-components-option-card-option-hover-border flex h-8 grow cursor-pointer items-center justify-center rounded-lg border border-components-option-card-option-border bg-components-option-card-option-bg p-2 text-text-secondary',
|
||||||
|
value === option.value && 'border-components-option-card-option-selected-border bg-components-option-card-option-selected-bg text-text-primary shadow-xs',
|
||||||
|
)}
|
||||||
|
onClick={() => field.handleChange(option.value)}
|
||||||
|
>
|
||||||
|
{option.label}
|
||||||
|
</div>
|
||||||
|
))
|
||||||
|
}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default memo(BaseField)
|
||||||
115
web/app/components/base/form/components/base/base-form.tsx
Normal file
115
web/app/components/base/form/components/base/base-form.tsx
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
import {
|
||||||
|
memo,
|
||||||
|
useCallback,
|
||||||
|
useImperativeHandle,
|
||||||
|
} from 'react'
|
||||||
|
import type {
|
||||||
|
AnyFieldApi,
|
||||||
|
AnyFormApi,
|
||||||
|
} from '@tanstack/react-form'
|
||||||
|
import { useForm } from '@tanstack/react-form'
|
||||||
|
import type {
|
||||||
|
FormRef,
|
||||||
|
FormSchema,
|
||||||
|
} from '@/app/components/base/form/types'
|
||||||
|
import {
|
||||||
|
BaseField,
|
||||||
|
} from '.'
|
||||||
|
import type {
|
||||||
|
BaseFieldProps,
|
||||||
|
} from '.'
|
||||||
|
import cn from '@/utils/classnames'
|
||||||
|
import {
|
||||||
|
useGetFormValues,
|
||||||
|
useGetValidators,
|
||||||
|
} from '@/app/components/base/form/hooks'
|
||||||
|
|
||||||
|
export type BaseFormProps = {
|
||||||
|
formSchemas?: FormSchema[]
|
||||||
|
defaultValues?: Record<string, any>
|
||||||
|
formClassName?: string
|
||||||
|
ref?: FormRef
|
||||||
|
disabled?: boolean
|
||||||
|
formFromProps?: AnyFormApi
|
||||||
|
} & Pick<BaseFieldProps, 'fieldClassName' | 'labelClassName' | 'inputContainerClassName' | 'inputClassName'>
|
||||||
|
|
||||||
|
const BaseForm = ({
|
||||||
|
formSchemas = [],
|
||||||
|
defaultValues,
|
||||||
|
formClassName,
|
||||||
|
fieldClassName,
|
||||||
|
labelClassName,
|
||||||
|
inputContainerClassName,
|
||||||
|
inputClassName,
|
||||||
|
ref,
|
||||||
|
disabled,
|
||||||
|
formFromProps,
|
||||||
|
}: BaseFormProps) => {
|
||||||
|
const formFromHook = useForm({
|
||||||
|
defaultValues,
|
||||||
|
})
|
||||||
|
const form: any = formFromProps || formFromHook
|
||||||
|
const { getFormValues } = useGetFormValues(form, formSchemas)
|
||||||
|
const { getValidators } = useGetValidators()
|
||||||
|
|
||||||
|
useImperativeHandle(ref, () => {
|
||||||
|
return {
|
||||||
|
getForm() {
|
||||||
|
return form
|
||||||
|
},
|
||||||
|
getFormValues: (option) => {
|
||||||
|
return getFormValues(option)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}, [form, getFormValues])
|
||||||
|
|
||||||
|
const renderField = useCallback((field: AnyFieldApi) => {
|
||||||
|
const formSchema = formSchemas?.find(schema => schema.name === field.name)
|
||||||
|
|
||||||
|
if (formSchema) {
|
||||||
|
return (
|
||||||
|
<BaseField
|
||||||
|
field={field}
|
||||||
|
formSchema={formSchema}
|
||||||
|
fieldClassName={fieldClassName}
|
||||||
|
labelClassName={labelClassName}
|
||||||
|
inputContainerClassName={inputContainerClassName}
|
||||||
|
inputClassName={inputClassName}
|
||||||
|
disabled={disabled}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return null
|
||||||
|
}, [formSchemas, fieldClassName, labelClassName, inputContainerClassName, inputClassName, disabled])
|
||||||
|
|
||||||
|
const renderFieldWrapper = useCallback((formSchema: FormSchema) => {
|
||||||
|
const validators = getValidators(formSchema)
|
||||||
|
const {
|
||||||
|
name,
|
||||||
|
} = formSchema
|
||||||
|
|
||||||
|
return (
|
||||||
|
<form.Field
|
||||||
|
key={name}
|
||||||
|
name={name}
|
||||||
|
validators={validators}
|
||||||
|
>
|
||||||
|
{renderField}
|
||||||
|
</form.Field>
|
||||||
|
)
|
||||||
|
}, [renderField, form, getValidators])
|
||||||
|
|
||||||
|
if (!formSchemas?.length)
|
||||||
|
return null
|
||||||
|
|
||||||
|
return (
|
||||||
|
<form
|
||||||
|
className={cn(formClassName)}
|
||||||
|
>
|
||||||
|
{formSchemas.map(renderFieldWrapper)}
|
||||||
|
</form>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default memo(BaseForm)
|
||||||
2
web/app/components/base/form/components/base/index.tsx
Normal file
2
web/app/components/base/form/components/base/index.tsx
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
export { default as BaseForm, type BaseFormProps } from './base-form'
|
||||||
|
export { default as BaseField, type BaseFieldProps } from './base-field'
|
||||||
23
web/app/components/base/form/form-scenarios/auth/index.tsx
Normal file
23
web/app/components/base/form/form-scenarios/auth/index.tsx
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
import { memo } from 'react'
|
||||||
|
import { BaseForm } from '../../components/base'
|
||||||
|
import type { BaseFormProps } from '../../components/base'
|
||||||
|
|
||||||
|
const AuthForm = ({
|
||||||
|
formSchemas = [],
|
||||||
|
defaultValues,
|
||||||
|
ref,
|
||||||
|
formFromProps,
|
||||||
|
}: BaseFormProps) => {
|
||||||
|
return (
|
||||||
|
<BaseForm
|
||||||
|
ref={ref}
|
||||||
|
formSchemas={formSchemas}
|
||||||
|
defaultValues={defaultValues}
|
||||||
|
formClassName='space-y-4'
|
||||||
|
labelClassName='h-6 flex items-center mb-1 system-sm-medium text-text-secondary'
|
||||||
|
formFromProps={formFromProps}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default memo(AuthForm)
|
||||||
3
web/app/components/base/form/hooks/index.ts
Normal file
3
web/app/components/base/form/hooks/index.ts
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
export * from './use-check-validated'
|
||||||
|
export * from './use-get-form-values'
|
||||||
|
export * from './use-get-validators'
|
||||||
48
web/app/components/base/form/hooks/use-check-validated.ts
Normal file
48
web/app/components/base/form/hooks/use-check-validated.ts
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
import { useCallback } from 'react'
|
||||||
|
import type { AnyFormApi } from '@tanstack/react-form'
|
||||||
|
import { useToastContext } from '@/app/components/base/toast'
|
||||||
|
import type { FormSchema } from '@/app/components/base/form/types'
|
||||||
|
|
||||||
|
export const useCheckValidated = (form: AnyFormApi, FormSchemas: FormSchema[]) => {
|
||||||
|
const { notify } = useToastContext()
|
||||||
|
|
||||||
|
const checkValidated = useCallback(() => {
|
||||||
|
const allError = form?.getAllErrors()
|
||||||
|
const values = form.state.values
|
||||||
|
|
||||||
|
if (allError) {
|
||||||
|
const fields = allError.fields
|
||||||
|
const errorArray = Object.keys(fields).reduce((acc: string[], key: string) => {
|
||||||
|
const currentSchema = FormSchemas.find(schema => schema.name === key)
|
||||||
|
const { show_on = [] } = currentSchema || {}
|
||||||
|
const showOnValues = show_on.reduce((acc, condition) => {
|
||||||
|
acc[condition.variable] = values[condition.variable]
|
||||||
|
return acc
|
||||||
|
}, {} as Record<string, any>)
|
||||||
|
const show = show_on?.every((condition) => {
|
||||||
|
const conditionValue = showOnValues[condition.variable]
|
||||||
|
return conditionValue === condition.value
|
||||||
|
})
|
||||||
|
const errors: any[] = show ? fields[key].errors : []
|
||||||
|
|
||||||
|
return [...acc, ...errors]
|
||||||
|
}, [] as string[])
|
||||||
|
|
||||||
|
if (errorArray.length) {
|
||||||
|
notify({
|
||||||
|
type: 'error',
|
||||||
|
message: errorArray[0],
|
||||||
|
})
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}, [form, notify, FormSchemas])
|
||||||
|
|
||||||
|
return {
|
||||||
|
checkValidated,
|
||||||
|
}
|
||||||
|
}
|
||||||
44
web/app/components/base/form/hooks/use-get-form-values.ts
Normal file
44
web/app/components/base/form/hooks/use-get-form-values.ts
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
import { useCallback } from 'react'
|
||||||
|
import type { AnyFormApi } from '@tanstack/react-form'
|
||||||
|
import { useCheckValidated } from './use-check-validated'
|
||||||
|
import type {
|
||||||
|
FormSchema,
|
||||||
|
GetValuesOptions,
|
||||||
|
} from '../types'
|
||||||
|
import { getTransformedValuesWhenSecretInputPristine } from '../utils'
|
||||||
|
|
||||||
|
export const useGetFormValues = (form: AnyFormApi, formSchemas: FormSchema[]) => {
|
||||||
|
const { checkValidated } = useCheckValidated(form, formSchemas)
|
||||||
|
|
||||||
|
const getFormValues = useCallback((
|
||||||
|
{
|
||||||
|
needCheckValidatedValues,
|
||||||
|
needTransformWhenSecretFieldIsPristine,
|
||||||
|
}: GetValuesOptions,
|
||||||
|
) => {
|
||||||
|
const values = form?.store.state.values || {}
|
||||||
|
if (!needCheckValidatedValues) {
|
||||||
|
return {
|
||||||
|
values,
|
||||||
|
isCheckValidated: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (checkValidated()) {
|
||||||
|
return {
|
||||||
|
values: needTransformWhenSecretFieldIsPristine ? getTransformedValuesWhenSecretInputPristine(formSchemas, form) : values,
|
||||||
|
isCheckValidated: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
return {
|
||||||
|
values: {},
|
||||||
|
isCheckValidated: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, [form, checkValidated, formSchemas])
|
||||||
|
|
||||||
|
return {
|
||||||
|
getFormValues,
|
||||||
|
}
|
||||||
|
}
|
||||||
36
web/app/components/base/form/hooks/use-get-validators.ts
Normal file
36
web/app/components/base/form/hooks/use-get-validators.ts
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
import { useCallback } from 'react'
|
||||||
|
import { useTranslation } from 'react-i18next'
|
||||||
|
import type { FormSchema } from '../types'
|
||||||
|
|
||||||
|
export const useGetValidators = () => {
|
||||||
|
const { t } = useTranslation()
|
||||||
|
const getValidators = useCallback((formSchema: FormSchema) => {
|
||||||
|
const {
|
||||||
|
name,
|
||||||
|
validators,
|
||||||
|
required,
|
||||||
|
} = formSchema
|
||||||
|
let mergedValidators = validators
|
||||||
|
if (required && !validators) {
|
||||||
|
mergedValidators = {
|
||||||
|
onMount: ({ value }: any) => {
|
||||||
|
if (!value)
|
||||||
|
return t('common.errorMsg.fieldRequired', { field: name })
|
||||||
|
},
|
||||||
|
onChange: ({ value }: any) => {
|
||||||
|
if (!value)
|
||||||
|
return t('common.errorMsg.fieldRequired', { field: name })
|
||||||
|
},
|
||||||
|
onBlur: ({ value }: any) => {
|
||||||
|
if (!value)
|
||||||
|
return t('common.errorMsg.fieldRequired', { field: name })
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return mergedValidators
|
||||||
|
}, [t])
|
||||||
|
|
||||||
|
return {
|
||||||
|
getValidators,
|
||||||
|
}
|
||||||
|
}
|
||||||
76
web/app/components/base/form/types.ts
Normal file
76
web/app/components/base/form/types.ts
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
import type {
|
||||||
|
ForwardedRef,
|
||||||
|
ReactNode,
|
||||||
|
} from 'react'
|
||||||
|
import type {
|
||||||
|
AnyFormApi,
|
||||||
|
FieldValidators,
|
||||||
|
} from '@tanstack/react-form'
|
||||||
|
|
||||||
|
export type TypeWithI18N<T = string> = {
|
||||||
|
en_US: T
|
||||||
|
zh_Hans: T
|
||||||
|
[key: string]: T
|
||||||
|
}
|
||||||
|
|
||||||
|
export type FormShowOnObject = {
|
||||||
|
variable: string
|
||||||
|
value: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export enum FormTypeEnum {
|
||||||
|
textInput = 'text-input',
|
||||||
|
textNumber = 'number-input',
|
||||||
|
secretInput = 'secret-input',
|
||||||
|
select = 'select',
|
||||||
|
radio = 'radio',
|
||||||
|
boolean = 'boolean',
|
||||||
|
files = 'files',
|
||||||
|
file = 'file',
|
||||||
|
modelSelector = 'model-selector',
|
||||||
|
toolSelector = 'tool-selector',
|
||||||
|
multiToolSelector = 'array[tools]',
|
||||||
|
appSelector = 'app-selector',
|
||||||
|
dynamicSelect = 'dynamic-select',
|
||||||
|
}
|
||||||
|
|
||||||
|
export type FormOption = {
|
||||||
|
label: TypeWithI18N | string
|
||||||
|
value: string
|
||||||
|
show_on?: FormShowOnObject[]
|
||||||
|
icon?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export type AnyValidators = FieldValidators<any, any, any, any, any, any, any, any, any, any>
|
||||||
|
|
||||||
|
export type FormSchema = {
|
||||||
|
type: FormTypeEnum
|
||||||
|
name: string
|
||||||
|
label: string | ReactNode | TypeWithI18N
|
||||||
|
required: boolean
|
||||||
|
default?: any
|
||||||
|
tooltip?: string | TypeWithI18N
|
||||||
|
show_on?: FormShowOnObject[]
|
||||||
|
url?: string
|
||||||
|
scope?: string
|
||||||
|
help?: string | TypeWithI18N
|
||||||
|
placeholder?: string | TypeWithI18N
|
||||||
|
options?: FormOption[]
|
||||||
|
labelClassName?: string
|
||||||
|
validators?: AnyValidators
|
||||||
|
}
|
||||||
|
|
||||||
|
export type FormValues = Record<string, any>
|
||||||
|
|
||||||
|
export type GetValuesOptions = {
|
||||||
|
needTransformWhenSecretFieldIsPristine?: boolean
|
||||||
|
needCheckValidatedValues?: boolean
|
||||||
|
}
|
||||||
|
export type FormRefObject = {
|
||||||
|
getForm: () => AnyFormApi
|
||||||
|
getFormValues: (obj: GetValuesOptions) => {
|
||||||
|
values: Record<string, any>
|
||||||
|
isCheckValidated: boolean
|
||||||
|
}
|
||||||
|
}
|
||||||
|
export type FormRef = ForwardedRef<FormRefObject>
|
||||||
1
web/app/components/base/form/utils/index.ts
Normal file
1
web/app/components/base/form/utils/index.ts
Normal file
@@ -0,0 +1 @@
|
|||||||
|
export * from './secret-input'
|
||||||
29
web/app/components/base/form/utils/secret-input/index.ts
Normal file
29
web/app/components/base/form/utils/secret-input/index.ts
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
import type { AnyFormApi } from '@tanstack/react-form'
|
||||||
|
import type { FormSchema } from '@/app/components/base/form/types'
|
||||||
|
import { FormTypeEnum } from '@/app/components/base/form/types'
|
||||||
|
|
||||||
|
export const transformFormSchemasSecretInput = (isPristineSecretInputNames: string[], values: Record<string, any>) => {
|
||||||
|
const transformedValues: Record<string, any> = { ...values }
|
||||||
|
|
||||||
|
isPristineSecretInputNames.forEach((name) => {
|
||||||
|
if (transformedValues[name])
|
||||||
|
transformedValues[name] = '[__HIDDEN__]'
|
||||||
|
})
|
||||||
|
|
||||||
|
return transformedValues
|
||||||
|
}
|
||||||
|
|
||||||
|
export const getTransformedValuesWhenSecretInputPristine = (formSchemas: FormSchema[], form: AnyFormApi) => {
|
||||||
|
const values = form?.store.state.values || {}
|
||||||
|
const isPristineSecretInputNames: string[] = []
|
||||||
|
for (let i = 0; i < formSchemas.length; i++) {
|
||||||
|
const schema = formSchemas[i]
|
||||||
|
if (schema.type === FormTypeEnum.secretInput) {
|
||||||
|
const fieldMeta = form?.getFieldMeta(schema.name)
|
||||||
|
if (fieldMeta?.isPristine)
|
||||||
|
isPristineSecretInputNames.push(schema.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return transformFormSchemasSecretInput(isPristineSecretInputNames, values)
|
||||||
|
}
|
||||||
127
web/app/components/base/modal/modal.tsx
Normal file
127
web/app/components/base/modal/modal.tsx
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
import { memo } from 'react'
|
||||||
|
import { useTranslation } from 'react-i18next'
|
||||||
|
import { RiCloseLine } from '@remixicon/react'
|
||||||
|
import {
|
||||||
|
PortalToFollowElem,
|
||||||
|
PortalToFollowElemContent,
|
||||||
|
} from '@/app/components/base/portal-to-follow-elem'
|
||||||
|
import Button from '@/app/components/base/button'
|
||||||
|
import type { ButtonProps } from '@/app/components/base/button'
|
||||||
|
import cn from '@/utils/classnames'
|
||||||
|
|
||||||
|
type ModalProps = {
|
||||||
|
onClose?: () => void
|
||||||
|
size?: 'sm' | 'md'
|
||||||
|
title: string
|
||||||
|
subTitle?: string
|
||||||
|
children?: React.ReactNode
|
||||||
|
confirmButtonText?: string
|
||||||
|
onConfirm?: () => void
|
||||||
|
cancelButtonText?: string
|
||||||
|
onCancel?: () => void
|
||||||
|
showExtraButton?: boolean
|
||||||
|
extraButtonText?: string
|
||||||
|
extraButtonVariant?: ButtonProps['variant']
|
||||||
|
onExtraButtonClick?: () => void
|
||||||
|
footerSlot?: React.ReactNode
|
||||||
|
bottomSlot?: React.ReactNode
|
||||||
|
disabled?: boolean
|
||||||
|
}
|
||||||
|
const Modal = ({
|
||||||
|
onClose,
|
||||||
|
size = 'sm',
|
||||||
|
title,
|
||||||
|
subTitle,
|
||||||
|
children,
|
||||||
|
confirmButtonText,
|
||||||
|
onConfirm,
|
||||||
|
cancelButtonText,
|
||||||
|
onCancel,
|
||||||
|
showExtraButton,
|
||||||
|
extraButtonVariant = 'warning',
|
||||||
|
extraButtonText,
|
||||||
|
onExtraButtonClick,
|
||||||
|
footerSlot,
|
||||||
|
bottomSlot,
|
||||||
|
disabled,
|
||||||
|
}: ModalProps) => {
|
||||||
|
const { t } = useTranslation()
|
||||||
|
|
||||||
|
return (
|
||||||
|
<PortalToFollowElem open>
|
||||||
|
<PortalToFollowElemContent
|
||||||
|
className='z-[9998] flex h-full w-full items-center justify-center bg-background-overlay'
|
||||||
|
onClick={onClose}
|
||||||
|
>
|
||||||
|
<div
|
||||||
|
className={cn(
|
||||||
|
'max-h-[80%] w-[480px] overflow-y-auto rounded-2xl border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-xs',
|
||||||
|
size === 'sm' && 'w-[480px',
|
||||||
|
size === 'md' && 'w-[640px]',
|
||||||
|
)}
|
||||||
|
onClick={e => e.stopPropagation()}
|
||||||
|
>
|
||||||
|
<div className='title-2xl-semi-bold relative p-6 pb-3 pr-14 text-text-primary'>
|
||||||
|
{title}
|
||||||
|
{
|
||||||
|
subTitle && (
|
||||||
|
<div className='system-xs-regular mt-1 text-text-tertiary'>
|
||||||
|
{subTitle}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
<div
|
||||||
|
className='absolute right-5 top-5 flex h-8 w-8 cursor-pointer items-center justify-center rounded-lg'
|
||||||
|
onClick={onClose}
|
||||||
|
>
|
||||||
|
<RiCloseLine className='h-5 w-5 text-text-tertiary' />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{
|
||||||
|
children && (
|
||||||
|
<div className='px-6 py-3'>{children}</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
<div className='flex justify-between p-6 pt-5'>
|
||||||
|
<div>
|
||||||
|
{footerSlot}
|
||||||
|
</div>
|
||||||
|
<div className='flex items-center'>
|
||||||
|
{
|
||||||
|
showExtraButton && (
|
||||||
|
<>
|
||||||
|
<Button
|
||||||
|
variant={extraButtonVariant}
|
||||||
|
onClick={onExtraButtonClick}
|
||||||
|
disabled={disabled}
|
||||||
|
>
|
||||||
|
{extraButtonText || t('common.operation.remove')}
|
||||||
|
</Button>
|
||||||
|
<div className='mx-3 h-4 w-[1px] bg-divider-regular'></div>
|
||||||
|
</>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
<Button
|
||||||
|
onClick={onCancel}
|
||||||
|
disabled={disabled}
|
||||||
|
>
|
||||||
|
{cancelButtonText || t('common.operation.cancel')}
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
className='ml-2'
|
||||||
|
variant='primary'
|
||||||
|
onClick={onConfirm}
|
||||||
|
disabled={disabled}
|
||||||
|
>
|
||||||
|
{confirmButtonText || t('common.operation.save')}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{bottomSlot}
|
||||||
|
</div>
|
||||||
|
</PortalToFollowElemContent>
|
||||||
|
</PortalToFollowElem>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default memo(Modal)
|
||||||
@@ -39,6 +39,9 @@ type PureSelectProps = {
|
|||||||
itemClassName?: string
|
itemClassName?: string
|
||||||
title?: string
|
title?: string
|
||||||
},
|
},
|
||||||
|
placeholder?: string
|
||||||
|
disabled?: boolean
|
||||||
|
triggerPopupSameWidth?: boolean
|
||||||
}
|
}
|
||||||
const PureSelect = ({
|
const PureSelect = ({
|
||||||
options,
|
options,
|
||||||
@@ -47,6 +50,9 @@ const PureSelect = ({
|
|||||||
containerProps,
|
containerProps,
|
||||||
triggerProps,
|
triggerProps,
|
||||||
popupProps,
|
popupProps,
|
||||||
|
placeholder,
|
||||||
|
disabled,
|
||||||
|
triggerPopupSameWidth,
|
||||||
}: PureSelectProps) => {
|
}: PureSelectProps) => {
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
const {
|
const {
|
||||||
@@ -74,7 +80,7 @@ const PureSelect = ({
|
|||||||
}, [onOpenChange])
|
}, [onOpenChange])
|
||||||
|
|
||||||
const selectedOption = options.find(option => option.value === value)
|
const selectedOption = options.find(option => option.value === value)
|
||||||
const triggerText = selectedOption?.label || t('common.placeholder.select')
|
const triggerText = selectedOption?.label || placeholder || t('common.placeholder.select')
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<PortalToFollowElem
|
<PortalToFollowElem
|
||||||
@@ -82,6 +88,7 @@ const PureSelect = ({
|
|||||||
offset={offset || 4}
|
offset={offset || 4}
|
||||||
open={mergedOpen}
|
open={mergedOpen}
|
||||||
onOpenChange={handleOpenChange}
|
onOpenChange={handleOpenChange}
|
||||||
|
triggerPopupSameWidth={triggerPopupSameWidth}
|
||||||
>
|
>
|
||||||
<PortalToFollowElemTrigger
|
<PortalToFollowElemTrigger
|
||||||
onClick={() => handleOpenChange(!mergedOpen)}
|
onClick={() => handleOpenChange(!mergedOpen)}
|
||||||
@@ -135,6 +142,7 @@ const PureSelect = ({
|
|||||||
)}
|
)}
|
||||||
title={option.label}
|
title={option.label}
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
|
if (disabled) return
|
||||||
onChange?.(option.value)
|
onChange?.(option.value)
|
||||||
handleOpenChange(false)
|
handleOpenChange(false)
|
||||||
}}
|
}}
|
||||||
|
|||||||
@@ -0,0 +1,50 @@
|
|||||||
|
import {
|
||||||
|
memo,
|
||||||
|
useState,
|
||||||
|
} from 'react'
|
||||||
|
import Button from '@/app/components/base/button'
|
||||||
|
import type { ButtonProps } from '@/app/components/base/button'
|
||||||
|
import ApiKeyModal from './api-key-modal'
|
||||||
|
import type { PluginPayload } from '../types'
|
||||||
|
|
||||||
|
export type AddApiKeyButtonProps = {
|
||||||
|
pluginPayload: PluginPayload
|
||||||
|
buttonVariant?: ButtonProps['variant']
|
||||||
|
buttonText?: string
|
||||||
|
disabled?: boolean
|
||||||
|
onUpdate?: () => void
|
||||||
|
}
|
||||||
|
const AddApiKeyButton = ({
|
||||||
|
pluginPayload,
|
||||||
|
buttonVariant = 'secondary-accent',
|
||||||
|
buttonText = 'use api key',
|
||||||
|
disabled,
|
||||||
|
onUpdate,
|
||||||
|
}: AddApiKeyButtonProps) => {
|
||||||
|
const [isApiKeyModalOpen, setIsApiKeyModalOpen] = useState(false)
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<Button
|
||||||
|
className='w-full'
|
||||||
|
variant={buttonVariant}
|
||||||
|
onClick={() => setIsApiKeyModalOpen(true)}
|
||||||
|
disabled={disabled}
|
||||||
|
>
|
||||||
|
{buttonText}
|
||||||
|
</Button>
|
||||||
|
{
|
||||||
|
isApiKeyModalOpen && (
|
||||||
|
<ApiKeyModal
|
||||||
|
pluginPayload={pluginPayload}
|
||||||
|
onClose={() => setIsApiKeyModalOpen(false)}
|
||||||
|
onUpdate={onUpdate}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
</>
|
||||||
|
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default memo(AddApiKeyButton)
|
||||||
@@ -0,0 +1,259 @@
|
|||||||
|
import {
|
||||||
|
memo,
|
||||||
|
useCallback,
|
||||||
|
useMemo,
|
||||||
|
useState,
|
||||||
|
} from 'react'
|
||||||
|
import { useTranslation } from 'react-i18next'
|
||||||
|
import {
|
||||||
|
RiClipboardLine,
|
||||||
|
RiEqualizer2Line,
|
||||||
|
RiInformation2Fill,
|
||||||
|
} from '@remixicon/react'
|
||||||
|
import Button from '@/app/components/base/button'
|
||||||
|
import type { ButtonProps } from '@/app/components/base/button'
|
||||||
|
import OAuthClientSettings from './oauth-client-settings'
|
||||||
|
import cn from '@/utils/classnames'
|
||||||
|
import type { PluginPayload } from '../types'
|
||||||
|
import { openOAuthPopup } from '@/hooks/use-oauth'
|
||||||
|
import Badge from '@/app/components/base/badge'
|
||||||
|
import {
|
||||||
|
useGetPluginOAuthClientSchemaHook,
|
||||||
|
useGetPluginOAuthUrlHook,
|
||||||
|
} from '../hooks/use-credential'
|
||||||
|
import type { FormSchema } from '@/app/components/base/form/types'
|
||||||
|
import { FormTypeEnum } from '@/app/components/base/form/types'
|
||||||
|
import ActionButton from '@/app/components/base/action-button'
|
||||||
|
import { useRenderI18nObject } from '@/hooks/use-i18n'
|
||||||
|
|
||||||
|
export type AddOAuthButtonProps = {
|
||||||
|
pluginPayload: PluginPayload
|
||||||
|
buttonVariant?: ButtonProps['variant']
|
||||||
|
buttonText?: string
|
||||||
|
className?: string
|
||||||
|
buttonLeftClassName?: string
|
||||||
|
buttonRightClassName?: string
|
||||||
|
dividerClassName?: string
|
||||||
|
disabled?: boolean
|
||||||
|
onUpdate?: () => void
|
||||||
|
}
|
||||||
|
const AddOAuthButton = ({
|
||||||
|
pluginPayload,
|
||||||
|
buttonVariant = 'primary',
|
||||||
|
buttonText = 'use oauth',
|
||||||
|
className,
|
||||||
|
buttonLeftClassName,
|
||||||
|
buttonRightClassName,
|
||||||
|
dividerClassName,
|
||||||
|
disabled,
|
||||||
|
onUpdate,
|
||||||
|
}: AddOAuthButtonProps) => {
|
||||||
|
const { t } = useTranslation()
|
||||||
|
const renderI18nObject = useRenderI18nObject()
|
||||||
|
const [isOAuthSettingsOpen, setIsOAuthSettingsOpen] = useState(false)
|
||||||
|
const { mutateAsync: getPluginOAuthUrl } = useGetPluginOAuthUrlHook(pluginPayload)
|
||||||
|
const { data, isLoading } = useGetPluginOAuthClientSchemaHook(pluginPayload)
|
||||||
|
const {
|
||||||
|
schema = [],
|
||||||
|
is_oauth_custom_client_enabled,
|
||||||
|
is_system_oauth_params_exists,
|
||||||
|
client_params,
|
||||||
|
redirect_uri,
|
||||||
|
} = data || {}
|
||||||
|
const isConfigured = is_system_oauth_params_exists || is_oauth_custom_client_enabled
|
||||||
|
const handleOAuth = useCallback(async () => {
|
||||||
|
const { authorization_url } = await getPluginOAuthUrl()
|
||||||
|
|
||||||
|
if (authorization_url) {
|
||||||
|
openOAuthPopup(
|
||||||
|
authorization_url,
|
||||||
|
() => onUpdate?.(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}, [getPluginOAuthUrl, onUpdate])
|
||||||
|
|
||||||
|
const renderCustomLabel = useCallback((item: FormSchema) => {
|
||||||
|
return (
|
||||||
|
<div className='w-full'>
|
||||||
|
<div className='mb-4 flex rounded-xl bg-background-section-burn p-4'>
|
||||||
|
<div className='mr-3 flex h-9 w-9 shrink-0 items-center justify-center rounded-lg border-[0.5px] border-components-card-border bg-components-card-bg shadow-lg'>
|
||||||
|
<RiInformation2Fill className='h-5 w-5 text-text-accent' />
|
||||||
|
</div>
|
||||||
|
<div className='w-0 grow'>
|
||||||
|
<div className='system-sm-regular mb-1.5'>
|
||||||
|
{t('plugin.auth.clientInfo')}
|
||||||
|
</div>
|
||||||
|
{
|
||||||
|
redirect_uri && (
|
||||||
|
<div className='system-sm-medium flex w-full py-0.5'>
|
||||||
|
<div className='w-0 grow break-words'>{redirect_uri}</div>
|
||||||
|
<ActionButton
|
||||||
|
className='shrink-0'
|
||||||
|
onClick={() => {
|
||||||
|
navigator.clipboard.writeText(redirect_uri || '')
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<RiClipboardLine className='h-4 w-4' />
|
||||||
|
</ActionButton>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className='system-sm-medium flex h-6 items-center text-text-secondary'>
|
||||||
|
{renderI18nObject(item.label as Record<string, string>)}
|
||||||
|
{
|
||||||
|
item.required && (
|
||||||
|
<span className='ml-1 text-text-destructive-secondary'>*</span>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}, [t, redirect_uri, renderI18nObject])
|
||||||
|
const memorizedSchemas = useMemo(() => {
|
||||||
|
const result: FormSchema[] = schema.map((item, index) => {
|
||||||
|
return {
|
||||||
|
...item,
|
||||||
|
label: index === 0 ? renderCustomLabel(item) : item.label,
|
||||||
|
labelClassName: index === 0 ? 'h-auto' : undefined,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if (is_system_oauth_params_exists) {
|
||||||
|
result.unshift({
|
||||||
|
name: '__oauth_client__',
|
||||||
|
label: t('plugin.auth.oauthClient'),
|
||||||
|
type: FormTypeEnum.radio,
|
||||||
|
options: [
|
||||||
|
{
|
||||||
|
label: t('plugin.auth.default'),
|
||||||
|
value: 'default',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: t('plugin.auth.custom'),
|
||||||
|
value: 'custom',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
required: false,
|
||||||
|
default: is_oauth_custom_client_enabled ? 'custom' : 'default',
|
||||||
|
} as FormSchema)
|
||||||
|
result.forEach((item, index) => {
|
||||||
|
if (index > 0) {
|
||||||
|
item.show_on = [
|
||||||
|
{
|
||||||
|
variable: '__oauth_client__',
|
||||||
|
value: 'custom',
|
||||||
|
},
|
||||||
|
]
|
||||||
|
if (client_params)
|
||||||
|
item.default = client_params[item.name] || item.default
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}, [schema, renderCustomLabel, t, is_system_oauth_params_exists, is_oauth_custom_client_enabled, client_params])
|
||||||
|
|
||||||
|
const __auth_client__ = useMemo(() => {
|
||||||
|
if (isConfigured) {
|
||||||
|
if (is_oauth_custom_client_enabled)
|
||||||
|
return 'custom'
|
||||||
|
return 'default'
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
if (is_system_oauth_params_exists)
|
||||||
|
return 'default'
|
||||||
|
return 'custom'
|
||||||
|
}
|
||||||
|
}, [isConfigured, is_oauth_custom_client_enabled, is_system_oauth_params_exists])
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
{
|
||||||
|
isConfigured && (
|
||||||
|
<Button
|
||||||
|
variant={buttonVariant}
|
||||||
|
className={cn(
|
||||||
|
'w-full px-0 py-0 hover:bg-components-button-primary-bg',
|
||||||
|
className,
|
||||||
|
)}
|
||||||
|
disabled={disabled}
|
||||||
|
onClick={handleOAuth}
|
||||||
|
>
|
||||||
|
<div className={cn(
|
||||||
|
'flex h-full w-0 grow items-center justify-center rounded-l-lg pl-0.5 hover:bg-components-button-primary-bg-hover',
|
||||||
|
buttonLeftClassName,
|
||||||
|
)}>
|
||||||
|
<div
|
||||||
|
className='truncate'
|
||||||
|
title={buttonText}
|
||||||
|
>
|
||||||
|
{buttonText}
|
||||||
|
</div>
|
||||||
|
{
|
||||||
|
is_oauth_custom_client_enabled && (
|
||||||
|
<Badge
|
||||||
|
className={cn(
|
||||||
|
'ml-1 mr-0.5',
|
||||||
|
buttonVariant === 'primary' && 'border-text-primary-on-surface bg-components-badge-bg-dimm text-text-primary-on-surface',
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
{t('plugin.auth.custom')}
|
||||||
|
</Badge>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
</div>
|
||||||
|
<div className={cn(
|
||||||
|
'h-4 w-[1px] shrink-0 bg-text-primary-on-surface opacity-[0.15]',
|
||||||
|
dividerClassName,
|
||||||
|
)}></div>
|
||||||
|
<div
|
||||||
|
className={cn(
|
||||||
|
'flex h-full w-8 shrink-0 items-center justify-center rounded-r-lg hover:bg-components-button-primary-bg-hover',
|
||||||
|
buttonRightClassName,
|
||||||
|
)}
|
||||||
|
onClick={(e) => {
|
||||||
|
e.stopPropagation()
|
||||||
|
setIsOAuthSettingsOpen(true)
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<RiEqualizer2Line className='h-4 w-4' />
|
||||||
|
</div>
|
||||||
|
</Button>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
!isConfigured && (
|
||||||
|
<Button
|
||||||
|
variant={buttonVariant}
|
||||||
|
onClick={() => setIsOAuthSettingsOpen(true)}
|
||||||
|
disabled={disabled}
|
||||||
|
className='w-full'
|
||||||
|
>
|
||||||
|
<RiEqualizer2Line className='mr-0.5 h-4 w-4' />
|
||||||
|
{t('plugin.auth.setupOAuth')}
|
||||||
|
</Button>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
isOAuthSettingsOpen && (
|
||||||
|
<OAuthClientSettings
|
||||||
|
pluginPayload={pluginPayload}
|
||||||
|
onClose={() => setIsOAuthSettingsOpen(false)}
|
||||||
|
disabled={disabled || isLoading}
|
||||||
|
schemas={memorizedSchemas}
|
||||||
|
onAuth={handleOAuth}
|
||||||
|
editValues={{
|
||||||
|
...client_params,
|
||||||
|
__oauth_client__: __auth_client__,
|
||||||
|
}}
|
||||||
|
hasOriginalClientParams={Object.keys(client_params || {}).length > 0}
|
||||||
|
onUpdate={onUpdate}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
</>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default memo(AddOAuthButton)
|
||||||
@@ -0,0 +1,181 @@
|
|||||||
|
import {
|
||||||
|
memo,
|
||||||
|
useCallback,
|
||||||
|
useMemo,
|
||||||
|
useRef,
|
||||||
|
useState,
|
||||||
|
} from 'react'
|
||||||
|
import { useTranslation } from 'react-i18next'
|
||||||
|
import { RiExternalLinkLine } from '@remixicon/react'
|
||||||
|
import { Lock01 } from '@/app/components/base/icons/src/vender/solid/security'
|
||||||
|
import Modal from '@/app/components/base/modal/modal'
|
||||||
|
import { CredentialTypeEnum } from '../types'
|
||||||
|
import AuthForm from '@/app/components/base/form/form-scenarios/auth'
|
||||||
|
import type { FormRefObject } from '@/app/components/base/form/types'
|
||||||
|
import { FormTypeEnum } from '@/app/components/base/form/types'
|
||||||
|
import { useToastContext } from '@/app/components/base/toast'
|
||||||
|
import Loading from '@/app/components/base/loading'
|
||||||
|
import type { PluginPayload } from '../types'
|
||||||
|
import {
|
||||||
|
useAddPluginCredentialHook,
|
||||||
|
useGetPluginCredentialSchemaHook,
|
||||||
|
useUpdatePluginCredentialHook,
|
||||||
|
} from '../hooks/use-credential'
|
||||||
|
import { useRenderI18nObject } from '@/hooks/use-i18n'
|
||||||
|
|
||||||
|
export type ApiKeyModalProps = {
|
||||||
|
pluginPayload: PluginPayload
|
||||||
|
onClose?: () => void
|
||||||
|
editValues?: Record<string, any>
|
||||||
|
onRemove?: () => void
|
||||||
|
disabled?: boolean
|
||||||
|
onUpdate?: () => void
|
||||||
|
}
|
||||||
|
const ApiKeyModal = ({
|
||||||
|
pluginPayload,
|
||||||
|
onClose,
|
||||||
|
editValues,
|
||||||
|
onRemove,
|
||||||
|
disabled,
|
||||||
|
onUpdate,
|
||||||
|
}: ApiKeyModalProps) => {
|
||||||
|
const { t } = useTranslation()
|
||||||
|
const { notify } = useToastContext()
|
||||||
|
const [doingAction, setDoingAction] = useState(false)
|
||||||
|
const doingActionRef = useRef(doingAction)
|
||||||
|
const handleSetDoingAction = useCallback((value: boolean) => {
|
||||||
|
doingActionRef.current = value
|
||||||
|
setDoingAction(value)
|
||||||
|
}, [])
|
||||||
|
const { data = [], isLoading } = useGetPluginCredentialSchemaHook(pluginPayload, CredentialTypeEnum.API_KEY)
|
||||||
|
const formSchemas = useMemo(() => {
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
type: FormTypeEnum.textInput,
|
||||||
|
name: '__name__',
|
||||||
|
label: t('plugin.auth.authorizationName'),
|
||||||
|
required: false,
|
||||||
|
},
|
||||||
|
...data,
|
||||||
|
]
|
||||||
|
}, [data, t])
|
||||||
|
const defaultValues = formSchemas.reduce((acc, schema) => {
|
||||||
|
if (schema.default)
|
||||||
|
acc[schema.name] = schema.default
|
||||||
|
return acc
|
||||||
|
}, {} as Record<string, any>)
|
||||||
|
const helpField = formSchemas.find(schema => schema.url && schema.help)
|
||||||
|
const renderI18nObject = useRenderI18nObject()
|
||||||
|
const { mutateAsync: addPluginCredential } = useAddPluginCredentialHook(pluginPayload)
|
||||||
|
const { mutateAsync: updatePluginCredential } = useUpdatePluginCredentialHook(pluginPayload)
|
||||||
|
const formRef = useRef<FormRefObject>(null)
|
||||||
|
const handleConfirm = useCallback(async () => {
|
||||||
|
if (doingActionRef.current)
|
||||||
|
return
|
||||||
|
const {
|
||||||
|
isCheckValidated,
|
||||||
|
values,
|
||||||
|
} = formRef.current?.getFormValues({
|
||||||
|
needCheckValidatedValues: true,
|
||||||
|
needTransformWhenSecretFieldIsPristine: true,
|
||||||
|
}) || { isCheckValidated: false, values: {} }
|
||||||
|
if (!isCheckValidated)
|
||||||
|
return
|
||||||
|
|
||||||
|
try {
|
||||||
|
const {
|
||||||
|
__name__,
|
||||||
|
__credential_id__,
|
||||||
|
...restValues
|
||||||
|
} = values
|
||||||
|
|
||||||
|
handleSetDoingAction(true)
|
||||||
|
if (editValues) {
|
||||||
|
await updatePluginCredential({
|
||||||
|
credentials: restValues,
|
||||||
|
credential_id: __credential_id__,
|
||||||
|
name: __name__ || '',
|
||||||
|
})
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
await addPluginCredential({
|
||||||
|
credentials: restValues,
|
||||||
|
type: CredentialTypeEnum.API_KEY,
|
||||||
|
name: __name__ || '',
|
||||||
|
})
|
||||||
|
}
|
||||||
|
notify({
|
||||||
|
type: 'success',
|
||||||
|
message: t('common.api.actionSuccess'),
|
||||||
|
})
|
||||||
|
|
||||||
|
onClose?.()
|
||||||
|
onUpdate?.()
|
||||||
|
}
|
||||||
|
finally {
|
||||||
|
handleSetDoingAction(false)
|
||||||
|
}
|
||||||
|
}, [addPluginCredential, onClose, onUpdate, updatePluginCredential, notify, t, editValues, handleSetDoingAction])
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Modal
|
||||||
|
size='md'
|
||||||
|
title={t('plugin.auth.useApiAuth')}
|
||||||
|
subTitle={t('plugin.auth.useApiAuthDesc')}
|
||||||
|
onClose={onClose}
|
||||||
|
onCancel={onClose}
|
||||||
|
footerSlot={
|
||||||
|
helpField && (
|
||||||
|
<a
|
||||||
|
className='system-xs-regular mr-2 flex items-center py-2 text-text-accent'
|
||||||
|
href={helpField?.url}
|
||||||
|
target='_blank'
|
||||||
|
>
|
||||||
|
<span className='break-all'>
|
||||||
|
{renderI18nObject(helpField?.help as any)}
|
||||||
|
</span>
|
||||||
|
<RiExternalLinkLine className='ml-1 h-3 w-3' />
|
||||||
|
</a>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
bottomSlot={
|
||||||
|
<div className='flex items-center justify-center bg-background-section-burn py-3 text-xs text-text-tertiary'>
|
||||||
|
<Lock01 className='mr-1 h-3 w-3 text-text-tertiary' />
|
||||||
|
{t('common.modelProvider.encrypted.front')}
|
||||||
|
<a
|
||||||
|
className='mx-1 text-text-accent'
|
||||||
|
target='_blank' rel='noopener noreferrer'
|
||||||
|
href='https://pycryptodome.readthedocs.io/en/latest/src/cipher/oaep.html'
|
||||||
|
>
|
||||||
|
PKCS1_OAEP
|
||||||
|
</a>
|
||||||
|
{t('common.modelProvider.encrypted.back')}
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
onConfirm={handleConfirm}
|
||||||
|
showExtraButton={!!editValues}
|
||||||
|
onExtraButtonClick={onRemove}
|
||||||
|
disabled={disabled || isLoading || doingAction}
|
||||||
|
>
|
||||||
|
{
|
||||||
|
isLoading && (
|
||||||
|
<div className='flex h-40 items-center justify-center'>
|
||||||
|
<Loading />
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
!isLoading && !!data.length && (
|
||||||
|
<AuthForm
|
||||||
|
ref={formRef}
|
||||||
|
formSchemas={formSchemas}
|
||||||
|
defaultValues={editValues || defaultValues}
|
||||||
|
disabled={disabled}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
</Modal>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default memo(ApiKeyModal)
|
||||||
104
web/app/components/plugins/plugin-auth/authorize/index.tsx
Normal file
104
web/app/components/plugins/plugin-auth/authorize/index.tsx
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
import {
|
||||||
|
memo,
|
||||||
|
useMemo,
|
||||||
|
} from 'react'
|
||||||
|
import { useTranslation } from 'react-i18next'
|
||||||
|
import AddOAuthButton from './add-oauth-button'
|
||||||
|
import type { AddOAuthButtonProps } from './add-oauth-button'
|
||||||
|
import AddApiKeyButton from './add-api-key-button'
|
||||||
|
import type { AddApiKeyButtonProps } from './add-api-key-button'
|
||||||
|
import type { PluginPayload } from '../types'
|
||||||
|
|
||||||
|
type AuthorizeProps = {
|
||||||
|
pluginPayload: PluginPayload
|
||||||
|
theme?: 'primary' | 'secondary'
|
||||||
|
showDivider?: boolean
|
||||||
|
canOAuth?: boolean
|
||||||
|
canApiKey?: boolean
|
||||||
|
disabled?: boolean
|
||||||
|
onUpdate?: () => void
|
||||||
|
}
|
||||||
|
const Authorize = ({
|
||||||
|
pluginPayload,
|
||||||
|
theme = 'primary',
|
||||||
|
showDivider = true,
|
||||||
|
canOAuth,
|
||||||
|
canApiKey,
|
||||||
|
disabled,
|
||||||
|
onUpdate,
|
||||||
|
}: AuthorizeProps) => {
|
||||||
|
const { t } = useTranslation()
|
||||||
|
const oAuthButtonProps: AddOAuthButtonProps = useMemo(() => {
|
||||||
|
if (theme === 'secondary') {
|
||||||
|
return {
|
||||||
|
buttonText: !canApiKey ? t('plugin.auth.useOAuthAuth') : t('plugin.auth.addOAuth'),
|
||||||
|
buttonVariant: 'secondary',
|
||||||
|
className: 'hover:bg-components-button-secondary-bg',
|
||||||
|
buttonLeftClassName: 'hover:bg-components-button-secondary-bg-hover',
|
||||||
|
buttonRightClassName: 'hover:bg-components-button-secondary-bg-hover',
|
||||||
|
dividerClassName: 'bg-divider-regular opacity-100',
|
||||||
|
pluginPayload,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
buttonText: !canApiKey ? t('plugin.auth.useOAuthAuth') : t('plugin.auth.addOAuth'),
|
||||||
|
pluginPayload,
|
||||||
|
}
|
||||||
|
}, [canApiKey, theme, pluginPayload, t])
|
||||||
|
|
||||||
|
const apiKeyButtonProps: AddApiKeyButtonProps = useMemo(() => {
|
||||||
|
if (theme === 'secondary') {
|
||||||
|
return {
|
||||||
|
pluginPayload,
|
||||||
|
buttonVariant: 'secondary',
|
||||||
|
buttonText: !canOAuth ? t('plugin.auth.useApiAuth') : t('plugin.auth.addApi'),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
pluginPayload,
|
||||||
|
buttonText: !canOAuth ? t('plugin.auth.useApiAuth') : t('plugin.auth.addApi'),
|
||||||
|
buttonVariant: !canOAuth ? 'primary' : 'secondary-accent',
|
||||||
|
}
|
||||||
|
}, [canOAuth, theme, pluginPayload, t])
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<div className='flex items-center space-x-1.5'>
|
||||||
|
{
|
||||||
|
canOAuth && (
|
||||||
|
<div className='min-w-0 flex-[1]'>
|
||||||
|
<AddOAuthButton
|
||||||
|
{...oAuthButtonProps}
|
||||||
|
disabled={disabled}
|
||||||
|
onUpdate={onUpdate}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
showDivider && canOAuth && canApiKey && (
|
||||||
|
<div className='system-2xs-medium-uppercase flex shrink-0 flex-col items-center justify-between text-text-tertiary'>
|
||||||
|
<div className='h-2 w-[1px] bg-divider-subtle'></div>
|
||||||
|
or
|
||||||
|
<div className='h-2 w-[1px] bg-divider-subtle'></div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
canApiKey && (
|
||||||
|
<div className='min-w-0 flex-[1]'>
|
||||||
|
<AddApiKeyButton
|
||||||
|
{...apiKeyButtonProps}
|
||||||
|
disabled={disabled}
|
||||||
|
onUpdate={onUpdate}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default memo(Authorize)
|
||||||
@@ -0,0 +1,188 @@
|
|||||||
|
import {
|
||||||
|
memo,
|
||||||
|
useCallback,
|
||||||
|
useRef,
|
||||||
|
useState,
|
||||||
|
} from 'react'
|
||||||
|
import { RiExternalLinkLine } from '@remixicon/react'
|
||||||
|
import {
|
||||||
|
useForm,
|
||||||
|
useStore,
|
||||||
|
} from '@tanstack/react-form'
|
||||||
|
import { useTranslation } from 'react-i18next'
|
||||||
|
import Modal from '@/app/components/base/modal/modal'
|
||||||
|
import {
|
||||||
|
useDeletePluginOAuthCustomClientHook,
|
||||||
|
useInvalidPluginOAuthClientSchemaHook,
|
||||||
|
useSetPluginOAuthCustomClientHook,
|
||||||
|
} from '../hooks/use-credential'
|
||||||
|
import type { PluginPayload } from '../types'
|
||||||
|
import AuthForm from '@/app/components/base/form/form-scenarios/auth'
|
||||||
|
import type {
|
||||||
|
FormRefObject,
|
||||||
|
FormSchema,
|
||||||
|
} from '@/app/components/base/form/types'
|
||||||
|
import { useToastContext } from '@/app/components/base/toast'
|
||||||
|
import Button from '@/app/components/base/button'
|
||||||
|
import { useRenderI18nObject } from '@/hooks/use-i18n'
|
||||||
|
|
||||||
|
type OAuthClientSettingsProps = {
|
||||||
|
pluginPayload: PluginPayload
|
||||||
|
onClose?: () => void
|
||||||
|
editValues?: Record<string, any>
|
||||||
|
disabled?: boolean
|
||||||
|
schemas: FormSchema[]
|
||||||
|
onAuth?: () => Promise<void>
|
||||||
|
hasOriginalClientParams?: boolean
|
||||||
|
onUpdate?: () => void
|
||||||
|
}
|
||||||
|
const OAuthClientSettings = ({
|
||||||
|
pluginPayload,
|
||||||
|
onClose,
|
||||||
|
editValues,
|
||||||
|
disabled,
|
||||||
|
schemas,
|
||||||
|
onAuth,
|
||||||
|
hasOriginalClientParams,
|
||||||
|
onUpdate,
|
||||||
|
}: OAuthClientSettingsProps) => {
|
||||||
|
const { t } = useTranslation()
|
||||||
|
const { notify } = useToastContext()
|
||||||
|
const [doingAction, setDoingAction] = useState(false)
|
||||||
|
const doingActionRef = useRef(doingAction)
|
||||||
|
const handleSetDoingAction = useCallback((value: boolean) => {
|
||||||
|
doingActionRef.current = value
|
||||||
|
setDoingAction(value)
|
||||||
|
}, [])
|
||||||
|
const defaultValues = schemas.reduce((acc, schema) => {
|
||||||
|
if (schema.default)
|
||||||
|
acc[schema.name] = schema.default
|
||||||
|
return acc
|
||||||
|
}, {} as Record<string, any>)
|
||||||
|
const { mutateAsync: setPluginOAuthCustomClient } = useSetPluginOAuthCustomClientHook(pluginPayload)
|
||||||
|
const invalidPluginOAuthClientSchema = useInvalidPluginOAuthClientSchemaHook(pluginPayload)
|
||||||
|
const formRef = useRef<FormRefObject>(null)
|
||||||
|
const handleConfirm = useCallback(async () => {
|
||||||
|
if (doingActionRef.current)
|
||||||
|
return
|
||||||
|
|
||||||
|
try {
|
||||||
|
const {
|
||||||
|
isCheckValidated,
|
||||||
|
values,
|
||||||
|
} = formRef.current?.getFormValues({
|
||||||
|
needCheckValidatedValues: true,
|
||||||
|
needTransformWhenSecretFieldIsPristine: true,
|
||||||
|
}) || { isCheckValidated: false, values: {} }
|
||||||
|
if (!isCheckValidated)
|
||||||
|
throw new Error('error')
|
||||||
|
const {
|
||||||
|
__oauth_client__,
|
||||||
|
...restValues
|
||||||
|
} = values
|
||||||
|
|
||||||
|
handleSetDoingAction(true)
|
||||||
|
await setPluginOAuthCustomClient({
|
||||||
|
client_params: restValues,
|
||||||
|
enable_oauth_custom_client: __oauth_client__ === 'custom',
|
||||||
|
})
|
||||||
|
notify({
|
||||||
|
type: 'success',
|
||||||
|
message: t('common.api.actionSuccess'),
|
||||||
|
})
|
||||||
|
|
||||||
|
onClose?.()
|
||||||
|
onUpdate?.()
|
||||||
|
invalidPluginOAuthClientSchema()
|
||||||
|
}
|
||||||
|
finally {
|
||||||
|
handleSetDoingAction(false)
|
||||||
|
}
|
||||||
|
}, [onClose, onUpdate, invalidPluginOAuthClientSchema, setPluginOAuthCustomClient, notify, t, handleSetDoingAction])
|
||||||
|
|
||||||
|
const handleConfirmAndAuthorize = useCallback(async () => {
|
||||||
|
await handleConfirm()
|
||||||
|
if (onAuth)
|
||||||
|
await onAuth()
|
||||||
|
}, [handleConfirm, onAuth])
|
||||||
|
const { mutateAsync: deletePluginOAuthCustomClient } = useDeletePluginOAuthCustomClientHook(pluginPayload)
|
||||||
|
const handleRemove = useCallback(async () => {
|
||||||
|
if (doingActionRef.current)
|
||||||
|
return
|
||||||
|
|
||||||
|
try {
|
||||||
|
handleSetDoingAction(true)
|
||||||
|
await deletePluginOAuthCustomClient()
|
||||||
|
notify({
|
||||||
|
type: 'success',
|
||||||
|
message: t('common.api.actionSuccess'),
|
||||||
|
})
|
||||||
|
onClose?.()
|
||||||
|
onUpdate?.()
|
||||||
|
invalidPluginOAuthClientSchema()
|
||||||
|
}
|
||||||
|
finally {
|
||||||
|
handleSetDoingAction(false)
|
||||||
|
}
|
||||||
|
}, [onUpdate, invalidPluginOAuthClientSchema, deletePluginOAuthCustomClient, notify, t, handleSetDoingAction, onClose])
|
||||||
|
const form = useForm({
|
||||||
|
defaultValues: editValues || defaultValues,
|
||||||
|
})
|
||||||
|
const __oauth_client__ = useStore(form.store, s => s.values.__oauth_client__)
|
||||||
|
const helpField = schemas.find(schema => schema.url && schema.help)
|
||||||
|
const renderI18nObject = useRenderI18nObject()
|
||||||
|
return (
|
||||||
|
<Modal
|
||||||
|
title={t('plugin.auth.oauthClientSettings')}
|
||||||
|
confirmButtonText={t('plugin.auth.saveAndAuth')}
|
||||||
|
cancelButtonText={t('plugin.auth.saveOnly')}
|
||||||
|
extraButtonText={t('common.operation.cancel')}
|
||||||
|
showExtraButton
|
||||||
|
extraButtonVariant='secondary'
|
||||||
|
onExtraButtonClick={onClose}
|
||||||
|
onClose={onClose}
|
||||||
|
onCancel={handleConfirm}
|
||||||
|
onConfirm={handleConfirmAndAuthorize}
|
||||||
|
disabled={disabled || doingAction}
|
||||||
|
footerSlot={
|
||||||
|
__oauth_client__ === 'custom' && hasOriginalClientParams && (
|
||||||
|
<div className='grow'>
|
||||||
|
<Button
|
||||||
|
variant='secondary'
|
||||||
|
className='text-components-button-destructive-secondary-text'
|
||||||
|
disabled={disabled || doingAction || !editValues}
|
||||||
|
onClick={handleRemove}
|
||||||
|
>
|
||||||
|
{t('common.operation.remove')}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
>
|
||||||
|
<>
|
||||||
|
<AuthForm
|
||||||
|
formFromProps={form}
|
||||||
|
ref={formRef}
|
||||||
|
formSchemas={schemas}
|
||||||
|
defaultValues={editValues || defaultValues}
|
||||||
|
disabled={disabled}
|
||||||
|
/>
|
||||||
|
{
|
||||||
|
helpField && __oauth_client__ === 'custom' && (
|
||||||
|
<a
|
||||||
|
className='system-xs-regular mt-4 flex items-center text-text-accent'
|
||||||
|
href={helpField?.url}
|
||||||
|
target='_blank'
|
||||||
|
>
|
||||||
|
<span className='break-all'>
|
||||||
|
{renderI18nObject(helpField?.help as any)}
|
||||||
|
</span>
|
||||||
|
<RiExternalLinkLine className='ml-1 h-3 w-3' />
|
||||||
|
</a>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
</Modal>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default memo(OAuthClientSettings)
|
||||||
113
web/app/components/plugins/plugin-auth/authorized-in-node.tsx
Normal file
113
web/app/components/plugins/plugin-auth/authorized-in-node.tsx
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
import {
|
||||||
|
memo,
|
||||||
|
useCallback,
|
||||||
|
useState,
|
||||||
|
} from 'react'
|
||||||
|
import { useTranslation } from 'react-i18next'
|
||||||
|
import { RiArrowDownSLine } from '@remixicon/react'
|
||||||
|
import Button from '@/app/components/base/button'
|
||||||
|
import Indicator from '@/app/components/header/indicator'
|
||||||
|
import cn from '@/utils/classnames'
|
||||||
|
import type {
|
||||||
|
Credential,
|
||||||
|
PluginPayload,
|
||||||
|
} from './types'
|
||||||
|
import {
|
||||||
|
Authorized,
|
||||||
|
usePluginAuth,
|
||||||
|
} from '.'
|
||||||
|
|
||||||
|
type AuthorizedInNodeProps = {
|
||||||
|
pluginPayload: PluginPayload
|
||||||
|
onAuthorizationItemClick: (id: string) => void
|
||||||
|
credentialId?: string
|
||||||
|
}
|
||||||
|
const AuthorizedInNode = ({
|
||||||
|
pluginPayload,
|
||||||
|
onAuthorizationItemClick,
|
||||||
|
credentialId,
|
||||||
|
}: AuthorizedInNodeProps) => {
|
||||||
|
const { t } = useTranslation()
|
||||||
|
const [isOpen, setIsOpen] = useState(false)
|
||||||
|
const {
|
||||||
|
canApiKey,
|
||||||
|
canOAuth,
|
||||||
|
credentials,
|
||||||
|
disabled,
|
||||||
|
invalidPluginCredentialInfo,
|
||||||
|
} = usePluginAuth(pluginPayload, isOpen || !!credentialId)
|
||||||
|
const renderTrigger = useCallback((open?: boolean) => {
|
||||||
|
let label = ''
|
||||||
|
let removed = false
|
||||||
|
if (!credentialId) {
|
||||||
|
label = t('plugin.auth.workspaceDefault')
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
const credential = credentials.find(c => c.id === credentialId)
|
||||||
|
label = credential ? credential.name : t('plugin.auth.authRemoved')
|
||||||
|
removed = !credential
|
||||||
|
}
|
||||||
|
return (
|
||||||
|
<Button
|
||||||
|
size='small'
|
||||||
|
className={cn(
|
||||||
|
open && !removed && 'bg-components-button-ghost-bg-hover',
|
||||||
|
removed && 'bg-transparent text-text-destructive',
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<Indicator
|
||||||
|
className='mr-1.5'
|
||||||
|
color={removed ? 'red' : 'green'}
|
||||||
|
/>
|
||||||
|
{label}
|
||||||
|
<RiArrowDownSLine
|
||||||
|
className={cn(
|
||||||
|
'h-3.5 w-3.5 text-components-button-ghost-text',
|
||||||
|
removed && 'text-text-destructive',
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
</Button>
|
||||||
|
)
|
||||||
|
}, [credentialId, credentials, t])
|
||||||
|
const extraAuthorizationItems: Credential[] = [
|
||||||
|
{
|
||||||
|
id: '__workspace_default__',
|
||||||
|
name: t('plugin.auth.workspaceDefault'),
|
||||||
|
provider: '',
|
||||||
|
is_default: !credentialId,
|
||||||
|
isWorkspaceDefault: true,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
const handleAuthorizationItemClick = useCallback((id: string) => {
|
||||||
|
onAuthorizationItemClick(id)
|
||||||
|
setIsOpen(false)
|
||||||
|
}, [
|
||||||
|
onAuthorizationItemClick,
|
||||||
|
setIsOpen,
|
||||||
|
])
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Authorized
|
||||||
|
pluginPayload={pluginPayload}
|
||||||
|
credentials={credentials}
|
||||||
|
canOAuth={canOAuth}
|
||||||
|
canApiKey={canApiKey}
|
||||||
|
renderTrigger={renderTrigger}
|
||||||
|
isOpen={isOpen}
|
||||||
|
onOpenChange={setIsOpen}
|
||||||
|
offset={4}
|
||||||
|
placement='bottom-end'
|
||||||
|
triggerPopupSameWidth={false}
|
||||||
|
popupClassName='w-[360px]'
|
||||||
|
disabled={disabled}
|
||||||
|
disableSetDefault
|
||||||
|
onItemClick={handleAuthorizationItemClick}
|
||||||
|
extraAuthorizationItems={extraAuthorizationItems}
|
||||||
|
showItemSelectedIcon
|
||||||
|
selectedCredentialId={credentialId || '__workspace_default__'}
|
||||||
|
onUpdate={invalidPluginCredentialInfo}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default memo(AuthorizedInNode)
|
||||||
342
web/app/components/plugins/plugin-auth/authorized/index.tsx
Normal file
342
web/app/components/plugins/plugin-auth/authorized/index.tsx
Normal file
@@ -0,0 +1,342 @@
|
|||||||
|
import {
|
||||||
|
memo,
|
||||||
|
useCallback,
|
||||||
|
useRef,
|
||||||
|
useState,
|
||||||
|
} from 'react'
|
||||||
|
import {
|
||||||
|
RiArrowDownSLine,
|
||||||
|
} from '@remixicon/react'
|
||||||
|
import { useTranslation } from 'react-i18next'
|
||||||
|
import {
|
||||||
|
PortalToFollowElem,
|
||||||
|
PortalToFollowElemContent,
|
||||||
|
PortalToFollowElemTrigger,
|
||||||
|
} from '@/app/components/base/portal-to-follow-elem'
|
||||||
|
import type {
|
||||||
|
PortalToFollowElemOptions,
|
||||||
|
} from '@/app/components/base/portal-to-follow-elem'
|
||||||
|
import Button from '@/app/components/base/button'
|
||||||
|
import Indicator from '@/app/components/header/indicator'
|
||||||
|
import cn from '@/utils/classnames'
|
||||||
|
import Confirm from '@/app/components/base/confirm'
|
||||||
|
import Authorize from '../authorize'
|
||||||
|
import type { Credential } from '../types'
|
||||||
|
import { CredentialTypeEnum } from '../types'
|
||||||
|
import ApiKeyModal from '../authorize/api-key-modal'
|
||||||
|
import Item from './item'
|
||||||
|
import { useToastContext } from '@/app/components/base/toast'
|
||||||
|
import type { PluginPayload } from '../types'
|
||||||
|
import {
|
||||||
|
useDeletePluginCredentialHook,
|
||||||
|
useSetPluginDefaultCredentialHook,
|
||||||
|
useUpdatePluginCredentialHook,
|
||||||
|
} from '../hooks/use-credential'
|
||||||
|
|
||||||
|
type AuthorizedProps = {
|
||||||
|
pluginPayload: PluginPayload
|
||||||
|
credentials: Credential[]
|
||||||
|
canOAuth?: boolean
|
||||||
|
canApiKey?: boolean
|
||||||
|
disabled?: boolean
|
||||||
|
renderTrigger?: (open?: boolean) => React.ReactNode
|
||||||
|
isOpen?: boolean
|
||||||
|
onOpenChange?: (open: boolean) => void
|
||||||
|
offset?: PortalToFollowElemOptions['offset']
|
||||||
|
placement?: PortalToFollowElemOptions['placement']
|
||||||
|
triggerPopupSameWidth?: boolean
|
||||||
|
popupClassName?: string
|
||||||
|
disableSetDefault?: boolean
|
||||||
|
onItemClick?: (id: string) => void
|
||||||
|
extraAuthorizationItems?: Credential[]
|
||||||
|
showItemSelectedIcon?: boolean
|
||||||
|
selectedCredentialId?: string
|
||||||
|
onUpdate?: () => void
|
||||||
|
}
|
||||||
|
const Authorized = ({
|
||||||
|
pluginPayload,
|
||||||
|
credentials,
|
||||||
|
canOAuth,
|
||||||
|
canApiKey,
|
||||||
|
disabled,
|
||||||
|
renderTrigger,
|
||||||
|
isOpen,
|
||||||
|
onOpenChange,
|
||||||
|
offset = 8,
|
||||||
|
placement = 'bottom-start',
|
||||||
|
triggerPopupSameWidth = true,
|
||||||
|
popupClassName,
|
||||||
|
disableSetDefault,
|
||||||
|
onItemClick,
|
||||||
|
extraAuthorizationItems,
|
||||||
|
showItemSelectedIcon,
|
||||||
|
selectedCredentialId,
|
||||||
|
onUpdate,
|
||||||
|
}: AuthorizedProps) => {
|
||||||
|
const { t } = useTranslation()
|
||||||
|
const { notify } = useToastContext()
|
||||||
|
const [isLocalOpen, setIsLocalOpen] = useState(false)
|
||||||
|
const mergedIsOpen = isOpen ?? isLocalOpen
|
||||||
|
const setMergedIsOpen = useCallback((open: boolean) => {
|
||||||
|
if (onOpenChange)
|
||||||
|
onOpenChange(open)
|
||||||
|
|
||||||
|
setIsLocalOpen(open)
|
||||||
|
}, [onOpenChange])
|
||||||
|
const oAuthCredentials = credentials.filter(credential => credential.credential_type === CredentialTypeEnum.OAUTH2)
|
||||||
|
const apiKeyCredentials = credentials.filter(credential => credential.credential_type === CredentialTypeEnum.API_KEY)
|
||||||
|
const pendingOperationCredentialId = useRef<string | null>(null)
|
||||||
|
const [deleteCredentialId, setDeleteCredentialId] = useState<string | null>(null)
|
||||||
|
const { mutateAsync: deletePluginCredential } = useDeletePluginCredentialHook(pluginPayload)
|
||||||
|
const openConfirm = useCallback((credentialId?: string) => {
|
||||||
|
if (credentialId)
|
||||||
|
pendingOperationCredentialId.current = credentialId
|
||||||
|
|
||||||
|
setDeleteCredentialId(pendingOperationCredentialId.current)
|
||||||
|
}, [])
|
||||||
|
const closeConfirm = useCallback(() => {
|
||||||
|
setDeleteCredentialId(null)
|
||||||
|
pendingOperationCredentialId.current = null
|
||||||
|
}, [])
|
||||||
|
const [doingAction, setDoingAction] = useState(false)
|
||||||
|
const doingActionRef = useRef(doingAction)
|
||||||
|
const handleSetDoingAction = useCallback((doing: boolean) => {
|
||||||
|
doingActionRef.current = doing
|
||||||
|
setDoingAction(doing)
|
||||||
|
}, [])
|
||||||
|
const handleConfirm = useCallback(async () => {
|
||||||
|
if (doingActionRef.current)
|
||||||
|
return
|
||||||
|
if (!pendingOperationCredentialId.current) {
|
||||||
|
setDeleteCredentialId(null)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
handleSetDoingAction(true)
|
||||||
|
await deletePluginCredential({ credential_id: pendingOperationCredentialId.current })
|
||||||
|
notify({
|
||||||
|
type: 'success',
|
||||||
|
message: t('common.api.actionSuccess'),
|
||||||
|
})
|
||||||
|
onUpdate?.()
|
||||||
|
setDeleteCredentialId(null)
|
||||||
|
pendingOperationCredentialId.current = null
|
||||||
|
}
|
||||||
|
finally {
|
||||||
|
handleSetDoingAction(false)
|
||||||
|
}
|
||||||
|
}, [deletePluginCredential, onUpdate, notify, t, handleSetDoingAction])
|
||||||
|
const [editValues, setEditValues] = useState<Record<string, any> | null>(null)
|
||||||
|
const handleEdit = useCallback((id: string, values: Record<string, any>) => {
|
||||||
|
pendingOperationCredentialId.current = id
|
||||||
|
setEditValues(values)
|
||||||
|
}, [])
|
||||||
|
const handleRemove = useCallback(() => {
|
||||||
|
setDeleteCredentialId(pendingOperationCredentialId.current)
|
||||||
|
}, [])
|
||||||
|
const { mutateAsync: setPluginDefaultCredential } = useSetPluginDefaultCredentialHook(pluginPayload)
|
||||||
|
const handleSetDefault = useCallback(async (id: string) => {
|
||||||
|
if (doingActionRef.current)
|
||||||
|
return
|
||||||
|
try {
|
||||||
|
handleSetDoingAction(true)
|
||||||
|
await setPluginDefaultCredential(id)
|
||||||
|
notify({
|
||||||
|
type: 'success',
|
||||||
|
message: t('common.api.actionSuccess'),
|
||||||
|
})
|
||||||
|
onUpdate?.()
|
||||||
|
}
|
||||||
|
finally {
|
||||||
|
handleSetDoingAction(false)
|
||||||
|
}
|
||||||
|
}, [setPluginDefaultCredential, onUpdate, notify, t, handleSetDoingAction])
|
||||||
|
const { mutateAsync: updatePluginCredential } = useUpdatePluginCredentialHook(pluginPayload)
|
||||||
|
const handleRename = useCallback(async (payload: {
|
||||||
|
credential_id: string
|
||||||
|
name: string
|
||||||
|
}) => {
|
||||||
|
if (doingActionRef.current)
|
||||||
|
return
|
||||||
|
try {
|
||||||
|
handleSetDoingAction(true)
|
||||||
|
await updatePluginCredential(payload)
|
||||||
|
notify({
|
||||||
|
type: 'success',
|
||||||
|
message: t('common.api.actionSuccess'),
|
||||||
|
})
|
||||||
|
onUpdate?.()
|
||||||
|
}
|
||||||
|
finally {
|
||||||
|
handleSetDoingAction(false)
|
||||||
|
}
|
||||||
|
}, [updatePluginCredential, notify, t, handleSetDoingAction, onUpdate])
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<PortalToFollowElem
|
||||||
|
open={mergedIsOpen}
|
||||||
|
onOpenChange={setMergedIsOpen}
|
||||||
|
placement={placement}
|
||||||
|
offset={offset}
|
||||||
|
triggerPopupSameWidth={triggerPopupSameWidth}
|
||||||
|
>
|
||||||
|
<PortalToFollowElemTrigger
|
||||||
|
onClick={() => setMergedIsOpen(!mergedIsOpen)}
|
||||||
|
asChild
|
||||||
|
>
|
||||||
|
{
|
||||||
|
renderTrigger
|
||||||
|
? renderTrigger(mergedIsOpen)
|
||||||
|
: (
|
||||||
|
<Button
|
||||||
|
className={cn(
|
||||||
|
'w-full',
|
||||||
|
isOpen && 'bg-components-button-secondary-bg-hover',
|
||||||
|
)}>
|
||||||
|
<Indicator className='mr-2' />
|
||||||
|
{credentials.length}
|
||||||
|
{
|
||||||
|
credentials.length > 1
|
||||||
|
? t('plugin.auth.authorizations')
|
||||||
|
: t('plugin.auth.authorization')
|
||||||
|
}
|
||||||
|
<RiArrowDownSLine className='ml-0.5 h-4 w-4' />
|
||||||
|
</Button>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
</PortalToFollowElemTrigger>
|
||||||
|
<PortalToFollowElemContent className='z-[100]'>
|
||||||
|
<div className={cn(
|
||||||
|
'max-h-[360px] overflow-y-auto rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur shadow-lg',
|
||||||
|
popupClassName,
|
||||||
|
)}>
|
||||||
|
<div className='py-1'>
|
||||||
|
{
|
||||||
|
!!extraAuthorizationItems?.length && (
|
||||||
|
<div className='p-1'>
|
||||||
|
{
|
||||||
|
extraAuthorizationItems.map(credential => (
|
||||||
|
<Item
|
||||||
|
key={credential.id}
|
||||||
|
credential={credential}
|
||||||
|
disabled={disabled}
|
||||||
|
onItemClick={onItemClick}
|
||||||
|
disableRename
|
||||||
|
disableEdit
|
||||||
|
disableDelete
|
||||||
|
disableSetDefault
|
||||||
|
showSelectedIcon={showItemSelectedIcon}
|
||||||
|
selectedCredentialId={selectedCredentialId}
|
||||||
|
/>
|
||||||
|
))
|
||||||
|
}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
!!oAuthCredentials.length && (
|
||||||
|
<div className='p-1'>
|
||||||
|
<div className={cn(
|
||||||
|
'system-xs-medium px-3 pb-0.5 pt-1 text-text-tertiary',
|
||||||
|
showItemSelectedIcon && 'pl-7',
|
||||||
|
)}>
|
||||||
|
OAuth
|
||||||
|
</div>
|
||||||
|
{
|
||||||
|
oAuthCredentials.map(credential => (
|
||||||
|
<Item
|
||||||
|
key={credential.id}
|
||||||
|
credential={credential}
|
||||||
|
disabled={disabled}
|
||||||
|
disableEdit
|
||||||
|
onDelete={openConfirm}
|
||||||
|
onSetDefault={handleSetDefault}
|
||||||
|
onRename={handleRename}
|
||||||
|
disableSetDefault={disableSetDefault}
|
||||||
|
onItemClick={onItemClick}
|
||||||
|
showSelectedIcon={showItemSelectedIcon}
|
||||||
|
selectedCredentialId={selectedCredentialId}
|
||||||
|
/>
|
||||||
|
))
|
||||||
|
}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
!!apiKeyCredentials.length && (
|
||||||
|
<div className='p-1'>
|
||||||
|
<div className={cn(
|
||||||
|
'system-xs-medium px-3 pb-0.5 pt-1 text-text-tertiary',
|
||||||
|
showItemSelectedIcon && 'pl-7',
|
||||||
|
)}>
|
||||||
|
API Keys
|
||||||
|
</div>
|
||||||
|
{
|
||||||
|
apiKeyCredentials.map(credential => (
|
||||||
|
<Item
|
||||||
|
key={credential.id}
|
||||||
|
credential={credential}
|
||||||
|
disabled={disabled}
|
||||||
|
onDelete={openConfirm}
|
||||||
|
onEdit={handleEdit}
|
||||||
|
onSetDefault={handleSetDefault}
|
||||||
|
disableSetDefault={disableSetDefault}
|
||||||
|
disableRename
|
||||||
|
onItemClick={onItemClick}
|
||||||
|
onRename={handleRename}
|
||||||
|
showSelectedIcon={showItemSelectedIcon}
|
||||||
|
selectedCredentialId={selectedCredentialId}
|
||||||
|
/>
|
||||||
|
))
|
||||||
|
}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
</div>
|
||||||
|
<div className='h-[1px] bg-divider-subtle'></div>
|
||||||
|
<div className='p-2'>
|
||||||
|
<Authorize
|
||||||
|
pluginPayload={pluginPayload}
|
||||||
|
theme='secondary'
|
||||||
|
showDivider={false}
|
||||||
|
canOAuth={canOAuth}
|
||||||
|
canApiKey={canApiKey}
|
||||||
|
disabled={disabled}
|
||||||
|
onUpdate={onUpdate}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</PortalToFollowElemContent>
|
||||||
|
</PortalToFollowElem>
|
||||||
|
{
|
||||||
|
deleteCredentialId && (
|
||||||
|
<Confirm
|
||||||
|
isShow
|
||||||
|
title={t('datasetDocuments.list.delete.title')}
|
||||||
|
isDisabled={doingAction}
|
||||||
|
onCancel={closeConfirm}
|
||||||
|
onConfirm={handleConfirm}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
!!editValues && (
|
||||||
|
<ApiKeyModal
|
||||||
|
pluginPayload={pluginPayload}
|
||||||
|
editValues={editValues}
|
||||||
|
onClose={() => {
|
||||||
|
setEditValues(null)
|
||||||
|
pendingOperationCredentialId.current = null
|
||||||
|
}}
|
||||||
|
onRemove={handleRemove}
|
||||||
|
disabled={disabled || doingAction}
|
||||||
|
onUpdate={onUpdate}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
</>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default memo(Authorized)
|
||||||
219
web/app/components/plugins/plugin-auth/authorized/item.tsx
Normal file
219
web/app/components/plugins/plugin-auth/authorized/item.tsx
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
import {
|
||||||
|
memo,
|
||||||
|
useMemo,
|
||||||
|
useState,
|
||||||
|
} from 'react'
|
||||||
|
import { useTranslation } from 'react-i18next'
|
||||||
|
import {
|
||||||
|
RiCheckLine,
|
||||||
|
RiDeleteBinLine,
|
||||||
|
RiEditLine,
|
||||||
|
RiEqualizer2Line,
|
||||||
|
} from '@remixicon/react'
|
||||||
|
import Indicator from '@/app/components/header/indicator'
|
||||||
|
import Badge from '@/app/components/base/badge'
|
||||||
|
import ActionButton from '@/app/components/base/action-button'
|
||||||
|
import Tooltip from '@/app/components/base/tooltip'
|
||||||
|
import Button from '@/app/components/base/button'
|
||||||
|
import Input from '@/app/components/base/input'
|
||||||
|
import cn from '@/utils/classnames'
|
||||||
|
import type { Credential } from '../types'
|
||||||
|
import { CredentialTypeEnum } from '../types'
|
||||||
|
|
||||||
|
type ItemProps = {
|
||||||
|
credential: Credential
|
||||||
|
disabled?: boolean
|
||||||
|
onDelete?: (id: string) => void
|
||||||
|
onEdit?: (id: string, values: Record<string, any>) => void
|
||||||
|
onSetDefault?: (id: string) => void
|
||||||
|
onRename?: (payload: {
|
||||||
|
credential_id: string
|
||||||
|
name: string
|
||||||
|
}) => void
|
||||||
|
disableRename?: boolean
|
||||||
|
disableEdit?: boolean
|
||||||
|
disableDelete?: boolean
|
||||||
|
disableSetDefault?: boolean
|
||||||
|
onItemClick?: (id: string) => void
|
||||||
|
showSelectedIcon?: boolean
|
||||||
|
selectedCredentialId?: string
|
||||||
|
}
|
||||||
|
const Item = ({
|
||||||
|
credential,
|
||||||
|
disabled,
|
||||||
|
onDelete,
|
||||||
|
onEdit,
|
||||||
|
onSetDefault,
|
||||||
|
onRename,
|
||||||
|
disableRename,
|
||||||
|
disableEdit,
|
||||||
|
disableDelete,
|
||||||
|
disableSetDefault,
|
||||||
|
onItemClick,
|
||||||
|
showSelectedIcon,
|
||||||
|
selectedCredentialId,
|
||||||
|
}: ItemProps) => {
|
||||||
|
const { t } = useTranslation()
|
||||||
|
const [renaming, setRenaming] = useState(false)
|
||||||
|
const [renameValue, setRenameValue] = useState(credential.name)
|
||||||
|
const isOAuth = credential.credential_type === CredentialTypeEnum.OAUTH2
|
||||||
|
const showAction = useMemo(() => {
|
||||||
|
return !(disableRename && disableEdit && disableDelete && disableSetDefault)
|
||||||
|
}, [disableRename, disableEdit, disableDelete, disableSetDefault])
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
key={credential.id}
|
||||||
|
className={cn(
|
||||||
|
'group flex h-8 items-center rounded-lg p-1 hover:bg-state-base-hover',
|
||||||
|
renaming && 'bg-state-base-hover',
|
||||||
|
)}
|
||||||
|
onClick={() => onItemClick?.(credential.id === '__workspace_default__' ? '' : credential.id)}
|
||||||
|
>
|
||||||
|
{
|
||||||
|
renaming && (
|
||||||
|
<div className='flex w-full items-center space-x-1'>
|
||||||
|
<Input
|
||||||
|
wrapperClassName='grow rounded-[6px]'
|
||||||
|
className='h-6'
|
||||||
|
value={renameValue}
|
||||||
|
onChange={e => setRenameValue(e.target.value)}
|
||||||
|
placeholder={t('common.placeholder.input')}
|
||||||
|
onClick={e => e.stopPropagation()}
|
||||||
|
/>
|
||||||
|
<Button
|
||||||
|
size='small'
|
||||||
|
variant='primary'
|
||||||
|
onClick={(e) => {
|
||||||
|
e.stopPropagation()
|
||||||
|
onRename?.({
|
||||||
|
credential_id: credential.id,
|
||||||
|
name: renameValue,
|
||||||
|
})
|
||||||
|
setRenaming(false)
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{t('common.operation.save')}
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
size='small'
|
||||||
|
onClick={(e) => {
|
||||||
|
e.stopPropagation()
|
||||||
|
setRenaming(false)
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{t('common.operation.cancel')}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
!renaming && (
|
||||||
|
<div className='flex w-0 grow items-center space-x-1.5'>
|
||||||
|
{
|
||||||
|
showSelectedIcon && (
|
||||||
|
<div className='h-4 w-4'>
|
||||||
|
{
|
||||||
|
selectedCredentialId === credential.id && (
|
||||||
|
<RiCheckLine className='h-4 w-4 text-text-accent' />
|
||||||
|
)
|
||||||
|
}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
<Indicator className='ml-2 mr-1.5 shrink-0' />
|
||||||
|
<div
|
||||||
|
className='system-md-regular truncate text-text-secondary'
|
||||||
|
title={credential.name}
|
||||||
|
>
|
||||||
|
{credential.name}
|
||||||
|
</div>
|
||||||
|
{
|
||||||
|
credential.is_default && (
|
||||||
|
<Badge className='shrink-0'>
|
||||||
|
{t('plugin.auth.default')}
|
||||||
|
</Badge>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
showAction && !renaming && (
|
||||||
|
<div className='ml-2 hidden shrink-0 items-center group-hover:flex'>
|
||||||
|
{
|
||||||
|
!credential.is_default && !disableSetDefault && (
|
||||||
|
<Button
|
||||||
|
size='small'
|
||||||
|
disabled={disabled}
|
||||||
|
onClick={(e) => {
|
||||||
|
e.stopPropagation()
|
||||||
|
onSetDefault?.(credential.id)
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{t('plugin.auth.setDefault')}
|
||||||
|
</Button>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
!disableRename && (
|
||||||
|
<Tooltip popupContent={t('common.operation.rename')}>
|
||||||
|
<ActionButton
|
||||||
|
disabled={disabled}
|
||||||
|
onClick={(e) => {
|
||||||
|
e.stopPropagation()
|
||||||
|
setRenaming(true)
|
||||||
|
setRenameValue(credential.name)
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<RiEditLine className='h-4 w-4 text-text-tertiary' />
|
||||||
|
</ActionButton>
|
||||||
|
</Tooltip>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
!isOAuth && !disableEdit && (
|
||||||
|
<Tooltip popupContent={t('common.operation.edit')}>
|
||||||
|
<ActionButton
|
||||||
|
disabled={disabled}
|
||||||
|
onClick={(e) => {
|
||||||
|
e.stopPropagation()
|
||||||
|
onEdit?.(
|
||||||
|
credential.id,
|
||||||
|
{
|
||||||
|
...credential.credentials,
|
||||||
|
__name__: credential.name,
|
||||||
|
__credential_id__: credential.id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<RiEqualizer2Line className='h-4 w-4 text-text-tertiary' />
|
||||||
|
</ActionButton>
|
||||||
|
</Tooltip>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
!disableDelete && (
|
||||||
|
<Tooltip popupContent={t('common.operation.delete')}>
|
||||||
|
<ActionButton
|
||||||
|
className='hover:bg-transparent'
|
||||||
|
disabled={disabled}
|
||||||
|
onClick={(e) => {
|
||||||
|
e.stopPropagation()
|
||||||
|
onDelete?.(credential.id)
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<RiDeleteBinLine className='h-4 w-4 text-text-tertiary hover:text-text-destructive' />
|
||||||
|
</ActionButton>
|
||||||
|
</Tooltip>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default memo(Item)
|
||||||
@@ -0,0 +1,88 @@
|
|||||||
|
import {
|
||||||
|
useAddPluginCredential,
|
||||||
|
useDeletePluginCredential,
|
||||||
|
useDeletePluginOAuthCustomClient,
|
||||||
|
useGetPluginCredentialInfo,
|
||||||
|
useGetPluginCredentialSchema,
|
||||||
|
useGetPluginOAuthClientSchema,
|
||||||
|
useGetPluginOAuthUrl,
|
||||||
|
useInvalidPluginCredentialInfo,
|
||||||
|
useInvalidPluginOAuthClientSchema,
|
||||||
|
useSetPluginDefaultCredential,
|
||||||
|
useSetPluginOAuthCustomClient,
|
||||||
|
useUpdatePluginCredential,
|
||||||
|
} from '@/service/use-plugins-auth'
|
||||||
|
import { useGetApi } from './use-get-api'
|
||||||
|
import type { PluginPayload } from '../types'
|
||||||
|
import type { CredentialTypeEnum } from '../types'
|
||||||
|
|
||||||
|
export const useGetPluginCredentialInfoHook = (pluginPayload: PluginPayload, enable?: boolean) => {
|
||||||
|
const apiMap = useGetApi(pluginPayload)
|
||||||
|
return useGetPluginCredentialInfo(enable ? apiMap.getCredentialInfo : '')
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useDeletePluginCredentialHook = (pluginPayload: PluginPayload) => {
|
||||||
|
const apiMap = useGetApi(pluginPayload)
|
||||||
|
|
||||||
|
return useDeletePluginCredential(apiMap.deleteCredential)
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useInvalidPluginCredentialInfoHook = (pluginPayload: PluginPayload) => {
|
||||||
|
const apiMap = useGetApi(pluginPayload)
|
||||||
|
|
||||||
|
return useInvalidPluginCredentialInfo(apiMap.getCredentialInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useSetPluginDefaultCredentialHook = (pluginPayload: PluginPayload) => {
|
||||||
|
const apiMap = useGetApi(pluginPayload)
|
||||||
|
|
||||||
|
return useSetPluginDefaultCredential(apiMap.setDefaultCredential)
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useGetPluginCredentialSchemaHook = (pluginPayload: PluginPayload, credentialType: CredentialTypeEnum) => {
|
||||||
|
const apiMap = useGetApi(pluginPayload)
|
||||||
|
|
||||||
|
return useGetPluginCredentialSchema(apiMap.getCredentialSchema(credentialType))
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useAddPluginCredentialHook = (pluginPayload: PluginPayload) => {
|
||||||
|
const apiMap = useGetApi(pluginPayload)
|
||||||
|
|
||||||
|
return useAddPluginCredential(apiMap.addCredential)
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useUpdatePluginCredentialHook = (pluginPayload: PluginPayload) => {
|
||||||
|
const apiMap = useGetApi(pluginPayload)
|
||||||
|
|
||||||
|
return useUpdatePluginCredential(apiMap.updateCredential)
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useGetPluginOAuthUrlHook = (pluginPayload: PluginPayload) => {
|
||||||
|
const apiMap = useGetApi(pluginPayload)
|
||||||
|
|
||||||
|
return useGetPluginOAuthUrl(apiMap.getOauthUrl)
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useGetPluginOAuthClientSchemaHook = (pluginPayload: PluginPayload) => {
|
||||||
|
const apiMap = useGetApi(pluginPayload)
|
||||||
|
|
||||||
|
return useGetPluginOAuthClientSchema(apiMap.getOauthClientSchema)
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useInvalidPluginOAuthClientSchemaHook = (pluginPayload: PluginPayload) => {
|
||||||
|
const apiMap = useGetApi(pluginPayload)
|
||||||
|
|
||||||
|
return useInvalidPluginOAuthClientSchema(apiMap.getOauthClientSchema)
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useSetPluginOAuthCustomClientHook = (pluginPayload: PluginPayload) => {
|
||||||
|
const apiMap = useGetApi(pluginPayload)
|
||||||
|
|
||||||
|
return useSetPluginOAuthCustomClient(apiMap.setCustomOauthClient)
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useDeletePluginOAuthCustomClientHook = (pluginPayload: PluginPayload) => {
|
||||||
|
const apiMap = useGetApi(pluginPayload)
|
||||||
|
|
||||||
|
return useDeletePluginOAuthCustomClient(apiMap.deleteCustomOAuthClient)
|
||||||
|
}
|
||||||
41
web/app/components/plugins/plugin-auth/hooks/use-get-api.ts
Normal file
41
web/app/components/plugins/plugin-auth/hooks/use-get-api.ts
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
import {
|
||||||
|
AuthCategory,
|
||||||
|
} from '../types'
|
||||||
|
import type {
|
||||||
|
CredentialTypeEnum,
|
||||||
|
PluginPayload,
|
||||||
|
} from '../types'
|
||||||
|
|
||||||
|
export const useGetApi = ({ category = AuthCategory.tool, provider }: PluginPayload) => {
|
||||||
|
if (category === AuthCategory.tool) {
|
||||||
|
return {
|
||||||
|
getCredentialInfo: `/workspaces/current/tool-provider/builtin/${provider}/credential/info`,
|
||||||
|
setDefaultCredential: `/workspaces/current/tool-provider/builtin/${provider}/default-credential`,
|
||||||
|
getCredentials: `/workspaces/current/tool-provider/builtin/${provider}/credentials`,
|
||||||
|
addCredential: `/workspaces/current/tool-provider/builtin/${provider}/add`,
|
||||||
|
updateCredential: `/workspaces/current/tool-provider/builtin/${provider}/update`,
|
||||||
|
deleteCredential: `/workspaces/current/tool-provider/builtin/${provider}/delete`,
|
||||||
|
getCredentialSchema: (credential_type: CredentialTypeEnum) => `/workspaces/current/tool-provider/builtin/${provider}/credential/schema/${credential_type}`,
|
||||||
|
getOauthUrl: `/oauth/plugin/${provider}/tool/authorization-url`,
|
||||||
|
getOauthClientSchema: `/workspaces/current/tool-provider/builtin/${provider}/oauth/client-schema`,
|
||||||
|
setCustomOauthClient: `/workspaces/current/tool-provider/builtin/${provider}/oauth/custom-client`,
|
||||||
|
getCustomOAuthClientValues: `/workspaces/current/tool-provider/builtin/${provider}/oauth/custom-client`,
|
||||||
|
deleteCustomOAuthClient: `/workspaces/current/tool-provider/builtin/${provider}/oauth/custom-client`,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
getCredentialInfo: '',
|
||||||
|
setDefaultCredential: '',
|
||||||
|
getCredentials: '',
|
||||||
|
addCredential: '',
|
||||||
|
updateCredential: '',
|
||||||
|
deleteCredential: '',
|
||||||
|
getCredentialSchema: () => '',
|
||||||
|
getOauthUrl: '',
|
||||||
|
getOauthClientSchema: '',
|
||||||
|
setCustomOauthClient: '',
|
||||||
|
getCustomOAuthClientValues: '',
|
||||||
|
deleteCustomOAuthClient: '',
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
import { useAppContext } from '@/context/app-context'
|
||||||
|
import {
|
||||||
|
useGetPluginCredentialInfoHook,
|
||||||
|
useInvalidPluginCredentialInfoHook,
|
||||||
|
} from './use-credential'
|
||||||
|
import { CredentialTypeEnum } from '../types'
|
||||||
|
import type { PluginPayload } from '../types'
|
||||||
|
|
||||||
|
export const usePluginAuth = (pluginPayload: PluginPayload, enable?: boolean) => {
|
||||||
|
const { data } = useGetPluginCredentialInfoHook(pluginPayload, enable)
|
||||||
|
const { isCurrentWorkspaceManager } = useAppContext()
|
||||||
|
const isAuthorized = !!data?.credentials.length
|
||||||
|
const canOAuth = data?.supported_credential_types.includes(CredentialTypeEnum.OAUTH2)
|
||||||
|
const canApiKey = data?.supported_credential_types.includes(CredentialTypeEnum.API_KEY)
|
||||||
|
const invalidPluginCredentialInfo = useInvalidPluginCredentialInfoHook(pluginPayload)
|
||||||
|
|
||||||
|
return {
|
||||||
|
isAuthorized,
|
||||||
|
canOAuth,
|
||||||
|
canApiKey,
|
||||||
|
credentials: data?.credentials || [],
|
||||||
|
disabled: !isCurrentWorkspaceManager,
|
||||||
|
invalidPluginCredentialInfo,
|
||||||
|
}
|
||||||
|
}
|
||||||
6
web/app/components/plugins/plugin-auth/index.tsx
Normal file
6
web/app/components/plugins/plugin-auth/index.tsx
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
export { default as PluginAuth } from './plugin-auth'
|
||||||
|
export { default as Authorized } from './authorized'
|
||||||
|
export { default as AuthorizedInNode } from './authorized-in-node'
|
||||||
|
export { default as PluginAuthInAgent } from './plugin-auth-in-agent'
|
||||||
|
export { usePluginAuth } from './hooks/use-plugin-auth'
|
||||||
|
export * from './types'
|
||||||
123
web/app/components/plugins/plugin-auth/plugin-auth-in-agent.tsx
Normal file
123
web/app/components/plugins/plugin-auth/plugin-auth-in-agent.tsx
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
import {
|
||||||
|
memo,
|
||||||
|
useCallback,
|
||||||
|
useState,
|
||||||
|
} from 'react'
|
||||||
|
import { RiArrowDownSLine } from '@remixicon/react'
|
||||||
|
import { useTranslation } from 'react-i18next'
|
||||||
|
import Authorize from './authorize'
|
||||||
|
import Authorized from './authorized'
|
||||||
|
import type {
|
||||||
|
Credential,
|
||||||
|
PluginPayload,
|
||||||
|
} from './types'
|
||||||
|
import { usePluginAuth } from './hooks/use-plugin-auth'
|
||||||
|
import Button from '@/app/components/base/button'
|
||||||
|
import Indicator from '@/app/components/header/indicator'
|
||||||
|
import cn from '@/utils/classnames'
|
||||||
|
|
||||||
|
type PluginAuthInAgentProps = {
|
||||||
|
pluginPayload: PluginPayload
|
||||||
|
credentialId?: string
|
||||||
|
onAuthorizationItemClick?: (id: string) => void
|
||||||
|
}
|
||||||
|
const PluginAuthInAgent = ({
|
||||||
|
pluginPayload,
|
||||||
|
credentialId,
|
||||||
|
onAuthorizationItemClick,
|
||||||
|
}: PluginAuthInAgentProps) => {
|
||||||
|
const { t } = useTranslation()
|
||||||
|
const [isOpen, setIsOpen] = useState(false)
|
||||||
|
const {
|
||||||
|
isAuthorized,
|
||||||
|
canOAuth,
|
||||||
|
canApiKey,
|
||||||
|
credentials,
|
||||||
|
disabled,
|
||||||
|
invalidPluginCredentialInfo,
|
||||||
|
} = usePluginAuth(pluginPayload, true)
|
||||||
|
|
||||||
|
const extraAuthorizationItems: Credential[] = [
|
||||||
|
{
|
||||||
|
id: '__workspace_default__',
|
||||||
|
name: t('plugin.auth.workspaceDefault'),
|
||||||
|
provider: '',
|
||||||
|
is_default: !credentialId,
|
||||||
|
isWorkspaceDefault: true,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
const handleAuthorizationItemClick = useCallback((id: string) => {
|
||||||
|
onAuthorizationItemClick?.(id)
|
||||||
|
setIsOpen(false)
|
||||||
|
}, [
|
||||||
|
onAuthorizationItemClick,
|
||||||
|
setIsOpen,
|
||||||
|
])
|
||||||
|
|
||||||
|
const renderTrigger = useCallback((isOpen?: boolean) => {
|
||||||
|
let label = ''
|
||||||
|
let removed = false
|
||||||
|
if (!credentialId) {
|
||||||
|
label = t('plugin.auth.workspaceDefault')
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
const credential = credentials.find(c => c.id === credentialId)
|
||||||
|
label = credential ? credential.name : t('plugin.auth.authRemoved')
|
||||||
|
removed = !credential
|
||||||
|
}
|
||||||
|
return (
|
||||||
|
<Button
|
||||||
|
className={cn(
|
||||||
|
'w-full',
|
||||||
|
isOpen && 'bg-components-button-secondary-bg-hover',
|
||||||
|
removed && 'text-text-destructive',
|
||||||
|
)}>
|
||||||
|
<Indicator
|
||||||
|
className='mr-2'
|
||||||
|
color={removed ? 'red' : 'green'}
|
||||||
|
/>
|
||||||
|
{label}
|
||||||
|
<RiArrowDownSLine className='ml-0.5 h-4 w-4' />
|
||||||
|
</Button>
|
||||||
|
)
|
||||||
|
}, [credentialId, credentials, t])
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
{
|
||||||
|
!isAuthorized && (
|
||||||
|
<Authorize
|
||||||
|
pluginPayload={pluginPayload}
|
||||||
|
canOAuth={canOAuth}
|
||||||
|
canApiKey={canApiKey}
|
||||||
|
disabled={disabled}
|
||||||
|
onUpdate={invalidPluginCredentialInfo}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
isAuthorized && (
|
||||||
|
<Authorized
|
||||||
|
pluginPayload={pluginPayload}
|
||||||
|
credentials={credentials}
|
||||||
|
canOAuth={canOAuth}
|
||||||
|
canApiKey={canApiKey}
|
||||||
|
disabled={disabled}
|
||||||
|
disableSetDefault
|
||||||
|
onItemClick={handleAuthorizationItemClick}
|
||||||
|
extraAuthorizationItems={extraAuthorizationItems}
|
||||||
|
showItemSelectedIcon
|
||||||
|
renderTrigger={renderTrigger}
|
||||||
|
isOpen={isOpen}
|
||||||
|
onOpenChange={setIsOpen}
|
||||||
|
selectedCredentialId={credentialId || '__workspace_default__'}
|
||||||
|
onUpdate={invalidPluginCredentialInfo}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
</>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default memo(PluginAuthInAgent)
|
||||||
59
web/app/components/plugins/plugin-auth/plugin-auth.tsx
Normal file
59
web/app/components/plugins/plugin-auth/plugin-auth.tsx
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
import { memo } from 'react'
|
||||||
|
import Authorize from './authorize'
|
||||||
|
import Authorized from './authorized'
|
||||||
|
import type { PluginPayload } from './types'
|
||||||
|
import { usePluginAuth } from './hooks/use-plugin-auth'
|
||||||
|
import cn from '@/utils/classnames'
|
||||||
|
|
||||||
|
type PluginAuthProps = {
|
||||||
|
pluginPayload: PluginPayload
|
||||||
|
children?: React.ReactNode
|
||||||
|
className?: string
|
||||||
|
}
|
||||||
|
const PluginAuth = ({
|
||||||
|
pluginPayload,
|
||||||
|
children,
|
||||||
|
className,
|
||||||
|
}: PluginAuthProps) => {
|
||||||
|
const {
|
||||||
|
isAuthorized,
|
||||||
|
canOAuth,
|
||||||
|
canApiKey,
|
||||||
|
credentials,
|
||||||
|
disabled,
|
||||||
|
invalidPluginCredentialInfo,
|
||||||
|
} = usePluginAuth(pluginPayload, !!pluginPayload.provider)
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className={cn(!isAuthorized && className)}>
|
||||||
|
{
|
||||||
|
!isAuthorized && (
|
||||||
|
<Authorize
|
||||||
|
pluginPayload={pluginPayload}
|
||||||
|
canOAuth={canOAuth}
|
||||||
|
canApiKey={canApiKey}
|
||||||
|
disabled={disabled}
|
||||||
|
onUpdate={invalidPluginCredentialInfo}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
isAuthorized && !children && (
|
||||||
|
<Authorized
|
||||||
|
pluginPayload={pluginPayload}
|
||||||
|
credentials={credentials}
|
||||||
|
canOAuth={canOAuth}
|
||||||
|
canApiKey={canApiKey}
|
||||||
|
disabled={disabled}
|
||||||
|
onUpdate={invalidPluginCredentialInfo}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
isAuthorized && children
|
||||||
|
}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default memo(PluginAuth)
|
||||||
25
web/app/components/plugins/plugin-auth/types.ts
Normal file
25
web/app/components/plugins/plugin-auth/types.ts
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
export enum AuthCategory {
|
||||||
|
tool = 'tool',
|
||||||
|
datasource = 'datasource',
|
||||||
|
model = 'model',
|
||||||
|
}
|
||||||
|
|
||||||
|
export type PluginPayload = {
|
||||||
|
category: AuthCategory
|
||||||
|
provider: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export enum CredentialTypeEnum {
|
||||||
|
OAUTH2 = 'oauth2',
|
||||||
|
API_KEY = 'api-key',
|
||||||
|
}
|
||||||
|
|
||||||
|
export type Credential = {
|
||||||
|
id: string
|
||||||
|
name: string
|
||||||
|
provider: string
|
||||||
|
credential_type?: CredentialTypeEnum
|
||||||
|
is_default: boolean
|
||||||
|
credentials?: Record<string, any>
|
||||||
|
isWorkspaceDefault?: boolean
|
||||||
|
}
|
||||||
10
web/app/components/plugins/plugin-auth/utils.ts
Normal file
10
web/app/components/plugins/plugin-auth/utils.ts
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
export const transformFormSchemasSecretInput = (isPristineSecretInputNames: string[], values: Record<string, any>) => {
|
||||||
|
const transformedValues: Record<string, any> = { ...values }
|
||||||
|
|
||||||
|
isPristineSecretInputNames.forEach((name) => {
|
||||||
|
if (transformedValues[name])
|
||||||
|
transformedValues[name] = '[__HIDDEN__]'
|
||||||
|
})
|
||||||
|
|
||||||
|
return transformedValues
|
||||||
|
}
|
||||||
@@ -1,17 +1,9 @@
|
|||||||
import React, { useMemo, useState } from 'react'
|
import React, { useMemo } from 'react'
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
import { useAppContext } from '@/context/app-context'
|
|
||||||
import Button from '@/app/components/base/button'
|
|
||||||
import Toast from '@/app/components/base/toast'
|
|
||||||
import Indicator from '@/app/components/header/indicator'
|
|
||||||
import ToolItem from '@/app/components/tools/provider/tool-item'
|
import ToolItem from '@/app/components/tools/provider/tool-item'
|
||||||
import ConfigCredential from '@/app/components/tools/setting/build-in/config-credentials'
|
|
||||||
import {
|
import {
|
||||||
useAllToolProviders,
|
useAllToolProviders,
|
||||||
useBuiltinTools,
|
useBuiltinTools,
|
||||||
useInvalidateAllToolProviders,
|
|
||||||
useRemoveProviderCredentials,
|
|
||||||
useUpdateProviderCredentials,
|
|
||||||
} from '@/service/use-tools'
|
} from '@/service/use-tools'
|
||||||
import type { PluginDetail } from '@/app/components/plugins/types'
|
import type { PluginDetail } from '@/app/components/plugins/types'
|
||||||
|
|
||||||
@@ -23,35 +15,14 @@ const ActionList = ({
|
|||||||
detail,
|
detail,
|
||||||
}: Props) => {
|
}: Props) => {
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
const { isCurrentWorkspaceManager } = useAppContext()
|
|
||||||
const providerBriefInfo = detail.declaration.tool.identity
|
const providerBriefInfo = detail.declaration.tool.identity
|
||||||
const providerKey = `${detail.plugin_id}/${providerBriefInfo.name}`
|
const providerKey = `${detail.plugin_id}/${providerBriefInfo.name}`
|
||||||
const { data: collectionList = [] } = useAllToolProviders()
|
const { data: collectionList = [] } = useAllToolProviders()
|
||||||
const invalidateAllToolProviders = useInvalidateAllToolProviders()
|
|
||||||
const provider = useMemo(() => {
|
const provider = useMemo(() => {
|
||||||
return collectionList.find(collection => collection.name === providerKey)
|
return collectionList.find(collection => collection.name === providerKey)
|
||||||
}, [collectionList, providerKey])
|
}, [collectionList, providerKey])
|
||||||
const { data } = useBuiltinTools(providerKey)
|
const { data } = useBuiltinTools(providerKey)
|
||||||
|
|
||||||
const [showSettingAuth, setShowSettingAuth] = useState(false)
|
|
||||||
|
|
||||||
const handleCredentialSettingUpdate = () => {
|
|
||||||
invalidateAllToolProviders()
|
|
||||||
Toast.notify({
|
|
||||||
type: 'success',
|
|
||||||
message: t('common.api.actionSuccess'),
|
|
||||||
})
|
|
||||||
setShowSettingAuth(false)
|
|
||||||
}
|
|
||||||
|
|
||||||
const { mutate: updatePermission, isPending } = useUpdateProviderCredentials({
|
|
||||||
onSuccess: handleCredentialSettingUpdate,
|
|
||||||
})
|
|
||||||
|
|
||||||
const { mutate: removePermission } = useRemoveProviderCredentials({
|
|
||||||
onSuccess: handleCredentialSettingUpdate,
|
|
||||||
})
|
|
||||||
|
|
||||||
if (!data || !provider)
|
if (!data || !provider)
|
||||||
return null
|
return null
|
||||||
|
|
||||||
@@ -60,26 +31,7 @@ const ActionList = ({
|
|||||||
<div className='mb-1 py-1'>
|
<div className='mb-1 py-1'>
|
||||||
<div className='system-sm-semibold-uppercase mb-1 flex h-6 items-center justify-between text-text-secondary'>
|
<div className='system-sm-semibold-uppercase mb-1 flex h-6 items-center justify-between text-text-secondary'>
|
||||||
{t('plugin.detailPanel.actionNum', { num: data.length, action: data.length > 1 ? 'actions' : 'action' })}
|
{t('plugin.detailPanel.actionNum', { num: data.length, action: data.length > 1 ? 'actions' : 'action' })}
|
||||||
{provider.is_team_authorization && provider.allow_delete && (
|
|
||||||
<Button
|
|
||||||
variant='secondary'
|
|
||||||
size='small'
|
|
||||||
onClick={() => setShowSettingAuth(true)}
|
|
||||||
disabled={!isCurrentWorkspaceManager}
|
|
||||||
>
|
|
||||||
<Indicator className='mr-2' color={'green'} />
|
|
||||||
{t('tools.auth.authorized')}
|
|
||||||
</Button>
|
|
||||||
)}
|
|
||||||
</div>
|
</div>
|
||||||
{!provider.is_team_authorization && provider.allow_delete && (
|
|
||||||
<Button
|
|
||||||
variant='primary'
|
|
||||||
className='w-full'
|
|
||||||
onClick={() => setShowSettingAuth(true)}
|
|
||||||
disabled={!isCurrentWorkspaceManager}
|
|
||||||
>{t('workflow.nodes.tool.authorize')}</Button>
|
|
||||||
)}
|
|
||||||
</div>
|
</div>
|
||||||
<div className='flex flex-col gap-2'>
|
<div className='flex flex-col gap-2'>
|
||||||
{data.map(tool => (
|
{data.map(tool => (
|
||||||
@@ -93,18 +45,6 @@ const ActionList = ({
|
|||||||
/>
|
/>
|
||||||
))}
|
))}
|
||||||
</div>
|
</div>
|
||||||
{showSettingAuth && (
|
|
||||||
<ConfigCredential
|
|
||||||
collection={provider}
|
|
||||||
onCancel={() => setShowSettingAuth(false)}
|
|
||||||
onSaved={async value => updatePermission({
|
|
||||||
providerName: provider.name,
|
|
||||||
credentials: value,
|
|
||||||
})}
|
|
||||||
onRemove={async () => removePermission(provider.name)}
|
|
||||||
isSaving={isPending}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -36,6 +36,9 @@ import { useInvalidateAllToolProviders } from '@/service/use-tools'
|
|||||||
import { API_PREFIX } from '@/config'
|
import { API_PREFIX } from '@/config'
|
||||||
import cn from '@/utils/classnames'
|
import cn from '@/utils/classnames'
|
||||||
import { getMarketplaceUrl } from '@/utils/var'
|
import { getMarketplaceUrl } from '@/utils/var'
|
||||||
|
import { PluginAuth } from '@/app/components/plugins/plugin-auth'
|
||||||
|
import { AuthCategory } from '@/app/components/plugins/plugin-auth'
|
||||||
|
import { useAllToolProviders } from '@/service/use-tools'
|
||||||
|
|
||||||
const i18nPrefix = 'plugin.action'
|
const i18nPrefix = 'plugin.action'
|
||||||
|
|
||||||
@@ -68,7 +71,14 @@ const DetailHeader = ({
|
|||||||
meta,
|
meta,
|
||||||
plugin_id,
|
plugin_id,
|
||||||
} = detail
|
} = detail
|
||||||
const { author, category, name, label, description, icon, verified } = detail.declaration
|
const { author, category, name, label, description, icon, verified, tool } = detail.declaration
|
||||||
|
const isTool = category === PluginType.tool
|
||||||
|
const providerBriefInfo = tool?.identity
|
||||||
|
const providerKey = `${plugin_id}/${providerBriefInfo?.name}`
|
||||||
|
const { data: collectionList = [] } = useAllToolProviders(isTool)
|
||||||
|
const provider = useMemo(() => {
|
||||||
|
return collectionList.find(collection => collection.name === providerKey)
|
||||||
|
}, [collectionList, providerKey])
|
||||||
const isFromGitHub = source === PluginSource.github
|
const isFromGitHub = source === PluginSource.github
|
||||||
const isFromMarketplace = source === PluginSource.marketplace
|
const isFromMarketplace = source === PluginSource.marketplace
|
||||||
|
|
||||||
@@ -262,7 +272,17 @@ const DetailHeader = ({
|
|||||||
</ActionButton>
|
</ActionButton>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<Description className='mt-3' text={description[locale]} descriptionLineRows={2}></Description>
|
<Description className='mb-2 mt-3 h-auto' text={description[locale]} descriptionLineRows={2}></Description>
|
||||||
|
{
|
||||||
|
category === PluginType.tool && (
|
||||||
|
<PluginAuth
|
||||||
|
pluginPayload={{
|
||||||
|
provider: provider?.name || '',
|
||||||
|
category: AuthCategory.tool,
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
{isShowPluginInfo && (
|
{isShowPluginInfo && (
|
||||||
<PluginInfo
|
<PluginInfo
|
||||||
repository={isFromGitHub ? meta?.repo : ''}
|
repository={isFromGitHub ? meta?.repo : ''}
|
||||||
|
|||||||
@@ -3,9 +3,6 @@ import type { FC } from 'react'
|
|||||||
import React, { useMemo, useState } from 'react'
|
import React, { useMemo, useState } from 'react'
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
import Link from 'next/link'
|
import Link from 'next/link'
|
||||||
import {
|
|
||||||
RiArrowLeftLine,
|
|
||||||
} from '@remixicon/react'
|
|
||||||
import {
|
import {
|
||||||
PortalToFollowElem,
|
PortalToFollowElem,
|
||||||
PortalToFollowElemContent,
|
PortalToFollowElemContent,
|
||||||
@@ -15,24 +12,17 @@ import ToolTrigger from '@/app/components/plugins/plugin-detail-panel/tool-selec
|
|||||||
import ToolItem from '@/app/components/plugins/plugin-detail-panel/tool-selector/tool-item'
|
import ToolItem from '@/app/components/plugins/plugin-detail-panel/tool-selector/tool-item'
|
||||||
import ToolPicker from '@/app/components/workflow/block-selector/tool-picker'
|
import ToolPicker from '@/app/components/workflow/block-selector/tool-picker'
|
||||||
import ToolForm from '@/app/components/workflow/nodes/tool/components/tool-form'
|
import ToolForm from '@/app/components/workflow/nodes/tool/components/tool-form'
|
||||||
import Button from '@/app/components/base/button'
|
|
||||||
import Indicator from '@/app/components/header/indicator'
|
|
||||||
import ToolCredentialForm from '@/app/components/plugins/plugin-detail-panel/tool-selector/tool-credentials-form'
|
|
||||||
import Toast from '@/app/components/base/toast'
|
|
||||||
import Textarea from '@/app/components/base/textarea'
|
import Textarea from '@/app/components/base/textarea'
|
||||||
import Divider from '@/app/components/base/divider'
|
import Divider from '@/app/components/base/divider'
|
||||||
import TabSlider from '@/app/components/base/tab-slider-plain'
|
import TabSlider from '@/app/components/base/tab-slider-plain'
|
||||||
import ReasoningConfigForm from '@/app/components/plugins/plugin-detail-panel/tool-selector/reasoning-config-form'
|
import ReasoningConfigForm from '@/app/components/plugins/plugin-detail-panel/tool-selector/reasoning-config-form'
|
||||||
import { generateFormValue, getPlainValue, getStructureValue, toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema'
|
import { generateFormValue, getPlainValue, getStructureValue, toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema'
|
||||||
|
|
||||||
import { useAppContext } from '@/context/app-context'
|
|
||||||
import {
|
import {
|
||||||
useAllBuiltInTools,
|
useAllBuiltInTools,
|
||||||
useAllCustomTools,
|
useAllCustomTools,
|
||||||
useAllMCPTools,
|
useAllMCPTools,
|
||||||
useAllWorkflowTools,
|
useAllWorkflowTools,
|
||||||
useInvalidateAllBuiltInTools,
|
useInvalidateAllBuiltInTools,
|
||||||
useUpdateProviderCredentials,
|
|
||||||
} from '@/service/use-tools'
|
} from '@/service/use-tools'
|
||||||
import { useInvalidateInstalledPluginList } from '@/service/use-plugins'
|
import { useInvalidateInstalledPluginList } from '@/service/use-plugins'
|
||||||
import { usePluginInstalledCheck } from '@/app/components/plugins/plugin-detail-panel/tool-selector/hooks'
|
import { usePluginInstalledCheck } from '@/app/components/plugins/plugin-detail-panel/tool-selector/hooks'
|
||||||
@@ -46,6 +36,10 @@ import { MARKETPLACE_API_PREFIX } from '@/config'
|
|||||||
import type { Node } from 'reactflow'
|
import type { Node } from 'reactflow'
|
||||||
import type { NodeOutPutVar } from '@/app/components/workflow/types'
|
import type { NodeOutPutVar } from '@/app/components/workflow/types'
|
||||||
import cn from '@/utils/classnames'
|
import cn from '@/utils/classnames'
|
||||||
|
import {
|
||||||
|
AuthCategory,
|
||||||
|
PluginAuthInAgent,
|
||||||
|
} from '@/app/components/plugins/plugin-auth'
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
disabled?: boolean
|
disabled?: boolean
|
||||||
@@ -196,23 +190,6 @@ const ToolSelector: FC<Props> = ({
|
|||||||
} as any)
|
} as any)
|
||||||
}
|
}
|
||||||
|
|
||||||
// authorization
|
|
||||||
const { isCurrentWorkspaceManager } = useAppContext()
|
|
||||||
const [isShowSettingAuth, setShowSettingAuth] = useState(false)
|
|
||||||
const handleCredentialSettingUpdate = () => {
|
|
||||||
invalidateAllBuiltinTools()
|
|
||||||
Toast.notify({
|
|
||||||
type: 'success',
|
|
||||||
message: t('common.api.actionSuccess'),
|
|
||||||
})
|
|
||||||
setShowSettingAuth(false)
|
|
||||||
onShowChange(false)
|
|
||||||
}
|
|
||||||
|
|
||||||
const { mutate: updatePermission } = useUpdateProviderCredentials({
|
|
||||||
onSuccess: handleCredentialSettingUpdate,
|
|
||||||
})
|
|
||||||
|
|
||||||
// install from marketplace
|
// install from marketplace
|
||||||
const currentTool = useMemo(() => {
|
const currentTool = useMemo(() => {
|
||||||
return currentProvider?.tools.find(tool => tool.name === value?.tool_name)
|
return currentProvider?.tools.find(tool => tool.name === value?.tool_name)
|
||||||
@@ -226,6 +203,12 @@ const ToolSelector: FC<Props> = ({
|
|||||||
invalidateAllBuiltinTools()
|
invalidateAllBuiltinTools()
|
||||||
invalidateInstalledPluginList()
|
invalidateInstalledPluginList()
|
||||||
}
|
}
|
||||||
|
const handleAuthorizationItemClick = (id: string) => {
|
||||||
|
onSelect({
|
||||||
|
...value,
|
||||||
|
credential_id: id,
|
||||||
|
} as any)
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
@@ -264,7 +247,6 @@ const ToolSelector: FC<Props> = ({
|
|||||||
onSwitchChange={handleEnabledChange}
|
onSwitchChange={handleEnabledChange}
|
||||||
onDelete={onDelete}
|
onDelete={onDelete}
|
||||||
noAuth={currentProvider && currentTool && !currentProvider.is_team_authorization}
|
noAuth={currentProvider && currentTool && !currentProvider.is_team_authorization}
|
||||||
onAuth={() => setShowSettingAuth(true)}
|
|
||||||
uninstalled={!currentProvider && inMarketPlace}
|
uninstalled={!currentProvider && inMarketPlace}
|
||||||
versionMismatch={currentProvider && inMarketPlace && !currentTool}
|
versionMismatch={currentProvider && inMarketPlace && !currentTool}
|
||||||
installInfo={manifest?.latest_package_identifier}
|
installInfo={manifest?.latest_package_identifier}
|
||||||
@@ -284,171 +266,131 @@ const ToolSelector: FC<Props> = ({
|
|||||||
)}
|
)}
|
||||||
</PortalToFollowElemTrigger>
|
</PortalToFollowElemTrigger>
|
||||||
<PortalToFollowElemContent>
|
<PortalToFollowElemContent>
|
||||||
<div className={cn('relative max-h-[642px] min-h-20 w-[361px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur pb-4 shadow-lg backdrop-blur-sm', !isShowSettingAuth && 'overflow-y-auto pb-2')}>
|
<div className={cn('relative max-h-[642px] min-h-20 w-[361px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur pb-4 shadow-lg backdrop-blur-sm', 'overflow-y-auto pb-2')}>
|
||||||
{!isShowSettingAuth && (
|
<>
|
||||||
<>
|
<div className='system-xl-semibold px-4 pb-1 pt-3.5 text-text-primary'>{t(`plugin.detailPanel.toolSelector.${isEdit ? 'toolSetting' : 'title'}`)}</div>
|
||||||
<div className='system-xl-semibold px-4 pb-1 pt-3.5 text-text-primary'>{t(`plugin.detailPanel.toolSelector.${isEdit ? 'toolSetting' : 'title'}`)}</div>
|
{/* base form */}
|
||||||
{/* base form */}
|
<div className='flex flex-col gap-3 px-4 py-2'>
|
||||||
<div className='flex flex-col gap-3 px-4 py-2'>
|
<div className='flex flex-col gap-1'>
|
||||||
<div className='flex flex-col gap-1'>
|
<div className='system-sm-semibold flex h-6 items-center text-text-secondary'>{t('plugin.detailPanel.toolSelector.toolLabel')}</div>
|
||||||
<div className='system-sm-semibold flex h-6 items-center text-text-secondary'>{t('plugin.detailPanel.toolSelector.toolLabel')}</div>
|
<ToolPicker
|
||||||
<ToolPicker
|
placement='bottom'
|
||||||
placement='bottom'
|
offset={offset}
|
||||||
offset={offset}
|
trigger={
|
||||||
trigger={
|
<ToolTrigger
|
||||||
<ToolTrigger
|
open={panelShowState || isShowChooseTool}
|
||||||
open={panelShowState || isShowChooseTool}
|
value={value}
|
||||||
value={value}
|
provider={currentProvider}
|
||||||
provider={currentProvider}
|
|
||||||
/>
|
|
||||||
}
|
|
||||||
isShow={panelShowState || isShowChooseTool}
|
|
||||||
onShowChange={trigger ? onPanelShowStateChange as any : setIsShowChooseTool}
|
|
||||||
disabled={false}
|
|
||||||
supportAddCustomTool
|
|
||||||
onSelect={handleSelectTool}
|
|
||||||
onSelectMultiple={handleSelectMultipleTool}
|
|
||||||
scope={scope}
|
|
||||||
selectedTools={selectedTools}
|
|
||||||
canChooseMCPTool={canChooseMCPTool}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
<div className='flex flex-col gap-1'>
|
|
||||||
<div className='system-sm-semibold flex h-6 items-center text-text-secondary'>{t('plugin.detailPanel.toolSelector.descriptionLabel')}</div>
|
|
||||||
<Textarea
|
|
||||||
className='resize-none'
|
|
||||||
placeholder={t('plugin.detailPanel.toolSelector.descriptionPlaceholder')}
|
|
||||||
value={value?.extra?.description || ''}
|
|
||||||
onChange={handleDescriptionChange}
|
|
||||||
disabled={!value?.provider_name}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
{/* authorization */}
|
|
||||||
{currentProvider && currentProvider.type === CollectionType.builtIn && currentProvider.allow_delete && (
|
|
||||||
<>
|
|
||||||
<Divider className='my-1 w-full' />
|
|
||||||
<div className='px-4 py-2'>
|
|
||||||
{!currentProvider.is_team_authorization && (
|
|
||||||
<Button
|
|
||||||
variant='primary'
|
|
||||||
className={cn('w-full shrink-0')}
|
|
||||||
onClick={() => setShowSettingAuth(true)}
|
|
||||||
disabled={!isCurrentWorkspaceManager}
|
|
||||||
>
|
|
||||||
{t('tools.auth.unauthorized')}
|
|
||||||
</Button>
|
|
||||||
)}
|
|
||||||
{currentProvider.is_team_authorization && (
|
|
||||||
<Button
|
|
||||||
variant='secondary'
|
|
||||||
className={cn('w-full shrink-0')}
|
|
||||||
onClick={() => setShowSettingAuth(true)}
|
|
||||||
disabled={!isCurrentWorkspaceManager}
|
|
||||||
>
|
|
||||||
<Indicator className='mr-2' color={'green'} />
|
|
||||||
{t('tools.auth.authorized')}
|
|
||||||
</Button>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
</>
|
|
||||||
)}
|
|
||||||
{/* tool settings */}
|
|
||||||
{(currentToolSettings.length > 0 || currentToolParams.length > 0) && currentProvider?.is_team_authorization && (
|
|
||||||
<>
|
|
||||||
<Divider className='my-1 w-full' />
|
|
||||||
{/* tabs */}
|
|
||||||
{nodeId && showTabSlider && (
|
|
||||||
<TabSlider
|
|
||||||
className='mt-1 shrink-0 px-4'
|
|
||||||
itemClassName='py-3'
|
|
||||||
noBorderBottom
|
|
||||||
smallItem
|
|
||||||
value={currType}
|
|
||||||
onChange={(value) => {
|
|
||||||
setCurrType(value)
|
|
||||||
}}
|
|
||||||
options={[
|
|
||||||
{ value: 'settings', text: t('plugin.detailPanel.toolSelector.settings')! },
|
|
||||||
{ value: 'params', text: t('plugin.detailPanel.toolSelector.params')! },
|
|
||||||
]}
|
|
||||||
/>
|
/>
|
||||||
)}
|
}
|
||||||
{nodeId && showTabSlider && currType === 'params' && (
|
isShow={panelShowState || isShowChooseTool}
|
||||||
<div className='px-4 py-2'>
|
onShowChange={trigger ? onPanelShowStateChange as any : setIsShowChooseTool}
|
||||||
|
disabled={false}
|
||||||
|
supportAddCustomTool
|
||||||
|
onSelect={handleSelectTool}
|
||||||
|
onSelectMultiple={handleSelectMultipleTool}
|
||||||
|
scope={scope}
|
||||||
|
selectedTools={selectedTools}
|
||||||
|
canChooseMCPTool={canChooseMCPTool}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div className='flex flex-col gap-1'>
|
||||||
|
<div className='system-sm-semibold flex h-6 items-center text-text-secondary'>{t('plugin.detailPanel.toolSelector.descriptionLabel')}</div>
|
||||||
|
<Textarea
|
||||||
|
className='resize-none'
|
||||||
|
placeholder={t('plugin.detailPanel.toolSelector.descriptionPlaceholder')}
|
||||||
|
value={value?.extra?.description || ''}
|
||||||
|
onChange={handleDescriptionChange}
|
||||||
|
disabled={!value?.provider_name}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{/* authorization */}
|
||||||
|
{currentProvider && currentProvider.type === CollectionType.builtIn && currentProvider.allow_delete && (
|
||||||
|
<>
|
||||||
|
<Divider className='my-1 w-full' />
|
||||||
|
<div className='px-4 py-2'>
|
||||||
|
<PluginAuthInAgent
|
||||||
|
pluginPayload={{
|
||||||
|
provider: currentProvider.name,
|
||||||
|
category: AuthCategory.tool,
|
||||||
|
}}
|
||||||
|
credentialId={value?.credential_id}
|
||||||
|
onAuthorizationItemClick={handleAuthorizationItemClick}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
{/* tool settings */}
|
||||||
|
{(currentToolSettings.length > 0 || currentToolParams.length > 0) && currentProvider?.is_team_authorization && (
|
||||||
|
<>
|
||||||
|
<Divider className='my-1 w-full' />
|
||||||
|
{/* tabs */}
|
||||||
|
{nodeId && showTabSlider && (
|
||||||
|
<TabSlider
|
||||||
|
className='mt-1 shrink-0 px-4'
|
||||||
|
itemClassName='py-3'
|
||||||
|
noBorderBottom
|
||||||
|
smallItem
|
||||||
|
value={currType}
|
||||||
|
onChange={(value) => {
|
||||||
|
setCurrType(value)
|
||||||
|
}}
|
||||||
|
options={[
|
||||||
|
{ value: 'settings', text: t('plugin.detailPanel.toolSelector.settings')! },
|
||||||
|
{ value: 'params', text: t('plugin.detailPanel.toolSelector.params')! },
|
||||||
|
]}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{nodeId && showTabSlider && currType === 'params' && (
|
||||||
|
<div className='px-4 py-2'>
|
||||||
|
<div className='system-xs-regular text-text-tertiary'>{t('plugin.detailPanel.toolSelector.paramsTip1')}</div>
|
||||||
|
<div className='system-xs-regular text-text-tertiary'>{t('plugin.detailPanel.toolSelector.paramsTip2')}</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
{/* user settings only */}
|
||||||
|
{userSettingsOnly && (
|
||||||
|
<div className='p-4 pb-1'>
|
||||||
|
<div className='system-sm-semibold-uppercase text-text-primary'>{t('plugin.detailPanel.toolSelector.settings')}</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
{/* reasoning config only */}
|
||||||
|
{nodeId && reasoningConfigOnly && (
|
||||||
|
<div className='mb-1 p-4 pb-1'>
|
||||||
|
<div className='system-sm-semibold-uppercase text-text-primary'>{t('plugin.detailPanel.toolSelector.params')}</div>
|
||||||
|
<div className='pb-1'>
|
||||||
<div className='system-xs-regular text-text-tertiary'>{t('plugin.detailPanel.toolSelector.paramsTip1')}</div>
|
<div className='system-xs-regular text-text-tertiary'>{t('plugin.detailPanel.toolSelector.paramsTip1')}</div>
|
||||||
<div className='system-xs-regular text-text-tertiary'>{t('plugin.detailPanel.toolSelector.paramsTip2')}</div>
|
<div className='system-xs-regular text-text-tertiary'>{t('plugin.detailPanel.toolSelector.paramsTip2')}</div>
|
||||||
</div>
|
</div>
|
||||||
)}
|
</div>
|
||||||
{/* user settings only */}
|
)}
|
||||||
{userSettingsOnly && (
|
{/* user settings form */}
|
||||||
<div className='p-4 pb-1'>
|
{(currType === 'settings' || userSettingsOnly) && (
|
||||||
<div className='system-sm-semibold-uppercase text-text-primary'>{t('plugin.detailPanel.toolSelector.settings')}</div>
|
<div className='px-4 py-2'>
|
||||||
</div>
|
<ToolForm
|
||||||
)}
|
inPanel
|
||||||
{/* reasoning config only */}
|
readOnly={false}
|
||||||
{nodeId && reasoningConfigOnly && (
|
|
||||||
<div className='mb-1 p-4 pb-1'>
|
|
||||||
<div className='system-sm-semibold-uppercase text-text-primary'>{t('plugin.detailPanel.toolSelector.params')}</div>
|
|
||||||
<div className='pb-1'>
|
|
||||||
<div className='system-xs-regular text-text-tertiary'>{t('plugin.detailPanel.toolSelector.paramsTip1')}</div>
|
|
||||||
<div className='system-xs-regular text-text-tertiary'>{t('plugin.detailPanel.toolSelector.paramsTip2')}</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
{/* user settings form */}
|
|
||||||
{(currType === 'settings' || userSettingsOnly) && (
|
|
||||||
<div className='px-4 py-2'>
|
|
||||||
<ToolForm
|
|
||||||
inPanel
|
|
||||||
readOnly={false}
|
|
||||||
nodeId={nodeId}
|
|
||||||
schema={settingsFormSchemas as any}
|
|
||||||
value={getPlainValue(value?.settings || {})}
|
|
||||||
onChange={handleSettingsFormChange}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
{/* reasoning config form */}
|
|
||||||
{nodeId && (currType === 'params' || reasoningConfigOnly) && (
|
|
||||||
<ReasoningConfigForm
|
|
||||||
value={value?.parameters || {}}
|
|
||||||
onChange={handleParamsFormChange}
|
|
||||||
schemas={paramsFormSchemas as any}
|
|
||||||
nodeOutputVars={nodeOutputVars}
|
|
||||||
availableNodes={availableNodes}
|
|
||||||
nodeId={nodeId}
|
nodeId={nodeId}
|
||||||
|
schema={settingsFormSchemas as any}
|
||||||
|
value={getPlainValue(value?.settings || {})}
|
||||||
|
onChange={handleSettingsFormChange}
|
||||||
/>
|
/>
|
||||||
)}
|
</div>
|
||||||
</>
|
)}
|
||||||
)}
|
{/* reasoning config form */}
|
||||||
</>
|
{nodeId && (currType === 'params' || reasoningConfigOnly) && (
|
||||||
)}
|
<ReasoningConfigForm
|
||||||
{/* authorization panel */}
|
value={value?.parameters || {}}
|
||||||
{isShowSettingAuth && currentProvider && (
|
onChange={handleParamsFormChange}
|
||||||
<>
|
schemas={paramsFormSchemas as any}
|
||||||
<div className='relative flex flex-col gap-1 pt-3.5'>
|
nodeOutputVars={nodeOutputVars}
|
||||||
<div className='absolute -top-2 left-2 w-[345px] rounded-t-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur pt-2 backdrop-blur-sm'></div>
|
availableNodes={availableNodes}
|
||||||
<div
|
nodeId={nodeId}
|
||||||
className='system-xs-semibold-uppercase flex h-6 cursor-pointer items-center gap-1 px-3 text-text-accent-secondary'
|
/>
|
||||||
onClick={() => setShowSettingAuth(false)}
|
)}
|
||||||
>
|
</>
|
||||||
<RiArrowLeftLine className='h-4 w-4' />
|
)}
|
||||||
BACK
|
</>
|
||||||
</div>
|
|
||||||
<div className='system-xl-semibold px-4 text-text-primary'>{t('tools.auth.setupModalTitle')}</div>
|
|
||||||
<div className='system-xs-regular px-4 text-text-tertiary'>{t('tools.auth.setupModalTitleDescription')}</div>
|
|
||||||
</div>
|
|
||||||
<ToolCredentialForm
|
|
||||||
collection={currentProvider}
|
|
||||||
onCancel={() => setShowSettingAuth(false)}
|
|
||||||
onSaved={async value => updatePermission({
|
|
||||||
providerName: currentProvider.name,
|
|
||||||
credentials: value,
|
|
||||||
})}
|
|
||||||
/>
|
|
||||||
</>
|
|
||||||
)}
|
|
||||||
</div>
|
</div>
|
||||||
</PortalToFollowElemContent>
|
</PortalToFollowElemContent>
|
||||||
</PortalToFollowElem>
|
</PortalToFollowElem>
|
||||||
|
|||||||
@@ -30,7 +30,6 @@ type Props = {
|
|||||||
onSwitchChange?: (value: boolean) => void
|
onSwitchChange?: (value: boolean) => void
|
||||||
onDelete?: () => void
|
onDelete?: () => void
|
||||||
noAuth?: boolean
|
noAuth?: boolean
|
||||||
onAuth?: () => void
|
|
||||||
isError?: boolean
|
isError?: boolean
|
||||||
errorTip?: any
|
errorTip?: any
|
||||||
uninstalled?: boolean
|
uninstalled?: boolean
|
||||||
@@ -38,6 +37,7 @@ type Props = {
|
|||||||
onInstall?: () => void
|
onInstall?: () => void
|
||||||
versionMismatch?: boolean
|
versionMismatch?: boolean
|
||||||
open: boolean
|
open: boolean
|
||||||
|
authRemoved?: boolean
|
||||||
canChooseMCPTool?: boolean,
|
canChooseMCPTool?: boolean,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -53,13 +53,13 @@ const ToolItem = ({
|
|||||||
onSwitchChange,
|
onSwitchChange,
|
||||||
onDelete,
|
onDelete,
|
||||||
noAuth,
|
noAuth,
|
||||||
onAuth,
|
|
||||||
uninstalled,
|
uninstalled,
|
||||||
installInfo,
|
installInfo,
|
||||||
onInstall,
|
onInstall,
|
||||||
isError,
|
isError,
|
||||||
errorTip,
|
errorTip,
|
||||||
versionMismatch,
|
versionMismatch,
|
||||||
|
authRemoved,
|
||||||
canChooseMCPTool,
|
canChooseMCPTool,
|
||||||
}: Props) => {
|
}: Props) => {
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
@@ -125,11 +125,17 @@ const ToolItem = ({
|
|||||||
<McpToolNotSupportTooltip />
|
<McpToolNotSupportTooltip />
|
||||||
)}
|
)}
|
||||||
{!isError && !uninstalled && !versionMismatch && noAuth && (
|
{!isError && !uninstalled && !versionMismatch && noAuth && (
|
||||||
<Button variant='secondary' size='small' onClick={onAuth}>
|
<Button variant='secondary' size='small'>
|
||||||
{t('tools.notAuthorized')}
|
{t('tools.notAuthorized')}
|
||||||
<Indicator className='ml-2' color='orange' />
|
<Indicator className='ml-2' color='orange' />
|
||||||
</Button>
|
</Button>
|
||||||
)}
|
)}
|
||||||
|
{!isError && !uninstalled && !versionMismatch && authRemoved && (
|
||||||
|
<Button variant='secondary' size='small'>
|
||||||
|
{t('plugin.auth.authRemoved')}
|
||||||
|
<Indicator className='ml-2' color='red' />
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
{!isError && !uninstalled && versionMismatch && installInfo && (
|
{!isError && !uninstalled && versionMismatch && installInfo && (
|
||||||
<div onClick={e => e.stopPropagation()}>
|
<div onClick={e => e.stopPropagation()}>
|
||||||
<SwitchPluginVersion
|
<SwitchPluginVersion
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ export type ToolDefaultValue = {
|
|||||||
params: Record<string, any>
|
params: Record<string, any>
|
||||||
paramSchemas: Record<string, any>[]
|
paramSchemas: Record<string, any>[]
|
||||||
output_schema: Record<string, any>
|
output_schema: Record<string, any>
|
||||||
|
credential_id?: string
|
||||||
meta?: PluginMeta
|
meta?: PluginMeta
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -46,4 +47,5 @@ export type ToolValue = {
|
|||||||
parameters?: Record<string, any>
|
parameters?: Record<string, any>
|
||||||
enabled?: boolean
|
enabled?: boolean
|
||||||
extra?: Record<string, any>
|
extra?: Record<string, any>
|
||||||
|
credential_id?: string
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -59,6 +59,12 @@ import { useLogs } from '@/app/components/workflow/run/hooks'
|
|||||||
import PanelWrap from '../before-run-form/panel-wrap'
|
import PanelWrap from '../before-run-form/panel-wrap'
|
||||||
import SpecialResultPanel from '@/app/components/workflow/run/special-result-panel'
|
import SpecialResultPanel from '@/app/components/workflow/run/special-result-panel'
|
||||||
import { Stop } from '@/app/components/base/icons/src/vender/line/mediaAndDevices'
|
import { Stop } from '@/app/components/base/icons/src/vender/line/mediaAndDevices'
|
||||||
|
import {
|
||||||
|
AuthorizedInNode,
|
||||||
|
PluginAuth,
|
||||||
|
} from '@/app/components/plugins/plugin-auth'
|
||||||
|
import { AuthCategory } from '@/app/components/plugins/plugin-auth'
|
||||||
|
import { canFindTool } from '@/utils'
|
||||||
|
|
||||||
type BasePanelProps = {
|
type BasePanelProps = {
|
||||||
children: ReactNode
|
children: ReactNode
|
||||||
@@ -221,6 +227,22 @@ const BasePanel: FC<BasePanelProps> = ({
|
|||||||
return {}
|
return {}
|
||||||
})()
|
})()
|
||||||
|
|
||||||
|
const buildInTools = useStore(s => s.buildInTools)
|
||||||
|
const currCollection = useMemo(() => {
|
||||||
|
return buildInTools.find(item => canFindTool(item.id, data.provider_id))
|
||||||
|
}, [buildInTools, data.provider_id])
|
||||||
|
const showPluginAuth = useMemo(() => {
|
||||||
|
return data.type === BlockEnum.Tool && currCollection?.allow_delete
|
||||||
|
}, [currCollection, data.type])
|
||||||
|
const handleAuthorizationItemClick = useCallback((credential_id: string) => {
|
||||||
|
handleNodeDataUpdateWithSyncDraft({
|
||||||
|
id,
|
||||||
|
data: {
|
||||||
|
credential_id,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}, [handleNodeDataUpdateWithSyncDraft, id])
|
||||||
|
|
||||||
if(logParams.showSpecialResultPanel) {
|
if(logParams.showSpecialResultPanel) {
|
||||||
return (
|
return (
|
||||||
<div className={cn(
|
<div className={cn(
|
||||||
@@ -353,12 +375,42 @@ const BasePanel: FC<BasePanelProps> = ({
|
|||||||
onChange={handleDescriptionChange}
|
onChange={handleDescriptionChange}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
<div className='pl-4'>
|
{
|
||||||
<Tab
|
showPluginAuth && (
|
||||||
value={tabType}
|
<PluginAuth
|
||||||
onChange={setTabType}
|
className='px-4 pb-2'
|
||||||
/>
|
pluginPayload={{
|
||||||
</div>
|
provider: currCollection?.name || '',
|
||||||
|
category: AuthCategory.tool,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<div className='flex items-center justify-between pl-4 pr-3'>
|
||||||
|
<Tab
|
||||||
|
value={tabType}
|
||||||
|
onChange={setTabType}
|
||||||
|
/>
|
||||||
|
<AuthorizedInNode
|
||||||
|
pluginPayload={{
|
||||||
|
provider: currCollection?.name || '',
|
||||||
|
category: AuthCategory.tool,
|
||||||
|
}}
|
||||||
|
onAuthorizationItemClick={handleAuthorizationItemClick}
|
||||||
|
credentialId={data.credential_id}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</PluginAuth>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
!showPluginAuth && (
|
||||||
|
<div className='flex items-center justify-between pl-4 pr-3'>
|
||||||
|
<Tab
|
||||||
|
value={tabType}
|
||||||
|
onChange={setTabType}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
<Split />
|
<Split />
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|||||||
@@ -5,10 +5,8 @@ import Split from '../_base/components/split'
|
|||||||
import type { ToolNodeType } from './types'
|
import type { ToolNodeType } from './types'
|
||||||
import useConfig from './use-config'
|
import useConfig from './use-config'
|
||||||
import ToolForm from './components/tool-form'
|
import ToolForm from './components/tool-form'
|
||||||
import Button from '@/app/components/base/button'
|
|
||||||
import Field from '@/app/components/workflow/nodes/_base/components/field'
|
import Field from '@/app/components/workflow/nodes/_base/components/field'
|
||||||
import type { NodePanelProps } from '@/app/components/workflow/types'
|
import type { NodePanelProps } from '@/app/components/workflow/types'
|
||||||
import ConfigCredential from '@/app/components/tools/setting/build-in/config-credentials'
|
|
||||||
import Loading from '@/app/components/base/loading'
|
import Loading from '@/app/components/base/loading'
|
||||||
import OutputVars, { VarItem } from '@/app/components/workflow/nodes/_base/components/output-vars'
|
import OutputVars, { VarItem } from '@/app/components/workflow/nodes/_base/components/output-vars'
|
||||||
import StructureOutputItem from '@/app/components/workflow/nodes/_base/components/variable/object-child-tree-panel/show'
|
import StructureOutputItem from '@/app/components/workflow/nodes/_base/components/variable/object-child-tree-panel/show'
|
||||||
@@ -32,10 +30,6 @@ const Panel: FC<NodePanelProps<ToolNodeType>> = ({
|
|||||||
setToolSettingValue,
|
setToolSettingValue,
|
||||||
currCollection,
|
currCollection,
|
||||||
isShowAuthBtn,
|
isShowAuthBtn,
|
||||||
showSetAuth,
|
|
||||||
showSetAuthModal,
|
|
||||||
hideSetAuthModal,
|
|
||||||
handleSaveAuth,
|
|
||||||
isLoading,
|
isLoading,
|
||||||
outputSchema,
|
outputSchema,
|
||||||
hasObjectOutput,
|
hasObjectOutput,
|
||||||
@@ -52,19 +46,6 @@ const Panel: FC<NodePanelProps<ToolNodeType>> = ({
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<div className='pt-2'>
|
<div className='pt-2'>
|
||||||
{!readOnly && isShowAuthBtn && (
|
|
||||||
<>
|
|
||||||
<div className='px-4'>
|
|
||||||
<Button
|
|
||||||
variant='primary'
|
|
||||||
className='w-full'
|
|
||||||
onClick={showSetAuthModal}
|
|
||||||
>
|
|
||||||
{t(`${i18nPrefix}.authorize`)}
|
|
||||||
</Button>
|
|
||||||
</div>
|
|
||||||
</>
|
|
||||||
)}
|
|
||||||
{!isShowAuthBtn && (
|
{!isShowAuthBtn && (
|
||||||
<div className='relative'>
|
<div className='relative'>
|
||||||
{toolInputVarSchema.length > 0 && (
|
{toolInputVarSchema.length > 0 && (
|
||||||
@@ -109,15 +90,6 @@ const Panel: FC<NodePanelProps<ToolNodeType>> = ({
|
|||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{showSetAuth && (
|
|
||||||
<ConfigCredential
|
|
||||||
collection={currCollection!}
|
|
||||||
onCancel={hideSetAuthModal}
|
|
||||||
onSaved={handleSaveAuth}
|
|
||||||
isHideRemoveBtn
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
|
|
||||||
<div>
|
<div>
|
||||||
<OutputVars>
|
<OutputVars>
|
||||||
<>
|
<>
|
||||||
|
|||||||
@@ -93,6 +93,7 @@ export type CommonNodeType<T = {}> = {
|
|||||||
error_strategy?: ErrorHandleTypeEnum
|
error_strategy?: ErrorHandleTypeEnum
|
||||||
retry_config?: WorkflowRetryConfig
|
retry_config?: WorkflowRetryConfig
|
||||||
default_value?: DefaultValueForm[]
|
default_value?: DefaultValueForm[]
|
||||||
|
credential_id?: string
|
||||||
} & T & Partial<Pick<ToolDefaultValue, 'provider_id' | 'provider_type' | 'provider_name' | 'tool_name'>>
|
} & T & Partial<Pick<ToolDefaultValue, 'provider_id' | 'provider_type' | 'provider_name' | 'tool_name'>>
|
||||||
|
|
||||||
export type CommonEdgeType = {
|
export type CommonEdgeType = {
|
||||||
|
|||||||
@@ -214,6 +214,29 @@ const translation = {
|
|||||||
requestAPlugin: 'Request a plugin',
|
requestAPlugin: 'Request a plugin',
|
||||||
publishPlugins: 'Publish plugins',
|
publishPlugins: 'Publish plugins',
|
||||||
difyVersionNotCompatible: 'The current Dify version is not compatible with this plugin, please upgrade to the minimum version required: {{minimalDifyVersion}}',
|
difyVersionNotCompatible: 'The current Dify version is not compatible with this plugin, please upgrade to the minimum version required: {{minimalDifyVersion}}',
|
||||||
|
auth: {
|
||||||
|
default: 'Default',
|
||||||
|
custom: 'Custom',
|
||||||
|
setDefault: 'Set as default',
|
||||||
|
useOAuth: 'Use OAuth',
|
||||||
|
useOAuthAuth: 'Use OAuth Authorization',
|
||||||
|
addOAuth: 'Add OAuth',
|
||||||
|
setupOAuth: 'Setup OAuth Client',
|
||||||
|
useApi: 'Use API Key',
|
||||||
|
addApi: 'Add API Key',
|
||||||
|
useApiAuth: 'API Key Authorization Configuration',
|
||||||
|
useApiAuthDesc: 'After configuring credentials, all members within the workspace can use this tool when orchestrating applications.',
|
||||||
|
oauthClientSettings: 'OAuth Client Settings',
|
||||||
|
saveOnly: 'Save only',
|
||||||
|
saveAndAuth: 'Save and Authorize',
|
||||||
|
authorization: 'Authorization',
|
||||||
|
authorizations: 'Authorizations',
|
||||||
|
authorizationName: 'Authorization Name',
|
||||||
|
workspaceDefault: 'Workspace Default',
|
||||||
|
authRemoved: 'Auth removed',
|
||||||
|
clientInfo: 'As no system client secrets found for this tool provider, setup it manually is required, for redirect_uri, please use',
|
||||||
|
oauthClient: 'OAuth Client',
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
export default translation
|
export default translation
|
||||||
|
|||||||
@@ -214,6 +214,29 @@ const translation = {
|
|||||||
requestAPlugin: '申请插件',
|
requestAPlugin: '申请插件',
|
||||||
publishPlugins: '发布插件',
|
publishPlugins: '发布插件',
|
||||||
difyVersionNotCompatible: '当前 Dify 版本不兼容该插件,其最低版本要求为 {{minimalDifyVersion}}',
|
difyVersionNotCompatible: '当前 Dify 版本不兼容该插件,其最低版本要求为 {{minimalDifyVersion}}',
|
||||||
|
auth: {
|
||||||
|
default: '默认',
|
||||||
|
custom: '自定义',
|
||||||
|
setDefault: '设为默认',
|
||||||
|
useOAuth: '使用 OAuth',
|
||||||
|
useOAuthAuth: '使用 OAuth 授权',
|
||||||
|
addOAuth: '添加 OAuth',
|
||||||
|
setupOAuth: '设置 OAuth 客户端',
|
||||||
|
useApi: '使用 API Key',
|
||||||
|
addApi: '添加 API Key',
|
||||||
|
useApiAuth: 'API Key 授权配置',
|
||||||
|
useApiAuthDesc: '配置凭据后,工作区内的所有成员在编排应用时都可以使用此工具。',
|
||||||
|
oauthClientSettings: 'OAuth 客户端设置',
|
||||||
|
saveOnly: '仅保存',
|
||||||
|
saveAndAuth: '保存并授权',
|
||||||
|
authorization: '凭据',
|
||||||
|
authorizations: '凭据',
|
||||||
|
authorizationName: '凭据名称',
|
||||||
|
workspaceDefault: '工作区默认',
|
||||||
|
authRemoved: '凭据已移除',
|
||||||
|
clientInfo: '由于未找到此工具提供者的系统客户端密钥,因此需要手动设置,对于 redirect_uri,请使用',
|
||||||
|
oauthClient: 'OAuth 客户端',
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
export default translation
|
export default translation
|
||||||
|
|||||||
161
web/service/use-plugins-auth.ts
Normal file
161
web/service/use-plugins-auth.ts
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
import {
|
||||||
|
useMutation,
|
||||||
|
useQuery,
|
||||||
|
} from '@tanstack/react-query'
|
||||||
|
import { del, get, post } from './base'
|
||||||
|
import { useInvalid } from './use-base'
|
||||||
|
import type {
|
||||||
|
Credential,
|
||||||
|
CredentialTypeEnum,
|
||||||
|
} from '@/app/components/plugins/plugin-auth/types'
|
||||||
|
import type { FormSchema } from '@/app/components/base/form/types'
|
||||||
|
|
||||||
|
const NAME_SPACE = 'plugins-auth'
|
||||||
|
|
||||||
|
export const useGetPluginCredentialInfo = (
|
||||||
|
url: string,
|
||||||
|
) => {
|
||||||
|
return useQuery({
|
||||||
|
enabled: !!url,
|
||||||
|
queryKey: [NAME_SPACE, 'credential-info', url],
|
||||||
|
queryFn: () => get<{
|
||||||
|
supported_credential_types: string[]
|
||||||
|
credentials: Credential[]
|
||||||
|
is_oauth_custom_client_enabled: boolean
|
||||||
|
}>(url),
|
||||||
|
staleTime: 0,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useInvalidPluginCredentialInfo = (
|
||||||
|
url: string,
|
||||||
|
) => {
|
||||||
|
return useInvalid([NAME_SPACE, 'credential-info', url])
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useSetPluginDefaultCredential = (
|
||||||
|
url: string,
|
||||||
|
) => {
|
||||||
|
return useMutation({
|
||||||
|
mutationFn: (id: string) => {
|
||||||
|
return post(url, { body: { id } })
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useGetPluginCredentialList = (
|
||||||
|
url: string,
|
||||||
|
) => {
|
||||||
|
return useQuery({
|
||||||
|
queryKey: [NAME_SPACE, 'credential-list', url],
|
||||||
|
queryFn: () => get(url),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useAddPluginCredential = (
|
||||||
|
url: string,
|
||||||
|
) => {
|
||||||
|
return useMutation({
|
||||||
|
mutationFn: (params: {
|
||||||
|
credentials: Record<string, any>
|
||||||
|
type: CredentialTypeEnum
|
||||||
|
name?: string
|
||||||
|
}) => {
|
||||||
|
return post(url, { body: params })
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useUpdatePluginCredential = (
|
||||||
|
url: string,
|
||||||
|
) => {
|
||||||
|
return useMutation({
|
||||||
|
mutationFn: (params: {
|
||||||
|
credential_id: string
|
||||||
|
credentials?: Record<string, any>
|
||||||
|
name?: string
|
||||||
|
}) => {
|
||||||
|
return post(url, { body: params })
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useDeletePluginCredential = (
|
||||||
|
url: string,
|
||||||
|
) => {
|
||||||
|
return useMutation({
|
||||||
|
mutationFn: (params: { credential_id: string }) => {
|
||||||
|
return post(url, { body: params })
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useGetPluginCredentialSchema = (
|
||||||
|
url: string,
|
||||||
|
) => {
|
||||||
|
return useQuery({
|
||||||
|
queryKey: [NAME_SPACE, 'credential-schema', url],
|
||||||
|
queryFn: () => get<FormSchema[]>(url),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useGetPluginOAuthUrl = (
|
||||||
|
url: string,
|
||||||
|
) => {
|
||||||
|
return useMutation({
|
||||||
|
mutationKey: [NAME_SPACE, 'oauth-url', url],
|
||||||
|
mutationFn: () => {
|
||||||
|
return get<
|
||||||
|
{
|
||||||
|
authorization_url: string
|
||||||
|
state: string
|
||||||
|
context_id: string
|
||||||
|
}>(url)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useGetPluginOAuthClientSchema = (
|
||||||
|
url: string,
|
||||||
|
) => {
|
||||||
|
return useQuery({
|
||||||
|
queryKey: [NAME_SPACE, 'oauth-client-schema', url],
|
||||||
|
queryFn: () => get<{
|
||||||
|
schema: FormSchema[]
|
||||||
|
is_oauth_custom_client_enabled: boolean
|
||||||
|
is_system_oauth_params_exists?: boolean
|
||||||
|
client_params?: Record<string, any>
|
||||||
|
redirect_uri?: string
|
||||||
|
}>(url),
|
||||||
|
staleTime: 0,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useInvalidPluginOAuthClientSchema = (
|
||||||
|
url: string,
|
||||||
|
) => {
|
||||||
|
return useInvalid([NAME_SPACE, 'oauth-client-schema', url])
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useSetPluginOAuthCustomClient = (
|
||||||
|
url: string,
|
||||||
|
) => {
|
||||||
|
return useMutation({
|
||||||
|
mutationFn: (params: {
|
||||||
|
client_params: Record<string, any>
|
||||||
|
enable_oauth_custom_client: boolean
|
||||||
|
}) => {
|
||||||
|
return post<{ result: string }>(url, { body: params })
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useDeletePluginOAuthCustomClient = (
|
||||||
|
url: string,
|
||||||
|
) => {
|
||||||
|
return useMutation({
|
||||||
|
mutationFn: () => {
|
||||||
|
return del<{ result: string }>(url)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -16,10 +16,11 @@ import {
|
|||||||
const NAME_SPACE = 'tools'
|
const NAME_SPACE = 'tools'
|
||||||
|
|
||||||
const useAllToolProvidersKey = [NAME_SPACE, 'allToolProviders']
|
const useAllToolProvidersKey = [NAME_SPACE, 'allToolProviders']
|
||||||
export const useAllToolProviders = () => {
|
export const useAllToolProviders = (enabled = true) => {
|
||||||
return useQuery<Collection[]>({
|
return useQuery<Collection[]>({
|
||||||
queryKey: useAllToolProvidersKey,
|
queryKey: useAllToolProvidersKey,
|
||||||
queryFn: () => get<Collection[]>('/workspaces/current/tool-providers'),
|
queryFn: () => get<Collection[]>('/workspaces/current/tool-providers'),
|
||||||
|
enabled,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -130,6 +130,7 @@ export type AgentTool = {
|
|||||||
enabled: boolean
|
enabled: boolean
|
||||||
isDeleted?: boolean
|
isDeleted?: boolean
|
||||||
notAuthor?: boolean
|
notAuthor?: boolean
|
||||||
|
credential_id?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
export type ToolItem = {
|
export type ToolItem = {
|
||||||
|
|||||||
Reference in New Issue
Block a user