mirror of
https://github.com/langgenius/dify.git
synced 2026-04-10 02:29:25 +08:00
Compare commits
40 Commits
copilot/an
...
fix/vinext
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a0e778258a | ||
|
|
35caa04fe7 | ||
|
|
0706c52680 | ||
|
|
d872594945 | ||
|
|
0ae73296d7 | ||
|
|
125ece1d0c | ||
|
|
fd71e85ed4 | ||
|
|
75bbb616ea | ||
|
|
2a468da440 | ||
|
|
322d3cd555 | ||
|
|
e1df0fad2b | ||
|
|
3f515dcdda | ||
|
|
db627e75f6 | ||
|
|
1ecedab024 | ||
|
|
a0ed350871 | ||
|
|
a5832df586 | ||
|
|
a808389122 | ||
|
|
45a8967b8b | ||
|
|
3835cfe87e | ||
|
|
eaf86c521f | ||
|
|
08b3bce53c | ||
|
|
2a3cc2951b | ||
|
|
504138bb23 | ||
|
|
0ab4e16335 | ||
|
|
01991f3536 | ||
|
|
4f835107b2 | ||
|
|
3f3b788356 | ||
|
|
b9d05d3456 | ||
|
|
a480e9beb1 | ||
|
|
a59c54b3e7 | ||
|
|
7737bdc699 | ||
|
|
65637fc6b7 | ||
|
|
be6f7b8712 | ||
|
|
b257e8ed44 | ||
|
|
176d3c8c3a | ||
|
|
c72ac8a434 | ||
|
|
497feac48e | ||
|
|
8906ab8e52 | ||
|
|
03dcbeafdf | ||
|
|
bbfa28e8a7 |
33
.github/actions/setup-web/action.yml
vendored
Normal file
33
.github/actions/setup-web/action.yml
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
name: Setup Web Environment
|
||||
description: Setup pnpm, Node.js, and install web dependencies.
|
||||
|
||||
inputs:
|
||||
node-version:
|
||||
description: Node.js version to use
|
||||
required: false
|
||||
default: "22"
|
||||
install-dependencies:
|
||||
description: Whether to install web dependencies after setting up Node.js
|
||||
required: false
|
||||
default: "true"
|
||||
|
||||
runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@41ff72655975bd51cab0327fa583b6e92b6d3061 # v4.2.0
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0
|
||||
with:
|
||||
node-version: ${{ inputs.node-version }}
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
if: ${{ inputs.install-dependencies == 'true' }}
|
||||
shell: bash
|
||||
run: pnpm --dir web install --frozen-lockfile
|
||||
13
.github/dependabot.yml
vendored
13
.github/dependabot.yml
vendored
@@ -24,6 +24,15 @@ updates:
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 2
|
||||
ignore:
|
||||
- dependency-name: "tailwind-merge"
|
||||
update-types: ["version-update:semver-major"]
|
||||
- dependency-name: "tailwindcss"
|
||||
update-types: ["version-update:semver-major"]
|
||||
- dependency-name: "react-syntax-highlighter"
|
||||
update-types: ["version-update:semver-major"]
|
||||
- dependency-name: "react-window"
|
||||
update-types: ["version-update:semver-major"]
|
||||
groups:
|
||||
lexical:
|
||||
patterns:
|
||||
@@ -33,6 +42,9 @@ updates:
|
||||
patterns:
|
||||
- "storybook"
|
||||
- "@storybook/*"
|
||||
eslint-group:
|
||||
patterns:
|
||||
- "*eslint*"
|
||||
npm-dependencies:
|
||||
patterns:
|
||||
- "*"
|
||||
@@ -41,3 +53,4 @@ updates:
|
||||
- "@lexical/*"
|
||||
- "storybook"
|
||||
- "@storybook/*"
|
||||
- "*eslint*"
|
||||
|
||||
17
.github/workflows/anti-slop.yml
vendored
Normal file
17
.github/workflows/anti-slop.yml
vendored
Normal file
@@ -0,0 +1,17 @@
|
||||
name: Anti-Slop PR Check
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
types: [opened, edited, synchronize]
|
||||
|
||||
permissions:
|
||||
pull-requests: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
anti-slop:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: peakoss/anti-slop@v0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
6
.github/workflows/api-tests.yml
vendored
6
.github/workflows/api-tests.yml
vendored
@@ -22,12 +22,12 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@@ -51,7 +51,7 @@ jobs:
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
- name: Set up Sandbox
|
||||
uses: hoverkraft-tech/compose-action@v2
|
||||
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
|
||||
39
.github/workflows/autofix.yml
vendored
39
.github/workflows/autofix.yml
vendored
@@ -12,22 +12,34 @@ jobs:
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Check Docker Compose inputs
|
||||
id: docker-compose-changes
|
||||
uses: tj-actions/changed-files@v47
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
with:
|
||||
files: |
|
||||
docker/generate_docker_compose
|
||||
docker/.env.example
|
||||
docker/docker-compose-template.yaml
|
||||
docker/docker-compose.yaml
|
||||
- uses: actions/setup-python@v6
|
||||
- name: Check web inputs
|
||||
id: web-changes
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
- name: Check api inputs
|
||||
id: api-changes
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
with:
|
||||
files: |
|
||||
api/**
|
||||
- uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- uses: astral-sh/setup-uv@v7
|
||||
- uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
|
||||
|
||||
- name: Generate Docker Compose
|
||||
if: steps.docker-compose-changes.outputs.any_changed == 'true'
|
||||
@@ -35,7 +47,8 @@ jobs:
|
||||
cd docker
|
||||
./generate_docker_compose
|
||||
|
||||
- run: |
|
||||
- if: steps.api-changes.outputs.any_changed == 'true'
|
||||
run: |
|
||||
cd api
|
||||
uv sync --dev
|
||||
# fmt first to avoid line too long
|
||||
@@ -46,11 +59,13 @@ jobs:
|
||||
uv run ruff format ..
|
||||
|
||||
- name: count migration progress
|
||||
if: steps.api-changes.outputs.any_changed == 'true'
|
||||
run: |
|
||||
cd api
|
||||
./cnt_base.sh
|
||||
|
||||
- name: ast-grep
|
||||
if: steps.api-changes.outputs.any_changed == 'true'
|
||||
run: |
|
||||
# ast-grep exits 1 if no matches are found; allow idempotent runs.
|
||||
uvx --from ast-grep-cli ast-grep --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all || true
|
||||
@@ -84,4 +99,16 @@ jobs:
|
||||
run: |
|
||||
uvx --python 3.13 mdformat . --exclude ".agents/skills/**"
|
||||
|
||||
- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27
|
||||
- name: Setup web environment
|
||||
if: steps.web-changes.outputs.any_changed == 'true'
|
||||
uses: ./.github/actions/setup-web
|
||||
with:
|
||||
node-version: "24"
|
||||
|
||||
- name: ESLint autofix
|
||||
if: steps.web-changes.outputs.any_changed == 'true'
|
||||
run: |
|
||||
cd web
|
||||
pnpm eslint --concurrency=2 --prune-suppressions --quiet || true
|
||||
|
||||
- uses: autofix-ci/action@7a166d7532b277f34e16238930461bf77f9d7ed8 # v1.3.3
|
||||
|
||||
18
.github/workflows/build-push.yml
vendored
18
.github/workflows/build-push.yml
vendored
@@ -53,26 +53,26 @@ jobs:
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
|
||||
with:
|
||||
username: ${{ env.DOCKERHUB_USER }}
|
||||
password: ${{ env.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4.0.0
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||
|
||||
- name: Extract metadata for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6.0.0
|
||||
with:
|
||||
images: ${{ env[matrix.image_name_env] }}
|
||||
|
||||
- name: Build Docker image
|
||||
id: build
|
||||
uses: docker/build-push-action@v6
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
|
||||
with:
|
||||
context: "{{defaultContext}}:${{ matrix.context }}"
|
||||
platforms: ${{ matrix.platform }}
|
||||
@@ -91,7 +91,7 @@ jobs:
|
||||
touch "/tmp/digests/${sanitized_digest}"
|
||||
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
with:
|
||||
name: digests-${{ matrix.context }}-${{ env.PLATFORM_PAIR }}
|
||||
path: /tmp/digests/*
|
||||
@@ -113,21 +113,21 @@ jobs:
|
||||
context: "web"
|
||||
steps:
|
||||
- name: Download digests
|
||||
uses: actions/download-artifact@v7
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.0
|
||||
with:
|
||||
path: /tmp/digests
|
||||
pattern: digests-${{ matrix.context }}-*
|
||||
merge-multiple: true
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
|
||||
with:
|
||||
username: ${{ env.DOCKERHUB_USER }}
|
||||
password: ${{ env.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Extract metadata for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6.0.0
|
||||
with:
|
||||
images: ${{ env[matrix.image_name_env] }}
|
||||
tags: |
|
||||
|
||||
12
.github/workflows/db-migration-test.yml
vendored
12
.github/workflows/db-migration-test.yml
vendored
@@ -13,13 +13,13 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
@@ -40,7 +40,7 @@ jobs:
|
||||
cp middleware.env.example middleware.env
|
||||
|
||||
- name: Set up Middlewares
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
@@ -63,13 +63,13 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
@@ -94,7 +94,7 @@ jobs:
|
||||
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=mysql/' middleware.env
|
||||
|
||||
- name: Set up Middlewares
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
|
||||
2
.github/workflows/deploy-agent-dev.yml
vendored
2
.github/workflows/deploy-agent-dev.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
github.event.workflow_run.head_branch == 'deploy/agent-dev'
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v1
|
||||
uses: appleboy/ssh-action@0ff4204d59e8e51228ff73bce53f80d53301dee2 # v1.2.5
|
||||
with:
|
||||
host: ${{ secrets.AGENT_DEV_SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
|
||||
2
.github/workflows/deploy-dev.yml
vendored
2
.github/workflows/deploy-dev.yml
vendored
@@ -16,7 +16,7 @@ jobs:
|
||||
github.event.workflow_run.head_branch == 'deploy/dev'
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v1
|
||||
uses: appleboy/ssh-action@0ff4204d59e8e51228ff73bce53f80d53301dee2 # v1.2.5
|
||||
with:
|
||||
host: ${{ secrets.SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
|
||||
2
.github/workflows/deploy-hitl.yml
vendored
2
.github/workflows/deploy-hitl.yml
vendored
@@ -16,7 +16,7 @@ jobs:
|
||||
github.event.workflow_run.head_branch == 'build/feat/hitl'
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v1
|
||||
uses: appleboy/ssh-action@0ff4204d59e8e51228ff73bce53f80d53301dee2 # v1.2.5
|
||||
with:
|
||||
host: ${{ secrets.HITL_SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
|
||||
6
.github/workflows/docker-build.yml
vendored
6
.github/workflows/docker-build.yml
vendored
@@ -32,13 +32,13 @@ jobs:
|
||||
context: "web"
|
||||
steps:
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4.0.0
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||
|
||||
- name: Build Docker Image
|
||||
uses: docker/build-push-action@v6
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
|
||||
with:
|
||||
push: false
|
||||
context: "{{defaultContext}}:${{ matrix.context }}"
|
||||
|
||||
2
.github/workflows/labeler.yml
vendored
2
.github/workflows/labeler.yml
vendored
@@ -9,6 +9,6 @@ jobs:
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/labeler@v6
|
||||
- uses: actions/labeler@634933edcd8ababfe52f92936142cc22ac488b1b # v6.0.1
|
||||
with:
|
||||
sync-labels: true
|
||||
|
||||
5
.github/workflows/main-ci.yml
vendored
5
.github/workflows/main-ci.yml
vendored
@@ -27,8 +27,8 @@ jobs:
|
||||
vdb-changed: ${{ steps.changes.outputs.vdb }}
|
||||
migration-changed: ${{ steps.changes.outputs.migration }}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: dorny/paths-filter@v3
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2
|
||||
id: changes
|
||||
with:
|
||||
filters: |
|
||||
@@ -39,6 +39,7 @@ jobs:
|
||||
web:
|
||||
- 'web/**'
|
||||
- '.github/workflows/web-tests.yml'
|
||||
- '.github/actions/setup-web/**'
|
||||
vdb:
|
||||
- 'api/core/rag/datasource/**'
|
||||
- 'docker/**'
|
||||
|
||||
4
.github/workflows/pyrefly-diff-comment.yml
vendored
4
.github/workflows/pyrefly-diff-comment.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }}
|
||||
steps:
|
||||
- name: Download pyrefly diff artifact
|
||||
uses: actions/github-script@v8
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
@@ -49,7 +49,7 @@ jobs:
|
||||
run: unzip -o pyrefly_diff.zip
|
||||
|
||||
- name: Post comment
|
||||
uses: actions/github-script@v8
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
|
||||
8
.github/workflows/pyrefly-diff.yml
vendored
8
.github/workflows/pyrefly-diff.yml
vendored
@@ -17,12 +17,12 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Checkout PR branch
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python & UV
|
||||
uses: astral-sh/setup-uv@v5
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
@@ -55,7 +55,7 @@ jobs:
|
||||
echo ${{ github.event.pull_request.number }} > pr_number.txt
|
||||
|
||||
- name: Upload pyrefly diff
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
with:
|
||||
name: pyrefly_diff
|
||||
path: |
|
||||
@@ -64,7 +64,7 @@ jobs:
|
||||
|
||||
- name: Comment PR with pyrefly diff
|
||||
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
|
||||
uses: actions/github-script@v8
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
|
||||
2
.github/workflows/semantic-pull-request.yml
vendored
2
.github/workflows/semantic-pull-request.yml
vendored
@@ -16,6 +16,6 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check title
|
||||
uses: amannn/action-semantic-pull-request@v6.1.1
|
||||
uses: amannn/action-semantic-pull-request@48f256284bd46cdaab1048c3721360e808335d50 # v6.1.1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
2
.github/workflows/stale.yml
vendored
2
.github/workflows/stale.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- uses: actions/stale@v10
|
||||
- uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10.2.0
|
||||
with:
|
||||
days-before-issue-stale: 15
|
||||
days-before-issue-close: 3
|
||||
|
||||
36
.github/workflows/style.yml
vendored
36
.github/workflows/style.yml
vendored
@@ -19,13 +19,13 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v47
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
with:
|
||||
files: |
|
||||
api/**
|
||||
@@ -33,7 +33,7 @@ jobs:
|
||||
|
||||
- name: Setup UV and Python
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
|
||||
with:
|
||||
enable-cache: false
|
||||
python-version: "3.12"
|
||||
@@ -67,36 +67,22 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v47
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
.github/workflows/style.yml
|
||||
.github/actions/setup-web/**
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Setup NodeJS
|
||||
uses: actions/setup-node@v6
|
||||
- name: Setup web environment
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
node-version: 22
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Web dependencies
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm install --frozen-lockfile
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Web style check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
@@ -134,14 +120,14 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v47
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
with:
|
||||
files: |
|
||||
**.sh
|
||||
@@ -152,7 +138,7 @@ jobs:
|
||||
.editorconfig
|
||||
|
||||
- name: Super-linter
|
||||
uses: super-linter/super-linter/slim@v8
|
||||
uses: super-linter/super-linter/slim@61abc07d755095a68f4987d1c2c3d1d64408f1f9 # v8.5.0
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
env:
|
||||
BASH_SEVERITY: warning
|
||||
|
||||
4
.github/workflows/tool-test-sdks.yaml
vendored
4
.github/workflows/tool-test-sdks.yaml
vendored
@@ -21,12 +21,12 @@ jobs:
|
||||
working-directory: sdks/nodejs-client
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Use Node.js
|
||||
uses: actions/setup-node@v6
|
||||
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0
|
||||
with:
|
||||
node-version: 22
|
||||
cache: ''
|
||||
|
||||
18
.github/workflows/translate-i18n-claude.yml
vendored
18
.github/workflows/translate-i18n-claude.yml
vendored
@@ -38,7 +38,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
@@ -48,18 +48,10 @@ jobs:
|
||||
git config --global user.name "github-actions[bot]"
|
||||
git config --global user.email "github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 22
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
install-dependencies: "false"
|
||||
|
||||
- name: Detect changed files and generate diff
|
||||
id: detect_changes
|
||||
@@ -130,7 +122,7 @@ jobs:
|
||||
|
||||
- name: Run Claude Code for Translation Sync
|
||||
if: steps.detect_changes.outputs.CHANGED_FILES != ''
|
||||
uses: anthropics/claude-code-action@v1
|
||||
uses: anthropics/claude-code-action@26ec041249acb0a944c0a47b6c0c13f05dbc5b44 # v1.0.70
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
4
.github/workflows/trigger-i18n-sync.yml
vendored
4
.github/workflows/trigger-i18n-sync.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
@@ -59,7 +59,7 @@ jobs:
|
||||
|
||||
- name: Trigger i18n sync workflow
|
||||
if: steps.detect.outputs.has_changes == 'true'
|
||||
uses: peter-evans/repository-dispatch@v3
|
||||
uses: peter-evans/repository-dispatch@28959ce8df70de7be546dd1250a005dd32156697 # v4.0.1
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
event-type: i18n-sync
|
||||
|
||||
8
.github/workflows/vdb-tests.yml
vendored
8
.github/workflows/vdb-tests.yml
vendored
@@ -19,19 +19,19 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Free Disk Space
|
||||
uses: endersonmenezes/free-disk-space@v3
|
||||
uses: endersonmenezes/free-disk-space@7901478139cff6e9d44df5972fd8ab8fcade4db1 # v3.2.2
|
||||
with:
|
||||
remove_dotnet: true
|
||||
remove_haskell: true
|
||||
remove_tool_cache: true
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@@ -60,7 +60,7 @@ jobs:
|
||||
# tiflash
|
||||
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase, OceanBase)
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.yaml
|
||||
|
||||
68
.github/workflows/web-tests.yml
vendored
68
.github/workflows/web-tests.yml
vendored
@@ -26,32 +26,19 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 22
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Run tests
|
||||
run: pnpm vitest run --reporter=blob --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage
|
||||
|
||||
- name: Upload blob report
|
||||
if: ${{ !cancelled() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
with:
|
||||
name: blob-report-${{ matrix.shardIndex }}
|
||||
path: web/.vitest-reports/*
|
||||
@@ -70,28 +57,15 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 22
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Download blob reports
|
||||
uses: actions/download-artifact@v6
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.0
|
||||
with:
|
||||
path: web/.vitest-reports
|
||||
pattern: blob-report-*
|
||||
@@ -419,7 +393,7 @@ jobs:
|
||||
|
||||
- name: Upload Coverage Artifact
|
||||
if: steps.coverage-summary.outputs.has_coverage == 'true'
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
with:
|
||||
name: web-coverage-report
|
||||
path: web/coverage
|
||||
@@ -435,36 +409,22 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v47
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
.github/workflows/web-tests.yml
|
||||
.github/actions/setup-web/**
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Setup NodeJS
|
||||
uses: actions/setup-node@v6
|
||||
- name: Setup web environment
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
node-version: 22
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Web dependencies
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm install --frozen-lockfile
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Web build check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
|
||||
@@ -56,7 +56,7 @@
|
||||
<a href="./docs/bn-BD/README.md"><img alt="README in বাংলা" src="https://img.shields.io/badge/বাংলা-d9d9d9"></a>
|
||||
</p>
|
||||
|
||||
Dify is an open-source platform for developing LLM applications. Its intuitive interface combines agentic AI workflows, RAG pipelines, agent capabilities, model management, observability features, and more—allowing you to quickly move from prototype to production.
|
||||
Dify is an open-source LLM app development platform. Its intuitive interface combines AI workflow, RAG pipeline, agent capabilities, model management, observability features (including [Opik](https://www.comet.com/docs/opik/integrations/dify), [Langfuse](https://docs.langfuse.com), and [Arize Phoenix](https://docs.arize.com/phoenix)) and more, letting you quickly go from prototype to production. Here's a list of the core features:
|
||||
|
||||
## Quick start
|
||||
|
||||
|
||||
@@ -44,9 +44,7 @@ forbidden_modules =
|
||||
allow_indirect_imports = True
|
||||
ignore_imports =
|
||||
dify_graph.nodes.agent.agent_node -> extensions.ext_database
|
||||
dify_graph.nodes.llm.file_saver -> extensions.ext_database
|
||||
dify_graph.nodes.llm.node -> extensions.ext_database
|
||||
dify_graph.nodes.tool.tool_node -> extensions.ext_database
|
||||
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
|
||||
dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
|
||||
|
||||
@@ -112,9 +110,7 @@ ignore_imports =
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
|
||||
dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer
|
||||
dify_graph.nodes.tool.tool_node -> models
|
||||
dify_graph.nodes.agent.agent_node -> models.model
|
||||
dify_graph.nodes.llm.file_saver -> core.helper.ssrf_proxy
|
||||
dify_graph.nodes.llm.node -> core.helper.code_executor
|
||||
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors
|
||||
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output
|
||||
@@ -135,9 +131,7 @@ ignore_imports =
|
||||
dify_graph.nodes.llm.file_saver -> core.tools.tool_file_manager
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.errors
|
||||
dify_graph.nodes.agent.agent_node -> extensions.ext_database
|
||||
dify_graph.nodes.llm.file_saver -> extensions.ext_database
|
||||
dify_graph.nodes.llm.node -> extensions.ext_database
|
||||
dify_graph.nodes.tool.tool_node -> extensions.ext_database
|
||||
dify_graph.nodes.agent.agent_node -> models
|
||||
dify_graph.nodes.llm.node -> models.model
|
||||
dify_graph.nodes.agent.agent_node -> services
|
||||
|
||||
2813
api/commands.py
2813
api/commands.py
File diff suppressed because it is too large
Load Diff
71
api/commands/__init__.py
Normal file
71
api/commands/__init__.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
CLI command modules extracted from `commands.py`.
|
||||
"""
|
||||
|
||||
from .account import create_tenant, reset_email, reset_password
|
||||
from .plugin import (
|
||||
extract_plugins,
|
||||
extract_unique_plugins,
|
||||
install_plugins,
|
||||
install_rag_pipeline_plugins,
|
||||
migrate_data_for_plugin,
|
||||
setup_datasource_oauth_client,
|
||||
setup_system_tool_oauth_client,
|
||||
setup_system_trigger_oauth_client,
|
||||
transform_datasource_credentials,
|
||||
)
|
||||
from .retention import (
|
||||
archive_workflow_runs,
|
||||
clean_expired_messages,
|
||||
clean_workflow_runs,
|
||||
cleanup_orphaned_draft_variables,
|
||||
clear_free_plan_tenant_expired_logs,
|
||||
delete_archived_workflow_runs,
|
||||
export_app_messages,
|
||||
restore_workflow_runs,
|
||||
)
|
||||
from .storage import clear_orphaned_file_records, file_usage, migrate_oss, remove_orphaned_files_on_storage
|
||||
from .system import convert_to_agent_apps, fix_app_site_missing, reset_encrypt_key_pair, upgrade_db
|
||||
from .vector import (
|
||||
add_qdrant_index,
|
||||
migrate_annotation_vector_database,
|
||||
migrate_knowledge_vector_database,
|
||||
old_metadata_migration,
|
||||
vdb_migrate,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"add_qdrant_index",
|
||||
"archive_workflow_runs",
|
||||
"clean_expired_messages",
|
||||
"clean_workflow_runs",
|
||||
"cleanup_orphaned_draft_variables",
|
||||
"clear_free_plan_tenant_expired_logs",
|
||||
"clear_orphaned_file_records",
|
||||
"convert_to_agent_apps",
|
||||
"create_tenant",
|
||||
"delete_archived_workflow_runs",
|
||||
"export_app_messages",
|
||||
"extract_plugins",
|
||||
"extract_unique_plugins",
|
||||
"file_usage",
|
||||
"fix_app_site_missing",
|
||||
"install_plugins",
|
||||
"install_rag_pipeline_plugins",
|
||||
"migrate_annotation_vector_database",
|
||||
"migrate_data_for_plugin",
|
||||
"migrate_knowledge_vector_database",
|
||||
"migrate_oss",
|
||||
"old_metadata_migration",
|
||||
"remove_orphaned_files_on_storage",
|
||||
"reset_email",
|
||||
"reset_encrypt_key_pair",
|
||||
"reset_password",
|
||||
"restore_workflow_runs",
|
||||
"setup_datasource_oauth_client",
|
||||
"setup_system_tool_oauth_client",
|
||||
"setup_system_trigger_oauth_client",
|
||||
"transform_datasource_credentials",
|
||||
"upgrade_db",
|
||||
"vdb_migrate",
|
||||
]
|
||||
130
api/commands/account.py
Normal file
130
api/commands/account.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import base64
|
||||
import secrets
|
||||
|
||||
import click
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from constants.languages import languages
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email as email_validate
|
||||
from libs.password import hash_password, password_pattern, valid_password
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
|
||||
|
||||
@click.command("reset-password", help="Reset the account password.")
|
||||
@click.option("--email", prompt=True, help="Account email to reset password for")
|
||||
@click.option("--new-password", prompt=True, help="New password")
|
||||
@click.option("--password-confirm", prompt=True, help="Confirm new password")
|
||||
def reset_password(email, new_password, password_confirm):
|
||||
"""
|
||||
Reset password of owner account
|
||||
Only available in SELF_HOSTED mode
|
||||
"""
|
||||
if str(new_password).strip() != str(password_confirm).strip():
|
||||
click.echo(click.style("Passwords do not match.", fg="red"))
|
||||
return
|
||||
normalized_email = email.strip().lower()
|
||||
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
|
||||
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
|
||||
try:
|
||||
valid_password(new_password)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
|
||||
return
|
||||
|
||||
# generate password salt
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
|
||||
# encrypt password with salt
|
||||
password_hashed = hash_password(new_password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
click.echo(click.style("Password reset successfully.", fg="green"))
|
||||
|
||||
|
||||
@click.command("reset-email", help="Reset the account email.")
|
||||
@click.option("--email", prompt=True, help="Current account email")
|
||||
@click.option("--new-email", prompt=True, help="New email")
|
||||
@click.option("--email-confirm", prompt=True, help="Confirm new email")
|
||||
def reset_email(email, new_email, email_confirm):
|
||||
"""
|
||||
Replace account email
|
||||
:return:
|
||||
"""
|
||||
if str(new_email).strip() != str(email_confirm).strip():
|
||||
click.echo(click.style("New emails do not match.", fg="red"))
|
||||
return
|
||||
normalized_new_email = new_email.strip().lower()
|
||||
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
|
||||
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
|
||||
try:
|
||||
email_validate(normalized_new_email)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
|
||||
return
|
||||
|
||||
account.email = normalized_new_email
|
||||
click.echo(click.style("Email updated successfully.", fg="green"))
|
||||
|
||||
|
||||
@click.command("create-tenant", help="Create account and tenant.")
|
||||
@click.option("--email", prompt=True, help="Tenant account email.")
|
||||
@click.option("--name", prompt=True, help="Workspace name.")
|
||||
@click.option("--language", prompt=True, help="Account language, default: en-US.")
|
||||
def create_tenant(email: str, language: str | None = None, name: str | None = None):
|
||||
"""
|
||||
Create tenant account
|
||||
"""
|
||||
if not email:
|
||||
click.echo(click.style("Email is required.", fg="red"))
|
||||
return
|
||||
|
||||
# Create account
|
||||
email = email.strip().lower()
|
||||
|
||||
if "@" not in email:
|
||||
click.echo(click.style("Invalid email address.", fg="red"))
|
||||
return
|
||||
|
||||
account_name = email.split("@")[0]
|
||||
|
||||
if language not in languages:
|
||||
language = "en-US"
|
||||
|
||||
# Validates name encoding for non-Latin characters.
|
||||
name = name.strip().encode("utf-8").decode("utf-8") if name else None
|
||||
|
||||
# generate random password
|
||||
new_password = secrets.token_urlsafe(16)
|
||||
|
||||
# register account
|
||||
account = RegisterService.register(
|
||||
email=email,
|
||||
name=account_name,
|
||||
password=new_password,
|
||||
language=language,
|
||||
create_workspace_required=False,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name)
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Account and tenant created.\nAccount: {email}\nPassword: {new_password}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
467
api/commands/plugin.py
Normal file
467
api/commands/plugin.py
Normal file
@@ -0,0 +1,467 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import encrypter
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
|
||||
from extensions.ext_database import db
|
||||
from models import Tenant
|
||||
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||
from models.provider_ids import DatasourceProviderID, ToolProviderID
|
||||
from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
|
||||
from models.tools import ToolOAuthSystemClient
|
||||
from services.plugin.data_migration import PluginDataMigration
|
||||
from services.plugin.plugin_migration import PluginMigration
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@click.command("setup-system-tool-oauth-client", help="Setup system tool oauth client.")
|
||||
@click.option("--provider", prompt=True, help="Provider name")
|
||||
@click.option("--client-params", prompt=True, help="Client Params")
|
||||
def setup_system_tool_oauth_client(provider, client_params):
|
||||
"""
|
||||
Setup system tool oauth client
|
||||
"""
|
||||
provider_id = ToolProviderID(provider)
|
||||
provider_name = provider_id.provider_name
|
||||
plugin_id = provider_id.plugin_id
|
||||
|
||||
try:
|
||||
# json validate
|
||||
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
|
||||
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
|
||||
click.echo(click.style("Client params validated successfully.", fg="green"))
|
||||
|
||||
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
|
||||
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
||||
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
|
||||
click.echo(click.style("Client params encrypted successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
return
|
||||
|
||||
deleted_count = (
|
||||
db.session.query(ToolOAuthSystemClient)
|
||||
.filter_by(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
if deleted_count > 0:
|
||||
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
|
||||
|
||||
oauth_client = ToolOAuthSystemClient(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
encrypted_oauth_params=oauth_client_params,
|
||||
)
|
||||
db.session.add(oauth_client)
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
|
||||
|
||||
|
||||
@click.command("setup-system-trigger-oauth-client", help="Setup system trigger oauth client.")
|
||||
@click.option("--provider", prompt=True, help="Provider name")
|
||||
@click.option("--client-params", prompt=True, help="Client Params")
|
||||
def setup_system_trigger_oauth_client(provider, client_params):
|
||||
"""
|
||||
Setup system trigger oauth client
|
||||
"""
|
||||
from models.provider_ids import TriggerProviderID
|
||||
from models.trigger import TriggerOAuthSystemClient
|
||||
|
||||
provider_id = TriggerProviderID(provider)
|
||||
provider_name = provider_id.provider_name
|
||||
plugin_id = provider_id.plugin_id
|
||||
|
||||
try:
|
||||
# json validate
|
||||
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
|
||||
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
|
||||
click.echo(click.style("Client params validated successfully.", fg="green"))
|
||||
|
||||
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
|
||||
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
||||
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
|
||||
click.echo(click.style("Client params encrypted successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
return
|
||||
|
||||
deleted_count = (
|
||||
db.session.query(TriggerOAuthSystemClient)
|
||||
.filter_by(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
if deleted_count > 0:
|
||||
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
|
||||
|
||||
oauth_client = TriggerOAuthSystemClient(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
encrypted_oauth_params=oauth_client_params,
|
||||
)
|
||||
db.session.add(oauth_client)
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
|
||||
|
||||
|
||||
@click.command("setup-datasource-oauth-client", help="Setup datasource oauth client.")
|
||||
@click.option("--provider", prompt=True, help="Provider name")
|
||||
@click.option("--client-params", prompt=True, help="Client Params")
|
||||
def setup_datasource_oauth_client(provider, client_params):
|
||||
"""
|
||||
Setup datasource oauth client
|
||||
"""
|
||||
provider_id = DatasourceProviderID(provider)
|
||||
provider_name = provider_id.provider_name
|
||||
plugin_id = provider_id.plugin_id
|
||||
|
||||
try:
|
||||
# json validate
|
||||
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
|
||||
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
|
||||
click.echo(click.style("Client params validated successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
return
|
||||
|
||||
click.echo(click.style(f"Ready to delete existing oauth client params: {provider_name}", fg="yellow"))
|
||||
deleted_count = (
|
||||
db.session.query(DatasourceOauthParamConfig)
|
||||
.filter_by(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
if deleted_count > 0:
|
||||
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
|
||||
|
||||
click.echo(click.style(f"Ready to setup datasource oauth client: {provider_name}", fg="yellow"))
|
||||
oauth_client = DatasourceOauthParamConfig(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
system_credentials=client_params_dict,
|
||||
)
|
||||
db.session.add(oauth_client)
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"provider: {provider_name}", fg="green"))
|
||||
click.echo(click.style(f"plugin_id: {plugin_id}", fg="green"))
|
||||
click.echo(click.style(f"params: {json.dumps(client_params_dict, indent=2, ensure_ascii=False)}", fg="green"))
|
||||
click.echo(click.style(f"Datasource oauth client setup successfully. id: {oauth_client.id}", fg="green"))
|
||||
|
||||
|
||||
@click.command("transform-datasource-credentials", help="Transform datasource credentials.")
|
||||
@click.option(
|
||||
"--environment", prompt=True, help="the environment to transform datasource credentials", default="online"
|
||||
)
|
||||
def transform_datasource_credentials(environment: str):
|
||||
"""
|
||||
Transform datasource credentials
|
||||
"""
|
||||
try:
|
||||
installer_manager = PluginInstaller()
|
||||
plugin_migration = PluginMigration()
|
||||
|
||||
notion_plugin_id = "langgenius/notion_datasource"
|
||||
firecrawl_plugin_id = "langgenius/firecrawl_datasource"
|
||||
jina_plugin_id = "langgenius/jina_datasource"
|
||||
if environment == "online":
|
||||
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
else:
|
||||
notion_plugin_unique_identifier = None
|
||||
firecrawl_plugin_unique_identifier = None
|
||||
jina_plugin_unique_identifier = None
|
||||
oauth_credential_type = CredentialType.OAUTH2
|
||||
api_key_credential_type = CredentialType.API_KEY
|
||||
|
||||
# deal notion credentials
|
||||
deal_notion_count = 0
|
||||
notion_credentials = db.session.query(DataSourceOauthBinding).filter_by(provider="notion").all()
|
||||
if notion_credentials:
|
||||
notion_credentials_tenant_mapping: dict[str, list[DataSourceOauthBinding]] = {}
|
||||
for notion_credential in notion_credentials:
|
||||
tenant_id = notion_credential.tenant_id
|
||||
if tenant_id not in notion_credentials_tenant_mapping:
|
||||
notion_credentials_tenant_mapping[tenant_id] = []
|
||||
notion_credentials_tenant_mapping[tenant_id].append(notion_credential)
|
||||
for tenant_id, notion_tenant_credentials in notion_credentials_tenant_mapping.items():
|
||||
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
|
||||
if not tenant:
|
||||
continue
|
||||
try:
|
||||
# check notion plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if notion_plugin_id not in installed_plugins_ids:
|
||||
if notion_plugin_unique_identifier:
|
||||
# install notion plugin
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier])
|
||||
auth_count = 0
|
||||
for notion_tenant_credential in notion_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential oauth params
|
||||
access_token = notion_tenant_credential.access_token
|
||||
# notion info
|
||||
notion_info = notion_tenant_credential.source_info
|
||||
workspace_id = notion_info.get("workspace_id")
|
||||
workspace_name = notion_info.get("workspace_name")
|
||||
workspace_icon = notion_info.get("workspace_icon")
|
||||
new_credentials = {
|
||||
"integration_secret": encrypter.encrypt_token(tenant_id, access_token),
|
||||
"workspace_id": workspace_id,
|
||||
"workspace_name": workspace_name,
|
||||
"workspace_icon": workspace_icon,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="notion_datasource",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=notion_plugin_id,
|
||||
auth_type=oauth_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url=workspace_icon or "default",
|
||||
is_default=False,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_notion_count += 1
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Error transforming notion credentials: {str(e)}, tenant_id: {tenant_id}", fg="red"
|
||||
)
|
||||
)
|
||||
continue
|
||||
db.session.commit()
|
||||
# deal firecrawl credentials
|
||||
deal_firecrawl_count = 0
|
||||
firecrawl_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="firecrawl").all()
|
||||
if firecrawl_credentials:
|
||||
firecrawl_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
|
||||
for firecrawl_credential in firecrawl_credentials:
|
||||
tenant_id = firecrawl_credential.tenant_id
|
||||
if tenant_id not in firecrawl_credentials_tenant_mapping:
|
||||
firecrawl_credentials_tenant_mapping[tenant_id] = []
|
||||
firecrawl_credentials_tenant_mapping[tenant_id].append(firecrawl_credential)
|
||||
for tenant_id, firecrawl_tenant_credentials in firecrawl_credentials_tenant_mapping.items():
|
||||
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
|
||||
if not tenant:
|
||||
continue
|
||||
try:
|
||||
# check firecrawl plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if firecrawl_plugin_id not in installed_plugins_ids:
|
||||
if firecrawl_plugin_unique_identifier:
|
||||
# install firecrawl plugin
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier])
|
||||
|
||||
auth_count = 0
|
||||
for firecrawl_tenant_credential in firecrawl_tenant_credentials:
|
||||
auth_count += 1
|
||||
if not firecrawl_tenant_credential.credentials:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Skipping firecrawl credential for tenant {tenant_id} due to missing credentials.",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
continue
|
||||
# get credential api key
|
||||
credentials_json = json.loads(firecrawl_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
base_url = credentials_json.get("config", {}).get("base_url")
|
||||
new_credentials = {
|
||||
"firecrawl_api_key": api_key,
|
||||
"base_url": base_url,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="firecrawl",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=firecrawl_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url="default",
|
||||
is_default=False,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_firecrawl_count += 1
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Error transforming firecrawl credentials: {str(e)}, tenant_id: {tenant_id}", fg="red"
|
||||
)
|
||||
)
|
||||
continue
|
||||
db.session.commit()
|
||||
# deal jina credentials
|
||||
deal_jina_count = 0
|
||||
jina_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="jinareader").all()
|
||||
if jina_credentials:
|
||||
jina_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
|
||||
for jina_credential in jina_credentials:
|
||||
tenant_id = jina_credential.tenant_id
|
||||
if tenant_id not in jina_credentials_tenant_mapping:
|
||||
jina_credentials_tenant_mapping[tenant_id] = []
|
||||
jina_credentials_tenant_mapping[tenant_id].append(jina_credential)
|
||||
for tenant_id, jina_tenant_credentials in jina_credentials_tenant_mapping.items():
|
||||
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
|
||||
if not tenant:
|
||||
continue
|
||||
try:
|
||||
# check jina plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if jina_plugin_id not in installed_plugins_ids:
|
||||
if jina_plugin_unique_identifier:
|
||||
# install jina plugin
|
||||
logger.debug("Installing Jina plugin %s", jina_plugin_unique_identifier)
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier])
|
||||
|
||||
auth_count = 0
|
||||
for jina_tenant_credential in jina_tenant_credentials:
|
||||
auth_count += 1
|
||||
if not jina_tenant_credential.credentials:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Skipping jina credential for tenant {tenant_id} due to missing credentials.",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
continue
|
||||
# get credential api key
|
||||
credentials_json = json.loads(jina_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
new_credentials = {
|
||||
"integration_secret": api_key,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="jinareader",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=jina_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url="default",
|
||||
is_default=False,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_jina_count += 1
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style(f"Error transforming jina credentials: {str(e)}, tenant_id: {tenant_id}", fg="red")
|
||||
)
|
||||
continue
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
return
|
||||
click.echo(click.style(f"Transforming notion successfully. deal_notion_count: {deal_notion_count}", fg="green"))
|
||||
click.echo(
|
||||
click.style(f"Transforming firecrawl successfully. deal_firecrawl_count: {deal_firecrawl_count}", fg="green")
|
||||
)
|
||||
click.echo(click.style(f"Transforming jina successfully. deal_jina_count: {deal_jina_count}", fg="green"))
|
||||
|
||||
|
||||
@click.command("migrate-data-for-plugin", help="Migrate data for plugin.")
|
||||
def migrate_data_for_plugin():
|
||||
"""
|
||||
Migrate data for plugin.
|
||||
"""
|
||||
click.echo(click.style("Starting migrate data for plugin.", fg="white"))
|
||||
|
||||
PluginDataMigration.migrate()
|
||||
|
||||
click.echo(click.style("Migrate data for plugin completed.", fg="green"))
|
||||
|
||||
|
||||
@click.command("extract-plugins", help="Extract plugins.")
|
||||
@click.option("--output_file", prompt=True, help="The file to store the extracted plugins.", default="plugins.jsonl")
|
||||
@click.option("--workers", prompt=True, help="The number of workers to extract plugins.", default=10)
|
||||
def extract_plugins(output_file: str, workers: int):
|
||||
"""
|
||||
Extract plugins.
|
||||
"""
|
||||
click.echo(click.style("Starting extract plugins.", fg="white"))
|
||||
|
||||
PluginMigration.extract_plugins(output_file, workers)
|
||||
|
||||
click.echo(click.style("Extract plugins completed.", fg="green"))
|
||||
|
||||
|
||||
@click.command("extract-unique-identifiers", help="Extract unique identifiers.")
|
||||
@click.option(
|
||||
"--output_file",
|
||||
prompt=True,
|
||||
help="The file to store the extracted unique identifiers.",
|
||||
default="unique_identifiers.json",
|
||||
)
|
||||
@click.option(
|
||||
"--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl"
|
||||
)
|
||||
def extract_unique_plugins(output_file: str, input_file: str):
|
||||
"""
|
||||
Extract unique plugins.
|
||||
"""
|
||||
click.echo(click.style("Starting extract unique plugins.", fg="white"))
|
||||
|
||||
PluginMigration.extract_unique_plugins_to_file(input_file, output_file)
|
||||
|
||||
click.echo(click.style("Extract unique plugins completed.", fg="green"))
|
||||
|
||||
|
||||
@click.command("install-plugins", help="Install plugins.")
|
||||
@click.option(
|
||||
"--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl"
|
||||
)
|
||||
@click.option(
|
||||
"--output_file", prompt=True, help="The file to store the installed plugins.", default="installed_plugins.jsonl"
|
||||
)
|
||||
@click.option("--workers", prompt=True, help="The number of workers to install plugins.", default=100)
|
||||
def install_plugins(input_file: str, output_file: str, workers: int):
|
||||
"""
|
||||
Install plugins.
|
||||
"""
|
||||
click.echo(click.style("Starting install plugins.", fg="white"))
|
||||
|
||||
PluginMigration.install_plugins(input_file, output_file, workers)
|
||||
|
||||
click.echo(click.style("Install plugins completed.", fg="green"))
|
||||
|
||||
|
||||
@click.command("install-rag-pipeline-plugins", help="Install rag pipeline plugins.")
|
||||
@click.option(
|
||||
"--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl"
|
||||
)
|
||||
@click.option(
|
||||
"--output_file", prompt=True, help="The file to store the installed plugins.", default="installed_plugins.jsonl"
|
||||
)
|
||||
@click.option("--workers", prompt=True, help="The number of workers to install plugins.", default=100)
|
||||
def install_rag_pipeline_plugins(input_file, output_file, workers):
|
||||
"""
|
||||
Install rag pipeline plugins
|
||||
"""
|
||||
click.echo(click.style("Installing rag pipeline plugins", fg="yellow"))
|
||||
plugin_migration = PluginMigration()
|
||||
plugin_migration.install_rag_pipeline_plugins(
|
||||
input_file,
|
||||
output_file,
|
||||
workers,
|
||||
)
|
||||
click.echo(click.style("Installing rag pipeline plugins successfully", fg="green"))
|
||||
830
api/commands/retention.py
Normal file
830
api/commands/retention.py
Normal file
@@ -0,0 +1,830 @@
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs
|
||||
from services.retention.conversation.messages_clean_policy import create_message_clean_policy
|
||||
from services.retention.conversation.messages_clean_service import MessagesCleanService
|
||||
from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup
|
||||
from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@click.command("clear-free-plan-tenant-expired-logs", help="Clear free plan tenant expired logs.")
|
||||
@click.option("--days", prompt=True, help="The days to clear free plan tenant expired logs.", default=30)
|
||||
@click.option("--batch", prompt=True, help="The batch size to clear free plan tenant expired logs.", default=100)
|
||||
@click.option(
|
||||
"--tenant_ids",
|
||||
prompt=True,
|
||||
multiple=True,
|
||||
help="The tenant ids to clear free plan tenant expired logs.",
|
||||
)
|
||||
def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[str]):
|
||||
"""
|
||||
Clear free plan tenant expired logs.
|
||||
"""
|
||||
click.echo(click.style("Starting clear free plan tenant expired logs.", fg="white"))
|
||||
|
||||
ClearFreePlanTenantExpiredLogs.process(days, batch, tenant_ids)
|
||||
|
||||
click.echo(click.style("Clear free plan tenant expired logs completed.", fg="green"))
|
||||
|
||||
|
||||
@click.command("clean-workflow-runs", help="Clean expired workflow runs and related data for free tenants.")
|
||||
@click.option(
|
||||
"--before-days",
|
||||
"--days",
|
||||
default=30,
|
||||
show_default=True,
|
||||
type=click.IntRange(min=0),
|
||||
help="Delete workflow runs created before N days ago.",
|
||||
)
|
||||
@click.option("--batch-size", default=200, show_default=True, help="Batch size for selecting workflow runs.")
|
||||
@click.option(
|
||||
"--from-days-ago",
|
||||
default=None,
|
||||
type=click.IntRange(min=0),
|
||||
help="Lower bound in days ago (older). Must be paired with --to-days-ago.",
|
||||
)
|
||||
@click.option(
|
||||
"--to-days-ago",
|
||||
default=None,
|
||||
type=click.IntRange(min=0),
|
||||
help="Upper bound in days ago (newer). Must be paired with --from-days-ago.",
|
||||
)
|
||||
@click.option(
|
||||
"--start-from",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.",
|
||||
)
|
||||
@click.option(
|
||||
"--end-before",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.",
|
||||
)
|
||||
@click.option(
|
||||
"--dry-run",
|
||||
is_flag=True,
|
||||
help="Preview cleanup results without deleting any workflow run data.",
|
||||
)
|
||||
def clean_workflow_runs(
|
||||
before_days: int,
|
||||
batch_size: int,
|
||||
from_days_ago: int | None,
|
||||
to_days_ago: int | None,
|
||||
start_from: datetime.datetime | None,
|
||||
end_before: datetime.datetime | None,
|
||||
dry_run: bool,
|
||||
):
|
||||
"""
|
||||
Clean workflow runs and related workflow data for free tenants.
|
||||
"""
|
||||
if (start_from is None) ^ (end_before is None):
|
||||
raise click.UsageError("--start-from and --end-before must be provided together.")
|
||||
|
||||
if (from_days_ago is None) ^ (to_days_ago is None):
|
||||
raise click.UsageError("--from-days-ago and --to-days-ago must be provided together.")
|
||||
|
||||
if from_days_ago is not None and to_days_ago is not None:
|
||||
if start_from or end_before:
|
||||
raise click.UsageError("Choose either day offsets or explicit dates, not both.")
|
||||
if from_days_ago <= to_days_ago:
|
||||
raise click.UsageError("--from-days-ago must be greater than --to-days-ago.")
|
||||
now = datetime.datetime.now()
|
||||
start_from = now - datetime.timedelta(days=from_days_ago)
|
||||
end_before = now - datetime.timedelta(days=to_days_ago)
|
||||
before_days = 0
|
||||
|
||||
start_time = datetime.datetime.now(datetime.UTC)
|
||||
click.echo(click.style(f"Starting workflow run cleanup at {start_time.isoformat()}.", fg="white"))
|
||||
|
||||
WorkflowRunCleanup(
|
||||
days=before_days,
|
||||
batch_size=batch_size,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
dry_run=dry_run,
|
||||
).run()
|
||||
|
||||
end_time = datetime.datetime.now(datetime.UTC)
|
||||
elapsed = end_time - start_time
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Workflow run cleanup completed. start={start_time.isoformat()} "
|
||||
f"end={end_time.isoformat()} duration={elapsed}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@click.command(
|
||||
"archive-workflow-runs",
|
||||
help="Archive workflow runs for paid plan tenants to S3-compatible storage.",
|
||||
)
|
||||
@click.option("--tenant-ids", default=None, help="Optional comma-separated tenant IDs for grayscale rollout.")
|
||||
@click.option("--before-days", default=90, show_default=True, help="Archive runs older than N days.")
|
||||
@click.option(
|
||||
"--from-days-ago",
|
||||
default=None,
|
||||
type=click.IntRange(min=0),
|
||||
help="Lower bound in days ago (older). Must be paired with --to-days-ago.",
|
||||
)
|
||||
@click.option(
|
||||
"--to-days-ago",
|
||||
default=None,
|
||||
type=click.IntRange(min=0),
|
||||
help="Upper bound in days ago (newer). Must be paired with --from-days-ago.",
|
||||
)
|
||||
@click.option(
|
||||
"--start-from",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Archive runs created at or after this timestamp (UTC if no timezone).",
|
||||
)
|
||||
@click.option(
|
||||
"--end-before",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Archive runs created before this timestamp (UTC if no timezone).",
|
||||
)
|
||||
@click.option("--batch-size", default=100, show_default=True, help="Batch size for processing.")
|
||||
@click.option("--workers", default=1, show_default=True, type=int, help="Concurrent workflow runs to archive.")
|
||||
@click.option("--limit", default=None, type=int, help="Maximum number of runs to archive.")
|
||||
@click.option("--dry-run", is_flag=True, help="Preview without archiving.")
|
||||
@click.option("--delete-after-archive", is_flag=True, help="Delete runs and related data after archiving.")
|
||||
def archive_workflow_runs(
|
||||
tenant_ids: str | None,
|
||||
before_days: int,
|
||||
from_days_ago: int | None,
|
||||
to_days_ago: int | None,
|
||||
start_from: datetime.datetime | None,
|
||||
end_before: datetime.datetime | None,
|
||||
batch_size: int,
|
||||
workers: int,
|
||||
limit: int | None,
|
||||
dry_run: bool,
|
||||
delete_after_archive: bool,
|
||||
):
|
||||
"""
|
||||
Archive workflow runs for paid plan tenants older than the specified days.
|
||||
|
||||
This command archives the following tables to storage:
|
||||
- workflow_node_executions
|
||||
- workflow_node_execution_offload
|
||||
- workflow_pauses
|
||||
- workflow_pause_reasons
|
||||
- workflow_trigger_logs
|
||||
|
||||
The workflow_runs and workflow_app_logs tables are preserved for UI listing.
|
||||
"""
|
||||
from services.retention.workflow_run.archive_paid_plan_workflow_run import WorkflowRunArchiver
|
||||
|
||||
run_started_at = datetime.datetime.now(datetime.UTC)
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Starting workflow run archiving at {run_started_at.isoformat()}.",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
|
||||
if (start_from is None) ^ (end_before is None):
|
||||
click.echo(click.style("start-from and end-before must be provided together.", fg="red"))
|
||||
return
|
||||
|
||||
if (from_days_ago is None) ^ (to_days_ago is None):
|
||||
click.echo(click.style("from-days-ago and to-days-ago must be provided together.", fg="red"))
|
||||
return
|
||||
|
||||
if from_days_ago is not None and to_days_ago is not None:
|
||||
if start_from or end_before:
|
||||
click.echo(click.style("Choose either day offsets or explicit dates, not both.", fg="red"))
|
||||
return
|
||||
if from_days_ago <= to_days_ago:
|
||||
click.echo(click.style("from-days-ago must be greater than to-days-ago.", fg="red"))
|
||||
return
|
||||
now = datetime.datetime.now()
|
||||
start_from = now - datetime.timedelta(days=from_days_ago)
|
||||
end_before = now - datetime.timedelta(days=to_days_ago)
|
||||
before_days = 0
|
||||
|
||||
if start_from and end_before and start_from >= end_before:
|
||||
click.echo(click.style("start-from must be earlier than end-before.", fg="red"))
|
||||
return
|
||||
if workers < 1:
|
||||
click.echo(click.style("workers must be at least 1.", fg="red"))
|
||||
return
|
||||
|
||||
archiver = WorkflowRunArchiver(
|
||||
days=before_days,
|
||||
batch_size=batch_size,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
workers=workers,
|
||||
tenant_ids=[tid.strip() for tid in tenant_ids.split(",")] if tenant_ids else None,
|
||||
limit=limit,
|
||||
dry_run=dry_run,
|
||||
delete_after_archive=delete_after_archive,
|
||||
)
|
||||
summary = archiver.run()
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Summary: processed={summary.total_runs_processed}, archived={summary.runs_archived}, "
|
||||
f"skipped={summary.runs_skipped}, failed={summary.runs_failed}, "
|
||||
f"time={summary.total_elapsed_time:.2f}s",
|
||||
fg="cyan",
|
||||
)
|
||||
)
|
||||
|
||||
run_finished_at = datetime.datetime.now(datetime.UTC)
|
||||
elapsed = run_finished_at - run_started_at
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Workflow run archiving completed. start={run_started_at.isoformat()} "
|
||||
f"end={run_finished_at.isoformat()} duration={elapsed}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@click.command(
|
||||
"restore-workflow-runs",
|
||||
help="Restore archived workflow runs from S3-compatible storage.",
|
||||
)
|
||||
@click.option(
|
||||
"--tenant-ids",
|
||||
required=False,
|
||||
help="Tenant IDs (comma-separated).",
|
||||
)
|
||||
@click.option("--run-id", required=False, help="Workflow run ID to restore.")
|
||||
@click.option(
|
||||
"--start-from",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.",
|
||||
)
|
||||
@click.option(
|
||||
"--end-before",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.",
|
||||
)
|
||||
@click.option("--workers", default=1, show_default=True, type=int, help="Concurrent workflow runs to restore.")
|
||||
@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of runs to restore.")
|
||||
@click.option("--dry-run", is_flag=True, help="Preview without restoring.")
|
||||
def restore_workflow_runs(
|
||||
tenant_ids: str | None,
|
||||
run_id: str | None,
|
||||
start_from: datetime.datetime | None,
|
||||
end_before: datetime.datetime | None,
|
||||
workers: int,
|
||||
limit: int,
|
||||
dry_run: bool,
|
||||
):
|
||||
"""
|
||||
Restore an archived workflow run from storage to the database.
|
||||
|
||||
This restores the following tables:
|
||||
- workflow_node_executions
|
||||
- workflow_node_execution_offload
|
||||
- workflow_pauses
|
||||
- workflow_pause_reasons
|
||||
- workflow_trigger_logs
|
||||
"""
|
||||
from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore
|
||||
|
||||
parsed_tenant_ids = None
|
||||
if tenant_ids:
|
||||
parsed_tenant_ids = [tid.strip() for tid in tenant_ids.split(",") if tid.strip()]
|
||||
if not parsed_tenant_ids:
|
||||
raise click.BadParameter("tenant-ids must not be empty")
|
||||
|
||||
if (start_from is None) ^ (end_before is None):
|
||||
raise click.UsageError("--start-from and --end-before must be provided together.")
|
||||
if run_id is None and (start_from is None or end_before is None):
|
||||
raise click.UsageError("--start-from and --end-before are required for batch restore.")
|
||||
if workers < 1:
|
||||
raise click.BadParameter("workers must be at least 1")
|
||||
|
||||
start_time = datetime.datetime.now(datetime.UTC)
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Starting restore of workflow run {run_id} at {start_time.isoformat()}.",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
|
||||
restorer = WorkflowRunRestore(dry_run=dry_run, workers=workers)
|
||||
if run_id:
|
||||
results = [restorer.restore_by_run_id(run_id)]
|
||||
else:
|
||||
assert start_from is not None
|
||||
assert end_before is not None
|
||||
results = restorer.restore_batch(
|
||||
parsed_tenant_ids,
|
||||
start_date=start_from,
|
||||
end_date=end_before,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
end_time = datetime.datetime.now(datetime.UTC)
|
||||
elapsed = end_time - start_time
|
||||
|
||||
successes = sum(1 for result in results if result.success)
|
||||
failures = len(results) - successes
|
||||
|
||||
if failures == 0:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Restore completed successfully. success={successes} duration={elapsed}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
else:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Restore completed with failures. success={successes} failed={failures} duration={elapsed}",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@click.command(
|
||||
"delete-archived-workflow-runs",
|
||||
help="Delete archived workflow runs from the database.",
|
||||
)
|
||||
@click.option(
|
||||
"--tenant-ids",
|
||||
required=False,
|
||||
help="Tenant IDs (comma-separated).",
|
||||
)
|
||||
@click.option("--run-id", required=False, help="Workflow run ID to delete.")
|
||||
@click.option(
|
||||
"--start-from",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.",
|
||||
)
|
||||
@click.option(
|
||||
"--end-before",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.",
|
||||
)
|
||||
@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of runs to delete.")
|
||||
@click.option("--dry-run", is_flag=True, help="Preview without deleting.")
|
||||
def delete_archived_workflow_runs(
|
||||
tenant_ids: str | None,
|
||||
run_id: str | None,
|
||||
start_from: datetime.datetime | None,
|
||||
end_before: datetime.datetime | None,
|
||||
limit: int,
|
||||
dry_run: bool,
|
||||
):
|
||||
"""
|
||||
Delete archived workflow runs from the database.
|
||||
"""
|
||||
from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion
|
||||
|
||||
parsed_tenant_ids = None
|
||||
if tenant_ids:
|
||||
parsed_tenant_ids = [tid.strip() for tid in tenant_ids.split(",") if tid.strip()]
|
||||
if not parsed_tenant_ids:
|
||||
raise click.BadParameter("tenant-ids must not be empty")
|
||||
|
||||
if (start_from is None) ^ (end_before is None):
|
||||
raise click.UsageError("--start-from and --end-before must be provided together.")
|
||||
if run_id is None and (start_from is None or end_before is None):
|
||||
raise click.UsageError("--start-from and --end-before are required for batch delete.")
|
||||
|
||||
start_time = datetime.datetime.now(datetime.UTC)
|
||||
target_desc = f"workflow run {run_id}" if run_id else "workflow runs"
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Starting delete of {target_desc} at {start_time.isoformat()}.",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
|
||||
deleter = ArchivedWorkflowRunDeletion(dry_run=dry_run)
|
||||
if run_id:
|
||||
results = [deleter.delete_by_run_id(run_id)]
|
||||
else:
|
||||
assert start_from is not None
|
||||
assert end_before is not None
|
||||
results = deleter.delete_batch(
|
||||
parsed_tenant_ids,
|
||||
start_date=start_from,
|
||||
end_date=end_before,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
for result in results:
|
||||
if result.success:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"{'[DRY RUN] Would delete' if dry_run else 'Deleted'} "
|
||||
f"workflow run {result.run_id} (tenant={result.tenant_id})",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
else:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Failed to delete workflow run {result.run_id}: {result.error}",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
|
||||
end_time = datetime.datetime.now(datetime.UTC)
|
||||
elapsed = end_time - start_time
|
||||
|
||||
successes = sum(1 for result in results if result.success)
|
||||
failures = len(results) - successes
|
||||
|
||||
if failures == 0:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Delete completed successfully. success={successes} duration={elapsed}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
else:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Delete completed with failures. success={successes} failed={failures} duration={elapsed}",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
|
||||
"""
|
||||
Find draft variables that reference non-existent apps.
|
||||
|
||||
Args:
|
||||
batch_size: Maximum number of orphaned app IDs to return
|
||||
|
||||
Returns:
|
||||
List of app IDs that have draft variables but don't exist in the apps table
|
||||
"""
|
||||
query = """
|
||||
SELECT DISTINCT wdv.app_id
|
||||
FROM workflow_draft_variables AS wdv
|
||||
WHERE NOT EXISTS(
|
||||
SELECT 1 FROM apps WHERE apps.id = wdv.app_id
|
||||
)
|
||||
LIMIT :batch_size
|
||||
"""
|
||||
|
||||
with db.engine.connect() as conn:
|
||||
result = conn.execute(sa.text(query), {"batch_size": batch_size})
|
||||
return [row[0] for row in result]
|
||||
|
||||
|
||||
def _count_orphaned_draft_variables() -> dict[str, Any]:
|
||||
"""
|
||||
Count orphaned draft variables by app, including associated file counts.
|
||||
|
||||
Returns:
|
||||
Dictionary with statistics about orphaned variables and files
|
||||
"""
|
||||
# Count orphaned variables by app
|
||||
variables_query = """
|
||||
SELECT
|
||||
wdv.app_id,
|
||||
COUNT(*) as variable_count,
|
||||
COUNT(wdv.file_id) as file_count
|
||||
FROM workflow_draft_variables AS wdv
|
||||
WHERE NOT EXISTS(
|
||||
SELECT 1 FROM apps WHERE apps.id = wdv.app_id
|
||||
)
|
||||
GROUP BY wdv.app_id
|
||||
ORDER BY variable_count DESC
|
||||
"""
|
||||
|
||||
with db.engine.connect() as conn:
|
||||
result = conn.execute(sa.text(variables_query))
|
||||
orphaned_by_app = {}
|
||||
total_files = 0
|
||||
|
||||
for row in result:
|
||||
app_id, variable_count, file_count = row
|
||||
orphaned_by_app[app_id] = {"variables": variable_count, "files": file_count}
|
||||
total_files += file_count
|
||||
|
||||
total_orphaned = sum(app_data["variables"] for app_data in orphaned_by_app.values())
|
||||
app_count = len(orphaned_by_app)
|
||||
|
||||
return {
|
||||
"total_orphaned_variables": total_orphaned,
|
||||
"total_orphaned_files": total_files,
|
||||
"orphaned_app_count": app_count,
|
||||
"orphaned_by_app": orphaned_by_app,
|
||||
}
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--dry-run", is_flag=True, help="Show what would be deleted without actually deleting")
|
||||
@click.option("--batch-size", default=1000, help="Number of records to process per batch (default 1000)")
|
||||
@click.option("--max-apps", default=None, type=int, help="Maximum number of apps to process (default: no limit)")
|
||||
@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.")
|
||||
def cleanup_orphaned_draft_variables(
|
||||
dry_run: bool,
|
||||
batch_size: int,
|
||||
max_apps: int | None,
|
||||
force: bool = False,
|
||||
):
|
||||
"""
|
||||
Clean up orphaned draft variables from the database.
|
||||
|
||||
This script finds and removes draft variables that belong to apps
|
||||
that no longer exist in the database.
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Get statistics
|
||||
stats = _count_orphaned_draft_variables()
|
||||
|
||||
logger.info("Found %s orphaned draft variables", stats["total_orphaned_variables"])
|
||||
logger.info("Found %s associated offload files", stats["total_orphaned_files"])
|
||||
logger.info("Across %s non-existent apps", stats["orphaned_app_count"])
|
||||
|
||||
if stats["total_orphaned_variables"] == 0:
|
||||
logger.info("No orphaned draft variables found. Exiting.")
|
||||
return
|
||||
|
||||
if dry_run:
|
||||
logger.info("DRY RUN: Would delete the following:")
|
||||
for app_id, data in sorted(stats["orphaned_by_app"].items(), key=lambda x: x[1]["variables"], reverse=True)[
|
||||
:10
|
||||
]: # Show top 10
|
||||
logger.info(" App %s: %s variables, %s files", app_id, data["variables"], data["files"])
|
||||
if len(stats["orphaned_by_app"]) > 10:
|
||||
logger.info(" ... and %s more apps", len(stats["orphaned_by_app"]) - 10)
|
||||
return
|
||||
|
||||
# Confirm deletion
|
||||
if not force:
|
||||
click.confirm(
|
||||
f"Are you sure you want to delete {stats['total_orphaned_variables']} "
|
||||
f"orphaned draft variables and {stats['total_orphaned_files']} associated files "
|
||||
f"from {stats['orphaned_app_count']} apps?",
|
||||
abort=True,
|
||||
)
|
||||
|
||||
total_deleted = 0
|
||||
processed_apps = 0
|
||||
|
||||
while True:
|
||||
if max_apps and processed_apps >= max_apps:
|
||||
logger.info("Reached maximum app limit (%s). Stopping.", max_apps)
|
||||
break
|
||||
|
||||
orphaned_app_ids = _find_orphaned_draft_variables(batch_size=10)
|
||||
if not orphaned_app_ids:
|
||||
logger.info("No more orphaned draft variables found.")
|
||||
break
|
||||
|
||||
for app_id in orphaned_app_ids:
|
||||
if max_apps and processed_apps >= max_apps:
|
||||
break
|
||||
|
||||
try:
|
||||
deleted_count = delete_draft_variables_batch(app_id, batch_size)
|
||||
total_deleted += deleted_count
|
||||
processed_apps += 1
|
||||
|
||||
logger.info("Deleted %s variables for app %s", deleted_count, app_id)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error processing app %s", app_id)
|
||||
continue
|
||||
|
||||
logger.info("Cleanup completed. Total deleted: %s variables across %s apps", total_deleted, processed_apps)
|
||||
|
||||
|
||||
@click.command("clean-expired-messages", help="Clean expired messages.")
|
||||
@click.option(
|
||||
"--start-from",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
required=False,
|
||||
default=None,
|
||||
help="Lower bound (inclusive) for created_at.",
|
||||
)
|
||||
@click.option(
|
||||
"--end-before",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
required=False,
|
||||
default=None,
|
||||
help="Upper bound (exclusive) for created_at.",
|
||||
)
|
||||
@click.option(
|
||||
"--from-days-ago",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Relative lower bound in days ago (inclusive). Must be used with --before-days.",
|
||||
)
|
||||
@click.option(
|
||||
"--before-days",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Relative upper bound in days ago (exclusive). Required for relative mode.",
|
||||
)
|
||||
@click.option("--batch-size", default=1000, show_default=True, help="Batch size for selecting messages.")
|
||||
@click.option(
|
||||
"--graceful-period",
|
||||
default=21,
|
||||
show_default=True,
|
||||
help="Graceful period in days after subscription expiration, will be ignored when billing is disabled.",
|
||||
)
|
||||
@click.option("--dry-run", is_flag=True, default=False, help="Show messages logs would be cleaned without deleting")
|
||||
def clean_expired_messages(
|
||||
batch_size: int,
|
||||
graceful_period: int,
|
||||
start_from: datetime.datetime | None,
|
||||
end_before: datetime.datetime | None,
|
||||
from_days_ago: int | None,
|
||||
before_days: int | None,
|
||||
dry_run: bool,
|
||||
):
|
||||
"""
|
||||
Clean expired messages and related data for tenants based on clean policy.
|
||||
"""
|
||||
click.echo(click.style("clean_messages: start clean messages.", fg="green"))
|
||||
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
abs_mode = start_from is not None and end_before is not None
|
||||
rel_mode = before_days is not None
|
||||
|
||||
if abs_mode and rel_mode:
|
||||
raise click.UsageError(
|
||||
"Options are mutually exclusive: use either (--start-from,--end-before) "
|
||||
"or (--from-days-ago,--before-days)."
|
||||
)
|
||||
|
||||
if from_days_ago is not None and before_days is None:
|
||||
raise click.UsageError("--from-days-ago must be used together with --before-days.")
|
||||
|
||||
if (start_from is None) ^ (end_before is None):
|
||||
raise click.UsageError("Both --start-from and --end-before are required when using absolute time range.")
|
||||
|
||||
if not abs_mode and not rel_mode:
|
||||
raise click.UsageError(
|
||||
"You must provide either (--start-from,--end-before) or (--before-days [--from-days-ago])."
|
||||
)
|
||||
|
||||
if rel_mode:
|
||||
assert before_days is not None
|
||||
if before_days < 0:
|
||||
raise click.UsageError("--before-days must be >= 0.")
|
||||
if from_days_ago is not None:
|
||||
if from_days_ago < 0:
|
||||
raise click.UsageError("--from-days-ago must be >= 0.")
|
||||
if from_days_ago <= before_days:
|
||||
raise click.UsageError("--from-days-ago must be greater than --before-days.")
|
||||
|
||||
# Create policy based on billing configuration
|
||||
# NOTE: graceful_period will be ignored when billing is disabled.
|
||||
policy = create_message_clean_policy(graceful_period_days=graceful_period)
|
||||
|
||||
# Create and run the cleanup service
|
||||
if abs_mode:
|
||||
assert start_from is not None
|
||||
assert end_before is not None
|
||||
service = MessagesCleanService.from_time_range(
|
||||
policy=policy,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
elif from_days_ago is None:
|
||||
assert before_days is not None
|
||||
service = MessagesCleanService.from_days(
|
||||
policy=policy,
|
||||
days=before_days,
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
else:
|
||||
assert before_days is not None
|
||||
assert from_days_ago is not None
|
||||
now = naive_utc_now()
|
||||
service = MessagesCleanService.from_time_range(
|
||||
policy=policy,
|
||||
start_from=now - datetime.timedelta(days=from_days_ago),
|
||||
end_before=now - datetime.timedelta(days=before_days),
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
stats = service.run()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
click.echo(
|
||||
click.style(
|
||||
f"clean_messages: completed successfully\n"
|
||||
f" - Latency: {end_at - start_at:.2f}s\n"
|
||||
f" - Batches processed: {stats['batches']}\n"
|
||||
f" - Total messages scanned: {stats['total_messages']}\n"
|
||||
f" - Messages filtered: {stats['filtered_messages']}\n"
|
||||
f" - Messages deleted: {stats['total_deleted']}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
end_at = time.perf_counter()
|
||||
logger.exception("clean_messages failed")
|
||||
click.echo(
|
||||
click.style(
|
||||
f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
raise
|
||||
|
||||
click.echo(click.style("messages cleanup completed.", fg="green"))
|
||||
|
||||
|
||||
@click.command("export-app-messages", help="Export messages for an app to JSONL.GZ.")
|
||||
@click.option("--app-id", required=True, help="Application ID to export messages for.")
|
||||
@click.option(
|
||||
"--start-from",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Optional lower bound (inclusive) for created_at.",
|
||||
)
|
||||
@click.option(
|
||||
"--end-before",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
required=True,
|
||||
help="Upper bound (exclusive) for created_at.",
|
||||
)
|
||||
@click.option(
|
||||
"--filename",
|
||||
required=True,
|
||||
help="Base filename (relative path). Do not include suffix like .jsonl.gz.",
|
||||
)
|
||||
@click.option("--use-cloud-storage", is_flag=True, default=False, help="Upload to cloud storage instead of local file.")
|
||||
@click.option("--batch-size", default=1000, show_default=True, help="Batch size for cursor pagination.")
|
||||
@click.option("--dry-run", is_flag=True, default=False, help="Scan only, print stats without writing any file.")
|
||||
def export_app_messages(
|
||||
app_id: str,
|
||||
start_from: datetime.datetime | None,
|
||||
end_before: datetime.datetime,
|
||||
filename: str,
|
||||
use_cloud_storage: bool,
|
||||
batch_size: int,
|
||||
dry_run: bool,
|
||||
):
|
||||
if start_from and start_from >= end_before:
|
||||
raise click.UsageError("--start-from must be before --end-before.")
|
||||
|
||||
from services.retention.conversation.message_export_service import AppMessageExportService
|
||||
|
||||
try:
|
||||
validated_filename = AppMessageExportService.validate_export_filename(filename)
|
||||
except ValueError as e:
|
||||
raise click.BadParameter(str(e), param_hint="--filename") from e
|
||||
|
||||
click.echo(click.style(f"export_app_messages: starting export for app {app_id}.", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
service = AppMessageExportService(
|
||||
app_id=app_id,
|
||||
end_before=end_before,
|
||||
filename=validated_filename,
|
||||
start_from=start_from,
|
||||
batch_size=batch_size,
|
||||
use_cloud_storage=use_cloud_storage,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
stats = service.run()
|
||||
|
||||
elapsed = time.perf_counter() - start_at
|
||||
click.echo(
|
||||
click.style(
|
||||
f"export_app_messages: completed in {elapsed:.2f}s\n"
|
||||
f" - Batches: {stats.batches}\n"
|
||||
f" - Total messages: {stats.total_messages}\n"
|
||||
f" - Messages with feedback: {stats.messages_with_feedback}\n"
|
||||
f" - Total feedbacks: {stats.total_feedbacks}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
elapsed = time.perf_counter() - start_at
|
||||
logger.exception("export_app_messages failed")
|
||||
click.echo(click.style(f"export_app_messages: failed after {elapsed:.2f}s - {e}", fg="red"))
|
||||
raise
|
||||
755
api/commands/storage.py
Normal file
755
api/commands/storage.py
Normal file
@@ -0,0 +1,755 @@
|
||||
import json
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.opendal_storage import OpenDALStorage
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models.model import UploadFile
|
||||
|
||||
|
||||
@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.")
|
||||
@click.command("clear-orphaned-file-records", help="Clear orphaned file records.")
|
||||
def clear_orphaned_file_records(force: bool):
|
||||
"""
|
||||
Clear orphaned file records in the database.
|
||||
"""
|
||||
|
||||
# define tables and columns to process
|
||||
files_tables = [
|
||||
{"table": "upload_files", "id_column": "id", "key_column": "key"},
|
||||
{"table": "tool_files", "id_column": "id", "key_column": "file_key"},
|
||||
]
|
||||
ids_tables = [
|
||||
{"type": "uuid", "table": "message_files", "column": "upload_file_id"},
|
||||
{"type": "text", "table": "documents", "column": "data_source_info"},
|
||||
{"type": "text", "table": "document_segments", "column": "content"},
|
||||
{"type": "text", "table": "messages", "column": "answer"},
|
||||
{"type": "text", "table": "workflow_node_executions", "column": "inputs"},
|
||||
{"type": "text", "table": "workflow_node_executions", "column": "process_data"},
|
||||
{"type": "text", "table": "workflow_node_executions", "column": "outputs"},
|
||||
{"type": "text", "table": "conversations", "column": "introduction"},
|
||||
{"type": "text", "table": "conversations", "column": "system_instruction"},
|
||||
{"type": "text", "table": "accounts", "column": "avatar"},
|
||||
{"type": "text", "table": "apps", "column": "icon"},
|
||||
{"type": "text", "table": "sites", "column": "icon"},
|
||||
{"type": "json", "table": "messages", "column": "inputs"},
|
||||
{"type": "json", "table": "messages", "column": "message"},
|
||||
]
|
||||
|
||||
# notify user and ask for confirmation
|
||||
click.echo(
|
||||
click.style(
|
||||
"This command will first find and delete orphaned file records from the message_files table,", fg="yellow"
|
||||
)
|
||||
)
|
||||
click.echo(
|
||||
click.style(
|
||||
"and then it will find and delete orphaned file records in the following tables:",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
for files_table in files_tables:
|
||||
click.echo(click.style(f"- {files_table['table']}", fg="yellow"))
|
||||
click.echo(
|
||||
click.style("The following tables and columns will be scanned to find orphaned file records:", fg="yellow")
|
||||
)
|
||||
for ids_table in ids_tables:
|
||||
click.echo(click.style(f"- {ids_table['table']} ({ids_table['column']})", fg="yellow"))
|
||||
click.echo("")
|
||||
|
||||
click.echo(click.style("!!! USE WITH CAUTION !!!", fg="red"))
|
||||
click.echo(
|
||||
click.style(
|
||||
(
|
||||
"Since not all patterns have been fully tested, "
|
||||
"please note that this command may delete unintended file records."
|
||||
),
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
click.echo(
|
||||
click.style("This cannot be undone. Please make sure to back up your database before proceeding.", fg="yellow")
|
||||
)
|
||||
click.echo(
|
||||
click.style(
|
||||
(
|
||||
"It is also recommended to run this during the maintenance window, "
|
||||
"as this may cause high load on your instance."
|
||||
),
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
if not force:
|
||||
click.confirm("Do you want to proceed?", abort=True)
|
||||
|
||||
# start the cleanup process
|
||||
click.echo(click.style("Starting orphaned file records cleanup.", fg="white"))
|
||||
|
||||
# clean up the orphaned records in the message_files table where message_id doesn't exist in messages table
|
||||
try:
|
||||
click.echo(
|
||||
click.style("- Listing message_files records where message_id doesn't exist in messages table", fg="white")
|
||||
)
|
||||
query = (
|
||||
"SELECT mf.id, mf.message_id "
|
||||
"FROM message_files mf LEFT JOIN messages m ON mf.message_id = m.id "
|
||||
"WHERE m.id IS NULL"
|
||||
)
|
||||
orphaned_message_files = []
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
orphaned_message_files.append({"id": str(i[0]), "message_id": str(i[1])})
|
||||
|
||||
if orphaned_message_files:
|
||||
click.echo(click.style(f"Found {len(orphaned_message_files)} orphaned message_files records:", fg="white"))
|
||||
for record in orphaned_message_files:
|
||||
click.echo(click.style(f" - id: {record['id']}, message_id: {record['message_id']}", fg="black"))
|
||||
|
||||
if not force:
|
||||
click.confirm(
|
||||
(
|
||||
f"Do you want to proceed "
|
||||
f"to delete all {len(orphaned_message_files)} orphaned message_files records?"
|
||||
),
|
||||
abort=True,
|
||||
)
|
||||
|
||||
click.echo(click.style("- Deleting orphaned message_files records", fg="white"))
|
||||
query = "DELETE FROM message_files WHERE id IN :ids"
|
||||
with db.engine.begin() as conn:
|
||||
conn.execute(sa.text(query), {"ids": tuple(record["id"] for record in orphaned_message_files)})
|
||||
click.echo(
|
||||
click.style(f"Removed {len(orphaned_message_files)} orphaned message_files records.", fg="green")
|
||||
)
|
||||
else:
|
||||
click.echo(click.style("No orphaned message_files records found. There is nothing to delete.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error deleting orphaned message_files records: {str(e)}", fg="red"))
|
||||
|
||||
# clean up the orphaned records in the rest of the *_files tables
|
||||
try:
|
||||
# fetch file id and keys from each table
|
||||
all_files_in_tables = []
|
||||
for files_table in files_tables:
|
||||
click.echo(click.style(f"- Listing file records in table {files_table['table']}", fg="white"))
|
||||
query = f"SELECT {files_table['id_column']}, {files_table['key_column']} FROM {files_table['table']}"
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
all_files_in_tables.append({"table": files_table["table"], "id": str(i[0]), "key": i[1]})
|
||||
click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white"))
|
||||
|
||||
# fetch referred table and columns
|
||||
guid_regexp = "[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
|
||||
all_ids_in_tables = []
|
||||
for ids_table in ids_tables:
|
||||
query = ""
|
||||
match ids_table["type"]:
|
||||
case "uuid":
|
||||
click.echo(
|
||||
click.style(
|
||||
f"- Listing file ids in column {ids_table['column']} in table {ids_table['table']}",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
c = ids_table["column"]
|
||||
query = f"SELECT {c} FROM {ids_table['table']} WHERE {c} IS NOT NULL"
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])})
|
||||
case "text":
|
||||
t = ids_table["table"]
|
||||
click.echo(
|
||||
click.style(
|
||||
f"- Listing file-id-like strings in column {ids_table['column']} in table {t}",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
query = (
|
||||
f"SELECT regexp_matches({ids_table['column']}, '{guid_regexp}', 'g') AS extracted_id "
|
||||
f"FROM {ids_table['table']}"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
for j in i[0]:
|
||||
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
|
||||
case "json":
|
||||
click.echo(
|
||||
click.style(
|
||||
(
|
||||
f"- Listing file-id-like JSON string in column {ids_table['column']} "
|
||||
f"in table {ids_table['table']}"
|
||||
),
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
query = (
|
||||
f"SELECT regexp_matches({ids_table['column']}::text, '{guid_regexp}', 'g') AS extracted_id "
|
||||
f"FROM {ids_table['table']}"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
for j in i[0]:
|
||||
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
|
||||
case _:
|
||||
pass
|
||||
click.echo(click.style(f"Found {len(all_ids_in_tables)} file ids in tables.", fg="white"))
|
||||
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error fetching keys: {str(e)}", fg="red"))
|
||||
return
|
||||
|
||||
# find orphaned files
|
||||
all_files = [file["id"] for file in all_files_in_tables]
|
||||
all_ids = [file["id"] for file in all_ids_in_tables]
|
||||
orphaned_files = list(set(all_files) - set(all_ids))
|
||||
if not orphaned_files:
|
||||
click.echo(click.style("No orphaned file records found. There is nothing to delete.", fg="green"))
|
||||
return
|
||||
click.echo(click.style(f"Found {len(orphaned_files)} orphaned file records.", fg="white"))
|
||||
for file in orphaned_files:
|
||||
click.echo(click.style(f"- orphaned file id: {file}", fg="black"))
|
||||
if not force:
|
||||
click.confirm(f"Do you want to proceed to delete all {len(orphaned_files)} orphaned file records?", abort=True)
|
||||
|
||||
# delete orphaned records for each file
|
||||
try:
|
||||
for files_table in files_tables:
|
||||
click.echo(click.style(f"- Deleting orphaned file records in table {files_table['table']}", fg="white"))
|
||||
query = f"DELETE FROM {files_table['table']} WHERE {files_table['id_column']} IN :ids"
|
||||
with db.engine.begin() as conn:
|
||||
conn.execute(sa.text(query), {"ids": tuple(orphaned_files)})
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error deleting orphaned file records: {str(e)}", fg="red"))
|
||||
return
|
||||
click.echo(click.style(f"Removed {len(orphaned_files)} orphaned file records.", fg="green"))
|
||||
|
||||
|
||||
@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.")
|
||||
@click.command("remove-orphaned-files-on-storage", help="Remove orphaned files on the storage.")
|
||||
def remove_orphaned_files_on_storage(force: bool):
|
||||
"""
|
||||
Remove orphaned files on the storage.
|
||||
"""
|
||||
|
||||
# define tables and columns to process
|
||||
files_tables = [
|
||||
{"table": "upload_files", "key_column": "key"},
|
||||
{"table": "tool_files", "key_column": "file_key"},
|
||||
]
|
||||
storage_paths = ["image_files", "tools", "upload_files"]
|
||||
|
||||
# notify user and ask for confirmation
|
||||
click.echo(click.style("This command will find and remove orphaned files on the storage,", fg="yellow"))
|
||||
click.echo(
|
||||
click.style("by comparing the files on the storage with the records in the following tables:", fg="yellow")
|
||||
)
|
||||
for files_table in files_tables:
|
||||
click.echo(click.style(f"- {files_table['table']}", fg="yellow"))
|
||||
click.echo(click.style("The following paths on the storage will be scanned to find orphaned files:", fg="yellow"))
|
||||
for storage_path in storage_paths:
|
||||
click.echo(click.style(f"- {storage_path}", fg="yellow"))
|
||||
click.echo("")
|
||||
|
||||
click.echo(click.style("!!! USE WITH CAUTION !!!", fg="red"))
|
||||
click.echo(
|
||||
click.style(
|
||||
"Currently, this command will work only for opendal based storage (STORAGE_TYPE=opendal).", fg="yellow"
|
||||
)
|
||||
)
|
||||
click.echo(
|
||||
click.style(
|
||||
"Since not all patterns have been fully tested, please note that this command may delete unintended files.",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
click.echo(
|
||||
click.style("This cannot be undone. Please make sure to back up your storage before proceeding.", fg="yellow")
|
||||
)
|
||||
click.echo(
|
||||
click.style(
|
||||
(
|
||||
"It is also recommended to run this during the maintenance window, "
|
||||
"as this may cause high load on your instance."
|
||||
),
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
if not force:
|
||||
click.confirm("Do you want to proceed?", abort=True)
|
||||
|
||||
# start the cleanup process
|
||||
click.echo(click.style("Starting orphaned files cleanup.", fg="white"))
|
||||
|
||||
# fetch file id and keys from each table
|
||||
all_files_in_tables = []
|
||||
try:
|
||||
for files_table in files_tables:
|
||||
click.echo(click.style(f"- Listing files from table {files_table['table']}", fg="white"))
|
||||
query = f"SELECT {files_table['key_column']} FROM {files_table['table']}"
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
all_files_in_tables.append(str(i[0]))
|
||||
click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error fetching keys: {str(e)}", fg="red"))
|
||||
return
|
||||
|
||||
all_files_on_storage = []
|
||||
for storage_path in storage_paths:
|
||||
try:
|
||||
click.echo(click.style(f"- Scanning files on storage path {storage_path}", fg="white"))
|
||||
files = storage.scan(path=storage_path, files=True, directories=False)
|
||||
all_files_on_storage.extend(files)
|
||||
except FileNotFoundError:
|
||||
click.echo(click.style(f" -> Skipping path {storage_path} as it does not exist.", fg="yellow"))
|
||||
continue
|
||||
except Exception as e:
|
||||
click.echo(click.style(f" -> Error scanning files on storage path {storage_path}: {str(e)}", fg="red"))
|
||||
continue
|
||||
click.echo(click.style(f"Found {len(all_files_on_storage)} files on storage.", fg="white"))
|
||||
|
||||
# find orphaned files
|
||||
orphaned_files = list(set(all_files_on_storage) - set(all_files_in_tables))
|
||||
if not orphaned_files:
|
||||
click.echo(click.style("No orphaned files found. There is nothing to remove.", fg="green"))
|
||||
return
|
||||
click.echo(click.style(f"Found {len(orphaned_files)} orphaned files.", fg="white"))
|
||||
for file in orphaned_files:
|
||||
click.echo(click.style(f"- orphaned file: {file}", fg="black"))
|
||||
if not force:
|
||||
click.confirm(f"Do you want to proceed to remove all {len(orphaned_files)} orphaned files?", abort=True)
|
||||
|
||||
# delete orphaned files
|
||||
removed_files = 0
|
||||
error_files = 0
|
||||
for file in orphaned_files:
|
||||
try:
|
||||
storage.delete(file)
|
||||
removed_files += 1
|
||||
click.echo(click.style(f"- Removing orphaned file: {file}", fg="white"))
|
||||
except Exception as e:
|
||||
error_files += 1
|
||||
click.echo(click.style(f"- Error deleting orphaned file {file}: {str(e)}", fg="red"))
|
||||
continue
|
||||
if error_files == 0:
|
||||
click.echo(click.style(f"Removed {removed_files} orphaned files without errors.", fg="green"))
|
||||
else:
|
||||
click.echo(click.style(f"Removed {removed_files} orphaned files, with {error_files} errors.", fg="yellow"))
|
||||
|
||||
|
||||
@click.command("file-usage", help="Query file usages and show where files are referenced.")
|
||||
@click.option("--file-id", type=str, default=None, help="Filter by file UUID.")
|
||||
@click.option("--key", type=str, default=None, help="Filter by storage key.")
|
||||
@click.option("--src", type=str, default=None, help="Filter by table.column pattern (e.g., 'documents.%' or '%.icon').")
|
||||
@click.option("--limit", type=int, default=100, help="Limit number of results (default: 100).")
|
||||
@click.option("--offset", type=int, default=0, help="Offset for pagination (default: 0).")
|
||||
@click.option("--json", "output_json", is_flag=True, help="Output results in JSON format.")
|
||||
def file_usage(
|
||||
file_id: str | None,
|
||||
key: str | None,
|
||||
src: str | None,
|
||||
limit: int,
|
||||
offset: int,
|
||||
output_json: bool,
|
||||
):
|
||||
"""
|
||||
Query file usages and show where files are referenced in the database.
|
||||
|
||||
This command reuses the same reference checking logic as clear-orphaned-file-records
|
||||
and displays detailed information about where each file is referenced.
|
||||
"""
|
||||
# define tables and columns to process
|
||||
files_tables = [
|
||||
{"table": "upload_files", "id_column": "id", "key_column": "key"},
|
||||
{"table": "tool_files", "id_column": "id", "key_column": "file_key"},
|
||||
]
|
||||
ids_tables = [
|
||||
{"type": "uuid", "table": "message_files", "column": "upload_file_id", "pk_column": "id"},
|
||||
{"type": "text", "table": "documents", "column": "data_source_info", "pk_column": "id"},
|
||||
{"type": "text", "table": "document_segments", "column": "content", "pk_column": "id"},
|
||||
{"type": "text", "table": "messages", "column": "answer", "pk_column": "id"},
|
||||
{"type": "text", "table": "workflow_node_executions", "column": "inputs", "pk_column": "id"},
|
||||
{"type": "text", "table": "workflow_node_executions", "column": "process_data", "pk_column": "id"},
|
||||
{"type": "text", "table": "workflow_node_executions", "column": "outputs", "pk_column": "id"},
|
||||
{"type": "text", "table": "conversations", "column": "introduction", "pk_column": "id"},
|
||||
{"type": "text", "table": "conversations", "column": "system_instruction", "pk_column": "id"},
|
||||
{"type": "text", "table": "accounts", "column": "avatar", "pk_column": "id"},
|
||||
{"type": "text", "table": "apps", "column": "icon", "pk_column": "id"},
|
||||
{"type": "text", "table": "sites", "column": "icon", "pk_column": "id"},
|
||||
{"type": "json", "table": "messages", "column": "inputs", "pk_column": "id"},
|
||||
{"type": "json", "table": "messages", "column": "message", "pk_column": "id"},
|
||||
]
|
||||
|
||||
# Stream file usages with pagination to avoid holding all results in memory
|
||||
paginated_usages = []
|
||||
total_count = 0
|
||||
|
||||
# First, build a mapping of file_id -> storage_key from the base tables
|
||||
file_key_map = {}
|
||||
for files_table in files_tables:
|
||||
query = f"SELECT {files_table['id_column']}, {files_table['key_column']} FROM {files_table['table']}"
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for row in rs:
|
||||
file_key_map[str(row[0])] = f"{files_table['table']}:{row[1]}"
|
||||
|
||||
# If filtering by key or file_id, verify it exists
|
||||
if file_id and file_id not in file_key_map:
|
||||
if output_json:
|
||||
click.echo(json.dumps({"error": f"File ID {file_id} not found in base tables"}))
|
||||
else:
|
||||
click.echo(click.style(f"File ID {file_id} not found in base tables.", fg="red"))
|
||||
return
|
||||
|
||||
if key:
|
||||
valid_prefixes = {f"upload_files:{key}", f"tool_files:{key}"}
|
||||
matching_file_ids = [fid for fid, fkey in file_key_map.items() if fkey in valid_prefixes]
|
||||
if not matching_file_ids:
|
||||
if output_json:
|
||||
click.echo(json.dumps({"error": f"Key {key} not found in base tables"}))
|
||||
else:
|
||||
click.echo(click.style(f"Key {key} not found in base tables.", fg="red"))
|
||||
return
|
||||
|
||||
guid_regexp = "[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
|
||||
|
||||
# For each reference table/column, find matching file IDs and record the references
|
||||
for ids_table in ids_tables:
|
||||
src_filter = f"{ids_table['table']}.{ids_table['column']}"
|
||||
|
||||
# Skip if src filter doesn't match (use fnmatch for wildcard patterns)
|
||||
if src:
|
||||
if "%" in src or "_" in src:
|
||||
import fnmatch
|
||||
|
||||
# Convert SQL LIKE wildcards to fnmatch wildcards (% -> *, _ -> ?)
|
||||
pattern = src.replace("%", "*").replace("_", "?")
|
||||
if not fnmatch.fnmatch(src_filter, pattern):
|
||||
continue
|
||||
else:
|
||||
if src_filter != src:
|
||||
continue
|
||||
|
||||
match ids_table["type"]:
|
||||
case "uuid":
|
||||
# Direct UUID match
|
||||
query = (
|
||||
f"SELECT {ids_table['pk_column']}, {ids_table['column']} "
|
||||
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for row in rs:
|
||||
record_id = str(row[0])
|
||||
ref_file_id = str(row[1])
|
||||
if ref_file_id not in file_key_map:
|
||||
continue
|
||||
storage_key = file_key_map[ref_file_id]
|
||||
|
||||
# Apply filters
|
||||
if file_id and ref_file_id != file_id:
|
||||
continue
|
||||
if key and not storage_key.endswith(key):
|
||||
continue
|
||||
|
||||
# Only collect items within the requested page range
|
||||
if offset <= total_count < offset + limit:
|
||||
paginated_usages.append(
|
||||
{
|
||||
"src": f"{ids_table['table']}.{ids_table['column']}",
|
||||
"record_id": record_id,
|
||||
"file_id": ref_file_id,
|
||||
"key": storage_key,
|
||||
}
|
||||
)
|
||||
total_count += 1
|
||||
|
||||
case "text" | "json":
|
||||
# Extract UUIDs from text/json content
|
||||
column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"]
|
||||
query = (
|
||||
f"SELECT {ids_table['pk_column']}, {column_cast} "
|
||||
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for row in rs:
|
||||
record_id = str(row[0])
|
||||
content = str(row[1])
|
||||
|
||||
# Find all UUIDs in the content
|
||||
import re
|
||||
|
||||
uuid_pattern = re.compile(guid_regexp, re.IGNORECASE)
|
||||
matches = uuid_pattern.findall(content)
|
||||
|
||||
for ref_file_id in matches:
|
||||
if ref_file_id not in file_key_map:
|
||||
continue
|
||||
storage_key = file_key_map[ref_file_id]
|
||||
|
||||
# Apply filters
|
||||
if file_id and ref_file_id != file_id:
|
||||
continue
|
||||
if key and not storage_key.endswith(key):
|
||||
continue
|
||||
|
||||
# Only collect items within the requested page range
|
||||
if offset <= total_count < offset + limit:
|
||||
paginated_usages.append(
|
||||
{
|
||||
"src": f"{ids_table['table']}.{ids_table['column']}",
|
||||
"record_id": record_id,
|
||||
"file_id": ref_file_id,
|
||||
"key": storage_key,
|
||||
}
|
||||
)
|
||||
total_count += 1
|
||||
case _:
|
||||
pass
|
||||
|
||||
# Output results
|
||||
if output_json:
|
||||
result = {
|
||||
"total": total_count,
|
||||
"offset": offset,
|
||||
"limit": limit,
|
||||
"usages": paginated_usages,
|
||||
}
|
||||
click.echo(json.dumps(result, indent=2))
|
||||
else:
|
||||
click.echo(
|
||||
click.style(f"Found {total_count} file usages (showing {len(paginated_usages)} results)", fg="white")
|
||||
)
|
||||
click.echo("")
|
||||
|
||||
if not paginated_usages:
|
||||
click.echo(click.style("No file usages found matching the specified criteria.", fg="yellow"))
|
||||
return
|
||||
|
||||
# Print table header
|
||||
click.echo(
|
||||
click.style(
|
||||
f"{'Src (Table.Column)':<50} {'Record ID':<40} {'File ID':<40} {'Storage Key':<60}",
|
||||
fg="cyan",
|
||||
)
|
||||
)
|
||||
click.echo(click.style("-" * 190, fg="white"))
|
||||
|
||||
# Print each usage
|
||||
for usage in paginated_usages:
|
||||
click.echo(f"{usage['src']:<50} {usage['record_id']:<40} {usage['file_id']:<40} {usage['key']:<60}")
|
||||
|
||||
# Show pagination info
|
||||
if offset + limit < total_count:
|
||||
click.echo("")
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Showing {offset + 1}-{offset + len(paginated_usages)} of {total_count} results", fg="white"
|
||||
)
|
||||
)
|
||||
click.echo(click.style(f"Use --offset {offset + limit} to see next page", fg="white"))
|
||||
|
||||
|
||||
@click.command(
|
||||
"migrate-oss",
|
||||
help="Migrate files from Local or OpenDAL source to a cloud OSS storage (destination must NOT be local/opendal).",
|
||||
)
|
||||
@click.option(
|
||||
"--path",
|
||||
"paths",
|
||||
multiple=True,
|
||||
help="Storage path prefixes to migrate (repeatable). Defaults: privkeys, upload_files, image_files,"
|
||||
" tools, website_files, keyword_files, ops_trace",
|
||||
)
|
||||
@click.option(
|
||||
"--source",
|
||||
type=click.Choice(["local", "opendal"], case_sensitive=False),
|
||||
default="opendal",
|
||||
show_default=True,
|
||||
help="Source storage type to read from",
|
||||
)
|
||||
@click.option("--overwrite", is_flag=True, default=False, help="Overwrite destination if file already exists")
|
||||
@click.option("--dry-run", is_flag=True, default=False, help="Show what would be migrated without uploading")
|
||||
@click.option("-f", "--force", is_flag=True, help="Skip confirmation and run without prompts")
|
||||
@click.option(
|
||||
"--update-db/--no-update-db",
|
||||
default=True,
|
||||
help="Update upload_files.storage_type from source type to current storage after migration",
|
||||
)
|
||||
def migrate_oss(
|
||||
paths: tuple[str, ...],
|
||||
source: str,
|
||||
overwrite: bool,
|
||||
dry_run: bool,
|
||||
force: bool,
|
||||
update_db: bool,
|
||||
):
|
||||
"""
|
||||
Copy all files under selected prefixes from a source storage
|
||||
(Local filesystem or OpenDAL-backed) into the currently configured
|
||||
destination storage backend, then optionally update DB records.
|
||||
|
||||
Expected usage: set STORAGE_TYPE (and its credentials) to your target backend.
|
||||
"""
|
||||
# Ensure target storage is not local/opendal
|
||||
if dify_config.STORAGE_TYPE in (StorageType.LOCAL, StorageType.OPENDAL):
|
||||
click.echo(
|
||||
click.style(
|
||||
"Target STORAGE_TYPE must be a cloud OSS (not 'local' or 'opendal').\n"
|
||||
"Please set STORAGE_TYPE to one of: s3, aliyun-oss, azure-blob, google-storage, tencent-cos, \n"
|
||||
"volcengine-tos, supabase, oci-storage, huawei-obs, baidu-obs, clickzetta-volume.",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Default paths if none specified
|
||||
default_paths = ("privkeys", "upload_files", "image_files", "tools", "website_files", "keyword_files", "ops_trace")
|
||||
path_list = list(paths) if paths else list(default_paths)
|
||||
is_source_local = source.lower() == "local"
|
||||
|
||||
click.echo(click.style("Preparing migration to target storage.", fg="yellow"))
|
||||
click.echo(click.style(f"Target storage type: {dify_config.STORAGE_TYPE}", fg="white"))
|
||||
if is_source_local:
|
||||
src_root = dify_config.STORAGE_LOCAL_PATH
|
||||
click.echo(click.style(f"Source: local fs, root: {src_root}", fg="white"))
|
||||
else:
|
||||
click.echo(click.style(f"Source: opendal scheme={dify_config.OPENDAL_SCHEME}", fg="white"))
|
||||
click.echo(click.style(f"Paths to migrate: {', '.join(path_list)}", fg="white"))
|
||||
click.echo("")
|
||||
|
||||
if not force:
|
||||
click.confirm("Proceed with migration?", abort=True)
|
||||
|
||||
# Instantiate source storage
|
||||
try:
|
||||
if is_source_local:
|
||||
src_root = dify_config.STORAGE_LOCAL_PATH
|
||||
source_storage = OpenDALStorage(scheme="fs", root=src_root)
|
||||
else:
|
||||
source_storage = OpenDALStorage(scheme=dify_config.OPENDAL_SCHEME)
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Failed to initialize source storage: {str(e)}", fg="red"))
|
||||
return
|
||||
|
||||
total_files = 0
|
||||
copied_files = 0
|
||||
skipped_files = 0
|
||||
errored_files = 0
|
||||
copied_upload_file_keys: list[str] = []
|
||||
|
||||
for prefix in path_list:
|
||||
click.echo(click.style(f"Scanning source path: {prefix}", fg="white"))
|
||||
try:
|
||||
keys = source_storage.scan(path=prefix, files=True, directories=False)
|
||||
except FileNotFoundError:
|
||||
click.echo(click.style(f" -> Skipping missing path: {prefix}", fg="yellow"))
|
||||
continue
|
||||
except NotImplementedError:
|
||||
click.echo(click.style(" -> Source storage does not support scanning.", fg="red"))
|
||||
return
|
||||
except Exception as e:
|
||||
click.echo(click.style(f" -> Error scanning '{prefix}': {str(e)}", fg="red"))
|
||||
continue
|
||||
|
||||
click.echo(click.style(f"Found {len(keys)} files under {prefix}", fg="white"))
|
||||
|
||||
for key in keys:
|
||||
total_files += 1
|
||||
|
||||
# check destination existence
|
||||
if not overwrite:
|
||||
try:
|
||||
if storage.exists(key):
|
||||
skipped_files += 1
|
||||
continue
|
||||
except Exception as e:
|
||||
# existence check failures should not block migration attempt
|
||||
# but should be surfaced to user as a warning for visibility
|
||||
click.echo(
|
||||
click.style(
|
||||
f" -> Warning: failed target existence check for {key}: {str(e)}",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
|
||||
if dry_run:
|
||||
copied_files += 1
|
||||
continue
|
||||
|
||||
# read from source and write to destination
|
||||
try:
|
||||
data = source_storage.load_once(key)
|
||||
except FileNotFoundError:
|
||||
errored_files += 1
|
||||
click.echo(click.style(f" -> Missing on source: {key}", fg="yellow"))
|
||||
continue
|
||||
except Exception as e:
|
||||
errored_files += 1
|
||||
click.echo(click.style(f" -> Error reading {key}: {str(e)}", fg="red"))
|
||||
continue
|
||||
|
||||
try:
|
||||
storage.save(key, data)
|
||||
copied_files += 1
|
||||
if prefix == "upload_files":
|
||||
copied_upload_file_keys.append(key)
|
||||
except Exception as e:
|
||||
errored_files += 1
|
||||
click.echo(click.style(f" -> Error writing {key} to target: {str(e)}", fg="red"))
|
||||
continue
|
||||
|
||||
click.echo("")
|
||||
click.echo(click.style("Migration summary:", fg="yellow"))
|
||||
click.echo(click.style(f" Total: {total_files}", fg="white"))
|
||||
click.echo(click.style(f" Copied: {copied_files}", fg="green"))
|
||||
click.echo(click.style(f" Skipped: {skipped_files}", fg="white"))
|
||||
if errored_files:
|
||||
click.echo(click.style(f" Errors: {errored_files}", fg="red"))
|
||||
|
||||
if dry_run:
|
||||
click.echo(click.style("Dry-run complete. No changes were made.", fg="green"))
|
||||
return
|
||||
|
||||
if errored_files:
|
||||
click.echo(
|
||||
click.style(
|
||||
"Some files failed to migrate. Review errors above before updating DB records.",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
if update_db and not force:
|
||||
if not click.confirm("Proceed to update DB storage_type despite errors?", default=False):
|
||||
update_db = False
|
||||
|
||||
# Optionally update DB records for upload_files.storage_type (only for successfully copied upload_files)
|
||||
if update_db:
|
||||
if not copied_upload_file_keys:
|
||||
click.echo(click.style("No upload_files copied. Skipping DB storage_type update.", fg="yellow"))
|
||||
else:
|
||||
try:
|
||||
source_storage_type = StorageType.LOCAL if is_source_local else StorageType.OPENDAL
|
||||
updated = (
|
||||
db.session.query(UploadFile)
|
||||
.where(
|
||||
UploadFile.storage_type == source_storage_type,
|
||||
UploadFile.key.in_(copied_upload_file_keys),
|
||||
)
|
||||
.update({UploadFile.storage_type: dify_config.STORAGE_TYPE}, synchronize_session=False)
|
||||
)
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"Updated storage_type for {updated} upload_files records.", fg="green"))
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
click.echo(click.style(f"Failed to update DB storage_type: {str(e)}", fg="red"))
|
||||
204
api/commands/system.py
Normal file
204
api/commands/system.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import logging
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from events.app_event import app_was_created
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.db_migration_lock import DbMigrationAutoRenewLock
|
||||
from libs.rsa import generate_key_pair
|
||||
from models import Tenant
|
||||
from models.model import App, AppMode, Conversation
|
||||
from models.provider import Provider, ProviderModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DB_UPGRADE_LOCK_TTL_SECONDS = 60
|
||||
|
||||
|
||||
@click.command(
|
||||
"reset-encrypt-key-pair",
|
||||
help="Reset the asymmetric key pair of workspace for encrypt LLM credentials. "
|
||||
"After the reset, all LLM credentials will become invalid, "
|
||||
"requiring re-entry."
|
||||
"Only support SELF_HOSTED mode.",
|
||||
)
|
||||
@click.confirmation_option(
|
||||
prompt=click.style(
|
||||
"Are you sure you want to reset encrypt key pair? This operation cannot be rolled back!", fg="red"
|
||||
)
|
||||
)
|
||||
def reset_encrypt_key_pair():
|
||||
"""
|
||||
Reset the encrypted key pair of workspace for encrypt LLM credentials.
|
||||
After the reset, all LLM credentials will become invalid, requiring re-entry.
|
||||
Only support SELF_HOSTED mode.
|
||||
"""
|
||||
if dify_config.EDITION != "SELF_HOSTED":
|
||||
click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red"))
|
||||
return
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
tenants = session.query(Tenant).all()
|
||||
for tenant in tenants:
|
||||
if not tenant:
|
||||
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
|
||||
return
|
||||
|
||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||
|
||||
session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
|
||||
session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@click.command("convert-to-agent-apps", help="Convert Agent Assistant to Agent App.")
|
||||
def convert_to_agent_apps():
|
||||
"""
|
||||
Convert Agent Assistant to Agent App.
|
||||
"""
|
||||
click.echo(click.style("Starting convert to agent apps.", fg="green"))
|
||||
|
||||
proceeded_app_ids = []
|
||||
|
||||
while True:
|
||||
# fetch first 1000 apps
|
||||
sql_query = """SELECT a.id AS id FROM apps a
|
||||
INNER JOIN app_model_configs am ON a.app_model_config_id=am.id
|
||||
WHERE a.mode = 'chat'
|
||||
AND am.agent_mode is not null
|
||||
AND (
|
||||
am.agent_mode like '%"strategy": "function_call"%'
|
||||
OR am.agent_mode like '%"strategy": "react"%'
|
||||
)
|
||||
AND (
|
||||
am.agent_mode like '{"enabled": true%'
|
||||
OR am.agent_mode like '{"max_iteration": %'
|
||||
) ORDER BY a.created_at DESC LIMIT 1000
|
||||
"""
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(sql_query))
|
||||
|
||||
apps = []
|
||||
for i in rs:
|
||||
app_id = str(i.id)
|
||||
if app_id not in proceeded_app_ids:
|
||||
proceeded_app_ids.append(app_id)
|
||||
app = db.session.query(App).where(App.id == app_id).first()
|
||||
if app is not None:
|
||||
apps.append(app)
|
||||
|
||||
if len(apps) == 0:
|
||||
break
|
||||
|
||||
for app in apps:
|
||||
click.echo(f"Converting app: {app.id}")
|
||||
|
||||
try:
|
||||
app.mode = AppMode.AGENT_CHAT
|
||||
db.session.commit()
|
||||
|
||||
# update conversation mode to agent
|
||||
db.session.query(Conversation).where(Conversation.app_id == app.id).update(
|
||||
{Conversation.mode: AppMode.AGENT_CHAT}
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"Converted app: {app.id}", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Convert app error: {e.__class__.__name__} {str(e)}", fg="red"))
|
||||
|
||||
click.echo(click.style(f"Conversion complete. Converted {len(proceeded_app_ids)} agent apps.", fg="green"))
|
||||
|
||||
|
||||
@click.command("upgrade-db", help="Upgrade the database")
|
||||
def upgrade_db():
|
||||
click.echo("Preparing database migration...")
|
||||
lock = DbMigrationAutoRenewLock(
|
||||
redis_client=redis_client,
|
||||
name="db_upgrade_lock",
|
||||
ttl_seconds=DB_UPGRADE_LOCK_TTL_SECONDS,
|
||||
logger=logger,
|
||||
log_context="db_migration",
|
||||
)
|
||||
if lock.acquire(blocking=False):
|
||||
migration_succeeded = False
|
||||
try:
|
||||
click.echo(click.style("Starting database migration.", fg="green"))
|
||||
|
||||
# run db migration
|
||||
import flask_migrate
|
||||
|
||||
flask_migrate.upgrade()
|
||||
|
||||
migration_succeeded = True
|
||||
click.echo(click.style("Database migration successful!", fg="green"))
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to execute database migration")
|
||||
click.echo(click.style(f"Database migration failed: {e}", fg="red"))
|
||||
raise SystemExit(1)
|
||||
finally:
|
||||
status = "successful" if migration_succeeded else "failed"
|
||||
lock.release_safely(status=status)
|
||||
else:
|
||||
click.echo("Database migration skipped")
|
||||
|
||||
|
||||
@click.command("fix-app-site-missing", help="Fix app related site missing issue.")
|
||||
def fix_app_site_missing():
|
||||
"""
|
||||
Fix app related site missing issue.
|
||||
"""
|
||||
click.echo(click.style("Starting fix for missing app-related sites.", fg="green"))
|
||||
|
||||
failed_app_ids = []
|
||||
while True:
|
||||
sql = """select apps.id as id from apps left join sites on sites.app_id=apps.id
|
||||
where sites.id is null limit 1000"""
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(sql))
|
||||
|
||||
processed_count = 0
|
||||
for i in rs:
|
||||
processed_count += 1
|
||||
app_id = str(i.id)
|
||||
|
||||
if app_id in failed_app_ids:
|
||||
continue
|
||||
|
||||
try:
|
||||
app = db.session.query(App).where(App.id == app_id).first()
|
||||
if not app:
|
||||
logger.info("App %s not found", app_id)
|
||||
continue
|
||||
|
||||
tenant = app.tenant
|
||||
if tenant:
|
||||
accounts = tenant.get_accounts()
|
||||
if not accounts:
|
||||
logger.info("Fix failed for app %s", app.id)
|
||||
continue
|
||||
|
||||
account = accounts[0]
|
||||
logger.info("Fixing missing site for app %s", app.id)
|
||||
app_was_created.send(app, account=account)
|
||||
except Exception:
|
||||
failed_app_ids.append(app_id)
|
||||
click.echo(click.style(f"Failed to fix missing site for app {app_id}", fg="red"))
|
||||
logger.exception("Failed to fix app related site missing issue, app_id: %s", app_id)
|
||||
continue
|
||||
|
||||
if not processed_count:
|
||||
break
|
||||
|
||||
click.echo(click.style("Fix for missing app-related sites completed successfully!", fg="green"))
|
||||
466
api/commands/vector.py
Normal file
466
api/commands/vector.py
Normal file
@@ -0,0 +1,466 @@
|
||||
import json
|
||||
|
||||
import click
|
||||
from flask import current_app
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.rag.models.document import ChildDocument, Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import App, AppAnnotationSetting, MessageAnnotation
|
||||
|
||||
|
||||
@click.command("vdb-migrate", help="Migrate vector db.")
|
||||
@click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.")
|
||||
def vdb_migrate(scope: str):
|
||||
if scope in {"knowledge", "all"}:
|
||||
migrate_knowledge_vector_database()
|
||||
if scope in {"annotation", "all"}:
|
||||
migrate_annotation_vector_database()
|
||||
|
||||
|
||||
def migrate_annotation_vector_database():
|
||||
"""
|
||||
Migrate annotation datas to target vector database .
|
||||
"""
|
||||
click.echo(click.style("Starting annotation data migration.", fg="green"))
|
||||
create_count = 0
|
||||
skipped_count = 0
|
||||
total_count = 0
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
# get apps info
|
||||
per_page = 50
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
apps = (
|
||||
session.query(App)
|
||||
.where(App.status == "normal")
|
||||
.order_by(App.created_at.desc())
|
||||
.limit(per_page)
|
||||
.offset((page - 1) * per_page)
|
||||
.all()
|
||||
)
|
||||
if not apps:
|
||||
break
|
||||
except SQLAlchemyError:
|
||||
raise
|
||||
|
||||
page += 1
|
||||
for app in apps:
|
||||
total_count = total_count + 1
|
||||
click.echo(
|
||||
f"Processing the {total_count} app {app.id}. " + f"{create_count} created, {skipped_count} skipped."
|
||||
)
|
||||
try:
|
||||
click.echo(f"Creating app annotation index: {app.id}")
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
app_annotation_setting = (
|
||||
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
|
||||
)
|
||||
|
||||
if not app_annotation_setting:
|
||||
skipped_count = skipped_count + 1
|
||||
click.echo(f"App annotation setting disabled: {app.id}")
|
||||
continue
|
||||
# get dataset_collection_binding info
|
||||
dataset_collection_binding = (
|
||||
session.query(DatasetCollectionBinding)
|
||||
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
|
||||
.first()
|
||||
)
|
||||
if not dataset_collection_binding:
|
||||
click.echo(f"App annotation collection binding not found: {app.id}")
|
||||
continue
|
||||
annotations = session.scalars(
|
||||
select(MessageAnnotation).where(MessageAnnotation.app_id == app.id)
|
||||
).all()
|
||||
dataset = Dataset(
|
||||
id=app.id,
|
||||
tenant_id=app.tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider=dataset_collection_binding.provider_name,
|
||||
embedding_model=dataset_collection_binding.model_name,
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
)
|
||||
documents = []
|
||||
if annotations:
|
||||
for annotation in annotations:
|
||||
document = Document(
|
||||
page_content=annotation.question_text,
|
||||
metadata={"annotation_id": annotation.id, "app_id": app.id, "doc_id": annotation.id},
|
||||
)
|
||||
documents.append(document)
|
||||
|
||||
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
|
||||
click.echo(f"Migrating annotations for app: {app.id}.")
|
||||
|
||||
try:
|
||||
vector.delete()
|
||||
click.echo(click.style(f"Deleted vector index for app {app.id}.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Failed to delete vector index for app {app.id}.", fg="red"))
|
||||
raise e
|
||||
if documents:
|
||||
try:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Creating vector index with {len(documents)} annotations for app {app.id}.",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
vector.create(documents)
|
||||
click.echo(click.style(f"Created vector index for app {app.id}.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Failed to created vector index for app {app.id}.", fg="red"))
|
||||
raise e
|
||||
click.echo(f"Successfully migrated app annotation {app.id}.")
|
||||
create_count += 1
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style(f"Error creating app annotation index: {e.__class__.__name__} {str(e)}", fg="red")
|
||||
)
|
||||
continue
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Migration complete. Created {create_count} app annotation indexes. Skipped {skipped_count} apps.",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def migrate_knowledge_vector_database():
|
||||
"""
|
||||
Migrate vector database datas to target vector database .
|
||||
"""
|
||||
click.echo(click.style("Starting vector database migration.", fg="green"))
|
||||
create_count = 0
|
||||
skipped_count = 0
|
||||
total_count = 0
|
||||
vector_type = dify_config.VECTOR_STORE
|
||||
upper_collection_vector_types = {
|
||||
VectorType.MILVUS,
|
||||
VectorType.PGVECTOR,
|
||||
VectorType.VASTBASE,
|
||||
VectorType.RELYT,
|
||||
VectorType.WEAVIATE,
|
||||
VectorType.ORACLE,
|
||||
VectorType.ELASTICSEARCH,
|
||||
VectorType.OPENGAUSS,
|
||||
VectorType.TABLESTORE,
|
||||
VectorType.MATRIXONE,
|
||||
}
|
||||
lower_collection_vector_types = {
|
||||
VectorType.ANALYTICDB,
|
||||
VectorType.CHROMA,
|
||||
VectorType.MYSCALE,
|
||||
VectorType.PGVECTO_RS,
|
||||
VectorType.TIDB_VECTOR,
|
||||
VectorType.OPENSEARCH,
|
||||
VectorType.TENCENT,
|
||||
VectorType.BAIDU,
|
||||
VectorType.VIKINGDB,
|
||||
VectorType.UPSTASH,
|
||||
VectorType.COUCHBASE,
|
||||
VectorType.OCEANBASE,
|
||||
}
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
stmt = (
|
||||
select(Dataset).where(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc())
|
||||
)
|
||||
|
||||
datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
|
||||
if not datasets.items:
|
||||
break
|
||||
except SQLAlchemyError:
|
||||
raise
|
||||
|
||||
page += 1
|
||||
for dataset in datasets:
|
||||
total_count = total_count + 1
|
||||
click.echo(
|
||||
f"Processing the {total_count} dataset {dataset.id}. {create_count} created, {skipped_count} skipped."
|
||||
)
|
||||
try:
|
||||
click.echo(f"Creating dataset vector database index: {dataset.id}")
|
||||
if dataset.index_struct_dict:
|
||||
if dataset.index_struct_dict["type"] == vector_type:
|
||||
skipped_count = skipped_count + 1
|
||||
continue
|
||||
collection_name = ""
|
||||
dataset_id = dataset.id
|
||||
if vector_type in upper_collection_vector_types:
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
elif vector_type == VectorType.QDRANT:
|
||||
if dataset.collection_binding_id:
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
.where(DatasetCollectionBinding.id == dataset.collection_binding_id)
|
||||
.one_or_none()
|
||||
)
|
||||
if dataset_collection_binding:
|
||||
collection_name = dataset_collection_binding.collection_name
|
||||
else:
|
||||
raise ValueError("Dataset Collection Binding not found")
|
||||
else:
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
|
||||
elif vector_type in lower_collection_vector_types:
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
else:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
|
||||
index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
vector = Vector(dataset)
|
||||
click.echo(f"Migrating dataset {dataset.id}.")
|
||||
|
||||
try:
|
||||
vector.delete()
|
||||
click.echo(
|
||||
click.style(f"Deleted vector index {collection_name} for dataset {dataset.id}.", fg="green")
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Failed to delete vector index {collection_name} for dataset {dataset.id}.", fg="red"
|
||||
)
|
||||
)
|
||||
raise e
|
||||
|
||||
dataset_documents = db.session.scalars(
|
||||
select(DatasetDocument).where(
|
||||
DatasetDocument.dataset_id == dataset.id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
).all()
|
||||
|
||||
documents = []
|
||||
segments_count = 0
|
||||
for dataset_document in dataset_documents:
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
).all()
|
||||
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
if dataset_document.doc_form == "hierarchical_model":
|
||||
child_chunks = segment.get_child_chunks()
|
||||
if child_chunks:
|
||||
child_documents = []
|
||||
for child_chunk in child_chunks:
|
||||
child_document = ChildDocument(
|
||||
page_content=child_chunk.content,
|
||||
metadata={
|
||||
"doc_id": child_chunk.index_node_id,
|
||||
"doc_hash": child_chunk.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
child_documents.append(child_document)
|
||||
document.children = child_documents
|
||||
|
||||
documents.append(document)
|
||||
segments_count = segments_count + 1
|
||||
|
||||
if documents:
|
||||
try:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Creating vector index with {len(documents)} documents of {segments_count}"
|
||||
f" segments for dataset {dataset.id}.",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
all_child_documents = []
|
||||
for doc in documents:
|
||||
if doc.children:
|
||||
all_child_documents.extend(doc.children)
|
||||
vector.create(documents)
|
||||
if all_child_documents:
|
||||
vector.create(all_child_documents)
|
||||
click.echo(click.style(f"Created vector index for dataset {dataset.id}.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Failed to created vector index for dataset {dataset.id}.", fg="red"))
|
||||
raise e
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
click.echo(f"Successfully migrated dataset {dataset.id}.")
|
||||
create_count += 1
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
click.echo(click.style(f"Error creating dataset index: {e.__class__.__name__} {str(e)}", fg="red"))
|
||||
continue
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Migration complete. Created {create_count} dataset indexes. Skipped {skipped_count} datasets.", fg="green"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@click.command("add-qdrant-index", help="Add Qdrant index.")
|
||||
@click.option("--field", default="metadata.doc_id", prompt=False, help="Index field , default is metadata.doc_id.")
|
||||
def add_qdrant_index(field: str):
|
||||
click.echo(click.style("Starting Qdrant index creation.", fg="green"))
|
||||
|
||||
create_count = 0
|
||||
|
||||
try:
|
||||
bindings = db.session.query(DatasetCollectionBinding).all()
|
||||
if not bindings:
|
||||
click.echo(click.style("No dataset collection bindings found.", fg="red"))
|
||||
return
|
||||
import qdrant_client
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
from qdrant_client.http.models import PayloadSchemaType
|
||||
|
||||
from core.rag.datasource.vdb.qdrant.qdrant_vector import PathQdrantParams, QdrantConfig
|
||||
|
||||
for binding in bindings:
|
||||
if dify_config.QDRANT_URL is None:
|
||||
raise ValueError("Qdrant URL is required.")
|
||||
qdrant_config = QdrantConfig(
|
||||
endpoint=dify_config.QDRANT_URL,
|
||||
api_key=dify_config.QDRANT_API_KEY,
|
||||
root_path=current_app.root_path,
|
||||
timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
|
||||
grpc_port=dify_config.QDRANT_GRPC_PORT,
|
||||
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED,
|
||||
)
|
||||
try:
|
||||
params = qdrant_config.to_qdrant_params()
|
||||
# Check the type before using
|
||||
if isinstance(params, PathQdrantParams):
|
||||
# PathQdrantParams case
|
||||
client = qdrant_client.QdrantClient(path=params.path)
|
||||
else:
|
||||
# UrlQdrantParams case - params is UrlQdrantParams
|
||||
client = qdrant_client.QdrantClient(
|
||||
url=params.url,
|
||||
api_key=params.api_key,
|
||||
timeout=int(params.timeout),
|
||||
verify=params.verify,
|
||||
grpc_port=params.grpc_port,
|
||||
prefer_grpc=params.prefer_grpc,
|
||||
)
|
||||
# create payload index
|
||||
client.create_payload_index(binding.collection_name, field, field_schema=PayloadSchemaType.KEYWORD)
|
||||
create_count += 1
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code == 404:
|
||||
click.echo(click.style(f"Collection not found: {binding.collection_name}.", fg="red"))
|
||||
continue
|
||||
# Some other error occurred, so re-raise the exception
|
||||
else:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Failed to create Qdrant index for collection: {binding.collection_name}.", fg="red"
|
||||
)
|
||||
)
|
||||
|
||||
except Exception:
|
||||
click.echo(click.style("Failed to create Qdrant client.", fg="red"))
|
||||
|
||||
click.echo(click.style(f"Index creation complete. Created {create_count} collection indexes.", fg="green"))
|
||||
|
||||
|
||||
@click.command("old-metadata-migration", help="Old metadata migration.")
|
||||
def old_metadata_migration():
|
||||
"""
|
||||
Old metadata migration.
|
||||
"""
|
||||
click.echo(click.style("Starting old metadata migration.", fg="green"))
|
||||
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
stmt = (
|
||||
select(DatasetDocument)
|
||||
.where(DatasetDocument.doc_metadata.is_not(None))
|
||||
.order_by(DatasetDocument.created_at.desc())
|
||||
)
|
||||
documents = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
|
||||
except SQLAlchemyError:
|
||||
raise
|
||||
if not documents:
|
||||
break
|
||||
for document in documents:
|
||||
if document.doc_metadata:
|
||||
doc_metadata = document.doc_metadata
|
||||
for key in doc_metadata:
|
||||
for field in BuiltInField:
|
||||
if field.value == key:
|
||||
break
|
||||
else:
|
||||
dataset_metadata = (
|
||||
db.session.query(DatasetMetadata)
|
||||
.where(DatasetMetadata.dataset_id == document.dataset_id, DatasetMetadata.name == key)
|
||||
.first()
|
||||
)
|
||||
if not dataset_metadata:
|
||||
dataset_metadata = DatasetMetadata(
|
||||
tenant_id=document.tenant_id,
|
||||
dataset_id=document.dataset_id,
|
||||
name=key,
|
||||
type="string",
|
||||
created_by=document.created_by,
|
||||
)
|
||||
db.session.add(dataset_metadata)
|
||||
db.session.flush()
|
||||
dataset_metadata_binding = DatasetMetadataBinding(
|
||||
tenant_id=document.tenant_id,
|
||||
dataset_id=document.dataset_id,
|
||||
metadata_id=dataset_metadata.id,
|
||||
document_id=document.id,
|
||||
created_by=document.created_by,
|
||||
)
|
||||
db.session.add(dataset_metadata_binding)
|
||||
else:
|
||||
dataset_metadata_binding = (
|
||||
db.session.query(DatasetMetadataBinding) # type: ignore
|
||||
.where(
|
||||
DatasetMetadataBinding.dataset_id == document.dataset_id,
|
||||
DatasetMetadataBinding.document_id == document.id,
|
||||
DatasetMetadataBinding.metadata_id == dataset_metadata.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not dataset_metadata_binding:
|
||||
dataset_metadata_binding = DatasetMetadataBinding(
|
||||
tenant_id=document.tenant_id,
|
||||
dataset_id=document.dataset_id,
|
||||
metadata_id=dataset_metadata.id,
|
||||
document_id=document.id,
|
||||
created_by=document.created_by,
|
||||
)
|
||||
db.session.add(dataset_metadata_binding)
|
||||
db.session.commit()
|
||||
page += 1
|
||||
click.echo(click.style("Old metadata migration completed.", fg="green"))
|
||||
@@ -18,3 +18,7 @@ class EnterpriseFeatureConfig(BaseSettings):
|
||||
description="Allow customization of the enterprise logo.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
ENTERPRISE_REQUEST_TIMEOUT: int = Field(
|
||||
ge=1, description="Maximum timeout in seconds for enterprise requests", default=5
|
||||
)
|
||||
|
||||
@@ -807,7 +807,7 @@ class DatasetApiKeyApi(Resource):
|
||||
console_ns.abort(
|
||||
400,
|
||||
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
|
||||
code="max_keys_exceeded",
|
||||
custom="max_keys_exceeded",
|
||||
)
|
||||
|
||||
key = ApiToken.generate_api_key(self.token_prefix, 24)
|
||||
|
||||
@@ -10,7 +10,6 @@ from controllers.common.file_response import enforce_download_for_html
|
||||
from controllers.files import files_ns
|
||||
from core.tools.signature import verify_tool_file_signature
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from extensions.ext_database import db as global_db
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
@@ -57,7 +56,7 @@ class ToolFileApi(Resource):
|
||||
raise Forbidden("Invalid request.")
|
||||
|
||||
try:
|
||||
tool_file_manager = ToolFileManager(engine=global_db.engine)
|
||||
tool_file_manager = ToolFileManager()
|
||||
stream, tool_file = tool_file_manager.get_file_generator_by_tool_file_id(
|
||||
file_id,
|
||||
)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import json
|
||||
from contextlib import ExitStack
|
||||
from typing import Self
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask import request, send_file
|
||||
from flask_restx import marshal
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from sqlalchemy import desc, select
|
||||
@@ -100,6 +101,15 @@ class DocumentListQuery(BaseModel):
|
||||
status: str | None = Field(default=None, description="Document status filter")
|
||||
|
||||
|
||||
DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100
|
||||
|
||||
|
||||
class DocumentBatchDownloadZipPayload(BaseModel):
|
||||
"""Request payload for bulk downloading uploaded documents as a ZIP archive."""
|
||||
|
||||
document_ids: list[UUID] = Field(..., min_length=1, max_length=DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS)
|
||||
|
||||
|
||||
register_enum_models(service_api_ns, RetrievalMethod)
|
||||
|
||||
register_schema_models(
|
||||
@@ -109,6 +119,7 @@ register_schema_models(
|
||||
DocumentTextCreatePayload,
|
||||
DocumentTextUpdate,
|
||||
DocumentListQuery,
|
||||
DocumentBatchDownloadZipPayload,
|
||||
Rule,
|
||||
PreProcessingRule,
|
||||
Segmentation,
|
||||
@@ -540,6 +551,46 @@ class DocumentListApi(DatasetApiResource):
|
||||
return response
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/download-zip")
|
||||
class DocumentBatchDownloadZipApi(DatasetApiResource):
|
||||
"""Download multiple uploaded-file documents as a single ZIP archive."""
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[DocumentBatchDownloadZipPayload.__name__])
|
||||
@service_api_ns.doc("download_documents_as_zip")
|
||||
@service_api_ns.doc(description="Download selected uploaded documents as a single ZIP archive")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "ZIP archive generated successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
403: "Forbidden - insufficient permissions",
|
||||
404: "Document or dataset not found",
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
payload = DocumentBatchDownloadZipPayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
upload_files, download_name = DocumentService.prepare_document_batch_download_zip(
|
||||
dataset_id=str(dataset_id),
|
||||
document_ids=[str(document_id) for document_id in payload.document_ids],
|
||||
tenant_id=str(tenant_id),
|
||||
current_user=current_user,
|
||||
)
|
||||
|
||||
with ExitStack() as stack:
|
||||
zip_path = stack.enter_context(FileService.build_upload_files_zip_tempfile(upload_files=upload_files))
|
||||
response = send_file(
|
||||
zip_path,
|
||||
mimetype="application/zip",
|
||||
as_attachment=True,
|
||||
download_name=download_name,
|
||||
)
|
||||
cleanup = stack.pop_all()
|
||||
response.call_on_close(cleanup.close)
|
||||
return response
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status")
|
||||
class DocumentIndexingStatusApi(DatasetApiResource):
|
||||
@service_api_ns.doc("get_document_indexing_status")
|
||||
@@ -600,6 +651,35 @@ class DocumentIndexingStatusApi(DatasetApiResource):
|
||||
return data
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/download")
|
||||
class DocumentDownloadApi(DatasetApiResource):
|
||||
"""Return a signed download URL for a document's original uploaded file."""
|
||||
|
||||
@service_api_ns.doc("get_document_download_url")
|
||||
@service_api_ns.doc(description="Get a signed download URL for a document's original uploaded file")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Download URL generated successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
403: "Forbidden - insufficient permissions",
|
||||
404: "Document or upload file not found",
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def get(self, tenant_id, dataset_id, document_id):
|
||||
dataset = self.get_dataset(str(dataset_id), str(tenant_id))
|
||||
document = DocumentService.get_document(dataset.id, str(document_id))
|
||||
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
if document.tenant_id != str(tenant_id):
|
||||
raise Forbidden("No permission.")
|
||||
|
||||
return {"url": DocumentService.get_document_download_url(document)}
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
||||
class DocumentApi(DatasetApiResource):
|
||||
METADATA_CHOICES = {"all", "only", "without"}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
@@ -46,6 +47,22 @@ def wrap_metadata(metadata, **kwargs):
|
||||
return metadata
|
||||
|
||||
|
||||
def _seed_to_uuid4(seed: str) -> str:
|
||||
"""Derive a deterministic UUID4-formatted string from an arbitrary seed.
|
||||
|
||||
uuid4_to_uuid7 requires a valid UUID v4 string, but some Dify identifiers
|
||||
are not UUIDs (e.g. a workflow_run_id with a "-root" suffix appended to
|
||||
distinguish the root span from the trace). This helper hashes the seed
|
||||
with MD5 and patches the version/variant bits so the result satisfies the
|
||||
UUID v4 contract.
|
||||
"""
|
||||
raw = hashlib.md5(seed.encode()).digest()
|
||||
ba = bytearray(raw)
|
||||
ba[6] = (ba[6] & 0x0F) | 0x40 # version 4
|
||||
ba[8] = (ba[8] & 0x3F) | 0x80 # variant 1
|
||||
return str(uuid.UUID(bytes=bytes(ba)))
|
||||
|
||||
|
||||
def prepare_opik_uuid(user_datetime: datetime | None, user_uuid: str | None):
|
||||
"""Opik needs UUIDv7 while Dify uses UUIDv4 for identifier of most
|
||||
messages and objects. The type-hints of BaseTraceInfo indicates that
|
||||
@@ -95,60 +112,52 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
self.generate_name_trace(trace_info)
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
dify_trace_id = trace_info.trace_id or trace_info.workflow_run_id
|
||||
opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id)
|
||||
workflow_metadata = wrap_metadata(
|
||||
trace_info.metadata, message_id=trace_info.message_id, workflow_app_log_id=trace_info.workflow_app_log_id
|
||||
)
|
||||
root_span_id = None
|
||||
|
||||
if trace_info.message_id:
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id)
|
||||
|
||||
trace_data = {
|
||||
"id": opik_trace_id,
|
||||
"name": TraceTaskName.MESSAGE_TRACE,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": workflow_metadata,
|
||||
"input": wrap_dict("input", trace_info.workflow_run_inputs),
|
||||
"output": wrap_dict("output", trace_info.workflow_run_outputs),
|
||||
"thread_id": trace_info.conversation_id,
|
||||
"tags": ["message", "workflow"],
|
||||
"project_name": self.project,
|
||||
}
|
||||
self.add_trace(trace_data)
|
||||
|
||||
root_span_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_run_id)
|
||||
span_data = {
|
||||
"id": root_span_id,
|
||||
"parent_span_id": None,
|
||||
"trace_id": opik_trace_id,
|
||||
"name": TraceTaskName.WORKFLOW_TRACE,
|
||||
"input": wrap_dict("input", trace_info.workflow_run_inputs),
|
||||
"output": wrap_dict("output", trace_info.workflow_run_outputs),
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": workflow_metadata,
|
||||
"tags": ["workflow"],
|
||||
"project_name": self.project,
|
||||
}
|
||||
self.add_span(span_data)
|
||||
trace_name = TraceTaskName.MESSAGE_TRACE
|
||||
trace_tags = ["message", "workflow"]
|
||||
root_span_seed = trace_info.workflow_run_id
|
||||
else:
|
||||
trace_data = {
|
||||
"id": opik_trace_id,
|
||||
"name": TraceTaskName.MESSAGE_TRACE,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": workflow_metadata,
|
||||
"input": wrap_dict("input", trace_info.workflow_run_inputs),
|
||||
"output": wrap_dict("output", trace_info.workflow_run_outputs),
|
||||
"thread_id": trace_info.conversation_id,
|
||||
"tags": ["workflow"],
|
||||
"project_name": self.project,
|
||||
}
|
||||
self.add_trace(trace_data)
|
||||
dify_trace_id = trace_info.trace_id or trace_info.workflow_run_id
|
||||
trace_name = TraceTaskName.WORKFLOW_TRACE
|
||||
trace_tags = ["workflow"]
|
||||
root_span_seed = _seed_to_uuid4(trace_info.workflow_run_id + "-root")
|
||||
|
||||
opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id)
|
||||
|
||||
trace_data = {
|
||||
"id": opik_trace_id,
|
||||
"name": trace_name,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": workflow_metadata,
|
||||
"input": wrap_dict("input", trace_info.workflow_run_inputs),
|
||||
"output": wrap_dict("output", trace_info.workflow_run_outputs),
|
||||
"thread_id": trace_info.conversation_id,
|
||||
"tags": trace_tags,
|
||||
"project_name": self.project,
|
||||
}
|
||||
self.add_trace(trace_data)
|
||||
|
||||
root_span_id = prepare_opik_uuid(trace_info.start_time, root_span_seed)
|
||||
span_data = {
|
||||
"id": root_span_id,
|
||||
"parent_span_id": None,
|
||||
"trace_id": opik_trace_id,
|
||||
"name": TraceTaskName.WORKFLOW_TRACE,
|
||||
"input": wrap_dict("input", trace_info.workflow_run_inputs),
|
||||
"output": wrap_dict("output", trace_info.workflow_run_outputs),
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": workflow_metadata,
|
||||
"tags": ["workflow"],
|
||||
"project_name": self.project,
|
||||
}
|
||||
self.add_span(span_data)
|
||||
|
||||
# through workflow_run_id get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
@@ -231,15 +240,13 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
else:
|
||||
run_type = "tool"
|
||||
|
||||
parent_span_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id
|
||||
|
||||
if not total_tokens:
|
||||
total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
|
||||
|
||||
span_data = {
|
||||
"trace_id": opik_trace_id,
|
||||
"id": prepare_opik_uuid(created_at, node_execution_id),
|
||||
"parent_span_id": prepare_opik_uuid(trace_info.start_time, parent_span_id),
|
||||
"parent_span_id": root_span_id,
|
||||
"name": node_name,
|
||||
"type": run_type,
|
||||
"start_time": created_at,
|
||||
|
||||
@@ -57,7 +57,7 @@ from core.rag.retrieval.template_prompts import (
|
||||
from core.tools.signature import sign_upload_file
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from dify_graph.file import File, FileTransferMethod, FileType
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
@@ -127,11 +127,12 @@ class DatasetRetrieval:
|
||||
metadata_filter_document_ids, metadata_condition = None, None
|
||||
|
||||
if request.metadata_filtering_mode != "disabled":
|
||||
# Convert workflow layer types to app_config layer types
|
||||
if not request.metadata_model_config:
|
||||
raise ValueError("metadata_model_config is required for this method")
|
||||
app_metadata_model_config = ModelConfig(provider="", name="", mode=LLMMode.CHAT, completion_params={})
|
||||
if request.metadata_filtering_mode == "automatic":
|
||||
if not request.metadata_model_config:
|
||||
raise ValueError("metadata_model_config is required for this method")
|
||||
|
||||
app_metadata_model_config = ModelConfig.model_validate(request.metadata_model_config.model_dump())
|
||||
app_metadata_model_config = ModelConfig.model_validate(request.metadata_model_config.model_dump())
|
||||
|
||||
app_metadata_filtering_conditions = None
|
||||
if request.metadata_filtering_conditions is not None:
|
||||
|
||||
@@ -10,28 +10,19 @@ from typing import Union
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db as global_db
|
||||
from dify_graph.file.models import ToolFile as ToolFilePydanticModel
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import MessageFile
|
||||
from models.tools import ToolFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
|
||||
class ToolFileManager:
|
||||
_engine: Engine
|
||||
|
||||
def __init__(self, engine: Engine | None = None):
|
||||
if engine is None:
|
||||
engine = global_db.engine
|
||||
self._engine = engine
|
||||
|
||||
@staticmethod
|
||||
def sign_file(tool_file_id: str, extension: str) -> str:
|
||||
"""
|
||||
@@ -89,7 +80,7 @@ class ToolFileManager:
|
||||
filepath = f"tools/{tenant_id}/{unique_filename}"
|
||||
storage.save(filepath, file_binary)
|
||||
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
tool_file = ToolFile(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
@@ -132,7 +123,7 @@ class ToolFileManager:
|
||||
filename = f"{unique_name}{extension}"
|
||||
filepath = f"tools/{tenant_id}/{filename}"
|
||||
storage.save(filepath, blob)
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
tool_file = ToolFile(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
@@ -157,7 +148,7 @@ class ToolFileManager:
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
tool_file: ToolFile | None = (
|
||||
session.query(ToolFile)
|
||||
.where(
|
||||
@@ -181,7 +172,7 @@ class ToolFileManager:
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
message_file: MessageFile | None = (
|
||||
session.query(MessageFile)
|
||||
.where(
|
||||
@@ -217,7 +208,9 @@ class ToolFileManager:
|
||||
|
||||
return blob, tool_file.mimetype
|
||||
|
||||
def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, ToolFile | None]:
|
||||
def get_file_generator_by_tool_file_id(
|
||||
self, tool_file_id: str
|
||||
) -> tuple[Generator | None, ToolFilePydanticModel | None]:
|
||||
"""
|
||||
get file binary
|
||||
|
||||
@@ -225,7 +218,7 @@ class ToolFileManager:
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
tool_file: ToolFile | None = (
|
||||
session.query(ToolFile)
|
||||
.where(
|
||||
@@ -239,7 +232,7 @@ class ToolFileManager:
|
||||
|
||||
stream = storage.load_stream(tool_file.file_key)
|
||||
|
||||
return stream, tool_file
|
||||
return stream, ToolFilePydanticModel.model_validate(tool_file)
|
||||
|
||||
|
||||
# init tool_file_parser
|
||||
|
||||
@@ -50,6 +50,7 @@ from dify_graph.nodes.template_transform.template_renderer import (
|
||||
CodeExecutorJinja2TemplateRenderer,
|
||||
)
|
||||
from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from dify_graph.nodes.tool.tool_node import ToolNode
|
||||
from dify_graph.variables.segments import StringSegment
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation
|
||||
@@ -250,6 +251,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
model_factory=self._llm_model_factory,
|
||||
model_instance=model_instance,
|
||||
memory=memory,
|
||||
http_client=self._http_request_http_client,
|
||||
)
|
||||
|
||||
if node_type == NodeType.DATASOURCE:
|
||||
@@ -292,6 +294,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
model_factory=self._llm_model_factory,
|
||||
model_instance=model_instance,
|
||||
memory=memory,
|
||||
http_client=self._http_request_http_client,
|
||||
)
|
||||
|
||||
if node_type == NodeType.PARAMETER_EXTRACTOR:
|
||||
@@ -308,6 +311,15 @@ class DifyNodeFactory(NodeFactory):
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
if node_type == NodeType.TOOL:
|
||||
return ToolNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
tool_file_manager_factory=self._http_request_tool_file_manager_factory(),
|
||||
)
|
||||
|
||||
return node_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
@@ -43,6 +44,24 @@ class FileUploadConfig(BaseModel):
|
||||
number_limits: int = 0
|
||||
|
||||
|
||||
class ToolFile(BaseModel):
|
||||
id: UUID = Field(default_factory=uuid4, description="Unique identifier for the file")
|
||||
user_id: UUID = Field(..., description="ID of the user who owns this file")
|
||||
tenant_id: UUID = Field(..., description="ID of the tenant/organization")
|
||||
conversation_id: UUID | None = Field(None, description="ID of the associated conversation")
|
||||
file_key: str = Field(..., max_length=255, description="Storage key for the file")
|
||||
mimetype: str = Field(..., max_length=255, description="MIME type of the file")
|
||||
original_url: str | None = Field(
|
||||
None, max_length=2048, description="Original URL if file was fetched from external source"
|
||||
)
|
||||
name: str = Field(default="", max_length=255, description="Display name of the file")
|
||||
size: int = Field(default=-1, ge=-1, description="File size in bytes (-1 if unknown)")
|
||||
|
||||
class Config:
|
||||
from_attributes = True # Enable ORM mode for SQLAlchemy compatibility
|
||||
populate_by_name = True
|
||||
|
||||
|
||||
class File(BaseModel):
|
||||
# NOTE: dify_model_identity is a special identifier used to distinguish between
|
||||
# new and old data formats during serialization and deserialization.
|
||||
|
||||
@@ -14,7 +14,6 @@ from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base import LLMUsageTrackingMixin
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
|
||||
from dify_graph.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest, RAGRetrievalProtocol, Source
|
||||
from dify_graph.variables import (
|
||||
ArrayFileSegment,
|
||||
@@ -47,8 +46,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
# Output variable for file
|
||||
_file_outputs: list["File"]
|
||||
|
||||
_llm_file_saver: LLMFileSaver
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
@@ -56,8 +53,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
rag_retrieval: RAGRetrievalProtocol,
|
||||
*,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
@@ -69,14 +64,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
self._file_outputs = []
|
||||
self._rag_retrieval = rag_retrieval
|
||||
|
||||
if llm_file_saver is None:
|
||||
dify_ctx = self.require_dify_context()
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
@classmethod
|
||||
def version(cls):
|
||||
return "1"
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
import mimetypes
|
||||
import typing as tp
|
||||
|
||||
from sqlalchemy import Engine
|
||||
|
||||
from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE
|
||||
from core.helper import ssrf_proxy
|
||||
from core.tools.signature import sign_tool_file
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from dify_graph.file import File, FileTransferMethod, FileType
|
||||
from extensions.ext_database import db as global_db
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
|
||||
|
||||
class LLMFileSaver(tp.Protocol):
|
||||
@@ -59,30 +56,20 @@ class LLMFileSaver(tp.Protocol):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
EngineFactory: tp.TypeAlias = tp.Callable[[], Engine]
|
||||
|
||||
|
||||
class FileSaverImpl(LLMFileSaver):
|
||||
_engine_factory: EngineFactory
|
||||
_tenant_id: str
|
||||
_user_id: str
|
||||
|
||||
def __init__(self, user_id: str, tenant_id: str, engine_factory: EngineFactory | None = None):
|
||||
if engine_factory is None:
|
||||
|
||||
def _factory():
|
||||
return global_db.engine
|
||||
|
||||
engine_factory = _factory
|
||||
self._engine_factory = engine_factory
|
||||
def __init__(self, user_id: str, tenant_id: str, http_client: HttpClientProtocol):
|
||||
self._user_id = user_id
|
||||
self._tenant_id = tenant_id
|
||||
self._http_client = http_client
|
||||
|
||||
def _get_tool_file_manager(self):
|
||||
return ToolFileManager(engine=self._engine_factory())
|
||||
return ToolFileManager()
|
||||
|
||||
def save_remote_url(self, url: str, file_type: FileType) -> File:
|
||||
http_response = ssrf_proxy.get(url)
|
||||
http_response = self._http_client.get(url)
|
||||
http_response.raise_for_status()
|
||||
data = http_response.content
|
||||
mime_type_from_header = http_response.headers.get("Content-Type")
|
||||
|
||||
@@ -64,6 +64,7 @@ from dify_graph.nodes.base.entities import VariableSelector
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
from dify_graph.runtime import VariablePool
|
||||
from dify_graph.variables import (
|
||||
ArrayFileSegment,
|
||||
@@ -127,6 +128,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
credentials_provider: CredentialsProvider,
|
||||
model_factory: ModelFactory,
|
||||
model_instance: ModelInstance,
|
||||
http_client: HttpClientProtocol,
|
||||
memory: PromptMessageMemory | None = None,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
):
|
||||
@@ -149,6 +151,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
http_client=http_client,
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Protocol
|
||||
|
||||
import httpx
|
||||
|
||||
from dify_graph.file import File
|
||||
from dify_graph.file.models import ToolFile
|
||||
|
||||
|
||||
class HttpClientProtocol(Protocol):
|
||||
@@ -40,3 +42,5 @@ class ToolFileManagerProtocol(Protocol):
|
||||
mimetype: str,
|
||||
filename: str | None = None,
|
||||
) -> Any: ...
|
||||
|
||||
def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, ToolFile | None]: ...
|
||||
|
||||
@@ -28,6 +28,7 @@ from dify_graph.nodes.llm import (
|
||||
)
|
||||
from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
|
||||
from .entities import QuestionClassifierNodeData
|
||||
@@ -68,6 +69,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
credentials_provider: "CredentialsProvider",
|
||||
model_factory: "ModelFactory",
|
||||
model_instance: ModelInstance,
|
||||
http_client: HttpClientProtocol,
|
||||
memory: PromptMessageMemory | None = None,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
):
|
||||
@@ -90,6 +92,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
http_client=http_client,
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
@@ -21,11 +18,10 @@ from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from dify_graph.nodes.protocols import ToolFileManagerProtocol
|
||||
from dify_graph.variables.segments import ArrayAnySegment, ArrayFileSegment
|
||||
from dify_graph.variables.variables import ArrayAnyVariable
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import ToolFile
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
from .entities import ToolNodeData
|
||||
@@ -36,7 +32,8 @@ from .exc import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dify_graph.runtime import VariablePool
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
|
||||
|
||||
class ToolNode(Node[ToolNodeData]):
|
||||
@@ -46,6 +43,23 @@ class ToolNode(Node[ToolNodeData]):
|
||||
|
||||
node_type = NodeType.TOOL
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
tool_file_manager_factory: ToolFileManagerProtocol,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._tool_file_manager_factory = tool_file_manager_factory
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
@@ -271,11 +285,9 @@ class ToolNode(Node[ToolNodeData]):
|
||||
|
||||
tool_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
|
||||
_, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id)
|
||||
if not tool_file:
|
||||
raise ToolFileError(f"tool file {tool_file_id} not found")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
@@ -294,11 +306,9 @@ class ToolNode(Node[ToolNodeData]):
|
||||
assert message.meta
|
||||
|
||||
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileError(f"tool file {tool_file_id} not exists")
|
||||
_, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id)
|
||||
if not tool_file:
|
||||
raise ToolFileError(f"tool file {tool_file_id} not exists")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
|
||||
114
api/extensions/otel/celery_sqlcommenter.py
Normal file
114
api/extensions/otel/celery_sqlcommenter.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""
|
||||
Celery SQL comment context for OpenTelemetry SQLCommenter.
|
||||
|
||||
Injects Celery-specific metadata (framework, task_name, traceparent, celery_retries,
|
||||
routing_key) into SQL comments for queries executed by Celery workers. This improves
|
||||
trace-to-SQL correlation and debugging in production.
|
||||
|
||||
Uses the OpenTelemetry context key SQLCOMMENTER_ORM_TAGS_AND_VALUES, which is read
|
||||
by opentelemetry.instrumentation.sqlcommenter_utils._add_framework_tags() when the
|
||||
SQLAlchemy instrumentor appends comments to SQL statements.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from celery.signals import task_postrun, task_prerun
|
||||
from opentelemetry import context
|
||||
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_TRACE_PROPAGATOR = TraceContextTextMapPropagator()
|
||||
|
||||
_SQLCOMMENTER_CONTEXT_KEY = "SQLCOMMENTER_ORM_TAGS_AND_VALUES"
|
||||
_TOKEN_ATTR = "_dify_sqlcommenter_context_token"
|
||||
|
||||
|
||||
def _build_celery_sqlcommenter_tags(task: Any) -> dict[str, str | int]:
|
||||
"""Build SQL commenter tags from the current Celery task and OpenTelemetry context."""
|
||||
tags: dict[str, str | int] = {}
|
||||
|
||||
try:
|
||||
tags["framework"] = f"celery:{_get_celery_version()}"
|
||||
except Exception:
|
||||
tags["framework"] = "celery:unknown"
|
||||
|
||||
if task and getattr(task, "name", None):
|
||||
tags["task_name"] = str(task.name)
|
||||
|
||||
traceparent = _get_traceparent()
|
||||
if traceparent:
|
||||
tags["traceparent"] = traceparent
|
||||
|
||||
if task and hasattr(task, "request"):
|
||||
request = task.request
|
||||
retries = getattr(request, "retries", None)
|
||||
if retries is not None and retries > 0:
|
||||
tags["celery_retries"] = int(retries)
|
||||
|
||||
delivery_info = getattr(request, "delivery_info", None) or {}
|
||||
if isinstance(delivery_info, dict):
|
||||
routing_key = delivery_info.get("routing_key")
|
||||
if routing_key:
|
||||
tags["routing_key"] = str(routing_key)
|
||||
|
||||
return tags
|
||||
|
||||
|
||||
def _get_celery_version() -> str:
|
||||
import celery
|
||||
|
||||
return getattr(celery, "__version__", "unknown")
|
||||
|
||||
|
||||
def _get_traceparent() -> str | None:
|
||||
"""Extract traceparent from the current OpenTelemetry context."""
|
||||
carrier: dict[str, str] = {}
|
||||
_TRACE_PROPAGATOR.inject(carrier)
|
||||
return carrier.get("traceparent")
|
||||
|
||||
|
||||
def _on_task_prerun(*args: object, **kwargs: object) -> None:
|
||||
task = kwargs.get("task")
|
||||
if not task:
|
||||
return
|
||||
|
||||
tags = _build_celery_sqlcommenter_tags(task)
|
||||
if not tags:
|
||||
return
|
||||
|
||||
current = context.get_current()
|
||||
new_ctx = context.set_value(_SQLCOMMENTER_CONTEXT_KEY, tags, current)
|
||||
token = context.attach(new_ctx)
|
||||
setattr(task, _TOKEN_ATTR, token)
|
||||
|
||||
|
||||
def _on_task_postrun(*args: object, **kwargs: object) -> None:
|
||||
task = kwargs.get("task")
|
||||
if not task:
|
||||
return
|
||||
|
||||
token = getattr(task, _TOKEN_ATTR, None)
|
||||
if token is None:
|
||||
return
|
||||
|
||||
try:
|
||||
context.detach(token)
|
||||
except Exception:
|
||||
logger.debug("Failed to detach SQL commenter context", exc_info=True)
|
||||
finally:
|
||||
try:
|
||||
delattr(task, _TOKEN_ATTR)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
def setup_celery_sqlcommenter() -> None:
|
||||
"""
|
||||
Connect Celery task_prerun and task_postrun handlers to inject SQL comment
|
||||
context for worker queries. Call this from init_celery_worker after
|
||||
CeleryInstrumentor().instrument() so our handlers run after the OTEL
|
||||
instrumentor's and the trace context is already attached.
|
||||
"""
|
||||
task_prerun.connect(_on_task_prerun, weak=False)
|
||||
task_postrun.connect(_on_task_postrun, weak=False)
|
||||
@@ -67,11 +67,14 @@ def init_celery_worker(*args, **kwargs):
|
||||
from opentelemetry.metrics import get_meter_provider
|
||||
from opentelemetry.trace import get_tracer_provider
|
||||
|
||||
from extensions.otel.celery_sqlcommenter import setup_celery_sqlcommenter
|
||||
|
||||
tracer_provider = get_tracer_provider()
|
||||
metric_provider = get_meter_provider()
|
||||
if dify_config.DEBUG:
|
||||
logger.info("Initializing OpenTelemetry for Celery worker")
|
||||
CeleryInstrumentor(tracer_provider=tracer_provider, meter_provider=metric_provider).instrument()
|
||||
setup_celery_sqlcommenter()
|
||||
|
||||
|
||||
def is_instrument_flag_enabled() -> bool:
|
||||
|
||||
@@ -824,6 +824,7 @@ class DatasourceProviderService:
|
||||
"langgenius/firecrawl_datasource",
|
||||
"langgenius/notion_datasource",
|
||||
"langgenius/jina_datasource",
|
||||
"watercrawl/watercrawl_datasource",
|
||||
]:
|
||||
datasource_provider_id = DatasourceProviderID(f"{datasource.plugin_id}/{datasource.provider}")
|
||||
credentials = self.list_datasource_credentials(
|
||||
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
from services.enterprise.base import EnterprisePluginManagerRequest
|
||||
from services.errors.base import BaseServiceError
|
||||
|
||||
@@ -28,6 +29,11 @@ class CheckCredentialPolicyComplianceRequest(BaseModel):
|
||||
return data
|
||||
|
||||
|
||||
class PreUninstallPluginRequest(BaseModel):
|
||||
tenant_id: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
|
||||
class CredentialPolicyViolationError(BaseServiceError):
|
||||
pass
|
||||
|
||||
@@ -55,3 +61,21 @@ class PluginManagerService:
|
||||
body.dify_credential_id,
|
||||
ret.get("result", False),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def try_pre_uninstall_plugin(cls, body: PreUninstallPluginRequest):
|
||||
try:
|
||||
# the invocation must be synchronous.
|
||||
EnterprisePluginManagerRequest.send_request(
|
||||
"POST",
|
||||
"/pre-uninstall-plugin",
|
||||
json=body.model_dump(),
|
||||
raise_for_status=True,
|
||||
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"failed to perform pre uninstall plugin hook. tenant_id: %s, plugin_unique_identifier: %s",
|
||||
body.tenant_id,
|
||||
body.plugin_unique_identifier,
|
||||
)
|
||||
|
||||
@@ -32,6 +32,10 @@ from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider import Provider, ProviderCredential
|
||||
from models.provider_ids import GenericProviderID
|
||||
from services.enterprise.plugin_manager_service import (
|
||||
PluginManagerService,
|
||||
PreUninstallPluginRequest,
|
||||
)
|
||||
from services.errors.plugin import PluginInstallationForbiddenError
|
||||
from services.feature_service import FeatureService, PluginInstallationScope
|
||||
|
||||
@@ -519,6 +523,13 @@ class PluginService:
|
||||
if not plugin:
|
||||
return manager.uninstall(tenant_id, plugin_installation_id)
|
||||
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
PluginManagerService.try_pre_uninstall_plugin(
|
||||
PreUninstallPluginRequest(
|
||||
tenant_id=tenant_id,
|
||||
plugin_unique_identifier=plugin.plugin_unique_identifier,
|
||||
)
|
||||
)
|
||||
with Session(db.engine) as session, session.begin():
|
||||
plugin_id = plugin.plugin_id
|
||||
logger.info("Deleting credentials for plugin: %s", plugin_id)
|
||||
|
||||
@@ -124,7 +124,7 @@ class WebsiteService:
|
||||
if provider == "firecrawl":
|
||||
plugin_id = "langgenius/firecrawl_datasource"
|
||||
elif provider == "watercrawl":
|
||||
plugin_id = "langgenius/watercrawl_datasource"
|
||||
plugin_id = "watercrawl/watercrawl_datasource"
|
||||
elif provider == "jinareader":
|
||||
plugin_id = "langgenius/jina_datasource"
|
||||
else:
|
||||
|
||||
@@ -11,6 +11,7 @@ from dify_graph.enums import WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events import StreamCompletedEvent
|
||||
from dify_graph.nodes.llm.node import LLMNode
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
@@ -74,6 +75,7 @@ def init_llm_node(config: dict) -> LLMNode:
|
||||
credentials_provider=MagicMock(spec=CredentialsProvider),
|
||||
model_factory=MagicMock(spec=ModelFactory),
|
||||
model_instance=MagicMock(spec=ModelInstance),
|
||||
http_client=MagicMock(spec=HttpClientProtocol),
|
||||
)
|
||||
|
||||
return node
|
||||
|
||||
@@ -8,6 +8,7 @@ from core.workflow.node_factory import DifyNodeFactory
|
||||
from dify_graph.enums import WorkflowNodeExecutionStatus
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.node_events import StreamCompletedEvent
|
||||
from dify_graph.nodes.protocols import ToolFileManagerProtocol
|
||||
from dify_graph.nodes.tool.tool_node import ToolNode
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
@@ -55,11 +56,14 @@ def init_tool_node(config: dict):
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol)
|
||||
|
||||
node = ToolNode(
|
||||
id=str(uuid.uuid4()),
|
||||
config=config,
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tool_file_manager_factory=tool_file_manager_factory,
|
||||
)
|
||||
return node
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from dify_graph.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest
|
||||
from models.dataset import Dataset, Document
|
||||
from services.account_service import AccountService, TenantService
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
|
||||
class TestGetAvailableDatasetsIntegration:
|
||||
@@ -22,7 +23,7 @@ class TestGetAvailableDatasetsIntegration:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -83,7 +84,7 @@ class TestGetAvailableDatasetsIntegration:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -136,7 +137,7 @@ class TestGetAvailableDatasetsIntegration:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -189,7 +190,7 @@ class TestGetAvailableDatasetsIntegration:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -252,7 +253,7 @@ class TestGetAvailableDatasetsIntegration:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -286,7 +287,7 @@ class TestGetAvailableDatasetsIntegration:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account1, name=fake.company())
|
||||
tenant1 = account1.current_tenant
|
||||
@@ -295,7 +296,7 @@ class TestGetAvailableDatasetsIntegration:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account2, name=fake.company())
|
||||
tenant2 = account2.current_tenant
|
||||
@@ -362,7 +363,7 @@ class TestGetAvailableDatasetsIntegration:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -384,7 +385,7 @@ class TestGetAvailableDatasetsIntegration:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -445,7 +446,7 @@ class TestKnowledgeRetrievalIntegration:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -513,7 +514,7 @@ class TestKnowledgeRetrievalIntegration:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -561,7 +562,7 @@ class TestKnowledgeRetrievalIntegration:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
@@ -1 +1,24 @@
|
||||
"""Helper utilities for integration tests."""
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def generate_valid_password(fake, length: int = 12) -> str:
|
||||
"""Generate a password that always satisfies the project's password validation rules.
|
||||
|
||||
The password validation rule in ``api/libs/password.py`` requires passwords to
|
||||
contain **both letters and digits** with a minimum length of 8:
|
||||
|
||||
``^(?=.*[a-zA-Z])(?=.*\\d).{8,}$``
|
||||
|
||||
``Faker.password()`` does **not** guarantee that the generated password will
|
||||
contain both character types, which can cause intermittent test failures.
|
||||
|
||||
This helper re-generates until the result is valid (typically first attempt).
|
||||
"""
|
||||
for _ in range(100):
|
||||
pwd = fake.password(length=length)
|
||||
if re.search(r"[a-zA-Z]", pwd) and re.search(r"\d", pwd):
|
||||
return pwd
|
||||
# Fallback: should never be reached in practice
|
||||
return fake.password(length=max(length - 2, 6)) + "a1"
|
||||
|
||||
@@ -20,6 +20,7 @@ from services.errors.account import (
|
||||
TenantNotFoundError,
|
||||
)
|
||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
|
||||
class TestAccountService:
|
||||
@@ -53,7 +54,7 @@ class TestAccountService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
@@ -133,7 +134,7 @@ class TestAccountService:
|
||||
email=email,
|
||||
name=name,
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
|
||||
def test_create_account_email_in_freeze(
|
||||
@@ -145,7 +146,7 @@ class TestAccountService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = True
|
||||
@@ -169,7 +170,7 @@ class TestAccountService:
|
||||
"""
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
with pytest.raises(AccountPasswordError):
|
||||
AccountService.authenticate(email, password)
|
||||
|
||||
@@ -180,7 +181,7 @@ class TestAccountService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
@@ -208,8 +209,8 @@ class TestAccountService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
correct_password = fake.password(length=12)
|
||||
wrong_password = fake.password(length=12)
|
||||
correct_password = generate_valid_password(fake)
|
||||
wrong_password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
@@ -234,7 +235,7 @@ class TestAccountService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
new_password = fake.password(length=12)
|
||||
new_password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
@@ -267,7 +268,7 @@ class TestAccountService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
@@ -297,8 +298,8 @@ class TestAccountService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
old_password = fake.password(length=12)
|
||||
new_password = fake.password(length=12)
|
||||
old_password = generate_valid_password(fake)
|
||||
new_password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
@@ -327,9 +328,9 @@ class TestAccountService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
old_password = fake.password(length=12)
|
||||
wrong_password = fake.password(length=12)
|
||||
new_password = fake.password(length=12)
|
||||
old_password = generate_valid_password(fake)
|
||||
wrong_password = generate_valid_password(fake)
|
||||
new_password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
@@ -354,7 +355,7 @@ class TestAccountService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
old_password = fake.password(length=12)
|
||||
old_password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
@@ -378,7 +379,7 @@ class TestAccountService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies[
|
||||
@@ -412,7 +413,7 @@ class TestAccountService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies[
|
||||
@@ -437,7 +438,7 @@ class TestAccountService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies[
|
||||
@@ -535,7 +536,7 @@ class TestAccountService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
@@ -563,7 +564,7 @@ class TestAccountService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
updated_name = fake.name()
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
@@ -592,7 +593,7 @@ class TestAccountService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
@@ -615,7 +616,7 @@ class TestAccountService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
ip_address = fake.ipv4()
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
@@ -645,7 +646,7 @@ class TestAccountService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
ip_address = fake.ipv4()
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
@@ -684,7 +685,7 @@ class TestAccountService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
@@ -714,7 +715,7 @@ class TestAccountService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
@@ -747,7 +748,7 @@ class TestAccountService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
tenant_name = fake.company()
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
@@ -792,7 +793,7 @@ class TestAccountService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
@@ -825,7 +826,7 @@ class TestAccountService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
tenant_name = fake.company()
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
@@ -864,7 +865,7 @@ class TestAccountService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
@@ -892,7 +893,7 @@ class TestAccountService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
@@ -926,7 +927,7 @@ class TestAccountService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
tenant_name = fake.company()
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
@@ -957,7 +958,7 @@ class TestAccountService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
@@ -997,7 +998,7 @@ class TestAccountService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
@@ -1043,7 +1044,7 @@ class TestAccountService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
@@ -1080,7 +1081,7 @@ class TestAccountService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
@@ -1110,7 +1111,7 @@ class TestAccountService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
@@ -1139,7 +1140,7 @@ class TestAccountService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
wrong_code = fake.numerify(text="######")
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
@@ -1259,7 +1260,7 @@ class TestTenantService:
|
||||
tenant_name = fake.company()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies[
|
||||
"feature_service"
|
||||
@@ -1291,10 +1292,10 @@ class TestTenantService:
|
||||
tenant_name = fake.company()
|
||||
email1 = fake.email()
|
||||
name1 = fake.name()
|
||||
password1 = fake.password(length=12)
|
||||
password1 = generate_valid_password(fake)
|
||||
email2 = fake.email()
|
||||
name2 = fake.name()
|
||||
password2 = fake.password(length=12)
|
||||
password2 = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies[
|
||||
"feature_service"
|
||||
@@ -1332,7 +1333,7 @@ class TestTenantService:
|
||||
tenant_name = fake.company()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies[
|
||||
"feature_service"
|
||||
@@ -1364,7 +1365,7 @@ class TestTenantService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
tenant1_name = fake.company()
|
||||
tenant2_name = fake.company()
|
||||
# Setup mocks
|
||||
@@ -1403,7 +1404,7 @@ class TestTenantService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
tenant_name = fake.company()
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies[
|
||||
@@ -1441,7 +1442,7 @@ class TestTenantService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies[
|
||||
"feature_service"
|
||||
@@ -1466,7 +1467,7 @@ class TestTenantService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
tenant1_name = fake.company()
|
||||
tenant2_name = fake.company()
|
||||
# Setup mocks
|
||||
@@ -1507,7 +1508,7 @@ class TestTenantService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies[
|
||||
"feature_service"
|
||||
@@ -1534,7 +1535,7 @@ class TestTenantService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
tenant_name = fake.company()
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies[
|
||||
@@ -1562,10 +1563,10 @@ class TestTenantService:
|
||||
tenant_name = fake.company()
|
||||
owner_email = fake.email()
|
||||
owner_name = fake.name()
|
||||
owner_password = fake.password(length=12)
|
||||
owner_password = generate_valid_password(fake)
|
||||
admin_email = fake.email()
|
||||
admin_name = fake.name()
|
||||
admin_password = fake.password(length=12)
|
||||
admin_password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies[
|
||||
"feature_service"
|
||||
@@ -1631,7 +1632,7 @@ class TestTenantService:
|
||||
tenant_name = fake.company()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies[
|
||||
"feature_service"
|
||||
@@ -1664,10 +1665,10 @@ class TestTenantService:
|
||||
tenant_name = fake.company()
|
||||
owner_email = fake.email()
|
||||
owner_name = fake.name()
|
||||
owner_password = fake.password(length=12)
|
||||
owner_password = generate_valid_password(fake)
|
||||
member_email = fake.email()
|
||||
member_name = fake.name()
|
||||
member_password = fake.password(length=12)
|
||||
member_password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies[
|
||||
"feature_service"
|
||||
@@ -1705,7 +1706,7 @@ class TestTenantService:
|
||||
tenant_name = fake.company()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
invalid_action = "invalid_action_that_doesnt_exist"
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies[
|
||||
@@ -1738,7 +1739,7 @@ class TestTenantService:
|
||||
tenant_name = fake.company()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies[
|
||||
"feature_service"
|
||||
@@ -1770,10 +1771,10 @@ class TestTenantService:
|
||||
tenant_name = fake.company()
|
||||
owner_email = fake.email()
|
||||
owner_name = fake.name()
|
||||
owner_password = fake.password(length=12)
|
||||
owner_password = generate_valid_password(fake)
|
||||
member_email = fake.email()
|
||||
member_name = fake.name()
|
||||
member_password = fake.password(length=12)
|
||||
member_password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies[
|
||||
"feature_service"
|
||||
@@ -1829,7 +1830,7 @@ class TestTenantService:
|
||||
tenant_name = fake.company()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies[
|
||||
"feature_service"
|
||||
@@ -1861,10 +1862,10 @@ class TestTenantService:
|
||||
tenant_name = fake.company()
|
||||
owner_email = fake.email()
|
||||
owner_name = fake.name()
|
||||
owner_password = fake.password(length=12)
|
||||
owner_password = generate_valid_password(fake)
|
||||
non_member_email = fake.email()
|
||||
non_member_name = fake.name()
|
||||
non_member_password = fake.password(length=12)
|
||||
non_member_password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies[
|
||||
"feature_service"
|
||||
@@ -1900,10 +1901,10 @@ class TestTenantService:
|
||||
tenant_name = fake.company()
|
||||
owner_email = fake.email()
|
||||
owner_name = fake.name()
|
||||
owner_password = fake.password(length=12)
|
||||
owner_password = generate_valid_password(fake)
|
||||
member_email = fake.email()
|
||||
member_name = fake.name()
|
||||
member_password = fake.password(length=12)
|
||||
member_password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies[
|
||||
"feature_service"
|
||||
@@ -1949,10 +1950,10 @@ class TestTenantService:
|
||||
tenant_name = fake.company()
|
||||
owner_email = fake.email()
|
||||
owner_name = fake.name()
|
||||
owner_password = fake.password(length=12)
|
||||
owner_password = generate_valid_password(fake)
|
||||
member_email = fake.email()
|
||||
member_name = fake.name()
|
||||
member_password = fake.password(length=12)
|
||||
member_password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies[
|
||||
"feature_service"
|
||||
@@ -2006,10 +2007,10 @@ class TestTenantService:
|
||||
tenant_name = fake.company()
|
||||
owner_email = fake.email()
|
||||
owner_name = fake.name()
|
||||
owner_password = fake.password(length=12)
|
||||
owner_password = generate_valid_password(fake)
|
||||
member_email = fake.email()
|
||||
member_name = fake.name()
|
||||
member_password = fake.password(length=12)
|
||||
member_password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies[
|
||||
"feature_service"
|
||||
@@ -2071,7 +2072,7 @@ class TestTenantService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
workspace_name = fake.company()
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies[
|
||||
@@ -2110,7 +2111,7 @@ class TestTenantService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
existing_tenant_name = fake.company()
|
||||
new_workspace_name = fake.company()
|
||||
# Setup mocks
|
||||
@@ -2151,7 +2152,7 @@ class TestTenantService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
workspace_name = fake.company()
|
||||
# Setup mocks to disable workspace creation
|
||||
mock_external_service_dependencies[
|
||||
@@ -2178,13 +2179,13 @@ class TestTenantService:
|
||||
tenant_name = fake.company()
|
||||
owner_email = fake.email()
|
||||
owner_name = fake.name()
|
||||
owner_password = fake.password(length=12)
|
||||
owner_password = generate_valid_password(fake)
|
||||
admin_email = fake.email()
|
||||
admin_name = fake.name()
|
||||
admin_password = fake.password(length=12)
|
||||
admin_password = generate_valid_password(fake)
|
||||
normal_email = fake.email()
|
||||
normal_name = fake.name()
|
||||
normal_password = fake.password(length=12)
|
||||
normal_password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies[
|
||||
"feature_service"
|
||||
@@ -2244,13 +2245,13 @@ class TestTenantService:
|
||||
tenant_name = fake.company()
|
||||
owner_email = fake.email()
|
||||
owner_name = fake.name()
|
||||
owner_password = fake.password(length=12)
|
||||
owner_password = generate_valid_password(fake)
|
||||
operator_email = fake.email()
|
||||
operator_name = fake.name()
|
||||
operator_password = fake.password(length=12)
|
||||
operator_password = generate_valid_password(fake)
|
||||
normal_email = fake.email()
|
||||
normal_name = fake.name()
|
||||
normal_password = fake.password(length=12)
|
||||
normal_password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies[
|
||||
"feature_service"
|
||||
@@ -2351,7 +2352,7 @@ class TestRegisterService:
|
||||
fake = Faker()
|
||||
admin_email = fake.email()
|
||||
admin_name = fake.name()
|
||||
admin_password = fake.password(length=12)
|
||||
admin_password = generate_valid_password(fake)
|
||||
ip_address = fake.ipv4()
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
@@ -2399,7 +2400,7 @@ class TestRegisterService:
|
||||
fake = Faker()
|
||||
admin_email = fake.email()
|
||||
admin_name = fake.name()
|
||||
admin_password = fake.password(length=12)
|
||||
admin_password = generate_valid_password(fake)
|
||||
ip_address = fake.ipv4()
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
@@ -2440,7 +2441,7 @@ class TestRegisterService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
language = fake.random_element(elements=("en-US", "zh-CN"))
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
@@ -2531,7 +2532,7 @@ class TestRegisterService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
language = fake.random_element(elements=("en-US", "zh-CN"))
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
@@ -2576,7 +2577,7 @@ class TestRegisterService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
language = fake.random_element(elements=("en-US", "zh-CN"))
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
@@ -2614,7 +2615,7 @@ class TestRegisterService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
language = fake.random_element(elements=("en-US", "zh-CN"))
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
@@ -2653,7 +2654,7 @@ class TestRegisterService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
language = fake.random_element(elements=("en-US", "zh-CN"))
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
@@ -2690,7 +2691,7 @@ class TestRegisterService:
|
||||
tenant_name = fake.company()
|
||||
inviter_email = fake.email()
|
||||
inviter_name = fake.name()
|
||||
inviter_password = fake.password(length=12)
|
||||
inviter_password = generate_valid_password(fake)
|
||||
new_member_email = fake.email()
|
||||
language = fake.random_element(elements=("en-US", "zh-CN"))
|
||||
# Setup mocks
|
||||
@@ -2760,10 +2761,10 @@ class TestRegisterService:
|
||||
tenant_name = fake.company()
|
||||
inviter_email = fake.email()
|
||||
inviter_name = fake.name()
|
||||
inviter_password = fake.password(length=12)
|
||||
inviter_password = generate_valid_password(fake)
|
||||
existing_member_email = fake.email()
|
||||
existing_member_name = fake.name()
|
||||
existing_member_password = fake.password(length=12)
|
||||
existing_member_password = generate_valid_password(fake)
|
||||
language = fake.random_element(elements=("en-US", "zh-CN"))
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
@@ -2824,10 +2825,10 @@ class TestRegisterService:
|
||||
tenant_name = fake.company()
|
||||
inviter_email = fake.email()
|
||||
inviter_name = fake.name()
|
||||
inviter_password = fake.password(length=12)
|
||||
inviter_password = generate_valid_password(fake)
|
||||
existing_pending_member_email = fake.email()
|
||||
existing_pending_member_name = fake.name()
|
||||
existing_pending_member_password = fake.password(length=12)
|
||||
existing_pending_member_password = generate_valid_password(fake)
|
||||
language = fake.random_element(elements=("en-US", "zh-CN"))
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
@@ -2914,10 +2915,10 @@ class TestRegisterService:
|
||||
tenant_name = fake.company()
|
||||
inviter_email = fake.email()
|
||||
inviter_name = fake.name()
|
||||
inviter_password = fake.password(length=12)
|
||||
inviter_password = generate_valid_password(fake)
|
||||
already_in_tenant_email = fake.email()
|
||||
already_in_tenant_name = fake.name()
|
||||
already_in_tenant_password = fake.password(length=12)
|
||||
already_in_tenant_password = generate_valid_password(fake)
|
||||
language = fake.random_element(elements=("en-US", "zh-CN"))
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
@@ -2967,7 +2968,7 @@ class TestRegisterService:
|
||||
tenant_name = fake.company()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
@@ -3011,7 +3012,7 @@ class TestRegisterService:
|
||||
tenant_name = fake.company()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
@@ -3058,7 +3059,7 @@ class TestRegisterService:
|
||||
tenant_name = fake.company()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
@@ -3101,7 +3102,7 @@ class TestRegisterService:
|
||||
tenant_name = fake.company()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
@@ -3144,7 +3145,7 @@ class TestRegisterService:
|
||||
tenant_name = fake.company()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
@@ -3212,7 +3213,7 @@ class TestRegisterService:
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
invalid_tenant_id = fake.uuid4()
|
||||
token = fake.uuid4()
|
||||
# Setup mocks
|
||||
@@ -3263,7 +3264,7 @@ class TestRegisterService:
|
||||
tenant_name = fake.company()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
token = fake.uuid4()
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
@@ -3313,7 +3314,7 @@ class TestRegisterService:
|
||||
tenant_name = fake.company()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
token = fake.uuid4()
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
|
||||
@@ -11,6 +11,7 @@ from models.model import AppModelConfig, Conversation, EndUser, Message, Message
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.agent_service import AgentService
|
||||
from services.app_service import AppService
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
|
||||
class TestAgentService:
|
||||
@@ -111,7 +112,7 @@ class TestAgentService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
@@ -9,6 +9,7 @@ from models import Account
|
||||
from models.model import MessageAnnotation
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.app_service import AppService
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
|
||||
class TestAnnotationService:
|
||||
@@ -78,7 +79,7 @@ class TestAnnotationService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
@@ -7,6 +7,7 @@ from sqlalchemy.orm import Session
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.api_based_extension_service import APIBasedExtensionService
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
|
||||
class TestAPIBasedExtensionService:
|
||||
@@ -55,7 +56,7 @@ class TestAPIBasedExtensionService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
@@ -9,6 +9,7 @@ from models.model import App, AppModelConfig
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.app_dsl_service import AppDslService, ImportMode, ImportStatus
|
||||
from services.app_service import AppService
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
|
||||
class TestAppDslService:
|
||||
@@ -89,7 +90,7 @@ class TestAppDslService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
@@ -10,6 +10,7 @@ from models.model import EndUser
|
||||
from models.workflow import Workflow
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
|
||||
class TestAppGenerateService:
|
||||
@@ -147,7 +148,7 @@ class TestAppGenerateService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
@@ -8,6 +8,7 @@ from constants.model_template import default_app_templates
|
||||
from models import Account
|
||||
from models.model import App, Site
|
||||
from services.account_service import AccountService, TenantService
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
# Delay import of AppService to avoid circular dependency
|
||||
# from services.app_service import AppService
|
||||
@@ -56,7 +57,7 @@ class TestAppService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -112,7 +113,7 @@ class TestAppService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -155,7 +156,7 @@ class TestAppService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -203,7 +204,7 @@ class TestAppService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -259,7 +260,7 @@ class TestAppService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -334,7 +335,7 @@ class TestAppService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -404,7 +405,7 @@ class TestAppService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -473,7 +474,7 @@ class TestAppService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -526,7 +527,7 @@ class TestAppService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -585,7 +586,7 @@ class TestAppService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -645,7 +646,7 @@ class TestAppService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -705,7 +706,7 @@ class TestAppService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -756,7 +757,7 @@ class TestAppService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -808,7 +809,7 @@ class TestAppService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -868,7 +869,7 @@ class TestAppService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -907,7 +908,7 @@ class TestAppService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -947,7 +948,7 @@ class TestAppService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -997,7 +998,7 @@ class TestAppService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -1039,7 +1040,7 @@ class TestAppService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
@@ -13,6 +13,7 @@ from services.errors.message import (
|
||||
SuggestedQuestionsAfterAnswerDisabledError,
|
||||
)
|
||||
from services.message_service import MessageService
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
|
||||
class TestMessageService:
|
||||
@@ -95,7 +96,7 @@ class TestMessageService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -633,7 +634,7 @@ class TestMessageService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(other_account, name=fake.company())
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from models.model import EndUser, Message
|
||||
from models.web import SavedMessage
|
||||
from services.app_service import AppService
|
||||
from services.saved_message_service import SavedMessageService
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
|
||||
class TestSavedMessageService:
|
||||
@@ -64,7 +65,7 @@ class TestSavedMessageService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
@@ -10,6 +10,7 @@ from core.trigger.entities.entities import Subscription as TriggerSubscriptionEn
|
||||
from models.provider_ids import TriggerProviderID
|
||||
from models.trigger import TriggerSubscription
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
|
||||
class TestTriggerProviderService:
|
||||
@@ -75,7 +76,7 @@ class TestTriggerProviderService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
@@ -12,6 +12,7 @@ from models.web import PinnedConversation
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.app_service import AppService
|
||||
from services.web_conversation_service import WebConversationService
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
|
||||
class TestWebConversationService:
|
||||
@@ -69,7 +70,7 @@ class TestWebConversationService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
@@ -12,6 +12,7 @@ from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAcco
|
||||
from models.model import App, Site
|
||||
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
|
||||
from services.webapp_auth_service import WebAppAuthService, WebAppAuthType
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
|
||||
class TestWebAppAuthService:
|
||||
@@ -109,7 +110,7 @@ class TestWebAppAuthService:
|
||||
tuple: (account, tenant, password) - Created account, tenant and password
|
||||
"""
|
||||
fake = Faker()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
|
||||
# Create account with password
|
||||
import uuid
|
||||
@@ -272,7 +273,7 @@ class TestWebAppAuthService:
|
||||
"""
|
||||
# Arrange: Create banned account
|
||||
fake = Faker()
|
||||
password = fake.password(length=12)
|
||||
password = generate_valid_password(fake)
|
||||
unique_email = f"test_{uuid.uuid4().hex[:8]}@example.com"
|
||||
|
||||
account = Account(
|
||||
|
||||
@@ -13,6 +13,7 @@ from models.trigger import AppTrigger, WorkflowWebhookTrigger
|
||||
from models.workflow import Workflow
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.trigger.webhook_service import WebhookService
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
|
||||
class TestWebhookService:
|
||||
@@ -60,7 +61,7 @@ class TestWebhookService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
@@ -15,6 +15,7 @@ from services.account_service import AccountService, TenantService
|
||||
# Delay import of AppService to avoid circular dependency
|
||||
# from services.app_service import AppService
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
|
||||
class TestWorkflowAppService:
|
||||
@@ -72,7 +73,7 @@ class TestWorkflowAppService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -120,7 +121,7 @@ class TestWorkflowAppService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
@@ -15,6 +15,7 @@ from models.workflow import WorkflowRun
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.app_service import AppService
|
||||
from services.workflow_run_service import WorkflowRunService
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
|
||||
class TestWorkflowRunService:
|
||||
@@ -72,7 +73,7 @@ class TestWorkflowRunService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
@@ -13,6 +13,7 @@ from models.workflow import Workflow as WorkflowModel
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.app_service import AppService
|
||||
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
|
||||
class TestWorkflowToolManageService:
|
||||
@@ -87,7 +88,7 @@ class TestWorkflowToolManageService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
@@ -15,6 +15,7 @@ from faker import Faker
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from services.account_service import AccountService, TenantService
|
||||
from tasks.clean_notion_document_task import clean_notion_document_task
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
|
||||
class TestCleanNotionDocumentTask:
|
||||
@@ -76,7 +77,7 @@ class TestCleanNotionDocumentTask:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -208,7 +209,7 @@ class TestCleanNotionDocumentTask:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -252,7 +253,7 @@ class TestCleanNotionDocumentTask:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -345,7 +346,7 @@ class TestCleanNotionDocumentTask:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -431,7 +432,7 @@ class TestCleanNotionDocumentTask:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -546,7 +547,7 @@ class TestCleanNotionDocumentTask:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -642,7 +643,7 @@ class TestCleanNotionDocumentTask:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -724,7 +725,7 @@ class TestCleanNotionDocumentTask:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -834,7 +835,7 @@ class TestCleanNotionDocumentTask:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -951,7 +952,7 @@ class TestCleanNotionDocumentTask:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
@@ -1054,7 +1055,7 @@ class TestCleanNotionDocumentTask:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
@@ -15,6 +15,7 @@ from faker import Faker
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from services.account_service import AccountService, TenantService
|
||||
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
|
||||
class TestDealDatasetVectorIndexTask:
|
||||
@@ -61,7 +62,7 @@ class TestDealDatasetVectorIndexTask:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
@@ -26,9 +26,9 @@ def test_absolute_mode_calls_from_time_range():
|
||||
end_before = datetime.datetime(2024, 2, 1, 0, 0, 0)
|
||||
|
||||
with (
|
||||
patch("commands.create_message_clean_policy", return_value=policy),
|
||||
patch("commands.MessagesCleanService.from_time_range", return_value=service) as mock_from_time_range,
|
||||
patch("commands.MessagesCleanService.from_days") as mock_from_days,
|
||||
patch("commands.retention.create_message_clean_policy", return_value=policy),
|
||||
patch("commands.retention.MessagesCleanService.from_time_range", return_value=service) as mock_from_time_range,
|
||||
patch("commands.retention.MessagesCleanService.from_days") as mock_from_days,
|
||||
):
|
||||
clean_expired_messages.callback(
|
||||
batch_size=200,
|
||||
@@ -55,9 +55,9 @@ def test_relative_mode_before_days_only_calls_from_days():
|
||||
service = _mock_service()
|
||||
|
||||
with (
|
||||
patch("commands.create_message_clean_policy", return_value=policy),
|
||||
patch("commands.MessagesCleanService.from_days", return_value=service) as mock_from_days,
|
||||
patch("commands.MessagesCleanService.from_time_range") as mock_from_time_range,
|
||||
patch("commands.retention.create_message_clean_policy", return_value=policy),
|
||||
patch("commands.retention.MessagesCleanService.from_days", return_value=service) as mock_from_days,
|
||||
patch("commands.retention.MessagesCleanService.from_time_range") as mock_from_time_range,
|
||||
):
|
||||
clean_expired_messages.callback(
|
||||
batch_size=500,
|
||||
@@ -84,10 +84,10 @@ def test_relative_mode_with_from_days_ago_calls_from_time_range():
|
||||
fixed_now = datetime.datetime(2024, 8, 20, 12, 0, 0)
|
||||
|
||||
with (
|
||||
patch("commands.create_message_clean_policy", return_value=policy),
|
||||
patch("commands.MessagesCleanService.from_time_range", return_value=service) as mock_from_time_range,
|
||||
patch("commands.MessagesCleanService.from_days") as mock_from_days,
|
||||
patch("commands.naive_utc_now", return_value=fixed_now),
|
||||
patch("commands.retention.create_message_clean_policy", return_value=policy),
|
||||
patch("commands.retention.MessagesCleanService.from_time_range", return_value=service) as mock_from_time_range,
|
||||
patch("commands.retention.MessagesCleanService.from_days") as mock_from_days,
|
||||
patch("commands.retention.naive_utc_now", return_value=fixed_now),
|
||||
):
|
||||
clean_expired_messages.callback(
|
||||
batch_size=1000,
|
||||
|
||||
@@ -4,6 +4,7 @@ import types
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import commands
|
||||
from commands import system as system_commands
|
||||
from libs.db_migration_lock import LockNotOwnedError, RedisError
|
||||
|
||||
HEARTBEAT_WAIT_TIMEOUT_SECONDS = 5.0
|
||||
@@ -24,11 +25,11 @@ def _invoke_upgrade_db() -> int:
|
||||
|
||||
|
||||
def test_upgrade_db_skips_when_lock_not_acquired(monkeypatch, capsys):
|
||||
monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 1234)
|
||||
monkeypatch.setattr(system_commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 1234)
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = False
|
||||
commands.redis_client.lock.return_value = lock
|
||||
system_commands.redis_client.lock.return_value = lock
|
||||
|
||||
exit_code = _invoke_upgrade_db()
|
||||
captured = capsys.readouterr()
|
||||
@@ -36,18 +37,18 @@ def test_upgrade_db_skips_when_lock_not_acquired(monkeypatch, capsys):
|
||||
assert exit_code == 0
|
||||
assert "Database migration skipped" in captured.out
|
||||
|
||||
commands.redis_client.lock.assert_called_once_with(name="db_upgrade_lock", timeout=1234, thread_local=False)
|
||||
system_commands.redis_client.lock.assert_called_once_with(name="db_upgrade_lock", timeout=1234, thread_local=False)
|
||||
lock.acquire.assert_called_once_with(blocking=False)
|
||||
lock.release.assert_not_called()
|
||||
|
||||
|
||||
def test_upgrade_db_failure_not_masked_by_lock_release(monkeypatch, capsys):
|
||||
monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 321)
|
||||
monkeypatch.setattr(system_commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 321)
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.release.side_effect = LockNotOwnedError("simulated")
|
||||
commands.redis_client.lock.return_value = lock
|
||||
system_commands.redis_client.lock.return_value = lock
|
||||
|
||||
def _upgrade():
|
||||
raise RuntimeError("boom")
|
||||
@@ -60,18 +61,18 @@ def test_upgrade_db_failure_not_masked_by_lock_release(monkeypatch, capsys):
|
||||
assert exit_code == 1
|
||||
assert "Database migration failed: boom" in captured.out
|
||||
|
||||
commands.redis_client.lock.assert_called_once_with(name="db_upgrade_lock", timeout=321, thread_local=False)
|
||||
system_commands.redis_client.lock.assert_called_once_with(name="db_upgrade_lock", timeout=321, thread_local=False)
|
||||
lock.acquire.assert_called_once_with(blocking=False)
|
||||
lock.release.assert_called_once()
|
||||
|
||||
|
||||
def test_upgrade_db_success_ignores_lock_not_owned_on_release(monkeypatch, capsys):
|
||||
monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 999)
|
||||
monkeypatch.setattr(system_commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 999)
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.release.side_effect = LockNotOwnedError("simulated")
|
||||
commands.redis_client.lock.return_value = lock
|
||||
system_commands.redis_client.lock.return_value = lock
|
||||
|
||||
_install_fake_flask_migrate(monkeypatch, lambda: None)
|
||||
|
||||
@@ -81,7 +82,7 @@ def test_upgrade_db_success_ignores_lock_not_owned_on_release(monkeypatch, capsy
|
||||
assert exit_code == 0
|
||||
assert "Database migration successful!" in captured.out
|
||||
|
||||
commands.redis_client.lock.assert_called_once_with(name="db_upgrade_lock", timeout=999, thread_local=False)
|
||||
system_commands.redis_client.lock.assert_called_once_with(name="db_upgrade_lock", timeout=999, thread_local=False)
|
||||
lock.acquire.assert_called_once_with(blocking=False)
|
||||
lock.release.assert_called_once()
|
||||
|
||||
@@ -92,11 +93,11 @@ def test_upgrade_db_renews_lock_during_migration(monkeypatch, capsys):
|
||||
"""
|
||||
|
||||
# Use a small TTL so the heartbeat interval triggers quickly.
|
||||
monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 0.3)
|
||||
monkeypatch.setattr(system_commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 0.3)
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
commands.redis_client.lock.return_value = lock
|
||||
system_commands.redis_client.lock.return_value = lock
|
||||
|
||||
renewed = threading.Event()
|
||||
|
||||
@@ -120,11 +121,11 @@ def test_upgrade_db_renews_lock_during_migration(monkeypatch, capsys):
|
||||
|
||||
def test_upgrade_db_ignores_reacquire_errors(monkeypatch, capsys):
|
||||
# Use a small TTL so heartbeat runs during the upgrade call.
|
||||
monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 0.3)
|
||||
monkeypatch.setattr(system_commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 0.3)
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
commands.redis_client.lock.return_value = lock
|
||||
system_commands.redis_client.lock.return_value = lock
|
||||
|
||||
attempted = threading.Event()
|
||||
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from controllers.console.app import annotation as annotation_module
|
||||
|
||||
|
||||
def test_annotation_reply_payload_valid():
|
||||
"""Test AnnotationReplyPayload with valid data."""
|
||||
payload = annotation_module.AnnotationReplyPayload(
|
||||
score_threshold=0.5,
|
||||
embedding_provider_name="openai",
|
||||
embedding_model_name="text-embedding-3-small",
|
||||
)
|
||||
assert payload.score_threshold == 0.5
|
||||
assert payload.embedding_provider_name == "openai"
|
||||
assert payload.embedding_model_name == "text-embedding-3-small"
|
||||
|
||||
|
||||
def test_annotation_setting_update_payload_valid():
|
||||
"""Test AnnotationSettingUpdatePayload with valid data."""
|
||||
payload = annotation_module.AnnotationSettingUpdatePayload(
|
||||
score_threshold=0.75,
|
||||
)
|
||||
assert payload.score_threshold == 0.75
|
||||
|
||||
|
||||
def test_annotation_list_query_defaults():
|
||||
"""Test AnnotationListQuery with default parameters."""
|
||||
query = annotation_module.AnnotationListQuery()
|
||||
assert query.page == 1
|
||||
assert query.limit == 20
|
||||
assert query.keyword == ""
|
||||
|
||||
|
||||
def test_annotation_list_query_custom_page():
|
||||
"""Test AnnotationListQuery with custom page."""
|
||||
query = annotation_module.AnnotationListQuery(page=3, limit=50)
|
||||
assert query.page == 3
|
||||
assert query.limit == 50
|
||||
|
||||
|
||||
def test_annotation_list_query_with_keyword():
|
||||
"""Test AnnotationListQuery with keyword."""
|
||||
query = annotation_module.AnnotationListQuery(keyword="test")
|
||||
assert query.keyword == "test"
|
||||
|
||||
|
||||
def test_create_annotation_payload_with_message_id():
|
||||
"""Test CreateAnnotationPayload with message ID."""
|
||||
payload = annotation_module.CreateAnnotationPayload(
|
||||
message_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
question="What is AI?",
|
||||
)
|
||||
assert payload.message_id == "550e8400-e29b-41d4-a716-446655440000"
|
||||
assert payload.question == "What is AI?"
|
||||
|
||||
|
||||
def test_create_annotation_payload_with_text():
|
||||
"""Test CreateAnnotationPayload with text content."""
|
||||
payload = annotation_module.CreateAnnotationPayload(
|
||||
question="What is ML?",
|
||||
answer="Machine learning is...",
|
||||
)
|
||||
assert payload.question == "What is ML?"
|
||||
assert payload.answer == "Machine learning is..."
|
||||
|
||||
|
||||
def test_update_annotation_payload():
|
||||
"""Test UpdateAnnotationPayload."""
|
||||
payload = annotation_module.UpdateAnnotationPayload(
|
||||
question="Updated question",
|
||||
answer="Updated answer",
|
||||
)
|
||||
assert payload.question == "Updated question"
|
||||
assert payload.answer == "Updated answer"
|
||||
|
||||
|
||||
def test_annotation_reply_status_query_enable():
|
||||
"""Test AnnotationReplyStatusQuery with enable action."""
|
||||
query = annotation_module.AnnotationReplyStatusQuery(action="enable")
|
||||
assert query.action == "enable"
|
||||
|
||||
|
||||
def test_annotation_reply_status_query_disable():
|
||||
"""Test AnnotationReplyStatusQuery with disable action."""
|
||||
query = annotation_module.AnnotationReplyStatusQuery(action="disable")
|
||||
assert query.action == "disable"
|
||||
|
||||
|
||||
def test_annotation_file_payload_valid():
|
||||
"""Test AnnotationFilePayload with valid message ID."""
|
||||
payload = annotation_module.AnnotationFilePayload(message_id="550e8400-e29b-41d4-a716-446655440000")
|
||||
assert payload.message_id == "550e8400-e29b-41d4-a716-446655440000"
|
||||
@@ -13,6 +13,9 @@ from pandas.errors import ParserError
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console.wraps import annotation_import_concurrency_limit, annotation_import_rate_limit
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task
|
||||
|
||||
|
||||
class TestAnnotationImportRateLimiting:
|
||||
@@ -33,8 +36,6 @@ class TestAnnotationImportRateLimiting:
|
||||
|
||||
def test_rate_limit_per_minute_enforced(self, mock_redis, mock_current_account):
|
||||
"""Test that per-minute rate limit is enforced."""
|
||||
from controllers.console.wraps import annotation_import_rate_limit
|
||||
|
||||
# Simulate exceeding per-minute limit
|
||||
mock_redis.zcard.side_effect = [
|
||||
dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE + 1, # Minute check
|
||||
@@ -54,7 +55,6 @@ class TestAnnotationImportRateLimiting:
|
||||
|
||||
def test_rate_limit_per_hour_enforced(self, mock_redis, mock_current_account):
|
||||
"""Test that per-hour rate limit is enforced."""
|
||||
from controllers.console.wraps import annotation_import_rate_limit
|
||||
|
||||
# Simulate exceeding per-hour limit
|
||||
mock_redis.zcard.side_effect = [
|
||||
@@ -74,7 +74,6 @@ class TestAnnotationImportRateLimiting:
|
||||
|
||||
def test_rate_limit_within_limits_passes(self, mock_redis, mock_current_account):
|
||||
"""Test that requests within limits are allowed."""
|
||||
from controllers.console.wraps import annotation_import_rate_limit
|
||||
|
||||
# Simulate being under both limits
|
||||
mock_redis.zcard.return_value = 2
|
||||
@@ -110,7 +109,6 @@ class TestAnnotationImportConcurrencyControl:
|
||||
|
||||
def test_concurrency_limit_enforced(self, mock_redis, mock_current_account):
|
||||
"""Test that concurrent task limit is enforced."""
|
||||
from controllers.console.wraps import annotation_import_concurrency_limit
|
||||
|
||||
# Simulate max concurrent tasks already running
|
||||
mock_redis.zcard.return_value = dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT
|
||||
@@ -127,7 +125,6 @@ class TestAnnotationImportConcurrencyControl:
|
||||
|
||||
def test_concurrency_within_limit_passes(self, mock_redis, mock_current_account):
|
||||
"""Test that requests within concurrency limits are allowed."""
|
||||
from controllers.console.wraps import annotation_import_concurrency_limit
|
||||
|
||||
# Simulate being under concurrent task limit
|
||||
mock_redis.zcard.return_value = 1
|
||||
@@ -142,7 +139,6 @@ class TestAnnotationImportConcurrencyControl:
|
||||
|
||||
def test_stale_jobs_are_cleaned_up(self, mock_redis, mock_current_account):
|
||||
"""Test that old/stale job entries are removed."""
|
||||
from controllers.console.wraps import annotation_import_concurrency_limit
|
||||
|
||||
mock_redis.zcard.return_value = 0
|
||||
|
||||
@@ -203,7 +199,6 @@ class TestAnnotationImportServiceValidation:
|
||||
|
||||
def test_max_records_limit_enforced(self, mock_app, mock_db_session):
|
||||
"""Test that files with too many records are rejected."""
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
# Create CSV with too many records
|
||||
max_records = dify_config.ANNOTATION_IMPORT_MAX_RECORDS
|
||||
@@ -229,7 +224,6 @@ class TestAnnotationImportServiceValidation:
|
||||
|
||||
def test_min_records_limit_enforced(self, mock_app, mock_db_session):
|
||||
"""Test that files with too few valid records are rejected."""
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
# Create CSV with only header (no data rows)
|
||||
csv_content = "question,answer\n"
|
||||
@@ -249,7 +243,6 @@ class TestAnnotationImportServiceValidation:
|
||||
|
||||
def test_invalid_csv_format_handled(self, mock_app, mock_db_session):
|
||||
"""Test that invalid CSV format is handled gracefully."""
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
# Any content is fine once we force ParserError
|
||||
csv_content = 'invalid,csv,format\nwith,unbalanced,quotes,and"stuff'
|
||||
@@ -270,7 +263,6 @@ class TestAnnotationImportServiceValidation:
|
||||
|
||||
def test_valid_import_succeeds(self, mock_app, mock_db_session):
|
||||
"""Test that valid import request succeeds."""
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
# Create valid CSV
|
||||
csv_content = "question,answer\nWhat is AI?,Artificial Intelligence\nWhat is ML?,Machine Learning\n"
|
||||
@@ -300,18 +292,10 @@ class TestAnnotationImportServiceValidation:
|
||||
class TestAnnotationImportTaskOptimization:
|
||||
"""Test optimizations in batch import task."""
|
||||
|
||||
def test_task_has_timeout_configured(self):
|
||||
"""Test that task has proper timeout configuration."""
|
||||
from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task
|
||||
|
||||
# Verify task configuration
|
||||
assert hasattr(batch_import_annotations_task, "time_limit")
|
||||
assert hasattr(batch_import_annotations_task, "soft_time_limit")
|
||||
|
||||
# Check timeout values are reasonable
|
||||
# Hard limit should be 6 minutes (360s)
|
||||
# Soft limit should be 5 minutes (300s)
|
||||
# Note: actual values depend on Celery configuration
|
||||
def test_task_is_registered_with_queue(self):
|
||||
"""Test that task is registered with the correct queue."""
|
||||
assert hasattr(batch_import_annotations_task, "apply_async")
|
||||
assert hasattr(batch_import_annotations_task, "delay")
|
||||
|
||||
|
||||
class TestConfigurationValues:
|
||||
|
||||
585
api/tests/unit_tests/controllers/console/app/test_app_apis.py
Normal file
585
api/tests/unit_tests/controllers/console/app/test_app_apis.py
Normal file
@@ -0,0 +1,585 @@
|
||||
"""
|
||||
Additional tests to improve coverage for low-coverage modules in controllers/console/app.
|
||||
Target: increase coverage for files with <75% coverage.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.console.app import (
|
||||
annotation as annotation_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
completion as completion_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
message as message_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
ops_trace as ops_trace_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
site as site_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
statistic as statistic_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
workflow_app_log as workflow_app_log_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
workflow_draft_variable as workflow_draft_variable_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
workflow_statistic as workflow_statistic_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
workflow_trigger as workflow_trigger_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
wraps as wraps_module,
|
||||
)
|
||||
from controllers.console.app.completion import ChatMessagePayload, CompletionMessagePayload
|
||||
from controllers.console.app.mcp_server import MCPServerCreatePayload, MCPServerUpdatePayload
|
||||
from controllers.console.app.ops_trace import TraceConfigPayload, TraceProviderQuery
|
||||
from controllers.console.app.site import AppSiteUpdatePayload
|
||||
from controllers.console.app.workflow import AdvancedChatWorkflowRunPayload, SyncDraftWorkflowPayload
|
||||
from controllers.console.app.workflow_app_log import WorkflowAppLogQuery
|
||||
from controllers.console.app.workflow_draft_variable import WorkflowDraftVariableUpdatePayload
|
||||
from controllers.console.app.workflow_statistic import WorkflowStatisticQuery
|
||||
from controllers.console.app.workflow_trigger import Parser, ParserEnable
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
class _ConnContext:
|
||||
def __init__(self, rows):
|
||||
self._rows = rows
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def execute(self, _query, _args):
|
||||
return self._rows
|
||||
|
||||
|
||||
# ========== Completion Tests ==========
|
||||
class TestCompletionEndpoints:
|
||||
"""Tests for completion API endpoints."""
|
||||
|
||||
def test_completion_create_payload(self):
|
||||
"""Test completion creation payload."""
|
||||
payload = CompletionMessagePayload(inputs={"prompt": "test"}, model_config={})
|
||||
assert payload.inputs == {"prompt": "test"}
|
||||
|
||||
def test_chat_message_payload_uuid_validation(self):
|
||||
payload = ChatMessagePayload(
|
||||
inputs={},
|
||||
model_config={},
|
||||
query="hi",
|
||||
conversation_id=str(uuid.uuid4()),
|
||||
parent_message_id=str(uuid.uuid4()),
|
||||
)
|
||||
assert payload.query == "hi"
|
||||
|
||||
def test_completion_api_success(self, app, monkeypatch):
|
||||
api = completion_module.CompletionMessageApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
class DummyAccount:
|
||||
pass
|
||||
|
||||
dummy_account = DummyAccount()
|
||||
|
||||
monkeypatch.setattr(completion_module, "current_user", dummy_account)
|
||||
monkeypatch.setattr(completion_module, "Account", DummyAccount)
|
||||
monkeypatch.setattr(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
lambda **_kwargs: {"text": "ok"},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
completion_module.helper,
|
||||
"compact_generate_response",
|
||||
lambda response: {"result": response},
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/",
|
||||
json={"inputs": {}, "model_config": {}, "query": "hi"},
|
||||
):
|
||||
resp = method(app_model=MagicMock(id="app-1"))
|
||||
|
||||
assert resp == {"result": {"text": "ok"}}
|
||||
|
||||
def test_completion_api_conversation_not_exists(self, app, monkeypatch):
|
||||
api = completion_module.CompletionMessageApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
class DummyAccount:
|
||||
pass
|
||||
|
||||
dummy_account = DummyAccount()
|
||||
|
||||
monkeypatch.setattr(completion_module, "current_user", dummy_account)
|
||||
monkeypatch.setattr(completion_module, "Account", DummyAccount)
|
||||
monkeypatch.setattr(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
lambda **_kwargs: (_ for _ in ()).throw(
|
||||
completion_module.services.errors.conversation.ConversationNotExistsError()
|
||||
),
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/",
|
||||
json={"inputs": {}, "model_config": {}, "query": "hi"},
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(app_model=MagicMock(id="app-1"))
|
||||
|
||||
def test_completion_api_provider_not_initialized(self, app, monkeypatch):
|
||||
api = completion_module.CompletionMessageApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
class DummyAccount:
|
||||
pass
|
||||
|
||||
dummy_account = DummyAccount()
|
||||
|
||||
monkeypatch.setattr(completion_module, "current_user", dummy_account)
|
||||
monkeypatch.setattr(completion_module, "Account", DummyAccount)
|
||||
monkeypatch.setattr(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
lambda **_kwargs: (_ for _ in ()).throw(completion_module.ProviderTokenNotInitError("x")),
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/",
|
||||
json={"inputs": {}, "model_config": {}, "query": "hi"},
|
||||
):
|
||||
with pytest.raises(completion_module.ProviderNotInitializeError):
|
||||
method(app_model=MagicMock(id="app-1"))
|
||||
|
||||
def test_completion_api_quota_exceeded(self, app, monkeypatch):
|
||||
api = completion_module.CompletionMessageApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
class DummyAccount:
|
||||
pass
|
||||
|
||||
dummy_account = DummyAccount()
|
||||
|
||||
monkeypatch.setattr(completion_module, "current_user", dummy_account)
|
||||
monkeypatch.setattr(completion_module, "Account", DummyAccount)
|
||||
monkeypatch.setattr(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
lambda **_kwargs: (_ for _ in ()).throw(completion_module.QuotaExceededError()),
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/",
|
||||
json={"inputs": {}, "model_config": {}, "query": "hi"},
|
||||
):
|
||||
with pytest.raises(completion_module.ProviderQuotaExceededError):
|
||||
method(app_model=MagicMock(id="app-1"))
|
||||
|
||||
|
||||
# ========== OpsTrace Tests ==========
|
||||
class TestOpsTraceEndpoints:
|
||||
"""Tests for ops_trace endpoint."""
|
||||
|
||||
def test_ops_trace_query_basic(self):
|
||||
"""Test ops_trace query."""
|
||||
query = TraceProviderQuery(tracing_provider="langfuse")
|
||||
assert query.tracing_provider == "langfuse"
|
||||
|
||||
def test_ops_trace_config_payload(self):
|
||||
payload = TraceConfigPayload(tracing_provider="langfuse", tracing_config={"api_key": "k"})
|
||||
assert payload.tracing_config["api_key"] == "k"
|
||||
|
||||
def test_trace_app_config_get_empty(self, app, monkeypatch):
|
||||
api = ops_trace_module.TraceAppConfigApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(
|
||||
ops_trace_module.OpsService,
|
||||
"get_tracing_app_config",
|
||||
lambda **_kwargs: None,
|
||||
)
|
||||
|
||||
with app.test_request_context("/?tracing_provider=langfuse"):
|
||||
result = method(app_id="app-1")
|
||||
|
||||
assert result == {"has_not_configured": True}
|
||||
|
||||
def test_trace_app_config_post_invalid(self, app, monkeypatch):
|
||||
api = ops_trace_module.TraceAppConfigApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(
|
||||
ops_trace_module.OpsService,
|
||||
"create_tracing_app_config",
|
||||
lambda **_kwargs: {"error": True},
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/",
|
||||
json={"tracing_provider": "langfuse", "tracing_config": {"api_key": "k"}},
|
||||
):
|
||||
with pytest.raises(BadRequest):
|
||||
method(app_id="app-1")
|
||||
|
||||
def test_trace_app_config_delete_not_found(self, app, monkeypatch):
|
||||
api = ops_trace_module.TraceAppConfigApi()
|
||||
method = _unwrap(api.delete)
|
||||
|
||||
monkeypatch.setattr(
|
||||
ops_trace_module.OpsService,
|
||||
"delete_tracing_app_config",
|
||||
lambda **_kwargs: False,
|
||||
)
|
||||
|
||||
with app.test_request_context("/?tracing_provider=langfuse"):
|
||||
with pytest.raises(BadRequest):
|
||||
method(app_id="app-1")
|
||||
|
||||
|
||||
# ========== Site Tests ==========
|
||||
class TestSiteEndpoints:
|
||||
"""Tests for site endpoint."""
|
||||
|
||||
def test_site_response_structure(self):
|
||||
"""Test site response structure."""
|
||||
payload = AppSiteUpdatePayload(title="My Site", description="Test site")
|
||||
assert payload.title == "My Site"
|
||||
|
||||
def test_site_default_language_validation(self):
|
||||
payload = AppSiteUpdatePayload(default_language="en-US")
|
||||
assert payload.default_language == "en-US"
|
||||
|
||||
def test_app_site_update_post(self, app, monkeypatch):
|
||||
api = site_module.AppSite()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
site = MagicMock()
|
||||
query = MagicMock()
|
||||
query.where.return_value.first.return_value = site
|
||||
monkeypatch.setattr(
|
||||
site_module.db,
|
||||
"session",
|
||||
MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
site_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="u1"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr(site_module, "naive_utc_now", lambda: "now")
|
||||
|
||||
with app.test_request_context("/", json={"title": "My Site"}):
|
||||
result = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert result is site
|
||||
|
||||
def test_app_site_access_token_reset(self, app, monkeypatch):
|
||||
api = site_module.AppSiteAccessTokenReset()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
site = MagicMock()
|
||||
query = MagicMock()
|
||||
query.where.return_value.first.return_value = site
|
||||
monkeypatch.setattr(
|
||||
site_module.db,
|
||||
"session",
|
||||
MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None),
|
||||
)
|
||||
monkeypatch.setattr(site_module.Site, "generate_code", lambda *_args, **_kwargs: "code")
|
||||
monkeypatch.setattr(
|
||||
site_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="u1"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr(site_module, "naive_utc_now", lambda: "now")
|
||||
|
||||
with app.test_request_context("/"):
|
||||
result = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert result is site
|
||||
|
||||
|
||||
# ========== Workflow Tests ==========
|
||||
class TestWorkflowEndpoints:
|
||||
"""Tests for workflow endpoints."""
|
||||
|
||||
def test_workflow_copy_payload(self):
|
||||
"""Test workflow copy payload."""
|
||||
payload = SyncDraftWorkflowPayload(graph={}, features={})
|
||||
assert payload.graph == {}
|
||||
|
||||
def test_workflow_mode_query(self):
|
||||
"""Test workflow mode query."""
|
||||
payload = AdvancedChatWorkflowRunPayload(inputs={}, query="hi")
|
||||
assert payload.query == "hi"
|
||||
|
||||
|
||||
# ========== Workflow App Log Tests ==========
|
||||
class TestWorkflowAppLogEndpoints:
|
||||
"""Tests for workflow app log endpoints."""
|
||||
|
||||
def test_workflow_app_log_query(self):
|
||||
"""Test workflow app log query."""
|
||||
query = WorkflowAppLogQuery(keyword="test", page=1, limit=20)
|
||||
assert query.keyword == "test"
|
||||
|
||||
def test_workflow_app_log_query_detail_bool(self):
|
||||
query = WorkflowAppLogQuery(detail="true")
|
||||
assert query.detail is True
|
||||
|
||||
def test_workflow_app_log_api_get(self, app, monkeypatch):
|
||||
api = workflow_app_log_module.WorkflowAppLogApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(workflow_app_log_module, "db", SimpleNamespace(engine=MagicMock()))
|
||||
|
||||
class DummySession:
|
||||
def __enter__(self):
|
||||
return "session"
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(workflow_app_log_module, "Session", lambda *args, **kwargs: DummySession())
|
||||
|
||||
def fake_get_paginate(self, **_kwargs):
|
||||
return {"items": [], "total": 0}
|
||||
|
||||
monkeypatch.setattr(
|
||||
workflow_app_log_module.WorkflowAppService,
|
||||
"get_paginate_workflow_app_logs",
|
||||
fake_get_paginate,
|
||||
)
|
||||
|
||||
with app.test_request_context("/?page=1&limit=20"):
|
||||
result = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert result == {"items": [], "total": 0}
|
||||
|
||||
|
||||
# ========== Workflow Draft Variable Tests ==========
|
||||
class TestWorkflowDraftVariableEndpoints:
|
||||
"""Tests for workflow draft variable endpoints."""
|
||||
|
||||
def test_workflow_variable_creation(self):
|
||||
"""Test workflow variable creation."""
|
||||
payload = WorkflowDraftVariableUpdatePayload(name="var1", value="test")
|
||||
assert payload.name == "var1"
|
||||
|
||||
def test_workflow_variable_collection_get(self, app, monkeypatch):
|
||||
api = workflow_draft_variable_module.WorkflowVariableCollectionApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(workflow_draft_variable_module, "db", SimpleNamespace(engine=MagicMock()))
|
||||
|
||||
class DummySession:
|
||||
def __enter__(self):
|
||||
return "session"
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
class DummyDraftService:
|
||||
def __init__(self, session):
|
||||
self.session = session
|
||||
|
||||
def list_variables_without_values(self, **_kwargs):
|
||||
return {"items": [], "total": 0}
|
||||
|
||||
monkeypatch.setattr(workflow_draft_variable_module, "Session", lambda *args, **kwargs: DummySession())
|
||||
|
||||
class DummyWorkflowService:
|
||||
def is_workflow_exist(self, *args, **kwargs):
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(workflow_draft_variable_module, "WorkflowDraftVariableService", DummyDraftService)
|
||||
monkeypatch.setattr(workflow_draft_variable_module, "WorkflowService", DummyWorkflowService)
|
||||
|
||||
with app.test_request_context("/?page=1&limit=20"):
|
||||
result = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert result == {"items": [], "total": 0}
|
||||
|
||||
|
||||
# ========== Workflow Statistic Tests ==========
|
||||
class TestWorkflowStatisticEndpoints:
|
||||
"""Tests for workflow statistic endpoints."""
|
||||
|
||||
def test_workflow_statistic_time_range(self):
|
||||
"""Test workflow statistic time range query."""
|
||||
query = WorkflowStatisticQuery(start="2024-01-01", end="2024-12-31")
|
||||
assert query.start == "2024-01-01"
|
||||
|
||||
def test_workflow_statistic_blank_to_none(self):
|
||||
query = WorkflowStatisticQuery(start="", end="")
|
||||
assert query.start is None
|
||||
assert query.end is None
|
||||
|
||||
def test_workflow_daily_runs_statistic(self, app, monkeypatch):
|
||||
monkeypatch.setattr(workflow_statistic_module, "db", SimpleNamespace(engine=MagicMock()))
|
||||
monkeypatch.setattr(
|
||||
workflow_statistic_module.DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_run_repository",
|
||||
lambda *_args, **_kwargs: SimpleNamespace(get_daily_runs_statistics=lambda **_kw: [{"date": "2024-01-01"}]),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_statistic_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_statistic_module,
|
||||
"parse_time_range",
|
||||
lambda *_args, **_kwargs: (None, None),
|
||||
)
|
||||
|
||||
api = workflow_statistic_module.WorkflowDailyRunsStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
response = method(app_model=SimpleNamespace(tenant_id="t1", id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-01"}]}
|
||||
|
||||
def test_workflow_daily_terminals_statistic(self, app, monkeypatch):
|
||||
monkeypatch.setattr(workflow_statistic_module, "db", SimpleNamespace(engine=MagicMock()))
|
||||
monkeypatch.setattr(
|
||||
workflow_statistic_module.DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_run_repository",
|
||||
lambda *_args, **_kwargs: SimpleNamespace(
|
||||
get_daily_terminals_statistics=lambda **_kw: [{"date": "2024-01-02"}]
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_statistic_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_statistic_module,
|
||||
"parse_time_range",
|
||||
lambda *_args, **_kwargs: (None, None),
|
||||
)
|
||||
|
||||
api = workflow_statistic_module.WorkflowDailyTerminalsStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
response = method(app_model=SimpleNamespace(tenant_id="t1", id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-02"}]}
|
||||
|
||||
|
||||
# ========== Workflow Trigger Tests ==========
|
||||
class TestWorkflowTriggerEndpoints:
|
||||
"""Tests for workflow trigger endpoints."""
|
||||
|
||||
def test_webhook_trigger_payload(self):
|
||||
"""Test webhook trigger payload."""
|
||||
payload = Parser(node_id="node-1")
|
||||
assert payload.node_id == "node-1"
|
||||
|
||||
enable_payload = ParserEnable(trigger_id="trigger-1", enable_trigger=True)
|
||||
assert enable_payload.enable_trigger is True
|
||||
|
||||
def test_webhook_trigger_api_get(self, app, monkeypatch):
|
||||
api = workflow_trigger_module.WebhookTriggerApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(workflow_trigger_module, "db", SimpleNamespace(engine=MagicMock()))
|
||||
|
||||
trigger = MagicMock()
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = trigger
|
||||
|
||||
class DummySession:
|
||||
def __enter__(self):
|
||||
return session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(workflow_trigger_module, "Session", lambda *_args, **_kwargs: DummySession())
|
||||
|
||||
with app.test_request_context("/?node_id=node-1"):
|
||||
result = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert result is trigger
|
||||
|
||||
|
||||
# ========== Wraps Tests ==========
|
||||
class TestWrapsEndpoints:
|
||||
"""Tests for wraps utility functions."""
|
||||
|
||||
def test_get_app_model_context(self):
|
||||
"""Test get_app_model wrapper context."""
|
||||
# These are decorator functions, so we test their availability
|
||||
assert hasattr(wraps_module, "get_app_model")
|
||||
|
||||
|
||||
# ========== MCP Server Tests ==========
|
||||
class TestMCPServerEndpoints:
|
||||
"""Tests for MCP server endpoints."""
|
||||
|
||||
def test_mcp_server_connection(self):
|
||||
"""Test MCP server connection."""
|
||||
payload = MCPServerCreatePayload(parameters={"url": "http://localhost:3000"})
|
||||
assert payload.parameters["url"] == "http://localhost:3000"
|
||||
|
||||
def test_mcp_server_update_payload(self):
|
||||
payload = MCPServerUpdatePayload(id="server-1", parameters={"timeout": 30}, status="active")
|
||||
assert payload.status == "active"
|
||||
|
||||
|
||||
# ========== Error Handling Tests ==========
|
||||
class TestErrorHandling:
|
||||
"""Tests for error handling in various endpoints."""
|
||||
|
||||
def test_annotation_list_query_validation(self):
|
||||
"""Test annotation list query validation."""
|
||||
with pytest.raises(ValueError):
|
||||
annotation_module.AnnotationListQuery(page=0)
|
||||
|
||||
|
||||
# ========== Integration-like Tests ==========
|
||||
class TestPayloadIntegration:
|
||||
"""Integration tests for payload handling."""
|
||||
|
||||
def test_multiple_payload_types(self):
|
||||
"""Test handling of multiple payload types."""
|
||||
payloads = [
|
||||
annotation_module.AnnotationReplyPayload(
|
||||
score_threshold=0.5, embedding_provider_name="openai", embedding_model_name="text-embedding-3-small"
|
||||
),
|
||||
message_module.MessageFeedbackPayload(message_id=str(uuid.uuid4()), rating="like"),
|
||||
statistic_module.StatisticTimeRangeQuery(start="2024-01-01"),
|
||||
]
|
||||
assert len(payloads) == 3
|
||||
assert all(p is not None for p in payloads)
|
||||
@@ -0,0 +1,157 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.app import app_import as app_import_module
|
||||
from services.app_dsl_service import ImportStatus
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
class _Result:
|
||||
def __init__(self, status: ImportStatus, app_id: str | None = "app-1"):
|
||||
self.status = status
|
||||
self.app_id = app_id
|
||||
|
||||
def model_dump(self, mode: str = "json"):
|
||||
return {"status": self.status, "app_id": self.app_id}
|
||||
|
||||
|
||||
class _SessionContext:
|
||||
def __init__(self, session):
|
||||
self._session = session
|
||||
|
||||
def __enter__(self):
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
|
||||
def _install_session(monkeypatch: pytest.MonkeyPatch, session: MagicMock) -> None:
|
||||
monkeypatch.setattr(app_import_module, "Session", lambda *_: _SessionContext(session))
|
||||
monkeypatch.setattr(app_import_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
|
||||
def _install_features(monkeypatch: pytest.MonkeyPatch, enabled: bool) -> None:
|
||||
features = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=enabled))
|
||||
monkeypatch.setattr(app_import_module.FeatureService, "get_system_features", lambda: features)
|
||||
|
||||
|
||||
def test_import_post_returns_failed_status(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
session = MagicMock()
|
||||
_install_session(monkeypatch, session)
|
||||
_install_features(monkeypatch, enabled=False)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
|
||||
session.commit.assert_called_once()
|
||||
assert status == 400
|
||||
assert response["status"] == ImportStatus.FAILED
|
||||
|
||||
|
||||
def test_import_post_returns_pending_status(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
session = MagicMock()
|
||||
_install_session(monkeypatch, session)
|
||||
_install_features(monkeypatch, enabled=False)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.PENDING),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
|
||||
session.commit.assert_called_once()
|
||||
assert status == 202
|
||||
assert response["status"] == ImportStatus.PENDING
|
||||
|
||||
|
||||
def test_import_post_updates_webapp_auth_when_enabled(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
session = MagicMock()
|
||||
_install_session(monkeypatch, session)
|
||||
_install_features(monkeypatch, enabled=True)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"),
|
||||
)
|
||||
update_access = MagicMock()
|
||||
monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
|
||||
session.commit.assert_called_once()
|
||||
update_access.assert_called_once_with("app-123", "private")
|
||||
assert status == 200
|
||||
assert response["status"] == ImportStatus.COMPLETED
|
||||
|
||||
|
||||
def test_import_confirm_returns_failed_status(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportConfirmApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
session = MagicMock()
|
||||
_install_session(monkeypatch, session)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"confirm_import",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"):
|
||||
response, status = method(import_id="import-1")
|
||||
|
||||
session.commit.assert_called_once()
|
||||
assert status == 400
|
||||
assert response["status"] == ImportStatus.FAILED
|
||||
|
||||
|
||||
def test_import_check_dependencies_returns_result(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportCheckDependenciesApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
session = MagicMock()
|
||||
_install_session(monkeypatch, session)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"check_dependencies",
|
||||
lambda *_args, **_kwargs: SimpleNamespace(model_dump=lambda mode="json": {"leaked_dependencies": []}),
|
||||
)
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports/app-1/check-dependencies", method="GET"):
|
||||
response, status = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert status == 200
|
||||
assert response["leaked_dependencies"] == []
|
||||
292
api/tests/unit_tests/controllers/console/app/test_audio.py
Normal file
292
api/tests/unit_tests/controllers/console/app/test_audio.py
Normal file
@@ -0,0 +1,292 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.console.app.audio import ChatMessageAudioApi, ChatMessageTextApi, TextModesApi
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
CompletionRequestError,
|
||||
NoAudioUploadedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderNotSupportSpeechToTextError,
|
||||
ProviderQuotaExceededError,
|
||||
UnsupportedAudioTypeError,
|
||||
)
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeError
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
NoAudioUploadedServiceError,
|
||||
ProviderNotSupportSpeechToTextServiceError,
|
||||
ProviderNotSupportTextToSpeechLanageServiceError,
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def _file_data():
|
||||
return FileStorage(stream=io.BytesIO(b"audio"), filename="audio.wav", content_type="audio/wav")
|
||||
|
||||
|
||||
def test_console_audio_api_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "ok"})
|
||||
api = ChatMessageAudioApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
|
||||
with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}):
|
||||
response = handler(app_model=app_model)
|
||||
|
||||
assert response == {"text": "ok"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("exc", "expected"),
|
||||
[
|
||||
(AppModelConfigBrokenError(), AppUnavailableError),
|
||||
(NoAudioUploadedServiceError(), NoAudioUploadedError),
|
||||
(AudioTooLargeServiceError("too big"), AudioTooLargeError),
|
||||
(UnsupportedAudioTypeServiceError(), UnsupportedAudioTypeError),
|
||||
(ProviderNotSupportSpeechToTextServiceError(), ProviderNotSupportSpeechToTextError),
|
||||
(ProviderTokenNotInitError("token"), ProviderNotInitializeError),
|
||||
(QuotaExceededError(), ProviderQuotaExceededError),
|
||||
(ModelCurrentlyNotSupportError(), ProviderModelCurrentlyNotSupportError),
|
||||
(InvokeError("invoke"), CompletionRequestError),
|
||||
],
|
||||
)
|
||||
def test_console_audio_api_error_mapping(app, monkeypatch: pytest.MonkeyPatch, exc, expected) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(exc))
|
||||
api = ChatMessageAudioApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
|
||||
with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}):
|
||||
with pytest.raises(expected):
|
||||
handler(app_model=app_model)
|
||||
|
||||
|
||||
def test_console_audio_api_unhandled_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("boom")))
|
||||
api = ChatMessageAudioApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
|
||||
with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}):
|
||||
with pytest.raises(InternalServerError):
|
||||
handler(app_model=app_model)
|
||||
|
||||
|
||||
def test_console_text_api_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
|
||||
|
||||
api = ChatMessageTextApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app/text-to-audio",
|
||||
method="POST",
|
||||
json={"text": "hello", "voice": "v"},
|
||||
):
|
||||
response = handler(app_model=app_model)
|
||||
|
||||
assert response == {"audio": "ok"}
|
||||
|
||||
|
||||
def test_console_text_api_error_mapping(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: (_ for _ in ()).throw(QuotaExceededError()))
|
||||
|
||||
api = ChatMessageTextApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app/text-to-audio",
|
||||
method="POST",
|
||||
json={"text": "hello"},
|
||||
):
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
handler(app_model=app_model)
|
||||
|
||||
|
||||
def test_console_text_modes_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"])
|
||||
|
||||
api = TextModesApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(tenant_id="t1")
|
||||
|
||||
with app.test_request_context("/console/api/apps/app/text-to-audio/voices?language=en", method="GET"):
|
||||
response = handler(app_model=app_model)
|
||||
|
||||
assert response == ["voice-1"]
|
||||
|
||||
|
||||
def test_console_text_modes_language_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
AudioService,
|
||||
"transcript_tts_voices",
|
||||
lambda **_kwargs: (_ for _ in ()).throw(ProviderNotSupportTextToSpeechLanageServiceError()),
|
||||
)
|
||||
|
||||
api = TextModesApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(tenant_id="t1")
|
||||
|
||||
with app.test_request_context("/console/api/apps/app/text-to-audio/voices?language=en", method="GET"):
|
||||
with pytest.raises(AppUnavailableError):
|
||||
handler(app_model=app_model)
|
||||
|
||||
|
||||
def test_audio_to_text_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = ChatMessageAudioApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
response_payload = {"text": "hello"}
|
||||
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: response_payload)
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
data = {"file": (io.BytesIO(b"x"), "sample.wav")}
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/audio-to-text",
|
||||
method="POST",
|
||||
data=data,
|
||||
content_type="multipart/form-data",
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
|
||||
assert response == response_payload
|
||||
|
||||
|
||||
def test_audio_to_text_maps_audio_too_large(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = ChatMessageAudioApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(
|
||||
AudioService,
|
||||
"transcript_asr",
|
||||
lambda **_kwargs: (_ for _ in ()).throw(AudioTooLargeServiceError("too large")),
|
||||
)
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
data = {"file": (io.BytesIO(b"x"), "sample.wav")}
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/audio-to-text",
|
||||
method="POST",
|
||||
data=data,
|
||||
content_type="multipart/form-data",
|
||||
):
|
||||
with pytest.raises(AudioTooLargeError):
|
||||
method(app_model=app_model)
|
||||
|
||||
|
||||
def test_text_to_audio_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = ChatMessageTextApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/text-to-audio",
|
||||
method="POST",
|
||||
json={"text": "hello"},
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
|
||||
assert response == {"audio": "ok"}
|
||||
|
||||
|
||||
def test_text_to_audio_voices_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = TextModesApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"])
|
||||
|
||||
app_model = SimpleNamespace(tenant_id="tenant-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/text-to-audio/voices",
|
||||
method="GET",
|
||||
query_string={"language": "en-US"},
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
|
||||
assert response == ["voice-1"]
|
||||
|
||||
|
||||
def test_audio_to_text_with_invalid_file(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = ChatMessageAudioApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "test"})
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
data = {"file": (io.BytesIO(b"invalid"), "sample.xyz")}
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/audio-to-text",
|
||||
method="POST",
|
||||
data=data,
|
||||
content_type="multipart/form-data",
|
||||
):
|
||||
# Should not raise, AudioService is mocked
|
||||
response = method(app_model=app_model)
|
||||
assert response == {"text": "test"}
|
||||
|
||||
|
||||
def test_text_to_audio_with_language_param(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = ChatMessageTextApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "test"})
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/text-to-audio",
|
||||
method="POST",
|
||||
json={"text": "hello", "language": "en-US"},
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
assert response == {"audio": "test"}
|
||||
|
||||
|
||||
def test_text_to_audio_voices_with_language_filter(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = TextModesApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(
|
||||
AudioService,
|
||||
"transcript_tts_voices",
|
||||
lambda **_kwargs: [{"id": "voice-1", "name": "Voice 1"}],
|
||||
)
|
||||
|
||||
app_model = SimpleNamespace(tenant_id="tenant-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/text-to-audio/voices?language=en-US",
|
||||
method="GET",
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
assert isinstance(response, list)
|
||||
156
api/tests/unit_tests/controllers/console/app/test_audio_api.py
Normal file
156
api/tests/unit_tests/controllers/console/app/test_audio_api.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.app import audio as audio_module
|
||||
from controllers.console.app.error import AudioTooLargeError
|
||||
from services.errors.audio import AudioTooLargeServiceError
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def test_audio_to_text_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = audio_module.ChatMessageAudioApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
response_payload = {"text": "hello"}
|
||||
monkeypatch.setattr(audio_module.AudioService, "transcript_asr", lambda **_kwargs: response_payload)
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
data = {"file": (io.BytesIO(b"x"), "sample.wav")}
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/audio-to-text",
|
||||
method="POST",
|
||||
data=data,
|
||||
content_type="multipart/form-data",
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
|
||||
assert response == response_payload
|
||||
|
||||
|
||||
def test_audio_to_text_maps_audio_too_large(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = audio_module.ChatMessageAudioApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(
|
||||
audio_module.AudioService,
|
||||
"transcript_asr",
|
||||
lambda **_kwargs: (_ for _ in ()).throw(AudioTooLargeServiceError("too large")),
|
||||
)
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
data = {"file": (io.BytesIO(b"x"), "sample.wav")}
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/audio-to-text",
|
||||
method="POST",
|
||||
data=data,
|
||||
content_type="multipart/form-data",
|
||||
):
|
||||
with pytest.raises(AudioTooLargeError):
|
||||
method(app_model=app_model)
|
||||
|
||||
|
||||
def test_text_to_audio_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = audio_module.ChatMessageTextApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(audio_module.AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/text-to-audio",
|
||||
method="POST",
|
||||
json={"text": "hello"},
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
|
||||
assert response == {"audio": "ok"}
|
||||
|
||||
|
||||
def test_text_to_audio_voices_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = audio_module.TextModesApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(audio_module.AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"])
|
||||
|
||||
app_model = SimpleNamespace(tenant_id="tenant-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/text-to-audio/voices",
|
||||
method="GET",
|
||||
query_string={"language": "en-US"},
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
|
||||
assert response == ["voice-1"]
|
||||
|
||||
|
||||
def test_audio_to_text_with_invalid_file(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = audio_module.ChatMessageAudioApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(audio_module.AudioService, "transcript_asr", lambda **_kwargs: {"text": "test"})
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
data = {"file": (io.BytesIO(b"invalid"), "sample.xyz")}
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/audio-to-text",
|
||||
method="POST",
|
||||
data=data,
|
||||
content_type="multipart/form-data",
|
||||
):
|
||||
# Should not raise, AudioService is mocked
|
||||
response = method(app_model=app_model)
|
||||
assert response == {"text": "test"}
|
||||
|
||||
|
||||
def test_text_to_audio_with_language_param(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = audio_module.ChatMessageTextApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(audio_module.AudioService, "transcript_tts", lambda **_kwargs: {"audio": "test"})
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/text-to-audio",
|
||||
method="POST",
|
||||
json={"text": "hello", "language": "en-US"},
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
assert response == {"audio": "test"}
|
||||
|
||||
|
||||
def test_text_to_audio_voices_with_language_filter(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = audio_module.TextModesApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(
|
||||
audio_module.AudioService,
|
||||
"transcript_tts_voices",
|
||||
lambda **_kwargs: [{"id": "voice-1", "name": "Voice 1"}],
|
||||
)
|
||||
|
||||
app_model = SimpleNamespace(tenant_id="tenant-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/text-to-audio/voices?language=en-US",
|
||||
method="GET",
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
assert isinstance(response, list)
|
||||
@@ -0,0 +1,130 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.console.app import conversation as conversation_module
|
||||
from models.model import AppMode
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def _make_account():
|
||||
return SimpleNamespace(timezone="UTC", id="u1")
|
||||
|
||||
|
||||
def test_completion_conversation_list_returns_paginated_result(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = conversation_module.CompletionConversationApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
account = _make_account()
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1"))
|
||||
monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None))
|
||||
|
||||
paginate_result = MagicMock()
|
||||
monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/completion-conversations", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response is paginate_result
|
||||
|
||||
|
||||
def test_completion_conversation_list_invalid_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = conversation_module.CompletionConversationApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
account = _make_account()
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1"))
|
||||
monkeypatch.setattr(
|
||||
conversation_module,
|
||||
"parse_time_range",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(ValueError("bad range")),
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/completion-conversations",
|
||||
method="GET",
|
||||
query_string={"start": "bad"},
|
||||
):
|
||||
with pytest.raises(BadRequest):
|
||||
method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
|
||||
def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = conversation_module.ChatConversationApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
account = _make_account()
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1"))
|
||||
monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None))
|
||||
|
||||
paginate_result = MagicMock()
|
||||
monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/chat-conversations", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1", mode=AppMode.ADVANCED_CHAT))
|
||||
|
||||
assert response is paginate_result
|
||||
|
||||
|
||||
def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
conversation = SimpleNamespace(id="c1", app_id="app-1")
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.first.return_value = conversation
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
|
||||
monkeypatch.setattr(conversation_module.db, "session", session)
|
||||
|
||||
result = conversation_module._get_conversation(SimpleNamespace(id="app-1"), "c1")
|
||||
|
||||
assert result is conversation
|
||||
session.execute.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
session.refresh.assert_called_once_with(conversation)
|
||||
|
||||
|
||||
def test_get_conversation_missing_raises_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.first.return_value = None
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
|
||||
monkeypatch.setattr(conversation_module.db, "session", session)
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
conversation_module._get_conversation(SimpleNamespace(id="app-1"), "missing")
|
||||
|
||||
|
||||
def test_completion_conversation_delete_maps_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = conversation_module.CompletionConversationDetailApi()
|
||||
method = _unwrap(api.delete)
|
||||
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
|
||||
monkeypatch.setattr(
|
||||
conversation_module.ConversationService,
|
||||
"delete",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationNotExistsError()),
|
||||
)
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
method(app_model=SimpleNamespace(id="app-1"), conversation_id="c1")
|
||||
@@ -0,0 +1,260 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.app import generator as generator_module
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def _model_config_payload():
|
||||
return {"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}}
|
||||
|
||||
|
||||
def _install_workflow_service(monkeypatch: pytest.MonkeyPatch, workflow):
|
||||
class _Service:
|
||||
def get_draft_workflow(self, app_model):
|
||||
return workflow
|
||||
|
||||
monkeypatch.setattr(generator_module, "WorkflowService", lambda: _Service())
|
||||
|
||||
|
||||
def test_rule_generate_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.RuleGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
monkeypatch.setattr(generator_module.LLMGenerator, "generate_rule_config", lambda **_kwargs: {"rules": []})
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/rule-generate",
|
||||
method="POST",
|
||||
json={"instruction": "do it", "model_config": _model_config_payload()},
|
||||
):
|
||||
response = method()
|
||||
|
||||
assert response == {"rules": []}
|
||||
|
||||
|
||||
def test_rule_code_generate_maps_token_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.RuleCodeGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
def _raise(*_args, **_kwargs):
|
||||
raise ProviderTokenNotInitError("missing token")
|
||||
|
||||
monkeypatch.setattr(generator_module.LLMGenerator, "generate_code", _raise)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/rule-code-generate",
|
||||
method="POST",
|
||||
json={"instruction": "do it", "model_config": _model_config_payload()},
|
||||
):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
method()
|
||||
|
||||
|
||||
def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: None)
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate",
|
||||
method="POST",
|
||||
json={
|
||||
"flow_id": "app-1",
|
||||
"node_id": "node-1",
|
||||
"instruction": "do",
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response, status = method()
|
||||
|
||||
assert status == 400
|
||||
assert response["error"] == "app app-1 not found"
|
||||
|
||||
|
||||
def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
_install_workflow_service(monkeypatch, workflow=None)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate",
|
||||
method="POST",
|
||||
json={
|
||||
"flow_id": "app-1",
|
||||
"node_id": "node-1",
|
||||
"instruction": "do",
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response, status = method()
|
||||
|
||||
assert status == 400
|
||||
assert response["error"] == "workflow app-1 not found"
|
||||
|
||||
|
||||
def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
|
||||
workflow = SimpleNamespace(graph_dict={"nodes": []})
|
||||
_install_workflow_service(monkeypatch, workflow=workflow)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate",
|
||||
method="POST",
|
||||
json={
|
||||
"flow_id": "app-1",
|
||||
"node_id": "node-1",
|
||||
"instruction": "do",
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response, status = method()
|
||||
|
||||
assert status == 400
|
||||
assert response["error"] == "node node-1 not found"
|
||||
|
||||
|
||||
def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
|
||||
workflow = SimpleNamespace(
|
||||
graph_dict={
|
||||
"nodes": [
|
||||
{"id": "node-1", "data": {"type": "code"}},
|
||||
]
|
||||
}
|
||||
)
|
||||
_install_workflow_service(monkeypatch, workflow=workflow)
|
||||
monkeypatch.setattr(generator_module.LLMGenerator, "generate_code", lambda **_kwargs: {"code": "x"})
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate",
|
||||
method="POST",
|
||||
json={
|
||||
"flow_id": "app-1",
|
||||
"node_id": "node-1",
|
||||
"instruction": "do",
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response = method()
|
||||
|
||||
assert response == {"code": "x"}
|
||||
|
||||
|
||||
def test_instruction_generate_legacy_modify(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
monkeypatch.setattr(
|
||||
generator_module.LLMGenerator,
|
||||
"instruction_modify_legacy",
|
||||
lambda **_kwargs: {"instruction": "ok"},
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate",
|
||||
method="POST",
|
||||
json={
|
||||
"flow_id": "app-1",
|
||||
"node_id": "",
|
||||
"current": "old",
|
||||
"instruction": "do",
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response = method()
|
||||
|
||||
assert response == {"instruction": "ok"}
|
||||
|
||||
|
||||
def test_instruction_generate_incompatible_params(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate",
|
||||
method="POST",
|
||||
json={
|
||||
"flow_id": "app-1",
|
||||
"node_id": "",
|
||||
"current": "",
|
||||
"instruction": "do",
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response, status = method()
|
||||
|
||||
assert status == 400
|
||||
assert response["error"] == "incompatible parameters"
|
||||
|
||||
|
||||
def test_instruction_template_prompt(app) -> None:
|
||||
api = generator_module.InstructionGenerationTemplateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate/template",
|
||||
method="POST",
|
||||
json={"type": "prompt"},
|
||||
):
|
||||
response = method()
|
||||
|
||||
assert "data" in response
|
||||
|
||||
|
||||
def test_instruction_template_invalid_type(app) -> None:
|
||||
api = generator_module.InstructionGenerationTemplateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate/template",
|
||||
method="POST",
|
||||
json={"type": "unknown"},
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method()
|
||||
122
api/tests/unit_tests/controllers/console/app/test_message_api.py
Normal file
122
api/tests/unit_tests/controllers/console/app/test_message_api.py
Normal file
@@ -0,0 +1,122 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.app import message as message_module
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def test_chat_messages_query_valid(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test valid ChatMessagesQuery with all fields."""
|
||||
query = message_module.ChatMessagesQuery(
|
||||
conversation_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
first_id="550e8400-e29b-41d4-a716-446655440001",
|
||||
limit=50,
|
||||
)
|
||||
assert query.limit == 50
|
||||
|
||||
|
||||
def test_chat_messages_query_defaults(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test ChatMessagesQuery with defaults."""
|
||||
query = message_module.ChatMessagesQuery(conversation_id="550e8400-e29b-41d4-a716-446655440000")
|
||||
assert query.first_id is None
|
||||
assert query.limit == 20
|
||||
|
||||
|
||||
def test_chat_messages_query_empty_first_id(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test ChatMessagesQuery converts empty first_id to None."""
|
||||
query = message_module.ChatMessagesQuery(
|
||||
conversation_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
first_id="",
|
||||
)
|
||||
assert query.first_id is None
|
||||
|
||||
|
||||
def test_message_feedback_payload_valid_like(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test MessageFeedbackPayload with like rating."""
|
||||
payload = message_module.MessageFeedbackPayload(
|
||||
message_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
rating="like",
|
||||
content="Good answer",
|
||||
)
|
||||
assert payload.rating == "like"
|
||||
assert payload.content == "Good answer"
|
||||
|
||||
|
||||
def test_message_feedback_payload_valid_dislike(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test MessageFeedbackPayload with dislike rating."""
|
||||
payload = message_module.MessageFeedbackPayload(
|
||||
message_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
rating="dislike",
|
||||
)
|
||||
assert payload.rating == "dislike"
|
||||
|
||||
|
||||
def test_message_feedback_payload_no_rating(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test MessageFeedbackPayload without rating."""
|
||||
payload = message_module.MessageFeedbackPayload(message_id="550e8400-e29b-41d4-a716-446655440000")
|
||||
assert payload.rating is None
|
||||
|
||||
|
||||
def test_feedback_export_query_defaults(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test FeedbackExportQuery with default format."""
|
||||
query = message_module.FeedbackExportQuery()
|
||||
assert query.format == "csv"
|
||||
assert query.from_source is None
|
||||
|
||||
|
||||
def test_feedback_export_query_json_format(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test FeedbackExportQuery with JSON format."""
|
||||
query = message_module.FeedbackExportQuery(format="json")
|
||||
assert query.format == "json"
|
||||
|
||||
|
||||
def test_feedback_export_query_has_comment_true(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test FeedbackExportQuery with has_comment as true string."""
|
||||
query = message_module.FeedbackExportQuery(has_comment="true")
|
||||
assert query.has_comment is True
|
||||
|
||||
|
||||
def test_feedback_export_query_has_comment_false(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test FeedbackExportQuery with has_comment as false string."""
|
||||
query = message_module.FeedbackExportQuery(has_comment="false")
|
||||
assert query.has_comment is False
|
||||
|
||||
|
||||
def test_feedback_export_query_has_comment_1(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test FeedbackExportQuery with has_comment as 1."""
|
||||
query = message_module.FeedbackExportQuery(has_comment="1")
|
||||
assert query.has_comment is True
|
||||
|
||||
|
||||
def test_feedback_export_query_has_comment_0(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test FeedbackExportQuery with has_comment as 0."""
|
||||
query = message_module.FeedbackExportQuery(has_comment="0")
|
||||
assert query.has_comment is False
|
||||
|
||||
|
||||
def test_feedback_export_query_rating_filter(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test FeedbackExportQuery with rating filter."""
|
||||
query = message_module.FeedbackExportQuery(rating="like")
|
||||
assert query.rating == "like"
|
||||
|
||||
|
||||
def test_annotation_count_response(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test AnnotationCountResponse creation."""
|
||||
response = message_module.AnnotationCountResponse(count=10)
|
||||
assert response.count == 10
|
||||
|
||||
|
||||
def test_suggested_questions_response(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test SuggestedQuestionsResponse creation."""
|
||||
response = message_module.SuggestedQuestionsResponse(data=["What is AI?", "How does ML work?"])
|
||||
assert len(response.data) == 2
|
||||
assert response.data[0] == "What is AI?"
|
||||
@@ -0,0 +1,151 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.app import model_config as model_config_module
|
||||
from models.model import AppMode, AppModelConfig
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def test_post_updates_app_model_config_for_chat(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = model_config_module.ModelConfigResource()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
app_model = SimpleNamespace(
|
||||
id="app-1",
|
||||
mode=AppMode.CHAT.value,
|
||||
is_agent=False,
|
||||
app_model_config_id=None,
|
||||
updated_by=None,
|
||||
updated_at=None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
model_config_module.AppModelConfigService,
|
||||
"validate_configuration",
|
||||
lambda **_kwargs: {"pre_prompt": "hi"},
|
||||
)
|
||||
monkeypatch.setattr(model_config_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
session = MagicMock()
|
||||
monkeypatch.setattr(model_config_module.db, "session", session)
|
||||
|
||||
def _from_model_config_dict(self, model_config):
|
||||
self.pre_prompt = model_config["pre_prompt"]
|
||||
self.id = "config-1"
|
||||
return self
|
||||
|
||||
monkeypatch.setattr(AppModelConfig, "from_model_config_dict", _from_model_config_dict)
|
||||
send_mock = MagicMock()
|
||||
monkeypatch.setattr(model_config_module.app_model_config_was_updated, "send", send_mock)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/model-config", method="POST", json={"pre_prompt": "hi"}):
|
||||
response = method(app_model=app_model)
|
||||
|
||||
session.add.assert_called_once()
|
||||
session.flush.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
send_mock.assert_called_once()
|
||||
assert app_model.app_model_config_id == "config-1"
|
||||
assert response["result"] == "success"
|
||||
|
||||
|
||||
def test_post_encrypts_agent_tool_parameters(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = model_config_module.ModelConfigResource()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
app_model = SimpleNamespace(
|
||||
id="app-1",
|
||||
mode=AppMode.AGENT_CHAT.value,
|
||||
is_agent=True,
|
||||
app_model_config_id="config-0",
|
||||
updated_by=None,
|
||||
updated_at=None,
|
||||
)
|
||||
|
||||
original_config = AppModelConfig(app_id="app-1", created_by="u1", updated_by="u1")
|
||||
original_config.agent_mode = json.dumps(
|
||||
{
|
||||
"enabled": True,
|
||||
"strategy": "function-calling",
|
||||
"tools": [
|
||||
{
|
||||
"provider_id": "provider",
|
||||
"provider_type": "builtin",
|
||||
"tool_name": "tool",
|
||||
"tool_parameters": {"secret": "masked"},
|
||||
}
|
||||
],
|
||||
"prompt": None,
|
||||
}
|
||||
)
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.first.return_value = original_config
|
||||
session.query.return_value = query
|
||||
monkeypatch.setattr(model_config_module.db, "session", session)
|
||||
|
||||
monkeypatch.setattr(
|
||||
model_config_module.AppModelConfigService,
|
||||
"validate_configuration",
|
||||
lambda **_kwargs: {
|
||||
"pre_prompt": "hi",
|
||||
"agent_mode": {
|
||||
"enabled": True,
|
||||
"strategy": "function-calling",
|
||||
"tools": [
|
||||
{
|
||||
"provider_id": "provider",
|
||||
"provider_type": "builtin",
|
||||
"tool_name": "tool",
|
||||
"tool_parameters": {"secret": "masked"},
|
||||
}
|
||||
],
|
||||
"prompt": None,
|
||||
},
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(model_config_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
monkeypatch.setattr(model_config_module.ToolManager, "get_agent_tool_runtime", lambda **_kwargs: object())
|
||||
|
||||
class _ParamManager:
|
||||
def __init__(self, **_kwargs):
|
||||
self.delete_called = False
|
||||
|
||||
def decrypt_tool_parameters(self, _value):
|
||||
return {"secret": "decrypted"}
|
||||
|
||||
def mask_tool_parameters(self, _value):
|
||||
return {"secret": "masked"}
|
||||
|
||||
def encrypt_tool_parameters(self, _value):
|
||||
return {"secret": "encrypted"}
|
||||
|
||||
def delete_tool_parameters_cache(self):
|
||||
self.delete_called = True
|
||||
|
||||
monkeypatch.setattr(model_config_module, "ToolParameterConfigurationManager", _ParamManager)
|
||||
send_mock = MagicMock()
|
||||
monkeypatch.setattr(model_config_module.app_model_config_was_updated, "send", send_mock)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/model-config", method="POST", json={"pre_prompt": "hi"}):
|
||||
response = method(app_model=app_model)
|
||||
|
||||
stored_config = session.add.call_args[0][0]
|
||||
stored_agent_mode = json.loads(stored_config.agent_mode)
|
||||
assert stored_agent_mode["tools"][0]["tool_parameters"]["secret"] == "encrypted"
|
||||
assert response["result"] == "success"
|
||||
@@ -0,0 +1,215 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.console.app import statistic as statistic_module
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
class _ConnContext:
|
||||
def __init__(self, rows):
|
||||
self._rows = rows
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def execute(self, _query, _args):
|
||||
return self._rows
|
||||
|
||||
|
||||
def _install_db(monkeypatch: pytest.MonkeyPatch, rows) -> None:
|
||||
engine = SimpleNamespace(begin=lambda: _ConnContext(rows))
|
||||
monkeypatch.setattr(statistic_module, "db", SimpleNamespace(engine=engine))
|
||||
|
||||
|
||||
def _install_common(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
statistic_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
statistic_module,
|
||||
"parse_time_range",
|
||||
lambda *_args, **_kwargs: (None, None),
|
||||
)
|
||||
monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field)
|
||||
|
||||
|
||||
def test_daily_message_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyMessageStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
rows = [SimpleNamespace(date="2024-01-01", message_count=3)]
|
||||
_install_common(monkeypatch)
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-01", "message_count": 3}]}
|
||||
|
||||
|
||||
def test_daily_conversation_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyConversationStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
rows = [SimpleNamespace(date="2024-01-02", conversation_count=5)]
|
||||
_install_common(monkeypatch)
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-conversations", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]}
|
||||
|
||||
|
||||
def test_daily_token_cost_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyTokenCostStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
rows = [SimpleNamespace(date="2024-01-03", token_count=10, total_price=0.25, currency="USD")]
|
||||
_install_common(monkeypatch)
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/token-costs", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
data = response.get_json()
|
||||
assert len(data["data"]) == 1
|
||||
assert data["data"][0]["date"] == "2024-01-03"
|
||||
assert data["data"][0]["token_count"] == 10
|
||||
assert data["data"][0]["total_price"] == 0.25
|
||||
|
||||
|
||||
def test_daily_terminals_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyTerminalsStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
rows = [SimpleNamespace(date="2024-01-04", terminal_count=7)]
|
||||
_install_common(monkeypatch)
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-end-users", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-04", "terminal_count": 7}]}
|
||||
|
||||
|
||||
def test_average_session_interaction_statistic_requires_chat_mode(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test that AverageSessionInteractionStatistic is limited to chat/agent modes."""
|
||||
# This just verifies the decorator is applied correctly
|
||||
# Actual endpoint testing would require complex JOIN mocking
|
||||
api = statistic_module.AverageSessionInteractionStatistic()
|
||||
method = _unwrap(api.get)
|
||||
assert callable(method)
|
||||
|
||||
|
||||
def test_daily_message_statistic_with_invalid_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyMessageStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
def mock_parse(*args, **kwargs):
|
||||
raise ValueError("Invalid time range")
|
||||
|
||||
_install_db(monkeypatch, [])
|
||||
monkeypatch.setattr(
|
||||
statistic_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr(statistic_module, "parse_time_range", mock_parse)
|
||||
monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
|
||||
with pytest.raises(BadRequest):
|
||||
method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
|
||||
def test_daily_message_statistic_multiple_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyMessageStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
rows = [
|
||||
SimpleNamespace(date="2024-01-01", message_count=10),
|
||||
SimpleNamespace(date="2024-01-02", message_count=15),
|
||||
SimpleNamespace(date="2024-01-03", message_count=12),
|
||||
]
|
||||
_install_common(monkeypatch)
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
data = response.get_json()
|
||||
assert len(data["data"]) == 3
|
||||
|
||||
|
||||
def test_daily_message_statistic_empty_result(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyMessageStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
_install_common(monkeypatch)
|
||||
_install_db(monkeypatch, [])
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": []}
|
||||
|
||||
|
||||
def test_daily_conversation_statistic_with_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyConversationStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
rows = [SimpleNamespace(date="2024-01-02", conversation_count=5)]
|
||||
_install_db(monkeypatch, rows)
|
||||
monkeypatch.setattr(
|
||||
statistic_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
statistic_module,
|
||||
"parse_time_range",
|
||||
lambda *_args, **_kwargs: ("s", "e"),
|
||||
)
|
||||
monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-conversations", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]}
|
||||
|
||||
|
||||
def test_daily_token_cost_with_multiple_currencies(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyTokenCostStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
rows = [
|
||||
SimpleNamespace(date="2024-01-01", token_count=100, total_price=Decimal("0.50"), currency="USD"),
|
||||
SimpleNamespace(date="2024-01-02", token_count=200, total_price=Decimal("1.00"), currency="USD"),
|
||||
]
|
||||
_install_common(monkeypatch)
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/token-costs", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
data = response.get_json()
|
||||
assert len(data["data"]) == 2
|
||||
163
api/tests/unit_tests/controllers/console/app/test_workflow.py
Normal file
163
api/tests/unit_tests/controllers/console/app/test_workflow.py
Normal file
@@ -0,0 +1,163 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import HTTPException, NotFound
|
||||
|
||||
from controllers.console.app import workflow as workflow_module
|
||||
from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from dify_graph.file.enums import FileTransferMethod, FileType
|
||||
from dify_graph.file.models import File
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
def test_parse_file_no_config(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(workflow_module.FileUploadConfigManager, "convert", lambda *_args, **_kwargs: None)
|
||||
workflow = SimpleNamespace(features_dict={}, tenant_id="t1")
|
||||
|
||||
assert workflow_module._parse_file(workflow, files=[{"id": "f"}]) == []
|
||||
|
||||
|
||||
def test_parse_file_with_config(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
config = object()
|
||||
file_list = [
|
||||
File(
|
||||
tenant_id="t1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="http://u",
|
||||
)
|
||||
]
|
||||
build_mock = Mock(return_value=file_list)
|
||||
monkeypatch.setattr(workflow_module.FileUploadConfigManager, "convert", lambda *_args, **_kwargs: config)
|
||||
monkeypatch.setattr(workflow_module.file_factory, "build_from_mappings", build_mock)
|
||||
|
||||
workflow = SimpleNamespace(features_dict={}, tenant_id="t1")
|
||||
result = workflow_module._parse_file(workflow, files=[{"id": "f"}])
|
||||
|
||||
assert result == file_list
|
||||
build_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_sync_draft_workflow_invalid_content_type(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = workflow_module.DraftWorkflowApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
|
||||
|
||||
with app.test_request_context("/apps/app/workflows/draft", method="POST", data="x", content_type="text/html"):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
handler(api, app_model=SimpleNamespace(id="app"))
|
||||
|
||||
assert exc.value.code == 415
|
||||
|
||||
|
||||
def test_sync_draft_workflow_invalid_json(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = workflow_module.DraftWorkflowApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/app/workflows/draft",
|
||||
method="POST",
|
||||
data="[]",
|
||||
content_type="application/json",
|
||||
):
|
||||
response, status = handler(api, app_model=SimpleNamespace(id="app"))
|
||||
|
||||
assert status == 400
|
||||
assert response["message"] == "Invalid JSON data"
|
||||
|
||||
|
||||
def test_sync_draft_workflow_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow = SimpleNamespace(
|
||||
unique_hash="h",
|
||||
updated_at=None,
|
||||
created_at=datetime(2024, 1, 1),
|
||||
)
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
|
||||
monkeypatch.setattr(
|
||||
workflow_module.variable_factory, "build_environment_variable_from_mapping", lambda *_args: "env"
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_module.variable_factory, "build_conversation_variable_from_mapping", lambda *_args: "conv"
|
||||
)
|
||||
|
||||
service = SimpleNamespace(sync_draft_workflow=lambda **_kwargs: workflow)
|
||||
monkeypatch.setattr(workflow_module, "WorkflowService", lambda: service)
|
||||
|
||||
api = workflow_module.DraftWorkflowApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/app/workflows/draft",
|
||||
method="POST",
|
||||
json={"graph": {}, "features": {}, "hash": "h"},
|
||||
):
|
||||
response = handler(api, app_model=SimpleNamespace(id="app"))
|
||||
|
||||
assert response["result"] == "success"
|
||||
|
||||
|
||||
def test_sync_draft_workflow_hash_mismatch(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
|
||||
|
||||
def _raise(*_args, **_kwargs):
|
||||
raise workflow_module.WorkflowHashNotEqualError()
|
||||
|
||||
service = SimpleNamespace(sync_draft_workflow=_raise)
|
||||
monkeypatch.setattr(workflow_module, "WorkflowService", lambda: service)
|
||||
|
||||
api = workflow_module.DraftWorkflowApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/app/workflows/draft",
|
||||
method="POST",
|
||||
json={"graph": {}, "features": {}, "hash": "h"},
|
||||
):
|
||||
with pytest.raises(DraftWorkflowNotSync):
|
||||
handler(api, app_model=SimpleNamespace(id="app"))
|
||||
|
||||
|
||||
def test_draft_workflow_get_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
workflow_module, "WorkflowService", lambda: SimpleNamespace(get_draft_workflow=lambda **_k: None)
|
||||
)
|
||||
|
||||
api = workflow_module.DraftWorkflowApi()
|
||||
handler = _unwrap(api.get)
|
||||
|
||||
with pytest.raises(DraftWorkflowNotExist):
|
||||
handler(api, app_model=SimpleNamespace(id="app"))
|
||||
|
||||
|
||||
def test_advanced_chat_run_conversation_not_exists(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
workflow_module.AppGenerateService,
|
||||
"generate",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(
|
||||
workflow_module.services.errors.conversation.ConversationNotExistsError()
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
|
||||
|
||||
api = workflow_module.AdvancedChatDraftWorkflowRunApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/app/advanced-chat/workflows/draft/run",
|
||||
method="POST",
|
||||
json={"inputs": {}},
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=SimpleNamespace(id="app"))
|
||||
47
api/tests/unit_tests/controllers/console/app/test_wraps.py
Normal file
47
api/tests/unit_tests/controllers/console/app/test_wraps.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.app import wraps as wraps_module
|
||||
from controllers.console.app.error import AppNotFoundError
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1")
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
||||
|
||||
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
|
||||
@wraps_module.get_app_model
|
||||
def handler(app_model):
|
||||
return app_model.id
|
||||
|
||||
assert handler(app_id="app-1") == "app-1"
|
||||
|
||||
|
||||
def test_get_app_model_rejects_wrong_mode(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1")
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
||||
|
||||
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
|
||||
@wraps_module.get_app_model(mode=[AppMode.COMPLETION])
|
||||
def handler(app_model):
|
||||
return app_model.id
|
||||
|
||||
with pytest.raises(AppNotFoundError):
|
||||
handler(app_id="app-1")
|
||||
|
||||
|
||||
def test_get_app_model_requires_app_id() -> None:
|
||||
@wraps_module.get_app_model
|
||||
def handler(app_model):
|
||||
return app_model.id
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
handler()
|
||||
@@ -0,0 +1,817 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.rag_pipeline.datasource_auth import (
|
||||
DatasourceAuth,
|
||||
DatasourceAuthDefaultApi,
|
||||
DatasourceAuthDeleteApi,
|
||||
DatasourceAuthListApi,
|
||||
DatasourceAuthOauthCustomClient,
|
||||
DatasourceAuthUpdateApi,
|
||||
DatasourceHardCodeAuthListApi,
|
||||
DatasourceOAuthCallback,
|
||||
DatasourcePluginOAuthAuthorizationUrl,
|
||||
DatasourceUpdateProviderNameApi,
|
||||
)
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestDatasourcePluginOAuthAuthorizationUrl:
|
||||
def test_get_success(self, app):
|
||||
api = DatasourcePluginOAuthAuthorizationUrl()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock(id="user-1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/?credential_id=cred-1"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
return_value={"client_id": "abc"},
|
||||
),
|
||||
patch.object(
|
||||
OAuthProxyService,
|
||||
"create_proxy_context",
|
||||
return_value="ctx-1",
|
||||
),
|
||||
patch.object(
|
||||
OAuthHandler,
|
||||
"get_authorization_url",
|
||||
return_value={"url": "http://auth"},
|
||||
),
|
||||
):
|
||||
response = method(api, "notion")
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_get_no_oauth_config(self, app):
|
||||
api = DatasourcePluginOAuthAuthorizationUrl()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
|
||||
def test_get_without_credential_id_sets_cookie(self, app):
|
||||
api = DatasourcePluginOAuthAuthorizationUrl()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock(id="user-1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
return_value={"client_id": "abc"},
|
||||
),
|
||||
patch.object(
|
||||
OAuthProxyService,
|
||||
"create_proxy_context",
|
||||
return_value="ctx-123",
|
||||
),
|
||||
patch.object(
|
||||
OAuthHandler,
|
||||
"get_authorization_url",
|
||||
return_value={"url": "http://auth"},
|
||||
),
|
||||
):
|
||||
response = method(api, "notion")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "context_id" in response.headers.get("Set-Cookie")
|
||||
|
||||
|
||||
class TestDatasourceOAuthCallback:
|
||||
def test_callback_success_new_credential(self, app):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
|
||||
oauth_response = MagicMock()
|
||||
oauth_response.credentials = {"token": "abc"}
|
||||
oauth_response.expires_at = None
|
||||
oauth_response.metadata = {"name": "test"}
|
||||
|
||||
context = {
|
||||
"user_id": "user-1",
|
||||
"tenant_id": "tenant-1",
|
||||
"credential_id": None,
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/?context_id=ctx"),
|
||||
patch.object(
|
||||
OAuthProxyService,
|
||||
"use_proxy_context",
|
||||
return_value=context,
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
return_value={"client_id": "abc"},
|
||||
),
|
||||
patch.object(
|
||||
OAuthHandler,
|
||||
"get_credentials",
|
||||
return_value=oauth_response,
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"add_datasource_oauth_provider",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response = method(api, "notion")
|
||||
|
||||
assert response.status_code == 302
|
||||
|
||||
def test_callback_missing_context(self, app):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "notion")
|
||||
|
||||
def test_callback_invalid_context(self, app):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?context_id=bad"),
|
||||
patch.object(
|
||||
OAuthProxyService,
|
||||
"use_proxy_context",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "notion")
|
||||
|
||||
def test_callback_oauth_config_not_found(self, app):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
|
||||
context = {"user_id": "u", "tenant_id": "t"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/?context_id=ctx"),
|
||||
patch.object(
|
||||
OAuthProxyService,
|
||||
"use_proxy_context",
|
||||
return_value=context,
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "notion")
|
||||
|
||||
def test_callback_reauthorize_existing_credential(self, app):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
|
||||
oauth_response = MagicMock()
|
||||
oauth_response.credentials = {"token": "abc"}
|
||||
oauth_response.expires_at = None
|
||||
oauth_response.metadata = {} # avatar + name missing
|
||||
|
||||
context = {
|
||||
"user_id": "user-1",
|
||||
"tenant_id": "tenant-1",
|
||||
"credential_id": "cred-1",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/?context_id=ctx"),
|
||||
patch.object(
|
||||
OAuthProxyService,
|
||||
"use_proxy_context",
|
||||
return_value=context,
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
return_value={"client_id": "abc"},
|
||||
),
|
||||
patch.object(
|
||||
OAuthHandler,
|
||||
"get_credentials",
|
||||
return_value=oauth_response,
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"reauthorize_datasource_oauth_provider",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response = method(api, "notion")
|
||||
|
||||
assert response.status_code == 302
|
||||
assert "/oauth-callback" in response.location
|
||||
|
||||
def test_callback_context_id_from_cookie(self, app):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
|
||||
oauth_response = MagicMock()
|
||||
oauth_response.credentials = {"token": "abc"}
|
||||
oauth_response.expires_at = None
|
||||
oauth_response.metadata = {}
|
||||
|
||||
context = {
|
||||
"user_id": "user-1",
|
||||
"tenant_id": "tenant-1",
|
||||
"credential_id": None,
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", headers={"Cookie": "context_id=ctx"}),
|
||||
patch.object(
|
||||
OAuthProxyService,
|
||||
"use_proxy_context",
|
||||
return_value=context,
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
return_value={"client_id": "abc"},
|
||||
),
|
||||
patch.object(
|
||||
OAuthHandler,
|
||||
"get_credentials",
|
||||
return_value=oauth_response,
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"add_datasource_oauth_provider",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response = method(api, "notion")
|
||||
|
||||
assert response.status_code == 302
|
||||
|
||||
|
||||
class TestDatasourceAuth:
|
||||
def test_post_success(self, app):
|
||||
api = DatasourceAuth()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credentials": {"key": "val"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"add_datasource_api_key_provider",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_post_invalid_credentials(self, app):
|
||||
api = DatasourceAuth()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credentials": {"key": "bad"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"add_datasource_api_key_provider",
|
||||
side_effect=CredentialsValidateFailedError("invalid"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
|
||||
def test_get_success(self, app):
|
||||
api = DatasourceAuth()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"list_datasource_credentials",
|
||||
return_value=[{"id": "1"}],
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
assert response["result"]
|
||||
|
||||
def test_post_missing_credentials(self, app):
|
||||
api = DatasourceAuth()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
|
||||
def test_get_empty_list(self, app):
|
||||
api = DatasourceAuth()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"list_datasource_credentials",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
assert response["result"] == []
|
||||
|
||||
|
||||
class TestDatasourceAuthDeleteApi:
|
||||
def test_delete_success(self, app):
|
||||
api = DatasourceAuthDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "cred-1"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"remove_datasource_credentials",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_delete_missing_credential_id(self, app):
|
||||
api = DatasourceAuthDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
|
||||
|
||||
class TestDatasourceAuthUpdateApi:
|
||||
def test_update_success(self, app):
|
||||
api = DatasourceAuthUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "id", "credentials": {"k": "v"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"update_datasource_credentials",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 201
|
||||
|
||||
def test_update_with_credentials_none(self, app):
|
||||
api = DatasourceAuthUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "id", "credentials": None}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"update_datasource_credentials",
|
||||
return_value=None,
|
||||
) as update_mock,
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
update_mock.assert_called_once()
|
||||
assert status == 201
|
||||
|
||||
def test_update_name_only(self, app):
|
||||
api = DatasourceAuthUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "id", "name": "New Name"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"update_datasource_credentials",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
_, status = method(api, "notion")
|
||||
|
||||
assert status == 201
|
||||
|
||||
def test_update_with_empty_credentials_dict(self, app):
|
||||
api = DatasourceAuthUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "id", "credentials": {}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"update_datasource_credentials",
|
||||
return_value=None,
|
||||
) as update_mock,
|
||||
):
|
||||
_, status = method(api, "notion")
|
||||
|
||||
update_mock.assert_called_once()
|
||||
assert status == 201
|
||||
|
||||
|
||||
class TestDatasourceAuthListApi:
|
||||
def test_list_success(self, app):
|
||||
api = DatasourceAuthListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_all_datasource_credentials",
|
||||
return_value=[{"id": "1"}],
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_auth_list_empty(self, app):
|
||||
api = DatasourceAuthListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_all_datasource_credentials",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert response["result"] == []
|
||||
|
||||
def test_hardcode_list_empty(self, app):
|
||||
api = DatasourceHardCodeAuthListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_hard_code_datasource_credentials",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert response["result"] == []
|
||||
|
||||
|
||||
class TestDatasourceHardCodeAuthListApi:
|
||||
def test_list_success(self, app):
|
||||
api = DatasourceHardCodeAuthListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_hard_code_datasource_credentials",
|
||||
return_value=[{"id": "1"}],
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
|
||||
|
||||
class TestDatasourceAuthOauthCustomClient:
|
||||
def test_post_success(self, app):
|
||||
api = DatasourceAuthOauthCustomClient()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"client_params": {}, "enable_oauth_custom_client": True}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"setup_oauth_custom_client_params",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_delete_success(self, app):
|
||||
api = DatasourceAuthOauthCustomClient()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"remove_oauth_custom_client_params",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_post_empty_payload(self, app):
|
||||
api = DatasourceAuthOauthCustomClient()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"setup_oauth_custom_client_params",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
_, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_post_disabled_flag(self, app):
|
||||
api = DatasourceAuthOauthCustomClient()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"client_params": {"a": 1},
|
||||
"enable_oauth_custom_client": False,
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"setup_oauth_custom_client_params",
|
||||
return_value=None,
|
||||
) as setup_mock,
|
||||
):
|
||||
_, status = method(api, "notion")
|
||||
|
||||
setup_mock.assert_called_once()
|
||||
assert status == 200
|
||||
|
||||
|
||||
class TestDatasourceAuthDefaultApi:
|
||||
def test_set_default_success(self, app):
|
||||
api = DatasourceAuthDefaultApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"id": "cred-1"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"set_default_datasource_provider",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_default_missing_id(self, app):
|
||||
api = DatasourceAuthDefaultApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
|
||||
|
||||
class TestDatasourceUpdateProviderNameApi:
|
||||
def test_update_name_success(self, app):
|
||||
api = DatasourceUpdateProviderNameApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "id", "name": "New Name"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"update_datasource_provider_name",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_update_name_too_long(self, app):
|
||||
api = DatasourceUpdateProviderNameApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"credential_id": "id",
|
||||
"name": "x" * 101,
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
|
||||
def test_update_name_missing_credential_id(self, app):
|
||||
api = DatasourceUpdateProviderNameApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"name": "Valid"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
@@ -0,0 +1,143 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.rag_pipeline.datasource_content_preview import (
|
||||
DataSourceContentPreviewApi,
|
||||
)
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestDataSourceContentPreviewApi:
|
||||
def _valid_payload(self):
|
||||
return {
|
||||
"inputs": {"query": "hello"},
|
||||
"datasource_type": "notion",
|
||||
"credential_id": "cred-1",
|
||||
}
|
||||
|
||||
def test_post_success(self, app):
|
||||
api = DataSourceContentPreviewApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._valid_payload()
|
||||
|
||||
pipeline = MagicMock(spec=Pipeline)
|
||||
node_id = "node-1"
|
||||
account = MagicMock(spec=Account)
|
||||
|
||||
preview_result = {"content": "preview data"}
|
||||
|
||||
service_instance = MagicMock()
|
||||
service_instance.run_datasource_node_preview.return_value = preview_result
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
|
||||
account,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.RagPipelineService",
|
||||
return_value=service_instance,
|
||||
),
|
||||
):
|
||||
response, status = method(api, pipeline, node_id)
|
||||
|
||||
service_instance.run_datasource_node_preview.assert_called_once_with(
|
||||
pipeline=pipeline,
|
||||
node_id=node_id,
|
||||
user_inputs=payload["inputs"],
|
||||
account=account,
|
||||
datasource_type=payload["datasource_type"],
|
||||
is_published=True,
|
||||
credential_id=payload["credential_id"],
|
||||
)
|
||||
assert status == 200
|
||||
assert response == preview_result
|
||||
|
||||
def test_post_forbidden_non_account_user(self, app):
|
||||
api = DataSourceContentPreviewApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._valid_payload()
|
||||
|
||||
pipeline = MagicMock(spec=Pipeline)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
|
||||
MagicMock(), # NOT Account
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, pipeline, "node-1")
|
||||
|
||||
def test_post_invalid_payload(self, app):
|
||||
api = DataSourceContentPreviewApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"inputs": {"query": "hello"},
|
||||
# datasource_type missing
|
||||
}
|
||||
|
||||
pipeline = MagicMock(spec=Pipeline)
|
||||
account = MagicMock(spec=Account)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
|
||||
account,
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, pipeline, "node-1")
|
||||
|
||||
def test_post_without_credential_id(self, app):
|
||||
api = DataSourceContentPreviewApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"inputs": {"query": "hello"},
|
||||
"datasource_type": "notion",
|
||||
"credential_id": None,
|
||||
}
|
||||
|
||||
pipeline = MagicMock(spec=Pipeline)
|
||||
account = MagicMock(spec=Account)
|
||||
|
||||
service_instance = MagicMock()
|
||||
service_instance.run_datasource_node_preview.return_value = {"ok": True}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
|
||||
account,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.RagPipelineService",
|
||||
return_value=service_instance,
|
||||
),
|
||||
):
|
||||
response, status = method(api, pipeline, "node-1")
|
||||
|
||||
service_instance.run_datasource_node_preview.assert_called_once()
|
||||
assert status == 200
|
||||
assert response == {"ok": True}
|
||||
@@ -0,0 +1,187 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline import (
|
||||
CustomizedPipelineTemplateApi,
|
||||
PipelineTemplateDetailApi,
|
||||
PipelineTemplateListApi,
|
||||
PublishCustomizedPipelineTemplateApi,
|
||||
)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestPipelineTemplateListApi:
|
||||
def test_get_success(self, app):
|
||||
api = PipelineTemplateListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
templates = [{"id": "t1"}]
|
||||
|
||||
with (
|
||||
app.test_request_context("/?type=built-in&language=en-US"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.get_pipeline_templates",
|
||||
return_value=templates,
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert response == templates
|
||||
|
||||
|
||||
class TestPipelineTemplateDetailApi:
|
||||
def test_get_success(self, app):
|
||||
api = PipelineTemplateDetailApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
template = {"id": "tpl-1"}
|
||||
|
||||
service = MagicMock()
|
||||
service.get_pipeline_template_detail.return_value = template
|
||||
|
||||
with (
|
||||
app.test_request_context("/?type=built-in"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "tpl-1")
|
||||
|
||||
assert status == 200
|
||||
assert response == template
|
||||
|
||||
|
||||
class TestCustomizedPipelineTemplateApi:
|
||||
def test_patch_success(self, app):
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
payload = {
|
||||
"name": "Template",
|
||||
"description": "Desc",
|
||||
"icon_info": {"icon": "📘"},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.update_customized_pipeline_template"
|
||||
) as update_mock,
|
||||
):
|
||||
response = method(api, "tpl-1")
|
||||
|
||||
update_mock.assert_called_once()
|
||||
assert response == 200
|
||||
|
||||
def test_delete_success(self, app):
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.delete_customized_pipeline_template"
|
||||
) as delete_mock,
|
||||
):
|
||||
response = method(api, "tpl-1")
|
||||
|
||||
delete_mock.assert_called_once_with("tpl-1")
|
||||
assert response == 200
|
||||
|
||||
def test_post_success(self, app):
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
template = MagicMock()
|
||||
template.yaml_content = "yaml-data"
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = template
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "tpl-1")
|
||||
|
||||
assert status == 200
|
||||
assert response == {"data": "yaml-data"}
|
||||
|
||||
def test_post_template_not_found(self, app):
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "tpl-1")
|
||||
|
||||
|
||||
class TestPublishCustomizedPipelineTemplateApi:
|
||||
def test_post_success(self, app):
|
||||
api = PublishCustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"name": "Template",
|
||||
"description": "Desc",
|
||||
"icon_info": {"icon": "📘"},
|
||||
}
|
||||
|
||||
service = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response = method(api, "pipeline-1")
|
||||
|
||||
service.publish_customized_pipeline_template.assert_called_once()
|
||||
assert response == {"result": "success"}
|
||||
@@ -0,0 +1,187 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline_datasets import (
|
||||
CreateEmptyRagPipelineDatasetApi,
|
||||
CreateRagPipelineDatasetApi,
|
||||
)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestCreateRagPipelineDatasetApi:
|
||||
def _valid_payload(self):
|
||||
return {"yaml_content": "name: test"}
|
||||
|
||||
def test_post_success(self, app):
|
||||
api = CreateRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._valid_payload()
|
||||
user = MagicMock(is_dataset_editor=True)
|
||||
import_info = {"dataset_id": "ds-1"}
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service.create_rag_pipeline_dataset.return_value = import_info
|
||||
|
||||
mock_session_ctx = MagicMock()
|
||||
mock_session_ctx.__enter__.return_value = MagicMock()
|
||||
mock_session_ctx.__exit__.return_value = None
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session",
|
||||
return_value=mock_session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService",
|
||||
return_value=mock_service,
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 201
|
||||
assert response == import_info
|
||||
|
||||
def test_post_forbidden_non_editor(self, app):
|
||||
api = CreateRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._valid_payload()
|
||||
user = MagicMock(is_dataset_editor=False)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
|
||||
def test_post_dataset_name_duplicate(self, app):
|
||||
api = CreateRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._valid_payload()
|
||||
user = MagicMock(is_dataset_editor=True)
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service.create_rag_pipeline_dataset.side_effect = services.errors.dataset.DatasetNameDuplicateError()
|
||||
|
||||
mock_session_ctx = MagicMock()
|
||||
mock_session_ctx.__enter__.return_value = MagicMock()
|
||||
mock_session_ctx.__exit__.return_value = None
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session",
|
||||
return_value=mock_session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService",
|
||||
return_value=mock_service,
|
||||
),
|
||||
):
|
||||
with pytest.raises(DatasetNameDuplicateError):
|
||||
method(api)
|
||||
|
||||
def test_post_invalid_payload(self, app):
|
||||
api = CreateRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {}
|
||||
user = MagicMock(is_dataset_editor=True)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestCreateEmptyRagPipelineDatasetApi:
|
||||
def test_post_success(self, app):
|
||||
api = CreateEmptyRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock(is_dataset_editor=True)
|
||||
dataset = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.DatasetService.create_empty_rag_pipeline_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.marshal",
|
||||
return_value={"id": "ds-1"},
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 201
|
||||
assert response == {"id": "ds-1"}
|
||||
|
||||
def test_post_forbidden_non_editor(self, app):
|
||||
api = CreateEmptyRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock(is_dataset_editor=False)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
@@ -0,0 +1,324 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Response
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import DraftWorkflowNotExist
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable import (
|
||||
RagPipelineEnvironmentVariableCollectionApi,
|
||||
RagPipelineNodeVariableCollectionApi,
|
||||
RagPipelineSystemVariableCollectionApi,
|
||||
RagPipelineVariableApi,
|
||||
RagPipelineVariableCollectionApi,
|
||||
RagPipelineVariableResetApi,
|
||||
)
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from dify_graph.variables.types import SegmentType
|
||||
from models.account import Account
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_db():
|
||||
db = MagicMock()
|
||||
db.engine = MagicMock()
|
||||
db.session.return_value = MagicMock()
|
||||
return db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def editor_user():
|
||||
user = MagicMock(spec=Account)
|
||||
user.has_edit_permission = True
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def restx_config(app):
|
||||
return patch.dict(app.config, {"RESTX_MASK_HEADER": "X-Fields"})
|
||||
|
||||
|
||||
class TestRagPipelineVariableCollectionApi:
|
||||
def test_get_variables_success(self, app, fake_db, editor_user, restx_config):
|
||||
api = RagPipelineVariableCollectionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock(id="p1")
|
||||
|
||||
rag_srv = MagicMock()
|
||||
rag_srv.is_workflow_exist.return_value = True
|
||||
|
||||
# IMPORTANT: RESTX expects .variables
|
||||
var_list = MagicMock()
|
||||
var_list.variables = []
|
||||
|
||||
draft_srv = MagicMock()
|
||||
draft_srv.list_variables_without_values.return_value = var_list
|
||||
|
||||
with (
|
||||
app.test_request_context("/?page=1&limit=10"),
|
||||
restx_config,
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
|
||||
return_value=rag_srv,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=draft_srv,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
|
||||
assert result["items"] == []
|
||||
|
||||
def test_get_variables_workflow_not_exist(self, app, fake_db, editor_user):
|
||||
api = RagPipelineVariableCollectionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
|
||||
rag_srv = MagicMock()
|
||||
rag_srv.is_workflow_exist.return_value = False
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
|
||||
return_value=rag_srv,
|
||||
),
|
||||
):
|
||||
with pytest.raises(DraftWorkflowNotExist):
|
||||
method(api, pipeline)
|
||||
|
||||
def test_delete_variables_success(self, app, fake_db, editor_user):
|
||||
api = RagPipelineVariableCollectionApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
pipeline = MagicMock(id="p1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService"),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert result.status_code == 204
|
||||
|
||||
|
||||
class TestRagPipelineNodeVariableCollectionApi:
|
||||
def test_get_node_variables_success(self, app, fake_db, editor_user, restx_config):
|
||||
api = RagPipelineNodeVariableCollectionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock(id="p1")
|
||||
|
||||
var_list = MagicMock()
|
||||
var_list.variables = []
|
||||
|
||||
srv = MagicMock()
|
||||
srv.list_node_variables.return_value = var_list
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
restx_config,
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=srv,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "node1")
|
||||
|
||||
assert result["items"] == []
|
||||
|
||||
def test_get_node_variables_invalid_node(self, app, editor_user):
|
||||
api = RagPipelineNodeVariableCollectionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
):
|
||||
with pytest.raises(InvalidArgumentError):
|
||||
method(api, MagicMock(), SYSTEM_VARIABLE_NODE_ID)
|
||||
|
||||
|
||||
class TestRagPipelineVariableApi:
|
||||
def test_get_variable_not_found(self, app, fake_db, editor_user):
|
||||
api = RagPipelineVariableApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
srv = MagicMock()
|
||||
srv.get_variable.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=srv,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFoundError):
|
||||
method(api, MagicMock(), "v1")
|
||||
|
||||
def test_patch_variable_invalid_file_payload(self, app, fake_db, editor_user):
|
||||
api = RagPipelineVariableApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
pipeline = MagicMock(id="p1", tenant_id="t1")
|
||||
variable = MagicMock(app_id="p1", value_type=SegmentType.FILE)
|
||||
|
||||
srv = MagicMock()
|
||||
srv.get_variable.return_value = variable
|
||||
|
||||
payload = {"value": "invalid"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=srv,
|
||||
),
|
||||
):
|
||||
with pytest.raises(InvalidArgumentError):
|
||||
method(api, pipeline, "v1")
|
||||
|
||||
def test_delete_variable_success(self, app, fake_db, editor_user):
|
||||
api = RagPipelineVariableApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
pipeline = MagicMock(id="p1")
|
||||
variable = MagicMock(app_id="p1")
|
||||
|
||||
srv = MagicMock()
|
||||
srv.get_variable.return_value = variable
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=srv,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "v1")
|
||||
|
||||
assert result.status_code == 204
|
||||
|
||||
|
||||
class TestRagPipelineVariableResetApi:
|
||||
def test_reset_variable_success(self, app, fake_db, editor_user):
|
||||
api = RagPipelineVariableResetApi()
|
||||
method = unwrap(api.put)
|
||||
|
||||
pipeline = MagicMock(id="p1")
|
||||
workflow = MagicMock()
|
||||
variable = MagicMock(app_id="p1")
|
||||
|
||||
srv = MagicMock()
|
||||
srv.get_variable.return_value = variable
|
||||
srv.reset_variable.return_value = variable
|
||||
|
||||
rag_srv = MagicMock()
|
||||
rag_srv.get_draft_workflow.return_value = workflow
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
|
||||
return_value=rag_srv,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=srv,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.marshal",
|
||||
return_value={"id": "v1"},
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "v1")
|
||||
|
||||
assert result == {"id": "v1"}
|
||||
|
||||
|
||||
class TestSystemAndEnvironmentVariablesApi:
|
||||
def test_system_variables_success(self, app, fake_db, editor_user, restx_config):
|
||||
api = RagPipelineSystemVariableCollectionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock(id="p1")
|
||||
|
||||
var_list = MagicMock()
|
||||
var_list.variables = []
|
||||
|
||||
srv = MagicMock()
|
||||
srv.list_system_variables.return_value = var_list
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
restx_config,
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=srv,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
|
||||
assert result["items"] == []
|
||||
|
||||
def test_environment_variables_success(self, app, editor_user):
|
||||
api = RagPipelineEnvironmentVariableCollectionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
env_var = MagicMock(
|
||||
id="e1",
|
||||
name="ENV",
|
||||
description="d",
|
||||
selector="s",
|
||||
value_type=MagicMock(value="string"),
|
||||
value="x",
|
||||
)
|
||||
|
||||
workflow = MagicMock(environment_variables=[env_var])
|
||||
pipeline = MagicMock(id="p1")
|
||||
|
||||
rag_srv = MagicMock()
|
||||
rag_srv.get_draft_workflow.return_value = workflow
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
|
||||
return_value=rag_srv,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
|
||||
assert len(result["items"]) == 1
|
||||
@@ -0,0 +1,329 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline_import import (
|
||||
RagPipelineExportApi,
|
||||
RagPipelineImportApi,
|
||||
RagPipelineImportCheckDependenciesApi,
|
||||
RagPipelineImportConfirmApi,
|
||||
)
|
||||
from models.dataset import Pipeline
|
||||
from services.app_dsl_service import ImportStatus
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestRagPipelineImportApi:
|
||||
def _payload(self, mode="create"):
|
||||
return {
|
||||
"mode": mode,
|
||||
"yaml_content": "content",
|
||||
"name": "Test",
|
||||
}
|
||||
|
||||
def test_post_success_200(self, app):
|
||||
api = RagPipelineImportApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._payload()
|
||||
|
||||
user = MagicMock()
|
||||
result = MagicMock()
|
||||
result.status = "completed"
|
||||
result.model_dump.return_value = {"status": "success"}
|
||||
|
||||
service = MagicMock()
|
||||
service.import_rag_pipeline.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert response == {"status": "success"}
|
||||
|
||||
def test_post_failed_400(self, app):
|
||||
api = RagPipelineImportApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._payload()
|
||||
|
||||
user = MagicMock()
|
||||
result = MagicMock()
|
||||
result.status = ImportStatus.FAILED
|
||||
result.model_dump.return_value = {"status": "failed"}
|
||||
|
||||
service = MagicMock()
|
||||
service.import_rag_pipeline.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 400
|
||||
assert response == {"status": "failed"}
|
||||
|
||||
def test_post_pending_202(self, app):
|
||||
api = RagPipelineImportApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._payload()
|
||||
|
||||
user = MagicMock()
|
||||
result = MagicMock()
|
||||
result.status = ImportStatus.PENDING
|
||||
result.model_dump.return_value = {"status": "pending"}
|
||||
|
||||
service = MagicMock()
|
||||
service.import_rag_pipeline.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 202
|
||||
assert response == {"status": "pending"}
|
||||
|
||||
|
||||
class TestRagPipelineImportConfirmApi:
|
||||
def test_confirm_success(self, app):
|
||||
api = RagPipelineImportConfirmApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock()
|
||||
result = MagicMock()
|
||||
result.status = "completed"
|
||||
result.model_dump.return_value = {"ok": True}
|
||||
|
||||
service = MagicMock()
|
||||
service.confirm_import.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "import-1")
|
||||
|
||||
assert status == 200
|
||||
assert response == {"ok": True}
|
||||
|
||||
def test_confirm_failed(self, app):
|
||||
api = RagPipelineImportConfirmApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock()
|
||||
result = MagicMock()
|
||||
result.status = ImportStatus.FAILED
|
||||
result.model_dump.return_value = {"ok": False}
|
||||
|
||||
service = MagicMock()
|
||||
service.confirm_import.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "import-1")
|
||||
|
||||
assert status == 400
|
||||
assert response == {"ok": False}
|
||||
|
||||
|
||||
class TestRagPipelineImportCheckDependenciesApi:
|
||||
def test_get_success(self, app):
|
||||
api = RagPipelineImportCheckDependenciesApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock(spec=Pipeline)
|
||||
result = MagicMock()
|
||||
result.model_dump.return_value = {"deps": []}
|
||||
|
||||
service = MagicMock()
|
||||
service.check_dependencies.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api, pipeline)
|
||||
|
||||
assert status == 200
|
||||
assert response == {"deps": []}
|
||||
|
||||
|
||||
class TestRagPipelineExportApi:
|
||||
def test_get_with_include_secret(self, app):
|
||||
api = RagPipelineExportApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock(spec=Pipeline)
|
||||
service = MagicMock()
|
||||
service.export_rag_pipeline_dsl.return_value = {"yaml": "data"}
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/?include_secret=true"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api, pipeline)
|
||||
|
||||
assert status == 200
|
||||
assert response == {"data": {"yaml": "data"}}
|
||||
@@ -0,0 +1,688 @@
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline_workflow import (
|
||||
DefaultRagPipelineBlockConfigApi,
|
||||
DraftRagPipelineApi,
|
||||
DraftRagPipelineRunApi,
|
||||
PublishedAllRagPipelineApi,
|
||||
PublishedRagPipelineApi,
|
||||
PublishedRagPipelineRunApi,
|
||||
RagPipelineByIdApi,
|
||||
RagPipelineDatasourceVariableApi,
|
||||
RagPipelineDraftNodeRunApi,
|
||||
RagPipelineDraftRunIterationNodeApi,
|
||||
RagPipelineDraftRunLoopNodeApi,
|
||||
RagPipelineRecommendedPluginApi,
|
||||
RagPipelineTaskStopApi,
|
||||
RagPipelineTransformApi,
|
||||
RagPipelineWorkflowLastRunApi,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestDraftWorkflowApi:
|
||||
def test_get_draft_success(self, app):
|
||||
api = DraftRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
workflow = MagicMock()
|
||||
|
||||
service = MagicMock()
|
||||
service.get_draft_workflow.return_value = workflow
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
assert result == workflow
|
||||
|
||||
def test_get_draft_not_exist(self, app):
|
||||
api = DraftRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
service = MagicMock()
|
||||
service.get_draft_workflow.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
with pytest.raises(DraftWorkflowNotExist):
|
||||
method(api, pipeline)
|
||||
|
||||
def test_sync_hash_not_match(self, app):
|
||||
api = DraftRagPipelineApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
service = MagicMock()
|
||||
service.sync_draft_workflow.side_effect = WorkflowHashNotEqualError()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"graph": {}, "features": {}}),
|
||||
patch.object(type(console_ns), "payload", {"graph": {}, "features": {}}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
with pytest.raises(DraftWorkflowNotSync):
|
||||
method(api, pipeline)
|
||||
|
||||
def test_sync_invalid_text_plain(self, app):
|
||||
api = DraftRagPipelineApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", data="bad-json", headers={"Content-Type": "text/plain"}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
):
|
||||
response, status = method(api, pipeline)
|
||||
assert status == 400
|
||||
|
||||
|
||||
class TestDraftRunNodes:
|
||||
def test_iteration_node_success(self, app):
|
||||
api = RagPipelineDraftRunIterationNodeApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"inputs": {}}),
|
||||
patch.object(type(console_ns), "payload", {"inputs": {}}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_iteration",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "node")
|
||||
assert result == {"ok": True}
|
||||
|
||||
def test_iteration_node_conversation_not_exists(self, app):
|
||||
api = RagPipelineDraftRunIterationNodeApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"inputs": {}}),
|
||||
patch.object(type(console_ns), "payload", {"inputs": {}}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_iteration",
|
||||
side_effect=services.errors.conversation.ConversationNotExistsError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, pipeline, "node")
|
||||
|
||||
def test_loop_node_success(self, app):
|
||||
api = RagPipelineDraftRunLoopNodeApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"inputs": {}}),
|
||||
patch.object(type(console_ns), "payload", {"inputs": {}}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_loop",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, pipeline, "node") == {"ok": True}
|
||||
|
||||
|
||||
class TestPipelineRunApis:
|
||||
def test_draft_run_success(self, app):
|
||||
api = DraftRagPipelineRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
payload = {
|
||||
"inputs": {},
|
||||
"datasource_type": "x",
|
||||
"datasource_info_list": [],
|
||||
"start_node_id": "n",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, pipeline) == {"ok": True}
|
||||
|
||||
def test_draft_run_rate_limit(self, app):
|
||||
api = DraftRagPipelineRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/", json={"inputs": {}, "datasource_type": "x", "datasource_info_list": [], "start_node_id": "n"}
|
||||
),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
{"inputs": {}, "datasource_type": "x", "datasource_info_list": [], "start_node_id": "n"},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
|
||||
side_effect=InvokeRateLimitError("limit"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(InvokeRateLimitHttpError):
|
||||
method(api, pipeline)
|
||||
|
||||
|
||||
class TestDraftNodeRun:
|
||||
def test_execution_not_found(self, app):
|
||||
api = RagPipelineDraftNodeRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
service = MagicMock()
|
||||
service.run_draft_workflow_node.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"inputs": {}}),
|
||||
patch.object(type(console_ns), "payload", {"inputs": {}}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, pipeline, "node")
|
||||
|
||||
|
||||
class TestPublishedPipelineApis:
|
||||
def test_publish_success(self, app):
|
||||
api = PublishedRagPipelineApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
workflow = MagicMock(
|
||||
id="w1",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
session = MagicMock()
|
||||
session.merge.return_value = pipeline
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
service = MagicMock()
|
||||
service.publish_workflow.return_value = workflow
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
|
||||
assert result["result"] == "success"
|
||||
assert "created_at" in result
|
||||
|
||||
|
||||
class TestMiscApis:
|
||||
def test_task_stop(self, app):
|
||||
api = RagPipelineTaskStopApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.AppQueueManager.set_stop_flag"
|
||||
) as stop_mock,
|
||||
):
|
||||
result = method(api, pipeline, "task-1")
|
||||
stop_mock.assert_called_once()
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_transform_forbidden(self, app):
|
||||
api = RagPipelineTransformApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock(has_edit_permission=False, is_dataset_operator=False)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "ds1")
|
||||
|
||||
def test_recommended_plugins(self, app):
|
||||
api = RagPipelineRecommendedPluginApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
service = MagicMock()
|
||||
service.get_recommended_plugins.return_value = [{"id": "p1"}]
|
||||
|
||||
with (
|
||||
app.test_request_context("/?type=all"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
assert result == [{"id": "p1"}]
|
||||
|
||||
|
||||
class TestPublishedRagPipelineRunApi:
|
||||
def test_published_run_success(self, app):
|
||||
api = PublishedRagPipelineRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
payload = {
|
||||
"inputs": {},
|
||||
"datasource_type": "x",
|
||||
"datasource_info_list": [],
|
||||
"start_node_id": "n",
|
||||
"response_mode": "blocking",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
assert result == {"ok": True}
|
||||
|
||||
def test_published_run_rate_limit(self, app):
|
||||
api = PublishedRagPipelineRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
payload = {
|
||||
"inputs": {},
|
||||
"datasource_type": "x",
|
||||
"datasource_info_list": [],
|
||||
"start_node_id": "n",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
|
||||
side_effect=InvokeRateLimitError("limit"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(InvokeRateLimitHttpError):
|
||||
method(api, pipeline)
|
||||
|
||||
|
||||
class TestDefaultBlockConfigApi:
|
||||
def test_get_block_config_success(self, app):
|
||||
api = DefaultRagPipelineBlockConfigApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
|
||||
service = MagicMock()
|
||||
service.get_default_block_config.return_value = {"k": "v"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/?q={}"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "llm")
|
||||
assert result == {"k": "v"}
|
||||
|
||||
def test_get_block_config_invalid_json(self, app):
|
||||
api = DefaultRagPipelineBlockConfigApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
|
||||
with app.test_request_context("/?q=bad-json"):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, pipeline, "llm")
|
||||
|
||||
|
||||
class TestPublishedAllRagPipelineApi:
|
||||
def test_get_published_workflows_success(self, app):
|
||||
api = PublishedAllRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
service = MagicMock()
|
||||
service.get_all_published_workflow.return_value = ([{"id": "w1"}], False)
|
||||
|
||||
session = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
|
||||
assert result["items"] == [{"id": "w1"}]
|
||||
assert result["has_more"] is False
|
||||
|
||||
def test_get_published_workflows_forbidden(self, app):
|
||||
api = PublishedAllRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/?user_id=u2"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, pipeline)
|
||||
|
||||
|
||||
class TestRagPipelineByIdApi:
|
||||
def test_patch_success(self, app):
|
||||
api = RagPipelineByIdApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
pipeline = MagicMock(tenant_id="t1")
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
workflow = MagicMock()
|
||||
|
||||
service = MagicMock()
|
||||
service.update_workflow.return_value = workflow
|
||||
|
||||
session = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
payload = {"marked_name": "test"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "w1")
|
||||
|
||||
assert result == workflow
|
||||
|
||||
def test_patch_no_fields(self, app):
|
||||
api = RagPipelineByIdApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
patch.object(type(console_ns), "payload", {}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
):
|
||||
result, status = method(api, pipeline, "w1")
|
||||
assert status == 400
|
||||
|
||||
|
||||
class TestRagPipelineWorkflowLastRunApi:
|
||||
def test_last_run_success(self, app):
|
||||
api = RagPipelineWorkflowLastRunApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
workflow = MagicMock()
|
||||
node_exec = MagicMock()
|
||||
|
||||
service = MagicMock()
|
||||
service.get_draft_workflow.return_value = workflow
|
||||
service.get_node_last_run.return_value = node_exec
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "node1")
|
||||
assert result == node_exec
|
||||
|
||||
def test_last_run_not_found(self, app):
|
||||
api = RagPipelineWorkflowLastRunApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
|
||||
service = MagicMock()
|
||||
service.get_draft_workflow.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, pipeline, "node1")
|
||||
|
||||
|
||||
class TestRagPipelineDatasourceVariableApi:
|
||||
def test_set_datasource_variables_success(self, app):
|
||||
api = RagPipelineDatasourceVariableApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
payload = {
|
||||
"datasource_type": "db",
|
||||
"datasource_info": {},
|
||||
"start_node_id": "n1",
|
||||
"start_node_title": "Node",
|
||||
}
|
||||
|
||||
service = MagicMock()
|
||||
service.set_datasource_variables.return_value = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
assert result is not None
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user