mirror of
https://github.com/langgenius/dify.git
synced 2026-04-12 17:19:21 +08:00
Compare commits
177 Commits
fix/templa
...
test/log-r
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d52d80681e | ||
|
|
bac7da83f5 | ||
|
|
0fa063c640 | ||
|
|
40d35304ea | ||
|
|
89821d66bb | ||
|
|
09d84e900c | ||
|
|
a8746bff30 | ||
|
|
c4d8bf0ce9 | ||
|
|
9cca605bac | ||
|
|
dbd23f91e5 | ||
|
|
9387cc088c | ||
|
|
11f7a89e25 | ||
|
|
654d522b31 | ||
|
|
31e6ef77a6 | ||
|
|
e56c847210 | ||
|
|
e00172199a | ||
|
|
04f47836d8 | ||
|
|
faaca822e4 | ||
|
|
dc0f053925 | ||
|
|
517726da3a | ||
|
|
1d6c03eddf | ||
|
|
fdfccd1205 | ||
|
|
b30e7ced0a | ||
|
|
11770439be | ||
|
|
d89c5f7146 | ||
|
|
4a475bf1cd | ||
|
|
10be9cfbbf | ||
|
|
c20e0ad90d | ||
|
|
22f64d60bb | ||
|
|
7b7d332239 | ||
|
|
b1d189324a | ||
|
|
00fb468f2e | ||
|
|
bbbb6e04cb | ||
|
|
f5161d9add | ||
|
|
787251f00e | ||
|
|
cfe21f0826 | ||
|
|
196f691865 | ||
|
|
7a5bb1cfac | ||
|
|
b80d55b764 | ||
|
|
dd71625f52 | ||
|
|
19936d23d1 | ||
|
|
decf0f3da0 | ||
|
|
7242a67f84 | ||
|
|
c4884eb669 | ||
|
|
d49f3327e4 | ||
|
|
633e68a2f7 | ||
|
|
809f48f733 | ||
|
|
578b1b45ea | ||
|
|
86c3c58e64 | ||
|
|
8d803a26eb | ||
|
|
aa3129c2a9 | ||
|
|
97c924fe29 | ||
|
|
591c463e4b | ||
|
|
e1691fddaa | ||
|
|
b4d4351203 | ||
|
|
f7b1348623 | ||
|
|
2619c7553a | ||
|
|
f79d8baf63 | ||
|
|
bbdcbac544 | ||
|
|
d552680e72 | ||
|
|
df43c6ab8a | ||
|
|
cd47a47c3b | ||
|
|
e5d4235f1b | ||
|
|
f60aa36fa0 | ||
|
|
b2bcb6d21a | ||
|
|
b6cea71023 | ||
|
|
6462328620 | ||
|
|
fd86cadf67 | ||
|
|
c43c72c1a3 | ||
|
|
d77c2e4d17 | ||
|
|
1a7898dff1 | ||
|
|
af662b100b | ||
|
|
595df172a8 | ||
|
|
70bc5ca7f4 | ||
|
|
30617feff8 | ||
|
|
756864c85b | ||
|
|
c8c94ef870 | ||
|
|
10d51ada59 | ||
|
|
00f3a53f1c | ||
|
|
d2f0551170 | ||
|
|
cba2b9b2ad | ||
|
|
029d5d36ac | ||
|
|
8d897153a5 | ||
|
|
2e914808ea | ||
|
|
d00a72a435 | ||
|
|
36580221aa | ||
|
|
e686cc9eab | ||
|
|
66196459d5 | ||
|
|
a5387b304e | ||
|
|
beb1448441 | ||
|
|
272102c06d | ||
|
|
36406cd62f | ||
|
|
87c41c88a3 | ||
|
|
095c56a646 | ||
|
|
244c132656 | ||
|
|
043ec46c33 | ||
|
|
0e4f19eee0 | ||
|
|
ff34969f21 | ||
|
|
9a7245e1df | ||
|
|
4906eeac18 | ||
|
|
4da93ba579 | ||
|
|
319ecdd312 | ||
|
|
0c1ec35244 | ||
|
|
46375aacdb | ||
|
|
e6d4331994 | ||
|
|
2a0abc51b1 | ||
|
|
3bb67885ef | ||
|
|
e682749d03 | ||
|
|
9b83b0aadd | ||
|
|
0cac330bc2 | ||
|
|
fb8114792a | ||
|
|
eab6f65409 | ||
|
|
915023b809 | ||
|
|
f104839672 | ||
|
|
6841a09667 | ||
|
|
e937c8c72e | ||
|
|
960bb8a9b4 | ||
|
|
9b36059292 | ||
|
|
a4acc64afd | ||
|
|
25c69ac540 | ||
|
|
96a0b9991e | ||
|
|
2913d17fe2 | ||
|
|
d9e45a1abe | ||
|
|
24b4289d6c | ||
|
|
fb6ccccc3d | ||
|
|
8b74ae683a | ||
|
|
dd08957381 | ||
|
|
407323f817 | ||
|
|
2e2c87c5a1 | ||
|
|
f4522fd695 | ||
|
|
760a2c656c | ||
|
|
8940decd1b | ||
|
|
0c4193bd91 | ||
|
|
cd40cde790 | ||
|
|
c60c754ac9 | ||
|
|
ef80d3b707 | ||
|
|
24e8d21b3f | ||
|
|
d823da18db | ||
|
|
1e3df09fc6 | ||
|
|
75a10c276c | ||
|
|
50050527eb | ||
|
|
a39b185627 | ||
|
|
15270f09af | ||
|
|
f6a5ac0698 | ||
|
|
2b79da722b | ||
|
|
71d69e43cd | ||
|
|
5bc6e8a433 | ||
|
|
68076f2e22 | ||
|
|
8c38363038 | ||
|
|
345ac8333c | ||
|
|
2375047ef0 | ||
|
|
857a48012e | ||
|
|
208fe3d7de | ||
|
|
92cddbcc02 | ||
|
|
599b53c9cb | ||
|
|
062b173c66 | ||
|
|
db690013fd | ||
|
|
e93bfe3d41 | ||
|
|
ab910c736c | ||
|
|
4047a6bb12 | ||
|
|
df2478dc26 | ||
|
|
4cc3f6045b | ||
|
|
1550316b8d | ||
|
|
87394d2512 | ||
|
|
bad59c95bc | ||
|
|
9f138ef246 | ||
|
|
6453fc4973 | ||
|
|
f62f926537 | ||
|
|
b3dafd913b | ||
|
|
b2d8a7eaf1 | ||
|
|
3e54414191 | ||
|
|
a173546c8d | ||
|
|
aa69d90489 | ||
|
|
4ba1292455 | ||
|
|
bb01c31f30 | ||
|
|
cd90b2ca9e | ||
|
|
9a65350cf7 |
@@ -1,4 +1,4 @@
|
||||
FROM mcr.microsoft.com/devcontainers/python:3.12-bullseye
|
||||
FROM mcr.microsoft.com/devcontainers/python:3.12-bookworm
|
||||
|
||||
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
||||
&& apt-get -y install libgmp-dev libmpfr-dev libmpc-dev
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
#!/bin/bash
|
||||
WORKSPACE_ROOT=$(pwd)
|
||||
|
||||
corepack enable
|
||||
cd web && pnpm install
|
||||
pipx install uv
|
||||
|
||||
echo 'alias start-api="cd /workspaces/dify/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
|
||||
echo 'alias start-worker="cd /workspaces/dify/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage"' >> ~/.bashrc
|
||||
echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc
|
||||
echo 'alias start-web-prod="cd /workspaces/dify/web && pnpm build && pnpm start"' >> ~/.bashrc
|
||||
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d"' >> ~/.bashrc
|
||||
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down"' >> ~/.bashrc
|
||||
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage\"" >> ~/.bashrc
|
||||
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
|
||||
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
|
||||
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
|
||||
echo "alias stop-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down\"" >> ~/.bashrc
|
||||
|
||||
source /home/vscode/.bashrc
|
||||
|
||||
|
||||
3
.github/ISSUE_TEMPLATE/config.yml
vendored
3
.github/ISSUE_TEMPLATE/config.yml
vendored
@@ -1,5 +1,8 @@
|
||||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: "\U0001F510 Security Vulnerabilities"
|
||||
url: "https://github.com/langgenius/dify/security/advisories/new"
|
||||
about: Report security vulnerabilities through GitHub Security Advisories to ensure responsible disclosure. 💡 Please do not report security vulnerabilities in public issues.
|
||||
- name: "\U0001F4A1 Model Providers & Plugins"
|
||||
url: "https://github.com/langgenius/dify-official-plugins/issues/new/choose"
|
||||
about: Report issues with official plugins or model providers, you will need to provide the plugin version and other relevant details.
|
||||
|
||||
4
.github/workflows/autofix.yml
vendored
4
.github/workflows/autofix.yml
vendored
@@ -15,10 +15,12 @@ jobs:
|
||||
# Use uv to ensure we have the same ruff version in CI and locally.
|
||||
- uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
python-version: "3.12"
|
||||
python-version: "3.11"
|
||||
- run: |
|
||||
cd api
|
||||
uv sync --dev
|
||||
# fmt first to avoid line too long
|
||||
uv run ruff format ..
|
||||
# Fix lint errors
|
||||
uv run ruff check --fix .
|
||||
# Format code
|
||||
|
||||
3
.github/workflows/build-push.yml
vendored
3
.github/workflows/build-push.yml
vendored
@@ -8,8 +8,7 @@ on:
|
||||
- "deploy/enterprise"
|
||||
- "build/**"
|
||||
- "release/e-*"
|
||||
- "deploy/rag-dev"
|
||||
- "feat/rag-2"
|
||||
- "hotfix/**"
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
|
||||
4
.github/workflows/deploy-dev.yml
vendored
4
.github/workflows/deploy-dev.yml
vendored
@@ -4,7 +4,7 @@ on:
|
||||
workflow_run:
|
||||
workflows: ["Build and Push API & Web"]
|
||||
branches:
|
||||
- "deploy/rag-dev"
|
||||
- "deploy/dev"
|
||||
types:
|
||||
- completed
|
||||
|
||||
@@ -13,7 +13,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/rag-dev'
|
||||
github.event.workflow_run.head_branch == 'deploy/dev'
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v0.1.8
|
||||
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -230,4 +230,8 @@ api/.env.backup
|
||||
|
||||
# Benchmark
|
||||
scripts/stress-test/setup/config/
|
||||
scripts/stress-test/reports/
|
||||
scripts/stress-test/reports/
|
||||
|
||||
# mcp
|
||||
.playwright-mcp/
|
||||
.serena/
|
||||
89
AGENTS.md
89
AGENTS.md
@@ -4,84 +4,51 @@
|
||||
|
||||
Dify is an open-source platform for developing LLM applications with an intuitive interface combining agentic AI workflows, RAG pipelines, agent capabilities, and model management.
|
||||
|
||||
The codebase consists of:
|
||||
The codebase is split into:
|
||||
|
||||
- **Backend API** (`/api`): Python Flask application with Domain-Driven Design architecture
|
||||
- **Frontend Web** (`/web`): Next.js 15 application with TypeScript and React 19
|
||||
- **Backend API** (`/api`): Python Flask application organized with Domain-Driven Design
|
||||
- **Frontend Web** (`/web`): Next.js 15 application using TypeScript and React 19
|
||||
- **Docker deployment** (`/docker`): Containerized deployment configurations
|
||||
|
||||
## Development Commands
|
||||
## Backend Workflow
|
||||
|
||||
### Backend (API)
|
||||
- Run backend CLI commands through `uv run --project api <command>`.
|
||||
|
||||
All Python commands must be prefixed with `uv run --project api`:
|
||||
- Backend QA gate requires passing `make lint`, `make type-check`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh` before review.
|
||||
|
||||
```bash
|
||||
# Start development servers
|
||||
./dev/start-api # Start API server
|
||||
./dev/start-worker # Start Celery worker
|
||||
- Use Makefile targets for linting and formatting; `make lint` and `make type-check` cover the required checks.
|
||||
|
||||
# Run tests
|
||||
uv run --project api pytest # Run all tests
|
||||
uv run --project api pytest tests/unit_tests/ # Unit tests only
|
||||
uv run --project api pytest tests/integration_tests/ # Integration tests
|
||||
- Integration tests are CI-only and are not expected to run in the local environment.
|
||||
|
||||
# Code quality
|
||||
./dev/reformat # Run all formatters and linters
|
||||
uv run --project api ruff check --fix ./ # Fix linting issues
|
||||
uv run --project api ruff format ./ # Format code
|
||||
uv run --directory api basedpyright # Type checking
|
||||
```
|
||||
|
||||
### Frontend (Web)
|
||||
## Frontend Workflow
|
||||
|
||||
```bash
|
||||
cd web
|
||||
pnpm lint # Run ESLint
|
||||
pnpm eslint-fix # Fix ESLint issues
|
||||
pnpm test # Run Jest tests
|
||||
pnpm lint
|
||||
pnpm lint:fix
|
||||
pnpm test
|
||||
```
|
||||
|
||||
## Testing Guidelines
|
||||
## Testing & Quality Practices
|
||||
|
||||
### Backend Testing
|
||||
- Follow TDD: red → green → refactor.
|
||||
- Use `pytest` for backend tests with Arrange-Act-Assert structure.
|
||||
- Enforce strong typing; avoid `Any` and prefer explicit type annotations.
|
||||
- Write self-documenting code; only add comments that explain intent.
|
||||
|
||||
- Use `pytest` for all backend tests
|
||||
- Write tests first (TDD approach)
|
||||
- Test structure: Arrange-Act-Assert
|
||||
## Language Style
|
||||
|
||||
## Code Style Requirements
|
||||
- **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`).
|
||||
- **TypeScript**: Use the strict config, lean on ESLint + Prettier workflows, and avoid `any` types.
|
||||
|
||||
### Python
|
||||
## General Practices
|
||||
|
||||
- Use type hints for all functions and class attributes
|
||||
- No `Any` types unless absolutely necessary
|
||||
- Implement special methods (`__repr__`, `__str__`) appropriately
|
||||
- Prefer editing existing files; add new documentation only when requested.
|
||||
- Inject dependencies through constructors and preserve clean architecture boundaries.
|
||||
- Handle errors with domain-specific exceptions at the correct layer.
|
||||
|
||||
### TypeScript/JavaScript
|
||||
## Project Conventions
|
||||
|
||||
- Strict TypeScript configuration
|
||||
- ESLint with Prettier integration
|
||||
- Avoid `any` type
|
||||
|
||||
## Important Notes
|
||||
|
||||
- **Environment Variables**: Always use UV for Python commands: `uv run --project api <command>`
|
||||
- **Comments**: Only write meaningful comments that explain "why", not "what"
|
||||
- **File Creation**: Always prefer editing existing files over creating new ones
|
||||
- **Documentation**: Don't create documentation files unless explicitly requested
|
||||
- **Code Quality**: Always run `./dev/reformat` before committing backend changes
|
||||
|
||||
## Common Development Tasks
|
||||
|
||||
### Adding a New API Endpoint
|
||||
|
||||
1. Create controller in `/api/controllers/`
|
||||
1. Add service logic in `/api/services/`
|
||||
1. Update routes in controller's `__init__.py`
|
||||
1. Write tests in `/api/tests/`
|
||||
|
||||
## Project-Specific Conventions
|
||||
|
||||
- All async tasks use Celery with Redis as broker
|
||||
- **Internationalization**: Frontend supports multiple languages with English (`web/i18n/en-US/`) as the source. All user-facing text must use i18n keys, no hardcoded strings. Edit corresponding module files in `en-US/` directory for translations.
|
||||
- Backend architecture adheres to DDD and Clean Architecture principles.
|
||||
- Async work runs through Celery with Redis as the broker.
|
||||
- Frontend user-facing strings must use `web/i18n/en-US/`; avoid hardcoded text.
|
||||
|
||||
6
Makefile
6
Makefile
@@ -26,7 +26,6 @@ prepare-web:
|
||||
@echo "🌐 Setting up web environment..."
|
||||
@cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists"
|
||||
@cd web && pnpm install
|
||||
@cd web && pnpm build
|
||||
@echo "✅ Web environment prepared (not started)"
|
||||
|
||||
# Step 3: Prepare API environment
|
||||
@@ -61,8 +60,9 @@ check:
|
||||
@echo "✅ Code check complete"
|
||||
|
||||
lint:
|
||||
@echo "🔧 Running ruff format and check with fixes..."
|
||||
@uv run --directory api --dev sh -c 'ruff format ./api && ruff check --fix ./api'
|
||||
@echo "🔧 Running ruff format, check with fixes, and import linter..."
|
||||
@uv run --project api --dev sh -c 'ruff format ./api && ruff check --fix ./api'
|
||||
@uv run --directory api --dev lint-imports
|
||||
@echo "✅ Linting complete"
|
||||
|
||||
type-check:
|
||||
|
||||
24
README.md
24
README.md
@@ -40,18 +40,18 @@
|
||||
|
||||
<p align="center">
|
||||
<a href="./README.md"><img alt="README in English" src="https://img.shields.io/badge/English-d9d9d9"></a>
|
||||
<a href="./README_TW.md"><img alt="繁體中文文件" src="https://img.shields.io/badge/繁體中文-d9d9d9"></a>
|
||||
<a href="./README_CN.md"><img alt="简体中文版自述文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
|
||||
<a href="./README_JA.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
|
||||
<a href="./README_ES.md"><img alt="README en Español" src="https://img.shields.io/badge/Español-d9d9d9"></a>
|
||||
<a href="./README_FR.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a>
|
||||
<a href="./README_KL.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
|
||||
<a href="./README_KR.md"><img alt="README in Korean" src="https://img.shields.io/badge/한국어-d9d9d9"></a>
|
||||
<a href="./README_AR.md"><img alt="README بالعربية" src="https://img.shields.io/badge/العربية-d9d9d9"></a>
|
||||
<a href="./README_TR.md"><img alt="Türkçe README" src="https://img.shields.io/badge/Türkçe-d9d9d9"></a>
|
||||
<a href="./README_VI.md"><img alt="README Tiếng Việt" src="https://img.shields.io/badge/Ti%E1%BA%BFng%20Vi%E1%BB%87t-d9d9d9"></a>
|
||||
<a href="./README_DE.md"><img alt="README in Deutsch" src="https://img.shields.io/badge/German-d9d9d9"></a>
|
||||
<a href="./README_BN.md"><img alt="README in বাংলা" src="https://img.shields.io/badge/বাংলা-d9d9d9"></a>
|
||||
<a href="./docs/zh-TW/README.md"><img alt="繁體中文文件" src="https://img.shields.io/badge/繁體中文-d9d9d9"></a>
|
||||
<a href="./docs/zh-CN/README.md"><img alt="简体中文文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
|
||||
<a href="./docs/ja-JP/README.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
|
||||
<a href="./docs/es-ES/README.md"><img alt="README en Español" src="https://img.shields.io/badge/Español-d9d9d9"></a>
|
||||
<a href="./docs/fr-FR/README.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a>
|
||||
<a href="./docs/tlh/README.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
|
||||
<a href="./docs/ko-KR/README.md"><img alt="README in Korean" src="https://img.shields.io/badge/한국어-d9d9d9"></a>
|
||||
<a href="./docs/ar-SA/README.md"><img alt="README بالعربية" src="https://img.shields.io/badge/العربية-d9d9d9"></a>
|
||||
<a href="./docs/tr-TR/README.md"><img alt="Türkçe README" src="https://img.shields.io/badge/Türkçe-d9d9d9"></a>
|
||||
<a href="./docs/vi-VN/README.md"><img alt="README Tiếng Việt" src="https://img.shields.io/badge/Ti%E1%BA%BFng%20Vi%E1%BB%87t-d9d9d9"></a>
|
||||
<a href="./docs/de-DE/README.md"><img alt="README in Deutsch" src="https://img.shields.io/badge/German-d9d9d9"></a>
|
||||
<a href="./docs/bn-BD/README.md"><img alt="README in বাংলা" src="https://img.shields.io/badge/বাংলা-d9d9d9"></a>
|
||||
</p>
|
||||
|
||||
Dify is an open-source platform for developing LLM applications. Its intuitive interface combines agentic AI workflows, RAG pipelines, agent capabilities, model management, observability features, and more—allowing you to quickly move from prototype to production.
|
||||
|
||||
@@ -304,6 +304,8 @@ BAIDU_VECTOR_DB_API_KEY=dify
|
||||
BAIDU_VECTOR_DB_DATABASE=dify
|
||||
BAIDU_VECTOR_DB_SHARD=1
|
||||
BAIDU_VECTOR_DB_REPLICAS=3
|
||||
BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER=DEFAULT_ANALYZER
|
||||
BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE=COARSE_MODE
|
||||
|
||||
# Upstash configuration
|
||||
UPSTASH_VECTOR_URL=your-server-url
|
||||
@@ -406,6 +408,9 @@ SSRF_DEFAULT_TIME_OUT=5
|
||||
SSRF_DEFAULT_CONNECT_TIME_OUT=5
|
||||
SSRF_DEFAULT_READ_TIME_OUT=5
|
||||
SSRF_DEFAULT_WRITE_TIME_OUT=5
|
||||
SSRF_POOL_MAX_CONNECTIONS=100
|
||||
SSRF_POOL_MAX_KEEPALIVE_CONNECTIONS=20
|
||||
SSRF_POOL_KEEPALIVE_EXPIRY=5.0
|
||||
|
||||
BATCH_UPLOAD_LIMIT=10
|
||||
KEYWORD_DATA_SOURCE_TYPE=database
|
||||
@@ -416,10 +421,14 @@ WORKFLOW_FILE_UPLOAD_LIMIT=10
|
||||
# CODE EXECUTION CONFIGURATION
|
||||
CODE_EXECUTION_ENDPOINT=http://127.0.0.1:8194
|
||||
CODE_EXECUTION_API_KEY=dify-sandbox
|
||||
CODE_EXECUTION_SSL_VERIFY=True
|
||||
CODE_EXECUTION_POOL_MAX_CONNECTIONS=100
|
||||
CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS=20
|
||||
CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY=5.0
|
||||
CODE_MAX_NUMBER=9223372036854775807
|
||||
CODE_MIN_NUMBER=-9223372036854775808
|
||||
CODE_MAX_STRING_LENGTH=80000
|
||||
TEMPLATE_TRANSFORM_MAX_LENGTH=80000
|
||||
CODE_MAX_STRING_LENGTH=400000
|
||||
TEMPLATE_TRANSFORM_MAX_LENGTH=400000
|
||||
CODE_MAX_STRING_ARRAY_LENGTH=30
|
||||
CODE_MAX_OBJECT_ARRAY_LENGTH=30
|
||||
CODE_MAX_NUMBER_ARRAY_LENGTH=1000
|
||||
@@ -459,7 +468,6 @@ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000
|
||||
WORKFLOW_MAX_EXECUTION_STEPS=500
|
||||
WORKFLOW_MAX_EXECUTION_TIME=1200
|
||||
WORKFLOW_CALL_MAX_DEPTH=5
|
||||
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
|
||||
MAX_VARIABLE_SIZE=204800
|
||||
|
||||
# GraphEngine Worker Pool Configuration
|
||||
|
||||
@@ -30,6 +30,7 @@ select = [
|
||||
"RUF022", # unsorted-dunder-all
|
||||
"S506", # unsafe-yaml-load
|
||||
"SIM", # flake8-simplify rules
|
||||
"T201", # print-found
|
||||
"TRY400", # error-instead-of-exception
|
||||
"TRY401", # verbose-log-message
|
||||
"UP", # pyupgrade rules
|
||||
@@ -91,11 +92,18 @@ ignore = [
|
||||
"configs/*" = [
|
||||
"N802", # invalid-function-name
|
||||
]
|
||||
"core/model_runtime/callbacks/base_callback.py" = [
|
||||
"T201",
|
||||
]
|
||||
"core/workflow/callbacks/workflow_logging_callback.py" = [
|
||||
"T201",
|
||||
]
|
||||
"libs/gmpy2_pkcs10aep_cipher.py" = [
|
||||
"N803", # invalid-argument-name
|
||||
]
|
||||
"tests/*" = [
|
||||
"F811", # redefined-while-unused
|
||||
"T201", # allow print in tests
|
||||
]
|
||||
|
||||
[lint.pyflakes]
|
||||
|
||||
@@ -80,10 +80,10 @@
|
||||
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
|
||||
|
||||
```bash
|
||||
uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation
|
||||
uv run celery -A app.celery worker -P gevent -c 2 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation
|
||||
```
|
||||
|
||||
Addition, if you want to debug the celery scheduled tasks, you can use the following command in another terminal:
|
||||
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:
|
||||
|
||||
```bash
|
||||
uv run celery -A app.celery beat
|
||||
|
||||
@@ -50,6 +50,7 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_commands,
|
||||
ext_compress,
|
||||
ext_database,
|
||||
ext_elasticsearch,
|
||||
ext_hosting_provider,
|
||||
ext_import_modules,
|
||||
ext_logging,
|
||||
@@ -82,6 +83,7 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_migrate,
|
||||
ext_redis,
|
||||
ext_storage,
|
||||
ext_elasticsearch,
|
||||
ext_celery,
|
||||
ext_login,
|
||||
ext_mail,
|
||||
|
||||
@@ -1,20 +1,11 @@
|
||||
import logging
|
||||
|
||||
import psycogreen.gevent as pscycogreen_gevent # type: ignore
|
||||
from grpc.experimental import gevent as grpc_gevent # type: ignore
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _log(message: str):
|
||||
print(message, flush=True)
|
||||
|
||||
|
||||
# grpc gevent
|
||||
grpc_gevent.init_gevent()
|
||||
_log("gRPC patched with gevent.")
|
||||
print("gRPC patched with gevent.", flush=True) # noqa: T201
|
||||
pscycogreen_gevent.patch_psycopg()
|
||||
_log("psycopg2 patched with gevent.")
|
||||
print("psycopg2 patched with gevent.", flush=True) # noqa: T201
|
||||
|
||||
|
||||
from app import app, celery
|
||||
|
||||
663
api/commands.py
663
api/commands.py
@@ -10,6 +10,7 @@ from flask import current_app
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
@@ -61,31 +62,30 @@ def reset_password(email, new_password, password_confirm):
|
||||
if str(new_password).strip() != str(password_confirm).strip():
|
||||
click.echo(click.style("Passwords do not match.", fg="red"))
|
||||
return
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
account = session.query(Account).where(Account.email == email).one_or_none()
|
||||
|
||||
account = db.session.query(Account).where(Account.email == email).one_or_none()
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
try:
|
||||
valid_password(new_password)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
|
||||
return
|
||||
|
||||
try:
|
||||
valid_password(new_password)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
|
||||
return
|
||||
# generate password salt
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
|
||||
# generate password salt
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
|
||||
# encrypt password with salt
|
||||
password_hashed = hash_password(new_password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
db.session.commit()
|
||||
AccountService.reset_login_error_rate_limit(email)
|
||||
click.echo(click.style("Password reset successfully.", fg="green"))
|
||||
# encrypt password with salt
|
||||
password_hashed = hash_password(new_password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
AccountService.reset_login_error_rate_limit(email)
|
||||
click.echo(click.style("Password reset successfully.", fg="green"))
|
||||
|
||||
|
||||
@click.command("reset-email", help="Reset the account email.")
|
||||
@@ -100,22 +100,21 @@ def reset_email(email, new_email, email_confirm):
|
||||
if str(new_email).strip() != str(email_confirm).strip():
|
||||
click.echo(click.style("New emails do not match.", fg="red"))
|
||||
return
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
account = session.query(Account).where(Account.email == email).one_or_none()
|
||||
|
||||
account = db.session.query(Account).where(Account.email == email).one_or_none()
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
try:
|
||||
email_validate(new_email)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
|
||||
return
|
||||
|
||||
try:
|
||||
email_validate(new_email)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
|
||||
return
|
||||
|
||||
account.email = new_email
|
||||
db.session.commit()
|
||||
click.echo(click.style("Email updated successfully.", fg="green"))
|
||||
account.email = new_email
|
||||
click.echo(click.style("Email updated successfully.", fg="green"))
|
||||
|
||||
|
||||
@click.command(
|
||||
@@ -139,25 +138,24 @@ def reset_encrypt_key_pair():
|
||||
if dify_config.EDITION != "SELF_HOSTED":
|
||||
click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red"))
|
||||
return
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
tenants = session.query(Tenant).all()
|
||||
for tenant in tenants:
|
||||
if not tenant:
|
||||
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
|
||||
return
|
||||
|
||||
tenants = db.session.query(Tenant).all()
|
||||
for tenant in tenants:
|
||||
if not tenant:
|
||||
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
|
||||
return
|
||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||
|
||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||
session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
|
||||
session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()
|
||||
|
||||
db.session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
|
||||
db.session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()
|
||||
db.session.commit()
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.",
|
||||
fg="green",
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@click.command("vdb-migrate", help="Migrate vector db.")
|
||||
@@ -182,14 +180,15 @@ def migrate_annotation_vector_database():
|
||||
try:
|
||||
# get apps info
|
||||
per_page = 50
|
||||
apps = (
|
||||
db.session.query(App)
|
||||
.where(App.status == "normal")
|
||||
.order_by(App.created_at.desc())
|
||||
.limit(per_page)
|
||||
.offset((page - 1) * per_page)
|
||||
.all()
|
||||
)
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
apps = (
|
||||
session.query(App)
|
||||
.where(App.status == "normal")
|
||||
.order_by(App.created_at.desc())
|
||||
.limit(per_page)
|
||||
.offset((page - 1) * per_page)
|
||||
.all()
|
||||
)
|
||||
if not apps:
|
||||
break
|
||||
except SQLAlchemyError:
|
||||
@@ -203,26 +202,27 @@ def migrate_annotation_vector_database():
|
||||
)
|
||||
try:
|
||||
click.echo(f"Creating app annotation index: {app.id}")
|
||||
app_annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
|
||||
)
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
app_annotation_setting = (
|
||||
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
|
||||
)
|
||||
|
||||
if not app_annotation_setting:
|
||||
skipped_count = skipped_count + 1
|
||||
click.echo(f"App annotation setting disabled: {app.id}")
|
||||
continue
|
||||
# get dataset_collection_binding info
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
|
||||
.first()
|
||||
)
|
||||
if not dataset_collection_binding:
|
||||
click.echo(f"App annotation collection binding not found: {app.id}")
|
||||
continue
|
||||
annotations = db.session.scalars(
|
||||
select(MessageAnnotation).where(MessageAnnotation.app_id == app.id)
|
||||
).all()
|
||||
if not app_annotation_setting:
|
||||
skipped_count = skipped_count + 1
|
||||
click.echo(f"App annotation setting disabled: {app.id}")
|
||||
continue
|
||||
# get dataset_collection_binding info
|
||||
dataset_collection_binding = (
|
||||
session.query(DatasetCollectionBinding)
|
||||
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
|
||||
.first()
|
||||
)
|
||||
if not dataset_collection_binding:
|
||||
click.echo(f"App annotation collection binding not found: {app.id}")
|
||||
continue
|
||||
annotations = session.scalars(
|
||||
select(MessageAnnotation).where(MessageAnnotation.app_id == app.id)
|
||||
).all()
|
||||
dataset = Dataset(
|
||||
id=app.id,
|
||||
tenant_id=app.tenant_id,
|
||||
@@ -739,18 +739,18 @@ where sites.id is null limit 1000"""
|
||||
try:
|
||||
app = db.session.query(App).where(App.id == app_id).first()
|
||||
if not app:
|
||||
print(f"App {app_id} not found")
|
||||
logger.info("App %s not found", app_id)
|
||||
continue
|
||||
|
||||
tenant = app.tenant
|
||||
if tenant:
|
||||
accounts = tenant.get_accounts()
|
||||
if not accounts:
|
||||
print(f"Fix failed for app {app.id}")
|
||||
logger.info("Fix failed for app %s", app.id)
|
||||
continue
|
||||
|
||||
account = accounts[0]
|
||||
print(f"Fixing missing site for app {app.id}")
|
||||
logger.info("Fixing missing site for app %s", app.id)
|
||||
app_was_created.send(app, account=account)
|
||||
except Exception:
|
||||
failed_app_ids.append(app_id)
|
||||
@@ -1448,41 +1448,52 @@ def transform_datasource_credentials():
|
||||
notion_credentials_tenant_mapping[tenant_id] = []
|
||||
notion_credentials_tenant_mapping[tenant_id].append(notion_credential)
|
||||
for tenant_id, notion_tenant_credentials in notion_credentials_tenant_mapping.items():
|
||||
# check notion plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if notion_plugin_id not in installed_plugins_ids:
|
||||
if notion_plugin_unique_identifier:
|
||||
# install notion plugin
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier])
|
||||
auth_count = 0
|
||||
for notion_tenant_credential in notion_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential oauth params
|
||||
access_token = notion_tenant_credential.access_token
|
||||
# notion info
|
||||
notion_info = notion_tenant_credential.source_info
|
||||
workspace_id = notion_info.get("workspace_id")
|
||||
workspace_name = notion_info.get("workspace_name")
|
||||
workspace_icon = notion_info.get("workspace_icon")
|
||||
new_credentials = {
|
||||
"integration_secret": encrypter.encrypt_token(tenant_id, access_token),
|
||||
"workspace_id": workspace_id,
|
||||
"workspace_name": workspace_name,
|
||||
"workspace_icon": workspace_icon,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="notion_datasource",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=notion_plugin_id,
|
||||
auth_type=oauth_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url=workspace_icon or "default",
|
||||
is_default=False,
|
||||
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
|
||||
if not tenant:
|
||||
continue
|
||||
try:
|
||||
# check notion plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if notion_plugin_id not in installed_plugins_ids:
|
||||
if notion_plugin_unique_identifier:
|
||||
# install notion plugin
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier])
|
||||
auth_count = 0
|
||||
for notion_tenant_credential in notion_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential oauth params
|
||||
access_token = notion_tenant_credential.access_token
|
||||
# notion info
|
||||
notion_info = notion_tenant_credential.source_info
|
||||
workspace_id = notion_info.get("workspace_id")
|
||||
workspace_name = notion_info.get("workspace_name")
|
||||
workspace_icon = notion_info.get("workspace_icon")
|
||||
new_credentials = {
|
||||
"integration_secret": encrypter.encrypt_token(tenant_id, access_token),
|
||||
"workspace_id": workspace_id,
|
||||
"workspace_name": workspace_name,
|
||||
"workspace_icon": workspace_icon,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="notion_datasource",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=notion_plugin_id,
|
||||
auth_type=oauth_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url=workspace_icon or "default",
|
||||
is_default=False,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_notion_count += 1
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Error transforming notion credentials: {str(e)}, tenant_id: {tenant_id}", fg="red"
|
||||
)
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_notion_count += 1
|
||||
continue
|
||||
db.session.commit()
|
||||
# deal firecrawl credentials
|
||||
deal_firecrawl_count = 0
|
||||
@@ -1495,37 +1506,48 @@ def transform_datasource_credentials():
|
||||
firecrawl_credentials_tenant_mapping[tenant_id] = []
|
||||
firecrawl_credentials_tenant_mapping[tenant_id].append(firecrawl_credential)
|
||||
for tenant_id, firecrawl_tenant_credentials in firecrawl_credentials_tenant_mapping.items():
|
||||
# check firecrawl plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if firecrawl_plugin_id not in installed_plugins_ids:
|
||||
if firecrawl_plugin_unique_identifier:
|
||||
# install firecrawl plugin
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier])
|
||||
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
|
||||
if not tenant:
|
||||
continue
|
||||
try:
|
||||
# check firecrawl plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if firecrawl_plugin_id not in installed_plugins_ids:
|
||||
if firecrawl_plugin_unique_identifier:
|
||||
# install firecrawl plugin
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier])
|
||||
|
||||
auth_count = 0
|
||||
for firecrawl_tenant_credential in firecrawl_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential api key
|
||||
credentials_json = json.loads(firecrawl_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
base_url = credentials_json.get("config", {}).get("base_url")
|
||||
new_credentials = {
|
||||
"firecrawl_api_key": api_key,
|
||||
"base_url": base_url,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="firecrawl",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=firecrawl_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url="default",
|
||||
is_default=False,
|
||||
auth_count = 0
|
||||
for firecrawl_tenant_credential in firecrawl_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential api key
|
||||
credentials_json = json.loads(firecrawl_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
base_url = credentials_json.get("config", {}).get("base_url")
|
||||
new_credentials = {
|
||||
"firecrawl_api_key": api_key,
|
||||
"base_url": base_url,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="firecrawl",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=firecrawl_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url="default",
|
||||
is_default=False,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_firecrawl_count += 1
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Error transforming firecrawl credentials: {str(e)}, tenant_id: {tenant_id}", fg="red"
|
||||
)
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_firecrawl_count += 1
|
||||
continue
|
||||
db.session.commit()
|
||||
# deal jina credentials
|
||||
deal_jina_count = 0
|
||||
@@ -1538,36 +1560,45 @@ def transform_datasource_credentials():
|
||||
jina_credentials_tenant_mapping[tenant_id] = []
|
||||
jina_credentials_tenant_mapping[tenant_id].append(jina_credential)
|
||||
for tenant_id, jina_tenant_credentials in jina_credentials_tenant_mapping.items():
|
||||
# check jina plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if jina_plugin_id not in installed_plugins_ids:
|
||||
if jina_plugin_unique_identifier:
|
||||
# install jina plugin
|
||||
print(jina_plugin_unique_identifier)
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier])
|
||||
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
|
||||
if not tenant:
|
||||
continue
|
||||
try:
|
||||
# check jina plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if jina_plugin_id not in installed_plugins_ids:
|
||||
if jina_plugin_unique_identifier:
|
||||
# install jina plugin
|
||||
logger.debug("Installing Jina plugin %s", jina_plugin_unique_identifier)
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier])
|
||||
|
||||
auth_count = 0
|
||||
for jina_tenant_credential in jina_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential api key
|
||||
credentials_json = json.loads(jina_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
new_credentials = {
|
||||
"integration_secret": api_key,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="jina",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=jina_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url="default",
|
||||
is_default=False,
|
||||
auth_count = 0
|
||||
for jina_tenant_credential in jina_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential api key
|
||||
credentials_json = json.loads(jina_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
new_credentials = {
|
||||
"integration_secret": api_key,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="jina",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=jina_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url="default",
|
||||
is_default=False,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_jina_count += 1
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style(f"Error transforming jina credentials: {str(e)}, tenant_id: {tenant_id}", fg="red")
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_jina_count += 1
|
||||
continue
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
@@ -1793,3 +1824,295 @@ def migrate_oss(
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
click.echo(click.style(f"Failed to update DB storage_type: {str(e)}", fg="red"))
|
||||
|
||||
|
||||
# Elasticsearch Migration Commands
|
||||
@click.group()
|
||||
def elasticsearch():
|
||||
"""Elasticsearch migration and management commands."""
|
||||
pass
|
||||
|
||||
|
||||
@elasticsearch.command()
|
||||
@click.option(
|
||||
"--tenant-id",
|
||||
help="Migrate data for specific tenant only",
|
||||
)
|
||||
@click.option(
|
||||
"--start-date",
|
||||
help="Start date for migration (YYYY-MM-DD format)",
|
||||
)
|
||||
@click.option(
|
||||
"--end-date",
|
||||
help="End date for migration (YYYY-MM-DD format)",
|
||||
)
|
||||
@click.option(
|
||||
"--data-type",
|
||||
type=click.Choice(["workflow_runs", "app_logs", "node_executions", "all"]),
|
||||
default="all",
|
||||
help="Type of data to migrate",
|
||||
)
|
||||
@click.option(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of records to process in each batch",
|
||||
)
|
||||
@click.option(
|
||||
"--dry-run",
|
||||
is_flag=True,
|
||||
help="Perform a dry run without actually migrating data",
|
||||
)
|
||||
def migrate(
|
||||
tenant_id: str | None,
|
||||
start_date: str | None,
|
||||
end_date: str | None,
|
||||
data_type: str,
|
||||
batch_size: int,
|
||||
dry_run: bool,
|
||||
):
|
||||
"""
|
||||
Migrate workflow log data from PostgreSQL to Elasticsearch.
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
from extensions.ext_elasticsearch import elasticsearch as es_extension
|
||||
from services.elasticsearch_migration_service import ElasticsearchMigrationService
|
||||
|
||||
if not es_extension.is_available():
|
||||
click.echo("Error: Elasticsearch is not available. Please check your configuration.", err=True)
|
||||
return
|
||||
|
||||
# Parse dates
|
||||
start_dt = None
|
||||
end_dt = None
|
||||
|
||||
if start_date:
|
||||
try:
|
||||
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
|
||||
except ValueError:
|
||||
click.echo(f"Error: Invalid start date format '{start_date}'. Use YYYY-MM-DD.", err=True)
|
||||
return
|
||||
|
||||
if end_date:
|
||||
try:
|
||||
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
except ValueError:
|
||||
click.echo(f"Error: Invalid end date format '{end_date}'. Use YYYY-MM-DD.", err=True)
|
||||
return
|
||||
|
||||
# Initialize migration service
|
||||
migration_service = ElasticsearchMigrationService(batch_size=batch_size)
|
||||
|
||||
click.echo(f"Starting {'dry run' if dry_run else 'migration'} to Elasticsearch...")
|
||||
click.echo(f"Tenant ID: {tenant_id or 'All tenants'}")
|
||||
click.echo(f"Date range: {start_date or 'No start'} to {end_date or 'No end'}")
|
||||
click.echo(f"Data type: {data_type}")
|
||||
click.echo(f"Batch size: {batch_size}")
|
||||
click.echo()
|
||||
|
||||
total_stats = {
|
||||
"workflow_runs": {},
|
||||
"app_logs": {},
|
||||
"node_executions": {},
|
||||
}
|
||||
|
||||
try:
|
||||
# Migrate workflow runs
|
||||
if data_type in ["workflow_runs", "all"]:
|
||||
click.echo("Migrating WorkflowRun data...")
|
||||
stats = migration_service.migrate_workflow_runs(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_dt,
|
||||
end_date=end_dt,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
total_stats["workflow_runs"] = stats
|
||||
|
||||
click.echo(f" Total records: {stats['total_records']}")
|
||||
click.echo(f" Migrated: {stats['migrated_records']}")
|
||||
click.echo(f" Failed: {stats['failed_records']}")
|
||||
if stats.get("duration"):
|
||||
click.echo(f" Duration: {stats['duration']:.2f}s")
|
||||
click.echo()
|
||||
|
||||
# Migrate app logs
|
||||
if data_type in ["app_logs", "all"]:
|
||||
click.echo("Migrating WorkflowAppLog data...")
|
||||
stats = migration_service.migrate_workflow_app_logs(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_dt,
|
||||
end_date=end_dt,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
total_stats["app_logs"] = stats
|
||||
|
||||
click.echo(f" Total records: {stats['total_records']}")
|
||||
click.echo(f" Migrated: {stats['migrated_records']}")
|
||||
click.echo(f" Failed: {stats['failed_records']}")
|
||||
if stats.get("duration"):
|
||||
click.echo(f" Duration: {stats['duration']:.2f}s")
|
||||
click.echo()
|
||||
|
||||
# Migrate node executions
|
||||
if data_type in ["node_executions", "all"]:
|
||||
click.echo("Migrating WorkflowNodeExecution data...")
|
||||
stats = migration_service.migrate_workflow_node_executions(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_dt,
|
||||
end_date=end_dt,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
total_stats["node_executions"] = stats
|
||||
|
||||
click.echo(f" Total records: {stats['total_records']}")
|
||||
click.echo(f" Migrated: {stats['migrated_records']}")
|
||||
click.echo(f" Failed: {stats['failed_records']}")
|
||||
if stats.get("duration"):
|
||||
click.echo(f" Duration: {stats['duration']:.2f}s")
|
||||
click.echo()
|
||||
|
||||
# Summary
|
||||
total_migrated = sum(stats.get("migrated_records", 0) for stats in total_stats.values())
|
||||
total_failed = sum(stats.get("failed_records", 0) for stats in total_stats.values())
|
||||
|
||||
click.echo("Migration Summary:")
|
||||
click.echo(f" Total migrated: {total_migrated}")
|
||||
click.echo(f" Total failed: {total_failed}")
|
||||
|
||||
# Show errors if any
|
||||
all_errors = []
|
||||
for stats in total_stats.values():
|
||||
all_errors.extend(stats.get("errors", []))
|
||||
|
||||
if all_errors:
|
||||
click.echo(f" Errors ({len(all_errors)}):")
|
||||
for error in all_errors[:10]: # Show first 10 errors
|
||||
click.echo(f" - {error}")
|
||||
if len(all_errors) > 10:
|
||||
click.echo(f" ... and {len(all_errors) - 10} more errors")
|
||||
|
||||
if dry_run:
|
||||
click.echo("\nThis was a dry run. No data was actually migrated.")
|
||||
else:
|
||||
click.echo(f"\nMigration {'completed successfully' if total_failed == 0 else 'completed with errors'}!")
|
||||
|
||||
except Exception as e:
|
||||
click.echo(f"Error: Migration failed: {str(e)}", err=True)
|
||||
logger.exception("Migration failed")
|
||||
|
||||
|
||||
@elasticsearch.command()
|
||||
@click.option(
|
||||
"--tenant-id",
|
||||
required=True,
|
||||
help="Tenant ID to validate",
|
||||
)
|
||||
@click.option(
|
||||
"--sample-size",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Number of records to sample for validation",
|
||||
)
|
||||
def validate(tenant_id: str, sample_size: int):
|
||||
"""
|
||||
Validate migrated data by comparing samples from PostgreSQL and Elasticsearch.
|
||||
"""
|
||||
from extensions.ext_elasticsearch import elasticsearch as es_extension
|
||||
from services.elasticsearch_migration_service import ElasticsearchMigrationService
|
||||
|
||||
if not es_extension.is_available():
|
||||
click.echo("Error: Elasticsearch is not available. Please check your configuration.", err=True)
|
||||
return
|
||||
|
||||
migration_service = ElasticsearchMigrationService()
|
||||
|
||||
click.echo(f"Validating migration for tenant: {tenant_id}")
|
||||
click.echo(f"Sample size: {sample_size}")
|
||||
click.echo()
|
||||
|
||||
try:
|
||||
results = migration_service.validate_migration(tenant_id, sample_size)
|
||||
|
||||
click.echo("Validation Results:")
|
||||
|
||||
for data_type, stats in results.items():
|
||||
if data_type == "errors":
|
||||
continue
|
||||
|
||||
click.echo(f"\n{data_type.replace('_', ' ').title()}:")
|
||||
click.echo(f" Total sampled: {stats['total']}")
|
||||
click.echo(f" Matched: {stats['matched']}")
|
||||
click.echo(f" Mismatched: {stats['mismatched']}")
|
||||
click.echo(f" Missing in ES: {stats['missing']}")
|
||||
|
||||
if stats['total'] > 0:
|
||||
accuracy = (stats['matched'] / stats['total']) * 100
|
||||
click.echo(f" Accuracy: {accuracy:.1f}%")
|
||||
|
||||
if results["errors"]:
|
||||
click.echo(f"\nValidation Errors ({len(results['errors'])}):")
|
||||
for error in results["errors"][:10]:
|
||||
click.echo(f" - {error}")
|
||||
if len(results["errors"]) > 10:
|
||||
click.echo(f" ... and {len(results['errors']) - 10} more errors")
|
||||
|
||||
except Exception as e:
|
||||
click.echo(f"Error: Validation failed: {str(e)}", err=True)
|
||||
logger.exception("Validation failed")
|
||||
|
||||
|
||||
@elasticsearch.command()
|
||||
def status():
|
||||
"""
|
||||
Check Elasticsearch connection and index status.
|
||||
"""
|
||||
from extensions.ext_elasticsearch import elasticsearch as es_extension
|
||||
|
||||
if not es_extension.is_available():
|
||||
click.echo("Error: Elasticsearch is not available. Please check your configuration.", err=True)
|
||||
return
|
||||
|
||||
try:
|
||||
es_client = es_extension.client
|
||||
|
||||
# Cluster health
|
||||
health = es_client.cluster.health()
|
||||
click.echo("Elasticsearch Cluster Status:")
|
||||
click.echo(f" Status: {health['status']}")
|
||||
click.echo(f" Nodes: {health['number_of_nodes']}")
|
||||
click.echo(f" Data nodes: {health['number_of_data_nodes']}")
|
||||
click.echo()
|
||||
|
||||
# Index information
|
||||
index_pattern = "dify-*"
|
||||
|
||||
try:
|
||||
indices = es_client.indices.get(index=index_pattern)
|
||||
|
||||
click.echo(f"Indices matching '{index_pattern}':")
|
||||
total_docs = 0
|
||||
total_size = 0
|
||||
|
||||
for index_name, index_info in indices.items():
|
||||
stats = es_client.indices.stats(index=index_name)
|
||||
docs = stats['indices'][index_name]['total']['docs']['count']
|
||||
size_bytes = stats['indices'][index_name]['total']['store']['size_in_bytes']
|
||||
size_mb = size_bytes / (1024 * 1024)
|
||||
|
||||
total_docs += docs
|
||||
total_size += size_mb
|
||||
|
||||
click.echo(f" {index_name}: {docs:,} docs, {size_mb:.1f} MB")
|
||||
|
||||
click.echo(f"\nTotal: {total_docs:,} documents, {total_size:.1f} MB")
|
||||
|
||||
except Exception as e:
|
||||
if "index_not_found_exception" in str(e):
|
||||
click.echo(f"No indices found matching pattern '{index_pattern}'")
|
||||
else:
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
click.echo(f"Error: Failed to get Elasticsearch status: {str(e)}", err=True)
|
||||
logger.exception("Status check failed")
|
||||
|
||||
@@ -113,6 +113,21 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
||||
default=10.0,
|
||||
)
|
||||
|
||||
CODE_EXECUTION_POOL_MAX_CONNECTIONS: PositiveInt = Field(
|
||||
description="Maximum number of concurrent connections for the code execution HTTP client",
|
||||
default=100,
|
||||
)
|
||||
|
||||
CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS: PositiveInt = Field(
|
||||
description="Maximum number of persistent keep-alive connections for the code execution HTTP client",
|
||||
default=20,
|
||||
)
|
||||
|
||||
CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY: PositiveFloat | None = Field(
|
||||
description="Keep-alive expiry in seconds for idle connections (set to None to disable)",
|
||||
default=5.0,
|
||||
)
|
||||
|
||||
CODE_MAX_NUMBER: PositiveInt = Field(
|
||||
description="Maximum allowed numeric value in code execution",
|
||||
default=9223372036854775807,
|
||||
@@ -135,7 +150,7 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
||||
|
||||
CODE_MAX_STRING_LENGTH: PositiveInt = Field(
|
||||
description="Maximum allowed length for strings in code execution",
|
||||
default=80000,
|
||||
default=400_000,
|
||||
)
|
||||
|
||||
CODE_MAX_STRING_ARRAY_LENGTH: PositiveInt = Field(
|
||||
@@ -153,6 +168,11 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
||||
default=1000,
|
||||
)
|
||||
|
||||
CODE_EXECUTION_SSL_VERIFY: bool = Field(
|
||||
description="Enable or disable SSL verification for code execution requests",
|
||||
default=True,
|
||||
)
|
||||
|
||||
|
||||
class PluginConfig(BaseSettings):
|
||||
"""
|
||||
@@ -404,6 +424,21 @@ class HttpConfig(BaseSettings):
|
||||
default=5,
|
||||
)
|
||||
|
||||
SSRF_POOL_MAX_CONNECTIONS: PositiveInt = Field(
|
||||
description="Maximum number of concurrent connections for the SSRF HTTP client",
|
||||
default=100,
|
||||
)
|
||||
|
||||
SSRF_POOL_MAX_KEEPALIVE_CONNECTIONS: PositiveInt = Field(
|
||||
description="Maximum number of persistent keep-alive connections for the SSRF HTTP client",
|
||||
default=20,
|
||||
)
|
||||
|
||||
SSRF_POOL_KEEPALIVE_EXPIRY: PositiveFloat | None = Field(
|
||||
description="Keep-alive expiry in seconds for idle SSRF connections (set to None to disable)",
|
||||
default=5.0,
|
||||
)
|
||||
|
||||
RESPECT_XFORWARD_HEADERS_ENABLED: bool = Field(
|
||||
description="Enable handling of X-Forwarded-For, X-Forwarded-Proto, and X-Forwarded-Port headers"
|
||||
" when the app is behind a single trusted reverse proxy.",
|
||||
@@ -542,16 +577,16 @@ class WorkflowConfig(BaseSettings):
|
||||
default=5,
|
||||
)
|
||||
|
||||
WORKFLOW_PARALLEL_DEPTH_LIMIT: PositiveInt = Field(
|
||||
description="Maximum allowed depth for nested parallel executions",
|
||||
default=3,
|
||||
)
|
||||
|
||||
MAX_VARIABLE_SIZE: PositiveInt = Field(
|
||||
description="Maximum size in bytes for a single variable in workflows. Default to 200 KB.",
|
||||
default=200 * 1024,
|
||||
)
|
||||
|
||||
TEMPLATE_TRANSFORM_MAX_LENGTH: PositiveInt = Field(
|
||||
description="Maximum number of characters allowed in Template Transform node output",
|
||||
default=400_000,
|
||||
)
|
||||
|
||||
# GraphEngine Worker Pool Configuration
|
||||
GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field(
|
||||
description="Minimum number of workers per GraphEngine instance",
|
||||
@@ -624,6 +659,67 @@ class RepositoryConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class ElasticsearchConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for Elasticsearch integration
|
||||
"""
|
||||
|
||||
ELASTICSEARCH_ENABLED: bool = Field(
|
||||
description="Enable Elasticsearch for workflow logs storage",
|
||||
default=False,
|
||||
)
|
||||
|
||||
ELASTICSEARCH_HOSTS: list[str] = Field(
|
||||
description="List of Elasticsearch hosts",
|
||||
default=["http://localhost:9200"],
|
||||
)
|
||||
|
||||
ELASTICSEARCH_USERNAME: str | None = Field(
|
||||
description="Elasticsearch username for authentication",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ELASTICSEARCH_PASSWORD: str | None = Field(
|
||||
description="Elasticsearch password for authentication",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ELASTICSEARCH_USE_SSL: bool = Field(
|
||||
description="Use SSL/TLS for Elasticsearch connections",
|
||||
default=False,
|
||||
)
|
||||
|
||||
ELASTICSEARCH_VERIFY_CERTS: bool = Field(
|
||||
description="Verify SSL certificates for Elasticsearch connections",
|
||||
default=True,
|
||||
)
|
||||
|
||||
ELASTICSEARCH_CA_CERTS: str | None = Field(
|
||||
description="Path to CA certificates file for Elasticsearch SSL verification",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ELASTICSEARCH_TIMEOUT: int = Field(
|
||||
description="Elasticsearch request timeout in seconds",
|
||||
default=30,
|
||||
)
|
||||
|
||||
ELASTICSEARCH_MAX_RETRIES: int = Field(
|
||||
description="Maximum number of retries for Elasticsearch requests",
|
||||
default=3,
|
||||
)
|
||||
|
||||
ELASTICSEARCH_INDEX_PREFIX: str = Field(
|
||||
description="Prefix for Elasticsearch indices",
|
||||
default="dify",
|
||||
)
|
||||
|
||||
ELASTICSEARCH_RETENTION_DAYS: int = Field(
|
||||
description="Number of days to retain data in Elasticsearch",
|
||||
default=30,
|
||||
)
|
||||
|
||||
|
||||
class AuthConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for authentication and OAuth
|
||||
@@ -1073,6 +1169,7 @@ class FeatureConfig(
|
||||
AuthConfig, # Changed from OAuthConfig to AuthConfig
|
||||
BillingConfig,
|
||||
CodeExecutionSandboxConfig,
|
||||
ElasticsearchConfig,
|
||||
PluginConfig,
|
||||
MarketplaceConfig,
|
||||
DataSetConfig,
|
||||
|
||||
@@ -41,3 +41,13 @@ class BaiduVectorDBConfig(BaseSettings):
|
||||
description="Number of replicas for the Baidu Vector Database (default is 3)",
|
||||
default=3,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER: str = Field(
|
||||
description="Analyzer type for inverted index in Baidu Vector Database (default is DEFAULT_ANALYZER)",
|
||||
default="DEFAULT_ANALYZER",
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE: str = Field(
|
||||
description="Parser mode for inverted index in Baidu Vector Database (default is COARSE_MODE)",
|
||||
default="COARSE_MODE",
|
||||
)
|
||||
|
||||
@@ -37,3 +37,15 @@ class OceanBaseVectorConfig(BaseSettings):
|
||||
"with older versions",
|
||||
default=False,
|
||||
)
|
||||
|
||||
OCEANBASE_FULLTEXT_PARSER: str | None = Field(
|
||||
description=(
|
||||
"Fulltext parser to use for text indexing. "
|
||||
"Built-in options: 'ngram' (N-gram tokenizer for English/numbers), "
|
||||
"'beng' (Basic English tokenizer), 'space' (Space-based tokenizer), "
|
||||
"'ngram2' (Improved N-gram tokenizer), 'ik' (Chinese tokenizer). "
|
||||
"External plugins (require installation): 'japanese_ftparser' (Japanese tokenizer), "
|
||||
"'thai_ftparser' (Thai tokenizer). Default is 'ik'"
|
||||
),
|
||||
default="ik",
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -30,10 +30,10 @@ class NacosHttpClient:
|
||||
params = {}
|
||||
try:
|
||||
self._inject_auth_info(headers, params)
|
||||
response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params)
|
||||
response = httpx.request(method, url="http://" + self.server + url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except requests.RequestException as e:
|
||||
except httpx.RequestError as e:
|
||||
return f"Request to Nacos failed: {e}"
|
||||
|
||||
def _inject_auth_info(self, headers: dict[str, str], params: dict[str, str], module: str = "config") -> None:
|
||||
@@ -78,7 +78,7 @@ class NacosHttpClient:
|
||||
params = {"username": self.username, "password": self.password}
|
||||
url = "http://" + self.server + "/nacos/v1/auth/login"
|
||||
try:
|
||||
resp = requests.request("POST", url, headers=None, params=params)
|
||||
resp = httpx.request("POST", url, headers=None, params=params)
|
||||
resp.raise_for_status()
|
||||
response_data = resp.json()
|
||||
self.token = response_data.get("accessToken")
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from configs import dify_config
|
||||
from libs.collection_utils import convert_to_lower_and_upper_set
|
||||
|
||||
HIDDEN_VALUE = "[__HIDDEN__]"
|
||||
UNKNOWN_VALUE = "[__UNKNOWN__]"
|
||||
@@ -6,24 +7,39 @@ UUID_NIL = "00000000-0000-0000-0000-000000000000"
|
||||
|
||||
DEFAULT_FILE_NUMBER_LIMITS = 3
|
||||
|
||||
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
|
||||
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
|
||||
IMAGE_EXTENSIONS = convert_to_lower_and_upper_set({"jpg", "jpeg", "png", "webp", "gif", "svg"})
|
||||
|
||||
VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "webm"]
|
||||
VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS])
|
||||
VIDEO_EXTENSIONS = convert_to_lower_and_upper_set({"mp4", "mov", "mpeg", "webm"})
|
||||
|
||||
AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "amr", "mpga"]
|
||||
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
|
||||
AUDIO_EXTENSIONS = convert_to_lower_and_upper_set({"mp3", "m4a", "wav", "amr", "mpga"})
|
||||
|
||||
|
||||
_doc_extensions: list[str]
|
||||
_doc_extensions: set[str]
|
||||
if dify_config.ETL_TYPE == "Unstructured":
|
||||
_doc_extensions = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"]
|
||||
_doc_extensions.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
|
||||
_doc_extensions = {
|
||||
"txt",
|
||||
"markdown",
|
||||
"md",
|
||||
"mdx",
|
||||
"pdf",
|
||||
"html",
|
||||
"htm",
|
||||
"xlsx",
|
||||
"xls",
|
||||
"vtt",
|
||||
"properties",
|
||||
"doc",
|
||||
"docx",
|
||||
"csv",
|
||||
"eml",
|
||||
"msg",
|
||||
"pptx",
|
||||
"xml",
|
||||
"epub",
|
||||
}
|
||||
if dify_config.UNSTRUCTURED_API_URL:
|
||||
_doc_extensions.append("ppt")
|
||||
_doc_extensions.add("ppt")
|
||||
else:
|
||||
_doc_extensions = [
|
||||
_doc_extensions = {
|
||||
"txt",
|
||||
"markdown",
|
||||
"md",
|
||||
@@ -37,5 +53,5 @@ else:
|
||||
"csv",
|
||||
"vtt",
|
||||
"properties",
|
||||
]
|
||||
DOCUMENT_EXTENSIONS = _doc_extensions + [ext.upper() for ext in _doc_extensions]
|
||||
}
|
||||
DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions)
|
||||
|
||||
@@ -1,31 +1,10 @@
|
||||
from importlib import import_module
|
||||
|
||||
from flask import Blueprint
|
||||
from flask_restx import Namespace
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
from .app.app_import import AppImportApi, AppImportCheckDependenciesApi, AppImportConfirmApi
|
||||
from .explore.audio import ChatAudioApi, ChatTextApi
|
||||
from .explore.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi
|
||||
from .explore.conversation import (
|
||||
ConversationApi,
|
||||
ConversationListApi,
|
||||
ConversationPinApi,
|
||||
ConversationRenameApi,
|
||||
ConversationUnPinApi,
|
||||
)
|
||||
from .explore.message import (
|
||||
MessageFeedbackApi,
|
||||
MessageListApi,
|
||||
MessageMoreLikeThisApi,
|
||||
MessageSuggestedQuestionApi,
|
||||
)
|
||||
from .explore.workflow import (
|
||||
InstalledAppWorkflowRunApi,
|
||||
InstalledAppWorkflowTaskStopApi,
|
||||
)
|
||||
from .files import FileApi, FilePreviewApi, FileSupportTypeApi
|
||||
from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi
|
||||
|
||||
bp = Blueprint("console", __name__, url_prefix="/console/api")
|
||||
|
||||
api = ExternalApi(
|
||||
@@ -35,23 +14,23 @@ api = ExternalApi(
|
||||
description="Console management APIs for app configuration, monitoring, and administration",
|
||||
)
|
||||
|
||||
# Create namespace
|
||||
console_ns = Namespace("console", description="Console management API operations", path="/")
|
||||
|
||||
# File
|
||||
api.add_resource(FileApi, "/files/upload")
|
||||
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview")
|
||||
api.add_resource(FileSupportTypeApi, "/files/support-type")
|
||||
RESOURCE_MODULES = (
|
||||
"controllers.console.app.app_import",
|
||||
"controllers.console.explore.audio",
|
||||
"controllers.console.explore.completion",
|
||||
"controllers.console.explore.conversation",
|
||||
"controllers.console.explore.message",
|
||||
"controllers.console.explore.workflow",
|
||||
"controllers.console.files",
|
||||
"controllers.console.remote_files",
|
||||
)
|
||||
|
||||
# Remote files
|
||||
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
|
||||
api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
|
||||
|
||||
# Import App
|
||||
api.add_resource(AppImportApi, "/apps/imports")
|
||||
api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm")
|
||||
api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies")
|
||||
for module_name in RESOURCE_MODULES:
|
||||
import_module(module_name)
|
||||
|
||||
# Ensure resource modules are imported so route decorators are evaluated.
|
||||
# Import other controllers
|
||||
from . import (
|
||||
admin,
|
||||
@@ -150,77 +129,6 @@ from .workspace import (
|
||||
workspace,
|
||||
)
|
||||
|
||||
# Explore Audio
|
||||
api.add_resource(ChatAudioApi, "/installed-apps/<uuid:installed_app_id>/audio-to-text", endpoint="installed_app_audio")
|
||||
api.add_resource(ChatTextApi, "/installed-apps/<uuid:installed_app_id>/text-to-audio", endpoint="installed_app_text")
|
||||
|
||||
# Explore Completion
|
||||
api.add_resource(
|
||||
CompletionApi, "/installed-apps/<uuid:installed_app_id>/completion-messages", endpoint="installed_app_completion"
|
||||
)
|
||||
api.add_resource(
|
||||
CompletionStopApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/completion-messages/<string:task_id>/stop",
|
||||
endpoint="installed_app_stop_completion",
|
||||
)
|
||||
api.add_resource(
|
||||
ChatApi, "/installed-apps/<uuid:installed_app_id>/chat-messages", endpoint="installed_app_chat_completion"
|
||||
)
|
||||
api.add_resource(
|
||||
ChatStopApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/chat-messages/<string:task_id>/stop",
|
||||
endpoint="installed_app_stop_chat_completion",
|
||||
)
|
||||
|
||||
# Explore Conversation
|
||||
api.add_resource(
|
||||
ConversationRenameApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/name",
|
||||
endpoint="installed_app_conversation_rename",
|
||||
)
|
||||
api.add_resource(
|
||||
ConversationListApi, "/installed-apps/<uuid:installed_app_id>/conversations", endpoint="installed_app_conversations"
|
||||
)
|
||||
api.add_resource(
|
||||
ConversationApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>",
|
||||
endpoint="installed_app_conversation",
|
||||
)
|
||||
api.add_resource(
|
||||
ConversationPinApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/pin",
|
||||
endpoint="installed_app_conversation_pin",
|
||||
)
|
||||
api.add_resource(
|
||||
ConversationUnPinApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/unpin",
|
||||
endpoint="installed_app_conversation_unpin",
|
||||
)
|
||||
|
||||
|
||||
# Explore Message
|
||||
api.add_resource(MessageListApi, "/installed-apps/<uuid:installed_app_id>/messages", endpoint="installed_app_messages")
|
||||
api.add_resource(
|
||||
MessageFeedbackApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks",
|
||||
endpoint="installed_app_message_feedback",
|
||||
)
|
||||
api.add_resource(
|
||||
MessageMoreLikeThisApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this",
|
||||
endpoint="installed_app_more_like_this",
|
||||
)
|
||||
api.add_resource(
|
||||
MessageSuggestedQuestionApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions",
|
||||
endpoint="installed_app_suggested_question",
|
||||
)
|
||||
# Explore Workflow
|
||||
api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps/<uuid:installed_app_id>/workflows/run")
|
||||
api.add_resource(
|
||||
InstalledAppWorkflowTaskStopApi, "/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop"
|
||||
)
|
||||
|
||||
api.add_namespace(console_ns)
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -19,6 +19,7 @@ from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields
|
||||
from libs.login import login_required
|
||||
from libs.validators import validate_description_length
|
||||
from models import Account, App
|
||||
from services.app_dsl_service import AppDslService, ImportMode
|
||||
from services.app_service import AppService
|
||||
@@ -28,12 +29,6 @@ from services.feature_service import FeatureService
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if description and len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
@console_ns.route("/apps")
|
||||
class AppListApi(Resource):
|
||||
@api.doc("list_apps")
|
||||
@@ -138,7 +133,7 @@ class AppListApi(Resource):
|
||||
"""Create app"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
parser.add_argument("description", type=_validate_description_length, location="json")
|
||||
parser.add_argument("description", type=validate_description_length, location="json")
|
||||
parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
@@ -219,7 +214,7 @@ class AppApi(Resource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("description", type=_validate_description_length, location="json")
|
||||
parser.add_argument("description", type=validate_description_length, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
@@ -297,7 +292,7 @@ class AppCopyApi(Resource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, location="json")
|
||||
parser.add_argument("description", type=_validate_description_length, location="json")
|
||||
parser.add_argument("description", type=validate_description_length, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
|
||||
@@ -20,7 +20,10 @@ from services.app_dsl_service import AppDslService, ImportStatus
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from .. import console_ns
|
||||
|
||||
|
||||
@console_ns.route("/apps/imports")
|
||||
class AppImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -74,6 +77,7 @@ class AppImportApi(Resource):
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/apps/imports/<string:import_id>/confirm")
|
||||
class AppImportConfirmApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -98,6 +102,7 @@ class AppImportConfirmApi(Resource):
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/apps/imports/<string:app_id>/check-dependencies")
|
||||
class AppImportCheckDependenciesApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytz # pip install pytz
|
||||
import sqlalchemy as sa
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
@@ -70,7 +71,7 @@ class CompletionConversationApi(Resource):
|
||||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
query = db.select(Conversation).where(
|
||||
query = sa.select(Conversation).where(
|
||||
Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False)
|
||||
)
|
||||
|
||||
@@ -236,7 +237,7 @@ class ChatConversationApi(Resource):
|
||||
.subquery()
|
||||
)
|
||||
|
||||
query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
|
||||
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
|
||||
|
||||
if args["keyword"]:
|
||||
keyword_filter = f"%{args['keyword']}%"
|
||||
|
||||
@@ -62,6 +62,9 @@ class ChatMessageListApi(Resource):
|
||||
@account_initialization_required
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model):
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
||||
parser.add_argument("first_id", type=uuid_value, location="args")
|
||||
|
||||
@@ -50,8 +50,9 @@ class DailyMessageStatistic(Resource):
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
@@ -187,8 +188,9 @@ class DailyTerminalsStatistic(Resource):
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
@@ -259,8 +261,9 @@ class DailyTokenCostStatistic(Resource):
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
@@ -340,8 +343,9 @@ FROM
|
||||
messages m
|
||||
ON c.id = m.conversation_id
|
||||
WHERE
|
||||
c.app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
c.app_id = :app_id
|
||||
AND m.invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
@@ -426,8 +430,9 @@ LEFT JOIN
|
||||
message_feedbacks mf
|
||||
ON mf.message_id=m.id AND mf.rating='like'
|
||||
WHERE
|
||||
m.app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
m.app_id = :app_id
|
||||
AND m.invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
@@ -502,8 +507,9 @@ class AverageResponseTimeStatistic(Resource):
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
@@ -576,8 +582,9 @@ class TokensPerSecondStatistic(Resource):
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
@@ -9,7 +9,6 @@ from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
@@ -797,24 +796,6 @@ class ConvertToWorkflowApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/config")
|
||||
class WorkflowConfigApi(Resource):
|
||||
"""Resource for workflow configuration."""
|
||||
|
||||
@api.doc("get_workflow_config")
|
||||
@api.doc(description="Get workflow configuration")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Workflow configuration retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App):
|
||||
return {
|
||||
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows")
|
||||
class PublishedAllWorkflowApi(Resource):
|
||||
@api.doc("get_all_published_workflows")
|
||||
|
||||
@@ -2,7 +2,7 @@ from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import ApiKeyAuthFailedError
|
||||
from libs.login import login_required
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
@@ -10,6 +10,7 @@ from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
from ..wraps import account_initialization_required, setup_required
|
||||
|
||||
|
||||
@console_ns.route("/api-key-auth/data-source")
|
||||
class ApiKeyAuthDataSource(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -33,6 +34,7 @@ class ApiKeyAuthDataSource(Resource):
|
||||
return {"sources": []}
|
||||
|
||||
|
||||
@console_ns.route("/api-key-auth/data-source/binding")
|
||||
class ApiKeyAuthDataSourceBinding(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -54,6 +56,7 @@ class ApiKeyAuthDataSourceBinding(Resource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/api-key-auth/data-source/<uuid:binding_id>")
|
||||
class ApiKeyAuthDataSourceBindingDelete(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -66,8 +69,3 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
|
||||
ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
api.add_resource(ApiKeyAuthDataSource, "/api-key-auth/data-source")
|
||||
api.add_resource(ApiKeyAuthDataSourceBinding, "/api-key-auth/data-source/binding")
|
||||
api.add_resource(ApiKeyAuthDataSourceBindingDelete, "/api-key-auth/data-source/<uuid:binding_id>")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from flask import current_app, redirect, request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields
|
||||
@@ -119,7 +119,7 @@ class OAuthDataSourceBinding(Resource):
|
||||
return {"error": "Invalid code"}, 400
|
||||
try:
|
||||
oauth_provider.get_access_token(code)
|
||||
except requests.HTTPError as e:
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.exception(
|
||||
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
|
||||
)
|
||||
@@ -152,7 +152,7 @@ class OAuthDataSourceSync(Resource):
|
||||
return {"error": "Invalid provider"}, 400
|
||||
try:
|
||||
oauth_provider.sync_data_source(binding_id)
|
||||
except requests.HTTPError as e:
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.exception(
|
||||
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
EmailAlreadyInUseError,
|
||||
EmailCodeError,
|
||||
@@ -25,6 +25,7 @@ from services.billing_service import BillingService
|
||||
from services.errors.account import AccountNotFoundError, AccountRegisterError
|
||||
|
||||
|
||||
@console_ns.route("/email-register/send-email")
|
||||
class EmailRegisterSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@@ -52,6 +53,7 @@ class EmailRegisterSendEmailApi(Resource):
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
@console_ns.route("/email-register/validity")
|
||||
class EmailRegisterCheckApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@@ -92,6 +94,7 @@ class EmailRegisterCheckApi(Resource):
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
|
||||
|
||||
@console_ns.route("/email-register")
|
||||
class EmailRegisterResetApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@@ -148,8 +151,3 @@ class EmailRegisterResetApi(Resource):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
return account
|
||||
|
||||
|
||||
api.add_resource(EmailRegisterSendEmailApi, "/email-register/send-email")
|
||||
api.add_resource(EmailRegisterCheckApi, "/email-register/validity")
|
||||
api.add_resource(EmailRegisterResetApi, "/email-register")
|
||||
|
||||
@@ -221,8 +221,3 @@ class ForgotPasswordResetApi(Resource):
|
||||
TenantService.create_tenant_member(tenant, account, role="owner")
|
||||
account.current_tenant = tenant
|
||||
tenant_was_created.send(tenant)
|
||||
|
||||
|
||||
api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password")
|
||||
api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity")
|
||||
api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets")
|
||||
|
||||
@@ -7,7 +7,7 @@ from flask_restx import Resource, reqparse
|
||||
import services
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
AuthenticationFailedError,
|
||||
EmailCodeError,
|
||||
@@ -34,6 +34,7 @@ from services.errors.workspace import WorkSpaceNotAllowedCreateError, Workspaces
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
@console_ns.route("/login")
|
||||
class LoginApi(Resource):
|
||||
"""Resource for user login."""
|
||||
|
||||
@@ -91,6 +92,7 @@ class LoginApi(Resource):
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
|
||||
|
||||
@console_ns.route("/logout")
|
||||
class LogoutApi(Resource):
|
||||
@setup_required
|
||||
def get(self):
|
||||
@@ -102,6 +104,7 @@ class LogoutApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/reset-password")
|
||||
class ResetPasswordSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@@ -130,6 +133,7 @@ class ResetPasswordSendEmailApi(Resource):
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
@console_ns.route("/email-code-login")
|
||||
class EmailCodeLoginSendEmailApi(Resource):
|
||||
@setup_required
|
||||
def post(self):
|
||||
@@ -162,6 +166,7 @@ class EmailCodeLoginSendEmailApi(Resource):
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
@console_ns.route("/email-code-login/validity")
|
||||
class EmailCodeLoginApi(Resource):
|
||||
@setup_required
|
||||
def post(self):
|
||||
@@ -218,6 +223,7 @@ class EmailCodeLoginApi(Resource):
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
|
||||
|
||||
@console_ns.route("/refresh-token")
|
||||
class RefreshTokenApi(Resource):
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
@@ -229,11 +235,3 @@ class RefreshTokenApi(Resource):
|
||||
return {"result": "success", "data": new_token_pair.model_dump()}
|
||||
except Exception as e:
|
||||
return {"result": "fail", "data": str(e)}, 401
|
||||
|
||||
|
||||
api.add_resource(LoginApi, "/login")
|
||||
api.add_resource(LogoutApi, "/logout")
|
||||
api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
|
||||
api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")
|
||||
api.add_resource(ResetPasswordSendEmailApi, "/reset-password")
|
||||
api.add_resource(RefreshTokenApi, "/refresh-token")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from flask import current_app, redirect, request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import select
|
||||
@@ -101,8 +101,10 @@ class OAuthCallback(Resource):
|
||||
try:
|
||||
token = oauth_provider.get_access_token(code)
|
||||
user_info = oauth_provider.get_user_info(token)
|
||||
except requests.RequestException as e:
|
||||
error_text = e.response.text if e.response else str(e)
|
||||
except httpx.RequestError as e:
|
||||
error_text = str(e)
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
error_text = e.response.text
|
||||
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
|
||||
return {"error": "OAuth process failed"}, 400
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ from models.account import Account
|
||||
from models.model import OAuthProviderApp
|
||||
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
|
||||
|
||||
from .. import api
|
||||
from .. import console_ns
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
@@ -86,6 +86,7 @@ def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProvid
|
||||
return decorated
|
||||
|
||||
|
||||
@console_ns.route("/oauth/provider")
|
||||
class OAuthServerAppApi(Resource):
|
||||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
@@ -108,6 +109,7 @@ class OAuthServerAppApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/oauth/provider/authorize")
|
||||
class OAuthServerUserAuthorizeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -125,6 +127,7 @@ class OAuthServerUserAuthorizeApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/oauth/provider/token")
|
||||
class OAuthServerUserTokenApi(Resource):
|
||||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
@@ -180,6 +183,7 @@ class OAuthServerUserTokenApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/oauth/provider/account")
|
||||
class OAuthServerUserAccountApi(Resource):
|
||||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
@@ -194,9 +198,3 @@ class OAuthServerUserAccountApi(Resource):
|
||||
"timezone": account.timezone,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
api.add_resource(OAuthServerAppApi, "/oauth/provider")
|
||||
api.add_resource(OAuthServerUserAuthorizeApi, "/oauth/provider/authorize")
|
||||
api.add_resource(OAuthServerUserTokenApi, "/oauth/provider/token")
|
||||
api.add_resource(OAuthServerUserAccountApi, "/oauth/provider/account")
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from flask_restx import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.model import Account
|
||||
from services.billing_service import BillingService
|
||||
|
||||
|
||||
@console_ns.route("/billing/subscription")
|
||||
class Subscription(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -26,6 +27,7 @@ class Subscription(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/billing/invoices")
|
||||
class Invoices(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -36,7 +38,3 @@ class Invoices(Resource):
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
assert current_user.current_tenant_id is not None
|
||||
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)
|
||||
|
||||
|
||||
api.add_resource(Subscription, "/billing/subscription")
|
||||
api.add_resource(Invoices, "/billing/invoices")
|
||||
|
||||
@@ -6,10 +6,11 @@ from libs.helper import extract_remote_ip
|
||||
from libs.login import login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
from .. import api
|
||||
from .. import console_ns
|
||||
from ..wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
|
||||
|
||||
@console_ns.route("/compliance/download")
|
||||
class ComplianceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -30,6 +31,3 @@ class ComplianceApi(Resource):
|
||||
ip=ip_address,
|
||||
device_info=device_info,
|
||||
)
|
||||
|
||||
|
||||
api.add_resource(ComplianceApi, "/compliance/download")
|
||||
|
||||
@@ -9,7 +9,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
@@ -27,6 +27,10 @@ from services.datasource_provider_service import DatasourceProviderService
|
||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/data-source/integrates",
|
||||
"/data-source/integrates/<uuid:binding_id>/<string:action>",
|
||||
)
|
||||
class DataSourceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -109,6 +113,7 @@ class DataSourceApi(Resource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/notion/pre-import/pages")
|
||||
class DataSourceNotionListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -196,6 +201,10 @@ class DataSourceNotionListApi(Resource):
|
||||
return {"notion_info": {**workspace_info, "pages": pages}}, 200
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview",
|
||||
"/datasets/notion-indexing-estimate",
|
||||
)
|
||||
class DataSourceNotionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -269,6 +278,7 @@ class DataSourceNotionApi(Resource):
|
||||
return response.model_dump(), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/notion/sync")
|
||||
class DataSourceNotionDatasetSyncApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -285,6 +295,7 @@ class DataSourceNotionDatasetSyncApi(Resource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/notion/sync")
|
||||
class DataSourceNotionDocumentSyncApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -301,16 +312,3 @@ class DataSourceNotionDocumentSyncApi(Resource):
|
||||
raise NotFound("Document not found.")
|
||||
document_indexing_sync_task.delay(dataset_id_str, document_id_str)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
api.add_resource(DataSourceApi, "/data-source/integrates", "/data-source/integrates/<uuid:binding_id>/<string:action>")
|
||||
api.add_resource(DataSourceNotionListApi, "/notion/pre-import/pages")
|
||||
api.add_resource(
|
||||
DataSourceNotionApi,
|
||||
"/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview",
|
||||
"/datasets/notion-indexing-estimate",
|
||||
)
|
||||
api.add_resource(DataSourceNotionDatasetSyncApi, "/datasets/<uuid:dataset_id>/notion/sync")
|
||||
api.add_resource(
|
||||
DataSourceNotionDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/notion/sync"
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import flask_restx
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||
@@ -30,24 +31,20 @@ from fields.app_fields import related_app_list
|
||||
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
|
||||
from fields.document_fields import document_status_fields
|
||||
from libs.login import login_required
|
||||
from libs.validators import validate_description_length
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models.account import Account
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
def _validate_name(name: str) -> str:
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if description and len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
@console_ns.route("/datasets")
|
||||
class DatasetListApi(Resource):
|
||||
@api.doc("get_datasets")
|
||||
@@ -92,7 +89,7 @@ class DatasetListApi(Resource):
|
||||
for embedding_model in embedding_models:
|
||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||
|
||||
data = marshal(datasets, dataset_detail_fields)
|
||||
data = cast(list[dict[str, Any]], marshal(datasets, dataset_detail_fields))
|
||||
for item in data:
|
||||
# convert embedding_model_provider to plugin standard format
|
||||
if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
|
||||
@@ -147,7 +144,7 @@ class DatasetListApi(Resource):
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=_validate_description_length,
|
||||
type=validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
@@ -192,7 +189,7 @@ class DatasetListApi(Resource):
|
||||
name=args["name"],
|
||||
description=args["description"],
|
||||
indexing_technique=args["indexing_technique"],
|
||||
account=current_user,
|
||||
account=cast(Account, current_user),
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
provider=args["provider"],
|
||||
external_knowledge_api_id=args["external_knowledge_api_id"],
|
||||
@@ -224,7 +221,7 @@ class DatasetApi(Resource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
data = marshal(dataset, dataset_detail_fields)
|
||||
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.embedding_model_provider:
|
||||
provider_id = ModelProviderID(dataset.embedding_model_provider)
|
||||
@@ -288,7 +285,7 @@ class DatasetApi(Resource):
|
||||
help="type is required. Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length)
|
||||
parser.add_argument("description", location="json", store_missing=False, type=validate_description_length)
|
||||
parser.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
@@ -369,7 +366,7 @@ class DatasetApi(Resource):
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
result_data = marshal(dataset, dataset_detail_fields)
|
||||
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
if data.get("partial_member_list") and data.get("permission") == "partial_members":
|
||||
@@ -688,7 +685,7 @@ class DatasetApiKeyApi(Resource):
|
||||
)
|
||||
|
||||
if current_key_count >= self.max_keys:
|
||||
flask_restx.abort(
|
||||
api.abort(
|
||||
400,
|
||||
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
|
||||
code="max_keys_exceeded",
|
||||
@@ -733,7 +730,7 @@ class DatasetApiDeleteApi(Resource):
|
||||
)
|
||||
|
||||
if key is None:
|
||||
flask_restx.abort(404, message="API key not found")
|
||||
api.abort(404, message="API key not found")
|
||||
|
||||
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
|
||||
db.session.commit()
|
||||
@@ -782,7 +779,6 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
| VectorType.TIDB_VECTOR
|
||||
| VectorType.CHROMA
|
||||
| VectorType.PGVECTO_RS
|
||||
| VectorType.BAIDU
|
||||
| VectorType.VIKINGDB
|
||||
| VectorType.UPSTASH
|
||||
):
|
||||
@@ -809,6 +805,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
| VectorType.TENCENT
|
||||
| VectorType.MATRIXONE
|
||||
| VectorType.CLICKZETTA
|
||||
| VectorType.BAIDU
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
@@ -838,7 +835,6 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
| VectorType.TIDB_VECTOR
|
||||
| VectorType.CHROMA
|
||||
| VectorType.PGVECTO_RS
|
||||
| VectorType.BAIDU
|
||||
| VectorType.VIKINGDB
|
||||
| VectorType.UPSTASH
|
||||
):
|
||||
@@ -863,6 +859,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
| VectorType.HUAWEI_CLOUD
|
||||
| VectorType.MATRIXONE
|
||||
| VectorType.CLICKZETTA
|
||||
| VectorType.BAIDU
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
|
||||
@@ -4,6 +4,7 @@ from argparse import ArgumentTypeError
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||
@@ -54,6 +55,7 @@ from fields.document_fields import (
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import login_required
|
||||
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from models.account import Account
|
||||
from models.dataset import DocumentPipelineExecutionLog
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
|
||||
@@ -211,13 +213,13 @@ class DatasetDocumentListApi(Resource):
|
||||
|
||||
if sort == "hit_count":
|
||||
sub_query = (
|
||||
db.select(DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
|
||||
sa.select(DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
|
||||
.group_by(DocumentSegment.document_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by(
|
||||
sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)),
|
||||
sort_logic(sa.func.coalesce(sub_query.c.total_hit_count, 0)),
|
||||
sort_logic(Document.position),
|
||||
)
|
||||
elif sort == "created_at":
|
||||
@@ -417,7 +419,9 @@ class DatasetInitApi(Resource):
|
||||
|
||||
try:
|
||||
dataset, documents, batch = DocumentService.save_document_without_dataset_id(
|
||||
tenant_id=current_user.current_tenant_id, knowledge_config=knowledge_config, account=current_user
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
knowledge_config=knowledge_config,
|
||||
account=cast(Account, current_user),
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@@ -451,7 +455,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
raise DocumentAlreadyFinishedError()
|
||||
|
||||
data_process_rule = document.dataset_process_rule
|
||||
data_process_rule_dict = data_process_rule.to_dict()
|
||||
data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {}
|
||||
|
||||
response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}
|
||||
|
||||
@@ -513,7 +517,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
if not documents:
|
||||
return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200
|
||||
data_process_rule = documents[0].dataset_process_rule
|
||||
data_process_rule_dict = data_process_rule.to_dict()
|
||||
data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {}
|
||||
extract_settings = []
|
||||
for document in documents:
|
||||
if document.indexing_status in {"completed", "error"}:
|
||||
@@ -752,7 +756,7 @@ class DocumentApi(DocumentResource):
|
||||
}
|
||||
else:
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict()
|
||||
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
|
||||
data_source_info = document.data_source_detail_dict
|
||||
response = {
|
||||
"id": document.id,
|
||||
@@ -1072,7 +1076,9 @@ class DocumentRenameApi(DocumentResource):
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
DatasetService.check_dataset_operator_permission(current_user, dataset)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_operator_permission(cast(Account, current_user), dataset)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
@@ -1113,6 +1119,7 @@ class WebsiteDocumentSyncApi(DocumentResource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/pipeline-execution-log")
|
||||
class DocumentPipelineExecutionLogApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -1146,29 +1153,3 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
|
||||
"input_data": log.input_data,
|
||||
"datasource_node_id": log.datasource_node_id,
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(GetProcessRuleApi, "/datasets/process-rule")
|
||||
api.add_resource(DatasetDocumentListApi, "/datasets/<uuid:dataset_id>/documents")
|
||||
api.add_resource(DatasetInitApi, "/datasets/init")
|
||||
api.add_resource(
|
||||
DocumentIndexingEstimateApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate"
|
||||
)
|
||||
api.add_resource(DocumentBatchIndexingEstimateApi, "/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-estimate")
|
||||
api.add_resource(DocumentBatchIndexingStatusApi, "/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-status")
|
||||
api.add_resource(DocumentIndexingStatusApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-status")
|
||||
api.add_resource(DocumentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
||||
api.add_resource(
|
||||
DocumentProcessingApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/<string:action>"
|
||||
)
|
||||
api.add_resource(DocumentMetadataApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata")
|
||||
api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/status/<string:action>/batch")
|
||||
api.add_resource(DocumentPauseApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause")
|
||||
api.add_resource(DocumentRecoverApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume")
|
||||
api.add_resource(DocumentRetryApi, "/datasets/<uuid:dataset_id>/retry")
|
||||
api.add_resource(DocumentRenameApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename")
|
||||
|
||||
api.add_resource(WebsiteDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")
|
||||
api.add_resource(
|
||||
DocumentPipelineExecutionLogApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/pipeline-execution-log"
|
||||
)
|
||||
|
||||
@@ -7,7 +7,7 @@ from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.datasets.error import (
|
||||
ChildChunkDeleteIndexError,
|
||||
@@ -37,6 +37,7 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
|
||||
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
|
||||
class DatasetDocumentSegmentListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -139,6 +140,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>")
|
||||
class DatasetDocumentSegmentApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -193,6 +195,7 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
|
||||
class DatasetDocumentSegmentAddApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -244,6 +247,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
|
||||
class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -345,6 +349,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import",
|
||||
"/datasets/batch_import_status/<uuid:job_id>",
|
||||
)
|
||||
class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -384,7 +392,12 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
# send batch add segments task
|
||||
redis_client.setnx(indexing_cache_key, "waiting")
|
||||
batch_create_segment_to_index_task.delay(
|
||||
str(job_id), upload_file_id, dataset_id, document_id, current_user.current_tenant_id, current_user.id
|
||||
str(job_id),
|
||||
upload_file_id,
|
||||
dataset_id,
|
||||
document_id,
|
||||
current_user.current_tenant_id,
|
||||
current_user.id,
|
||||
)
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 500
|
||||
@@ -393,7 +406,9 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, job_id):
|
||||
def get(self, job_id=None, dataset_id=None, document_id=None):
|
||||
if job_id is None:
|
||||
raise NotFound("The job does not exist.")
|
||||
job_id = str(job_id)
|
||||
indexing_cache_key = f"segment_batch_import_{job_id}"
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
@@ -403,6 +418,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
return {"job_id": job_id, "job_status": cache_result.decode()}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks")
|
||||
class ChildChunkAddApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -457,7 +473,8 @@ class ChildChunkAddApi(Resource):
|
||||
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset)
|
||||
content = args["content"]
|
||||
child_chunk = SegmentService.create_child_chunk(content, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
|
||||
@@ -546,13 +563,17 @@ class ChildChunkAddApi(Resource):
|
||||
parser.add_argument("chunks", type=list, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
chunks = [ChildChunkUpdateArgs(**chunk) for chunk in args.get("chunks")]
|
||||
chunks_data = args["chunks"]
|
||||
chunks = [ChildChunkUpdateArgs(**chunk) for chunk in chunks_data]
|
||||
child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
return {"data": marshal(child_chunks, child_chunk_fields)}, 200
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks/<uuid:child_chunk_id>"
|
||||
)
|
||||
class ChildChunkUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -660,33 +681,8 @@ class ChildChunkUpdateApi(Resource):
|
||||
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
child_chunk = SegmentService.update_child_chunk(
|
||||
args.get("content"), child_chunk, segment, document, dataset
|
||||
)
|
||||
content = args["content"]
|
||||
child_chunk = SegmentService.update_child_chunk(content, child_chunk, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
|
||||
|
||||
|
||||
api.add_resource(DatasetDocumentSegmentListApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
|
||||
api.add_resource(
|
||||
DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>"
|
||||
)
|
||||
api.add_resource(DatasetDocumentSegmentAddApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
|
||||
api.add_resource(
|
||||
DatasetDocumentSegmentUpdateApi,
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
DatasetDocumentSegmentBatchImportApi,
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import",
|
||||
"/datasets/batch_import_status/<uuid:job_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
ChildChunkAddApi,
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks",
|
||||
)
|
||||
api.add_resource(
|
||||
ChildChunkUpdateApi,
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks/<uuid:child_chunk_id>",
|
||||
)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import cast
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal, reqparse
|
||||
@@ -9,13 +11,14 @@ from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
from services.knowledge_service import ExternalDatasetTestService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
def _validate_name(name: str) -> str:
|
||||
if not name or len(name) < 1 or len(name) > 100:
|
||||
raise ValueError("Name must be between 1 to 100 characters.")
|
||||
return name
|
||||
@@ -274,7 +277,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
||||
response = HitTestingService.external_retrieve(
|
||||
dataset=dataset,
|
||||
query=args["query"],
|
||||
account=current_user,
|
||||
account=cast(Account, current_user),
|
||||
external_retrieval_model=args["external_retrieval_model"],
|
||||
metadata_filtering_conditions=args["metadata_filtering_conditions"],
|
||||
)
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services.dataset_service
|
||||
import services
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
@@ -20,6 +21,7 @@ from core.errors.error import (
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.hit_testing_fields import hit_testing_record_fields
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
@@ -59,7 +61,7 @@ class DatasetsHitTestingBase:
|
||||
response = HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=args["query"],
|
||||
account=current_user,
|
||||
account=cast(Account, current_user),
|
||||
retrieval_model=args["retrieval_model"],
|
||||
external_retrieval_model=args["external_retrieval_model"],
|
||||
limit=10,
|
||||
|
||||
@@ -4,7 +4,7 @@ from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
from fields.dataset_fields import dataset_metadata_fields
|
||||
from libs.login import login_required
|
||||
@@ -16,6 +16,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/metadata")
|
||||
class DatasetMetadataCreateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -50,6 +51,7 @@ class DatasetMetadataCreateApi(Resource):
|
||||
return MetadataService.get_dataset_metadatas(dataset), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
|
||||
class DatasetMetadataApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -60,6 +62,7 @@ class DatasetMetadataApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
name = args["name"]
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
metadata_id_str = str(metadata_id)
|
||||
@@ -68,7 +71,7 @@ class DatasetMetadataApi(Resource):
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name"))
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, name)
|
||||
return metadata, 200
|
||||
|
||||
@setup_required
|
||||
@@ -87,6 +90,7 @@ class DatasetMetadataApi(Resource):
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/metadata/built-in")
|
||||
class DatasetMetadataBuiltInFieldApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -97,6 +101,7 @@ class DatasetMetadataBuiltInFieldApi(Resource):
|
||||
return {"fields": built_in_fields}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>")
|
||||
class DatasetMetadataBuiltInFieldActionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -116,6 +121,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/metadata")
|
||||
class DocumentMetadataEditApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -136,10 +142,3 @@ class DocumentMetadataEditApi(Resource):
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
api.add_resource(DatasetMetadataCreateApi, "/datasets/<uuid:dataset_id>/metadata")
|
||||
api.add_resource(DatasetMetadataApi, "/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
|
||||
api.add_resource(DatasetMetadataBuiltInFieldApi, "/datasets/metadata/built-in")
|
||||
api.add_resource(DatasetMetadataBuiltInFieldActionApi, "/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>")
|
||||
api.add_resource(DocumentMetadataEditApi, "/datasets/<uuid:dataset_id>/documents/metadata")
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from flask import make_response, redirect, request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from libs.helper import StrLen
|
||||
from libs.login import login_required
|
||||
@@ -19,6 +19,7 @@ from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
|
||||
|
||||
@console_ns.route("/oauth/plugin/<path:provider_id>/datasource/get-authorization-url")
|
||||
class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -68,6 +69,7 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
||||
return response
|
||||
|
||||
|
||||
@console_ns.route("/oauth/plugin/<path:provider_id>/datasource/callback")
|
||||
class DatasourceOAuthCallback(Resource):
|
||||
@setup_required
|
||||
def get(self, provider_id: str):
|
||||
@@ -123,6 +125,7 @@ class DatasourceOAuthCallback(Resource):
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>")
|
||||
class DatasourceAuth(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -165,6 +168,7 @@ class DatasourceAuth(Resource):
|
||||
return {"result": datasources}, 200
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/delete")
|
||||
class DatasourceAuthDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -188,6 +192,7 @@ class DatasourceAuthDeleteApi(Resource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update")
|
||||
class DatasourceAuthUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -213,6 +218,7 @@ class DatasourceAuthUpdateApi(Resource):
|
||||
return {"result": "success"}, 201
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/list")
|
||||
class DatasourceAuthListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -225,6 +231,7 @@ class DatasourceAuthListApi(Resource):
|
||||
return {"result": jsonable_encoder(datasources)}, 200
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/default-list")
|
||||
class DatasourceHardCodeAuthListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -237,6 +244,7 @@ class DatasourceHardCodeAuthListApi(Resource):
|
||||
return {"result": jsonable_encoder(datasources)}, 200
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/custom-client")
|
||||
class DatasourceAuthOauthCustomClient(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -271,6 +279,7 @@ class DatasourceAuthOauthCustomClient(Resource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/default")
|
||||
class DatasourceAuthDefaultApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -291,6 +300,7 @@ class DatasourceAuthDefaultApi(Resource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update-name")
|
||||
class DatasourceUpdateProviderNameApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -311,52 +321,3 @@ class DatasourceUpdateProviderNameApi(Resource):
|
||||
credential_id=args["credential_id"],
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
api.add_resource(
|
||||
DatasourcePluginOAuthAuthorizationUrl,
|
||||
"/oauth/plugin/<path:provider_id>/datasource/get-authorization-url",
|
||||
)
|
||||
api.add_resource(
|
||||
DatasourceOAuthCallback,
|
||||
"/oauth/plugin/<path:provider_id>/datasource/callback",
|
||||
)
|
||||
api.add_resource(
|
||||
DatasourceAuth,
|
||||
"/auth/plugin/datasource/<path:provider_id>",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceAuthUpdateApi,
|
||||
"/auth/plugin/datasource/<path:provider_id>/update",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceAuthDeleteApi,
|
||||
"/auth/plugin/datasource/<path:provider_id>/delete",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceAuthListApi,
|
||||
"/auth/plugin/datasource/list",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceHardCodeAuthListApi,
|
||||
"/auth/plugin/datasource/default-list",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceAuthOauthCustomClient,
|
||||
"/auth/plugin/datasource/<path:provider_id>/custom-client",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceAuthDefaultApi,
|
||||
"/auth/plugin/datasource/<path:provider_id>/default",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceUpdateProviderNameApi,
|
||||
"/auth/plugin/datasource/<path:provider_id>/update-name",
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ from flask_restx import ( # type: ignore
|
||||
)
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import current_user, login_required
|
||||
@@ -13,6 +13,7 @@ from models.dataset import Pipeline
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
|
||||
class DataSourceContentPreviewApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -49,9 +50,3 @@ class DataSourceContentPreviewApi(Resource):
|
||||
credential_id=args.get("credential_id"),
|
||||
)
|
||||
return preview_content, 200
|
||||
|
||||
|
||||
api.add_resource(
|
||||
DataSourceContentPreviewApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview",
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
enterprise_license_required,
|
||||
@@ -20,18 +20,19 @@ from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
def _validate_name(name: str) -> str:
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
def _validate_description_length(description: str) -> str:
|
||||
if len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/templates")
|
||||
class PipelineTemplateListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -45,6 +46,7 @@ class PipelineTemplateListApi(Resource):
|
||||
return pipeline_templates, 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/templates/<string:template_id>")
|
||||
class PipelineTemplateDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -57,6 +59,7 @@ class PipelineTemplateDetailApi(Resource):
|
||||
return pipeline_template, 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/customized/templates/<string:template_id>")
|
||||
class CustomizedPipelineTemplateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -73,7 +76,7 @@ class CustomizedPipelineTemplateApi(Resource):
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
type=_validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
@@ -112,6 +115,7 @@ class CustomizedPipelineTemplateApi(Resource):
|
||||
return {"data": template.yaml_content}, 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<string:pipeline_id>/customized/publish")
|
||||
class PublishCustomizedPipelineTemplateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -129,7 +133,7 @@ class PublishCustomizedPipelineTemplateApi(Resource):
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
type=_validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
@@ -144,21 +148,3 @@ class PublishCustomizedPipelineTemplateApi(Resource):
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args)
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
api.add_resource(
|
||||
PipelineTemplateListApi,
|
||||
"/rag/pipeline/templates",
|
||||
)
|
||||
api.add_resource(
|
||||
PipelineTemplateDetailApi,
|
||||
"/rag/pipeline/templates/<string:template_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
CustomizedPipelineTemplateApi,
|
||||
"/rag/pipeline/customized/templates/<string:template_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
PublishCustomizedPipelineTemplateApi,
|
||||
"/rag/pipelines/<string:pipeline_id>/customized/publish",
|
||||
)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from flask_login import current_user # type: ignore # type: ignore
|
||||
from flask_restx import Resource, marshal, reqparse # type: ignore
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
@@ -20,18 +20,7 @@ from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo,
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/dataset")
|
||||
class CreateRagPipelineDatasetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -84,6 +73,7 @@ class CreateRagPipelineDatasetApi(Resource):
|
||||
return import_info, 201
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/empty-dataset")
|
||||
class CreateEmptyRagPipelineDatasetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -108,7 +98,3 @@ class CreateEmptyRagPipelineDatasetApi(Resource):
|
||||
),
|
||||
)
|
||||
return marshal(dataset, dataset_detail_fields), 201
|
||||
|
||||
|
||||
api.add_resource(CreateRagPipelineDatasetApi, "/rag/pipeline/dataset")
|
||||
api.add_resource(CreateEmptyRagPipelineDatasetApi, "/rag/pipeline/empty-dataset")
|
||||
|
||||
@@ -1,24 +1,22 @@
|
||||
import logging
|
||||
from typing import Any, NoReturn
|
||||
from typing import NoReturn
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
DraftWorkflowNotExist,
|
||||
)
|
||||
from controllers.console.app.workflow_draft_variable import (
|
||||
_WORKFLOW_DRAFT_VARIABLE_FIELDS,
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
|
||||
_WORKFLOW_DRAFT_VARIABLE_FIELDS, # type: ignore[private-usage]
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, # type: ignore[private-usage]
|
||||
)
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.variables.segment_group import SegmentGroup
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from extensions.ext_database import db
|
||||
@@ -34,32 +32,6 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList,
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _convert_values_to_json_serializable_object(value: Segment) -> Any:
|
||||
if isinstance(value, FileSegment):
|
||||
return value.value.model_dump()
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
return [i.model_dump() for i in value.value]
|
||||
elif isinstance(value, SegmentGroup):
|
||||
return [_convert_values_to_json_serializable_object(i) for i in value.value]
|
||||
else:
|
||||
return value.value
|
||||
|
||||
|
||||
def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
|
||||
value = variable.get_value()
|
||||
# create a copy of the value to avoid affecting the model cache.
|
||||
value = value.model_copy(deep=True)
|
||||
# Refresh the url signature before returning it to client.
|
||||
if isinstance(value, FileSegment):
|
||||
file = value.value
|
||||
file.remote_url = file.generate_url()
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
files = value.value
|
||||
for file in files:
|
||||
file.remote_url = file.generate_url()
|
||||
return _convert_values_to_json_serializable_object(value)
|
||||
|
||||
|
||||
def _create_pagination_parser():
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
@@ -104,13 +76,14 @@ def _api_prerequisite(f):
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def wrapper(*args, **kwargs):
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables")
|
||||
class RagPipelineVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
|
||||
@@ -168,6 +141,7 @@ def validate_node_id(node_id: str) -> NoReturn | None:
|
||||
return None
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/variables")
|
||||
class RagPipelineNodeVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
@@ -190,6 +164,7 @@ class RagPipelineNodeVariableCollectionApi(Resource):
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>")
|
||||
class RagPipelineVariableApi(Resource):
|
||||
_PATCH_NAME_FIELD = "name"
|
||||
_PATCH_VALUE_FIELD = "value"
|
||||
@@ -284,6 +259,7 @@ class RagPipelineVariableApi(Resource):
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>/reset")
|
||||
class RagPipelineVariableResetApi(Resource):
|
||||
@_api_prerequisite
|
||||
def put(self, pipeline: Pipeline, variable_id: str):
|
||||
@@ -325,6 +301,7 @@ def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList
|
||||
return draft_vars
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/system-variables")
|
||||
class RagPipelineSystemVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
@@ -332,6 +309,7 @@ class RagPipelineSystemVariableCollectionApi(Resource):
|
||||
return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/environment-variables")
|
||||
class RagPipelineEnvironmentVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
def get(self, pipeline: Pipeline):
|
||||
@@ -364,26 +342,3 @@ class RagPipelineEnvironmentVariableCollectionApi(Resource):
|
||||
)
|
||||
|
||||
return {"items": env_vars_list}
|
||||
|
||||
|
||||
api.add_resource(
|
||||
RagPipelineVariableCollectionApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineNodeVariableCollectionApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/variables",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineVariableApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>"
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineVariableResetApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>/reset"
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineSystemVariableCollectionApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/system-variables"
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineEnvironmentVariableCollectionApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/environment-variables",
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from flask_restx import Resource, marshal_with, reqparse # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
@@ -20,6 +20,7 @@ from services.app_dsl_service import ImportStatus
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/imports")
|
||||
class RagPipelineImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -66,6 +67,7 @@ class RagPipelineImportApi(Resource):
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/imports/<string:import_id>/confirm")
|
||||
class RagPipelineImportConfirmApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -90,6 +92,7 @@ class RagPipelineImportConfirmApi(Resource):
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/imports/<string:pipeline_id>/check-dependencies")
|
||||
class RagPipelineImportCheckDependenciesApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -107,6 +110,7 @@ class RagPipelineImportCheckDependenciesApi(Resource):
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<string:pipeline_id>/exports")
|
||||
class RagPipelineExportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -118,30 +122,13 @@ class RagPipelineExportApi(Resource):
|
||||
|
||||
# Add include_secret params
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("include_secret", type=bool, default=False, location="args")
|
||||
parser.add_argument("include_secret", type=str, default="false", location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
export_service = RagPipelineDslService(session)
|
||||
result = export_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=args["include_secret"])
|
||||
result = export_service.export_rag_pipeline_dsl(
|
||||
pipeline=pipeline, include_secret=args["include_secret"] == "true"
|
||||
)
|
||||
|
||||
return {"data": result}, 200
|
||||
|
||||
|
||||
# Import Rag Pipeline
|
||||
api.add_resource(
|
||||
RagPipelineImportApi,
|
||||
"/rag/pipelines/imports",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineImportConfirmApi,
|
||||
"/rag/pipelines/imports/<string:import_id>/confirm",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineImportCheckDependenciesApi,
|
||||
"/rag/pipelines/imports/<string:pipeline_id>/check-dependencies",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineExportApi,
|
||||
"/rag/pipelines/<string:pipeline_id>/exports",
|
||||
)
|
||||
|
||||
@@ -9,8 +9,7 @@ from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
ConversationCompletedError,
|
||||
DraftWorkflowNotExist,
|
||||
@@ -51,6 +50,7 @@ from services.rag_pipeline.rag_pipeline_transform_service import RagPipelineTran
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft")
|
||||
class DraftRagPipelineApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -148,6 +148,7 @@ class DraftRagPipelineApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
|
||||
class RagPipelineDraftRunIterationNodeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -182,6 +183,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run")
|
||||
class RagPipelineDraftRunLoopNodeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -216,6 +218,7 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run")
|
||||
class DraftRagPipelineRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -250,6 +253,7 @@ class DraftRagPipelineRunApi(Resource):
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/run")
|
||||
class PublishedRagPipelineRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -370,6 +374,7 @@ class PublishedRagPipelineRunApi(Resource):
|
||||
#
|
||||
# return result
|
||||
#
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run")
|
||||
class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -412,6 +417,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run")
|
||||
class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -454,6 +460,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run")
|
||||
class RagPipelineDraftNodeRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -487,6 +494,7 @@ class RagPipelineDraftNodeRunApi(Resource):
|
||||
return workflow_node_execution
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs/tasks/<string:task_id>/stop")
|
||||
class RagPipelineTaskStopApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -505,6 +513,7 @@ class RagPipelineTaskStopApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/publish")
|
||||
class PublishedRagPipelineApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -560,6 +569,7 @@ class PublishedRagPipelineApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs")
|
||||
class DefaultRagPipelineBlockConfigsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -578,6 +588,7 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
|
||||
return rag_pipeline_service.get_default_block_configs()
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>")
|
||||
class DefaultRagPipelineBlockConfigApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -609,18 +620,7 @@ class DefaultRagPipelineBlockConfigApi(Resource):
|
||||
return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters)
|
||||
|
||||
|
||||
class RagPipelineConfigApi(Resource):
|
||||
"""Resource for rag pipeline configuration."""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, pipeline_id):
|
||||
return {
|
||||
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows")
|
||||
class PublishedAllRagPipelineApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -669,6 +669,7 @@ class PublishedAllRagPipelineApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>")
|
||||
class RagPipelineByIdApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -726,6 +727,7 @@ class RagPipelineByIdApi(Resource):
|
||||
return workflow
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters")
|
||||
class PublishedRagPipelineSecondStepApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -751,6 +753,7 @@ class PublishedRagPipelineSecondStepApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters")
|
||||
class PublishedRagPipelineFirstStepApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -776,6 +779,7 @@ class PublishedRagPipelineFirstStepApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters")
|
||||
class DraftRagPipelineFirstStepApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -801,6 +805,7 @@ class DraftRagPipelineFirstStepApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters")
|
||||
class DraftRagPipelineSecondStepApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -827,6 +832,7 @@ class DraftRagPipelineSecondStepApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs")
|
||||
class RagPipelineWorkflowRunListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -848,6 +854,7 @@ class RagPipelineWorkflowRunListApi(Resource):
|
||||
return result
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs/<uuid:run_id>")
|
||||
class RagPipelineWorkflowRunDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -866,6 +873,7 @@ class RagPipelineWorkflowRunDetailApi(Resource):
|
||||
return workflow_run
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs/<uuid:run_id>/node-executions")
|
||||
class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -889,6 +897,7 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
|
||||
return {"data": node_executions}
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/datasource-plugins")
|
||||
class DatasourceListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -904,6 +913,7 @@ class DatasourceListApi(Resource):
|
||||
return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(tenant_id))
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/last-run")
|
||||
class RagPipelineWorkflowLastRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -925,6 +935,7 @@ class RagPipelineWorkflowLastRunApi(Resource):
|
||||
return node_exec
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/transform/datasets/<uuid:dataset_id>")
|
||||
class RagPipelineTransformApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -942,6 +953,7 @@ class RagPipelineTransformApi(Resource):
|
||||
return result
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect")
|
||||
class RagPipelineDatasourceVariableApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -971,6 +983,7 @@ class RagPipelineDatasourceVariableApi(Resource):
|
||||
return workflow_node_execution
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/recommended-plugins")
|
||||
class RagPipelineRecommendedPluginApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -979,118 +992,3 @@ class RagPipelineRecommendedPluginApi(Resource):
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
recommended_plugins = rag_pipeline_service.get_recommended_plugins()
|
||||
return recommended_plugins
|
||||
|
||||
|
||||
api.add_resource(
|
||||
DraftRagPipelineApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineConfigApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/config",
|
||||
)
|
||||
api.add_resource(
|
||||
DraftRagPipelineRunApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run",
|
||||
)
|
||||
api.add_resource(
|
||||
PublishedRagPipelineRunApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/run",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineTaskStopApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflow-runs/tasks/<string:task_id>/stop",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineDraftNodeRunApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelinePublishedDatasourceNodeRunApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
RagPipelineDraftDatasourceNodeRunApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
RagPipelineDraftRunIterationNodeApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
RagPipelineDraftRunLoopNodeApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
PublishedRagPipelineApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/publish",
|
||||
)
|
||||
api.add_resource(
|
||||
PublishedAllRagPipelineApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows",
|
||||
)
|
||||
api.add_resource(
|
||||
DefaultRagPipelineBlockConfigsApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs",
|
||||
)
|
||||
api.add_resource(
|
||||
DefaultRagPipelineBlockConfigApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineByIdApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineWorkflowRunListApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflow-runs",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineWorkflowRunDetailApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflow-runs/<uuid:run_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineWorkflowRunNodeExecutionListApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflow-runs/<uuid:run_id>/node-executions",
|
||||
)
|
||||
api.add_resource(
|
||||
DatasourceListApi,
|
||||
"/rag/pipelines/datasource-plugins",
|
||||
)
|
||||
api.add_resource(
|
||||
PublishedRagPipelineSecondStepApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters",
|
||||
)
|
||||
api.add_resource(
|
||||
PublishedRagPipelineFirstStepApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters",
|
||||
)
|
||||
api.add_resource(
|
||||
DraftRagPipelineSecondStepApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters",
|
||||
)
|
||||
api.add_resource(
|
||||
DraftRagPipelineFirstStepApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineWorkflowLastRunApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/last-run",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineTransformApi,
|
||||
"/rag/pipelines/transform/datasets/<uuid:dataset_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineDatasourceVariableApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
RagPipelineRecommendedPluginApi,
|
||||
"/rag/pipelines/recommended-plugins",
|
||||
)
|
||||
|
||||
@@ -26,9 +26,15 @@ from services.errors.audio import (
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
|
||||
from .. import console_ns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/audio-to-text",
|
||||
endpoint="installed_app_audio",
|
||||
)
|
||||
class ChatAudioApi(InstalledAppResource):
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
@@ -65,6 +71,10 @@ class ChatAudioApi(InstalledAppResource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/text-to-audio",
|
||||
endpoint="installed_app_text",
|
||||
)
|
||||
class ChatTextApi(InstalledAppResource):
|
||||
def post(self, installed_app):
|
||||
from flask_restx import reqparse
|
||||
|
||||
@@ -33,10 +33,16 @@ from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
from .. import console_ns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# define completion api for user
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/completion-messages",
|
||||
endpoint="installed_app_completion",
|
||||
)
|
||||
class CompletionApi(InstalledAppResource):
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
@@ -87,6 +93,10 @@ class CompletionApi(InstalledAppResource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/completion-messages/<string:task_id>/stop",
|
||||
endpoint="installed_app_stop_completion",
|
||||
)
|
||||
class CompletionStopApi(InstalledAppResource):
|
||||
def post(self, installed_app, task_id):
|
||||
app_model = installed_app.app
|
||||
@@ -100,6 +110,10 @@ class CompletionStopApi(InstalledAppResource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/chat-messages",
|
||||
endpoint="installed_app_chat_completion",
|
||||
)
|
||||
class ChatApi(InstalledAppResource):
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
@@ -153,6 +167,10 @@ class ChatApi(InstalledAppResource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/chat-messages/<string:task_id>/stop",
|
||||
endpoint="installed_app_stop_chat_completion",
|
||||
)
|
||||
class ChatStopApi(InstalledAppResource):
|
||||
def post(self, installed_app, task_id):
|
||||
app_model = installed_app.app
|
||||
|
||||
@@ -16,7 +16,13 @@ from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
|
||||
from services.web_conversation_service import WebConversationService
|
||||
|
||||
from .. import console_ns
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/conversations",
|
||||
endpoint="installed_app_conversations",
|
||||
)
|
||||
class ConversationListApi(InstalledAppResource):
|
||||
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
||||
def get(self, installed_app):
|
||||
@@ -52,6 +58,10 @@ class ConversationListApi(InstalledAppResource):
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>",
|
||||
endpoint="installed_app_conversation",
|
||||
)
|
||||
class ConversationApi(InstalledAppResource):
|
||||
def delete(self, installed_app, c_id):
|
||||
app_model = installed_app.app
|
||||
@@ -70,6 +80,10 @@ class ConversationApi(InstalledAppResource):
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/name",
|
||||
endpoint="installed_app_conversation_rename",
|
||||
)
|
||||
class ConversationRenameApi(InstalledAppResource):
|
||||
@marshal_with(simple_conversation_fields)
|
||||
def post(self, installed_app, c_id):
|
||||
@@ -95,6 +109,10 @@ class ConversationRenameApi(InstalledAppResource):
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/pin",
|
||||
endpoint="installed_app_conversation_pin",
|
||||
)
|
||||
class ConversationPinApi(InstalledAppResource):
|
||||
def patch(self, installed_app, c_id):
|
||||
app_model = installed_app.app
|
||||
@@ -114,6 +132,10 @@ class ConversationPinApi(InstalledAppResource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/unpin",
|
||||
endpoint="installed_app_conversation_unpin",
|
||||
)
|
||||
class ConversationUnPinApi(InstalledAppResource):
|
||||
def patch(self, installed_app, c_id):
|
||||
app_model = installed_app.app
|
||||
|
||||
@@ -36,9 +36,15 @@ from services.errors.message import (
|
||||
)
|
||||
from services.message_service import MessageService
|
||||
|
||||
from .. import console_ns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/messages",
|
||||
endpoint="installed_app_messages",
|
||||
)
|
||||
class MessageListApi(InstalledAppResource):
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
def get(self, installed_app):
|
||||
@@ -66,6 +72,10 @@ class MessageListApi(InstalledAppResource):
|
||||
raise NotFound("First Message Not Exists.")
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks",
|
||||
endpoint="installed_app_message_feedback",
|
||||
)
|
||||
class MessageFeedbackApi(InstalledAppResource):
|
||||
def post(self, installed_app, message_id):
|
||||
app_model = installed_app.app
|
||||
@@ -93,6 +103,10 @@ class MessageFeedbackApi(InstalledAppResource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this",
|
||||
endpoint="installed_app_more_like_this",
|
||||
)
|
||||
class MessageMoreLikeThisApi(InstalledAppResource):
|
||||
def get(self, installed_app, message_id):
|
||||
app_model = installed_app.app
|
||||
@@ -139,6 +153,10 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions",
|
||||
endpoint="installed_app_suggested_question",
|
||||
)
|
||||
class MessageSuggestedQuestionApi(InstalledAppResource):
|
||||
def get(self, installed_app, message_id):
|
||||
app_model = installed_app.app
|
||||
|
||||
@@ -27,9 +27,12 @@ from models.model import AppMode, InstalledApp
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
from .. import console_ns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
|
||||
class InstalledAppWorkflowRunApi(InstalledAppResource):
|
||||
def post(self, installed_app: InstalledApp):
|
||||
"""
|
||||
@@ -70,6 +73,7 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop")
|
||||
class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
|
||||
def post(self, installed_app: InstalledApp, task_id: str):
|
||||
"""
|
||||
|
||||
@@ -26,9 +26,12 @@ from libs.login import login_required
|
||||
from models import Account
|
||||
from services.file_service import FileService
|
||||
|
||||
from . import console_ns
|
||||
|
||||
PREVIEW_WORDS_LIMIT = 3000
|
||||
|
||||
|
||||
@console_ns.route("/files/upload")
|
||||
class FileApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -88,6 +91,7 @@ class FileApi(Resource):
|
||||
return upload_file, 201
|
||||
|
||||
|
||||
@console_ns.route("/files/<uuid:file_id>/preview")
|
||||
class FilePreviewApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -98,6 +102,7 @@ class FilePreviewApi(Resource):
|
||||
return {"content": text}
|
||||
|
||||
|
||||
@console_ns.route("/files/support-type")
|
||||
class FileSupportTypeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
||||
@@ -19,7 +19,10 @@ from fields.file_fields import file_fields_with_signed_url, remote_file_info_fie
|
||||
from models.account import Account
|
||||
from services.file_service import FileService
|
||||
|
||||
from . import console_ns
|
||||
|
||||
|
||||
@console_ns.route("/remote-files/<path:url>")
|
||||
class RemoteFileInfoApi(Resource):
|
||||
@marshal_with(remote_file_info_fields)
|
||||
def get(self, url):
|
||||
@@ -35,6 +38,7 @@ class RemoteFileInfoApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/remote-files/upload")
|
||||
class RemoteFileUploadApi(Resource):
|
||||
@marshal_with(file_fields_with_signed_url)
|
||||
def post(self):
|
||||
|
||||
@@ -2,7 +2,6 @@ import logging
|
||||
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
@@ -10,9 +9,12 @@ from controllers.console.wraps import (
|
||||
from core.schemas.schema_manager import SchemaManager
|
||||
from libs.login import login_required
|
||||
|
||||
from . import console_ns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@console_ns.route("/spec/schema-definitions")
|
||||
class SpecSchemaDefinitionsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -30,6 +32,3 @@ class SpecSchemaDefinitionsApi(Resource):
|
||||
logger.exception("Failed to get schema definitions from local registry")
|
||||
# Return empty array as fallback
|
||||
return [], 200
|
||||
|
||||
|
||||
api.add_resource(SpecSchemaDefinitionsApi, "/spec/schema-definitions")
|
||||
|
||||
@@ -3,7 +3,7 @@ from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.tag_fields import dataset_tag_fields
|
||||
from libs.login import login_required
|
||||
@@ -17,6 +17,7 @@ def _validate_name(name):
|
||||
return name
|
||||
|
||||
|
||||
@console_ns.route("/tags")
|
||||
class TagListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -52,6 +53,7 @@ class TagListApi(Resource):
|
||||
return response, 200
|
||||
|
||||
|
||||
@console_ns.route("/tags/<uuid:tag_id>")
|
||||
class TagUpdateDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -89,6 +91,7 @@ class TagUpdateDeleteApi(Resource):
|
||||
return 204
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/create")
|
||||
class TagBindingCreateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -114,6 +117,7 @@ class TagBindingCreateApi(Resource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/remove")
|
||||
class TagBindingDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -133,9 +137,3 @@ class TagBindingDeleteApi(Resource):
|
||||
TagService.delete_tag_binding(args)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
api.add_resource(TagListApi, "/tags")
|
||||
api.add_resource(TagUpdateDeleteApi, "/tags/<uuid:tag_id>")
|
||||
api.add_resource(TagBindingCreateApi, "/tag-bindings/create")
|
||||
api.add_resource(TagBindingDeleteApi, "/tag-bindings/remove")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from packaging import version
|
||||
|
||||
@@ -57,7 +57,11 @@ class VersionApi(Resource):
|
||||
return result
|
||||
|
||||
try:
|
||||
response = requests.get(check_update_url, {"current_version": args["current_version"]}, timeout=(3, 10))
|
||||
response = httpx.get(
|
||||
check_update_url,
|
||||
params={"current_version": args["current_version"]},
|
||||
timeout=httpx.Timeout(connect=3, read=10),
|
||||
)
|
||||
except Exception as error:
|
||||
logger.warning("Check update version error: %s.", str(error))
|
||||
result["version"] = args["current_version"]
|
||||
|
||||
@@ -24,20 +24,14 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
||||
NOTE: user_id is not trusted, it could be maliciously set to any value.
|
||||
As a result, it could only be considered as an end user id.
|
||||
"""
|
||||
if not user_id:
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
|
||||
is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
if not user_id:
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
|
||||
user_model = None
|
||||
|
||||
user_model = (
|
||||
session.query(EndUser)
|
||||
.where(
|
||||
EndUser.id == user_id,
|
||||
EndUser.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not user_model:
|
||||
if is_anonymous:
|
||||
user_model = (
|
||||
session.query(EndUser)
|
||||
.where(
|
||||
@@ -46,11 +40,21 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
||||
)
|
||||
.first()
|
||||
)
|
||||
else:
|
||||
user_model = (
|
||||
session.query(EndUser)
|
||||
.where(
|
||||
EndUser.id == user_id,
|
||||
EndUser.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not user_model:
|
||||
user_model = EndUser(
|
||||
tenant_id=tenant_id,
|
||||
type="service_api",
|
||||
is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID.value,
|
||||
is_anonymous=is_anonymous,
|
||||
session_id=user_id,
|
||||
)
|
||||
session.add(user_model)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from typing import Literal
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services.dataset_service
|
||||
import services
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
|
||||
from controllers.service_api.wraps import (
|
||||
@@ -17,6 +17,7 @@ from core.provider_manager import ProviderManager
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from fields.tag_fields import build_dataset_tag_fields
|
||||
from libs.login import current_user
|
||||
from libs.validators import validate_description_length
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, DatasetPermissionEnum
|
||||
from models.provider_ids import ModelProviderID
|
||||
@@ -31,12 +32,6 @@ def _validate_name(name):
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if description and len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
# Define parsers for dataset operations
|
||||
dataset_create_parser = reqparse.RequestParser()
|
||||
dataset_create_parser.add_argument(
|
||||
@@ -48,7 +43,7 @@ dataset_create_parser.add_argument(
|
||||
)
|
||||
dataset_create_parser.add_argument(
|
||||
"description",
|
||||
type=_validate_description_length,
|
||||
type=validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
@@ -101,7 +96,7 @@ dataset_update_parser.add_argument(
|
||||
type=_validate_name,
|
||||
)
|
||||
dataset_update_parser.add_argument(
|
||||
"description", location="json", store_missing=False, type=_validate_description_length
|
||||
"description", location="json", store_missing=False, type=validate_description_length
|
||||
)
|
||||
dataset_update_parser.add_argument(
|
||||
"indexing_technique",
|
||||
@@ -254,19 +249,21 @@ class DatasetListApi(DatasetApiResource):
|
||||
"""Resource for creating datasets."""
|
||||
args = dataset_create_parser.parse_args()
|
||||
|
||||
if args.get("embedding_model_provider"):
|
||||
DatasetService.check_embedding_model_setting(
|
||||
tenant_id, args.get("embedding_model_provider"), args.get("embedding_model")
|
||||
)
|
||||
embedding_model_provider = args.get("embedding_model_provider")
|
||||
embedding_model = args.get("embedding_model")
|
||||
if embedding_model_provider and embedding_model:
|
||||
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
|
||||
|
||||
retrieval_model = args.get("retrieval_model")
|
||||
if (
|
||||
args.get("retrieval_model")
|
||||
and args.get("retrieval_model").get("reranking_model")
|
||||
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
|
||||
retrieval_model
|
||||
and retrieval_model.get("reranking_model")
|
||||
and retrieval_model.get("reranking_model").get("reranking_provider_name")
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_provider_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -317,7 +314,7 @@ class DatasetApi(DatasetApiResource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
data = marshal(dataset, dataset_detail_fields)
|
||||
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
assert isinstance(current_user, Account)
|
||||
@@ -331,8 +328,8 @@ class DatasetApi(DatasetApiResource):
|
||||
for embedding_model in embedding_models:
|
||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||
|
||||
if data["indexing_technique"] == "high_quality":
|
||||
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
|
||||
if data.get("indexing_technique") == "high_quality":
|
||||
item_model = f"{data.get('embedding_model')}:{data.get('embedding_model_provider')}"
|
||||
if item_model in model_names:
|
||||
data["embedding_available"] = True
|
||||
else:
|
||||
@@ -341,7 +338,9 @@ class DatasetApi(DatasetApiResource):
|
||||
data["embedding_available"] = True
|
||||
|
||||
# force update search method to keyword_search if indexing_technique is economic
|
||||
data["retrieval_model_dict"]["search_method"] = "keyword_search"
|
||||
retrieval_model_dict = data.get("retrieval_model_dict")
|
||||
if retrieval_model_dict:
|
||||
retrieval_model_dict["search_method"] = "keyword_search"
|
||||
|
||||
if data.get("permission") == "partial_members":
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
@@ -372,19 +371,24 @@ class DatasetApi(DatasetApiResource):
|
||||
data = request.get_json()
|
||||
|
||||
# check embedding model setting
|
||||
if data.get("indexing_technique") == "high_quality" or data.get("embedding_model_provider"):
|
||||
DatasetService.check_embedding_model_setting(
|
||||
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
|
||||
)
|
||||
embedding_model_provider = data.get("embedding_model_provider")
|
||||
embedding_model = data.get("embedding_model")
|
||||
if data.get("indexing_technique") == "high_quality" or embedding_model_provider:
|
||||
if embedding_model_provider and embedding_model:
|
||||
DatasetService.check_embedding_model_setting(
|
||||
dataset.tenant_id, embedding_model_provider, embedding_model
|
||||
)
|
||||
|
||||
retrieval_model = data.get("retrieval_model")
|
||||
if (
|
||||
data.get("retrieval_model")
|
||||
and data.get("retrieval_model").get("reranking_model")
|
||||
and data.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
|
||||
retrieval_model
|
||||
and retrieval_model.get("reranking_model")
|
||||
and retrieval_model.get("reranking_model").get("reranking_provider_name")
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
dataset.tenant_id,
|
||||
data.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
|
||||
data.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_provider_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
@@ -397,7 +401,7 @@ class DatasetApi(DatasetApiResource):
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
result_data = marshal(dataset, dataset_detail_fields)
|
||||
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
assert isinstance(current_user, Account)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
@@ -591,9 +595,10 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
|
||||
args = tag_update_parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
tag = TagService.update_tags(args, args.get("tag_id"))
|
||||
tag_id = args["tag_id"]
|
||||
tag = TagService.update_tags(args, tag_id)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(args.get("tag_id"))
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||
|
||||
@@ -616,7 +621,7 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
args = tag_delete_parser.parse_args()
|
||||
TagService.delete_tag(args.get("tag_id"))
|
||||
TagService.delete_tag(args["tag_id"])
|
||||
|
||||
return 204
|
||||
|
||||
|
||||
@@ -30,7 +30,6 @@ from extensions.ext_database import db
|
||||
from fields.document_fields import document_fields, document_status_fields
|
||||
from libs.login import current_user
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.model import EndUser
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
|
||||
from services.file_service import FileService
|
||||
@@ -109,19 +108,21 @@ class DocumentAddByTextApi(DatasetApiResource):
|
||||
if text is None or name is None:
|
||||
raise ValueError("Both 'text' and 'name' must be non-null values.")
|
||||
|
||||
if args.get("embedding_model_provider"):
|
||||
DatasetService.check_embedding_model_setting(
|
||||
tenant_id, args.get("embedding_model_provider"), args.get("embedding_model")
|
||||
)
|
||||
embedding_model_provider = args.get("embedding_model_provider")
|
||||
embedding_model = args.get("embedding_model")
|
||||
if embedding_model_provider and embedding_model:
|
||||
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
|
||||
|
||||
retrieval_model = args.get("retrieval_model")
|
||||
if (
|
||||
args.get("retrieval_model")
|
||||
and args.get("retrieval_model").get("reranking_model")
|
||||
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
|
||||
retrieval_model
|
||||
and retrieval_model.get("reranking_model")
|
||||
and retrieval_model.get("reranking_model").get("reranking_provider_name")
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_provider_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
if not current_user:
|
||||
@@ -188,15 +189,16 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
retrieval_model = args.get("retrieval_model")
|
||||
if (
|
||||
args.get("retrieval_model")
|
||||
and args.get("retrieval_model").get("reranking_model")
|
||||
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
|
||||
retrieval_model
|
||||
and retrieval_model.get("reranking_model")
|
||||
and retrieval_model.get("reranking_model").get("reranking_provider_name")
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_provider_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
@@ -311,8 +313,6 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
if not isinstance(current_user, EndUser):
|
||||
raise ValueError("Invalid user account")
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
@@ -406,9 +406,6 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
if not isinstance(current_user, EndUser):
|
||||
raise ValueError("Invalid user account")
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
|
||||
@@ -106,7 +106,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name"))
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args["name"])
|
||||
return marshal(metadata, dataset_metadata_fields), 200
|
||||
|
||||
@service_api_ns.doc("delete_dataset_metadata")
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import uuid
|
||||
from typing import Literal, cast
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
DatasetEntity,
|
||||
@@ -74,6 +75,9 @@ class DatasetConfigManager:
|
||||
return None
|
||||
query_variable = config.get("dataset_query_variable")
|
||||
|
||||
metadata_model_config_dict = dataset_configs.get("metadata_model_config")
|
||||
metadata_filtering_conditions_dict = dataset_configs.get("metadata_filtering_conditions")
|
||||
|
||||
if dataset_configs["retrieval_model"] == "single":
|
||||
return DatasetEntity(
|
||||
dataset_ids=dataset_ids,
|
||||
@@ -82,18 +86,23 @@ class DatasetConfigManager:
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||
dataset_configs["retrieval_model"]
|
||||
),
|
||||
metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"),
|
||||
metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config"))
|
||||
if dataset_configs.get("metadata_model_config")
|
||||
metadata_filtering_mode=cast(
|
||||
Literal["disabled", "automatic", "manual"],
|
||||
dataset_configs.get("metadata_filtering_mode", "disabled"),
|
||||
),
|
||||
metadata_model_config=ModelConfig(**metadata_model_config_dict)
|
||||
if isinstance(metadata_model_config_dict, dict)
|
||||
else None,
|
||||
metadata_filtering_conditions=MetadataFilteringCondition(
|
||||
**dataset_configs.get("metadata_filtering_conditions", {})
|
||||
)
|
||||
if dataset_configs.get("metadata_filtering_conditions")
|
||||
metadata_filtering_conditions=MetadataFilteringCondition(**metadata_filtering_conditions_dict)
|
||||
if isinstance(metadata_filtering_conditions_dict, dict)
|
||||
else None,
|
||||
),
|
||||
)
|
||||
else:
|
||||
score_threshold_val = dataset_configs.get("score_threshold")
|
||||
reranking_model_val = dataset_configs.get("reranking_model")
|
||||
weights_val = dataset_configs.get("weights")
|
||||
|
||||
return DatasetEntity(
|
||||
dataset_ids=dataset_ids,
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
@@ -101,22 +110,23 @@ class DatasetConfigManager:
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||
dataset_configs["retrieval_model"]
|
||||
),
|
||||
top_k=dataset_configs.get("top_k", 4),
|
||||
score_threshold=dataset_configs.get("score_threshold")
|
||||
if dataset_configs.get("score_threshold_enabled", False)
|
||||
top_k=int(dataset_configs.get("top_k", 4)),
|
||||
score_threshold=float(score_threshold_val)
|
||||
if dataset_configs.get("score_threshold_enabled", False) and score_threshold_val is not None
|
||||
else None,
|
||||
reranking_model=dataset_configs.get("reranking_model"),
|
||||
weights=dataset_configs.get("weights"),
|
||||
reranking_enabled=dataset_configs.get("reranking_enabled", True),
|
||||
reranking_model=reranking_model_val if isinstance(reranking_model_val, dict) else None,
|
||||
weights=weights_val if isinstance(weights_val, dict) else None,
|
||||
reranking_enabled=bool(dataset_configs.get("reranking_enabled", True)),
|
||||
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
|
||||
metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"),
|
||||
metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config"))
|
||||
if dataset_configs.get("metadata_model_config")
|
||||
metadata_filtering_mode=cast(
|
||||
Literal["disabled", "automatic", "manual"],
|
||||
dataset_configs.get("metadata_filtering_mode", "disabled"),
|
||||
),
|
||||
metadata_model_config=ModelConfig(**metadata_model_config_dict)
|
||||
if isinstance(metadata_model_config_dict, dict)
|
||||
else None,
|
||||
metadata_filtering_conditions=MetadataFilteringCondition(
|
||||
**dataset_configs.get("metadata_filtering_conditions", {})
|
||||
)
|
||||
if dataset_configs.get("metadata_filtering_conditions")
|
||||
metadata_filtering_conditions=MetadataFilteringCondition(**metadata_filtering_conditions_dict)
|
||||
if isinstance(metadata_filtering_conditions_dict, dict)
|
||||
else None,
|
||||
),
|
||||
)
|
||||
@@ -134,18 +144,17 @@ class DatasetConfigManager:
|
||||
config = cls.extract_dataset_config_for_legacy_compatibility(tenant_id, app_mode, config)
|
||||
|
||||
# dataset_configs
|
||||
if not config.get("dataset_configs"):
|
||||
config["dataset_configs"] = {"retrieval_model": "single"}
|
||||
if "dataset_configs" not in config or not config.get("dataset_configs"):
|
||||
config["dataset_configs"] = {}
|
||||
config["dataset_configs"]["retrieval_model"] = config["dataset_configs"].get("retrieval_model", "single")
|
||||
|
||||
if not isinstance(config["dataset_configs"], dict):
|
||||
raise ValueError("dataset_configs must be of object type")
|
||||
|
||||
if not config["dataset_configs"].get("datasets"):
|
||||
if "datasets" not in config["dataset_configs"] or not config["dataset_configs"].get("datasets"):
|
||||
config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []}
|
||||
|
||||
need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get(
|
||||
"datasets", {}
|
||||
).get("datasets")
|
||||
need_manual_query_datasets = config.get("dataset_configs", {}).get("datasets", {}).get("datasets")
|
||||
|
||||
if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
|
||||
# Only check when mode is completion
|
||||
@@ -166,8 +175,8 @@ class DatasetConfigManager:
|
||||
:param config: app model config args
|
||||
"""
|
||||
# Extract dataset config for legacy compatibility
|
||||
if not config.get("agent_mode"):
|
||||
config["agent_mode"] = {"enabled": False, "tools": []}
|
||||
if "agent_mode" not in config or not config.get("agent_mode"):
|
||||
config["agent_mode"] = {}
|
||||
|
||||
if not isinstance(config["agent_mode"], dict):
|
||||
raise ValueError("agent_mode must be of object type")
|
||||
@@ -180,19 +189,22 @@ class DatasetConfigManager:
|
||||
raise ValueError("enabled in agent_mode must be of boolean type")
|
||||
|
||||
# tools
|
||||
if not config["agent_mode"].get("tools"):
|
||||
if "tools" not in config["agent_mode"] or not config["agent_mode"].get("tools"):
|
||||
config["agent_mode"]["tools"] = []
|
||||
|
||||
if not isinstance(config["agent_mode"]["tools"], list):
|
||||
raise ValueError("tools in agent_mode must be a list of objects")
|
||||
|
||||
# strategy
|
||||
if not config["agent_mode"].get("strategy"):
|
||||
if "strategy" not in config["agent_mode"] or not config["agent_mode"].get("strategy"):
|
||||
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
|
||||
|
||||
has_datasets = False
|
||||
if config["agent_mode"]["strategy"] in {PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value}:
|
||||
for tool in config["agent_mode"]["tools"]:
|
||||
if config.get("agent_mode", {}).get("strategy") in {
|
||||
PlanningStrategy.ROUTER.value,
|
||||
PlanningStrategy.REACT_ROUTER.value,
|
||||
}:
|
||||
for tool in config.get("agent_mode", {}).get("tools", []):
|
||||
key = list(tool.keys())[0]
|
||||
if key == "dataset":
|
||||
# old style, use tool name as key
|
||||
@@ -217,7 +229,7 @@ class DatasetConfigManager:
|
||||
|
||||
has_datasets = True
|
||||
|
||||
need_manual_query_datasets = has_datasets and config["agent_mode"]["enabled"]
|
||||
need_manual_query_datasets = has_datasets and config.get("agent_mode", {}).get("enabled")
|
||||
|
||||
if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
|
||||
# Only check when mode is completion
|
||||
|
||||
@@ -79,29 +79,12 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
# if only single iteration run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
# Handle single iteration or single loop run
|
||||
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
# if only single loop run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=dict(self.application_generate_entity.single_loop_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
single_iteration_run=self.application_generate_entity.single_iteration_run,
|
||||
single_loop_run=self.application_generate_entity.single_loop_run,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
|
||||
@@ -551,7 +551,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
total_steps=validated_state.node_run_steps,
|
||||
outputs=event.outputs,
|
||||
exceptions_count=event.exceptions_count,
|
||||
conversation_id=None,
|
||||
conversation_id=self._conversation_id,
|
||||
trace_manager=trace_manager,
|
||||
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
|
||||
)
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import logging
|
||||
import queue
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from enum import IntEnum, auto
|
||||
from typing import Any
|
||||
|
||||
from redis.exceptions import RedisError
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
|
||||
from configs import dify_config
|
||||
@@ -18,6 +20,8 @@ from core.app.entities.queue_entities import (
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PublishFrom(IntEnum):
|
||||
APPLICATION_MANAGER = auto()
|
||||
@@ -35,9 +39,8 @@ class AppQueueManager:
|
||||
self.invoke_from = invoke_from # Public accessor for invoke_from
|
||||
|
||||
user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
|
||||
redis_client.setex(
|
||||
AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
|
||||
)
|
||||
self._task_belong_cache_key = AppQueueManager._generate_task_belong_cache_key(self._task_id)
|
||||
redis_client.setex(self._task_belong_cache_key, 1800, f"{user_prefix}-{self._user_id}")
|
||||
|
||||
q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
|
||||
|
||||
@@ -79,9 +82,21 @@ class AppQueueManager:
|
||||
Stop listen to queue
|
||||
:return:
|
||||
"""
|
||||
self._clear_task_belong_cache()
|
||||
self._q.put(None)
|
||||
|
||||
def publish_error(self, e, pub_from: PublishFrom):
|
||||
def _clear_task_belong_cache(self) -> None:
|
||||
"""
|
||||
Remove the task belong cache key once listening is finished.
|
||||
"""
|
||||
try:
|
||||
redis_client.delete(self._task_belong_cache_key)
|
||||
except RedisError:
|
||||
logger.exception(
|
||||
"Failed to clear task belong cache for task %s (key: %s)", self._task_id, self._task_belong_cache_key
|
||||
)
|
||||
|
||||
def publish_error(self, e, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish error
|
||||
:param e: error
|
||||
|
||||
@@ -427,6 +427,9 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
workflow_execution_id=str(uuid.uuid4()),
|
||||
single_iteration_run=RagPipelineGenerateEntity.SingleIterationRunEntity(
|
||||
node_id=node_id, inputs=args["inputs"]
|
||||
),
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
@@ -465,6 +468,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
variable_loader=var_loader,
|
||||
context=contextvars.copy_context(),
|
||||
)
|
||||
|
||||
def single_loop_generate(
|
||||
@@ -559,6 +563,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
variable_loader=var_loader,
|
||||
context=contextvars.copy_context(),
|
||||
)
|
||||
|
||||
def _generate_worker(
|
||||
|
||||
@@ -86,29 +86,12 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
db.session.close()
|
||||
|
||||
# if only single iteration run is requested
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
# if only single iteration run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
# Handle single iteration or single loop run
|
||||
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
# if only single loop run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_loop_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
single_iteration_run=self.application_generate_entity.single_iteration_run,
|
||||
single_loop_run=self.application_generate_entity.single_loop_run,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
|
||||
@@ -51,30 +51,12 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(WorkflowAppConfig, app_config)
|
||||
|
||||
# if only single iteration run is requested
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
# if only single iteration run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
# if only single iteration or single loop run is requested
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
# if only single loop run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_loop_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
single_iteration_run=self.application_generate_entity.single_iteration_run,
|
||||
single_loop_run=self.application_generate_entity.single_loop_run,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -119,15 +120,81 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
return graph
|
||||
|
||||
def _get_graph_and_variable_pool_of_single_iteration(
|
||||
def _prepare_single_node_execution(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
single_iteration_run: Any | None = None,
|
||||
single_loop_run: Any | None = None,
|
||||
) -> tuple[Graph, VariablePool, GraphRuntimeState]:
|
||||
"""
|
||||
Prepare graph, variable pool, and runtime state for single node execution
|
||||
(either single iteration or single loop).
|
||||
|
||||
Args:
|
||||
workflow: The workflow instance
|
||||
single_iteration_run: SingleIterationRunEntity if running single iteration, None otherwise
|
||||
single_loop_run: SingleLoopRunEntity if running single loop, None otherwise
|
||||
|
||||
Returns:
|
||||
A tuple containing (graph, variable_pool, graph_runtime_state)
|
||||
|
||||
Raises:
|
||||
ValueError: If neither single_iteration_run nor single_loop_run is specified
|
||||
"""
|
||||
# Create initial runtime state with variable pool containing environment variables
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
),
|
||||
start_at=time.time(),
|
||||
)
|
||||
|
||||
# Determine which type of single node execution and get graph/variable_pool
|
||||
if single_iteration_run:
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
node_id=single_iteration_run.node_id,
|
||||
user_inputs=dict(single_iteration_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif single_loop_run:
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=workflow,
|
||||
node_id=single_loop_run.node_id,
|
||||
user_inputs=dict(single_loop_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Neither single_iteration_run nor single_loop_run is specified")
|
||||
|
||||
# Return the graph, variable_pool, and the same graph_runtime_state used during graph creation
|
||||
# This ensures all nodes in the graph reference the same GraphRuntimeState instance
|
||||
return graph, variable_pool, graph_runtime_state
|
||||
|
||||
def _get_graph_and_variable_pool_for_single_node_run(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
user_inputs: dict[str, Any],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
node_type_filter_key: str, # 'iteration_id' or 'loop_id'
|
||||
node_type_label: str = "node", # 'iteration' or 'loop' for error messages
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single iteration
|
||||
Get graph and variable pool for single node execution (iteration or loop).
|
||||
|
||||
Args:
|
||||
workflow: The workflow instance
|
||||
node_id: The node ID to execute
|
||||
user_inputs: User inputs for the node
|
||||
graph_runtime_state: The graph runtime state
|
||||
node_type_filter_key: The key to filter nodes ('iteration_id' or 'loop_id')
|
||||
node_type_label: Label for error messages ('iteration' or 'loop')
|
||||
|
||||
Returns:
|
||||
A tuple containing (graph, variable_pool)
|
||||
"""
|
||||
# fetch workflow graph
|
||||
graph_config = workflow.graph_dict
|
||||
@@ -145,18 +212,22 @@ class WorkflowBasedAppRunner:
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
|
||||
# filter nodes only in iteration
|
||||
# filter nodes only in the specified node type (iteration or loop)
|
||||
main_node_config = next((n for n in graph_config.get("nodes", []) if n.get("id") == node_id), None)
|
||||
start_node_id = main_node_config.get("data", {}).get("start_node_id") if main_node_config else None
|
||||
node_configs = [
|
||||
node
|
||||
for node in graph_config.get("nodes", [])
|
||||
if node.get("id") == node_id or node.get("data", {}).get("iteration_id", "") == node_id
|
||||
if node.get("id") == node_id
|
||||
or node.get("data", {}).get(node_type_filter_key, "") == node_id
|
||||
or (start_node_id and node.get("id") == start_node_id)
|
||||
]
|
||||
|
||||
graph_config["nodes"] = node_configs
|
||||
|
||||
node_ids = [node.get("id") for node in node_configs]
|
||||
|
||||
# filter edges only in iteration
|
||||
# filter edges only in the specified node type
|
||||
edge_configs = [
|
||||
edge
|
||||
for edge in graph_config.get("edges", [])
|
||||
@@ -190,30 +261,26 @@ class WorkflowBasedAppRunner:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
# fetch node config from node id
|
||||
iteration_node_config = None
|
||||
target_node_config = None
|
||||
for node in node_configs:
|
||||
if node.get("id") == node_id:
|
||||
iteration_node_config = node
|
||||
target_node_config = node
|
||||
break
|
||||
|
||||
if not iteration_node_config:
|
||||
raise ValueError("iteration node id not found in workflow graph")
|
||||
if not target_node_config:
|
||||
raise ValueError(f"{node_type_label} node id not found in workflow graph")
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType(iteration_node_config.get("data", {}).get("type"))
|
||||
node_version = iteration_node_config.get("data", {}).get("version", "1")
|
||||
node_type = NodeType(target_node_config.get("data", {}).get("type"))
|
||||
node_version = target_node_config.get("data", {}).get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
# Use the variable pool from graph_runtime_state instead of creating a new one
|
||||
variable_pool = graph_runtime_state.variable_pool
|
||||
|
||||
try:
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict, config=iteration_node_config
|
||||
graph_config=workflow.graph_dict, config=target_node_config
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
@@ -234,120 +301,44 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
return graph, variable_pool
|
||||
|
||||
def _get_graph_and_variable_pool_of_single_iteration(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict[str, Any],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single iteration
|
||||
"""
|
||||
return self._get_graph_and_variable_pool_for_single_node_run(
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
node_type_filter_key="iteration_id",
|
||||
node_type_label="iteration",
|
||||
)
|
||||
|
||||
def _get_graph_and_variable_pool_of_single_loop(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
user_inputs: dict[str, Any],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single loop
|
||||
"""
|
||||
# fetch workflow graph
|
||||
graph_config = workflow.graph_dict
|
||||
if not graph_config:
|
||||
raise ValueError("workflow graph not found")
|
||||
|
||||
graph_config = cast(dict[str, Any], graph_config)
|
||||
|
||||
if "nodes" not in graph_config or "edges" not in graph_config:
|
||||
raise ValueError("nodes or edges not found in workflow graph")
|
||||
|
||||
if not isinstance(graph_config.get("nodes"), list):
|
||||
raise ValueError("nodes in workflow graph must be a list")
|
||||
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
|
||||
# filter nodes only in loop
|
||||
node_configs = [
|
||||
node
|
||||
for node in graph_config.get("nodes", [])
|
||||
if node.get("id") == node_id or node.get("data", {}).get("loop_id", "") == node_id
|
||||
]
|
||||
|
||||
graph_config["nodes"] = node_configs
|
||||
|
||||
node_ids = [node.get("id") for node in node_configs]
|
||||
|
||||
# filter edges only in loop
|
||||
edge_configs = [
|
||||
edge
|
||||
for edge in graph_config.get("edges", [])
|
||||
if (edge.get("source") is None or edge.get("source") in node_ids)
|
||||
and (edge.get("target") is None or edge.get("target") in node_ids)
|
||||
]
|
||||
|
||||
graph_config["edges"] = edge_configs
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
workflow_id=workflow.id,
|
||||
graph_config=graph_config,
|
||||
user_id="",
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
return self._get_graph_and_variable_pool_for_single_node_run(
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
node_type_filter_key="loop_id",
|
||||
node_type_label="loop",
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
# fetch node config from node id
|
||||
loop_node_config = None
|
||||
for node in node_configs:
|
||||
if node.get("id") == node_id:
|
||||
loop_node_config = node
|
||||
break
|
||||
|
||||
if not loop_node_config:
|
||||
raise ValueError("loop node id not found in workflow graph")
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType(loop_node_config.get("data", {}).get("type"))
|
||||
node_version = loop_node_config.get("data", {}).get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
|
||||
try:
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict, config=loop_node_config
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
load_into_variable_pool(
|
||||
self._variable_loader,
|
||||
variable_pool=variable_pool,
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
)
|
||||
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=workflow.tenant_id,
|
||||
)
|
||||
|
||||
return graph, variable_pool
|
||||
|
||||
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent):
|
||||
"""
|
||||
Handle event
|
||||
|
||||
@@ -107,7 +107,6 @@ class MessageCycleManager:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("generate conversation name failed, conversation_id: %s", conversation_id)
|
||||
|
||||
db.session.merge(conversation)
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from openai import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Import InvokeFrom locally to avoid circular import
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
||||
@@ -1,388 +0,0 @@
|
||||
import re
|
||||
import uuid
|
||||
from json import dumps as json_dumps
|
||||
from json import loads as json_loads
|
||||
from json.decoder import JSONDecodeError
|
||||
|
||||
from flask import request
|
||||
from requests import get
|
||||
from yaml import YAMLError, safe_load # type: ignore
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter
|
||||
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
|
||||
|
||||
|
||||
class ApiBasedToolSchemaParser:
|
||||
@staticmethod
|
||||
def parse_openapi_to_tool_bundle(
|
||||
openapi: dict, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
# set description to extra_info
|
||||
extra_info["description"] = openapi["info"].get("description", "")
|
||||
|
||||
if len(openapi["servers"]) == 0:
|
||||
raise ToolProviderNotFoundError("No server found in the openapi yaml.")
|
||||
|
||||
server_url = openapi["servers"][0]["url"]
|
||||
request_env = request.headers.get("X-Request-Env")
|
||||
if request_env:
|
||||
matched_servers = [server["url"] for server in openapi["servers"] if server["env"] == request_env]
|
||||
server_url = matched_servers[0] if matched_servers else server_url
|
||||
|
||||
# list all interfaces
|
||||
interfaces = []
|
||||
for path, path_item in openapi["paths"].items():
|
||||
methods = ["get", "post", "put", "delete", "patch", "head", "options", "trace"]
|
||||
for method in methods:
|
||||
if method in path_item:
|
||||
interfaces.append(
|
||||
{
|
||||
"path": path,
|
||||
"method": method,
|
||||
"operation": path_item[method],
|
||||
}
|
||||
)
|
||||
|
||||
# get all parameters
|
||||
bundles = []
|
||||
for interface in interfaces:
|
||||
# convert parameters
|
||||
parameters = []
|
||||
if "parameters" in interface["operation"]:
|
||||
for parameter in interface["operation"]["parameters"]:
|
||||
tool_parameter = ToolParameter(
|
||||
name=parameter["name"],
|
||||
label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]),
|
||||
human_description=I18nObject(
|
||||
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=parameter.get("required", False),
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description=parameter.get("description"),
|
||||
default=parameter["schema"]["default"]
|
||||
if "schema" in parameter and "default" in parameter["schema"]
|
||||
else None,
|
||||
placeholder=I18nObject(
|
||||
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
|
||||
),
|
||||
)
|
||||
|
||||
# check if there is a type
|
||||
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(parameter)
|
||||
if typ:
|
||||
tool_parameter.type = typ
|
||||
|
||||
parameters.append(tool_parameter)
|
||||
# create tool bundle
|
||||
# check if there is a request body
|
||||
if "requestBody" in interface["operation"]:
|
||||
request_body = interface["operation"]["requestBody"]
|
||||
if "content" in request_body:
|
||||
for content_type, content in request_body["content"].items():
|
||||
# if there is a reference, get the reference and overwrite the content
|
||||
if "schema" not in content:
|
||||
continue
|
||||
|
||||
if "$ref" in content["schema"]:
|
||||
# get the reference
|
||||
root = openapi
|
||||
reference = content["schema"]["$ref"].split("/")[1:]
|
||||
for ref in reference:
|
||||
root = root[ref]
|
||||
# overwrite the content
|
||||
interface["operation"]["requestBody"]["content"][content_type]["schema"] = root
|
||||
|
||||
# parse body parameters
|
||||
if "schema" in interface["operation"]["requestBody"]["content"][content_type]: # pyright: ignore[reportIndexIssue, reportPossiblyUnboundVariable]
|
||||
body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"] # pyright: ignore[reportIndexIssue, reportPossiblyUnboundVariable]
|
||||
required = body_schema.get("required", [])
|
||||
properties = body_schema.get("properties", {})
|
||||
for name, property in properties.items():
|
||||
tool = ToolParameter(
|
||||
name=name,
|
||||
label=I18nObject(en_US=name, zh_Hans=name),
|
||||
human_description=I18nObject(
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=name in required,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description=property.get("description", ""),
|
||||
default=property.get("default", None),
|
||||
placeholder=I18nObject(
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
)
|
||||
|
||||
# check if there is a type
|
||||
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
|
||||
if typ:
|
||||
tool.type = typ
|
||||
|
||||
parameters.append(tool)
|
||||
|
||||
# check if parameters is duplicated
|
||||
parameters_count = {}
|
||||
for parameter in parameters:
|
||||
if parameter.name not in parameters_count:
|
||||
parameters_count[parameter.name] = 0
|
||||
parameters_count[parameter.name] += 1
|
||||
for name, count in parameters_count.items():
|
||||
if count > 1:
|
||||
warning["duplicated_parameter"] = f"Parameter {name} is duplicated."
|
||||
|
||||
# check if there is a operation id, use $path_$method as operation id if not
|
||||
if "operationId" not in interface["operation"]:
|
||||
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
|
||||
path = interface["path"]
|
||||
if interface["path"].startswith("/"):
|
||||
path = interface["path"][1:]
|
||||
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
|
||||
path = re.sub(r"[^a-zA-Z0-9_-]", "", path)
|
||||
if not path:
|
||||
path = str(uuid.uuid4())
|
||||
|
||||
interface["operation"]["operationId"] = f"{path}_{interface['method']}"
|
||||
|
||||
bundles.append(
|
||||
ApiToolBundle(
|
||||
server_url=server_url + interface["path"],
|
||||
method=interface["method"],
|
||||
summary=interface["operation"]["description"]
|
||||
if "description" in interface["operation"]
|
||||
else interface["operation"].get("summary", None),
|
||||
operation_id=interface["operation"]["operationId"],
|
||||
parameters=parameters,
|
||||
author="",
|
||||
icon=None,
|
||||
openapi=interface["operation"],
|
||||
)
|
||||
)
|
||||
|
||||
return bundles
|
||||
|
||||
@staticmethod
|
||||
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None:
|
||||
parameter = parameter or {}
|
||||
typ: str | None = None
|
||||
if parameter.get("format") == "binary":
|
||||
return ToolParameter.ToolParameterType.FILE
|
||||
|
||||
if "type" in parameter:
|
||||
typ = parameter["type"]
|
||||
elif "schema" in parameter and "type" in parameter["schema"]:
|
||||
typ = parameter["schema"]["type"]
|
||||
|
||||
if typ in {"integer", "number"}:
|
||||
return ToolParameter.ToolParameterType.NUMBER
|
||||
elif typ == "boolean":
|
||||
return ToolParameter.ToolParameterType.BOOLEAN
|
||||
elif typ == "string":
|
||||
return ToolParameter.ToolParameterType.STRING
|
||||
elif typ == "array":
|
||||
items = parameter.get("items") or parameter.get("schema", {}).get("items")
|
||||
return ToolParameter.ToolParameterType.FILES if items and items.get("format") == "binary" else None
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_openapi_yaml_to_tool_bundle(
|
||||
yaml: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi yaml to tool bundle
|
||||
|
||||
:param yaml: the yaml string
|
||||
:param extra_info: the extra info
|
||||
:param warning: the warning message
|
||||
:return: the tool bundle
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
openapi: dict = safe_load(yaml)
|
||||
if openapi is None:
|
||||
raise ToolApiSchemaError("Invalid openapi yaml.")
|
||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
|
||||
|
||||
@staticmethod
|
||||
def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None) -> dict:
|
||||
warning = warning or {}
|
||||
"""
|
||||
parse swagger to openapi
|
||||
|
||||
:param swagger: the swagger dict
|
||||
:return: the openapi dict
|
||||
"""
|
||||
# convert swagger to openapi
|
||||
info = swagger.get("info", {"title": "Swagger", "description": "Swagger", "version": "1.0.0"})
|
||||
|
||||
servers = swagger.get("servers", [])
|
||||
|
||||
if len(servers) == 0:
|
||||
raise ToolApiSchemaError("No server found in the swagger yaml.")
|
||||
|
||||
openapi = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {
|
||||
"title": info.get("title", "Swagger"),
|
||||
"description": info.get("description", "Swagger"),
|
||||
"version": info.get("version", "1.0.0"),
|
||||
},
|
||||
"servers": swagger["servers"],
|
||||
"paths": {},
|
||||
"components": {"schemas": {}},
|
||||
}
|
||||
|
||||
# check paths
|
||||
if "paths" not in swagger or len(swagger["paths"]) == 0:
|
||||
raise ToolApiSchemaError("No paths found in the swagger yaml.")
|
||||
|
||||
# convert paths
|
||||
for path, path_item in swagger["paths"].items():
|
||||
openapi["paths"][path] = {} # pyright: ignore[reportIndexIssue]
|
||||
for method, operation in path_item.items():
|
||||
if "operationId" not in operation:
|
||||
raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.")
|
||||
|
||||
if ("summary" not in operation or len(operation["summary"]) == 0) and (
|
||||
"description" not in operation or len(operation["description"]) == 0
|
||||
):
|
||||
if warning is not None:
|
||||
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
|
||||
|
||||
openapi["paths"][path][method] = { # pyright: ignore[reportIndexIssue]
|
||||
"operationId": operation["operationId"],
|
||||
"summary": operation.get("summary", ""),
|
||||
"description": operation.get("description", ""),
|
||||
"parameters": operation.get("parameters", []),
|
||||
"responses": operation.get("responses", {}),
|
||||
}
|
||||
|
||||
if "requestBody" in operation:
|
||||
openapi["paths"][path][method]["requestBody"] = operation["requestBody"] # pyright: ignore[reportIndexIssue]
|
||||
|
||||
# convert definitions
|
||||
for name, definition in swagger["definitions"].items():
|
||||
openapi["components"]["schemas"][name] = definition # pyright: ignore[reportIndexIssue, reportArgumentType]
|
||||
|
||||
return openapi
|
||||
|
||||
@staticmethod
|
||||
def parse_openai_plugin_json_to_tool_bundle(
|
||||
json: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi plugin yaml to tool bundle
|
||||
|
||||
:param json: the json string
|
||||
:param extra_info: the extra info
|
||||
:param warning: the warning message
|
||||
:return: the tool bundle
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
try:
|
||||
openai_plugin = json_loads(json)
|
||||
api = openai_plugin["api"]
|
||||
api_url = api["url"]
|
||||
api_type = api["type"]
|
||||
except JSONDecodeError:
|
||||
raise ToolProviderNotFoundError("Invalid openai plugin json.")
|
||||
|
||||
if api_type != "openapi":
|
||||
raise ToolNotSupportedError("Only openapi is supported now.")
|
||||
|
||||
# get openapi yaml
|
||||
response = get(api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ToolProviderNotFoundError("cannot get openapi yaml from url.")
|
||||
|
||||
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(
|
||||
response.text, extra_info=extra_info, warning=warning
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def auto_parse_to_tool_bundle(
|
||||
content: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> tuple[list[ApiToolBundle], str]:
|
||||
"""
|
||||
auto parse to tool bundle
|
||||
|
||||
:param content: the content
|
||||
:param extra_info: the extra info
|
||||
:param warning: the warning message
|
||||
:return: tools bundle, schema_type
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
content = content.strip()
|
||||
loaded_content = None
|
||||
json_error = None
|
||||
yaml_error = None
|
||||
|
||||
try:
|
||||
loaded_content = json_loads(content)
|
||||
except JSONDecodeError as e:
|
||||
json_error = e
|
||||
|
||||
if loaded_content is None:
|
||||
try:
|
||||
loaded_content = safe_load(content)
|
||||
except YAMLError as e:
|
||||
yaml_error = e
|
||||
if loaded_content is None:
|
||||
raise ToolApiSchemaError(
|
||||
f"Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)},"
|
||||
f" yaml error: {str(yaml_error)}"
|
||||
)
|
||||
|
||||
swagger_error = None
|
||||
openapi_error = None
|
||||
openapi_plugin_error = None
|
||||
schema_type = None
|
||||
|
||||
try:
|
||||
openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
|
||||
loaded_content, extra_info=extra_info, warning=warning
|
||||
)
|
||||
schema_type = ApiProviderSchemaType.OPENAPI.value
|
||||
return openapi, schema_type
|
||||
except ToolApiSchemaError as e:
|
||||
openapi_error = e
|
||||
|
||||
# openai parse error, fallback to swagger
|
||||
try:
|
||||
converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(
|
||||
loaded_content, extra_info=extra_info, warning=warning
|
||||
)
|
||||
schema_type = ApiProviderSchemaType.SWAGGER.value
|
||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
|
||||
converted_swagger, extra_info=extra_info, warning=warning
|
||||
), schema_type
|
||||
except ToolApiSchemaError as e:
|
||||
swagger_error = e
|
||||
|
||||
# swagger parse error, fallback to openai plugin
|
||||
try:
|
||||
openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
|
||||
json_dumps(loaded_content), extra_info=extra_info, warning=warning
|
||||
)
|
||||
return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN.value
|
||||
except ToolNotSupportedError as e:
|
||||
# maybe it's not plugin at all
|
||||
openapi_plugin_error = e
|
||||
|
||||
raise ToolApiSchemaError(
|
||||
f"Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)},"
|
||||
f" openapi plugin error: {str(openapi_plugin_error)}"
|
||||
)
|
||||
@@ -1,17 +0,0 @@
|
||||
import re
|
||||
|
||||
|
||||
def remove_leading_symbols(text: str) -> str:
|
||||
"""
|
||||
Remove leading punctuation or symbols from the given text.
|
||||
|
||||
Args:
|
||||
text (str): The input text to process.
|
||||
|
||||
Returns:
|
||||
str: The text with leading punctuation or symbols removed.
|
||||
"""
|
||||
# Match Unicode ranges for punctuation and symbols
|
||||
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
|
||||
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+"
|
||||
return re.sub(pattern, "", text)
|
||||
@@ -1,9 +0,0 @@
|
||||
import uuid
|
||||
|
||||
|
||||
def is_valid_uuid(uuid_str: str) -> bool:
|
||||
try:
|
||||
uuid.UUID(uuid_str)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
@@ -1,43 +0,0 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
|
||||
|
||||
class WorkflowToolConfigurationUtils:
|
||||
@classmethod
|
||||
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
|
||||
for configuration in configurations:
|
||||
WorkflowToolParameterConfiguration.model_validate(configuration)
|
||||
|
||||
@classmethod
|
||||
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
|
||||
"""
|
||||
get workflow graph variables
|
||||
"""
|
||||
nodes = graph.get("nodes", [])
|
||||
start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None)
|
||||
|
||||
if not start_node:
|
||||
return []
|
||||
|
||||
return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])]
|
||||
|
||||
@classmethod
|
||||
def check_is_synced(
|
||||
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
|
||||
):
|
||||
"""
|
||||
check is synced
|
||||
|
||||
raise ValueError if not synced
|
||||
"""
|
||||
variable_names = [variable.variable for variable in variables]
|
||||
|
||||
if len(tool_configurations) != len(variables):
|
||||
raise ValueError("parameter configuration mismatch, please republish the tool to update")
|
||||
|
||||
for parameter in tool_configurations:
|
||||
if parameter.name not in variable_names:
|
||||
raise ValueError("parameter configuration mismatch, please republish the tool to update")
|
||||
@@ -1,35 +0,0 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml # type: ignore
|
||||
from yaml import YAMLError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}) -> Any:
|
||||
"""
|
||||
Safe loading a YAML file
|
||||
:param file_path: the path of the YAML file
|
||||
:param ignore_error:
|
||||
if True, return default_value if error occurs and the error will be logged in debug level
|
||||
if False, raise error if error occurs
|
||||
:param default_value: the value returned when errors ignored
|
||||
:return: an object of the YAML content
|
||||
"""
|
||||
if not file_path or not Path(file_path).exists():
|
||||
if ignore_error:
|
||||
return default_value
|
||||
else:
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with open(file_path, encoding="utf-8") as yaml_file:
|
||||
try:
|
||||
yaml_content = yaml.safe_load(yaml_file)
|
||||
return yaml_content or default_value
|
||||
except Exception as e:
|
||||
if ignore_error:
|
||||
return default_value
|
||||
else:
|
||||
raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e
|
||||
@@ -205,16 +205,10 @@ class ProviderConfiguration(BaseModel):
|
||||
"""
|
||||
Get custom provider record.
|
||||
"""
|
||||
# get provider
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
|
||||
stmt = select(Provider).where(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||
Provider.provider_name.in_(provider_names),
|
||||
Provider.provider_name.in_(self._get_provider_names()),
|
||||
)
|
||||
|
||||
return session.execute(stmt).scalar_one_or_none()
|
||||
@@ -276,7 +270,7 @@ class ProviderConfiguration(BaseModel):
|
||||
"""
|
||||
stmt = select(ProviderCredential.id).where(
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderCredential.credential_name == credential_name,
|
||||
)
|
||||
if exclude_id:
|
||||
@@ -324,7 +318,7 @@ class ProviderConfiguration(BaseModel):
|
||||
try:
|
||||
stmt = select(ProviderCredential).where(
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderCredential.id == credential_id,
|
||||
)
|
||||
credential_record = s.execute(stmt).scalar_one_or_none()
|
||||
@@ -374,7 +368,7 @@ class ProviderConfiguration(BaseModel):
|
||||
session=session,
|
||||
query_factory=lambda: select(ProviderCredential).where(
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -387,7 +381,7 @@ class ProviderConfiguration(BaseModel):
|
||||
session=session,
|
||||
query_factory=lambda: select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
),
|
||||
@@ -423,6 +417,16 @@ class ProviderConfiguration(BaseModel):
|
||||
logger.warning("Error generating next credential name: %s", str(e))
|
||||
return "API KEY 1"
|
||||
|
||||
def _get_provider_names(self):
|
||||
"""
|
||||
The provider name might be stored in the database as either `openai` or `langgenius/openai/openai`.
|
||||
"""
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
return provider_names
|
||||
|
||||
def create_provider_credential(self, credentials: dict, credential_name: str | None):
|
||||
"""
|
||||
Add custom provider credentials.
|
||||
@@ -501,7 +505,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderCredential).where(
|
||||
ProviderCredential.id == credential_id,
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
)
|
||||
|
||||
# Get the credential record to update
|
||||
@@ -554,7 +558,7 @@ class ProviderConfiguration(BaseModel):
|
||||
# Find all load balancing configs that use this credential_id
|
||||
stmt = select(LoadBalancingModelConfig).where(
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == self.provider.provider,
|
||||
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
|
||||
LoadBalancingModelConfig.credential_id == credential_id,
|
||||
LoadBalancingModelConfig.credential_source_type == credential_source,
|
||||
)
|
||||
@@ -591,7 +595,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderCredential).where(
|
||||
ProviderCredential.id == credential_id,
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
)
|
||||
|
||||
# Get the credential record to update
|
||||
@@ -602,7 +606,7 @@ class ProviderConfiguration(BaseModel):
|
||||
# Check if this credential is used in load balancing configs
|
||||
lb_stmt = select(LoadBalancingModelConfig).where(
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == self.provider.provider,
|
||||
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
|
||||
LoadBalancingModelConfig.credential_id == credential_id,
|
||||
LoadBalancingModelConfig.credential_source_type == "provider",
|
||||
)
|
||||
@@ -624,7 +628,7 @@ class ProviderConfiguration(BaseModel):
|
||||
# if this is the last credential, we need to delete the provider record
|
||||
count_stmt = select(func.count(ProviderCredential.id)).where(
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
)
|
||||
available_credentials_count = session.execute(count_stmt).scalar() or 0
|
||||
session.delete(credential_record)
|
||||
@@ -668,7 +672,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderCredential).where(
|
||||
ProviderCredential.id == credential_id,
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
)
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
if not credential_record:
|
||||
@@ -737,7 +741,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
@@ -784,7 +788,7 @@ class ProviderConfiguration(BaseModel):
|
||||
"""
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelCredential.credential_name == credential_name,
|
||||
@@ -860,7 +864,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
@@ -997,7 +1001,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
@@ -1042,7 +1046,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
@@ -1052,7 +1056,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
lb_stmt = select(LoadBalancingModelConfig).where(
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == self.provider.provider,
|
||||
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
|
||||
LoadBalancingModelConfig.credential_id == credential_id,
|
||||
LoadBalancingModelConfig.credential_source_type == "custom_model",
|
||||
)
|
||||
@@ -1075,7 +1079,7 @@ class ProviderConfiguration(BaseModel):
|
||||
# if this is the last credential, we need to delete the custom model record
|
||||
count_stmt = select(func.count(ProviderModelCredential.id)).where(
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
@@ -1115,7 +1119,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
@@ -1157,7 +1161,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
@@ -1204,15 +1208,9 @@ class ProviderConfiguration(BaseModel):
|
||||
"""
|
||||
Get provider model setting.
|
||||
"""
|
||||
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
|
||||
stmt = select(ProviderModelSetting).where(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name.in_(provider_names),
|
||||
ProviderModelSetting.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model,
|
||||
)
|
||||
@@ -1384,15 +1382,9 @@ class ProviderConfiguration(BaseModel):
|
||||
return
|
||||
|
||||
def _switch(s: Session):
|
||||
# get preferred provider
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
|
||||
stmt = select(TenantPreferredModelProvider).where(
|
||||
TenantPreferredModelProvider.tenant_id == self.tenant_id,
|
||||
TenantPreferredModelProvider.provider_name.in_(provider_names),
|
||||
TenantPreferredModelProvider.provider_name.in_(self._get_provider_names()),
|
||||
)
|
||||
preferred_model_provider = s.execute(stmt).scalars().first()
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from enum import StrEnum
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
|
||||
from httpx import Timeout, post
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from yarl import URL
|
||||
|
||||
@@ -13,9 +13,17 @@ from core.helper.code_executor.javascript.javascript_transformer import NodeJsTe
|
||||
from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTransformer
|
||||
from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer
|
||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
from core.helper.http_client_pooling import get_pooled_http_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
code_execution_endpoint_url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT))
|
||||
CODE_EXECUTION_SSL_VERIFY = dify_config.CODE_EXECUTION_SSL_VERIFY
|
||||
_CODE_EXECUTOR_CLIENT_LIMITS = httpx.Limits(
|
||||
max_connections=dify_config.CODE_EXECUTION_POOL_MAX_CONNECTIONS,
|
||||
max_keepalive_connections=dify_config.CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS,
|
||||
keepalive_expiry=dify_config.CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY,
|
||||
)
|
||||
_CODE_EXECUTOR_CLIENT_KEY = "code_executor:http_client"
|
||||
|
||||
|
||||
class CodeExecutionError(Exception):
|
||||
@@ -38,6 +46,13 @@ class CodeLanguage(StrEnum):
|
||||
JAVASCRIPT = "javascript"
|
||||
|
||||
|
||||
def _build_code_executor_client() -> httpx.Client:
|
||||
return httpx.Client(
|
||||
verify=CODE_EXECUTION_SSL_VERIFY,
|
||||
limits=_CODE_EXECUTOR_CLIENT_LIMITS,
|
||||
)
|
||||
|
||||
|
||||
class CodeExecutor:
|
||||
dependencies_cache: dict[str, str] = {}
|
||||
dependencies_cache_lock = Lock()
|
||||
@@ -76,17 +91,21 @@ class CodeExecutor:
|
||||
"enable_network": True,
|
||||
}
|
||||
|
||||
timeout = httpx.Timeout(
|
||||
connect=dify_config.CODE_EXECUTION_CONNECT_TIMEOUT,
|
||||
read=dify_config.CODE_EXECUTION_READ_TIMEOUT,
|
||||
write=dify_config.CODE_EXECUTION_WRITE_TIMEOUT,
|
||||
pool=None,
|
||||
)
|
||||
|
||||
client = get_pooled_http_client(_CODE_EXECUTOR_CLIENT_KEY, _build_code_executor_client)
|
||||
|
||||
try:
|
||||
response = post(
|
||||
response = client.post(
|
||||
str(url),
|
||||
json=data,
|
||||
headers=headers,
|
||||
timeout=Timeout(
|
||||
connect=dify_config.CODE_EXECUTION_CONNECT_TIMEOUT,
|
||||
read=dify_config.CODE_EXECUTION_READ_TIMEOUT,
|
||||
write=dify_config.CODE_EXECUTION_WRITE_TIMEOUT,
|
||||
pool=None,
|
||||
),
|
||||
timeout=timeout,
|
||||
)
|
||||
if response.status_code == 503:
|
||||
raise CodeExecutionError("Code execution service is unavailable")
|
||||
@@ -106,8 +125,8 @@ class CodeExecutor:
|
||||
|
||||
try:
|
||||
response_data = response.json()
|
||||
except:
|
||||
raise CodeExecutionError("Failed to parse response")
|
||||
except Exception as e:
|
||||
raise CodeExecutionError("Failed to parse response") from e
|
||||
|
||||
if (code := response_data.get("code")) != 0:
|
||||
raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response_data.get('message')}")
|
||||
|
||||
59
api/core/helper/http_client_pooling.py
Normal file
59
api/core/helper/http_client_pooling.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""HTTP client pooling utilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
|
||||
import httpx
|
||||
|
||||
ClientBuilder = Callable[[], httpx.Client]
|
||||
|
||||
|
||||
class HttpClientPoolFactory:
|
||||
"""Thread-safe factory that maintains reusable HTTP client instances."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._clients: dict[str, httpx.Client] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def get_or_create(self, key: str, builder: ClientBuilder) -> httpx.Client:
|
||||
"""Return a pooled client associated with ``key`` creating it on demand."""
|
||||
client = self._clients.get(key)
|
||||
if client is not None:
|
||||
return client
|
||||
|
||||
with self._lock:
|
||||
client = self._clients.get(key)
|
||||
if client is None:
|
||||
client = builder()
|
||||
self._clients[key] = client
|
||||
return client
|
||||
|
||||
def close_all(self) -> None:
|
||||
"""Close all pooled clients and clear the pool."""
|
||||
with self._lock:
|
||||
for client in self._clients.values():
|
||||
client.close()
|
||||
self._clients.clear()
|
||||
|
||||
|
||||
_factory = HttpClientPoolFactory()
|
||||
|
||||
|
||||
def get_pooled_http_client(key: str, builder: ClientBuilder) -> httpx.Client:
|
||||
"""Return a pooled client for the given ``key`` using ``builder`` when missing."""
|
||||
return _factory.get_or_create(key, builder)
|
||||
|
||||
|
||||
def close_all_pooled_clients() -> None:
|
||||
"""Close every client created through the pooling factory."""
|
||||
_factory.close_all()
|
||||
|
||||
|
||||
def _register_shutdown_hook() -> None:
|
||||
atexit.register(close_all_pooled_clients)
|
||||
|
||||
|
||||
_register_shutdown_hook()
|
||||
@@ -23,7 +23,7 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP
|
||||
return []
|
||||
|
||||
url = str(marketplace_api_url / "api/v1/plugins/batch")
|
||||
response = httpx.post(url, json={"plugin_ids": plugin_ids})
|
||||
response = httpx.post(url, json={"plugin_ids": plugin_ids}, headers={"X-Dify-Version": dify_config.project.version})
|
||||
response.raise_for_status()
|
||||
|
||||
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]
|
||||
@@ -36,7 +36,7 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
|
||||
return []
|
||||
|
||||
url = str(marketplace_api_url / "api/v1/plugins/batch")
|
||||
response = httpx.post(url, json={"plugin_ids": plugin_ids})
|
||||
response = httpx.post(url, json={"plugin_ids": plugin_ids}, headers={"X-Dify-Version": dify_config.project.version})
|
||||
response.raise_for_status()
|
||||
result: list[MarketplacePluginDeclaration] = []
|
||||
for plugin in response.json()["data"]["plugins"]:
|
||||
|
||||
@@ -8,27 +8,23 @@ import time
|
||||
import httpx
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.http_client_pooling import get_pooled_http_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
|
||||
|
||||
http_request_node_ssl_verify = True # Default value for http_request_node_ssl_verify is True
|
||||
try:
|
||||
config_value = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
|
||||
http_request_node_ssl_verify_lower = str(config_value).lower()
|
||||
if http_request_node_ssl_verify_lower == "true":
|
||||
http_request_node_ssl_verify = True
|
||||
elif http_request_node_ssl_verify_lower == "false":
|
||||
http_request_node_ssl_verify = False
|
||||
else:
|
||||
raise ValueError("Invalid value. HTTP_REQUEST_NODE_SSL_VERIFY should be 'True' or 'False'")
|
||||
except NameError:
|
||||
http_request_node_ssl_verify = True
|
||||
|
||||
BACKOFF_FACTOR = 0.5
|
||||
STATUS_FORCELIST = [429, 500, 502, 503, 504]
|
||||
|
||||
_SSL_VERIFIED_POOL_KEY = "ssrf:verified"
|
||||
_SSL_UNVERIFIED_POOL_KEY = "ssrf:unverified"
|
||||
_SSRF_CLIENT_LIMITS = httpx.Limits(
|
||||
max_connections=dify_config.SSRF_POOL_MAX_CONNECTIONS,
|
||||
max_keepalive_connections=dify_config.SSRF_POOL_MAX_KEEPALIVE_CONNECTIONS,
|
||||
keepalive_expiry=dify_config.SSRF_POOL_KEEPALIVE_EXPIRY,
|
||||
)
|
||||
|
||||
|
||||
class MaxRetriesExceededError(ValueError):
|
||||
"""Raised when the maximum number of retries is exceeded."""
|
||||
@@ -36,6 +32,45 @@ class MaxRetriesExceededError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
def _create_proxy_mounts() -> dict[str, httpx.HTTPTransport]:
|
||||
return {
|
||||
"http://": httpx.HTTPTransport(
|
||||
proxy=dify_config.SSRF_PROXY_HTTP_URL,
|
||||
),
|
||||
"https://": httpx.HTTPTransport(
|
||||
proxy=dify_config.SSRF_PROXY_HTTPS_URL,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _build_ssrf_client(verify: bool) -> httpx.Client:
|
||||
if dify_config.SSRF_PROXY_ALL_URL:
|
||||
return httpx.Client(
|
||||
proxy=dify_config.SSRF_PROXY_ALL_URL,
|
||||
verify=verify,
|
||||
limits=_SSRF_CLIENT_LIMITS,
|
||||
)
|
||||
|
||||
if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
|
||||
return httpx.Client(
|
||||
mounts=_create_proxy_mounts(),
|
||||
verify=verify,
|
||||
limits=_SSRF_CLIENT_LIMITS,
|
||||
)
|
||||
|
||||
return httpx.Client(verify=verify, limits=_SSRF_CLIENT_LIMITS)
|
||||
|
||||
|
||||
def _get_ssrf_client(ssl_verify_enabled: bool) -> httpx.Client:
|
||||
if not isinstance(ssl_verify_enabled, bool):
|
||||
raise ValueError("SSRF client verify flag must be a boolean")
|
||||
|
||||
return get_pooled_http_client(
|
||||
_SSL_VERIFIED_POOL_KEY if ssl_verify_enabled else _SSL_UNVERIFIED_POOL_KEY,
|
||||
lambda: _build_ssrf_client(verify=ssl_verify_enabled),
|
||||
)
|
||||
|
||||
|
||||
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
if "allow_redirects" in kwargs:
|
||||
allow_redirects = kwargs.pop("allow_redirects")
|
||||
@@ -50,33 +85,22 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
|
||||
)
|
||||
|
||||
if "ssl_verify" not in kwargs:
|
||||
kwargs["ssl_verify"] = http_request_node_ssl_verify
|
||||
|
||||
ssl_verify = kwargs.pop("ssl_verify")
|
||||
# prioritize per-call option, which can be switched on and off inside the HTTP node on the web UI
|
||||
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
|
||||
client = _get_ssrf_client(verify_option)
|
||||
|
||||
retries = 0
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
if dify_config.SSRF_PROXY_ALL_URL:
|
||||
with httpx.Client(proxy=dify_config.SSRF_PROXY_ALL_URL, verify=ssl_verify) as client:
|
||||
response = client.request(method=method, url=url, **kwargs)
|
||||
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
|
||||
proxy_mounts = {
|
||||
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=ssl_verify),
|
||||
"https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=ssl_verify),
|
||||
}
|
||||
with httpx.Client(mounts=proxy_mounts, verify=ssl_verify) as client:
|
||||
response = client.request(method=method, url=url, **kwargs)
|
||||
else:
|
||||
with httpx.Client(verify=ssl_verify) as client:
|
||||
response = client.request(method=method, url=url, **kwargs)
|
||||
response = client.request(method=method, url=url, **kwargs)
|
||||
|
||||
if response.status_code not in STATUS_FORCELIST:
|
||||
return response
|
||||
else:
|
||||
logger.warning(
|
||||
"Received status code %s for URL %s which is in the force list", response.status_code, url
|
||||
"Received status code %s for URL %s which is in the force list",
|
||||
response.status_code,
|
||||
url,
|
||||
)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
|
||||
@@ -28,7 +28,6 @@ from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.utils import measure_time
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.node_events import AgentLogEvent
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models import App, Message, WorkflowNodeExecutionModel
|
||||
@@ -462,19 +461,18 @@ class LLMGenerator:
|
||||
)
|
||||
|
||||
def agent_log_of(node_execution: WorkflowNodeExecutionModel) -> Sequence:
|
||||
raw_agent_log = node_execution.execution_metadata_dict.get(WorkflowNodeExecutionMetadataKey.AGENT_LOG)
|
||||
raw_agent_log = node_execution.execution_metadata_dict.get(WorkflowNodeExecutionMetadataKey.AGENT_LOG, [])
|
||||
if not raw_agent_log:
|
||||
return []
|
||||
parsed: Sequence[AgentLogEvent] = json.loads(raw_agent_log)
|
||||
|
||||
def dict_of_event(event: AgentLogEvent):
|
||||
return {
|
||||
"status": event.status,
|
||||
"error": event.error,
|
||||
"data": event.data,
|
||||
return [
|
||||
{
|
||||
"status": event["status"],
|
||||
"error": event["error"],
|
||||
"data": event["data"],
|
||||
}
|
||||
|
||||
return [dict_of_event(event) for event in parsed]
|
||||
for event in raw_agent_log
|
||||
]
|
||||
|
||||
inputs = last_run.load_full_inputs(session, storage)
|
||||
last_run_dict = {
|
||||
|
||||
@@ -74,7 +74,7 @@ class TextPromptMessageContent(PromptMessageContent):
|
||||
Model class for text prompt message content.
|
||||
"""
|
||||
|
||||
type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT
|
||||
type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT # type: ignore
|
||||
data: str
|
||||
|
||||
|
||||
@@ -95,11 +95,11 @@ class MultiModalPromptMessageContent(PromptMessageContent):
|
||||
|
||||
|
||||
class VideoPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO
|
||||
type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO # type: ignore
|
||||
|
||||
|
||||
class AudioPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO
|
||||
type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO # type: ignore
|
||||
|
||||
|
||||
class ImagePromptMessageContent(MultiModalPromptMessageContent):
|
||||
@@ -111,12 +111,12 @@ class ImagePromptMessageContent(MultiModalPromptMessageContent):
|
||||
LOW = auto()
|
||||
HIGH = auto()
|
||||
|
||||
type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE
|
||||
type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE # type: ignore
|
||||
detail: DETAIL = DETAIL.LOW
|
||||
|
||||
|
||||
class DocumentPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT
|
||||
type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT # type: ignore
|
||||
|
||||
|
||||
PromptMessageContentUnionTypes = Annotated[
|
||||
|
||||
@@ -15,7 +15,7 @@ class GPT2Tokenizer:
|
||||
use gpt2 tokenizer to get num tokens
|
||||
"""
|
||||
_tokenizer = GPT2Tokenizer.get_encoder()
|
||||
tokens = _tokenizer.encode(text)
|
||||
tokens = _tokenizer.encode(text) # type: ignore
|
||||
return len(tokens)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -196,15 +196,15 @@ def jsonable_encoder(
|
||||
return encoder(obj)
|
||||
|
||||
try:
|
||||
data = dict(obj)
|
||||
data = dict(obj) # type: ignore
|
||||
except Exception as e:
|
||||
errors: list[Exception] = []
|
||||
errors.append(e)
|
||||
try:
|
||||
data = vars(obj)
|
||||
data = vars(obj) # type: ignore
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
raise ValueError(errors) from e
|
||||
raise ValueError(str(errors)) from e
|
||||
return jsonable_encoder(
|
||||
data,
|
||||
by_alias=by_alias,
|
||||
|
||||
@@ -1,38 +1,28 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from opentelemetry.trace import Link, Status, StatusCode
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.aliyun_trace.data_exporter.traceclient import (
|
||||
TraceClient,
|
||||
build_endpoint,
|
||||
convert_datetime_to_nanoseconds,
|
||||
convert_to_span_id,
|
||||
convert_to_trace_id,
|
||||
create_link,
|
||||
generate_span_id,
|
||||
)
|
||||
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
|
||||
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata
|
||||
from core.ops.aliyun_trace.entities.semconv import (
|
||||
GEN_AI_COMPLETION,
|
||||
GEN_AI_FRAMEWORK,
|
||||
GEN_AI_MODEL_NAME,
|
||||
GEN_AI_INPUT_MESSAGE,
|
||||
GEN_AI_OUTPUT_MESSAGE,
|
||||
GEN_AI_PROMPT,
|
||||
GEN_AI_PROMPT_TEMPLATE_TEMPLATE,
|
||||
GEN_AI_PROMPT_TEMPLATE_VARIABLE,
|
||||
GEN_AI_PROVIDER_NAME,
|
||||
GEN_AI_REQUEST_MODEL,
|
||||
GEN_AI_RESPONSE_FINISH_REASON,
|
||||
GEN_AI_SESSION_ID,
|
||||
GEN_AI_SPAN_KIND,
|
||||
GEN_AI_SYSTEM,
|
||||
GEN_AI_USAGE_INPUT_TOKENS,
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS,
|
||||
GEN_AI_USAGE_TOTAL_TOKENS,
|
||||
GEN_AI_USER_ID,
|
||||
INPUT_VALUE,
|
||||
OUTPUT_VALUE,
|
||||
RETRIEVAL_DOCUMENT,
|
||||
RETRIEVAL_QUERY,
|
||||
TOOL_DESCRIPTION,
|
||||
@@ -40,6 +30,18 @@ from core.ops.aliyun_trace.entities.semconv import (
|
||||
TOOL_PARAMETERS,
|
||||
GenAISpanKind,
|
||||
)
|
||||
from core.ops.aliyun_trace.utils import (
|
||||
create_common_span_attributes,
|
||||
create_links_from_trace_id,
|
||||
create_status_from_error,
|
||||
extract_retrieval_documents,
|
||||
format_input_messages,
|
||||
format_output_messages,
|
||||
format_retrieval_documents,
|
||||
get_user_id_from_message_data,
|
||||
get_workflow_node_status,
|
||||
serialize_json_data,
|
||||
)
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import AliyunConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
@@ -52,12 +54,11 @@ from core.ops.entities.trace_entity import (
|
||||
ToolTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.rag.models.document import Document
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.entities import WorkflowNodeExecution
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
|
||||
from extensions.ext_database import db
|
||||
from models import Account, App, EndUser, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom
|
||||
from models import WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -68,8 +69,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
aliyun_config: AliyunConfig,
|
||||
):
|
||||
super().__init__(aliyun_config)
|
||||
base_url = aliyun_config.endpoint.rstrip("/")
|
||||
endpoint = urljoin(base_url, f"adapt_{aliyun_config.license_key}/api/otlp/traces")
|
||||
endpoint = build_endpoint(aliyun_config.endpoint, aliyun_config.license_key)
|
||||
self.trace_client = TraceClient(service_name=aliyun_config.app_name, endpoint=endpoint)
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
@@ -95,423 +95,425 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
try:
|
||||
return self.trace_client.get_project_url()
|
||||
except Exception as e:
|
||||
logger.info("Aliyun get run url failed: %s", str(e), exc_info=True)
|
||||
raise ValueError(f"Aliyun get run url failed: {str(e)}")
|
||||
logger.info("Aliyun get project url failed: %s", str(e), exc_info=True)
|
||||
raise ValueError(f"Aliyun get project url failed: {str(e)}")
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
trace_id = convert_to_trace_id(trace_info.workflow_run_id)
|
||||
links = []
|
||||
if trace_info.trace_id:
|
||||
links.append(create_link(trace_id_str=trace_info.trace_id))
|
||||
workflow_span_id = convert_to_span_id(trace_info.workflow_run_id, "workflow")
|
||||
self.add_workflow_span(trace_id, workflow_span_id, trace_info, links)
|
||||
trace_metadata = TraceMetadata(
|
||||
trace_id=convert_to_trace_id(trace_info.workflow_run_id),
|
||||
workflow_span_id=convert_to_span_id(trace_info.workflow_run_id, "workflow"),
|
||||
session_id=trace_info.metadata.get("conversation_id") or "",
|
||||
user_id=str(trace_info.metadata.get("user_id") or ""),
|
||||
links=create_links_from_trace_id(trace_info.trace_id),
|
||||
)
|
||||
|
||||
self.add_workflow_span(trace_info, trace_metadata)
|
||||
|
||||
workflow_node_executions = self.get_workflow_node_executions(trace_info)
|
||||
for node_execution in workflow_node_executions:
|
||||
node_span = self.build_workflow_node_span(node_execution, trace_id, trace_info, workflow_span_id)
|
||||
node_span = self.build_workflow_node_span(node_execution, trace_info, trace_metadata)
|
||||
self.trace_client.add_span(node_span)
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
message_data = trace_info.message_data
|
||||
if message_data is None:
|
||||
return
|
||||
|
||||
message_id = trace_info.message_id
|
||||
user_id = get_user_id_from_message_data(message_data)
|
||||
status = create_status_from_error(trace_info.error)
|
||||
|
||||
user_id = message_data.from_account_id
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: EndUser | None = (
|
||||
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
|
||||
)
|
||||
if end_user_data is not None:
|
||||
user_id = end_user_data.session_id
|
||||
trace_metadata = TraceMetadata(
|
||||
trace_id=convert_to_trace_id(message_id),
|
||||
workflow_span_id=0,
|
||||
session_id=trace_info.metadata.get("conversation_id") or "",
|
||||
user_id=user_id,
|
||||
links=create_links_from_trace_id(trace_info.trace_id),
|
||||
)
|
||||
|
||||
status: Status = Status(StatusCode.OK)
|
||||
if trace_info.error:
|
||||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
|
||||
trace_id = convert_to_trace_id(message_id)
|
||||
links = []
|
||||
if trace_info.trace_id:
|
||||
links.append(create_link(trace_id_str=trace_info.trace_id))
|
||||
inputs_json = serialize_json_data(trace_info.inputs)
|
||||
outputs_str = str(trace_info.outputs)
|
||||
|
||||
message_span_id = convert_to_span_id(message_id, "message")
|
||||
message_span = SpanData(
|
||||
trace_id=trace_id,
|
||||
trace_id=trace_metadata.trace_id,
|
||||
parent_span_id=None,
|
||||
span_id=message_span_id,
|
||||
name="message",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "",
|
||||
GEN_AI_USER_ID: str(user_id),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
OUTPUT_VALUE: str(trace_info.outputs),
|
||||
},
|
||||
attributes=create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.CHAIN,
|
||||
inputs=inputs_json,
|
||||
outputs=outputs_str,
|
||||
),
|
||||
status=status,
|
||||
links=links,
|
||||
links=trace_metadata.links,
|
||||
)
|
||||
self.trace_client.add_span(message_span)
|
||||
|
||||
app_model_config = getattr(trace_info.message_data, "app_model_config", {})
|
||||
pre_prompt = getattr(app_model_config, "pre_prompt", "")
|
||||
inputs_data = getattr(trace_info.message_data, "inputs", {})
|
||||
llm_span = SpanData(
|
||||
trace_id=trace_id,
|
||||
trace_id=trace_metadata.trace_id,
|
||||
parent_span_id=message_span_id,
|
||||
span_id=convert_to_span_id(message_id, "llm"),
|
||||
name="llm",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "",
|
||||
GEN_AI_USER_ID: str(user_id),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
GEN_AI_MODEL_NAME: trace_info.metadata.get("ls_model_name") or "",
|
||||
GEN_AI_SYSTEM: trace_info.metadata.get("ls_provider") or "",
|
||||
**create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.LLM,
|
||||
inputs=inputs_json,
|
||||
outputs=outputs_str,
|
||||
),
|
||||
GEN_AI_REQUEST_MODEL: trace_info.metadata.get("ls_model_name") or "",
|
||||
GEN_AI_PROVIDER_NAME: trace_info.metadata.get("ls_provider") or "",
|
||||
GEN_AI_USAGE_INPUT_TOKENS: str(trace_info.message_tokens),
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS: str(trace_info.answer_tokens),
|
||||
GEN_AI_USAGE_TOTAL_TOKENS: str(trace_info.total_tokens),
|
||||
GEN_AI_PROMPT_TEMPLATE_VARIABLE: json.dumps(inputs_data, ensure_ascii=False),
|
||||
GEN_AI_PROMPT_TEMPLATE_TEMPLATE: pre_prompt,
|
||||
GEN_AI_PROMPT: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
GEN_AI_COMPLETION: str(trace_info.outputs),
|
||||
INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
OUTPUT_VALUE: str(trace_info.outputs),
|
||||
GEN_AI_PROMPT: inputs_json,
|
||||
GEN_AI_COMPLETION: outputs_str,
|
||||
},
|
||||
status=status,
|
||||
links=trace_metadata.links,
|
||||
)
|
||||
self.trace_client.add_span(llm_span)
|
||||
|
||||
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
message_id = trace_info.message_id
|
||||
|
||||
trace_id = convert_to_trace_id(message_id)
|
||||
links = []
|
||||
if trace_info.trace_id:
|
||||
links.append(create_link(trace_id_str=trace_info.trace_id))
|
||||
trace_metadata = TraceMetadata(
|
||||
trace_id=convert_to_trace_id(message_id),
|
||||
workflow_span_id=0,
|
||||
session_id=trace_info.metadata.get("conversation_id") or "",
|
||||
user_id=str(trace_info.metadata.get("user_id") or ""),
|
||||
links=create_links_from_trace_id(trace_info.trace_id),
|
||||
)
|
||||
|
||||
documents_data = extract_retrieval_documents(trace_info.documents)
|
||||
documents_json = serialize_json_data(documents_data)
|
||||
inputs_str = str(trace_info.inputs)
|
||||
|
||||
dataset_retrieval_span = SpanData(
|
||||
trace_id=trace_id,
|
||||
trace_id=trace_metadata.trace_id,
|
||||
parent_span_id=convert_to_span_id(message_id, "message"),
|
||||
span_id=generate_span_id(),
|
||||
name="dataset_retrieval",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
RETRIEVAL_QUERY: str(trace_info.inputs),
|
||||
RETRIEVAL_DOCUMENT: json.dumps(documents_data, ensure_ascii=False),
|
||||
INPUT_VALUE: str(trace_info.inputs),
|
||||
OUTPUT_VALUE: json.dumps(documents_data, ensure_ascii=False),
|
||||
**create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.RETRIEVER,
|
||||
inputs=inputs_str,
|
||||
outputs=documents_json,
|
||||
),
|
||||
RETRIEVAL_QUERY: inputs_str,
|
||||
RETRIEVAL_DOCUMENT: documents_json,
|
||||
},
|
||||
links=links,
|
||||
links=trace_metadata.links,
|
||||
)
|
||||
self.trace_client.add_span(dataset_retrieval_span)
|
||||
|
||||
def tool_trace(self, trace_info: ToolTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
message_id = trace_info.message_id
|
||||
status = create_status_from_error(trace_info.error)
|
||||
|
||||
status: Status = Status(StatusCode.OK)
|
||||
if trace_info.error:
|
||||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
trace_metadata = TraceMetadata(
|
||||
trace_id=convert_to_trace_id(message_id),
|
||||
workflow_span_id=0,
|
||||
session_id=trace_info.metadata.get("conversation_id") or "",
|
||||
user_id=str(trace_info.metadata.get("user_id") or ""),
|
||||
links=create_links_from_trace_id(trace_info.trace_id),
|
||||
)
|
||||
|
||||
trace_id = convert_to_trace_id(message_id)
|
||||
links = []
|
||||
if trace_info.trace_id:
|
||||
links.append(create_link(trace_id_str=trace_info.trace_id))
|
||||
tool_config_json = serialize_json_data(trace_info.tool_config)
|
||||
tool_inputs_json = serialize_json_data(trace_info.tool_inputs)
|
||||
inputs_json = serialize_json_data(trace_info.inputs)
|
||||
|
||||
tool_span = SpanData(
|
||||
trace_id=trace_id,
|
||||
trace_id=trace_metadata.trace_id,
|
||||
parent_span_id=convert_to_span_id(message_id, "message"),
|
||||
span_id=generate_span_id(),
|
||||
name=trace_info.tool_name,
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
**create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.TOOL,
|
||||
inputs=inputs_json,
|
||||
outputs=str(trace_info.tool_outputs),
|
||||
),
|
||||
TOOL_NAME: trace_info.tool_name,
|
||||
TOOL_DESCRIPTION: json.dumps(trace_info.tool_config, ensure_ascii=False),
|
||||
TOOL_PARAMETERS: json.dumps(trace_info.tool_inputs, ensure_ascii=False),
|
||||
INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
OUTPUT_VALUE: str(trace_info.tool_outputs),
|
||||
TOOL_DESCRIPTION: tool_config_json,
|
||||
TOOL_PARAMETERS: tool_inputs_json,
|
||||
},
|
||||
status=status,
|
||||
links=links,
|
||||
links=trace_metadata.links,
|
||||
)
|
||||
self.trace_client.add_span(tool_span)
|
||||
|
||||
def get_workflow_node_executions(self, trace_info: WorkflowTraceInfo) -> Sequence[WorkflowNodeExecution]:
|
||||
# through workflow_run_id get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
# Find the app's creator account
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get the app to find its creator
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
app_stmt = select(App).where(App.id == app_id)
|
||||
app = session.scalar(app_stmt)
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
|
||||
account_stmt = select(Account).where(Account.id == app.created_by)
|
||||
service_account = session.scalar(account_stmt)
|
||||
if not service_account:
|
||||
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||
current_tenant = (
|
||||
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first()
|
||||
)
|
||||
if not current_tenant:
|
||||
raise ValueError(f"Current tenant not found for account {service_account.id}")
|
||||
service_account.set_tenant_id(current_tenant.tenant_id)
|
||||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=service_account,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
# Get all executions for this workflow run
|
||||
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
|
||||
workflow_run_id=trace_info.workflow_run_id
|
||||
)
|
||||
return workflow_node_executions
|
||||
|
||||
return workflow_node_execution_repository.get_by_workflow_run(workflow_run_id=trace_info.workflow_run_id)
|
||||
|
||||
def build_workflow_node_span(
|
||||
self, node_execution: WorkflowNodeExecution, trace_id: int, trace_info: WorkflowTraceInfo, workflow_span_id: int
|
||||
self, node_execution: WorkflowNodeExecution, trace_info: WorkflowTraceInfo, trace_metadata: TraceMetadata
|
||||
):
|
||||
try:
|
||||
if node_execution.node_type == NodeType.LLM:
|
||||
node_span = self.build_workflow_llm_span(trace_id, workflow_span_id, trace_info, node_execution)
|
||||
node_span = self.build_workflow_llm_span(trace_info, node_execution, trace_metadata)
|
||||
elif node_execution.node_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
node_span = self.build_workflow_retrieval_span(trace_id, workflow_span_id, trace_info, node_execution)
|
||||
node_span = self.build_workflow_retrieval_span(trace_info, node_execution, trace_metadata)
|
||||
elif node_execution.node_type == NodeType.TOOL:
|
||||
node_span = self.build_workflow_tool_span(trace_id, workflow_span_id, trace_info, node_execution)
|
||||
node_span = self.build_workflow_tool_span(trace_info, node_execution, trace_metadata)
|
||||
else:
|
||||
node_span = self.build_workflow_task_span(trace_id, workflow_span_id, trace_info, node_execution)
|
||||
node_span = self.build_workflow_task_span(trace_info, node_execution, trace_metadata)
|
||||
return node_span
|
||||
except Exception as e:
|
||||
logger.debug("Error occurred in build_workflow_node_span: %s", e, exc_info=True)
|
||||
return None
|
||||
|
||||
def get_workflow_node_status(self, node_execution: WorkflowNodeExecution) -> Status:
|
||||
span_status: Status = Status(StatusCode.UNSET)
|
||||
if node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
span_status = Status(StatusCode.OK)
|
||||
elif node_execution.status in [WorkflowNodeExecutionStatus.FAILED, WorkflowNodeExecutionStatus.EXCEPTION]:
|
||||
span_status = Status(StatusCode.ERROR, str(node_execution.error))
|
||||
return span_status
|
||||
|
||||
def build_workflow_task_span(
|
||||
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
|
||||
self, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution, trace_metadata: TraceMetadata
|
||||
) -> SpanData:
|
||||
inputs_json = serialize_json_data(node_execution.inputs)
|
||||
outputs_json = serialize_json_data(node_execution.outputs)
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=workflow_span_id,
|
||||
trace_id=trace_metadata.trace_id,
|
||||
parent_span_id=trace_metadata.workflow_span_id,
|
||||
span_id=convert_to_span_id(node_execution.id, "node"),
|
||||
name=node_execution.title,
|
||||
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
|
||||
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
|
||||
attributes={
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "",
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.TASK.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
INPUT_VALUE: json.dumps(node_execution.inputs, ensure_ascii=False),
|
||||
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
|
||||
},
|
||||
status=self.get_workflow_node_status(node_execution),
|
||||
attributes=create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.TASK,
|
||||
inputs=inputs_json,
|
||||
outputs=outputs_json,
|
||||
),
|
||||
status=get_workflow_node_status(node_execution),
|
||||
links=trace_metadata.links,
|
||||
)
|
||||
|
||||
def build_workflow_tool_span(
|
||||
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
|
||||
self, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution, trace_metadata: TraceMetadata
|
||||
) -> SpanData:
|
||||
tool_des = {}
|
||||
if node_execution.metadata:
|
||||
tool_des = node_execution.metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO, {})
|
||||
|
||||
inputs_json = serialize_json_data(node_execution.inputs or {})
|
||||
outputs_json = serialize_json_data(node_execution.outputs)
|
||||
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=workflow_span_id,
|
||||
trace_id=trace_metadata.trace_id,
|
||||
parent_span_id=trace_metadata.workflow_span_id,
|
||||
span_id=convert_to_span_id(node_execution.id, "node"),
|
||||
name=node_execution.title,
|
||||
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
|
||||
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
|
||||
attributes={
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
**create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.TOOL,
|
||||
inputs=inputs_json,
|
||||
outputs=outputs_json,
|
||||
),
|
||||
TOOL_NAME: node_execution.title,
|
||||
TOOL_DESCRIPTION: json.dumps(tool_des, ensure_ascii=False),
|
||||
TOOL_PARAMETERS: json.dumps(node_execution.inputs or {}, ensure_ascii=False),
|
||||
INPUT_VALUE: json.dumps(node_execution.inputs or {}, ensure_ascii=False),
|
||||
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
|
||||
TOOL_DESCRIPTION: serialize_json_data(tool_des),
|
||||
TOOL_PARAMETERS: inputs_json,
|
||||
},
|
||||
status=self.get_workflow_node_status(node_execution),
|
||||
status=get_workflow_node_status(node_execution),
|
||||
links=trace_metadata.links,
|
||||
)
|
||||
|
||||
def build_workflow_retrieval_span(
|
||||
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
|
||||
self, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution, trace_metadata: TraceMetadata
|
||||
) -> SpanData:
|
||||
input_value = ""
|
||||
if node_execution.inputs:
|
||||
input_value = str(node_execution.inputs.get("query", ""))
|
||||
output_value = ""
|
||||
if node_execution.outputs:
|
||||
output_value = json.dumps(node_execution.outputs.get("result", []), ensure_ascii=False)
|
||||
input_value = str(node_execution.inputs.get("query", "")) if node_execution.inputs else ""
|
||||
output_value = serialize_json_data(node_execution.outputs.get("result", [])) if node_execution.outputs else ""
|
||||
|
||||
retrieval_documents = node_execution.outputs.get("result", []) if node_execution.outputs else []
|
||||
semantic_retrieval_documents = format_retrieval_documents(retrieval_documents)
|
||||
semantic_retrieval_documents_json = serialize_json_data(semantic_retrieval_documents)
|
||||
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=workflow_span_id,
|
||||
trace_id=trace_metadata.trace_id,
|
||||
parent_span_id=trace_metadata.workflow_span_id,
|
||||
span_id=convert_to_span_id(node_execution.id, "node"),
|
||||
name=node_execution.title,
|
||||
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
|
||||
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
|
||||
attributes={
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
**create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.RETRIEVER,
|
||||
inputs=input_value,
|
||||
outputs=output_value,
|
||||
),
|
||||
RETRIEVAL_QUERY: input_value,
|
||||
RETRIEVAL_DOCUMENT: output_value,
|
||||
INPUT_VALUE: input_value,
|
||||
OUTPUT_VALUE: output_value,
|
||||
RETRIEVAL_DOCUMENT: semantic_retrieval_documents_json,
|
||||
},
|
||||
status=self.get_workflow_node_status(node_execution),
|
||||
status=get_workflow_node_status(node_execution),
|
||||
links=trace_metadata.links,
|
||||
)
|
||||
|
||||
def build_workflow_llm_span(
|
||||
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
|
||||
self, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution, trace_metadata: TraceMetadata
|
||||
) -> SpanData:
|
||||
process_data = node_execution.process_data or {}
|
||||
outputs = node_execution.outputs or {}
|
||||
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
|
||||
prompts_json = serialize_json_data(process_data.get("prompts", []))
|
||||
text_output = str(outputs.get("text", ""))
|
||||
|
||||
gen_ai_input_message = format_input_messages(process_data)
|
||||
gen_ai_output_message = format_output_messages(outputs)
|
||||
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=workflow_span_id,
|
||||
trace_id=trace_metadata.trace_id,
|
||||
parent_span_id=trace_metadata.workflow_span_id,
|
||||
span_id=convert_to_span_id(node_execution.id, "node"),
|
||||
name=node_execution.title,
|
||||
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
|
||||
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
|
||||
attributes={
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "",
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
GEN_AI_MODEL_NAME: process_data.get("model_name") or "",
|
||||
GEN_AI_SYSTEM: process_data.get("model_provider") or "",
|
||||
**create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.LLM,
|
||||
inputs=prompts_json,
|
||||
outputs=text_output,
|
||||
),
|
||||
GEN_AI_REQUEST_MODEL: process_data.get("model_name") or "",
|
||||
GEN_AI_PROVIDER_NAME: process_data.get("model_provider") or "",
|
||||
GEN_AI_USAGE_INPUT_TOKENS: str(usage_data.get("prompt_tokens", 0)),
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS: str(usage_data.get("completion_tokens", 0)),
|
||||
GEN_AI_USAGE_TOTAL_TOKENS: str(usage_data.get("total_tokens", 0)),
|
||||
GEN_AI_PROMPT: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
|
||||
GEN_AI_COMPLETION: str(outputs.get("text", "")),
|
||||
GEN_AI_PROMPT: prompts_json,
|
||||
GEN_AI_COMPLETION: text_output,
|
||||
GEN_AI_RESPONSE_FINISH_REASON: outputs.get("finish_reason") or "",
|
||||
INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
|
||||
OUTPUT_VALUE: str(outputs.get("text", "")),
|
||||
GEN_AI_INPUT_MESSAGE: gen_ai_input_message,
|
||||
GEN_AI_OUTPUT_MESSAGE: gen_ai_output_message,
|
||||
},
|
||||
status=self.get_workflow_node_status(node_execution),
|
||||
status=get_workflow_node_status(node_execution),
|
||||
links=trace_metadata.links,
|
||||
)
|
||||
|
||||
def add_workflow_span(
|
||||
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, links: Sequence[Link]
|
||||
):
|
||||
def add_workflow_span(self, trace_info: WorkflowTraceInfo, trace_metadata: TraceMetadata):
|
||||
message_span_id = None
|
||||
if trace_info.message_id:
|
||||
message_span_id = convert_to_span_id(trace_info.message_id, "message")
|
||||
user_id = trace_info.metadata.get("user_id")
|
||||
status: Status = Status(StatusCode.OK)
|
||||
if trace_info.error:
|
||||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
if message_span_id: # chatflow
|
||||
status = create_status_from_error(trace_info.error)
|
||||
|
||||
inputs_json = serialize_json_data(trace_info.workflow_run_inputs)
|
||||
outputs_json = serialize_json_data(trace_info.workflow_run_outputs)
|
||||
|
||||
if message_span_id:
|
||||
message_span = SpanData(
|
||||
trace_id=trace_id,
|
||||
trace_id=trace_metadata.trace_id,
|
||||
parent_span_id=None,
|
||||
span_id=message_span_id,
|
||||
name="message",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "",
|
||||
GEN_AI_USER_ID: str(user_id),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
INPUT_VALUE: trace_info.workflow_run_inputs.get("sys.query") or "",
|
||||
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
|
||||
},
|
||||
attributes=create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.CHAIN,
|
||||
inputs=trace_info.workflow_run_inputs.get("sys.query") or "",
|
||||
outputs=outputs_json,
|
||||
),
|
||||
status=status,
|
||||
links=links,
|
||||
links=trace_metadata.links,
|
||||
)
|
||||
self.trace_client.add_span(message_span)
|
||||
|
||||
workflow_span = SpanData(
|
||||
trace_id=trace_id,
|
||||
trace_id=trace_metadata.trace_id,
|
||||
parent_span_id=message_span_id,
|
||||
span_id=workflow_span_id,
|
||||
span_id=trace_metadata.workflow_span_id,
|
||||
name="workflow",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_USER_ID: str(user_id),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
INPUT_VALUE: json.dumps(trace_info.workflow_run_inputs, ensure_ascii=False),
|
||||
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
|
||||
},
|
||||
attributes=create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.CHAIN,
|
||||
inputs=inputs_json,
|
||||
outputs=outputs_json,
|
||||
),
|
||||
status=status,
|
||||
links=links,
|
||||
links=trace_metadata.links,
|
||||
)
|
||||
self.trace_client.add_span(workflow_span)
|
||||
|
||||
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
|
||||
message_id = trace_info.message_id
|
||||
status: Status = Status(StatusCode.OK)
|
||||
if trace_info.error:
|
||||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
status = create_status_from_error(trace_info.error)
|
||||
|
||||
trace_id = convert_to_trace_id(message_id)
|
||||
links = []
|
||||
if trace_info.trace_id:
|
||||
links.append(create_link(trace_id_str=trace_info.trace_id))
|
||||
trace_metadata = TraceMetadata(
|
||||
trace_id=convert_to_trace_id(message_id),
|
||||
workflow_span_id=0,
|
||||
session_id=trace_info.metadata.get("conversation_id") or "",
|
||||
user_id=str(trace_info.metadata.get("user_id") or ""),
|
||||
links=create_links_from_trace_id(trace_info.trace_id),
|
||||
)
|
||||
|
||||
inputs_json = serialize_json_data(trace_info.inputs)
|
||||
suggested_question_json = serialize_json_data(trace_info.suggested_question)
|
||||
|
||||
suggested_question_span = SpanData(
|
||||
trace_id=trace_id,
|
||||
trace_id=trace_metadata.trace_id,
|
||||
parent_span_id=convert_to_span_id(message_id, "message"),
|
||||
span_id=convert_to_span_id(message_id, "suggested_question"),
|
||||
name="suggested_question",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
GEN_AI_MODEL_NAME: trace_info.metadata.get("ls_model_name") or "",
|
||||
GEN_AI_SYSTEM: trace_info.metadata.get("ls_provider") or "",
|
||||
GEN_AI_PROMPT: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
GEN_AI_COMPLETION: json.dumps(trace_info.suggested_question, ensure_ascii=False),
|
||||
INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False),
|
||||
**create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.LLM,
|
||||
inputs=inputs_json,
|
||||
outputs=suggested_question_json,
|
||||
),
|
||||
GEN_AI_REQUEST_MODEL: trace_info.metadata.get("ls_model_name") or "",
|
||||
GEN_AI_PROVIDER_NAME: trace_info.metadata.get("ls_provider") or "",
|
||||
GEN_AI_PROMPT: inputs_json,
|
||||
GEN_AI_COMPLETION: suggested_question_json,
|
||||
},
|
||||
status=status,
|
||||
links=links,
|
||||
links=trace_metadata.links,
|
||||
)
|
||||
self.trace_client.add_span(suggested_question_span)
|
||||
|
||||
|
||||
def extract_retrieval_documents(documents: list[Document]):
|
||||
documents_data = []
|
||||
for document in documents:
|
||||
document_data = {
|
||||
"content": document.page_content,
|
||||
"metadata": {
|
||||
"dataset_id": document.metadata.get("dataset_id"),
|
||||
"doc_id": document.metadata.get("doc_id"),
|
||||
"document_id": document.metadata.get("document_id"),
|
||||
},
|
||||
"score": document.metadata.get("score"),
|
||||
}
|
||||
documents_data.append(document_data)
|
||||
return documents_data
|
||||
|
||||
@@ -7,8 +7,10 @@ import uuid
|
||||
from collections import deque
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Final
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from opentelemetry import trace as trace_api
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
@@ -20,8 +22,12 @@ from opentelemetry.trace import Link, SpanContext, TraceFlags
|
||||
from configs import dify_config
|
||||
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
|
||||
|
||||
INVALID_SPAN_ID = 0x0000000000000000
|
||||
INVALID_TRACE_ID = 0x00000000000000000000000000000000
|
||||
INVALID_SPAN_ID: Final[int] = 0x0000000000000000
|
||||
INVALID_TRACE_ID: Final[int] = 0x00000000000000000000000000000000
|
||||
DEFAULT_TIMEOUT: Final[int] = 5
|
||||
DEFAULT_MAX_QUEUE_SIZE: Final[int] = 1000
|
||||
DEFAULT_SCHEDULE_DELAY_SEC: Final[int] = 5
|
||||
DEFAULT_MAX_EXPORT_BATCH_SIZE: Final[int] = 50
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -31,9 +37,9 @@ class TraceClient:
|
||||
self,
|
||||
service_name: str,
|
||||
endpoint: str,
|
||||
max_queue_size: int = 1000,
|
||||
schedule_delay_sec: int = 5,
|
||||
max_export_batch_size: int = 50,
|
||||
max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE,
|
||||
schedule_delay_sec: int = DEFAULT_SCHEDULE_DELAY_SEC,
|
||||
max_export_batch_size: int = DEFAULT_MAX_EXPORT_BATCH_SIZE,
|
||||
):
|
||||
self.endpoint = endpoint
|
||||
self.resource = Resource(
|
||||
@@ -63,24 +69,25 @@ class TraceClient:
|
||||
def export(self, spans: Sequence[ReadableSpan]):
|
||||
self.exporter.export(spans)
|
||||
|
||||
def api_check(self):
|
||||
def api_check(self) -> bool:
|
||||
try:
|
||||
response = requests.head(self.endpoint, timeout=5)
|
||||
response = httpx.head(self.endpoint, timeout=DEFAULT_TIMEOUT)
|
||||
if response.status_code == 405:
|
||||
return True
|
||||
else:
|
||||
logger.debug("AliyunTrace API check failed: Unexpected status code: %s", response.status_code)
|
||||
return False
|
||||
except requests.RequestException as e:
|
||||
except httpx.RequestError as e:
|
||||
logger.debug("AliyunTrace API check failed: %s", str(e))
|
||||
raise ValueError(f"AliyunTrace API check failed: {str(e)}")
|
||||
|
||||
def get_project_url(self):
|
||||
def get_project_url(self) -> str:
|
||||
return "https://arms.console.aliyun.com/#/llm"
|
||||
|
||||
def add_span(self, span_data: SpanData):
|
||||
def add_span(self, span_data: SpanData | None) -> None:
|
||||
if span_data is None:
|
||||
return
|
||||
|
||||
span: ReadableSpan = self.span_builder.build_span(span_data)
|
||||
with self.condition:
|
||||
if len(self.queue) == self.max_queue_size:
|
||||
@@ -92,14 +99,14 @@ class TraceClient:
|
||||
if len(self.queue) >= self.max_export_batch_size:
|
||||
self.condition.notify()
|
||||
|
||||
def _worker(self):
|
||||
def _worker(self) -> None:
|
||||
while not self.done:
|
||||
with self.condition:
|
||||
if len(self.queue) < self.max_export_batch_size and not self.done:
|
||||
self.condition.wait(timeout=self.schedule_delay_sec)
|
||||
self._export_batch()
|
||||
|
||||
def _export_batch(self):
|
||||
def _export_batch(self) -> None:
|
||||
spans_to_export: list[ReadableSpan] = []
|
||||
with self.condition:
|
||||
while len(spans_to_export) < self.max_export_batch_size and self.queue:
|
||||
@@ -111,7 +118,7 @@ class TraceClient:
|
||||
except Exception as e:
|
||||
logger.debug("Error exporting spans: %s", e)
|
||||
|
||||
def shutdown(self):
|
||||
def shutdown(self) -> None:
|
||||
with self.condition:
|
||||
self.done = True
|
||||
self.condition.notify_all()
|
||||
@@ -121,7 +128,7 @@ class TraceClient:
|
||||
|
||||
|
||||
class SpanBuilder:
|
||||
def __init__(self, resource):
|
||||
def __init__(self, resource: Resource) -> None:
|
||||
self.resource = resource
|
||||
self.instrumentation_scope = InstrumentationScope(
|
||||
__name__,
|
||||
@@ -167,8 +174,12 @@ class SpanBuilder:
|
||||
|
||||
|
||||
def create_link(trace_id_str: str) -> Link:
|
||||
placeholder_span_id = 0x0000000000000000
|
||||
trace_id = int(trace_id_str, 16)
|
||||
placeholder_span_id = INVALID_SPAN_ID
|
||||
try:
|
||||
trace_id = int(trace_id_str, 16)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid trace ID format: {trace_id_str}") from e
|
||||
|
||||
span_context = SpanContext(
|
||||
trace_id=trace_id, span_id=placeholder_span_id, is_remote=False, trace_flags=TraceFlags(TraceFlags.SAMPLED)
|
||||
)
|
||||
@@ -184,26 +195,29 @@ def generate_span_id() -> int:
|
||||
|
||||
|
||||
def convert_to_trace_id(uuid_v4: str | None) -> int:
|
||||
if uuid_v4 is None:
|
||||
raise ValueError("UUID cannot be None")
|
||||
try:
|
||||
uuid_obj = uuid.UUID(uuid_v4)
|
||||
return uuid_obj.int
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid UUID input: {e}")
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid UUID input: {uuid_v4}") from e
|
||||
|
||||
|
||||
def convert_string_to_id(string: str | None) -> int:
|
||||
if not string:
|
||||
return generate_span_id()
|
||||
hash_bytes = hashlib.sha256(string.encode("utf-8")).digest()
|
||||
id = int.from_bytes(hash_bytes[:8], byteorder="big", signed=False)
|
||||
return id
|
||||
return int.from_bytes(hash_bytes[:8], byteorder="big", signed=False)
|
||||
|
||||
|
||||
def convert_to_span_id(uuid_v4: str | None, span_type: str) -> int:
|
||||
if uuid_v4 is None:
|
||||
raise ValueError("UUID cannot be None")
|
||||
try:
|
||||
uuid_obj = uuid.UUID(uuid_v4)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid UUID input: {e}")
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid UUID input: {uuid_v4}") from e
|
||||
combined_key = f"{uuid_obj.hex}-{span_type}"
|
||||
return convert_string_to_id(combined_key)
|
||||
|
||||
@@ -212,5 +226,11 @@ def convert_datetime_to_nanoseconds(start_time_a: datetime | None) -> int | None
|
||||
if start_time_a is None:
|
||||
return None
|
||||
timestamp_in_seconds = start_time_a.timestamp()
|
||||
timestamp_in_nanoseconds = int(timestamp_in_seconds * 1e9)
|
||||
return timestamp_in_nanoseconds
|
||||
return int(timestamp_in_seconds * 1e9)
|
||||
|
||||
|
||||
def build_endpoint(base_url: str, license_key: str) -> str:
|
||||
if "log.aliyuncs.com" in base_url: # cms2.0 endpoint
|
||||
return urljoin(base_url, f"adapt_{license_key}/api/v1/traces")
|
||||
else: # xtrace endpoint
|
||||
return urljoin(base_url, f"adapt_{license_key}/api/otlp/traces")
|
||||
|
||||
@@ -1,18 +1,34 @@
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from opentelemetry import trace as trace_api
|
||||
from opentelemetry.sdk.trace import Event, Status, StatusCode
|
||||
from opentelemetry.sdk.trace import Event
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@dataclass
|
||||
class TraceMetadata:
|
||||
"""Metadata for trace operations, containing common attributes for all spans in a trace."""
|
||||
|
||||
trace_id: int
|
||||
workflow_span_id: int
|
||||
session_id: str
|
||||
user_id: str
|
||||
links: list[trace_api.Link]
|
||||
|
||||
|
||||
class SpanData(BaseModel):
|
||||
"""Data model for span information in Aliyun trace system."""
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
trace_id: int = Field(..., description="The unique identifier for the trace.")
|
||||
parent_span_id: int | None = Field(None, description="The ID of the parent span, if any.")
|
||||
span_id: int = Field(..., description="The unique identifier for this span.")
|
||||
name: str = Field(..., description="The name of the span.")
|
||||
attributes: dict[str, str] = Field(default_factory=dict, description="Attributes associated with the span.")
|
||||
attributes: dict[str, Any] = Field(default_factory=dict, description="Attributes associated with the span.")
|
||||
events: Sequence[Event] = Field(default_factory=list, description="Events recorded in the span.")
|
||||
links: Sequence[trace_api.Link] = Field(default_factory=list, description="Links to other spans.")
|
||||
status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.")
|
||||
|
||||
@@ -1,56 +1,38 @@
|
||||
from enum import StrEnum
|
||||
from typing import Final
|
||||
|
||||
# public
|
||||
GEN_AI_SESSION_ID = "gen_ai.session.id"
|
||||
# Public attributes
|
||||
GEN_AI_SESSION_ID: Final[str] = "gen_ai.session.id"
|
||||
GEN_AI_USER_ID: Final[str] = "gen_ai.user.id"
|
||||
GEN_AI_USER_NAME: Final[str] = "gen_ai.user.name"
|
||||
GEN_AI_SPAN_KIND: Final[str] = "gen_ai.span.kind"
|
||||
GEN_AI_FRAMEWORK: Final[str] = "gen_ai.framework"
|
||||
|
||||
GEN_AI_USER_ID = "gen_ai.user.id"
|
||||
# Chain attributes
|
||||
INPUT_VALUE: Final[str] = "input.value"
|
||||
OUTPUT_VALUE: Final[str] = "output.value"
|
||||
|
||||
GEN_AI_USER_NAME = "gen_ai.user.name"
|
||||
# Retriever attributes
|
||||
RETRIEVAL_QUERY: Final[str] = "retrieval.query"
|
||||
RETRIEVAL_DOCUMENT: Final[str] = "retrieval.document"
|
||||
|
||||
GEN_AI_SPAN_KIND = "gen_ai.span.kind"
|
||||
# LLM attributes
|
||||
GEN_AI_REQUEST_MODEL: Final[str] = "gen_ai.request.model"
|
||||
GEN_AI_PROVIDER_NAME: Final[str] = "gen_ai.provider.name"
|
||||
GEN_AI_USAGE_INPUT_TOKENS: Final[str] = "gen_ai.usage.input_tokens"
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS: Final[str] = "gen_ai.usage.output_tokens"
|
||||
GEN_AI_USAGE_TOTAL_TOKENS: Final[str] = "gen_ai.usage.total_tokens"
|
||||
GEN_AI_PROMPT: Final[str] = "gen_ai.prompt"
|
||||
GEN_AI_COMPLETION: Final[str] = "gen_ai.completion"
|
||||
GEN_AI_RESPONSE_FINISH_REASON: Final[str] = "gen_ai.response.finish_reason"
|
||||
|
||||
GEN_AI_FRAMEWORK = "gen_ai.framework"
|
||||
GEN_AI_INPUT_MESSAGE: Final[str] = "gen_ai.input.messages"
|
||||
GEN_AI_OUTPUT_MESSAGE: Final[str] = "gen_ai.output.messages"
|
||||
|
||||
|
||||
# Chain
|
||||
INPUT_VALUE = "input.value"
|
||||
|
||||
OUTPUT_VALUE = "output.value"
|
||||
|
||||
|
||||
# Retriever
|
||||
RETRIEVAL_QUERY = "retrieval.query"
|
||||
|
||||
RETRIEVAL_DOCUMENT = "retrieval.document"
|
||||
|
||||
|
||||
# LLM
|
||||
GEN_AI_MODEL_NAME = "gen_ai.model_name"
|
||||
|
||||
GEN_AI_SYSTEM = "gen_ai.system"
|
||||
|
||||
GEN_AI_USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens"
|
||||
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
|
||||
|
||||
GEN_AI_USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens"
|
||||
|
||||
GEN_AI_PROMPT_TEMPLATE_TEMPLATE = "gen_ai.prompt_template.template"
|
||||
|
||||
GEN_AI_PROMPT_TEMPLATE_VARIABLE = "gen_ai.prompt_template.variable"
|
||||
|
||||
GEN_AI_PROMPT = "gen_ai.prompt"
|
||||
|
||||
GEN_AI_COMPLETION = "gen_ai.completion"
|
||||
|
||||
GEN_AI_RESPONSE_FINISH_REASON = "gen_ai.response.finish_reason"
|
||||
|
||||
# Tool
|
||||
TOOL_NAME = "tool.name"
|
||||
|
||||
TOOL_DESCRIPTION = "tool.description"
|
||||
|
||||
TOOL_PARAMETERS = "tool.parameters"
|
||||
# Tool attributes
|
||||
TOOL_NAME: Final[str] = "tool.name"
|
||||
TOOL_DESCRIPTION: Final[str] = "tool.description"
|
||||
TOOL_PARAMETERS: Final[str] = "tool.parameters"
|
||||
|
||||
|
||||
class GenAISpanKind(StrEnum):
|
||||
|
||||
190
api/core/ops/aliyun_trace/utils.py
Normal file
190
api/core/ops/aliyun_trace/utils.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from opentelemetry.trace import Link, Status, StatusCode
|
||||
|
||||
from core.ops.aliyun_trace.entities.semconv import (
|
||||
GEN_AI_FRAMEWORK,
|
||||
GEN_AI_SESSION_ID,
|
||||
GEN_AI_SPAN_KIND,
|
||||
GEN_AI_USER_ID,
|
||||
INPUT_VALUE,
|
||||
OUTPUT_VALUE,
|
||||
GenAISpanKind,
|
||||
)
|
||||
from core.rag.models.document import Document
|
||||
from core.workflow.entities import WorkflowNodeExecution
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser
|
||||
|
||||
# Constants
|
||||
DEFAULT_JSON_ENSURE_ASCII = False
|
||||
DEFAULT_FRAMEWORK_NAME = "dify"
|
||||
|
||||
|
||||
def get_user_id_from_message_data(message_data) -> str:
|
||||
user_id = message_data.from_account_id
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: EndUser | None = (
|
||||
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
|
||||
)
|
||||
if end_user_data is not None:
|
||||
user_id = end_user_data.session_id
|
||||
return user_id
|
||||
|
||||
|
||||
def create_status_from_error(error: str | None) -> Status:
|
||||
if error:
|
||||
return Status(StatusCode.ERROR, error)
|
||||
return Status(StatusCode.OK)
|
||||
|
||||
|
||||
def get_workflow_node_status(node_execution: WorkflowNodeExecution) -> Status:
|
||||
if node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
return Status(StatusCode.OK)
|
||||
if node_execution.status in [WorkflowNodeExecutionStatus.FAILED, WorkflowNodeExecutionStatus.EXCEPTION]:
|
||||
return Status(StatusCode.ERROR, str(node_execution.error))
|
||||
return Status(StatusCode.UNSET)
|
||||
|
||||
|
||||
def create_links_from_trace_id(trace_id: str | None) -> list[Link]:
|
||||
from core.ops.aliyun_trace.data_exporter.traceclient import create_link
|
||||
|
||||
links = []
|
||||
if trace_id:
|
||||
links.append(create_link(trace_id_str=trace_id))
|
||||
return links
|
||||
|
||||
|
||||
def extract_retrieval_documents(documents: list[Document]) -> list[dict[str, Any]]:
|
||||
documents_data = []
|
||||
for document in documents:
|
||||
document_data = {
|
||||
"content": document.page_content,
|
||||
"metadata": {
|
||||
"dataset_id": document.metadata.get("dataset_id"),
|
||||
"doc_id": document.metadata.get("doc_id"),
|
||||
"document_id": document.metadata.get("document_id"),
|
||||
},
|
||||
"score": document.metadata.get("score"),
|
||||
}
|
||||
documents_data.append(document_data)
|
||||
return documents_data
|
||||
|
||||
|
||||
def serialize_json_data(data: Any, ensure_ascii: bool = DEFAULT_JSON_ENSURE_ASCII) -> str:
|
||||
return json.dumps(data, ensure_ascii=ensure_ascii)
|
||||
|
||||
|
||||
def create_common_span_attributes(
|
||||
session_id: str = "",
|
||||
user_id: str = "",
|
||||
span_kind: str = GenAISpanKind.CHAIN,
|
||||
framework: str = DEFAULT_FRAMEWORK_NAME,
|
||||
inputs: str = "",
|
||||
outputs: str = "",
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
GEN_AI_SESSION_ID: session_id,
|
||||
GEN_AI_USER_ID: user_id,
|
||||
GEN_AI_SPAN_KIND: span_kind,
|
||||
GEN_AI_FRAMEWORK: framework,
|
||||
INPUT_VALUE: inputs,
|
||||
OUTPUT_VALUE: outputs,
|
||||
}
|
||||
|
||||
|
||||
def format_retrieval_documents(retrieval_documents: list) -> list:
|
||||
try:
|
||||
if not isinstance(retrieval_documents, list):
|
||||
return []
|
||||
|
||||
semantic_documents = []
|
||||
for doc in retrieval_documents:
|
||||
if not isinstance(doc, dict):
|
||||
continue
|
||||
|
||||
metadata = doc.get("metadata", {})
|
||||
content = doc.get("content", "")
|
||||
title = doc.get("title", "")
|
||||
score = metadata.get("score", 0.0)
|
||||
document_id = metadata.get("document_id", "")
|
||||
|
||||
semantic_metadata = {}
|
||||
if title:
|
||||
semantic_metadata["title"] = title
|
||||
if metadata.get("source"):
|
||||
semantic_metadata["source"] = metadata["source"]
|
||||
elif metadata.get("_source"):
|
||||
semantic_metadata["source"] = metadata["_source"]
|
||||
if metadata.get("doc_metadata"):
|
||||
doc_metadata = metadata["doc_metadata"]
|
||||
if isinstance(doc_metadata, dict):
|
||||
semantic_metadata.update(doc_metadata)
|
||||
|
||||
semantic_doc = {
|
||||
"document": {"content": content, "metadata": semantic_metadata, "score": score, "id": document_id}
|
||||
}
|
||||
semantic_documents.append(semantic_doc)
|
||||
|
||||
return semantic_documents
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
def format_input_messages(process_data: Mapping[str, Any]) -> str:
|
||||
try:
|
||||
if not isinstance(process_data, dict):
|
||||
return serialize_json_data([])
|
||||
|
||||
prompts = process_data.get("prompts", [])
|
||||
if not prompts:
|
||||
return serialize_json_data([])
|
||||
|
||||
valid_roles = {"system", "user", "assistant", "tool"}
|
||||
input_messages = []
|
||||
for prompt in prompts:
|
||||
if not isinstance(prompt, dict):
|
||||
continue
|
||||
|
||||
role = prompt.get("role", "")
|
||||
text = prompt.get("text", "")
|
||||
|
||||
if not role or role not in valid_roles:
|
||||
continue
|
||||
|
||||
if text:
|
||||
message = {"role": role, "parts": [{"type": "text", "content": text}]}
|
||||
input_messages.append(message)
|
||||
|
||||
return serialize_json_data(input_messages)
|
||||
except Exception:
|
||||
return serialize_json_data([])
|
||||
|
||||
|
||||
def format_output_messages(outputs: Mapping[str, Any]) -> str:
|
||||
try:
|
||||
if not isinstance(outputs, dict):
|
||||
return serialize_json_data([])
|
||||
|
||||
text = outputs.get("text", "")
|
||||
finish_reason = outputs.get("finish_reason", "")
|
||||
|
||||
if not text:
|
||||
return serialize_json_data([])
|
||||
|
||||
valid_finish_reasons = {"stop", "length", "content_filter", "tool_call", "error"}
|
||||
if finish_reason not in valid_finish_reasons:
|
||||
finish_reason = "stop"
|
||||
|
||||
output_message = {
|
||||
"role": "assistant",
|
||||
"parts": [{"type": "text", "content": text}],
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
|
||||
return serialize_json_data([output_message])
|
||||
except Exception:
|
||||
return serialize_json_data([])
|
||||
@@ -191,7 +191,8 @@ class AliyunConfig(BaseTracingConfig):
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_endpoint_url(v, "https://tracing-analysis-dc-hz.aliyuncs.com")
|
||||
# aliyun uses two URL formats, which may include a URL path
|
||||
return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com")
|
||||
|
||||
|
||||
OPS_FILE_PATH = "ops_trace/"
|
||||
|
||||
@@ -155,7 +155,10 @@ class OpsTraceManager:
|
||||
if key in tracing_config:
|
||||
if "*" in tracing_config[key]:
|
||||
# If the key contains '*', retain the original value from the current config
|
||||
new_config[key] = current_trace_config.get(key, tracing_config[key])
|
||||
if current_trace_config:
|
||||
new_config[key] = current_trace_config.get(key, tracing_config[key])
|
||||
else:
|
||||
new_config[key] = tracing_config[key]
|
||||
else:
|
||||
# Otherwise, encrypt the key
|
||||
new_config[key] = encrypt_token(tenant_id, tracing_config[key])
|
||||
|
||||
@@ -62,7 +62,8 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
self,
|
||||
):
|
||||
try:
|
||||
project_url = f"https://wandb.ai/{self.weave_client._project_id()}"
|
||||
project_identifier = f"{self.entity}/{self.project_name}" if self.entity else self.project_name
|
||||
project_url = f"https://wandb.ai/{project_identifier}"
|
||||
return project_url
|
||||
except Exception as e:
|
||||
logger.debug("Weave get run url failed: %s", str(e))
|
||||
@@ -417,14 +418,30 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
if not login_status:
|
||||
raise ValueError("Weave login failed")
|
||||
else:
|
||||
print("Weave login successful")
|
||||
logger.info("Weave login successful")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug("Weave API check failed: %s", str(e))
|
||||
raise ValueError(f"Weave API check failed: {str(e)}")
|
||||
|
||||
def start_call(self, run_data: WeaveTraceModel, parent_run_id: str | None = None):
|
||||
call = self.weave_client.create_call(op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes)
|
||||
inputs = run_data.inputs
|
||||
if inputs is None:
|
||||
inputs = {}
|
||||
elif not isinstance(inputs, dict):
|
||||
inputs = {"inputs": str(inputs)}
|
||||
|
||||
attributes = run_data.attributes
|
||||
if attributes is None:
|
||||
attributes = {}
|
||||
elif not isinstance(attributes, dict):
|
||||
attributes = {"attributes": str(attributes)}
|
||||
|
||||
call = self.weave_client.create_call(
|
||||
op=run_data.op,
|
||||
inputs=inputs,
|
||||
attributes=attributes,
|
||||
)
|
||||
self.calls[run_data.id] = call
|
||||
if parent_run_id:
|
||||
self.calls[run_data.id].parent_id = parent_run_id
|
||||
@@ -432,6 +449,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
def finish_call(self, run_data: WeaveTraceModel):
|
||||
call = self.calls.get(run_data.id)
|
||||
if call:
|
||||
self.weave_client.finish_call(call=call, output=run_data.outputs, exception=run_data.exception)
|
||||
exception = Exception(run_data.exception) if run_data.exception else None
|
||||
self.weave_client.finish_call(call=call, output=run_data.outputs, exception=exception)
|
||||
else:
|
||||
raise ValueError(f"Call with id {run_data.id} not found")
|
||||
|
||||
@@ -513,6 +513,21 @@ class ProviderManager:
|
||||
|
||||
return provider_name_to_provider_load_balancing_model_configs_dict
|
||||
|
||||
@staticmethod
|
||||
def _get_provider_names(provider_name: str) -> list[str]:
|
||||
"""
|
||||
provider_name: `openai` or `langgenius/openai/openai`
|
||||
return: [`openai`, `langgenius/openai/openai`]
|
||||
"""
|
||||
provider_names = [provider_name]
|
||||
model_provider_id = ModelProviderID(provider_name)
|
||||
if model_provider_id.is_langgenius():
|
||||
if "/" in provider_name:
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
else:
|
||||
provider_names.append(str(model_provider_id))
|
||||
return provider_names
|
||||
|
||||
@staticmethod
|
||||
def get_provider_available_credentials(tenant_id: str, provider_name: str) -> list[CredentialConfiguration]:
|
||||
"""
|
||||
@@ -525,7 +540,10 @@ class ProviderManager:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = (
|
||||
select(ProviderCredential)
|
||||
.where(ProviderCredential.tenant_id == tenant_id, ProviderCredential.provider_name == provider_name)
|
||||
.where(
|
||||
ProviderCredential.tenant_id == tenant_id,
|
||||
ProviderCredential.provider_name.in_(ProviderManager._get_provider_names(provider_name)),
|
||||
)
|
||||
.order_by(ProviderCredential.created_at.desc())
|
||||
)
|
||||
|
||||
@@ -554,7 +572,7 @@ class ProviderManager:
|
||||
select(ProviderModelCredential)
|
||||
.where(
|
||||
ProviderModelCredential.tenant_id == tenant_id,
|
||||
ProviderModelCredential.provider_name == provider_name,
|
||||
ProviderModelCredential.provider_name.in_(ProviderManager._get_provider_names(provider_name)),
|
||||
ProviderModelCredential.model_name == model_name,
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
)
|
||||
|
||||
@@ -106,7 +106,9 @@ class RetrievalService:
|
||||
if exceptions:
|
||||
raise ValueError(";\n".join(exceptions))
|
||||
|
||||
# Deduplicate documents for hybrid search to avoid duplicate chunks
|
||||
if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value:
|
||||
all_documents = cls._deduplicate_documents(all_documents)
|
||||
data_post_processor = DataPostProcessor(
|
||||
str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
|
||||
)
|
||||
@@ -143,6 +145,40 @@ class RetrievalService:
|
||||
)
|
||||
return all_documents
|
||||
|
||||
@classmethod
|
||||
def _deduplicate_documents(cls, documents: list[Document]) -> list[Document]:
|
||||
"""Deduplicate documents based on doc_id to avoid duplicate chunks in hybrid search."""
|
||||
if not documents:
|
||||
return documents
|
||||
|
||||
unique_documents = []
|
||||
seen_doc_ids = set()
|
||||
|
||||
for document in documents:
|
||||
# For dify provider documents, use doc_id for deduplication
|
||||
if document.provider == "dify" and document.metadata is not None and "doc_id" in document.metadata:
|
||||
doc_id = document.metadata["doc_id"]
|
||||
if doc_id not in seen_doc_ids:
|
||||
seen_doc_ids.add(doc_id)
|
||||
unique_documents.append(document)
|
||||
# If duplicate, keep the one with higher score
|
||||
elif "score" in document.metadata:
|
||||
# Find existing document with same doc_id and compare scores
|
||||
for i, existing_doc in enumerate(unique_documents):
|
||||
if (
|
||||
existing_doc.metadata
|
||||
and existing_doc.metadata.get("doc_id") == doc_id
|
||||
and existing_doc.metadata.get("score", 0) < document.metadata.get("score", 0)
|
||||
):
|
||||
unique_documents[i] = document
|
||||
break
|
||||
else:
|
||||
# For non-dify documents, use content-based deduplication
|
||||
if document not in unique_documents:
|
||||
unique_documents.append(document)
|
||||
|
||||
return unique_documents
|
||||
|
||||
@classmethod
|
||||
def _get_dataset(cls, dataset_id: str) -> Dataset | None:
|
||||
with Session(db.engine) as session:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
@@ -9,11 +10,24 @@ from pymochow import MochowClient # type: ignore
|
||||
from pymochow.auth.bce_credentials import BceCredentials # type: ignore
|
||||
from pymochow.configuration import Configuration # type: ignore
|
||||
from pymochow.exception import ServerError # type: ignore
|
||||
from pymochow.model.database import Database
|
||||
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState # type: ignore
|
||||
from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex # type: ignore
|
||||
from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row # type: ignore
|
||||
from pymochow.model.schema import (
|
||||
Field,
|
||||
FilteringIndex,
|
||||
HNSWParams,
|
||||
InvertedIndex,
|
||||
InvertedIndexAnalyzer,
|
||||
InvertedIndexFieldAttribute,
|
||||
InvertedIndexParams,
|
||||
InvertedIndexParseMode,
|
||||
Schema,
|
||||
VectorIndex,
|
||||
) # type: ignore
|
||||
from pymochow.model.table import AnnSearch, BM25SearchRequest, HNSWSearchParams, Partition, Row # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field as VDBField
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
@@ -22,6 +36,8 @@ from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaiduConfig(BaseModel):
|
||||
endpoint: str
|
||||
@@ -30,9 +46,11 @@ class BaiduConfig(BaseModel):
|
||||
api_key: str
|
||||
database: str
|
||||
index_type: str = "HNSW"
|
||||
metric_type: str = "L2"
|
||||
metric_type: str = "IP"
|
||||
shard: int = 1
|
||||
replicas: int = 3
|
||||
inverted_index_analyzer: str = "DEFAULT_ANALYZER"
|
||||
inverted_index_parser_mode: str = "COARSE_MODE"
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -49,13 +67,9 @@ class BaiduConfig(BaseModel):
|
||||
|
||||
|
||||
class BaiduVector(BaseVector):
|
||||
field_id: str = "id"
|
||||
field_vector: str = "vector"
|
||||
field_text: str = "text"
|
||||
field_metadata: str = "metadata"
|
||||
field_app_id: str = "app_id"
|
||||
field_annotation_id: str = "annotation_id"
|
||||
index_vector: str = "vector_idx"
|
||||
vector_index: str = "vector_idx"
|
||||
filtering_index: str = "filtering_idx"
|
||||
inverted_index: str = "content_inverted_idx"
|
||||
|
||||
def __init__(self, collection_name: str, config: BaiduConfig):
|
||||
super().__init__(collection_name)
|
||||
@@ -74,8 +88,6 @@ class BaiduVector(BaseVector):
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
texts = [doc.page_content for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents if doc.metadata is not None]
|
||||
total_count = len(documents)
|
||||
batch_size = 1000
|
||||
|
||||
@@ -84,29 +96,31 @@ class BaiduVector(BaseVector):
|
||||
for start in range(0, total_count, batch_size):
|
||||
end = min(start + batch_size, total_count)
|
||||
rows = []
|
||||
assert len(metadatas) == total_count, "metadatas length should be equal to total_count"
|
||||
for i in range(start, end, 1):
|
||||
metadata = documents[i].metadata
|
||||
row = Row(
|
||||
id=metadatas[i].get("doc_id", str(uuid.uuid4())),
|
||||
id=metadata.get("doc_id", str(uuid.uuid4())),
|
||||
page_content=documents[i].page_content,
|
||||
metadata=metadata,
|
||||
vector=embeddings[i],
|
||||
text=texts[i],
|
||||
metadata=json.dumps(metadatas[i]),
|
||||
app_id=metadatas[i].get("app_id", ""),
|
||||
annotation_id=metadatas[i].get("annotation_id", ""),
|
||||
)
|
||||
rows.append(row)
|
||||
table.upsert(rows=rows)
|
||||
|
||||
# rebuild vector index after upsert finished
|
||||
table.rebuild_index(self.index_vector)
|
||||
table.rebuild_index(self.vector_index)
|
||||
timeout = 3600 # 1 hour timeout
|
||||
start_time = time.time()
|
||||
while True:
|
||||
time.sleep(1)
|
||||
index = table.describe_index(self.index_vector)
|
||||
index = table.describe_index(self.vector_index)
|
||||
if index.state == IndexState.NORMAL:
|
||||
break
|
||||
if time.time() - start_time > timeout:
|
||||
raise TimeoutError(f"Index rebuild timeout after {timeout} seconds")
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
res = self._db.table(self._collection_name).query(primary_key={self.field_id: id})
|
||||
res = self._db.table(self._collection_name).query(primary_key={VDBField.PRIMARY_KEY: id})
|
||||
if res and res.code == 0:
|
||||
return True
|
||||
return False
|
||||
@@ -115,53 +129,73 @@ class BaiduVector(BaseVector):
|
||||
if not ids:
|
||||
return
|
||||
quoted_ids = [f"'{id}'" for id in ids]
|
||||
self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})")
|
||||
self._db.table(self._collection_name).delete(filter=f"{VDBField.PRIMARY_KEY} IN({', '.join(quoted_ids)})")
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
self._db.table(self._collection_name).delete(filter=f"{key} = '{value}'")
|
||||
# Escape double quotes in value to prevent injection
|
||||
escaped_value = value.replace('"', '\\"')
|
||||
self._db.table(self._collection_name).delete(filter=f'metadata["{key}"] = "{escaped_value}"')
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector]
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filter = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
anns = AnnSearch(
|
||||
vector_field=self.field_vector,
|
||||
vector_floats=query_vector,
|
||||
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
|
||||
filter=f"document_id IN ({document_ids})",
|
||||
)
|
||||
else:
|
||||
anns = AnnSearch(
|
||||
vector_field=self.field_vector,
|
||||
vector_floats=query_vector,
|
||||
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
|
||||
)
|
||||
filter = f'metadata["document_id"] IN({document_ids})'
|
||||
anns = AnnSearch(
|
||||
vector_field=VDBField.VECTOR,
|
||||
vector_floats=query_vector,
|
||||
params=HNSWSearchParams(ef=kwargs.get("ef", 20), limit=kwargs.get("top_k", 4)),
|
||||
filter=filter,
|
||||
)
|
||||
res = self._db.table(self._collection_name).search(
|
||||
anns=anns,
|
||||
projections=[self.field_id, self.field_text, self.field_metadata],
|
||||
retrieve_vector=True,
|
||||
projections=[VDBField.CONTENT_KEY, VDBField.METADATA_KEY],
|
||||
retrieve_vector=False,
|
||||
)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
return self._get_search_res(res, score_threshold)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
# baidu vector database doesn't support bm25 search on current version
|
||||
return []
|
||||
# document ids filter
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filter = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
filter = f'metadata["document_id"] IN({document_ids})'
|
||||
|
||||
request = BM25SearchRequest(
|
||||
index_name=self.inverted_index, search_text=query, limit=kwargs.get("top_k", 4), filter=filter
|
||||
)
|
||||
res = self._db.table(self._collection_name).bm25_search(
|
||||
request=request, projections=[VDBField.CONTENT_KEY, VDBField.METADATA_KEY]
|
||||
)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
return self._get_search_res(res, score_threshold)
|
||||
|
||||
def _get_search_res(self, res, score_threshold) -> list[Document]:
|
||||
docs = []
|
||||
for row in res.rows:
|
||||
row_data = row.get("row", {})
|
||||
meta = row_data.get(self.field_metadata)
|
||||
if meta is not None:
|
||||
meta = json.loads(meta)
|
||||
score = row.get("score", 0.0)
|
||||
meta = row_data.get(VDBField.METADATA_KEY, {})
|
||||
|
||||
# Handle both JSON string and dict formats for backward compatibility
|
||||
if isinstance(meta, str):
|
||||
try:
|
||||
import json
|
||||
|
||||
meta = json.loads(meta)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
meta = {}
|
||||
elif not isinstance(meta, dict):
|
||||
meta = {}
|
||||
|
||||
if score >= score_threshold:
|
||||
meta["score"] = score
|
||||
doc = Document(page_content=row_data.get(self.field_text), metadata=meta)
|
||||
doc = Document(page_content=row_data.get(VDBField.CONTENT_KEY), metadata=meta)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
def delete(self):
|
||||
@@ -178,7 +212,7 @@ class BaiduVector(BaseVector):
|
||||
client = MochowClient(config)
|
||||
return client
|
||||
|
||||
def _init_database(self):
|
||||
def _init_database(self) -> Database:
|
||||
exists = False
|
||||
for db in self._client.list_databases():
|
||||
if db.database_name == self._client_config.database:
|
||||
@@ -192,10 +226,10 @@ class BaiduVector(BaseVector):
|
||||
self._client.create_database(database_name=self._client_config.database)
|
||||
except ServerError as e:
|
||||
if e.code == ServerErrCode.DB_ALREADY_EXIST:
|
||||
pass
|
||||
return self._client.database(self._client_config.database)
|
||||
else:
|
||||
raise
|
||||
return
|
||||
return self._client.database(self._client_config.database)
|
||||
|
||||
def _table_existed(self) -> bool:
|
||||
tables = self._db.list_table()
|
||||
@@ -232,7 +266,7 @@ class BaiduVector(BaseVector):
|
||||
fields = []
|
||||
fields.append(
|
||||
Field(
|
||||
self.field_id,
|
||||
VDBField.PRIMARY_KEY,
|
||||
FieldType.STRING,
|
||||
primary_key=True,
|
||||
partition_key=True,
|
||||
@@ -240,24 +274,57 @@ class BaiduVector(BaseVector):
|
||||
not_null=True,
|
||||
)
|
||||
)
|
||||
fields.append(Field(self.field_metadata, FieldType.STRING, not_null=True))
|
||||
fields.append(Field(self.field_app_id, FieldType.STRING))
|
||||
fields.append(Field(self.field_annotation_id, FieldType.STRING))
|
||||
fields.append(Field(self.field_text, FieldType.TEXT, not_null=True))
|
||||
fields.append(Field(self.field_vector, FieldType.FLOAT_VECTOR, not_null=True, dimension=dimension))
|
||||
fields.append(Field(VDBField.CONTENT_KEY, FieldType.TEXT, not_null=False))
|
||||
fields.append(Field(VDBField.METADATA_KEY, FieldType.JSON, not_null=False))
|
||||
fields.append(Field(VDBField.VECTOR, FieldType.FLOAT_VECTOR, not_null=True, dimension=dimension))
|
||||
|
||||
# Construct vector index params
|
||||
indexes = []
|
||||
indexes.append(
|
||||
VectorIndex(
|
||||
index_name="vector_idx",
|
||||
index_name=self.vector_index,
|
||||
index_type=index_type,
|
||||
field="vector",
|
||||
field=VDBField.VECTOR,
|
||||
metric_type=metric_type,
|
||||
params=HNSWParams(m=16, efconstruction=200),
|
||||
)
|
||||
)
|
||||
|
||||
# Filtering index
|
||||
indexes.append(
|
||||
FilteringIndex(
|
||||
index_name=self.filtering_index,
|
||||
fields=[VDBField.METADATA_KEY],
|
||||
)
|
||||
)
|
||||
|
||||
# Get analyzer and parse_mode from config
|
||||
analyzer = getattr(
|
||||
InvertedIndexAnalyzer,
|
||||
self._client_config.inverted_index_analyzer,
|
||||
InvertedIndexAnalyzer.DEFAULT_ANALYZER,
|
||||
)
|
||||
|
||||
parse_mode = getattr(
|
||||
InvertedIndexParseMode,
|
||||
self._client_config.inverted_index_parser_mode,
|
||||
InvertedIndexParseMode.COARSE_MODE,
|
||||
)
|
||||
|
||||
# Inverted index
|
||||
indexes.append(
|
||||
InvertedIndex(
|
||||
index_name=self.inverted_index,
|
||||
fields=[VDBField.CONTENT_KEY],
|
||||
params=InvertedIndexParams(
|
||||
analyzer=analyzer,
|
||||
parse_mode=parse_mode,
|
||||
case_sensitive=True,
|
||||
),
|
||||
field_attributes=[InvertedIndexFieldAttribute.ANALYZED],
|
||||
)
|
||||
)
|
||||
|
||||
# Create table
|
||||
self._db.create_table(
|
||||
table_name=self._collection_name,
|
||||
@@ -268,11 +335,15 @@ class BaiduVector(BaseVector):
|
||||
)
|
||||
|
||||
# Wait for table created
|
||||
timeout = 300 # 5 minutes timeout
|
||||
start_time = time.time()
|
||||
while True:
|
||||
time.sleep(1)
|
||||
table = self._db.describe_table(self._collection_name)
|
||||
if table.state == TableState.NORMAL:
|
||||
break
|
||||
if time.time() - start_time > timeout:
|
||||
raise TimeoutError(f"Table creation timeout after {timeout} seconds")
|
||||
redis_client.set(table_exist_cache_key, 1, ex=3600)
|
||||
|
||||
|
||||
@@ -296,5 +367,7 @@ class BaiduVectorFactory(AbstractVectorFactory):
|
||||
database=dify_config.BAIDU_VECTOR_DB_DATABASE or "",
|
||||
shard=dify_config.BAIDU_VECTOR_DB_SHARD,
|
||||
replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS,
|
||||
inverted_index_analyzer=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER,
|
||||
inverted_index_parser_mode=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE,
|
||||
),
|
||||
)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user