mirror of
https://github.com/langgenius/dify.git
synced 2026-04-10 05:04:22 +08:00
Compare commits
243 Commits
1.8.0
...
feat/trigg
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
037cdb3d7d | ||
|
|
7b9d01bfca | ||
|
|
bd1fcd3525 | ||
|
|
0cb0cea167 | ||
|
|
ee68a685a7 | ||
|
|
c78bd492af | ||
|
|
6857bb4406 | ||
|
|
dcf3ee6982 | ||
|
|
76850749e4 | ||
|
|
91e5e33440 | ||
|
|
11e55088c9 | ||
|
|
57c0bc9fb6 | ||
|
|
c3ebb22a4b | ||
|
|
1562d00037 | ||
|
|
e9e843b27d | ||
|
|
ec33b9908e | ||
|
|
67004368d9 | ||
|
|
50bff270b6 | ||
|
|
bd5cf1c272 | ||
|
|
d22404994a | ||
|
|
9898730cc5 | ||
|
|
b0f1e55a87 | ||
|
|
6566824807 | ||
|
|
9249a2af0d | ||
|
|
112fc3b1d1 | ||
|
|
37299b3bd7 | ||
|
|
8f65ce995a | ||
|
|
4a743e6dc1 | ||
|
|
07dda61929 | ||
|
|
0d8438ef40 | ||
|
|
96bb638969 | ||
|
|
e74962272e | ||
|
|
5a15419baf | ||
|
|
e8403977b9 | ||
|
|
add2ca85f2 | ||
|
|
fbb7b02e90 | ||
|
|
249b62c9de | ||
|
|
b433322e8d | ||
|
|
1c8850fc95 | ||
|
|
dc16f1b65a | ||
|
|
ff30395dc1 | ||
|
|
8e600f3302 | ||
|
|
5a1e0a8379 | ||
|
|
2a3ce6baa9 | ||
|
|
01b2f9cff6 | ||
|
|
ac38614171 | ||
|
|
eb95c5cd07 | ||
|
|
a799b54b9e | ||
|
|
98ba0236e6 | ||
|
|
b6c552df07 | ||
|
|
e2827e475d | ||
|
|
58cbd337b5 | ||
|
|
a91e59d544 | ||
|
|
814787677a | ||
|
|
85caa5bd0c | ||
|
|
e04083fc0e | ||
|
|
cf532e5e0d | ||
|
|
c097fc2c48 | ||
|
|
0371d71409 | ||
|
|
81ef7343d4 | ||
|
|
8e4b59c90c | ||
|
|
68f73410fc | ||
|
|
88af8ed374 | ||
|
|
015f82878e | ||
|
|
3874e58dc2 | ||
|
|
9f8c159583 | ||
|
|
d8f6f9ce19 | ||
|
|
eab03e63d4 | ||
|
|
461829274a | ||
|
|
e751c0c535 | ||
|
|
1fffc79c32 | ||
|
|
83fab4bc19 | ||
|
|
f60e28d2f5 | ||
|
|
a62d7aa3ee | ||
|
|
cc84a45244 | ||
|
|
5cf3d24018 | ||
|
|
4bdbe617fe | ||
|
|
33c867fd8c | ||
|
|
2013ceb9d2 | ||
|
|
7120c6414c | ||
|
|
5ce7b2d98d | ||
|
|
cb82198271 | ||
|
|
5e5ffaa416 | ||
|
|
4b253e1f73 | ||
|
|
dd929dbf0e | ||
|
|
97a9d34e96 | ||
|
|
602070ec9c | ||
|
|
afd8989150 | ||
|
|
694197a701 | ||
|
|
2f08306695 | ||
|
|
6acc77d86d | ||
|
|
5ddd5e49ee | ||
|
|
72f9e77368 | ||
|
|
a46c9238fa | ||
|
|
87120ad4ac | ||
|
|
7544b5ec9a | ||
|
|
ff4a62d1e7 | ||
|
|
41daa51988 | ||
|
|
d522350c99 | ||
|
|
1d1bb9451e | ||
|
|
1fce1a61d4 | ||
|
|
883a6caf96 | ||
|
|
a239c39f09 | ||
|
|
e925a8ab99 | ||
|
|
bccaf939e6 | ||
|
|
676648e0b3 | ||
|
|
4ae19e6dde | ||
|
|
4d0ff5c281 | ||
|
|
bd482eb8ef | ||
|
|
327b354cc2 | ||
|
|
5b3cc560d5 | ||
|
|
d41d4deaac | ||
|
|
208ce4e774 | ||
|
|
414ee51975 | ||
|
|
d5a521eef2 | ||
|
|
6d307cc9fc | ||
|
|
1b401063e8 | ||
|
|
60d9d0584a | ||
|
|
ffba341258 | ||
|
|
f11131f8b5 | ||
|
|
adc7134af5 | ||
|
|
10f19cd0c2 | ||
|
|
2e6e414a9e | ||
|
|
c45d676477 | ||
|
|
b8d8dddd5a | ||
|
|
c45c22b1b2 | ||
|
|
3d57a9ccdc | ||
|
|
cb04c21141 | ||
|
|
f70272f638 | ||
|
|
b4b71ded47 | ||
|
|
24e2b72b71 | ||
|
|
529791ce62 | ||
|
|
b66945b9b8 | ||
|
|
f3c5d77ad5 | ||
|
|
e5e42bc483 | ||
|
|
bdfbfa391f | ||
|
|
72acd9b483 | ||
|
|
9ed45594c6 | ||
|
|
9f528d23d4 | ||
|
|
d937cc491d | ||
|
|
863f3aeb27 | ||
|
|
c138f4c3a6 | ||
|
|
a35be05790 | ||
|
|
0fe078d25e | ||
|
|
d9420c7224 | ||
|
|
9ff6baaf52 | ||
|
|
574d00bb13 | ||
|
|
8d60e5c342 | ||
|
|
d9eb1a73af | ||
|
|
1a34ff8a67 | ||
|
|
14e7ba4818 | ||
|
|
60b5ed8e5d | ||
|
|
52e9bcbfdb | ||
|
|
20ae3eae54 | ||
|
|
0fb145e667 | ||
|
|
bcac43c812 | ||
|
|
929d9e0b3f | ||
|
|
d5e560a987 | ||
|
|
e4383d6167 | ||
|
|
d8ddbc4d87 | ||
|
|
19c0fc85e2 | ||
|
|
f32e176d6a | ||
|
|
3d5a4df9d0 | ||
|
|
a58df35ead | ||
|
|
e47bfd2ca3 | ||
|
|
f8f768873e | ||
|
|
d043e1a05a | ||
|
|
9789bd02d8 | ||
|
|
837c0ddacc | ||
|
|
7c340695d6 | ||
|
|
e87d4fbf69 | ||
|
|
d94e54923f | ||
|
|
39064197da | ||
|
|
c4496e6cf2 | ||
|
|
27d09d1783 | ||
|
|
a174ee419e | ||
|
|
79e6138ce2 | ||
|
|
5a64f69456 | ||
|
|
64c7be59b7 | ||
|
|
5c01dd97e8 | ||
|
|
89ad6ad902 | ||
|
|
ecf74d91e2 | ||
|
|
62892ed8d7 | ||
|
|
4f73bc9693 | ||
|
|
7b399cc5e5 | ||
|
|
fab5740778 | ||
|
|
30f2d756a7 | ||
|
|
0d745c64d8 | ||
|
|
add6b79231 | ||
|
|
56c51f0112 | ||
|
|
c90dad566f | ||
|
|
4adf85d7d4 | ||
|
|
7995ff1410 | ||
|
|
d2f234757b | ||
|
|
bf34437186 | ||
|
|
47f02eec96 | ||
|
|
06dd4d6e00 | ||
|
|
fbceda7b66 | ||
|
|
9d6ce3065d | ||
|
|
5cbe6bf8f8 | ||
|
|
bb718acadf | ||
|
|
4cd00efe3b | ||
|
|
22b11e4b43 | ||
|
|
4ef6ff217e | ||
|
|
87abfbf515 | ||
|
|
73e65fd838 | ||
|
|
2a29c61041 | ||
|
|
e53edb0fc2 | ||
|
|
17908fbf6b | ||
|
|
3dae108f84 | ||
|
|
34b041e9f0 | ||
|
|
917ed8cf84 | ||
|
|
85b0b8373b | ||
|
|
5bbf685035 | ||
|
|
a63d1e87b1 | ||
|
|
7129de98cd | ||
|
|
2984dbc0df | ||
|
|
392db7f611 | ||
|
|
5a427b8daa | ||
|
|
18f2e6f166 | ||
|
|
e78903302f | ||
|
|
4084ade86c | ||
|
|
6b0d919dbd | ||
|
|
a7b558b38b | ||
|
|
6aed7e3ff4 | ||
|
|
8e93a8a2e2 | ||
|
|
e38a86e37b | ||
|
|
392e3530bf | ||
|
|
833c902b2b | ||
|
|
6eaea64b3f | ||
|
|
5303b50737 | ||
|
|
6acbcfe679 | ||
|
|
16ef5ebb97 | ||
|
|
acfb95f9c2 | ||
|
|
aacea166d7 | ||
|
|
f7bb3b852a | ||
|
|
d4ff1e031a | ||
|
|
6a3d135d49 | ||
|
|
5c4bf7aabd | ||
|
|
e9c7dc7464 | ||
|
|
74ad21b145 | ||
|
|
f214eeb7b1 | ||
|
|
ae25f90f34 |
19
.claude/settings.json.example
Normal file
19
.claude/settings.json.example
Normal file
@@ -0,0 +1,19 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [],
|
||||
"deny": []
|
||||
},
|
||||
"env": {
|
||||
"__comment": "Environment variables for MCP servers. Override in .claude/settings.local.json with actual values.",
|
||||
"GITHUB_PERSONAL_ACCESS_TOKEN": "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||||
},
|
||||
"enabledMcpjsonServers": [
|
||||
"context7",
|
||||
"sequential-thinking",
|
||||
"github",
|
||||
"fetch",
|
||||
"playwright",
|
||||
"ide"
|
||||
],
|
||||
"enableAllProjectMcpServers": true
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
#!/bin/bash
|
||||
|
||||
npm add -g pnpm@10.15.0
|
||||
corepack enable
|
||||
cd web && pnpm install
|
||||
pipx install uv
|
||||
|
||||
|
||||
8
.github/workflows/api-tests.yml
vendored
8
.github/workflows/api-tests.yml
vendored
@@ -1,13 +1,7 @@
|
||||
name: Run Pytest
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- api/**
|
||||
- docker/**
|
||||
- .github/workflows/api-tests.yml
|
||||
workflow_call:
|
||||
|
||||
concurrency:
|
||||
group: api-tests-${{ github.head_ref || github.run_id }}
|
||||
|
||||
8
.github/workflows/autofix.yml
vendored
8
.github/workflows/autofix.yml
vendored
@@ -1,10 +1,9 @@
|
||||
name: autofix.ci
|
||||
on:
|
||||
workflow_call:
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
branches: ["main"]
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
branches: ["main"]
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
@@ -18,7 +17,7 @@ jobs:
|
||||
# Use uv to ensure we have the same ruff version in CI and locally.
|
||||
- uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
python-version: "3.12"
|
||||
python-version: "3.12"
|
||||
- run: |
|
||||
cd api
|
||||
uv sync --dev
|
||||
@@ -29,6 +28,7 @@ jobs:
|
||||
- name: ast-grep
|
||||
run: |
|
||||
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
|
||||
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
|
||||
- name: mdformat
|
||||
run: |
|
||||
uvx mdformat .
|
||||
|
||||
14
.github/workflows/db-migration-test.yml
vendored
14
.github/workflows/db-migration-test.yml
vendored
@@ -1,13 +1,7 @@
|
||||
name: DB Migration Test
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- plugins/beta
|
||||
paths:
|
||||
- api/migrations/**
|
||||
- .github/workflows/db-migration-test.yml
|
||||
workflow_call:
|
||||
|
||||
concurrency:
|
||||
group: db-migration-test-${{ github.ref }}
|
||||
@@ -33,6 +27,12 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api
|
||||
- name: Ensure Offline migration are supported
|
||||
run: |
|
||||
# upgrade
|
||||
uv run --directory api flask db upgrade 'base:head' --sql
|
||||
# downgrade
|
||||
uv run --directory api flask db downgrade 'head:base' --sql
|
||||
|
||||
- name: Prepare middleware env
|
||||
run: |
|
||||
|
||||
78
.github/workflows/main-ci.yml
vendored
Normal file
78
.github/workflows/main-ci.yml
vendored
Normal file
@@ -0,0 +1,78 @@
|
||||
name: Main CI Pipeline
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
push:
|
||||
branches: ["main"]
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
checks: write
|
||||
statuses: write
|
||||
|
||||
concurrency:
|
||||
group: main-ci-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
# Check which paths were changed to determine which tests to run
|
||||
check-changes:
|
||||
name: Check Changed Files
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
api-changed: ${{ steps.changes.outputs.api }}
|
||||
web-changed: ${{ steps.changes.outputs.web }}
|
||||
vdb-changed: ${{ steps.changes.outputs.vdb }}
|
||||
migration-changed: ${{ steps.changes.outputs.migration }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: dorny/paths-filter@v3
|
||||
id: changes
|
||||
with:
|
||||
filters: |
|
||||
api:
|
||||
- 'api/**'
|
||||
- 'docker/**'
|
||||
- '.github/workflows/api-tests.yml'
|
||||
web:
|
||||
- 'web/**'
|
||||
vdb:
|
||||
- 'api/core/rag/datasource/**'
|
||||
- 'docker/**'
|
||||
- '.github/workflows/vdb-tests.yml'
|
||||
- 'api/uv.lock'
|
||||
- 'api/pyproject.toml'
|
||||
migration:
|
||||
- 'api/migrations/**'
|
||||
- '.github/workflows/db-migration-test.yml'
|
||||
|
||||
# Run tests in parallel
|
||||
api-tests:
|
||||
name: API Tests
|
||||
needs: check-changes
|
||||
if: needs.check-changes.outputs.api-changed == 'true'
|
||||
uses: ./.github/workflows/api-tests.yml
|
||||
|
||||
web-tests:
|
||||
name: Web Tests
|
||||
needs: check-changes
|
||||
if: needs.check-changes.outputs.web-changed == 'true'
|
||||
uses: ./.github/workflows/web-tests.yml
|
||||
|
||||
style-check:
|
||||
name: Style Check
|
||||
uses: ./.github/workflows/style.yml
|
||||
|
||||
vdb-tests:
|
||||
name: VDB Tests
|
||||
needs: check-changes
|
||||
if: needs.check-changes.outputs.vdb-changed == 'true'
|
||||
uses: ./.github/workflows/vdb-tests.yml
|
||||
|
||||
db-migration-test:
|
||||
name: DB Migration Test
|
||||
needs: check-changes
|
||||
if: needs.check-changes.outputs.migration-changed == 'true'
|
||||
uses: ./.github/workflows/db-migration-test.yml
|
||||
19
.github/workflows/style.yml
vendored
19
.github/workflows/style.yml
vendored
@@ -1,9 +1,7 @@
|
||||
name: Style check
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
workflow_call:
|
||||
|
||||
concurrency:
|
||||
group: style-${{ github.head_ref || github.run_id }}
|
||||
@@ -46,21 +44,10 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Ruff check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: |
|
||||
uv run --directory api ruff --version
|
||||
uv run --directory api ruff check ./
|
||||
uv run --directory api ruff format --check ./
|
||||
|
||||
- name: Dotenv check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: uv run --project api dotenv-linter ./api/.env.example ./web/.env.example
|
||||
|
||||
- name: Lint hints
|
||||
if: failure()
|
||||
run: echo "Please run 'dev/reformat' to fix the fixable linting errors."
|
||||
|
||||
web-style:
|
||||
name: Web Style
|
||||
runs-on: ubuntu-latest
|
||||
@@ -102,7 +89,9 @@ jobs:
|
||||
- name: Web style check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run lint
|
||||
run: |
|
||||
pnpm run lint
|
||||
pnpm run eslint
|
||||
|
||||
docker-compose-template:
|
||||
name: Docker Compose Template
|
||||
|
||||
10
.github/workflows/vdb-tests.yml
vendored
10
.github/workflows/vdb-tests.yml
vendored
@@ -1,15 +1,7 @@
|
||||
name: Run VDB Tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- api/core/rag/datasource/**
|
||||
- docker/**
|
||||
- .github/workflows/vdb-tests.yml
|
||||
- api/uv.lock
|
||||
- api/pyproject.toml
|
||||
workflow_call:
|
||||
|
||||
concurrency:
|
||||
group: vdb-tests-${{ github.head_ref || github.run_id }}
|
||||
|
||||
6
.github/workflows/web-tests.yml
vendored
6
.github/workflows/web-tests.yml
vendored
@@ -1,11 +1,7 @@
|
||||
name: Web Tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- web/**
|
||||
workflow_call:
|
||||
|
||||
concurrency:
|
||||
group: web-tests-${{ github.head_ref || github.run_id }}
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -218,3 +218,6 @@ mise.toml
|
||||
.roo/
|
||||
api/.env.backup
|
||||
/clickzetta
|
||||
|
||||
# mcp
|
||||
.serena
|
||||
34
.mcp.json
Normal file
34
.mcp.json
Normal file
@@ -0,0 +1,34 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"context7": {
|
||||
"type": "http",
|
||||
"url": "https://mcp.context7.com/mcp"
|
||||
},
|
||||
"sequential-thinking": {
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-sequential-thinking"],
|
||||
"env": {}
|
||||
},
|
||||
"github": {
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-github"],
|
||||
"env": {
|
||||
"GITHUB_PERSONAL_ACCESS_TOKEN": "${GITHUB_PERSONAL_ACCESS_TOKEN}"
|
||||
}
|
||||
},
|
||||
"fetch": {
|
||||
"type": "stdio",
|
||||
"command": "uvx",
|
||||
"args": ["mcp-server-fetch"],
|
||||
"env": {}
|
||||
},
|
||||
"playwright": {
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"args": ["-y", "@playwright/mcp@latest"],
|
||||
"env": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -59,6 +59,7 @@ pnpm test # Run Jest tests
|
||||
- Use type hints for all functions and class attributes
|
||||
- No `Any` types unless absolutely necessary
|
||||
- Implement special methods (`__repr__`, `__str__`) appropriately
|
||||
- **Logging**: Never use `str(e)` in `logger.exception()` calls. Use `logger.exception("message", exc_info=e)` instead
|
||||
|
||||
### TypeScript/JavaScript
|
||||
|
||||
@@ -86,3 +87,4 @@ pnpm test # Run Jest tests
|
||||
## Project-Specific Conventions
|
||||
|
||||
- All async tasks use Celery with Redis as broker
|
||||
- **Internationalization**: Frontend supports multiple languages with English (`web/i18n/en-US/`) as the source. All user-facing text must use i18n keys, no hardcoded strings. Edit corresponding module files in `en-US/` directory for translations.
|
||||
|
||||
@@ -434,6 +434,9 @@ HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
|
||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY=True
|
||||
|
||||
# Webhook request configuration
|
||||
WEBHOOK_REQUEST_BODY_MAX_SIZE=10485760
|
||||
|
||||
# Respect X-* headers to redirect clients
|
||||
RESPECT_XFORWARD_HEADERS_ENABLED=false
|
||||
|
||||
@@ -502,6 +505,12 @@ ENABLE_CLEAN_MESSAGES=false
|
||||
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
|
||||
ENABLE_DATASETS_QUEUE_MONITOR=false
|
||||
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true
|
||||
ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK=true
|
||||
# Interval time in minutes for polling scheduled workflows(default: 1 min)
|
||||
WORKFLOW_SCHEDULE_POLLER_INTERVAL=1
|
||||
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE=100
|
||||
# Maximum number of scheduled workflows to dispatch per tick (0 for unlimited)
|
||||
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK=0
|
||||
|
||||
# Position configuration
|
||||
POSITION_TOOL_PINS=
|
||||
|
||||
2
api/.vscode/launch.json.example
vendored
2
api/.vscode/launch.json.example
vendored
@@ -54,7 +54,7 @@
|
||||
"--loglevel",
|
||||
"DEBUG",
|
||||
"-Q",
|
||||
"dataset,generation,mail,ops_trace,app_deletion"
|
||||
"dataset,generation,mail,ops_trace,app_deletion,workflow"
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
@@ -1207,6 +1207,55 @@ def setup_system_tool_oauth_client(provider, client_params):
|
||||
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
|
||||
|
||||
|
||||
@click.command("setup-system-trigger-oauth-client", help="Setup system trigger oauth client.")
|
||||
@click.option("--provider", prompt=True, help="Provider name")
|
||||
@click.option("--client-params", prompt=True, help="Client Params")
|
||||
def setup_system_trigger_oauth_client(provider, client_params):
|
||||
"""
|
||||
Setup system trigger oauth client
|
||||
"""
|
||||
from core.plugin.entities.plugin import TriggerProviderID
|
||||
from models.trigger import TriggerOAuthSystemClient
|
||||
|
||||
provider_id = TriggerProviderID(provider)
|
||||
provider_name = provider_id.provider_name
|
||||
plugin_id = provider_id.plugin_id
|
||||
|
||||
try:
|
||||
# json validate
|
||||
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
|
||||
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
|
||||
click.echo(click.style("Client params validated successfully.", fg="green"))
|
||||
|
||||
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
|
||||
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
||||
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
|
||||
click.echo(click.style("Client params encrypted successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
return
|
||||
|
||||
deleted_count = (
|
||||
db.session.query(TriggerOAuthSystemClient)
|
||||
.filter_by(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
if deleted_count > 0:
|
||||
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
|
||||
|
||||
oauth_client = TriggerOAuthSystemClient(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
encrypted_oauth_params=oauth_client_params,
|
||||
)
|
||||
db.session.add(oauth_client)
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
|
||||
|
||||
|
||||
def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
|
||||
"""
|
||||
Find draft variables that reference non-existent apps.
|
||||
|
||||
@@ -147,6 +147,17 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class TriggerConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for trigger
|
||||
"""
|
||||
|
||||
WEBHOOK_REQUEST_BODY_MAX_SIZE: PositiveInt = Field(
|
||||
description="Maximum allowed size for webhook request bodies in bytes",
|
||||
default=10485760,
|
||||
)
|
||||
|
||||
|
||||
class PluginConfig(BaseSettings):
|
||||
"""
|
||||
Plugin configs
|
||||
@@ -871,6 +882,22 @@ class CeleryScheduleTasksConfig(BaseSettings):
|
||||
description="Enable check upgradable plugin task",
|
||||
default=True,
|
||||
)
|
||||
ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK: bool = Field(
|
||||
description="Enable workflow schedule poller task",
|
||||
default=True,
|
||||
)
|
||||
WORKFLOW_SCHEDULE_POLLER_INTERVAL: int = Field(
|
||||
description="Workflow schedule poller interval in minutes",
|
||||
default=1,
|
||||
)
|
||||
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE: int = Field(
|
||||
description="Maximum number of schedules to process in each poll batch",
|
||||
default=100,
|
||||
)
|
||||
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK: int = Field(
|
||||
description="Maximum schedules to dispatch per tick (0=unlimited, circuit breaker)",
|
||||
default=0,
|
||||
)
|
||||
|
||||
|
||||
class PositionConfig(BaseSettings):
|
||||
@@ -994,6 +1021,7 @@ class FeatureConfig(
|
||||
AuthConfig, # Changed from OAuthConfig to AuthConfig
|
||||
BillingConfig,
|
||||
CodeExecutionSandboxConfig,
|
||||
TriggerConfig,
|
||||
PluginConfig,
|
||||
MarketplaceConfig,
|
||||
DataSetConfig,
|
||||
|
||||
@@ -8,6 +8,7 @@ if TYPE_CHECKING:
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
|
||||
@@ -33,3 +34,11 @@ plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(Cont
|
||||
plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
|
||||
ContextVar("plugin_model_schemas")
|
||||
)
|
||||
|
||||
plugin_trigger_providers: RecyclableContextVar[dict[str, "PluginTriggerProviderController"]] = RecyclableContextVar(
|
||||
ContextVar("plugin_trigger_providers")
|
||||
)
|
||||
|
||||
plugin_trigger_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
||||
ContextVar("plugin_trigger_providers_lock")
|
||||
)
|
||||
|
||||
@@ -67,10 +67,11 @@ from .app import (
|
||||
workflow_draft_variable,
|
||||
workflow_run,
|
||||
workflow_statistic,
|
||||
workflow_trigger,
|
||||
)
|
||||
|
||||
# Import auth controllers
|
||||
from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth
|
||||
from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth, oauth_server
|
||||
|
||||
# Import billing controllers
|
||||
from .billing import billing, compliance
|
||||
@@ -180,5 +181,6 @@ from .workspace import (
|
||||
models,
|
||||
plugin,
|
||||
tool_providers,
|
||||
trigger_providers,
|
||||
workspace,
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ from controllers.console.app.error import (
|
||||
)
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
@@ -125,13 +126,11 @@ class InstructionGenerateApi(Resource):
|
||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("ideal_output", type=str, required=False, default="", location="json")
|
||||
args = parser.parse_args()
|
||||
code_template = (
|
||||
Python3CodeProvider.get_default_code()
|
||||
if args["language"] == "python"
|
||||
else (JavascriptCodeProvider.get_default_code())
|
||||
if args["language"] == "javascript"
|
||||
else ""
|
||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||
code_provider: type[CodeNodeProvider] | None = next(
|
||||
(p for p in providers if p.is_accept_language(args["language"])), None
|
||||
)
|
||||
code_template = code_provider.get_default_code() if code_provider else ""
|
||||
try:
|
||||
# Generate from nothing for a workflow node
|
||||
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
|
||||
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from sqlalchemy import exists, select
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
from controllers.console import api
|
||||
@@ -94,21 +95,22 @@ class ChatMessageListApi(Resource):
|
||||
.all()
|
||||
)
|
||||
|
||||
has_more = False
|
||||
# Initialize has_more based on whether we have a full page
|
||||
if len(history_messages) == args["limit"]:
|
||||
current_page_first_message = history_messages[-1]
|
||||
rest_count = (
|
||||
db.session.query(Message)
|
||||
.where(
|
||||
Message.conversation_id == conversation.id,
|
||||
Message.created_at < current_page_first_message.created_at,
|
||||
Message.id != current_page_first_message.id,
|
||||
# Check if there are more messages before the current page
|
||||
has_more = db.session.scalar(
|
||||
select(
|
||||
exists().where(
|
||||
Message.conversation_id == conversation.id,
|
||||
Message.created_at < current_page_first_message.created_at,
|
||||
Message.id != current_page_first_message.id,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
else:
|
||||
# If we don't have a full page, there are no more messages
|
||||
has_more = False
|
||||
|
||||
history_messages = list(reversed(history_messages))
|
||||
|
||||
@@ -128,7 +130,7 @@ class MessageFeedbackApi(Resource):
|
||||
|
||||
message_id = str(args["message_id"])
|
||||
|
||||
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||
|
||||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
@@ -24,6 +24,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory, variable_factory
|
||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
@@ -38,6 +39,7 @@ from models.workflow import Workflow
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.trigger_debug_service import TriggerDebugService
|
||||
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -806,6 +808,132 @@ class DraftWorkflowNodeLastRunApi(Resource):
|
||||
return node_exec
|
||||
|
||||
|
||||
class DraftWorkflowTriggerNodeApi(Resource):
|
||||
"""
|
||||
Single node debug - Polling API for trigger events
|
||||
Path: /apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/trigger
|
||||
"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Poll for trigger events and execute single node when event arrives
|
||||
"""
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("trigger_name", type=str, required=True, location="json")
|
||||
parser.add_argument("subscription_id", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
trigger_name = args["trigger_name"]
|
||||
subscription_id = args["subscription_id"]
|
||||
event = TriggerDebugService.poll_event(
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=current_user.id,
|
||||
app_id=app_model.id,
|
||||
subscription_id=subscription_id,
|
||||
node_id=node_id,
|
||||
trigger_name=trigger_name,
|
||||
)
|
||||
if not event:
|
||||
return jsonable_encoder({"status": "waiting"})
|
||||
|
||||
try:
|
||||
workflow_service = WorkflowService()
|
||||
draft_workflow = workflow_service.get_draft_workflow(app_model)
|
||||
if not draft_workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
user_inputs = event.model_dump()
|
||||
node_execution = workflow_service.run_draft_workflow_node(
|
||||
app_model=app_model,
|
||||
draft_workflow=draft_workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
account=current_user,
|
||||
query="",
|
||||
files=[],
|
||||
)
|
||||
return jsonable_encoder(node_execution)
|
||||
except Exception:
|
||||
logger.exception("Error running draft workflow trigger node")
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"status": "error",
|
||||
}
|
||||
), 500
|
||||
|
||||
|
||||
class DraftWorkflowTriggerRunApi(Resource):
|
||||
"""
|
||||
Full workflow debug - Polling API for trigger events
|
||||
Path: /apps/<uuid:app_id>/workflows/draft/trigger/run
|
||||
"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
Poll for trigger events and execute full workflow when event arrives
|
||||
"""
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, location="json", nullable=False)
|
||||
parser.add_argument("trigger_name", type=str, required=True, location="json", nullable=False)
|
||||
parser.add_argument("subscription_id", type=str, required=True, location="json", nullable=False)
|
||||
args = parser.parse_args()
|
||||
node_id = args["node_id"]
|
||||
trigger_name = args["trigger_name"]
|
||||
subscription_id = args["subscription_id"]
|
||||
|
||||
event = TriggerDebugService.poll_event(
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=current_user.id,
|
||||
app_id=app_model.id,
|
||||
subscription_id=subscription_id,
|
||||
node_id=node_id,
|
||||
trigger_name=trigger_name,
|
||||
)
|
||||
if not event:
|
||||
return jsonable_encoder({"status": "waiting"})
|
||||
|
||||
workflow_args = {
|
||||
"inputs": event.model_dump(),
|
||||
"query": "",
|
||||
"files": [],
|
||||
}
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
if external_trace_id:
|
||||
workflow_args["external_trace_id"] = external_trace_id
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=workflow_args,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=True,
|
||||
)
|
||||
return helper.compact_generate_response(response)
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except Exception:
|
||||
logger.exception("Error running draft workflow trigger run")
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"status": "error",
|
||||
}
|
||||
), 500
|
||||
|
||||
|
||||
api.add_resource(
|
||||
DraftWorkflowApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft",
|
||||
@@ -830,6 +958,14 @@ api.add_resource(
|
||||
DraftWorkflowNodeRunApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
DraftWorkflowTriggerNodeApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/trigger",
|
||||
)
|
||||
api.add_resource(
|
||||
DraftWorkflowTriggerRunApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/trigger/run",
|
||||
)
|
||||
api.add_resource(
|
||||
AdvancedChatDraftRunIterationNodeApi,
|
||||
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run",
|
||||
|
||||
@@ -27,7 +27,9 @@ class WorkflowAppLogApi(Resource):
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
|
||||
parser.add_argument(
|
||||
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
|
||||
)
|
||||
parser.add_argument(
|
||||
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
|
||||
)
|
||||
|
||||
249
api/controllers/console/app/workflow_trigger.py
Normal file
249
api/controllers/console/app/workflow_trigger.py
Normal file
@@ -0,0 +1,249 @@
|
||||
import logging
|
||||
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
|
||||
from libs.login import current_user, login_required
|
||||
from models.model import Account, AppMode
|
||||
from models.workflow import AppTrigger, AppTriggerStatus, WorkflowWebhookTrigger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from services.workflow_plugin_trigger_service import WorkflowPluginTriggerService
|
||||
|
||||
|
||||
class PluginTriggerApi(Resource):
|
||||
"""Workflow Plugin Trigger API"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
def post(self, app_model):
|
||||
"""Create plugin trigger"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=False, location="json")
|
||||
parser.add_argument("provider_id", type=str, required=False, location="json")
|
||||
parser.add_argument("trigger_name", type=str, required=False, location="json")
|
||||
parser.add_argument("subscription_id", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
plugin_trigger = WorkflowPluginTriggerService.create_plugin_trigger(
|
||||
app_id=app_model.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
node_id=args["node_id"],
|
||||
provider_id=args["provider_id"],
|
||||
trigger_name=args["trigger_name"],
|
||||
subscription_id=args["subscription_id"],
|
||||
)
|
||||
|
||||
return jsonable_encoder(plugin_trigger)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
def get(self, app_model):
|
||||
"""Get plugin trigger"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||
args = parser.parse_args()
|
||||
|
||||
plugin_trigger = WorkflowPluginTriggerService.get_plugin_trigger(
|
||||
app_id=app_model.id,
|
||||
node_id=args["node_id"],
|
||||
)
|
||||
|
||||
return jsonable_encoder(plugin_trigger)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
def put(self, app_model):
|
||||
"""Update plugin trigger"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||
parser.add_argument("subscription_id", type=str, required=True, location="json", help="Subscription ID")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
plugin_trigger = WorkflowPluginTriggerService.update_plugin_trigger(
|
||||
app_id=app_model.id,
|
||||
node_id=args["node_id"],
|
||||
subscription_id=args["subscription_id"],
|
||||
)
|
||||
|
||||
return jsonable_encoder(plugin_trigger)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
def delete(self, app_model):
|
||||
"""Delete plugin trigger"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
WorkflowPluginTriggerService.delete_plugin_trigger(
|
||||
app_id=app_model.id,
|
||||
node_id=args["node_id"],
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class WebhookTriggerApi(Resource):
|
||||
"""Webhook Trigger API"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(webhook_trigger_fields)
|
||||
def get(self, app_model):
|
||||
"""Get webhook trigger for a node"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||
args = parser.parse_args()
|
||||
|
||||
node_id = args["node_id"]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Get webhook trigger for this app and node
|
||||
webhook_trigger = (
|
||||
session.query(WorkflowWebhookTrigger)
|
||||
.filter(
|
||||
WorkflowWebhookTrigger.app_id == app_model.id,
|
||||
WorkflowWebhookTrigger.node_id == node_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not webhook_trigger:
|
||||
raise NotFound("Webhook trigger not found for this node")
|
||||
|
||||
# Add computed fields for marshal_with
|
||||
base_url = dify_config.SERVICE_API_URL
|
||||
webhook_trigger.webhook_url = f"{base_url}/triggers/webhook/{webhook_trigger.webhook_id}" # type: ignore
|
||||
webhook_trigger.webhook_debug_url = f"{base_url}/triggers/webhook-debug/{webhook_trigger.webhook_id}" # type: ignore
|
||||
|
||||
return webhook_trigger
|
||||
|
||||
|
||||
class AppTriggersApi(Resource):
|
||||
"""App Triggers list API"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(triggers_list_fields)
|
||||
def get(self, app_model):
|
||||
"""Get app triggers list"""
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Get all triggers for this app using select API
|
||||
triggers = (
|
||||
session.execute(
|
||||
select(AppTrigger)
|
||||
.where(
|
||||
AppTrigger.tenant_id == current_user.current_tenant_id,
|
||||
AppTrigger.app_id == app_model.id,
|
||||
)
|
||||
.order_by(AppTrigger.created_at.desc(), AppTrigger.id.desc())
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
# Add computed icon field for each trigger
|
||||
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
|
||||
for trigger in triggers:
|
||||
if trigger.trigger_type == "trigger-plugin":
|
||||
trigger.icon = url_prefix + trigger.provider_name + "/icon" # type: ignore
|
||||
else:
|
||||
trigger.icon = "" # type: ignore
|
||||
|
||||
return {"data": triggers}
|
||||
|
||||
|
||||
class AppTriggerEnableApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(trigger_fields)
|
||||
def post(self, app_model):
|
||||
"""Update app trigger (enable/disable)"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("trigger_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("enable_trigger", type=bool, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
trigger_id = args["trigger_id"]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Find the trigger using select
|
||||
trigger = session.execute(
|
||||
select(AppTrigger).where(
|
||||
AppTrigger.id == trigger_id,
|
||||
AppTrigger.tenant_id == current_user.current_tenant_id,
|
||||
AppTrigger.app_id == app_model.id,
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not trigger:
|
||||
raise NotFound("Trigger not found")
|
||||
|
||||
# Update status based on enable_trigger boolean
|
||||
trigger.status = AppTriggerStatus.ENABLED if args["enable_trigger"] else AppTriggerStatus.DISABLED
|
||||
|
||||
session.commit()
|
||||
session.refresh(trigger)
|
||||
|
||||
# Add computed icon field
|
||||
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
|
||||
if trigger.trigger_type == "trigger-plugin":
|
||||
trigger.icon = url_prefix + trigger.provider_name + "/icon" # type: ignore
|
||||
else:
|
||||
trigger.icon = "" # type: ignore
|
||||
|
||||
return trigger
|
||||
|
||||
|
||||
api.add_resource(WebhookTriggerApi, "/apps/<uuid:app_id>/workflows/triggers/webhook")
|
||||
api.add_resource(PluginTriggerApi, "/apps/<uuid:app_id>/workflows/triggers/plugin")
|
||||
api.add_resource(AppTriggersApi, "/apps/<uuid:app_id>/triggers")
|
||||
api.add_resource(AppTriggerEnableApi, "/apps/<uuid:app_id>/trigger-enable")
|
||||
187
api/controllers/console/auth/oauth_server.py
Normal file
187
api/controllers/console/auth/oauth_server.py
Normal file
@@ -0,0 +1,187 @@
|
||||
from functools import wraps
|
||||
from typing import cast
|
||||
|
||||
import flask_login
|
||||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models.model import OAuthProviderApp
|
||||
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
|
||||
|
||||
from .. import api
|
||||
|
||||
|
||||
def oauth_server_client_id_required(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("client_id", type=str, required=True, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
client_id = parsed_args.get("client_id")
|
||||
if not client_id:
|
||||
raise BadRequest("client_id is required")
|
||||
|
||||
oauth_provider_app = OAuthServerService.get_oauth_provider_app(client_id)
|
||||
if not oauth_provider_app:
|
||||
raise NotFound("client_id is invalid")
|
||||
|
||||
kwargs["oauth_provider_app"] = oauth_provider_app
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def oauth_server_access_token_required(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
oauth_provider_app = kwargs.get("oauth_provider_app")
|
||||
if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp):
|
||||
raise BadRequest("Invalid oauth_provider_app")
|
||||
|
||||
authorization_header = request.headers.get("Authorization")
|
||||
if not authorization_header:
|
||||
raise BadRequest("Authorization header is required")
|
||||
|
||||
parts = authorization_header.strip().split(" ")
|
||||
if len(parts) != 2:
|
||||
raise BadRequest("Invalid Authorization header format")
|
||||
|
||||
token_type = parts[0].strip()
|
||||
if token_type.lower() != "bearer":
|
||||
raise BadRequest("token_type is invalid")
|
||||
|
||||
access_token = parts[1].strip()
|
||||
if not access_token:
|
||||
raise BadRequest("access_token is required")
|
||||
|
||||
account = OAuthServerService.validate_oauth_access_token(oauth_provider_app.client_id, access_token)
|
||||
if not account:
|
||||
raise BadRequest("access_token or client_id is invalid")
|
||||
|
||||
kwargs["account"] = account
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
class OAuthServerAppApi(Resource):
|
||||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("redirect_uri", type=str, required=True, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
redirect_uri = parsed_args.get("redirect_uri")
|
||||
|
||||
# check if redirect_uri is valid
|
||||
if redirect_uri not in oauth_provider_app.redirect_uris:
|
||||
raise BadRequest("redirect_uri is invalid")
|
||||
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"app_icon": oauth_provider_app.app_icon,
|
||||
"app_label": oauth_provider_app.app_label,
|
||||
"scope": oauth_provider_app.scope,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class OAuthServerUserAuthorizeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
account = cast(Account, flask_login.current_user)
|
||||
user_account_id = account.id
|
||||
|
||||
code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"code": code,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class OAuthServerUserTokenApi(Resource):
|
||||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("grant_type", type=str, required=True, location="json")
|
||||
parser.add_argument("code", type=str, required=False, location="json")
|
||||
parser.add_argument("client_secret", type=str, required=False, location="json")
|
||||
parser.add_argument("redirect_uri", type=str, required=False, location="json")
|
||||
parser.add_argument("refresh_token", type=str, required=False, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
|
||||
try:
|
||||
grant_type = OAuthGrantType(parsed_args["grant_type"])
|
||||
except ValueError:
|
||||
raise BadRequest("invalid grant_type")
|
||||
|
||||
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
|
||||
if not parsed_args["code"]:
|
||||
raise BadRequest("code is required")
|
||||
|
||||
if parsed_args["client_secret"] != oauth_provider_app.client_secret:
|
||||
raise BadRequest("client_secret is invalid")
|
||||
|
||||
if parsed_args["redirect_uri"] not in oauth_provider_app.redirect_uris:
|
||||
raise BadRequest("redirect_uri is invalid")
|
||||
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type, code=parsed_args["code"], client_id=oauth_provider_app.client_id
|
||||
)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
)
|
||||
elif grant_type == OAuthGrantType.REFRESH_TOKEN:
|
||||
if not parsed_args["refresh_token"]:
|
||||
raise BadRequest("refresh_token is required")
|
||||
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type, refresh_token=parsed_args["refresh_token"], client_id=oauth_provider_app.client_id
|
||||
)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class OAuthServerUserAccountApi(Resource):
|
||||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
@oauth_server_access_token_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp, account: Account):
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"name": account.name,
|
||||
"email": account.email,
|
||||
"avatar": account.avatar,
|
||||
"interface_language": account.interface_language,
|
||||
"timezone": account.timezone,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
api.add_resource(OAuthServerAppApi, "/oauth/provider")
|
||||
api.add_resource(OAuthServerUserAuthorizeApi, "/oauth/provider/authorize")
|
||||
api.add_resource(OAuthServerUserTokenApi, "/oauth/provider/token")
|
||||
api.add_resource(OAuthServerUserAccountApi, "/oauth/provider/account")
|
||||
@@ -516,18 +516,20 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
|
||||
parser.add_argument("provider", type=str, required=True, location="args")
|
||||
parser.add_argument("action", type=str, required=True, location="args")
|
||||
parser.add_argument("parameter", type=str, required=True, location="args")
|
||||
parser.add_argument("credential_id", type=str, required=False, location="args")
|
||||
parser.add_argument("provider_type", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
options = PluginParameterService.get_dynamic_select_options(
|
||||
tenant_id,
|
||||
user_id,
|
||||
args["plugin_id"],
|
||||
args["provider"],
|
||||
args["action"],
|
||||
args["parameter"],
|
||||
args["provider_type"],
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=args["plugin_id"],
|
||||
provider=args["provider"],
|
||||
action=args["action"],
|
||||
parameter=args["parameter"],
|
||||
credential_id=args["credential_id"],
|
||||
provider_type=args["provider_type"],
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
@@ -22,8 +22,8 @@ from core.mcp.error import MCPAuthError, MCPError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.tools.entities.tool_entities import CredentialType
|
||||
from libs.helper import StrLen, alphanumeric, uuid_value
|
||||
from libs.login import login_required
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
|
||||
589
api/controllers/console/workspace/trigger_providers.py
Normal file
589
api/controllers/console/workspace/trigger_providers.py
Normal file
@@ -0,0 +1,589 @@
|
||||
import logging
|
||||
|
||||
from flask import make_response, redirect, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin import TriggerProviderID
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.trigger.entities.entities import SubscriptionBuilderUpdater
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService
|
||||
from services.workflow_plugin_trigger_service import WorkflowPluginTriggerService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerProviderListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
"""List all trigger providers for the current tenant"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
return jsonable_encoder(TriggerProviderService.list_trigger_providers(user.current_tenant_id))
|
||||
|
||||
|
||||
class TriggerProviderInfoApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""Get info for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
return jsonable_encoder(
|
||||
TriggerProviderService.get_trigger_provider(user.current_tenant_id, TriggerProviderID(provider))
|
||||
)
|
||||
|
||||
|
||||
class TriggerSubscriptionListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""List all trigger subscriptions for the current tenant's provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
TriggerProviderService.list_trigger_provider_subscriptions(
|
||||
tenant_id=user.current_tenant_id, provider_id=TriggerProviderID(provider)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error listing trigger providers", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderCreateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
"""Add a new subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_type", type=str, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
credential_type = CredentialType.of(args.get("credential_type") or CredentialType.UNAUTHORIZED.value)
|
||||
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
credential_type=credential_type,
|
||||
)
|
||||
return jsonable_encoder({"subscription_builder": subscription_builder})
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error adding provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderGetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider, subscription_builder_id):
|
||||
"""Get a subscription instance for a trigger provider"""
|
||||
return jsonable_encoder(
|
||||
TriggerSubscriptionBuilderService.get_subscription_builder_by_id(subscription_builder_id)
|
||||
)
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderVerifyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
"""Verify a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
# The credentials of the subscription builder
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
credentials=args.get("credentials", None),
|
||||
),
|
||||
)
|
||||
return TriggerSubscriptionBuilderService.verify_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error verifying provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
"""Update a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
# The name of the subscription builder
|
||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
# The parameters of the subscription builder
|
||||
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
|
||||
# The properties of the subscription builder
|
||||
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json")
|
||||
# The credentials of the subscription builder
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
name=args.get("name", None),
|
||||
parameters=args.get("parameters", None),
|
||||
properties=args.get("properties", None),
|
||||
credentials=args.get("credentials", None),
|
||||
),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error updating provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderLogsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider, subscription_builder_id):
|
||||
"""Get the request logs for a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
try:
|
||||
logs = TriggerSubscriptionBuilderService.list_logs(subscription_builder_id)
|
||||
return jsonable_encoder({"logs": [log.model_dump(mode="json") for log in logs]})
|
||||
except Exception as e:
|
||||
logger.exception("Error getting request logs for subscription builder", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderBuildApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
"""Build a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
# The name of the subscription builder
|
||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
# The parameters of the subscription builder
|
||||
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
|
||||
# The properties of the subscription builder
|
||||
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json")
|
||||
# The credentials of the subscription builder
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
name=args.get("name", None),
|
||||
parameters=args.get("parameters", None),
|
||||
properties=args.get("properties", None),
|
||||
),
|
||||
)
|
||||
TriggerSubscriptionBuilderService.build_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
)
|
||||
return 200
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error building provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, subscription_id):
|
||||
"""Delete a subscription instance"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
# Delete trigger provider subscription
|
||||
TriggerProviderService.delete_trigger_provider(
|
||||
session=session,
|
||||
tenant_id=user.current_tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
)
|
||||
# Delete plugin triggers
|
||||
WorkflowPluginTriggerService.delete_plugin_trigger_by_subscription(
|
||||
session=session,
|
||||
tenant_id=user.current_tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
)
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error deleting provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerOAuthAuthorizeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""Initiate OAuth authorization flow for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
plugin_id = provider_id.plugin_id
|
||||
provider_name = provider_id.provider_name
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
# Get OAuth client configuration
|
||||
oauth_client_params = TriggerProviderService.get_oauth_client(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
if oauth_client_params is None:
|
||||
raise Forbidden("No OAuth client configuration found for this trigger provider")
|
||||
|
||||
# Create subscription builder
|
||||
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=provider_id,
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
)
|
||||
|
||||
# Create OAuth handler and proxy context
|
||||
oauth_handler = OAuthHandler()
|
||||
context_id = OAuthProxyService.create_proxy_context(
|
||||
user_id=user.id,
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
extra_data={
|
||||
"subscription_builder_id": subscription_builder.id,
|
||||
},
|
||||
)
|
||||
|
||||
# Build redirect URI for callback
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
|
||||
|
||||
# Get authorization URL
|
||||
authorization_url_response = oauth_handler.get_authorization_url(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_client_params,
|
||||
)
|
||||
|
||||
# Create response with cookie
|
||||
response = make_response(
|
||||
jsonable_encoder(
|
||||
{
|
||||
"authorization_url": authorization_url_response.authorization_url,
|
||||
"subscription_builder_id": subscription_builder.id,
|
||||
"subscription_builder": subscription_builder,
|
||||
}
|
||||
)
|
||||
)
|
||||
response.set_cookie(
|
||||
"context_id",
|
||||
context_id,
|
||||
httponly=True,
|
||||
samesite="Lax",
|
||||
max_age=OAuthProxyService.__MAX_AGE__,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error initiating OAuth flow", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerOAuthCallbackApi(Resource):
|
||||
@setup_required
|
||||
def get(self, provider):
|
||||
"""Handle OAuth callback for trigger provider"""
|
||||
context_id = request.cookies.get("context_id")
|
||||
if not context_id:
|
||||
raise Forbidden("context_id not found")
|
||||
|
||||
# Use and validate proxy context
|
||||
context = OAuthProxyService.use_proxy_context(context_id)
|
||||
if context is None:
|
||||
raise Forbidden("Invalid context_id")
|
||||
|
||||
# Parse provider ID
|
||||
provider_id = TriggerProviderID(provider)
|
||||
plugin_id = provider_id.plugin_id
|
||||
provider_name = provider_id.provider_name
|
||||
user_id = context.get("user_id")
|
||||
tenant_id = context.get("tenant_id")
|
||||
subscription_builder_id = context.get("subscription_builder_id")
|
||||
|
||||
# Get OAuth client configuration
|
||||
oauth_client_params = TriggerProviderService.get_oauth_client(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
if oauth_client_params is None:
|
||||
raise Forbidden("No OAuth client configuration found for this trigger provider")
|
||||
|
||||
# Get OAuth credentials from callback
|
||||
oauth_handler = OAuthHandler()
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
|
||||
|
||||
credentials_response = oauth_handler.get_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_client_params,
|
||||
request=request,
|
||||
)
|
||||
|
||||
credentials = credentials_response.credentials
|
||||
expires_at = credentials_response.expires_at
|
||||
|
||||
if not credentials:
|
||||
raise Exception("Failed to get OAuth credentials")
|
||||
|
||||
# Update subscription builder
|
||||
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
credentials=credentials,
|
||||
credential_expires_at=expires_at,
|
||||
),
|
||||
)
|
||||
# Redirect to OAuth callback page
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
class TriggerOAuthClientManageApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""Get OAuth client configuration for a provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
|
||||
# Get custom OAuth client params if exists
|
||||
custom_params = TriggerProviderService.get_custom_oauth_client_params(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
# Check if custom client is enabled
|
||||
is_custom_enabled = TriggerProviderService.is_oauth_custom_client_enabled(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
# Check if there's a system OAuth client
|
||||
system_client = TriggerProviderService.get_oauth_client(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
provider_controller = TriggerManager.get_trigger_provider(user.current_tenant_id, provider_id)
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"configured": bool(custom_params or system_client),
|
||||
"oauth_client_schema": provider_controller.get_oauth_client_schema(),
|
||||
"custom_configured": bool(custom_params),
|
||||
"custom_enabled": is_custom_enabled,
|
||||
"redirect_uri": redirect_uri,
|
||||
"params": custom_params if custom_params else {},
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error getting OAuth client", exc_info=e)
|
||||
raise
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
"""Configure custom OAuth client for a provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
parser.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
return TriggerProviderService.save_custom_oauth_client_params(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
client_params=args.get("client_params"),
|
||||
enabled=args.get("enabled"),
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error configuring OAuth client", exc_info=e)
|
||||
raise
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider):
|
||||
"""Remove custom OAuth client configuration"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
|
||||
return TriggerProviderService.delete_custom_oauth_client_params(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error removing OAuth client", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
# Trigger Subscription
|
||||
api.add_resource(TriggerProviderListApi, "/workspaces/current/triggers")
|
||||
api.add_resource(TriggerProviderInfoApi, "/workspaces/current/trigger-provider/<path:provider>/info")
|
||||
api.add_resource(TriggerSubscriptionListApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/list")
|
||||
api.add_resource(
|
||||
TriggerSubscriptionDeleteApi,
|
||||
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/delete",
|
||||
)
|
||||
|
||||
# Trigger Subscription Builder
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderCreateApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/create",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderGetApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/<path:subscription_builder_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderUpdateApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/update/<path:subscription_builder_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderVerifyApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify/<path:subscription_builder_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderBuildApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/build/<path:subscription_builder_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderLogsApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/logs/<path:subscription_builder_id>",
|
||||
)
|
||||
|
||||
|
||||
# OAuth
|
||||
api.add_resource(
|
||||
TriggerOAuthAuthorizeApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/oauth/authorize"
|
||||
)
|
||||
api.add_resource(TriggerOAuthCallbackApi, "/oauth/plugin/<path:provider>/trigger/callback")
|
||||
api.add_resource(TriggerOAuthClientManageApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/client")
|
||||
@@ -1,8 +1,12 @@
|
||||
from base64 import b64encode
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from hashlib import sha1
|
||||
from hmac import new as hmac_new
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
from flask import abort, request
|
||||
|
||||
from configs import dify_config
|
||||
@@ -10,9 +14,9 @@ from extensions.ext_database import db
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
def billing_inner_api_only(view):
|
||||
def billing_inner_api_only(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if not dify_config.INNER_API:
|
||||
abort(404)
|
||||
|
||||
@@ -26,9 +30,9 @@ def billing_inner_api_only(view):
|
||||
return decorated
|
||||
|
||||
|
||||
def enterprise_inner_api_only(view):
|
||||
def enterprise_inner_api_only(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if not dify_config.INNER_API:
|
||||
abort(404)
|
||||
|
||||
@@ -78,9 +82,9 @@ def enterprise_inner_api_user_auth(view):
|
||||
return decorated
|
||||
|
||||
|
||||
def plugin_inner_api_only(view):
|
||||
def plugin_inner_api_only(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if not dify_config.PLUGIN_DAEMON_KEY:
|
||||
abort(404)
|
||||
|
||||
|
||||
@@ -1,20 +1,28 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource, reqparse
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console.app.mcp_server import AppMCPServerStatus
|
||||
from controllers.mcp import mcp_ns
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.mcp import types
|
||||
from core.mcp.server.streamable_http import MCPServerStreamableHTTPRequestHandler
|
||||
from core.mcp.types import ClientNotification, ClientRequest
|
||||
from core.mcp.utils import create_mcp_error_response
|
||||
from core.mcp import types as mcp_types
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.model import App, AppMCPServer, AppMode
|
||||
|
||||
|
||||
class MCPRequestError(Exception):
|
||||
"""Custom exception for MCP request processing errors"""
|
||||
|
||||
def __init__(self, error_code: int, message: str):
|
||||
self.error_code = error_code
|
||||
self.message = message
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def int_or_str(value):
|
||||
"""Validate that a value is either an integer or string."""
|
||||
if isinstance(value, (int, str)):
|
||||
@@ -63,76 +71,128 @@ class MCPAppApi(Resource):
|
||||
Raises:
|
||||
ValidationError: Invalid request format or parameters
|
||||
"""
|
||||
# Parse and validate all arguments
|
||||
args = mcp_request_parser.parse_args()
|
||||
|
||||
request_id: Optional[Union[int, str]] = args.get("id")
|
||||
mcp_request = self._parse_mcp_request(args)
|
||||
|
||||
server = db.session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
|
||||
if not server:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server Not Found")
|
||||
)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get MCP server and app
|
||||
mcp_server, app = self._get_mcp_server_and_app(server_code, session)
|
||||
self._validate_server_status(mcp_server)
|
||||
|
||||
if server.status != AppMCPServerStatus.ACTIVE:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server is not active")
|
||||
)
|
||||
# Get user input form
|
||||
user_input_form = self._get_user_input_form(app)
|
||||
|
||||
app = db.session.query(App).where(App.id == server.app_id).first()
|
||||
# Handle notification vs request differently
|
||||
return self._process_mcp_message(mcp_request, request_id, app, mcp_server, user_input_form, session)
|
||||
|
||||
def _get_mcp_server_and_app(self, server_code: str, session: Session) -> tuple[AppMCPServer, App]:
|
||||
"""Get and validate MCP server and app in one query session"""
|
||||
mcp_server = session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
|
||||
if not mcp_server:
|
||||
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server Not Found")
|
||||
|
||||
app = session.query(App).where(App.id == mcp_server.app_id).first()
|
||||
if not app:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App Not Found")
|
||||
)
|
||||
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App Not Found")
|
||||
|
||||
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
workflow = app.workflow
|
||||
if workflow is None:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable")
|
||||
)
|
||||
return mcp_server, app
|
||||
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
def _validate_server_status(self, mcp_server: AppMCPServer) -> None:
|
||||
"""Validate MCP server status"""
|
||||
if mcp_server.status != AppMCPServerStatus.ACTIVE:
|
||||
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server is not active")
|
||||
|
||||
def _process_mcp_message(
|
||||
self,
|
||||
mcp_request: mcp_types.ClientRequest | mcp_types.ClientNotification,
|
||||
request_id: Optional[Union[int, str]],
|
||||
app: App,
|
||||
mcp_server: AppMCPServer,
|
||||
user_input_form: list[VariableEntity],
|
||||
session: Session,
|
||||
) -> Response:
|
||||
"""Process MCP message (notification or request)"""
|
||||
if isinstance(mcp_request, mcp_types.ClientNotification):
|
||||
return self._handle_notification(mcp_request)
|
||||
else:
|
||||
app_model_config = app.app_model_config
|
||||
if app_model_config is None:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable")
|
||||
)
|
||||
return self._handle_request(mcp_request, request_id, app, mcp_server, user_input_form, session)
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
converted_user_input_form: list[VariableEntity] = []
|
||||
try:
|
||||
for item in user_input_form:
|
||||
variable_type = item.get("type", "") or list(item.keys())[0]
|
||||
variable = item[variable_type]
|
||||
converted_user_input_form.append(
|
||||
VariableEntity(
|
||||
type=variable_type,
|
||||
variable=variable.get("variable"),
|
||||
description=variable.get("description") or "",
|
||||
label=variable.get("label"),
|
||||
required=variable.get("required", False),
|
||||
max_length=variable.get("max_length"),
|
||||
options=variable.get("options") or [],
|
||||
)
|
||||
)
|
||||
except ValidationError as e:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
|
||||
)
|
||||
def _handle_notification(self, mcp_request: mcp_types.ClientNotification) -> Response:
|
||||
"""Handle MCP notification"""
|
||||
# For notifications, only support init notification
|
||||
if mcp_request.root.method != "notifications/initialized":
|
||||
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Invalid notification method")
|
||||
# Return HTTP 202 Accepted for notifications (no response body)
|
||||
return Response("", status=202, content_type="application/json")
|
||||
|
||||
def _handle_request(
|
||||
self,
|
||||
mcp_request: mcp_types.ClientRequest,
|
||||
request_id: Optional[Union[int, str]],
|
||||
app: App,
|
||||
mcp_server: AppMCPServer,
|
||||
user_input_form: list[VariableEntity],
|
||||
session: Session,
|
||||
) -> Response:
|
||||
"""Handle MCP request"""
|
||||
if request_id is None:
|
||||
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Request ID is required")
|
||||
|
||||
result = self._handle_mcp_request(app, mcp_server, mcp_request, user_input_form, session, request_id)
|
||||
if result is None:
|
||||
# This shouldn't happen for requests, but handle gracefully
|
||||
raise MCPRequestError(mcp_types.INTERNAL_ERROR, "No response generated for request")
|
||||
|
||||
return helper.compact_generate_response(result.model_dump(by_alias=True, mode="json", exclude_none=True))
|
||||
|
||||
def _get_user_input_form(self, app: App) -> list[VariableEntity]:
|
||||
"""Get and convert user input form"""
|
||||
# Get raw user input form based on app mode
|
||||
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
if not app.workflow:
|
||||
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App is unavailable")
|
||||
raw_user_input_form = app.workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
if not app.app_model_config:
|
||||
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App is unavailable")
|
||||
features_dict = app.app_model_config.to_dict()
|
||||
raw_user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
# Convert to VariableEntity objects
|
||||
try:
|
||||
request: ClientRequest | ClientNotification = ClientRequest.model_validate(args)
|
||||
return self._convert_user_input_form(raw_user_input_form)
|
||||
except ValidationError as e:
|
||||
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
|
||||
|
||||
def _convert_user_input_form(self, raw_form: list[dict]) -> list[VariableEntity]:
|
||||
"""Convert raw user input form to VariableEntity objects"""
|
||||
return [self._create_variable_entity(item) for item in raw_form]
|
||||
|
||||
def _create_variable_entity(self, item: dict) -> VariableEntity:
|
||||
"""Create a single VariableEntity from raw form item"""
|
||||
variable_type = item.get("type", "") or list(item.keys())[0]
|
||||
variable = item[variable_type]
|
||||
|
||||
return VariableEntity(
|
||||
type=variable_type,
|
||||
variable=variable.get("variable"),
|
||||
description=variable.get("description") or "",
|
||||
label=variable.get("label"),
|
||||
required=variable.get("required", False),
|
||||
max_length=variable.get("max_length"),
|
||||
options=variable.get("options") or [],
|
||||
)
|
||||
|
||||
def _parse_mcp_request(self, args: dict) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
|
||||
"""Parse and validate MCP request"""
|
||||
try:
|
||||
return mcp_types.ClientRequest.model_validate(args)
|
||||
except ValidationError:
|
||||
try:
|
||||
notification = ClientNotification.model_validate(args)
|
||||
request = notification
|
||||
return mcp_types.ClientNotification.model_validate(args)
|
||||
except ValidationError as e:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
|
||||
)
|
||||
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
|
||||
|
||||
mcp_server_handler = MCPServerStreamableHTTPRequestHandler(app, request, converted_user_input_form)
|
||||
response = mcp_server_handler.handle()
|
||||
|
||||
@@ -318,10 +318,6 @@ class DatasetApi(DatasetApiResource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
data = marshal(dataset, dataset_detail_fields)
|
||||
if data.get("permission") == "partial_members":
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
data.update({"partial_member_list": part_users_list})
|
||||
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_login import current_user
|
||||
from flask_restx import marshal, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
from enum import StrEnum, auto
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in # type: ignore
|
||||
from flask_login import user_logged_in
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select, update
|
||||
@@ -23,14 +23,14 @@ from models.model import ApiToken, App, EndUser
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
class WhereisUserArg(Enum):
|
||||
class WhereisUserArg(StrEnum):
|
||||
"""
|
||||
Enum for whereis_user_arg.
|
||||
"""
|
||||
|
||||
QUERY = "query"
|
||||
JSON = "json"
|
||||
FORM = "form"
|
||||
QUERY = auto()
|
||||
JSON = auto()
|
||||
FORM = auto()
|
||||
|
||||
|
||||
class FetchUserArg(BaseModel):
|
||||
|
||||
7
api/controllers/trigger/__init__.py
Normal file
7
api/controllers/trigger/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from flask import Blueprint
|
||||
|
||||
# Create trigger blueprint
|
||||
bp = Blueprint("trigger", __name__, url_prefix="/triggers")
|
||||
|
||||
# Import routes after blueprint creation to avoid circular imports
|
||||
from . import trigger, webhook
|
||||
41
api/controllers/trigger/trigger.py
Normal file
41
api/controllers/trigger/trigger.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import logging
|
||||
import re
|
||||
|
||||
from flask import jsonify, request
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.trigger import bp
|
||||
from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService
|
||||
from services.trigger_service import TriggerService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
UUID_PATTERN = r"^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$"
|
||||
UUID_MATCHER = re.compile(UUID_PATTERN)
|
||||
|
||||
|
||||
@bp.route("/plugin/<string:endpoint_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
|
||||
def trigger_endpoint(endpoint_id: str):
|
||||
"""
|
||||
Handle endpoint trigger calls.
|
||||
"""
|
||||
# endpoint_id must be UUID
|
||||
if not UUID_MATCHER.match(endpoint_id):
|
||||
raise NotFound("Invalid endpoint ID")
|
||||
handling_chain = [
|
||||
TriggerService.process_endpoint,
|
||||
TriggerSubscriptionBuilderService.process_builder_validation_endpoint,
|
||||
]
|
||||
try:
|
||||
for handler in handling_chain:
|
||||
response = handler(endpoint_id, request)
|
||||
if response:
|
||||
break
|
||||
if not response:
|
||||
raise NotFound("Endpoint not found")
|
||||
return response
|
||||
except ValueError as e:
|
||||
raise NotFound(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Webhook processing failed for {endpoint_id}")
|
||||
return jsonify({"error": "Internal server error", "message": str(e)}), 500
|
||||
46
api/controllers/trigger/webhook.py
Normal file
46
api/controllers/trigger/webhook.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import logging
|
||||
|
||||
from flask import jsonify
|
||||
from werkzeug.exceptions import NotFound, RequestEntityTooLarge
|
||||
|
||||
from controllers.trigger import bp
|
||||
from services.webhook_service import WebhookService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@bp.route("/webhook/<string:webhook_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
|
||||
@bp.route("/webhook-debug/<string:webhook_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
|
||||
def handle_webhook(webhook_id: str):
|
||||
"""
|
||||
Handle webhook trigger calls.
|
||||
|
||||
This endpoint receives webhook calls and processes them according to the
|
||||
configured webhook trigger settings.
|
||||
"""
|
||||
try:
|
||||
# Get webhook trigger, workflow, and node configuration
|
||||
webhook_trigger, workflow, node_config = WebhookService.get_webhook_trigger_and_workflow(webhook_id)
|
||||
|
||||
# Extract request data
|
||||
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
|
||||
|
||||
# Validate request against node configuration
|
||||
validation_result = WebhookService.validate_webhook_request(webhook_data, node_config)
|
||||
if not validation_result["valid"]:
|
||||
return jsonify({"error": "Bad Request", "message": validation_result["error"]}), 400
|
||||
|
||||
# Process webhook call (send to Celery)
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
|
||||
|
||||
# Return configured response
|
||||
response_data, status_code = WebhookService.generate_webhook_response(node_config)
|
||||
return jsonify(response_data), status_code
|
||||
|
||||
except ValueError as e:
|
||||
raise NotFound(str(e))
|
||||
except RequestEntityTooLarge:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Webhook processing failed for %s", webhook_id)
|
||||
return jsonify({"error": "Internal server error", "message": str(e)}), 500
|
||||
@@ -1,19 +1,20 @@
|
||||
from flask import Blueprint
|
||||
from flask_restx import Namespace
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
from .files import FileApi
|
||||
from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi
|
||||
|
||||
bp = Blueprint("web", __name__, url_prefix="/api")
|
||||
api = ExternalApi(bp)
|
||||
|
||||
# Files
|
||||
api.add_resource(FileApi, "/files/upload")
|
||||
api = ExternalApi(
|
||||
bp,
|
||||
version="1.0",
|
||||
title="Web API",
|
||||
description="Public APIs for web applications including file uploads, chat interactions, and app management",
|
||||
doc="/docs", # Enable Swagger UI at /api/docs
|
||||
)
|
||||
|
||||
# Remote files
|
||||
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
|
||||
api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
|
||||
# Create namespace
|
||||
web_ns = Namespace("web", description="Web application API operations", path="/")
|
||||
|
||||
from . import (
|
||||
app,
|
||||
@@ -21,11 +22,15 @@ from . import (
|
||||
completion,
|
||||
conversation,
|
||||
feature,
|
||||
files,
|
||||
forgot_password,
|
||||
login,
|
||||
message,
|
||||
passport,
|
||||
remote_files,
|
||||
saved_message,
|
||||
site,
|
||||
workflow,
|
||||
)
|
||||
|
||||
api.add_namespace(web_ns)
|
||||
|
||||
@@ -5,7 +5,7 @@ from flask_restx import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from controllers.common import fields
|
||||
from controllers.web import api
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import AppUnavailableError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
@@ -19,9 +19,22 @@ from services.webapp_auth_service import WebAppAuthService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@web_ns.route("/parameters")
|
||||
class AppParameterApi(WebApiResource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
@web_ns.doc("Get App Parameters")
|
||||
@web_ns.doc(description="Retrieve the parameters for a specific app.")
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
404: "App Not Found",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(fields.parameters_fields)
|
||||
def get(self, app_model: App, end_user):
|
||||
"""Retrieve app parameters."""
|
||||
@@ -44,13 +57,42 @@ class AppParameterApi(WebApiResource):
|
||||
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
|
||||
|
||||
@web_ns.route("/meta")
|
||||
class AppMeta(WebApiResource):
|
||||
@web_ns.doc("Get App Meta")
|
||||
@web_ns.doc(description="Retrieve the metadata for a specific app.")
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
404: "App Not Found",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def get(self, app_model: App, end_user):
|
||||
"""Get app meta"""
|
||||
return AppService().get_app_meta(app_model)
|
||||
|
||||
|
||||
@web_ns.route("/webapp/access-mode")
|
||||
class AppAccessMode(Resource):
|
||||
@web_ns.doc("Get App Access Mode")
|
||||
@web_ns.doc(description="Retrieve the access mode for a web application (public or restricted).")
|
||||
@web_ns.doc(
|
||||
params={
|
||||
"appId": {"description": "Application ID", "type": "string", "required": False},
|
||||
"appCode": {"description": "Application code", "type": "string", "required": False},
|
||||
}
|
||||
)
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("appId", type=str, required=False, location="args")
|
||||
@@ -74,7 +116,19 @@ class AppAccessMode(Resource):
|
||||
return {"accessMode": res.access_mode}
|
||||
|
||||
|
||||
@web_ns.route("/webapp/permission")
|
||||
class AppWebAuthPermission(Resource):
|
||||
@web_ns.doc("Check App Permission")
|
||||
@web_ns.doc(description="Check if user has permission to access a web application.")
|
||||
@web_ns.doc(params={"appId": {"description": "Application ID", "type": "string", "required": True}})
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
401: "Unauthorized",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def get(self):
|
||||
user_id = "visitor"
|
||||
try:
|
||||
@@ -112,10 +166,3 @@ class AppWebAuthPermission(Resource):
|
||||
if WebAppAuthService.is_app_require_permission_check(app_id=app_id):
|
||||
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(str(user_id), app_code)
|
||||
return {"result": res}
|
||||
|
||||
|
||||
api.add_resource(AppParameterApi, "/parameters")
|
||||
api.add_resource(AppMeta, "/meta")
|
||||
# webapp auth apis
|
||||
api.add_resource(AppAccessMode, "/webapp/access-mode")
|
||||
api.add_resource(AppWebAuthPermission, "/webapp/permission")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import fields, marshal_with, reqparse
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
@@ -32,7 +33,26 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AudioApi(WebApiResource):
|
||||
audio_to_text_response_fields = {
|
||||
"text": fields.String,
|
||||
}
|
||||
|
||||
@marshal_with(audio_to_text_response_fields)
|
||||
@api.doc("Audio to Text")
|
||||
@api.doc(description="Convert audio file to text using speech-to-text service.")
|
||||
@api.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
413: "Audio file too large",
|
||||
415: "Unsupported audio type",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def post(self, app_model: App, end_user):
|
||||
"""Convert audio to text"""
|
||||
file = request.files["file"]
|
||||
|
||||
try:
|
||||
@@ -66,9 +86,25 @@ class AudioApi(WebApiResource):
|
||||
|
||||
|
||||
class TextApi(WebApiResource):
|
||||
def post(self, app_model: App, end_user):
|
||||
from flask_restx import reqparse
|
||||
text_to_audio_response_fields = {
|
||||
"audio_url": fields.String,
|
||||
"duration": fields.Float,
|
||||
}
|
||||
|
||||
@marshal_with(text_to_audio_response_fields)
|
||||
@api.doc("Text to Audio")
|
||||
@api.doc(description="Convert text to audio using text-to-speech service.")
|
||||
@api.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def post(self, app_model: App, end_user):
|
||||
"""Convert text to audio"""
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", type=str, required=False, location="json")
|
||||
|
||||
@@ -36,6 +36,32 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# define completion api for user
|
||||
class CompletionApi(WebApiResource):
|
||||
@api.doc("Create Completion Message")
|
||||
@api.doc(description="Create a completion message for text generation applications.")
|
||||
@api.doc(
|
||||
params={
|
||||
"inputs": {"description": "Input variables for the completion", "type": "object", "required": True},
|
||||
"query": {"description": "Query text for completion", "type": "string", "required": False},
|
||||
"files": {"description": "Files to be processed", "type": "array", "required": False},
|
||||
"response_mode": {
|
||||
"description": "Response mode: blocking or streaming",
|
||||
"type": "string",
|
||||
"enum": ["blocking", "streaming"],
|
||||
"required": False,
|
||||
},
|
||||
"retriever_from": {"description": "Source of retriever", "type": "string", "required": False},
|
||||
}
|
||||
)
|
||||
@api.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
404: "App Not Found",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def post(self, app_model, end_user):
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
@@ -81,6 +107,19 @@ class CompletionApi(WebApiResource):
|
||||
|
||||
|
||||
class CompletionStopApi(WebApiResource):
|
||||
@api.doc("Stop Completion Message")
|
||||
@api.doc(description="Stop a running completion message task.")
|
||||
@api.doc(params={"task_id": {"description": "Task ID to stop", "type": "string", "required": True}})
|
||||
@api.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
404: "Task Not Found",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def post(self, app_model, end_user, task_id):
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
@@ -91,6 +130,34 @@ class CompletionStopApi(WebApiResource):
|
||||
|
||||
|
||||
class ChatApi(WebApiResource):
|
||||
@api.doc("Create Chat Message")
|
||||
@api.doc(description="Create a chat message for conversational applications.")
|
||||
@api.doc(
|
||||
params={
|
||||
"inputs": {"description": "Input variables for the chat", "type": "object", "required": True},
|
||||
"query": {"description": "User query/message", "type": "string", "required": True},
|
||||
"files": {"description": "Files to be processed", "type": "array", "required": False},
|
||||
"response_mode": {
|
||||
"description": "Response mode: blocking or streaming",
|
||||
"type": "string",
|
||||
"enum": ["blocking", "streaming"],
|
||||
"required": False,
|
||||
},
|
||||
"conversation_id": {"description": "Conversation UUID", "type": "string", "required": False},
|
||||
"parent_message_id": {"description": "Parent message UUID", "type": "string", "required": False},
|
||||
"retriever_from": {"description": "Source of retriever", "type": "string", "required": False},
|
||||
}
|
||||
)
|
||||
@api.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
404: "App Not Found",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def post(self, app_model, end_user):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
@@ -141,6 +208,19 @@ class ChatApi(WebApiResource):
|
||||
|
||||
|
||||
class ChatStopApi(WebApiResource):
|
||||
@api.doc("Stop Chat Message")
|
||||
@api.doc(description="Stop a running chat message task.")
|
||||
@api.doc(params={"task_id": {"description": "Task ID to stop", "type": "string", "required": True}})
|
||||
@api.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
404: "Task Not Found",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def post(self, app_model, end_user, task_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from flask_restx import marshal_with, reqparse
|
||||
from flask_restx import fields, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
@@ -58,6 +58,11 @@ class ConversationListApi(WebApiResource):
|
||||
|
||||
|
||||
class ConversationApi(WebApiResource):
|
||||
delete_response_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
@marshal_with(delete_response_fields)
|
||||
def delete(self, app_model, end_user, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
@@ -94,6 +99,11 @@ class ConversationRenameApi(WebApiResource):
|
||||
|
||||
|
||||
class ConversationPinApi(WebApiResource):
|
||||
pin_response_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
@marshal_with(pin_response_fields)
|
||||
def patch(self, app_model, end_user, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
@@ -110,6 +120,11 @@ class ConversationPinApi(WebApiResource):
|
||||
|
||||
|
||||
class ConversationUnPinApi(WebApiResource):
|
||||
unpin_response_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
@marshal_with(unpin_response_fields)
|
||||
def patch(self, app_model, end_user, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
|
||||
@@ -1,12 +1,21 @@
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.web import api
|
||||
from controllers.web import web_ns
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
@web_ns.route("/system-features")
|
||||
class SystemFeatureApi(Resource):
|
||||
@web_ns.doc("get_system_features")
|
||||
@web_ns.doc(description="Get system feature flags and configuration")
|
||||
@web_ns.doc(responses={200: "System features retrieved successfully", 500: "Internal server error"})
|
||||
def get(self):
|
||||
"""Get system feature flags and configuration.
|
||||
|
||||
Returns the current system feature flags and configuration
|
||||
that control various functionalities across the platform.
|
||||
|
||||
Returns:
|
||||
dict: System feature configuration object
|
||||
"""
|
||||
return FeatureService.get_system_features().model_dump()
|
||||
|
||||
|
||||
api.add_resource(SystemFeatureApi, "/system-features")
|
||||
|
||||
@@ -9,14 +9,50 @@ from controllers.common.errors import (
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from fields.file_fields import file_fields
|
||||
from fields.file_fields import build_file_model
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
@web_ns.route("/files/upload")
|
||||
class FileApi(WebApiResource):
|
||||
@marshal_with(file_fields)
|
||||
@web_ns.doc("upload_file")
|
||||
@web_ns.doc(description="Upload a file for use in web applications")
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
201: "File uploaded successfully",
|
||||
400: "Bad request - invalid file or parameters",
|
||||
413: "File too large",
|
||||
415: "Unsupported file type",
|
||||
}
|
||||
)
|
||||
@marshal_with(build_file_model(web_ns))
|
||||
def post(self, app_model, end_user):
|
||||
"""Upload a file for use in web applications.
|
||||
|
||||
Accepts file uploads for use within web applications, supporting
|
||||
multiple file types with automatic validation and storage.
|
||||
|
||||
Args:
|
||||
app_model: The associated application model
|
||||
end_user: The end user uploading the file
|
||||
|
||||
Form Parameters:
|
||||
file: The file to upload (required)
|
||||
source: Optional source type (datasets or None)
|
||||
|
||||
Returns:
|
||||
dict: File information including ID, URL, and metadata
|
||||
int: HTTP status code 201 for success
|
||||
|
||||
Raises:
|
||||
NoFileUploadedError: No file provided in request
|
||||
TooManyFilesError: Multiple files provided (only one allowed)
|
||||
FilenameNotExistsError: File has no filename
|
||||
FileTooLargeError: File exceeds size limit
|
||||
UnsupportedFileTypeError: File type not supported
|
||||
"""
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ from controllers.console.auth.error import (
|
||||
)
|
||||
from controllers.console.error import EmailSendIpLimitError
|
||||
from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required
|
||||
from controllers.web import api
|
||||
from controllers.web import web_ns
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.password import hash_password, valid_password
|
||||
@@ -24,10 +24,21 @@ from models.account import Account
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
@web_ns.route("/forgot-password")
|
||||
class ForgotPasswordSendEmailApi(Resource):
|
||||
@only_edition_enterprise
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@web_ns.doc("send_forgot_password_email")
|
||||
@web_ns.doc(description="Send password reset email")
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Password reset email sent successfully",
|
||||
400: "Bad request - invalid email format",
|
||||
404: "Account not found",
|
||||
429: "Too many requests - rate limit exceeded",
|
||||
}
|
||||
)
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
@@ -54,10 +65,16 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
@web_ns.route("/forgot-password/validity")
|
||||
class ForgotPasswordCheckApi(Resource):
|
||||
@only_edition_enterprise
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@web_ns.doc("check_forgot_password_token")
|
||||
@web_ns.doc(description="Verify password reset token validity")
|
||||
@web_ns.doc(
|
||||
responses={200: "Token is valid", 400: "Bad request - invalid token format", 401: "Invalid or expired token"}
|
||||
)
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=str, required=True, location="json")
|
||||
@@ -94,10 +111,21 @@ class ForgotPasswordCheckApi(Resource):
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
|
||||
|
||||
@web_ns.route("/forgot-password/resets")
|
||||
class ForgotPasswordResetApi(Resource):
|
||||
@only_edition_enterprise
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@web_ns.doc("reset_password")
|
||||
@web_ns.doc(description="Reset user password with verification token")
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Password reset successfully",
|
||||
400: "Bad request - invalid parameters or password mismatch",
|
||||
401: "Invalid or expired token",
|
||||
404: "Account not found",
|
||||
}
|
||||
)
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
@@ -141,8 +169,3 @@ class ForgotPasswordResetApi(Resource):
|
||||
account.password = base64.b64encode(password_hashed).decode()
|
||||
account.password_salt = base64.b64encode(salt).decode()
|
||||
session.commit()
|
||||
|
||||
|
||||
api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password")
|
||||
api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity")
|
||||
api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from flask_restx import Resource, reqparse
|
||||
from jwt import InvalidTokenError # type: ignore
|
||||
from jwt import InvalidTokenError
|
||||
|
||||
import services
|
||||
from controllers.console.auth.error import (
|
||||
@@ -9,18 +9,30 @@ from controllers.console.auth.error import (
|
||||
)
|
||||
from controllers.console.error import AccountBannedError
|
||||
from controllers.console.wraps import only_edition_enterprise, setup_required
|
||||
from controllers.web import api
|
||||
from controllers.web import web_ns
|
||||
from libs.helper import email
|
||||
from libs.password import valid_password
|
||||
from services.account_service import AccountService
|
||||
from services.webapp_auth_service import WebAppAuthService
|
||||
|
||||
|
||||
@web_ns.route("/login")
|
||||
class LoginApi(Resource):
|
||||
"""Resource for web app email/password login."""
|
||||
|
||||
@setup_required
|
||||
@only_edition_enterprise
|
||||
@web_ns.doc("web_app_login")
|
||||
@web_ns.doc(description="Authenticate user for web application access")
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Authentication successful",
|
||||
400: "Bad request - invalid email or password format",
|
||||
401: "Authentication failed - email or password mismatch",
|
||||
403: "Account banned or login disabled",
|
||||
404: "Account not found",
|
||||
}
|
||||
)
|
||||
def post(self):
|
||||
"""Authenticate user and login."""
|
||||
parser = reqparse.RequestParser()
|
||||
@@ -51,9 +63,19 @@ class LoginApi(Resource):
|
||||
# return {"result": "success"}
|
||||
|
||||
|
||||
@web_ns.route("/email-code-login")
|
||||
class EmailCodeLoginSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@only_edition_enterprise
|
||||
@web_ns.doc("send_email_code_login")
|
||||
@web_ns.doc(description="Send email verification code for login")
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Email code sent successfully",
|
||||
400: "Bad request - invalid email format",
|
||||
404: "Account not found",
|
||||
}
|
||||
)
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
@@ -74,9 +96,20 @@ class EmailCodeLoginSendEmailApi(Resource):
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
@web_ns.route("/email-code-login/validity")
|
||||
class EmailCodeLoginApi(Resource):
|
||||
@setup_required
|
||||
@only_edition_enterprise
|
||||
@web_ns.doc("verify_email_code_login")
|
||||
@web_ns.doc(description="Verify email code and complete login")
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Email code verified and login successful",
|
||||
400: "Bad request - invalid code or token",
|
||||
401: "Invalid token or expired code",
|
||||
404: "Account not found",
|
||||
}
|
||||
)
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=str, required=True, location="json")
|
||||
@@ -104,9 +137,3 @@ class EmailCodeLoginApi(Resource):
|
||||
token = WebAppAuthService.login(account=account)
|
||||
AccountService.reset_login_error_rate_limit(args["email"])
|
||||
return {"result": "success", "data": {"access_token": token}}
|
||||
|
||||
|
||||
api.add_resource(LoginApi, "/login")
|
||||
# api.add_resource(LogoutApi, "/logout")
|
||||
api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
|
||||
api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")
|
||||
|
||||
@@ -85,6 +85,11 @@ class MessageListApi(WebApiResource):
|
||||
|
||||
|
||||
class MessageFeedbackApi(WebApiResource):
|
||||
feedback_response_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
@marshal_with(feedback_response_fields)
|
||||
def post(self, app_model, end_user, message_id):
|
||||
message_id = str(message_id)
|
||||
|
||||
@@ -152,6 +157,11 @@ class MessageMoreLikeThisApi(WebApiResource):
|
||||
|
||||
|
||||
class MessageSuggestedQuestionApi(WebApiResource):
|
||||
suggested_questions_response_fields = {
|
||||
"data": fields.List(fields.String),
|
||||
}
|
||||
|
||||
@marshal_with(suggested_questions_response_fields)
|
||||
def get(self, app_model, end_user, message_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
|
||||
@@ -7,7 +7,7 @@ from sqlalchemy import func, select
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.web import api
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import WebAppAuthRequiredError
|
||||
from extensions.ext_database import db
|
||||
from libs.passport import PassportService
|
||||
@@ -17,9 +17,19 @@ from services.feature_service import FeatureService
|
||||
from services.webapp_auth_service import WebAppAuthService, WebAppAuthType
|
||||
|
||||
|
||||
@web_ns.route("/passport")
|
||||
class PassportResource(Resource):
|
||||
"""Base resource for passport."""
|
||||
|
||||
@web_ns.doc("get_passport")
|
||||
@web_ns.doc(description="Get authentication passport for web application access")
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Passport retrieved successfully",
|
||||
401: "Unauthorized - missing app code or invalid authentication",
|
||||
404: "Application or user not found",
|
||||
}
|
||||
)
|
||||
def get(self):
|
||||
system_features = FeatureService.get_system_features()
|
||||
app_code = request.headers.get("X-App-Code")
|
||||
@@ -94,9 +104,6 @@ class PassportResource(Resource):
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(PassportResource, "/passport")
|
||||
|
||||
|
||||
def decode_enterprise_webapp_user_id(jwt_token: str | None):
|
||||
"""
|
||||
Decode the enterprise user session from the Authorization header.
|
||||
|
||||
@@ -10,16 +10,44 @@ from controllers.common.errors import (
|
||||
RemoteFileUploadError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import ssrf_proxy
|
||||
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
||||
from fields.file_fields import build_file_with_signed_url_model, build_remote_file_info_model
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
@web_ns.route("/remote-files/<path:url>")
|
||||
class RemoteFileInfoApi(WebApiResource):
|
||||
@marshal_with(remote_file_info_fields)
|
||||
@web_ns.doc("get_remote_file_info")
|
||||
@web_ns.doc(description="Get information about a remote file")
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Remote file information retrieved successfully",
|
||||
400: "Bad request - invalid URL",
|
||||
404: "Remote file not found",
|
||||
500: "Failed to fetch remote file",
|
||||
}
|
||||
)
|
||||
@marshal_with(build_remote_file_info_model(web_ns))
|
||||
def get(self, app_model, end_user, url):
|
||||
"""Get information about a remote file.
|
||||
|
||||
Retrieves basic information about a file located at a remote URL,
|
||||
including content type and content length.
|
||||
|
||||
Args:
|
||||
app_model: The associated application model
|
||||
end_user: The end user making the request
|
||||
url: URL-encoded path to the remote file
|
||||
|
||||
Returns:
|
||||
dict: Remote file information including type and length
|
||||
|
||||
Raises:
|
||||
HTTPException: If the remote file cannot be accessed
|
||||
"""
|
||||
decoded_url = urllib.parse.unquote(url)
|
||||
resp = ssrf_proxy.head(decoded_url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
@@ -32,9 +60,42 @@ class RemoteFileInfoApi(WebApiResource):
|
||||
}
|
||||
|
||||
|
||||
@web_ns.route("/remote-files/upload")
|
||||
class RemoteFileUploadApi(WebApiResource):
|
||||
@marshal_with(file_fields_with_signed_url)
|
||||
def post(self, app_model, end_user): # Add app_model and end_user parameters
|
||||
@web_ns.doc("upload_remote_file")
|
||||
@web_ns.doc(description="Upload a file from a remote URL")
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
201: "Remote file uploaded successfully",
|
||||
400: "Bad request - invalid URL or parameters",
|
||||
413: "File too large",
|
||||
415: "Unsupported file type",
|
||||
500: "Failed to fetch remote file",
|
||||
}
|
||||
)
|
||||
@marshal_with(build_file_with_signed_url_model(web_ns))
|
||||
def post(self, app_model, end_user):
|
||||
"""Upload a file from a remote URL.
|
||||
|
||||
Downloads a file from the provided remote URL and uploads it
|
||||
to the platform storage for use in web applications.
|
||||
|
||||
Args:
|
||||
app_model: The associated application model
|
||||
end_user: The end user making the request
|
||||
|
||||
JSON Parameters:
|
||||
url: The remote URL to download the file from (required)
|
||||
|
||||
Returns:
|
||||
dict: File information including ID, signed URL, and metadata
|
||||
int: HTTP status code 201 for success
|
||||
|
||||
Raises:
|
||||
RemoteFileUploadError: Failed to fetch file from remote URL
|
||||
FileTooLargeError: File exceeds size limit
|
||||
UnsupportedFileTypeError: File type not supported
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("url", type=str, required=True, help="URL is required")
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -30,6 +30,10 @@ class SavedMessageListApi(WebApiResource):
|
||||
"data": fields.List(fields.Nested(message_fields)),
|
||||
}
|
||||
|
||||
post_response_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
@marshal_with(saved_message_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model, end_user):
|
||||
if app_model.mode != "completion":
|
||||
@@ -42,6 +46,7 @@ class SavedMessageListApi(WebApiResource):
|
||||
|
||||
return SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"])
|
||||
|
||||
@marshal_with(post_response_fields)
|
||||
def post(self, app_model, end_user):
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
@@ -59,6 +64,11 @@ class SavedMessageListApi(WebApiResource):
|
||||
|
||||
|
||||
class SavedMessageApi(WebApiResource):
|
||||
delete_response_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
@marshal_with(delete_response_fields)
|
||||
def delete(self, app_model, end_user, message_id):
|
||||
message_id = str(message_id)
|
||||
|
||||
|
||||
@@ -53,6 +53,18 @@ class AppSiteApi(WebApiResource):
|
||||
"custom_config": fields.Raw(attribute="custom_config"),
|
||||
}
|
||||
|
||||
@api.doc("Get App Site Info")
|
||||
@api.doc(description="Retrieve app site information and configuration.")
|
||||
@api.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
404: "App Not Found",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(app_fields)
|
||||
def get(self, app_model, end_user):
|
||||
"""Retrieve app site info."""
|
||||
|
||||
@@ -30,6 +30,24 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowRunApi(WebApiResource):
|
||||
@api.doc("Run Workflow")
|
||||
@api.doc(description="Execute a workflow with provided inputs and files.")
|
||||
@api.doc(
|
||||
params={
|
||||
"inputs": {"description": "Input variables for the workflow", "type": "object", "required": True},
|
||||
"files": {"description": "Files to be processed by the workflow", "type": "array", "required": False},
|
||||
}
|
||||
)
|
||||
@api.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
404: "App Not Found",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
"""
|
||||
Run workflow
|
||||
@@ -67,6 +85,23 @@ class WorkflowRunApi(WebApiResource):
|
||||
|
||||
|
||||
class WorkflowTaskStopApi(WebApiResource):
|
||||
@api.doc("Stop Workflow Task")
|
||||
@api.doc(description="Stop a running workflow task.")
|
||||
@api.doc(
|
||||
params={
|
||||
"task_id": {"description": "Task ID to stop", "type": "string", "required": True},
|
||||
}
|
||||
)
|
||||
@api.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
404: "Task Not Found",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def post(self, app_model: App, end_user: EndUser, task_id: str):
|
||||
"""
|
||||
Stop workflow task
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
||||
|
||||
|
||||
class MoreLikeThisConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class AppConfigModel(BaseModel):
|
||||
more_like_this: MoreLikeThisConfig = Field(default_factory=MoreLikeThisConfig)
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class MoreLikeThisConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> bool:
|
||||
@@ -6,31 +19,14 @@ class MoreLikeThisConfigManager:
|
||||
|
||||
:param config: model config args
|
||||
"""
|
||||
more_like_this = False
|
||||
more_like_this_dict = config.get("more_like_this")
|
||||
if more_like_this_dict:
|
||||
if more_like_this_dict.get("enabled"):
|
||||
more_like_this = True
|
||||
|
||||
return more_like_this
|
||||
validated_config, _ = cls.validate_and_set_defaults(config)
|
||||
return AppConfigModel.model_validate(validated_config).more_like_this.enabled
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
"""
|
||||
Validate and set defaults for more like this feature
|
||||
|
||||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("more_like_this"):
|
||||
config["more_like_this"] = {"enabled": False}
|
||||
|
||||
if not isinstance(config["more_like_this"], dict):
|
||||
raise ValueError("more_like_this must be of dict type")
|
||||
|
||||
if "enabled" not in config["more_like_this"] or not config["more_like_this"]["enabled"]:
|
||||
config["more_like_this"]["enabled"] = False
|
||||
|
||||
if not isinstance(config["more_like_this"]["enabled"], bool):
|
||||
raise ValueError("enabled in more_like_this must be of boolean type")
|
||||
|
||||
return config, ["more_like_this"]
|
||||
try:
|
||||
return AppConfigModel.model_validate(config).model_dump(), ["more_like_this"]
|
||||
except ValidationError as e:
|
||||
raise ValueError(
|
||||
"more_like_this must be of dict type and enabled in more_like_this must be of boolean type"
|
||||
)
|
||||
|
||||
@@ -140,7 +140,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
environment_variables=self._workflow.environment_variables,
|
||||
# Based on the definition of `VariableUnion`,
|
||||
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
|
||||
conversation_variables=cast(list[VariableUnion], conversation_variables),
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
|
||||
# init graph
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
from contextlib import contextmanager
|
||||
@@ -143,6 +144,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
|
||||
self._workflow_response_converter = WorkflowResponseConverter(
|
||||
application_generate_entity=application_generate_entity,
|
||||
user=user,
|
||||
)
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
@@ -373,7 +375,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle node succeeded events."""
|
||||
# Record files if it's an answer node or end node
|
||||
if event.node_type in [NodeType.ANSWER, NodeType.END]:
|
||||
if event.node_type in [NodeType.ANSWER, NodeType.END, NodeType.LLM]:
|
||||
self._recorded_files.extend(
|
||||
self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {})
|
||||
)
|
||||
@@ -896,7 +898,14 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
|
||||
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
|
||||
message = self._get_message(session=session)
|
||||
message.answer = self._task_state.answer
|
||||
|
||||
# If there are assistant files, remove markdown image links from answer
|
||||
answer_text = self._task_state.answer
|
||||
if self._recorded_files:
|
||||
# Remove markdown image links since we're storing files separately
|
||||
answer_text = re.sub(r"!\[.*?\]\(.*?\)", "", answer_text).strip()
|
||||
|
||||
message.answer = answer_text
|
||||
message.updated_at = naive_utc_now()
|
||||
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, final
|
||||
|
||||
@@ -14,6 +13,7 @@ from core.workflow.repositories.draft_variable_repository import (
|
||||
NoopDraftVariableSaver,
|
||||
)
|
||||
from factories import file_factory
|
||||
from libs.orjson import orjson_dumps
|
||||
from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -174,7 +174,7 @@ class BaseAppGenerator:
|
||||
def gen():
|
||||
for message in generator:
|
||||
if isinstance(message, Mapping | dict):
|
||||
yield f"data: {json.dumps(message)}\n\n"
|
||||
yield f"data: {orjson_dumps(message)}\n\n"
|
||||
else:
|
||||
yield f"event: {message}\n\n"
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import queue
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from enum import IntEnum, auto
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
@@ -19,9 +19,9 @@ from core.app.entities.queue_entities import (
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class PublishFrom(Enum):
|
||||
APPLICATION_MANAGER = 1
|
||||
TASK_PIPELINE = 2
|
||||
class PublishFrom(IntEnum):
|
||||
APPLICATION_MANAGER = auto()
|
||||
TASK_PIPELINE = auto()
|
||||
|
||||
|
||||
class AppQueueManager:
|
||||
|
||||
@@ -3,7 +3,6 @@ from collections.abc import Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
@@ -53,9 +52,7 @@ from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import (
|
||||
Account,
|
||||
CreatorUserRole,
|
||||
EndUser,
|
||||
WorkflowRun,
|
||||
)
|
||||
|
||||
|
||||
@@ -64,8 +61,10 @@ class WorkflowResponseConverter:
|
||||
self,
|
||||
*,
|
||||
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
|
||||
user: Union[Account, EndUser],
|
||||
) -> None:
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._user = user
|
||||
|
||||
def workflow_start_to_stream_response(
|
||||
self,
|
||||
@@ -92,27 +91,21 @@ class WorkflowResponseConverter:
|
||||
workflow_execution: WorkflowExecution,
|
||||
) -> WorkflowFinishStreamResponse:
|
||||
created_by = None
|
||||
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id_))
|
||||
assert workflow_run is not None
|
||||
if workflow_run.created_by_role == CreatorUserRole.ACCOUNT:
|
||||
stmt = select(Account).where(Account.id == workflow_run.created_by)
|
||||
account = session.scalar(stmt)
|
||||
if account:
|
||||
created_by = {
|
||||
"id": account.id,
|
||||
"name": account.name,
|
||||
"email": account.email,
|
||||
}
|
||||
elif workflow_run.created_by_role == CreatorUserRole.END_USER:
|
||||
stmt = select(EndUser).where(EndUser.id == workflow_run.created_by)
|
||||
end_user = session.scalar(stmt)
|
||||
if end_user:
|
||||
created_by = {
|
||||
"id": end_user.id,
|
||||
"user": end_user.session_id,
|
||||
}
|
||||
|
||||
user = self._user
|
||||
if isinstance(user, Account):
|
||||
created_by = {
|
||||
"id": user.id,
|
||||
"name": user.name,
|
||||
"email": user.email,
|
||||
}
|
||||
elif isinstance(user, EndUser):
|
||||
created_by = {
|
||||
"id": user.id,
|
||||
"user": user.session_id,
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError(f"unknown created_by_role: {workflow_run.created_by_role}")
|
||||
raise NotImplementedError(f"User type not supported: {type(user)}")
|
||||
|
||||
# Handle the case where finished_at is None by using current time as default
|
||||
finished_at_timestamp = (
|
||||
|
||||
@@ -54,6 +54,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
streaming: Literal[True],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
|
||||
@overload
|
||||
@@ -68,6 +70,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
streaming: Literal[False],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
@@ -82,6 +86,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
streaming: bool,
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
@@ -95,6 +101,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
|
||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||
|
||||
@@ -130,17 +138,20 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
**extract_external_trace_id_from_args(args),
|
||||
}
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
if triggered_from in (WorkflowRunTriggeredFrom.DEBUGGING, WorkflowRunTriggeredFrom.APP_RUN):
|
||||
# start node get inputs
|
||||
inputs = self._prepare_user_inputs(
|
||||
user_inputs=inputs,
|
||||
variables=app_config.variables,
|
||||
tenant_id=app_model.tenant_id,
|
||||
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||
)
|
||||
# init application generate entity
|
||||
application_generate_entity = WorkflowAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
file_upload_config=file_extra_config,
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs,
|
||||
variables=app_config.variables,
|
||||
tenant_id=app_model.tenant_id,
|
||||
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||
),
|
||||
inputs=inputs,
|
||||
files=list(system_files),
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
@@ -159,7 +170,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
# Create session factory
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
# Create workflow execution(aka workflow run) repository
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
if triggered_from is not None:
|
||||
# Use explicitly provided triggered_from (for async triggers)
|
||||
workflow_triggered_from = triggered_from
|
||||
elif invoke_from == InvokeFrom.DEBUGGER:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
|
||||
else:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
@@ -187,6 +201,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
root_node_id=root_node_id,
|
||||
)
|
||||
|
||||
def _generate(
|
||||
@@ -202,6 +217,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
streaming: bool = True,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
@@ -239,6 +255,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
"context": context,
|
||||
"workflow_thread_pool_id": workflow_thread_pool_id,
|
||||
"variable_loader": variable_loader,
|
||||
"root_node_id": root_node_id,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -435,6 +452,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
context: contextvars.Context,
|
||||
variable_loader: VariableLoader,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
@@ -478,6 +496,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
variable_loader=variable_loader,
|
||||
workflow=workflow,
|
||||
system_user_id=system_user_id,
|
||||
root_node_id=root_node_id,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -34,6 +34,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
workflow: Workflow,
|
||||
system_user_id: str,
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
queue_manager=queue_manager,
|
||||
@@ -44,6 +45,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
self.workflow_thread_pool_id = workflow_thread_pool_id
|
||||
self._workflow = workflow
|
||||
self._sys_user_id = system_user_id
|
||||
self._root_node_id = root_node_id
|
||||
|
||||
def run(self) -> None:
|
||||
"""
|
||||
@@ -93,7 +95,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = self._init_graph(graph_config=self._workflow.graph_dict)
|
||||
graph = self._init_graph(graph_config=self._workflow.graph_dict, root_node_id=self._root_node_id)
|
||||
|
||||
# RUN WORKFLOW
|
||||
workflow_entry = WorkflowEntry(
|
||||
|
||||
@@ -131,6 +131,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
|
||||
self._workflow_response_converter = WorkflowResponseConverter(
|
||||
application_generate_entity=application_generate_entity,
|
||||
user=user,
|
||||
)
|
||||
|
||||
self._application_generate_entity = application_generate_entity
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
@@ -79,7 +79,7 @@ class WorkflowBasedAppRunner:
|
||||
self._variable_loader = variable_loader
|
||||
self._app_id = app_id
|
||||
|
||||
def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
|
||||
def _init_graph(self, graph_config: Mapping[str, Any], root_node_id: Optional[str] = None) -> Graph:
|
||||
"""
|
||||
Init graph
|
||||
"""
|
||||
@@ -92,7 +92,7 @@ class WorkflowBasedAppRunner:
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
@@ -118,7 +118,7 @@ class QueueIterationNextEvent(AppQueueEvent):
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""iteratoin run in parallel mode run id"""
|
||||
"""iteration run in parallel mode run id"""
|
||||
node_run_index: int
|
||||
output: Optional[Any] = None # output for the current iteration
|
||||
duration: Optional[float] = None
|
||||
@@ -201,7 +201,7 @@ class QueueLoopNextEvent(AppQueueEvent):
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""iteratoin run in parallel mode run id"""
|
||||
"""iteration run in parallel mode run id"""
|
||||
node_run_index: int
|
||||
output: Optional[Any] = None # output for the current loop
|
||||
duration: Optional[float] = None
|
||||
@@ -382,7 +382,7 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""iteratoin run in parallel mode run id"""
|
||||
"""iteration run in parallel mode run id"""
|
||||
agent_strategy: Optional[AgentNodeStrategyInit] = None
|
||||
|
||||
|
||||
|
||||
@@ -472,9 +472,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
:param event: agent thought event
|
||||
:return:
|
||||
"""
|
||||
agent_thought: Optional[MessageAgentThought] = (
|
||||
db.session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
|
||||
)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
agent_thought: Optional[MessageAgentThought] = (
|
||||
session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
|
||||
)
|
||||
|
||||
if agent_thought:
|
||||
return AgentThoughtStreamResponse(
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
class TaskPipilineError(ValueError):
|
||||
class TaskPipelineError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class RecordNotFoundError(TaskPipilineError):
|
||||
class RecordNotFoundError(TaskPipelineError):
|
||||
def __init__(self, record_name: str, record_id: str):
|
||||
super().__init__(f"{record_name} with id {record_id} not found")
|
||||
|
||||
|
||||
@@ -192,7 +192,7 @@ class ProviderConfig(BasicProviderConfig):
|
||||
|
||||
scope: AppSelectorScope | ModelSelectorScope | ToolSelectorScope | None = None
|
||||
required: bool = False
|
||||
default: Optional[Union[int, str, float, bool]] = None
|
||||
default: Optional[Union[int, str, float, bool, list]] = None
|
||||
options: Optional[list[Option]] = None
|
||||
label: Optional[I18nObject] = None
|
||||
help: Optional[I18nObject] = None
|
||||
|
||||
@@ -88,6 +88,7 @@ def to_prompt_message_content(
|
||||
"url": _to_url(f) if dify_config.MULTIMODAL_SEND_FORMAT == "url" else "",
|
||||
"format": f.extension.removeprefix("."),
|
||||
"mime_type": f.mime_type,
|
||||
"filename": f.filename or "",
|
||||
}
|
||||
if f.type == FileType.IMAGE:
|
||||
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
@@ -3,7 +3,7 @@ import base64
|
||||
from libs import rsa
|
||||
|
||||
|
||||
def obfuscated_token(token: str):
|
||||
def obfuscated_token(token: str) -> str:
|
||||
if not token:
|
||||
return token
|
||||
if len(token) <= 8:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
@@ -23,7 +23,7 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP
|
||||
return []
|
||||
|
||||
url = str(marketplace_api_url / "api/v1/plugins/batch")
|
||||
response = requests.post(url, json={"plugin_ids": plugin_ids})
|
||||
response = httpx.post(url, json={"plugin_ids": plugin_ids})
|
||||
response.raise_for_status()
|
||||
|
||||
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]
|
||||
@@ -36,7 +36,7 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
|
||||
return []
|
||||
|
||||
url = str(marketplace_api_url / "api/v1/plugins/batch")
|
||||
response = requests.post(url, json={"plugin_ids": plugin_ids})
|
||||
response = httpx.post(url, json={"plugin_ids": plugin_ids})
|
||||
response.raise_for_status()
|
||||
result: list[MarketplacePluginDeclaration] = []
|
||||
for plugin in response.json()["data"]["plugins"]:
|
||||
@@ -50,5 +50,5 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
|
||||
|
||||
def record_install_plugin_event(plugin_unique_identifier: str):
|
||||
url = str(marketplace_api_url / "api/v1/stats/plugins/install_count")
|
||||
response = requests.post(url, json={"unique_identifier": plugin_unique_identifier})
|
||||
response = httpx.post(url, json={"unique_identifier": plugin_unique_identifier})
|
||||
response.raise_for_status()
|
||||
|
||||
128
api/core/helper/provider_encryption.py
Normal file
128
api/core/helper/provider_encryption.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import contextlib
|
||||
from copy import deepcopy
|
||||
from typing import Any, Optional, Protocol
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper import encrypter
|
||||
|
||||
|
||||
class ProviderConfigCache(Protocol):
|
||||
"""
|
||||
Interface for provider configuration cache operations
|
||||
"""
|
||||
|
||||
def get(self) -> Optional[dict]:
|
||||
"""Get cached provider configuration"""
|
||||
...
|
||||
|
||||
def set(self, config: dict[str, Any]) -> None:
|
||||
"""Cache provider configuration"""
|
||||
...
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete cached provider configuration"""
|
||||
...
|
||||
|
||||
|
||||
class ProviderConfigEncrypter:
|
||||
tenant_id: str
|
||||
config: list[BasicProviderConfig]
|
||||
provider_config_cache: ProviderConfigCache
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
config: list[BasicProviderConfig],
|
||||
provider_config_cache: ProviderConfigCache,
|
||||
):
|
||||
self.tenant_id = tenant_id
|
||||
self.config = config
|
||||
self.provider_config_cache = provider_config_cache
|
||||
|
||||
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
deep copy data
|
||||
"""
|
||||
return deepcopy(data)
|
||||
|
||||
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
encrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with encrypted values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
|
||||
data[field_name] = encrypted
|
||||
|
||||
return data
|
||||
|
||||
def mask_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
mask credentials
|
||||
|
||||
return a deep copy of credentials with masked values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
if len(data[field_name]) > 6:
|
||||
data[field_name] = (
|
||||
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
|
||||
)
|
||||
else:
|
||||
data[field_name] = "*" * len(data[field_name])
|
||||
|
||||
return data
|
||||
|
||||
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
return self.mask_credentials(data)
|
||||
|
||||
def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
|
||||
"""
|
||||
decrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with decrypted values
|
||||
"""
|
||||
cached_credentials = self.provider_config_cache.get()
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
|
||||
data = self._deep_copy(data)
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
with contextlib.suppress(Exception):
|
||||
# if the value is None or empty string, skip decrypt
|
||||
if not data[field_name]:
|
||||
continue
|
||||
|
||||
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
|
||||
|
||||
self.provider_config_cache.set(data)
|
||||
return data
|
||||
|
||||
|
||||
def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
|
||||
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
|
||||
@@ -4,224 +4,259 @@ from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.web.passport import generate_session_id
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
|
||||
from core.mcp import types
|
||||
from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND
|
||||
from core.mcp.utils import create_mcp_error_response
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from extensions.ext_database import db
|
||||
from core.mcp import types as mcp_types
|
||||
from models.model import App, AppMCPServer, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPServerStreamableHTTPRequestHandler:
|
||||
def handle_mcp_request(
|
||||
app: App,
|
||||
request: mcp_types.ClientRequest,
|
||||
user_input_form: list[VariableEntity],
|
||||
mcp_server: AppMCPServer,
|
||||
end_user: EndUser | None = None,
|
||||
request_id: int | str = 1,
|
||||
) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError:
|
||||
"""
|
||||
Apply to MCP HTTP streamable server with stateless http
|
||||
Handle MCP request and return JSON-RPC response
|
||||
|
||||
Args:
|
||||
app: The Dify app instance
|
||||
request: The JSON-RPC request message
|
||||
user_input_form: List of variable entities for the app
|
||||
mcp_server: The MCP server configuration
|
||||
end_user: Optional end user
|
||||
request_id: The request ID
|
||||
|
||||
Returns:
|
||||
JSON-RPC response or error
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity]
|
||||
):
|
||||
self.app = app
|
||||
self.request = request
|
||||
mcp_server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == self.app.id).first()
|
||||
if not mcp_server:
|
||||
raise ValueError("MCP server not found")
|
||||
self.mcp_server: AppMCPServer = mcp_server
|
||||
self.end_user = self.retrieve_end_user()
|
||||
self.user_input_form = user_input_form
|
||||
request_type = type(request.root)
|
||||
|
||||
@property
|
||||
def request_type(self):
|
||||
return type(self.request.root)
|
||||
def create_success_response(result_data: mcp_types.Result) -> mcp_types.JSONRPCResponse:
|
||||
"""Create success response with business result data"""
|
||||
return mcp_types.JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
result=result_data.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
|
||||
@property
|
||||
def parameter_schema(self):
|
||||
parameters, required = self._convert_input_form_to_parameters(self.user_input_form)
|
||||
if self.app.mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": parameters,
|
||||
"required": required,
|
||||
}
|
||||
def create_error_response(code: int, message: str) -> mcp_types.JSONRPCError:
|
||||
"""Create error response with error code and message"""
|
||||
from core.mcp.types import ErrorData
|
||||
|
||||
error_data = ErrorData(code=code, message=message)
|
||||
return mcp_types.JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
error=error_data,
|
||||
)
|
||||
|
||||
# Request handler mapping using functional approach
|
||||
request_handlers = {
|
||||
mcp_types.InitializeRequest: lambda: handle_initialize(mcp_server.description),
|
||||
mcp_types.ListToolsRequest: lambda: handle_list_tools(
|
||||
app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict
|
||||
),
|
||||
mcp_types.CallToolRequest: lambda: handle_call_tool(app, request, user_input_form, end_user),
|
||||
mcp_types.PingRequest: lambda: handle_ping(),
|
||||
}
|
||||
|
||||
try:
|
||||
# Dispatch request to appropriate handler
|
||||
handler = request_handlers.get(request_type)
|
||||
if handler:
|
||||
return create_success_response(handler())
|
||||
else:
|
||||
return create_error_response(mcp_types.METHOD_NOT_FOUND, f"Method not found: {request_type.__name__}")
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Invalid params")
|
||||
return create_error_response(mcp_types.INVALID_PARAMS, str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Internal server error")
|
||||
return create_error_response(mcp_types.INTERNAL_ERROR, "Internal server error: " + str(e))
|
||||
|
||||
|
||||
def handle_ping() -> mcp_types.EmptyResult:
|
||||
"""Handle ping request"""
|
||||
return mcp_types.EmptyResult()
|
||||
|
||||
|
||||
def handle_initialize(description: str) -> mcp_types.InitializeResult:
|
||||
"""Handle initialize request"""
|
||||
capabilities = mcp_types.ServerCapabilities(
|
||||
tools=mcp_types.ToolsCapability(listChanged=False),
|
||||
)
|
||||
|
||||
return mcp_types.InitializeResult(
|
||||
protocolVersion=mcp_types.SERVER_LATEST_PROTOCOL_VERSION,
|
||||
capabilities=capabilities,
|
||||
serverInfo=mcp_types.Implementation(name="Dify", version=dify_config.project.version),
|
||||
instructions=description,
|
||||
)
|
||||
|
||||
|
||||
def handle_list_tools(
|
||||
app_name: str,
|
||||
app_mode: str,
|
||||
user_input_form: list[VariableEntity],
|
||||
description: str,
|
||||
parameters_dict: dict[str, str],
|
||||
) -> mcp_types.ListToolsResult:
|
||||
"""Handle list tools request"""
|
||||
parameter_schema = build_parameter_schema(app_mode, user_input_form, parameters_dict)
|
||||
|
||||
return mcp_types.ListToolsResult(
|
||||
tools=[
|
||||
mcp_types.Tool(
|
||||
name=app_name,
|
||||
description=description,
|
||||
inputSchema=parameter_schema,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def handle_call_tool(
|
||||
app: App,
|
||||
request: mcp_types.ClientRequest,
|
||||
user_input_form: list[VariableEntity],
|
||||
end_user: EndUser | None,
|
||||
) -> mcp_types.CallToolResult:
|
||||
"""Handle call tool request"""
|
||||
request_obj = cast(mcp_types.CallToolRequest, request.root)
|
||||
args = prepare_tool_arguments(app, request_obj.params.arguments or {})
|
||||
|
||||
if not end_user:
|
||||
raise ValueError("End user not found")
|
||||
|
||||
response = AppGenerateService.generate(
|
||||
app,
|
||||
end_user,
|
||||
args,
|
||||
InvokeFrom.SERVICE_API,
|
||||
streaming=app.mode == AppMode.AGENT_CHAT.value,
|
||||
)
|
||||
|
||||
answer = extract_answer_from_response(app, response)
|
||||
return mcp_types.CallToolResult(content=[mcp_types.TextContent(text=answer, type="text")])
|
||||
|
||||
|
||||
def build_parameter_schema(
|
||||
app_mode: str,
|
||||
user_input_form: list[VariableEntity],
|
||||
parameters_dict: dict[str, str],
|
||||
) -> dict[str, Any]:
|
||||
"""Build parameter schema for the tool"""
|
||||
parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict)
|
||||
|
||||
if app_mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "User Input/Question content"},
|
||||
**parameters,
|
||||
},
|
||||
"required": ["query", *required],
|
||||
"properties": parameters,
|
||||
"required": required,
|
||||
}
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "User Input/Question content"},
|
||||
**parameters,
|
||||
},
|
||||
"required": ["query", *required],
|
||||
}
|
||||
|
||||
@property
|
||||
def capabilities(self):
|
||||
return types.ServerCapabilities(
|
||||
tools=types.ToolsCapability(listChanged=False),
|
||||
)
|
||||
|
||||
def response(self, response: types.Result | str):
|
||||
if isinstance(response, str):
|
||||
sse_content = f"event: ping\ndata: {response}\n\n".encode()
|
||||
yield sse_content
|
||||
return
|
||||
json_response = types.JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=(self.request.root.model_extra or {}).get("id", 1),
|
||||
result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
json_data = json.dumps(jsonable_encoder(json_response))
|
||||
def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Prepare arguments based on app mode"""
|
||||
if app.mode == AppMode.WORKFLOW.value:
|
||||
return {"inputs": arguments}
|
||||
elif app.mode == AppMode.COMPLETION.value:
|
||||
return {"query": "", "inputs": arguments}
|
||||
else:
|
||||
# Chat modes - create a copy to avoid modifying original dict
|
||||
args_copy = arguments.copy()
|
||||
query = args_copy.pop("query", "")
|
||||
return {"query": query, "inputs": args_copy}
|
||||
|
||||
sse_content = f"event: message\ndata: {json_data}\n\n".encode()
|
||||
|
||||
yield sse_content
|
||||
def extract_answer_from_response(app: App, response: Any) -> str:
|
||||
"""Extract answer from app generate response"""
|
||||
answer = ""
|
||||
|
||||
def error_response(self, code: int, message: str, data=None):
|
||||
request_id = (self.request.root.model_extra or {}).get("id", 1) or 1
|
||||
return create_mcp_error_response(request_id, code, message, data)
|
||||
if isinstance(response, RateLimitGenerator):
|
||||
answer = process_streaming_response(response)
|
||||
elif isinstance(response, Mapping):
|
||||
answer = process_mapping_response(app, response)
|
||||
else:
|
||||
logger.warning("Unexpected response type: %s", type(response))
|
||||
|
||||
def handle(self):
|
||||
handle_map = {
|
||||
types.InitializeRequest: self.initialize,
|
||||
types.ListToolsRequest: self.list_tools,
|
||||
types.CallToolRequest: self.invoke_tool,
|
||||
types.InitializedNotification: self.handle_notification,
|
||||
types.PingRequest: self.handle_ping,
|
||||
}
|
||||
try:
|
||||
if self.request_type in handle_map:
|
||||
return self.response(handle_map[self.request_type]())
|
||||
else:
|
||||
return self.error_response(METHOD_NOT_FOUND, f"Method not found: {self.request_type}")
|
||||
except ValueError as e:
|
||||
logger.exception("Invalid params")
|
||||
return self.error_response(INVALID_PARAMS, str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Internal server error")
|
||||
return self.error_response(INTERNAL_ERROR, f"Internal server error: {str(e)}")
|
||||
return answer
|
||||
|
||||
def handle_notification(self):
|
||||
return "ping"
|
||||
|
||||
def handle_ping(self):
|
||||
return types.EmptyResult()
|
||||
|
||||
def initialize(self):
|
||||
request = cast(types.InitializeRequest, self.request.root)
|
||||
client_info = request.params.clientInfo
|
||||
client_name = f"{client_info.name}@{client_info.version}"
|
||||
if not self.end_user:
|
||||
end_user = EndUser(
|
||||
tenant_id=self.app.tenant_id,
|
||||
app_id=self.app.id,
|
||||
type="mcp",
|
||||
name=client_name,
|
||||
session_id=generate_session_id(),
|
||||
external_user_id=self.mcp_server.id,
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
return types.InitializeResult(
|
||||
protocolVersion=types.SERVER_LATEST_PROTOCOL_VERSION,
|
||||
capabilities=self.capabilities,
|
||||
serverInfo=types.Implementation(name="Dify", version=dify_config.project.version),
|
||||
instructions=self.mcp_server.description,
|
||||
)
|
||||
|
||||
def list_tools(self):
|
||||
if not self.end_user:
|
||||
raise ValueError("User not found")
|
||||
return types.ListToolsResult(
|
||||
tools=[
|
||||
types.Tool(
|
||||
name=self.app.name,
|
||||
description=self.mcp_server.description,
|
||||
inputSchema=self.parameter_schema,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def invoke_tool(self):
|
||||
if not self.end_user:
|
||||
raise ValueError("User not found")
|
||||
request = cast(types.CallToolRequest, self.request.root)
|
||||
args = request.params.arguments or {}
|
||||
if self.app.mode in {AppMode.WORKFLOW.value}:
|
||||
args = {"inputs": args}
|
||||
elif self.app.mode in {AppMode.COMPLETION.value}:
|
||||
args = {"query": "", "inputs": args}
|
||||
else:
|
||||
args = {"query": args["query"], "inputs": {k: v for k, v in args.items() if k != "query"}}
|
||||
response = AppGenerateService.generate(
|
||||
self.app,
|
||||
self.end_user,
|
||||
args,
|
||||
InvokeFrom.SERVICE_API,
|
||||
streaming=self.app.mode == AppMode.AGENT_CHAT.value,
|
||||
)
|
||||
answer = ""
|
||||
if isinstance(response, RateLimitGenerator):
|
||||
for item in response.generator:
|
||||
data = item
|
||||
if isinstance(data, str) and data.startswith("data: "):
|
||||
try:
|
||||
json_str = data[6:].strip()
|
||||
parsed_data = json.loads(json_str)
|
||||
if parsed_data.get("event") == "agent_thought":
|
||||
answer += parsed_data.get("thought", "")
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if isinstance(response, Mapping):
|
||||
if self.app.mode in {
|
||||
AppMode.ADVANCED_CHAT.value,
|
||||
AppMode.COMPLETION.value,
|
||||
AppMode.CHAT.value,
|
||||
AppMode.AGENT_CHAT.value,
|
||||
}:
|
||||
answer = response["answer"]
|
||||
elif self.app.mode in {AppMode.WORKFLOW.value}:
|
||||
answer = json.dumps(response["data"]["outputs"], ensure_ascii=False)
|
||||
else:
|
||||
raise ValueError("Invalid app mode")
|
||||
# Not support image yet
|
||||
return types.CallToolResult(content=[types.TextContent(text=answer, type="text")])
|
||||
|
||||
def retrieve_end_user(self):
|
||||
return (
|
||||
db.session.query(EndUser)
|
||||
.where(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
|
||||
.first()
|
||||
)
|
||||
|
||||
def _convert_input_form_to_parameters(self, user_input_form: list[VariableEntity]):
|
||||
parameters: dict[str, dict[str, Any]] = {}
|
||||
required = []
|
||||
for item in user_input_form:
|
||||
parameters[item.variable] = {}
|
||||
if item.type in (
|
||||
VariableEntityType.FILE,
|
||||
VariableEntityType.FILE_LIST,
|
||||
VariableEntityType.EXTERNAL_DATA_TOOL,
|
||||
):
|
||||
continue
|
||||
if item.required:
|
||||
required.append(item.variable)
|
||||
# if the workflow republished, the parameters not changed
|
||||
# we should not raise error here
|
||||
def process_streaming_response(response: RateLimitGenerator) -> str:
|
||||
"""Process streaming response for agent chat mode"""
|
||||
answer = ""
|
||||
for item in response.generator:
|
||||
if isinstance(item, str) and item.startswith("data: "):
|
||||
try:
|
||||
description = self.mcp_server.parameters_dict[item.variable]
|
||||
except KeyError:
|
||||
description = ""
|
||||
parameters[item.variable]["description"] = description
|
||||
if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
|
||||
parameters[item.variable]["type"] = "string"
|
||||
elif item.type == VariableEntityType.SELECT:
|
||||
parameters[item.variable]["type"] = "string"
|
||||
parameters[item.variable]["enum"] = item.options
|
||||
elif item.type == VariableEntityType.NUMBER:
|
||||
parameters[item.variable]["type"] = "float"
|
||||
return parameters, required
|
||||
json_str = item[6:].strip()
|
||||
parsed_data = json.loads(json_str)
|
||||
if parsed_data.get("event") == "agent_thought":
|
||||
answer += parsed_data.get("thought", "")
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
return answer
|
||||
|
||||
|
||||
def process_mapping_response(app: App, response: Mapping) -> str:
|
||||
"""Process mapping response based on app mode"""
|
||||
if app.mode in {
|
||||
AppMode.ADVANCED_CHAT.value,
|
||||
AppMode.COMPLETION.value,
|
||||
AppMode.CHAT.value,
|
||||
AppMode.AGENT_CHAT.value,
|
||||
}:
|
||||
return response.get("answer", "")
|
||||
elif app.mode == AppMode.WORKFLOW.value:
|
||||
return json.dumps(response["data"]["outputs"], ensure_ascii=False)
|
||||
else:
|
||||
raise ValueError("Invalid app mode: " + str(app.mode))
|
||||
|
||||
|
||||
def convert_input_form_to_parameters(
|
||||
user_input_form: list[VariableEntity],
|
||||
parameters_dict: dict[str, str],
|
||||
) -> tuple[dict[str, dict[str, Any]], list[str]]:
|
||||
"""Convert user input form to parameter schema"""
|
||||
parameters: dict[str, dict[str, Any]] = {}
|
||||
required = []
|
||||
|
||||
for item in user_input_form:
|
||||
if item.type in (
|
||||
VariableEntityType.FILE,
|
||||
VariableEntityType.FILE_LIST,
|
||||
VariableEntityType.EXTERNAL_DATA_TOOL,
|
||||
):
|
||||
continue
|
||||
parameters[item.variable] = {}
|
||||
if item.required:
|
||||
required.append(item.variable)
|
||||
# if the workflow republished, the parameters not changed
|
||||
# we should not raise error here
|
||||
description = parameters_dict.get(item.variable, "")
|
||||
parameters[item.variable]["description"] = description
|
||||
if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
|
||||
parameters[item.variable]["type"] = "string"
|
||||
elif item.type == VariableEntityType.SELECT:
|
||||
parameters[item.variable]["type"] = "string"
|
||||
parameters[item.variable]["enum"] = item.options
|
||||
elif item.type == VariableEntityType.NUMBER:
|
||||
parameters[item.variable]["type"] = "float"
|
||||
return parameters, required
|
||||
|
||||
@@ -138,5 +138,5 @@ def create_mcp_error_response(
|
||||
error=error_data,
|
||||
)
|
||||
json_data = json.dumps(jsonable_encoder(json_response))
|
||||
sse_content = f"event: message\ndata: {json_data}\n\n".encode()
|
||||
sse_content = json_data.encode()
|
||||
yield sse_content
|
||||
|
||||
@@ -31,6 +31,65 @@ class TokenBufferMemory:
|
||||
self.conversation = conversation
|
||||
self.model_instance = model_instance
|
||||
|
||||
def _build_prompt_message_with_files(
|
||||
self, message_files: list[MessageFile], text_content: str, message: Message, app_record, is_user_message: bool
|
||||
) -> PromptMessage:
|
||||
"""
|
||||
Build prompt message with files.
|
||||
:param message_files: list of MessageFile objects
|
||||
:param text_content: text content of the message
|
||||
:param message: Message object
|
||||
:param app_record: app record
|
||||
:param is_user_message: whether this is a user message
|
||||
:return: PromptMessage
|
||||
"""
|
||||
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
|
||||
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
||||
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow_run = db.session.scalar(select(WorkflowRun).where(WorkflowRun.id == message.workflow_run_id))
|
||||
if not workflow_run:
|
||||
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
|
||||
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow not found: {workflow_run.workflow_id}")
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
else:
|
||||
raise AssertionError(f"Invalid app mode: {self.conversation.mode}")
|
||||
|
||||
detail = ImagePromptMessageContent.DETAIL.HIGH
|
||||
if file_extra_config and app_record:
|
||||
# Build files directly without filtering by belongs_to
|
||||
file_objs = [
|
||||
file_factory.build_from_message_file(
|
||||
message_file=message_file, tenant_id=app_record.tenant_id, config=file_extra_config
|
||||
)
|
||||
for message_file in message_files
|
||||
]
|
||||
if file_extra_config.image_config and file_extra_config.image_config.detail:
|
||||
detail = file_extra_config.image_config.detail
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
if not file_objs:
|
||||
if is_user_message:
|
||||
return UserPromptMessage(content=text_content)
|
||||
else:
|
||||
return AssistantPromptMessage(content=text_content)
|
||||
else:
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
for file in file_objs:
|
||||
prompt_message = file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=detail,
|
||||
)
|
||||
prompt_message_contents.append(prompt_message)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=text_content))
|
||||
|
||||
if is_user_message:
|
||||
return UserPromptMessage(content=prompt_message_contents)
|
||||
else:
|
||||
return AssistantPromptMessage(content=prompt_message_contents)
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
|
||||
) -> Sequence[PromptMessage]:
|
||||
@@ -67,52 +126,46 @@ class TokenBufferMemory:
|
||||
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
for message in messages:
|
||||
files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all()
|
||||
if files:
|
||||
file_extra_config = None
|
||||
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
|
||||
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
||||
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow_run = db.session.scalar(
|
||||
select(WorkflowRun).where(WorkflowRun.id == message.workflow_run_id)
|
||||
)
|
||||
if not workflow_run:
|
||||
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
|
||||
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow not found: {workflow_run.workflow_id}")
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
else:
|
||||
raise AssertionError(f"Invalid app mode: {self.conversation.mode}")
|
||||
|
||||
detail = ImagePromptMessageContent.DETAIL.LOW
|
||||
if file_extra_config and app_record:
|
||||
file_objs = file_factory.build_from_message_files(
|
||||
message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config
|
||||
)
|
||||
if file_extra_config.image_config and file_extra_config.image_config.detail:
|
||||
detail = file_extra_config.image_config.detail
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
if not file_objs:
|
||||
prompt_messages.append(UserPromptMessage(content=message.query))
|
||||
else:
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
for file in file_objs:
|
||||
prompt_message = file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=detail,
|
||||
)
|
||||
prompt_message_contents.append(prompt_message)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
# Process user message with files
|
||||
user_files = (
|
||||
db.session.query(MessageFile)
|
||||
.where(
|
||||
MessageFile.message_id == message.id,
|
||||
(MessageFile.belongs_to == "user") | (MessageFile.belongs_to.is_(None)),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if user_files:
|
||||
user_prompt_message = self._build_prompt_message_with_files(
|
||||
message_files=user_files,
|
||||
text_content=message.query,
|
||||
message=message,
|
||||
app_record=app_record,
|
||||
is_user_message=True,
|
||||
)
|
||||
prompt_messages.append(user_prompt_message)
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=message.query))
|
||||
|
||||
prompt_messages.append(AssistantPromptMessage(content=message.answer))
|
||||
# Process assistant message with files
|
||||
assistant_files = (
|
||||
db.session.query(MessageFile)
|
||||
.where(MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant")
|
||||
.all()
|
||||
)
|
||||
|
||||
if assistant_files:
|
||||
assistant_prompt_message = self._build_prompt_message_with_files(
|
||||
message_files=assistant_files,
|
||||
text_content=message.answer,
|
||||
message=message,
|
||||
app_record=app_record,
|
||||
is_user_message=False,
|
||||
)
|
||||
prompt_messages.append(assistant_prompt_message)
|
||||
else:
|
||||
prompt_messages.append(AssistantPromptMessage(content=message.answer))
|
||||
|
||||
if not prompt_messages:
|
||||
return []
|
||||
|
||||
@@ -158,8 +158,6 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, LargeLanguageModel):
|
||||
raise Exception("Model type instance is not LargeLanguageModel")
|
||||
|
||||
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
|
||||
return cast(
|
||||
Union[LLMResult, Generator],
|
||||
self._round_robin_invoke(
|
||||
@@ -188,8 +186,6 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, LargeLanguageModel):
|
||||
raise Exception("Model type instance is not LargeLanguageModel")
|
||||
|
||||
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
|
||||
return cast(
|
||||
int,
|
||||
self._round_robin_invoke(
|
||||
@@ -214,8 +210,6 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
|
||||
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
|
||||
return cast(
|
||||
TextEmbeddingResult,
|
||||
self._round_robin_invoke(
|
||||
@@ -237,8 +231,6 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
|
||||
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
|
||||
return cast(
|
||||
list[int],
|
||||
self._round_robin_invoke(
|
||||
@@ -269,8 +261,6 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, RerankModel):
|
||||
raise Exception("Model type instance is not RerankModel")
|
||||
|
||||
self.model_type_instance = cast(RerankModel, self.model_type_instance)
|
||||
return cast(
|
||||
RerankResult,
|
||||
self._round_robin_invoke(
|
||||
@@ -295,8 +285,6 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, ModerationModel):
|
||||
raise Exception("Model type instance is not ModerationModel")
|
||||
|
||||
self.model_type_instance = cast(ModerationModel, self.model_type_instance)
|
||||
return cast(
|
||||
bool,
|
||||
self._round_robin_invoke(
|
||||
@@ -318,8 +306,6 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, Speech2TextModel):
|
||||
raise Exception("Model type instance is not Speech2TextModel")
|
||||
|
||||
self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
|
||||
return cast(
|
||||
str,
|
||||
self._round_robin_invoke(
|
||||
@@ -343,8 +329,6 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TTSModel):
|
||||
raise Exception("Model type instance is not TTSModel")
|
||||
|
||||
self.model_type_instance = cast(TTSModel, self.model_type_instance)
|
||||
return cast(
|
||||
Iterable[bytes],
|
||||
self._round_robin_invoke(
|
||||
@@ -404,8 +388,6 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TTSModel):
|
||||
raise Exception("Model type instance is not TTSModel")
|
||||
|
||||
self.model_type_instance = cast(TTSModel, self.model_type_instance)
|
||||
return self.model_type_instance.get_tts_model_voices(
|
||||
model=self.model, credentials=self.credentials, language=language
|
||||
)
|
||||
|
||||
@@ -87,6 +87,7 @@ class MultiModalPromptMessageContent(PromptMessageContent):
|
||||
base64_data: str = Field(default="", description="the base64 data of multi-modal file")
|
||||
url: str = Field(default="", description="the url of multi-modal file")
|
||||
mime_type: str = Field(default=..., description="the mime type of multi-modal file")
|
||||
filename: str = Field(default="", description="the filename of multi-modal file")
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
|
||||
@@ -43,7 +43,7 @@ class GPT2Tokenizer:
|
||||
except Exception:
|
||||
from os.path import abspath, dirname, join
|
||||
|
||||
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
|
||||
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer
|
||||
|
||||
base_path = abspath(__file__)
|
||||
gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
|
||||
|
||||
@@ -330,7 +330,7 @@ class OpsTraceManager:
|
||||
except KeyError:
|
||||
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
|
||||
else:
|
||||
if tracing_provider is not None:
|
||||
if tracing_provider is None:
|
||||
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
|
||||
|
||||
app_config: Optional[App] = db.session.query(App).where(App.id == app_id).first()
|
||||
|
||||
@@ -375,16 +375,16 @@ Here is the extra instruction you need to follow:
|
||||
|
||||
# merge lines into messages with max tokens
|
||||
messages: list[str] = []
|
||||
for i in new_lines: # type: ignore
|
||||
for line in new_lines:
|
||||
if len(messages) == 0:
|
||||
messages.append(i) # type: ignore
|
||||
messages.append(line)
|
||||
else:
|
||||
if len(messages[-1]) + len(i) < max_tokens * 0.5: # type: ignore
|
||||
messages[-1] += i # type: ignore
|
||||
if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7: # type: ignore
|
||||
messages.append(i) # type: ignore
|
||||
if len(messages[-1]) + len(line) < max_tokens * 0.5:
|
||||
messages[-1] += line
|
||||
if get_prompt_tokens(messages[-1] + line) > max_tokens * 0.7:
|
||||
messages.append(line)
|
||||
else:
|
||||
messages[-1] += i # type: ignore
|
||||
messages[-1] += line
|
||||
|
||||
summaries = []
|
||||
for i in range(len(messages)):
|
||||
|
||||
@@ -13,6 +13,7 @@ from core.plugin.entities.base import BasePluginEntity
|
||||
from core.plugin.entities.endpoint import EndpointProviderDeclaration
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderEntity
|
||||
from core.trigger.entities.entities import TriggerProviderEntity
|
||||
|
||||
|
||||
class PluginInstallationSource(enum.StrEnum):
|
||||
@@ -62,6 +63,7 @@ class PluginCategory(enum.StrEnum):
|
||||
Model = "model"
|
||||
Extension = "extension"
|
||||
AgentStrategy = "agent-strategy"
|
||||
Trigger = "trigger"
|
||||
|
||||
|
||||
class PluginDeclaration(BaseModel):
|
||||
@@ -69,6 +71,7 @@ class PluginDeclaration(BaseModel):
|
||||
tools: Optional[list[str]] = Field(default_factory=list[str])
|
||||
models: Optional[list[str]] = Field(default_factory=list[str])
|
||||
endpoints: Optional[list[str]] = Field(default_factory=list[str])
|
||||
triggers: Optional[list[str]] = Field(default_factory=list[str])
|
||||
|
||||
class Meta(BaseModel):
|
||||
minimum_dify_version: Optional[str] = Field(default=None, pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$")
|
||||
@@ -89,6 +92,7 @@ class PluginDeclaration(BaseModel):
|
||||
repo: Optional[str] = Field(default=None)
|
||||
verified: bool = Field(default=False)
|
||||
tool: Optional[ToolProviderEntity] = None
|
||||
trigger: Optional[TriggerProviderEntity] = None
|
||||
model: Optional[ProviderEntity] = None
|
||||
endpoint: Optional[EndpointProviderDeclaration] = None
|
||||
agent_strategy: Optional[AgentStrategyProviderEntity] = None
|
||||
@@ -104,6 +108,8 @@ class PluginDeclaration(BaseModel):
|
||||
values["category"] = PluginCategory.Model
|
||||
elif values.get("agent_strategy"):
|
||||
values["category"] = PluginCategory.AgentStrategy
|
||||
elif values.get("trigger"):
|
||||
values["category"] = PluginCategory.Trigger
|
||||
else:
|
||||
values["category"] = PluginCategory.Extension
|
||||
return values
|
||||
@@ -184,6 +190,10 @@ class ToolProviderID(GenericProviderID):
|
||||
self.plugin_name = f"{self.provider_name}_tool"
|
||||
|
||||
|
||||
class TriggerProviderID(GenericProviderID):
|
||||
pass
|
||||
|
||||
|
||||
class PluginDependency(BaseModel):
|
||||
class Type(enum.StrEnum):
|
||||
Github = PluginInstallationSource.Github.value
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import enum
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
@@ -13,6 +14,7 @@ from core.plugin.entities.parameters import PluginParameterOption
|
||||
from core.plugin.entities.plugin import PluginDeclaration, PluginEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin
|
||||
from core.trigger.entities.entities import TriggerProviderEntity
|
||||
|
||||
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
|
||||
|
||||
@@ -196,3 +198,48 @@ class PluginListResponse(BaseModel):
|
||||
|
||||
class PluginDynamicSelectOptionsResponse(BaseModel):
|
||||
options: Sequence[PluginParameterOption] = Field(description="The options of the dynamic select.")
|
||||
|
||||
|
||||
class PluginTriggerProviderEntity(BaseModel):
|
||||
provider: str
|
||||
plugin_unique_identifier: str
|
||||
plugin_id: str
|
||||
declaration: TriggerProviderEntity
|
||||
|
||||
|
||||
class CredentialType(enum.StrEnum):
|
||||
API_KEY = "api-key"
|
||||
OAUTH2 = "oauth2"
|
||||
UNAUTHORIZED = "unauthorized"
|
||||
|
||||
def get_name(self):
|
||||
if self == CredentialType.API_KEY:
|
||||
return "API KEY"
|
||||
elif self == CredentialType.OAUTH2:
|
||||
return "AUTH"
|
||||
elif self == CredentialType.UNAUTHORIZED:
|
||||
return "UNAUTHORIZED"
|
||||
else:
|
||||
return self.value.replace("-", " ").upper()
|
||||
|
||||
def is_editable(self):
|
||||
return self == CredentialType.API_KEY
|
||||
|
||||
def is_validate_allowed(self):
|
||||
return self == CredentialType.API_KEY
|
||||
|
||||
@classmethod
|
||||
def values(cls):
|
||||
return [item.value for item in cls]
|
||||
|
||||
@classmethod
|
||||
def of(cls, credential_type: str) -> "CredentialType":
|
||||
type_name = credential_type.lower()
|
||||
if type_name == "api-key":
|
||||
return cls.API_KEY
|
||||
elif type_name == "oauth2":
|
||||
return cls.OAUTH2
|
||||
elif type_name == "unauthorized":
|
||||
return cls.UNAUTHORIZED
|
||||
else:
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from flask import Response
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
@@ -237,3 +239,33 @@ class RequestFetchAppInfo(BaseModel):
|
||||
"""
|
||||
|
||||
app_id: str
|
||||
|
||||
|
||||
class Event(BaseModel):
|
||||
variables: Mapping[str, Any]
|
||||
|
||||
|
||||
class TriggerInvokeResponse(BaseModel):
|
||||
event: Event
|
||||
|
||||
|
||||
class PluginTriggerDispatchResponse(BaseModel):
|
||||
triggers: list[str]
|
||||
raw_http_response: str
|
||||
|
||||
|
||||
class TriggerSubscriptionResponse(BaseModel):
|
||||
subscription: dict[str, Any]
|
||||
|
||||
|
||||
class TriggerValidateProviderCredentialsResponse(BaseModel):
|
||||
result: bool
|
||||
|
||||
|
||||
class TriggerDispatchResponse:
|
||||
triggers: list[str]
|
||||
response: Response
|
||||
|
||||
def __init__(self, triggers: list[str], response: Response):
|
||||
self.triggers = triggers
|
||||
self.response = response
|
||||
|
||||
@@ -15,6 +15,7 @@ class DynamicSelectClient(BasePluginClient):
|
||||
provider: str,
|
||||
action: str,
|
||||
credentials: Mapping[str, Any],
|
||||
credential_type: str,
|
||||
parameter: str,
|
||||
) -> PluginDynamicSelectOptionsResponse:
|
||||
"""
|
||||
@@ -29,6 +30,7 @@ class DynamicSelectClient(BasePluginClient):
|
||||
"data": {
|
||||
"provider": GenericProviderID(provider).provider_name,
|
||||
"credentials": credentials,
|
||||
"credential_type": credential_type,
|
||||
"provider_action": action,
|
||||
"parameter": parameter,
|
||||
},
|
||||
|
||||
@@ -4,10 +4,10 @@ from typing import Any, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
|
||||
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
|
||||
from core.plugin.entities.plugin_daemon import CredentialType, PluginBasicBooleanResponse, PluginToolProviderEntity
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from core.plugin.utils.chunk_merger import merge_blob_chunks
|
||||
from core.tools.entities.tool_entities import CredentialType, ToolInvokeMessage, ToolParameter
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
|
||||
|
||||
class PluginToolManager(BasePluginClient):
|
||||
|
||||
301
api/core/plugin/impl/trigger.py
Normal file
301
api/core/plugin/impl/trigger.py
Normal file
@@ -0,0 +1,301 @@
|
||||
import binascii
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import Request
|
||||
|
||||
from core.plugin.entities.plugin import GenericProviderID, TriggerProviderID
|
||||
from core.plugin.entities.plugin_daemon import CredentialType, PluginTriggerProviderEntity
|
||||
from core.plugin.entities.request import (
|
||||
PluginTriggerDispatchResponse,
|
||||
TriggerDispatchResponse,
|
||||
TriggerInvokeResponse,
|
||||
TriggerSubscriptionResponse,
|
||||
TriggerValidateProviderCredentialsResponse,
|
||||
)
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from core.plugin.utils.http_parser import deserialize_response, serialize_request
|
||||
from core.trigger.entities.entities import Subscription
|
||||
|
||||
|
||||
class PluginTriggerManager(BasePluginClient):
|
||||
def fetch_trigger_providers(self, tenant_id: str) -> list[PluginTriggerProviderEntity]:
|
||||
"""
|
||||
Fetch trigger providers for the given tenant.
|
||||
"""
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
for provider in json_response.get("data", []):
|
||||
declaration = provider.get("declaration", {}) or {}
|
||||
provider_id = provider.get("plugin_id") + "/" + provider.get("provider")
|
||||
for trigger in declaration.get("triggers", []):
|
||||
trigger["identity"]["provider"] = provider_id
|
||||
|
||||
return json_response
|
||||
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/triggers",
|
||||
list[PluginTriggerProviderEntity],
|
||||
params={"page": 1, "page_size": 256},
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
for provider in response:
|
||||
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each trigger to plugin_id/provider_name
|
||||
for trigger in provider.declaration.triggers:
|
||||
trigger.identity.provider = provider.declaration.identity.name
|
||||
|
||||
return response
|
||||
|
||||
def fetch_trigger_provider(self, tenant_id: str, provider_id: TriggerProviderID) -> PluginTriggerProviderEntity:
|
||||
"""
|
||||
Fetch trigger provider for the given tenant and plugin.
|
||||
"""
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
data = json_response.get("data")
|
||||
if data:
|
||||
for trigger in data.get("declaration", {}).get("triggers", []):
|
||||
trigger["identity"]["provider"] = str(provider_id)
|
||||
|
||||
return json_response
|
||||
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/trigger",
|
||||
PluginTriggerProviderEntity,
|
||||
params={"provider": provider_id.provider_name, "plugin_id": provider_id.plugin_id},
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
response.declaration.identity.name = str(provider_id)
|
||||
|
||||
# override the provider name for each trigger to plugin_id/provider_name
|
||||
for trigger in response.declaration.triggers:
|
||||
trigger.identity.provider = str(provider_id)
|
||||
|
||||
return response
|
||||
|
||||
def invoke_trigger(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
trigger: str,
|
||||
credentials: Mapping[str, str],
|
||||
credential_type: CredentialType,
|
||||
request: Request,
|
||||
parameters: Mapping[str, Any],
|
||||
) -> TriggerInvokeResponse:
|
||||
"""
|
||||
Invoke a trigger with the given parameters.
|
||||
"""
|
||||
trigger_provider_id = GenericProviderID(provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/trigger/invoke",
|
||||
TriggerInvokeResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": trigger_provider_id.provider_name,
|
||||
"trigger": trigger,
|
||||
"credentials": credentials,
|
||||
"credential_type": credential_type,
|
||||
"raw_http_request": binascii.hexlify(serialize_request(request)).decode(),
|
||||
"parameters": parameters,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": trigger_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return TriggerInvokeResponse(event=resp.event)
|
||||
|
||||
raise ValueError("No response received from plugin daemon for invoke trigger")
|
||||
|
||||
def validate_provider_credentials(
|
||||
self, tenant_id: str, user_id: str, provider: str, credentials: Mapping[str, str]
|
||||
) -> bool:
|
||||
"""
|
||||
Validate the credentials of the trigger provider.
|
||||
"""
|
||||
trigger_provider_id = GenericProviderID(provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/trigger/validate_credentials",
|
||||
TriggerValidateProviderCredentialsResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": trigger_provider_id.provider_name,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": trigger_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp.result
|
||||
|
||||
raise ValueError("No response received from plugin daemon for validate provider credentials")
|
||||
|
||||
def dispatch_event(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
subscription: Mapping[str, Any],
|
||||
request: Request,
|
||||
) -> TriggerDispatchResponse:
|
||||
"""
|
||||
Dispatch an event to triggers.
|
||||
"""
|
||||
trigger_provider_id = GenericProviderID(provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/trigger/dispatch_event",
|
||||
PluginTriggerDispatchResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": trigger_provider_id.provider_name,
|
||||
"subscription": subscription,
|
||||
"raw_http_request": binascii.hexlify(serialize_request(request)).decode(),
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": trigger_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return TriggerDispatchResponse(
|
||||
triggers=resp.triggers,
|
||||
response=deserialize_response(binascii.unhexlify(resp.raw_http_response.encode())),
|
||||
)
|
||||
|
||||
raise ValueError("No response received from plugin daemon for dispatch event")
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
credentials: Mapping[str, str],
|
||||
endpoint: str,
|
||||
parameters: Mapping[str, Any],
|
||||
) -> TriggerSubscriptionResponse:
|
||||
"""
|
||||
Subscribe to a trigger.
|
||||
"""
|
||||
trigger_provider_id = GenericProviderID(provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/trigger/subscribe",
|
||||
TriggerSubscriptionResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": trigger_provider_id.provider_name,
|
||||
"credentials": credentials,
|
||||
"endpoint": endpoint,
|
||||
"parameters": parameters,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": trigger_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp
|
||||
|
||||
raise ValueError("No response received from plugin daemon for subscribe")
|
||||
|
||||
def unsubscribe(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
subscription: Subscription,
|
||||
credentials: Mapping[str, str],
|
||||
) -> TriggerSubscriptionResponse:
|
||||
"""
|
||||
Unsubscribe from a trigger.
|
||||
"""
|
||||
trigger_provider_id = GenericProviderID(provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/trigger/unsubscribe",
|
||||
TriggerSubscriptionResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": trigger_provider_id.provider_name,
|
||||
"subscription": subscription.model_dump(),
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": trigger_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp
|
||||
|
||||
raise ValueError("No response received from plugin daemon for unsubscribe")
|
||||
|
||||
def refresh(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
subscription: Subscription,
|
||||
credentials: Mapping[str, str],
|
||||
) -> TriggerSubscriptionResponse:
|
||||
"""
|
||||
Refresh a trigger subscription.
|
||||
"""
|
||||
trigger_provider_id = GenericProviderID(provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/trigger/refresh",
|
||||
TriggerSubscriptionResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": trigger_provider_id.provider_name,
|
||||
"subscription": subscription.model_dump(),
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": trigger_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp
|
||||
|
||||
raise ValueError("No response received from plugin daemon for refresh")
|
||||
159
api/core/plugin/utils/http_parser.py
Normal file
159
api/core/plugin/utils/http_parser.py
Normal file
@@ -0,0 +1,159 @@
|
||||
from io import BytesIO
|
||||
|
||||
from flask import Request, Response
|
||||
from werkzeug.datastructures import Headers
|
||||
|
||||
|
||||
def serialize_request(request: Request) -> bytes:
|
||||
method = request.method
|
||||
path = request.full_path.rstrip("?")
|
||||
raw = f"{method} {path} HTTP/1.1\r\n".encode()
|
||||
|
||||
for name, value in request.headers.items():
|
||||
raw += f"{name}: {value}\r\n".encode()
|
||||
|
||||
raw += b"\r\n"
|
||||
|
||||
body = request.get_data(as_text=False)
|
||||
if body:
|
||||
raw += body
|
||||
|
||||
return raw
|
||||
|
||||
|
||||
def deserialize_request(raw_data: bytes) -> Request:
|
||||
header_end = raw_data.find(b"\r\n\r\n")
|
||||
if header_end == -1:
|
||||
header_end = raw_data.find(b"\n\n")
|
||||
if header_end == -1:
|
||||
header_data = raw_data
|
||||
body = b""
|
||||
else:
|
||||
header_data = raw_data[:header_end]
|
||||
body = raw_data[header_end + 2 :]
|
||||
else:
|
||||
header_data = raw_data[:header_end]
|
||||
body = raw_data[header_end + 4 :]
|
||||
|
||||
lines = header_data.split(b"\r\n")
|
||||
if len(lines) == 1 and b"\n" in lines[0]:
|
||||
lines = header_data.split(b"\n")
|
||||
|
||||
if not lines or not lines[0]:
|
||||
raise ValueError("Empty HTTP request")
|
||||
|
||||
request_line = lines[0].decode("utf-8", errors="ignore")
|
||||
parts = request_line.split(" ", 2)
|
||||
if len(parts) < 2:
|
||||
raise ValueError(f"Invalid request line: {request_line}")
|
||||
|
||||
method = parts[0]
|
||||
full_path = parts[1]
|
||||
protocol = parts[2] if len(parts) > 2 else "HTTP/1.1"
|
||||
|
||||
if "?" in full_path:
|
||||
path, query_string = full_path.split("?", 1)
|
||||
else:
|
||||
path = full_path
|
||||
query_string = ""
|
||||
|
||||
headers = Headers()
|
||||
for line in lines[1:]:
|
||||
if not line:
|
||||
continue
|
||||
line_str = line.decode("utf-8", errors="ignore")
|
||||
if ":" not in line_str:
|
||||
continue
|
||||
name, value = line_str.split(":", 1)
|
||||
headers.add(name, value.strip())
|
||||
|
||||
host = headers.get("Host", "localhost")
|
||||
if ":" in host:
|
||||
server_name, server_port = host.rsplit(":", 1)
|
||||
else:
|
||||
server_name = host
|
||||
server_port = "80"
|
||||
|
||||
environ = {
|
||||
"REQUEST_METHOD": method,
|
||||
"PATH_INFO": path,
|
||||
"QUERY_STRING": query_string,
|
||||
"SERVER_NAME": server_name,
|
||||
"SERVER_PORT": server_port,
|
||||
"SERVER_PROTOCOL": protocol,
|
||||
"wsgi.input": BytesIO(body),
|
||||
"wsgi.url_scheme": "http",
|
||||
}
|
||||
|
||||
if "Content-Type" in headers:
|
||||
environ["CONTENT_TYPE"] = headers.get("Content-Type")
|
||||
|
||||
if "Content-Length" in headers:
|
||||
environ["CONTENT_LENGTH"] = headers.get("Content-Length")
|
||||
elif body:
|
||||
environ["CONTENT_LENGTH"] = str(len(body))
|
||||
|
||||
for name, value in headers.items():
|
||||
if name.upper() in ("CONTENT-TYPE", "CONTENT-LENGTH"):
|
||||
continue
|
||||
env_name = f"HTTP_{name.upper().replace('-', '_')}"
|
||||
environ[env_name] = value
|
||||
|
||||
return Request(environ)
|
||||
|
||||
|
||||
def serialize_response(response: Response) -> bytes:
|
||||
raw = f"HTTP/1.1 {response.status}\r\n".encode()
|
||||
|
||||
for name, value in response.headers.items():
|
||||
raw += f"{name}: {value}\r\n".encode()
|
||||
|
||||
raw += b"\r\n"
|
||||
|
||||
body = response.get_data(as_text=False)
|
||||
if body:
|
||||
raw += body
|
||||
|
||||
return raw
|
||||
|
||||
|
||||
def deserialize_response(raw_data: bytes) -> Response:
|
||||
header_end = raw_data.find(b"\r\n\r\n")
|
||||
if header_end == -1:
|
||||
header_end = raw_data.find(b"\n\n")
|
||||
if header_end == -1:
|
||||
header_data = raw_data
|
||||
body = b""
|
||||
else:
|
||||
header_data = raw_data[:header_end]
|
||||
body = raw_data[header_end + 2 :]
|
||||
else:
|
||||
header_data = raw_data[:header_end]
|
||||
body = raw_data[header_end + 4 :]
|
||||
|
||||
lines = header_data.split(b"\r\n")
|
||||
if len(lines) == 1 and b"\n" in lines[0]:
|
||||
lines = header_data.split(b"\n")
|
||||
|
||||
if not lines or not lines[0]:
|
||||
raise ValueError("Empty HTTP response")
|
||||
|
||||
status_line = lines[0].decode("utf-8", errors="ignore")
|
||||
parts = status_line.split(" ", 2)
|
||||
if len(parts) < 2:
|
||||
raise ValueError(f"Invalid status line: {status_line}")
|
||||
|
||||
status_code = int(parts[1])
|
||||
|
||||
response = Response(response=body, status=status_code)
|
||||
|
||||
for line in lines[1:]:
|
||||
if not line:
|
||||
continue
|
||||
line_str = line.decode("utf-8", errors="ignore")
|
||||
if ":" not in line_str:
|
||||
continue
|
||||
name, value = line_str.split(":", 1)
|
||||
response.headers[name] = value.strip()
|
||||
|
||||
return response
|
||||
@@ -87,7 +87,6 @@ class PromptMessageUtil:
|
||||
if isinstance(prompt_message.content, list):
|
||||
for content in prompt_message.content:
|
||||
if content.type == PromptMessageContentType.TEXT:
|
||||
content = cast(TextPromptMessageContent, content)
|
||||
text += content.data
|
||||
else:
|
||||
content = cast(ImagePromptMessageContent, content)
|
||||
|
||||
@@ -2,7 +2,7 @@ import contextlib
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
@@ -154,8 +154,8 @@ class ProviderManager:
|
||||
for provider_entity in provider_entities:
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=cast(set[str], dify_config.POSITION_PROVIDER_INCLUDES_SET),
|
||||
exclude_set=cast(set[str], dify_config.POSITION_PROVIDER_EXCLUDES_SET),
|
||||
include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET,
|
||||
exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET,
|
||||
data=provider_entity,
|
||||
name_func=lambda x: x.provider,
|
||||
):
|
||||
|
||||
@@ -24,7 +24,7 @@ default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"top_k": 4,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import psycopg2.extras # type: ignore
|
||||
import psycopg2.pool # type: ignore
|
||||
import psycopg2.extras
|
||||
import psycopg2.pool
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
@@ -304,7 +304,7 @@ class CouchbaseVector(BaseVector):
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 2)
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
try:
|
||||
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
|
||||
search_iter = self._scope.search(
|
||||
|
||||
@@ -3,8 +3,8 @@ import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import psycopg2.extras # type: ignore
|
||||
import psycopg2.pool # type: ignore
|
||||
import psycopg2.extras
|
||||
import psycopg2.pool
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
@@ -48,7 +48,7 @@ class OpenSearchConfig(BaseModel):
|
||||
return values
|
||||
|
||||
def create_aws_managed_iam_auth(self) -> Urllib3AWSV4SignerAuth:
|
||||
import boto3 # type: ignore
|
||||
import boto3
|
||||
|
||||
return Urllib3AWSV4SignerAuth(
|
||||
credentials=boto3.Session().get_credentials(),
|
||||
|
||||
@@ -6,8 +6,8 @@ from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import psycopg2.errors
|
||||
import psycopg2.extras # type: ignore
|
||||
import psycopg2.pool # type: ignore
|
||||
import psycopg2.extras
|
||||
import psycopg2.pool
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
@@ -3,8 +3,8 @@ import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import psycopg2.extras # type: ignore
|
||||
import psycopg2.pool # type: ignore
|
||||
import psycopg2.extras
|
||||
import psycopg2.pool
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import uuid
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from itertools import islice
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import qdrant_client
|
||||
from flask import current_app
|
||||
@@ -426,7 +426,6 @@ class QdrantVector(BaseVector):
|
||||
|
||||
def _reload_if_needed(self):
|
||||
if isinstance(self._client, QdrantLocal):
|
||||
self._client = cast(QdrantLocal, self._client)
|
||||
self._client._load()
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import uuid
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from itertools import islice
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import qdrant_client
|
||||
import requests
|
||||
@@ -398,7 +398,6 @@ class TidbOnQdrantVector(BaseVector):
|
||||
|
||||
def _reload_if_needed(self):
|
||||
if isinstance(self._client, QdrantLocal):
|
||||
self._client = cast(QdrantLocal, self._client)
|
||||
self._client._load()
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -4,7 +4,7 @@ import os
|
||||
from typing import Optional, cast
|
||||
|
||||
import pandas as pd
|
||||
from openpyxl import load_workbook # type: ignore
|
||||
from openpyxl import load_workbook
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
@@ -73,8 +73,8 @@ class ExtractProcessor:
|
||||
suffix = "." + match.group(1)
|
||||
else:
|
||||
suffix = ""
|
||||
# FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
|
||||
# https://stackoverflow.com/questions/26541416/generate-temporary-file-names-without-creating-actual-file-in-python#comment90414256_26541521
|
||||
file_path = f"{temp_dir}/{tempfile.gettempdir()}{suffix}"
|
||||
Path(file_path).write_bytes(response.content)
|
||||
extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model")
|
||||
if return_text:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
from bs4 import BeautifulSoup # type: ignore
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional, cast
|
||||
from typing import Optional
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.extractor.helpers import detect_file_encodings
|
||||
@@ -76,7 +76,7 @@ class MarkdownExtractor(BaseExtractor):
|
||||
markdown_tups.append((current_header, current_text))
|
||||
|
||||
markdown_tups = [
|
||||
(re.sub(r"#", "", cast(str, key)).strip() if key else None, re.sub(r"<.*?>", "", value))
|
||||
(re.sub(r"#", "", key).strip() if key else None, re.sub(r"<.*?>", "", value))
|
||||
for key, value in markdown_tups
|
||||
]
|
||||
|
||||
|
||||
@@ -385,4 +385,4 @@ class NotionExtractor(BaseExtractor):
|
||||
f"No notion data source binding found for tenant {tenant_id} and notion workspace {notion_workspace_id}"
|
||||
)
|
||||
|
||||
return cast(str, data_source_binding.access_token)
|
||||
return data_source_binding.access_token
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import contextlib
|
||||
from collections.abc import Iterator
|
||||
from typing import Optional, cast
|
||||
from typing import Optional
|
||||
|
||||
from core.rag.extractor.blob.blob import Blob
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
@@ -27,7 +27,7 @@ class PdfExtractor(BaseExtractor):
|
||||
plaintext_file_exists = False
|
||||
if self._file_cache_key:
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8")
|
||||
text = storage.load(self._file_cache_key).decode("utf-8")
|
||||
plaintext_file_exists = True
|
||||
return [Document(page_content=text)]
|
||||
documents = list(self.load())
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user