Compare commits

..

59 Commits

Author SHA1 Message Date
-LAN-
c7700ac176 chore(docker): bump version (#25092)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-03 20:25:44 +08:00
Stream
d011ddfc64 chore(version): bump version to 1.8.1 (#25060) 2025-09-03 18:54:07 +08:00
zxhlyh
67cc70ad61 fix: model credential name (#25081)
Co-authored-by: hjlarry <hjlarry@163.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-03 18:23:57 +08:00
-LAN-
a384ae9140 Fix advanced chat workflow event handler signature mismatch (#25078)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-03 16:22:13 +08:00
17hz
a7627882a7 fix: Boolean type control is not displayed (#25031)
Co-authored-by: WTW0313 <twwu@dify.ai>
2025-09-03 15:39:09 +08:00
NeatGuyCoding
8eae7a95be Hotfix translation error (#25035) 2025-09-03 15:23:04 +08:00
dswl23
dabf266048 Fix: handle 204 No Content response in MCP client (#25040) 2025-09-03 15:22:42 +08:00
Asuka Minato
462e764a3c typevar example (#25064)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-03 14:54:38 +08:00
github-actions[bot]
0e8a37dca8 chore: translate i18n files (#25061)
Co-authored-by: zxhlyh <16177003+zxhlyh@users.noreply.github.com>
2025-09-03 14:48:53 +08:00
zyileven
bffbe54120 fix: Solve the problem of opening remarks appearing in the chat cont… (#25067) 2025-09-03 14:48:30 +08:00
非法操作
b673560b92 feat: improve multi model credentials (#25009)
Co-authored-by: Claude <noreply@anthropic.com>
2025-09-03 13:52:31 +08:00
zxhlyh
9e125e2029 Refactor/model credential (#24994) 2025-09-03 13:36:59 +08:00
-LAN-
b88146c443 chore: consolidate type checking in style workflow (#25053) 2025-09-03 13:34:43 +08:00
-LAN-
c40cb7fd59 [Chore/Refactor] Update .gitignore to exclude pyrightconfig.json while preserving api/pyrightconfig.json (#25055) 2025-09-03 13:34:07 +08:00
-LAN-
9d5956cef8 [Chore/Refactor] Switch from MyPy to Basedpyright for type checking (#25047)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-03 11:52:26 +08:00
湛露先生
1fff4620e6 clean console apis and rag cleans. (#25042)
Signed-off-by: zhanluxianshen <zhanluxianshen@163.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-03 11:25:18 +08:00
-LAN-
c3820f55f4 chore: translate Chinese comments to English in ClickZetta Volume storage module (#25037) 2025-09-03 10:57:58 +08:00
17hz
60c5bdd62f fix: remove redundant z-index from Field component (#25034) 2025-09-03 10:39:07 +08:00
Will
5092e5f631 fix: workflow not published (#25030) 2025-09-03 10:07:31 +08:00
NeatGuyCoding
c0bd35594e feat: add test containers based tests for tools manage service (#25028) 2025-09-03 09:20:16 +08:00
Yongtao Huang
bc9efa7ea8 Refactor: use DatasourceType.XX.value instead of hardcoded (#25015)
Signed-off-by: Yongtao Huang <yongtaoh2022@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-03 08:56:48 +08:00
-LAN-
f540d0b747 chore: remove ty type checker from reformat script and pre-commit hooks (#25021) 2025-09-03 08:56:23 +08:00
-LAN-
7bcaa513fa chore: remove duplicate test helper classes from api root directory (#25024) 2025-09-03 08:56:00 +08:00
Will
d33dfee8a3 fix: EndUser is not bound to a Session (#25010) 2025-09-02 21:37:21 +08:00
Will
b5216df4fe fix: xxx is not bound to a Session (#24966) 2025-09-02 21:37:06 +08:00
GuanMu
25a11bfafc Export DSL from history (#24939)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-02 21:36:52 +08:00
Yongtao Huang
8fcc864fb7 Post fix of #23224 (#25007) 2025-09-02 20:59:08 +08:00
NeatGuyCoding
ed5ed0306e minor fix: fix the check of subscription capacity limit (#24991)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-02 19:14:30 +08:00
Asuka Minato
a418c43d32 example add more type check (#24999)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-02 19:13:43 +08:00
17hz
5aa8c9c8df fix: refresh UI after user profile change (#24998) 2025-09-02 18:57:35 +08:00
17hz
32972b45db fix: remove unnecessary modal visibility toggle on error in name save (#25001) 2025-09-02 18:57:24 +08:00
17hz
af351b1723 fix: ensure the modal closed by level (#24984) 2025-09-02 17:06:10 +08:00
Bowen Liang
af88266212 chore: run ty check CI action only when api code changed (#24986) 2025-09-02 16:59:11 +08:00
-LAN-
b14119b531 feat: add development environment setup commands to Makefile (#24976) 2025-09-02 16:24:21 +08:00
Novice
68c75f221b fix: workflow log status filter add parial success status (#24977) 2025-09-02 16:24:03 +08:00
Bowen Liang
7b379e2a61 chore: apply ty checks on api code with script and ci action (#24653) 2025-09-02 16:05:13 +08:00
17hz
c373b734bc feat: make secretInput type field prevent browser auto-fill (#24971) 2025-09-02 16:04:12 +08:00
17hz
2ac8f8003f refactor: update radio component to handle boolean values instead of numeric (#24956) 2025-09-02 15:11:42 +08:00
17hz
d6b3df8f6f fix: API Key Authorization Configuration Model Form render default value (#24963) 2025-09-02 14:52:05 +08:00
湛露先生
deea07e905 make clean() function in index_processor_base abstractmethod (#24959)
Signed-off-by: zhanluxianshen <zhanluxianshen@163.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-02 14:48:45 +08:00
lyzno1
0caa94bd1c fix: add Indonesian (id-ID) language support and improve language selector (#24951) 2025-09-02 14:44:59 +08:00
-LAN-
a32dde5428 Fix: Resolve workflow_node_execution primary key conflicts with UUID v7 (#24643)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-02 14:18:29 +08:00
Yongtao Huang
067b0d07c4 Fix: ensure InstalledApp deletion uses model instances instead of Row (#24942)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-02 11:59:38 +08:00
17hz
044f96bd93 feat: LLM prompt Jinja2 template now support more variables (#24944) 2025-09-02 11:59:31 +08:00
Novice
ca96350707 chore: optimize SQL queries that perform partial full table scans (#24786) 2025-09-02 11:46:11 +08:00
Yongtao Huang
be3af1e234 Migrate SQLAlchemy from 1.x to 2.0 with automated and manual adjustments (#23224)
Co-authored-by: Yongtao Huang <99629139+hyongtao-db@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-02 10:30:19 +08:00
github-actions[bot]
2e89d29c87 chore: translate i18n files (#24934)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-09-02 10:16:14 +08:00
Zhedong Cen
e4eb9f7c55 fix(i18n): align zh-Hant indexMethodEconomyTip with zh-Hans (#24933) 2025-09-02 09:57:39 +08:00
znn
dd6547de06 downvote with reason (#24922) 2025-09-02 09:57:04 +08:00
Atif
84d09b8b8a fix: API key input uses password type and no autocomplete (#24864)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-09-02 09:37:24 +08:00
17hz
2c462154f7 fix: email input cannot scroll (#24930) 2025-09-02 09:35:53 +08:00
NeatGuyCoding
b810efdb3f Feature add test containers tool transform service (#24927) 2025-09-02 09:30:55 +08:00
17hz
ae04ccc445 fix: npx typo error (#24929) 2025-09-02 09:20:51 +08:00
Charles Liu
f7ac1192ae replace the secret field from obfuscated to full-masked value (#24800)
Co-authored-by: charles liu <dearcharles.liu@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-02 09:19:20 +08:00
jiangbo721
e048588a88 fix: remove duplicated code (#24893) 2025-09-02 08:58:31 +08:00
Frederick2313072
2042353526 fix:score threshold (#24897) 2025-09-02 08:58:14 +08:00
wlleiiwang
9486715929 FEAT: Tencent Vector optimize BM25 initialization to reduce loading time (#24915)
Co-authored-by: wlleiiwang <wlleiiwang@tencent.com>
2025-09-01 21:08:41 +08:00
湛露先生
64319c0d56 fix close session twice. (#24917)
Signed-off-by: zhanluxianshen <zhanluxianshen@163.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-01 21:08:01 +08:00
耐小心
acd209a890 fix: prevent database connection leaks in chatflow mode by using Session-managed queries (#24656)
Co-authored-by: 王锶奇 <wangsiqi2@tal.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-01 18:22:42 +08:00
652 changed files with 13819 additions and 36624 deletions

View File

@@ -1,6 +1,5 @@
#!/bin/bash
npm add -g pnpm@10.15.0
corepack enable
cd web && pnpm install
pipx install uv

View File

@@ -42,11 +42,7 @@ jobs:
- name: Run Unit tests
run: |
uv run --project api bash dev/pytest/pytest_unit_tests.sh
- name: Run ty check
run: |
cd api
uv add --dev ty
uv run ty check || true
- name: Run pyrefly check
run: |
cd api
@@ -66,15 +62,6 @@ jobs:
- name: Run dify config tests
run: uv run --project api dev/pytest/pytest_config_tests.py
- name: MyPy Cache
uses: actions/cache@v4
with:
path: api/.mypy_cache
key: mypy-${{ matrix.python-version }}-${{ runner.os }}-${{ hashFiles('api/uv.lock') }}
- name: Run MyPy Checks
run: dev/mypy-check
- name: Set up dotenvs
run: |
cp docker/.env.example docker/.env

View File

@@ -2,8 +2,6 @@ name: autofix.ci
on:
pull_request:
branches: ["main"]
push:
branches: ["main"]
permissions:
contents: read

View File

@@ -44,6 +44,10 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
run: uv sync --project api --dev
- name: Run Basedpyright Checks
if: steps.changed-files.outputs.any_changed == 'true'
run: dev/basedpyright-check
- name: Dotenv check
if: steps.changed-files.outputs.any_changed == 'true'
run: uv run --project api dotenv-linter ./api/.env.example ./web/.env.example

8
.gitignore vendored
View File

@@ -123,10 +123,12 @@ venv.bak/
# mkdocs documentation
/site
# mypy
# type checking
.mypy_cache/
.dmypy.json
dmypy.json
pyrightconfig.json
!api/pyrightconfig.json
# Pyre type checker
.pyre/
@@ -195,7 +197,6 @@ sdks/python-client/dify_client.egg-info
.vscode/*
!.vscode/launch.json.template
!.vscode/README.md
pyrightconfig.json
api/.vscode
# vscode Code History Extension
.history
@@ -218,6 +219,3 @@ mise.toml
.roo/
api/.env.backup
/clickzetta
# mcp
.serena

View File

@@ -32,7 +32,7 @@ uv run --project api pytest tests/integration_tests/ # Integration tests
./dev/reformat # Run all formatters and linters
uv run --project api ruff check --fix ./ # Fix linting issues
uv run --project api ruff format ./ # Format code
uv run --project api mypy . # Type checking
uv run --directory api basedpyright # Type checking
```
### Frontend (Web)
@@ -59,7 +59,6 @@ pnpm test # Run Jest tests
- Use type hints for all functions and class attributes
- No `Any` types unless absolutely necessary
- Implement special methods (`__repr__`, `__str__`) appropriately
- **Logging**: Never use `str(e)` in `logger.exception()` calls. Use `logger.exception("message", exc_info=e)` instead
### TypeScript/JavaScript

View File

@@ -4,6 +4,48 @@ WEB_IMAGE=$(DOCKER_REGISTRY)/dify-web
API_IMAGE=$(DOCKER_REGISTRY)/dify-api
VERSION=latest
# Backend Development Environment Setup
.PHONY: dev-setup prepare-docker prepare-web prepare-api
# Default dev setup target
dev-setup: prepare-docker prepare-web prepare-api
@echo "✅ Backend development environment setup complete!"
# Step 1: Prepare Docker middleware
prepare-docker:
@echo "🐳 Setting up Docker middleware..."
@cp -n docker/middleware.env.example docker/middleware.env 2>/dev/null || echo "Docker middleware.env already exists"
@cd docker && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p dify-middlewares-dev up -d
@echo "✅ Docker middleware started"
# Step 2: Prepare web environment
prepare-web:
@echo "🌐 Setting up web environment..."
@cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists"
@cd web && pnpm install
@cd web && pnpm build
@echo "✅ Web environment prepared (not started)"
# Step 3: Prepare API environment
prepare-api:
@echo "🔧 Setting up API environment..."
@cp -n api/.env.example api/.env 2>/dev/null || echo "API .env already exists"
@cd api && uv sync --dev
@cd api && uv run flask db upgrade
@echo "✅ API environment prepared (not started)"
# Clean dev environment
dev-clean:
@echo "⚠️ Stopping Docker containers..."
@cd docker && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p dify-middlewares-dev down
@echo "🗑️ Removing volumes..."
@rm -rf docker/volumes/db
@rm -rf docker/volumes/redis
@rm -rf docker/volumes/plugin_daemon
@rm -rf docker/volumes/weaviate
@rm -rf api/storage
@echo "✅ Cleanup complete"
# Build Docker images
build-web:
@echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..."
@@ -39,5 +81,21 @@ build-push-web: build-web push-web
build-push-all: build-all push-all
@echo "All Docker images have been built and pushed."
# Help target
help:
@echo "Development Setup Targets:"
@echo " make dev-setup - Run all setup steps for backend dev environment"
@echo " make prepare-docker - Set up Docker middleware"
@echo " make prepare-web - Set up web environment"
@echo " make prepare-api - Set up API environment"
@echo " make dev-clean - Stop Docker middleware containers"
@echo ""
@echo "Docker Build Targets:"
@echo " make build-web - Build web Docker image"
@echo " make build-api - Build API Docker image"
@echo " make build-all - Build all Docker images"
@echo " make push-all - Push all Docker images"
@echo " make build-push-all - Build and push all Docker images"
# Phony targets
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help

View File

@@ -434,9 +434,6 @@ HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
HTTP_REQUEST_NODE_SSL_VERIFY=True
# Webhook request configuration
WEBHOOK_REQUEST_BODY_MAX_SIZE=10485760
# Respect X-* headers to redirect clients
RESPECT_XFORWARD_HEADERS_ENABLED=false
@@ -505,12 +502,6 @@ ENABLE_CLEAN_MESSAGES=false
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
ENABLE_DATASETS_QUEUE_MONITOR=false
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true
ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK=true
# Interval time in minutes for polling scheduled workflows(default: 1 min)
WORKFLOW_SCHEDULE_POLLER_INTERVAL=1
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE=100
# Maximum number of scheduled workflows to dispatch per tick (0 for unlimited)
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK=0
# Position configuration
POSITION_TOOL_PINS=

View File

@@ -54,7 +54,7 @@
"--loglevel",
"DEBUG",
"-Q",
"dataset,generation,mail,ops_trace,app_deletion,workflow"
"dataset,generation,mail,ops_trace,app_deletion"
]
}
]

View File

@@ -108,5 +108,5 @@ uv run celery -A app.celery beat
../dev/reformat # Run all formatters and linters
uv run ruff check --fix ./ # Fix linting issues
uv run ruff format ./ # Format code
uv run mypy . # Type checking
uv run basedpyright . # Type checking
```

View File

@@ -1,11 +0,0 @@
from tests.integration_tests.utils.parent_class import ParentClass
class ChildClass(ParentClass):
"""Test child class for module import helper tests"""
def __init__(self, name):
super().__init__(name)
def get_name(self):
return f"Child: {self.name}"

View File

@@ -571,7 +571,7 @@ def old_metadata_migration():
for document in documents:
if document.doc_metadata:
doc_metadata = document.doc_metadata
for key, value in doc_metadata.items():
for key in doc_metadata:
for field in BuiltInField:
if field.value == key:
break
@@ -1207,55 +1207,6 @@ def setup_system_tool_oauth_client(provider, client_params):
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
@click.command("setup-system-trigger-oauth-client", help="Setup system trigger oauth client.")
@click.option("--provider", prompt=True, help="Provider name")
@click.option("--client-params", prompt=True, help="Client Params")
def setup_system_trigger_oauth_client(provider, client_params):
"""
Setup system trigger oauth client
"""
from core.plugin.entities.plugin import TriggerProviderID
from models.trigger import TriggerOAuthSystemClient
provider_id = TriggerProviderID(provider)
provider_name = provider_id.provider_name
plugin_id = provider_id.plugin_id
try:
# json validate
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
click.echo(click.style("Client params validated successfully.", fg="green"))
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
click.echo(click.style("Client params encrypted successfully.", fg="green"))
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
return
deleted_count = (
db.session.query(TriggerOAuthSystemClient)
.filter_by(
provider=provider_name,
plugin_id=plugin_id,
)
.delete()
)
if deleted_count > 0:
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
oauth_client = TriggerOAuthSystemClient(
provider=provider_name,
plugin_id=plugin_id,
encrypted_oauth_params=oauth_client_params,
)
db.session.add(oauth_client)
db.session.commit()
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
"""
Find draft variables that reference non-existent apps.

View File

@@ -147,17 +147,6 @@ class CodeExecutionSandboxConfig(BaseSettings):
)
class TriggerConfig(BaseSettings):
"""
Configuration for trigger
"""
WEBHOOK_REQUEST_BODY_MAX_SIZE: PositiveInt = Field(
description="Maximum allowed size for webhook request bodies in bytes",
default=10485760,
)
class PluginConfig(BaseSettings):
"""
Plugin configs
@@ -882,22 +871,6 @@ class CeleryScheduleTasksConfig(BaseSettings):
description="Enable check upgradable plugin task",
default=True,
)
ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK: bool = Field(
description="Enable workflow schedule poller task",
default=True,
)
WORKFLOW_SCHEDULE_POLLER_INTERVAL: int = Field(
description="Workflow schedule poller interval in minutes",
default=1,
)
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE: int = Field(
description="Maximum number of schedules to process in each poll batch",
default=100,
)
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK: int = Field(
description="Maximum schedules to dispatch per tick (0=unlimited, circuit breaker)",
default=0,
)
class PositionConfig(BaseSettings):
@@ -1021,7 +994,6 @@ class FeatureConfig(
AuthConfig, # Changed from OAuthConfig to AuthConfig
BillingConfig,
CodeExecutionSandboxConfig,
TriggerConfig,
PluginConfig,
MarketplaceConfig,
DataSetConfig,

View File

@@ -29,7 +29,7 @@ class NacosSettingsSource(RemoteSettingsSource):
try:
content = NacosHttpClient().http_request("/nacos/v1/cs/configs", method="GET", headers={}, params=params)
self.remote_configs = self._parse_config(content)
except Exception as e:
except Exception:
logger.exception("[get-access-token] exception occurred")
raise

View File

@@ -27,7 +27,7 @@ class NacosHttpClient:
response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params)
response.raise_for_status()
return response.text
except requests.exceptions.RequestException as e:
except requests.RequestException as e:
return f"Request to Nacos failed: {e}"
def _inject_auth_info(self, headers, params, module="config"):
@@ -77,6 +77,6 @@ class NacosHttpClient:
self.token = response_data.get("accessToken")
self.token_ttl = response_data.get("tokenTtl", 18000)
self.token_expire_time = current_time + self.token_ttl - 10
except Exception as e:
except Exception:
logger.exception("[get-access-token] exception occur")
raise

View File

@@ -19,6 +19,7 @@ language_timezone_mapping = {
"fa-IR": "Asia/Tehran",
"sl-SI": "Europe/Ljubljana",
"th-TH": "Asia/Bangkok",
"id-ID": "Asia/Jakarta",
}
languages = list(language_timezone_mapping.keys())

View File

@@ -8,7 +8,6 @@ if TYPE_CHECKING:
from core.model_runtime.entities.model_entities import AIModelEntity
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.trigger.provider import PluginTriggerProviderController
from core.workflow.entities.variable_pool import VariablePool
@@ -34,11 +33,3 @@ plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(Cont
plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
ContextVar("plugin_model_schemas")
)
plugin_trigger_providers: RecyclableContextVar[dict[str, "PluginTriggerProviderController"]] = RecyclableContextVar(
ContextVar("plugin_trigger_providers")
)
plugin_trigger_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
ContextVar("plugin_trigger_providers_lock")
)

View File

@@ -67,7 +67,6 @@ from .app import (
workflow_draft_variable,
workflow_run,
workflow_statistic,
workflow_trigger,
)
# Import auth controllers
@@ -181,6 +180,5 @@ from .workspace import (
models,
plugin,
tool_providers,
trigger_providers,
workspace,
)

View File

@@ -130,15 +130,19 @@ class InsertExploreAppApi(Resource):
app.is_public = False
with Session(db.engine) as session:
installed_apps = session.execute(
select(InstalledApp).where(
InstalledApp.app_id == recommended_app.app_id,
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
installed_apps = (
session.execute(
select(InstalledApp).where(
InstalledApp.app_id == recommended_app.app_id,
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
)
)
).all()
.scalars()
.all()
)
for installed_app in installed_apps:
db.session.delete(installed_app)
for installed_app in installed_apps:
session.delete(installed_app)
db.session.delete(recommended_app)
db.session.commit()

View File

@@ -84,7 +84,7 @@ class BaseApiKeyListResource(Resource):
flask_restx.abort(
400,
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
code="max_keys_exceeded",
custom="max_keys_exceeded",
)
key = ApiToken.generate_api_key(self.token_prefix, 24)

View File

@@ -237,9 +237,14 @@ class AppExportApi(Resource):
# Add include_secret params
parser = reqparse.RequestParser()
parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
parser.add_argument("workflow_id", type=str, location="args")
args = parser.parse_args()
return {"data": AppDslService.export_dsl(app_model=app_model, include_secret=args["include_secret"])}
return {
"data": AppDslService.export_dsl(
app_model=app_model, include_secret=args["include_secret"], workflow_id=args.get("workflow_id")
)
}
class AppNameApi(Resource):

View File

@@ -12,7 +12,6 @@ from controllers.console.app.error import (
)
from controllers.console.wraps import account_initialization_required, setup_required
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.llm_generator.llm_generator import LLMGenerator
@@ -126,11 +125,13 @@ class InstructionGenerateApi(Resource):
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
parser.add_argument("ideal_output", type=str, required=False, default="", location="json")
args = parser.parse_args()
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
code_provider: type[CodeNodeProvider] | None = next(
(p for p in providers if p.is_accept_language(args["language"])), None
code_template = (
Python3CodeProvider.get_default_code()
if args["language"] == "python"
else (JavascriptCodeProvider.get_default_code())
if args["language"] == "javascript"
else ""
)
code_template = code_provider.get_default_code() if code_provider else ""
try:
# Generate from nothing for a workflow node
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":

View File

@@ -24,7 +24,6 @@ from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from core.helper.trace_id_helper import get_external_trace_id
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from factories import file_factory, variable_factory
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
@@ -39,7 +38,6 @@ from models.workflow import Workflow
from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowHashNotEqualError
from services.errors.llm import InvokeRateLimitError
from services.trigger_debug_service import TriggerDebugService
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
logger = logging.getLogger(__name__)
@@ -528,7 +526,7 @@ class PublishedWorkflowApi(Resource):
)
app_model.workflow_id = workflow.id
db.session.commit()
db.session.commit() # NOTE: this is necessary for update app_model.workflow_id
workflow_created_at = TimestampField().format(workflow.created_at)
@@ -808,132 +806,6 @@ class DraftWorkflowNodeLastRunApi(Resource):
return node_exec
class DraftWorkflowTriggerNodeApi(Resource):
"""
Single node debug - Polling API for trigger events
Path: /apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/trigger
"""
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
def post(self, app_model: App, node_id: str):
"""
Poll for trigger events and execute single node when event arrives
"""
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("trigger_name", type=str, required=True, location="json")
parser.add_argument("subscription_id", type=str, required=True, location="json")
args = parser.parse_args()
trigger_name = args["trigger_name"]
subscription_id = args["subscription_id"]
event = TriggerDebugService.poll_event(
tenant_id=app_model.tenant_id,
user_id=current_user.id,
app_id=app_model.id,
subscription_id=subscription_id,
node_id=node_id,
trigger_name=trigger_name,
)
if not event:
return jsonable_encoder({"status": "waiting"})
try:
workflow_service = WorkflowService()
draft_workflow = workflow_service.get_draft_workflow(app_model)
if not draft_workflow:
raise ValueError("Workflow not found")
user_inputs = event.model_dump()
node_execution = workflow_service.run_draft_workflow_node(
app_model=app_model,
draft_workflow=draft_workflow,
node_id=node_id,
user_inputs=user_inputs,
account=current_user,
query="",
files=[],
)
return jsonable_encoder(node_execution)
except Exception:
logger.exception("Error running draft workflow trigger node")
return jsonable_encoder(
{
"status": "error",
}
), 500
class DraftWorkflowTriggerRunApi(Resource):
"""
Full workflow debug - Polling API for trigger events
Path: /apps/<uuid:app_id>/workflows/draft/trigger/run
"""
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
def post(self, app_model: App):
"""
Poll for trigger events and execute full workflow when event arrives
"""
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="json", nullable=False)
parser.add_argument("trigger_name", type=str, required=True, location="json", nullable=False)
parser.add_argument("subscription_id", type=str, required=True, location="json", nullable=False)
args = parser.parse_args()
node_id = args["node_id"]
trigger_name = args["trigger_name"]
subscription_id = args["subscription_id"]
event = TriggerDebugService.poll_event(
tenant_id=app_model.tenant_id,
user_id=current_user.id,
app_id=app_model.id,
subscription_id=subscription_id,
node_id=node_id,
trigger_name=trigger_name,
)
if not event:
return jsonable_encoder({"status": "waiting"})
workflow_args = {
"inputs": event.model_dump(),
"query": "",
"files": [],
}
external_trace_id = get_external_trace_id(request)
if external_trace_id:
workflow_args["external_trace_id"] = external_trace_id
try:
response = AppGenerateService.generate(
app_model=app_model,
user=current_user,
args=workflow_args,
invoke_from=InvokeFrom.DEBUGGER,
streaming=True,
)
return helper.compact_generate_response(response)
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except Exception:
logger.exception("Error running draft workflow trigger run")
return jsonable_encoder(
{
"status": "error",
}
), 500
api.add_resource(
DraftWorkflowApi,
"/apps/<uuid:app_id>/workflows/draft",
@@ -958,14 +830,6 @@ api.add_resource(
DraftWorkflowNodeRunApi,
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run",
)
api.add_resource(
DraftWorkflowTriggerNodeApi,
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/trigger",
)
api.add_resource(
DraftWorkflowTriggerRunApi,
"/apps/<uuid:app_id>/workflows/draft/trigger/run",
)
api.add_resource(
AdvancedChatDraftRunIterationNodeApi,
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run",

View File

@@ -1,249 +0,0 @@
import logging
from flask_restx import Resource, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
from libs.login import current_user, login_required
from models.model import Account, AppMode
from models.workflow import AppTrigger, AppTriggerStatus, WorkflowWebhookTrigger
logger = logging.getLogger(__name__)
from services.workflow_plugin_trigger_service import WorkflowPluginTriggerService
class PluginTriggerApi(Resource):
"""Workflow Plugin Trigger API"""
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
def post(self, app_model):
"""Create plugin trigger"""
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=False, location="json")
parser.add_argument("provider_id", type=str, required=False, location="json")
parser.add_argument("trigger_name", type=str, required=False, location="json")
parser.add_argument("subscription_id", type=str, required=False, location="json")
args = parser.parse_args()
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
if not current_user.is_editor:
raise Forbidden()
plugin_trigger = WorkflowPluginTriggerService.create_plugin_trigger(
app_id=app_model.id,
tenant_id=current_user.current_tenant_id,
node_id=args["node_id"],
provider_id=args["provider_id"],
trigger_name=args["trigger_name"],
subscription_id=args["subscription_id"],
)
return jsonable_encoder(plugin_trigger)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
def get(self, app_model):
"""Get plugin trigger"""
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
args = parser.parse_args()
plugin_trigger = WorkflowPluginTriggerService.get_plugin_trigger(
app_id=app_model.id,
node_id=args["node_id"],
)
return jsonable_encoder(plugin_trigger)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
def put(self, app_model):
"""Update plugin trigger"""
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
parser.add_argument("subscription_id", type=str, required=True, location="json", help="Subscription ID")
args = parser.parse_args()
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
if not current_user.is_editor:
raise Forbidden()
plugin_trigger = WorkflowPluginTriggerService.update_plugin_trigger(
app_id=app_model.id,
node_id=args["node_id"],
subscription_id=args["subscription_id"],
)
return jsonable_encoder(plugin_trigger)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
def delete(self, app_model):
"""Delete plugin trigger"""
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
args = parser.parse_args()
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
if not current_user.is_editor:
raise Forbidden()
WorkflowPluginTriggerService.delete_plugin_trigger(
app_id=app_model.id,
node_id=args["node_id"],
)
return {"result": "success"}, 204
class WebhookTriggerApi(Resource):
"""Webhook Trigger API"""
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(webhook_trigger_fields)
def get(self, app_model):
"""Get webhook trigger for a node"""
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
args = parser.parse_args()
node_id = args["node_id"]
with Session(db.engine) as session:
# Get webhook trigger for this app and node
webhook_trigger = (
session.query(WorkflowWebhookTrigger)
.filter(
WorkflowWebhookTrigger.app_id == app_model.id,
WorkflowWebhookTrigger.node_id == node_id,
)
.first()
)
if not webhook_trigger:
raise NotFound("Webhook trigger not found for this node")
# Add computed fields for marshal_with
base_url = dify_config.SERVICE_API_URL
webhook_trigger.webhook_url = f"{base_url}/triggers/webhook/{webhook_trigger.webhook_id}" # type: ignore
webhook_trigger.webhook_debug_url = f"{base_url}/triggers/webhook-debug/{webhook_trigger.webhook_id}" # type: ignore
return webhook_trigger
class AppTriggersApi(Resource):
"""App Triggers list API"""
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(triggers_list_fields)
def get(self, app_model):
"""Get app triggers list"""
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
with Session(db.engine) as session:
# Get all triggers for this app using select API
triggers = (
session.execute(
select(AppTrigger)
.where(
AppTrigger.tenant_id == current_user.current_tenant_id,
AppTrigger.app_id == app_model.id,
)
.order_by(AppTrigger.created_at.desc(), AppTrigger.id.desc())
)
.scalars()
.all()
)
# Add computed icon field for each trigger
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
for trigger in triggers:
if trigger.trigger_type == "trigger-plugin":
trigger.icon = url_prefix + trigger.provider_name + "/icon" # type: ignore
else:
trigger.icon = "" # type: ignore
return {"data": triggers}
class AppTriggerEnableApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(trigger_fields)
def post(self, app_model):
"""Update app trigger (enable/disable)"""
parser = reqparse.RequestParser()
parser.add_argument("trigger_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("enable_trigger", type=bool, required=True, nullable=False, location="json")
args = parser.parse_args()
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
if not current_user.is_editor:
raise Forbidden()
trigger_id = args["trigger_id"]
with Session(db.engine) as session:
# Find the trigger using select
trigger = session.execute(
select(AppTrigger).where(
AppTrigger.id == trigger_id,
AppTrigger.tenant_id == current_user.current_tenant_id,
AppTrigger.app_id == app_model.id,
)
).scalar_one_or_none()
if not trigger:
raise NotFound("Trigger not found")
# Update status based on enable_trigger boolean
trigger.status = AppTriggerStatus.ENABLED if args["enable_trigger"] else AppTriggerStatus.DISABLED
session.commit()
session.refresh(trigger)
# Add computed icon field
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
if trigger.trigger_type == "trigger-plugin":
trigger.icon = url_prefix + trigger.provider_name + "/icon" # type: ignore
else:
trigger.icon = "" # type: ignore
return trigger
api.add_resource(WebhookTriggerApi, "/apps/<uuid:app_id>/workflows/triggers/webhook")
api.add_resource(PluginTriggerApi, "/apps/<uuid:app_id>/workflows/triggers/plugin")
api.add_resource(AppTriggersApi, "/apps/<uuid:app_id>/triggers")
api.add_resource(AppTriggerEnableApi, "/apps/<uuid:app_id>/trigger-enable")

View File

@@ -81,7 +81,7 @@ class OAuthDataSourceBinding(Resource):
return {"error": "Invalid code"}, 400
try:
oauth_provider.get_access_token(code)
except requests.exceptions.HTTPError as e:
except requests.HTTPError as e:
logger.exception(
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
)
@@ -104,7 +104,7 @@ class OAuthDataSourceSync(Resource):
return {"error": "Invalid provider"}, 400
try:
oauth_provider.sync_data_source(binding_id)
except requests.exceptions.HTTPError as e:
except requests.HTTPError as e:
logger.exception(
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
)

View File

@@ -130,7 +130,7 @@ class ResetPasswordSendEmailApi(Resource):
language = "en-US"
try:
account = AccountService.get_user_through_email(args["email"])
except AccountRegisterError as are:
except AccountRegisterError:
raise AccountInFreezeError()
if account is None:
@@ -162,7 +162,7 @@ class EmailCodeLoginSendEmailApi(Resource):
language = "en-US"
try:
account = AccountService.get_user_through_email(args["email"])
except AccountRegisterError as are:
except AccountRegisterError:
raise AccountInFreezeError()
if account is None:
@@ -200,7 +200,7 @@ class EmailCodeLoginApi(Resource):
AccountService.revoke_email_code_login_token(args["token"])
try:
account = AccountService.get_user_through_email(user_email)
except AccountRegisterError as are:
except AccountRegisterError:
raise AccountInFreezeError()
if account:
tenants = TenantService.get_join_tenants(account)
@@ -223,7 +223,7 @@ class EmailCodeLoginApi(Resource):
)
except WorkSpaceNotAllowedCreateError:
raise NotAllowedCreateWorkspace()
except AccountRegisterError as are:
except AccountRegisterError:
raise AccountInFreezeError()
except WorkspacesLimitExceededError:
raise WorkspacesLimitExceeded()

View File

@@ -80,7 +80,7 @@ class OAuthCallback(Resource):
try:
token = oauth_provider.get_access_token(code)
user_info = oauth_provider.get_user_info(token)
except requests.exceptions.RequestException as e:
except requests.RequestException as e:
error_text = e.response.text if e.response else str(e)
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
return {"error": "OAuth process failed"}, 400

View File

@@ -10,6 +10,7 @@ from werkzeug.exceptions import NotFound
from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required
from core.indexing_runner import IndexingRunner
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db
@@ -214,7 +215,7 @@ class DataSourceNotionApi(Resource):
workspace_id = notion_info["workspace_id"]
for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type="notion_import",
datasource_type=DatasourceType.NOTION.value,
notion_info={
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],

View File

@@ -22,6 +22,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.plugin.entities.plugin import ModelProviderID
from core.provider_manager import ProviderManager
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
@@ -422,7 +423,9 @@ class DatasetIndexingEstimateApi(Resource):
if file_details:
for file_detail in file_details:
extract_setting = ExtractSetting(
datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"]
datasource_type=DatasourceType.FILE.value,
upload_file=file_detail,
document_model=args["doc_form"],
)
extract_settings.append(extract_setting)
elif args["info_list"]["data_source_type"] == "notion_import":
@@ -431,7 +434,7 @@ class DatasetIndexingEstimateApi(Resource):
workspace_id = notion_info["workspace_id"]
for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type="notion_import",
datasource_type=DatasourceType.NOTION.value,
notion_info={
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
@@ -445,7 +448,7 @@ class DatasetIndexingEstimateApi(Resource):
website_info_list = args["info_list"]["website_info_list"]
for url in website_info_list["urls"]:
extract_setting = ExtractSetting(
datasource_type="website_crawl",
datasource_type=DatasourceType.WEBSITE.value,
website_info={
"provider": website_info_list["provider"],
"job_id": website_info_list["job_id"],

View File

@@ -40,6 +40,7 @@ from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from extensions.ext_database import db
from fields.document_fields import (
@@ -354,9 +355,6 @@ class DatasetInitApi(Resource):
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
knowledge_config = KnowledgeConfig(**args)
if knowledge_config.indexing_technique == "high_quality":
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
@@ -428,7 +426,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
raise NotFound("File not found.")
extract_setting = ExtractSetting(
datasource_type="upload_file", upload_file=file, document_model=document.doc_form
datasource_type=DatasourceType.FILE.value, upload_file=file, document_model=document.doc_form
)
indexing_runner = IndexingRunner()
@@ -488,13 +486,13 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
raise NotFound("File not found.")
extract_setting = ExtractSetting(
datasource_type="upload_file", upload_file=file_detail, document_model=document.doc_form
datasource_type=DatasourceType.FILE.value, upload_file=file_detail, document_model=document.doc_form
)
extract_settings.append(extract_setting)
elif document.data_source_type == "notion_import":
extract_setting = ExtractSetting(
datasource_type="notion_import",
datasource_type=DatasourceType.NOTION.value,
notion_info={
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
@@ -506,7 +504,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
extract_settings.append(extract_setting)
elif document.data_source_type == "website_crawl":
extract_setting = ExtractSetting(
datasource_type="website_crawl",
datasource_type=DatasourceType.WEBSITE.value,
website_info={
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],

View File

@@ -61,7 +61,6 @@ class ConversationApi(InstalledAppResource):
ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
WebConversationService.unpin(app_model, conversation_id, current_user)
return {"result": "success"}, 204

View File

@@ -67,7 +67,7 @@ class ModelProviderCredentialApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
args = parser.parse_args()
model_provider_service = ModelProviderService()
@@ -94,7 +94,7 @@ class ModelProviderCredentialApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
args = parser.parse_args()
model_provider_service = ModelProviderService()

View File

@@ -219,7 +219,11 @@ class ModelProviderModelCredentialApi(Resource):
model_load_balancing_service = ModelLoadBalancingService()
is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
tenant_id=tenant_id,
provider=provider,
model=args["model"],
model_type=args["model_type"],
config_from=args.get("config_from", ""),
)
if args.get("config_from", "") == "predefined-model":
@@ -263,7 +267,7 @@ class ModelProviderModelCredentialApi(Resource):
choices=[mt.value for mt in ModelType],
location="json",
)
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
@@ -309,7 +313,7 @@ class ModelProviderModelCredentialApi(Resource):
)
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
args = parser.parse_args()
model_provider_service = ModelProviderService()

View File

@@ -516,20 +516,18 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
parser.add_argument("provider", type=str, required=True, location="args")
parser.add_argument("action", type=str, required=True, location="args")
parser.add_argument("parameter", type=str, required=True, location="args")
parser.add_argument("credential_id", type=str, required=False, location="args")
parser.add_argument("provider_type", type=str, required=True, location="args")
args = parser.parse_args()
try:
options = PluginParameterService.get_dynamic_select_options(
tenant_id=tenant_id,
user_id=user_id,
plugin_id=args["plugin_id"],
provider=args["provider"],
action=args["action"],
parameter=args["parameter"],
credential_id=args["credential_id"],
provider_type=args["provider_type"],
tenant_id,
user_id,
args["plugin_id"],
args["provider"],
args["action"],
args["parameter"],
args["provider_type"],
)
except PluginDaemonClientSideError as e:
raise ValueError(e)

View File

@@ -22,8 +22,8 @@ from core.mcp.error import MCPAuthError, MCPError
from core.mcp.mcp_client import MCPClient
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.oauth import OAuthHandler
from core.tools.entities.tool_entities import CredentialType
from libs.helper import StrLen, alphanumeric, uuid_value
from libs.login import login_required
from services.plugin.oauth_service import OAuthProxyService

View File

@@ -1,589 +0,0 @@
import logging
from flask import make_response, redirect, request
from flask_restx import Resource, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden
from configs import dify_config
from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import TriggerProviderID
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.oauth import OAuthHandler
from core.trigger.entities.entities import SubscriptionBuilderUpdater
from core.trigger.trigger_manager import TriggerManager
from extensions.ext_database import db
from libs.login import current_user, login_required
from models.account import Account
from services.plugin.oauth_service import OAuthProxyService
from services.trigger.trigger_provider_service import TriggerProviderService
from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService
from services.workflow_plugin_trigger_service import WorkflowPluginTriggerService
logger = logging.getLogger(__name__)
class TriggerProviderListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
"""List all trigger providers for the current tenant"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
return jsonable_encoder(TriggerProviderService.list_trigger_providers(user.current_tenant_id))
class TriggerProviderInfoApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
"""Get info for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
return jsonable_encoder(
TriggerProviderService.get_trigger_provider(user.current_tenant_id, TriggerProviderID(provider))
)
class TriggerSubscriptionListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
"""List all trigger subscriptions for the current tenant's provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
try:
return jsonable_encoder(
TriggerProviderService.list_trigger_provider_subscriptions(
tenant_id=user.current_tenant_id, provider_id=TriggerProviderID(provider)
)
)
except Exception as e:
logger.exception("Error listing trigger providers", exc_info=e)
raise
class TriggerSubscriptionBuilderCreateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
"""Add a new subscription instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credential_type", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
try:
credential_type = CredentialType.of(args.get("credential_type") or CredentialType.UNAUTHORIZED.value)
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
tenant_id=user.current_tenant_id,
user_id=user.id,
provider_id=TriggerProviderID(provider),
credential_type=credential_type,
)
return jsonable_encoder({"subscription_builder": subscription_builder})
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:
logger.exception("Error adding provider credential", exc_info=e)
raise
class TriggerSubscriptionBuilderGetApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider, subscription_builder_id):
"""Get a subscription instance for a trigger provider"""
return jsonable_encoder(
TriggerSubscriptionBuilderService.get_subscription_builder_by_id(subscription_builder_id)
)
class TriggerSubscriptionBuilderVerifyApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
"""Verify a subscription instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
# The credentials of the subscription builder
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
args = parser.parse_args()
try:
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
tenant_id=user.current_tenant_id,
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
credentials=args.get("credentials", None),
),
)
return TriggerSubscriptionBuilderService.verify_trigger_subscription_builder(
tenant_id=user.current_tenant_id,
user_id=user.id,
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
)
except Exception as e:
logger.exception("Error verifying provider credential", exc_info=e)
raise
class TriggerSubscriptionBuilderUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
"""Update a subscription instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
parser = reqparse.RequestParser()
# The name of the subscription builder
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
# The parameters of the subscription builder
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
# The properties of the subscription builder
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json")
# The credentials of the subscription builder
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
args = parser.parse_args()
try:
return jsonable_encoder(
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
tenant_id=user.current_tenant_id,
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
name=args.get("name", None),
parameters=args.get("parameters", None),
properties=args.get("properties", None),
credentials=args.get("credentials", None),
),
)
)
except Exception as e:
logger.exception("Error updating provider credential", exc_info=e)
raise
class TriggerSubscriptionBuilderLogsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider, subscription_builder_id):
"""Get the request logs for a subscription instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
try:
logs = TriggerSubscriptionBuilderService.list_logs(subscription_builder_id)
return jsonable_encoder({"logs": [log.model_dump(mode="json") for log in logs]})
except Exception as e:
logger.exception("Error getting request logs for subscription builder", exc_info=e)
raise
class TriggerSubscriptionBuilderBuildApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
"""Build a subscription instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
# The name of the subscription builder
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
# The parameters of the subscription builder
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
# The properties of the subscription builder
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json")
# The credentials of the subscription builder
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
args = parser.parse_args()
try:
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
tenant_id=user.current_tenant_id,
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
name=args.get("name", None),
parameters=args.get("parameters", None),
properties=args.get("properties", None),
),
)
TriggerSubscriptionBuilderService.build_trigger_subscription_builder(
tenant_id=user.current_tenant_id,
user_id=user.id,
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
)
return 200
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:
logger.exception("Error building provider credential", exc_info=e)
raise
class TriggerSubscriptionDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, subscription_id):
"""Delete a subscription instance"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
try:
with Session(db.engine) as session:
# Delete trigger provider subscription
TriggerProviderService.delete_trigger_provider(
session=session,
tenant_id=user.current_tenant_id,
subscription_id=subscription_id,
)
# Delete plugin triggers
WorkflowPluginTriggerService.delete_plugin_trigger_by_subscription(
session=session,
tenant_id=user.current_tenant_id,
subscription_id=subscription_id,
)
session.commit()
return {"result": "success"}
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:
logger.exception("Error deleting provider credential", exc_info=e)
raise
class TriggerOAuthAuthorizeApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
"""Initiate OAuth authorization flow for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
try:
provider_id = TriggerProviderID(provider)
plugin_id = provider_id.plugin_id
provider_name = provider_id.provider_name
tenant_id = user.current_tenant_id
# Get OAuth client configuration
oauth_client_params = TriggerProviderService.get_oauth_client(
tenant_id=tenant_id,
provider_id=provider_id,
)
if oauth_client_params is None:
raise Forbidden("No OAuth client configuration found for this trigger provider")
# Create subscription builder
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
tenant_id=tenant_id,
user_id=user.id,
provider_id=provider_id,
credential_type=CredentialType.OAUTH2,
)
# Create OAuth handler and proxy context
oauth_handler = OAuthHandler()
context_id = OAuthProxyService.create_proxy_context(
user_id=user.id,
tenant_id=tenant_id,
plugin_id=plugin_id,
provider=provider_name,
extra_data={
"subscription_builder_id": subscription_builder.id,
},
)
# Build redirect URI for callback
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
# Get authorization URL
authorization_url_response = oauth_handler.get_authorization_url(
tenant_id=tenant_id,
user_id=user.id,
plugin_id=plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=oauth_client_params,
)
# Create response with cookie
response = make_response(
jsonable_encoder(
{
"authorization_url": authorization_url_response.authorization_url,
"subscription_builder_id": subscription_builder.id,
"subscription_builder": subscription_builder,
}
)
)
response.set_cookie(
"context_id",
context_id,
httponly=True,
samesite="Lax",
max_age=OAuthProxyService.__MAX_AGE__,
)
return response
except Exception as e:
logger.exception("Error initiating OAuth flow", exc_info=e)
raise
class TriggerOAuthCallbackApi(Resource):
@setup_required
def get(self, provider):
"""Handle OAuth callback for trigger provider"""
context_id = request.cookies.get("context_id")
if not context_id:
raise Forbidden("context_id not found")
# Use and validate proxy context
context = OAuthProxyService.use_proxy_context(context_id)
if context is None:
raise Forbidden("Invalid context_id")
# Parse provider ID
provider_id = TriggerProviderID(provider)
plugin_id = provider_id.plugin_id
provider_name = provider_id.provider_name
user_id = context.get("user_id")
tenant_id = context.get("tenant_id")
subscription_builder_id = context.get("subscription_builder_id")
# Get OAuth client configuration
oauth_client_params = TriggerProviderService.get_oauth_client(
tenant_id=tenant_id,
provider_id=provider_id,
)
if oauth_client_params is None:
raise Forbidden("No OAuth client configuration found for this trigger provider")
# Get OAuth credentials from callback
oauth_handler = OAuthHandler()
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
credentials_response = oauth_handler.get_credentials(
tenant_id=tenant_id,
user_id=user_id,
plugin_id=plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=oauth_client_params,
request=request,
)
credentials = credentials_response.credentials
expires_at = credentials_response.expires_at
if not credentials:
raise Exception("Failed to get OAuth credentials")
# Update subscription builder
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
tenant_id=tenant_id,
provider_id=provider_id,
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
credentials=credentials,
credential_expires_at=expires_at,
),
)
# Redirect to OAuth callback page
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
class TriggerOAuthClientManageApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
"""Get OAuth client configuration for a provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
try:
provider_id = TriggerProviderID(provider)
# Get custom OAuth client params if exists
custom_params = TriggerProviderService.get_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
provider_id=provider_id,
)
# Check if custom client is enabled
is_custom_enabled = TriggerProviderService.is_oauth_custom_client_enabled(
tenant_id=user.current_tenant_id,
provider_id=provider_id,
)
# Check if there's a system OAuth client
system_client = TriggerProviderService.get_oauth_client(
tenant_id=user.current_tenant_id,
provider_id=provider_id,
)
provider_controller = TriggerManager.get_trigger_provider(user.current_tenant_id, provider_id)
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
return jsonable_encoder(
{
"configured": bool(custom_params or system_client),
"oauth_client_schema": provider_controller.get_oauth_client_schema(),
"custom_configured": bool(custom_params),
"custom_enabled": is_custom_enabled,
"redirect_uri": redirect_uri,
"params": custom_params if custom_params else {},
}
)
except Exception as e:
logger.exception("Error getting OAuth client", exc_info=e)
raise
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
"""Configure custom OAuth client for a provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
parser.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
args = parser.parse_args()
try:
provider_id = TriggerProviderID(provider)
return TriggerProviderService.save_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
provider_id=provider_id,
client_params=args.get("client_params"),
enabled=args.get("enabled"),
)
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:
logger.exception("Error configuring OAuth client", exc_info=e)
raise
@setup_required
@login_required
@account_initialization_required
def delete(self, provider):
"""Remove custom OAuth client configuration"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
try:
provider_id = TriggerProviderID(provider)
return TriggerProviderService.delete_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
provider_id=provider_id,
)
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:
logger.exception("Error removing OAuth client", exc_info=e)
raise
# Trigger Subscription
api.add_resource(TriggerProviderListApi, "/workspaces/current/triggers")
api.add_resource(TriggerProviderInfoApi, "/workspaces/current/trigger-provider/<path:provider>/info")
api.add_resource(TriggerSubscriptionListApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/list")
api.add_resource(
TriggerSubscriptionDeleteApi,
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/delete",
)
# Trigger Subscription Builder
api.add_resource(
TriggerSubscriptionBuilderCreateApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/create",
)
api.add_resource(
TriggerSubscriptionBuilderGetApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/<path:subscription_builder_id>",
)
api.add_resource(
TriggerSubscriptionBuilderUpdateApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/update/<path:subscription_builder_id>",
)
api.add_resource(
TriggerSubscriptionBuilderVerifyApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify/<path:subscription_builder_id>",
)
api.add_resource(
TriggerSubscriptionBuilderBuildApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/build/<path:subscription_builder_id>",
)
api.add_resource(
TriggerSubscriptionBuilderLogsApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/logs/<path:subscription_builder_id>",
)
# OAuth
api.add_resource(
TriggerOAuthAuthorizeApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/oauth/authorize"
)
api.add_resource(TriggerOAuthCallbackApi, "/oauth/plugin/<path:provider>/trigger/callback")
api.add_resource(TriggerOAuthClientManageApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/client")

View File

@@ -9,9 +9,10 @@ from controllers.console.app.mcp_server import AppMCPServerStatus
from controllers.mcp import mcp_ns
from core.app.app_config.entities import VariableEntity
from core.mcp import types as mcp_types
from core.mcp.server.streamable_http import handle_mcp_request
from extensions.ext_database import db
from libs import helper
from models.model import App, AppMCPServer, AppMode
from models.model import App, AppMCPServer, AppMode, EndUser
class MCPRequestError(Exception):
@@ -194,6 +195,50 @@ class MCPAppApi(Resource):
except ValidationError as e:
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
mcp_server_handler = MCPServerStreamableHTTPRequestHandler(app, request, converted_user_input_form)
response = mcp_server_handler.handle()
return helper.compact_generate_response(response)
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str, session: Session) -> EndUser | None:
"""Get end user from existing session - optimized query"""
return (
session.query(EndUser)
.where(EndUser.tenant_id == tenant_id)
.where(EndUser.session_id == mcp_server_id)
.where(EndUser.type == "mcp")
.first()
)
def _create_end_user(
self, client_name: str, tenant_id: str, app_id: str, mcp_server_id: str, session: Session
) -> EndUser:
"""Create end user in existing session"""
end_user = EndUser(
tenant_id=tenant_id,
app_id=app_id,
type="mcp",
name=client_name,
session_id=mcp_server_id,
)
session.add(end_user)
session.flush() # Use flush instead of commit to keep transaction open
session.refresh(end_user)
return end_user
def _handle_mcp_request(
self,
app: App,
mcp_server: AppMCPServer,
mcp_request: mcp_types.ClientRequest,
user_input_form: list[VariableEntity],
session: Session,
request_id: Union[int, str],
) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError | None:
"""Handle MCP request and return response"""
end_user = self._retrieve_end_user(mcp_server.tenant_id, mcp_server.id, session)
if not end_user and isinstance(mcp_request.root, mcp_types.InitializeRequest):
client_info = mcp_request.root.params.clientInfo
client_name = f"{client_info.name}@{client_info.version}"
# Commit the session before creating end user to avoid transaction conflicts
session.commit()
with Session(db.engine, expire_on_commit=False) as create_session, create_session.begin():
end_user = self._create_end_user(client_name, app.tenant_id, app.id, mcp_server.id, create_session)
return handle_mcp_request(app, mcp_request, user_input_form, mcp_server, end_user, request_id)

View File

@@ -55,7 +55,7 @@ class AudioApi(Resource):
file = request.files["file"]
try:
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user)
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user.id)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:

View File

@@ -59,7 +59,7 @@ class FilePreviewApi(Resource):
args = file_preview_parser.parse_args()
# Validate file ownership and get file objects
message_file, upload_file = self._validate_file_ownership(file_id, app_model.id)
_, upload_file = self._validate_file_ownership(file_id, app_model.id)
# Get file content generator
try:

View File

@@ -410,7 +410,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
DocumentService.document_create_args_validate(knowledge_config)
try:
documents, batch = DocumentService.save_document_with_dataset_id(
documents, _ = DocumentService.save_document_with_dataset_id(
dataset=dataset,
knowledge_config=knowledge_config,
account=dataset.created_by_account,

View File

@@ -291,27 +291,28 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str]
if not user_id:
user_id = "DEFAULT-USER"
end_user = (
db.session.query(EndUser)
.where(
EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id,
EndUser.session_id == user_id,
EndUser.type == "service_api",
with Session(db.engine, expire_on_commit=False) as session:
end_user = (
session.query(EndUser)
.where(
EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id,
EndUser.session_id == user_id,
EndUser.type == "service_api",
)
.first()
)
.first()
)
if end_user is None:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type="service_api",
is_anonymous=user_id == "DEFAULT-USER",
session_id=user_id,
)
db.session.add(end_user)
db.session.commit()
if end_user is None:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type="service_api",
is_anonymous=user_id == "DEFAULT-USER",
session_id=user_id,
)
session.add(end_user)
session.commit()
return end_user

View File

@@ -1,7 +0,0 @@
from flask import Blueprint
# Create trigger blueprint
bp = Blueprint("trigger", __name__, url_prefix="/triggers")
# Import routes after blueprint creation to avoid circular imports
from . import trigger, webhook

View File

@@ -1,41 +0,0 @@
import logging
import re
from flask import jsonify, request
from werkzeug.exceptions import NotFound
from controllers.trigger import bp
from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService
from services.trigger_service import TriggerService
logger = logging.getLogger(__name__)
UUID_PATTERN = r"^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$"
UUID_MATCHER = re.compile(UUID_PATTERN)
@bp.route("/plugin/<string:endpoint_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
def trigger_endpoint(endpoint_id: str):
"""
Handle endpoint trigger calls.
"""
# endpoint_id must be UUID
if not UUID_MATCHER.match(endpoint_id):
raise NotFound("Invalid endpoint ID")
handling_chain = [
TriggerService.process_endpoint,
TriggerSubscriptionBuilderService.process_builder_validation_endpoint,
]
try:
for handler in handling_chain:
response = handler(endpoint_id, request)
if response:
break
if not response:
raise NotFound("Endpoint not found")
return response
except ValueError as e:
raise NotFound(str(e))
except Exception as e:
logger.exception("Webhook processing failed for {endpoint_id}")
return jsonify({"error": "Internal server error", "message": str(e)}), 500

View File

@@ -1,46 +0,0 @@
import logging
from flask import jsonify
from werkzeug.exceptions import NotFound, RequestEntityTooLarge
from controllers.trigger import bp
from services.webhook_service import WebhookService
logger = logging.getLogger(__name__)
@bp.route("/webhook/<string:webhook_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
@bp.route("/webhook-debug/<string:webhook_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
def handle_webhook(webhook_id: str):
"""
Handle webhook trigger calls.
This endpoint receives webhook calls and processes them according to the
configured webhook trigger settings.
"""
try:
# Get webhook trigger, workflow, and node configuration
webhook_trigger, workflow, node_config = WebhookService.get_webhook_trigger_and_workflow(webhook_id)
# Extract request data
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
# Validate request against node configuration
validation_result = WebhookService.validate_webhook_request(webhook_data, node_config)
if not validation_result["valid"]:
return jsonify({"error": "Bad Request", "message": validation_result["error"]}), 400
# Process webhook call (send to Celery)
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
# Return configured response
response_data, status_code = WebhookService.generate_webhook_response(node_config)
return jsonify(response_data), status_code
except ValueError as e:
raise NotFound(str(e))
except RequestEntityTooLarge:
raise
except Exception as e:
logger.exception("Webhook processing failed for %s", webhook_id)
return jsonify({"error": "Internal server error", "message": str(e)}), 500

View File

@@ -73,8 +73,6 @@ class ConversationApi(WebApiResource):
ConversationService.delete(app_model, conversation_id, end_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
WebConversationService.unpin(app_model, conversation_id, end_user)
return {"result": "success"}, 204

View File

@@ -4,6 +4,7 @@ from functools import wraps
from flask import request
from flask_restx import Resource
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError
@@ -49,18 +50,19 @@ def decode_jwt_token():
decoded = PassportService().verify(tk)
app_code = decoded.get("app_code")
app_id = decoded.get("app_id")
app_model = db.session.scalar(select(App).where(App.id == app_id))
site = db.session.scalar(select(Site).where(Site.code == app_code))
if not app_model:
raise NotFound()
if not app_code or not site:
raise BadRequest("Site URL is no longer valid.")
if app_model.enable_site is False:
raise BadRequest("Site is disabled.")
end_user_id = decoded.get("end_user_id")
end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id))
if not end_user:
raise NotFound()
with Session(db.engine, expire_on_commit=False) as session:
app_model = session.scalar(select(App).where(App.id == app_id))
site = session.scalar(select(Site).where(Site.code == app_code))
if not app_model:
raise NotFound()
if not app_code or not site:
raise BadRequest("Site URL is no longer valid.")
if app_model.enable_site is False:
raise BadRequest("Site is disabled.")
end_user_id = decoded.get("end_user_id")
end_user = session.scalar(select(EndUser).where(EndUser.id == end_user_id))
if not end_user:
raise NotFound()
# for enterprise webapp auth
app_web_auth_enabled = False

View File

@@ -334,7 +334,8 @@ class BaseAgentRunner(AppRunner):
"""
Save agent thought
"""
agent_thought = db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id).first()
stmt = select(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id)
agent_thought = db.session.scalar(stmt)
if not agent_thought:
raise ValueError("agent thought not found")
@@ -492,7 +493,8 @@ class BaseAgentRunner(AppRunner):
return result
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all()
stmt = select(MessageFile).where(MessageFile.message_id == message.id)
files = db.session.scalars(stmt).all()
if not files:
return UserPromptMessage(content=message.query)
if message.app_model_config:

View File

@@ -26,7 +26,7 @@ class MoreLikeThisConfigManager:
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
try:
return AppConfigModel.model_validate(config).model_dump(), ["more_like_this"]
except ValidationError as e:
except ValidationError:
raise ValueError(
"more_like_this must be of dict type and enabled in more_like_this must be of boolean type"
)

View File

@@ -450,6 +450,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
worker_thread.start()
# release database connection, because the following new thread operations may take a long time
db.session.refresh(workflow)
db.session.refresh(message)
# db.session.refresh(user)
db.session.close()
# return response or stream generator
response = self._handle_advanced_chat_response(
application_generate_entity=application_generate_entity,

View File

@@ -72,7 +72,9 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config)
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
with Session(db.engine, expire_on_commit=False) as session:
app_record = session.scalar(select(App).where(App.id == app_config.app_id))
if not app_record:
raise ValueError("App not found")

View File

@@ -118,7 +118,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
else:
response_chunk.update(sub_stream_response.to_dict())

View File

@@ -73,7 +73,6 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
from events.message_event import message_was_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Conversation, EndUser, Message, MessageFile
@@ -311,13 +310,8 @@ class AdvancedChatAppGenerateTaskPipeline:
err = self._base_task_pipeline._handle_error(event=event, session=session, message_id=self._message_id)
yield self._base_task_pipeline._error_to_stream_response(err)
def _handle_workflow_started_event(
self, event: QueueWorkflowStartedEvent, *, graph_runtime_state: Optional[GraphRuntimeState] = None, **kwargs
) -> Generator[StreamResponse, None, None]:
def _handle_workflow_started_event(self, *args, **kwargs) -> Generator[StreamResponse, None, None]:
"""Handle workflow started events."""
# Override graph runtime state - this is a side effect but necessary
graph_runtime_state = event.graph_runtime_state
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
self._workflow_run_id = workflow_execution.id_
@@ -338,15 +332,14 @@ class AdvancedChatAppGenerateTaskPipeline:
"""Handle node retry events."""
self._ensure_workflow_initialized()
with self._database_session() as session:
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_execution_id=self._workflow_run_id, event=event
)
node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_execution_id=self._workflow_run_id, event=event
)
node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if node_retry_resp:
yield node_retry_resp
@@ -380,13 +373,12 @@ class AdvancedChatAppGenerateTaskPipeline:
self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {})
)
with self._database_session() as session:
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
self._save_output_for_event(event, workflow_node_execution.id)
@@ -939,10 +931,6 @@ class AdvancedChatAppGenerateTaskPipeline:
self._task_state.metadata.usage = usage
else:
self._task_state.metadata.usage = LLMUsage.empty_usage()
message_was_created.send(
message,
application_generate_entity=self._application_generate_entity,
)
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
"""

View File

@@ -1,6 +1,8 @@
import logging
from typing import cast
from sqlalchemy import select
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
from core.agent.entities import AgentEntity
@@ -44,8 +46,8 @@ class AgentChatAppRunner(AppRunner):
"""
app_config = application_generate_entity.app_config
app_config = cast(AgentChatAppConfig, app_config)
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
app_stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.scalar(app_stmt)
if not app_record:
raise ValueError("App not found")
@@ -182,11 +184,12 @@ class AgentChatAppRunner(AppRunner):
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
conversation_result = db.session.query(Conversation).where(Conversation.id == conversation.id).first()
conversation_stmt = select(Conversation).where(Conversation.id == conversation.id)
conversation_result = db.session.scalar(conversation_stmt)
if conversation_result is None:
raise ValueError("Conversation not found")
message_result = db.session.query(Message).where(Message.id == message.id).first()
msg_stmt = select(Message).where(Message.id == message.id)
message_result = db.session.scalar(msg_stmt)
if message_result is None:
raise ValueError("Message not found")
db.session.close()

View File

@@ -159,7 +159,7 @@ class AppQueueManager:
def _check_for_sqlalchemy_models(self, data: Any):
# from entity to dict or list
if isinstance(data, dict):
for key, value in data.items():
for value in data.values():
self._check_for_sqlalchemy_models(value)
elif isinstance(data, list):
for item in data:

View File

@@ -1,6 +1,8 @@
import logging
from typing import cast
from sqlalchemy import select
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.apps.chat.app_config_manager import ChatAppConfig
@@ -42,8 +44,8 @@ class ChatAppRunner(AppRunner):
"""
app_config = application_generate_entity.app_config
app_config = cast(ChatAppConfig, app_config)
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.scalar(stmt)
if not app_record:
raise ValueError("App not found")

View File

@@ -6,6 +6,7 @@ from typing import Any, Literal, Union, overload
from flask import Flask, copy_current_request_context, current_app
from pydantic import ValidationError
from sqlalchemy import select
from configs import dify_config
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
@@ -248,17 +249,14 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
:param invoke_from: invoke from source
:param stream: is stream
"""
message = (
db.session.query(Message)
.where(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Message.from_account_id == (user.id if isinstance(user, Account) else None),
)
.first()
stmt = select(Message).where(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Message.from_account_id == (user.id if isinstance(user, Account) else None),
)
message = db.session.scalar(stmt)
if not message:
raise MessageNotExistsError()

View File

@@ -1,6 +1,8 @@
import logging
from typing import cast
from sqlalchemy import select
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.base_app_runner import AppRunner
from core.app.apps.completion.app_config_manager import CompletionAppConfig
@@ -35,8 +37,8 @@ class CompletionAppRunner(AppRunner):
"""
app_config = application_generate_entity.app_config
app_config = cast(CompletionAppConfig, app_config)
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.scalar(stmt)
if not app_record:
raise ValueError("App not found")

View File

@@ -3,6 +3,9 @@ import logging
from collections.abc import Generator
from typing import Optional, Union, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager
@@ -83,11 +86,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig:
if conversation:
app_model_config = (
db.session.query(AppModelConfig)
.where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
.first()
stmt = select(AppModelConfig).where(
AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id
)
app_model_config = db.session.scalar(stmt)
if not app_model_config:
raise AppModelConfigBrokenError()
@@ -253,7 +255,8 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:param conversation_id: conversation id
:return: conversation
"""
conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
with Session(db.engine, expire_on_commit=False) as session:
conversation = session.scalar(select(Conversation).where(Conversation.id == conversation_id))
if not conversation:
raise ConversationNotExistsError("Conversation not exists")
@@ -266,7 +269,8 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:param message_id: message id
:return: message
"""
message = db.session.query(Message).where(Message.id == message_id).first()
with Session(db.engine, expire_on_commit=False) as session:
message = session.scalar(select(Message).where(Message.id == message_id))
if message is None:
raise MessageNotExistsError("Message not exists")

View File

@@ -54,8 +54,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
streaming: Literal[True],
call_depth: int,
workflow_thread_pool_id: Optional[str],
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
root_node_id: Optional[str] = None,
) -> Generator[Mapping | str, None, None]: ...
@overload
@@ -70,8 +68,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
streaming: Literal[False],
call_depth: int,
workflow_thread_pool_id: Optional[str],
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
root_node_id: Optional[str] = None,
) -> Mapping[str, Any]: ...
@overload
@@ -86,8 +82,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
streaming: bool,
call_depth: int,
workflow_thread_pool_id: Optional[str],
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
root_node_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
def generate(
@@ -101,8 +95,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
root_node_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
files: Sequence[Mapping[str, Any]] = args.get("files") or []
@@ -138,20 +130,17 @@ class WorkflowAppGenerator(BaseAppGenerator):
**extract_external_trace_id_from_args(args),
}
workflow_run_id = str(uuid.uuid4())
if triggered_from in (WorkflowRunTriggeredFrom.DEBUGGING, WorkflowRunTriggeredFrom.APP_RUN):
# start node get inputs
inputs = self._prepare_user_inputs(
user_inputs=inputs,
variables=app_config.variables,
tenant_id=app_model.tenant_id,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
)
# init application generate entity
application_generate_entity = WorkflowAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
file_upload_config=file_extra_config,
inputs=inputs,
inputs=self._prepare_user_inputs(
user_inputs=inputs,
variables=app_config.variables,
tenant_id=app_model.tenant_id,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
),
files=list(system_files),
user_id=user.id,
stream=streaming,
@@ -170,10 +159,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
if triggered_from is not None:
# Use explicitly provided triggered_from (for async triggers)
workflow_triggered_from = triggered_from
elif invoke_from == InvokeFrom.DEBUGGER:
if invoke_from == InvokeFrom.DEBUGGER:
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
else:
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
@@ -201,7 +187,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
root_node_id=root_node_id,
)
def _generate(
@@ -217,7 +202,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
streaming: bool = True,
workflow_thread_pool_id: Optional[str] = None,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
root_node_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.
@@ -255,7 +239,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
"context": context,
"workflow_thread_pool_id": workflow_thread_pool_id,
"variable_loader": variable_loader,
"root_node_id": root_node_id,
},
)
@@ -452,7 +435,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
context: contextvars.Context,
variable_loader: VariableLoader,
workflow_thread_pool_id: Optional[str] = None,
root_node_id: Optional[str] = None,
) -> None:
"""
Generate worker in a new thread.
@@ -496,7 +478,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
variable_loader=variable_loader,
workflow=workflow,
system_user_id=system_user_id,
root_node_id=root_node_id,
)
try:

View File

@@ -34,7 +34,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
workflow_thread_pool_id: Optional[str] = None,
workflow: Workflow,
system_user_id: str,
root_node_id: Optional[str] = None,
) -> None:
super().__init__(
queue_manager=queue_manager,
@@ -45,7 +44,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
self.workflow_thread_pool_id = workflow_thread_pool_id
self._workflow = workflow
self._sys_user_id = system_user_id
self._root_node_id = root_node_id
def run(self) -> None:
"""
@@ -95,7 +93,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
)
# init graph
graph = self._init_graph(graph_config=self._workflow.graph_dict, root_node_id=self._root_node_id)
graph = self._init_graph(graph_config=self._workflow.graph_dict)
# RUN WORKFLOW
workflow_entry = WorkflowEntry(

View File

@@ -89,7 +89,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
else:
response_chunk.update(sub_stream_response.to_dict())
yield response_chunk

View File

@@ -300,16 +300,15 @@ class WorkflowAppGenerateTaskPipeline:
"""Handle node retry events."""
self._ensure_workflow_initialized()
with self._database_session() as session:
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_execution_id=self._workflow_run_id,
event=event,
)
response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_execution_id=self._workflow_run_id,
event=event,
)
response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
yield response

View File

@@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any, Optional, cast
from typing import Any, cast
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.queue_entities import (
@@ -79,7 +79,7 @@ class WorkflowBasedAppRunner:
self._variable_loader = variable_loader
self._app_id = app_id
def _init_graph(self, graph_config: Mapping[str, Any], root_node_id: Optional[str] = None) -> Graph:
def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
"""
Init graph
"""
@@ -92,7 +92,7 @@ class WorkflowBasedAppRunner:
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
# init graph
graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id)
graph = Graph.init(graph_config=graph_config)
if not graph:
raise ValueError("graph not found in workflow")

View File

@@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence
from enum import Enum
from enum import StrEnum
from typing import Any, Optional
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
@@ -11,7 +11,7 @@ from core.file import File, FileUploadConfig
from core.model_runtime.entities.model_entities import AIModelEntity
class InvokeFrom(Enum):
class InvokeFrom(StrEnum):
"""
Invoke From.
"""

View File

@@ -1,6 +1,8 @@
import logging
from typing import Optional
from sqlalchemy import select
from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
@@ -25,9 +27,8 @@ class AnnotationReplyFeature:
:param invoke_from: invoke from
:return:
"""
annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id).first()
)
stmt = select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id)
annotation_setting = db.session.scalar(stmt)
if not annotation_setting:
return None

View File

@@ -96,7 +96,11 @@ class RateLimit:
if isinstance(generator, Mapping):
return generator
else:
return RateLimitGenerator(rate_limit=self, generator=generator, request_id=request_id)
return RateLimitGenerator(
rate_limit=self,
generator=generator, # ty: ignore [invalid-argument-type]
request_id=request_id,
)
class RateLimitGenerator:

View File

@@ -50,7 +50,7 @@ class BasedGenerateTaskPipeline:
if isinstance(e, InvokeAuthorizationError):
err = InvokeAuthorizationError("Incorrect API key provided")
elif isinstance(e, InvokeError | ValueError):
err = e
err = e # ty: ignore [invalid-assignment]
else:
description = getattr(e, "description", None)
err = Exception(description if description is not None else str(e))

View File

@@ -3,6 +3,8 @@ from threading import Thread
from typing import Optional, Union
from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import (
@@ -84,7 +86,8 @@ class MessageCycleManager:
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
with flask_app.app_context():
# get conversation and message
conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
stmt = select(Conversation).where(Conversation.id == conversation_id)
conversation = db.session.scalar(stmt)
if not conversation:
return
@@ -98,7 +101,7 @@ class MessageCycleManager:
try:
name = LLMGenerator.generate_conversation_name(app_model.tenant_id, query)
conversation.name = name
except Exception as e:
except Exception:
if dify_config.DEBUG:
logger.exception("generate conversation name failed, conversation_id: %s", conversation_id)
pass
@@ -143,7 +146,8 @@ class MessageCycleManager:
:param event: event
:return:
"""
message_file = db.session.query(MessageFile).where(MessageFile.id == event.message_file_id).first()
with Session(db.engine, expire_on_commit=False) as session:
message_file = session.scalar(select(MessageFile).where(MessageFile.id == event.message_file_id))
if message_file and message_file.url is not None:
# get tool file id
@@ -183,7 +187,8 @@ class MessageCycleManager:
:param message_id: message id
:return:
"""
message_file = db.session.query(MessageFile).where(MessageFile.id == message_id).first()
with Session(db.engine, expire_on_commit=False) as session:
message_file = session.scalar(select(MessageFile).where(MessageFile.id == message_id))
event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE
return MessageStreamResponse(

View File

@@ -1,6 +1,8 @@
import logging
from collections.abc import Sequence
from sqlalchemy import select
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
@@ -49,7 +51,8 @@ class DatasetIndexToolCallbackHandler:
for document in documents:
if document.metadata is not None:
document_id = document.metadata["document_id"]
dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
dataset_document_stmt = select(DatasetDocument).where(DatasetDocument.id == document_id)
dataset_document = db.session.scalar(dataset_document_stmt)
if not dataset_document:
_logger.warning(
"Expected DatasetDocument record to exist, but none was found, document_id=%s",
@@ -57,17 +60,14 @@ class DatasetIndexToolCallbackHandler:
)
continue
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = (
db.session.query(ChildChunk)
.where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
.first()
child_chunk_stmt = select(ChildChunk).where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
child_chunk = db.session.scalar(child_chunk_stmt)
if child_chunk:
segment = (
_ = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == child_chunk.segment_id)
.update(

View File

@@ -1,5 +1,6 @@
import json
import logging
import re
from collections import defaultdict
from collections.abc import Iterator, Sequence
from json import JSONDecodeError
@@ -343,7 +344,65 @@ class ProviderConfiguration(BaseModel):
with Session(db.engine) as new_session:
return _validate(new_session)
def create_provider_credential(self, credentials: dict, credential_name: str) -> None:
def _generate_provider_credential_name(self, session) -> str:
"""
Generate a unique credential name for provider.
:return: credential name
"""
return self._generate_next_api_key_name(
session=session,
query_factory=lambda: select(ProviderCredential).where(
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
),
)
def _generate_custom_model_credential_name(self, model: str, model_type: ModelType, session) -> str:
"""
Generate a unique credential name for custom model.
:return: credential name
"""
return self._generate_next_api_key_name(
session=session,
query_factory=lambda: select(ProviderModelCredential).where(
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
),
)
def _generate_next_api_key_name(self, session, query_factory) -> str:
"""
Generate next available API KEY name by finding the highest numbered suffix.
:param session: database session
:param query_factory: function that returns the SQLAlchemy query
:return: next available API KEY name
"""
try:
stmt = query_factory()
credential_records = session.execute(stmt).scalars().all()
if not credential_records:
return "API KEY 1"
# Extract numbers from API KEY pattern using list comprehension
pattern = re.compile(r"^API KEY\s+(\d+)$")
numbers = [
int(match.group(1))
for cr in credential_records
if cr.credential_name and (match := pattern.match(cr.credential_name.strip()))
]
# Return next sequential number
next_number = max(numbers, default=0) + 1
return f"API KEY {next_number}"
except Exception as e:
logger.warning("Error generating next credential name: %s", str(e))
return "API KEY 1"
def create_provider_credential(self, credentials: dict, credential_name: str | None) -> None:
"""
Add custom provider credentials.
:param credentials: provider credentials
@@ -351,8 +410,11 @@ class ProviderConfiguration(BaseModel):
:return:
"""
with Session(db.engine) as session:
if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
raise ValueError(f"Credential with name '{credential_name}' already exists.")
if credential_name:
if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
raise ValueError(f"Credential with name '{credential_name}' already exists.")
else:
credential_name = self._generate_provider_credential_name(session)
credentials = self.validate_provider_credentials(credentials=credentials, session=session)
provider_record = self._get_provider_record(session)
@@ -395,7 +457,7 @@ class ProviderConfiguration(BaseModel):
self,
credentials: dict,
credential_id: str,
credential_name: str,
credential_name: str | None,
) -> None:
"""
update a saved provider credential (by credential_id).
@@ -406,7 +468,7 @@ class ProviderConfiguration(BaseModel):
:return:
"""
with Session(db.engine) as session:
if self._check_provider_credential_name_exists(
if credential_name and self._check_provider_credential_name_exists(
credential_name=credential_name, session=session, exclude_id=credential_id
):
raise ValueError(f"Credential with name '{credential_name}' already exists.")
@@ -428,9 +490,9 @@ class ProviderConfiguration(BaseModel):
try:
# Update credential
credential_record.encrypted_config = json.dumps(credentials)
credential_record.credential_name = credential_name
credential_record.updated_at = naive_utc_now()
if credential_name:
credential_record.credential_name = credential_name
session.commit()
if provider_record and provider_record.credential_id == credential_id:
@@ -532,13 +594,7 @@ class ProviderConfiguration(BaseModel):
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
)
lb_credentials_cache.delete()
lb_config.credential_id = None
lb_config.encrypted_config = None
lb_config.enabled = False
lb_config.name = "__delete__"
lb_config.updated_at = naive_utc_now()
session.add(lb_config)
session.delete(lb_config)
# Check if this is the currently active credential
provider_record = self._get_provider_record(session)
@@ -822,7 +878,7 @@ class ProviderConfiguration(BaseModel):
return _validate(new_session)
def create_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict, credential_name: str
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None
) -> None:
"""
Create a custom model credential.
@@ -833,10 +889,15 @@ class ProviderConfiguration(BaseModel):
:return:
"""
with Session(db.engine) as session:
if self._check_custom_model_credential_name_exists(
model=model, model_type=model_type, credential_name=credential_name, session=session
):
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
if credential_name:
if self._check_custom_model_credential_name_exists(
model=model, model_type=model_type, credential_name=credential_name, session=session
):
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
else:
credential_name = self._generate_custom_model_credential_name(
model=model, model_type=model_type, session=session
)
# validate custom model config
credentials = self.validate_custom_model_credentials(
model_type=model_type, model=model, credentials=credentials, session=session
@@ -880,7 +941,7 @@ class ProviderConfiguration(BaseModel):
raise
def update_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict, credential_name: str, credential_id: str
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str
) -> None:
"""
Update a custom model credential.
@@ -893,7 +954,7 @@ class ProviderConfiguration(BaseModel):
:return:
"""
with Session(db.engine) as session:
if self._check_custom_model_credential_name_exists(
if credential_name and self._check_custom_model_credential_name_exists(
model=model,
model_type=model_type,
credential_name=credential_name,
@@ -925,8 +986,9 @@ class ProviderConfiguration(BaseModel):
try:
# Update credential
credential_record.encrypted_config = json.dumps(credentials)
credential_record.credential_name = credential_name
credential_record.updated_at = naive_utc_now()
if credential_name:
credential_record.credential_name = credential_name
session.commit()
if provider_model_record and provider_model_record.credential_id == credential_id:
@@ -982,12 +1044,7 @@ class ProviderConfiguration(BaseModel):
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
)
lb_credentials_cache.delete()
lb_config.credential_id = None
lb_config.encrypted_config = None
lb_config.enabled = False
lb_config.name = "__delete__"
lb_config.updated_at = naive_utc_now()
session.add(lb_config)
session.delete(lb_config)
# Check if this is the currently active credential
provider_model_record = self._get_custom_model_record(model_type, model, session=session)
@@ -1054,6 +1111,7 @@ class ProviderConfiguration(BaseModel):
provider_name=self.provider.provider,
model_name=model,
model_type=model_type.to_origin_model_type(),
is_valid=True,
credential_id=credential_id,
)
else:
@@ -1605,11 +1663,9 @@ class ProviderConfiguration(BaseModel):
if config.credential_source_type != "custom_model"
]
if len(provider_model_lb_configs) > 1:
load_balancing_enabled = True
if any(config.name == "__delete__" for config in provider_model_lb_configs):
has_invalid_load_balancing_configs = True
load_balancing_enabled = model_setting.load_balancing_enabled
# when the user enable load_balancing but available configs are less than 2 display warning
has_invalid_load_balancing_configs = load_balancing_enabled and len(provider_model_lb_configs) < 2
provider_models.append(
ModelWithProviderEntity(
@@ -1631,6 +1687,8 @@ class ProviderConfiguration(BaseModel):
for model_configuration in self.custom_configuration.models:
if model_configuration.model_type not in model_types:
continue
if model_configuration.unadded_to_model_list:
continue
if model and model != model_configuration.model:
continue
try:
@@ -1663,11 +1721,9 @@ class ProviderConfiguration(BaseModel):
if config.credential_source_type != "provider"
]
if len(custom_model_lb_configs) > 1:
load_balancing_enabled = True
if any(config.name == "__delete__" for config in custom_model_lb_configs):
has_invalid_load_balancing_configs = True
load_balancing_enabled = model_setting.load_balancing_enabled
# when the user enable load_balancing but available configs are less than 2 display warning
has_invalid_load_balancing_configs = load_balancing_enabled and len(custom_model_lb_configs) < 2
if len(model_configuration.available_model_credentials) > 0 and not model_configuration.credentials:
status = ModelStatus.CREDENTIAL_REMOVED

View File

@@ -111,11 +111,21 @@ class CustomModelConfiguration(BaseModel):
current_credential_id: Optional[str] = None
current_credential_name: Optional[str] = None
available_model_credentials: list[CredentialConfiguration] = []
unadded_to_model_list: Optional[bool] = False
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
class UnaddedModelConfiguration(BaseModel):
"""
Model class for provider unadded model configuration.
"""
model: str
model_type: ModelType
class CustomConfiguration(BaseModel):
"""
Model class for provider custom configuration.
@@ -123,6 +133,7 @@ class CustomConfiguration(BaseModel):
provider: Optional[CustomProviderConfiguration] = None
models: list[CustomModelConfiguration] = []
can_added_models: list[UnaddedModelConfiguration] = []
class ModelLoadBalancingConfiguration(BaseModel):
@@ -144,6 +155,7 @@ class ModelSettings(BaseModel):
model: str
model_type: ModelType
enabled: bool = True
load_balancing_enabled: bool = False
load_balancing_configs: list[ModelLoadBalancingConfiguration] = []
# pydantic configs
@@ -192,9 +204,8 @@ class ProviderConfig(BasicProviderConfig):
scope: AppSelectorScope | ModelSelectorScope | ToolSelectorScope | None = None
required: bool = False
default: Optional[Union[int, str, float, bool, list]] = None
default: Optional[Union[int, str, float, bool]] = None
options: Optional[list[Option]] = None
multiple: bool | None = False
label: Optional[I18nObject] = None
help: Optional[I18nObject] = None
url: Optional[str] = None

View File

@@ -43,9 +43,9 @@ class APIBasedExtensionRequestor:
timeout=self.timeout,
proxies=proxies,
)
except requests.exceptions.Timeout:
except requests.Timeout:
raise ValueError("request timeout")
except requests.exceptions.ConnectionError:
except requests.ConnectionError:
raise ValueError("request connection error")
if response.status_code != 200:

View File

@@ -91,7 +91,7 @@ class Extensible:
# Find extension class
extension_class = None
for name, obj in vars(mod).items():
for obj in vars(mod).values():
if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
extension_class = obj
break
@@ -123,7 +123,7 @@ class Extensible:
)
)
except Exception as e:
except Exception:
logger.exception("Error scanning extensions")
raise

View File

@@ -41,9 +41,3 @@ class Extension:
assert module_extension.extension_class is not None
t: type = module_extension.extension_class
return t
def validate_form_schema(self, module: ExtensionModule, extension_name: str, config: dict) -> None:
module_extension = self.module_extension(module, extension_name)
form_schema = module_extension.form_schema
# TODO validate form_schema

View File

@@ -1,5 +1,7 @@
from typing import Optional
from sqlalchemy import select
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
from core.external_data_tool.base import ExternalDataTool
from core.helper import encrypter
@@ -28,13 +30,11 @@ class ApiExternalDataTool(ExternalDataTool):
api_based_extension_id = config.get("api_based_extension_id")
if not api_based_extension_id:
raise ValueError("api_based_extension_id is required")
# get api_based_extension
api_based_extension = (
db.session.query(APIBasedExtension)
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
stmt = select(APIBasedExtension).where(
APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id
)
api_based_extension = db.session.scalar(stmt)
if not api_based_extension:
raise ValueError("api_based_extension_id is invalid")
@@ -52,13 +52,11 @@ class ApiExternalDataTool(ExternalDataTool):
raise ValueError(f"config is required, config: {self.config}")
api_based_extension_id = self.config.get("api_based_extension_id")
assert api_based_extension_id is not None, "api_based_extension_id is required"
# get api_based_extension
api_based_extension = (
db.session.query(APIBasedExtension)
.where(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
stmt = select(APIBasedExtension).where(
APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id
)
api_based_extension = db.session.scalar(stmt)
if not api_based_extension:
raise ValueError(

View File

@@ -22,7 +22,6 @@ class ExternalDataToolFactory:
:param config: the form config data
:return:
"""
code_based_extension.validate_form_schema(ExtensionModule.EXTERNAL_DATA_TOOL, name, config)
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
# FIXME mypy issue here, figure out how to fix it
extension_class.validate_config(tenant_id, config) # type: ignore

View File

@@ -11,6 +11,10 @@ def obfuscated_token(token: str) -> str:
return token[:6] + "*" * 12 + token[-2:]
def full_mask_token(token_length=20):
return "*" * token_length
def encrypt_token(tenant_id: str, token: str):
from models.account import Tenant
from models.engine import db

View File

@@ -42,7 +42,7 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
for plugin in response.json()["data"]["plugins"]:
try:
result.append(MarketplacePluginDeclaration(**plugin))
except Exception as e:
except Exception:
pass
return result

View File

@@ -47,7 +47,7 @@ def get_subclasses_from_module(mod: ModuleType, parent_type: type) -> list[type]
def load_single_subclass_from_source(
*, module_name: str, script_path: AnyStr, parent_type: type, use_lazy_loader: bool = False
*, module_name: str, script_path: str, parent_type: type, use_lazy_loader: bool = False
) -> type:
"""
Load a single subclass from the source

View File

@@ -1,7 +1,7 @@
import os
from collections import OrderedDict
from collections.abc import Callable
from typing import Any
from typing import TypeVar
from configs import dify_config
from core.tools.utils.yaml_utils import load_yaml_file
@@ -72,11 +72,14 @@ def pin_position_map(original_position_map: dict[str, int], pin_list: list[str])
return position_map
T = TypeVar("T")
def is_filtered(
include_set: set[str],
exclude_set: set[str],
data: Any,
name_func: Callable[[Any], str],
data: T,
name_func: Callable[[T], str],
) -> bool:
"""
Check if the object should be filtered out.
@@ -103,9 +106,9 @@ def is_filtered(
def sort_by_position_map(
position_map: dict[str, int],
data: list[Any],
name_func: Callable[[Any], str],
) -> list[Any]:
data: list[T],
name_func: Callable[[T], str],
):
"""
Sort the objects by the position map.
If the name of the object is not in the position map, it will be put at the end.
@@ -122,9 +125,9 @@ def sort_by_position_map(
def sort_to_dict_by_position_map(
position_map: dict[str, int],
data: list[Any],
name_func: Callable[[Any], str],
) -> OrderedDict[str, Any]:
data: list[T],
name_func: Callable[[T], str],
):
"""
Sort the objects into a ordered dict by the position map.
If the name of the object is not in the position map, it will be put at the end.
@@ -134,4 +137,4 @@ def sort_to_dict_by_position_map(
:return: an OrderedDict with the sorted pairs of name and object
"""
sorted_items = sort_by_position_map(position_map, data, name_func)
return OrderedDict([(name_func(item), item) for item in sorted_items])
return OrderedDict((name_func(item), item) for item in sorted_items)

View File

@@ -1,128 +0,0 @@
import contextlib
from copy import deepcopy
from typing import Any, Optional, Protocol
from core.entities.provider_entities import BasicProviderConfig
from core.helper import encrypter
class ProviderConfigCache(Protocol):
"""
Interface for provider configuration cache operations
"""
def get(self) -> Optional[dict]:
"""Get cached provider configuration"""
...
def set(self, config: dict[str, Any]) -> None:
"""Cache provider configuration"""
...
def delete(self) -> None:
"""Delete cached provider configuration"""
...
class ProviderConfigEncrypter:
tenant_id: str
config: list[BasicProviderConfig]
provider_config_cache: ProviderConfigCache
def __init__(
self,
tenant_id: str,
config: list[BasicProviderConfig],
provider_config_cache: ProviderConfigCache,
):
self.tenant_id = tenant_id
self.config = config
self.provider_config_cache = provider_config_cache
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
"""
deep copy data
"""
return deepcopy(data)
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
"""
encrypt tool credentials with tenant id
return a deep copy of credentials with encrypted values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
data[field_name] = encrypted
return data
def mask_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
"""
mask credentials
return a deep copy of credentials with masked values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
if len(data[field_name]) > 6:
data[field_name] = (
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
)
else:
data[field_name] = "*" * len(data[field_name])
return data
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
return self.mask_credentials(data)
def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
"""
decrypt tool credentials with tenant id
return a deep copy of credentials with decrypted values
"""
cached_credentials = self.provider_config_cache.get()
if cached_credentials:
return cached_credentials
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
with contextlib.suppress(Exception):
# if the value is None or empty string, skip decrypt
if not data[field_name]:
continue
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
self.provider_config_cache.set(data)
return data
def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache

View File

@@ -5,9 +5,10 @@ import re
import threading
import time
import uuid
from typing import Any, Optional, cast
from typing import Any, Optional
from flask import current_app
from sqlalchemy import select
from sqlalchemy.orm.exc import ObjectDeletedError
from configs import dify_config
@@ -18,6 +19,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
@@ -56,13 +58,11 @@ class IndexingRunner:
if not dataset:
raise ValueError("no dataset found")
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first()
stmt = select(DatasetProcessRule).where(
DatasetProcessRule.id == dataset_document.dataset_process_rule_id
)
processing_rule = db.session.scalar(stmt)
if not processing_rule:
raise ValueError("no process rule found")
index_type = dataset_document.doc_form
@@ -123,11 +123,8 @@ class IndexingRunner:
db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete()
db.session.commit()
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first()
)
stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
processing_rule = db.session.scalar(stmt)
if not processing_rule:
raise ValueError("no process rule found")
@@ -208,7 +205,6 @@ class IndexingRunner:
child_documents.append(child_document)
document.children = child_documents
documents.append(document)
# build index
index_type = dataset_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
@@ -310,7 +306,8 @@ class IndexingRunner:
# delete image files and related db records
image_upload_file_ids = get_image_upload_file_ids(document.page_content)
for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
stmt = select(UploadFile).where(UploadFile.id == upload_file_id)
image_file = db.session.scalar(stmt)
if image_file is None:
continue
try:
@@ -339,14 +336,14 @@ class IndexingRunner:
if dataset_document.data_source_type == "upload_file":
if not data_source_info or "upload_file_id" not in data_source_info:
raise ValueError("no upload file found")
file_detail = (
db.session.query(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]).one_or_none()
)
stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"])
file_detail = db.session.scalars(stmt).one_or_none()
if file_detail:
extract_setting = ExtractSetting(
datasource_type="upload_file", upload_file=file_detail, document_model=dataset_document.doc_form
datasource_type=DatasourceType.FILE.value,
upload_file=file_detail,
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
elif dataset_document.data_source_type == "notion_import":
@@ -357,7 +354,7 @@ class IndexingRunner:
):
raise ValueError("no notion import info found")
extract_setting = ExtractSetting(
datasource_type="notion_import",
datasource_type=DatasourceType.NOTION.value,
notion_info={
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
@@ -377,7 +374,7 @@ class IndexingRunner:
):
raise ValueError("no website import info found")
extract_setting = ExtractSetting(
datasource_type="website_crawl",
datasource_type=DatasourceType.WEBSITE.value,
website_info={
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
@@ -400,7 +397,6 @@ class IndexingRunner:
)
# replace doc id to document model id
text_docs = cast(list[Document], text_docs)
for text_doc in text_docs:
if text_doc.metadata is not None:
text_doc.metadata["document_id"] = dataset_document.id

View File

@@ -56,11 +56,8 @@ class LLMGenerator:
prompts = [UserPromptMessage(content=prompt)]
with measure_time() as timer:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
),
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
)
answer = cast(str, response.message.content)
cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL)
@@ -69,7 +66,7 @@ class LLMGenerator:
try:
result_dict = json.loads(cleaned_answer)
answer = result_dict["Your Output"]
except json.JSONDecodeError as e:
except json.JSONDecodeError:
logger.exception("Failed to generate name after answer, use query instead")
answer = query
name = answer.strip()
@@ -113,13 +110,10 @@ class LLMGenerator:
prompt_messages = [UserPromptMessage(content=prompt)]
try:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(prompt_messages),
model_parameters={"max_tokens": 256, "temperature": 0},
stream=False,
),
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages),
model_parameters={"max_tokens": 256, "temperature": 0},
stream=False,
)
text_content = response.message.get_text_content()
@@ -162,11 +156,8 @@ class LLMGenerator:
)
try:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
rule_config["prompt"] = cast(str, response.message.content)
@@ -212,11 +203,8 @@ class LLMGenerator:
try:
try:
# the first step to generate the task prompt
prompt_content = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
prompt_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
except InvokeError as e:
error = str(e)
@@ -248,11 +236,8 @@ class LLMGenerator:
statement_messages = [UserPromptMessage(content=statement_generate_prompt)]
try:
parameter_content = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
),
parameter_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
)
rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content))
except InvokeError as e:
@@ -260,11 +245,8 @@ class LLMGenerator:
error_step = "generate variables"
try:
statement_content = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
),
statement_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
)
rule_config["opening_statement"] = cast(str, statement_content.message.content)
except InvokeError as e:
@@ -307,11 +289,8 @@ class LLMGenerator:
prompt_messages = [UserPromptMessage(content=prompt)]
model_parameters = model_config.get("completion_params", {})
try:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
generated_code = cast(str, response.message.content)
@@ -338,13 +317,10 @@ class LLMGenerator:
prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)]
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters={"temperature": 0.01, "max_tokens": 2000},
stream=False,
),
response: LLMResult = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters={"temperature": 0.01, "max_tokens": 2000},
stream=False,
)
answer = cast(str, response.message.content)
@@ -367,11 +343,8 @@ class LLMGenerator:
model_parameters = model_config.get("model_parameters", {})
try:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
raw_content = response.message.content
@@ -555,11 +528,8 @@ class LLMGenerator:
model_parameters = {"temperature": 0.4}
try:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
generated_raw = cast(str, response.message.content)

View File

@@ -101,7 +101,7 @@ def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackSta
def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
"""Check if the server supports OAuth 2.0 Resource Discovery."""
b_scheme, b_netloc, b_path, b_params, b_query, b_fragment = urlparse(server_url, "", True)
b_scheme, b_netloc, b_path, _, b_query, b_fragment = urlparse(server_url, "", True)
url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource{b_path}"
if b_query:
url_for_resource_discovery += f"?{b_query}"
@@ -117,7 +117,7 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
else:
return False, ""
return False, ""
except httpx.RequestError as e:
except httpx.RequestError:
# Not support resource discovery, fall back to well-known OAuth metadata
return False, ""

View File

@@ -246,6 +246,10 @@ class StreamableHTTPTransport:
logger.debug("Received 202 Accepted")
return
if response.status_code == 204:
logger.debug("Received 204 No Content")
return
if response.status_code == 404:
if isinstance(message.root, JSONRPCRequest):
self._send_session_terminated_error(

View File

@@ -2,7 +2,7 @@ import logging
from collections.abc import Callable
from contextlib import AbstractContextManager, ExitStack
from types import TracebackType
from typing import Any, Optional, cast
from typing import Any, Optional
from urllib.parse import urlparse
from core.mcp.client.sse_client import sse_client
@@ -116,8 +116,7 @@ class MCPClient:
self._session_context = ClientSession(*streams)
self._session = self._exit_stack.enter_context(self._session_context)
session = cast(ClientSession, self._session)
session.initialize()
self._session.initialize()
return
except MCPAuthError:

View File

@@ -110,9 +110,9 @@ class TokenBufferMemory:
else:
message_limit = 500
stmt = stmt.limit(message_limit)
msg_limit_stmt = stmt.limit(message_limit)
messages = db.session.scalars(stmt).all()
messages = db.session.scalars(msg_limit_stmt).all()
# instead of all messages from the conversation, we only need to extract messages
# that belong to the thread of last message

View File

@@ -1,6 +1,7 @@
from typing import Optional
from pydantic import BaseModel, Field
from sqlalchemy import select
from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor
from core.helper.encrypter import decrypt_token
@@ -87,10 +88,9 @@ class ApiModeration(Moderation):
@staticmethod
def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]:
extension = (
db.session.query(APIBasedExtension)
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
stmt = select(APIBasedExtension).where(
APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id
)
extension = db.session.scalar(stmt)
return extension

View File

@@ -20,7 +20,6 @@ class ModerationFactory:
:param config: the form config data
:return:
"""
code_based_extension.validate_form_schema(ExtensionModule.MODERATION, name, config)
extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
# FIXME: mypy error, try to fix it instead of using type: ignore
extension_class.validate_config(tenant_id, config) # type: ignore

View File

@@ -135,7 +135,7 @@ class OutputModeration(BaseModel):
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
return result
except Exception as e:
except Exception:
logger.exception("Moderation Output error, app_id: %s", app_id)
return None

View File

@@ -5,6 +5,7 @@ from typing import Optional
from urllib.parse import urljoin
from opentelemetry.trace import Link, Status, StatusCode
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from core.ops.aliyun_trace.data_exporter.traceclient import (
@@ -263,15 +264,15 @@ class AliyunDataTrace(BaseTraceInstance):
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
app = session.query(App).where(App.id == app_id).first()
app_stmt = select(App).where(App.id == app_id)
app = session.scalar(app_stmt)
if not app:
raise ValueError(f"App with id {app_id} not found")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
service_account = session.query(Account).where(Account.id == app.created_by).first()
account_stmt = select(Account).where(Account.id == app.created_by)
service_account = session.scalar(account_stmt)
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
current_tenant = (

View File

@@ -72,7 +72,7 @@ class TraceClient:
else:
logger.debug("AliyunTrace API check failed: Unexpected status code: %s", response.status_code)
return False
except requests.exceptions.RequestException as e:
except requests.RequestException as e:
logger.debug("AliyunTrace API check failed: %s", str(e))
raise ValueError(f"AliyunTrace API check failed: {str(e)}")

View File

@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.ops.entities.config_entity import BaseTracingConfig
@@ -44,14 +45,15 @@ class BaseTraceInstance(ABC):
"""
with Session(db.engine, expire_on_commit=False) as session:
# Get the app to find its creator
app = session.query(App).where(App.id == app_id).first()
app_stmt = select(App).where(App.id == app_id)
app = session.scalar(app_stmt)
if not app:
raise ValueError(f"App with id {app_id} not found")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
service_account = session.query(Account).where(Account.id == app.created_by).first()
account_stmt = select(Account).where(Account.id == app.created_by)
service_account = session.scalar(account_stmt)
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")

View File

@@ -226,9 +226,9 @@ class OpsTraceManager:
if not trace_config_data:
return None
# decrypt_token
app = db.session.query(App).where(App.id == app_id).first()
stmt = select(App).where(App.id == app_id)
app = db.session.scalar(stmt)
if not app:
raise ValueError("App not found")
@@ -295,20 +295,19 @@ class OpsTraceManager:
@classmethod
def get_app_config_through_message_id(cls, message_id: str):
app_model_config = None
message_data = db.session.query(Message).where(Message.id == message_id).first()
message_stmt = select(Message).where(Message.id == message_id)
message_data = db.session.scalar(message_stmt)
if not message_data:
return None
conversation_id = message_data.conversation_id
conversation_data = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
conversation_stmt = select(Conversation).where(Conversation.id == conversation_id)
conversation_data = db.session.scalar(conversation_stmt)
if not conversation_data:
return None
if conversation_data.app_model_config_id:
app_model_config = (
db.session.query(AppModelConfig)
.where(AppModelConfig.id == conversation_data.app_model_config_id)
.first()
)
config_stmt = select(AppModelConfig).where(AppModelConfig.id == conversation_data.app_model_config_id)
app_model_config = db.session.scalar(config_stmt)
elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs:
app_model_config = conversation_data.override_model_configs
@@ -850,7 +849,7 @@ class TraceQueueManager:
if self.trace_instance:
trace_task.app_id = self.app_id
trace_manager_queue.put(trace_task)
except Exception as e:
except Exception:
logger.exception("Error adding trace task, trace_type %s", trace_task.trace_type)
finally:
self.start_timer()
@@ -869,7 +868,7 @@ class TraceQueueManager:
tasks = self.collect_tasks()
if tasks:
self.send_to_celery(tasks)
except Exception as e:
except Exception:
logger.exception("Error processing trace tasks")
def start_timer(self):

View File

@@ -1,6 +1,8 @@
from collections.abc import Generator, Mapping
from typing import Optional, Union
from sqlalchemy import select
from controllers.service_api.wraps import create_or_update_end_user_for_user_id
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
@@ -192,10 +194,11 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
"""
get the user by user id
"""
user = db.session.query(EndUser).where(EndUser.id == user_id).first()
stmt = select(EndUser).where(EndUser.id == user_id)
user = db.session.scalar(stmt)
if not user:
user = db.session.query(Account).where(Account.id == user_id).first()
stmt = select(Account).where(Account.id == user_id)
user = db.session.scalar(stmt)
if not user:
raise ValueError("user not found")

View File

@@ -13,7 +13,6 @@ from core.plugin.entities.base import BasePluginEntity
from core.plugin.entities.endpoint import EndpointProviderDeclaration
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderEntity
from core.trigger.entities.entities import TriggerProviderEntity
class PluginInstallationSource(enum.StrEnum):
@@ -63,7 +62,6 @@ class PluginCategory(enum.StrEnum):
Model = "model"
Extension = "extension"
AgentStrategy = "agent-strategy"
Trigger = "trigger"
class PluginDeclaration(BaseModel):
@@ -71,7 +69,6 @@ class PluginDeclaration(BaseModel):
tools: Optional[list[str]] = Field(default_factory=list[str])
models: Optional[list[str]] = Field(default_factory=list[str])
endpoints: Optional[list[str]] = Field(default_factory=list[str])
triggers: Optional[list[str]] = Field(default_factory=list[str])
class Meta(BaseModel):
minimum_dify_version: Optional[str] = Field(default=None, pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$")
@@ -92,7 +89,6 @@ class PluginDeclaration(BaseModel):
repo: Optional[str] = Field(default=None)
verified: bool = Field(default=False)
tool: Optional[ToolProviderEntity] = None
trigger: Optional[TriggerProviderEntity] = None
model: Optional[ProviderEntity] = None
endpoint: Optional[EndpointProviderDeclaration] = None
agent_strategy: Optional[AgentStrategyProviderEntity] = None
@@ -108,8 +104,6 @@ class PluginDeclaration(BaseModel):
values["category"] = PluginCategory.Model
elif values.get("agent_strategy"):
values["category"] = PluginCategory.AgentStrategy
elif values.get("trigger"):
values["category"] = PluginCategory.Trigger
else:
values["category"] = PluginCategory.Extension
return values
@@ -190,10 +184,6 @@ class ToolProviderID(GenericProviderID):
self.plugin_name = f"{self.provider_name}_tool"
class TriggerProviderID(GenericProviderID):
pass
class PluginDependency(BaseModel):
class Type(enum.StrEnum):
Github = PluginInstallationSource.Github.value

View File

@@ -1,4 +1,3 @@
import enum
from collections.abc import Mapping, Sequence
from datetime import datetime
from enum import StrEnum
@@ -14,7 +13,6 @@ from core.plugin.entities.parameters import PluginParameterOption
from core.plugin.entities.plugin import PluginDeclaration, PluginEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin
from core.trigger.entities.entities import TriggerProviderEntity
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
@@ -198,48 +196,3 @@ class PluginListResponse(BaseModel):
class PluginDynamicSelectOptionsResponse(BaseModel):
options: Sequence[PluginParameterOption] = Field(description="The options of the dynamic select.")
class PluginTriggerProviderEntity(BaseModel):
provider: str
plugin_unique_identifier: str
plugin_id: str
declaration: TriggerProviderEntity
class CredentialType(enum.StrEnum):
API_KEY = "api-key"
OAUTH2 = "oauth2"
UNAUTHORIZED = "unauthorized"
def get_name(self):
if self == CredentialType.API_KEY:
return "API KEY"
elif self == CredentialType.OAUTH2:
return "AUTH"
elif self == CredentialType.UNAUTHORIZED:
return "UNAUTHORIZED"
else:
return self.value.replace("-", " ").upper()
def is_editable(self):
return self == CredentialType.API_KEY
def is_validate_allowed(self):
return self == CredentialType.API_KEY
@classmethod
def values(cls):
return [item.value for item in cls]
@classmethod
def of(cls, credential_type: str) -> "CredentialType":
type_name = credential_type.lower()
if type_name == "api-key":
return cls.API_KEY
elif type_name == "oauth2":
return cls.OAUTH2
elif type_name == "unauthorized":
return cls.UNAUTHORIZED
else:
raise ValueError(f"Invalid credential type: {credential_type}")

View File

@@ -1,7 +1,5 @@
from collections.abc import Mapping
from typing import Any, Literal, Optional
from flask import Response
from pydantic import BaseModel, ConfigDict, Field, field_validator
from core.entities.provider_entities import BasicProviderConfig
@@ -239,34 +237,3 @@ class RequestFetchAppInfo(BaseModel):
"""
app_id: str
class Event(BaseModel):
variables: Mapping[str, Any]
class TriggerInvokeResponse(BaseModel):
event: Event
cancelled: Optional[bool] = False
class PluginTriggerDispatchResponse(BaseModel):
triggers: list[str]
raw_http_response: str
class TriggerSubscriptionResponse(BaseModel):
subscription: dict[str, Any]
class TriggerValidateProviderCredentialsResponse(BaseModel):
result: bool
class TriggerDispatchResponse:
triggers: list[str]
response: Response
def __init__(self, triggers: list[str], response: Response):
self.triggers = triggers
self.response = response

View File

@@ -64,7 +64,7 @@ class BasePluginClient:
response = requests.request(
method=method, url=str(url), headers=headers, data=data, params=params, stream=stream, files=files
)
except requests.exceptions.ConnectionError:
except requests.ConnectionError:
logger.exception("Request to Plugin Daemon Service failed")
raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed")

View File

@@ -15,7 +15,6 @@ class DynamicSelectClient(BasePluginClient):
provider: str,
action: str,
credentials: Mapping[str, Any],
credential_type: str,
parameter: str,
) -> PluginDynamicSelectOptionsResponse:
"""
@@ -30,7 +29,6 @@ class DynamicSelectClient(BasePluginClient):
"data": {
"provider": GenericProviderID(provider).provider_name,
"credentials": credentials,
"credential_type": credential_type,
"provider_action": action,
"parameter": parameter,
},

Some files were not shown because too many files have changed in this diff Show More