diff --git a/.gemini/config.yaml b/.gemini/config.yaml new file mode 100644 index 00000000000..15c697730ad --- /dev/null +++ b/.gemini/config.yaml @@ -0,0 +1,13 @@ +have_fun: false +memory_config: + disabled: false +code_review: + disable: true + comment_severity_threshold: MEDIUM + max_review_comments: -1 + pull_request_opened: + help: false + summary: false + code_review: false + include_drafts: false +ignore_patterns: [] diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 1bb7d062327..94e857f93a5 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -36,7 +36,6 @@ /api/core/workflow/graph/ @laipz8200 @QuantumGhost /api/core/workflow/graph_events/ @laipz8200 @QuantumGhost /api/core/workflow/node_events/ @laipz8200 @QuantumGhost -/api/dify_graph/model_runtime/ @laipz8200 @QuantumGhost # Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM) /api/core/workflow/nodes/agent/ @Nov1c444 diff --git a/.github/actions/setup-web/action.yml b/.github/actions/setup-web/action.yml index 6f3b3c08b47..673155bcf7c 100644 --- a/.github/actions/setup-web/action.yml +++ b/.github/actions/setup-web/action.yml @@ -4,10 +4,8 @@ runs: using: composite steps: - name: Setup Vite+ - uses: voidzero-dev/setup-vp@4a524139920f87f9f7080d3b8545acac019e1852 # v1.0.0 + uses: voidzero-dev/setup-vp@20553a7a7429c429a74894104a2835d7fed28a72 # v1.3.0 with: - node-version-file: web/.nvmrc + node-version-file: .nvmrc cache: true - cache-dependency-path: web/pnpm-lock.yaml - run-install: | - cwd: ./web + run-install: true diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index 6b87946221a..cd967b76cf7 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -14,18 +14,17 @@ concurrency: cancel-in-progress: true jobs: - test: - name: API Tests + api-unit: + name: API Unit Tests runs-on: ubuntu-latest env: - CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + COVERAGE_FILE: coverage-unit defaults: run: shell: bash strategy: matrix: python-version: - - "3.11" - "3.12" steps: @@ -36,7 +35,7 @@ jobs: persist-credentials: false - name: Setup UV and Python - uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 + uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 with: enable-cache: true python-version: ${{ matrix.python-version }} @@ -51,6 +50,52 @@ jobs: - name: Run dify config tests run: uv run --project api dev/pytest/pytest_config_tests.py + - name: Run Unit Tests + run: uv run --project api bash dev/pytest/pytest_unit_tests.sh + + - name: Upload unit coverage data + uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + with: + name: api-coverage-unit + path: coverage-unit + retention-days: 1 + + api-integration: + name: API Integration Tests + runs-on: ubuntu-latest + env: + COVERAGE_FILE: coverage-integration + STORAGE_TYPE: opendal + OPENDAL_SCHEME: fs + OPENDAL_FS_ROOT: /tmp/dify-storage + defaults: + run: + shell: bash + strategy: + matrix: + python-version: + - "3.12" + + steps: + - name: Checkout code + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + fetch-depth: 0 + persist-credentials: false + + - name: Setup UV and Python + uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + with: + enable-cache: true + python-version: ${{ matrix.python-version }} + cache-dependency-glob: api/uv.lock + + - name: Check UV lockfile + run: uv lock --project api --check + + - name: Install dependencies + run: uv sync --project api --dev + - name: Set up dotenvs run: | cp docker/.env.example docker/.env @@ -74,23 +119,91 @@ jobs: run: | cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env - - name: Run API Tests - env: - STORAGE_TYPE: opendal - OPENDAL_SCHEME: fs - OPENDAL_FS_ROOT: /tmp/dify-storage + - name: Run Integration Tests run: | uv run --project api pytest \ -n auto \ --timeout "${PYTEST_TIMEOUT:-180}" \ api/tests/integration_tests/workflow \ api/tests/integration_tests/tools \ - api/tests/test_containers_integration_tests \ - api/tests/unit_tests + api/tests/test_containers_integration_tests + + - name: Upload integration coverage data + uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + with: + name: api-coverage-integration + path: coverage-integration + retention-days: 1 + + api-coverage: + name: API Coverage + runs-on: ubuntu-latest + needs: + - api-unit + - api-integration + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + COVERAGE_FILE: .coverage + defaults: + run: + shell: bash + + steps: + - name: Checkout code + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + fetch-depth: 0 + persist-credentials: false + + - name: Setup UV and Python + uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + with: + enable-cache: true + python-version: "3.12" + cache-dependency-glob: api/uv.lock + + - name: Install dependencies + run: uv sync --project api --dev + + - name: Download coverage data + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 + with: + path: coverage-data + pattern: api-coverage-* + merge-multiple: true + + - name: Combine coverage + run: | + set -euo pipefail + + echo "### API Coverage" >> "$GITHUB_STEP_SUMMARY" + echo "" >> "$GITHUB_STEP_SUMMARY" + echo "Merged backend coverage report generated for Codecov project status." >> "$GITHUB_STEP_SUMMARY" + echo "" >> "$GITHUB_STEP_SUMMARY" + + unit_coverage="$(find coverage-data -type f -name coverage-unit -print -quit)" + integration_coverage="$(find coverage-data -type f -name coverage-integration -print -quit)" + : "${unit_coverage:?coverage-unit artifact not found}" + : "${integration_coverage:?coverage-integration artifact not found}" + + report_file="$(mktemp)" + uv run --project api coverage combine "$unit_coverage" "$integration_coverage" + uv run --project api coverage report --show-missing | tee "$report_file" + echo "Summary: \`$(tail -n 1 "$report_file")\`" >> "$GITHUB_STEP_SUMMARY" + { + echo "" + echo "
Coverage report" + echo "" + echo '```' + cat "$report_file" + echo '```' + echo "
" + } >> "$GITHUB_STEP_SUMMARY" + uv run --project api coverage xml -o coverage.xml - name: Report coverage - if: ${{ env.CODECOV_TOKEN != '' && matrix.python-version == '3.12' }} - uses: codecov/codecov-action@1af58845a975a7985b0beb0cbe6fbbb71a41dbad # v5.5.3 + if: ${{ env.CODECOV_TOKEN != '' }} + uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0 with: files: ./coverage.xml disable_search: true diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index 8947ae4030e..9648c34274a 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -2,6 +2,9 @@ name: autofix.ci on: pull_request: branches: ["main"] + merge_group: + branches: ["main"] + types: [checks_requested] push: branches: ["main"] permissions: @@ -12,9 +15,15 @@ jobs: if: github.repository == 'langgenius/dify' runs-on: ubuntu-latest steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + - name: Complete merge group check + if: github.event_name == 'merge_group' + run: echo "autofix.ci updates pull request branches, not merge group refs." + + - if: github.event_name != 'merge_group' + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Check Docker Compose inputs + if: github.event_name != 'merge_group' id: docker-compose-changes uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 with: @@ -24,30 +33,38 @@ jobs: docker/docker-compose-template.yaml docker/docker-compose.yaml - name: Check web inputs + if: github.event_name != 'merge_group' id: web-changes uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 with: files: | web/** + package.json + pnpm-lock.yaml + pnpm-workspace.yaml + .nvmrc - name: Check api inputs + if: github.event_name != 'merge_group' id: api-changes uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 with: files: | api/** - - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 + - if: github.event_name != 'merge_group' + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 with: python-version: "3.11" - - uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 + - if: github.event_name != 'merge_group' + uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 - name: Generate Docker Compose - if: steps.docker-compose-changes.outputs.any_changed == 'true' + if: github.event_name != 'merge_group' && steps.docker-compose-changes.outputs.any_changed == 'true' run: | cd docker ./generate_docker_compose - - if: steps.api-changes.outputs.any_changed == 'true' + - if: github.event_name != 'merge_group' && steps.api-changes.outputs.any_changed == 'true' run: | cd api uv sync --dev @@ -59,13 +76,13 @@ jobs: uv run ruff format .. - name: count migration progress - if: steps.api-changes.outputs.any_changed == 'true' + if: github.event_name != 'merge_group' && steps.api-changes.outputs.any_changed == 'true' run: | cd api ./cnt_base.sh - name: ast-grep - if: steps.api-changes.outputs.any_changed == 'true' + if: github.event_name != 'merge_group' && steps.api-changes.outputs.any_changed == 'true' run: | # ast-grep exits 1 if no matches are found; allow idempotent runs. uvx --from ast-grep-cli ast-grep --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all || true @@ -94,19 +111,15 @@ jobs: find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \; find . -name "*.py.bak" -type f -delete - # mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter. - - name: mdformat - run: | - uvx --python 3.13 mdformat . --exclude ".agents/skills/**" - - name: Setup web environment - if: steps.web-changes.outputs.any_changed == 'true' + if: github.event_name != 'merge_group' && steps.web-changes.outputs.any_changed == 'true' uses: ./.github/actions/setup-web - name: ESLint autofix - if: steps.web-changes.outputs.any_changed == 'true' + if: github.event_name != 'merge_group' && steps.web-changes.outputs.any_changed == 'true' run: | cd web vp exec eslint --concurrency=2 --prune-suppressions --quiet || true - - uses: autofix-ci/action@7a166d7532b277f34e16238930461bf77f9d7ed8 # v1.3.3 + - if: github.event_name != 'merge_group' + uses: autofix-ci/action@7a166d7532b277f34e16238930461bf77f9d7ed8 # v1.3.3 diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml index 1ae8d444829..a23edc70e53 100644 --- a/.github/workflows/build-push.yml +++ b/.github/workflows/build-push.yml @@ -24,27 +24,39 @@ env: jobs: build: - runs-on: ${{ matrix.platform == 'linux/arm64' && 'arm64_runner' || 'ubuntu-latest' }} + runs-on: ${{ matrix.runs_on }} if: github.repository == 'langgenius/dify' strategy: matrix: include: - service_name: "build-api-amd64" image_name_env: "DIFY_API_IMAGE_NAME" - context: "api" + artifact_context: "api" + build_context: "{{defaultContext}}:api" + file: "Dockerfile" platform: linux/amd64 + runs_on: ubuntu-latest - service_name: "build-api-arm64" image_name_env: "DIFY_API_IMAGE_NAME" - context: "api" + artifact_context: "api" + build_context: "{{defaultContext}}:api" + file: "Dockerfile" platform: linux/arm64 + runs_on: ubuntu-24.04-arm - service_name: "build-web-amd64" image_name_env: "DIFY_WEB_IMAGE_NAME" - context: "web" + artifact_context: "web" + build_context: "{{defaultContext}}" + file: "web/Dockerfile" platform: linux/amd64 + runs_on: ubuntu-latest - service_name: "build-web-arm64" image_name_env: "DIFY_WEB_IMAGE_NAME" - context: "web" + artifact_context: "web" + build_context: "{{defaultContext}}" + file: "web/Dockerfile" platform: linux/arm64 + runs_on: ubuntu-24.04-arm steps: - name: Prepare @@ -58,9 +70,6 @@ jobs: username: ${{ env.DOCKERHUB_USER }} password: ${{ env.DOCKERHUB_TOKEN }} - - name: Set up QEMU - uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4.0.0 - - name: Set up Docker Buildx uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0 @@ -74,7 +83,8 @@ jobs: id: build uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0 with: - context: "{{defaultContext}}:${{ matrix.context }}" + context: ${{ matrix.build_context }} + file: ${{ matrix.file }} platforms: ${{ matrix.platform }} build-args: COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }} labels: ${{ steps.meta.outputs.labels }} @@ -93,7 +103,7 @@ jobs: - name: Upload digest uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 with: - name: digests-${{ matrix.context }}-${{ env.PLATFORM_PAIR }} + name: digests-${{ matrix.artifact_context }}-${{ env.PLATFORM_PAIR }} path: /tmp/digests/* if-no-files-found: error retention-days: 1 diff --git a/.github/workflows/db-migration-test.yml b/.github/workflows/db-migration-test.yml index ffb9734e488..5991abe3ba1 100644 --- a/.github/workflows/db-migration-test.yml +++ b/.github/workflows/db-migration-test.yml @@ -19,7 +19,7 @@ jobs: persist-credentials: false - name: Setup UV and Python - uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 + uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 with: enable-cache: true python-version: "3.12" @@ -69,7 +69,7 @@ jobs: persist-credentials: false - name: Setup UV and Python - uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 + uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 with: enable-cache: true python-version: "3.12" diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index 340b380dc93..cbeb1a3bb1b 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -6,7 +6,12 @@ on: - "main" paths: - api/Dockerfile + - web/docker/** - web/Dockerfile + - package.json + - pnpm-lock.yaml + - pnpm-workspace.yaml + - .nvmrc concurrency: group: docker-build-${{ github.head_ref || github.run_id }} @@ -14,26 +19,31 @@ concurrency: jobs: build-docker: - runs-on: ubuntu-latest + runs-on: ${{ matrix.runs_on }} strategy: matrix: include: - service_name: "api-amd64" platform: linux/amd64 - context: "api" + runs_on: ubuntu-latest + context: "{{defaultContext}}:api" + file: "Dockerfile" - service_name: "api-arm64" platform: linux/arm64 - context: "api" + runs_on: ubuntu-24.04-arm + context: "{{defaultContext}}:api" + file: "Dockerfile" - service_name: "web-amd64" platform: linux/amd64 - context: "web" + runs_on: ubuntu-latest + context: "{{defaultContext}}" + file: "web/Dockerfile" - service_name: "web-arm64" platform: linux/arm64 - context: "web" + runs_on: ubuntu-24.04-arm + context: "{{defaultContext}}" + file: "web/Dockerfile" steps: - - name: Set up QEMU - uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4.0.0 - - name: Set up Docker Buildx uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0 @@ -41,8 +51,8 @@ jobs: uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0 with: push: false - context: "{{defaultContext}}:${{ matrix.context }}" - file: "${{ matrix.file }}" + context: ${{ matrix.context }} + file: ${{ matrix.file }} platforms: ${{ matrix.platform }} cache-from: type=gha cache-to: type=gha,mode=max diff --git a/.github/workflows/main-ci.yml b/.github/workflows/main-ci.yml index 69023c24cc9..104368d1929 100644 --- a/.github/workflows/main-ci.yml +++ b/.github/workflows/main-ci.yml @@ -3,10 +3,14 @@ name: Main CI Pipeline on: pull_request: branches: ["main"] + merge_group: + branches: ["main"] + types: [checks_requested] push: branches: ["main"] permissions: + actions: write contents: write pull-requests: write checks: write @@ -17,12 +21,28 @@ concurrency: cancel-in-progress: true jobs: + pre_job: + name: Skip Duplicate Checks + runs-on: ubuntu-latest + outputs: + should_skip: ${{ steps.skip_check.outputs.should_skip || 'false' }} + steps: + - id: skip_check + continue-on-error: true + uses: fkirc/skip-duplicate-actions@f75f66ce1886f00957d99748a42c724f4330bdcf # v5.3.1 + with: + cancel_others: 'true' + concurrent_skipping: same_content_newer + # Check which paths were changed to determine which tests to run check-changes: name: Check Changed Files + needs: pre_job + if: needs.pre_job.outputs.should_skip != 'true' runs-on: ubuntu-latest outputs: api-changed: ${{ steps.changes.outputs.api }} + e2e-changed: ${{ steps.changes.outputs.e2e }} web-changed: ${{ steps.changes.outputs.web }} vdb-changed: ${{ steps.changes.outputs.vdb }} migration-changed: ${{ steps.changes.outputs.migration }} @@ -34,49 +54,372 @@ jobs: filters: | api: - 'api/**' - - 'docker/**' - '.github/workflows/api-tests.yml' + - '.github/workflows/expose_service_ports.sh' + - 'docker/.env.example' + - 'docker/middleware.env.example' + - 'docker/docker-compose.middleware.yaml' + - 'docker/docker-compose-template.yaml' + - 'docker/generate_docker_compose' + - 'docker/ssrf_proxy/**' + - 'docker/volumes/sandbox/conf/**' web: - 'web/**' + - 'package.json' + - 'pnpm-lock.yaml' + - 'pnpm-workspace.yaml' + - '.nvmrc' - '.github/workflows/web-tests.yml' - '.github/actions/setup-web/**' + e2e: + - 'api/**' + - 'api/pyproject.toml' + - 'api/uv.lock' + - 'e2e/**' + - 'web/**' + - 'package.json' + - 'pnpm-lock.yaml' + - 'pnpm-workspace.yaml' + - '.nvmrc' + - 'docker/docker-compose.middleware.yaml' + - 'docker/middleware.env.example' + - '.github/workflows/web-e2e.yml' + - '.github/actions/setup-web/**' vdb: - 'api/core/rag/datasource/**' - - 'docker/**' + - 'api/tests/integration_tests/vdb/**' - '.github/workflows/vdb-tests.yml' + - '.github/workflows/expose_service_ports.sh' + - 'docker/.env.example' + - 'docker/middleware.env.example' + - 'docker/docker-compose.yaml' + - 'docker/docker-compose-template.yaml' + - 'docker/generate_docker_compose' + - 'docker/certbot/**' + - 'docker/couchbase-server/**' + - 'docker/elasticsearch/**' + - 'docker/iris/**' + - 'docker/nginx/**' + - 'docker/pgvector/**' + - 'docker/ssrf_proxy/**' + - 'docker/startupscripts/**' + - 'docker/tidb/**' + - 'docker/volumes/**' - 'api/uv.lock' - 'api/pyproject.toml' migration: - 'api/migrations/**' + - 'api/.env.example' - '.github/workflows/db-migration-test.yml' + - '.github/workflows/expose_service_ports.sh' + - 'docker/.env.example' + - 'docker/middleware.env.example' + - 'docker/docker-compose.middleware.yaml' + - 'docker/docker-compose-template.yaml' + - 'docker/generate_docker_compose' + - 'docker/ssrf_proxy/**' + - 'docker/volumes/sandbox/conf/**' - # Run tests in parallel - api-tests: - name: API Tests - needs: check-changes - if: needs.check-changes.outputs.api-changed == 'true' + # Run tests in parallel while always emitting stable required checks. + api-tests-run: + name: Run API Tests + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.api-changed == 'true' uses: ./.github/workflows/api-tests.yml secrets: inherit - web-tests: - name: Web Tests - needs: check-changes - if: needs.check-changes.outputs.web-changed == 'true' + api-tests-skip: + name: Skip API Tests + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.api-changed != 'true' + runs-on: ubuntu-latest + steps: + - name: Report skipped API tests + run: echo "No API-related changes detected; skipping API tests." + + api-tests: + name: API Tests + if: ${{ always() }} + needs: + - pre_job + - check-changes + - api-tests-run + - api-tests-skip + runs-on: ubuntu-latest + steps: + - name: Finalize API Tests status + env: + SHOULD_SKIP_WORKFLOW: ${{ needs.pre_job.outputs.should_skip }} + TESTS_CHANGED: ${{ needs.check-changes.outputs.api-changed }} + RUN_RESULT: ${{ needs.api-tests-run.result }} + SKIP_RESULT: ${{ needs.api-tests-skip.result }} + run: | + if [[ "$SHOULD_SKIP_WORKFLOW" == 'true' ]]; then + echo "API tests were skipped because this workflow run duplicated a successful or newer run." + exit 0 + fi + + if [[ "$TESTS_CHANGED" == 'true' ]]; then + if [[ "$RUN_RESULT" == 'success' ]]; then + echo "API tests ran successfully." + exit 0 + fi + + echo "API tests were required but finished with result: $RUN_RESULT" >&2 + exit 1 + fi + + if [[ "$SKIP_RESULT" == 'success' ]]; then + echo "API tests were skipped because no API-related files changed." + exit 0 + fi + + echo "API tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2 + exit 1 + + web-tests-run: + name: Run Web Tests + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.web-changed == 'true' uses: ./.github/workflows/web-tests.yml secrets: inherit + web-tests-skip: + name: Skip Web Tests + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.web-changed != 'true' + runs-on: ubuntu-latest + steps: + - name: Report skipped web tests + run: echo "No web-related changes detected; skipping web tests." + + web-tests: + name: Web Tests + if: ${{ always() }} + needs: + - pre_job + - check-changes + - web-tests-run + - web-tests-skip + runs-on: ubuntu-latest + steps: + - name: Finalize Web Tests status + env: + SHOULD_SKIP_WORKFLOW: ${{ needs.pre_job.outputs.should_skip }} + TESTS_CHANGED: ${{ needs.check-changes.outputs.web-changed }} + RUN_RESULT: ${{ needs.web-tests-run.result }} + SKIP_RESULT: ${{ needs.web-tests-skip.result }} + run: | + if [[ "$SHOULD_SKIP_WORKFLOW" == 'true' ]]; then + echo "Web tests were skipped because this workflow run duplicated a successful or newer run." + exit 0 + fi + + if [[ "$TESTS_CHANGED" == 'true' ]]; then + if [[ "$RUN_RESULT" == 'success' ]]; then + echo "Web tests ran successfully." + exit 0 + fi + + echo "Web tests were required but finished with result: $RUN_RESULT" >&2 + exit 1 + fi + + if [[ "$SKIP_RESULT" == 'success' ]]; then + echo "Web tests were skipped because no web-related files changed." + exit 0 + fi + + echo "Web tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2 + exit 1 + + web-e2e-run: + name: Run Web Full-Stack E2E + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.e2e-changed == 'true' + uses: ./.github/workflows/web-e2e.yml + + web-e2e-skip: + name: Skip Web Full-Stack E2E + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.e2e-changed != 'true' + runs-on: ubuntu-latest + steps: + - name: Report skipped web full-stack e2e + run: echo "No E2E-related changes detected; skipping web full-stack E2E." + + web-e2e: + name: Web Full-Stack E2E + if: ${{ always() }} + needs: + - pre_job + - check-changes + - web-e2e-run + - web-e2e-skip + runs-on: ubuntu-latest + steps: + - name: Finalize Web Full-Stack E2E status + env: + SHOULD_SKIP_WORKFLOW: ${{ needs.pre_job.outputs.should_skip }} + TESTS_CHANGED: ${{ needs.check-changes.outputs.e2e-changed }} + RUN_RESULT: ${{ needs.web-e2e-run.result }} + SKIP_RESULT: ${{ needs.web-e2e-skip.result }} + run: | + if [[ "$SHOULD_SKIP_WORKFLOW" == 'true' ]]; then + echo "Web full-stack E2E was skipped because this workflow run duplicated a successful or newer run." + exit 0 + fi + + if [[ "$TESTS_CHANGED" == 'true' ]]; then + if [[ "$RUN_RESULT" == 'success' ]]; then + echo "Web full-stack E2E ran successfully." + exit 0 + fi + + echo "Web full-stack E2E was required but finished with result: $RUN_RESULT" >&2 + exit 1 + fi + + if [[ "$SKIP_RESULT" == 'success' ]]; then + echo "Web full-stack E2E was skipped because no E2E-related files changed." + exit 0 + fi + + echo "Web full-stack E2E was not required, but the skip job finished with result: $SKIP_RESULT" >&2 + exit 1 + style-check: name: Style Check + needs: pre_job + if: needs.pre_job.outputs.should_skip != 'true' uses: ./.github/workflows/style.yml + vdb-tests-run: + name: Run VDB Tests + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.vdb-changed == 'true' + uses: ./.github/workflows/vdb-tests.yml + + vdb-tests-skip: + name: Skip VDB Tests + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.vdb-changed != 'true' + runs-on: ubuntu-latest + steps: + - name: Report skipped VDB tests + run: echo "No VDB-related changes detected; skipping VDB tests." + vdb-tests: name: VDB Tests - needs: check-changes - if: needs.check-changes.outputs.vdb-changed == 'true' - uses: ./.github/workflows/vdb-tests.yml + if: ${{ always() }} + needs: + - pre_job + - check-changes + - vdb-tests-run + - vdb-tests-skip + runs-on: ubuntu-latest + steps: + - name: Finalize VDB Tests status + env: + SHOULD_SKIP_WORKFLOW: ${{ needs.pre_job.outputs.should_skip }} + TESTS_CHANGED: ${{ needs.check-changes.outputs.vdb-changed }} + RUN_RESULT: ${{ needs.vdb-tests-run.result }} + SKIP_RESULT: ${{ needs.vdb-tests-skip.result }} + run: | + if [[ "$SHOULD_SKIP_WORKFLOW" == 'true' ]]; then + echo "VDB tests were skipped because this workflow run duplicated a successful or newer run." + exit 0 + fi + + if [[ "$TESTS_CHANGED" == 'true' ]]; then + if [[ "$RUN_RESULT" == 'success' ]]; then + echo "VDB tests ran successfully." + exit 0 + fi + + echo "VDB tests were required but finished with result: $RUN_RESULT" >&2 + exit 1 + fi + + if [[ "$SKIP_RESULT" == 'success' ]]; then + echo "VDB tests were skipped because no VDB-related files changed." + exit 0 + fi + + echo "VDB tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2 + exit 1 + + db-migration-test-run: + name: Run DB Migration Test + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.migration-changed == 'true' + uses: ./.github/workflows/db-migration-test.yml + + db-migration-test-skip: + name: Skip DB Migration Test + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.migration-changed != 'true' + runs-on: ubuntu-latest + steps: + - name: Report skipped DB migration tests + run: echo "No migration-related changes detected; skipping DB migration tests." db-migration-test: name: DB Migration Test - needs: check-changes - if: needs.check-changes.outputs.migration-changed == 'true' - uses: ./.github/workflows/db-migration-test.yml + if: ${{ always() }} + needs: + - pre_job + - check-changes + - db-migration-test-run + - db-migration-test-skip + runs-on: ubuntu-latest + steps: + - name: Finalize DB Migration Test status + env: + SHOULD_SKIP_WORKFLOW: ${{ needs.pre_job.outputs.should_skip }} + TESTS_CHANGED: ${{ needs.check-changes.outputs.migration-changed }} + RUN_RESULT: ${{ needs.db-migration-test-run.result }} + SKIP_RESULT: ${{ needs.db-migration-test-skip.result }} + run: | + if [[ "$SHOULD_SKIP_WORKFLOW" == 'true' ]]; then + echo "DB migration tests were skipped because this workflow run duplicated a successful or newer run." + exit 0 + fi + + if [[ "$TESTS_CHANGED" == 'true' ]]; then + if [[ "$RUN_RESULT" == 'success' ]]; then + echo "DB migration tests ran successfully." + exit 0 + fi + + echo "DB migration tests were required but finished with result: $RUN_RESULT" >&2 + exit 1 + fi + + if [[ "$SKIP_RESULT" == 'success' ]]; then + echo "DB migration tests were skipped because no migration-related files changed." + exit 0 + fi + + echo "DB migration tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2 + exit 1 diff --git a/.github/workflows/pyrefly-diff.yml b/.github/workflows/pyrefly-diff.yml index a00f469bbeb..0b2a7b8e9e0 100644 --- a/.github/workflows/pyrefly-diff.yml +++ b/.github/workflows/pyrefly-diff.yml @@ -22,7 +22,7 @@ jobs: fetch-depth: 0 - name: Setup Python & UV - uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 + uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 with: enable-cache: true diff --git a/.github/workflows/semantic-pull-request.yml b/.github/workflows/semantic-pull-request.yml index c21331ec0d0..49d2e946956 100644 --- a/.github/workflows/semantic-pull-request.yml +++ b/.github/workflows/semantic-pull-request.yml @@ -7,6 +7,9 @@ on: - edited - reopened - synchronize + merge_group: + branches: ["main"] + types: [checks_requested] jobs: lint: @@ -15,7 +18,11 @@ jobs: pull-requests: read runs-on: ubuntu-latest steps: + - name: Complete merge group check + if: github.event_name == 'merge_group' + run: echo "Semantic PR title validation is handled on pull requests." - name: Check title + if: github.event_name == 'pull_request' uses: amannn/action-semantic-pull-request@48f256284bd46cdaab1048c3721360e808335d50 # v6.1.1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 657a481f743..9bc4ceaa93e 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -33,7 +33,7 @@ jobs: - name: Setup UV and Python if: steps.changed-files.outputs.any_changed == 'true' - uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 + uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 with: enable-cache: false python-version: "3.12" @@ -49,7 +49,7 @@ jobs: - name: Run Type Checks if: steps.changed-files.outputs.any_changed == 'true' - run: make type-check + run: make type-check-core - name: Dotenv check if: steps.changed-files.outputs.any_changed == 'true' @@ -77,6 +77,10 @@ jobs: with: files: | web/** + package.json + pnpm-lock.yaml + pnpm-workspace.yaml + .nvmrc .github/workflows/style.yml .github/actions/setup-web/** @@ -84,20 +88,20 @@ jobs: if: steps.changed-files.outputs.any_changed == 'true' uses: ./.github/actions/setup-web + - name: Restore ESLint cache + if: steps.changed-files.outputs.any_changed == 'true' + id: eslint-cache-restore + uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4 + with: + path: web/.eslintcache + key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }} + restore-keys: | + ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}- + - name: Web style check if: steps.changed-files.outputs.any_changed == 'true' working-directory: ./web - run: | - vp run lint:ci - # pnpm run lint:report - # continue-on-error: true - - # - name: Annotate Code - # if: steps.changed-files.outputs.any_changed == 'true' && github.event_name == 'pull_request' - # uses: DerLev/eslint-annotations@51347b3a0abfb503fc8734d5ae31c4b151297fae - # with: - # eslint-report: web/eslint_report.json - # github-token: ${{ secrets.GITHUB_TOKEN }} + run: vp run lint:ci - name: Web tsslint if: steps.changed-files.outputs.any_changed == 'true' @@ -114,6 +118,13 @@ jobs: working-directory: ./web run: vp run knip + - name: Save ESLint cache + if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true' + uses: actions/cache/save@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4 + with: + path: web/.eslintcache + key: ${{ steps.eslint-cache-restore.outputs.cache-primary-key }} + superlinter: name: SuperLinter runs-on: ubuntu-latest diff --git a/.github/workflows/tool-test-sdks.yaml b/.github/workflows/tool-test-sdks.yaml index 3fc351c0c29..536a52b5600 100644 --- a/.github/workflows/tool-test-sdks.yaml +++ b/.github/workflows/tool-test-sdks.yaml @@ -6,6 +6,9 @@ on: - main paths: - sdks/** + - package.json + - pnpm-lock.yaml + - pnpm-workspace.yaml concurrency: group: sdk-tests-${{ github.head_ref || github.run_id }} diff --git a/.github/workflows/translate-i18n-claude.yml b/.github/workflows/translate-i18n-claude.yml index 849f965c363..aaf51aa6064 100644 --- a/.github/workflows/translate-i18n-claude.yml +++ b/.github/workflows/translate-i18n-claude.yml @@ -1,26 +1,24 @@ name: Translate i18n Files with Claude Code -# Note: claude-code-action doesn't support push events directly. -# Push events are handled by trigger-i18n-sync.yml which sends repository_dispatch. -# See: https://github.com/langgenius/dify/issues/30743 - on: - repository_dispatch: - types: [i18n-sync] + push: + branches: [main] + paths: + - 'web/i18n/en-US/*.json' workflow_dispatch: inputs: files: - description: 'Specific files to translate (space-separated, e.g., "app common"). Leave empty for all files.' + description: 'Specific files to translate (space-separated, e.g., "app common"). Required for full mode; leave empty in incremental mode to use en-US files changed since HEAD~1.' required: false type: string languages: - description: 'Specific languages to translate (space-separated, e.g., "zh-Hans ja-JP"). Leave empty for all supported languages.' + description: 'Specific languages to translate (space-separated, e.g., "zh-Hans ja-JP"). Leave empty for all supported target languages except en-US.' required: false type: string mode: - description: 'Sync mode: incremental (only changes) or full (re-check all keys)' + description: 'Sync mode: incremental (compare with previous en-US revision) or full (sync all keys in scope)' required: false - default: 'incremental' + default: incremental type: choice options: - incremental @@ -30,11 +28,15 @@ permissions: contents: write pull-requests: write +concurrency: + group: translate-i18n-${{ github.event_name }}-${{ github.ref }} + cancel-in-progress: ${{ github.event_name == 'push' }} + jobs: translate: if: github.repository == 'langgenius/dify' runs-on: ubuntu-latest - timeout-minutes: 60 + timeout-minutes: 120 steps: - name: Checkout repository @@ -51,380 +53,161 @@ jobs: - name: Setup web environment uses: ./.github/actions/setup-web - - name: Detect changed files and generate diff - id: detect_changes + - name: Prepare sync context + id: context + shell: bash run: | - if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then - # Manual trigger - if [ -n "${{ github.event.inputs.files }}" ]; then - echo "CHANGED_FILES=${{ github.event.inputs.files }}" >> $GITHUB_OUTPUT - else - # Get all JSON files in en-US directory - files=$(ls web/i18n/en-US/*.json 2>/dev/null | xargs -n1 basename | sed 's/.json$//' | tr '\n' ' ') - echo "CHANGED_FILES=$files" >> $GITHUB_OUTPUT - fi - echo "TARGET_LANGS=${{ github.event.inputs.languages }}" >> $GITHUB_OUTPUT - echo "SYNC_MODE=${{ github.event.inputs.mode || 'incremental' }}" >> $GITHUB_OUTPUT + DEFAULT_TARGET_LANGS=$(awk " + /value: '/ { + value=\$2 + gsub(/[',]/, \"\", value) + } + /supported: true/ && value != \"en-US\" { + printf \"%s \", value + } + " web/i18n-config/languages.ts | sed 's/[[:space:]]*$//') - # For manual trigger with incremental mode, get diff from last commit - # For full mode, we'll do a complete check anyway - if [ "${{ github.event.inputs.mode }}" == "full" ]; then - echo "Full mode: will check all keys" > /tmp/i18n-diff.txt - echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT + if [ "${{ github.event_name }}" = "push" ]; then + BASE_SHA="${{ github.event.before }}" + if [ -z "$BASE_SHA" ] || [ "$BASE_SHA" = "0000000000000000000000000000000000000000" ]; then + BASE_SHA=$(git rev-parse HEAD~1 2>/dev/null || true) + fi + HEAD_SHA="${{ github.sha }}" + if [ -n "$BASE_SHA" ]; then + CHANGED_FILES=$(git diff --name-only "$BASE_SHA" "$HEAD_SHA" -- 'web/i18n/en-US/*.json' 2>/dev/null | sed -n 's@^.*/@@p' | sed 's/\.json$//' | tr '\n' ' ' | sed 's/[[:space:]]*$//') else - git diff HEAD~1..HEAD -- 'web/i18n/en-US/*.json' > /tmp/i18n-diff.txt 2>/dev/null || echo "" > /tmp/i18n-diff.txt - if [ -s /tmp/i18n-diff.txt ]; then - echo "DIFF_AVAILABLE=true" >> $GITHUB_OUTPUT - else - echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT - fi - fi - elif [ "${{ github.event_name }}" == "repository_dispatch" ]; then - # Triggered by push via trigger-i18n-sync.yml workflow - # Validate required payload fields - if [ -z "${{ github.event.client_payload.changed_files }}" ]; then - echo "Error: repository_dispatch payload missing required 'changed_files' field" >&2 - exit 1 - fi - echo "CHANGED_FILES=${{ github.event.client_payload.changed_files }}" >> $GITHUB_OUTPUT - echo "TARGET_LANGS=" >> $GITHUB_OUTPUT - echo "SYNC_MODE=${{ github.event.client_payload.sync_mode || 'incremental' }}" >> $GITHUB_OUTPUT - - # Decode the base64-encoded diff from the trigger workflow - if [ -n "${{ github.event.client_payload.diff_base64 }}" ]; then - if ! echo "${{ github.event.client_payload.diff_base64 }}" | base64 -d > /tmp/i18n-diff.txt 2>&1; then - echo "Warning: Failed to decode base64 diff payload" >&2 - echo "" > /tmp/i18n-diff.txt - echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT - elif [ -s /tmp/i18n-diff.txt ]; then - echo "DIFF_AVAILABLE=true" >> $GITHUB_OUTPUT - else - echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT - fi - else - echo "" > /tmp/i18n-diff.txt - echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT + CHANGED_FILES=$(find web/i18n/en-US -maxdepth 1 -type f -name '*.json' -print | sed -n 's@^.*/@@p' | sed 's/\.json$//' | sort | tr '\n' ' ' | sed 's/[[:space:]]*$//') fi + TARGET_LANGS="$DEFAULT_TARGET_LANGS" + SYNC_MODE="incremental" else - echo "Unsupported event type: ${{ github.event_name }}" - exit 1 + BASE_SHA="" + HEAD_SHA=$(git rev-parse HEAD) + if [ -n "${{ github.event.inputs.languages }}" ]; then + TARGET_LANGS="${{ github.event.inputs.languages }}" + else + TARGET_LANGS="$DEFAULT_TARGET_LANGS" + fi + SYNC_MODE="${{ github.event.inputs.mode || 'incremental' }}" + if [ -n "${{ github.event.inputs.files }}" ]; then + CHANGED_FILES="${{ github.event.inputs.files }}" + elif [ "$SYNC_MODE" = "incremental" ]; then + BASE_SHA=$(git rev-parse HEAD~1 2>/dev/null || true) + if [ -n "$BASE_SHA" ]; then + CHANGED_FILES=$(git diff --name-only "$BASE_SHA" "$HEAD_SHA" -- 'web/i18n/en-US/*.json' 2>/dev/null | sed -n 's@^.*/@@p' | sed 's/\.json$//' | tr '\n' ' ' | sed 's/[[:space:]]*$//') + else + CHANGED_FILES=$(find web/i18n/en-US -maxdepth 1 -type f -name '*.json' -print | sed -n 's@^.*/@@p' | sed 's/\.json$//' | sort | tr '\n' ' ' | sed 's/[[:space:]]*$//') + fi + elif [ "$SYNC_MODE" = "full" ]; then + echo "workflow_dispatch full mode requires the files input to stay within CI limits." >&2 + exit 1 + else + CHANGED_FILES="" + fi fi - # Truncate diff if too large (keep first 50KB) - if [ -f /tmp/i18n-diff.txt ]; then - head -c 50000 /tmp/i18n-diff.txt > /tmp/i18n-diff-truncated.txt - mv /tmp/i18n-diff-truncated.txt /tmp/i18n-diff.txt + FILE_ARGS="" + if [ -n "$CHANGED_FILES" ]; then + FILE_ARGS="--file $CHANGED_FILES" fi - echo "Detected files: $(cat $GITHUB_OUTPUT | grep CHANGED_FILES || echo 'none')" + LANG_ARGS="" + if [ -n "$TARGET_LANGS" ]; then + LANG_ARGS="--lang $TARGET_LANGS" + fi + + { + echo "DEFAULT_TARGET_LANGS=$DEFAULT_TARGET_LANGS" + echo "BASE_SHA=$BASE_SHA" + echo "HEAD_SHA=$HEAD_SHA" + echo "CHANGED_FILES=$CHANGED_FILES" + echo "TARGET_LANGS=$TARGET_LANGS" + echo "SYNC_MODE=$SYNC_MODE" + echo "FILE_ARGS=$FILE_ARGS" + echo "LANG_ARGS=$LANG_ARGS" + } >> "$GITHUB_OUTPUT" + + echo "Files: ${CHANGED_FILES:-}" + echo "Languages: ${TARGET_LANGS:-}" + echo "Mode: $SYNC_MODE" - name: Run Claude Code for Translation Sync - if: steps.detect_changes.outputs.CHANGED_FILES != '' - uses: anthropics/claude-code-action@df37d2f0760a4b5683a6e617c9325bc1a36443f6 # v1.0.75 + if: steps.context.outputs.CHANGED_FILES != '' + uses: anthropics/claude-code-action@88c168b39e7e64da0286d812b6e9fbebb6708185 # v1.0.82 with: anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} github_token: ${{ secrets.GITHUB_TOKEN }} - # Allow github-actions bot to trigger this workflow via repository_dispatch - # See: https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md allowed_bots: 'github-actions[bot]' + show_full_output: ${{ github.event_name == 'workflow_dispatch' }} prompt: | - You are a professional i18n synchronization engineer for the Dify project. - Your task is to keep all language translations in sync with the English source (en-US). + You are the i18n sync agent for the Dify repository. + Your job is to keep translations synchronized with the English source files under `${{ github.workspace }}/web/i18n/en-US/`, then open a PR with the result. - ## CRITICAL TOOL RESTRICTIONS - - Use **Read** tool to read files (NOT cat or bash) - - Use **Edit** tool to modify JSON files (NOT node, jq, or bash scripts) - - Use **Bash** ONLY for: git commands, gh commands, pnpm commands - - Run bash commands ONE BY ONE, never combine with && or || - - NEVER use `$()` command substitution - it's not supported. Split into separate commands instead. + Use absolute paths at all times: + - Repo root: `${{ github.workspace }}` + - Web directory: `${{ github.workspace }}/web` + - Language config: `${{ github.workspace }}/web/i18n-config/languages.ts` - ## WORKING DIRECTORY & ABSOLUTE PATHS - Claude Code sandbox working directory may vary. Always use absolute paths: - - For pnpm: `pnpm --dir ${{ github.workspace }}/web ` - - For git: `git -C ${{ github.workspace }} ` - - For gh: `gh --repo ${{ github.repository }} ` - - For file paths: `${{ github.workspace }}/web/i18n/` + Inputs: + - Files in scope: `${{ steps.context.outputs.CHANGED_FILES }}` + - Target languages: `${{ steps.context.outputs.TARGET_LANGS }}` + - Sync mode: `${{ steps.context.outputs.SYNC_MODE }}` + - Base SHA: `${{ steps.context.outputs.BASE_SHA }}` + - Head SHA: `${{ steps.context.outputs.HEAD_SHA }}` + - Scoped file args: `${{ steps.context.outputs.FILE_ARGS }}` + - Scoped language args: `${{ steps.context.outputs.LANG_ARGS }}` - ## EFFICIENCY RULES - - **ONE Edit per language file** - batch all key additions into a single Edit - - Insert new keys at the beginning of JSON (after `{`), lint:fix will sort them - - Translate ALL keys for a language mentally first, then do ONE Edit - - ## Context - - Changed/target files: ${{ steps.detect_changes.outputs.CHANGED_FILES }} - - Target languages (empty means all supported): ${{ steps.detect_changes.outputs.TARGET_LANGS }} - - Sync mode: ${{ steps.detect_changes.outputs.SYNC_MODE }} - - Translation files are located in: ${{ github.workspace }}/web/i18n/{locale}/{filename}.json - - Language configuration is in: ${{ github.workspace }}/web/i18n-config/languages.ts - - Git diff is available: ${{ steps.detect_changes.outputs.DIFF_AVAILABLE }} - - ## CRITICAL DESIGN: Verify First, Then Sync - - You MUST follow this three-phase approach: - - ═══════════════════════════════════════════════════════════════ - ║ PHASE 1: VERIFY - Analyze and Generate Change Report ║ - ═══════════════════════════════════════════════════════════════ - - ### Step 1.1: Analyze Git Diff (for incremental mode) - Use the Read tool to read `/tmp/i18n-diff.txt` to see the git diff. - - Parse the diff to categorize changes: - - Lines with `+` (not `+++`): Added or modified values - - Lines with `-` (not `---`): Removed or old values - - Identify specific keys for each category: - * ADD: Keys that appear only in `+` lines (new keys) - * UPDATE: Keys that appear in both `-` and `+` lines (value changed) - * DELETE: Keys that appear only in `-` lines (removed keys) - - ### Step 1.2: Read Language Configuration - Use the Read tool to read `${{ github.workspace }}/web/i18n-config/languages.ts`. - Extract all languages with `supported: true`. - - ### Step 1.3: Run i18n:check for Each Language - ```bash - pnpm --dir ${{ github.workspace }}/web install --frozen-lockfile - ``` - ```bash - pnpm --dir ${{ github.workspace }}/web run i18n:check - ``` - - This will report: - - Missing keys (need to ADD) - - Extra keys (need to DELETE) - - ### Step 1.4: Generate Change Report - - Create a structured report identifying: - ``` - ╔══════════════════════════════════════════════════════════════╗ - ║ I18N SYNC CHANGE REPORT ║ - ╠══════════════════════════════════════════════════════════════╣ - ║ Files to process: [list] ║ - ║ Languages to sync: [list] ║ - ╠══════════════════════════════════════════════════════════════╣ - ║ ADD (New Keys): ║ - ║ - [filename].[key]: "English value" ║ - ║ ... ║ - ╠══════════════════════════════════════════════════════════════╣ - ║ UPDATE (Modified Keys - MUST re-translate): ║ - ║ - [filename].[key]: "Old value" → "New value" ║ - ║ ... ║ - ╠══════════════════════════════════════════════════════════════╣ - ║ DELETE (Extra Keys): ║ - ║ - [language]/[filename].[key] ║ - ║ ... ║ - ╚══════════════════════════════════════════════════════════════╝ - ``` - - **IMPORTANT**: For UPDATE detection, compare git diff to find keys where - the English value changed. These MUST be re-translated even if target - language already has a translation (it's now stale!). - - ═══════════════════════════════════════════════════════════════ - ║ PHASE 2: SYNC - Execute Changes Based on Report ║ - ═══════════════════════════════════════════════════════════════ - - ### Step 2.1: Process ADD Operations (BATCH per language file) - - **CRITICAL WORKFLOW for efficiency:** - 1. First, translate ALL new keys for ALL languages mentally - 2. Then, for EACH language file, do ONE Edit operation: - - Read the file once - - Insert ALL new keys at the beginning (right after the opening `{`) - - Don't worry about alphabetical order - lint:fix will sort them later - - Example Edit (adding 3 keys to zh-Hans/app.json): - ``` - old_string: '{\n "accessControl"' - new_string: '{\n "newKey1": "translation1",\n "newKey2": "translation2",\n "newKey3": "translation3",\n "accessControl"' - ``` - - **IMPORTANT**: - - ONE Edit per language file (not one Edit per key!) - - Always use the Edit tool. NEVER use bash scripts, node, or jq. - - ### Step 2.2: Process UPDATE Operations - - **IMPORTANT: Special handling for zh-Hans and ja-JP** - If zh-Hans or ja-JP files were ALSO modified in the same push: - - Run: `git -C ${{ github.workspace }} diff HEAD~1 --name-only` and check for zh-Hans or ja-JP files - - If found, it means someone manually translated them. Apply these rules: - - 1. **Missing keys**: Still ADD them (completeness required) - 2. **Existing translations**: Compare with the NEW English value: - - If translation is **completely wrong** or **unrelated** → Update it - - If translation is **roughly correct** (captures the meaning) → Keep it, respect manual work - - When in doubt, **keep the manual translation** - - Example: - - English changed: "Save" → "Save Changes" - - Manual translation: "保存更改" → Keep it (correct meaning) - - Manual translation: "删除" → Update it (completely wrong) - - For other languages: - Use Edit tool to replace the old value with the new translation. - You can batch multiple updates in one Edit if they are adjacent. - - ### Step 2.3: Process DELETE Operations - For extra keys reported by i18n:check: - - Run: `pnpm --dir ${{ github.workspace }}/web run i18n:check --auto-remove` - - Or manually remove from target language JSON files - - ## Translation Guidelines - - - PRESERVE all placeholders exactly as-is: - - `{{variable}}` - Mustache interpolation - - `${variable}` - Template literal - - `content` - HTML tags - - `_one`, `_other` - Pluralization suffixes (these are KEY suffixes, not values) - - **CRITICAL: Variable names and tag names MUST stay in English - NEVER translate them** - - ✅ CORRECT examples: - - English: "{{count}} items" → Japanese: "{{count}} 個のアイテム" - - English: "{{name}} updated" → Korean: "{{name}} 업데이트됨" - - English: "{{email}}" → Chinese: "{{email}}" - - English: "Marketplace" → Japanese: "マーケットプレイス" - - ❌ WRONG examples (NEVER do this - will break the application): - - "{{count}}" → "{{カウント}}" ❌ (variable name translated to Japanese) - - "{{name}}" → "{{이름}}" ❌ (variable name translated to Korean) - - "{{email}}" → "{{邮箱}}" ❌ (variable name translated to Chinese) - - "" → "<メール>" ❌ (tag name translated) - - "" → "<自定义链接>" ❌ (component name translated) - - - Use appropriate language register (formal/informal) based on existing translations - - Match existing translation style in each language - - Technical terms: check existing conventions per language - - For CJK languages: no spaces between characters unless necessary - - For RTL languages (ar-TN, fa-IR): ensure proper text handling - - ## Output Format Requirements - - Alphabetical key ordering (if original file uses it) - - 2-space indentation - - Trailing newline at end of file - - Valid JSON (use proper escaping for special characters) - - ═══════════════════════════════════════════════════════════════ - ║ PHASE 3: RE-VERIFY - Confirm All Issues Resolved ║ - ═══════════════════════════════════════════════════════════════ - - ### Step 3.1: Run Lint Fix (IMPORTANT!) - ```bash - pnpm --dir ${{ github.workspace }}/web lint:fix --quiet -- 'i18n/**/*.json' - ``` - This ensures: - - JSON keys are sorted alphabetically (jsonc/sort-keys rule) - - Valid i18n keys (dify-i18n/valid-i18n-keys rule) - - No extra keys (dify-i18n/no-extra-keys rule) - - ### Step 3.2: Run Final i18n Check - ```bash - pnpm --dir ${{ github.workspace }}/web run i18n:check - ``` - - ### Step 3.3: Fix Any Remaining Issues - If check reports issues: - - Go back to PHASE 2 for unresolved items - - Repeat until check passes - - ### Step 3.4: Generate Final Summary - ``` - ╔══════════════════════════════════════════════════════════════╗ - ║ SYNC COMPLETED SUMMARY ║ - ╠══════════════════════════════════════════════════════════════╣ - ║ Language │ Added │ Updated │ Deleted │ Status ║ - ╠══════════════════════════════════════════════════════════════╣ - ║ zh-Hans │ 5 │ 2 │ 1 │ ✓ Complete ║ - ║ ja-JP │ 5 │ 2 │ 1 │ ✓ Complete ║ - ║ ... │ ... │ ... │ ... │ ... ║ - ╠══════════════════════════════════════════════════════════════╣ - ║ i18n:check │ PASSED - All keys in sync ║ - ╚══════════════════════════════════════════════════════════════╝ - ``` - - ## Mode-Specific Behavior - - **SYNC_MODE = "incremental"** (default): - - Focus on keys identified from git diff - - Also check i18n:check output for any missing/extra keys - - Efficient for small changes - - **SYNC_MODE = "full"**: - - Compare ALL keys between en-US and each language - - Run i18n:check to identify all discrepancies - - Use for first-time sync or fixing historical issues - - ## Important Notes - - 1. Always run i18n:check BEFORE and AFTER making changes - 2. The check script is the source of truth for missing/extra keys - 3. For UPDATE scenario: git diff is the source of truth for changed values - 4. Create a single commit with all translation changes - 5. If any translation fails, continue with others and report failures - - ═══════════════════════════════════════════════════════════════ - ║ PHASE 4: COMMIT AND CREATE PR ║ - ═══════════════════════════════════════════════════════════════ - - After all translations are complete and verified: - - ### Step 4.1: Check for changes - ```bash - git -C ${{ github.workspace }} status --porcelain - ``` - - If there are changes: - - ### Step 4.2: Create a new branch and commit - Run these git commands ONE BY ONE (not combined with &&). - **IMPORTANT**: Do NOT use `$()` command substitution. Use two separate commands: - - 1. First, get the timestamp: - ```bash - date +%Y%m%d-%H%M%S - ``` - (Note the output, e.g., "20260115-143052") - - 2. Then create branch using the timestamp value: - ```bash - git -C ${{ github.workspace }} checkout -b chore/i18n-sync-20260115-143052 - ``` - (Replace "20260115-143052" with the actual timestamp from step 1) - - 3. Stage changes: - ```bash - git -C ${{ github.workspace }} add web/i18n/ - ``` - - 4. Commit: - ```bash - git -C ${{ github.workspace }} commit -m "chore(i18n): sync translations with en-US - Mode: ${{ steps.detect_changes.outputs.SYNC_MODE }}" - ``` - - 5. Push: - ```bash - git -C ${{ github.workspace }} push origin HEAD - ``` - - ### Step 4.3: Create Pull Request - ```bash - gh pr create --repo ${{ github.repository }} --title "chore(i18n): sync translations with en-US" --body "## Summary - - This PR was automatically generated to sync i18n translation files. - - ### Changes - - Mode: ${{ steps.detect_changes.outputs.SYNC_MODE }} - - Files processed: ${{ steps.detect_changes.outputs.CHANGED_FILES }} - - ### Verification - - [x] \`i18n:check\` passed - - [x] \`lint:fix\` applied - - 🤖 Generated with Claude Code GitHub Action" --base main - ``` + Tool rules: + - Use Read for repository files. + - Use Edit for JSON updates. + - Use Bash only for `git`, `gh`, `pnpm`, and `date`. + - Run Bash commands one by one. Do not combine commands with `&&`, `||`, pipes, or command substitution. + Required execution plan: + 1. Resolve target languages. + - Use the provided `Target languages` value as the source of truth. + - If it is unexpectedly empty, read `${{ github.workspace }}/web/i18n-config/languages.ts` and use every language with `supported: true` except `en-US`. + 2. Stay strictly in scope. + - Only process the files listed in `Files in scope`. + - Only process the resolved target languages, never `en-US`. + - Do not touch unrelated i18n files. + - Do not modify `${{ github.workspace }}/web/i18n/en-US/`. + 3. Detect English changes per file. + - Read the current English JSON file for each file in scope. + - If sync mode is `incremental` and `Base SHA` is not empty, run: + `git -C ${{ github.workspace }} show :web/i18n/en-US/.json` + - If sync mode is `full` or `Base SHA` is empty, skip historical comparison and treat the current English file as the only source of truth for structural sync. + - If the file did not exist at Base SHA, treat all current keys as ADD. + - Compare previous and current English JSON to identify: + - ADD: key only in current + - UPDATE: key exists in both and the English value changed + - DELETE: key only in previous + - Do not rely on a truncated diff file. + 4. Run a scoped pre-check before editing: + - `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}` + - Use this command as the source of truth for missing and extra keys inside the current scope. + 5. Apply translations. + - For every target language and scoped file: + - If the locale file does not exist yet, create it with `Write` and then continue with `Edit` as needed. + - ADD missing keys. + - UPDATE stale translations when the English value changed. + - DELETE removed keys. Prefer `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }} --auto-remove` for extra keys so deletions stay in scope. + - For `zh-Hans` and `ja-JP`, if the locale file also changed between Base SHA and Head SHA, preserve manual translations unless they are clearly wrong for the new English value. If in doubt, keep the manual translation. + - Preserve placeholders exactly: `{{variable}}`, `${variable}`, HTML tags, component tags, and variable names. + - Match the existing terminology and register used by each locale. + - Prefer one Edit per file when stable, but prioritize correctness over batching. + 6. Verify only the edited files. + - Run `pnpm --dir ${{ github.workspace }}/web lint:fix --quiet -- ` + - Run `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}` + - If verification fails, fix the remaining problems before continuing. + 7. Create a PR only when there are changes in `web/i18n/`. + - Check `git -C ${{ github.workspace }} status --porcelain -- web/i18n/` + - Create branch `chore/i18n-sync-` + - Commit message: `chore(i18n): sync translations with en-US` + - Push the branch and open a PR against `main` + - PR title: `chore(i18n): sync translations with en-US` + - PR body: summarize files, languages, sync mode, and verification commands + 8. If there are no translation changes after verification, do not create a branch, commit, or PR. claude_args: | - --max-turns 150 + --max-turns 80 --allowedTools "Read,Write,Edit,Bash(git *),Bash(git:*),Bash(gh *),Bash(gh:*),Bash(pnpm *),Bash(pnpm:*),Bash(date *),Bash(date:*),Glob,Grep" diff --git a/.github/workflows/trigger-i18n-sync.yml b/.github/workflows/trigger-i18n-sync.yml deleted file mode 100644 index 1caaddd47a9..00000000000 --- a/.github/workflows/trigger-i18n-sync.yml +++ /dev/null @@ -1,66 +0,0 @@ -name: Trigger i18n Sync on Push - -# This workflow bridges the push event to repository_dispatch -# because claude-code-action doesn't support push events directly. -# See: https://github.com/langgenius/dify/issues/30743 - -on: - push: - branches: [main] - paths: - - 'web/i18n/en-US/*.json' - -permissions: - contents: write - -jobs: - trigger: - if: github.repository == 'langgenius/dify' - runs-on: ubuntu-latest - timeout-minutes: 5 - - steps: - - name: Checkout repository - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - with: - fetch-depth: 0 - - - name: Detect changed files and generate diff - id: detect - run: | - BEFORE_SHA="${{ github.event.before }}" - # Handle edge case: force push may have null/zero SHA - if [ -z "$BEFORE_SHA" ] || [ "$BEFORE_SHA" = "0000000000000000000000000000000000000000" ]; then - BEFORE_SHA="HEAD~1" - fi - - # Detect changed i18n files - changed=$(git diff --name-only "$BEFORE_SHA" "${{ github.sha }}" -- 'web/i18n/en-US/*.json' 2>/dev/null | xargs -n1 basename 2>/dev/null | sed 's/.json$//' | tr '\n' ' ' || echo "") - echo "changed_files=$changed" >> $GITHUB_OUTPUT - - # Generate diff for context - git diff "$BEFORE_SHA" "${{ github.sha }}" -- 'web/i18n/en-US/*.json' > /tmp/i18n-diff.txt 2>/dev/null || echo "" > /tmp/i18n-diff.txt - - # Truncate if too large (keep first 50KB to match receiving workflow) - head -c 50000 /tmp/i18n-diff.txt > /tmp/i18n-diff-truncated.txt - mv /tmp/i18n-diff-truncated.txt /tmp/i18n-diff.txt - - # Base64 encode the diff for safe JSON transport (portable, single-line) - diff_base64=$(base64 < /tmp/i18n-diff.txt | tr -d '\n') - echo "diff_base64=$diff_base64" >> $GITHUB_OUTPUT - - if [ -n "$changed" ]; then - echo "has_changes=true" >> $GITHUB_OUTPUT - echo "Detected changed files: $changed" - else - echo "has_changes=false" >> $GITHUB_OUTPUT - echo "No i18n changes detected" - fi - - - name: Trigger i18n sync workflow - if: steps.detect.outputs.has_changes == 'true' - uses: peter-evans/repository-dispatch@28959ce8df70de7be546dd1250a005dd32156697 # v4.0.1 - with: - token: ${{ secrets.GITHUB_TOKEN }} - event-type: i18n-sync - client-payload: '{"changed_files": "${{ steps.detect.outputs.changed_files }}", "diff_base64": "${{ steps.detect.outputs.diff_base64 }}", "sync_mode": "incremental", "trigger_sha": "${{ github.sha }}"}' diff --git a/.github/workflows/vdb-tests.yml b/.github/workflows/vdb-tests.yml index f45f2137d6a..026ff0fe57e 100644 --- a/.github/workflows/vdb-tests.yml +++ b/.github/workflows/vdb-tests.yml @@ -14,7 +14,6 @@ jobs: strategy: matrix: python-version: - - "3.11" - "3.12" steps: @@ -31,7 +30,7 @@ jobs: remove_tool_cache: true - name: Setup UV and Python - uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 + uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 with: enable-cache: true python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/web-e2e.yml b/.github/workflows/web-e2e.yml new file mode 100644 index 00000000000..eb752619be8 --- /dev/null +++ b/.github/workflows/web-e2e.yml @@ -0,0 +1,68 @@ +name: Web Full-Stack E2E + +on: + workflow_call: + +permissions: + contents: read + +concurrency: + group: web-e2e-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + test: + name: Web Full-Stack E2E + runs-on: ubuntu-latest + defaults: + run: + shell: bash + + steps: + - name: Checkout code + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false + + - name: Setup web dependencies + uses: ./.github/actions/setup-web + + - name: Setup UV and Python + uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + with: + enable-cache: true + python-version: "3.12" + cache-dependency-glob: api/uv.lock + + - name: Install API dependencies + run: uv sync --project api --dev + + - name: Install Playwright browser + working-directory: ./e2e + run: vp run e2e:install + + - name: Run isolated source-api and built-web Cucumber E2E tests + working-directory: ./e2e + env: + E2E_ADMIN_EMAIL: e2e-admin@example.com + E2E_ADMIN_NAME: E2E Admin + E2E_ADMIN_PASSWORD: E2eAdmin12345 + E2E_FORCE_WEB_BUILD: "1" + E2E_INIT_PASSWORD: E2eInit12345 + run: vp run e2e:full + + - name: Upload Cucumber report + if: ${{ !cancelled() }} + uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + with: + name: cucumber-report + path: e2e/cucumber-report + retention-days: 7 + + - name: Upload E2E logs + if: ${{ !cancelled() }} + uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + with: + name: e2e-logs + path: e2e/.logs + retention-days: 7 diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml index d40cd4bfebc..3c36335e791 100644 --- a/.github/workflows/web-tests.yml +++ b/.github/workflows/web-tests.yml @@ -22,8 +22,8 @@ jobs: strategy: fail-fast: false matrix: - shardIndex: [1, 2, 3, 4, 5, 6] - shardTotal: [6] + shardIndex: [1, 2, 3, 4] + shardTotal: [4] defaults: run: shell: bash @@ -66,7 +66,6 @@ jobs: - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: - fetch-depth: 0 persist-credentials: false - name: Setup web environment @@ -84,40 +83,9 @@ jobs: - name: Report coverage if: ${{ env.CODECOV_TOKEN != '' }} - uses: codecov/codecov-action@1af58845a975a7985b0beb0cbe6fbbb71a41dbad # v5.5.3 + uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0 with: directory: web/coverage flags: web env: CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }} - - web-build: - name: Web Build - runs-on: ubuntu-latest - defaults: - run: - working-directory: ./web - - steps: - - name: Checkout code - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - with: - persist-credentials: false - - - name: Check changed files - id: changed-files - uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 - with: - files: | - web/** - .github/workflows/web-tests.yml - .github/actions/setup-web/** - - - name: Setup web environment - if: steps.changed-files.outputs.any_changed == 'true' - uses: ./.github/actions/setup-web - - - name: Web build check - if: steps.changed-files.outputs.any_changed == 'true' - working-directory: ./web - run: vp run build diff --git a/.gitignore b/.gitignore index aaca9f2b0a3..d7698fe3fd9 100644 --- a/.gitignore +++ b/.gitignore @@ -212,6 +212,7 @@ api/.vscode # pnpm /.pnpm-store +/node_modules # plugin migrate plugins.jsonl @@ -239,4 +240,4 @@ scripts/stress-test/reports/ *.local.md # Code Agent Folder -.qoder/* \ No newline at end of file +.qoder/* diff --git a/web/.nvmrc b/.nvmrc similarity index 100% rename from web/.nvmrc rename to .nvmrc diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d7f007af679..775401bfa5c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -97,3 +97,8 @@ Feel free to reach out if you encounter any issues during the setup process. ## Getting Help If you ever get stuck or get a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat. + +## Automated Agent Contributions + +> [!NOTE] +> If you are an automated agent, just add 🤖🤖🤖 to the end of the PR title to opt-in. diff --git a/Makefile b/Makefile index 55871c86a72..d8c9df5208b 100644 --- a/Makefile +++ b/Makefile @@ -24,8 +24,8 @@ prepare-docker: # Step 2: Prepare web environment prepare-web: @echo "🌐 Setting up web environment..." - @cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists" - @cd web && pnpm install + @cp -n web/.env.example web/.env.local 2>/dev/null || echo "Web .env.local already exists" + @pnpm install @echo "✅ Web environment prepared (not started)" # Step 3: Prepare API environment @@ -74,6 +74,12 @@ type-check: @uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped . @echo "✅ Type checks complete" +type-check-core: + @echo "📝 Running core type checks (basedpyright + mypy)..." + @./dev/basedpyright-check $(PATH_TO_CHECK) + @uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped . + @echo "✅ Core type checks complete" + test: @echo "🧪 Running backend unit tests..." @if [ -n "$(TARGET_TESTS)" ]; then \ @@ -87,7 +93,7 @@ test: # Build Docker images build-web: @echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..." - docker build -t $(WEB_IMAGE):$(VERSION) ./web + docker build -f web/Dockerfile -t $(WEB_IMAGE):$(VERSION) . @echo "Web Docker image built successfully: $(WEB_IMAGE):$(VERSION)" build-api: @@ -133,6 +139,7 @@ help: @echo " make check - Check code with ruff" @echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)" @echo " make type-check - Run type checks (basedpyright, pyrefly, mypy)" + @echo " make type-check-core - Run core type checks (basedpyright, mypy)" @echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/)" @echo "" @echo "Docker Build Targets:" diff --git a/README.md b/README.md index bef8f6b782a..d9848a6c786 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

Dify is an open-source LLM app development platform. Its intuitive interface combines AI workflow, RAG pipeline, agent capabilities, model management, observability features (including [Opik](https://www.comet.com/docs/opik/integrations/dify), [Langfuse](https://docs.langfuse.com), and [Arize Phoenix](https://docs.arize.com/phoenix)) and more, letting you quickly go from prototype to production. Here's a list of the core features: diff --git a/api/.env.example b/api/.env.example index 40e1c2dfdfc..c6541731e64 100644 --- a/api/.env.example +++ b/api/.env.example @@ -127,7 +127,8 @@ ALIYUN_OSS_AUTH_VERSION=v1 ALIYUN_OSS_REGION=your-region # Don't start with '/'. OSS doesn't support leading slash in object names. ALIYUN_OSS_PATH=your-path -ALIYUN_CLOUDBOX_ID=your-cloudbox-id +# Optional CloudBox ID for Aliyun OSS, DO NOT enable it if you are not using CloudBox. +#ALIYUN_CLOUDBOX_ID=your-cloudbox-id # Google Storage configuration GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name @@ -353,6 +354,9 @@ BAIDU_VECTOR_DB_SHARD=1 BAIDU_VECTOR_DB_REPLICAS=3 BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER=DEFAULT_ANALYZER BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE=COARSE_MODE +BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT=500 +BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO=0.05 +BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS=300 # Upstash configuration UPSTASH_VECTOR_URL=your-server-url diff --git a/api/.importlinter b/api/.importlinter index a836d090887..5e06947d941 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -1,202 +1,14 @@ [importlinter] root_packages = core - dify_graph + constants + context configs controllers extensions + factories + libs models tasks services include_external_packages = True - -[importlinter:contract:workflow] -name = Workflow -type=layers -layers = - graph_engine - graph_events - graph - nodes - node_events - runtime - entities -containers = - dify_graph -ignore_imports = - dify_graph.nodes.base.node -> dify_graph.graph_events - dify_graph.nodes.iteration.iteration_node -> dify_graph.graph_events - dify_graph.nodes.loop.loop_node -> dify_graph.graph_events - - dify_graph.nodes.iteration.iteration_node -> dify_graph.graph_engine - dify_graph.nodes.loop.loop_node -> dify_graph.graph_engine - # TODO(QuantumGhost): fix the import violation later - dify_graph.entities.pause_reason -> dify_graph.nodes.human_input.entities - -[importlinter:contract:workflow-infrastructure-dependencies] -name = Workflow Infrastructure Dependencies -type = forbidden -source_modules = - dify_graph -forbidden_modules = - extensions.ext_database - extensions.ext_redis -allow_indirect_imports = True -ignore_imports = - dify_graph.nodes.llm.node -> extensions.ext_database - dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis - dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis - -[importlinter:contract:workflow-external-imports] -name = Workflow External Imports -type = forbidden -source_modules = - dify_graph -forbidden_modules = - configs - controllers - extensions - models - services - tasks - core.agent - core.app - core.base - core.callback_handler - core.datasource - core.db - core.entities - core.errors - core.extension - core.external_data_tool - core.file - core.helper - core.hosting_configuration - core.indexing_runner - core.llm_generator - core.logging - core.mcp - core.memory - core.moderation - core.ops - core.plugin - core.prompt - core.provider_manager - core.rag - core.repositories - core.schemas - core.tools - core.trigger - core.variables -ignore_imports = - dify_graph.nodes.llm.llm_utils -> core.model_manager - dify_graph.nodes.llm.protocols -> core.model_manager - dify_graph.nodes.llm.llm_utils -> dify_graph.model_runtime.model_providers.__base.large_language_model - dify_graph.nodes.llm.node -> core.tools.signature - dify_graph.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler - dify_graph.nodes.tool.tool_node -> core.tools.tool_engine - dify_graph.nodes.tool.tool_node -> core.tools.tool_manager - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> dify_graph.model_runtime.model_providers.__base.large_language_model - dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager - dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager - dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer - dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors - dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output - dify_graph.nodes.llm.node -> core.model_manager - dify_graph.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities - dify_graph.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities - dify_graph.nodes.llm.node -> core.prompt.utils.prompt_message_util - dify_graph.nodes.parameter_extractor.entities -> core.prompt.entities.advanced_prompt_entities - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.entities.advanced_prompt_entities - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util - dify_graph.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities - dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util - dify_graph.nodes.llm.node -> models.dataset - dify_graph.nodes.llm.file_saver -> core.tools.signature - dify_graph.nodes.llm.file_saver -> core.tools.tool_file_manager - dify_graph.nodes.tool.tool_node -> core.tools.errors - dify_graph.nodes.llm.node -> extensions.ext_database - dify_graph.nodes.llm.node -> models.model - dify_graph.nodes.tool.tool_node -> services - dify_graph.model_runtime.model_providers.__base.ai_model -> configs - dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis - dify_graph.model_runtime.model_providers.__base.large_language_model -> configs - dify_graph.model_runtime.model_providers.__base.text_embedding_model -> core.entities.embedding_type - dify_graph.model_runtime.model_providers.model_provider_factory -> configs - dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis - dify_graph.model_runtime.model_providers.model_provider_factory -> models.provider_ids - -[importlinter:contract:rsc] -name = RSC -type = layers -layers = - graph_engine - response_coordinator -containers = - dify_graph.graph_engine - -[importlinter:contract:worker] -name = Worker -type = layers -layers = - graph_engine - worker -containers = - dify_graph.graph_engine - -[importlinter:contract:graph-engine-architecture] -name = Graph Engine Architecture -type = layers -layers = - graph_engine - orchestration - command_processing - event_management - error_handler - graph_traversal - graph_state_manager - worker_management - domain -containers = - dify_graph.graph_engine - -[importlinter:contract:domain-isolation] -name = Domain Model Isolation -type = forbidden -source_modules = - dify_graph.graph_engine.domain -forbidden_modules = - dify_graph.graph_engine.worker_management - dify_graph.graph_engine.command_channels - dify_graph.graph_engine.layers - dify_graph.graph_engine.protocols - -[importlinter:contract:worker-management] -name = Worker Management -type = forbidden -source_modules = - dify_graph.graph_engine.worker_management -forbidden_modules = - dify_graph.graph_engine.orchestration - dify_graph.graph_engine.command_processing - dify_graph.graph_engine.event_management - - -[importlinter:contract:graph-traversal-components] -name = Graph Traversal Components -type = layers -layers = - edge_processor - skip_propagator -containers = - dify_graph.graph_engine.graph_traversal - -[importlinter:contract:command-channels] -name = Command Channels Independence -type = independence -modules = - dify_graph.graph_engine.command_channels.in_memory_channel - dify_graph.graph_engine.command_channels.redis_channel diff --git a/api/.ruff.toml b/api/.ruff.toml index b0947eb6190..4b1252a8613 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -100,7 +100,7 @@ ignore = [ "configs/*" = [ "N802", # invalid-function-name ] -"dify_graph/model_runtime/callbacks/base_callback.py" = ["T201"] +"graphon/model_runtime/callbacks/base_callback.py" = ["T201"] "core/workflow/callbacks/workflow_logging_callback.py" = ["T201"] "libs/gmpy2_pkcs10aep_cipher.py" = [ "N803", # invalid-argument-name diff --git a/api/README.md b/api/README.md index b6473670469..00562f3f78c 100644 --- a/api/README.md +++ b/api/README.md @@ -40,6 +40,8 @@ The scripts resolve paths relative to their location, so you can run them from a ./dev/start-web ``` + `./dev/setup` and `./dev/start-web` install JavaScript dependencies through the repository root workspace, so you do not need a separate `cd web && pnpm install` step. + 1. Set up your application by visiting `http://localhost:3000`. 1. Start the worker service (async and scheduler tasks, runs from `api`). diff --git a/api/app_factory.py b/api/app_factory.py index 066eb2ae2c8..76838f9925d 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -143,6 +143,7 @@ def initialize_extensions(app: DifyApp): ext_commands, ext_compress, ext_database, + ext_enterprise_telemetry, ext_fastopenapi, ext_forward_refs, ext_hosting_provider, @@ -193,6 +194,7 @@ def initialize_extensions(app: DifyApp): ext_commands, ext_fastopenapi, ext_otel, + ext_enterprise_telemetry, ext_request_logging, ext_session_factory, ] diff --git a/api/commands/vector.py b/api/commands/vector.py index 4cf11c9ad1a..cb7eb7c4522 100644 --- a/api/commands/vector.py +++ b/api/commands/vector.py @@ -10,6 +10,7 @@ from configs import dify_config from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_type import VectorType from core.rag.index_processor.constant.built_in_field import BuiltInField +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.models.document import ChildDocument, Document from extensions.ext_database import db from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment @@ -85,7 +86,7 @@ def migrate_annotation_vector_database(): dataset = Dataset( id=app.id, tenant_id=app.tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, collection_binding_id=dataset_collection_binding.id, @@ -177,7 +178,9 @@ def migrate_knowledge_vector_database(): while True: try: stmt = ( - select(Dataset).where(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc()) + select(Dataset) + .where(Dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY) + .order_by(Dataset.created_at.desc()) ) datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) @@ -269,7 +272,7 @@ def migrate_knowledge_vector_database(): "dataset_id": segment.dataset_id, }, ) - if dataset_document.doc_form == "hierarchical_model": + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = segment.get_child_chunks() if child_chunks: child_documents = [] diff --git a/api/configs/app_config.py b/api/configs/app_config.py index d3b1cf9d5b9..831f0a49e01 100644 --- a/api/configs/app_config.py +++ b/api/configs/app_config.py @@ -8,7 +8,7 @@ from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, Settings from libs.file_utils import search_file_upwards from .deploy import DeploymentConfig -from .enterprise import EnterpriseFeatureConfig +from .enterprise import EnterpriseFeatureConfig, EnterpriseTelemetryConfig from .extra import ExtraServiceConfig from .feature import FeatureConfig from .middleware import MiddlewareConfig @@ -73,6 +73,8 @@ class DifyConfig( # Enterprise feature configs # **Before using, please contact business@dify.ai by email to inquire about licensing matters.** EnterpriseFeatureConfig, + # Enterprise telemetry configs + EnterpriseTelemetryConfig, ): model_config = SettingsConfigDict( # read from dotenv format config file diff --git a/api/configs/enterprise/__init__.py b/api/configs/enterprise/__init__.py index f8447c6979a..8a6a921a4ea 100644 --- a/api/configs/enterprise/__init__.py +++ b/api/configs/enterprise/__init__.py @@ -22,3 +22,52 @@ class EnterpriseFeatureConfig(BaseSettings): ENTERPRISE_REQUEST_TIMEOUT: int = Field( ge=1, description="Maximum timeout in seconds for enterprise requests", default=5 ) + + +class EnterpriseTelemetryConfig(BaseSettings): + """ + Configuration for enterprise telemetry. + """ + + ENTERPRISE_TELEMETRY_ENABLED: bool = Field( + description="Enable enterprise telemetry collection (also requires ENTERPRISE_ENABLED=true).", + default=False, + ) + + ENTERPRISE_OTLP_ENDPOINT: str = Field( + description="Enterprise OTEL collector endpoint.", + default="", + ) + + ENTERPRISE_OTLP_HEADERS: str = Field( + description="Auth headers for OTLP export (key=value,key2=value2).", + default="", + ) + + ENTERPRISE_OTLP_PROTOCOL: str = Field( + description="OTLP protocol: 'http' or 'grpc' (default: http).", + default="http", + ) + + ENTERPRISE_OTLP_API_KEY: str = Field( + description="Bearer token for enterprise OTLP export authentication.", + default="", + ) + + ENTERPRISE_INCLUDE_CONTENT: bool = Field( + description="Include input/output content in traces (privacy toggle).", + # Setting the default value to False to avoid accidentally log PII data in traces. + default=False, + ) + + ENTERPRISE_SERVICE_NAME: str = Field( + description="Service name for OTEL resource.", + default="dify", + ) + + ENTERPRISE_OTEL_SAMPLING_RATE: float = Field( + description="Sampling rate for enterprise traces (0.0 to 1.0, default 1.0 = 100%).", + default=1.0, + ge=0.0, + le=1.0, + ) diff --git a/api/configs/middleware/vdb/baidu_vector_config.py b/api/configs/middleware/vdb/baidu_vector_config.py index 8f956745b1e..c8e4f7309f7 100644 --- a/api/configs/middleware/vdb/baidu_vector_config.py +++ b/api/configs/middleware/vdb/baidu_vector_config.py @@ -51,3 +51,18 @@ class BaiduVectorDBConfig(BaseSettings): description="Parser mode for inverted index in Baidu Vector Database (default is COARSE_MODE)", default="COARSE_MODE", ) + + BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT: int = Field( + description="Auto build row count increment threshold (default is 500)", + default=500, + ) + + BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO: float = Field( + description="Auto build row count increment ratio threshold (default is 0.05)", + default=0.05, + ) + + BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS: int = Field( + description="Timeout in seconds for rebuilding the index in Baidu Vector Database (default is 3600 seconds)", + default=300, + ) diff --git a/api/context/__init__.py b/api/context/__init__.py index 969e5f583d5..8df37138e8b 100644 --- a/api/context/__init__.py +++ b/api/context/__init__.py @@ -1,74 +1,36 @@ """ -Core Context - Framework-agnostic context management. +Application-layer context adapters. -This module provides context management that is independent of any specific -web framework. Framework-specific implementations register their context -capture functions at application initialization time. - -This ensures the workflow layer remains completely decoupled from Flask -or any other web framework. +Concrete execution-context implementations live here so `graphon` only +depends on injected context managers rather than framework state capture. """ -import contextvars -from collections.abc import Callable - -from dify_graph.context.execution_context import ( +from context.execution_context import ( + AppContext, + ContextProviderNotFoundError, ExecutionContext, + ExecutionContextBuilder, IExecutionContext, NullAppContext, + capture_current_context, + read_context, + register_context, + register_context_capturer, + reset_context_provider, ) - -# Global capturer function - set by framework-specific modules -_capturer: Callable[[], IExecutionContext] | None = None - - -def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None: - """ - Register a context capture function. - - This should be called by framework-specific modules (e.g., Flask) - during application initialization. - - Args: - capturer: Function that captures current context and returns IExecutionContext - """ - global _capturer - _capturer = capturer - - -def capture_current_context() -> IExecutionContext: - """ - Capture current execution context. - - This function uses the registered context capturer. If no capturer - is registered, it returns a minimal context with only contextvars - (suitable for non-framework environments like tests or standalone scripts). - - Returns: - IExecutionContext with captured context - """ - if _capturer is None: - # No framework registered - return minimal context - return ExecutionContext( - app_context=NullAppContext(), - context_vars=contextvars.copy_context(), - ) - - return _capturer() - - -def reset_context_provider() -> None: - """ - Reset the context capturer. - - This is primarily useful for testing to ensure a clean state. - """ - global _capturer - _capturer = None - +from context.models import SandboxContext __all__ = [ + "AppContext", + "ContextProviderNotFoundError", + "ExecutionContext", + "ExecutionContextBuilder", + "IExecutionContext", + "NullAppContext", + "SandboxContext", "capture_current_context", + "read_context", + "register_context", "register_context_capturer", "reset_context_provider", ] diff --git a/api/dify_graph/context/execution_context.py b/api/context/execution_context.py similarity index 60% rename from api/dify_graph/context/execution_context.py rename to api/context/execution_context.py index e3007530f08..ba9a24d4f33 100644 --- a/api/dify_graph/context/execution_context.py +++ b/api/context/execution_context.py @@ -1,5 +1,8 @@ """ -Execution Context - Abstracted context management for workflow execution. +Application-layer execution context adapters. + +Concrete context capture lives outside `graphon` so the graph package only +consumes injected context managers when it needs to preserve thread-local state. """ import contextvars @@ -16,33 +19,33 @@ class AppContext(ABC): """ Abstract application context interface. - This abstraction allows workflow execution to work with or without Flask - by providing a common interface for application context management. + Application adapters can implement this to restore framework-specific state + such as Flask app context around worker execution. """ @abstractmethod def get_config(self, key: str, default: Any = None) -> Any: """Get configuration value by key.""" - pass + raise NotImplementedError @abstractmethod def get_extension(self, name: str) -> Any: - """Get Flask extension by name (e.g., 'db', 'cache').""" - pass + """Get application extension by name.""" + raise NotImplementedError @abstractmethod def enter(self) -> AbstractContextManager[None]: """Enter the application context.""" - pass + raise NotImplementedError @runtime_checkable class IExecutionContext(Protocol): """ - Protocol for execution context. + Protocol for enterable execution context objects. - This protocol defines the interface that all execution contexts must implement, - allowing both ExecutionContext and FlaskExecutionContext to be used interchangeably. + Concrete implementations may carry extra framework state, but callers only + depend on standard context-manager behavior plus optional user metadata. """ def __enter__(self) -> "IExecutionContext": @@ -62,14 +65,10 @@ class IExecutionContext(Protocol): @final class ExecutionContext: """ - Execution context for workflow execution in worker threads. + Generic execution context used by application-layer adapters. - This class encapsulates all context needed for workflow execution: - - Application context (Flask app or standalone) - - Context variables for Python contextvars - - User information (optional) - - It is designed to be serializable and passable to worker threads. + It restores captured `contextvars` and optionally enters an application + context before the worker executes graph logic. """ def __init__( @@ -78,14 +77,6 @@ class ExecutionContext: context_vars: contextvars.Context | None = None, user: Any = None, ) -> None: - """ - Initialize execution context. - - Args: - app_context: Application context (Flask or standalone) - context_vars: Python contextvars to preserve - user: User object (optional) - """ self._app_context = app_context self._context_vars = context_vars self._user = user @@ -98,27 +89,21 @@ class ExecutionContext: @property def context_vars(self) -> contextvars.Context | None: - """Get context variables.""" + """Get captured context variables.""" return self._context_vars @property def user(self) -> Any: - """Get user object.""" + """Get captured user object.""" return self._user @contextmanager def enter(self) -> Generator[None, None, None]: - """ - Enter this execution context. - - This is a convenience method that creates a context manager. - """ - # Restore context variables if provided + """Enter this execution context.""" if self._context_vars: for var, val in self._context_vars.items(): var.set(val) - # Enter app context if available if self._app_context is not None: with self._app_context.enter(): yield @@ -141,18 +126,10 @@ class ExecutionContext: class NullAppContext(AppContext): """ - Null implementation of AppContext for non-Flask environments. - - This is used when running without Flask (e.g., in tests or standalone mode). + Null application context for non-framework environments. """ def __init__(self, config: dict[str, Any] | None = None) -> None: - """ - Initialize null app context. - - Args: - config: Optional configuration dictionary - """ self._config = config or {} self._extensions: dict[str, Any] = {} @@ -165,7 +142,7 @@ class NullAppContext(AppContext): return self._extensions.get(name) def set_extension(self, name: str, extension: Any) -> None: - """Set extension by name.""" + """Register an extension for tests or standalone execution.""" self._extensions[name] = extension @contextmanager @@ -176,9 +153,7 @@ class NullAppContext(AppContext): class ExecutionContextBuilder: """ - Builder for creating ExecutionContext instances. - - This provides a fluent API for building execution contexts. + Builder for creating `ExecutionContext` instances. """ def __init__(self) -> None: @@ -211,63 +186,42 @@ class ExecutionContextBuilder: _capturer: Callable[[], IExecutionContext] | None = None - -# Tenant-scoped providers using tuple keys for clarity and constant-time lookup. -# Key mapping: -# (name, tenant_id) -> provider -# - name: namespaced identifier (recommend prefixing, e.g. "workflow.sandbox") -# - tenant_id: tenant identifier string -# Value: -# provider: Callable[[], BaseModel] returning the typed context value -# Type-safety note: -# - This registry cannot enforce that all providers for a given name return the same BaseModel type. -# - Implementors SHOULD provide typed wrappers around register/read (like Go's context best practice), -# e.g. def register_sandbox_ctx(tenant_id: str, p: Callable[[], SandboxContext]) and -# def read_sandbox_ctx(tenant_id: str) -> SandboxContext. _tenant_context_providers: dict[tuple[str, str], Callable[[], BaseModel]] = {} T = TypeVar("T", bound=BaseModel) class ContextProviderNotFoundError(KeyError): - """Raised when a tenant-scoped context provider is missing for a given (name, tenant_id).""" + """Raised when a tenant-scoped context provider is missing.""" pass def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None: - """Register a single enterable execution context capturer (e.g., Flask).""" + """Register an enterable execution context capturer.""" global _capturer _capturer = capturer def register_context(name: str, tenant_id: str, provider: Callable[[], BaseModel]) -> None: - """Register a tenant-specific provider for a named context. - - Tip: use a namespaced "name" (e.g., "workflow.sandbox") to avoid key collisions. - Consider adding a typed wrapper for this registration in your feature module. - """ + """Register a tenant-specific provider for a named context.""" _tenant_context_providers[(name, tenant_id)] = provider def read_context(name: str, *, tenant_id: str) -> BaseModel: - """ - Read a context value for a specific tenant. - - Raises KeyError if the provider for (name, tenant_id) is not registered. - """ - prov = _tenant_context_providers.get((name, tenant_id)) - if prov is None: + """Read a context value for a specific tenant.""" + provider = _tenant_context_providers.get((name, tenant_id)) + if provider is None: raise ContextProviderNotFoundError(f"Context provider '{name}' not registered for tenant '{tenant_id}'") - return prov() + return provider() def capture_current_context() -> IExecutionContext: """ Capture current execution context from the calling environment. - If a capturer is registered (e.g., Flask), use it. Otherwise, return a minimal - context with NullAppContext + copy of current contextvars. + If no framework adapter is registered, return a minimal context that only + restores `contextvars`. """ if _capturer is None: return ExecutionContext( @@ -278,7 +232,22 @@ def capture_current_context() -> IExecutionContext: def reset_context_provider() -> None: - """Reset the capturer and all tenant-scoped context providers (primarily for tests).""" + """Reset the capturer and tenant-scoped providers.""" global _capturer _capturer = None _tenant_context_providers.clear() + + +__all__ = [ + "AppContext", + "ContextProviderNotFoundError", + "ExecutionContext", + "ExecutionContextBuilder", + "IExecutionContext", + "NullAppContext", + "capture_current_context", + "read_context", + "register_context", + "register_context_capturer", + "reset_context_provider", +] diff --git a/api/context/flask_app_context.py b/api/context/flask_app_context.py index 324a9ee8b44..eddd6448d83 100644 --- a/api/context/flask_app_context.py +++ b/api/context/flask_app_context.py @@ -10,11 +10,7 @@ from typing import Any, final from flask import Flask, current_app, g -from dify_graph.context import register_context_capturer -from dify_graph.context.execution_context import ( - AppContext, - IExecutionContext, -) +from context.execution_context import AppContext, IExecutionContext, register_context_capturer @final diff --git a/api/dify_graph/context/models.py b/api/context/models.py similarity index 100% rename from api/dify_graph/context/models.py rename to api/context/models.py diff --git a/api/contexts/__init__.py b/api/contexts/__init__.py index c52dcf8a574..764f9f8ee27 100644 --- a/api/contexts/__init__.py +++ b/api/contexts/__init__.py @@ -6,7 +6,6 @@ from contexts.wrapper import RecyclableContextVar if TYPE_CHECKING: from core.datasource.__base.datasource_provider import DatasourcePluginProviderController - from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.tools.plugin_tool.provider import PluginToolProviderController from core.trigger.provider import PluginTriggerProviderController @@ -20,14 +19,6 @@ plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderControl plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock")) -plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar( - ContextVar("plugin_model_providers") -) - -plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar( - ContextVar("plugin_model_providers_lock") -) - datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = ( RecyclableContextVar(ContextVar("datasource_plugin_providers")) ) diff --git a/api/controllers/common/fields.py b/api/controllers/common/fields.py index ff5326dade6..7348ef62aad 100644 --- a/api/controllers/common/fields.py +++ b/api/controllers/common/fields.py @@ -2,9 +2,9 @@ from __future__ import annotations from typing import Any, TypeAlias +from graphon.file import helpers as file_helpers from pydantic import BaseModel, ConfigDict, computed_field -from dify_graph.file import helpers as file_helpers from models.model import IconType JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any] diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index b6d1df319e5..783cb5c444b 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -1,7 +1,7 @@ import flask_restx from flask_restx import Resource, fields, marshal_with from flask_restx._http import HTTPStatus -from sqlalchemy import select +from sqlalchemy import delete, func, select from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden @@ -9,6 +9,7 @@ from extensions.ext_database import db from libs.helper import TimestampField from libs.login import current_account_with_tenant, login_required from models.dataset import Dataset +from models.enums import ApiTokenType from models.model import ApiToken, App from services.api_token_service import ApiTokenCache @@ -33,16 +34,10 @@ api_key_list_model = console_ns.model( def _get_resource(resource_id, tenant_id, resource_model): - if resource_model == App: - with Session(db.engine) as session: - resource = session.execute( - select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id) - ).scalar_one_or_none() - else: - with Session(db.engine) as session: - resource = session.execute( - select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id) - ).scalar_one_or_none() + with Session(db.engine) as session: + resource = session.execute( + select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id) + ).scalar_one_or_none() if resource is None: flask_restx.abort(HTTPStatus.NOT_FOUND, message=f"{resource_model.__name__} not found.") @@ -53,7 +48,7 @@ def _get_resource(resource_id, tenant_id, resource_model): class BaseApiKeyListResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] - resource_type: str | None = None + resource_type: ApiTokenType | None = None resource_model: type | None = None resource_id_field: str | None = None token_prefix: str | None = None @@ -80,10 +75,13 @@ class BaseApiKeyListResource(Resource): resource_id = str(resource_id) _, current_tenant_id = current_account_with_tenant() _get_resource(resource_id, current_tenant_id, self.resource_model) - current_key_count = ( - db.session.query(ApiToken) - .where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) - .count() + current_key_count: int = ( + db.session.scalar( + select(func.count(ApiToken.id)).where( + ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id + ) + ) + or 0 ) if current_key_count >= self.max_keys: @@ -94,6 +92,7 @@ class BaseApiKeyListResource(Resource): ) key = ApiToken.generate_api_key(self.token_prefix or "", 24) + assert self.resource_type is not None, "resource_type must be set" api_token = ApiToken() setattr(api_token, self.resource_id_field, resource_id) api_token.tenant_id = current_tenant_id @@ -107,7 +106,7 @@ class BaseApiKeyListResource(Resource): class BaseApiKeyResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] - resource_type: str | None = None + resource_type: ApiTokenType | None = None resource_model: type | None = None resource_id_field: str | None = None @@ -119,14 +118,14 @@ class BaseApiKeyResource(Resource): if not current_user.is_admin_or_owner: raise Forbidden() - key = ( - db.session.query(ApiToken) + key = db.session.scalar( + select(ApiToken) .where( getattr(ApiToken, self.resource_id_field) == resource_id, ApiToken.type == self.resource_type, ApiToken.id == api_key_id, ) - .first() + .limit(1) ) if key is None: @@ -137,7 +136,7 @@ class BaseApiKeyResource(Resource): assert key is not None # nosec - for type checker only ApiTokenCache.delete(key.token, key.type) - db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() + db.session.execute(delete(ApiToken).where(ApiToken.id == api_key_id)) db.session.commit() return {"result": "success"}, 204 @@ -162,7 +161,7 @@ class AppApiKeyListResource(BaseApiKeyListResource): """Create a new API key for an app""" return super().post(resource_id) - resource_type = "app" + resource_type = ApiTokenType.APP resource_model = App resource_id_field = "app_id" token_prefix = "app-" @@ -178,7 +177,7 @@ class AppApiKeyResource(BaseApiKeyResource): """Delete an API key for an app""" return super().delete(resource_id, api_key_id) - resource_type = "app" + resource_type = ApiTokenType.APP resource_model = App resource_id_field = "app_id" @@ -202,7 +201,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource): """Create a new API key for a dataset""" return super().post(resource_id) - resource_type = "dataset" + resource_type = ApiTokenType.DATASET resource_model = Dataset resource_id_field = "dataset_id" token_prefix = "ds-" @@ -218,6 +217,6 @@ class DatasetApiKeyResource(BaseApiKeyResource): """Delete an API key for a dataset""" return super().delete(resource_id, api_key_id) - resource_type = "dataset" + resource_type = ApiTokenType.DATASET resource_model = Dataset resource_id_field = "dataset_id" diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 5ac0e342e6e..738e77b3715 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -5,6 +5,8 @@ from typing import Any, Literal, TypeAlias from flask import request from flask_restx import Resource +from graphon.enums import WorkflowExecutionStatus +from graphon.file import helpers as file_helpers from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field, field_validator from sqlalchemy import select from sqlalchemy.orm import Session @@ -26,8 +28,6 @@ from controllers.console.wraps import ( from core.ops.ops_trace_manager import OpsTraceManager from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.trigger.constants import TRIGGER_NODE_TYPES -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.file import helpers as file_helpers from extensions.ext_database import db from libs.login import current_account_with_tenant, login_required from models import App, DatasetPermissionEnum, Workflow @@ -95,7 +95,7 @@ class CreateAppPayload(BaseModel): name: str = Field(..., min_length=1, description="App name") description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400) mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode") - icon_type: str | None = Field(default=None, description="Icon type") + icon_type: IconType | None = Field(default=None, description="Icon type") icon: str | None = Field(default=None, description="Icon") icon_background: str | None = Field(default=None, description="Icon background color") @@ -103,7 +103,7 @@ class CreateAppPayload(BaseModel): class UpdateAppPayload(BaseModel): name: str = Field(..., min_length=1, description="App name") description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400) - icon_type: str | None = Field(default=None, description="Icon type") + icon_type: IconType | None = Field(default=None, description="Icon type") icon: str | None = Field(default=None, description="Icon") icon_background: str | None = Field(default=None, description="Icon background color") use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon") @@ -113,7 +113,7 @@ class UpdateAppPayload(BaseModel): class CopyAppPayload(BaseModel): name: str | None = Field(default=None, description="Name for the copied app") description: str | None = Field(default=None, description="Description for the copied app", max_length=400) - icon_type: str | None = Field(default=None, description="Icon type") + icon_type: IconType | None = Field(default=None, description="Icon type") icon: str | None = Field(default=None, description="Icon") icon_background: str | None = Field(default=None, description="Icon background color") @@ -594,7 +594,7 @@ class AppApi(Resource): args_dict: AppService.ArgsDict = { "name": args.name, "description": args.description or "", - "icon_type": args.icon_type or "", + "icon_type": args.icon_type, "icon": args.icon or "", "icon_background": args.icon_background or "", "use_icon_as_answer_icon": args.use_icon_as_answer_icon or False, diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 2c5e8d29ee9..78ddb904e14 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -2,6 +2,7 @@ import logging from flask import request from flask_restx import Resource, fields +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError @@ -22,7 +23,6 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError from libs.login import login_required from models import App, AppMode from services.audio_service import AudioService diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 4d7ddfea139..d83925d173a 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -3,6 +3,7 @@ from typing import Any, Literal from flask import request from flask_restx import Resource +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound @@ -26,7 +27,6 @@ from core.errors.error import ( QuotaExceededError, ) from core.helper.trace_id_helper import get_external_trace_id -from dify_graph.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from libs.login import current_user, login_required diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 5eb61493c38..d329d22309d 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -5,7 +5,7 @@ from flask import abort, request from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel, Field, field_validator from sqlalchemy import func, or_ -from sqlalchemy.orm import joinedload +from sqlalchemy.orm import selectinload from werkzeug.exceptions import NotFound from controllers.console import console_ns @@ -376,8 +376,12 @@ class CompletionConversationApi(Resource): # FIXME, the type ignore in this file if args.annotation_status == "annotated": - query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore - MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id + query = ( + query.options(selectinload(Conversation.message_annotations)) # type: ignore[arg-type] + .join( # type: ignore + MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id + ) + .distinct() ) elif args.annotation_status == "not_annotated": query = ( @@ -454,9 +458,7 @@ class ChatConversationApi(Resource): args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore subquery = ( - db.session.query( - Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id") - ) + sa.select(Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id")) .outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id) .subquery() ) @@ -511,8 +513,12 @@ class ChatConversationApi(Resource): match args.annotation_status: case "annotated": - query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore - MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id + query = ( + query.options(selectinload(Conversation.message_annotations)) # type: ignore[arg-type] + .join( # type: ignore + MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id + ) + .distinct() ) case "not_annotated": query = ( @@ -587,10 +593,8 @@ class ChatConversationDetailApi(Resource): def _get_conversation(app_model, conversation_id): current_user, _ = current_account_with_tenant() - conversation = ( - db.session.query(Conversation) - .where(Conversation.id == conversation_id, Conversation.app_id == app_model.id) - .first() + conversation = db.session.scalar( + sa.select(Conversation).where(Conversation.id == conversation_id, Conversation.app_id == app_model.id).limit(1) ) if not conversation: diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index af4ac450bbc..7101d5df7b4 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,6 +1,7 @@ from collections.abc import Sequence from flask_restx import Resource +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from controllers.console import console_ns @@ -18,7 +19,6 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.llm_generator import LLMGenerator -from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from libs.login import current_account_with_tenant, login_required from models import App @@ -168,7 +168,7 @@ class InstructionGenerateApi(Resource): try: # Generate from nothing for a workflow node if (args.current in (code_template, "")) and args.node_id != "": - app = db.session.query(App).where(App.id == args.flow_id).first() + app = db.session.get(App, args.flow_id) if not app: return {"error": f"app {args.flow_id} not found"}, 400 workflow = WorkflowService().get_draft_workflow(app_model=app) diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index 4b20418b534..412fc8795a4 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -2,6 +2,7 @@ import json from flask_restx import Resource, marshal_with from pydantic import BaseModel, Field +from sqlalchemy import select from werkzeug.exceptions import NotFound from controllers.console import console_ns @@ -47,7 +48,7 @@ class AppMCPServerController(Resource): @get_app_model @marshal_with(app_server_model) def get(self, app_model): - server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first() + server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.app_id == app_model.id).limit(1)) return server @console_ns.doc("create_app_mcp_server") @@ -98,7 +99,7 @@ class AppMCPServerController(Resource): @edit_permission_required def put(self, app_model): payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {}) - server = db.session.query(AppMCPServer).where(AppMCPServer.id == payload.id).first() + server = db.session.get(AppMCPServer, payload.id) if not server: raise NotFound() @@ -135,11 +136,10 @@ class AppMCPServerRefreshController(Resource): @edit_permission_required def get(self, server_id): _, current_tenant_id = current_account_with_tenant() - server = ( - db.session.query(AppMCPServer) - .where(AppMCPServer.id == server_id) - .where(AppMCPServer.tenant_id == current_tenant_id) - .first() + server = db.session.scalar( + select(AppMCPServer) + .where(AppMCPServer.id == server_id, AppMCPServer.tenant_id == current_tenant_id) + .limit(1) ) if not server: raise NotFound() diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 4fb73f61f37..2afe2767427 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -3,8 +3,9 @@ from typing import Literal from flask import request from flask_restx import Resource, fields, marshal_with +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator -from sqlalchemy import exists, select +from sqlalchemy import exists, func, select from werkzeug.exceptions import InternalServerError, NotFound from controllers.common.schema import register_schema_models @@ -24,7 +25,6 @@ from controllers.console.wraps import ( ) from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from fields.raws import FilesContainedField from libs.helper import TimestampField, uuid_value @@ -244,27 +244,25 @@ class ChatMessageListApi(Resource): def get(self, app_model): args = ChatMessagesQuery.model_validate(request.args.to_dict()) - conversation = ( - db.session.query(Conversation) + conversation = db.session.scalar( + select(Conversation) .where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id) - .first() + .limit(1) ) if not conversation: raise NotFound("Conversation Not Exists.") if args.first_id: - first_message = ( - db.session.query(Message) - .where(Message.conversation_id == conversation.id, Message.id == args.first_id) - .first() + first_message = db.session.scalar( + select(Message).where(Message.conversation_id == conversation.id, Message.id == args.first_id).limit(1) ) if not first_message: raise NotFound("First message not found") - history_messages = ( - db.session.query(Message) + history_messages = db.session.scalars( + select(Message) .where( Message.conversation_id == conversation.id, Message.created_at < first_message.created_at, @@ -272,16 +270,14 @@ class ChatMessageListApi(Resource): ) .order_by(Message.created_at.desc()) .limit(args.limit) - .all() - ) + ).all() else: - history_messages = ( - db.session.query(Message) + history_messages = db.session.scalars( + select(Message) .where(Message.conversation_id == conversation.id) .order_by(Message.created_at.desc()) .limit(args.limit) - .all() - ) + ).all() # Initialize has_more based on whether we have a full page if len(history_messages) == args.limit: @@ -326,7 +322,9 @@ class MessageFeedbackApi(Resource): message_id = str(args.message_id) - message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first() + message = db.session.scalar( + select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1) + ) if not message: raise NotFound("Message Not Exists.") @@ -375,7 +373,9 @@ class MessageAnnotationCountApi(Resource): @login_required @account_initialization_required def get(self, app_model): - count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count() + count = db.session.scalar( + select(func.count(MessageAnnotation.id)).where(MessageAnnotation.app_id == app_model.id) + ) return {"count": count} @@ -479,7 +479,9 @@ class MessageApi(Resource): def get(self, app_model, message_id: str): message_id = str(message_id) - message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first() + message = db.session.scalar( + select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1) + ) if not message: raise NotFound("Message Not Exists.") diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index a85e54fb512..8bb5aa2c1b1 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -69,9 +69,7 @@ class ModelConfigResource(Resource): if app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent: # get original app model config - original_app_model_config = ( - db.session.query(AppModelConfig).where(AppModelConfig.id == app_model.app_model_config_id).first() - ) + original_app_model_config = db.session.get(AppModelConfig, app_model.app_model_config_id) if original_app_model_config is None: raise ValueError("Original app model config not found") agent_mode = original_app_model_config.agent_mode_dict @@ -90,6 +88,7 @@ class ModelConfigResource(Resource): tenant_id=current_tenant_id, app_id=app_model.id, agent_tool=agent_tool_entity, + user_id=current_user.id, ) manager = ToolParameterConfigurationManager( tenant_id=current_tenant_id, @@ -129,6 +128,7 @@ class ModelConfigResource(Resource): tenant_id=current_tenant_id, app_id=app_model.id, agent_tool=agent_tool_entity, + user_id=current_user.id, ) except Exception: continue diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index db218d8b813..7f44a99ff13 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -2,6 +2,7 @@ from typing import Literal from flask_restx import Resource, marshal_with from pydantic import BaseModel, Field, field_validator +from sqlalchemy import select from werkzeug.exceptions import NotFound from constants.languages import supported_language @@ -75,7 +76,7 @@ class AppSite(Resource): def post(self, app_model): args = AppSiteUpdatePayload.model_validate(console_ns.payload or {}) current_user, _ = current_account_with_tenant() - site = db.session.query(Site).where(Site.app_id == app_model.id).first() + site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) if not site: raise NotFound @@ -124,7 +125,7 @@ class AppSiteAccessTokenReset(Resource): @marshal_with(app_site_model) def post(self, app_model): current_user, _ = current_account_with_tenant() - site = db.session.query(Site).where(Site.app_id == app_model.id).first() + site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) if not site: raise NotFound diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 837245ecb11..1f5a84c0b2b 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -5,9 +5,13 @@ from typing import Any from flask import abort, request from flask_restx import Resource, fields, marshal_with +from graphon.enums import NodeType +from graphon.file import File +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import Session -from werkzeug.exceptions import Forbidden, InternalServerError, NotFound +from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound import services from controllers.console import console_ns @@ -20,6 +24,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.file_access import DatabaseFileAccessController from core.helper.trace_id_helper import get_external_trace_id from core.plugin.impl.exc import PluginInvokeError from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE @@ -29,10 +34,6 @@ from core.trigger.debug.event_selectors import ( create_event_poller, select_trigger_debug_events, ) -from dify_graph.enums import NodeType -from dify_graph.file.models import File -from dify_graph.graph_engine.manager import GraphEngineManager -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from extensions.ext_redis import redis_client from factories import file_factory, variable_factory @@ -46,13 +47,15 @@ from models import App from models.model import AppMode from models.workflow import Workflow from services.app_generate_service import AppGenerateService -from services.errors.app import WorkflowHashNotEqualError +from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() LISTENING_RETRY_IN = 2000 DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" +RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE = "source workflow must be published" # Register models for flask_restx to avoid dict type issues in Swagger # Register in dependency order: base models first, then dependent models @@ -203,6 +206,7 @@ def _parse_file(workflow: Workflow, files: list[dict] | None = None) -> Sequence mappings=files, tenant_id=workflow.tenant_id, config=file_extra_config, + access_controller=_file_access_controller, ) return file_objs @@ -284,7 +288,9 @@ class DraftWorkflowApi(Resource): workflow_service = WorkflowService() try: - environment_variables_list = args.get("environment_variables") or [] + environment_variables_list = Workflow.normalize_environment_variable_mappings( + args.get("environment_variables") or [], + ) environment_variables = [ variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list ] @@ -994,6 +1000,43 @@ class PublishedAllWorkflowApi(Resource): } +@console_ns.route("/apps//workflows//restore") +class DraftWorkflowRestoreApi(Resource): + @console_ns.doc("restore_workflow_to_draft") + @console_ns.doc(description="Restore a published workflow version into the draft workflow") + @console_ns.doc(params={"app_id": "Application ID", "workflow_id": "Published workflow ID"}) + @console_ns.response(200, "Workflow restored successfully") + @console_ns.response(400, "Source workflow must be published") + @console_ns.response(404, "Workflow not found") + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @edit_permission_required + def post(self, app_model: App, workflow_id: str): + current_user, _ = current_account_with_tenant() + workflow_service = WorkflowService() + + try: + workflow = workflow_service.restore_published_workflow_to_draft( + app_model=app_model, + workflow_id=workflow_id, + account=current_user, + ) + except IsDraftWorkflowError as exc: + raise BadRequest(RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE) from exc + except WorkflowNotFoundError as exc: + raise NotFound(str(exc)) from exc + except ValueError as exc: + raise BadRequest(str(exc)) from exc + + return { + "result": "success", + "hash": workflow.unique_hash, + "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at), + } + + @console_ns.route("/apps//workflows/") class WorkflowByIdApi(Resource): @console_ns.doc("update_workflow_by_id") diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 9b148c3f187..f0e26c86a5a 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -3,13 +3,13 @@ from datetime import datetime from dateutil.parser import isoparse from flask import request from flask_restx import Resource, marshal_with +from graphon.enums import WorkflowExecutionStatus from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import Session from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required -from dify_graph.enums import WorkflowExecutionStatus from extensions.ext_database import db from fields.workflow_app_log_fields import ( build_workflow_app_log_pagination_model, diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index b78d97a382e..4052897e9a4 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -5,6 +5,10 @@ from typing import Any, NoReturn, ParamSpec, TypeVar from flask import Response, request from flask_restx import Resource, fields, marshal, marshal_with +from graphon.file import helpers as file_helpers +from graphon.variables.segment_group import SegmentGroup +from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment +from graphon.variables.types import SegmentType from pydantic import BaseModel, Field from sqlalchemy.orm import Session @@ -15,11 +19,8 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.web.error import InvalidArgumentError, NotFoundError -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from dify_graph.file import helpers as file_helpers -from dify_graph.variables.segment_group import SegmentGroup -from dify_graph.variables.segments import ArrayFileSegment, FileSegment, Segment -from dify_graph.variables.types import SegmentType +from core.app.file_access import DatabaseFileAccessController +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type @@ -30,6 +31,7 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList, from services.workflow_service import WorkflowService logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" @@ -389,13 +391,21 @@ class VariableApi(Resource): if variable.value_type == SegmentType.FILE: if not isinstance(raw_value, dict): raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}") - raw_value = build_from_mapping(mapping=raw_value, tenant_id=app_model.tenant_id) + raw_value = build_from_mapping( + mapping=raw_value, + tenant_id=app_model.tenant_id, + access_controller=_file_access_controller, + ) elif variable.value_type == SegmentType.ARRAY_FILE: if not isinstance(raw_value, list): raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}") if len(raw_value) > 0 and not isinstance(raw_value[0], dict): raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}") - raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id) + raw_value = build_from_mappings( + mappings=raw_value, + tenant_id=app_model.tenant_id, + access_controller=_file_access_controller, + ) new_value = build_segment_with_type(variable.value_type, raw_value) draft_var_srv.update_variable(variable, name=new_name, value=new_value) db.session.commit() diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 7ac653395ee..83e8bedc110 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -1,8 +1,10 @@ from datetime import UTC, datetime, timedelta -from typing import Literal, cast +from typing import Literal, TypedDict, cast from flask import request from flask_restx import Resource, fields, marshal_with +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from sqlalchemy.orm import sessionmaker @@ -12,8 +14,7 @@ from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import NotFoundError -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import WorkflowExecutionStatus +from core.workflow.human_input_forms import load_form_tokens_by_form_id as _load_form_tokens_by_form_id from extensions.ext_database import db from fields.end_user_fields import simple_end_user_fields from fields.member_fields import simple_account_fields @@ -172,6 +173,23 @@ console_ns.schema_model( ) +class HumanInputPauseTypeResponse(TypedDict): + type: Literal["human_input"] + form_id: str + backstage_input_url: str | None + + +class PausedNodeResponse(TypedDict): + node_id: str + node_title: str + pause_type: HumanInputPauseTypeResponse + + +class WorkflowPauseDetailsResponse(TypedDict): + paused_at: str | None + paused_nodes: list[PausedNodeResponse] + + @console_ns.route("/apps//advanced-chat/workflow-runs") class AdvancedChatAppWorkflowRunListApi(Resource): @console_ns.doc("get_advanced_chat_workflow_runs") @@ -489,18 +507,22 @@ class ConsoleWorkflowPauseDetailsApi(Resource): # Check if workflow is suspended is_paused = workflow_run.status == WorkflowExecutionStatus.PAUSED if not is_paused: - return { + empty_response: WorkflowPauseDetailsResponse = { "paused_at": None, "paused_nodes": [], - }, 200 + } + return empty_response, 200 pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id) pause_reasons = pause_entity.get_pause_reasons() if pause_entity else [] + form_tokens_by_form_id = _load_form_tokens_by_form_id( + [reason.form_id for reason in pause_reasons if isinstance(reason, HumanInputRequired)] + ) # Build response paused_at = pause_entity.paused_at if pause_entity else None - paused_nodes = [] - response = { + paused_nodes: list[PausedNodeResponse] = [] + response: WorkflowPauseDetailsResponse = { "paused_at": paused_at.isoformat() + "Z" if paused_at else None, "paused_nodes": paused_nodes, } @@ -514,7 +536,9 @@ class ConsoleWorkflowPauseDetailsApi(Resource): "pause_type": { "type": "human_input", "form_id": reason.form_id, - "backstage_input_url": _build_backstage_input_url(reason.form_token), + "backstage_input_url": _build_backstage_input_url( + form_tokens_by_form_id.get(reason.form_id) + ), }, } ) diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index e687d980fa9..493022ffea7 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -2,6 +2,8 @@ from collections.abc import Callable from functools import wraps from typing import ParamSpec, TypeVar, Union +from sqlalchemy import select + from controllers.console.app.error import AppNotFoundError from extensions.ext_database import db from libs.login import current_account_with_tenant @@ -15,16 +17,14 @@ R1 = TypeVar("R1") def _load_app_model(app_id: str) -> App | None: _, current_tenant_id = current_account_with_tenant() - app_model = ( - db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") - .first() + app_model = db.session.scalar( + select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) ) return app_model def _load_app_model_with_trial(app_id: str) -> App | None: - app_model = db.session.query(App).where(App.id == app_id, App.status == "normal").first() + app_model = db.session.scalar(select(App).where(App.id == app_id, App.status == "normal").limit(1)) return app_model diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py index c2a95ddad2f..9e7faa09c58 100644 --- a/api/controllers/console/auth/email_register.py +++ b/api/controllers/console/auth/email_register.py @@ -1,7 +1,7 @@ from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, field_validator -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from configs import dify_config from constants.languages import languages @@ -73,7 +73,7 @@ class EmailRegisterSendEmailApi(Resource): if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email): raise AccountInFreezeError() - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language) return {"result": "success", "data": token} @@ -145,7 +145,7 @@ class EmailRegisterResetApi(Resource): email = register_data.get("email", "") normalized_email = email.lower() - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(email, session=session) if account: diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 1ed931b0d7d..844f3c91ff0 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -4,7 +4,7 @@ import secrets from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, field_validator -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from controllers.common.schema import register_schema_models from controllers.console import console_ns @@ -102,7 +102,7 @@ class ForgotPasswordSendEmailApi(Resource): else: language = "en-US" - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) token = AccountService.send_reset_password_email( @@ -201,7 +201,7 @@ class ForgotPasswordResetApi(Resource): password_hashed = hash_password(args.new_password, salt) email = reset_data.get("email", "") - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(email, session=session) if account: @@ -215,7 +215,6 @@ class ForgotPasswordResetApi(Resource): # Update existing account credentials account.password = base64.b64encode(password_hashed).decode() account.password_salt = base64.b64encode(salt).decode() - session.commit() # Create workspace if needed if ( diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 112e1524322..5c7011fd22f 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -1,9 +1,10 @@ import logging +import urllib.parse import httpx from flask import current_app, redirect, request from flask_restx import Resource -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import Unauthorized from configs import dify_config @@ -112,6 +113,9 @@ class OAuthCallback(Resource): error_text = e.response.text logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text) return {"error": "OAuth process failed"}, 400 + except ValueError as e: + logger.warning("OAuth error with %s", provider, exc_info=True) + return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={urllib.parse.quote(str(e))}") if invite_token and RegisterService.is_valid_invite_token(invite_token): invitation = RegisterService.get_invitation_by_token(token=invite_token) @@ -176,7 +180,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> account: Account | None = Account.get_by_openid(provider, user_info.id) if not account: - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session) return account diff --git a/api/controllers/console/auth/oauth_server.py b/api/controllers/console/auth/oauth_server.py index 6e59d4203cd..686b865871d 100644 --- a/api/controllers/console/auth/oauth_server.py +++ b/api/controllers/console/auth/oauth_server.py @@ -4,11 +4,11 @@ from typing import Concatenate, ParamSpec, TypeVar from flask import jsonify, request from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel from werkzeug.exceptions import BadRequest, NotFound from controllers.console.wraps import account_initialization_required, setup_required -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models import Account from models.model import OAuthProviderApp diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 725a8380cdc..f23c7eb4310 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -2,8 +2,9 @@ from typing import Any, cast from flask import request from flask_restx import Resource, fields, marshal, marshal_with +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field, field_validator -from sqlalchemy import select +from sqlalchemy import func, select from werkzeug.exceptions import Forbidden, NotFound import services @@ -25,12 +26,12 @@ from controllers.console.wraps import ( ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.indexing_runner import IndexingRunner -from core.provider_manager import ProviderManager +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.rag.datasource.vdb.vector_type import VectorType from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from fields.app_fields import app_detail_kernel_fields, related_app_list from fields.dataset_fields import ( @@ -54,7 +55,7 @@ from fields.document_fields import document_status_fields from libs.login import current_account_with_tenant, login_required from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile from models.dataset import DatasetPermission, DatasetPermissionEnum -from models.enums import SegmentStatus +from models.enums import ApiTokenType, SegmentStatus from models.provider_ids import ModelProviderID from services.api_token_service import ApiTokenCache from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService @@ -331,7 +332,7 @@ class DatasetListApi(Resource): ) # check embedding setting - provider_manager = ProviderManager() + provider_manager = create_plugin_provider_manager(tenant_id=current_tenant_id) configurations = provider_manager.get_configurations(tenant_id=current_tenant_id) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) @@ -355,7 +356,7 @@ class DatasetListApi(Resource): for item in data: # convert embedding_model_provider to plugin standard format - if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]: + if item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY and item["embedding_model_provider"]: item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"])) item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" if item_model in model_names: @@ -436,7 +437,7 @@ class DatasetApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields)) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: if dataset.embedding_model_provider: provider_id = ModelProviderID(dataset.embedding_model_provider) data["embedding_model_provider"] = str(provider_id) @@ -445,7 +446,7 @@ class DatasetApi(Resource): data.update({"partial_member_list": part_users_list}) # check embedding setting - provider_manager = ProviderManager() + provider_manager = create_plugin_provider_manager(tenant_id=current_tenant_id) configurations = provider_manager.get_configurations(tenant_id=current_tenant_id) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) @@ -454,7 +455,7 @@ class DatasetApi(Resource): for embedding_model in embedding_models: model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") - if data["indexing_technique"] == "high_quality": + if data["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY: item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}" if item_model in model_names: data["embedding_available"] = True @@ -485,7 +486,7 @@ class DatasetApi(Resource): current_user, current_tenant_id = current_account_with_tenant() # check embedding model setting if ( - payload.indexing_technique == "high_quality" + payload.indexing_technique == IndexTechniqueType.HIGH_QUALITY and payload.embedding_model_provider is not None and payload.embedding_model is not None ): @@ -738,20 +739,23 @@ class DatasetIndexingStatusApi(Resource): documents_status = [] for document in documents: completed_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != SegmentStatus.RE_SEGMENT, + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) total_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) # Create a dictionary with document attributes and additional fields document_dict = { @@ -777,7 +781,7 @@ class DatasetIndexingStatusApi(Resource): class DatasetApiKeyApi(Resource): max_keys = 10 token_prefix = "dataset-" - resource_type = "dataset" + resource_type = ApiTokenType.DATASET @console_ns.doc("get_dataset_api_keys") @console_ns.doc(description="Get dataset API keys") @@ -802,9 +806,12 @@ class DatasetApiKeyApi(Resource): _, current_tenant_id = current_account_with_tenant() current_key_count = ( - db.session.query(ApiToken) - .where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id) - .count() + db.session.scalar( + select(func.count(ApiToken.id)).where( + ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id + ) + ) + or 0 ) if current_key_count >= self.max_keys: @@ -826,7 +833,7 @@ class DatasetApiKeyApi(Resource): @console_ns.route("/datasets/api-keys/") class DatasetApiDeleteApi(Resource): - resource_type = "dataset" + resource_type = ApiTokenType.DATASET @console_ns.doc("delete_dataset_api_key") @console_ns.doc(description="Delete dataset API key") @@ -839,14 +846,14 @@ class DatasetApiDeleteApi(Resource): def delete(self, api_key_id): _, current_tenant_id = current_account_with_tenant() api_key_id = str(api_key_id) - key = ( - db.session.query(ApiToken) + key = db.session.scalar( + select(ApiToken) .where( ApiToken.tenant_id == current_tenant_id, ApiToken.type == self.resource_type, ApiToken.id == api_key_id, ) - .first() + .limit(1) ) if key is None: @@ -857,7 +864,7 @@ class DatasetApiDeleteApi(Resource): assert key is not None # nosec - for type checker only ApiTokenCache.delete(key.token, key.type) - db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() + db.session.delete(key) db.session.commit() return {"result": "success"}, 204 diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index bc90c4ffbdf..ab367d84838 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -9,8 +9,10 @@ from uuid import UUID import sqlalchemy as sa from flask import request, send_file from flask_restx import Resource, fields, marshal, marshal_with +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import BaseModel, Field -from sqlalchemy import asc, desc, select +from sqlalchemy import asc, desc, func, select from werkzeug.exceptions import Forbidden, NotFound import services @@ -27,8 +29,7 @@ from core.model_manager import ModelManager from core.plugin.impl.exc import PluginDaemonClientSideError from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError +from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from fields.dataset_fields import dataset_fields from fields.document_fields import ( @@ -211,12 +212,11 @@ class GetProcessRuleApi(Resource): raise Forbidden(str(e)) # get the latest process rule - dataset_process_rule = ( - db.session.query(DatasetProcessRule) + dataset_process_rule = db.session.scalar( + select(DatasetProcessRule) .where(DatasetProcessRule.dataset_id == document.dataset_id) .order_by(DatasetProcessRule.created_at.desc()) .limit(1) - .one_or_none() ) if dataset_process_rule: mode = dataset_process_rule.mode @@ -330,21 +330,23 @@ class DatasetDocumentListApi(Resource): if fetch: for document in documents: completed_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != SegmentStatus.RE_SEGMENT, + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) total_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != SegmentStatus.RE_SEGMENT, + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) document.completed_segments = completed_segments document.total_segments = total_segments @@ -448,11 +450,11 @@ class DatasetInitApi(Resource): raise Forbidden() knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {}) - if knowledge_config.indexing_technique == "high_quality": + if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY: if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None: raise ValueError("embedding model and embedding model provider are required for high quality indexing.") try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=knowledge_config.embedding_model_provider, @@ -462,7 +464,7 @@ class DatasetInitApi(Resource): is_multimodal = DatasetService.check_is_multimodal_model( current_tenant_id, knowledge_config.embedding_model_provider, knowledge_config.embedding_model ) - knowledge_config.is_multimodal = is_multimodal + knowledge_config.is_multimodal = is_multimodal # pyrefly: ignore[bad-assignment] except InvokeAuthorizationError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." @@ -521,10 +523,10 @@ class DocumentIndexingEstimateApi(DocumentResource): if data_source_info and "upload_file_id" in data_source_info: file_id = data_source_info["upload_file_id"] - file = ( - db.session.query(UploadFile) + file = db.session.scalar( + select(UploadFile) .where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id) - .first() + .limit(1) ) # raise error if file not found @@ -586,10 +588,10 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): if not data_source_info: continue file_id = data_source_info["upload_file_id"] - file_detail = ( - db.session.query(UploadFile) + file_detail = db.session.scalar( + select(UploadFile) .where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id) - .first() + .limit(1) ) if file_detail is None: @@ -672,20 +674,23 @@ class DocumentBatchIndexingStatusApi(DocumentResource): documents_status = [] for document in documents: completed_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != SegmentStatus.RE_SEGMENT, + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) total_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) # Create a dictionary with document attributes and additional fields document_dict = { @@ -723,18 +728,23 @@ class DocumentIndexingStatusApi(DocumentResource): document = self.get_document(dataset_id, document_id) completed_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document_id), - DocumentSegment.status != SegmentStatus.RE_SEGMENT, + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document_id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) total_segments = ( - db.session.query(DocumentSegment) - .where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != SegmentStatus.RE_SEGMENT) - .count() + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.document_id == str(document_id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) + ) + or 0 ) # Create a dictionary with document attributes and additional fields @@ -1258,11 +1268,11 @@ class DocumentPipelineExecutionLogApi(DocumentResource): document = DocumentService.get_document(dataset.id, document_id) if not document: raise NotFound("Document not found.") - log = ( - db.session.query(DocumentPipelineExecutionLog) - .filter_by(document_id=document_id) + log = db.session.scalar( + select(DocumentPipelineExecutionLog) + .where(DocumentPipelineExecutionLog.document_id == document_id) .order_by(DocumentPipelineExecutionLog.created_at.desc()) - .first() + .limit(1) ) if not log: return { @@ -1328,7 +1338,7 @@ class DocumentGenerateSummaryApi(Resource): raise BadRequest("document_list cannot be empty.") # Check if dataset configuration supports summary generation - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: raise ValueError( f"Summary generation is only available for 'high_quality' indexing technique. " f"Current indexing technique: {dataset.indexing_technique}" diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 3fd0f3b7124..c5f4e3a6e26 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -2,6 +2,7 @@ import uuid from flask import request from flask_restx import Resource, marshal +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field from sqlalchemy import String, cast, func, or_, select from sqlalchemy.dialects.postgresql import JSONB @@ -26,7 +27,7 @@ from controllers.console.wraps import ( ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager -from dify_graph.model_runtime.entities.model_entities import ModelType +from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.segment_fields import child_chunk_fields, segment_fields @@ -45,7 +46,7 @@ def _get_segment_with_summary(segment, dataset_id): """Helper function to marshal segment and add summary information.""" from services.summary_index_service import SummaryIndexService - segment_dict = dict(marshal(segment, segment_fields)) + segment_dict = dict(marshal(segment, segment_fields)) # type: ignore # Query summary for this segment (only enabled summaries) summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id) segment_dict["summary"] = summary.summary_content if summary else None @@ -206,7 +207,7 @@ class DatasetDocumentSegmentListApi(Resource): # Add summary to each segment segments_with_summary = [] for segment in segments.items: - segment_dict = dict(marshal(segment, segment_fields)) + segment_dict = dict(marshal(segment, segment_fields)) # type: ignore segment_dict["summary"] = summaries.get(segment.id) segments_with_summary.append(segment_dict) @@ -279,10 +280,10 @@ class DatasetDocumentSegmentApi(Resource): DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -333,9 +334,9 @@ class DatasetDocumentSegmentAddApi(Resource): if not current_user.is_dataset_editor: raise Forbidden() # check embedding model setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -383,10 +384,10 @@ class DatasetDocumentSegmentUpdateApi(Resource): document = DocumentService.get_document(dataset_id, document_id) if not document: raise NotFound("Document not found.") - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -401,10 +402,10 @@ class DatasetDocumentSegmentUpdateApi(Resource): raise ProviderNotInitializeError(ex.description) # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") @@ -447,10 +448,10 @@ class DatasetDocumentSegmentUpdateApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") @@ -494,7 +495,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource): payload = BatchImportPayload.model_validate(console_ns.payload or {}) upload_file_id = payload.upload_file_id - upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() + upload_file = db.session.scalar(select(UploadFile).where(UploadFile.id == upload_file_id).limit(1)) if not upload_file: raise NotFound("UploadFile not found.") @@ -559,19 +560,19 @@ class ChildChunkAddApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") if not current_user.is_dataset_editor: raise Forbidden() # check embedding model setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -616,10 +617,10 @@ class ChildChunkAddApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") @@ -666,10 +667,10 @@ class ChildChunkAddApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") @@ -714,24 +715,24 @@ class ChildChunkUpdateApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") # check child chunk child_chunk_id = str(child_chunk_id) - child_chunk = ( - db.session.query(ChildChunk) + child_chunk = db.session.scalar( + select(ChildChunk) .where( ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_tenant_id, ChildChunk.segment_id == segment.id, ChildChunk.document_id == document_id, ) - .first() + .limit(1) ) if not child_chunk: raise NotFound("Child chunk not found.") @@ -771,24 +772,24 @@ class ChildChunkUpdateApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") # check child chunk child_chunk_id = str(child_chunk_id) - child_chunk = ( - db.session.query(ChildChunk) + child_chunk = db.session.scalar( + select(ChildChunk) .where( ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_tenant_id, ChildChunk.segment_id == segment.id, ChildChunk.document_id == document_id, ) - .first() + .limit(1) ) if not child_chunk: raise NotFound("Child chunk not found.") diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index 86090bcd108..fc6896f1233 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -25,7 +25,7 @@ from libs.login import current_account_with_tenant, login_required from services.dataset_service import DatasetService from services.external_knowledge_service import ExternalDatasetService from services.hit_testing_service import HitTestingService -from services.knowledge_service import ExternalDatasetTestService +from services.knowledge_service import BedrockRetrievalSetting, ExternalDatasetTestService def _build_dataset_detail_model(): @@ -86,7 +86,7 @@ class ExternalHitTestingPayload(BaseModel): class BedrockRetrievalPayload(BaseModel): - retrieval_setting: dict[str, object] + retrieval_setting: "BedrockRetrievalSetting" query: str knowledge_id: str diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index cd568cf8350..8fb3699849e 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -2,6 +2,7 @@ import logging from typing import Any from flask_restx import marshal +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -19,7 +20,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.model_runtime.errors.invoke import InvokeError from fields.hit_testing_fields import hit_testing_record_fields from libs.login import current_user from models.account import Account diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index a4498005d84..bdf83b991e5 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -2,6 +2,8 @@ from typing import Any from flask import make_response, redirect, request from flask_restx import Resource +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, NotFound @@ -10,8 +12,6 @@ from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from core.plugin.impl.oauth import OAuthHandler -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models.provider_ids import DatasourceProviderID from services.datasource_provider_service import DatasourceProviderService @@ -120,7 +120,8 @@ class DatasourceOAuthCallback(Resource): if context is None: raise Forbidden("Invalid context_id") - user_id, tenant_id = context.get("user_id"), context.get("tenant_id") + user_id: str = context["user_id"] + tenant_id: str = context["tenant_id"] datasource_provider_id = DatasourceProviderID(provider_id) plugin_id = datasource_provider_id.plugin_id datasource_provider_service = DatasourceProviderService() @@ -141,7 +142,7 @@ class DatasourceOAuthCallback(Resource): system_credentials=oauth_client_params, request=request, ) - credential_id = context.get("credential_id") + credential_id: str | None = context.get("credential_id") if credential_id: datasource_provider_service.reauthorize_datasource_oauth_provider( tenant_id=tenant_id, @@ -150,7 +151,7 @@ class DatasourceOAuthCallback(Resource): name=oauth_response.metadata.get("name") or None, expire_at=oauth_response.expires_at, credentials=dict(oauth_response.credentials), - credential_id=context.get("credential_id"), + credential_id=credential_id, ) else: datasource_provider_service.add_datasource_oauth_provider( diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py index c5dadb75f5f..f12cbd34959 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -3,6 +3,7 @@ from typing import Any, NoReturn from flask import Response, request from flask_restx import Resource, marshal, marshal_with +from graphon.variables.types import SegmentType from pydantic import BaseModel, Field from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden @@ -21,8 +22,8 @@ from controllers.console.app.workflow_draft_variable import ( from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import InvalidArgumentError, NotFoundError -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from dify_graph.variables.types import SegmentType +from core.app.file_access import DatabaseFileAccessController +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type @@ -33,6 +34,7 @@ from services.rag_pipeline.rag_pipeline import RagPipelineService from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() def _create_pagination_parser(): @@ -223,13 +225,21 @@ class RagPipelineVariableApi(Resource): if variable.value_type == SegmentType.FILE: if not isinstance(raw_value, dict): raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}") - raw_value = build_from_mapping(mapping=raw_value, tenant_id=pipeline.tenant_id) + raw_value = build_from_mapping( + mapping=raw_value, + tenant_id=pipeline.tenant_id, + access_controller=_file_access_controller, + ) elif variable.value_type == SegmentType.ARRAY_FILE: if not isinstance(raw_value, list): raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}") if len(raw_value) > 0 and not isinstance(raw_value[0], dict): raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}") - raw_value = build_from_mappings(mappings=raw_value, tenant_id=pipeline.tenant_id) + raw_value = build_from_mappings( + mappings=raw_value, + tenant_id=pipeline.tenant_id, + access_controller=_file_access_controller, + ) new_value = build_segment_with_type(variable.value_type, raw_value) draft_var_srv.update_variable(variable, name=new_name, value=new_value) db.session.commit() diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 51cdcc0c7a7..8efb59a8e9c 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -4,9 +4,10 @@ from typing import Any, Literal, cast from flask import abort, request from flask_restx import Resource, marshal_with # type: ignore +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field from sqlalchemy.orm import Session -from werkzeug.exceptions import Forbidden, InternalServerError, NotFound +from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound import services from controllers.common.schema import register_schema_models @@ -16,7 +17,11 @@ from controllers.console.app.error import ( DraftWorkflowNotExist, DraftWorkflowNotSync, ) -from controllers.console.app.workflow import workflow_model, workflow_pagination_model +from controllers.console.app.workflow import ( + RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE, + workflow_model, + workflow_pagination_model, +) from controllers.console.app.workflow_run import ( workflow_run_detail_model, workflow_run_node_execution_list_model, @@ -33,7 +38,6 @@ from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpErr from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.pipeline.pipeline_generator import PipelineGenerator from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from factories import variable_factory from libs import helper @@ -42,12 +46,14 @@ from libs.login import current_account_with_tenant, current_user, login_required from models import Account from models.dataset import Pipeline from models.model import EndUser -from services.errors.app import WorkflowHashNotEqualError +from models.workflow import Workflow +from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService from services.rag_pipeline.rag_pipeline import RagPipelineService from services.rag_pipeline.rag_pipeline_manage_service import RagPipelineManageService from services.rag_pipeline.rag_pipeline_transform_service import RagPipelineTransformService +from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService logger = logging.getLogger(__name__) @@ -203,9 +209,12 @@ class DraftRagPipelineApi(Resource): abort(415) payload = DraftWorkflowSyncPayload.model_validate(payload_dict) + rag_pipeline_service = RagPipelineService() try: - environment_variables_list = payload.environment_variables or [] + environment_variables_list = Workflow.normalize_environment_variable_mappings( + payload.environment_variables or [], + ) environment_variables = [ variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list ] @@ -213,7 +222,6 @@ class DraftRagPipelineApi(Resource): conversation_variables = [ variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list ] - rag_pipeline_service = RagPipelineService() workflow = rag_pipeline_service.sync_draft_workflow( pipeline=pipeline, graph=payload.graph, @@ -705,6 +713,36 @@ class PublishedAllRagPipelineApi(Resource): } +@console_ns.route("/rag/pipelines//workflows//restore") +class RagPipelineDraftWorkflowRestoreApi(Resource): + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + @get_rag_pipeline + def post(self, pipeline: Pipeline, workflow_id: str): + current_user, _ = current_account_with_tenant() + rag_pipeline_service = RagPipelineService() + + try: + workflow = rag_pipeline_service.restore_published_workflow_to_draft( + pipeline=pipeline, + workflow_id=workflow_id, + account=current_user, + ) + except IsDraftWorkflowError as exc: + # Use a stable, predefined message to keep the 400 response consistent + raise BadRequest(RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE) from exc + except WorkflowNotFoundError as exc: + raise NotFound(str(exc)) from exc + + return { + "result": "success", + "hash": workflow.unique_hash, + "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at), + } + + @console_ns.route("/rag/pipelines//workflows/") class RagPipelineByIdApi(Resource): @setup_required @@ -744,7 +782,38 @@ class RagPipelineByIdApi(Resource): # Commit the transaction in the controller session.commit() - return workflow + return workflow + + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + @get_rag_pipeline + def delete(self, pipeline: Pipeline, workflow_id: str): + """ + Delete a published workflow version that is not currently active on the pipeline. + """ + if pipeline.workflow_id == workflow_id: + abort(400, description=f"Cannot delete workflow that is currently in use by pipeline '{pipeline.id}'") + + workflow_service = WorkflowService() + + with Session(db.engine) as session: + try: + workflow_service.delete_workflow( + session=session, + workflow_id=workflow_id, + tenant_id=pipeline.tenant_id, + ) + session.commit() + except WorkflowInUseError as e: + abort(400, description=str(e)) + except DraftWorkflowDeletionError as e: + abort(400, description=str(e)) + except ValueError as e: + raise NotFound(str(e)) + + return None, 204 @console_ns.route("/rag/pipelines//workflows/published/processing/parameters") diff --git a/api/controllers/console/datasets/wraps.py b/api/controllers/console/datasets/wraps.py index 3ef1341abc8..d533e6c5b14 100644 --- a/api/controllers/console/datasets/wraps.py +++ b/api/controllers/console/datasets/wraps.py @@ -2,6 +2,8 @@ from collections.abc import Callable from functools import wraps from typing import ParamSpec, TypeVar +from sqlalchemy import select + from controllers.console.datasets.error import PipelineNotFoundError from extensions.ext_database import db from libs.login import current_account_with_tenant @@ -24,10 +26,8 @@ def get_rag_pipeline(view_func: Callable[P, R]): del kwargs["pipeline_id"] - pipeline = ( - db.session.query(Pipeline) - .where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id) - .first() + pipeline = db.session.scalar( + select(Pipeline).where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id).limit(1) ) if not pipeline: diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index ffb9e5bb6ee..b1b01b5f51c 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -1,6 +1,7 @@ import logging from flask import request +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError @@ -19,7 +20,6 @@ from controllers.console.app.error import ( ) from controllers.console.explore.wraps import InstalledAppResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.audio import ( AudioTooLargeServiceError, diff --git a/api/controllers/console/explore/banner.py b/api/controllers/console/explore/banner.py index 5dfef6bf6af..757061d8dda 100644 --- a/api/controllers/console/explore/banner.py +++ b/api/controllers/console/explore/banner.py @@ -1,5 +1,6 @@ from flask import request from flask_restx import Resource +from sqlalchemy import select from controllers.console import api from controllers.console.explore.wraps import explore_banner_enabled @@ -17,14 +18,18 @@ class BannerApi(Resource): language = request.args.get("language", "en-US") # Build base query for enabled banners - base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == BannerStatus.ENABLED) + base_query = select(ExporleBanner).where(ExporleBanner.status == BannerStatus.ENABLED) # Try to get banners in the requested language - banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all() + banners = db.session.scalars( + base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort) + ).all() # Fallback to en-US if no banners found and language is not en-US if not banners and language != "en-US": - banners = base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort).all() + banners = db.session.scalars( + base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort) + ).all() # Convert banners to serializable format result = [] for banner in banners: diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index fcd52d28187..eacd7332fe8 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -2,6 +2,7 @@ import logging from typing import Any, Literal from uuid import UUID +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound @@ -24,7 +25,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from libs import helper from libs.datetime_utils import naive_utc_now diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index aca766567fd..0740dd0e24c 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -133,13 +133,15 @@ class InstalledAppsListApi(Resource): def post(self): payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {}) - recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).first() + recommended_app = db.session.scalar( + select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).limit(1) + ) if recommended_app is None: raise NotFound("Recommended app not found") _, current_tenant_id = current_account_with_tenant() - app = db.session.query(App).where(App.id == payload.app_id).first() + app = db.session.get(App, payload.app_id) if app is None: raise NotFound("App entity not found") @@ -147,10 +149,10 @@ class InstalledAppsListApi(Resource): if not app.is_public: raise Forbidden("You can't install a non-public app") - installed_app = ( - db.session.query(InstalledApp) + installed_app = db.session.scalar( + select(InstalledApp) .where(and_(InstalledApp.app_id == payload.app_id, InstalledApp.tenant_id == current_tenant_id)) - .first() + .limit(1) ) if installed_app is None: diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 15e1aea361a..fcbefcda33b 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -2,6 +2,7 @@ import logging from typing import Literal from flask import request +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, TypeAdapter from werkzeug.exceptions import InternalServerError, NotFound @@ -21,7 +22,6 @@ from controllers.console.explore.error import ( from controllers.console.explore.wraps import InstalledAppResource from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError from fields.conversation_fields import ResultResponse from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse from libs import helper diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py index 25bb8ed7fec..e432574434d 100644 --- a/api/controllers/console/explore/trial.py +++ b/api/controllers/console/explore/trial.py @@ -3,7 +3,10 @@ from typing import Any, Literal, cast from flask import request from flask_restx import Resource, fields, marshal, marshal_with +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel +from sqlalchemy import select from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services @@ -41,8 +44,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.graph_engine.manager import GraphEngineManager -from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.app_fields import ( @@ -476,7 +477,7 @@ class TrialSitApi(Resource): Returns the site configuration for the application including theme, icons, and text. """ - site = db.session.query(Site).where(Site.app_id == app_model.id).first() + site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) if not site: raise Forbidden() @@ -541,13 +542,7 @@ class AppWorkflowApi(Resource): if not app_model.workflow_id: raise AppUnavailableError() - workflow = ( - db.session.query(Workflow) - .where( - Workflow.id == app_model.workflow_id, - ) - .first() - ) + workflow = db.session.get(Workflow, app_model.workflow_id) return workflow diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index 7801cee4735..42cafc71932 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -1,6 +1,8 @@ import logging from typing import Any +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel from werkzeug.exceptions import InternalServerError @@ -21,8 +23,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.graph_engine.manager import GraphEngineManager -from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_redis import redis_client from libs import helper from libs.login import current_account_with_tenant diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index 03edb871e63..9d9337e63e9 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -4,6 +4,7 @@ from typing import Concatenate, ParamSpec, TypeVar from flask import abort from flask_restx import Resource +from sqlalchemy import select from werkzeug.exceptions import NotFound from controllers.console.explore.error import AppAccessDeniedError, TrialAppLimitExceeded, TrialAppNotAllowed @@ -24,10 +25,10 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non @wraps(view) def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs): _, current_tenant_id = current_account_with_tenant() - installed_app = ( - db.session.query(InstalledApp) + installed_app = db.session.scalar( + select(InstalledApp) .where(InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_tenant_id) - .first() + .limit(1) ) if installed_app is None: @@ -78,7 +79,7 @@ def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None): def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs): current_user, _ = current_account_with_tenant() - trial_app = db.session.query(TrialApp).where(TrialApp.app_id == str(app_id)).first() + trial_app = db.session.scalar(select(TrialApp).where(TrialApp.app_id == str(app_id)).limit(1)) if trial_app is None: raise TrialAppNotAllowed() @@ -87,10 +88,10 @@ def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None): if app is None: raise TrialAppNotAllowed() - account_trial_app_record = ( - db.session.query(AccountTrialAppRecord) + account_trial_app_record = db.session.scalar( + select(AccountTrialAppRecord) .where(AccountTrialAppRecord.account_id == current_user.id, AccountTrialAppRecord.app_id == app_id) - .first() + .limit(1) ) if account_trial_app_record: if account_trial_app_record.count >= trial_app.trial_limit: diff --git a/api/controllers/console/human_input_form.py b/api/controllers/console/human_input_form.py index 7207f7fd1d5..e37e78c966f 100644 --- a/api/controllers/console/human_input_form.py +++ b/api/controllers/console/human_input_form.py @@ -15,6 +15,7 @@ from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import InvalidArgumentError, NotFoundError from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator +from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.apps.message_generator import MessageGenerator from core.app.apps.workflow.app_generator import WorkflowAppGenerator @@ -166,6 +167,7 @@ class ConsoleWorkflowEventsApi(Resource): else: msg_generator = MessageGenerator() + generator: BaseAppGenerator if app.mode == AppMode.ADVANCED_CHAT: generator = AdvancedChatAppGenerator() elif app.mode == AppMode.WORKFLOW: @@ -202,7 +204,7 @@ class ConsoleWorkflowEventsApi(Resource): ) -def _retrieve_app_for_workflow_run(session: Session, workflow_run: WorkflowRun): +def _retrieve_app_for_workflow_run(session: Session, workflow_run: WorkflowRun) -> App: query = select(App).where( App.id == workflow_run.app_id, App.tenant_id == workflow_run.tenant_id, diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index 49162d4dae6..551c86fd827 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -2,6 +2,7 @@ import urllib.parse import httpx from flask_restx import Resource +from graphon.file import helpers as file_helpers from pydantic import BaseModel, Field import services @@ -13,7 +14,6 @@ from controllers.common.errors import ( ) from controllers.console import console_ns from core.helper import ssrf_proxy -from dify_graph.file import helpers as file_helpers from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo from libs.login import current_account_with_tenant, login_required diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index e099fe0f324..279e4ec502d 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -2,6 +2,7 @@ from typing import Literal from flask import request from pydantic import BaseModel, Field, field_validator +from sqlalchemy import select from configs import dify_config from controllers.fastopenapi import console_router @@ -100,6 +101,6 @@ def setup_system(payload: SetupRequestPayload) -> SetupResponse: def get_setup_status() -> DifySetup | bool | None: if dify_config.EDITION == "SELF_HOSTED": - return db.session.query(DifySetup).first() + return db.session.scalar(select(DifySetup).limit(1)) return True diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 0d8960c9bd8..6f93ff1e70e 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -212,13 +212,13 @@ class AccountInitApi(Resource): raise ValueError("invitation_code is required") # check invitation code - invitation_code = ( - db.session.query(InvitationCode) + invitation_code = db.session.scalar( + select(InvitationCode) .where( InvitationCode.code == args.invitation_code, InvitationCode.status == InvitationCodeStatus.UNUSED, ) - .first() + .limit(1) ) if not invitation_code: diff --git a/api/controllers/console/workspace/agent_providers.py b/api/controllers/console/workspace/agent_providers.py index e2b504751ba..3fdcbc47108 100644 --- a/api/controllers/console/workspace/agent_providers.py +++ b/api/controllers/console/workspace/agent_providers.py @@ -1,8 +1,8 @@ from flask_restx import Resource, fields +from graphon.model_runtime.utils.encoders import jsonable_encoder from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from services.agent_service import AgentService diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index 538c5fb561d..b6b9deb1f92 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -2,13 +2,13 @@ from typing import Any from flask import request from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.plugin.impl.exc import PluginPermissionDeniedError -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from services.plugin.endpoint_service import EndpointService diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index 0a9e54de99a..e4cfca9fa4c 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -1,12 +1,12 @@ from flask_restx import Resource +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic import BaseModel from werkzeug.exceptions import Forbidden from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError from libs.login import current_account_with_tenant, login_required from models import TenantAccountRole from services.model_load_balancing_service import ModelLoadBalancingService diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index dd302b90d69..e3bf4c95b8e 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -171,7 +171,7 @@ class MemberCancelInviteApi(Resource): current_user, _ = current_account_with_tenant() if not current_user.current_tenant: raise ValueError("No current tenant") - member = db.session.query(Account).where(Account.id == str(member_id)).first() + member = db.session.get(Account, str(member_id)) if member is None: abort(404) else: diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index db3b02ae949..8e0aefc9e3e 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -3,13 +3,13 @@ from typing import Any, Literal from flask import request, send_file from flask_restx import Resource +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from libs.helper import uuid_value from libs.login import current_account_with_tenant, login_required from services.billing_service import BillingService diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index d7eceb656c1..2ec1a9435a2 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -3,14 +3,14 @@ from typing import Any, cast from flask import request from flask_restx import Resource +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from controllers.common.schema import register_enum_models, register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from libs.helper import uuid_value from libs.login import current_account_with_tenant, login_required from services.model_load_balancing_service import ModelLoadBalancingService @@ -282,14 +282,18 @@ class ModelProviderModelCredentialApi(Resource): ) if args.config_from == "predefined-model": - available_credentials = model_provider_service.provider_manager.get_provider_available_credentials( - tenant_id=tenant_id, provider_name=provider + available_credentials = model_provider_service.get_provider_available_credentials( + tenant_id=tenant_id, + provider=provider, ) else: # Normalize model_type to the origin value stored in DB (e.g., "text-generation" for LLM) normalized_model_type = args.model_type.to_origin_model_type() - available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials( - tenant_id=tenant_id, provider_name=provider, model_type=normalized_model_type, model_name=args.model + available_credentials = model_provider_service.get_provider_model_available_credentials( + tenant_id=tenant_id, + provider=provider, + model_type=normalized_model_type, + model=args.model, ) return jsonable_encoder( diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index ee537367c7d..aa674a63b30 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -4,6 +4,7 @@ from typing import Any, Literal from flask import request, send_file from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field from werkzeug.datastructures import FileStorage from werkzeug.exceptions import Forbidden @@ -14,7 +15,6 @@ from controllers.console import console_ns from controllers.console.workspace import plugin_permission_required from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.plugin.impl.exc import PluginDaemonClientSideError -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService @@ -200,7 +200,7 @@ class PluginDebuggingKeyApi(Resource): "port": dify_config.PLUGIN_REMOTE_INSTALL_PORT, } except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/list") @@ -215,7 +215,7 @@ class PluginListApi(Resource): try: plugins_with_total = PluginService.list_with_total(tenant_id, args.page, args.page_size) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total}) @@ -232,7 +232,7 @@ class PluginListLatestVersionsApi(Resource): try: versions = PluginService.list_latest_versions(args.plugin_ids) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder({"versions": versions}) @@ -251,7 +251,7 @@ class PluginListInstallationsFromIdsApi(Resource): try: plugins = PluginService.list_installations_from_ids(tenant_id, args.plugin_ids) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder({"plugins": plugins}) @@ -266,7 +266,7 @@ class PluginIconApi(Resource): try: icon_bytes, mimetype = PluginService.get_asset(args.tenant_id, args.filename) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age) @@ -286,7 +286,7 @@ class PluginAssetApi(Resource): binary = PluginService.extract_asset(tenant_id, args.plugin_unique_identifier, args.file_name) return send_file(io.BytesIO(binary), mimetype="application/octet-stream") except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/upload/pkg") @@ -303,7 +303,7 @@ class PluginUploadFromPkgApi(Resource): try: response = PluginService.upload_pkg(tenant_id, content) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder(response) @@ -323,7 +323,7 @@ class PluginUploadFromGithubApi(Resource): try: response = PluginService.upload_pkg_from_github(tenant_id, args.repo, args.version, args.package) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder(response) @@ -361,7 +361,7 @@ class PluginInstallFromPkgApi(Resource): try: response = PluginService.install_from_local_pkg(tenant_id, args.plugin_unique_identifiers) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder(response) @@ -387,7 +387,7 @@ class PluginInstallFromGithubApi(Resource): args.package, ) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder(response) @@ -407,7 +407,7 @@ class PluginInstallFromMarketplaceApi(Resource): try: response = PluginService.install_from_marketplace_pkg(tenant_id, args.plugin_unique_identifiers) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder(response) @@ -433,7 +433,7 @@ class PluginFetchMarketplacePkgApi(Resource): } ) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/fetch-manifest") @@ -453,7 +453,7 @@ class PluginFetchManifestApi(Resource): {"manifest": PluginService.fetch_plugin_manifest(tenant_id, args.plugin_unique_identifier).model_dump()} ) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/tasks") @@ -471,7 +471,7 @@ class PluginFetchInstallTasksApi(Resource): try: return jsonable_encoder({"tasks": PluginService.fetch_install_tasks(tenant_id, args.page, args.page_size)}) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/tasks/") @@ -486,7 +486,7 @@ class PluginFetchInstallTaskApi(Resource): try: return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)}) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/tasks//delete") @@ -501,7 +501,7 @@ class PluginDeleteInstallTaskApi(Resource): try: return {"success": PluginService.delete_install_task(tenant_id, task_id)} except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/tasks/delete_all") @@ -516,7 +516,7 @@ class PluginDeleteAllInstallTaskItemsApi(Resource): try: return {"success": PluginService.delete_all_install_task_items(tenant_id)} except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/tasks//delete/") @@ -531,7 +531,7 @@ class PluginDeleteInstallTaskItemApi(Resource): try: return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)} except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/upgrade/marketplace") @@ -553,7 +553,7 @@ class PluginUpgradeFromMarketplaceApi(Resource): ) ) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/upgrade/github") @@ -580,7 +580,7 @@ class PluginUpgradeFromGithubApi(Resource): ) ) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/uninstall") @@ -598,7 +598,7 @@ class PluginUninstallApi(Resource): try: return {"success": PluginService.uninstall(tenant_id, args.plugin_installation_id)} except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/permission/change") @@ -674,7 +674,7 @@ class PluginFetchDynamicSelectOptionsApi(Resource): provider_type=args.provider_type, ) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder({"options": options}) @@ -705,7 +705,7 @@ class PluginFetchDynamicSelectOptionsWithCredentialsApi(Resource): credentials=args.credentials, ) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder({"options": options}) diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index b38f05795ab..80216915cd1 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -5,6 +5,7 @@ from urllib.parse import urlparse from flask import make_response, redirect, request, send_file from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, HttpUrl, field_validator, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden @@ -26,7 +27,6 @@ from core.mcp.mcp_client import MCPClient from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from libs.helper import alphanumeric, uuid_value from libs.login import current_account_with_tenant, login_required @@ -832,7 +832,8 @@ class ToolOAuthCallback(Resource): tool_provider = ToolProviderID(provider) plugin_id = tool_provider.plugin_id provider_name = tool_provider.provider_name - user_id, tenant_id = context.get("user_id"), context.get("tenant_id") + user_id: str = context["user_id"] + tenant_id: str = context["tenant_id"] oauth_handler = OAuthHandler() oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id, provider) diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py index ad78d2a623c..76d64cb97c8 100644 --- a/api/controllers/console/workspace/trigger_providers.py +++ b/api/controllers/console/workspace/trigger_providers.py @@ -3,6 +3,7 @@ from typing import Any from flask import make_response, redirect, request from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden @@ -14,7 +15,6 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler from core.trigger.entities.entities import SubscriptionBuilderUpdater from core.trigger.trigger_manager import TriggerManager -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from libs.login import current_user, login_required from models.account import Account @@ -499,9 +499,9 @@ class TriggerOAuthCallbackApi(Resource): provider_id = TriggerProviderID(provider) plugin_id = provider_id.plugin_id provider_name = provider_id.provider_name - user_id = context.get("user_id") - tenant_id = context.get("tenant_id") - subscription_builder_id = context.get("subscription_builder_id") + user_id: str = context["user_id"] + tenant_id: str = context["tenant_id"] + subscription_builder_id: str = context["subscription_builder_id"] # Get OAuth client configuration oauth_client_params = TriggerProviderService.get_oauth_client( diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 94be81d94f7..88fd2c010f0 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -7,6 +7,7 @@ from sqlalchemy import select from werkzeug.exceptions import Unauthorized import services +from configs import dify_config from controllers.common.errors import ( FilenameNotExistsError, FileTooLargeError, @@ -29,6 +30,7 @@ from libs.helper import TimestampField from libs.login import current_account_with_tenant, login_required from models.account import Tenant, TenantStatus from services.account_service import TenantService +from services.billing_service import BillingService, SubscriptionPlan from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService from services.file_service import FileService @@ -108,9 +110,29 @@ class TenantListApi(Resource): current_user, current_tenant_id = current_account_with_tenant() tenants = TenantService.get_join_tenants(current_user) tenant_dicts = [] + is_enterprise_only = dify_config.ENTERPRISE_ENABLED and not dify_config.BILLING_ENABLED + is_saas = dify_config.EDITION == "CLOUD" and dify_config.BILLING_ENABLED + tenant_plans: dict[str, SubscriptionPlan] = {} + + if is_saas: + tenant_ids = [tenant.id for tenant in tenants] + if tenant_ids: + tenant_plans = BillingService.get_plan_bulk(tenant_ids) + if not tenant_plans: + logger.warning("get_plan_bulk returned empty result, falling back to legacy feature path") for tenant in tenants: - features = FeatureService.get_features(tenant.id) + plan: str = CloudPlan.SANDBOX + if is_saas: + tenant_plan = tenant_plans.get(tenant.id) + if tenant_plan: + plan = tenant_plan["plan"] or CloudPlan.SANDBOX + else: + features = FeatureService.get_features(tenant.id) + plan = features.billing.subscription.plan or CloudPlan.SANDBOX + elif not is_enterprise_only: + features = FeatureService.get_features(tenant.id) + plan = features.billing.subscription.plan or CloudPlan.SANDBOX # Create a dictionary with tenant attributes tenant_dict = { @@ -118,7 +140,7 @@ class TenantListApi(Resource): "name": tenant.name, "status": tenant.status, "created_at": tenant.created_at, - "plan": features.billing.subscription.plan if features.billing.enabled else CloudPlan.SANDBOX, + "plan": plan, "current": tenant.id == current_tenant_id if current_tenant_id else False, } @@ -198,7 +220,7 @@ class SwitchWorkspaceApi(Resource): except Exception: raise AccountNotLinkTenantError("Account not link tenant") - new_tenant = db.session.query(Tenant).get(args.tenant_id) # Get new tenant + new_tenant = db.session.get(Tenant, args.tenant_id) # Get new tenant if new_tenant is None: raise ValueError("Tenant not found") diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 014f4c41325..6785ba0c344 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -7,6 +7,7 @@ from functools import wraps from typing import ParamSpec, TypeVar from flask import abort, request +from sqlalchemy import select from configs import dify_config from controllers.console.auth.error import AuthenticationFailedError, EmailCodeError @@ -218,13 +219,9 @@ def setup_required(view: Callable[P, R]) -> Callable[P, R]: @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs) -> R: # check setup - if ( - dify_config.EDITION == "SELF_HOSTED" - and os.environ.get("INIT_PASSWORD") - and not db.session.query(DifySetup).first() - ): - raise NotInitValidateError() - elif dify_config.EDITION == "SELF_HOSTED" and not db.session.query(DifySetup).first(): + if dify_config.EDITION == "SELF_HOSTED" and not db.session.scalar(select(DifySetup).limit(1)): + if os.environ.get("INIT_PASSWORD"): + raise NotInitValidateError() raise NotSetupError() return view(*args, **kwargs) diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index 9e3fb3a90b9..2f1e2f28bd8 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -70,22 +70,25 @@ class ToolFileApi(Resource): except Exception: raise UnsupportedFileTypeError() + mime_type = tool_file.mime_type + filename = tool_file.filename + response = Response( stream, - mimetype=tool_file.mimetype, + mimetype=mime_type, direct_passthrough=True, headers={}, ) if tool_file.size > 0: response.headers["Content-Length"] = str(tool_file.size) - if args.as_attachment: - encoded_filename = quote(tool_file.name) + if args.as_attachment and filename: + encoded_filename = quote(filename) response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}" enforce_download_for_html( response, - mime_type=tool_file.mimetype, - filename=tool_file.name, + mime_type=mime_type, + filename=filename, extension=extension, ) diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py index 52690a12e1d..ed3278a28b0 100644 --- a/api/controllers/files/upload.py +++ b/api/controllers/files/upload.py @@ -7,8 +7,8 @@ from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden import services +from core.tools.signature import verify_plugin_file_signature from core.tools.tool_file_manager import ToolFileManager -from dify_graph.file.helpers import verify_plugin_file_signature from fields.file_fields import FileResponse from ..common.errors import ( diff --git a/api/controllers/inner_api/__init__.py b/api/controllers/inner_api/__init__.py index 74005217efa..b38994f055c 100644 --- a/api/controllers/inner_api/__init__.py +++ b/api/controllers/inner_api/__init__.py @@ -16,12 +16,14 @@ api = ExternalApi( inner_api_ns = Namespace("inner_api", description="Internal API operations", path="/") from . import mail as _mail +from .app import dsl as _app_dsl from .plugin import plugin as _plugin from .workspace import workspace as _workspace api.add_namespace(inner_api_ns) __all__ = [ + "_app_dsl", "_mail", "_plugin", "_workspace", diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/__init__.py b/api/controllers/inner_api/app/__init__.py similarity index 100% rename from api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/__init__.py rename to api/controllers/inner_api/app/__init__.py diff --git a/api/controllers/inner_api/app/dsl.py b/api/controllers/inner_api/app/dsl.py new file mode 100644 index 00000000000..3b673d6e1d3 --- /dev/null +++ b/api/controllers/inner_api/app/dsl.py @@ -0,0 +1,111 @@ +"""Inner API endpoints for app DSL import/export. + +Called by the enterprise admin-api service. Import requires ``creator_email`` +to attribute the created app; workspace/membership validation is done by the +Go admin-api caller. +""" + +from flask import request +from flask_restx import Resource +from pydantic import BaseModel, Field +from sqlalchemy import select +from sqlalchemy.orm import Session + +from controllers.common.schema import register_schema_model +from controllers.console.wraps import setup_required +from controllers.inner_api import inner_api_ns +from controllers.inner_api.wraps import enterprise_inner_api_only +from extensions.ext_database import db +from models import Account, App +from models.account import AccountStatus +from services.app_dsl_service import AppDslService, ImportMode, ImportStatus + + +class InnerAppDSLImportPayload(BaseModel): + yaml_content: str = Field(description="YAML DSL content") + creator_email: str = Field(description="Email of the workspace member who will own the imported app") + name: str | None = Field(default=None, description="Override app name from DSL") + description: str | None = Field(default=None, description="Override app description from DSL") + + +register_schema_model(inner_api_ns, InnerAppDSLImportPayload) + + +@inner_api_ns.route("/enterprise/workspaces//dsl/import") +class EnterpriseAppDSLImport(Resource): + @setup_required + @enterprise_inner_api_only + @inner_api_ns.doc("enterprise_app_dsl_import") + @inner_api_ns.expect(inner_api_ns.models[InnerAppDSLImportPayload.__name__]) + @inner_api_ns.doc( + responses={ + 200: "Import completed", + 202: "Import pending (DSL version mismatch requires confirmation)", + 400: "Import failed (business error)", + 404: "Creator account not found or inactive", + } + ) + def post(self, workspace_id: str): + """Import a DSL into a workspace on behalf of a specified creator.""" + args = InnerAppDSLImportPayload.model_validate(inner_api_ns.payload or {}) + + account = _get_active_account(args.creator_email) + if account is None: + return {"message": f"account '{args.creator_email}' not found or inactive"}, 404 + + account.set_tenant_id(workspace_id) + + with Session(db.engine) as session: + dsl_service = AppDslService(session) + result = dsl_service.import_app( + account=account, + import_mode=ImportMode.YAML_CONTENT, + yaml_content=args.yaml_content, + name=args.name, + description=args.description, + ) + session.commit() + + if result.status == ImportStatus.FAILED: + return result.model_dump(mode="json"), 400 + if result.status == ImportStatus.PENDING: + return result.model_dump(mode="json"), 202 + return result.model_dump(mode="json"), 200 + + +@inner_api_ns.route("/enterprise/apps//dsl") +class EnterpriseAppDSLExport(Resource): + @setup_required + @enterprise_inner_api_only + @inner_api_ns.doc( + "enterprise_app_dsl_export", + responses={ + 200: "Export successful", + 404: "App not found", + }, + ) + def get(self, app_id: str): + """Export an app's DSL as YAML.""" + include_secret = request.args.get("include_secret", "false").lower() == "true" + + app_model = db.session.get(App, app_id) + if not app_model: + return {"message": "app not found"}, 404 + + data = AppDslService.export_dsl( + app_model=app_model, + include_secret=include_secret, + ) + + return {"data": data}, 200 + + +def _get_active_account(email: str) -> Account | None: + """Look up an active account by email. + + Workspace membership is already validated by the Go admin-api caller. + """ + account = db.session.scalar(select(Account).where(Account.email == email).limit(1)) + if account is None or account.status != AccountStatus.ACTIVE: + return None + return account diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 9b8b3950e60..83c8fa02fee 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -1,4 +1,5 @@ from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from controllers.console.wraps import setup_required from controllers.inner_api import inner_api_ns @@ -28,8 +29,7 @@ from core.plugin.entities.request import ( RequestRequestUploadFile, ) from core.tools.entities.tool_entities import ToolProviderType -from dify_graph.file.helpers import get_signed_file_url_for_plugin -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from core.tools.signature import get_signed_file_url_for_plugin from libs.helper import length_prefixed_response from models import Account, Tenant from models.model import EndUser diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index 9ddaaa315b3..3d00f77e79f 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -2,6 +2,7 @@ from typing import Any, Union from flask import Response from flask_restx import Resource +from graphon.variables.input_entities import VariableEntity from pydantic import BaseModel, Field, ValidationError from sqlalchemy.orm import Session @@ -9,7 +10,6 @@ from controllers.common.schema import register_schema_model from controllers.mcp import mcp_ns from core.mcp import types as mcp_types from core.mcp.server.streamable_http import handle_mcp_request -from dify_graph.variables.input_entities import VariableEntity from extensions.ext_database import db from libs import helper from models.enums import AppMCPServerStatus diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 38d292d0b90..6228cfc25be 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -2,6 +2,7 @@ import logging from flask import request from flask_restx import Resource +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError @@ -21,7 +22,6 @@ from controllers.service_api.app.error import ( ) from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError from models.model import App, EndUser from services.audio_service import AudioService from services.errors.audio import ( diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 98f09c44a18..3142e5118e9 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -4,6 +4,7 @@ from uuid import UUID from flask import request from flask_restx import Resource +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import BadRequest, InternalServerError, NotFound @@ -28,7 +29,6 @@ from core.errors.error import ( QuotaExceededError, ) from core.helper.trace_id_helper import get_external_trace_id -from dify_graph.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import UUIDStrOrEmpty from models.model import App, AppMode, EndUser diff --git a/api/controllers/service_api/app/file_preview.py b/api/controllers/service_api/app/file_preview.py index f853a124efa..5e7847d784f 100644 --- a/api/controllers/service_api/app/file_preview.py +++ b/api/controllers/service_api/app/file_preview.py @@ -4,6 +4,7 @@ from urllib.parse import quote from flask import Response, request from flask_restx import Resource from pydantic import BaseModel, Field +from sqlalchemy import select from controllers.common.file_response import enforce_download_for_html from controllers.common.schema import register_schema_model @@ -102,27 +103,27 @@ class FilePreviewApi(Resource): raise FileAccessDeniedError("Invalid file or app identifier") # First, find the MessageFile that references this upload file - message_file = db.session.query(MessageFile).where(MessageFile.upload_file_id == file_id).first() + message_file = db.session.scalar(select(MessageFile).where(MessageFile.upload_file_id == file_id).limit(1)) if not message_file: raise FileNotFoundError("File not found in message context") # Get the message and verify it belongs to the requesting app - message = ( - db.session.query(Message).where(Message.id == message_file.message_id, Message.app_id == app_id).first() + message = db.session.scalar( + select(Message).where(Message.id == message_file.message_id, Message.app_id == app_id).limit(1) ) if not message: raise FileAccessDeniedError("File access denied: not owned by requesting app") # Get the actual upload file record - upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() + upload_file = db.session.get(UploadFile, file_id) if not upload_file: raise FileNotFoundError("Upload file record not found") # Additional security: verify tenant isolation - app = db.session.query(App).where(App.id == app_id).first() + app = db.session.get(App, app_id) if app and upload_file.tenant_id != app.tenant_id: raise FileAccessDeniedError("File access denied: tenant mismatch") diff --git a/api/controllers/service_api/app/site.py b/api/controllers/service_api/app/site.py index 8b47a887bbe..bc06e8f386d 100644 --- a/api/controllers/service_api/app/site.py +++ b/api/controllers/service_api/app/site.py @@ -1,4 +1,5 @@ from flask_restx import Resource +from sqlalchemy import select from werkzeug.exceptions import Forbidden from controllers.common.fields import Site as SiteResponse @@ -28,7 +29,7 @@ class AppSiteApi(Resource): Returns the site configuration for the application including theme, icons, and text. """ - site = db.session.query(Site).where(Site.app_id == app_model.id).first() + site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) if not site: raise Forbidden() diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 35dd22c8013..17590751395 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -4,6 +4,9 @@ from typing import Any, Literal from dateutil.parser import isoparse from flask import request from flask_restx import Namespace, Resource, fields +from graphon.enums import WorkflowExecutionStatus +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from sqlalchemy.orm import Session, sessionmaker from werkzeug.exceptions import BadRequest, InternalServerError, NotFound @@ -27,9 +30,6 @@ from core.errors.error import ( QuotaExceededError, ) from core.helper.trace_id_helper import get_external_trace_id -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.graph_engine.manager import GraphEngineManager -from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 83d07087ab2..80205b283bc 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -2,6 +2,7 @@ from typing import Any, Literal, cast from flask import request from flask_restx import marshal +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field, TypeAdapter, field_validator from werkzeug.exceptions import Forbidden, NotFound @@ -14,8 +15,8 @@ from controllers.service_api.wraps import ( DatasetApiResource, cloud_edition_billing_rate_limit_check, ) -from core.provider_manager import ProviderManager -from dify_graph.model_runtime.entities.model_entities import ModelType +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager +from core.rag.index_processor.constant.index_type import IndexTechniqueType from fields.dataset_fields import dataset_detail_fields from fields.tag_fields import DataSetTag from libs.login import current_user @@ -139,10 +140,10 @@ class DatasetListApi(DatasetApiResource): query.page, query.limit, tenant_id, current_user, query.keyword, query.tag_ids, query.include_all ) # check embedding setting - provider_manager = ProviderManager() assert isinstance(current_user, Account) cid = current_user.current_tenant_id assert cid is not None + provider_manager = create_plugin_provider_manager(tenant_id=cid) configurations = provider_manager.get_configurations(tenant_id=cid) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) @@ -153,15 +154,20 @@ class DatasetListApi(DatasetApiResource): data = marshal(datasets, dataset_detail_fields) for item in data: - if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]: - item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"])) - item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" + if ( + item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY # pyrefly: ignore[bad-index] + and item["embedding_model_provider"] # pyrefly: ignore[bad-index] + ): + item["embedding_model_provider"] = str( # pyrefly: ignore[unsupported-operation] + ModelProviderID(item["embedding_model_provider"]) # pyrefly: ignore[bad-index] + ) + item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" # pyrefly: ignore[bad-index] if item_model in model_names: - item["embedding_available"] = True + item["embedding_available"] = True # type: ignore else: - item["embedding_available"] = False + item["embedding_available"] = False # type: ignore else: - item["embedding_available"] = True + item["embedding_available"] = True # type: ignore response = { "data": data, "has_more": len(datasets) == query.limit, @@ -253,10 +259,10 @@ class DatasetApi(DatasetApiResource): raise Forbidden(str(e)) data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields)) # check embedding setting - provider_manager = ProviderManager() assert isinstance(current_user, Account) cid = current_user.current_tenant_id assert cid is not None + provider_manager = create_plugin_provider_manager(tenant_id=cid) configurations = provider_manager.get_configurations(tenant_id=cid) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) @@ -265,7 +271,7 @@ class DatasetApi(DatasetApiResource): for embedding_model in embedding_models: model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") - if data.get("indexing_technique") == "high_quality": + if data.get("indexing_technique") == IndexTechniqueType.HIGH_QUALITY: item_model = f"{data.get('embedding_model')}:{data.get('embedding_model_provider')}" if item_model in model_names: data["embedding_available"] = True @@ -315,7 +321,7 @@ class DatasetApi(DatasetApiResource): # check embedding model setting embedding_model_provider = payload.embedding_model_provider embedding_model = payload.embedding_model - if payload.indexing_technique == "high_quality" or embedding_model_provider: + if payload.indexing_technique == IndexTechniqueType.HIGH_QUALITY or embedding_model_provider: if embedding_model_provider and embedding_model: DatasetService.check_embedding_model_setting( dataset.tenant_id, embedding_model_provider, embedding_model diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index d34b4124aeb..2c094aa3e6e 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -6,7 +6,7 @@ from uuid import UUID from flask import request, send_file from flask_restx import marshal from pydantic import BaseModel, Field, field_validator, model_validator -from sqlalchemy import desc, select +from sqlalchemy import desc, func, select from werkzeug.exceptions import Forbidden, NotFound import services @@ -155,7 +155,9 @@ class DocumentAddByTextApi(DatasetApiResource): dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise ValueError("Dataset does not exist.") @@ -238,7 +240,9 @@ class DocumentUpdateByTextApi(DatasetApiResource): def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID): """Update document by text.""" payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {}) - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).limit(1) + ) args = payload.model_dump(exclude_none=True) if not dataset: raise ValueError("Dataset does not exist.") @@ -315,7 +319,9 @@ class DocumentAddByFileApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): """Create document by upload file.""" - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise ValueError("Dataset does not exist.") @@ -425,7 +431,9 @@ class DocumentUpdateByFileApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id, document_id): """Update document by upload file.""" - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise ValueError("Dataset does not exist.") @@ -515,7 +523,9 @@ class DocumentListApi(DatasetApiResource): dataset_id = str(dataset_id) tenant_id = str(tenant_id) query_params = DocumentListQuery.model_validate(request.args.to_dict()) - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") @@ -609,7 +619,9 @@ class DocumentIndexingStatusApi(DatasetApiResource): batch = str(batch) tenant_id = str(tenant_id) # get dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # get documents @@ -619,20 +631,23 @@ class DocumentIndexingStatusApi(DatasetApiResource): documents_status = [] for document in documents: completed_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != SegmentStatus.RE_SEGMENT, + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) total_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) # Create a dictionary with document attributes and additional fields document_dict = { @@ -822,7 +837,9 @@ class DocumentApi(DatasetApiResource): tenant_id = str(tenant_id) # get dataset info - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise ValueError("Dataset does not exist.") diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 2e3b7fd85e8..5b16da81e08 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -2,7 +2,9 @@ from typing import Any from flask import request from flask_restx import marshal +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field +from sqlalchemy import select from werkzeug.exceptions import NotFound from configs import dify_config @@ -17,7 +19,7 @@ from controllers.service_api.wraps import ( ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager -from dify_graph.model_runtime.entities.model_entities import ModelType +from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from fields.segment_fields import child_chunk_fields, segment_fields from libs.login import current_account_with_tenant @@ -27,6 +29,31 @@ from services.entities.knowledge_entities.knowledge_entities import SegmentUpdat from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError +from services.summary_index_service import SummaryIndexService + + +def _marshal_segment_with_summary(segment, dataset_id: str) -> dict: + """Marshal a single segment and enrich it with summary content.""" + segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type] + summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id) + segment_dict["summary"] = summary.summary_content if summary else None + return segment_dict + + +def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict]: + """Marshal multiple segments and enrich them with summary content (batch query).""" + segment_ids = [segment.id for segment in segments] + summaries: dict = {} + if segment_ids: + summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id) + summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()} + + result = [] + for segment in segments: + segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type] + segment_dict["summary"] = summaries.get(segment.id) + result.append(segment_dict) + return result class SegmentCreatePayload(BaseModel): @@ -91,7 +118,9 @@ class SegmentApi(DatasetApiResource): _, current_tenant_id = current_account_with_tenant() """Create single segment.""" # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # check document @@ -103,9 +132,9 @@ class SegmentApi(DatasetApiResource): if not document.enabled: raise NotFound("Document is disabled.") # check embedding model setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -128,7 +157,7 @@ class SegmentApi(DatasetApiResource): for args_item in payload.segments: SegmentService.segment_create_args_validate(args_item, document) segments = SegmentService.multi_create_segment(payload.segments, document, dataset) - return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200 + return {"data": _marshal_segments_with_summary(segments, dataset_id), "doc_form": document.doc_form}, 200 else: return {"error": "Segments is required"}, 400 @@ -149,7 +178,9 @@ class SegmentApi(DatasetApiResource): # check dataset page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # check document @@ -157,9 +188,9 @@ class SegmentApi(DatasetApiResource): if not document: raise NotFound("Document not found.") # check embedding model setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -190,7 +221,7 @@ class SegmentApi(DatasetApiResource): ) response = { - "data": marshal(segments, segment_fields), + "data": _marshal_segments_with_summary(segments, dataset_id), "doc_form": document.doc_form, "total": total, "has_more": len(segments) == limit, @@ -219,7 +250,9 @@ class DatasetSegmentApi(DatasetApiResource): def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): _, current_tenant_id = current_account_with_tenant() # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # check user's model setting @@ -253,7 +286,9 @@ class DatasetSegmentApi(DatasetApiResource): def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): _, current_tenant_id = current_account_with_tenant() # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # check user's model setting @@ -262,10 +297,10 @@ class DatasetSegmentApi(DatasetApiResource): document = DocumentService.get_document(dataset_id, document_id) if not document: raise NotFound("Document not found.") - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -286,7 +321,7 @@ class DatasetSegmentApi(DatasetApiResource): payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {}) updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset) - return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200 + return {"data": _marshal_segment_with_summary(updated_segment, dataset_id), "doc_form": document.doc_form}, 200 @service_api_ns.doc("get_segment") @service_api_ns.doc(description="Get a specific segment by ID") @@ -300,7 +335,9 @@ class DatasetSegmentApi(DatasetApiResource): def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): _, current_tenant_id = current_account_with_tenant() # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # check user's model setting @@ -314,7 +351,7 @@ class DatasetSegmentApi(DatasetApiResource): if not segment: raise NotFound("Segment not found.") - return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 + return {"data": _marshal_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200 @service_api_ns.route( @@ -343,7 +380,9 @@ class ChildChunkApi(DatasetApiResource): _, current_tenant_id = current_account_with_tenant() """Create child chunk.""" # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") @@ -358,9 +397,9 @@ class ChildChunkApi(DatasetApiResource): raise NotFound("Segment not found.") # check embedding model setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -401,7 +440,9 @@ class ChildChunkApi(DatasetApiResource): _, current_tenant_id = current_account_with_tenant() """Get child chunks.""" # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") @@ -467,7 +508,9 @@ class DatasetChildChunkApi(DatasetApiResource): _, current_tenant_id = current_account_with_tenant() """Delete child chunk.""" # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") @@ -526,7 +569,9 @@ class DatasetChildChunkApi(DatasetApiResource): _, current_tenant_id = current_account_with_tenant() """Update child chunk.""" # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") diff --git a/api/controllers/service_api/workspace/models.py b/api/controllers/service_api/workspace/models.py index 35aed40a598..c0a6cb0a763 100644 --- a/api/controllers/service_api/workspace/models.py +++ b/api/controllers/service_api/workspace/models.py @@ -1,9 +1,9 @@ from flask_login import current_user from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_dataset_token -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from services.model_provider_service import ModelProviderService diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 7aa5b2f0925..1d52b8a737a 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -9,6 +9,7 @@ from flask import current_app, request from flask_login import user_logged_in from flask_restx import Resource from pydantic import BaseModel +from sqlalchemy import select from werkzeug.exceptions import Forbidden, NotFound, Unauthorized from enums.cloud_plan import CloudPlan @@ -62,7 +63,7 @@ def validate_app_token( def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R: api_token = validate_and_get_api_token("app") - app_model = db.session.query(App).where(App.id == api_token.app_id).first() + app_model = db.session.get(App, api_token.app_id) if not app_model: raise Forbidden("The app no longer exists.") @@ -72,7 +73,7 @@ def validate_app_token( if not app_model.enable_api: raise Forbidden("The app's API service has been disabled.") - tenant = db.session.query(Tenant).where(Tenant.id == app_model.tenant_id).first() + tenant = db.session.get(Tenant, app_model.tenant_id) if tenant is None: raise ValueError("Tenant does not exist.") if tenant.status == TenantStatus.ARCHIVE: @@ -106,8 +107,8 @@ def validate_app_token( else: # For service API without end-user context, ensure an Account is logged in # so services relying on current_account_with_tenant() work correctly. - tenant_owner_info = ( - db.session.query(Tenant, Account) + tenant_owner_info = db.session.execute( + select(Tenant, Account) .join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id) .join(Account, TenantAccountJoin.account_id == Account.id) .where( @@ -115,8 +116,7 @@ def validate_app_token( TenantAccountJoin.role == "owner", Tenant.status == TenantStatus.NORMAL, ) - .one_or_none() - ) + ).one_or_none() if tenant_owner_info: tenant_model, account = tenant_owner_info @@ -277,29 +277,28 @@ def validate_dataset_token( # Validate dataset if dataset_id is provided if dataset_id: dataset_id = str(dataset_id) - dataset = ( - db.session.query(Dataset) + dataset = db.session.scalar( + select(Dataset) .where( Dataset.id == dataset_id, Dataset.tenant_id == api_token.tenant_id, ) - .first() + .limit(1) ) if not dataset: raise NotFound("Dataset not found.") if not dataset.enable_api: raise Forbidden("Dataset api access is not enabled.") - tenant_account_join = ( - db.session.query(Tenant, TenantAccountJoin) + tenant_account_join = db.session.execute( + select(Tenant, TenantAccountJoin) .where(Tenant.id == api_token.tenant_id) .where(TenantAccountJoin.tenant_id == Tenant.id) .where(TenantAccountJoin.role.in_(["owner"])) .where(Tenant.status == TenantStatus.NORMAL) - .one_or_none() - ) # TODO: only owner information is required, so only one is returned. + ).one_or_none() # TODO: only owner information is required, so only one is returned. if tenant_account_join: tenant, ta = tenant_account_join - account = db.session.query(Account).where(Account.id == ta.account_id).first() + account = db.session.get(Account, ta.account_id) # Login admin if account: account.current_tenant = tenant @@ -360,7 +359,9 @@ class DatasetApiResource(Resource): method_decorators = [validate_dataset_token] def get_dataset(self, dataset_id: str, tenant_id: str) -> Dataset: - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 2b8f7526681..9ba1dc4a3ac 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -2,6 +2,7 @@ import logging from flask import request from flask_restx import fields, marshal_with +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, field_validator from werkzeug.exceptions import InternalServerError @@ -20,7 +21,6 @@ from controllers.web.error import ( ) from controllers.web.wraps import WebApiResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value from models.model import App from services.audio_service import AudioService diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 8634c1f43c0..e37f9af5f0a 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -1,6 +1,7 @@ import logging from typing import Any, Literal +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound @@ -25,7 +26,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from models.model import AppMode diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index aa562926141..c5505dd60de 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -2,6 +2,7 @@ import logging from typing import Literal from flask import request +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, TypeAdapter, field_validator from werkzeug.exceptions import InternalServerError, NotFound @@ -20,7 +21,6 @@ from controllers.web.error import ( from controllers.web.wraps import WebApiResource from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError from fields.conversation_fields import ResultResponse from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem from libs import helper diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index 6a93ef67484..38aeccc642b 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -1,6 +1,7 @@ import urllib.parse import httpx +from graphon.file import helpers as file_helpers from pydantic import BaseModel, Field, HttpUrl import services @@ -11,7 +12,6 @@ from controllers.common.errors import ( UnsupportedFileTypeError, ) from core.helper import ssrf_proxy -from dify_graph.file import helpers as file_helpers from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo from services.file_service import FileService diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index 508d1a756a7..7f5521f9f58 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -1,6 +1,8 @@ import logging from typing import Any +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError @@ -22,8 +24,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.graph_engine.manager import GraphEngineManager -from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_redis import redis_client from libs import helper from models.model import App, AppMode, EndUser diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 1bdc8df813f..06c746990d2 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -4,7 +4,21 @@ import uuid from decimal import Decimal from typing import Union, cast -from sqlalchemy import select +from graphon.file import file_manager +from graphon.model_runtime.entities import ( + AssistantPromptMessage, + LLMUsage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes +from graphon.model_runtime.entities.model_entities import ModelFeature +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from sqlalchemy import func, select from core.agent.entities import AgentEntity, AgentToolEntity from core.app.app_config.features.file_upload.manager import FileUploadConfigManager @@ -15,6 +29,7 @@ from core.app.entities.app_invoke_entities import ( AgentChatAppGenerateEntity, ModelConfigWithCredentialsEntity, ) +from core.app.file_access import DatabaseFileAccessController from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.memory.token_buffer_memory import TokenBufferMemory @@ -26,26 +41,13 @@ from core.tools.entities.tool_entities import ( ) from core.tools.tool_manager import ToolManager from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool -from dify_graph.file import file_manager -from dify_graph.model_runtime.entities import ( - AssistantPromptMessage, - LLMUsage, - PromptMessage, - PromptMessageTool, - SystemPromptMessage, - TextPromptMessageContent, - ToolPromptMessage, - UserPromptMessage, -) -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes -from dify_graph.model_runtime.entities.model_entities import ModelFeature -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from extensions.ext_database import db from factories import file_factory from models.enums import CreatorUserRole from models.model import Conversation, Message, MessageAgentThought, MessageFile logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() class BaseAgentRunner(AppRunner): @@ -102,11 +104,14 @@ class BaseAgentRunner(AppRunner): ) # get how many agent thoughts have been created self.agent_thought_count = ( - db.session.query(MessageAgentThought) - .where( - MessageAgentThought.message_id == self.message.id, + db.session.scalar( + select(func.count()) + .select_from(MessageAgentThought) + .where( + MessageAgentThought.message_id == self.message.id, + ) ) - .count() + or 0 ) db.session.close() @@ -138,6 +143,7 @@ class BaseAgentRunner(AppRunner): tenant_id=self.tenant_id, app_id=self.app_config.app_id, agent_tool=tool, + user_id=self.user_id, invoke_from=self.application_generate_entity.invoke_from, ) assert tool_entity.entity.description @@ -524,7 +530,10 @@ class BaseAgentRunner(AppRunner): image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW file_objs = file_factory.build_from_message_files( - message_files=files, tenant_id=self.tenant_id, config=file_extra_config + message_files=files, + tenant_id=self.tenant_id, + config=file_extra_config, + access_controller=_file_access_controller, ) if not file_objs: return UserPromptMessage(content=message.query) diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 9271ed10bda..11e2aa062d2 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -4,6 +4,15 @@ from abc import ABC, abstractmethod from collections.abc import Generator, Mapping, Sequence from typing import Any +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + ToolPromptMessage, + UserPromptMessage, +) + from core.agent.base_agent_runner import BaseAgentRunner from core.agent.entities import AgentScratchpadUnit from core.agent.errors import AgentMaxIterationError @@ -15,14 +24,6 @@ from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransfo from core.tools.__base.tool import Tool from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.tool_engine import ToolEngine -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from dify_graph.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessage, - PromptMessageTool, - ToolPromptMessage, - UserPromptMessage, -) from models.model import Message logger = logging.getLogger(__name__) @@ -122,7 +123,6 @@ class CotAgentRunner(BaseAgentRunner, ABC): tools=[], stop=app_generate_entity.model_conf.stop, stream=True, - user=self.user_id, callbacks=[], ) diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index 89451a0498b..a4c438e9296 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -1,16 +1,17 @@ import json -from core.agent.cot_agent_runner import CotAgentRunner -from dify_graph.file import file_manager -from dify_graph.model_runtime.entities import ( +from graphon.file import file_manager +from graphon.model_runtime.entities import ( AssistantPromptMessage, PromptMessage, SystemPromptMessage, TextPromptMessageContent, UserPromptMessage, ) -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes +from graphon.model_runtime.utils.encoders import jsonable_encoder + +from core.agent.cot_agent_runner import CotAgentRunner class CotChatAgentRunner(CotAgentRunner): diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py index 3023b9bc4d4..d4c52a8eb16 100644 --- a/api/core/agent/cot_completion_agent_runner.py +++ b/api/core/agent/cot_completion_agent_runner.py @@ -1,13 +1,14 @@ import json -from core.agent.cot_agent_runner import CotAgentRunner -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, TextPromptMessageContent, UserPromptMessage, ) -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.utils.encoders import jsonable_encoder + +from core.agent.cot_agent_runner import CotAgentRunner class CotCompletionAgentRunner(CotAgentRunner): diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 5e13a13b215..fdffde85d01 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -4,15 +4,8 @@ from collections.abc import Generator from copy import deepcopy from typing import Any, Union -from core.agent.base_agent_runner import BaseAgentRunner -from core.agent.errors import AgentMaxIterationError -from core.app.apps.base_app_queue_manager import PublishFrom -from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent -from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform -from core.tools.entities.tool_entities import ToolInvokeMeta -from core.tools.tool_engine import ToolEngine -from dify_graph.file import file_manager -from dify_graph.model_runtime.entities import ( +from graphon.file import file_manager +from graphon.model_runtime.entities import ( AssistantPromptMessage, LLMResult, LLMResultChunk, @@ -25,7 +18,15 @@ from dify_graph.model_runtime.entities import ( ToolPromptMessage, UserPromptMessage, ) -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes + +from core.agent.base_agent_runner import BaseAgentRunner +from core.agent.errors import AgentMaxIterationError +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent +from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform +from core.tools.entities.tool_entities import ToolInvokeMeta +from core.tools.tool_engine import ToolEngine from models.model import Message logger = logging.getLogger(__name__) @@ -96,7 +97,6 @@ class FunctionCallAgentRunner(BaseAgentRunner): tools=prompt_messages_tools, stop=app_generate_entity.model_conf.stop, stream=self.stream_tool_call, - user=self.user_id, callbacks=[], ) diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py index 82676f1ebda..46c1f1230d0 100644 --- a/api/core/agent/output_parser/cot_output_parser.py +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -3,8 +3,9 @@ import re from collections.abc import Generator from typing import Union +from graphon.model_runtime.entities.llm_entities import LLMResultChunk + from core.agent.entities import AgentScratchpadUnit -from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk class CotAgentOutputParser: diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index 558b6e69a0e..b7dd55632e2 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -1,13 +1,14 @@ from typing import cast +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel + from core.app.app_config.entities import EasyUIBasedAppConfig from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.model_entities import ModelStatus from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.provider_manager import ProviderManager -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager class ModelConfigConverter: @@ -21,7 +22,7 @@ class ModelConfigConverter: """ model_config = app_config.model - provider_manager = ProviderManager() + provider_manager = create_plugin_provider_manager(tenant_id=app_config.tenant_id) provider_model_bundle = provider_manager.get_provider_model_bundle( tenant_id=app_config.tenant_id, provider=model_config.provider, model_type=ModelType.LLM ) diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index 0929f52e337..5cc385c3781 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -1,10 +1,10 @@ from collections.abc import Mapping from typing import Any +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType + from core.app.app_config.entities import ModelConfigEntity -from core.provider_manager import ProviderManager -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly from models.model import AppModelConfigDict from models.provider_ids import ModelProviderID @@ -54,9 +54,12 @@ class ModelConfigManager: if not isinstance(config["model"], dict): raise ValueError("model must be of object type") + # Keep provider discovery and provider-backed model listing on the same + # request-scoped runtime so caller scope and provider caches stay aligned. + assembly = create_plugin_model_assembly(tenant_id=tenant_id) + # model.provider - model_provider_factory = ModelProviderFactory(tenant_id) - provider_entities = model_provider_factory.get_providers() + provider_entities = assembly.model_provider_factory.get_providers() model_provider_names = [provider.provider for provider in provider_entities] if "provider" not in config["model"]: raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") @@ -71,8 +74,7 @@ class ModelConfigManager: if "name" not in config["model"]: raise ValueError("model.name is required") - provider_manager = ProviderManager() - models = provider_manager.get_configurations(tenant_id).get_models( + models = assembly.provider_manager.get_configurations(tenant_id).get_models( provider=config["model"]["provider"], model_type=ModelType.LLM ) diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index b7073898d6d..76196e7034e 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -1,5 +1,7 @@ from typing import Any +from graphon.model_runtime.entities.message_entities import PromptMessageRole + from core.app.app_config.entities import ( AdvancedChatMessageEntity, AdvancedChatPromptTemplateEntity, @@ -7,7 +9,6 @@ from core.app.app_config.entities import ( PromptTemplateEntity, ) from core.prompt.simple_prompt_transform import ModelMode -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole from models.model import AppMode, AppModelConfigDict diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py index 8de1224a89a..f0b71c58016 100644 --- a/api/core/app/app_config/easy_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -1,9 +1,10 @@ import re from typing import cast +from graphon.variables.input_entities import VariableEntity, VariableEntityType + from core.app.app_config.entities import ExternalDataVariableEntity from core.external_data_tool.factory import ExternalDataToolFactory -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType from models.model import AppModelConfigDict _ALLOWED_VARIABLE_ENTITY_TYPE = frozenset( diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 95ea70bc40a..536617edba4 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -2,13 +2,13 @@ from collections.abc import Sequence from enum import StrEnum, auto from typing import Any, Literal +from graphon.file import FileUploadConfig +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.message_entities import PromptMessageRole +from graphon.variables.input_entities import VariableEntity as WorkflowVariableEntity from pydantic import BaseModel, Field from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict -from dify_graph.file import FileUploadConfig -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole -from dify_graph.variables.input_entities import VariableEntity as WorkflowVariableEntity from models.model import AppMode diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index 0c4266fbebb..e96517c4264 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -1,8 +1,9 @@ from collections.abc import Mapping from typing import Any +from graphon.file import FileUploadConfig + from constants import DEFAULT_FILE_NUMBER_LIMITS -from dify_graph.file import FileUploadConfig class FileUploadConfigManager: diff --git a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py index d2a9a73380e..62e0c31d1ae 100644 --- a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py @@ -1,7 +1,8 @@ import re +from graphon.variables.input_entities import VariableEntity + from core.app.app_config.entities import RagPipelineVariableEntity -from dify_graph.variables.input_entities import VariableEntity from models.workflow import Workflow diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 5d974335ff4..aa2b65766f8 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -5,7 +5,7 @@ import logging import threading import uuid from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload +from typing import TYPE_CHECKING, Any, Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -18,12 +18,23 @@ from constants import UUID_NIL if TYPE_CHECKING: from controllers.console.app.workflow import LoopNodeRunPayload +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.runtime import GraphRuntimeState +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader + from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter -from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline +from core.app.apps.advanced_chat.generate_task_pipeline import ( + AdvancedChatAppGenerateTaskPipeline, + ConversationSnapshot, + MessageSnapshot, + WorkflowSnapshot, +) from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.draft_variable_saver import DraftVariableSaverFactory from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager @@ -34,20 +45,11 @@ from core.helper.trace_id_helper import extract_external_trace_id_from_args from core.ops.ops_trace_manager import TraceQueueManager from core.prompt.utils.get_thread_messages_length import get_thread_messages_length from core.repositories import DifyCoreRepositoryFactory -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError -from dify_graph.repositories.draft_variable_repository import ( - DraftVariableSaverFactory, -) -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from dify_graph.runtime import GraphRuntimeState -from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from extensions.ext_database import db from factories import file_factory from libs.flask_utils import preserve_flask_contexts from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom -from models.base import Base from models.enums import WorkflowRunTriggeredFrom from services.conversation_service import ConversationService from services.workflow_draft_variable_service import ( @@ -150,85 +152,87 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # # For implementation reference, see the `_parse_file` function and # `DraftWorkflowNodeRunApi` class which handle this properly. - files = args["files"] if args.get("files") else [] - file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) - if file_extra_config: - file_objs = file_factory.build_from_mappings( - mappings=files, - tenant_id=app_model.tenant_id, - config=file_extra_config, + with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from): + files = args["files"] if args.get("files") else [] + file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) + if file_extra_config: + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + config=file_extra_config, + access_controller=self._file_access_controller, + ) + else: + file_objs = [] + + # convert to app config + app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) + + # get tracing instance + trace_manager = TraceQueueManager( + app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id ) - else: - file_objs = [] - # convert to app config - app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) + if invoke_from == InvokeFrom.DEBUGGER: + # always enable retriever resource in debugger mode + app_config.additional_features.show_retrieve_source = True # type: ignore - # get tracing instance - trace_manager = TraceQueueManager( - app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id - ) + # init application generate entity + application_generate_entity = AdvancedChatAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + file_upload_config=file_extra_config, + conversation_id=conversation.id if conversation else None, + inputs=self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), + query=query, + files=list(file_objs), + parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, + user_id=user.id, + stream=streaming, + invoke_from=invoke_from, + extras=extras, + trace_manager=trace_manager, + workflow_run_id=str(workflow_run_id), + ) + contexts.plugin_tool_providers.set({}) + contexts.plugin_tool_providers_lock.set(threading.Lock()) - if invoke_from == InvokeFrom.DEBUGGER: - # always enable retriever resource in debugger mode - app_config.additional_features.show_retrieve_source = True # type: ignore + # Create repositories + # + # Create session factory + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + # Create workflow execution(aka workflow run) repository + if invoke_from == InvokeFrom.DEBUGGER: + workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING + else: + workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=workflow_triggered_from, + ) + # Create workflow node execution repository + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) - # init application generate entity - application_generate_entity = AdvancedChatAppGenerateEntity( - task_id=str(uuid.uuid4()), - app_config=app_config, - file_upload_config=file_extra_config, - conversation_id=conversation.id if conversation else None, - inputs=self._prepare_user_inputs( - user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id - ), - query=query, - files=list(file_objs), - parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, - user_id=user.id, - stream=streaming, - invoke_from=invoke_from, - extras=extras, - trace_manager=trace_manager, - workflow_run_id=str(workflow_run_id), - ) - contexts.plugin_tool_providers.set({}) - contexts.plugin_tool_providers_lock.set(threading.Lock()) - - # Create repositories - # - # Create session factory - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - # Create workflow execution(aka workflow run) repository - if invoke_from == InvokeFrom.DEBUGGER: - workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING - else: - workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN - workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( - session_factory=session_factory, - user=user, - app_id=application_generate_entity.app_config.app_id, - triggered_from=workflow_triggered_from, - ) - # Create workflow node execution repository - workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( - session_factory=session_factory, - user=user, - app_id=application_generate_entity.app_config.app_id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, - ) - - return self._generate( - workflow=workflow, - user=user, - invoke_from=invoke_from, - application_generate_entity=application_generate_entity, - workflow_execution_repository=workflow_execution_repository, - workflow_node_execution_repository=workflow_node_execution_repository, - conversation=conversation, - stream=streaming, - pause_state_config=pause_state_config, - ) + return self._generate( + workflow=workflow, + user=user, + invoke_from=invoke_from, + application_generate_entity=application_generate_entity, + workflow_execution_repository=workflow_execution_repository, + workflow_node_execution_repository=workflow_node_execution_repository, + conversation=conversation, + stream=streaming, + pause_state_config=pause_state_config, + ) def resume( self, @@ -460,94 +464,91 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): :param conversation: conversation :param stream: is stream """ - is_first_conversation = conversation is None + with self._bind_file_access_scope( + tenant_id=application_generate_entity.app_config.tenant_id, + user=user, + invoke_from=invoke_from, + ): + is_first_conversation = conversation is None - if conversation is not None and message is not None: - pass - else: - conversation, message = self._init_generate_records(application_generate_entity, conversation) + if conversation is not None and message is not None: + pass + else: + conversation, message = self._init_generate_records(application_generate_entity, conversation) - if is_first_conversation: - # update conversation features - conversation.override_model_configs = workflow.features - db.session.commit() - db.session.refresh(conversation) + if is_first_conversation: + # update conversation features + conversation.override_model_configs = workflow.features + db.session.commit() + db.session.refresh(conversation) - # get conversation dialogue count - # NOTE: dialogue_count should not start from 0, - # because during the first conversation, dialogue_count should be 1. - self._dialogue_count = get_thread_messages_length(conversation.id) + 1 + # get conversation dialogue count + # NOTE: dialogue_count should not start from 0, + # because during the first conversation, dialogue_count should be 1. + self._dialogue_count = get_thread_messages_length(conversation.id) + 1 - # init queue manager - queue_manager = MessageBasedAppQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - conversation_id=conversation.id, - app_mode=conversation.mode, - message_id=message.id, - ) - - graph_layers: list[GraphEngineLayer] = list(graph_engine_layers) - if pause_state_config is not None: - graph_layers.append( - PauseStatePersistenceLayer( - session_factory=pause_state_config.session_factory, - generate_entity=application_generate_entity, - state_owner_user_id=pause_state_config.state_owner_user_id, - ) + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, + message_id=message.id, ) - # new thread with request context and contextvars - context = contextvars.copy_context() + graph_layers: list[GraphEngineLayer] = list(graph_engine_layers) + if pause_state_config is not None: + graph_layers.append( + PauseStatePersistenceLayer( + session_factory=pause_state_config.session_factory, + generate_entity=application_generate_entity, + state_owner_user_id=pause_state_config.state_owner_user_id, + ) + ) - worker_thread = threading.Thread( - target=self._generate_worker, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "application_generate_entity": application_generate_entity, - "queue_manager": queue_manager, - "conversation_id": conversation.id, - "message_id": message.id, - "context": context, - "variable_loader": variable_loader, - "workflow_execution_repository": workflow_execution_repository, - "workflow_node_execution_repository": workflow_node_execution_repository, - "graph_engine_layers": tuple(graph_layers), - "graph_runtime_state": graph_runtime_state, - }, - ) + # new thread with request context and contextvars + context = contextvars.copy_context() - worker_thread.start() + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "conversation_id": conversation.id, + "message_id": message.id, + "context": context, + "variable_loader": variable_loader, + "workflow_execution_repository": workflow_execution_repository, + "workflow_node_execution_repository": workflow_node_execution_repository, + "graph_engine_layers": tuple(graph_layers), + "graph_runtime_state": graph_runtime_state, + }, + ) - # release database connection, because the following new thread operations may take a long time - with Session(bind=db.engine, expire_on_commit=False) as session: - workflow = _refresh_model(session, workflow) - message = _refresh_model(session, message) - # workflow_ = session.get(Workflow, workflow.id) - # assert workflow_ is not None - # workflow = workflow_ - # message_ = session.get(Message, message.id) - # assert message_ is not None - # message = message_ - # db.session.refresh(workflow) - # db.session.refresh(message) - # db.session.refresh(user) - db.session.close() + worker_thread.start() - # return response or stream generator - response = self._handle_advanced_chat_response( - application_generate_entity=application_generate_entity, - workflow=workflow, - queue_manager=queue_manager, - conversation=conversation, - message=message, - user=user, - stream=stream, - draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user), - ) + # Capture the scalar fields needed by the response pipeline before + # releasing the request-scoped SQLAlchemy session. + workflow_snapshot = WorkflowSnapshot.from_workflow(workflow) + conversation_snapshot = ConversationSnapshot.from_conversation(conversation) + message_snapshot = MessageSnapshot.from_message(message) + db.session.close() - return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + # return response or stream generator + response = self._handle_advanced_chat_response( + application_generate_entity=application_generate_entity, + workflow=workflow_snapshot, + queue_manager=queue_manager, + conversation=conversation_snapshot, + message=message_snapshot, + user=user, + stream=stream, + draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user), + ) + + return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def _generate_worker( self, @@ -648,10 +649,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): self, *, application_generate_entity: AdvancedChatAppGenerateEntity, - workflow: Workflow, + workflow: WorkflowSnapshot, queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, + conversation: ConversationSnapshot, + message: MessageSnapshot, user: Union[Account, EndUser], draft_var_saver_factory: DraftVariableSaverFactory, stream: bool = False, @@ -688,13 +689,3 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): else: logger.exception("Failed to process generate task pipeline, conversation_id: %s", conversation.id) raise e - - -_T = TypeVar("_T", bound=Base) - - -def _refresh_model(session, model: _T) -> _T: - with Session(bind=db.engine, expire_on_commit=False) as session: - detach_model = session.get(type(model), model.id) - assert detach_model is not None - return detach_model diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 66037696af3..a884a1c7f9b 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -3,6 +3,12 @@ import time from collections.abc import Mapping, Sequence from typing import Any, cast +from graphon.enums import WorkflowType +from graphon.graph_engine.command_channels import RedisChannel +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader +from graphon.variables.variables import Variable from sqlalchemy import select from sqlalchemy.orm import Session @@ -25,16 +31,15 @@ from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, Workfl from core.db.session_factory import session_factory from core.moderation.base import ModerationError from core.moderation.input_moderation import InputModeration +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository +from core.workflow.node_factory import get_default_root_node_id +from core.workflow.system_variables import ( + build_bootstrap_variables, + build_system_variables, + system_variables_to_mapping, +) +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry -from dify_graph.enums import WorkflowType -from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variable_loader import VariableLoader -from dify_graph.variables.variables import Variable from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.otel import WorkflowAppRunnerHandler, trace_span @@ -90,7 +95,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): app_config = self.application_generate_entity.app_config app_config = cast(AdvancedChatAppConfig, app_config) - system_inputs = SystemVariable( + system_inputs = build_system_variables( query=self.application_generate_entity.query, files=self.application_generate_entity.files, conversation_id=self.conversation.id, @@ -132,6 +137,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): workflow=self._workflow, single_iteration_run=self.application_generate_entity.single_iteration_run, single_loop_run=self.application_generate_entity.single_loop_run, + user_id=self.application_generate_entity.user_id, ) else: inputs = self.application_generate_entity.inputs @@ -150,7 +156,10 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): self.application_generate_entity.inputs = new_inputs self.application_generate_entity.query = new_query - system_inputs.query = new_query + system_inputs = build_system_variables( + system_variables_to_mapping(system_inputs), + query=new_query, + ) # annotation reply if self.handle_annotation_reply( @@ -166,14 +175,17 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): # Create a variable pool. # init variable pool - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=new_inputs, - environment_variables=self._workflow.environment_variables, - # Based on the definition of `Variable`, - # `VariableBase` instances can be safely used as `Variable` since they are compatible. - conversation_variables=conversation_variables, + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_inputs, + environment_variables=self._workflow.environment_variables, + conversation_variables=conversation_variables, + ), ) + root_node_id = get_default_root_node_id(self._workflow.graph_dict) + add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=new_inputs) # init graph graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.time()) @@ -185,6 +197,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): user_id=self.application_generate_entity.user_id, user_from=user_from, invoke_from=invoke_from, + root_node_id=root_node_id, ) db.session.close() diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index f7b5030d339..5203de225cc 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -4,9 +4,17 @@ import re import time from collections.abc import Callable, Generator, Mapping from contextlib import contextmanager +from dataclasses import dataclass +from datetime import datetime from threading import Thread from typing import Any, Union +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes import BuiltinNodeTypes +from graphon.runtime import GraphRuntimeState from sqlalchemy import select from sqlalchemy.orm import Session @@ -14,6 +22,7 @@ from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter +from core.app.apps.draft_variable_saver import DraftVariableSaverFactory from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, InvokeFrom, @@ -65,24 +74,66 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.ops.ops_trace_manager import TraceQueueManager from core.repositories.human_input_repository import HumanInputFormRepositoryImpl -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.nodes import BuiltinNodeTypes -from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory -from dify_graph.runtime import GraphRuntimeState -from dify_graph.system_variable import SystemVariable +from core.workflow.file_reference import resolve_file_record_id +from core.workflow.system_variables import build_system_variables from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models import Account, Conversation, EndUser, Message, MessageFile from models.enums import CreatorUserRole, MessageFileBelongsTo, MessageStatus from models.execution_extra_content import HumanInputContent +from models.model import AppMode from models.workflow import Workflow logger = logging.getLogger(__name__) +@dataclass(frozen=True, slots=True) +class WorkflowSnapshot: + id: str + tenant_id: str + features_dict: Mapping[str, Any] + + @classmethod + def from_workflow(cls, workflow: Workflow) -> "WorkflowSnapshot": + return cls( + id=workflow.id, + tenant_id=workflow.tenant_id, + features_dict=dict(workflow.features_dict), + ) + + +@dataclass(frozen=True, slots=True) +class ConversationSnapshot: + id: str + mode: AppMode + + @classmethod + def from_conversation(cls, conversation: Conversation) -> "ConversationSnapshot": + return cls( + id=conversation.id, + mode=conversation.mode, + ) + + +@dataclass(frozen=True, slots=True) +class MessageSnapshot: + id: str + query: str + created_at: datetime + status: MessageStatus + answer: str + + @classmethod + def from_message(cls, message: Message) -> "MessageSnapshot": + return cls( + id=message.id, + query=message.query, + created_at=message.created_at, + status=message.status, + answer=message.answer, + ) + + class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): """ AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. @@ -91,10 +142,10 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): def __init__( self, application_generate_entity: AdvancedChatAppGenerateEntity, - workflow: Workflow, + workflow: WorkflowSnapshot, queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, + conversation: ConversationSnapshot, + message: MessageSnapshot, user: Union[Account, EndUser], stream: bool, dialogue_count: int, @@ -117,7 +168,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): else: raise NotImplementedError(f"User type not supported: {type(user)}") - self._workflow_system_variables = SystemVariable( + self._workflow_system_variables = build_system_variables( query=message.query, files=application_generate_entity.files, conversation_id=conversation.id, @@ -155,7 +206,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): self._message_saved_on_pause = False self._seed_graph_runtime_state_from_queue_manager() - def _seed_task_state_from_message(self, message: Message) -> None: + def _seed_task_state_from_message(self, message: MessageSnapshot) -> None: if message.status == MessageStatus.PAUSED and message.answer: self._task_state.answer = message.answer @@ -741,8 +792,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): def _load_human_input_form_id(self, *, node_id: str) -> str | None: form_repository = HumanInputFormRepositoryImpl( tenant_id=self._workflow_tenant_id, + workflow_execution_id=self._workflow_run_id, ) - form = form_repository.get_form(self._workflow_run_id, node_id) + form = form_repository.get_form(node_id) if form is None: return None return form.id @@ -933,21 +985,23 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): metadata = self._task_state.metadata.model_dump() message.message_metadata = json.dumps(jsonable_encoder(metadata)) - message_files = [ - MessageFile( - message_id=message.id, - type=file["type"], - transfer_method=file["transfer_method"], - url=file["remote_url"], - belongs_to=MessageFileBelongsTo.ASSISTANT, - upload_file_id=file["related_id"], - created_by_role=CreatorUserRole.ACCOUNT - if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} - else CreatorUserRole.END_USER, - created_by=message.from_account_id or message.from_end_user_id or "", + message_files: list[MessageFile] = [] + for file in self._recorded_files: + reference = file.get("reference") or file.get("related_id") + message_files.append( + MessageFile( + message_id=message.id, + type=file["type"], + transfer_method=file["transfer_method"], + url=file["remote_url"], + belongs_to=MessageFileBelongsTo.ASSISTANT, + upload_file_id=resolve_file_record_id(reference if isinstance(reference, str) else None), + created_by_role=CreatorUserRole.ACCOUNT + if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} + else CreatorUserRole.END_USER, + created_by=message.from_account_id or message.from_end_user_id or "", + ) ) - for file in self._recorded_files - ] session.add_all(message_files) def _seed_graph_runtime_state_from_queue_manager(self) -> None: @@ -1003,13 +1057,11 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): return message def _save_output_for_event(self, event: QueueNodeSucceededEvent | QueueNodeExceptionEvent, node_execution_id: str): - with Session(db.engine) as session, session.begin(): - saver = self._draft_var_saver_factory( - session=session, - app_id=self._application_generate_entity.app_config.app_id, - node_id=event.node_id, - node_type=event.node_type, - node_execution_id=node_execution_id, - enclosing_node_id=event.in_loop_id or event.in_iteration_id, - ) - saver.save(event.process_data, event.outputs) + saver = self._draft_var_saver_factory( + app_id=self._application_generate_entity.app_config.app_id, + node_id=event.node_id, + node_type=event.node_type, + node_execution_id=node_execution_id, + enclosing_node_id=event.in_loop_id or event.in_iteration_id, + ) + saver.save(event.process_data, event.outputs) diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 76a067d7b66..bb258af4c16 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -6,6 +6,7 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, Union, overload from flask import Flask, current_app +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from configs import dify_config @@ -21,7 +22,6 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom from core.ops.ops_trace_manager import TraceQueueManager -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError from extensions.ext_database import db from factories import file_factory from libs.flask_utils import preserve_flask_contexts @@ -129,89 +129,93 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): # # For implementation reference, see the `_parse_file` function and # `DraftWorkflowNodeRunApi` class which handle this properly. - files = args.get("files") or [] - file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) - if file_extra_config: - file_objs = file_factory.build_from_mappings( - mappings=files, - tenant_id=app_model.tenant_id, - config=file_extra_config, + with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from): + files = args.get("files") or [] + file_extra_config = FileUploadConfigManager.convert( + override_model_config_dict or app_model_config.to_dict() ) - else: - file_objs = [] + if file_extra_config: + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + config=file_extra_config, + access_controller=self._file_access_controller, + ) + else: + file_objs = [] - # convert to app config - app_config = AgentChatAppConfigManager.get_app_config( - app_model=app_model, - app_model_config=app_model_config, - conversation=conversation, - override_config_dict=override_model_config_dict, - ) + # convert to app config + app_config = AgentChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + conversation=conversation, + override_config_dict=override_model_config_dict, + ) - # get tracing instance - trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id) + # get tracing instance + trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id) - # init application generate entity - application_generate_entity = AgentChatAppGenerateEntity( - task_id=str(uuid.uuid4()), - app_config=app_config, - model_conf=ModelConfigConverter.convert(app_config), - file_upload_config=file_extra_config, - conversation_id=conversation.id if conversation else None, - inputs=self._prepare_user_inputs( - user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id - ), - query=query, - files=list(file_objs), - parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, - user_id=user.id, - stream=streaming, - invoke_from=invoke_from, - extras=extras, - call_depth=0, - trace_manager=trace_manager, - ) + # init application generate entity + application_generate_entity = AgentChatAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_conf=ModelConfigConverter.convert(app_config), + file_upload_config=file_extra_config, + conversation_id=conversation.id if conversation else None, + inputs=self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), + query=query, + files=list(file_objs), + parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, + user_id=user.id, + stream=streaming, + invoke_from=invoke_from, + extras=extras, + call_depth=0, + trace_manager=trace_manager, + ) - # init generate records - (conversation, message) = self._init_generate_records(application_generate_entity, conversation) + # init generate records + (conversation, message) = self._init_generate_records(application_generate_entity, conversation) - # init queue manager - queue_manager = MessageBasedAppQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - conversation_id=conversation.id, - app_mode=conversation.mode, - message_id=message.id, - ) + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, + message_id=message.id, + ) - # new thread with request context and contextvars - context = contextvars.copy_context() + # new thread with request context and contextvars + context = contextvars.copy_context() - worker_thread = threading.Thread( - target=self._generate_worker, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "context": context, - "application_generate_entity": application_generate_entity, - "queue_manager": queue_manager, - "conversation_id": conversation.id, - "message_id": message.id, - }, - ) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "context": context, + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "conversation_id": conversation.id, + "message_id": message.id, + }, + ) - worker_thread.start() + worker_thread.start() - # return response or stream generator - response = self._handle_response( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message, - user=user, - stream=streaming, - ) - return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + # return response or stream generator + response = self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + user=user, + stream=streaming, + ) + return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def _generate_worker( self, diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index a81da2e91c4..a20d3f3c38f 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -1,6 +1,9 @@ import logging from typing import cast +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import select from core.agent.cot_chat_agent_runner import CotChatAgentRunner @@ -15,9 +18,6 @@ from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationError -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from extensions.ext_database import db from models.model import App, Conversation, Message diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index a92e3dd2ea4..66390116d46 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -3,10 +3,11 @@ from abc import ABC, abstractmethod from collections.abc import Generator, Mapping from typing import Any, Union +from graphon.model_runtime.errors.invoke import InvokeError + from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 20e6ac98ea2..7eccd59d17c 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -1,27 +1,89 @@ from collections.abc import Generator, Mapping, Sequence +from contextlib import AbstractContextManager, nullcontext from typing import TYPE_CHECKING, Any, Union, final +from graphon.enums import NodeType +from graphon.file import File, FileUploadConfig +from graphon.variables.input_entities import VariableEntityType from sqlalchemy.orm import Session -from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.enums import NodeType -from dify_graph.file import File, FileUploadConfig -from dify_graph.repositories.draft_variable_repository import ( +from core.app.apps.draft_variable_saver import ( DraftVariableSaver, DraftVariableSaverFactory, NoopDraftVariableSaver, ) -from dify_graph.variables.input_entities import VariableEntityType +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.file_access import DatabaseFileAccessController, FileAccessScope, bind_file_access_scope +from extensions.ext_database import db from factories import file_factory from libs.orjson import orjson_dumps from models import Account, EndUser from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl if TYPE_CHECKING: - from dify_graph.variables.input_entities import VariableEntity + from graphon.variables.input_entities import VariableEntity + + +@final +class _DebuggerDraftVariableSaver: + """Adapter that binds SQLAlchemy session setup outside the saver port.""" + + def __init__( + self, + *, + account: Account, + app_id: str, + node_id: str, + node_type: NodeType, + node_execution_id: str, + enclosing_node_id: str | None = None, + ) -> None: + self._account = account + self._app_id = app_id + self._node_id = node_id + self._node_type = node_type + self._node_execution_id = node_execution_id + self._enclosing_node_id = enclosing_node_id + + def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None) -> None: + with Session(db.engine) as session, session.begin(): + DraftVariableSaverImpl( + session=session, + app_id=self._app_id, + node_id=self._node_id, + node_type=self._node_type, + node_execution_id=self._node_execution_id, + enclosing_node_id=self._enclosing_node_id, + user=self._account, + ).save(process_data, outputs) class BaseAppGenerator: + _file_access_controller: DatabaseFileAccessController = DatabaseFileAccessController() + + @staticmethod + def _bind_file_access_scope( + *, + tenant_id: str, + user: Account | EndUser, + invoke_from: InvokeFrom, + ) -> AbstractContextManager[None]: + """Bind request-scoped file ownership markers for downstream file lookups.""" + + user_id = getattr(user, "id", None) + if not isinstance(user_id, str) or not user_id: + return nullcontext() + + user_from = UserFrom.ACCOUNT if isinstance(user, Account) else UserFrom.END_USER + return bind_file_access_scope( + FileAccessScope( + tenant_id=tenant_id, + user_id=user_id, + user_from=user_from, + invoke_from=invoke_from, + ) + ) + def _prepare_user_inputs( self, *, @@ -50,6 +112,7 @@ class BaseAppGenerator: allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [], ), strict_type_validation=strict_type_validation, + access_controller=self._file_access_controller, ) for k, v in user_inputs.items() if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE @@ -64,6 +127,7 @@ class BaseAppGenerator: allowed_file_extensions=entity_dictionary[k].allowed_file_extensions or [], allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [], ), + access_controller=self._file_access_controller, ) for k, v in user_inputs.items() if isinstance(v, list) @@ -226,32 +290,30 @@ class BaseAppGenerator: assert isinstance(account, Account) def draft_var_saver_factory( - session: Session, app_id: str, node_id: str, node_type: NodeType, node_execution_id: str, enclosing_node_id: str | None = None, ) -> DraftVariableSaver: - return DraftVariableSaverImpl( - session=session, + return _DebuggerDraftVariableSaver( + account=account, app_id=app_id, node_id=node_id, node_type=node_type, node_execution_id=node_execution_id, enclosing_node_id=enclosing_node_id, - user=account, ) else: def draft_var_saver_factory( - session: Session, app_id: str, node_id: str, node_type: NodeType, node_execution_id: str, enclosing_node_id: str | None = None, ) -> DraftVariableSaver: + _ = app_id, node_id, node_type, node_execution_id, enclosing_node_id return NoopDraftVariableSaver() return draft_var_saver_factory diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index 5addd418158..20bf81aeecf 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -7,6 +7,7 @@ from enum import IntEnum, auto from typing import Any from cachetools import TTLCache, cachedmethod +from graphon.runtime import GraphRuntimeState from redis.exceptions import RedisError from sqlalchemy.orm import DeclarativeMeta @@ -20,7 +21,6 @@ from core.app.entities.queue_entities import ( QueueStopEvent, WorkflowQueueMessage, ) -from dify_graph.runtime import GraphRuntimeState from extensions.ext_redis import redis_client logger = logging.getLogger(__name__) @@ -61,27 +61,30 @@ class AppQueueManager(ABC): listen_timeout = dify_config.APP_MAX_EXECUTION_TIME start_time = time.time() last_ping_time: int | float = 0 - while True: - try: - message = self._q.get(timeout=1) - if message is None: - break + try: + while True: + try: + message = self._q.get(timeout=1) + if message is None: + break - yield message - except queue.Empty: - continue - finally: - elapsed_time = time.time() - start_time - if elapsed_time >= listen_timeout or self._is_stopped(): - # publish two messages to make sure the client can receive the stop signal - # and stop listening after the stop signal processed - self.publish( - QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.TASK_PIPELINE - ) + yield message + except queue.Empty: + continue + finally: + elapsed_time = time.time() - start_time + if elapsed_time >= listen_timeout or self._is_stopped(): + # publish two messages to make sure the client can receive the stop signal + # and stop listening after the stop signal processed + self.publish( + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.TASK_PIPELINE + ) - if elapsed_time // 10 > last_ping_time: - self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE) - last_ping_time = elapsed_time // 10 + if elapsed_time // 10 > last_ping_time: + self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE) + last_ping_time = elapsed_time // 10 + finally: + self._graph_runtime_state = None # Release reference once consumers finish or close the generator. def stop_listen(self): """ @@ -90,7 +93,6 @@ class AppQueueManager(ABC): """ self._clear_task_belong_cache() self._q.put(None) - self._graph_runtime_state = None # Release reference to allow GC to reclaim memory def _clear_task_belong_cache(self) -> None: """ diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 11fcbb75610..4aebc0cb30e 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -5,6 +5,17 @@ from collections.abc import Generator, Mapping, Sequence from mimetypes import guess_extension from typing import TYPE_CHECKING, Any, Union +from graphon.file import FileTransferMethod, FileType +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + TextPromptMessageContent, +) +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.errors.invoke import InvokeBadRequestError + from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import ( @@ -29,22 +40,12 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform from core.tools.tool_file_manager import ToolFileManager -from dify_graph.file.enums import FileTransferMethod, FileType -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from dify_graph.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - PromptMessage, - TextPromptMessageContent, -) -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey -from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError from extensions.ext_database import db from models.enums import CreatorUserRole, MessageFileBelongsTo from models.model import App, AppMode, Message, MessageAnnotation, MessageFile if TYPE_CHECKING: - from dify_graph.file.models import File + from graphon.file import File _logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 91cf54c774f..b675a87382c 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -1,3 +1,4 @@ +import contextvars import logging import threading import uuid @@ -5,6 +6,7 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, Union, overload from flask import Flask, copy_current_request_context, current_app +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from configs import dify_config @@ -20,7 +22,6 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom from core.ops.ops_trace_manager import TraceQueueManager -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError from extensions.ext_database import db from factories import file_factory from models import Account @@ -120,89 +121,96 @@ class ChatAppGenerator(MessageBasedAppGenerator): # # For implementation reference, see the `_parse_file` function and # `DraftWorkflowNodeRunApi` class which handle this properly. - files = args["files"] if args.get("files") else [] - file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) - if file_extra_config: - file_objs = file_factory.build_from_mappings( - mappings=files, - tenant_id=app_model.tenant_id, - config=file_extra_config, + with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from): + files = args["files"] if args.get("files") else [] + file_extra_config = FileUploadConfigManager.convert( + override_model_config_dict or app_model_config.to_dict() ) - else: - file_objs = [] + if file_extra_config: + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + config=file_extra_config, + access_controller=self._file_access_controller, + ) + else: + file_objs = [] - # convert to app config - app_config = ChatAppConfigManager.get_app_config( - app_model=app_model, - app_model_config=app_model_config, - conversation=conversation, - override_config_dict=override_model_config_dict, - ) + # convert to app config + app_config = ChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + conversation=conversation, + override_config_dict=override_model_config_dict, + ) - # get tracing instance - trace_manager = TraceQueueManager( - app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id - ) + # get tracing instance + trace_manager = TraceQueueManager( + app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id + ) - # init application generate entity - application_generate_entity = ChatAppGenerateEntity( - task_id=str(uuid.uuid4()), - app_config=app_config, - model_conf=ModelConfigConverter.convert(app_config), - file_upload_config=file_extra_config, - conversation_id=conversation.id if conversation else None, - inputs=self._prepare_user_inputs( - user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id - ), - query=query, - files=list(file_objs), - parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, - user_id=user.id, - invoke_from=invoke_from, - extras=extras, - trace_manager=trace_manager, - stream=streaming, - ) + # init application generate entity + application_generate_entity = ChatAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_conf=ModelConfigConverter.convert(app_config), + file_upload_config=file_extra_config, + conversation_id=conversation.id if conversation else None, + inputs=self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), + query=query, + files=list(file_objs), + parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, + user_id=user.id, + invoke_from=invoke_from, + extras=extras, + trace_manager=trace_manager, + stream=streaming, + ) - # init generate records - (conversation, message) = self._init_generate_records(application_generate_entity, conversation) + # init generate records + (conversation, message) = self._init_generate_records(application_generate_entity, conversation) - # init queue manager - queue_manager = MessageBasedAppQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - conversation_id=conversation.id, - app_mode=conversation.mode, - message_id=message.id, - ) - - # new thread with request context - @copy_current_request_context - def worker_with_context(): - return self._generate_worker( - flask_app=current_app._get_current_object(), # type: ignore - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, conversation_id=conversation.id, + app_mode=conversation.mode, message_id=message.id, ) - worker_thread = threading.Thread(target=worker_with_context) + context = contextvars.copy_context() - worker_thread.start() + # new thread with request context + @copy_current_request_context + def worker_with_context(): + return context.run( + self._generate_worker, + flask_app=current_app._get_current_object(), # type: ignore + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation_id=conversation.id, + message_id=message.id, + ) - # return response or stream generator - response = self._handle_response( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message, - user=user, - stream=streaming, - ) + worker_thread = threading.Thread(target=worker_with_context) - return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + worker_thread.start() + + # return response or stream generator + response = self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + user=user, + stream=streaming, + ) + + return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def _generate_worker( self, diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index f63b38fc866..050f763e958 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -1,6 +1,8 @@ import logging from typing import cast +from graphon.file import File +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from sqlalchemy import select from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -15,8 +17,6 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from dify_graph.file import File -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent from extensions.ext_database import db from models.model import App, Conversation, Message @@ -223,7 +223,6 @@ class ChatAppRunner(AppRunner): model_parameters=application_generate_entity.model_conf.parameters, stop=stop, stream=application_generate_entity.stream, - user=application_generate_entity.user_id, ) # handle invoke result diff --git a/api/core/app/apps/common/graph_runtime_state_support.py b/api/core/app/apps/common/graph_runtime_state_support.py index 6a8e4361635..ab277857fe5 100644 --- a/api/core/app/apps/common/graph_runtime_state_support.py +++ b/api/core/app/apps/common/graph_runtime_state_support.py @@ -4,7 +4,9 @@ from __future__ import annotations from typing import TYPE_CHECKING -from dify_graph.runtime import GraphRuntimeState +from graphon.runtime import GraphRuntimeState + +from core.workflow.system_variables import SystemVariableKey, get_system_text if TYPE_CHECKING: from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline @@ -30,10 +32,10 @@ class GraphRuntimeStateSupport: return self._resolve_graph_runtime_state(graph_runtime_state) def _extract_workflow_run_id(self, graph_runtime_state: GraphRuntimeState) -> str: - system_variables = graph_runtime_state.variable_pool.system_variables - if not system_variables or not system_variables.workflow_execution_id: + workflow_run_id = get_system_text(graph_runtime_state.variable_pool, SystemVariableKey.WORKFLOW_EXECUTION_ID) + if not workflow_run_id: raise ValueError("workflow_execution_id missing from runtime state") - return str(system_variables.workflow_execution_id) + return workflow_run_id def _resolve_graph_runtime_state( self, diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 621b0d8cf35..a5155316163 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -1,3 +1,4 @@ +import json import logging import time from collections.abc import Mapping, Sequence @@ -5,6 +6,19 @@ from dataclasses import dataclass from datetime import datetime from typing import Any, NewType, TypedDict, Union +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import ( + BuiltinNodeTypes, + WorkflowExecutionStatus, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.file import FILE_MODEL_IDENTITY, File +from graphon.runtime import GraphRuntimeState +from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment +from graphon.variables.variables import Variable +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import select from sqlalchemy.orm import Session @@ -50,21 +64,9 @@ from core.tools.entities.tool_entities import ToolProviderType from core.tools.tool_manager import ToolManager from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.trigger.trigger_manager import TriggerManager +from core.workflow.human_input_forms import load_form_tokens_by_form_id +from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping from core.workflow.workflow_entry import WorkflowEntry -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.enums import ( - BuiltinNodeTypes, - SystemVariableKey, - WorkflowExecutionStatus, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from dify_graph.file import FILE_MODEL_IDENTITY, File -from dify_graph.runtime import GraphRuntimeState -from dify_graph.system_variable import SystemVariable -from dify_graph.variables.segments import ArrayFileSegment, FileSegment, Segment -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models import Account, EndUser @@ -111,11 +113,11 @@ class WorkflowResponseConverter: *, application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity], user: Union[Account, EndUser], - system_variables: SystemVariable, + system_variables: Sequence[Variable], ): self._application_generate_entity = application_generate_entity self._user = user - self._system_variables = system_variables + self._system_variables = system_variables_to_mapping(system_variables) self._workflow_inputs = self._prepare_workflow_inputs() # Disable truncation for SERVICE_API calls to keep backward compatibility. @@ -133,7 +135,7 @@ class WorkflowResponseConverter: # ------------------------------------------------------------------ def _prepare_workflow_inputs(self) -> Mapping[str, Any]: inputs = dict(self._application_generate_entity.inputs) - for field_name, value in self._system_variables.to_dict().items(): + for field_name, value in self._system_variables.items(): # TODO(@future-refactor): store system variables separately from user inputs so we don't # need to flatten `sys.*` entries into the input payload just for rerun/export tooling. if field_name == SystemVariableKey.CONVERSATION_ID: @@ -318,13 +320,23 @@ class WorkflowResponseConverter: pause_reasons = [reason.model_dump(mode="json") for reason in event.reasons] human_input_form_ids = [reason.form_id for reason in event.reasons if isinstance(reason, HumanInputRequired)] expiration_times_by_form_id: dict[str, datetime] = {} + display_in_ui_by_form_id: dict[str, bool] = {} + form_token_by_form_id: dict[str, str] = {} if human_input_form_ids: - stmt = select(HumanInputForm.id, HumanInputForm.expiration_time).where( - HumanInputForm.id.in_(human_input_form_ids) - ) + stmt = select( + HumanInputForm.id, + HumanInputForm.expiration_time, + HumanInputForm.form_definition, + ).where(HumanInputForm.id.in_(human_input_form_ids)) with Session(bind=db.engine) as session: - for form_id, expiration_time in session.execute(stmt): + for form_id, expiration_time, form_definition in session.execute(stmt): expiration_times_by_form_id[str(form_id)] = expiration_time + try: + definition_payload = json.loads(form_definition) if form_definition else {} + except (TypeError, json.JSONDecodeError): + definition_payload = {} + display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui")) + form_token_by_form_id = load_form_tokens_by_form_id(human_input_form_ids, session=session) responses: list[StreamResponse] = [] @@ -344,8 +356,8 @@ class WorkflowResponseConverter: form_content=reason.form_content, inputs=reason.inputs, actions=reason.actions, - display_in_ui=reason.display_in_ui, - form_token=reason.form_token, + display_in_ui=display_in_ui_by_form_id.get(reason.form_id, False), + form_token=form_token_by_form_id.get(reason.form_id), resolved_default_values=reason.resolved_default_values, expiration_time=int(expiration_time.timestamp()), ), diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 002b914ef1f..a62c5b80b51 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -1,3 +1,4 @@ +import contextvars import logging import threading import uuid @@ -5,6 +6,7 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, Union, overload from flask import Flask, copy_current_request_context, current_app +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from sqlalchemy import select @@ -20,7 +22,6 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom from core.ops.ops_trace_manager import TraceQueueManager -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError from extensions.ext_database import db from factories import file_factory from models import Account, App, EndUser, Message @@ -108,83 +109,90 @@ class CompletionAppGenerator(MessageBasedAppGenerator): # # For implementation reference, see the `_parse_file` function and # `DraftWorkflowNodeRunApi` class which handle this properly. - files = args["files"] if args.get("files") else [] - file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) - if file_extra_config: - file_objs = file_factory.build_from_mappings( - mappings=files, - tenant_id=app_model.tenant_id, - config=file_extra_config, + with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from): + files = args["files"] if args.get("files") else [] + file_extra_config = FileUploadConfigManager.convert( + override_model_config_dict or app_model_config.to_dict() ) - else: - file_objs = [] + if file_extra_config: + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + config=file_extra_config, + access_controller=self._file_access_controller, + ) + else: + file_objs = [] - # convert to app config - app_config = CompletionAppConfigManager.get_app_config( - app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict - ) + # convert to app config + app_config = CompletionAppConfigManager.get_app_config( + app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict + ) - # get tracing instance - trace_manager = TraceQueueManager( - app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id - ) + # get tracing instance + trace_manager = TraceQueueManager( + app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id + ) - # init application generate entity - application_generate_entity = CompletionAppGenerateEntity( - task_id=str(uuid.uuid4()), - app_config=app_config, - model_conf=ModelConfigConverter.convert(app_config), - file_upload_config=file_extra_config, - inputs=self._prepare_user_inputs( - user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id - ), - query=query, - files=list(file_objs), - user_id=user.id, - stream=streaming, - invoke_from=invoke_from, - extras={}, - trace_manager=trace_manager, - ) + # init application generate entity + application_generate_entity = CompletionAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_conf=ModelConfigConverter.convert(app_config), + file_upload_config=file_extra_config, + inputs=self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), + query=query, + files=list(file_objs), + user_id=user.id, + stream=streaming, + invoke_from=invoke_from, + extras={}, + trace_manager=trace_manager, + ) - # init generate records - (conversation, message) = self._init_generate_records(application_generate_entity) + # init generate records + (conversation, message) = self._init_generate_records(application_generate_entity) - # init queue manager - queue_manager = MessageBasedAppQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - conversation_id=conversation.id, - app_mode=conversation.mode, - message_id=message.id, - ) - - # new thread with request context - @copy_current_request_context - def worker_with_context(): - return self._generate_worker( - flask_app=current_app._get_current_object(), # type: ignore - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, message_id=message.id, ) - worker_thread = threading.Thread(target=worker_with_context) + context = contextvars.copy_context() - worker_thread.start() + # new thread with request context + @copy_current_request_context + def worker_with_context(): + return context.run( + self._generate_worker, + flask_app=current_app._get_current_object(), # type: ignore + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + message_id=message.id, + ) - # return response or stream generator - response = self._handle_response( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message, - user=user, - stream=streaming, - ) + worker_thread = threading.Thread(target=worker_with_context) - return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + worker_thread.start() + + # return response or stream generator + response = self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + user=user, + stream=streaming, + ) + + return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def _generate_worker( self, @@ -280,71 +288,76 @@ class CompletionAppGenerator(MessageBasedAppGenerator): model_dict["completion_params"] = completion_params override_model_config_dict["model"] = model_dict - # parse files - file_extra_config = FileUploadConfigManager.convert(override_model_config_dict) - if file_extra_config: - file_objs = file_factory.build_from_mappings( - mappings=message.message_files, - tenant_id=app_model.tenant_id, - config=file_extra_config, + with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from): + # parse files + file_extra_config = FileUploadConfigManager.convert(override_model_config_dict) + if file_extra_config: + file_objs = file_factory.build_from_mappings( + mappings=message.message_files, + tenant_id=app_model.tenant_id, + config=file_extra_config, + access_controller=self._file_access_controller, + ) + else: + file_objs = [] + + # convert to app config + app_config = CompletionAppConfigManager.get_app_config( + app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict ) - else: - file_objs = [] - # convert to app config - app_config = CompletionAppConfigManager.get_app_config( - app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict - ) + # init application generate entity + application_generate_entity = CompletionAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_conf=ModelConfigConverter.convert(app_config), + inputs=message.inputs, + query=message.query, + files=list(file_objs), + user_id=user.id, + stream=stream, + invoke_from=invoke_from, + extras={}, + ) - # init application generate entity - application_generate_entity = CompletionAppGenerateEntity( - task_id=str(uuid.uuid4()), - app_config=app_config, - model_conf=ModelConfigConverter.convert(app_config), - inputs=message.inputs, - query=message.query, - files=list(file_objs), - user_id=user.id, - stream=stream, - invoke_from=invoke_from, - extras={}, - ) + # init generate records + (conversation, message) = self._init_generate_records(application_generate_entity) - # init generate records - (conversation, message) = self._init_generate_records(application_generate_entity) - - # init queue manager - queue_manager = MessageBasedAppQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - conversation_id=conversation.id, - app_mode=conversation.mode, - message_id=message.id, - ) - - # new thread with request context - @copy_current_request_context - def worker_with_context(): - return self._generate_worker( - flask_app=current_app._get_current_object(), # type: ignore - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, message_id=message.id, ) - worker_thread = threading.Thread(target=worker_with_context) + context = contextvars.copy_context() - worker_thread.start() + # new thread with request context + @copy_current_request_context + def worker_with_context(): + return context.run( + self._generate_worker, + flask_app=current_app._get_current_object(), # type: ignore + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + message_id=message.id, + ) - # return response or stream generator - response = self._handle_response( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message, - user=user, - stream=stream, - ) + worker_thread = threading.Thread(target=worker_with_context) - return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + worker_thread.start() + + # return response or stream generator + response = self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + user=user, + stream=stream, + ) + + return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index 56a45198792..b216f7cf7b1 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -1,6 +1,8 @@ import logging from typing import cast +from graphon.file import File +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from sqlalchemy import select from core.app.apps.base_app_queue_manager import AppQueueManager @@ -13,8 +15,6 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa from core.model_manager import ModelInstance from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from dify_graph.file import File -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent from extensions.ext_database import db from models.model import App, Message @@ -181,7 +181,6 @@ class CompletionAppRunner(AppRunner): model_parameters=application_generate_entity.model_conf.parameters, stop=stop, stream=application_generate_entity.stream, - user=application_generate_entity.user_id, ) # handle invoke result diff --git a/api/dify_graph/repositories/draft_variable_repository.py b/api/core/app/apps/draft_variable_saver.py similarity index 65% rename from api/dify_graph/repositories/draft_variable_repository.py rename to api/core/app/apps/draft_variable_saver.py index b2ebfacffd1..24018012c5f 100644 --- a/api/dify_graph/repositories/draft_variable_repository.py +++ b/api/core/app/apps/draft_variable_saver.py @@ -4,31 +4,30 @@ import abc from collections.abc import Mapping from typing import Any, Protocol -from sqlalchemy.orm import Session - -from dify_graph.enums import NodeType +from graphon.enums import NodeType class DraftVariableSaver(Protocol): @abc.abstractmethod - def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None): - pass + def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None) -> None: + """Persist node draft variables for a completed execution.""" + raise NotImplementedError class DraftVariableSaverFactory(Protocol): @abc.abstractmethod def __call__( self, - session: Session, app_id: str, node_id: str, node_type: NodeType, node_execution_id: str, enclosing_node_id: str | None = None, ) -> DraftVariableSaver: - pass + """Build a saver bound to a concrete node execution.""" + raise NotImplementedError class NoopDraftVariableSaver(DraftVariableSaver): - def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None): - pass + def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None) -> None: + return None diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 64c28ca60f0..fe61224ada5 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -28,12 +28,13 @@ from core.app.entities.task_entities import ( ) from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from core.workflow.file_reference import resolve_file_record_id from extensions.ext_database import db from extensions.ext_redis import get_pubsub_broadcast_channel from libs.broadcast_channel.channel import Topic from libs.datetime_utils import naive_utc_now from models import Account -from models.enums import CreatorUserRole, MessageFileBelongsTo +from models.enums import ConversationFromSource, CreatorUserRole, MessageFileBelongsTo from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.conversation import ConversationNotExistsError @@ -130,10 +131,10 @@ class MessageBasedAppGenerator(BaseAppGenerator): end_user_id = None account_id = None if application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: - from_source = "api" + from_source = ConversationFromSource.API end_user_id = application_generate_entity.user_id else: - from_source = "console" + from_source = ConversationFromSource.CONSOLE account_id = application_generate_entity.user_id if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity): @@ -227,7 +228,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): transfer_method=file.transfer_method, belongs_to=MessageFileBelongsTo.USER, url=file.remote_url, - upload_file_id=file.related_id, + upload_file_id=resolve_file_record_id(file.reference), created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER), created_by=account_id or end_user_id or "", ) diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 19d67eb108e..fa242003a25 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -10,6 +10,8 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, Union, cast, overload from flask import Flask, current_app +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from pydantic import ValidationError from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -18,6 +20,7 @@ import contexts from configs import dify_config from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.draft_variable_saver import DraftVariableSaverFactory from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager @@ -34,12 +37,11 @@ from core.datasource.entities.datasource_entities import ( from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin from core.entities.knowledge_entities import PipelineDataset, PipelineDocument from core.rag.index_processor.constant.built_in_field import BuiltInField -from core.repositories.factory import DifyCoreRepositoryFactory -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError -from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader +from core.repositories.factory import ( + DifyCoreRepositoryFactory, + WorkflowExecutionRepository, + WorkflowNodeExecutionRepository, +) from extensions.ext_database import db from libs.flask_utils import preserve_flask_contexts from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index e767766bdb9..4c188dac68d 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -2,6 +2,14 @@ import logging import time from typing import cast +from graphon.entities import GraphInitParams +from graphon.enums import WorkflowType +from graphon.graph import Graph +from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader +from graphon.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput + from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner @@ -12,18 +20,11 @@ from core.app.entities.app_invoke_entities import ( build_dify_run_context, ) from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry -from dify_graph.entities.graph_init_params import GraphInitParams -from dify_graph.enums import WorkflowType -from dify_graph.graph import Graph -from dify_graph.graph_events import GraphEngineEvent, GraphRunFailedEvent -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variable_loader import VariableLoader -from dify_graph.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput from extensions.ext_database import db from models.dataset import Document, Pipeline from models.model import EndUser @@ -106,13 +107,14 @@ class PipelineRunner(WorkflowBasedAppRunner): workflow=workflow, single_iteration_run=self.application_generate_entity.single_iteration_run, single_loop_run=self.application_generate_entity.single_loop_run, + user_id=self.application_generate_entity.user_id, ) else: inputs = self.application_generate_entity.inputs files = self.application_generate_entity.files # Create a variable pool. - system_inputs = SystemVariable( + system_inputs = build_system_variables( files=files, user_id=user_id, app_id=app_config.app_id, @@ -142,19 +144,25 @@ class PipelineRunner(WorkflowBasedAppRunner): ) ) - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=inputs, - environment_variables=workflow.environment_variables, - conversation_variables=[], - rag_pipeline_variables=rag_pipeline_variables, + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_inputs, + environment_variables=workflow.environment_variables, + rag_pipeline_variables=rag_pipeline_variables, + ), ) + root_node_id = self.application_generate_entity.start_node_id or get_default_root_node_id( + workflow.graph_dict + ) + add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=inputs) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) # init graph graph = self._init_rag_pipeline_graph( graph_runtime_state=graph_runtime_state, - start_node_id=self.application_generate_entity.start_node_id, + start_node_id=root_node_id, workflow=workflow, user_from=user_from, invoke_from=invoke_from, diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 6fbe19a3b2e..9618ab35c62 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -8,6 +8,10 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal, Union, overload from flask import Flask, current_app +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.runtime import GraphRuntimeState +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from pydantic import ValidationError from sqlalchemy import select from sqlalchemy.orm import sessionmaker @@ -17,6 +21,7 @@ from configs import dify_config from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.draft_variable_saver import DraftVariableSaverFactory from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager @@ -30,13 +35,7 @@ from core.db.session_factory import session_factory from core.helper.trace_id_helper import extract_external_trace_id_from_args from core.ops.ops_trace_manager import TraceQueueManager from core.repositories import DifyCoreRepositoryFactory -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError -from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from dify_graph.runtime import GraphRuntimeState -from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from extensions.ext_database import db from factories import file_factory from libs.flask_utils import preserve_flask_contexts @@ -129,107 +128,109 @@ class WorkflowAppGenerator(BaseAppGenerator): graph_engine_layers: Sequence[GraphEngineLayer] = (), pause_state_config: PauseStateLayerConfig | None = None, ) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: - files: Sequence[Mapping[str, Any]] = args.get("files") or [] + with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from): + files: Sequence[Mapping[str, Any]] = args.get("files") or [] - # parse files - # TODO(QuantumGhost): Move file parsing logic to the API controller layer - # for better separation of concerns. - # - # For implementation reference, see the `_parse_file` function and - # `DraftWorkflowNodeRunApi` class which handle this properly. - file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) - system_files = file_factory.build_from_mappings( - mappings=files, - tenant_id=app_model.tenant_id, - config=file_extra_config, - strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False, - ) - - # convert to app config - app_config = WorkflowAppConfigManager.get_app_config( - app_model=app_model, - workflow=workflow, - ) - - # get tracing instance - trace_manager = TraceQueueManager( - app_id=app_model.id, - user_id=user.id if isinstance(user, Account) else user.session_id, - ) - - inputs: Mapping[str, Any] = args["inputs"] - - extras = { - **extract_external_trace_id_from_args(args), - } - workflow_run_id = str(workflow_run_id or uuid.uuid4()) - # FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args - # trigger shouldn't prepare user inputs - if self._should_prepare_user_inputs(args): - inputs = self._prepare_user_inputs( - user_inputs=inputs, - variables=app_config.variables, + # parse files + # TODO(QuantumGhost): Move file parsing logic to the API controller layer + # for better separation of concerns. + # + # For implementation reference, see the `_parse_file` function and + # `DraftWorkflowNodeRunApi` class which handle this properly. + file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) + system_files = file_factory.build_from_mappings( + mappings=files, tenant_id=app_model.tenant_id, + config=file_extra_config, strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False, + access_controller=self._file_access_controller, ) - # init application generate entity - application_generate_entity = WorkflowAppGenerateEntity( - task_id=str(uuid.uuid4()), - app_config=app_config, - file_upload_config=file_extra_config, - inputs=inputs, - files=list(system_files), - user_id=user.id, - stream=streaming, - invoke_from=invoke_from, - call_depth=call_depth, - trace_manager=trace_manager, - workflow_execution_id=workflow_run_id, - extras=extras, - ) - contexts.plugin_tool_providers.set({}) - contexts.plugin_tool_providers_lock.set(threading.Lock()) + # convert to app config + app_config = WorkflowAppConfigManager.get_app_config( + app_model=app_model, + workflow=workflow, + ) - # Create repositories - # - # Create session factory - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - # Create workflow execution(aka workflow run) repository - if triggered_from is not None: - # Use explicitly provided triggered_from (for async triggers) - workflow_triggered_from = triggered_from - elif invoke_from == InvokeFrom.DEBUGGER: - workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING - else: - workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN - workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( - session_factory=session_factory, - user=user, - app_id=application_generate_entity.app_config.app_id, - triggered_from=workflow_triggered_from, - ) - # Create workflow node execution repository - workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( - session_factory=session_factory, - user=user, - app_id=application_generate_entity.app_config.app_id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, - ) + # get tracing instance + trace_manager = TraceQueueManager( + app_id=app_model.id, + user_id=user.id if isinstance(user, Account) else user.session_id, + ) - return self._generate( - app_model=app_model, - workflow=workflow, - user=user, - application_generate_entity=application_generate_entity, - invoke_from=invoke_from, - workflow_execution_repository=workflow_execution_repository, - workflow_node_execution_repository=workflow_node_execution_repository, - streaming=streaming, - root_node_id=root_node_id, - graph_engine_layers=graph_engine_layers, - pause_state_config=pause_state_config, - ) + inputs: Mapping[str, Any] = args["inputs"] + + extras = { + **extract_external_trace_id_from_args(args), + } + workflow_run_id = str(workflow_run_id or uuid.uuid4()) + # FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args + # trigger shouldn't prepare user inputs + if self._should_prepare_user_inputs(args): + inputs = self._prepare_user_inputs( + user_inputs=inputs, + variables=app_config.variables, + tenant_id=app_model.tenant_id, + strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False, + ) + # init application generate entity + application_generate_entity = WorkflowAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + file_upload_config=file_extra_config, + inputs=inputs, + files=list(system_files), + user_id=user.id, + stream=streaming, + invoke_from=invoke_from, + call_depth=call_depth, + trace_manager=trace_manager, + workflow_execution_id=workflow_run_id, + extras=extras, + ) + + contexts.plugin_tool_providers.set({}) + contexts.plugin_tool_providers_lock.set(threading.Lock()) + + # Create repositories + # + # Create session factory + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + # Create workflow execution(aka workflow run) repository + if triggered_from is not None: + # Use explicitly provided triggered_from (for async triggers) + workflow_triggered_from = triggered_from + elif invoke_from == InvokeFrom.DEBUGGER: + workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING + else: + workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=workflow_triggered_from, + ) + # Create workflow node execution repository + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + return self._generate( + app_model=app_model, + workflow=workflow, + user=user, + application_generate_entity=application_generate_entity, + invoke_from=invoke_from, + workflow_execution_repository=workflow_execution_repository, + workflow_node_execution_repository=workflow_node_execution_repository, + streaming=streaming, + root_node_id=root_node_id, + graph_engine_layers=graph_engine_layers, + pause_state_config=pause_state_config, + ) def resume( self, @@ -292,62 +293,67 @@ class WorkflowAppGenerator(BaseAppGenerator): :param workflow_node_execution_repository: repository for workflow node execution :param streaming: is stream """ - graph_layers: list[GraphEngineLayer] = list(graph_engine_layers) + with self._bind_file_access_scope( + tenant_id=application_generate_entity.app_config.tenant_id, + user=user, + invoke_from=invoke_from, + ): + graph_layers: list[GraphEngineLayer] = list(graph_engine_layers) - # init queue manager - queue_manager = WorkflowAppQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - app_mode=app_model.mode, - ) - - if pause_state_config is not None: - graph_layers.append( - PauseStatePersistenceLayer( - session_factory=pause_state_config.session_factory, - generate_entity=application_generate_entity, - state_owner_user_id=pause_state_config.state_owner_user_id, - ) + # init queue manager + queue_manager = WorkflowAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + app_mode=app_model.mode, ) - # new thread with request context and contextvars - context = contextvars.copy_context() + if pause_state_config is not None: + graph_layers.append( + PauseStatePersistenceLayer( + session_factory=pause_state_config.session_factory, + generate_entity=application_generate_entity, + state_owner_user_id=pause_state_config.state_owner_user_id, + ) + ) - # release database connection, because the following new thread operations may take a long time - db.session.close() + # new thread with request context and contextvars + context = contextvars.copy_context() - worker_thread = threading.Thread( - target=self._generate_worker, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "application_generate_entity": application_generate_entity, - "queue_manager": queue_manager, - "context": context, - "variable_loader": variable_loader, - "root_node_id": root_node_id, - "workflow_execution_repository": workflow_execution_repository, - "workflow_node_execution_repository": workflow_node_execution_repository, - "graph_engine_layers": tuple(graph_layers), - "graph_runtime_state": graph_runtime_state, - }, - ) + # release database connection, because the following new thread operations may take a long time + db.session.close() - worker_thread.start() + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "context": context, + "variable_loader": variable_loader, + "root_node_id": root_node_id, + "workflow_execution_repository": workflow_execution_repository, + "workflow_node_execution_repository": workflow_node_execution_repository, + "graph_engine_layers": tuple(graph_layers), + "graph_runtime_state": graph_runtime_state, + }, + ) - draft_var_saver_factory = self._get_draft_var_saver_factory(invoke_from, user) + worker_thread.start() - # return response or stream generator - response = self._handle_response( - application_generate_entity=application_generate_entity, - workflow=workflow, - queue_manager=queue_manager, - user=user, - draft_var_saver_factory=draft_var_saver_factory, - stream=streaming, - ) + draft_var_saver_factory = self._get_draft_var_saver_factory(invoke_from, user) - return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + # return response or stream generator + response = self._handle_response( + application_generate_entity=application_generate_entity, + workflow=workflow, + queue_manager=queue_manager, + user=user, + draft_var_saver_factory=draft_var_saver_factory, + stream=streaming, + ) + + return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def single_iteration_generate( self, diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index caea8b6b952..2cb8088971a 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -3,20 +3,22 @@ import time from collections.abc import Sequence from typing import cast +from graphon.enums import WorkflowType +from graphon.graph_engine.command_channels import RedisChannel +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader + from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfig from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository +from core.workflow.node_factory import get_default_root_node_id +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry -from dify_graph.enums import WorkflowType -from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variable_loader import VariableLoader from extensions.ext_redis import redis_client from extensions.otel import WorkflowAppRunnerHandler, trace_span from libs.datetime_utils import naive_utc_now @@ -91,12 +93,13 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): workflow=self._workflow, single_iteration_run=self.application_generate_entity.single_iteration_run, single_loop_run=self.application_generate_entity.single_loop_run, + user_id=self.application_generate_entity.user_id, ) else: inputs = self.application_generate_entity.inputs # Create a variable pool. - system_inputs = SystemVariable( + system_inputs = build_system_variables( files=self.application_generate_entity.files, user_id=self._sys_user_id, app_id=app_config.app_id, @@ -104,12 +107,16 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): workflow_id=app_config.workflow_id, workflow_execution_id=self.application_generate_entity.workflow_execution_id, ) - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=inputs, - environment_variables=self._workflow.environment_variables, - conversation_variables=[], + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_inputs, + environment_variables=self._workflow.environment_variables, + ), ) + root_node_id = self._root_node_id or get_default_root_node_id(self._workflow.graph_dict) + add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=inputs) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) graph = self._init_graph( @@ -120,7 +127,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): user_id=self.application_generate_entity.user_id, user_from=user_from, invoke_from=invoke_from, - root_node_id=self._root_node_id, + root_node_id=root_node_id, ) # RUN WORKFLOW diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 96dd8c5445f..49af169e88a 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -4,12 +4,16 @@ from collections.abc import Callable, Generator from contextlib import contextmanager from typing import Union +from graphon.entities import WorkflowStartReason +from graphon.enums import WorkflowExecutionStatus +from graphon.runtime import GraphRuntimeState from sqlalchemy.orm import Session from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter +from core.app.apps.draft_variable_saver import DraftVariableSaverFactory from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import ( AppQueueEvent, @@ -55,11 +59,7 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.ops.ops_trace_manager import TraceQueueManager -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory -from dify_graph.runtime import GraphRuntimeState -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables from extensions.ext_database import db from models import Account from models.enums import CreatorUserRole @@ -104,7 +104,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): self._invoke_from = queue_manager.invoke_from self._draft_var_saver_factory = draft_var_saver_factory self._workflow = workflow - self._workflow_system_variables = SystemVariable( + self._workflow_system_variables = build_system_variables( files=application_generate_entity.files, user_id=user_session_id, app_id=application_generate_entity.app_config.app_id, @@ -705,7 +705,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): app_id=self._application_generate_entity.app_config.app_id, workflow_id=self._workflow.id, workflow_run_id=workflow_run_id, - created_from=created_from.value, + created_from=created_from, created_by_role=self._created_by_role, created_by=self._user_id, ) @@ -728,13 +728,11 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): return response def _save_output_for_event(self, event: QueueNodeSucceededEvent | QueueNodeExceptionEvent, node_execution_id: str): - with Session(db.engine) as session, session.begin(): - saver = self._draft_var_saver_factory( - session=session, - app_id=self._application_generate_entity.app_config.app_id, - node_id=event.node_id, - node_type=event.node_type, - node_execution_id=node_execution_id, - enclosing_node_id=event.in_loop_id or event.in_iteration_id, - ) - saver.save(event.process_data, event.outputs) + saver = self._draft_var_saver_factory( + app_id=self._application_generate_entity.app_config.app_id, + node_id=event.node_id, + node_type=event.node_type, + node_execution_id=node_execution_id, + enclosing_node_id=event.in_loop_id or event.in_iteration_id, + ) + saver.save(event.process_data, event.outputs) diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index adc6cce9afd..f68c8e60b4f 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -3,6 +3,40 @@ import time from collections.abc import Mapping, Sequence from typing import Any, cast +from graphon.entities import GraphInitParams +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.entities.pause_reason import HumanInputRequired +from graphon.graph import Graph +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import ( + GraphEngineEvent, + GraphRunAbortedEvent, + GraphRunFailedEvent, + GraphRunPartialSucceededEvent, + GraphRunPausedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunAgentLogEvent, + NodeRunExceptionEvent, + NodeRunFailedEvent, + NodeRunHumanInputFormFilledEvent, + NodeRunHumanInputFormTimeoutEvent, + NodeRunIterationFailedEvent, + NodeRunIterationNextEvent, + NodeRunIterationStartedEvent, + NodeRunIterationSucceededEvent, + NodeRunLoopFailedEvent, + NodeRunLoopNextEvent, + NodeRunLoopStartedEvent, + NodeRunLoopSucceededEvent, + NodeRunRetrieverResourceEvent, + NodeRunRetryEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from pydantic import ValidationError from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -34,42 +68,16 @@ from core.app.entities.queue_entities import ( ) from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class -from core.workflow.workflow_entry import WorkflowEntry -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_config import NodeConfigDictAdapter -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.graph import Graph -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events import ( - GraphEngineEvent, - GraphRunFailedEvent, - GraphRunPartialSucceededEvent, - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunAgentLogEvent, - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunHumanInputFormFilledEvent, - NodeRunHumanInputFormTimeoutEvent, - NodeRunIterationFailedEvent, - NodeRunIterationNextEvent, - NodeRunIterationStartedEvent, - NodeRunIterationSucceededEvent, - NodeRunLoopFailedEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, - NodeRunRetrieverResourceEvent, - NodeRunRetryEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, +from core.workflow.system_variables import ( + build_bootstrap_variables, + default_system_variables, + get_node_creation_preload_selectors, + inject_default_system_variable_mappings, + preload_node_creation_variables, ) -from dify_graph.graph_events.graph import GraphRunAbortedEvent -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool +from core.workflow.variable_pool_initializer import add_variables_to_pool +from core.workflow.workflow_entry import WorkflowEntry +from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run from models.workflow import Workflow from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task @@ -156,6 +164,8 @@ class WorkflowBasedAppRunner: workflow: Workflow, single_iteration_run: Any | None = None, single_loop_run: Any | None = None, + *, + user_id: str, ) -> tuple[Graph, VariablePool, GraphRuntimeState]: """ Prepare graph, variable pool, and runtime state for single node execution @@ -173,14 +183,15 @@ class WorkflowBasedAppRunner: ValueError: If neither single_iteration_run nor single_loop_run is specified """ # Create initial runtime state with variable pool containing environment variables - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool( - system_variables=SystemVariable.default(), - user_inputs={}, + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=default_system_variables(), environment_variables=workflow.environment_variables, ), - start_at=time.time(), ) + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.time()) # Determine which type of single node execution and get graph/variable_pool if single_iteration_run: @@ -191,6 +202,7 @@ class WorkflowBasedAppRunner: graph_runtime_state=graph_runtime_state, node_type_filter_key="iteration_id", node_type_label="iteration", + user_id=user_id, ) elif single_loop_run: graph, variable_pool = self._get_graph_and_variable_pool_for_single_node_run( @@ -200,6 +212,7 @@ class WorkflowBasedAppRunner: graph_runtime_state=graph_runtime_state, node_type_filter_key="loop_id", node_type_label="loop", + user_id=user_id, ) else: raise ValueError("Neither single_iteration_run nor single_loop_run is specified") @@ -216,6 +229,8 @@ class WorkflowBasedAppRunner: graph_runtime_state: GraphRuntimeState, node_type_filter_key: str, # 'iteration_id' or 'loop_id' node_type_label: str = "node", # 'iteration' or 'loop' for error messages + *, + user_id: str = "", ) -> tuple[Graph, VariablePool]: """ Get graph and variable pool for single node execution (iteration or loop). @@ -272,6 +287,8 @@ class WorkflowBasedAppRunner: graph_config["edges"] = edge_configs + typed_node_configs = [NodeConfigDictAdapter.validate_python(node) for node in node_configs] + # Create required parameters for Graph.init graph_init_params = GraphInitParams( workflow_id=workflow.id, @@ -279,7 +296,7 @@ class WorkflowBasedAppRunner: run_context=build_dify_run_context( tenant_id=workflow.tenant_id, app_id=self._app_id, - user_id="", + user_id=user_id, user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, ), @@ -291,26 +308,15 @@ class WorkflowBasedAppRunner: graph_runtime_state=graph_runtime_state, ) - # init graph - graph = Graph.init( - graph_config=graph_config, node_factory=node_factory, root_node_id=node_id, skip_validation=True - ) - - if not graph: - raise ValueError("graph not found in workflow") - - # fetch node config from node id target_node_config = None - for node in node_configs: - if node.get("id") == node_id: + for node in typed_node_configs: + if node["id"] == node_id: target_node_config = node break if not target_node_config: raise ValueError(f"{node_type_label} node id not found in workflow graph") - target_node_config = NodeConfigDictAdapter.validate_python(target_node_config) - # Get node class node_type = target_node_config["data"].type node_version = str(target_node_config["data"].version) @@ -319,12 +325,31 @@ class WorkflowBasedAppRunner: # Use the variable pool from graph_runtime_state instead of creating a new one variable_pool = graph_runtime_state.variable_pool + preload_node_creation_variables( + variable_loader=self._variable_loader, + variable_pool=variable_pool, + selectors=[ + selector + for node_config in typed_node_configs + for selector in get_node_creation_preload_selectors( + node_type=node_config["data"].type, + node_data=node_config["data"], + ) + ], + ) + try: variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( graph_config=workflow.graph_dict, config=target_node_config ) except NotImplementedError: variable_mapping = {} + variable_mapping = inject_default_system_variable_mappings( + node_id=target_node_config["id"], + node_type=node_type, + node_data=target_node_config["data"], + variable_mapping=variable_mapping, + ) load_into_variable_pool( variable_loader=self._variable_loader, @@ -340,6 +365,14 @@ class WorkflowBasedAppRunner: tenant_id=workflow.tenant_id, ) + # init graph after constructor-time context has been loaded + graph = Graph.init( + graph_config=graph_config, node_factory=node_factory, root_node_id=node_id, skip_validation=True + ) + + if not graph: + raise ValueError("graph not found in workflow") + return graph, variable_pool @staticmethod @@ -408,7 +441,11 @@ class WorkflowBasedAppRunner: node_run_result = event.node_run_result inputs = node_run_result.inputs process_data = node_run_result.process_data - outputs = node_run_result.outputs + outputs = project_node_outputs_for_workflow_run( + node_type=event.node_type, + inputs=inputs, + outputs=node_run_result.outputs, + ) execution_metadata = node_run_result.metadata self._publish_event( QueueNodeRetryEvent( @@ -448,7 +485,11 @@ class WorkflowBasedAppRunner: node_run_result = event.node_run_result inputs = node_run_result.inputs process_data = node_run_result.process_data - outputs = node_run_result.outputs + outputs = project_node_outputs_for_workflow_run( + node_type=event.node_type, + inputs=inputs, + outputs=node_run_result.outputs, + ) execution_metadata = node_run_result.metadata self._publish_event( QueueNodeSucceededEvent( @@ -466,6 +507,11 @@ class WorkflowBasedAppRunner: ) ) elif isinstance(event, NodeRunFailedEvent): + outputs = project_node_outputs_for_workflow_run( + node_type=event.node_type, + inputs=event.node_run_result.inputs, + outputs=event.node_run_result.outputs, + ) self._publish_event( QueueNodeFailedEvent( node_execution_id=event.id, @@ -475,7 +521,7 @@ class WorkflowBasedAppRunner: finished_at=event.finished_at, inputs=event.node_run_result.inputs, process_data=event.node_run_result.process_data, - outputs=event.node_run_result.outputs, + outputs=outputs, error=event.node_run_result.error or "Unknown error", execution_metadata=event.node_run_result.metadata, in_iteration_id=event.in_iteration_id, @@ -483,6 +529,11 @@ class WorkflowBasedAppRunner: ) ) elif isinstance(event, NodeRunExceptionEvent): + outputs = project_node_outputs_for_workflow_run( + node_type=event.node_type, + inputs=event.node_run_result.inputs, + outputs=event.node_run_result.outputs, + ) self._publish_event( QueueNodeExceptionEvent( node_execution_id=event.id, @@ -492,7 +543,7 @@ class WorkflowBasedAppRunner: finished_at=event.finished_at, inputs=event.node_run_result.inputs, process_data=event.node_run_result.process_data, - outputs=event.node_run_result.outputs, + outputs=outputs, error=event.node_run_result.error or "Unknown error", execution_metadata=event.node_run_result.metadata, in_iteration_id=event.in_iteration_id, diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index ecbb1cf2f3a..0cdbb5f50a1 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -2,19 +2,21 @@ from collections.abc import Mapping, Sequence from enum import StrEnum from typing import TYPE_CHECKING, Any, Optional +from graphon.file import File, FileUploadConfig +from graphon.model_runtime.entities.model_entities import AIModelEntity from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator from constants import UUID_NIL from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.entities.provider_configuration import ProviderModelBundle -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.file import File, FileUploadConfig -from dify_graph.model_runtime.entities.model_entities import AIModelEntity if TYPE_CHECKING: from core.ops.ops_trace_manager import TraceQueueManager +DIFY_RUN_CONTEXT_KEY = "_dify" + + class UserFrom(StrEnum): ACCOUNT = "account" END_USER = "end-user" diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index d2a36f2a0de..5e56341f892 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -3,14 +3,14 @@ from datetime import datetime from enum import StrEnum, auto from typing import Any +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import PauseReason +from graphon.enums import NodeType, WorkflowNodeExecutionMetadataKey +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from pydantic import BaseModel, ConfigDict, Field from core.app.entities.agent_strategy import AgentStrategyInfo from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from dify_graph.entities.pause_reason import PauseReason -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk class QueueEvent(StrEnum): diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 46a8ab52f2f..ba3b2e356f7 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -2,14 +2,14 @@ from collections.abc import Mapping, Sequence from enum import StrEnum from typing import Any +from graphon.entities import WorkflowStartReason +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.nodes.human_input.entities import FormInput, UserAction from pydantic import BaseModel, ConfigDict, Field from core.app.entities.agent_strategy import AgentStrategyInfo from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from dify_graph.nodes.human_input.entities import FormInput, UserAction class AnnotationReplyAccount(BaseModel): diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index 50aed37163b..0bd904811a0 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -4,9 +4,10 @@ from sqlalchemy import select from core.app.entities.app_invoke_entities import InvokeFrom from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from models.dataset import Dataset -from models.enums import CollectionBindingType +from models.enums import CollectionBindingType, ConversationFromSource from models.model import App, AppAnnotationSetting, Message, MessageAnnotation from services.annotation_service import AppAnnotationService from services.dataset_service import DatasetCollectionBindingService @@ -50,7 +51,7 @@ class AnnotationReplyFeature: dataset = Dataset( id=app_record.id, tenant_id=app_record.tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=embedding_provider_name, embedding_model=embedding_model_name, collection_binding_id=dataset_collection_binding.id, @@ -68,9 +69,9 @@ class AnnotationReplyFeature: annotation = AppAnnotationService.get_annotation_by_id(annotation_id) if annotation: if invoke_from in {InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP}: - from_source = "api" + from_source = ConversationFromSource.API else: - from_source = "console" + from_source = ConversationFromSource.CONSOLE # insert annotation history AppAnnotationService.add_annotation_history( diff --git a/api/core/app/features/hosting_moderation/hosting_moderation.py b/api/core/app/features/hosting_moderation/hosting_moderation.py index 5ed1fadc412..d2d2fea4fb8 100644 --- a/api/core/app/features/hosting_moderation/hosting_moderation.py +++ b/api/core/app/features/hosting_moderation/hosting_moderation.py @@ -1,8 +1,9 @@ import logging +from graphon.model_runtime.entities.message_entities import PromptMessage + from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity from core.helper import moderation -from dify_graph.model_runtime.entities.message_entities import PromptMessage logger = logging.getLogger(__name__) diff --git a/api/core/app/features/rate_limiting/rate_limit.py b/api/core/app/features/rate_limiting/rate_limit.py index 2ca1275a8a4..e0f1759e5e9 100644 --- a/api/core/app/features/rate_limiting/rate_limit.py +++ b/api/core/app/features/rate_limiting/rate_limit.py @@ -19,6 +19,7 @@ class RateLimit: _REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes _ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes _instance_dict: dict[str, "RateLimit"] = {} + max_active_requests: int def __new__(cls, client_id: str, max_active_requests: int): if client_id not in cls._instance_dict: @@ -27,7 +28,13 @@ class RateLimit: return cls._instance_dict[client_id] def __init__(self, client_id: str, max_active_requests: int): + flush_cache = hasattr(self, "max_active_requests") and self.max_active_requests != max_active_requests self.max_active_requests = max_active_requests + # Only flush here if this instance has already been fully initialized, + # i.e. the Redis key attributes exist. Otherwise, rely on the flush at + # the end of initialization below. + if flush_cache and hasattr(self, "active_requests_key") and hasattr(self, "max_active_requests_key"): + self.flush_cache(use_local_value=True) # must be called after max_active_requests is set if self.disabled(): return @@ -41,8 +48,6 @@ class RateLimit: self.flush_cache(use_local_value=True) def flush_cache(self, use_local_value=False): - if self.disabled(): - return self.last_recalculate_time = time.time() # flush max active requests if use_local_value or not redis_client.exists(self.max_active_requests_key): @@ -50,7 +55,8 @@ class RateLimit: else: self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode("utf-8")) redis_client.expire(self.max_active_requests_key, timedelta(days=1)) - + if self.disabled(): + return # flush max active requests (in-transit request list) if not redis_client.exists(self.active_requests_key): return diff --git a/api/core/app/file_access/__init__.py b/api/core/app/file_access/__init__.py new file mode 100644 index 00000000000..a75ab9781be --- /dev/null +++ b/api/core/app/file_access/__init__.py @@ -0,0 +1,11 @@ +from .controller import DatabaseFileAccessController +from .protocols import FileAccessControllerProtocol +from .scope import FileAccessScope, bind_file_access_scope, get_current_file_access_scope + +__all__ = [ + "DatabaseFileAccessController", + "FileAccessControllerProtocol", + "FileAccessScope", + "bind_file_access_scope", + "get_current_file_access_scope", +] diff --git a/api/core/app/file_access/controller.py b/api/core/app/file_access/controller.py new file mode 100644 index 00000000000..300c187083f --- /dev/null +++ b/api/core/app/file_access/controller.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from collections.abc import Callable + +from sqlalchemy import select +from sqlalchemy.orm import Session +from sqlalchemy.sql import Select + +from models import ToolFile, UploadFile +from models.enums import CreatorUserRole + +from .protocols import FileAccessControllerProtocol +from .scope import FileAccessScope, get_current_file_access_scope + + +class DatabaseFileAccessController(FileAccessControllerProtocol): + """Workflow-layer authorization helper for database-backed file lookups. + + Tenant scoping remains mandatory. When the current execution belongs to an + end user, the lookup is additionally constrained to that end user's file + ownership markers. + """ + + _scope_getter: Callable[[], FileAccessScope | None] + + def __init__( + self, + *, + scope_getter: Callable[[], FileAccessScope | None] = get_current_file_access_scope, + ) -> None: + self._scope_getter = scope_getter + + def current_scope(self) -> FileAccessScope | None: + return self._scope_getter() + + def apply_upload_file_filters( + self, + stmt: Select[tuple[UploadFile]], + *, + scope: FileAccessScope | None = None, + ) -> Select[tuple[UploadFile]]: + resolved_scope = scope or self.current_scope() + if resolved_scope is None: + return stmt + + scoped_stmt = stmt.where(UploadFile.tenant_id == resolved_scope.tenant_id) + if not resolved_scope.requires_user_ownership: + return scoped_stmt + + return scoped_stmt.where( + UploadFile.created_by_role == CreatorUserRole.END_USER, + UploadFile.created_by == resolved_scope.user_id, + ) + + def apply_tool_file_filters( + self, + stmt: Select[tuple[ToolFile]], + *, + scope: FileAccessScope | None = None, + ) -> Select[tuple[ToolFile]]: + resolved_scope = scope or self.current_scope() + if resolved_scope is None: + return stmt + + scoped_stmt = stmt.where(ToolFile.tenant_id == resolved_scope.tenant_id) + if not resolved_scope.requires_user_ownership: + return scoped_stmt + + return scoped_stmt.where(ToolFile.user_id == resolved_scope.user_id) + + def get_upload_file( + self, + *, + session: Session, + file_id: str, + scope: FileAccessScope | None = None, + ) -> UploadFile | None: + resolved_scope = scope or self.current_scope() + if resolved_scope is None: + return session.get(UploadFile, file_id) + + stmt = self.apply_upload_file_filters( + select(UploadFile).where(UploadFile.id == file_id), + scope=resolved_scope, + ) + return session.scalar(stmt) + + def get_tool_file( + self, + *, + session: Session, + file_id: str, + scope: FileAccessScope | None = None, + ) -> ToolFile | None: + resolved_scope = scope or self.current_scope() + if resolved_scope is None: + return session.get(ToolFile, file_id) + + stmt = self.apply_tool_file_filters( + select(ToolFile).where(ToolFile.id == file_id), + scope=resolved_scope, + ) + return session.scalar(stmt) diff --git a/api/core/app/file_access/protocols.py b/api/core/app/file_access/protocols.py new file mode 100644 index 00000000000..8bb3eb99240 --- /dev/null +++ b/api/core/app/file_access/protocols.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from typing import Protocol + +from sqlalchemy.orm import Session +from sqlalchemy.sql import Select + +from models import ToolFile, UploadFile + +from .scope import FileAccessScope + + +class FileAccessControllerProtocol(Protocol): + """Contract for applying access rules to file lookups. + + Implementations translate an optional execution scope into query constraints + and authorized record retrieval. The contract is intentionally limited to + ownership and tenancy rules for workflow-layer file access. + """ + + def current_scope(self) -> FileAccessScope | None: + """Return the scope active for the current execution, if one exists. + + Callers use this to decide whether embedded file metadata may be trusted + or whether a fresh authorized lookup is required. + """ + ... + + def apply_upload_file_filters( + self, + stmt: Select[tuple[UploadFile]], + *, + scope: FileAccessScope | None = None, + ) -> Select[tuple[UploadFile]]: + """Return an upload-file query constrained by the supplied access scope. + + The returned statement must preserve the caller's existing predicates and + append only access-control conditions. + """ + ... + + def apply_tool_file_filters( + self, + stmt: Select[tuple[ToolFile]], + *, + scope: FileAccessScope | None = None, + ) -> Select[tuple[ToolFile]]: + """Return a tool-file query constrained by the supplied access scope. + + The returned statement must preserve the caller's existing predicates and + append only access-control conditions. + """ + ... + + def get_upload_file( + self, + *, + session: Session, + file_id: str, + scope: FileAccessScope | None = None, + ) -> UploadFile | None: + """Load one authorized upload-file record for the given identifier. + + Returns ``None`` when the file does not exist or when the scope does not + permit access to that record. + """ + ... + + def get_tool_file( + self, + *, + session: Session, + file_id: str, + scope: FileAccessScope | None = None, + ) -> ToolFile | None: + """Load one authorized tool-file record for the given identifier. + + Returns ``None`` when the file does not exist or when the scope does not + permit access to that record. + """ + ... diff --git a/api/core/app/file_access/scope.py b/api/core/app/file_access/scope.py new file mode 100644 index 00000000000..80d504ef1c0 --- /dev/null +++ b/api/core/app/file_access/scope.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from collections.abc import Iterator +from contextlib import contextmanager +from contextvars import ContextVar +from dataclasses import dataclass + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom + +_current_file_access_scope: ContextVar[FileAccessScope | None] = ContextVar( + "current_file_access_scope", + default=None, +) + + +@dataclass(frozen=True, slots=True) +class FileAccessScope: + """Request-scoped ownership context used by workflow-layer file lookups.""" + + tenant_id: str + user_id: str + user_from: UserFrom + invoke_from: InvokeFrom + + @property + def requires_user_ownership(self) -> bool: + return self.user_from == UserFrom.END_USER + + +def get_current_file_access_scope() -> FileAccessScope | None: + return _current_file_access_scope.get() + + +@contextmanager +def bind_file_access_scope(scope: FileAccessScope) -> Iterator[None]: + token = _current_file_access_scope.set(scope) + try: + yield + finally: + _current_file_access_scope.reset(token) diff --git a/api/core/app/layers/conversation_variable_persist_layer.py b/api/core/app/layers/conversation_variable_persist_layer.py index d227e4e904e..e09869f5f8f 100644 --- a/api/core/app/layers/conversation_variable_persist_layer.py +++ b/api/core/app/layers/conversation_variable_persist_layer.py @@ -1,12 +1,20 @@ +""" +Persist conversation-scoped variable updates emitted by the graph engine. + +The graph package emits generic variable update events and stays unaware of +conversation identity or storage concerns. This layer lives in the application +core, listens to those generic events, and persists only the `conversation.*` +scope updates that matter to chat applications. +""" + import logging -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID -from dify_graph.conversation_variable_updater import ConversationVariableUpdater -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events import GraphEngineEvent, NodeRunSucceededEvent -from dify_graph.nodes.variable_assigner.common import helpers as common_helpers -from dify_graph.variables import VariableBase +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, NodeRunVariableUpdatedEvent + +from core.workflow.system_variables import SystemVariableKey, get_system_text +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID +from services.conversation_variable_updater import ConversationVariableUpdater logger = logging.getLogger(__name__) @@ -20,41 +28,22 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer): pass def on_event(self, event: GraphEngineEvent) -> None: - if not isinstance(event, NodeRunSucceededEvent): - return - if event.node_type != BuiltinNodeTypes.VARIABLE_ASSIGNER: - return - if self.graph_runtime_state is None: + if not isinstance(event, NodeRunVariableUpdatedEvent): return - updated_variables = common_helpers.get_updated_variables(event.node_run_result.process_data) or [] - if not updated_variables: + selector = event.variable.selector + if len(selector) < 2: + logger.warning("Conversation variable selector invalid. selector=%s", selector) return - conversation_id = self.graph_runtime_state.system_variable.conversation_id + conversation_id = get_system_text(self.graph_runtime_state.variable_pool, SystemVariableKey.CONVERSATION_ID) if conversation_id is None: return - updated_any = False - for item in updated_variables: - selector = item.selector - if len(selector) < 2: - logger.warning("Conversation variable selector invalid. selector=%s", selector) - continue - if selector[0] != CONVERSATION_VARIABLE_NODE_ID: - continue - variable = self.graph_runtime_state.variable_pool.get(selector) - if not isinstance(variable, VariableBase): - logger.warning( - "Conversation variable not found in variable pool. selector=%s", - selector, - ) - continue - self._conversation_variable_updater.update(conversation_id=conversation_id, variable=variable) - updated_any = True + if selector[0] != CONVERSATION_VARIABLE_NODE_ID: + return - if updated_any: - self._conversation_variable_updater.flush() + self._conversation_variable_updater.update(conversation_id=conversation_id, variable=event.variable) def on_graph_end(self, error: Exception | None) -> None: pass diff --git a/api/core/app/layers/pause_state_persist_layer.py b/api/core/app/layers/pause_state_persist_layer.py index 4370c01a0bf..79a54421306 100644 --- a/api/core/app/layers/pause_state_persist_layer.py +++ b/api/core/app/layers/pause_state_persist_layer.py @@ -1,14 +1,14 @@ from dataclasses import dataclass from typing import Annotated, Literal, Self, TypeAlias +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent from pydantic import BaseModel, Field from sqlalchemy import Engine from sqlalchemy.orm import Session, sessionmaker from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events.base import GraphEngineEvent -from dify_graph.graph_events.graph import GraphRunPausedEvent +from core.workflow.system_variables import SystemVariableKey, get_system_text from models.model import AppMode from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.factory import DifyAPIRepositoryFactory @@ -119,7 +119,10 @@ class PauseStatePersistenceLayer(GraphEngineLayer): generate_entity=entity_wrapper, ) - workflow_run_id: str | None = self.graph_runtime_state.system_variable.workflow_execution_id + workflow_run_id = get_system_text( + self.graph_runtime_state.variable_pool, + SystemVariableKey.WORKFLOW_EXECUTION_ID, + ) assert workflow_run_id is not None repo = self._get_repo() repo.create_workflow_pause( diff --git a/api/core/app/layers/suspend_layer.py b/api/core/app/layers/suspend_layer.py index 2adaf14a357..1a79a9f843e 100644 --- a/api/core/app/layers/suspend_layer.py +++ b/api/core/app/layers/suspend_layer.py @@ -1,21 +1,27 @@ -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events.base import GraphEngineEvent -from dify_graph.graph_events.graph import GraphRunPausedEvent +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent class SuspendLayer(GraphEngineLayer): """ """ + def __init__(self) -> None: + super().__init__() + self._paused = False + def on_graph_start(self): - pass + self._paused = False def on_event(self, event: GraphEngineEvent): """ Handle the paused event, stash runtime state into storage and wait for resume. """ if isinstance(event, GraphRunPausedEvent): - pass + self._paused = True def on_graph_end(self, error: Exception | None): """ """ - pass + self._paused = False + + def is_paused(self) -> bool: + return self._paused diff --git a/api/core/app/layers/timeslice_layer.py b/api/core/app/layers/timeslice_layer.py index d7ca45f209a..8c8daf87122 100644 --- a/api/core/app/layers/timeslice_layer.py +++ b/api/core/app/layers/timeslice_layer.py @@ -3,10 +3,10 @@ import uuid from typing import ClassVar from apscheduler.schedulers.background import BackgroundScheduler # type: ignore +from graphon.graph_engine.entities.commands import CommandType, GraphEngineCommand +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent -from dify_graph.graph_engine.entities.commands import CommandType, GraphEngineCommand -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events.base import GraphEngineEvent from services.workflow.entities import WorkflowScheduleCFSPlanEntity from services.workflow.scheduler import CFSPlanScheduler, SchedulerCommand diff --git a/api/core/app/layers/trigger_post_layer.py b/api/core/app/layers/trigger_post_layer.py index a4019a83e14..77c7bec67e6 100644 --- a/api/core/app/layers/trigger_post_layer.py +++ b/api/core/app/layers/trigger_post_layer.py @@ -2,12 +2,12 @@ import logging from datetime import UTC, datetime from typing import Any, ClassVar +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent from pydantic import TypeAdapter from core.db.session_factory import session_factory -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events.base import GraphEngineEvent -from dify_graph.graph_events.graph import GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent +from core.workflow.system_variables import SystemVariableKey, get_system_text from models.enums import WorkflowTriggerStatus from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from tasks.workflow_cfs_scheduler.cfs_scheduler import AsyncWorkflowCFSPlanEntity @@ -59,7 +59,10 @@ class TriggerPostLayer(GraphEngineLayer): outputs = self.graph_runtime_state.outputs # BASICLY, workflow_execution_id is the same as workflow_run_id - workflow_run_id = self.graph_runtime_state.system_variable.workflow_execution_id + workflow_run_id = get_system_text( + self.graph_runtime_state.variable_pool, + SystemVariableKey.WORKFLOW_EXECUTION_ID, + ) assert workflow_run_id, "Workflow run id is not set" total_tokens = self.graph_runtime_state.total_tokens diff --git a/api/core/app/llm/model_access.py b/api/core/app/llm/model_access.py index a63ff39fa53..278d0cb30b5 100644 --- a/api/core/app/llm/model_access.py +++ b/api/core/app/llm/model_access.py @@ -2,23 +2,35 @@ from __future__ import annotations from typing import Any -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.nodes.llm.entities import ModelConfig +from graphon.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError +from graphon.nodes.llm.protocols import CredentialsProvider + +from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity from core.errors.error import ProviderTokenNotInitError from core.model_manager import ModelInstance, ModelManager +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.provider_manager import ProviderManager -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.nodes.llm.entities import ModelConfig -from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory class DifyCredentialsProvider: tenant_id: str provider_manager: ProviderManager - def __init__(self, tenant_id: str, provider_manager: ProviderManager | None = None) -> None: - self.tenant_id = tenant_id - self.provider_manager = provider_manager or ProviderManager() + def __init__( + self, + *, + run_context: DifyRunContext, + provider_manager: ProviderManager | None = None, + ) -> None: + self.tenant_id = run_context.tenant_id + if provider_manager is None: + provider_manager = create_plugin_provider_manager( + tenant_id=run_context.tenant_id, + user_id=run_context.user_id, + ) + self.provider_manager = provider_manager def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: provider_configurations = self.provider_manager.get_configurations(self.tenant_id) @@ -42,9 +54,21 @@ class DifyModelFactory: tenant_id: str model_manager: ModelManager - def __init__(self, tenant_id: str, model_manager: ModelManager | None = None) -> None: - self.tenant_id = tenant_id - self.model_manager = model_manager or ModelManager() + def __init__( + self, + *, + run_context: DifyRunContext, + model_manager: ModelManager | None = None, + ) -> None: + self.tenant_id = run_context.tenant_id + if model_manager is None: + model_manager = ModelManager( + provider_manager=create_plugin_provider_manager( + tenant_id=run_context.tenant_id, + user_id=run_context.user_id, + ) + ) + self.model_manager = model_manager def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance: return self.model_manager.get_model_instance( @@ -55,18 +79,42 @@ class DifyModelFactory: ) -def build_dify_model_access(tenant_id: str) -> tuple[CredentialsProvider, ModelFactory]: - return ( - DifyCredentialsProvider(tenant_id=tenant_id), - DifyModelFactory(tenant_id=tenant_id), +def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsProvider, DifyModelFactory]: + """Create LLM access adapters that share the same tenant-bound manager graph.""" + provider_manager = create_plugin_provider_manager( + tenant_id=run_context.tenant_id, + user_id=run_context.user_id, ) + model_manager = ModelManager(provider_manager=provider_manager) + + return ( + DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager), + DifyModelFactory(run_context=run_context, model_manager=model_manager), + ) + + +def _normalize_completion_params(completion_params: dict[str, Any]) -> tuple[dict[str, Any], list[str]]: + """ + Split node-level completion params into provider parameters and stop sequences. + + Workflow LLM-compatible nodes still consume runtime invocation settings from + ``ModelInstance.parameters`` and ``ModelInstance.stop``. Keep the + ``ModelInstance`` view and the returned config entity aligned here so callers + do not need to duplicate normalization logic. + """ + normalized_parameters = dict(completion_params) + stop = normalized_parameters.pop("stop", []) + if not isinstance(stop, list) or not all(isinstance(item, str) for item in stop): + stop = [] + + return normalized_parameters, stop def fetch_model_config( *, node_data_model: ModelConfig, credentials_provider: CredentialsProvider, - model_factory: ModelFactory, + model_factory: DifyModelFactory, ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: if not node_data_model.mode: raise LLMModeRequiredError("LLM mode is required.") @@ -80,22 +128,18 @@ def fetch_model_config( model_type=ModelType.LLM, ) if provider_model is None: - raise ModelNotExistError(f"Model {node_data_model.name} not exist.") + raise ModelNotExistError(f"Model {node_data_model.name} does not exist.") provider_model.raise_for_status() - completion_params = dict(node_data_model.completion_params) - stop = completion_params.pop("stop", []) - if not isinstance(stop, list): - stop = [] - model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials) - if not model_schema: - raise ModelNotExistError(f"Model {node_data_model.name} not exist.") + if model_schema is None: + raise ModelNotExistError(f"Model {node_data_model.name} schema does not exist.") + parameters, stop = _normalize_completion_params(node_data_model.completion_params) model_instance.provider = node_data_model.provider model_instance.model_name = node_data_model.name model_instance.credentials = credentials - model_instance.parameters = completion_params + model_instance.parameters = parameters model_instance.stop = tuple(stop) return model_instance, ModelConfigWithCredentialsEntity( @@ -103,8 +147,8 @@ def fetch_model_config( model=node_data_model.name, model_schema=model_schema, mode=node_data_model.mode, - provider_model_bundle=provider_model_bundle, credentials=credentials, - parameters=completion_params, + parameters=parameters, stop=stop, + provider_model_bundle=provider_model_bundle, ) diff --git a/api/core/app/llm/quota.py b/api/core/app/llm/quota.py index 7aa3bf15aba..63d22353588 100644 --- a/api/core/app/llm/quota.py +++ b/api/core/app/llm/quota.py @@ -1,3 +1,4 @@ +from graphon.model_runtime.entities.llm_entities import LLMUsage from sqlalchemy import update from sqlalchemy.orm import Session @@ -6,7 +7,6 @@ from core.entities.model_entities import ModelStatus from core.entities.provider_entities import ProviderQuotaType, QuotaUnit from core.errors.error import QuotaExceededError from core.model_manager import ModelInstance -from dify_graph.model_runtime.entities.llm_entities import LLMUsage from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.provider import Provider, ProviderType diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 0d5e0acec66..10b9c36d3e2 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -1,6 +1,7 @@ import logging import time +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from sqlalchemy import select from sqlalchemy.orm import Session @@ -17,7 +18,6 @@ from core.app.entities.task_entities import ( ) from core.errors.error import QuotaExceededError from core.moderation.output_moderation import ModerationRule, OutputModeration -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from models.enums import MessageStatus from models.model import Message diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index b530fe1ce4b..a410fac5580 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -4,6 +4,13 @@ from collections.abc import Generator from threading import Thread from typing import Any, Union, cast +from graphon.file import FileTransferMethod +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + TextPromptMessageContent, +) +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import select from sqlalchemy.orm import Session @@ -51,13 +58,6 @@ from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from dify_graph.file.enums import FileTransferMethod -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from dify_graph.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - TextPromptMessageContent, -) -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from events.message_event import message_was_created from extensions.ext_database import db from libs.datetime_utils import naive_utc_now diff --git a/api/core/app/task_pipeline/message_file_utils.py b/api/core/app/task_pipeline/message_file_utils.py index fc8b6c6b5aa..b23a33923b3 100644 --- a/api/core/app/task_pipeline/message_file_utils.py +++ b/api/core/app/task_pipeline/message_file_utils.py @@ -1,8 +1,9 @@ from typing import TypedDict +from graphon.file import FileTransferMethod +from graphon.file import helpers as file_helpers + from core.tools.signature import sign_tool_file -from dify_graph.file import helpers as file_helpers -from dify_graph.file.enums import FileTransferMethod from models.model import MessageFile, UploadFile MAX_TOOL_FILE_EXTENSION_LENGTH = 10 diff --git a/api/core/app/workflow/file_runtime.py b/api/core/app/workflow/file_runtime.py index e0f8d271110..8604235ef28 100644 --- a/api/core/app/workflow/file_runtime.py +++ b/api/core/app/workflow/file_runtime.py @@ -1,33 +1,43 @@ from __future__ import annotations +import base64 +import hashlib +import hmac +import os +import time +import urllib.parse from collections.abc import Generator +from typing import TYPE_CHECKING, Literal + +from graphon.file import FileTransferMethod +from graphon.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol +from graphon.file.runtime import set_workflow_file_runtime from configs import dify_config +from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol +from core.db.session_factory import session_factory from core.helper.ssrf_proxy import ssrf_proxy from core.tools.signature import sign_tool_file -from dify_graph.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol -from dify_graph.file.runtime import set_workflow_file_runtime +from core.workflow.file_reference import parse_file_reference from extensions.ext_storage import storage +if TYPE_CHECKING: + from graphon.file import File + class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol): - """Production runtime wiring for ``dify_graph.file``.""" + """Production runtime wiring for ``graphon.file``. - @property - def files_url(self) -> str: - return dify_config.FILES_URL + Opaque file references are resolved back to canonical database records before + URLs are signed or storage keys are used. When a request-scoped file access + scope is present, those lookups additionally enforce tenant and end-user + ownership filters. + """ - @property - def internal_files_url(self) -> str | None: - return dify_config.INTERNAL_FILES_URL + _file_access_controller: FileAccessControllerProtocol - @property - def secret_key(self) -> str: - return dify_config.SECRET_KEY - - @property - def files_access_timeout(self) -> int: - return dify_config.FILES_ACCESS_TIMEOUT + def __init__(self, *, file_access_controller: FileAccessControllerProtocol) -> None: + self._file_access_controller = file_access_controller @property def multimodal_send_format(self) -> str: @@ -39,9 +49,137 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol): def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: return storage.load(path, stream=stream) - def sign_tool_file(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: + def load_file_bytes(self, *, file: File) -> bytes: + storage_key = self._resolve_storage_key(file=file) + data = storage.load(storage_key, stream=False) + if not isinstance(data, bytes): + raise ValueError(f"file {storage_key} is not a bytes object") + return data + + def resolve_file_url(self, *, file: File, for_external: bool = True) -> str | None: + if file.transfer_method == FileTransferMethod.REMOTE_URL: + return file.remote_url + parsed_reference = parse_file_reference(file.reference) + if parsed_reference is None: + raise ValueError("Missing file reference") + if file.transfer_method == FileTransferMethod.LOCAL_FILE: + return self.resolve_upload_file_url( + upload_file_id=parsed_reference.record_id, + for_external=for_external, + ) + if file.transfer_method == FileTransferMethod.DATASOURCE_FILE: + if file.extension is None: + raise ValueError("Missing file extension") + self._assert_upload_file_access(upload_file_id=parsed_reference.record_id) + return sign_tool_file( + tool_file_id=parsed_reference.record_id, + extension=file.extension, + for_external=for_external, + ) + if file.transfer_method == FileTransferMethod.TOOL_FILE: + if file.extension is None: + raise ValueError("Missing file extension") + return self.resolve_tool_file_url( + tool_file_id=parsed_reference.record_id, + extension=file.extension, + for_external=for_external, + ) + return None + + def resolve_upload_file_url( + self, + *, + upload_file_id: str, + as_attachment: bool = False, + for_external: bool = True, + ) -> str: + self._assert_upload_file_access(upload_file_id=upload_file_id) + base_url = self._base_url(for_external=for_external) + url = f"{base_url}/files/{upload_file_id}/file-preview" + query = self._sign_query(payload=f"file-preview|{upload_file_id}") + if as_attachment: + query["as_attachment"] = "true" + return f"{url}?{urllib.parse.urlencode(query)}" + + def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: + self._assert_tool_file_access(tool_file_id=tool_file_id) return sign_tool_file(tool_file_id=tool_file_id, extension=extension, for_external=for_external) + def verify_preview_signature( + self, + *, + preview_kind: Literal["image", "file"], + file_id: str, + timestamp: str, + nonce: str, + sign: str, + ) -> bool: + payload = f"{preview_kind}-preview|{file_id}|{timestamp}|{nonce}" + recalculated = hmac.new(self._secret_key(), payload.encode(), hashlib.sha256).digest() + if sign != base64.urlsafe_b64encode(recalculated).decode(): + return False + return int(time.time()) - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT + + @staticmethod + def _base_url(*, for_external: bool) -> str: + if for_external: + return dify_config.FILES_URL + return dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL + + @staticmethod + def _secret_key() -> bytes: + return dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + + def _sign_query(self, *, payload: str) -> dict[str, str]: + timestamp = str(int(time.time())) + nonce = os.urandom(16).hex() + sign = hmac.new(self._secret_key(), f"{payload}|{timestamp}|{nonce}".encode(), hashlib.sha256).digest() + return { + "timestamp": timestamp, + "nonce": nonce, + "sign": base64.urlsafe_b64encode(sign).decode(), + } + + def _resolve_storage_key(self, *, file: File) -> str: + parsed_reference = parse_file_reference(file.reference) + if parsed_reference is None: + raise ValueError("Missing file reference") + + record_id = parsed_reference.record_id + with session_factory.create_session() as session: + if file.transfer_method in { + FileTransferMethod.LOCAL_FILE, + FileTransferMethod.REMOTE_URL, + FileTransferMethod.DATASOURCE_FILE, + }: + upload_file = self._file_access_controller.get_upload_file(session=session, file_id=record_id) + if upload_file is None: + raise ValueError(f"Upload file {record_id} not found") + return upload_file.key + + tool_file = self._file_access_controller.get_tool_file(session=session, file_id=record_id) + if tool_file is None: + raise ValueError(f"Tool file {record_id} not found") + return tool_file.file_key + + def _assert_upload_file_access(self, *, upload_file_id: str) -> None: + if self._file_access_controller.current_scope() is None: + return + + with session_factory.create_session() as session: + upload_file = self._file_access_controller.get_upload_file(session=session, file_id=upload_file_id) + if upload_file is None: + raise ValueError(f"Upload file {upload_file_id} not found") + + def _assert_tool_file_access(self, *, tool_file_id: str) -> None: + if self._file_access_controller.current_scope() is None: + return + + with session_factory.create_session() as session: + tool_file = self._file_access_controller.get_tool_file(session=session, file_id=tool_file_id) + if tool_file is None: + raise ValueError(f"Tool file {tool_file_id} not found") + def bind_dify_workflow_file_runtime() -> None: - set_workflow_file_runtime(DifyWorkflowFileRuntime()) + set_workflow_file_runtime(DifyWorkflowFileRuntime(file_access_controller=DatabaseFileAccessController())) diff --git a/api/core/app/workflow/layers/llm_quota.py b/api/core/app/workflow/layers/llm_quota.py index faf1516c404..48cabaf4d0f 100644 --- a/api/core/app/workflow/layers/llm_quota.py +++ b/api/core/app/workflow/layers/llm_quota.py @@ -7,22 +7,22 @@ This layer centralizes model-quota deduction outside node implementations. import logging from typing import TYPE_CHECKING, cast, final +from graphon.enums import BuiltinNodeTypes +from graphon.graph_engine.entities.commands import AbortCommand, CommandType +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent +from graphon.nodes.base.node import Node from typing_extensions import override +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from core.app.llm import deduct_llm_quota, ensure_llm_quota_available from core.errors.error import QuotaExceededError from core.model_manager import ModelInstance -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.graph_engine.entities.commands import AbortCommand, CommandType -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase -from dify_graph.graph_events.node import NodeRunSucceededEvent -from dify_graph.nodes.base.node import Node if TYPE_CHECKING: - from dify_graph.nodes.llm.node import LLMNode - from dify_graph.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode - from dify_graph.nodes.question_classifier.question_classifier_node import QuestionClassifierNode + from graphon.nodes.llm.node import LLMNode + from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode + from graphon.nodes.question_classifier.question_classifier_node import QuestionClassifierNode logger = logging.getLogger(__name__) @@ -75,7 +75,7 @@ class LLMQuotaLayer(GraphEngineLayer): return try: - dify_ctx = node.require_dify_context() + dify_ctx = DifyRunContext.model_validate(node.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) deduct_llm_quota( tenant_id=dify_ctx.tenant_id, model_instance=model_instance, @@ -114,11 +114,11 @@ class LLMQuotaLayer(GraphEngineLayer): try: match node.node_type: case BuiltinNodeTypes.LLM: - return cast("LLMNode", node).model_instance + model_instance = cast("LLMNode", node).model_instance case BuiltinNodeTypes.PARAMETER_EXTRACTOR: - return cast("ParameterExtractorNode", node).model_instance + model_instance = cast("ParameterExtractorNode", node).model_instance case BuiltinNodeTypes.QUESTION_CLASSIFIER: - return cast("QuestionClassifierNode", node).model_instance + model_instance = cast("QuestionClassifierNode", node).model_instance case _: return None except AttributeError: @@ -127,3 +127,12 @@ class LLMQuotaLayer(GraphEngineLayer): node.id, ) return None + + if isinstance(model_instance, ModelInstance): + return model_instance + + raw_model_instance = getattr(model_instance, "_model_instance", None) + if isinstance(raw_model_instance, ModelInstance): + return raw_model_instance + + return None diff --git a/api/core/app/workflow/layers/observability.py b/api/core/app/workflow/layers/observability.py index 4b20477a7ff..c4ed54a1406 100644 --- a/api/core/app/workflow/layers/observability.py +++ b/api/core/app/workflow/layers/observability.py @@ -8,18 +8,19 @@ associates with the node span. """ import logging +from contextvars import Token from dataclasses import dataclass from typing import cast, final +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node from opentelemetry import context as context_api from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_context from typing_extensions import override from configs import dify_config -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events import GraphNodeEventBase -from dify_graph.nodes.base.node import Node from extensions.otel.parser import ( DefaultNodeOTelParser, LLMNodeOTelParser, @@ -35,7 +36,7 @@ logger = logging.getLogger(__name__) @dataclass(slots=True) class _NodeSpanContext: span: "Span" - token: object + token: Token[context_api.Context] @final diff --git a/api/core/app/workflow/layers/persistence.py b/api/core/app/workflow/layers/persistence.py index 99b64b3ab57..ada065a9433 100644 --- a/api/core/app/workflow/layers/persistence.py +++ b/api/core/app/workflow/layers/persistence.py @@ -14,20 +14,15 @@ from dataclasses import dataclass from datetime import datetime from typing import Any, Union -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity -from core.ops.entities.trace_entity import TraceTaskName -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID -from dify_graph.entities import WorkflowExecution, WorkflowNodeExecution -from dify_graph.enums import ( - SystemVariableKey, +from graphon.entities import WorkflowExecution, WorkflowNodeExecution +from graphon.enums import ( WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, WorkflowType, ) -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events import ( +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import ( GraphEngineEvent, GraphRunAbortedEvent, GraphRunFailedEvent, @@ -42,9 +37,15 @@ from dify_graph.graph_events import ( NodeRunStartedEvent, NodeRunSucceededEvent, ) -from dify_graph.node_events import NodeRunResult -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from graphon.node_events import NodeRunResult + +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity +from core.ops.entities.trace_entity import TraceTaskName +from core.ops.ops_trace_manager import TraceQueueManager, TraceTask +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run from libs.datetime_utils import naive_utc_now @@ -128,14 +129,14 @@ class WorkflowPersistenceLayer(GraphEngineLayer): self._handle_graph_run_paused(event) return - if isinstance(event, NodeRunStartedEvent): - self._handle_node_started(event) - return - if isinstance(event, NodeRunRetryEvent): self._handle_node_retry(event) return + if isinstance(event, NodeRunStartedEvent): + self._handle_node_started(event) + return + if isinstance(event, NodeRunSucceededEvent): self._handle_node_succeeded(event) return @@ -372,10 +373,15 @@ class WorkflowPersistenceLayer(GraphEngineLayer): domain_execution.error = error if update_outputs: + projected_outputs = project_node_outputs_for_workflow_run( + node_type=domain_execution.node_type, + inputs=node_result.inputs, + outputs=node_result.outputs, + ) domain_execution.update_from_mapping( inputs=node_result.inputs, process_data=node_result.process_data, - outputs=node_result.outputs, + outputs=projected_outputs, metadata=node_result.metadata, ) diff --git a/api/core/base/tts/app_generator_tts_publisher.py b/api/core/base/tts/app_generator_tts_publisher.py index beda515666f..3d8a7a54f31 100644 --- a/api/core/base/tts/app_generator_tts_publisher.py +++ b/api/core/base/tts/app_generator_tts_publisher.py @@ -6,6 +6,9 @@ import re import threading from collections.abc import Iterable +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent +from graphon.model_runtime.entities.model_entities import ModelType + from core.app.entities.queue_entities import ( MessageQueueMessage, QueueAgentMessageEvent, @@ -15,8 +18,6 @@ from core.app.entities.queue_entities import ( WorkflowQueueMessage, ) from core.model_manager import ModelInstance, ModelManager -from dify_graph.model_runtime.entities.message_entities import TextPromptMessageContent -from dify_graph.model_runtime.entities.model_entities import ModelType class AudioTrunk: @@ -25,12 +26,10 @@ class AudioTrunk: self.status = status -def _invoice_tts(text_content: str, model_instance: ModelInstance, tenant_id: str, voice: str): +def _invoice_tts(text_content: str, model_instance: ModelInstance, voice: str): if not text_content or text_content.isspace(): return - return model_instance.invoke_tts( - content_text=text_content.strip(), user="responding_tts", tenant_id=tenant_id, voice=voice - ) + return model_instance.invoke_tts(content_text=text_content.strip(), voice=voice) def _process_future( @@ -62,7 +61,7 @@ class AppGeneratorTTSPublisher: self._audio_queue: queue.Queue[AudioTrunk] = queue.Queue() self._msg_queue: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue() self.match = re.compile(r"[。.!?]") - self.model_manager = ModelManager() + self.model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id, user_id="responding_tts") self.model_instance = self.model_manager.get_default_model_instance( tenant_id=self.tenant_id, model_type=ModelType.TTS ) @@ -89,7 +88,7 @@ class AppGeneratorTTSPublisher: if message is None: if self.msg_text and len(self.msg_text.strip()) > 0: futures_result = self.executor.submit( - _invoice_tts, self.msg_text, self.model_instance, self.tenant_id, self.voice + _invoice_tts, self.msg_text, self.model_instance, self.voice ) future_queue.put(futures_result) break @@ -117,9 +116,7 @@ class AppGeneratorTTSPublisher: if len(sentence_arr) >= min(self.max_sentence, 7): self.max_sentence += 1 text_content = "".join(sentence_arr) - futures_result = self.executor.submit( - _invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice - ) + futures_result = self.executor.submit(_invoice_tts, text_content, self.model_instance, self.voice) future_queue.put(futures_result) if isinstance(text_tmp, str): self.msg_text = text_tmp diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 8de5cb16900..6a071192447 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -1,7 +1,7 @@ import logging from collections.abc import Sequence -from sqlalchemy import select +from sqlalchemy import select, update from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom @@ -70,23 +70,21 @@ class DatasetIndexToolCallbackHandler: ) child_chunk = db.session.scalar(child_chunk_stmt) if child_chunk: - _ = ( - db.session.query(DocumentSegment) + db.session.execute( + update(DocumentSegment) .where(DocumentSegment.id == child_chunk.segment_id) - .update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False - ) + .values(hit_count=DocumentSegment.hit_count + 1) ) else: - query = db.session.query(DocumentSegment).where( - DocumentSegment.index_node_id == document.metadata["doc_id"] - ) + conditions = [DocumentSegment.index_node_id == document.metadata["doc_id"]] if "dataset_id" in document.metadata: - query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"]) + conditions.append(DocumentSegment.dataset_id == document.metadata["dataset_id"]) # add hit count to document segment - query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) + db.session.execute( + update(DocumentSegment).where(*conditions).values(hit_count=DocumentSegment.hit_count + 1) + ) db.session.commit() diff --git a/api/core/datasource/datasource_file_manager.py b/api/core/datasource/datasource_file_manager.py index 24243add17b..fe40d8f0e58 100644 --- a/api/core/datasource/datasource_file_manager.py +++ b/api/core/datasource/datasource_file_manager.py @@ -214,6 +214,6 @@ class DatasourceFileManager: # init tool_file_parser -# from dify_graph.file.datasource_file_parser import datasource_file_manager +# from graphon.file.datasource_file_parser import datasource_file_manager # # datasource_file_manager["manager"] = DatasourceFileManager diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index 4fa941ae164..143d1e696bf 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -3,9 +3,13 @@ from collections.abc import Generator from threading import Lock from typing import Any, cast +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType, get_file_type_by_mime_type +from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent from sqlalchemy import select import contexts +from core.app.file_access import DatabaseFileAccessController from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.datasource.entities.datasource_entities import ( @@ -24,18 +28,15 @@ from core.datasource.utils.message_transformer import DatasourceFileMessageTrans from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController from core.db.session_factory import session_factory from core.plugin.impl.datasource import PluginDatasourceManager +from core.workflow.file_reference import build_file_reference from core.workflow.nodes.datasource.entities import DatasourceParameter, OnlineDriveDownloadFileParam -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import WorkflowNodeExecutionMetadataKey -from dify_graph.file import File -from dify_graph.file.enums import FileTransferMethod, FileType -from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent from factories import file_factory from models.model import UploadFile from models.tools import ToolFile from services.datasource_provider_service import DatasourceProviderService logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() class DatasourceManager: @@ -279,11 +280,15 @@ class DatasourceManager: if datasource_file is not None: mapping = { "tool_file_id": datasource_file_id, - "type": file_factory.get_file_type_by_mime_type(mime_type), + "type": get_file_type_by_mime_type(mime_type), "transfer_method": FileTransferMethod.TOOL_FILE, "url": url, } - file_out = file_factory.build_from_mapping(mapping=mapping, tenant_id=tenant_id) + file_out = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + access_controller=_file_access_controller, + ) elif mtype == DatasourceMessage.MessageType.TEXT: assert isinstance(message.message, DatasourceMessage.TextMessage) yield StreamChunkEvent(selector=[node_id, "text"], chunk=message.message.text, is_final=False) @@ -351,11 +356,10 @@ class DatasourceManager: filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - tenant_id=tenant_id, type=FileType.CUSTOM, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, - related_id=upload_file.id, + reference=build_file_reference(record_id=str(upload_file.id)), size=upload_file.size, storage_key=upload_file.key, url=upload_file.source_url, diff --git a/api/core/datasource/entities/api_entities.py b/api/core/datasource/entities/api_entities.py index 4c9ff64479b..14d1af2e8b4 100644 --- a/api/core/datasource/entities/api_entities.py +++ b/api/core/datasource/entities/api_entities.py @@ -1,10 +1,10 @@ from typing import Literal, Optional +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from core.datasource.entities.datasource_entities import DatasourceParameter from core.tools.entities.common_entities import I18nObject -from dify_graph.model_runtime.utils.encoders import jsonable_encoder class DatasourceApiEntity(BaseModel): diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py index 2881888e27b..04f15dee31d 100644 --- a/api/core/datasource/utils/message_transformer.py +++ b/api/core/datasource/utils/message_transformer.py @@ -2,9 +2,11 @@ import logging from collections.abc import Generator from mimetypes import guess_extension, guess_type +from graphon.file import File, FileTransferMethod, FileType + from core.datasource.entities.datasource_entities import DatasourceMessage from core.tools.tool_file_manager import ToolFileManager -from dify_graph.file import File, FileTransferMethod, FileType +from core.workflow.file_reference import parse_file_reference from models.tools import ToolFile logger = logging.getLogger(__name__) @@ -103,8 +105,14 @@ class DatasourceFileMessageTransformer: file: File | None = meta.get("file") if isinstance(file, File): if file.transfer_method == FileTransferMethod.TOOL_FILE: - assert file.related_id is not None - url = cls.get_datasource_file_url(datasource_file_id=file.related_id, extension=file.extension) + reference = getattr(file, "reference", None) or getattr(file, "related_id", None) + parsed_reference = parse_file_reference(reference) if isinstance(reference, str) else None + if parsed_reference is None: + raise ValueError("datasource file is missing reference") + url = cls.get_datasource_file_url( + datasource_file_id=parsed_reference.record_id, + extension=file.extension, + ) if file.type == FileType.IMAGE: yield DatasourceMessage( type=DatasourceMessage.MessageType.IMAGE_LINK, diff --git a/api/core/entities/embedding_type.py b/api/core/entities/embedding_type.py index 89b48fd2efa..f49cbf9ffe3 100644 --- a/api/core/entities/embedding_type.py +++ b/api/core/entities/embedding_type.py @@ -1,10 +1,5 @@ -from enum import StrEnum, auto +"""Compatibility wrapper for the runtime embedding input enum.""" +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType -class EmbeddingInputType(StrEnum): - """ - Enum for embedding input type. - """ - - DOCUMENT = auto() - QUERY = auto() +__all__ = ["EmbeddingInputType"] diff --git a/api/core/entities/execution_extra_content.py b/api/core/entities/execution_extra_content.py index 1343bd8e82b..72f6590e683 100644 --- a/api/core/entities/execution_extra_content.py +++ b/api/core/entities/execution_extra_content.py @@ -3,9 +3,9 @@ from __future__ import annotations from collections.abc import Mapping, Sequence from typing import Any, TypeAlias +from graphon.nodes.human_input.entities import FormInput, UserAction from pydantic import BaseModel, ConfigDict, Field -from dify_graph.nodes.human_input.entities import FormInput, UserAction from models.execution_extra_content import ExecutionContentType diff --git a/api/core/entities/mcp_provider.py b/api/core/entities/mcp_provider.py index d214652e9c4..a440829b46b 100644 --- a/api/core/entities/mcp_provider.py +++ b/api/core/entities/mcp_provider.py @@ -6,6 +6,7 @@ from enum import StrEnum from typing import TYPE_CHECKING, Any from urllib.parse import urlparse +from graphon.file import helpers as file_helpers from pydantic import BaseModel from configs import dify_config @@ -15,7 +16,6 @@ from core.helper.provider_cache import NoOpProviderCredentialCache from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType -from dify_graph.file import helpers as file_helpers if TYPE_CHECKING: from models.tools import MCPToolProvider diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 3427fc54b14..84d95c38c6c 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -1,12 +1,11 @@ from collections.abc import Sequence from enum import StrEnum, auto +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType, ProviderModel +from graphon.model_runtime.entities.provider_entities import ProviderEntity from pydantic import BaseModel, ConfigDict -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import ModelType, ProviderModel -from dify_graph.model_runtime.entities.provider_entities import ProviderEntity - class ModelStatus(StrEnum): """ diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index a9f2300ba2c..8b48aa2660e 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import logging import re @@ -5,7 +7,17 @@ from collections import defaultdict from collections.abc import Iterator, Sequence from json import JSONDecodeError -from pydantic import BaseModel, ConfigDict, Field, model_validator +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FormType, + ProviderEntity, +) +from graphon.model_runtime.model_providers.__base.ai_model import AIModel +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from graphon.model_runtime.runtime import ModelRuntime +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator from sqlalchemy import func, select from sqlalchemy.orm import Session @@ -19,15 +31,7 @@ from core.entities.provider_entities import ( ) from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from dify_graph.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FormType, - ProviderEntity, -) -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory from libs.datetime_utils import naive_utc_now from models.engine import db from models.enums import CredentialSourceType @@ -60,6 +64,10 @@ class ProviderConfiguration(BaseModel): - Load balancing configurations - Model enablement/disablement + Request flows can bind a pre-scoped runtime via ``bind_model_runtime()`` so + nested schema and model lookups reuse the caller scope that was already + resolved by the composition layer. + TODO: lots of logic in a BaseModel entity should be separated, the exceptions should be classified """ @@ -73,6 +81,7 @@ class ProviderConfiguration(BaseModel): # pydantic configs model_config = ConfigDict(protected_namespaces=()) + _bound_model_runtime: ModelRuntime | None = PrivateAttr(default=None) @model_validator(mode="after") def _(self): @@ -92,6 +101,16 @@ class ProviderConfiguration(BaseModel): self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL) return self + def bind_model_runtime(self, model_runtime: ModelRuntime) -> None: + """Attach the already-composed runtime for request-bound call chains.""" + self._bound_model_runtime = model_runtime + + def get_model_provider_factory(self) -> ModelProviderFactory: + """Return a provider factory that preserves any request-bound runtime.""" + if self._bound_model_runtime is not None: + return ModelProviderFactory(model_runtime=self._bound_model_runtime) + return create_plugin_model_provider_factory(tenant_id=self.tenant_id) + def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None: """ Get current credentials. @@ -343,7 +362,7 @@ class ProviderConfiguration(BaseModel): tenant_id=self.tenant_id, token=original_credentials[key] ) - model_provider_factory = ModelProviderFactory(self.tenant_id) + model_provider_factory = self.get_model_provider_factory() validated_credentials = model_provider_factory.provider_credentials_validate( provider=self.provider.provider, credentials=credentials ) @@ -902,7 +921,7 @@ class ProviderConfiguration(BaseModel): tenant_id=self.tenant_id, token=original_credentials[key] ) - model_provider_factory = ModelProviderFactory(self.tenant_id) + model_provider_factory = self.get_model_provider_factory() validated_credentials = model_provider_factory.model_credentials_validate( provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials ) @@ -1388,7 +1407,7 @@ class ProviderConfiguration(BaseModel): :param model_type: model type :return: """ - model_provider_factory = ModelProviderFactory(self.tenant_id) + model_provider_factory = self.get_model_provider_factory() # Get model instance of LLM return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type) @@ -1397,7 +1416,7 @@ class ProviderConfiguration(BaseModel): """ Get model schema """ - model_provider_factory = ModelProviderFactory(self.tenant_id) + model_provider_factory = self.get_model_provider_factory() return model_provider_factory.get_model_schema( provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials ) @@ -1499,7 +1518,7 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - model_provider_factory = ModelProviderFactory(self.tenant_id) + model_provider_factory = self.get_model_provider_factory() provider_schema = model_provider_factory.get_provider_schema(self.provider.provider) model_types: list[ModelType] = [] diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index a830f227a9b..2c8767a32b8 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -3,6 +3,7 @@ from __future__ import annotations from enum import StrEnum, auto from typing import Union +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, ConfigDict, Field from core.entities.parameter_entities import ( @@ -12,7 +13,6 @@ from core.entities.parameter_entities import ( ToolSelectorScope, ) from core.tools.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import ModelType class ProviderQuotaType(StrEnum): diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 4251cfd30b0..35bfcfb6a5c 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -4,6 +4,7 @@ from threading import Lock from typing import Any import httpx +from graphon.nodes.code.entities import CodeLanguage from pydantic import BaseModel from yarl import URL @@ -13,7 +14,6 @@ from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTr from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer from core.helper.code_executor.template_transformer import TemplateTransformer from core.helper.http_client_pooling import get_pooled_http_client -from dify_graph.nodes.code.entities import CodeLanguage logger = logging.getLogger(__name__) code_execution_endpoint_url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index c569e066f4e..b96a9ce3808 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -5,7 +5,7 @@ from base64 import b64encode from collections.abc import Mapping from typing import Any -from dify_graph.variables.utils import dumps_with_segments +from graphon.variables.utils import dumps_with_segments class TemplateTransformer(ABC): diff --git a/api/core/helper/encrypter.py b/api/core/helper/encrypter.py index 17345dc203b..20125ec6b30 100644 --- a/api/core/helper/encrypter.py +++ b/api/core/helper/encrypter.py @@ -19,7 +19,7 @@ def encrypt_token(tenant_id: str, token: str): from extensions.ext_database import db from models.account import Tenant - if not (tenant := db.session.query(Tenant).where(Tenant.id == tenant_id).first()): + if not (tenant := db.session.get(Tenant, tenant_id)): raise ValueError(f"Tenant with id {tenant_id} not found") assert tenant.encrypt_public_key is not None encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key) diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index 873f6a40930..a1e782a094e 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -2,12 +2,13 @@ import logging import secrets from typing import cast +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeBadRequestError +from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities import DEFAULT_PLUGIN_ID -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError -from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory from extensions.ext_hosting_provider import hosting_configuration from models.provider import ProviderType @@ -41,7 +42,7 @@ def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEnt text_chunk = secrets.choice(text_chunks) try: - model_provider_factory = ModelProviderFactory(tenant_id) + model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id) # Get model instance of LLM model_type_instance = model_provider_factory.get_model_type_instance( diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index 600a4443578..60f5434bc1e 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -1,10 +1,10 @@ from flask import Flask +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel from configs import dify_config from core.entities import DEFAULT_PLUGIN_ID from core.entities.provider_entities import ProviderQuotaType, QuotaUnit, RestrictModel -from dify_graph.model_runtime.entities.model_entities import ModelType class HostingQuota(BaseModel): diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 52776ee626f..3ec17bc9864 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -9,6 +9,7 @@ from collections.abc import Mapping from typing import Any from flask import Flask, current_app +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import select from sqlalchemy.orm.exc import ObjectDeletedError @@ -21,7 +22,7 @@ from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import ChildDocument, Document @@ -31,7 +32,6 @@ from core.rag.splitter.fixed_text_splitter import ( ) from core.rag.splitter.text_splitter import TextSplitter from core.tools.utils.web_reader_tool import get_image_upload_file_ids -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage @@ -50,7 +50,10 @@ logger = logging.getLogger(__name__) class IndexingRunner: def __init__(self): self.storage = storage - self.model_manager = ModelManager() + + @staticmethod + def _get_model_manager(tenant_id: str) -> ModelManager: + return ModelManager.for_tenant(tenant_id=tenant_id) def _handle_indexing_error(self, document_id: str, error: Exception) -> None: """Handle indexing errors by updating document status.""" @@ -271,7 +274,7 @@ class IndexingRunner: doc_form: str | None = None, doc_language: str = "English", dataset_id: str | None = None, - indexing_technique: str = "economy", + indexing_technique: str = IndexTechniqueType.ECONOMY, ) -> IndexingEstimate: """ Estimate the indexing for the document. @@ -289,22 +292,22 @@ class IndexingRunner: dataset = db.session.query(Dataset).filter_by(id=dataset_id).first() if not dataset: raise ValueError("Dataset not found.") - if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality": + if IndexTechniqueType.HIGH_QUALITY in {dataset.indexing_technique, indexing_technique}: if dataset.embedding_model_provider: - embedding_model_instance = self.model_manager.get_model_instance( + embedding_model_instance = self._get_model_manager(tenant_id).get_model_instance( tenant_id=tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, ) else: - embedding_model_instance = self.model_manager.get_default_model_instance( + embedding_model_instance = self._get_model_manager(tenant_id).get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING, ) else: - if indexing_technique == "high_quality": - embedding_model_instance = self.model_manager.get_default_model_instance( + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: + embedding_model_instance = self._get_model_manager(tenant_id).get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING, ) @@ -573,8 +576,8 @@ class IndexingRunner: """ embedding_model_instance = None - if dataset.indexing_technique == "high_quality": - embedding_model_instance = self.model_manager.get_model_instance( + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: + embedding_model_instance = self._get_model_manager(dataset.tenant_id).get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, @@ -587,7 +590,7 @@ class IndexingRunner: create_keyword_thread = None if ( dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX - and dataset.indexing_technique == "economy" + and dataset.indexing_technique == IndexTechniqueType.ECONOMY ): # create keyword index create_keyword_thread = threading.Thread( @@ -597,7 +600,7 @@ class IndexingRunner: create_keyword_thread.start() max_workers = 10 - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [] @@ -628,7 +631,7 @@ class IndexingRunner: tokens += future.result() if ( dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX - and dataset.indexing_technique == "economy" + and dataset.indexing_technique == IndexTechniqueType.ECONOMY and create_keyword_thread is not None ): create_keyword_thread.join() @@ -654,7 +657,7 @@ class IndexingRunner: raise ValueError("no dataset found") keyword = Keyword(dataset) keyword.create(documents) - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: document_ids = [document.metadata["doc_id"] for document in documents] db.session.query(DocumentSegment).where( DocumentSegment.document_id == document_id, @@ -764,16 +767,16 @@ class IndexingRunner: ) -> list[Document]: # get embedding model instance embedding_model_instance = None - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: if dataset.embedding_model_provider: - embedding_model_instance = self.model_manager.get_model_instance( + embedding_model_instance = self._get_model_manager(dataset.tenant_id).get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, ) else: - embedding_model_instance = self.model_manager.get_default_model_instance( + embedding_model_instance = self._get_model_manager(dataset.tenant_id).get_default_model_instance( tenant_id=dataset.tenant_id, model_type=ModelType.TEXT_EMBEDDING, ) diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index c8848336d97..d39630ad951 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -5,6 +5,12 @@ from collections.abc import Sequence from typing import Protocol, cast import json_repair +from graphon.enums import WorkflowNodeExecutionMetadataKey +from graphon.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from sqlalchemy import select from core.app.app_config.entities import ModelConfig from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload @@ -27,11 +33,6 @@ from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey -from dify_graph.model_runtime.entities.llm_entities import LLMResult -from dify_graph.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from extensions.ext_database import db from extensions.ext_storage import storage from models import App, Message, WorkflowNodeExecutionModel @@ -62,7 +63,7 @@ class LLMGenerator: prompt += query + "\n" - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -120,7 +121,7 @@ class LLMGenerator: prompt = prompt_template.format({"histories": histories, "format_instructions": format_instructions}) try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -172,7 +173,7 @@ class LLMGenerator: prompt_messages = [UserPromptMessage(content=prompt_generate)] - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, @@ -219,7 +220,7 @@ class LLMGenerator: prompt_messages = [UserPromptMessage(content=prompt_generate_prompt)] # get model instance - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -306,7 +307,7 @@ class LLMGenerator: remove_template_variables=False, ) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -337,7 +338,7 @@ class LLMGenerator: def generate_qa_document(cls, tenant_id: str, query, document_language: str): prompt = GENERATOR_QA_PROMPT.format(language=document_language) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -362,7 +363,7 @@ class LLMGenerator: @classmethod def generate_structured_output(cls, tenant_id: str, args: RuleStructuredOutputPayload): - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -410,8 +411,8 @@ class LLMGenerator: model_config: ModelConfig, ideal_output: str | None, ): - last_run: Message | None = ( - db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first() + last_run: Message | None = db.session.scalar( + select(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).limit(1) ) if not last_run: return LLMGenerator.__instruction_modify_common( @@ -536,7 +537,7 @@ class LLMGenerator: injected_instruction = injected_instruction.replace(CURRENT, current or "null") if ERROR_MESSAGE in injected_instruction: injected_instruction = injected_instruction.replace(ERROR_MESSAGE, error_message or "null") - model_instance = ModelManager().get_model_instance( + model_instance = ModelManager.for_tenant(tenant_id=tenant_id).get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py index 77ea1713ea5..a1710f11ace 100644 --- a/api/core/llm_generator/output_parser/structured_output.py +++ b/api/core/llm_generator/output_parser/structured_output.py @@ -5,27 +5,27 @@ from enum import StrEnum from typing import Any, Literal, cast, overload import json_repair -from pydantic import TypeAdapter, ValidationError - -from core.llm_generator.output_parser.errors import OutputParserError -from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT -from core.model_manager import ModelInstance -from dify_graph.model_runtime.callbacks.base_callback import Callback -from dify_graph.model_runtime.entities.llm_entities import ( +from graphon.model_runtime.callbacks.base_callback import Callback +from graphon.model_runtime.entities.llm_entities import ( LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMResultChunkWithStructuredOutput, LLMResultWithStructuredOutput, ) -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, PromptMessageTool, SystemPromptMessage, TextPromptMessageContent, ) -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ParameterRule +from graphon.model_runtime.entities.model_entities import AIModelEntity, ParameterRule +from pydantic import TypeAdapter, ValidationError + +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT +from core.model_manager import ModelInstance class ResponseFormat(StrEnum): @@ -55,7 +55,6 @@ def invoke_llm_with_structured_output( tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: Literal[True], - user: str | None = None, callbacks: list[Callback] | None = None, ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... @overload @@ -70,7 +69,6 @@ def invoke_llm_with_structured_output( tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: Literal[False], - user: str | None = None, callbacks: list[Callback] | None = None, ) -> LLMResultWithStructuredOutput: ... @overload @@ -85,7 +83,6 @@ def invoke_llm_with_structured_output( tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ... def invoke_llm_with_structured_output( @@ -99,7 +96,6 @@ def invoke_llm_with_structured_output( tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: """ @@ -113,7 +109,6 @@ def invoke_llm_with_structured_output( :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id :param callbacks: callbacks :return: full response or stream response chunk generator result """ @@ -143,7 +138,6 @@ def invoke_llm_with_structured_output( tools=tools, stop=stop, stream=stream, - user=user, callbacks=callbacks, ) diff --git a/api/core/mcp/server/streamable_http.py b/api/core/mcp/server/streamable_http.py index de68eb268b6..27000c947c1 100644 --- a/api/core/mcp/server/streamable_http.py +++ b/api/core/mcp/server/streamable_http.py @@ -3,11 +3,12 @@ import logging from collections.abc import Mapping from typing import Any, cast +from graphon.variables.input_entities import VariableEntity, VariableEntityType + from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from core.mcp import types as mcp_types -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType from models.model import App, AppMCPServer, AppMode, EndUser from services.app_generate_service import AppGenerateService diff --git a/api/core/mcp/utils.py b/api/core/mcp/utils.py index db9cb726d7c..7e350441768 100644 --- a/api/core/mcp/utils.py +++ b/api/core/mcp/utils.py @@ -4,11 +4,11 @@ from contextlib import AbstractContextManager import httpx import httpx_sse +from graphon.model_runtime.utils.encoders import jsonable_encoder from httpx_sse import connect_sse from configs import dify_config from core.mcp.types import ErrorData, JSONRPCError -from dify_graph.model_runtime.utils.encoders import jsonable_encoder HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 1156a98af17..09c84538a9a 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -1,13 +1,7 @@ from collections.abc import Sequence -from sqlalchemy import select -from sqlalchemy.orm import sessionmaker - -from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.model_manager import ModelInstance -from core.prompt.utils.extract_thread_messages import extract_thread_messages -from dify_graph.file import file_manager -from dify_graph.model_runtime.entities import ( +from graphon.file import file_manager +from graphon.model_runtime.entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessage, @@ -15,7 +9,14 @@ from dify_graph.model_runtime.entities import ( TextPromptMessageContent, UserPromptMessage, ) -from dify_graph.model_runtime.entities.message_entities import PromptMessageContentUnionTypes +from graphon.model_runtime.entities.message_entities import PromptMessageContentUnionTypes +from sqlalchemy import select +from sqlalchemy.orm import sessionmaker + +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.file_access import DatabaseFileAccessController +from core.model_manager import ModelInstance +from core.prompt.utils.extract_thread_messages import extract_thread_messages from extensions.ext_database import db from factories import file_factory from models.model import AppMode, Conversation, Message, MessageFile @@ -23,6 +24,8 @@ from models.workflow import Workflow from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.factory import DifyAPIRepositoryFactory +_file_access_controller = DatabaseFileAccessController() + class TokenBufferMemory: def __init__( @@ -85,7 +88,10 @@ class TokenBufferMemory: # Build files directly without filtering by belongs_to file_objs = [ file_factory.build_from_message_file( - message_file=message_file, tenant_id=app_record.tenant_id, config=file_extra_config + message_file=message_file, + tenant_id=app_record.tenant_id, + config=file_extra_config, + access_controller=_file_access_controller, ) for message_file in message_files ] diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 0f710a8fcf4..87d1d7fba60 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -2,25 +2,27 @@ import logging from collections.abc import Callable, Generator, Iterable, Mapping, Sequence from typing import IO, Any, Literal, Optional, Union, cast, overload +from graphon.model_runtime.callbacks.base_callback import Callback +from graphon.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelType +from graphon.model_runtime.entities.rerank_entities import RerankResult +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel +from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel +from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel +from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from graphon.model_runtime.model_providers.__base.tts_model import TTSModel + from configs import dify_config from core.entities.embedding_type import EmbeddingInputType from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.errors.error import ProviderTokenNotInitError +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.provider_manager import ProviderManager -from dify_graph.model_runtime.callbacks.base_callback import Callback -from dify_graph.model_runtime.entities.llm_entities import LLMResult -from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType -from dify_graph.model_runtime.entities.rerank_entities import RerankResult -from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel -from dify_graph.model_runtime.model_providers.__base.rerank_model import RerankModel -from dify_graph.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from dify_graph.model_runtime.model_providers.__base.tts_model import TTSModel from extensions.ext_redis import redis_client from models.provider import ProviderType from services.enterprise.plugin_manager_service import PluginCredentialType @@ -30,7 +32,7 @@ logger = logging.getLogger(__name__) class ModelInstance: """ - Model instance class + Model instance class. """ def __init__(self, provider_model_bundle: ProviderModelBundle, model: str): @@ -49,6 +51,13 @@ class ModelInstance: credentials=self.credentials, ) + def get_model_schema(self) -> AIModelEntity: + """Return the resolved schema for the current model instance.""" + model_schema = self.model_type_instance.get_model_schema(self.model_name, self.credentials) + if model_schema is None: + raise ValueError(f"model schema not found for {self.model_name}") + return model_schema + @staticmethod def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str): """ @@ -110,7 +119,6 @@ class ModelInstance: tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: Literal[True] = True, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> Generator: ... @@ -122,7 +130,6 @@ class ModelInstance: tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: Literal[False] = False, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> LLMResult: ... @@ -134,7 +141,6 @@ class ModelInstance: tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> Union[LLMResult, Generator]: ... @@ -145,7 +151,6 @@ class ModelInstance: tools: Sequence[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> Union[LLMResult, Generator]: """ @@ -156,7 +161,6 @@ class ModelInstance: :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id :param callbacks: callbacks :return: full response or stream response chunk generator result """ @@ -173,7 +177,6 @@ class ModelInstance: tools=tools, stop=stop, stream=stream, - user=user, callbacks=callbacks, ), ) @@ -202,13 +205,12 @@ class ModelInstance: ) def invoke_text_embedding( - self, texts: list[str], user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT + self, texts: list[str], input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT ) -> EmbeddingResult: """ Invoke large language model :param texts: texts to embed - :param user: unique user id :param input_type: input type :return: embeddings result """ @@ -221,7 +223,6 @@ class ModelInstance: model=self.model_name, credentials=self.credentials, texts=texts, - user=user, input_type=input_type, ), ) @@ -229,14 +230,12 @@ class ModelInstance: def invoke_multimodal_embedding( self, multimodel_documents: list[dict], - user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, ) -> EmbeddingResult: """ Invoke large language model :param multimodel_documents: multimodel documents to embed - :param user: unique user id :param input_type: input type :return: embeddings result """ @@ -249,7 +248,6 @@ class ModelInstance: model=self.model_name, credentials=self.credentials, multimodel_documents=multimodel_documents, - user=user, input_type=input_type, ), ) @@ -279,7 +277,6 @@ class ModelInstance: docs: list[str], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, ) -> RerankResult: """ Invoke rerank model @@ -288,7 +285,6 @@ class ModelInstance: :param docs: docs for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id :return: rerank result """ if not isinstance(self.model_type_instance, RerankModel): @@ -303,7 +299,6 @@ class ModelInstance: docs=docs, score_threshold=score_threshold, top_n=top_n, - user=user, ), ) @@ -313,7 +308,6 @@ class ModelInstance: docs: list[dict], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, ) -> RerankResult: """ Invoke rerank model @@ -322,7 +316,6 @@ class ModelInstance: :param docs: docs for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id :return: rerank result """ if not isinstance(self.model_type_instance, RerankModel): @@ -337,16 +330,14 @@ class ModelInstance: docs=docs, score_threshold=score_threshold, top_n=top_n, - user=user, ), ) - def invoke_moderation(self, text: str, user: str | None = None) -> bool: + def invoke_moderation(self, text: str) -> bool: """ Invoke moderation model :param text: text to moderate - :param user: unique user id :return: false if text is safe, true otherwise """ if not isinstance(self.model_type_instance, ModerationModel): @@ -358,16 +349,14 @@ class ModelInstance: model=self.model_name, credentials=self.credentials, text=text, - user=user, ), ) - def invoke_speech2text(self, file: IO[bytes], user: str | None = None) -> str: + def invoke_speech2text(self, file: IO[bytes]) -> str: """ Invoke large language model :param file: audio file - :param user: unique user id :return: text for given audio file """ if not isinstance(self.model_type_instance, Speech2TextModel): @@ -379,18 +368,15 @@ class ModelInstance: model=self.model_name, credentials=self.credentials, file=file, - user=user, ), ) - def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: str | None = None) -> Iterable[bytes]: + def invoke_tts(self, content_text: str, voice: str = "") -> Iterable[bytes]: """ Invoke large language tts model :param content_text: text content to be translated - :param tenant_id: user tenant id :param voice: model timbre - :param user: unique user id :return: text for given audio file """ if not isinstance(self.model_type_instance, TTSModel): @@ -402,8 +388,6 @@ class ModelInstance: model=self.model_name, credentials=self.credentials, content_text=content_text, - user=user, - tenant_id=tenant_id, voice=voice, ), ) @@ -477,10 +461,20 @@ class ModelInstance: class ModelManager: - def __init__(self): - self._provider_manager = ProviderManager() + def __init__(self, provider_manager: ProviderManager): + self._provider_manager = provider_manager - def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance: + @classmethod + def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager": + return cls(provider_manager=create_plugin_provider_manager(tenant_id=tenant_id, user_id=user_id)) + + def get_model_instance( + self, + tenant_id: str, + provider: str, + model_type: ModelType, + model: str, + ) -> ModelInstance: """ Get model instance :param tenant_id: tenant id @@ -496,7 +490,8 @@ class ModelManager: tenant_id=tenant_id, provider=provider, model_type=model_type ) - return ModelInstance(provider_model_bundle, model) + model_instance = ModelInstance(provider_model_bundle, model) + return model_instance def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]: """ diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index 06676f5cf44..dd038c77f13 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -1,6 +1,7 @@ +from graphon.model_runtime.entities.model_entities import ModelType + from core.model_manager import ModelManager from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult -from dify_graph.model_runtime.entities.model_entities import ModelType class OpenAIModeration(Moderation): @@ -50,7 +51,7 @@ class OpenAIModeration(Moderation): def _is_violated(self, inputs: dict): text = "\n".join(str(inputs.values())) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id) model_instance = model_manager.get_model_instance( tenant_id=self.tenant_id, provider="openai", model_type=ModelType.MODERATION, model="omni-moderation-latest" ) diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py index 18f35b5b9c6..70aaf2a07be 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/core/ops/aliyun_trace/aliyun_trace.py @@ -1,6 +1,8 @@ import logging from collections.abc import Sequence +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from opentelemetry.trace import SpanKind from sqlalchemy.orm import sessionmaker @@ -57,8 +59,6 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.repositories import DifyCoreRepositoryFactory -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from extensions.ext_database import db from models import WorkflowNodeExecutionTriggeredFrom @@ -296,7 +296,9 @@ class AliyunDataTrace(BaseTraceInstance): triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) - return workflow_node_execution_repository.get_by_workflow_run(workflow_run_id=trace_info.workflow_run_id) + return workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id + ) def build_workflow_node_span( self, node_execution: WorkflowNodeExecution, trace_info: WorkflowTraceInfo, trace_metadata: TraceMetadata diff --git a/api/core/ops/aliyun_trace/data_exporter/traceclient.py b/api/core/ops/aliyun_trace/data_exporter/traceclient.py index 0e00e905200..67d5163b0f4 100644 --- a/api/core/ops/aliyun_trace/data_exporter/traceclient.py +++ b/api/core/ops/aliyun_trace/data_exporter/traceclient.py @@ -16,7 +16,13 @@ from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExport from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import ReadableSpan from opentelemetry.sdk.util.instrumentation import InstrumentationScope -from opentelemetry.semconv.resource import ResourceAttributes +from opentelemetry.semconv._incubating.attributes.deployment_attributes import ( # type: ignore[import-untyped] + DEPLOYMENT_ENVIRONMENT, +) +from opentelemetry.semconv._incubating.attributes.host_attributes import ( # type: ignore[import-untyped] + HOST_NAME, +) +from opentelemetry.semconv.attributes import service_attributes from opentelemetry.trace import Link, SpanContext, TraceFlags from configs import dify_config @@ -45,10 +51,10 @@ class TraceClient: self.endpoint = endpoint self.resource = Resource( attributes={ - ResourceAttributes.SERVICE_NAME: service_name, - ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}", - ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}", - ResourceAttributes.HOST_NAME: socket.gethostname(), + service_attributes.SERVICE_NAME: service_name, + service_attributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}", + DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}", + HOST_NAME: socket.gethostname(), ACS_ARMS_SERVICE_FEATURE: "genai_app", } ) diff --git a/api/core/ops/aliyun_trace/utils.py b/api/core/ops/aliyun_trace/utils.py index 45319f24c1a..d8e105d6a32 100644 --- a/api/core/ops/aliyun_trace/utils.py +++ b/api/core/ops/aliyun_trace/utils.py @@ -2,6 +2,8 @@ import json from collections.abc import Mapping from typing import Any +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionStatus from opentelemetry.trace import Link, Status, StatusCode from core.ops.aliyun_trace.entities.semconv import ( @@ -14,8 +16,6 @@ from core.ops.aliyun_trace.entities.semconv import ( GenAISpanKind, ) from core.rag.models.document import Document -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import WorkflowNodeExecutionStatus from extensions.ext_database import db from models import EndUser @@ -27,9 +27,7 @@ DEFAULT_FRAMEWORK_NAME = "dify" def get_user_id_from_message_data(message_data) -> str: user_id = message_data.from_account_id if message_data.from_end_user_id: - end_user_data: EndUser | None = ( - db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() - ) + end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id) if end_user_data is not None: user_id = end_user_data.session_id return user_id diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index 7cb54b2c884..902f58e6b7b 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -6,6 +6,7 @@ from datetime import datetime, timedelta from typing import Any, Union, cast from urllib.parse import urlparse +from graphon.enums import WorkflowNodeExecutionStatus from openinference.semconv.trace import ( MessageAttributes, OpenInferenceMimeTypeValues, @@ -18,7 +19,7 @@ from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExport from opentelemetry.sdk import trace as trace_sdk from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace.export import SimpleSpanProcessor -from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes +from opentelemetry.semconv.attributes import exception_attributes from opentelemetry.trace import Span, Status, StatusCode, set_span_in_context, use_span from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from opentelemetry.util.types import AttributeValue @@ -133,10 +134,10 @@ def set_span_status(current_span: Span, error: Exception | str | None = None): if not exception_message: exception_message = repr(error) attributes: dict[str, AttributeValue] = { - OTELSpanAttributes.EXCEPTION_TYPE: exception_type, - OTELSpanAttributes.EXCEPTION_MESSAGE: exception_message, - OTELSpanAttributes.EXCEPTION_ESCAPED: False, - OTELSpanAttributes.EXCEPTION_STACKTRACE: error_string, + exception_attributes.EXCEPTION_TYPE: exception_type, + exception_attributes.EXCEPTION_MESSAGE: exception_message, + exception_attributes.EXCEPTION_ESCAPED: False, + exception_attributes.EXCEPTION_STACKTRACE: error_string, } current_span.add_event(name="exception", attributes=attributes) else: @@ -181,10 +182,6 @@ class ArizePhoenixDataTrace(BaseTraceInstance): arize_phoenix_config: ArizeConfig | PhoenixConfig, ): super().__init__(arize_phoenix_config) - import logging - - logging.basicConfig() - logging.getLogger().setLevel(logging.DEBUG) self.arize_phoenix_config = arize_phoenix_config self.tracer, self.processor = setup_tracer(arize_phoenix_config) self.project = arize_phoenix_config.project @@ -275,8 +272,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): ) # Get all executions for this workflow run - workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( - workflow_run_id=trace_info.workflow_run_id + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id ) try: @@ -304,7 +301,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): "app_name": node_execution.title, "status": node_execution.status, "status_message": node_execution.error or "", - "level": "ERROR" if node_execution.status == "failed" else "DEFAULT", + "level": "ERROR" if node_execution.status == WorkflowNodeExecutionStatus.FAILED else "DEFAULT", } ) @@ -365,7 +362,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): llm_attributes.update(self._construct_llm_attributes(process_data.get("prompts", []))) node_span.set_attributes(llm_attributes) finally: - if node_execution.status == "failed": + if node_execution.status == WorkflowNodeExecutionStatus.FAILED: set_span_status(node_span, node_execution.error) else: set_span_status(node_span) @@ -413,9 +410,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): # Add end user data if available if trace_info.message_data.from_end_user_id: - end_user_data: EndUser | None = ( - db.session.query(EndUser).where(EndUser.id == trace_info.message_data.from_end_user_id).first() - ) + end_user_data: EndUser | None = db.session.get(EndUser, trace_info.message_data.from_end_user_id) if end_user_data is not None: metadata["end_user_id"] = end_user_data.session_id diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index 50a2cdea633..45b2f635bae 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -9,8 +9,8 @@ from pydantic import BaseModel, ConfigDict, field_serializer, field_validator class BaseTraceInfo(BaseModel): message_id: str | None = None message_data: Any | None = None - inputs: Union[str, dict[str, Any], list] | None = None - outputs: Union[str, dict[str, Any], list] | None = None + inputs: Union[str, dict[str, Any], list[Any]] | None = None + outputs: Union[str, dict[str, Any], list[Any]] | None = None start_time: datetime | None = None end_time: datetime | None = None metadata: dict[str, Any] @@ -18,7 +18,7 @@ class BaseTraceInfo(BaseModel): @field_validator("inputs", "outputs") @classmethod - def ensure_type(cls, v): + def ensure_type(cls, v: str | dict[str, Any] | list[Any] | None) -> str | dict[str, Any] | list[Any] | None: if v is None: return None if isinstance(v, str | dict | list): @@ -27,6 +27,48 @@ class BaseTraceInfo(BaseModel): model_config = ConfigDict(protected_namespaces=()) + @property + def resolved_trace_id(self) -> str | None: + """Get trace_id with intelligent fallback. + + Priority: + 1. External trace_id (from X-Trace-Id header) + 2. workflow_run_id (if this trace type has it) + 3. message_id (as final fallback) + """ + if self.trace_id: + return self.trace_id + + # Try workflow_run_id (only exists on workflow-related traces) + workflow_run_id = getattr(self, "workflow_run_id", None) + if workflow_run_id: + return workflow_run_id + + # Final fallback to message_id + return str(self.message_id) if self.message_id else None + + @property + def resolved_parent_context(self) -> tuple[str | None, str | None]: + """Resolve cross-workflow parent linking from metadata. + + Extracts typed parent IDs from the untyped ``parent_trace_context`` + metadata dict (set by tool_node when invoking nested workflows). + + Returns: + (trace_correlation_override, parent_span_id_source) where + trace_correlation_override is the outer workflow_run_id and + parent_span_id_source is the outer node_execution_id. + """ + parent_ctx = self.metadata.get("parent_trace_context") + if not isinstance(parent_ctx, dict): + return None, None + trace_override = parent_ctx.get("parent_workflow_run_id") + parent_span = parent_ctx.get("parent_node_execution_id") + return ( + trace_override if isinstance(trace_override, str) else None, + parent_span if isinstance(parent_span, str) else None, + ) + @field_serializer("start_time", "end_time") def serialize_datetime(self, dt: datetime | None) -> str | None: if dt is None: @@ -48,7 +90,10 @@ class WorkflowTraceInfo(BaseTraceInfo): workflow_run_version: str error: str | None = None total_tokens: int + prompt_tokens: int | None = None + completion_tokens: int | None = None file_list: list[str] + invoked_by: str | None = None query: str metadata: dict[str, Any] @@ -59,7 +104,7 @@ class MessageTraceInfo(BaseTraceInfo): answer_tokens: int total_tokens: int error: str | None = None - file_list: Union[str, dict[str, Any], list] | None = None + file_list: Union[str, dict[str, Any], list[Any]] | None = None message_file_data: Any | None = None conversation_mode: str gen_ai_server_time_to_first_token: float | None = None @@ -106,7 +151,7 @@ class ToolTraceInfo(BaseTraceInfo): tool_config: dict[str, Any] time_cost: Union[int, float] tool_parameters: dict[str, Any] - file_url: Union[str, None, list] = None + file_url: Union[str, None, list[str]] = None class GenerateNameTraceInfo(BaseTraceInfo): @@ -114,6 +159,79 @@ class GenerateNameTraceInfo(BaseTraceInfo): tenant_id: str +class PromptGenerationTraceInfo(BaseTraceInfo): + """Trace information for prompt generation operations (rule-generate, code-generate, etc.).""" + + tenant_id: str + user_id: str + app_id: str | None = None + + operation_type: str + instruction: str + + prompt_tokens: int + completion_tokens: int + total_tokens: int + + model_provider: str + model_name: str + + latency: float + + total_price: float | None = None + currency: str | None = None + + error: str | None = None + + model_config = ConfigDict(protected_namespaces=()) + + +class WorkflowNodeTraceInfo(BaseTraceInfo): + workflow_id: str + workflow_run_id: str + tenant_id: str + node_execution_id: str + node_id: str + node_type: str + title: str + + status: str + error: str | None = None + elapsed_time: float + + index: int + predecessor_node_id: str | None = None + + total_tokens: int = 0 + total_price: float = 0.0 + currency: str | None = None + + model_provider: str | None = None + model_name: str | None = None + prompt_tokens: int | None = None + completion_tokens: int | None = None + + tool_name: str | None = None + + iteration_id: str | None = None + iteration_index: int | None = None + loop_id: str | None = None + loop_index: int | None = None + parallel_id: str | None = None + + node_inputs: Mapping[str, Any] | None = None + node_outputs: Mapping[str, Any] | None = None + process_data: Mapping[str, Any] | None = None + + invoked_by: str | None = None + + model_config = ConfigDict(protected_namespaces=()) + + +class DraftNodeExecutionTrace(WorkflowNodeTraceInfo): + pass + + class TaskData(BaseModel): app_id: str trace_info_type: str @@ -128,11 +246,31 @@ trace_info_info_map = { "DatasetRetrievalTraceInfo": DatasetRetrievalTraceInfo, "ToolTraceInfo": ToolTraceInfo, "GenerateNameTraceInfo": GenerateNameTraceInfo, + "PromptGenerationTraceInfo": PromptGenerationTraceInfo, + "WorkflowNodeTraceInfo": WorkflowNodeTraceInfo, + "DraftNodeExecutionTrace": DraftNodeExecutionTrace, } +class OperationType(StrEnum): + """Operation type for token metric labels. + + Used as a metric attribute on ``dify.tokens.input`` / ``dify.tokens.output`` + counters so consumers can break down token usage by operation. + """ + + WORKFLOW = "workflow" + NODE_EXECUTION = "node_execution" + MESSAGE = "message" + RULE_GENERATE = "rule_generate" + CODE_GENERATE = "code_generate" + STRUCTURED_OUTPUT = "structured_output" + INSTRUCTION_MODIFY = "instruction_modify" + + class TraceTaskName(StrEnum): CONVERSATION_TRACE = "conversation" + DRAFT_NODE_EXECUTION_TRACE = "draft_node_execution" WORKFLOW_TRACE = "workflow" MESSAGE_TRACE = "message" MODERATION_TRACE = "moderation" @@ -140,4 +278,6 @@ class TraceTaskName(StrEnum): DATASET_RETRIEVAL_TRACE = "dataset_retrieval" TOOL_TRACE = "tool" GENERATE_NAME_TRACE = "generate_conversation_name" + PROMPT_GENERATION_TRACE = "prompt_generation" + NODE_EXECUTION_TRACE = "node_execution" DATASOURCE_TRACE = "datasource" diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 6e62387a1ff..9be2ce1bdfc 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -1,8 +1,19 @@ import logging import os -from datetime import datetime, timedelta +import uuid +from datetime import UTC, datetime, timedelta +from graphon.enums import BuiltinNodeTypes from langfuse import Langfuse +from langfuse.api import ( + CreateGenerationBody, + CreateSpanBody, + IngestionEvent_GenerationCreate, + IngestionEvent_SpanCreate, + IngestionEvent_TraceCreate, + TraceBody, +) +from langfuse.api.commons.types.usage import Usage from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance @@ -28,7 +39,6 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( ) from core.ops.utils import filter_none_values from core.repositories import DifyCoreRepositoryFactory -from dify_graph.enums import BuiltinNodeTypes from extensions.ext_database import db from models import EndUser, WorkflowNodeExecutionTriggeredFrom from models.enums import MessageStatus @@ -130,8 +140,8 @@ class LangFuseDataTrace(BaseTraceInstance): ) # Get all executions for this workflow run - workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( - workflow_run_id=trace_info.workflow_run_id + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id ) for node_execution in workflow_node_executions: @@ -241,9 +251,7 @@ class LangFuseDataTrace(BaseTraceInstance): user_id = message_data.from_account_id if message_data.from_end_user_id: - end_user_data: EndUser | None = ( - db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() - ) + end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id) if end_user_data is not None: user_id = end_user_data.session_id metadata["user_id"] = user_id @@ -398,18 +406,61 @@ class LangFuseDataTrace(BaseTraceInstance): ) self.add_span(langfuse_span_data=name_generation_span_data) + def _make_event_id(self) -> str: + return str(uuid.uuid4()) + + def _now_iso(self) -> str: + return datetime.now(UTC).isoformat() + def add_trace(self, langfuse_trace_data: LangfuseTrace | None = None): - format_trace_data = filter_none_values(langfuse_trace_data.model_dump()) if langfuse_trace_data else {} + data = filter_none_values(langfuse_trace_data.model_dump()) if langfuse_trace_data else {} try: - self.langfuse_client.trace(**format_trace_data) + body = TraceBody( + id=data.get("id"), + name=data.get("name"), + user_id=data.get("user_id"), + input=data.get("input"), + output=data.get("output"), + metadata=data.get("metadata"), + session_id=data.get("session_id"), + version=data.get("version"), + release=data.get("release"), + tags=data.get("tags"), + public=data.get("public"), + ) + event = IngestionEvent_TraceCreate( + body=body, + id=self._make_event_id(), + timestamp=self._now_iso(), + ) + self.langfuse_client.api.ingestion.batch(batch=[event]) logger.debug("LangFuse Trace created successfully") except Exception as e: raise ValueError(f"LangFuse Failed to create trace: {str(e)}") def add_span(self, langfuse_span_data: LangfuseSpan | None = None): - format_span_data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {} + data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {} try: - self.langfuse_client.span(**format_span_data) + body = CreateSpanBody( + id=data.get("id"), + trace_id=data.get("trace_id"), + name=data.get("name"), + start_time=data.get("start_time"), + end_time=data.get("end_time"), + input=data.get("input"), + output=data.get("output"), + metadata=data.get("metadata"), + level=data.get("level"), + status_message=data.get("status_message"), + parent_observation_id=data.get("parent_observation_id"), + version=data.get("version"), + ) + event = IngestionEvent_SpanCreate( + body=body, + id=self._make_event_id(), + timestamp=self._now_iso(), + ) + self.langfuse_client.api.ingestion.batch(batch=[event]) logger.debug("LangFuse Span created successfully") except Exception as e: raise ValueError(f"LangFuse Failed to create span: {str(e)}") @@ -420,11 +471,45 @@ class LangFuseDataTrace(BaseTraceInstance): span.end(**format_span_data) def add_generation(self, langfuse_generation_data: LangfuseGeneration | None = None): - format_generation_data = ( - filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {} - ) + data = filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {} try: - self.langfuse_client.generation(**format_generation_data) + usage_data = data.pop("usage", None) + usage = None + if usage_data: + usage = Usage( + input=usage_data.get("input", 0) or 0, + output=usage_data.get("output", 0) or 0, + total=usage_data.get("total", 0) or 0, + unit=usage_data.get("unit"), + input_cost=usage_data.get("inputCost"), + output_cost=usage_data.get("outputCost"), + total_cost=usage_data.get("totalCost"), + ) + + body = CreateGenerationBody( + id=data.get("id"), + trace_id=data.get("trace_id"), + name=data.get("name"), + start_time=data.get("start_time"), + end_time=data.get("end_time"), + model=data.get("model"), + model_parameters=data.get("model_parameters"), + input=data.get("input"), + output=data.get("output"), + usage=usage, + metadata=data.get("metadata"), + level=data.get("level"), + status_message=data.get("status_message"), + parent_observation_id=data.get("parent_observation_id"), + version=data.get("version"), + completion_start_time=data.get("completion_start_time"), + ) + event = IngestionEvent_GenerationCreate( + body=body, + id=self._make_event_id(), + timestamp=self._now_iso(), + ) + self.langfuse_client.api.ingestion.batch(batch=[event]) logger.debug("LangFuse Generation created successfully") except Exception as e: raise ValueError(f"LangFuse Failed to create generation: {str(e)}") @@ -445,7 +530,7 @@ class LangFuseDataTrace(BaseTraceInstance): def get_project_key(self): try: - projects = self.langfuse_client.client.projects.get() + projects = self.langfuse_client.api.projects.get() return projects.data[0].id except Exception as e: logger.debug("LangFuse get project key failed: %s", str(e)) diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index 32a0c77fe2a..490c64af84d 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -4,6 +4,7 @@ import uuid from datetime import datetime, timedelta from typing import cast +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from langsmith import Client from langsmith.schemas import RunBase from sqlalchemy.orm import sessionmaker @@ -28,7 +29,6 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( ) from core.ops.utils import filter_none_values, generate_dotted_order from core.repositories import DifyCoreRepositoryFactory -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from extensions.ext_database import db from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom @@ -152,8 +152,8 @@ class LangSmithDataTrace(BaseTraceInstance): ) # Get all executions for this workflow run - workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( - workflow_run_id=trace_info.workflow_run_id + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id ) for node_execution in workflow_node_executions: @@ -259,9 +259,7 @@ class LangSmithDataTrace(BaseTraceInstance): metadata["user_id"] = user_id if message_data.from_end_user_id: - end_user_data: EndUser | None = ( - db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() - ) + end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id) if end_user_data is not None: end_user_id = end_user_data.session_id metadata["end_user_id"] = end_user_id diff --git a/api/core/ops/mlflow_trace/mlflow_trace.py b/api/core/ops/mlflow_trace/mlflow_trace.py index ab4a7650ec8..946d3cdd479 100644 --- a/api/core/ops/mlflow_trace/mlflow_trace.py +++ b/api/core/ops/mlflow_trace/mlflow_trace.py @@ -5,10 +5,12 @@ from datetime import datetime, timedelta from typing import Any, cast import mlflow +from graphon.enums import BuiltinNodeTypes from mlflow.entities import Document, Span, SpanEvent, SpanStatusCode, SpanType from mlflow.tracing.constant import SpanAttributeKey, TokenUsageKey, TraceMetadataKey from mlflow.tracing.fluent import start_span_no_context, update_current_trace from mlflow.tracing.provider import detach_span_from_context, set_span_in_context +from sqlalchemy import select from core.ops.base_trace_instance import BaseTraceInstance from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig @@ -23,7 +25,6 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from dify_graph.enums import BuiltinNodeTypes from extensions.ext_database import db from models import EndUser from models.workflow import WorkflowNodeExecutionModel @@ -320,7 +321,7 @@ class MLflowDataTrace(BaseTraceInstance): def _get_message_user_id(self, metadata: dict) -> str | None: if (end_user_id := metadata.get("from_end_user_id")) and ( - end_user_data := db.session.query(EndUser).where(EndUser.id == end_user_id).first() + end_user_data := db.session.get(EndUser, end_user_id) ): return end_user_data.session_id @@ -447,25 +448,11 @@ class MLflowDataTrace(BaseTraceInstance): def _get_workflow_nodes(self, workflow_run_id: str): """Helper method to get workflow nodes""" - workflow_nodes = ( - db.session.query( - WorkflowNodeExecutionModel.id, - WorkflowNodeExecutionModel.tenant_id, - WorkflowNodeExecutionModel.app_id, - WorkflowNodeExecutionModel.title, - WorkflowNodeExecutionModel.node_type, - WorkflowNodeExecutionModel.status, - WorkflowNodeExecutionModel.inputs, - WorkflowNodeExecutionModel.outputs, - WorkflowNodeExecutionModel.created_at, - WorkflowNodeExecutionModel.elapsed_time, - WorkflowNodeExecutionModel.process_data, - WorkflowNodeExecutionModel.execution_metadata, - ) - .filter(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id) + workflow_nodes = db.session.scalars( + select(WorkflowNodeExecutionModel) + .where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id) .order_by(WorkflowNodeExecutionModel.created_at) - .all() - ) + ).all() return workflow_nodes def _get_node_span_type(self, node_type: str) -> str: diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index fb72bc23814..2215bdeb33b 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -5,6 +5,7 @@ import uuid from datetime import datetime, timedelta from typing import cast +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from opik import Opik, Trace from opik.id_helpers import uuid4_to_uuid7 from sqlalchemy.orm import sessionmaker @@ -23,7 +24,6 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.repositories import DifyCoreRepositoryFactory -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from extensions.ext_database import db from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom @@ -176,8 +176,8 @@ class OpikDataTrace(BaseTraceInstance): ) # Get all executions for this workflow run - workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( - workflow_run_id=trace_info.workflow_run_id + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id ) for node_execution in workflow_node_executions: @@ -288,9 +288,7 @@ class OpikDataTrace(BaseTraceInstance): metadata["file_list"] = file_list if message_data.from_end_user_id: - end_user_data: EndUser | None = ( - db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() - ) + end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id) if end_user_data is not None: end_user_id = end_user_data.session_id metadata["end_user_id"] = end_user_id diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 9ac753240bb..9c36d57c6f5 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -15,34 +15,179 @@ from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token -from core.ops.entities.config_entity import OPS_FILE_PATH, TracingProviderEnum +from core.ops.entities.config_entity import ( + OPS_FILE_PATH, + TracingProviderEnum, +) from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, + DraftNodeExecutionTrace, GenerateNameTraceInfo, MessageTraceInfo, ModerationTraceInfo, + PromptGenerationTraceInfo, SuggestedQuestionTraceInfo, TaskData, ToolTraceInfo, TraceTaskName, + WorkflowNodeTraceInfo, WorkflowTraceInfo, ) from core.ops.utils import get_message_data +from extensions.ext_database import db from extensions.ext_storage import storage -from models.engine import db +from models.account import Tenant +from models.dataset import Dataset from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig +from models.provider import Provider, ProviderCredential, ProviderModel, ProviderModelCredential, ProviderType +from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider from models.workflow import WorkflowAppLog from tasks.ops_trace_task import process_trace_tasks if TYPE_CHECKING: - from dify_graph.entities import WorkflowExecution + from graphon.entities import WorkflowExecution logger = logging.getLogger(__name__) +def _lookup_app_and_workspace_names(app_id: str | None, tenant_id: str | None) -> tuple[str, str]: + """Return (app_name, workspace_name) for the given IDs. Falls back to empty strings.""" + app_name = "" + workspace_name = "" + if not app_id and not tenant_id: + return app_name, workspace_name + with Session(db.engine) as session: + if app_id: + name = session.scalar(select(App.name).where(App.id == app_id)) + if name: + app_name = name + if tenant_id: + name = session.scalar(select(Tenant.name).where(Tenant.id == tenant_id)) + if name: + workspace_name = name + return app_name, workspace_name + + +_PROVIDER_TYPE_TO_MODEL: dict[str, type] = { + "builtin": BuiltinToolProvider, + "plugin": BuiltinToolProvider, + "api": ApiToolProvider, + "workflow": WorkflowToolProvider, + "mcp": MCPToolProvider, +} + + +def _lookup_credential_name(credential_id: str | None, provider_type: str | None) -> str: + if not credential_id: + return "" + model_cls = _PROVIDER_TYPE_TO_MODEL.get(provider_type or "") + if not model_cls: + return "" + with Session(db.engine) as session: + name = session.scalar(select(model_cls.name).where(model_cls.id == credential_id)) # type: ignore[attr-defined] + return str(name) if name else "" + + +def _lookup_llm_credential_info( + tenant_id: str | None, provider: str | None, model: str | None, model_type: str | None = "llm" +) -> tuple[str | None, str]: + """ + Lookup LLM credential ID and name for the given provider and model. + Returns (credential_id, credential_name). + + Handles async timing issues gracefully - if credential is deleted between lookups, + returns the ID but empty name rather than failing. + """ + if not tenant_id or not provider: + return None, "" + + try: + with Session(db.engine) as session: + # Try to find provider-level or model-level configuration + provider_record = session.scalar( + select(Provider).where( + Provider.tenant_id == tenant_id, + Provider.provider_name == provider, + Provider.provider_type == ProviderType.CUSTOM, + ) + ) + + if not provider_record: + return None, "" + + # Check if there's a model-specific config + credential_id = None + credential_name = "" + is_model_level = False + + if model: + # Try model-level first + model_record = session.scalar( + select(ProviderModel).where( + ProviderModel.tenant_id == tenant_id, + ProviderModel.provider_name == provider, + ProviderModel.model_name == model, + ProviderModel.model_type == model_type, + ) + ) + + if model_record and model_record.credential_id: + credential_id = model_record.credential_id + is_model_level = True + + if not credential_id and provider_record.credential_id: + # Fall back to provider-level credential + credential_id = provider_record.credential_id + is_model_level = False + + # Lookup credential_name if we have credential_id + if credential_id: + try: + if is_model_level: + # Query ProviderModelCredential + cred_name = session.scalar( + select(ProviderModelCredential.credential_name).where( + ProviderModelCredential.id == credential_id + ) + ) + else: + # Query ProviderCredential + cred_name = session.scalar( + select(ProviderCredential.credential_name).where(ProviderCredential.id == credential_id) + ) + + if cred_name: + credential_name = str(cred_name) + except Exception as e: + # Credential might have been deleted between lookups (async timing) + # Return ID but empty name rather than failing + logger.warning( + "Failed to lookup credential name for credential_id=%s (provider=%s, model=%s): %s", + credential_id, + provider, + model, + str(e), + exc_info=True, + ) + + return credential_id, credential_name + except Exception as e: + # Database query failed or other unexpected error + # Return empty rather than propagating error to telemetry emission + logger.warning( + "Failed to lookup LLM credential info for tenant_id=%s, provider=%s, model=%s: %s", + tenant_id, + provider, + model, + str(e), + exc_info=True, + ) + return None, "" + + class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]): - def __getitem__(self, key: str) -> dict[str, Any]: - match key: + def __getitem__(self, provider: str) -> dict[str, Any]: + match provider: case TracingProviderEnum.LANGFUSE: from core.ops.entities.config_entity import LangfuseConfig from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace @@ -149,7 +294,7 @@ class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]): } case _: - raise KeyError(f"Unsupported tracing provider: {key}") + raise KeyError(f"Unsupported tracing provider: {provider}") provider_config_map = OpsTraceProviderConfigMap() @@ -275,10 +420,10 @@ class OpsTraceManager: :param tracing_provider: tracing provider :return: """ - trace_config_data: TraceAppConfig | None = ( - db.session.query(TraceAppConfig) + trace_config_data: TraceAppConfig | None = db.session.scalar( + select(TraceAppConfig) .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) - .first() + .limit(1) ) if not trace_config_data: @@ -314,7 +459,11 @@ class OpsTraceManager: if app_id is None: return None - app: App | None = db.session.query(App).where(App.id == app_id).first() + # Handle storage_id format (tenant-{uuid}) - not a real app_id + if isinstance(app_id, str) and app_id.startswith("tenant-"): + return None + + app = db.session.get(App, app_id) if app is None: return None @@ -388,7 +537,7 @@ class OpsTraceManager: except KeyError: raise ValueError(f"Invalid tracing provider: {tracing_provider}") - app_config: App | None = db.session.query(App).where(App.id == app_id).first() + app_config: App | None = db.session.get(App, app_id) if not app_config: raise ValueError("App not found") app_config.tracing = json.dumps( @@ -406,7 +555,7 @@ class OpsTraceManager: :param app_id: app id :return: """ - app: App | None = db.session.query(App).where(App.id == app_id).first() + app: App | None = db.session.get(App, app_id) if not app: raise ValueError("App not found") if not app.tracing: @@ -466,8 +615,6 @@ class TraceTask: @classmethod def _get_workflow_run_repo(cls): - from repositories.factory import DifyAPIRepositoryFactory - if cls._workflow_run_repo is None: with cls._repo_lock: if cls._workflow_run_repo is None: @@ -478,6 +625,77 @@ class TraceTask: cls._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) return cls._workflow_run_repo + @classmethod + def _calculate_workflow_token_split( + cls, session: "Session", workflow_run_id: str, tenant_id: str + ) -> tuple[int, int]: + """Sum prompt/completion tokens across all node executions for a workflow run. + + Reads from the ``outputs`` column (where LLM nodes store ``usage.prompt_tokens`` + and ``usage.completion_tokens``) rather than ``execution_metadata``, which only + carries ``total_tokens``. Projects only the ``outputs`` column to avoid loading + large JSON blobs unnecessarily. + """ + import json + + from models.workflow import WorkflowNodeExecutionModel + + rows = ( + session.execute( + select(WorkflowNodeExecutionModel.outputs).where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id, + ) + ) + .scalars() + .all() + ) + + total_prompt = 0 + total_completion = 0 + + for raw in rows: + if not raw: + continue + try: + outputs = json.loads(raw) if isinstance(raw, str) else raw + except (ValueError, TypeError): + continue + if not isinstance(outputs, dict): + continue + usage = outputs.get("usage") + if not isinstance(usage, dict): + continue + prompt = usage.get("prompt_tokens") + if isinstance(prompt, (int, float)): + total_prompt += int(prompt) + completion = usage.get("completion_tokens") + if isinstance(completion, (int, float)): + total_completion += int(completion) + + return (total_prompt, total_completion) + + @classmethod + def _get_user_id_from_metadata(cls, metadata: dict[str, Any]) -> str: + """Extract user ID from metadata, prioritizing end_user over account. + + Returns the actual user ID (end_user or account) who invoked the workflow, + regardless of invoke_from context. + """ + # Priority 1: End user (external users via API/WebApp) + if user_id := metadata.get("from_end_user_id"): + return f"end_user:{user_id}" + + # Priority 2: Account user (internal users via console/debugger) + if user_id := metadata.get("from_account_id"): + return f"account:{user_id}" + + # Priority 3: User (internal users via console/debugger) + if user_id := metadata.get("user_id"): + return f"user:{user_id}" + + return "anonymous" + def __init__( self, trace_type: Any, @@ -491,6 +709,7 @@ class TraceTask: self.trace_type = trace_type self.message_id = message_id self.workflow_run_id = workflow_execution.id_ if workflow_execution else None + self.workflow_total_tokens: int | None = workflow_execution.total_tokens if workflow_execution else None self.conversation_id = conversation_id self.user_id = user_id self.timer = timer @@ -498,6 +717,8 @@ class TraceTask: self.app_id = None self.trace_id = None self.kwargs = kwargs + if user_id is not None and "user_id" not in self.kwargs: + self.kwargs["user_id"] = user_id external_trace_id = kwargs.get("external_trace_id") if external_trace_id: self.trace_id = external_trace_id @@ -509,9 +730,12 @@ class TraceTask: preprocess_map = { TraceTaskName.CONVERSATION_TRACE: lambda: self.conversation_trace(**self.kwargs), TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace( - workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id + workflow_run_id=self.workflow_run_id, + conversation_id=self.conversation_id, + user_id=self.user_id, + total_tokens_override=self.workflow_total_tokens, ), - TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id), + TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id, **self.kwargs), TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace( message_id=self.message_id, timer=self.timer, **self.kwargs ), @@ -527,6 +751,9 @@ class TraceTask: TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace( conversation_id=self.conversation_id, timer=self.timer, **self.kwargs ), + TraceTaskName.PROMPT_GENERATION_TRACE: lambda: self.prompt_generation_trace(**self.kwargs), + TraceTaskName.NODE_EXECUTION_TRACE: lambda: self.node_execution_trace(**self.kwargs), + TraceTaskName.DRAFT_NODE_EXECUTION_TRACE: lambda: self.draft_node_execution_trace(**self.kwargs), } return preprocess_map.get(self.trace_type, lambda: None)() @@ -541,6 +768,7 @@ class TraceTask: workflow_run_id: str | None, conversation_id: str | None, user_id: str | None, + total_tokens_override: int | None = None, ): if not workflow_run_id: return {} @@ -560,7 +788,7 @@ class TraceTask: workflow_run_version = workflow_run.version error = workflow_run.error or "" - total_tokens = workflow_run.total_tokens + total_tokens = total_tokens_override if total_tokens_override is not None else workflow_run.total_tokens file_list = workflow_run_inputs.get("sys.file") or [] query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or "" @@ -581,8 +809,18 @@ class TraceTask: Message.workflow_run_id == workflow_run_id, ) message_id = session.scalar(message_data_stmt) + prompt_tokens, completion_tokens = self._calculate_workflow_token_split( + session, workflow_run_id=workflow_run_id, tenant_id=tenant_id + ) - metadata = { + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + if is_enterprise_telemetry_enabled(): + app_name, workspace_name = _lookup_app_and_workspace_names(workflow_run.app_id, tenant_id) + else: + app_name, workspace_name = "", "" + + metadata: dict[str, Any] = { "workflow_id": workflow_id, "conversation_id": conversation_id, "workflow_run_id": workflow_run_id, @@ -595,8 +833,14 @@ class TraceTask: "triggered_from": workflow_run.triggered_from, "user_id": user_id, "app_id": workflow_run.app_id, + "app_name": app_name, + "workspace_name": workspace_name, } + parent_trace_context = self.kwargs.get("parent_trace_context") + if parent_trace_context: + metadata["parent_trace_context"] = parent_trace_context + workflow_trace_info = WorkflowTraceInfo( trace_id=self.trace_id, workflow_data=workflow_run.to_dict(), @@ -611,6 +855,8 @@ class TraceTask: workflow_run_version=workflow_run_version, error=error, total_tokens=total_tokens, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, file_list=file_list, query=query, metadata=metadata, @@ -618,10 +864,11 @@ class TraceTask: message_id=message_id, start_time=workflow_run.created_at, end_time=workflow_run.finished_at, + invoked_by=self._get_user_id_from_metadata(metadata), ) return workflow_trace_info - def message_trace(self, message_id: str | None): + def message_trace(self, message_id: str | None, **kwargs): if not message_id: return {} message_data = get_message_data(message_id) @@ -636,7 +883,7 @@ class TraceTask: inputs = message_data.message # get message file data - message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first() + message_file_data = db.session.scalar(select(MessageFile).where(MessageFile.message_id == message_id).limit(1)) file_list = [] if message_file_data and message_file_data.url is not None: file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" @@ -644,6 +891,19 @@ class TraceTask: streaming_metrics = self._extract_streaming_metrics(message_data) + tenant_id = "" + with Session(db.engine) as session: + tid = session.scalar(select(App.tenant_id).where(App.id == message_data.app_id)) + if tid: + tenant_id = str(tid) + + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + if is_enterprise_telemetry_enabled(): + app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id) + else: + app_name, workspace_name = "", "" + metadata = { "conversation_id": message_data.conversation_id, "ls_provider": message_data.model_provider, @@ -655,7 +915,14 @@ class TraceTask: "workflow_run_id": message_data.workflow_run_id, "from_source": message_data.from_source, "message_id": message_id, + "tenant_id": tenant_id, + "app_id": message_data.app_id, + "user_id": message_data.from_end_user_id or message_data.from_account_id, + "app_name": app_name, + "workspace_name": workspace_name, } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id message_tokens = message_data.message_tokens @@ -672,7 +939,9 @@ class TraceTask: outputs=message_data.answer, file_list=file_list, start_time=created_at, - end_time=created_at + timedelta(seconds=message_data.provider_response_latency), + end_time=message_data.updated_at + if message_data.updated_at and message_data.updated_at > created_at + else created_at + timedelta(seconds=message_data.provider_response_latency), metadata=metadata, message_file_data=message_file_data, conversation_mode=conversation_mode, @@ -697,12 +966,14 @@ class TraceTask: "preset_response": moderation_result.preset_response, "query": moderation_result.query, } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id # get workflow_app_log_id workflow_app_log_id = None if message_data.workflow_run_id: - workflow_app_log_data = ( - db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first() + workflow_app_log_data = db.session.scalar( + select(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id == message_data.workflow_run_id).limit(1) ) workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None @@ -738,12 +1009,14 @@ class TraceTask: "workflow_run_id": message_data.workflow_run_id, "from_source": message_data.from_source, } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id # get workflow_app_log_id workflow_app_log_id = None if message_data.workflow_run_id: - workflow_app_log_data = ( - db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first() + workflow_app_log_data = db.session.scalar( + select(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id == message_data.workflow_run_id).limit(1) ) workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None @@ -777,6 +1050,52 @@ class TraceTask: if not message_data: return {} + tenant_id = "" + with Session(db.engine) as session: + tid = session.scalar(select(App.tenant_id).where(App.id == message_data.app_id)) + if tid: + tenant_id = str(tid) + + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + if is_enterprise_telemetry_enabled(): + app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id) + else: + app_name, workspace_name = "", "" + + doc_list = [doc.model_dump() for doc in documents] if documents else [] + dataset_ids: set[str] = set() + for doc in doc_list: + doc_meta = doc.get("metadata") or {} + did = doc_meta.get("dataset_id") + if did: + dataset_ids.add(did) + + embedding_models: dict[str, dict[str, str]] = {} + if dataset_ids: + with Session(db.engine) as session: + rows = session.execute( + select(Dataset.id, Dataset.embedding_model, Dataset.embedding_model_provider).where( + Dataset.id.in_(list(dataset_ids)) + ) + ).all() + for row in rows: + embedding_models[str(row[0])] = { + "embedding_model": row[1] or "", + "embedding_model_provider": row[2] or "", + } + + # Extract rerank model info from retrieval_model kwargs + rerank_model_provider = "" + rerank_model_name = "" + if "retrieval_model" in kwargs: + retrieval_model = kwargs["retrieval_model"] + if isinstance(retrieval_model, dict): + reranking_model = retrieval_model.get("reranking_model") + if isinstance(reranking_model, dict): + rerank_model_provider = reranking_model.get("reranking_provider_name", "") + rerank_model_name = reranking_model.get("reranking_model_name", "") + metadata = { "message_id": message_id, "ls_provider": message_data.model_provider, @@ -787,13 +1106,23 @@ class TraceTask: "agent_based": message_data.agent_based, "workflow_run_id": message_data.workflow_run_id, "from_source": message_data.from_source, + "tenant_id": tenant_id, + "app_id": message_data.app_id, + "user_id": message_data.from_end_user_id or message_data.from_account_id, + "app_name": app_name, + "workspace_name": workspace_name, + "embedding_models": embedding_models, + "rerank_model_provider": rerank_model_provider, + "rerank_model_name": rerank_model_name, } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id dataset_retrieval_trace_info = DatasetRetrievalTraceInfo( trace_id=self.trace_id, message_id=message_id, inputs=message_data.query or message_data.inputs, - documents=[doc.model_dump() for doc in documents] if documents else [], + documents=doc_list, start_time=timer.get("start"), end_time=timer.get("end"), metadata=metadata, @@ -836,9 +1165,13 @@ class TraceTask: "error": error, "tool_parameters": tool_parameters, } + if message_data.workflow_run_id: + metadata["workflow_run_id"] = message_data.workflow_run_id + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id file_url = "" - message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first() + message_file_data = db.session.scalar(select(MessageFile).where(MessageFile.message_id == message_id).limit(1)) if message_file_data: message_file_id = message_file_data.id if message_file_data else None type = message_file_data.type @@ -890,6 +1223,8 @@ class TraceTask: "conversation_id": conversation_id, "tenant_id": tenant_id, } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id generate_name_trace_info = GenerateNameTraceInfo( trace_id=self.trace_id, @@ -904,6 +1239,182 @@ class TraceTask: return generate_name_trace_info + def prompt_generation_trace(self, **kwargs) -> PromptGenerationTraceInfo | dict: + tenant_id = kwargs.get("tenant_id", "") + user_id = kwargs.get("user_id", "") + app_id = kwargs.get("app_id") + operation_type = kwargs.get("operation_type", "") + instruction = kwargs.get("instruction", "") + generated_output = kwargs.get("generated_output", "") + + prompt_tokens = kwargs.get("prompt_tokens", 0) + completion_tokens = kwargs.get("completion_tokens", 0) + total_tokens = kwargs.get("total_tokens", 0) + + model_provider = kwargs.get("model_provider", "") + model_name = kwargs.get("model_name", "") + + latency = kwargs.get("latency", 0.0) + + timer = kwargs.get("timer") + start_time = timer.get("start") if timer else None + end_time = timer.get("end") if timer else None + + total_price = kwargs.get("total_price") + currency = kwargs.get("currency") + + error = kwargs.get("error") + + app_name = None + workspace_name = None + if app_id: + app_name, workspace_name = _lookup_app_and_workspace_names(app_id, tenant_id) + + metadata = { + "tenant_id": tenant_id, + "user_id": user_id, + "app_id": app_id or "", + "app_name": app_name, + "workspace_name": workspace_name, + "operation_type": operation_type, + "model_provider": model_provider, + "model_name": model_name, + } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id + + return PromptGenerationTraceInfo( + trace_id=self.trace_id, + inputs=instruction, + outputs=generated_output, + start_time=start_time, + end_time=end_time, + metadata=metadata, + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + operation_type=operation_type, + instruction=instruction, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + model_provider=model_provider, + model_name=model_name, + latency=latency, + total_price=total_price, + currency=currency, + error=error, + ) + + def node_execution_trace(self, **kwargs) -> WorkflowNodeTraceInfo | dict: + node_data: dict = kwargs.get("node_execution_data", {}) + if not node_data: + return {} + + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + if is_enterprise_telemetry_enabled(): + app_name, workspace_name = _lookup_app_and_workspace_names( + node_data.get("app_id"), node_data.get("tenant_id") + ) + else: + app_name, workspace_name = "", "" + + # Try tool credential lookup first + credential_id = node_data.get("credential_id") + if is_enterprise_telemetry_enabled(): + credential_name = _lookup_credential_name(credential_id, node_data.get("credential_provider_type")) + # If no credential_id found (e.g., LLM nodes), try LLM credential lookup + if not credential_id: + llm_cred_id, llm_cred_name = _lookup_llm_credential_info( + tenant_id=node_data.get("tenant_id"), + provider=node_data.get("model_provider"), + model=node_data.get("model_name"), + model_type="llm", + ) + if llm_cred_id: + credential_id = llm_cred_id + credential_name = llm_cred_name + else: + credential_name = "" + metadata: dict[str, Any] = { + "tenant_id": node_data.get("tenant_id"), + "app_id": node_data.get("app_id"), + "app_name": app_name, + "workspace_name": workspace_name, + "user_id": node_data.get("user_id"), + "invoke_from": node_data.get("invoke_from"), + "credential_id": credential_id, + "credential_name": credential_name, + "dataset_ids": node_data.get("dataset_ids"), + "dataset_names": node_data.get("dataset_names"), + "plugin_name": node_data.get("plugin_name"), + } + + parent_trace_context = node_data.get("parent_trace_context") + if parent_trace_context: + metadata["parent_trace_context"] = parent_trace_context + + message_id: str | None = None + conversation_id = node_data.get("conversation_id") + workflow_execution_id = node_data.get("workflow_execution_id") + if conversation_id and workflow_execution_id and not parent_trace_context: + with Session(db.engine) as session: + msg_id = session.scalar( + select(Message.id).where( + Message.conversation_id == conversation_id, + Message.workflow_run_id == workflow_execution_id, + ) + ) + if msg_id: + message_id = str(msg_id) + metadata["message_id"] = message_id + if conversation_id: + metadata["conversation_id"] = conversation_id + + return WorkflowNodeTraceInfo( + trace_id=self.trace_id, + message_id=message_id, + start_time=node_data.get("created_at"), + end_time=node_data.get("finished_at"), + metadata=metadata, + workflow_id=node_data.get("workflow_id", ""), + workflow_run_id=node_data.get("workflow_execution_id", ""), + tenant_id=node_data.get("tenant_id", ""), + node_execution_id=node_data.get("node_execution_id", ""), + node_id=node_data.get("node_id", ""), + node_type=node_data.get("node_type", ""), + title=node_data.get("title", ""), + status=node_data.get("status", ""), + error=node_data.get("error"), + elapsed_time=node_data.get("elapsed_time", 0.0), + index=node_data.get("index", 0), + predecessor_node_id=node_data.get("predecessor_node_id"), + total_tokens=node_data.get("total_tokens", 0), + total_price=node_data.get("total_price", 0.0), + currency=node_data.get("currency"), + model_provider=node_data.get("model_provider"), + model_name=node_data.get("model_name"), + prompt_tokens=node_data.get("prompt_tokens"), + completion_tokens=node_data.get("completion_tokens"), + tool_name=node_data.get("tool_name"), + iteration_id=node_data.get("iteration_id"), + iteration_index=node_data.get("iteration_index"), + loop_id=node_data.get("loop_id"), + loop_index=node_data.get("loop_index"), + parallel_id=node_data.get("parallel_id"), + node_inputs=node_data.get("node_inputs"), + node_outputs=node_data.get("node_outputs"), + process_data=node_data.get("process_data"), + invoked_by=self._get_user_id_from_metadata(metadata), + ) + + def draft_node_execution_trace(self, **kwargs) -> DraftNodeExecutionTrace | dict: + node_trace = self.node_execution_trace(**kwargs) + if not isinstance(node_trace, WorkflowNodeTraceInfo): + return node_trace + return DraftNodeExecutionTrace(**node_trace.model_dump()) + def _extract_streaming_metrics(self, message_data) -> dict: if not message_data.message_metadata: return {} @@ -937,13 +1448,17 @@ class TraceQueueManager: self.user_id = user_id self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id) self.flask_app = current_app._get_current_object() # type: ignore + + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + self._enterprise_telemetry_enabled = is_enterprise_telemetry_enabled() if trace_manager_timer is None: self.start_timer() def add_trace_task(self, trace_task: TraceTask): global trace_manager_timer, trace_manager_queue try: - if self.trace_instance: + if self._enterprise_telemetry_enabled or self.trace_instance: trace_task.app_id = self.app_id trace_manager_queue.put(trace_task) except Exception: @@ -979,20 +1494,27 @@ class TraceQueueManager: def send_to_celery(self, tasks: list[TraceTask]): with self.flask_app.app_context(): for task in tasks: - if task.app_id is None: - continue + storage_id = task.app_id + if storage_id is None: + tenant_id = task.kwargs.get("tenant_id") + if tenant_id: + storage_id = f"tenant-{tenant_id}" + else: + logger.warning("Skipping trace without app_id or tenant_id, trace_type: %s", task.trace_type) + continue + file_id = uuid4().hex trace_info = task.execute() task_data = TaskData( - app_id=task.app_id, + app_id=storage_id, trace_info_type=type(trace_info).__name__, trace_info=trace_info.model_dump() if trace_info else None, ) - file_path = f"{OPS_FILE_PATH}{task.app_id}/{file_id}.json" + file_path = f"{OPS_FILE_PATH}{storage_id}/{file_id}.json" storage.save(file_path, task_data.model_dump_json().encode("utf-8")) file_info = { "file_id": file_id, - "app_id": task.app_id, + "app_id": storage_id, } process_trace_tasks.delay(file_info) # type: ignore diff --git a/api/core/ops/tencent_trace/client.py b/api/core/ops/tencent_trace/client.py index c39093bf4c9..be06ab4a36a 100644 --- a/api/core/ops/tencent_trace/client.py +++ b/api/core/ops/tencent_trace/client.py @@ -26,7 +26,13 @@ from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExport from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor -from opentelemetry.semconv.resource import ResourceAttributes +from opentelemetry.semconv._incubating.attributes.deployment_attributes import ( # type: ignore[import-untyped] + DEPLOYMENT_ENVIRONMENT, +) +from opentelemetry.semconv._incubating.attributes.host_attributes import ( # type: ignore[import-untyped] + HOST_NAME, +) +from opentelemetry.semconv.attributes import service_attributes from opentelemetry.trace import SpanKind from opentelemetry.util.types import AttributeValue @@ -73,13 +79,13 @@ class TencentTraceClient: self.resource = Resource( attributes={ - ResourceAttributes.SERVICE_NAME: service_name, - ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}", - ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}", - ResourceAttributes.HOST_NAME: socket.gethostname(), - ResourceAttributes.TELEMETRY_SDK_LANGUAGE: "python", - ResourceAttributes.TELEMETRY_SDK_NAME: "opentelemetry", - ResourceAttributes.TELEMETRY_SDK_VERSION: _get_opentelemetry_sdk_version(), + service_attributes.SERVICE_NAME: service_name, + service_attributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}", + DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}", + HOST_NAME: socket.gethostname(), + "telemetry.sdk.language": "python", + "telemetry.sdk.name": "opentelemetry", + "telemetry.sdk.version": _get_opentelemetry_sdk_version(), } ) # Prepare gRPC endpoint/metadata diff --git a/api/core/ops/tencent_trace/span_builder.py b/api/core/ops/tencent_trace/span_builder.py index 0a6013e244c..f79095d9662 100644 --- a/api/core/ops/tencent_trace/span_builder.py +++ b/api/core/ops/tencent_trace/span_builder.py @@ -6,6 +6,8 @@ import json import logging from datetime import datetime +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from opentelemetry.trace import Status, StatusCode from core.ops.entities.trace_entity import ( @@ -41,11 +43,6 @@ from core.ops.tencent_trace.entities.semconv import ( from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData from core.ops.tencent_trace.utils import TencentTraceUtils from core.rag.models.document import Document -from dify_graph.entities.workflow_node_execution import ( - WorkflowNodeExecution, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) logger = logging.getLogger(__name__) diff --git a/api/core/ops/tencent_trace/tencent_trace.py b/api/core/ops/tencent_trace/tencent_trace.py index 7e56b1effa5..2bd6db22bf7 100644 --- a/api/core/ops/tencent_trace/tencent_trace.py +++ b/api/core/ops/tencent_trace/tencent_trace.py @@ -4,6 +4,10 @@ Tencent APM tracing implementation with separated concerns import logging +from graphon.entities.workflow_node_execution import ( + WorkflowNodeExecution, +) +from graphon.nodes import BuiltinNodeTypes from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -24,10 +28,6 @@ from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData from core.ops.tencent_trace.span_builder import TencentSpanBuilder from core.ops.tencent_trace.utils import TencentTraceUtils from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from dify_graph.entities.workflow_node_execution import ( - WorkflowNodeExecution, -) -from dify_graph.nodes import BuiltinNodeTypes from extensions.ext_database import db from models import Account, App, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom @@ -256,7 +256,7 @@ class TencentDataTrace(BaseTraceInstance): triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) - executions = repository.get_by_workflow_run(workflow_run_id=trace_info.workflow_run_id) + executions = repository.get_by_workflow_execution(workflow_execution_id=trace_info.workflow_run_id) return list(executions) except Exception: diff --git a/api/core/ops/weave_trace/entities/weave_trace_entity.py b/api/core/ops/weave_trace/entities/weave_trace_entity.py index ef1a3be45b9..ed6a7dabbb0 100644 --- a/api/core/ops/weave_trace/entities/weave_trace_entity.py +++ b/api/core/ops/weave_trace/entities/weave_trace_entity.py @@ -67,7 +67,8 @@ class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel): if field_name == "inputs": data = { "messages": [ - dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) for msg in v + dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) # type: ignore + for msg in v ] if isinstance(v, list) else v, diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index 2a657b672c6..8d9ba4694d9 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -6,6 +6,7 @@ from typing import Any, cast import wandb import weave +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from sqlalchemy.orm import sessionmaker from weave.trace_server.trace_server_interface import ( CallEndReq, @@ -31,7 +32,6 @@ from core.ops.entities.trace_entity import ( ) from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel from core.repositories import DifyCoreRepositoryFactory -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from extensions.ext_database import db from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom @@ -161,8 +161,8 @@ class WeaveDataTrace(BaseTraceInstance): ) # Get all executions for this workflow run - workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( - workflow_run_id=trace_info.workflow_run_id + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id ) # rearrange workflow_node_executions by starting time @@ -245,9 +245,7 @@ class WeaveDataTrace(BaseTraceInstance): attributes["user_id"] = user_id if message_data.from_end_user_id: - end_user_data: EndUser | None = ( - db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() - ) + end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id) if end_user_data is not None: end_user_id = end_user_data.session_id attributes["end_user_id"] = end_user_id diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py index 60d08b26c95..be11d2223ca 100644 --- a/api/core/plugin/backwards_invocation/app.py +++ b/api/core/plugin/backwards_invocation/app.py @@ -227,7 +227,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): get app """ try: - app = db.session.query(App).where(App.id == app_id).where(App.tenant_id == tenant_id).first() + app = db.session.scalar(select(App).where(App.id == app_id, App.tenant_id == tenant_id).limit(1)) except Exception: raise ValueError("app not found") diff --git a/api/core/plugin/backwards_invocation/model.py b/api/core/plugin/backwards_invocation/model.py index 11c9191bace..c715b9171c6 100644 --- a/api/core/plugin/backwards_invocation/model.py +++ b/api/core/plugin/backwards_invocation/model.py @@ -2,6 +2,20 @@ import tempfile from binascii import hexlify, unhexlify from collections.abc import Generator +from graphon.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, +) +from graphon.model_runtime.entities.message_entities import ( + PromptMessage, + SystemPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.entities.model_entities import ModelType + from core.app.llm import deduct_llm_quota from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output from core.model_manager import ModelManager @@ -18,22 +32,26 @@ from core.plugin.entities.request import ( ) from core.tools.entities.tool_entities import ToolProviderType from core.tools.utils.model_invocation_utils import ModelInvocationUtils -from dify_graph.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkDelta, - LLMResultChunkWithStructuredOutput, - LLMResultWithStructuredOutput, -) -from dify_graph.model_runtime.entities.message_entities import ( - PromptMessage, - SystemPromptMessage, - UserPromptMessage, -) from models.account import Tenant class PluginModelBackwardsInvocation(BaseBackwardsInvocation): + @staticmethod + def _get_bound_model_instance( + *, + tenant_id: str, + user_id: str | None, + provider: str, + model_type: ModelType, + model: str, + ): + return ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance( + tenant_id=tenant_id, + provider=provider, + model_type=model_type, + model=model, + ) + @classmethod def invoke_llm( cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLM @@ -41,8 +59,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke llm """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, @@ -55,7 +74,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): tools=payload.tools, stop=payload.stop, stream=True if payload.stream is None else payload.stream, - user=user_id, ) if isinstance(response, Generator): @@ -94,8 +112,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke llm with structured output """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, @@ -115,7 +134,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): tools=payload.tools, stop=payload.stop, stream=True if payload.stream is None else payload.stream, - user=user_id, model_parameters=payload.completion_params, ) @@ -156,18 +174,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke text embedding """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, ) # invoke model - response = model_instance.invoke_text_embedding( - texts=payload.texts, - user=user_id, - ) + response = model_instance.invoke_text_embedding(texts=payload.texts) return response @@ -176,8 +192,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke rerank """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, @@ -189,7 +206,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): docs=payload.docs, score_threshold=payload.score_threshold, top_n=payload.top_n, - user=user_id, ) return response @@ -199,20 +215,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke tts """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, ) # invoke model - response = model_instance.invoke_tts( - content_text=payload.content_text, - tenant_id=tenant.id, - voice=payload.voice, - user=user_id, - ) + response = model_instance.invoke_tts(content_text=payload.content_text, voice=payload.voice) def handle() -> Generator[dict, None, None]: for chunk in response: @@ -225,8 +237,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke speech2text """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, @@ -238,10 +251,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): temp.flush() temp.seek(0) - response = model_instance.invoke_speech2text( - file=temp, - user=user_id, - ) + response = model_instance.invoke_speech2text(file=temp) return { "result": response, @@ -252,36 +262,38 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke moderation """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, ) # invoke model - response = model_instance.invoke_moderation( - text=payload.text, - user=user_id, - ) + response = model_instance.invoke_moderation(text=payload.text) return { "result": response, } @classmethod - def get_system_model_max_tokens(cls, tenant_id: str) -> int: + def get_system_model_max_tokens(cls, tenant_id: str, user_id: str | None = None) -> int: """ get system model max tokens """ - return ModelInvocationUtils.get_max_llm_context_tokens(tenant_id=tenant_id) + return ModelInvocationUtils.get_max_llm_context_tokens(tenant_id=tenant_id, user_id=user_id) @classmethod - def get_prompt_tokens(cls, tenant_id: str, prompt_messages: list[PromptMessage]) -> int: + def get_prompt_tokens(cls, tenant_id: str, prompt_messages: list[PromptMessage], user_id: str | None = None) -> int: """ get prompt tokens """ - return ModelInvocationUtils.calculate_tokens(tenant_id=tenant_id, prompt_messages=prompt_messages) + return ModelInvocationUtils.calculate_tokens( + tenant_id=tenant_id, + prompt_messages=prompt_messages, + user_id=user_id, + ) @classmethod def invoke_system_model( @@ -299,6 +311,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): tool_type=ToolProviderType.PLUGIN, tool_name="plugin", prompt_messages=prompt_messages, + caller_user_id=user_id, ) @classmethod @@ -306,7 +319,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke summary """ - max_tokens = cls.get_system_model_max_tokens(tenant_id=tenant.id) + max_tokens = cls.get_system_model_max_tokens(tenant_id=tenant.id, user_id=user_id) content = payload.text SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language @@ -325,6 +338,7 @@ Here is the extra instruction you need to follow: cls.get_prompt_tokens( tenant_id=tenant.id, prompt_messages=[UserPromptMessage(content=content)], + user_id=user_id, ) < max_tokens * 0.6 ): @@ -337,6 +351,7 @@ Here is the extra instruction you need to follow: SystemPromptMessage(content=SUMMARY_PROMPT.replace("{payload.instruction}", payload.instruction)), UserPromptMessage(content=content), ], + user_id=user_id, ) def summarize(content: str) -> str: @@ -394,6 +409,7 @@ Here is the extra instruction you need to follow: cls.get_prompt_tokens( tenant_id=tenant.id, prompt_messages=[UserPromptMessage(content=result)], + user_id=user_id, ) > max_tokens * 0.7 ): diff --git a/api/core/plugin/backwards_invocation/node.py b/api/core/plugin/backwards_invocation/node.py index d6aef93fc40..94789974942 100644 --- a/api/core/plugin/backwards_invocation/node.py +++ b/api/core/plugin/backwards_invocation/node.py @@ -1,19 +1,15 @@ -from core.plugin.backwards_invocation.base import BaseBackwardsInvocation -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.nodes.parameter_extractor.entities import ( - ModelConfig as ParameterExtractorModelConfig, -) -from dify_graph.nodes.parameter_extractor.entities import ( +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.llm.entities import ModelConfig as LLMModelConfig +from graphon.nodes.parameter_extractor.entities import ( ParameterConfig, ParameterExtractorNodeData, ) -from dify_graph.nodes.question_classifier.entities import ( +from graphon.nodes.question_classifier.entities import ( ClassConfig, QuestionClassifierNodeData, ) -from dify_graph.nodes.question_classifier.entities import ( - ModelConfig as QuestionClassifierModelConfig, -) + +from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from services.workflow_service import WorkflowService @@ -24,7 +20,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation): tenant_id: str, user_id: str, parameters: list[ParameterConfig], - model_config: ParameterExtractorModelConfig, + model_config: LLMModelConfig, instruction: str, query: str, ): @@ -74,7 +70,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation): cls, tenant_id: str, user_id: str, - model_config: QuestionClassifierModelConfig, + model_config: LLMModelConfig, classes: list[ClassConfig], instruction: str, query: str, diff --git a/api/core/plugin/backwards_invocation/tool.py b/api/core/plugin/backwards_invocation/tool.py index c2d1574e67d..05854942691 100644 --- a/api/core/plugin/backwards_invocation/tool.py +++ b/api/core/plugin/backwards_invocation/tool.py @@ -31,7 +31,13 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation): # get tool runtime try: tool_runtime = ToolManager.get_tool_runtime_from_plugin( - tool_type, tenant_id, provider, tool_name, tool_parameters, credential_id + tool_type, + tenant_id, + provider, + tool_name, + tool_parameters, + user_id=user_id, + credential_id=credential_id, ) response = ToolEngine.generic_invoke( tool_runtime, tool_parameters, user_id, DifyWorkflowCallbackHandler(), workflow_call_depth=1 diff --git a/api/core/plugin/entities/marketplace.py b/api/core/plugin/entities/marketplace.py index 81e1e12c5f0..2177e8af908 100644 --- a/api/core/plugin/entities/marketplace.py +++ b/api/core/plugin/entities/marketplace.py @@ -1,10 +1,10 @@ +from graphon.model_runtime.entities.provider_entities import ProviderEntity from pydantic import BaseModel, Field, computed_field, model_validator from core.plugin.entities.endpoint import EndpointProviderDeclaration from core.plugin.entities.plugin import PluginResourceRequirements from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntity -from dify_graph.model_runtime.entities.provider_entities import ProviderEntity class MarketplacePluginDeclaration(BaseModel): diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index 7a3780f7de2..b095b4998d7 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -3,6 +3,7 @@ from collections.abc import Mapping from enum import StrEnum, auto from typing import Any +from graphon.model_runtime.entities.provider_entities import ProviderEntity from packaging.version import InvalidVersion, Version from pydantic import BaseModel, Field, field_validator, model_validator @@ -13,7 +14,6 @@ from core.plugin.entities.endpoint import EndpointProviderDeclaration from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntity from core.trigger.entities.entities import TriggerProviderEntity -from dify_graph.model_runtime.entities.provider_entities import ProviderEntity class PluginInstallationSource(StrEnum): diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 416e0f6b4d4..94263ec44e6 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -6,6 +6,8 @@ from datetime import datetime from enum import StrEnum from typing import Any, Generic, TypeVar +from graphon.model_runtime.entities.model_entities import AIModelEntity +from graphon.model_runtime.entities.provider_entities import ProviderEntity from pydantic import BaseModel, ConfigDict, Field from core.agent.plugin_entities import AgentProviderEntityWithPlugin @@ -16,8 +18,6 @@ from core.plugin.entities.plugin import PluginDeclaration, PluginEntity from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin from core.trigger.entities.entities import TriggerProviderEntity -from dify_graph.model_runtime.entities.model_entities import AIModelEntity -from dify_graph.model_runtime.entities.provider_entities import ProviderEntity T = TypeVar("T", bound=(BaseModel | dict | list | bool | str)) diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index c15e9b03857..059f3fa9be1 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -4,11 +4,7 @@ from collections.abc import Mapping from typing import Any, Literal from flask import Response -from pydantic import BaseModel, ConfigDict, Field, field_validator - -from core.entities.provider_entities import BasicProviderConfig -from core.plugin.utils.http_parser import deserialize_response -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, PromptMessageRole, @@ -17,19 +13,18 @@ from dify_graph.model_runtime.entities.message_entities import ( ToolPromptMessage, UserPromptMessage, ) -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.nodes.parameter_extractor.entities import ( - ModelConfig as ParameterExtractorModelConfig, -) -from dify_graph.nodes.parameter_extractor.entities import ( +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.nodes.llm.entities import ModelConfig as LLMModelConfig +from graphon.nodes.parameter_extractor.entities import ( ParameterConfig, ) -from dify_graph.nodes.question_classifier.entities import ( +from graphon.nodes.question_classifier.entities import ( ClassConfig, ) -from dify_graph.nodes.question_classifier.entities import ( - ModelConfig as QuestionClassifierModelConfig, -) +from pydantic import BaseModel, ConfigDict, Field, field_validator + +from core.entities.provider_entities import BasicProviderConfig +from core.plugin.utils.http_parser import deserialize_response class InvokeCredentials(BaseModel): @@ -176,7 +171,7 @@ class RequestInvokeParameterExtractorNode(BaseModel): """ parameters: list[ParameterConfig] - model: ParameterExtractorModelConfig + model: LLMModelConfig instruction: str query: str @@ -187,7 +182,7 @@ class RequestInvokeQuestionClassifierNode(BaseModel): """ query: str - model: QuestionClassifierModelConfig + model: LLMModelConfig classes: list[ClassConfig] instruction: str diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index 737d2041056..2d0ab3fcd73 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -5,6 +5,14 @@ from collections.abc import Callable, Generator from typing import Any, TypeVar, cast import httpx +from graphon.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic import BaseModel from yarl import URL @@ -13,6 +21,7 @@ from core.plugin.endpoint.exc import EndpointSetupFailedError from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse, PluginDaemonError, PluginDaemonInnerError from core.plugin.impl.exc import ( PluginDaemonBadRequestError, + PluginDaemonClientSideError, PluginDaemonInternalServerError, PluginDaemonNotFoundError, PluginDaemonUnauthorizedError, @@ -27,14 +36,6 @@ from core.trigger.errors import ( TriggerPluginInvokeError, TriggerProviderCredentialValidationError, ) -from dify_graph.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL)) _plugin_daemon_timeout_config = cast( @@ -235,7 +236,10 @@ class BasePluginClient: response.raise_for_status() except httpx.HTTPStatusError as e: logger.exception("Failed to request plugin daemon, status: %s, url: %s", e.response.status_code, path) - raise e + if e.response.status_code < 500: + raise PluginDaemonClientSideError(description=str(e)) + else: + raise PluginDaemonInternalServerError(description=str(e)) except Exception as e: msg = f"Failed to request plugin daemon, url: {path}" logger.exception("Failed to request plugin daemon, url: %s", path) diff --git a/api/core/plugin/impl/model.py b/api/core/plugin/impl/model.py index 49ee5d79cb7..1e38c24717f 100644 --- a/api/core/plugin/impl/model.py +++ b/api/core/plugin/impl/model.py @@ -1,6 +1,13 @@ import binascii from collections.abc import Generator, Sequence -from typing import IO +from typing import IO, Any + +from graphon.model_runtime.entities.llm_entities import LLMResultChunk +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult +from graphon.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin_daemon import ( PluginBasicBooleanResponse, @@ -13,15 +20,16 @@ from core.plugin.entities.plugin_daemon import ( PluginVoicesResponse, ) from core.plugin.impl.base import BasePluginClient -from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk -from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from dify_graph.model_runtime.entities.model_entities import AIModelEntity -from dify_graph.model_runtime.entities.rerank_entities import RerankResult -from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult -from dify_graph.model_runtime.utils.encoders import jsonable_encoder class PluginModelClient(BasePluginClient): + @staticmethod + def _dispatch_payload(*, user_id: str | None, data: dict[str, Any]) -> dict[str, Any]: + payload: dict[str, Any] = {"data": data} + if user_id is not None: + payload["user_id"] = user_id + return payload + def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]: """ Fetch model providers for the given tenant. @@ -37,7 +45,7 @@ class PluginModelClient(BasePluginClient): def get_model_schema( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model_type: str, @@ -51,15 +59,15 @@ class PluginModelClient(BasePluginClient): "POST", f"plugin/{tenant_id}/dispatch/model/schema", PluginModelSchemaEntity, - data={ - "user_id": user_id, - "data": { + data=self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": model_type, "model": model, "credentials": credentials, }, - }, + ), headers={ "X-Plugin-ID": plugin_id, "Content-Type": "application/json", @@ -72,7 +80,7 @@ class PluginModelClient(BasePluginClient): return None def validate_provider_credentials( - self, tenant_id: str, user_id: str, plugin_id: str, provider: str, credentials: dict + self, tenant_id: str, user_id: str | None, plugin_id: str, provider: str, credentials: dict ) -> bool: """ validate the credentials of the provider @@ -81,13 +89,13 @@ class PluginModelClient(BasePluginClient): "POST", f"plugin/{tenant_id}/dispatch/model/validate_provider_credentials", PluginBasicBooleanResponse, - data={ - "user_id": user_id, - "data": { + data=self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "credentials": credentials, }, - }, + ), headers={ "X-Plugin-ID": plugin_id, "Content-Type": "application/json", @@ -105,7 +113,7 @@ class PluginModelClient(BasePluginClient): def validate_model_credentials( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model_type: str, @@ -119,15 +127,15 @@ class PluginModelClient(BasePluginClient): "POST", f"plugin/{tenant_id}/dispatch/model/validate_model_credentials", PluginBasicBooleanResponse, - data={ - "user_id": user_id, - "data": { + data=self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": model_type, "model": model, "credentials": credentials, }, - }, + ), headers={ "X-Plugin-ID": plugin_id, "Content-Type": "application/json", @@ -145,7 +153,7 @@ class PluginModelClient(BasePluginClient): def invoke_llm( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -164,9 +172,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/llm/invoke", type_=LLMResultChunk, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "llm", "model": model, @@ -177,7 +185,7 @@ class PluginModelClient(BasePluginClient): "stop": stop, "stream": stream, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -193,7 +201,7 @@ class PluginModelClient(BasePluginClient): def get_llm_num_tokens( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model_type: str, @@ -210,9 +218,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/llm/num_tokens", type_=PluginLLMNumTokensResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": model_type, "model": model, @@ -220,7 +228,7 @@ class PluginModelClient(BasePluginClient): "prompt_messages": prompt_messages, "tools": tools, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -236,7 +244,7 @@ class PluginModelClient(BasePluginClient): def invoke_text_embedding( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -252,9 +260,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke", type_=EmbeddingResult, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "text-embedding", "model": model, @@ -262,7 +270,7 @@ class PluginModelClient(BasePluginClient): "texts": texts, "input_type": input_type, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -278,7 +286,7 @@ class PluginModelClient(BasePluginClient): def invoke_multimodal_embedding( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -294,9 +302,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/multimodal_embedding/invoke", type_=EmbeddingResult, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "text-embedding", "model": model, @@ -304,7 +312,7 @@ class PluginModelClient(BasePluginClient): "documents": documents, "input_type": input_type, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -320,7 +328,7 @@ class PluginModelClient(BasePluginClient): def get_text_embedding_num_tokens( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -335,16 +343,16 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens", type_=PluginTextEmbeddingNumTokensResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "text-embedding", "model": model, "credentials": credentials, "texts": texts, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -360,7 +368,7 @@ class PluginModelClient(BasePluginClient): def invoke_rerank( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -378,9 +386,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/rerank/invoke", type_=RerankResult, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "rerank", "model": model, @@ -390,7 +398,7 @@ class PluginModelClient(BasePluginClient): "score_threshold": score_threshold, "top_n": top_n, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -406,13 +414,13 @@ class PluginModelClient(BasePluginClient): def invoke_multimodal_rerank( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, credentials: dict, - query: dict, - docs: list[dict], + query: MultimodalRerankInput, + docs: list[MultimodalRerankInput], score_threshold: float | None = None, top_n: int | None = None, ) -> RerankResult: @@ -424,9 +432,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/multimodal_rerank/invoke", type_=RerankResult, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "rerank", "model": model, @@ -436,7 +444,7 @@ class PluginModelClient(BasePluginClient): "score_threshold": score_threshold, "top_n": top_n, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -451,7 +459,7 @@ class PluginModelClient(BasePluginClient): def invoke_tts( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -467,9 +475,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/tts/invoke", type_=PluginStringResultResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "tts", "model": model, @@ -478,7 +486,7 @@ class PluginModelClient(BasePluginClient): "content_text": content_text, "voice": voice, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -496,7 +504,7 @@ class PluginModelClient(BasePluginClient): def get_tts_model_voices( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -511,16 +519,16 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/tts/model/voices", type_=PluginVoicesResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "tts", "model": model, "credentials": credentials, "language": language, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -540,7 +548,7 @@ class PluginModelClient(BasePluginClient): def invoke_speech_to_text( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -555,16 +563,16 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/speech2text/invoke", type_=PluginStringResultResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "speech2text", "model": model, "credentials": credentials, "file": binascii.hexlify(file.read()).decode(), }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -580,7 +588,7 @@ class PluginModelClient(BasePluginClient): def invoke_moderation( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -595,16 +603,16 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/moderation/invoke", type_=PluginBasicBooleanResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "moderation", "model": model, "credentials": credentials, "text": text, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, diff --git a/api/core/plugin/impl/model_runtime.py b/api/core/plugin/impl/model_runtime.py new file mode 100644 index 00000000000..22c846b6de0 --- /dev/null +++ b/api/core/plugin/impl/model_runtime.py @@ -0,0 +1,499 @@ +from __future__ import annotations + +import hashlib +import logging +from collections.abc import Generator, Iterable, Sequence +from threading import Lock +from typing import IO, Any, Union + +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType +from graphon.model_runtime.entities.provider_entities import ProviderEntity +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult +from graphon.model_runtime.runtime import ModelRuntime +from pydantic import ValidationError +from redis import RedisError + +from configs import dify_config +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from core.plugin.impl.asset import PluginAssetManager +from core.plugin.impl.model import PluginModelClient +from extensions.ext_redis import redis_client +from models.provider_ids import ModelProviderID + +logger = logging.getLogger(__name__) + +# `TS` means tenant scope +TENANT_SCOPE_SCHEMA_CACHE_USER_ID = "__DIFY_TS__" + + +class PluginModelRuntime(ModelRuntime): + """Plugin-backed runtime adapter bound to tenant context and optional caller scope.""" + + tenant_id: str + user_id: str | None + client: PluginModelClient + _provider_entities: tuple[ProviderEntity, ...] | None + _provider_entities_lock: Lock + + def __init__(self, tenant_id: str, user_id: str | None, client: PluginModelClient) -> None: + if client is None: + raise ValueError("client is required.") + self.tenant_id = tenant_id + self.user_id = user_id + self.client = client + self._provider_entities = None + self._provider_entities_lock = Lock() + + def fetch_model_providers(self) -> Sequence[ProviderEntity]: + if self._provider_entities is not None: + return self._provider_entities + + with self._provider_entities_lock: + if self._provider_entities is None: + self._provider_entities = tuple( + self._to_provider_entity(provider) for provider in self.client.fetch_model_providers(self.tenant_id) + ) + + return self._provider_entities + + def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: + provider_schema = self._get_provider_schema(provider) + + if icon_type.lower() == "icon_small": + if not provider_schema.icon_small: + raise ValueError(f"Provider {provider} does not have small icon.") + file_name = ( + provider_schema.icon_small.zh_Hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_US + ) + elif icon_type.lower() == "icon_small_dark": + if not provider_schema.icon_small_dark: + raise ValueError(f"Provider {provider} does not have small dark icon.") + file_name = ( + provider_schema.icon_small_dark.zh_Hans + if lang.lower() == "zh_hans" + else provider_schema.icon_small_dark.en_US + ) + else: + raise ValueError(f"Unsupported icon type: {icon_type}.") + + if not file_name: + raise ValueError(f"Provider {provider} does not have icon.") + + image_mime_types = { + "jpg": "image/jpeg", + "jpeg": "image/jpeg", + "png": "image/png", + "gif": "image/gif", + "bmp": "image/bmp", + "tiff": "image/tiff", + "tif": "image/tiff", + "webp": "image/webp", + "svg": "image/svg+xml", + "ico": "image/vnd.microsoft.icon", + "heif": "image/heif", + "heic": "image/heic", + } + + extension = file_name.split(".")[-1] + mime_type = image_mime_types.get(extension, "image/png") + return PluginAssetManager().fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type + + def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None: + plugin_id, provider_name = self._split_provider(provider) + self.client.validate_provider_credentials( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + credentials=credentials, + ) + + def validate_model_credentials( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + ) -> None: + plugin_id, provider_name = self._split_provider(provider) + self.client.validate_model_credentials( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model_type=model_type.value, + model=model, + credentials=credentials, + ) + + def get_model_schema( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + ) -> AIModelEntity | None: + cache_key = self._get_schema_cache_key( + provider=provider, + model_type=model_type, + model=model, + credentials=credentials, + ) + + cached_schema_json = None + try: + cached_schema_json = redis_client.get(cache_key) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to read plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, + ) + + if cached_schema_json: + try: + return AIModelEntity.model_validate_json(cached_schema_json) + except ValidationError: + logger.warning("Failed to validate cached plugin model schema for model %s", model, exc_info=True) + try: + redis_client.delete(cache_key) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to delete invalid plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, + ) + + plugin_id, provider_name = self._split_provider(provider) + schema = self.client.get_model_schema( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model_type=model_type.value, + model=model, + credentials=credentials, + ) + + if schema: + try: + redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json()) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to write plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, + ) + + return schema + + def invoke_llm( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + model_parameters: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + tools: list[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: bool, + ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_llm( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + model_parameters=model_parameters, + prompt_messages=list(prompt_messages), + tools=tools, + stop=list(stop) if stop else None, + stream=stream, + ) + + def get_llm_num_tokens( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + tools: Sequence[PromptMessageTool] | None, + ) -> int: + if not dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED: + return 0 + + plugin_id, provider_name = self._split_provider(provider) + return self.client.get_llm_num_tokens( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model_type=model_type.value, + model=model, + credentials=credentials, + prompt_messages=list(prompt_messages), + tools=list(tools) if tools else None, + ) + + def invoke_text_embedding( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + texts: list[str], + input_type: EmbeddingInputType, + ) -> EmbeddingResult: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_text_embedding( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + texts=texts, + input_type=input_type, + ) + + def invoke_multimodal_embedding( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + documents: list[dict[str, Any]], + input_type: EmbeddingInputType, + ) -> EmbeddingResult: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_multimodal_embedding( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + documents=documents, + input_type=input_type, + ) + + def get_text_embedding_num_tokens( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + texts: list[str], + ) -> list[int]: + plugin_id, provider_name = self._split_provider(provider) + return self.client.get_text_embedding_num_tokens( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + texts=texts, + ) + + def invoke_rerank( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + query: str, + docs: list[str], + score_threshold: float | None, + top_n: int | None, + ) -> RerankResult: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_rerank( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + query=query, + docs=docs, + score_threshold=score_threshold, + top_n=top_n, + ) + + def invoke_multimodal_rerank( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + query: MultimodalRerankInput, + docs: list[MultimodalRerankInput], + score_threshold: float | None, + top_n: int | None, + ) -> RerankResult: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_multimodal_rerank( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + query=query, + docs=docs, + score_threshold=score_threshold, + top_n=top_n, + ) + + def invoke_tts( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + content_text: str, + voice: str, + ) -> Iterable[bytes]: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_tts( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + content_text=content_text, + voice=voice, + ) + + def get_tts_model_voices( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + language: str | None, + ) -> Any: + plugin_id, provider_name = self._split_provider(provider) + return self.client.get_tts_model_voices( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + language=language, + ) + + def invoke_speech_to_text( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + file: IO[bytes], + ) -> str: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_speech_to_text( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + file=file, + ) + + def invoke_moderation( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + text: str, + ) -> bool: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_moderation( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + text=text, + ) + + def _get_provider_short_name_alias(self, provider: PluginModelProviderEntity) -> str: + """ + Expose a bare provider alias only for the canonical provider mapping. + + Multiple plugins can publish the same short provider slug. If every + provider entity keeps that slug in ``provider_name``, callers that still + resolve by short name become order-dependent. Restrict the alias to the + provider selected by ``ModelProviderID`` so legacy short-name lookups + remain deterministic while the runtime surface stays canonical. + """ + try: + canonical_provider_id = ModelProviderID(provider.provider) + except ValueError: + return "" + + if canonical_provider_id.plugin_id != provider.plugin_id: + return "" + if canonical_provider_id.provider_name != provider.provider: + return "" + + return provider.provider + + def _to_provider_entity(self, provider: PluginModelProviderEntity) -> ProviderEntity: + declaration = provider.declaration.model_copy(deep=True) + declaration.provider = f"{provider.plugin_id}/{provider.provider}" + declaration.provider_name = self._get_provider_short_name_alias(provider) + return declaration + + def _get_provider_schema(self, provider: str) -> ProviderEntity: + providers = self.fetch_model_providers() + provider_entity = next((item for item in providers if item.provider == provider), None) + if provider_entity is None: + provider_entity = next((item for item in providers if provider == item.provider_name), None) + if provider_entity is None: + raise ValueError(f"Invalid provider: {provider}") + return provider_entity + + def _get_schema_cache_key( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + ) -> str: + # The plugin daemon distinguishes ``None`` from an explicit empty-string + # caller id, so the cache must only collapse ``None`` into tenant scope. + cache_user_id = TENANT_SCOPE_SCHEMA_CACHE_USER_ID if self.user_id is None else self.user_id + cache_key = f"{self.tenant_id}:{provider}:{model_type.value}:{model}:{cache_user_id}" + sorted_credentials = sorted(credentials.items()) if credentials else [] + if not sorted_credentials: + return cache_key + hashed_credentials = ":".join( + [hashlib.md5(f"{key}:{value}".encode()).hexdigest() for key, value in sorted_credentials] + ) + return f"{cache_key}:{hashed_credentials}" + + def _split_provider(self, provider: str) -> tuple[str, str]: + provider_id = ModelProviderID(provider) + return provider_id.plugin_id, provider_id.provider_name diff --git a/api/core/plugin/impl/model_runtime_factory.py b/api/core/plugin/impl/model_runtime_factory.py new file mode 100644 index 00000000000..4b29a6fc56b --- /dev/null +++ b/api/core/plugin/impl/model_runtime_factory.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory + +from core.plugin.impl.model import PluginModelClient + +if TYPE_CHECKING: + from core.model_manager import ModelManager + from core.plugin.impl.model_runtime import PluginModelRuntime + from core.provider_manager import ProviderManager + + +class PluginModelAssembly: + """Compose request-scoped model views on top of a single plugin runtime.""" + + tenant_id: str + user_id: str | None + _model_runtime: PluginModelRuntime | None + _model_provider_factory: ModelProviderFactory | None + _provider_manager: ProviderManager | None + _model_manager: ModelManager | None + + def __init__(self, *, tenant_id: str, user_id: str | None = None) -> None: + self.tenant_id = tenant_id + self.user_id = user_id + self._model_runtime = None + self._model_provider_factory = None + self._provider_manager = None + self._model_manager = None + + @property + def model_runtime(self) -> PluginModelRuntime: + if self._model_runtime is None: + self._model_runtime = create_plugin_model_runtime(tenant_id=self.tenant_id, user_id=self.user_id) + return self._model_runtime + + @property + def model_provider_factory(self) -> ModelProviderFactory: + if self._model_provider_factory is None: + self._model_provider_factory = ModelProviderFactory(model_runtime=self.model_runtime) + return self._model_provider_factory + + @property + def provider_manager(self) -> ProviderManager: + if self._provider_manager is None: + from core.provider_manager import ProviderManager + + self._provider_manager = ProviderManager(model_runtime=self.model_runtime) + return self._provider_manager + + @property + def model_manager(self) -> ModelManager: + if self._model_manager is None: + from core.model_manager import ModelManager + + self._model_manager = ModelManager(provider_manager=self.provider_manager) + return self._model_manager + + +def create_plugin_model_assembly(*, tenant_id: str, user_id: str | None = None) -> PluginModelAssembly: + """Create a request-scoped assembly that shares one plugin runtime across model views.""" + return PluginModelAssembly(tenant_id=tenant_id, user_id=user_id) + + +def create_plugin_model_runtime(*, tenant_id: str, user_id: str | None = None) -> PluginModelRuntime: + """Create a plugin runtime with its client dependency fully composed.""" + from core.plugin.impl.model_runtime import PluginModelRuntime + + return PluginModelRuntime( + tenant_id=tenant_id, + user_id=user_id, + client=PluginModelClient(), + ) + + +def create_plugin_model_provider_factory(*, tenant_id: str, user_id: str | None = None) -> ModelProviderFactory: + """Create a tenant-bound model provider factory for service flows.""" + return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).model_provider_factory + + +def create_plugin_provider_manager(*, tenant_id: str, user_id: str | None = None) -> ProviderManager: + """Create a tenant-bound provider manager for service flows.""" + return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).provider_manager + + +def create_plugin_model_manager(*, tenant_id: str, user_id: str | None = None) -> ModelManager: + """Create a tenant-bound model manager for service flows.""" + return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).model_manager diff --git a/api/core/plugin/impl/plugin.py b/api/core/plugin/impl/plugin.py index 0bbb62af937..ec4858ae2e1 100644 --- a/api/core/plugin/impl/plugin.py +++ b/api/core/plugin/impl/plugin.py @@ -209,8 +209,7 @@ class PluginInstaller(BasePluginClient): "GET", f"plugin/{tenant_id}/management/decode/from_identifier", PluginDecodeResponse, - data={"plugin_unique_identifier": plugin_unique_identifier}, - headers={"Content-Type": "application/json"}, + params={"plugin_unique_identifier": plugin_unique_identifier}, ) def fetch_plugin_installation_by_ids( diff --git a/api/core/plugin/utils/converter.py b/api/core/plugin/utils/converter.py index 53bcd9e9c6a..90350f84000 100644 --- a/api/core/plugin/utils/converter.py +++ b/api/core/plugin/utils/converter.py @@ -1,7 +1,8 @@ from typing import Any +from graphon.file import File + from core.tools.entities.tool_entities import ToolSelector -from dify_graph.file.models import File def convert_parameters_to_plugin_format(parameters: dict[str, Any]) -> dict[str, Any]: diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index ce9f7e64b24..19b5e9223a8 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -1,6 +1,18 @@ from collections.abc import Mapping, Sequence from typing import cast +from graphon.file import File, file_manager +from graphon.model_runtime.entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageRole, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes +from graphon.runtime import VariablePool + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter from core.memory.token_buffer_memory import TokenBufferMemory @@ -8,18 +20,6 @@ from core.model_manager import ModelInstance from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.prompt_transform import PromptTransform from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from dify_graph.file import file_manager -from dify_graph.file.models import File -from dify_graph.model_runtime.entities import ( - AssistantPromptMessage, - PromptMessage, - PromptMessageRole, - SystemPromptMessage, - TextPromptMessageContent, - UserPromptMessage, -) -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes -from dify_graph.runtime import VariablePool class AdvancedPromptTransform(PromptTransform): diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py index d09a46bfde5..9be70199b7d 100644 --- a/api/core/prompt/agent_history_prompt_transform.py +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -1,16 +1,17 @@ from typing import cast +from graphon.model_runtime.entities.message_entities import ( + PromptMessage, + SystemPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel + from core.app.entities.app_invoke_entities import ( ModelConfigWithCredentialsEntity, ) from core.memory.token_buffer_memory import TokenBufferMemory from core.prompt.prompt_transform import PromptTransform -from dify_graph.model_runtime.entities.message_entities import ( - PromptMessage, - SystemPromptMessage, - UserPromptMessage, -) -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel class AgentHistoryPromptTransform(PromptTransform): diff --git a/api/core/prompt/entities/advanced_prompt_entities.py b/api/core/prompt/entities/advanced_prompt_entities.py index 667f5ef0993..b98fd8c179e 100644 --- a/api/core/prompt/entities/advanced_prompt_entities.py +++ b/api/core/prompt/entities/advanced_prompt_entities.py @@ -1,50 +1,7 @@ -from typing import Literal +from graphon.prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig -from pydantic import BaseModel - -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole - - -class ChatModelMessage(BaseModel): - """ - Chat Message. - """ - - text: str - role: PromptMessageRole - edition_type: Literal["basic", "jinja2"] | None = None - - -class CompletionModelPromptTemplate(BaseModel): - """ - Completion Model Prompt Template. - """ - - text: str - edition_type: Literal["basic", "jinja2"] | None = None - - -class MemoryConfig(BaseModel): - """ - Memory Config. - """ - - class RolePrefix(BaseModel): - """ - Role Prefix. - """ - - user: str - assistant: str - - class WindowConfig(BaseModel): - """ - Window Config. - """ - - enabled: bool - size: int | None = None - - role_prefix: RolePrefix | None = None - window: WindowConfig - query_prompt_template: str | None = None +__all__ = [ + "ChatModelMessage", + "CompletionModelPromptTemplate", + "MemoryConfig", +] diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 951736831f1..4539ae9f11b 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -1,11 +1,12 @@ from typing import Any +from graphon.model_runtime.entities.message_entities import PromptMessage +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from dify_graph.model_runtime.entities.message_entities import PromptMessage -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey class PromptTransform: diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 10c44349ae2..c706353ffeb 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -4,14 +4,8 @@ from collections.abc import Mapping, Sequence from enum import StrEnum, auto from typing import TYPE_CHECKING, Any, cast -from core.app.app_config.entities import PromptTemplateEntity -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.memory.token_buffer_memory import TokenBufferMemory -from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.prompt.prompt_transform import PromptTransform -from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from dify_graph.file import file_manager -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.file import file_manager +from graphon.model_runtime.entities.message_entities import ( ImagePromptMessageContent, PromptMessage, PromptMessageContentUnionTypes, @@ -19,10 +13,17 @@ from dify_graph.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) + +from core.app.app_config.entities import PromptTemplateEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.memory.token_buffer_memory import TokenBufferMemory +from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.prompt.prompt_transform import PromptTransform +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import AppMode if TYPE_CHECKING: - from dify_graph.file.models import File + from graphon.file import File class ModelMode(StrEnum): diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index 85a22013958..dbda7499255 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -1,8 +1,7 @@ from collections.abc import Sequence from typing import Any, cast -from core.prompt.simple_prompt_transform import ModelMode -from dify_graph.model_runtime.entities import ( +from graphon.model_runtime.entities import ( AssistantPromptMessage, AudioPromptMessageContent, ImagePromptMessageContent, @@ -12,6 +11,8 @@ from dify_graph.model_runtime.entities import ( TextPromptMessageContent, ) +from core.prompt.simple_prompt_transform import ModelMode + class PromptMessageUtil: @staticmethod diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 3c3fbd6dd23..30933239f65 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -1,10 +1,20 @@ +from __future__ import annotations + import contextlib import json from collections import defaultdict from collections.abc import Sequence from json import JSONDecodeError -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FormType, + ProviderEntity, +) +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from sqlalchemy import select from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session @@ -28,14 +38,6 @@ from core.entities.provider_entities import ( from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.helper.position_helper import is_filtered -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FormType, - ProviderEntity, -) -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from extensions import ext_hosting_provider from extensions.ext_database import db from extensions.ext_redis import redis_client @@ -53,15 +55,25 @@ from models.provider import ( from models.provider_ids import ModelProviderID from services.feature_service import FeatureService +if TYPE_CHECKING: + from graphon.model_runtime.runtime import ModelRuntime + class ProviderManager: """ - ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers. + ProviderManager manages tenant-scoped model provider configuration. + + The runtime adapter is injected by the composition layer so this class stays + focused on configuration assembly instead of constructing plugin runtimes. + Request-bound managers may carry caller identity in that runtime, and the + resulting ``ProviderConfiguration`` objects must reuse it for downstream + model-type and schema lookups. """ - def __init__(self): + def __init__(self, model_runtime: ModelRuntime): self.decoding_rsa_key = None self.decoding_cipher_rsa = None + self._model_runtime = model_runtime def get_configurations(self, tenant_id: str) -> ProviderConfigurations: """ @@ -127,7 +139,7 @@ class ProviderManager: ) # Get all provider entities - model_provider_factory = ModelProviderFactory(tenant_id) + model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime) provider_entities = model_provider_factory.get_providers() # Get All preferred provider types of the workspace @@ -255,6 +267,7 @@ class ProviderManager: custom_configuration=custom_configuration, model_settings=model_settings, ) + provider_configuration.bind_model_runtime(self._model_runtime) provider_configurations[str(provider_id_entity)] = provider_configuration @@ -321,7 +334,7 @@ class ProviderManager: if not default_model: return None - model_provider_factory = ModelProviderFactory(tenant_id) + model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime) provider_schema = model_provider_factory.get_provider_schema(provider=default_model.provider_name) return DefaultModelEntity( @@ -392,7 +405,7 @@ class ProviderManager: # create default model default_model = TenantDefaultModel( tenant_id=tenant_id, - model_type=model_type.value, + model_type=model_type.to_origin_model_type(), provider_name=provider, model_name=model, ) @@ -918,11 +931,11 @@ class ProviderManager: trail_pool = CreditPoolService.get_pool( tenant_id=tenant_id, - pool_type=ProviderQuotaType.TRIAL.value, + pool_type=ProviderQuotaType.TRIAL, ) paid_pool = CreditPoolService.get_pool( tenant_id=tenant_id, - pool_type=ProviderQuotaType.PAID.value, + pool_type=ProviderQuotaType.PAID, ) else: trail_pool = None diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index 33eb5f963ac..b872ea8a8fb 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -1,3 +1,5 @@ +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from typing_extensions import TypedDict from core.model_manager import ModelInstance, ModelManager @@ -8,8 +10,6 @@ from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner from core.rag.rerank.rerank_factory import RerankRunnerFactory from core.rag.rerank.rerank_type import RerankMode -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError class RerankingModelDict(TypedDict): @@ -52,11 +52,10 @@ class DataPostProcessor: documents: list[Document], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: if self.rerank_runner: - documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user, query_type) + documents = self.rerank_runner.run(query, documents, score_threshold, top_n, query_type) if self.reorder_runner: documents = self.reorder_runner.run(documents) @@ -106,9 +105,9 @@ class DataPostProcessor: ) -> ModelInstance | None: if reranking_model: try: - model_manager = ModelManager() - reranking_provider_name = reranking_model["reranking_provider_name"] - reranking_model_name = reranking_model["reranking_model_name"] + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) + reranking_provider_name = reranking_model.get("reranking_provider_name") + reranking_model_name = reranking_model.get("reranking_model_name") if not reranking_provider_name or not reranking_model_name: return None rerank_model_instance = model_manager.get_model_instance( diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index b07dc108bee..b8d5db7a43b 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -97,13 +97,13 @@ class Jieba(BaseKeyword): documents = [] - segment_query_stmt = db.session.query(DocumentSegment).where( + segment_query_stmt = select(DocumentSegment).where( DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id.in_(sorted_chunk_indices) ) if document_ids_filter: segment_query_stmt = segment_query_stmt.where(DocumentSegment.document_id.in_(document_ids_filter)) - segments = db.session.execute(segment_query_stmt).scalars().all() + segments = db.session.scalars(segment_query_stmt).all() segment_map = {segment.index_node_id: segment for segment in segments} for chunk_index in sorted_chunk_indices: segment = segment_map.get(chunk_index) diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 713319ab9dc..203a8588d67 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -4,6 +4,7 @@ from concurrent.futures import ThreadPoolExecutor from typing import Any, NotRequired from flask import Flask, current_app +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import select from sqlalchemy.orm import Session, load_only from typing_extensions import TypedDict @@ -23,7 +24,6 @@ from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.signature import sign_upload_file -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from models.dataset import ( ChildChunk, @@ -328,7 +328,7 @@ class RetrievalService: str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False ) if dataset.is_multimodal: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) is_support_vision = model_manager.check_model_support_vision( tenant_id=dataset.tenant_id, provider=reranking_model["reranking_provider_name"], @@ -432,10 +432,11 @@ class RetrievalService: # Batch query dataset documents dataset_documents = { doc.id: doc - for doc in db.session.query(DatasetDocument) - .where(DatasetDocument.id.in_(document_ids)) - .options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id)) - .all() + for doc in db.session.scalars( + select(DatasetDocument) + .where(DatasetDocument.id.in_(document_ids)) + .options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id)) + ).all() } valid_dataset_documents = {} diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py index 144d8344951..9f5842e4493 100644 --- a/api/core/rag/datasource/vdb/baidu/baidu_vector.py +++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py @@ -13,6 +13,7 @@ from pymochow.exception import ServerError # type: ignore from pymochow.model.database import Database from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState # type: ignore from pymochow.model.schema import ( + AutoBuildRowCountIncrement, Field, FilteringIndex, HNSWParams, @@ -51,6 +52,9 @@ class BaiduConfig(BaseModel): replicas: int = 3 inverted_index_analyzer: str = "DEFAULT_ANALYZER" inverted_index_parser_mode: str = "COARSE_MODE" + auto_build_row_count_increment: int = 500 + auto_build_row_count_increment_ratio: float = 0.05 + rebuild_index_timeout_in_seconds: int = 300 @model_validator(mode="before") @classmethod @@ -107,18 +111,6 @@ class BaiduVector(BaseVector): rows.append(row) table.upsert(rows=rows) - # rebuild vector index after upsert finished - table.rebuild_index(self.vector_index) - timeout = 3600 # 1 hour timeout - start_time = time.time() - while True: - time.sleep(1) - index = table.describe_index(self.vector_index) - if index.state == IndexState.NORMAL: - break - if time.time() - start_time > timeout: - raise TimeoutError(f"Index rebuild timeout after {timeout} seconds") - def text_exists(self, id: str) -> bool: res = self._db.table(self._collection_name).query(primary_key={VDBField.PRIMARY_KEY: id}) if res and res.code == 0: @@ -232,8 +224,14 @@ class BaiduVector(BaseVector): return self._client.database(self._client_config.database) def _table_existed(self) -> bool: - tables = self._db.list_table() - return any(table.table_name == self._collection_name for table in tables) + try: + table = self._db.table(self._collection_name) + except ServerError as e: + if e.code == ServerErrCode.TABLE_NOT_EXIST: + return False + else: + raise + return True def _create_table(self, dimension: int): # Try to grab distributed lock and create table @@ -287,6 +285,11 @@ class BaiduVector(BaseVector): field=VDBField.VECTOR, metric_type=metric_type, params=HNSWParams(m=16, efconstruction=200), + auto_build=True, + auto_build_index_policy=AutoBuildRowCountIncrement( + row_count_increment=self._client_config.auto_build_row_count_increment, + row_count_increment_ratio=self._client_config.auto_build_row_count_increment_ratio, + ), ) ) @@ -335,7 +338,7 @@ class BaiduVector(BaseVector): ) # Wait for table created - timeout = 300 # 5 minutes timeout + timeout = self._client_config.rebuild_index_timeout_in_seconds # default 5 minutes timeout start_time = time.time() while True: time.sleep(1) @@ -345,6 +348,20 @@ class BaiduVector(BaseVector): if time.time() - start_time > timeout: raise TimeoutError(f"Table creation timeout after {timeout} seconds") redis_client.set(table_exist_cache_key, 1, ex=3600) + # rebuild vector index immediately after table created, make sure index is ready + table.rebuild_index(self.vector_index) + timeout = 3600 # 1 hour timeout + self._wait_for_index_ready(table, timeout) + + def _wait_for_index_ready(self, table, timeout: int = 3600): + start_time = time.time() + while True: + time.sleep(1) + index = table.describe_index(self.vector_index) + if index.state == IndexState.NORMAL: + break + if time.time() - start_time > timeout: + raise TimeoutError(f"Index rebuild timeout after {timeout} seconds") class BaiduVectorFactory(AbstractVectorFactory): @@ -369,5 +386,8 @@ class BaiduVectorFactory(AbstractVectorFactory): replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS, inverted_index_analyzer=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER, inverted_index_parser_mode=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE, + auto_build_row_count_increment=dify_config.BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT, + auto_build_row_count_increment_ratio=dify_config.BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO, + rebuild_index_timeout_in_seconds=dify_config.BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS, ), ) diff --git a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py index c7b6593a8f9..df02c584ede 100644 --- a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py +++ b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py @@ -124,13 +124,13 @@ class HuaweiCloudVector(BaseVector): ) ) + score_threshold = float(kwargs.get("score_threshold") or 0.0) docs = [] for doc, score in docs_and_scores: - score_threshold = float(kwargs.get("score_threshold") or 0.0) if score >= score_threshold: if doc.metadata is not None: doc.metadata["score"] = score - docs.append(doc) + docs.append(doc) return docs diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py index 56ffb36a2b2..69c81d521c9 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -33,6 +33,7 @@ from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import Dataset, TidbAuthBinding +from models.enums import TidbAuthBindingStatus if TYPE_CHECKING: from qdrant_client import grpc # noqa @@ -284,27 +285,29 @@ class TidbOnQdrantVector(BaseVector): from qdrant_client.http import models from qdrant_client.http.exceptions import UnexpectedResponse - for node_id in ids: - try: - filter = models.Filter( - must=[ - models.FieldCondition( - key="metadata.doc_id", - match=models.MatchValue(value=node_id), - ), - ], - ) - self._client.delete( - collection_name=self._collection_name, - points_selector=FilterSelector(filter=filter), - ) - except UnexpectedResponse as e: - # Collection does not exist, so return - if e.status_code == 404: - return - # Some other error occurred, so re-raise the exception - else: - raise e + if not ids: + return + + try: + filter = models.Filter( + must=[ + models.FieldCondition( + key="metadata.doc_id", + match=models.MatchAny(any=ids), + ), + ], + ) + self._client.delete( + collection_name=self._collection_name, + points_selector=FilterSelector(filter=filter), + ) + except UnexpectedResponse as e: + # Collection does not exist, so return + if e.status_code == 404: + return + # Some other error occurred, so re-raise the exception + else: + raise e def text_exists(self, id: str) -> bool: all_collection_name = [] @@ -423,11 +426,10 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory): TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" else: - idle_tidb_auth_binding = ( - db.session.query(TidbAuthBinding) + idle_tidb_auth_binding = db.session.scalar( + select(TidbAuthBinding) .where(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE") .limit(1) - .one_or_none() ) if idle_tidb_auth_binding: idle_tidb_auth_binding.active = True @@ -450,7 +452,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory): password=new_cluster["password"], tenant_id=dataset.tenant_id, active=True, - status="ACTIVE", + status=TidbAuthBindingStatus.ACTIVE, ) db.session.add(new_tidb_auth_binding) db.session.commit() diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py index 754c1492415..06b17b9e62c 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py @@ -9,6 +9,7 @@ from configs import dify_config from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import TidbAuthBinding +from models.enums import TidbAuthBindingStatus class TidbService: @@ -170,7 +171,7 @@ class TidbService: userPrefix = item["userPrefix"] if state == "ACTIVE" and len(userPrefix) > 0: cluster_info = tidb_serverless_list_map[item["clusterId"]] - cluster_info.status = "ACTIVE" + cluster_info.status = TidbAuthBindingStatus.ACTIVE cluster_info.account = f"{userPrefix}.root" db.session.add(cluster_info) db.session.commit() diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index cd12cd3fae7..26531eab886 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -4,6 +4,7 @@ import time from abc import ABC, abstractmethod from typing import Any +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import select from configs import dify_config @@ -14,7 +15,6 @@ from core.rag.embedding.cached_embedding import CacheEmbedding from core.rag.embedding.embedding_base import Embeddings from core.rag.index_processor.constant.doc_type import DocType from core.rag.models.document import Document -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage @@ -277,7 +277,7 @@ class Vector: return self._vector_processor.search_by_vector(query_vector, **kwargs) def search_by_file(self, file_id: str, **kwargs: Any) -> list[Document]: - upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == file_id).first() + upload_file: UploadFile | None = db.session.get(UploadFile, file_id) if not upload_file: return [] @@ -303,7 +303,7 @@ class Vector: redis_client.delete(collection_exist_cache_key) def _get_embeddings(self) -> Embeddings: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=self._dataset.tenant_id, diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 16a5588024d..40f45953af4 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -3,11 +3,12 @@ from __future__ import annotations from collections.abc import Sequence from typing import Any -from sqlalchemy import func, select +from graphon.model_runtime.entities.model_entities import ModelType +from sqlalchemy import delete, func, select from core.model_manager import ModelManager +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import AttachmentDocument, Document -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding @@ -62,17 +63,15 @@ class DatasetDocumentStore: return output def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False): - max_position = ( - db.session.query(func.max(DocumentSegment.position)) - .where(DocumentSegment.document_id == self._document_id) - .scalar() + max_position = db.session.scalar( + select(func.max(DocumentSegment.position)).where(DocumentSegment.document_id == self._document_id) ) if max_position is None: max_position = 0 embedding_model = None - if self._dataset.indexing_technique == "high_quality": - model_manager = ModelManager() + if self._dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: + model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=self._dataset.tenant_id, provider=self._dataset.embedding_model_provider, @@ -154,12 +153,14 @@ class DatasetDocumentStore: ) if save_child and doc.children: # delete the existing child chunks - db.session.query(ChildChunk).where( - ChildChunk.tenant_id == self._dataset.tenant_id, - ChildChunk.dataset_id == self._dataset.id, - ChildChunk.document_id == self._document_id, - ChildChunk.segment_id == segment_document.id, - ).delete() + db.session.execute( + delete(ChildChunk).where( + ChildChunk.tenant_id == self._dataset.tenant_id, + ChildChunk.dataset_id == self._dataset.id, + ChildChunk.document_id == self._document_id, + ChildChunk.segment_id == segment_document.id, + ) + ) # add new child chunks for position, child in enumerate(doc.children, start=1): child_segment = ChildChunk( diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 6d1b65a0556..8d1c0da392d 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -4,14 +4,15 @@ import pickle from typing import Any, cast import numpy as np +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from sqlalchemy import select from sqlalchemy.exc import IntegrityError from configs import dify_config from core.entities.embedding_type import EmbeddingInputType from core.model_manager import ModelInstance from core.rag.embedding.embedding_base import Embeddings -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey -from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from extensions.ext_database import db from extensions.ext_redis import redis_client from libs import helper @@ -21,9 +22,8 @@ logger = logging.getLogger(__name__) class CacheEmbedding(Embeddings): - def __init__(self, model_instance: ModelInstance, user: str | None = None): + def __init__(self, model_instance: ModelInstance): self._model_instance = model_instance - self._user = user def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embed search docs in batches of 10.""" @@ -32,14 +32,14 @@ class CacheEmbedding(Embeddings): embedding_queue_indices = [] for i, text in enumerate(texts): hash = helper.generate_text_hash(text) - embedding = ( - db.session.query(Embedding) - .filter_by( - model_name=self._model_instance.model_name, - hash=hash, - provider_name=self._model_instance.provider, + embedding = db.session.scalar( + select(Embedding) + .where( + Embedding.model_name == self._model_instance.model_name, + Embedding.hash == hash, + Embedding.provider_name == self._model_instance.provider, ) - .first() + .limit(1) ) if embedding: text_embeddings[i] = embedding.get_embedding() @@ -65,7 +65,7 @@ class CacheEmbedding(Embeddings): batch_texts = embedding_queue_texts[i : i + max_chunks] embedding_result = self._model_instance.invoke_text_embedding( - texts=batch_texts, user=self._user, input_type=EmbeddingInputType.DOCUMENT + texts=batch_texts, input_type=EmbeddingInputType.DOCUMENT ) for vector in embedding_result.embeddings: @@ -113,14 +113,14 @@ class CacheEmbedding(Embeddings): embedding_queue_indices = [] for i, multimodel_document in enumerate(multimodel_documents): file_id = multimodel_document["file_id"] - embedding = ( - db.session.query(Embedding) - .filter_by( - model_name=self._model_instance.model_name, - hash=file_id, - provider_name=self._model_instance.provider, + embedding = db.session.scalar( + select(Embedding) + .where( + Embedding.model_name == self._model_instance.model_name, + Embedding.hash == file_id, + Embedding.provider_name == self._model_instance.provider, ) - .first() + .limit(1) ) if embedding: multimodel_embeddings[i] = embedding.get_embedding() @@ -147,7 +147,6 @@ class CacheEmbedding(Embeddings): embedding_result = self._model_instance.invoke_multimodal_embedding( multimodel_documents=batch_multimodel_documents, - user=self._user, input_type=EmbeddingInputType.DOCUMENT, ) @@ -202,7 +201,7 @@ class CacheEmbedding(Embeddings): return [float(x) for x in decoded_embedding] try: embedding_result = self._model_instance.invoke_text_embedding( - texts=[text], user=self._user, input_type=EmbeddingInputType.QUERY + texts=[text], input_type=EmbeddingInputType.QUERY ) embedding_results = embedding_result.embeddings[0] @@ -245,7 +244,7 @@ class CacheEmbedding(Embeddings): return [float(x) for x in decoded_embedding] try: embedding_result = self._model_instance.invoke_multimodal_embedding( - multimodel_documents=[multimodel_document], user=self._user, input_type=EmbeddingInputType.QUERY + multimodel_documents=[multimodel_document], input_type=EmbeddingInputType.QUERY ) embedding_results = embedding_result.embeddings[0] diff --git a/api/core/rag/extractor/firecrawl/firecrawl_app.py b/api/core/rag/extractor/firecrawl/firecrawl_app.py index 371f7b08652..e1ddd2dd967 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_app.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_app.py @@ -95,15 +95,11 @@ class FirecrawlApp: if response.status_code == 200: crawl_status_response = response.json() if crawl_status_response.get("status") == "completed": - total = crawl_status_response.get("total", 0) - if total == 0: + # Normalize to avoid None bypassing the zero-guard when the API returns null. + total = crawl_status_response.get("total") or 0 + if total <= 0: raise Exception("Failed to check crawl status. Error: No page found") - data = crawl_status_response.get("data", []) - url_data_list: list[FirecrawlDocumentData] = [] - for item in data: - if isinstance(item, dict) and "metadata" in item and "markdown" in item: - url_data = self._extract_common_fields(item) - url_data_list.append(url_data) + url_data_list = self._collect_all_crawl_pages(crawl_status_response, headers) if url_data_list: file_key = "website_files/" + job_id + ".txt" try: @@ -120,6 +116,36 @@ class FirecrawlApp: self._handle_error(response, "check crawl status") raise RuntimeError("unreachable: _handle_error always raises") + def _collect_all_crawl_pages( + self, first_page: dict[str, Any], headers: dict[str, str] + ) -> list[FirecrawlDocumentData]: + """Collect all crawl result pages by following pagination links. + + Raises an exception if any paginated request fails, to avoid returning + partial data that is inconsistent with the reported total. + + The number of pages processed is capped at ``total`` (the + server-reported page count) to guard against infinite loops caused by + a misbehaving server that keeps returning a ``next`` URL. + """ + total: int = first_page.get("total") or 0 + url_data_list: list[FirecrawlDocumentData] = [] + current_page = first_page + pages_processed = 0 + while True: + for item in current_page.get("data", []): + if isinstance(item, dict) and "metadata" in item and "markdown" in item: + url_data_list.append(self._extract_common_fields(item)) + next_url: str | None = current_page.get("next") + pages_processed += 1 + if not next_url or pages_processed >= total: + break + response = self._get_request(next_url, headers) + if response.status_code != 200: + self._handle_error(response, "fetch next crawl page") + current_page = response.json() + return url_data_list + def _format_crawl_status_response( self, status: str, diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index 372af8fd941..aa361607110 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -4,6 +4,7 @@ import operator from typing import Any, cast import httpx +from sqlalchemy import update from configs import dify_config from core.rag.extractor.extractor_base import BaseExtractor @@ -346,9 +347,11 @@ class NotionExtractor(BaseExtractor): if data_source_info: data_source_info["last_edited_time"] = last_edited_time - db.session.query(DocumentModel).filter_by(id=document_model.id).update( - {DocumentModel.data_source_info: json.dumps(data_source_info)} - ) # type: ignore + db.session.execute( + update(DocumentModel) + .where(DocumentModel.id == document_model.id) + .values(data_source_info=json.dumps(data_source_info)) + ) db.session.commit() def get_notion_last_edited_time(self) -> str: diff --git a/api/core/rag/index_processor/index_processor.py b/api/core/rag/index_processor/index_processor.py index d9145023ac5..a6d1db214b0 100644 --- a/api/core/rag/index_processor/index_processor.py +++ b/api/core/rag/index_processor/index_processor.py @@ -9,6 +9,7 @@ from flask import current_app from sqlalchemy import delete, func, select from core.db.session_factory import session_factory +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError from core.workflow.nodes.knowledge_index.protocols import Preview, PreviewItem, QaPreview @@ -159,7 +160,7 @@ class IndexProcessor: tenant_id = dataset.tenant_id preview_output = self.format_preview(chunk_structure, chunks) - if indexing_technique != "high_quality": + if indexing_technique != IndexTechniqueType.HIGH_QUALITY: return preview_output if not summary_index_setting or not summary_index_setting.get("enable"): diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index a435dfc46a9..7d504fdb35e 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Any, NotRequired, Optional from urllib.parse import unquote, urlparse import httpx +from sqlalchemy import select from typing_extensions import TypedDict from configs import dify_config @@ -200,7 +201,7 @@ class BaseIndexProcessor(ABC): # Get unique IDs for database query unique_upload_file_ids = list(set(upload_file_id_list)) - upload_files = db.session.query(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids)).all() + upload_files = db.session.scalars(select(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids))).all() # Create a mapping from ID to UploadFile for quick lookup upload_file_map = {upload_file.id: upload_file for upload_file in upload_files} @@ -312,7 +313,7 @@ class BaseIndexProcessor(ABC): """ from services.file_service import FileService - tool_file = db.session.query(ToolFile).where(ToolFile.id == tool_file_id).first() + tool_file = db.session.get(ToolFile, tool_file_id) if not tool_file: return None blob = storage.load_once(tool_file.file_key) diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 80163b17070..22ab492cbf3 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -8,11 +8,24 @@ from typing import Any, cast logger = logging.getLogger(__name__) +from graphon.file import File, FileTransferMethod, FileType, file_manager +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentUnionTypes, + TextPromptMessageContent, + UserPromptMessage, +) +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType +from sqlalchemy import select + +from core.app.file_access import DatabaseFileAccessController from core.app.llm import deduct_llm_quota from core.entities.knowledge_entities import PreviewDetail from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT from core.model_manager import ModelInstance -from core.provider_manager import ProviderManager +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.data_post_processor.data_post_processor import RerankingModelDict from core.rag.datasource.keyword.keyword_factory import Keyword @@ -22,21 +35,12 @@ from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.constant.doc_type import DocType -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols -from dify_graph.file import File, FileTransferMethod, FileType, file_manager -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from dify_graph.model_runtime.entities.message_entities import ( - ImagePromptMessageContent, - PromptMessage, - PromptMessageContentUnionTypes, - TextPromptMessageContent, - UserPromptMessage, -) -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType +from core.workflow.file_reference import build_file_reference from extensions.ext_database import db from factories.file_factory import build_from_mapping from libs import helper @@ -48,6 +52,8 @@ from services.account_service import AccountService from services.entities.knowledge_entities.knowledge_entities import Rule from services.summary_index_service import SummaryIndexService +_file_access_controller = DatabaseFileAccessController() + class ParagraphIndexProcessor(BaseIndexProcessor): def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: @@ -117,7 +123,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): with_keywords: bool = True, **kwargs, ) -> None: - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) vector.create(documents) if multimodal_documents and dataset.is_multimodal: @@ -140,14 +146,12 @@ class ParagraphIndexProcessor(BaseIndexProcessor): if delete_summaries: if node_ids: # Find segments by index_node_id - segments = ( - db.session.query(DocumentSegment) - .filter( + segments = db.session.scalars( + select(DocumentSegment).where( DocumentSegment.dataset_id == dataset.id, DocumentSegment.index_node_id.in_(node_ids), ) - .all() - ) + ).all() segment_ids = [segment.id for segment in segments] if segment_ids: SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids) @@ -155,7 +159,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): # Delete all summaries for the dataset SummaryIndexService.delete_summaries_for_segments(dataset, None) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) if node_ids: vector.delete_by_ids(node_ids) @@ -253,12 +257,12 @@ class ParagraphIndexProcessor(BaseIndexProcessor): doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) # add document segments doc_store.add_documents(docs=documents, save_child=False) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) vector.create(documents) if all_multimodal_documents and dataset.is_multimodal: vector.create_multimodal(all_multimodal_documents) - elif dataset.indexing_technique == "economy": + elif dataset.indexing_technique == IndexTechniqueType.ECONOMY: keyword = Keyword(dataset) keyword.add_texts(documents) @@ -410,7 +414,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): # If default prompt doesn't have {language} placeholder, use it as-is pass - provider_manager = ProviderManager() + provider_manager = create_plugin_provider_manager(tenant_id=tenant_id) provider_model_bundle = provider_manager.get_provider_model_bundle( tenant_id, model_provider_name, ModelType.LLM ) @@ -532,11 +536,9 @@ class ParagraphIndexProcessor(BaseIndexProcessor): # Get unique IDs for database query unique_upload_file_ids = list(set(upload_file_id_list)) - upload_files = ( - db.session.query(UploadFile) - .where(UploadFile.id.in_(unique_upload_file_ids), UploadFile.tenant_id == tenant_id) - .all() - ) + upload_files = db.session.scalars( + select(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids), UploadFile.tenant_id == tenant_id) + ).all() # Create File objects from UploadFile records file_objects = [] @@ -555,6 +557,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): file_obj = build_from_mapping( mapping=mapping, tenant_id=tenant_id, + access_controller=_file_access_controller, ) file_objects.append(file_obj) except Exception as e: @@ -604,11 +607,12 @@ class ParagraphIndexProcessor(BaseIndexProcessor): filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - tenant_id=tenant_id, type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, - related_id=upload_file.id, + reference=build_file_reference( + record_id=str(upload_file.id), + ), size=upload_file.size, storage_key=upload_file.key, ) diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index df0761ca73f..1c5e02e9c8f 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -6,6 +6,8 @@ import uuid from collections.abc import Mapping from typing import Any +from sqlalchemy import delete, select + from configs import dify_config from core.db.session_factory import session_factory from core.entities.knowledge_entities import PreviewDetail @@ -18,7 +20,7 @@ from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.constant.doc_type import DocType -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -128,7 +130,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): with_keywords: bool = True, **kwargs, ) -> None: - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) for document in documents: child_documents = document.children @@ -166,7 +168,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): # Delete all summaries for the dataset SummaryIndexService.delete_summaries_for_segments(dataset, None) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: delete_child_chunks = kwargs.get("delete_child_chunks") or False precomputed_child_node_ids = kwargs.get("precomputed_child_node_ids") vector = Vector(dataset) @@ -177,17 +179,16 @@ class ParentChildIndexProcessor(BaseIndexProcessor): child_node_ids = precomputed_child_node_ids else: # Fallback to original query (may fail if segments are already deleted) - child_node_ids = ( - db.session.query(ChildChunk.index_node_id) + rows = db.session.execute( + select(ChildChunk.index_node_id) .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) .where( DocumentSegment.dataset_id == dataset.id, DocumentSegment.index_node_id.in_(node_ids), ChildChunk.dataset_id == dataset.id, ) - .all() - ) - child_node_ids = [child_node_id[0] for child_node_id in child_node_ids if child_node_id[0]] + ).all() + child_node_ids = [row[0] for row in rows if row[0]] # Delete from vector index if child_node_ids: @@ -195,18 +196,22 @@ class ParentChildIndexProcessor(BaseIndexProcessor): # Delete from database if delete_child_chunks and child_node_ids: - db.session.query(ChildChunk).where( - ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids) - ).delete(synchronize_session=False) + db.session.execute( + delete(ChildChunk).where( + ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids) + ) + ) db.session.commit() else: vector.delete() if delete_child_chunks: # Use existing compound index: (tenant_id, dataset_id, ...) - db.session.query(ChildChunk).where( - ChildChunk.tenant_id == dataset.tenant_id, ChildChunk.dataset_id == dataset.id - ).delete(synchronize_session=False) + db.session.execute( + delete(ChildChunk).where( + ChildChunk.tenant_id == dataset.tenant_id, ChildChunk.dataset_id == dataset.id + ) + ) db.session.commit() def retrieve( @@ -332,7 +337,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) # add document segments doc_store.add_documents(docs=documents, save_child=True) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: all_child_documents = [] all_multimodal_documents = [] for doc in documents: diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 62f88b7760b..6874603a833 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -21,7 +21,7 @@ from core.rag.datasource.vdb.vector_factory import Vector from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -141,7 +141,7 @@ class QAIndexProcessor(BaseIndexProcessor): with_keywords: bool = True, **kwargs, ) -> None: - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) vector.create(documents) if multimodal_documents and dataset.is_multimodal: @@ -224,7 +224,7 @@ class QAIndexProcessor(BaseIndexProcessor): # save node to document segment doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) doc_store.add_documents(docs=documents, save_child=False) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) vector.create(documents) else: diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index dc3b771406e..087736d0b0a 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -2,10 +2,9 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from typing import Any +from graphon.file import File from pydantic import BaseModel, Field -from dify_graph.file import File - class ChildDocument(BaseModel): """Class for storing a piece of text and associated metadata.""" diff --git a/api/core/rag/rerank/rerank_base.py b/api/core/rag/rerank/rerank_base.py index 88acb751334..cc652625277 100644 --- a/api/core/rag/rerank/rerank_base.py +++ b/api/core/rag/rerank/rerank_base.py @@ -12,7 +12,6 @@ class BaseRerankRunner(ABC): documents: list[Document], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: """ @@ -21,7 +20,6 @@ class BaseRerankRunner(ABC): :param documents: documents for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id if needed :return: """ raise NotImplementedError diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index fcb14ffc52c..8283be19f9b 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -1,12 +1,13 @@ import base64 +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.rerank_entities import RerankResult + from core.model_manager import ModelInstance, ModelManager from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.rerank_base import BaseRerankRunner -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.entities.rerank_entities import RerankResult from extensions.ext_database import db from extensions.ext_storage import storage from models.model import UploadFile @@ -22,7 +23,6 @@ class RerankModelRunner(BaseRerankRunner): documents: list[Document], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: """ @@ -31,10 +31,11 @@ class RerankModelRunner(BaseRerankRunner): :param documents: documents for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id if needed :return: """ - model_manager = ModelManager() + model_manager = ModelManager.for_tenant( + tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id + ) is_support_vision = model_manager.check_model_support_vision( tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id, provider=self.rerank_model_instance.provider, @@ -43,12 +44,12 @@ class RerankModelRunner(BaseRerankRunner): ) if not is_support_vision: if query_type == QueryType.TEXT_QUERY: - rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user) + rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n) else: return documents else: rerank_result, unique_documents = self.fetch_multimodal_rerank( - query, documents, score_threshold, top_n, user, query_type + query, documents, score_threshold, top_n, query_type ) rerank_documents = [] @@ -73,7 +74,6 @@ class RerankModelRunner(BaseRerankRunner): documents: list[Document], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, ) -> tuple[RerankResult, list[Document]]: """ Fetch text rerank @@ -81,7 +81,6 @@ class RerankModelRunner(BaseRerankRunner): :param documents: documents for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id if needed :return: """ docs = [] @@ -103,7 +102,7 @@ class RerankModelRunner(BaseRerankRunner): unique_documents.append(document) rerank_result = self.rerank_model_instance.invoke_rerank( - query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user + query=query, docs=docs, score_threshold=score_threshold, top_n=top_n ) return rerank_result, unique_documents @@ -113,7 +112,6 @@ class RerankModelRunner(BaseRerankRunner): documents: list[Document], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, query_type: QueryType = QueryType.TEXT_QUERY, ) -> tuple[RerankResult, list[Document]]: """ @@ -122,7 +120,6 @@ class RerankModelRunner(BaseRerankRunner): :param documents: documents for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id if needed :param query_type: query type :return: rerank result """ @@ -137,9 +134,7 @@ class RerankModelRunner(BaseRerankRunner): ): if document.metadata.get("doc_type") == DocType.IMAGE: # Query file info within db.session context to ensure thread-safe access - upload_file = ( - db.session.query(UploadFile).where(UploadFile.id == document.metadata["doc_id"]).first() - ) + upload_file = db.session.get(UploadFile, document.metadata["doc_id"]) if upload_file: blob = storage.load_once(upload_file.key) document_file_base64 = base64.b64encode(blob).decode() @@ -168,11 +163,11 @@ class RerankModelRunner(BaseRerankRunner): documents = unique_documents if query_type == QueryType.TEXT_QUERY: - rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user) + rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n) return rerank_result, unique_documents elif query_type == QueryType.IMAGE_QUERY: # Query file info within db.session context to ensure thread-safe access - upload_file = db.session.query(UploadFile).where(UploadFile.id == query).first() + upload_file = db.session.get(UploadFile, query) if upload_file: blob = storage.load_once(upload_file.key) file_query = base64.b64encode(blob).decode() @@ -181,7 +176,7 @@ class RerankModelRunner(BaseRerankRunner): "content_type": DocType.IMAGE, } rerank_result = self.rerank_model_instance.invoke_multimodal_rerank( - query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user + query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n ) return rerank_result, unique_documents else: diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index 7edd05d2d16..49123e13d05 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -2,6 +2,7 @@ import math from collections import Counter import numpy as np +from graphon.model_runtime.entities.model_entities import ModelType from core.model_manager import ModelManager from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler @@ -11,7 +12,6 @@ from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.entity.weight import VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner -from dify_graph.model_runtime.entities.model_entities import ModelType class WeightRerankRunner(BaseRerankRunner): @@ -25,7 +25,6 @@ class WeightRerankRunner(BaseRerankRunner): documents: list[Document], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: """ @@ -34,7 +33,6 @@ class WeightRerankRunner(BaseRerankRunner): :param documents: documents for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id if needed :return: """ @@ -163,7 +161,7 @@ class WeightRerankRunner(BaseRerankRunner): """ query_vector_scores = [] - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=tenant_id, diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 78a97f79a52..593e1f1420a 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -9,6 +9,11 @@ from collections.abc import Generator, Mapping from typing import Any, Union, cast from flask import Flask, current_app +from graphon.file import File, FileTransferMethod, FileType +from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import and_, func, literal, or_, select from sqlalchemy.orm import Session @@ -56,6 +61,7 @@ from core.rag.retrieval.template_prompts import ( ) from core.tools.signature import sign_upload_file from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool +from core.workflow.file_reference import build_file_reference from core.workflow.nodes.knowledge_retrieval import exc from core.workflow.nodes.knowledge_retrieval.retrieval import ( KnowledgeRetrievalRequest, @@ -63,13 +69,9 @@ from core.workflow.nodes.knowledge_retrieval.retrieval import ( SourceChildChunk, SourceMetadata, ) -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage -from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from extensions.ext_database import db from extensions.ext_redis import redis_client +from libs.helper import parse_uuid_str_or_none from libs.json_in_md_parser import parse_and_check_json_markdown from models import UploadFile from models.dataset import ( @@ -160,7 +162,7 @@ class DatasetRetrieval: if request.model_provider is None or request.model_name is None or request.query is None: raise ValueError("model_provider, model_name, and query are required for single retrieval mode") - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=request.tenant_id, user_id=request.user_id) model_instance = model_manager.get_model_instance( tenant_id=request.tenant_id, model_type=ModelType.LLM, @@ -383,23 +385,27 @@ class DatasetRetrieval: return None, [] retrieve_config = config.retrieve_config - # check model is support tool calling - model_type_instance = model_config.provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model ) + model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance) - # get model schema + # Reuse the caller-bound model instance for both schema resolution and + # downstream planner/invoke calls so a single request never mixes + # tenant-scope and request-bound runtimes. model_schema = model_type_instance.get_model_schema( - model=model_config.model, credentials=model_config.credentials + model=model_instance.model_name, + credentials=model_instance.credentials, ) if not model_schema: return None, [] + model_config.provider_model_bundle = model_instance.provider_model_bundle + model_config.credentials = model_instance.credentials + model_config.model_schema = model_schema + planning_strategy = PlanningStrategy.REACT_ROUTER features = model_schema.features if features: @@ -517,11 +523,12 @@ class DatasetRetrieval: filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - tenant_id=segment.tenant_id, type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, - related_id=upload_file.id, + reference=build_file_reference( + record_id=str(upload_file.id), + ), size=upload_file.size, storage_key=upload_file.key, url=sign_upload_file(upload_file.id, upload_file.extension), @@ -675,7 +682,7 @@ class DatasetRetrieval: # get top k top_k = retrieval_model_config["top_k"] # get retrieval method - if selected_dataset.indexing_technique == "economy": + if selected_dataset.indexing_technique == IndexTechniqueType.ECONOMY: retrieval_method = RetrievalMethod.KEYWORD_SEARCH else: retrieval_method = retrieval_model_config["search_method"] @@ -752,7 +759,7 @@ class DatasetRetrieval: "The configured knowledge base list have different indexing technique, please set reranking model." ) index_type = available_datasets[0].indexing_technique - if index_type == "high_quality": + if index_type == IndexTechniqueType.HIGH_QUALITY: embedding_model_check = all( item.embedding_model == available_datasets[0].embedding_model for item in available_datasets ) @@ -986,6 +993,24 @@ class DatasetRetrieval: ) ) + @staticmethod + def _resolve_creator_user_role(user_from: str) -> CreatorUserRole | None: + """Map runtime user source values to dataset query audit roles. + + Workflow run context uses the hyphenated ``end-user`` value, while + ``DatasetQuery.created_by_role`` persists the underscore-based + ``CreatorUserRole.END_USER`` enum. Query logging is a side effect, so an + unsupported value should be skipped instead of aborting retrieval. + """ + normalized_user_from = str(user_from).strip().lower().replace("-", "_") + if normalized_user_from == CreatorUserRole.ACCOUNT.value: + return CreatorUserRole.ACCOUNT + if normalized_user_from == CreatorUserRole.END_USER.value: + return CreatorUserRole.END_USER + + logger.warning("Skipping dataset query audit log for unsupported user_from=%r", user_from) + return None + def _on_query( self, query: str | None, @@ -996,10 +1021,18 @@ class DatasetRetrieval: user_id: str, ): """ - Handle query. + Persist dataset query audit rows for retrieval requests. """ if not query and not attachment_ids: return + created_by = parse_uuid_str_or_none(user_id) + if created_by is None: + logger.debug( + "Skipping dataset query log: empty created_by user_id (user_from=%s, app_id=%s)", + user_from, + app_id, + ) + return dataset_queries = [] for dataset_id in dataset_ids: contents = [] @@ -1015,7 +1048,7 @@ class DatasetRetrieval: source=DatasetQuerySource.APP, source_app_id=app_id, created_by_role=CreatorUserRole(user_from), - created_by=user_id, + created_by=created_by, ) dataset_queries.append(dataset_query) if dataset_queries: @@ -1068,7 +1101,7 @@ class DatasetRetrieval: else default_retrieval_model ) - if dataset.indexing_technique == "economy": + if dataset.indexing_technique == IndexTechniqueType.ECONOMY: # use keyword table query documents = RetrievalService.retrieve( retrieval_method=RetrievalMethod.KEYWORD_SEARCH, @@ -1307,7 +1340,7 @@ class DatasetRetrieval: metadata_filtering_conditions: MetadataFilteringCondition | None, inputs: dict, ) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]: - document_query = db.session.query(DatasetDocument).where( + document_query = select(DatasetDocument).where( DatasetDocument.dataset_id.in_(dataset_ids), DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, @@ -1378,7 +1411,7 @@ class DatasetRetrieval: document_query = document_query.where(and_(*filters)) else: document_query = document_query.where(or_(*filters)) - documents = document_query.all() + documents = db.session.scalars(document_query).all() # group by dataset_id metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore for document in documents: @@ -1411,7 +1444,7 @@ class DatasetRetrieval: raise ValueError("metadata_model_config is required") # get metadata model instance # fetch model config - model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config) + model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config, user_id=user_id) # fetch prompt messages prompt_messages, stop = self._get_prompt_template( @@ -1430,7 +1463,6 @@ class DatasetRetrieval: model_parameters=model_config.parameters, stop=stop, stream=True, - user=user_id, ), ) @@ -1533,7 +1565,7 @@ class DatasetRetrieval: return filters def _fetch_model_config( - self, tenant_id: str, model: ModelConfig + self, tenant_id: str, model: ModelConfig, user_id: str | None = None ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: """ Fetch model config @@ -1543,7 +1575,7 @@ class DatasetRetrieval: model_name = model.name provider_name = model.provider - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name ) diff --git a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py index 23a2ac83863..dce7b6226ce 100644 --- a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py +++ b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py @@ -1,9 +1,10 @@ from typing import Union +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.model_manager import ModelInstance -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from dify_graph.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage class FunctionCallMultiDatasetRouter: diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index ea110fa0a70..dd280cdf6a7 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -1,15 +1,17 @@ from collections.abc import Generator, Sequence from typing import Union +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool +from graphon.model_runtime.entities.model_entities import ModelType + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.llm import deduct_llm_quota -from core.model_manager import ModelInstance +from core.model_manager import ModelInstance, ModelManager from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.rag.retrieval.output_parser.react_output import ReactAction from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:""" @@ -119,6 +121,7 @@ class ReactMultiDatasetRouter: memory_config=None, memory=None, model_config=model_config, + model_instance=model_instance, ) result_text, usage = self._invoke_llm( completion_param=model_config.parameters, @@ -150,19 +153,24 @@ class ReactMultiDatasetRouter: :param stop: stop :return: """ - invoke_result: Generator[LLMResult, None, None] = model_instance.invoke_llm( + bound_model_instance = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance( + tenant_id=tenant_id, + provider=model_instance.provider, + model_type=ModelType.LLM, + model=model_instance.model_name, + ) + invoke_result: Generator[LLMResult, None, None] = bound_model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters=completion_param, stop=stop, stream=True, - user=user_id, ) # handle invoke result text, usage = self._handle_invoke_result(invoke_result=invoke_result) # deduct quota - deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage) + deduct_llm_quota(tenant_id=tenant_id, model_instance=bound_model_instance, usage=usage) return text, usage diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 7a00e8a886a..e6aec4a3af9 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -6,6 +6,8 @@ import codecs import re from typing import Any +from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer + from core.model_manager import ModelInstance from core.rag.splitter.text_splitter import ( TS, @@ -15,7 +17,6 @@ from core.rag.splitter.text_splitter import ( Set, Union, ) -from dify_graph.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): diff --git a/api/core/rag/summary_index/summary_index.py b/api/core/rag/summary_index/summary_index.py index 31d21dbeeeb..6f120bd4711 100644 --- a/api/core/rag/summary_index/summary_index.py +++ b/api/core/rag/summary_index/summary_index.py @@ -2,6 +2,7 @@ import concurrent.futures import logging from core.db.session_factory import session_factory +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary from services.summary_index_service import SummaryIndexService @@ -21,7 +22,7 @@ class SummaryIndex: if is_preview: with session_factory.create_session() as session: dataset = session.query(Dataset).filter_by(id=dataset_id).first() - if not dataset or dataset.indexing_technique != "high_quality": + if not dataset or dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: return if summary_index_setting is None: diff --git a/api/core/repositories/__init__.py b/api/core/repositories/__init__.py index 6f2826f6341..cfa9962ea8f 100644 --- a/api/core/repositories/__init__.py +++ b/api/core/repositories/__init__.py @@ -4,7 +4,13 @@ from __future__ import annotations from .celery_workflow_execution_repository import CeleryWorkflowExecutionRepository from .celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository -from .factory import DifyCoreRepositoryFactory, RepositoryImportError +from .factory import ( + DifyCoreRepositoryFactory, + OrderConfig, + RepositoryImportError, + WorkflowExecutionRepository, + WorkflowNodeExecutionRepository, +) from .sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from .sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository @@ -12,7 +18,10 @@ __all__ = [ "CeleryWorkflowExecutionRepository", "CeleryWorkflowNodeExecutionRepository", "DifyCoreRepositoryFactory", + "OrderConfig", "RepositoryImportError", "SQLAlchemyWorkflowExecutionRepository", "SQLAlchemyWorkflowNodeExecutionRepository", + "WorkflowExecutionRepository", + "WorkflowNodeExecutionRepository", ] diff --git a/api/core/repositories/celery_workflow_execution_repository.py b/api/core/repositories/celery_workflow_execution_repository.py index 57764574d7f..465f43da739 100644 --- a/api/core/repositories/celery_workflow_execution_repository.py +++ b/api/core/repositories/celery_workflow_execution_repository.py @@ -8,11 +8,11 @@ providing improved performance by offloading database operations to background w import logging from typing import Union +from graphon.entities import WorkflowExecution from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker -from dify_graph.entities.workflow_execution import WorkflowExecution -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.repositories.factory import WorkflowExecutionRepository from libs.helper import extract_tenant_id from models import Account, CreatorUserRole, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/repositories/celery_workflow_node_execution_repository.py b/api/core/repositories/celery_workflow_node_execution_repository.py index 650cf79550c..22ef44b3dc4 100644 --- a/api/core/repositories/celery_workflow_node_execution_repository.py +++ b/api/core/repositories/celery_workflow_node_execution_repository.py @@ -9,11 +9,11 @@ import logging from collections.abc import Sequence from typing import Union +from graphon.entities import WorkflowNodeExecution from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecution -from dify_graph.repositories.workflow_node_execution_repository import ( +from core.repositories.factory import ( OrderConfig, WorkflowNodeExecutionRepository, ) @@ -148,24 +148,24 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): # For now, we'll re-raise the exception raise - def get_by_workflow_run( + def get_by_workflow_execution( self, - workflow_run_id: str, + workflow_execution_id: str, order_config: OrderConfig | None = None, ) -> Sequence[WorkflowNodeExecution]: """ - Retrieve all WorkflowNodeExecution instances for a specific workflow run from cache. + Retrieve all workflow node executions for a workflow execution from cache. Args: - workflow_run_id: The workflow run ID + workflow_execution_id: The workflow execution identifier order_config: Optional configuration for ordering results Returns: A sequence of WorkflowNodeExecution instances """ try: - # Get execution IDs for this workflow run from cache - execution_ids = self._workflow_execution_mapping.get(workflow_run_id, []) + # Get execution IDs for this workflow execution from cache + execution_ids = self._workflow_execution_mapping.get(workflow_execution_id, []) # Retrieve executions from cache result = [] @@ -182,9 +182,16 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): for field_name in reversed(order_config.order_by): result.sort(key=lambda x: getattr(x, field_name, 0), reverse=reverse) - logger.debug("Retrieved %d workflow node executions for run %s from cache", len(result), workflow_run_id) + logger.debug( + "Retrieved %d workflow node executions for execution %s from cache", + len(result), + workflow_execution_id, + ) return result except Exception: - logger.exception("Failed to get workflow node executions for run %s from cache", workflow_run_id) + logger.exception( + "Failed to get workflow node executions for execution %s from cache", + workflow_execution_id, + ) return [] diff --git a/api/core/repositories/factory.py b/api/core/repositories/factory.py index dc9f8c96bf8..ed6d44f4340 100644 --- a/api/core/repositories/factory.py +++ b/api/core/repositories/factory.py @@ -5,20 +5,45 @@ This module provides a Django-like settings system for repository implementation allowing users to configure different repository backends through string paths. """ -from typing import Union +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Literal, Protocol, Union +from graphon.entities import WorkflowExecution, WorkflowNodeExecution from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from configs import dify_config -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from libs.module_loading import import_string from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom from models.workflow import WorkflowNodeExecutionTriggeredFrom +@dataclass +class OrderConfig: + """Configuration for ordering node execution instances.""" + + order_by: list[str] + order_direction: Literal["asc", "desc"] | None = None + + +class WorkflowExecutionRepository(Protocol): + def save(self, execution: WorkflowExecution): ... + + +class WorkflowNodeExecutionRepository(Protocol): + def save(self, execution: WorkflowNodeExecution): ... + + def save_execution_data(self, execution: WorkflowNodeExecution): ... + + def get_by_workflow_execution( + self, + workflow_execution_id: str, + order_config: OrderConfig | None = None, + ) -> Sequence[WorkflowNodeExecution]: ... + + class RepositoryImportError(Exception): """Raised when a repository implementation cannot be imported or instantiated.""" diff --git a/api/core/repositories/human_input_repository.py b/api/core/repositories/human_input_repository.py index 6607a87032d..72d93941498 100644 --- a/api/core/repositories/human_input_repository.py +++ b/api/core/repositories/human_input_repository.py @@ -2,32 +2,22 @@ import dataclasses import json from collections.abc import Mapping, Sequence from datetime import datetime -from typing import Any +from typing import Any, Protocol +from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from sqlalchemy import select from sqlalchemy.orm import Session, selectinload from core.db.session_factory import session_factory -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( + BoundRecipient, DeliveryChannelConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - FormDefinition, - HumanInputNodeData, - MemberRecipient, - WebAppDeliveryMethod, -) -from dify_graph.nodes.human_input.enums import ( - DeliveryMethodType, - HumanInputFormKind, - HumanInputFormStatus, -) -from dify_graph.repositories.human_input_form_repository import ( - FormCreateParams, - FormNotFoundError, - HumanInputFormEntity, - HumanInputFormRecipientEntity, + InteractiveSurfaceDeliveryMethod, + is_human_input_webapp_enabled, ) from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 @@ -36,6 +26,7 @@ from models.human_input import ( BackstageRecipientPayload, ConsoleDeliveryPayload, ConsoleRecipientPayload, + DeliveryMethodType, EmailExternalRecipientPayload, EmailMemberRecipientPayload, HumanInputDelivery, @@ -58,6 +49,65 @@ class _WorkspaceMemberInfo: email: str +class FormNotFoundError(Exception): + pass + + +@dataclasses.dataclass +class FormCreateParams: + workflow_execution_id: str | None + node_id: str + form_config: HumanInputNodeData + rendered_content: str + delivery_methods: Sequence[DeliveryChannelConfig] + display_in_ui: bool + resolved_default_values: Mapping[str, Any] + form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME + + +class HumanInputFormRecipientEntity(Protocol): + @property + def id(self) -> str: ... + + @property + def token(self) -> str: ... + + +class HumanInputFormEntity(Protocol): + @property + def id(self) -> str: ... + + @property + def submission_token(self) -> str | None: ... + + @property + def recipients(self) -> list[HumanInputFormRecipientEntity]: ... + + @property + def rendered_content(self) -> str: ... + + @property + def selected_action_id(self) -> str | None: ... + + @property + def submitted_data(self) -> Mapping[str, Any] | None: ... + + @property + def submitted(self) -> bool: ... + + @property + def status(self) -> HumanInputFormStatus: ... + + @property + def expiration_time(self) -> datetime: ... + + +class HumanInputFormRepository(Protocol): + def get_form(self, node_id: str) -> HumanInputFormEntity | None: ... + + def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: ... + + class _HumanInputFormRecipientEntityImpl(HumanInputFormRecipientEntity): def __init__(self, recipient_model: HumanInputFormRecipient): self._recipient_model = recipient_model @@ -77,7 +127,7 @@ class _HumanInputFormEntityImpl(HumanInputFormEntity): def __init__(self, form_model: HumanInputForm, recipient_models: Sequence[HumanInputFormRecipient]): self._form_model = form_model self._recipients = [_HumanInputFormRecipientEntityImpl(recipient) for recipient in recipient_models] - self._web_app_recipient = next( + self._interactive_surface_recipient = next( ( recipient for recipient in recipient_models @@ -98,12 +148,12 @@ class _HumanInputFormEntityImpl(HumanInputFormEntity): return self._form_model.id @property - def web_app_token(self): + def submission_token(self) -> str | None: if self._console_recipient is not None: return self._console_recipient.access_token - if self._web_app_recipient is None: + if self._interactive_surface_recipient is None: return None - return self._web_app_recipient.access_token + return self._interactive_surface_recipient.access_token @property def recipients(self) -> list[HumanInputFormRecipientEntity]: @@ -201,8 +251,16 @@ class HumanInputFormRepositoryImpl: self, *, tenant_id: str, - ): + app_id: str | None = None, + workflow_execution_id: str | None = None, + invoke_source: str | None = None, + submission_actor_id: str | None = None, + ) -> None: self._tenant_id = tenant_id + self._app_id = app_id + self._workflow_execution_id = workflow_execution_id + self._invoke_source = invoke_source + self._submission_actor_id = submission_actor_id def _delivery_method_to_model( self, @@ -219,7 +277,7 @@ class HumanInputFormRepositoryImpl: channel_payload=delivery_method.model_dump_json(), ) recipients: list[HumanInputFormRecipient] = [] - if isinstance(delivery_method, WebAppDeliveryMethod): + if isinstance(delivery_method, InteractiveSurfaceDeliveryMethod): recipient_model = HumanInputFormRecipient( form_id=form_id, delivery_id=delivery_id, @@ -247,16 +305,16 @@ class HumanInputFormRepositoryImpl: delivery_id: str, recipients_config: EmailRecipients, ) -> list[HumanInputFormRecipient]: - member_user_ids = [ - recipient.user_id for recipient in recipients_config.items if isinstance(recipient, MemberRecipient) + bound_reference_ids = [ + recipient.reference_id for recipient in recipients_config.items if isinstance(recipient, BoundRecipient) ] external_emails = [ recipient.email for recipient in recipients_config.items if isinstance(recipient, ExternalRecipient) ] - if recipients_config.whole_workspace: + if recipients_config.include_bound_group: members = self._query_all_workspace_members(session=session) else: - members = self._query_workspace_members_by_ids(session=session, restrict_to_user_ids=member_user_ids) + members = self._query_workspace_members_by_ids(session=session, restrict_to_user_ids=bound_reference_ids) return self._create_email_recipients_from_resolved( form_id=form_id, @@ -338,8 +396,33 @@ class HumanInputFormRepositoryImpl: rows = session.execute(stmt).all() return [_WorkspaceMemberInfo(user_id=account_id, email=email) for account_id, email in rows] + def _should_create_console_recipient( + self, + *, + form_config: HumanInputNodeData, + form_kind: HumanInputFormKind, + ) -> bool: + if form_kind != HumanInputFormKind.RUNTIME: + return False + if self._invoke_source == "debugger": + return True + if self._invoke_source == "explore": + return is_human_input_webapp_enabled(form_config) + return False + + def _should_create_backstage_recipient(self, *, form_kind: HumanInputFormKind) -> bool: + return form_kind == HumanInputFormKind.RUNTIME and ( + self._invoke_source is not None or self._submission_actor_id is not None + ) + def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: form_config: HumanInputNodeData = params.form_config + app_id = self._app_id + if not app_id: + raise ValueError("app_id is required to create a human input form") + workflow_execution_id = params.workflow_execution_id or self._workflow_execution_id + if params.form_kind == HumanInputFormKind.RUNTIME and workflow_execution_id is None: + raise ValueError("workflow_execution_id is required for runtime human input forms") with session_factory.create_session() as session, session.begin(): # Generate unique form ID @@ -359,8 +442,8 @@ class HumanInputFormRepositoryImpl: form_model = HumanInputForm( id=form_id, tenant_id=self._tenant_id, - app_id=params.app_id, - workflow_run_id=params.workflow_execution_id, + app_id=app_id, + workflow_run_id=workflow_execution_id, form_kind=params.form_kind, node_id=params.node_id, form_definition=form_definition.model_dump_json(), @@ -379,7 +462,7 @@ class HumanInputFormRepositoryImpl: session.add(delivery_and_recipients.delivery) session.add_all(delivery_and_recipients.recipients) recipient_models.extend(delivery_and_recipients.recipients) - if params.console_recipient_required and not any( + if self._should_create_console_recipient(form_config=form_config, form_kind=params.form_kind) and not any( recipient.recipient_type == RecipientType.CONSOLE for recipient in recipient_models ): console_delivery_id = str(uuidv7()) @@ -395,13 +478,13 @@ class HumanInputFormRepositoryImpl: delivery_id=console_delivery_id, recipient_type=RecipientType.CONSOLE, recipient_payload=ConsoleRecipientPayload( - account_id=params.console_creator_account_id, + account_id=self._submission_actor_id, ).model_dump_json(), ) session.add(console_delivery) session.add(console_recipient) recipient_models.append(console_recipient) - if params.backstage_recipient_required and not any( + if self._should_create_backstage_recipient(form_kind=params.form_kind) and not any( recipient.recipient_type == RecipientType.BACKSTAGE for recipient in recipient_models ): backstage_delivery_id = str(uuidv7()) @@ -417,7 +500,7 @@ class HumanInputFormRepositoryImpl: delivery_id=backstage_delivery_id, recipient_type=RecipientType.BACKSTAGE, recipient_payload=BackstageRecipientPayload( - account_id=params.console_creator_account_id, + account_id=self._submission_actor_id, ).model_dump_json(), ) session.add(backstage_delivery) @@ -427,9 +510,12 @@ class HumanInputFormRepositoryImpl: return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models) - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: + def get_form(self, node_id: str) -> HumanInputFormEntity | None: + if self._workflow_execution_id is None: + raise ValueError("workflow_execution_id is required to load runtime human input forms") + form_query = select(HumanInputForm).where( - HumanInputForm.workflow_run_id == workflow_execution_id, + HumanInputForm.workflow_run_id == self._workflow_execution_id, HumanInputForm.node_id == node_id, HumanInputForm.tenant_id == self._tenant_id, ) diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index 55e96515ac7..85d20b675d2 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -6,13 +6,13 @@ import json import logging from typing import Union +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus, WorkflowType +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker -from dify_graph.entities import WorkflowExecution -from dify_graph.enums import WorkflowExecutionStatus, WorkflowType -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter +from core.repositories.factory import WorkflowExecutionRepository from libs.helper import extract_tenant_id from models import ( Account, diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 7373ebc7cc0..a72bfa378bc 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -10,6 +10,10 @@ from concurrent.futures import ThreadPoolExecutor from typing import Any, TypeVar, Union import psycopg2.errors +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import UnaryExpression, asc, desc, select from sqlalchemy.engine import Engine from sqlalchemy.exc import IntegrityError @@ -17,11 +21,7 @@ from sqlalchemy.orm import sessionmaker from tenacity import before_sleep_log, retry, retry_if_exception, stop_after_attempt from configs import dify_config -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter +from core.repositories.factory import OrderConfig, WorkflowNodeExecutionRepository from extensions.ext_storage import storage from libs.helper import extract_tenant_id from libs.uuid_utils import uuidv7 @@ -518,29 +518,28 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) return db_models - def get_by_workflow_run( + def get_by_workflow_execution( self, - workflow_run_id: str, + workflow_execution_id: str, order_config: OrderConfig | None = None, triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) -> Sequence[WorkflowNodeExecution]: """ - Retrieve all NodeExecution instances for a specific workflow run. + Retrieve all node executions for a workflow execution. This method always queries the database to ensure complete and ordered results, but updates the cache with any retrieved executions. Args: - workflow_run_id: The workflow run ID + workflow_execution_id: The workflow execution identifier order_config: Optional configuration for ordering results order_config.order_by: List of fields to order by (e.g., ["index", "created_at"]) order_config.order_direction: Direction to order ("asc" or "desc") Returns: - A list of NodeExecution instances + A list of node execution instances """ - # Get the database models using the new method - db_models = self.get_db_models_by_workflow_run(workflow_run_id, order_config, triggered_from) + db_models = self.get_db_models_by_workflow_run(workflow_execution_id, order_config, triggered_from) with ThreadPoolExecutor(max_workers=10) as executor: domain_models = executor.map(self._to_domain_model, db_models, timeout=30) diff --git a/api/core/telemetry/__init__.py b/api/core/telemetry/__init__.py new file mode 100644 index 00000000000..ae4f53f3b7c --- /dev/null +++ b/api/core/telemetry/__init__.py @@ -0,0 +1,43 @@ +"""Telemetry facade. + +Thin public API for emitting telemetry events. All routing logic +lives in ``core.telemetry.gateway`` which is shared by both CE and EE. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from core.ops.entities.trace_entity import TraceTaskName +from core.telemetry.events import TelemetryContext, TelemetryEvent +from core.telemetry.gateway import emit as gateway_emit +from core.telemetry.gateway import get_trace_task_to_case + +if TYPE_CHECKING: + from core.ops.ops_trace_manager import TraceQueueManager + + +def emit(event: TelemetryEvent, trace_manager: TraceQueueManager | None = None) -> None: + """Emit a telemetry event. + + Translates the ``TelemetryEvent`` (keyed by ``TraceTaskName``) into a + ``TelemetryCase`` and delegates to ``core.telemetry.gateway.emit()``. + """ + case = get_trace_task_to_case().get(event.name) + if case is None: + return + + context: dict[str, object] = { + "tenant_id": event.context.tenant_id, + "user_id": event.context.user_id, + "app_id": event.context.app_id, + } + gateway_emit(case, context, event.payload, trace_manager) + + +__all__ = [ + "TelemetryContext", + "TelemetryEvent", + "TraceTaskName", + "emit", +] diff --git a/api/core/telemetry/events.py b/api/core/telemetry/events.py new file mode 100644 index 00000000000..35ace47510a --- /dev/null +++ b/api/core/telemetry/events.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from core.ops.entities.trace_entity import TraceTaskName + + +@dataclass(frozen=True) +class TelemetryContext: + tenant_id: str | None = None + user_id: str | None = None + app_id: str | None = None + + +@dataclass(frozen=True) +class TelemetryEvent: + name: TraceTaskName + context: TelemetryContext + payload: dict[str, Any] diff --git a/api/core/telemetry/gateway.py b/api/core/telemetry/gateway.py new file mode 100644 index 00000000000..7b013d05638 --- /dev/null +++ b/api/core/telemetry/gateway.py @@ -0,0 +1,239 @@ +"""Telemetry gateway — single routing layer for all editions. + +Maps ``TelemetryCase`` → ``CaseRoute`` and dispatches events to either +the CE/EE trace pipeline (``TraceQueueManager``) or the enterprise-only +metric/log Celery queue. + +This module lives in ``core/`` so both CE and EE share one routing table +and one ``emit()`` entry point. No separate enterprise gateway module is +needed — enterprise-specific dispatch (Celery task, payload offloading) +is handled here behind lazy imports that no-op in CE. +""" + +from __future__ import annotations + +import json +import logging +import uuid +from typing import TYPE_CHECKING, Any + +from core.ops.entities.trace_entity import TraceTaskName +from enterprise.telemetry.contracts import CaseRoute, SignalType +from extensions.ext_storage import storage + +if TYPE_CHECKING: + from core.ops.ops_trace_manager import TraceQueueManager + from enterprise.telemetry.contracts import TelemetryCase + +logger = logging.getLogger(__name__) + +PAYLOAD_SIZE_THRESHOLD_BYTES = 1 * 1024 * 1024 + +# --------------------------------------------------------------------------- +# Routing table — authoritative mapping for all editions +# --------------------------------------------------------------------------- + +_case_to_trace_task: dict[TelemetryCase, TraceTaskName] | None = None +_case_routing: dict[TelemetryCase, CaseRoute] | None = None + + +def _get_case_to_trace_task() -> dict[TelemetryCase, TraceTaskName]: + global _case_to_trace_task + if _case_to_trace_task is None: + from enterprise.telemetry.contracts import TelemetryCase + + _case_to_trace_task = { + TelemetryCase.WORKFLOW_RUN: TraceTaskName.WORKFLOW_TRACE, + TelemetryCase.MESSAGE_RUN: TraceTaskName.MESSAGE_TRACE, + TelemetryCase.NODE_EXECUTION: TraceTaskName.NODE_EXECUTION_TRACE, + TelemetryCase.DRAFT_NODE_EXECUTION: TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + TelemetryCase.PROMPT_GENERATION: TraceTaskName.PROMPT_GENERATION_TRACE, + TelemetryCase.TOOL_EXECUTION: TraceTaskName.TOOL_TRACE, + TelemetryCase.MODERATION_CHECK: TraceTaskName.MODERATION_TRACE, + TelemetryCase.SUGGESTED_QUESTION: TraceTaskName.SUGGESTED_QUESTION_TRACE, + TelemetryCase.DATASET_RETRIEVAL: TraceTaskName.DATASET_RETRIEVAL_TRACE, + TelemetryCase.GENERATE_NAME: TraceTaskName.GENERATE_NAME_TRACE, + } + return _case_to_trace_task + + +def get_trace_task_to_case() -> dict[TraceTaskName, TelemetryCase]: + """Return TraceTaskName → TelemetryCase (inverse of _get_case_to_trace_task).""" + return {v: k for k, v in _get_case_to_trace_task().items()} + + +def _get_case_routing() -> dict[TelemetryCase, CaseRoute]: + global _case_routing + if _case_routing is None: + from enterprise.telemetry.contracts import CaseRoute, SignalType, TelemetryCase + + _case_routing = { + # TRACE — CE-eligible (flow in both CE and EE) + TelemetryCase.WORKFLOW_RUN: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.MESSAGE_RUN: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.TOOL_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.MODERATION_CHECK: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.SUGGESTED_QUESTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.DATASET_RETRIEVAL: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.GENERATE_NAME: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + # TRACE — enterprise-only + TelemetryCase.NODE_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False), + TelemetryCase.DRAFT_NODE_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False), + TelemetryCase.PROMPT_GENERATION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False), + # METRIC_LOG — enterprise-only (signal-driven, not trace) + TelemetryCase.APP_CREATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False), + TelemetryCase.APP_UPDATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False), + TelemetryCase.APP_DELETED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False), + TelemetryCase.FEEDBACK_CREATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False), + } + return _case_routing + + +def __getattr__(name: str) -> dict: + """Lazy module-level access to routing tables.""" + if name == "CASE_ROUTING": + return _get_case_routing() + if name == "CASE_TO_TRACE_TASK": + return _get_case_to_trace_task() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def is_enterprise_telemetry_enabled() -> bool: + try: + from enterprise.telemetry.exporter import is_enterprise_telemetry_enabled + + return is_enterprise_telemetry_enabled() + except Exception: + return False + + +def _handle_payload_sizing( + payload: dict[str, Any], + tenant_id: str, + event_id: str, +) -> tuple[dict[str, Any], str | None]: + """Inline or offload payload based on size. + + Returns ``(payload_for_envelope, storage_key | None)``. Payloads + exceeding ``PAYLOAD_SIZE_THRESHOLD_BYTES`` are written to object + storage and replaced with an empty dict in the envelope. + """ + try: + payload_json = json.dumps(payload) + payload_size = len(payload_json.encode("utf-8")) + except (TypeError, ValueError): + logger.warning("Failed to serialize payload for sizing: event_id=%s", event_id) + return payload, None + + if payload_size <= PAYLOAD_SIZE_THRESHOLD_BYTES: + return payload, None + + storage_key = f"telemetry/{tenant_id}/{event_id}.json" + try: + storage.save(storage_key, payload_json.encode("utf-8")) + logger.debug("Stored large payload to storage: key=%s, size=%d", storage_key, payload_size) + return {}, storage_key + except Exception: + logger.warning("Failed to store large payload, inlining instead: event_id=%s", event_id, exc_info=True) + return payload, None + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def emit( + case: TelemetryCase, + context: dict[str, Any], + payload: dict[str, Any], + trace_manager: TraceQueueManager | None = None, +) -> None: + """Route a telemetry event to the correct pipeline. + + TRACE events are enqueued into ``TraceQueueManager`` (works in both CE + and EE). Enterprise-only traces are silently dropped when EE is + disabled. + + METRIC_LOG events are dispatched to the enterprise Celery queue; + silently dropped when enterprise telemetry is unavailable. + """ + route = _get_case_routing().get(case) + if route is None: + logger.warning("Unknown telemetry case: %s, dropping event", case) + return + + if not route.ce_eligible and not is_enterprise_telemetry_enabled(): + logger.debug("Dropping EE-only event: case=%s (EE disabled)", case) + return + + if route.signal_type == SignalType.TRACE: + _emit_trace(case, context, payload, trace_manager) + else: + _emit_metric_log(case, context, payload) + + +def _emit_trace( + case: TelemetryCase, + context: dict[str, Any], + payload: dict[str, Any], + trace_manager: TraceQueueManager | None, +) -> None: + from core.ops.ops_trace_manager import TraceQueueManager as LocalTraceQueueManager + from core.ops.ops_trace_manager import TraceTask + + trace_task_name = _get_case_to_trace_task().get(case) + if trace_task_name is None: + logger.warning("No TraceTaskName mapping for case: %s", case) + return + + queue_manager = trace_manager or LocalTraceQueueManager( + app_id=context.get("app_id"), + user_id=context.get("user_id"), + ) + queue_manager.add_trace_task(TraceTask(trace_task_name, user_id=context.get("user_id"), **payload)) + logger.debug("Enqueued trace task: case=%s, app_id=%s", case, context.get("app_id")) + + +def _emit_metric_log( + case: TelemetryCase, + context: dict[str, Any], + payload: dict[str, Any], +) -> None: + """Build envelope and dispatch to enterprise Celery queue. + + No-ops when the enterprise telemetry task is not importable (CE mode). + """ + try: + from tasks.enterprise_telemetry_task import process_enterprise_telemetry + except ImportError: + logger.debug("Enterprise metric/log dispatch unavailable, dropping: case=%s", case) + return + + tenant_id = context.get("tenant_id") or "" + event_id = str(uuid.uuid4()) + + payload_for_envelope, payload_ref = _handle_payload_sizing(payload, tenant_id, event_id) + + from enterprise.telemetry.contracts import TelemetryEnvelope + + envelope = TelemetryEnvelope( + case=case, + tenant_id=tenant_id, + event_id=event_id, + payload=payload_for_envelope, + metadata={"payload_ref": payload_ref} if payload_ref else None, + ) + + process_enterprise_telemetry.delay(envelope.model_dump_json()) + logger.debug( + "Enqueued metric/log event: case=%s, tenant_id=%s, event_id=%s", + case, + tenant_id, + event_id, + ) diff --git a/api/core/tools/__base/tool_runtime.py b/api/core/tools/__base/tool_runtime.py index 961d13f90a0..5154bc9805e 100644 --- a/api/core/tools/__base/tool_runtime.py +++ b/api/core/tools/__base/tool_runtime.py @@ -9,10 +9,14 @@ from core.tools.entities.tool_entities import ToolInvokeFrom class ToolRuntime(BaseModel): """ - Meta data of a tool call processing + Meta data of a tool call processing. + + ``user_id`` is optional so read-only tooling flows can stay tenant-scoped, + while execution paths may bind caller identity for model runtime lookups. """ tenant_id: str + user_id: str | None = None tool_id: str | None = None invoke_from: InvokeFrom | None = None tool_invoke_from: ToolInvokeFrom | None = None diff --git a/api/core/tools/builtin_tool/providers/audio/tools/asr.py b/api/core/tools/builtin_tool/providers/audio/tools/asr.py index dacc49c7464..e5390743036 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/asr.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/asr.py @@ -2,14 +2,15 @@ import io from collections.abc import Generator from typing import Any +from graphon.file import FileType +from graphon.file.file_manager import download +from graphon.model_runtime.entities.model_entities import ModelType + from core.model_manager import ModelManager from core.plugin.entities.parameters import PluginParameterOption from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter -from dify_graph.file.enums import FileType -from dify_graph.file.file_manager import download -from dify_graph.model_runtime.entities.model_entities import ModelType from services.model_provider_service import ModelProviderService @@ -22,6 +23,9 @@ class ASRTool(BuiltinTool): app_id: str | None = None, message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: + if not self.runtime: + raise ValueError("Runtime is required") + runtime = self.runtime file = tool_parameters.get("audio_file") if file.type != FileType.AUDIO: # type: ignore yield self.create_text_message("not a valid audio file") @@ -29,20 +33,19 @@ class ASRTool(BuiltinTool): audio_binary = io.BytesIO(download(file)) # type: ignore audio_binary.name = "temp.mp3" provider, model = tool_parameters.get("model").split("#") # type: ignore - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=runtime.tenant_id, user_id=user_id) model_instance = model_manager.get_model_instance( - tenant_id=self.runtime.tenant_id, + tenant_id=runtime.tenant_id, provider=provider, model_type=ModelType.SPEECH2TEXT, model=model, ) - text = model_instance.invoke_speech2text( - file=audio_binary, - user=user_id, - ) + text = model_instance.invoke_speech2text(file=audio_binary) yield self.create_text_message(text) def get_available_models(self) -> list[tuple[str, str]]: + if not self.runtime: + raise ValueError("Runtime is required") model_provider_service = ModelProviderService() models = model_provider_service.get_models_by_model_type( tenant_id=self.runtime.tenant_id, model_type="speech2text" diff --git a/api/core/tools/builtin_tool/providers/audio/tools/tts.py b/api/core/tools/builtin_tool/providers/audio/tools/tts.py index 7818bff0ab0..f49c669fe09 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/tts.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/tts.py @@ -2,12 +2,13 @@ import io from collections.abc import Generator from typing import Any +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType + from core.model_manager import ModelManager from core.plugin.entities.parameters import PluginParameterOption from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from services.model_provider_service import ModelProviderService @@ -20,13 +21,14 @@ class TTSTool(BuiltinTool): app_id: str | None = None, message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: - provider, model = tool_parameters.get("model").split("#") # type: ignore - voice = tool_parameters.get(f"voice#{provider}#{model}") - model_manager = ModelManager() if not self.runtime: raise ValueError("Runtime is required") + runtime = self.runtime + provider, model = tool_parameters.get("model").split("#") # type: ignore + voice = tool_parameters.get(f"voice#{provider}#{model}") + model_manager = ModelManager.for_tenant(tenant_id=runtime.tenant_id, user_id=user_id) model_instance = model_manager.get_model_instance( - tenant_id=self.runtime.tenant_id or "", + tenant_id=runtime.tenant_id or "", provider=provider, model_type=ModelType.TTS, model=model, @@ -39,12 +41,7 @@ class TTSTool(BuiltinTool): raise ValueError("Sorry, no voice available.") else: raise ValueError("Sorry, no voice available.") - tts = model_instance.invoke_tts( - content_text=tool_parameters.get("text"), # type: ignore - user=user_id, - tenant_id=self.runtime.tenant_id, - voice=voice, - ) + tts = model_instance.invoke_tts(content_text=tool_parameters.get("text"), voice=voice) # type: ignore[arg-type] buffer = io.BytesIO() for chunk in tts: buffer.write(chunk) diff --git a/api/core/tools/builtin_tool/providers/time/tools/current_time.py b/api/core/tools/builtin_tool/providers/time/tools/current_time.py index 44f94c27235..e07ca0d9199 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/current_time.py +++ b/api/core/tools/builtin_tool/providers/time/tools/current_time.py @@ -2,7 +2,7 @@ from collections.abc import Generator from datetime import UTC, datetime from typing import Any -from pytz import timezone as pytz_timezone +from pytz import timezone as pytz_timezone # type: ignore[import-untyped] from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py index d0a41b940f6..dc49b64dd82 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py +++ b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py @@ -2,7 +2,7 @@ from collections.abc import Generator from datetime import datetime from typing import Any -import pytz +import pytz # type: ignore[import-untyped] from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py b/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py index 462e4be5ce1..8045e4b980a 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py +++ b/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py @@ -2,7 +2,7 @@ from collections.abc import Generator from datetime import datetime from typing import Any -import pytz +import pytz # type: ignore[import-untyped] from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py index e23ae3b0019..e2570811d6b 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py +++ b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py @@ -2,7 +2,7 @@ from collections.abc import Generator from datetime import datetime from typing import Any -import pytz +import pytz # type: ignore[import-untyped] from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/builtin_tool/tool.py b/api/core/tools/builtin_tool/tool.py index 00f59310886..14af63a962e 100644 --- a/api/core/tools/builtin_tool/tool.py +++ b/api/core/tools/builtin_tool/tool.py @@ -1,11 +1,12 @@ from __future__ import annotations +from graphon.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage + from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolProviderType from core.tools.utils.model_invocation_utils import ModelInvocationUtils -from dify_graph.model_runtime.entities.llm_entities import LLMResult -from dify_graph.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage _SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language and you can quickly aimed at the main point of an webpage and reproduce it in your own words but @@ -50,9 +51,10 @@ class BuiltinTool(Tool): return ModelInvocationUtils.invoke( user_id=user_id, tenant_id=self.runtime.tenant_id or "", - tool_type="builtin", + tool_type=ToolProviderType.BUILT_IN, tool_name=self.entity.identity.name, prompt_messages=prompt_messages, + caller_user_id=self.runtime.user_id, ) def tool_provider_type(self) -> ToolProviderType: @@ -69,6 +71,7 @@ class BuiltinTool(Tool): return ModelInvocationUtils.get_max_llm_context_tokens( tenant_id=self.runtime.tenant_id or "", + user_id=self.runtime.user_id, ) def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int: @@ -82,7 +85,9 @@ class BuiltinTool(Tool): raise ValueError("runtime is required") return ModelInvocationUtils.calculate_tokens( - tenant_id=self.runtime.tenant_id or "", prompt_messages=prompt_messages + tenant_id=self.runtime.tenant_id or "", + prompt_messages=prompt_messages, + user_id=self.runtime.user_id, ) def summary(self, user_id: str, content: str) -> str: diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index c6a84e27c61..0a2c37c5632 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -6,6 +6,7 @@ from typing import Any, Union from urllib.parse import urlencode import httpx +from graphon.file.file_manager import download from core.helper import ssrf_proxy from core.tools.__base.tool import Tool @@ -13,7 +14,6 @@ from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError -from dify_graph.file.file_manager import download API_TOOL_DEFAULT_TIMEOUT = ( int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")), diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 2545290b57d..d5d3d1b1d95 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -2,6 +2,7 @@ from collections.abc import Mapping from datetime import datetime from typing import Any, Literal +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration @@ -9,7 +10,6 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.tools.__base.tool import ToolParameter from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType -from dify_graph.model_runtime.utils.encoders import jsonable_encoder class ToolApiEntity(BaseModel): diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index 9025ff6ef16..f6d09472b3c 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -6,6 +6,8 @@ import logging from collections.abc import Generator, Mapping from typing import Any, cast +from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata + from core.mcp.auth_client import MCPClientWithAuthRetry from core.mcp.error import MCPConnectionError from core.mcp.types import ( @@ -21,7 +23,6 @@ from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType from core.tools.errors import ToolInvokeError -from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata logger = logging.getLogger(__name__) diff --git a/api/core/tools/signature.py b/api/core/tools/signature.py index 22e099debad..1807226924d 100644 --- a/api/core/tools/signature.py +++ b/api/core/tools/signature.py @@ -3,6 +3,7 @@ import hashlib import hmac import os import time +import urllib.parse from configs import dify_config @@ -58,3 +59,43 @@ def verify_tool_file_signature(file_id: str, timestamp: str, nonce: str, sign: s current_time = int(time.time()) return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT + + +def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str, user_id: str) -> str: + """Build the signed upload URL used by the plugin-facing file upload endpoint.""" + + base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL + upload_url = f"{base_url}/files/upload/for-plugin" + timestamp = str(int(time.time())) + nonce = os.urandom(16).hex() + data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + query = urllib.parse.urlencode( + { + "timestamp": timestamp, + "nonce": nonce, + "sign": encoded_sign, + "user_id": user_id, + "tenant_id": tenant_id, + } + ) + return f"{upload_url}?{query}" + + +def verify_plugin_file_signature( + *, filename: str, mimetype: str, tenant_id: str, user_id: str, timestamp: str, nonce: str, sign: str +) -> bool: + """Verify the signature used by the plugin-facing file upload endpoint.""" + + data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() + + if sign != recalculated_encoded_sign: + return False + + current_time = int(time.time()) + return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 64212a26360..685d687d8c4 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -7,6 +7,7 @@ from datetime import UTC, datetime from mimetypes import guess_type from typing import Any, Union, cast +from graphon.file import FileTransferMethod, FileType from yarl import URL from core.app.entities.app_invoke_entities import InvokeFrom @@ -31,8 +32,6 @@ from core.tools.errors import ( ) from core.tools.utils.message_transformer import ToolFileMessageTransformer, safe_json_value from core.tools.workflow_as_tool.tool import WorkflowTool -from dify_graph.file import FileType -from dify_graph.file.models import FileTransferMethod from extensions.ext_database import db from models.enums import CreatorUserRole, MessageFileBelongsTo from models.model import Message, MessageFile diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index 210f488afcb..7ac29cf0698 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -10,11 +10,12 @@ from typing import Union from uuid import uuid4 import httpx +from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type from configs import dify_config from core.db.session_factory import session_factory from core.helper import ssrf_proxy -from dify_graph.file.models import ToolFile as ToolFilePydanticModel +from core.workflow.file_reference import build_file_reference from extensions.ext_storage import storage from models.model import MessageFile from models.tools import ToolFile @@ -23,6 +24,21 @@ logger = logging.getLogger(__name__) class ToolFileManager: + @staticmethod + def _build_graph_file_reference(tool_file: ToolFile) -> File: + extension = guess_extension(tool_file.mimetype) or ".bin" + return File( + type=get_file_type_by_mime_type(tool_file.mimetype), + transfer_method=FileTransferMethod.TOOL_FILE, + remote_url=tool_file.original_url, + reference=build_file_reference(record_id=str(tool_file.id)), + filename=tool_file.name, + extension=extension, + mime_type=tool_file.mimetype, + size=tool_file.size, + storage_key=tool_file.file_key, + ) + @staticmethod def sign_file(tool_file_id: str, extension: str) -> str: """ @@ -209,9 +225,7 @@ class ToolFileManager: return blob, tool_file.mimetype - def get_file_generator_by_tool_file_id( - self, tool_file_id: str - ) -> tuple[Generator | None, ToolFilePydanticModel | None]: + def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, File | None]: """ get file binary @@ -233,11 +247,11 @@ class ToolFileManager: stream = storage.load_stream(tool_file.file_key) - return stream, ToolFilePydanticModel.model_validate(tool_file) + return stream, self._build_graph_file_reference(tool_file) # init tool_file_parser -from dify_graph.file.tool_file_parser import set_tool_file_manager_factory +from graphon.file.tool_file_parser import set_tool_file_manager_factory def _factory() -> ToolFileManager: diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py index 90d5a647e92..58190d10894 100644 --- a/api/core/tools/tool_label_manager.py +++ b/api/core/tools/tool_label_manager.py @@ -1,4 +1,4 @@ -from sqlalchemy import select +from sqlalchemy import delete, select from core.tools.__base.tool_provider import ToolProviderController from core.tools.builtin_tool.provider import BuiltinToolProviderController @@ -31,14 +31,14 @@ class ToolLabelManager: raise ValueError("Unsupported tool type") # delete old labels - db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id).delete() + db.session.execute(delete(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id)) # insert new labels for label in labels: db.session.add( ToolLabelBinding( tool_id=provider_id, - tool_type=controller.provider_type.value, + tool_type=controller.provider_type, label_name=label, ) ) @@ -58,7 +58,7 @@ class ToolLabelManager: raise ValueError("Unsupported tool type") stmt = select(ToolLabelBinding.label_name).where( ToolLabelBinding.tool_id == provider_id, - ToolLabelBinding.tool_type == controller.provider_type.value, + ToolLabelBinding.tool_type == controller.provider_type, ) labels = db.session.scalars(stmt).all() diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 23a877b7e39..a58d3103137 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -5,9 +5,10 @@ import time from collections.abc import Generator, Mapping from os import listdir, path from threading import Lock -from typing import TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union, cast +from typing import TYPE_CHECKING, Any, Literal, Optional, Protocol, TypedDict, Union, cast import sqlalchemy as sa +from graphon.runtime import VariablePool from sqlalchemy import select from sqlalchemy.orm import Session from yarl import URL @@ -24,20 +25,20 @@ from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.plugin_tool.tool import PluginTool from core.tools.utils.uuid_utils import is_valid_uuid from core.tools.workflow_as_tool.provider import WorkflowToolProviderController -from dify_graph.runtime.variable_pool import VariablePool from extensions.ext_database import db from models.provider_ids import ToolProviderID from services.enterprise.plugin_manager_service import PluginCredentialType from services.tools.mcp_tools_manage_service import MCPToolManageService if TYPE_CHECKING: - from dify_graph.nodes.tool.entities import ToolEntity + pass + +from graphon.model_runtime.utils.encoders import jsonable_encoder from core.agent.entities import AgentToolEntity from core.app.entities.app_invoke_entities import InvokeFrom from core.helper.module_import_helper import load_single_subclass_from_source from core.helper.position_helper import is_filtered -from core.plugin.entities.plugin_daemon import CredentialType from core.tools.__base.tool import Tool from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort @@ -57,12 +58,11 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter from core.tools.workflow_as_tool.tool import WorkflowTool -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider from services.tools.tools_transform_service import ToolTransformService if TYPE_CHECKING: - from dify_graph.nodes.tool.entities import ToolEntity + pass logger = logging.getLogger(__name__) @@ -77,6 +77,23 @@ class EmojiIconDict(TypedDict): content: str +class WorkflowToolRuntimeSpec(Protocol): + @property + def provider_type(self) -> ToolProviderType: ... + + @property + def provider_id(self) -> str: ... + + @property + def tool_name(self) -> str: ... + + @property + def tool_configurations(self) -> Mapping[str, Any]: ... + + @property + def credential_id(self) -> str | None: ... + + class ToolManager: _builtin_provider_lock = Lock() _hardcoded_providers: dict[str, BuiltinToolProviderController] = {} @@ -167,6 +184,7 @@ class ToolManager: provider_id: str, tool_name: str, tenant_id: str, + user_id: str | None = None, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, credential_id: str | None = None, @@ -178,6 +196,7 @@ class ToolManager: :param provider_id: the id of the provider :param tool_name: the name of the tool :param tenant_id: the tenant id + :param user_id: the caller id bound to runtime-scoped model/tool lookups :param invoke_from: invoke from :param tool_invoke_from: the tool invoke from :param credential_id: the credential id @@ -196,6 +215,7 @@ class ToolManager: return builtin_tool.fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, + user_id=user_id, credentials={}, invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, @@ -235,11 +255,11 @@ class ToolManager: if builtin_provider is None: raise ToolProviderNotFoundError(f"no default provider for {provider_id}") else: - builtin_provider = ( - db.session.query(BuiltinToolProvider) + builtin_provider = db.session.scalar( + select(BuiltinToolProvider) .where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)) .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) - .first() + .limit(1) ) if builtin_provider is None: @@ -304,8 +324,9 @@ class ToolManager: return builtin_tool.fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, + user_id=user_id, credentials=dict(decrypted_credentials), - credential_type=CredentialType.of(builtin_provider.credential_type), + credential_type=builtin_provider.credential_type, runtime_parameters={}, invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, @@ -321,6 +342,7 @@ class ToolManager: return api_provider.get_tool(tool_name).fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, + user_id=user_id, credentials=dict(encrypter.decrypt(credentials)), invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, @@ -344,6 +366,7 @@ class ToolManager: return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, + user_id=user_id, credentials={}, invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, @@ -352,9 +375,21 @@ class ToolManager: elif provider_type == ToolProviderType.APP: raise NotImplementedError("app provider not implemented") elif provider_type == ToolProviderType.PLUGIN: - return cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name) + plugin_tool = cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name) + runtime = getattr(plugin_tool, "runtime", None) + if runtime is not None: + runtime.user_id = user_id + runtime.invoke_from = invoke_from + runtime.tool_invoke_from = tool_invoke_from + return plugin_tool elif provider_type == ToolProviderType.MCP: - return cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name) + mcp_tool = cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name) + runtime = getattr(mcp_tool, "runtime", None) + if runtime is not None: + runtime.user_id = user_id + runtime.invoke_from = invoke_from + runtime.tool_invoke_from = tool_invoke_from + return mcp_tool else: raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found") @@ -364,6 +399,7 @@ class ToolManager: tenant_id: str, app_id: str, agent_tool: AgentToolEntity, + user_id: str | None = None, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, variable_pool: Optional["VariablePool"] = None, ) -> Tool: @@ -375,6 +411,7 @@ class ToolManager: provider_id=agent_tool.provider_id, tool_name=agent_tool.tool_name, tenant_id=tenant_id, + user_id=user_id, invoke_from=invoke_from, tool_invoke_from=ToolInvokeFrom.AGENT, credential_id=agent_tool.credential_id, @@ -405,7 +442,8 @@ class ToolManager: tenant_id: str, app_id: str, node_id: str, - workflow_tool: "ToolEntity", + workflow_tool: WorkflowToolRuntimeSpec, + user_id: str | None = None, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, variable_pool: Optional["VariablePool"] = None, ) -> Tool: @@ -418,6 +456,7 @@ class ToolManager: provider_id=workflow_tool.provider_id, tool_name=workflow_tool.tool_name, tenant_id=tenant_id, + user_id=user_id, invoke_from=invoke_from, tool_invoke_from=ToolInvokeFrom.WORKFLOW, credential_id=workflow_tool.credential_id, @@ -450,6 +489,7 @@ class ToolManager: provider: str, tool_name: str, tool_parameters: dict[str, Any], + user_id: str | None = None, credential_id: str | None = None, ) -> Tool: """ @@ -460,6 +500,7 @@ class ToolManager: provider_id=provider, tool_name=tool_name, tenant_id=tenant_id, + user_id=user_id, invoke_from=InvokeFrom.SERVICE_API, tool_invoke_from=ToolInvokeFrom.PLUGIN, credential_id=credential_id, @@ -777,13 +818,13 @@ class ToolManager: :return: the provider controller, the credentials """ - provider: ApiToolProvider | None = ( - db.session.query(ApiToolProvider) + provider: ApiToolProvider | None = db.session.scalar( + select(ApiToolProvider) .where( ApiToolProvider.id == provider_id, ApiToolProvider.tenant_id == tenant_id, ) - .first() + .limit(1) ) if provider is None: @@ -831,13 +872,13 @@ class ToolManager: get api provider """ provider_name = provider - provider_obj: ApiToolProvider | None = ( - db.session.query(ApiToolProvider) + provider_obj: ApiToolProvider | None = db.session.scalar( + select(ApiToolProvider) .where( ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == provider, ) - .first() + .limit(1) ) if provider_obj is None: @@ -923,10 +964,10 @@ class ToolManager: @classmethod def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict: try: - workflow_provider: WorkflowToolProvider | None = ( - db.session.query(WorkflowToolProvider) + workflow_provider: WorkflowToolProvider | None = db.session.scalar( + select(WorkflowToolProvider) .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) - .first() + .limit(1) ) if workflow_provider is None: @@ -940,10 +981,10 @@ class ToolManager: @classmethod def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict: try: - api_provider: ApiToolProvider | None = ( - db.session.query(ApiToolProvider) + api_provider: ApiToolProvider | None = db.session.scalar( + select(ApiToolProvider) .where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id) - .first() + .limit(1) ) if api_provider is None: @@ -1015,14 +1056,14 @@ class ToolManager: cls, parameters: list[ToolParameter], variable_pool: Optional["VariablePool"], - tool_configurations: dict[str, Any], + tool_configurations: Mapping[str, Any], typ: Literal["agent", "workflow", "tool"] = "workflow", ) -> dict[str, Any]: """ Convert tool parameters type """ - from dify_graph.nodes.tool.entities import ToolNodeData - from dify_graph.nodes.tool.exc import ToolParameterError + from graphon.nodes.tool.entities import ToolNodeData + from graphon.nodes.tool.exc import ToolParameterError runtime_parameters = {} for parameter in parameters: diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index c2b520fa991..e63435db988 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -1,6 +1,7 @@ import threading from flask import Flask, current_app +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field from sqlalchemy import select @@ -8,12 +9,12 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa from core.model_manager import ModelManager from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.citation_metadata import RetrievalSourceMetadata +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document as RagDocument from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from core.tools.utils.dataset_retriever.dataset_retriever_tool import DefaultRetrievalModelDict -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment @@ -65,7 +66,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): for thread in threads: thread.join() # do rerank for searched documents - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id) rerank_model_instance = model_manager.get_model_instance( tenant_id=self.tenant_id, provider=self.reranking_provider_name, @@ -109,7 +110,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): context_list: list[RetrievalSourceMetadata] = [] resource_number = 1 for segment in sorted_segments: - dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() + dataset = db.session.get(Dataset, segment.dataset_id) document_stmt = select(Document).where( Document.id == segment.document_id, Document.enabled == True, @@ -169,7 +170,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): # get retrieval model , if the model is not setting , using default retrieval_model = dataset.retrieval_model or default_retrieval_model - if dataset.indexing_technique == "economy": + if dataset.indexing_technique == IndexTechniqueType.ECONOMY: # use keyword table query documents = RetrievalService.retrieve( retrieval_method=RetrievalMethod.KEYWORD_SEARCH, diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index 429b7e66227..cbd8bdb36cf 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -8,6 +8,7 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict, from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities.context_entities import DocumentContext +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document as RetrievalDocument from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -140,7 +141,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): # get retrieval model , if the model is not setting , using default retrieval_model = dataset.retrieval_model or default_retrieval_model retrieval_resource_list: list[RetrievalSourceMetadata] = [] - if dataset.indexing_technique == "economy": + if dataset.indexing_technique == IndexTechniqueType.ECONOMY: # use keyword table query documents = RetrievalService.retrieve( retrieval_method=RetrievalMethod.KEYWORD_SEARCH, @@ -173,7 +174,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): for hit_callback in self.hit_callbacks: hit_callback.on_tool_end(documents) document_score_list = {} - if dataset.indexing_technique != "economy": + if dataset.indexing_technique != IndexTechniqueType.ECONOMY: for item in documents: if item.metadata is not None and item.metadata.get("score"): document_score_list[item.metadata["doc_id"]] = item.metadata["score"] @@ -204,7 +205,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): if self.return_resource: for record in records: segment = record.segment - dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() + dataset = db.session.get(Dataset, segment.dataset_id) dataset_document_stmt = select(DatasetDocument).where( DatasetDocument.id == segment.document_id, DatasetDocument.enabled == True, diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 6fc5fead2de..bb5b3ba76e9 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -1,4 +1,5 @@ import logging +import re from collections.abc import Generator from datetime import date, datetime from decimal import Decimal @@ -7,15 +8,18 @@ from uuid import UUID import numpy as np import pytz +from graphon.file import File, FileTransferMethod, FileType from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool_file_manager import ToolFileManager -from dify_graph.file import File, FileTransferMethod, FileType +from core.workflow.file_reference import parse_file_reference from libs.login import current_user from models import Account logger = logging.getLogger(__name__) +_TOOL_FILE_URL_PATTERN = re.compile(r"(?:^|/+)files/tools/(?P[^/?#.]+)") + def safe_json_value(v): if isinstance(v, datetime): @@ -82,11 +86,15 @@ class ToolFileMessageTransformer: ) url = f"/files/tools/{tool_file.id}{guess_extension(tool_file.mimetype) or '.png'}" + meta = cls._with_tool_file_meta( + message.meta, + tool_file_id=str(tool_file.id), + ) yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, message=ToolInvokeMessage.TextMessage(text=url), - meta=message.meta.copy() if message.meta is not None else {}, + meta=meta, ) except Exception as e: yield ToolInvokeMessage( @@ -122,38 +130,45 @@ class ToolFileMessageTransformer: ) url = cls.get_tool_file_url(tool_file_id=tool_file.id, extension=guess_extension(tool_file.mimetype)) + meta = cls._with_tool_file_meta(meta, tool_file_id=str(tool_file.id)) # check if file is image if "image" in mimetype: yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, message=ToolInvokeMessage.TextMessage(text=url), - meta=meta.copy() if meta is not None else {}, + meta=meta, ) else: yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.BINARY_LINK, message=ToolInvokeMessage.TextMessage(text=url), - meta=meta.copy() if meta is not None else {}, + meta=meta, ) elif message.type == ToolInvokeMessage.MessageType.FILE: meta = message.meta or {} file = meta.get("file", None) if isinstance(file, File): if file.transfer_method == FileTransferMethod.TOOL_FILE: - assert file.related_id is not None - url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension) + parsed_reference = parse_file_reference(file.reference) + if parsed_reference is None: + raise ValueError("tool file is missing reference") + url = cls.get_tool_file_url( + tool_file_id=parsed_reference.record_id, + extension=file.extension, + ) + tool_file_meta = cls._with_tool_file_meta(meta, tool_file_id=parsed_reference.record_id) if file.type == FileType.IMAGE: yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, message=ToolInvokeMessage.TextMessage(text=url), - meta=meta.copy() if meta is not None else {}, + meta=tool_file_meta, ) else: yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text=url), - meta=meta.copy() if meta is not None else {}, + meta=tool_file_meta, ) else: yield message @@ -162,9 +177,40 @@ class ToolFileMessageTransformer: if isinstance(message.message, ToolInvokeMessage.JsonMessage): message.message.json_object = safe_json_value(message.message.json_object) yield message + elif message.type in { + ToolInvokeMessage.MessageType.IMAGE_LINK, + ToolInvokeMessage.MessageType.BINARY_LINK, + } and isinstance(message.message, ToolInvokeMessage.TextMessage): + yield ToolInvokeMessage( + type=message.type, + message=message.message, + meta=cls._with_tool_file_meta(message.meta, url=message.message.text), + ) else: yield message @classmethod def get_tool_file_url(cls, tool_file_id: str, extension: str | None) -> str: return f"/files/tools/{tool_file_id}{extension or '.bin'}" + + @staticmethod + def _with_tool_file_meta( + meta: dict | None, + *, + tool_file_id: str | None = None, + url: str | None = None, + ) -> dict: + normalized_meta = meta.copy() if meta is not None else {} + resolved_tool_file_id = tool_file_id or ToolFileMessageTransformer._extract_tool_file_id(url) + if resolved_tool_file_id and "tool_file_id" not in normalized_meta: + normalized_meta["tool_file_id"] = resolved_tool_file_id + return normalized_meta + + @staticmethod + def _extract_tool_file_id(url: str | None) -> str | None: + if not url: + return None + match = _TOOL_FILE_URL_PATTERN.search(url) + if match is None: + return None + return match.group("tool_file_id") diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 8f958563bd3..8d6f83dc07c 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -8,19 +8,21 @@ import json from decimal import Decimal from typing import cast -from core.model_manager import ModelManager -from dify_graph.model_runtime.entities.llm_entities import LLMResult -from dify_graph.model_runtime.entities.message_entities import PromptMessage -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from dify_graph.model_runtime.errors.invoke import ( +from graphon.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.message_entities import PromptMessage +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from graphon.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, InvokeRateLimitError, InvokeServerUnavailableError, ) -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.utils.encoders import jsonable_encoder + +from core.model_manager import ModelManager +from core.tools.entities.tool_entities import ToolProviderType from extensions.ext_database import db from models.tools import ToolModelInvoke @@ -33,11 +35,12 @@ class ModelInvocationUtils: @staticmethod def get_max_llm_context_tokens( tenant_id: str, + user_id: str | None = None, ) -> int: """ get max llm context tokens of the model """ - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id) model_instance = model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -59,13 +62,13 @@ class ModelInvocationUtils: return max_tokens @staticmethod - def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage]) -> int: + def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage], user_id: str | None = None) -> int: """ calculate tokens from prompt messages and model parameters """ # get model instance - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id) model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM) if not model_instance: @@ -78,7 +81,12 @@ class ModelInvocationUtils: @staticmethod def invoke( - user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage] + user_id: str, + tenant_id: str, + tool_type: ToolProviderType, + tool_name: str, + prompt_messages: list[PromptMessage], + caller_user_id: str | None = None, ) -> LLMResult: """ invoke model with parameters in user's own context @@ -92,7 +100,7 @@ class ModelInvocationUtils: """ # get model manager - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=caller_user_id or user_id) # get model instance model_instance = model_manager.get_default_model_instance( tenant_id=tenant_id, @@ -136,7 +144,6 @@ class ModelInvocationUtils: tools=[], stop=[], stream=False, - user=user_id, callbacks=[], ) except InvokeRateLimitError as e: diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index 28f13766555..c4b7d574493 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -1,11 +1,12 @@ from collections.abc import Mapping, Sequence from typing import Any +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.base.entities import OutputVariableEntity +from graphon.variables.input_entities import VariableEntity + from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration from core.tools.errors import WorkflowToolHumanInputNotSupportedError -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.nodes.base.entities import OutputVariableEntity -from dify_graph.variables.input_entities import VariableEntity class WorkflowToolConfigurationUtils: diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index aef8b3f7798..f48b24be30e 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -2,6 +2,7 @@ from __future__ import annotations from collections.abc import Mapping +from graphon.variables.input_entities import VariableEntity, VariableEntityType from pydantic import Field from sqlalchemy.orm import Session @@ -22,7 +23,6 @@ from core.tools.entities.tool_entities import ( ) from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from core.tools.workflow_as_tool.tool import WorkflowTool -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType from extensions.ext_database import db from models.account import Account from models.model import App, AppMode diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 9b9aa7a7419..a3fb4eda928 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -5,8 +5,11 @@ import logging from collections.abc import Generator, Mapping, Sequence from typing import Any, cast +from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod +from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from sqlalchemy import select +from core.app.file_access import DatabaseFileAccessController from core.db.session_factory import session_factory from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime @@ -17,14 +20,15 @@ from core.tools.entities.tool_entities import ( ToolProviderType, ) from core.tools.errors import ToolInvokeError -from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod -from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata +from core.workflow.file_reference import resolve_file_record_id from factories.file_factory import build_from_mapping from models import Account, Tenant from models.model import App, EndUser +from models.utils.file_input_compat import build_file_from_stored_mapping from models.workflow import Workflow logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() class WorkflowTool(Tool): @@ -288,16 +292,25 @@ class WorkflowTool(Tool): file = tool_parameters.get(parameter.name) if file: try: - file_var_list = [File.model_validate(f) for f in file] + file_var_list = [ + build_file_from_stored_mapping( + file_mapping=cast(Mapping[str, Any], f), + tenant_id=str(self.runtime.tenant_id), + ) + for f in file + if isinstance(f, Mapping) + ] for file in file_var_list: file_dict: dict[str, str | None] = { "transfer_method": file.transfer_method.value, "type": file.type.value, } if file.transfer_method == FileTransferMethod.TOOL_FILE: - file_dict["tool_file_id"] = file.related_id + file_dict["tool_file_id"] = resolve_file_record_id(file.reference) elif file.transfer_method == FileTransferMethod.LOCAL_FILE: - file_dict["upload_file_id"] = file.related_id + file_dict["upload_file_id"] = resolve_file_record_id(file.reference) + elif file.transfer_method == FileTransferMethod.DATASOURCE_FILE: + file_dict["datasource_file_id"] = resolve_file_record_id(file.reference) elif file.transfer_method == FileTransferMethod.REMOTE_URL: file_dict["url"] = file.generate_url() @@ -325,6 +338,7 @@ class WorkflowTool(Tool): file = build_from_mapping( mapping=item, tenant_id=str(self.runtime.tenant_id), + access_controller=_file_access_controller, ) files.append(file) elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: @@ -332,6 +346,7 @@ class WorkflowTool(Tool): file = build_from_mapping( mapping=value, tenant_id=str(self.runtime.tenant_id), + access_controller=_file_access_controller, ) files.append(file) @@ -340,9 +355,10 @@ class WorkflowTool(Tool): return result, files def _update_file_mapping(self, file_dict: dict): + file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id")) transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method")) if transfer_method == FileTransferMethod.TOOL_FILE: - file_dict["tool_file_id"] = file_dict.get("related_id") + file_dict["tool_file_id"] = file_id elif transfer_method == FileTransferMethod.LOCAL_FILE: - file_dict["upload_file_id"] = file_dict.get("related_id") + file_dict["upload_file_id"] = file_id return file_dict diff --git a/api/core/trigger/debug/event_selectors.py b/api/core/trigger/debug/event_selectors.py index 2a133b2b944..61d1cd85402 100644 --- a/api/core/trigger/debug/event_selectors.py +++ b/api/core/trigger/debug/event_selectors.py @@ -8,6 +8,7 @@ from collections.abc import Mapping from datetime import datetime from typing import Any +from graphon.entities.graph_config import NodeConfigDict from pydantic import BaseModel from core.plugin.entities.request import TriggerInvokeEventResponse @@ -26,7 +27,6 @@ from core.trigger.debug.events import ( ) from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig -from dify_graph.entities.graph_config import NodeConfigDict from extensions.ext_redis import redis_client from libs.datetime_utils import ensure_naive_utc, naive_utc_now from libs.schedule_utils import calculate_next_run_at diff --git a/api/core/workflow/file_reference.py b/api/core/workflow/file_reference.py new file mode 100644 index 00000000000..c80acb37830 --- /dev/null +++ b/api/core/workflow/file_reference.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import base64 +import json +from dataclasses import dataclass + +_FILE_REFERENCE_PREFIX = "dify-file-ref:" + + +@dataclass(frozen=True) +class FileReference: + record_id: str + storage_key: str | None = None + + +def build_file_reference(*, record_id: str, storage_key: str | None = None) -> str: + payload = {"record_id": record_id} + if storage_key is not None: + payload["storage_key"] = storage_key + encoded_payload = base64.urlsafe_b64encode(json.dumps(payload, separators=(",", ":")).encode()).decode() + return f"{_FILE_REFERENCE_PREFIX}{encoded_payload}" + + +def parse_file_reference(reference: str | None) -> FileReference | None: + if not reference: + return None + + if not reference.startswith(_FILE_REFERENCE_PREFIX): + return FileReference(record_id=reference) + + encoded_payload = reference.removeprefix(_FILE_REFERENCE_PREFIX) + try: + payload = json.loads(base64.urlsafe_b64decode(encoded_payload.encode())) + except (ValueError, json.JSONDecodeError): + return FileReference(record_id=reference) + + record_id = payload.get("record_id") + if not isinstance(record_id, str) or not record_id: + return FileReference(record_id=reference) + + storage_key = payload.get("storage_key") + if storage_key is not None and not isinstance(storage_key, str): + storage_key = None + + return FileReference(record_id=record_id, storage_key=storage_key) + + +def resolve_file_record_id(reference: str | None) -> str | None: + parsed_reference = parse_file_reference(reference) + if parsed_reference is None: + return None + return parsed_reference.record_id diff --git a/api/core/workflow/human_input_compat.py b/api/core/workflow/human_input_compat.py new file mode 100644 index 00000000000..c95516a240b --- /dev/null +++ b/api/core/workflow/human_input_compat.py @@ -0,0 +1,298 @@ +"""Workflow-layer adapters for legacy human-input payload keys. + +Stored workflow graphs and editor payloads may still use Dify-specific human +input recipient keys. Normalize them here before handing configs to +`graphon` so graph-owned models only see graph-neutral field names. +""" + +from __future__ import annotations + +import enum +import uuid +from collections.abc import Mapping, Sequence +from typing import Annotated, Any, ClassVar, Literal + +import bleach +import markdown +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.base.variable_template_parser import VariableTemplateParser +from graphon.runtime import VariablePool +from graphon.variables.consts import SELECTORS_LENGTH +from markdown.extensions.tables import TableExtension +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter + + +class DeliveryMethodType(enum.StrEnum): + WEBAPP = enum.auto() + EMAIL = enum.auto() + + +class EmailRecipientType(enum.StrEnum): + BOUND = "member" + MEMBER = BOUND + EXTERNAL = "external" + + +class _InteractiveSurfaceDeliveryConfig(BaseModel): + pass + + +class BoundRecipient(BaseModel): + model_config = ConfigDict(extra="forbid") + + type: Literal[EmailRecipientType.BOUND] = EmailRecipientType.BOUND + reference_id: str + + +class ExternalRecipient(BaseModel): + model_config = ConfigDict(extra="forbid") + + type: Literal[EmailRecipientType.EXTERNAL] = EmailRecipientType.EXTERNAL + email: str + + +MemberRecipient = BoundRecipient +EmailRecipient = Annotated[BoundRecipient | ExternalRecipient, Field(discriminator="type")] + + +class EmailRecipients(BaseModel): + model_config = ConfigDict(extra="forbid") + + include_bound_group: bool = Field( + default=False, + validation_alias=AliasChoices("include_bound_group", "whole_workspace"), + ) + items: list[EmailRecipient] = Field(default_factory=list) + + +class EmailDeliveryConfig(BaseModel): + URL_PLACEHOLDER: ClassVar[str] = "{{#url#}}" + _ALLOWED_HTML_TAGS: ClassVar[list[str]] = [ + "a", + "br", + "code", + "em", + "li", + "ol", + "p", + "pre", + "strong", + "table", + "tbody", + "td", + "th", + "thead", + "tr", + "ul", + ] + _ALLOWED_HTML_ATTRIBUTES: ClassVar[dict[str, list[str]]] = { + "a": ["href", "title"], + "td": ["align"], + "th": ["align"], + } + _ALLOWED_PROTOCOLS: ClassVar[set[str]] = set(bleach.sanitizer.ALLOWED_PROTOCOLS) | {"mailto"} + + recipients: EmailRecipients + subject: str + body: str + debug_mode: bool = False + + def with_recipients(self, recipients: EmailRecipients) -> EmailDeliveryConfig: + return self.model_copy(update={"recipients": recipients}) + + @classmethod + def replace_url_placeholder(cls, body: str, url: str | None) -> str: + return body.replace(cls.URL_PLACEHOLDER, url or "") + + @classmethod + def render_body_template( + cls, + *, + body: str, + url: str | None, + variable_pool: VariablePool | None = None, + ) -> str: + templated_body = cls.replace_url_placeholder(body, url) + if variable_pool is None: + return templated_body + return variable_pool.convert_template(templated_body).text + + @classmethod + def render_markdown_body(cls, body: str) -> str: + stripped_body = bleach.clean(body, tags=[], attributes={}, strip=True) + rendered = markdown.markdown( + stripped_body, + extensions=[TableExtension(use_align_attribute=True)], + output_format="html", + ) + return bleach.clean( + rendered, + tags=cls._ALLOWED_HTML_TAGS, + attributes=cls._ALLOWED_HTML_ATTRIBUTES, + protocols=cls._ALLOWED_PROTOCOLS, + strip=True, + ) + + @staticmethod + def sanitize_subject(subject: str) -> str: + sanitized = subject.replace("\r", " ").replace("\n", " ") + sanitized = bleach.clean(sanitized, tags=[], strip=True) + return " ".join(sanitized.split()) + + +class _DeliveryMethodBase(BaseModel): + enabled: bool = True + id: uuid.UUID = Field(default_factory=uuid.uuid4) + + def extract_variable_selectors(self) -> Sequence[Sequence[str]]: + return () + + +class InteractiveSurfaceDeliveryMethod(_DeliveryMethodBase): + type: Literal[DeliveryMethodType.WEBAPP] = DeliveryMethodType.WEBAPP + config: _InteractiveSurfaceDeliveryConfig = Field(default_factory=_InteractiveSurfaceDeliveryConfig) + + +class EmailDeliveryMethod(_DeliveryMethodBase): + type: Literal[DeliveryMethodType.EMAIL] = DeliveryMethodType.EMAIL + config: EmailDeliveryConfig + + def extract_variable_selectors(self) -> Sequence[Sequence[str]]: + variable_template_parser = VariableTemplateParser(template=self.config.body) + selectors: list[Sequence[str]] = [] + for variable_selector in variable_template_parser.extract_variable_selectors(): + value_selector = list(variable_selector.value_selector) + if len(value_selector) < SELECTORS_LENGTH: + continue + selectors.append(value_selector[:SELECTORS_LENGTH]) + return selectors + + +WebAppDeliveryMethod = InteractiveSurfaceDeliveryMethod +_WebAppDeliveryConfig = _InteractiveSurfaceDeliveryConfig + +DeliveryChannelConfig = Annotated[InteractiveSurfaceDeliveryMethod | EmailDeliveryMethod, Field(discriminator="type")] + +_DELIVERY_METHODS_ADAPTER = TypeAdapter(list[DeliveryChannelConfig]) + + +def _copy_mapping(value: object) -> dict[str, Any] | None: + if isinstance(value, BaseModel): + return value.model_dump(mode="python") + if isinstance(value, Mapping): + return dict(value) + return None + + +def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: + normalized = _copy_mapping(node_data) + if normalized is None: + raise TypeError(f"human-input node data must be a mapping, got {type(node_data).__name__}") + + delivery_methods = normalized.get("delivery_methods") + if not isinstance(delivery_methods, list): + return normalized + + normalized_methods: list[Any] = [] + for method in delivery_methods: + method_mapping = _copy_mapping(method) + if method_mapping is None: + normalized_methods.append(method) + continue + + config_mapping = _copy_mapping(method_mapping.get("config")) + if config_mapping is not None: + recipients_mapping = _copy_mapping(config_mapping.get("recipients")) + if recipients_mapping is not None: + config_mapping["recipients"] = _normalize_email_recipients(recipients_mapping) + method_mapping["config"] = config_mapping + + normalized_methods.append(method_mapping) + + normalized["delivery_methods"] = normalized_methods + return normalized + + +def parse_human_input_delivery_methods(node_data: Mapping[str, Any] | BaseModel) -> list[DeliveryChannelConfig]: + normalized = normalize_human_input_node_data_for_graph(node_data) + raw_delivery_methods = normalized.get("delivery_methods") + if not isinstance(raw_delivery_methods, list): + return [] + return list(_DELIVERY_METHODS_ADAPTER.validate_python(raw_delivery_methods)) + + +def is_human_input_webapp_enabled(node_data: Mapping[str, Any] | BaseModel) -> bool: + for method in parse_human_input_delivery_methods(node_data): + if method.enabled and method.type == DeliveryMethodType.WEBAPP: + return True + return False + + +def normalize_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: + normalized = _copy_mapping(node_data) + if normalized is None: + raise TypeError(f"node data must be a mapping, got {type(node_data).__name__}") + + if normalized.get("type") != BuiltinNodeTypes.HUMAN_INPUT: + return normalized + return normalize_human_input_node_data_for_graph(normalized) + + +def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]: + normalized = _copy_mapping(node_config) + if normalized is None: + raise TypeError(f"node config must be a mapping, got {type(node_config).__name__}") + + data_mapping = _copy_mapping(normalized.get("data")) + if data_mapping is None: + return normalized + + normalized["data"] = normalize_node_data_for_graph(data_mapping) + return normalized + + +def _normalize_email_recipients(recipients: Mapping[str, Any]) -> dict[str, Any]: + normalized = dict(recipients) + + legacy_include_bound_group = normalized.pop("whole_workspace", None) + if "include_bound_group" not in normalized and legacy_include_bound_group is not None: + normalized["include_bound_group"] = legacy_include_bound_group + + items = normalized.get("items") + if not isinstance(items, list): + return normalized + + normalized_items: list[Any] = [] + for item in items: + item_mapping = _copy_mapping(item) + if item_mapping is None: + normalized_items.append(item) + continue + + legacy_reference_id = item_mapping.pop("user_id", None) + if "reference_id" not in item_mapping and legacy_reference_id is not None: + item_mapping["reference_id"] = legacy_reference_id + normalized_items.append(item_mapping) + + normalized["items"] = normalized_items + return normalized + + +__all__ = [ + "BoundRecipient", + "DeliveryChannelConfig", + "DeliveryMethodType", + "EmailDeliveryConfig", + "EmailDeliveryMethod", + "EmailRecipientType", + "EmailRecipients", + "ExternalRecipient", + "MemberRecipient", + "WebAppDeliveryMethod", + "_WebAppDeliveryConfig", + "is_human_input_webapp_enabled", + "normalize_human_input_node_data_for_graph", + "normalize_node_config_for_graph", + "normalize_node_data_for_graph", + "parse_human_input_delivery_methods", +] diff --git a/api/core/workflow/human_input_forms.py b/api/core/workflow/human_input_forms.py new file mode 100644 index 00000000000..f124b321d4c --- /dev/null +++ b/api/core/workflow/human_input_forms.py @@ -0,0 +1,55 @@ +"""Shared helpers for workflow pause-time human input form lookups. + +Both controllers and streaming response converters need the same recipient +priority when exposing resume links for paused human input forms. Keep that +selection logic here so all API surfaces stay consistent. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from extensions.ext_database import db +from models.human_input import HumanInputFormRecipient, RecipientType + +_FORM_TOKEN_PRIORITY = { + RecipientType.BACKSTAGE: 0, + RecipientType.CONSOLE: 1, + RecipientType.STANDALONE_WEB_APP: 2, +} + + +def load_form_tokens_by_form_id( + form_ids: Sequence[str], + *, + session: Session | None = None, +) -> dict[str, str]: + """Load the preferred access token for each human input form.""" + unique_form_ids = list(dict.fromkeys(form_ids)) + if not unique_form_ids: + return {} + + if session is not None: + return _load_form_tokens_by_form_id(session, unique_form_ids) + + with Session(bind=db.engine, expire_on_commit=False) as new_session: + return _load_form_tokens_by_form_id(new_session, unique_form_ids) + + +def _load_form_tokens_by_form_id(session: Session, form_ids: Sequence[str]) -> dict[str, str]: + tokens_by_form_id: dict[str, tuple[int, str]] = {} + stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids)) + for recipient in session.scalars(stmt): + priority = _FORM_TOKEN_PRIORITY.get(recipient.recipient_type) + if priority is None or not recipient.access_token: + continue + + candidate = (priority, recipient.access_token) + current = tokens_by_form_id.get(recipient.form_id) + if current is None or candidate[0] < current[0]: + tokens_by_form_id[recipient.form_id] = candidate + + return {form_id: token for form_id, (_, token) in tokens_by_form_id.items()} diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index ab34263a791..8cc21d2cd96 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -4,13 +4,29 @@ from collections.abc import Callable, Iterator, Mapping, MutableMapping from functools import lru_cache from typing import TYPE_CHECKING, Any, TypeAlias, cast, final +from graphon.entities.base_node_data import BaseNodeData +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.file.file_manager import file_manager +from graphon.graph.graph import NodeFactory +from graphon.model_runtime.memory import PromptMessageMemory +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.nodes.base.node import Node +from graphon.nodes.code.code_node import WorkflowCodeExecutor +from graphon.nodes.code.entities import CodeLanguage +from graphon.nodes.code.limits import CodeNodeLimits +from graphon.nodes.document_extractor import UnstructuredApiConfig +from graphon.nodes.http_request import build_http_request_config +from graphon.nodes.llm.entities import LLMNodeData +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData from sqlalchemy import select from sqlalchemy.orm import Session from typing_extensions import override from configs import dify_config -from core.app.entities.app_invoke_entities import DifyRunContext -from core.app.llm.model_access import build_dify_model_access +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.app.llm.model_access import build_dify_model_access, fetch_model_config from core.helper.code_executor.code_executor import ( CodeExecutionError, CodeExecutor, @@ -19,45 +35,32 @@ from core.helper.ssrf_proxy import ssrf_proxy from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.repositories.human_input_repository import HumanInputFormRepositoryImpl -from core.tools.tool_file_manager import ToolFileManager from core.trigger.constants import TRIGGER_NODE_TYPES +from core.workflow.human_input_compat import normalize_node_config_for_graph +from core.workflow.node_runtime import ( + DifyFileReferenceFactory, + DifyHumanInputNodeRuntime, + DifyPreparedLLM, + DifyPromptMessageSerializer, + DifyRetrieverAttachmentLoader, + DifyToolFileManager, + DifyToolNodeRuntime, + build_dify_llm_file_saver, +) from core.workflow.nodes.agent.message_transformer import AgentMessageTransformer from core.workflow.nodes.agent.plugin_strategy_adapter import ( PluginAgentStrategyPresentationProvider, PluginAgentStrategyResolver, ) from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import BuiltinNodeTypes, NodeType, SystemVariableKey -from dify_graph.file.file_manager import file_manager -from dify_graph.graph.graph import NodeFactory -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.memory import PromptMessageMemory -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.code.code_node import WorkflowCodeExecutor -from dify_graph.nodes.code.entities import CodeLanguage -from dify_graph.nodes.code.limits import CodeNodeLimits -from dify_graph.nodes.document_extractor import UnstructuredApiConfig -from dify_graph.nodes.http_request import build_http_request_config -from dify_graph.nodes.llm.entities import LLMNodeData -from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError -from dify_graph.nodes.llm.protocols import TemplateRenderer -from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData -from dify_graph.nodes.template_transform.template_renderer import ( - CodeExecutorJinja2TemplateRenderer, -) -from dify_graph.variables.segments import StringSegment +from core.workflow.system_variables import SystemVariableKey, get_system_text, system_variable_selector +from core.workflow.template_rendering import CodeExecutorJinja2TemplateRenderer from extensions.ext_database import db from models.model import Conversation if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState LATEST_VERSION = "latest" _START_NODE_TYPES: frozenset[NodeType] = frozenset( @@ -76,7 +79,7 @@ def _import_node_package(package_name: str, *, excluded_modules: frozenset[str] @lru_cache(maxsize=1) def register_nodes() -> None: """Import production node modules so they self-register with ``Node``.""" - _import_node_package("dify_graph.nodes") + _import_node_package("graphon.nodes") _import_node_package("core.workflow.nodes") @@ -84,7 +87,7 @@ def get_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node] """Return a read-only snapshot of the current production node registry. The workflow layer owns node bootstrap because it must compose built-in - `dify_graph.nodes.*` implementations with workflow-local nodes under + `graphon.nodes.*` implementations with workflow-local nodes under `core.workflow.nodes.*`. Keeping this import side effect here avoids reintroducing registry bootstrapping into lower-level graph primitives. """ @@ -115,7 +118,7 @@ def get_default_root_node_id(graph_config: Mapping[str, Any]) -> str: This workflow-layer helper depends on start-node semantics defined by `is_start_node_type`, so it intentionally lives next to the node registry - instead of in the raw `dify_graph.entities.graph_config` schema module. + instead of in the raw `graphon.entities.graph_config` schema module. """ nodes = graph_config.get("nodes") if not isinstance(nodes, list): @@ -229,16 +232,6 @@ class DefaultWorkflowCodeExecutor: return isinstance(error, CodeExecutionError) -class DefaultLLMTemplateRenderer(TemplateRenderer): - def render_jinja2(self, *, template: str, inputs: Mapping[str, Any]) -> str: - result = CodeExecutor.execute_workflow_code_template( - language=CodeLanguage.JINJA2, - code=template, - inputs=inputs, - ) - return str(result.get("result", "")) - - @final class DifyNodeFactory(NodeFactory): """ @@ -264,11 +257,31 @@ class DifyNodeFactory(NodeFactory): max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH, max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH, ) - self._template_renderer = CodeExecutorJinja2TemplateRenderer(code_executor=self._code_executor) - self._llm_template_renderer: TemplateRenderer = DefaultLLMTemplateRenderer() + self._jinja2_template_renderer = CodeExecutorJinja2TemplateRenderer() self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH self._http_request_http_client = ssrf_proxy - self._http_request_tool_file_manager_factory = ToolFileManager + self._bound_tool_file_manager_factory = lambda: DifyToolFileManager( + self._dify_context, + conversation_id_getter=self._conversation_id, + ) + self._file_reference_factory = DifyFileReferenceFactory(self._dify_context) + self._prompt_message_serializer = DifyPromptMessageSerializer() + self._retriever_attachment_loader = DifyRetrieverAttachmentLoader( + file_reference_factory=self._file_reference_factory, + ) + self._llm_file_saver = build_dify_llm_file_saver( + run_context=self._dify_context, + http_client=self._http_request_http_client, + conversation_id_getter=self._conversation_id, + ) + self._human_input_runtime = DifyHumanInputNodeRuntime( + self._dify_context, + workflow_execution_id_getter=lambda: get_system_text( + self.graph_runtime_state.variable_pool, + SystemVariableKey.WORKFLOW_EXECUTION_ID, + ), + ) + self._tool_runtime = DifyToolNodeRuntime(self._dify_context) self._http_request_file_manager = file_manager self._document_extractor_unstructured_api_config = UnstructuredApiConfig( api_url=dify_config.UNSTRUCTURED_API_URL, @@ -284,7 +297,7 @@ class DifyNodeFactory(NodeFactory): ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES, ) - self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context.tenant_id) + self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context) self._agent_strategy_resolver = PluginAgentStrategyResolver() self._agent_strategy_presentation_provider = PluginAgentStrategyPresentationProvider() self._agent_runtime_support = AgentRuntimeSupport() @@ -299,6 +312,9 @@ class DifyNodeFactory(NodeFactory): return raw_ctx return DifyRunContext.model_validate(raw_ctx) + def _conversation_id(self) -> str | None: + return get_system_text(self.graph_runtime_state.variable_pool, SystemVariableKey.CONVERSATION_ID) + @override def create_node(self, node_config: dict[str, Any] | NodeConfigDict) -> Node: """ @@ -310,7 +326,7 @@ class DifyNodeFactory(NodeFactory): (including pydantic ValidationError, which subclasses ValueError), if node type is unknown, or if no implementation exists for the resolved version """ - typed_node_config = NodeConfigDictAdapter.validate_python(node_config) + typed_node_config = NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config)) node_id = typed_node_config["id"] node_data = typed_node_config["data"] node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version)) @@ -321,22 +337,29 @@ class DifyNodeFactory(NodeFactory): "code_limits": self._code_limits, }, BuiltinNodeTypes.TEMPLATE_TRANSFORM: lambda: { - "template_renderer": self._template_renderer, + "jinja2_template_renderer": self._jinja2_template_renderer, "max_output_length": self._template_transform_max_output_length, }, BuiltinNodeTypes.HTTP_REQUEST: lambda: { "http_request_config": self._http_request_config, "http_client": self._http_request_http_client, - "tool_file_manager_factory": self._http_request_tool_file_manager_factory, + "tool_file_manager_factory": self._bound_tool_file_manager_factory, "file_manager": self._http_request_file_manager, + "file_reference_factory": self._file_reference_factory, }, BuiltinNodeTypes.HUMAN_INPUT: lambda: { - "form_repository": HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id), + "runtime": self._human_input_runtime, + "form_repository": self._human_input_runtime.build_form_repository(), }, BuiltinNodeTypes.LLM: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, node_data=node_data, + wrap_model_instance=True, include_http_client=True, + include_llm_file_saver=True, + include_prompt_message_serializer=True, + include_retriever_attachment_loader=True, + include_jinja2_template_renderer=True, ), BuiltinNodeTypes.DOCUMENT_EXTRACTOR: lambda: { "unstructured_api_config": self._document_extractor_unstructured_api_config, @@ -345,15 +368,26 @@ class DifyNodeFactory(NodeFactory): BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, node_data=node_data, + wrap_model_instance=True, include_http_client=True, + include_llm_file_saver=True, + include_prompt_message_serializer=True, + include_retriever_attachment_loader=False, + include_jinja2_template_renderer=False, ), BuiltinNodeTypes.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, node_data=node_data, + wrap_model_instance=True, include_http_client=False, + include_llm_file_saver=False, + include_prompt_message_serializer=True, + include_retriever_attachment_loader=False, + include_jinja2_template_renderer=False, ), BuiltinNodeTypes.TOOL: lambda: { - "tool_file_manager_factory": self._http_request_tool_file_manager_factory(), + "tool_file_manager_factory": self._bound_tool_file_manager_factory(), + "runtime": self._tool_runtime, }, BuiltinNodeTypes.AGENT: lambda: { "strategy_resolver": self._agent_strategy_resolver, @@ -387,7 +421,12 @@ class DifyNodeFactory(NodeFactory): *, node_class: type[Node], node_data: BaseNodeData, + wrap_model_instance: bool, include_http_client: bool, + include_llm_file_saver: bool, + include_prompt_message_serializer: bool, + include_retriever_attachment_loader: bool, + include_jinja2_template_renderer: bool, ) -> dict[str, object]: validated_node_data = cast( LLMCompatibleNodeData, @@ -397,49 +436,35 @@ class DifyNodeFactory(NodeFactory): node_init_kwargs: dict[str, object] = { "credentials_provider": self._llm_credentials_provider, "model_factory": self._llm_model_factory, - "model_instance": model_instance, + "model_instance": DifyPreparedLLM(model_instance) if wrap_model_instance else model_instance, "memory": self._build_memory_for_llm_node( node_data=validated_node_data, model_instance=model_instance, ), } - if validated_node_data.type in {BuiltinNodeTypes.LLM, BuiltinNodeTypes.QUESTION_CLASSIFIER}: - node_init_kwargs["template_renderer"] = self._llm_template_renderer + if validated_node_data.type == BuiltinNodeTypes.QUESTION_CLASSIFIER: + node_init_kwargs["template_renderer"] = self._jinja2_template_renderer if include_http_client: node_init_kwargs["http_client"] = self._http_request_http_client + if include_llm_file_saver: + node_init_kwargs["llm_file_saver"] = self._llm_file_saver + if include_prompt_message_serializer: + node_init_kwargs["prompt_message_serializer"] = self._prompt_message_serializer + if include_retriever_attachment_loader: + node_init_kwargs["retriever_attachment_loader"] = self._retriever_attachment_loader + if include_jinja2_template_renderer: + node_init_kwargs["jinja2_template_renderer"] = self._jinja2_template_renderer + if validated_node_data.type == BuiltinNodeTypes.LLM: + node_init_kwargs["default_query_selector"] = system_variable_selector(SystemVariableKey.QUERY) return node_init_kwargs def _build_model_instance_for_llm_node(self, node_data: LLMCompatibleNodeData) -> ModelInstance: node_data_model = node_data.model - if not node_data_model.mode: - raise LLMModeRequiredError("LLM mode is required.") - - credentials = self._llm_credentials_provider.fetch(node_data_model.provider, node_data_model.name) - model_instance = self._llm_model_factory.init_model_instance(node_data_model.provider, node_data_model.name) - provider_model_bundle = model_instance.provider_model_bundle - - provider_model = provider_model_bundle.configuration.get_provider_model( - model=node_data_model.name, - model_type=ModelType.LLM, + model_instance, _ = fetch_model_config( + node_data_model=node_data_model, + credentials_provider=self._llm_credentials_provider, + model_factory=self._llm_model_factory, ) - if provider_model is None: - raise ModelNotExistError(f"Model {node_data_model.name} not exist.") - provider_model.raise_for_status() - - completion_params = dict(node_data_model.completion_params) - stop = completion_params.pop("stop", []) - if not isinstance(stop, list): - stop = [] - - model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials) - if not model_schema: - raise ModelNotExistError(f"Model {node_data_model.name} not exist.") - - model_instance.provider = node_data_model.provider - model_instance.model_name = node_data_model.name - model_instance.credentials = credentials - model_instance.parameters = completion_params - model_instance.stop = tuple(stop) model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance) return model_instance @@ -452,12 +477,7 @@ class DifyNodeFactory(NodeFactory): if node_data.memory is None: return None - conversation_id_variable = self.graph_runtime_state.variable_pool.get( - ["sys", SystemVariableKey.CONVERSATION_ID] - ) - conversation_id = ( - conversation_id_variable.value if isinstance(conversation_id_variable, StringSegment) else None - ) + conversation_id = get_system_text(self.graph_runtime_state.variable_pool, SystemVariableKey.CONVERSATION_ID) return fetch_memory( conversation_id=conversation_id, app_id=self._dify_context.app_id, diff --git a/api/core/workflow/node_runtime.py b/api/core/workflow/node_runtime.py new file mode 100644 index 00000000000..19cb3a7b0ab --- /dev/null +++ b/api/core/workflow/node_runtime.py @@ -0,0 +1,671 @@ +from __future__ import annotations + +from collections.abc import Callable, Generator, Mapping, Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, cast + +from graphon.file import FileTransferMethod, FileType +from graphon.model_runtime.entities import LLMMode +from graphon.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, + LLMUsage, +) +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.nodes.human_input.entities import HumanInputNodeData +from graphon.nodes.llm.runtime_protocols import ( + PreparedLLMProtocol, + PromptMessageSerializerProtocol, + RetrieverAttachmentLoaderProtocol, +) +from graphon.nodes.protocols import FileReferenceFactoryProtocol, HttpClientProtocol, ToolFileManagerProtocol +from graphon.nodes.runtime import ( + HumanInputFormStateProtocol, + HumanInputNodeRuntimeProtocol, + ToolNodeRuntimeProtocol, +) +from graphon.nodes.tool.exc import ToolNodeError, ToolRuntimeInvocationError, ToolRuntimeResolutionError +from graphon.nodes.tool_runtime_entities import ( + ToolRuntimeHandle, + ToolRuntimeMessage, + ToolRuntimeParameter, +) +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.app.file_access import DatabaseFileAccessController +from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output +from core.model_manager import ModelInstance +from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError +from core.plugin.impl.plugin import PluginInstaller +from core.prompt.utils.prompt_message_util import PromptMessageUtil +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormRepository, + HumanInputFormRepositoryImpl, +) +from core.tools.entities.tool_entities import ToolProviderType as CoreToolProviderType +from core.tools.errors import ToolInvokeError +from core.tools.tool_engine import ToolEngine +from core.tools.tool_file_manager import ToolFileManager +from core.tools.tool_manager import ToolManager +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.workflow.file_reference import build_file_reference +from extensions.ext_database import db +from factories import file_factory +from models.dataset import SegmentAttachmentBinding +from models.model import UploadFile +from services.tools.builtin_tools_manage_service import BuiltinToolManageService + +from .human_input_compat import ( + BoundRecipient, + DeliveryChannelConfig, + DeliveryMethodType, + EmailDeliveryMethod, + EmailRecipients, + is_human_input_webapp_enabled, + parse_human_input_delivery_methods, +) +from .system_variables import SystemVariableKey, get_system_text + +if TYPE_CHECKING: + from graphon.file import File + from graphon.nodes.llm.file_saver import LLMFileSaver + from graphon.nodes.tool.entities import ToolNodeData + + from core.tools.__base.tool import Tool + from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage + + +_file_access_controller = DatabaseFileAccessController() + + +def resolve_dify_run_context(run_context: Mapping[str, Any] | DifyRunContext) -> DifyRunContext: + if isinstance(run_context, DifyRunContext): + return run_context + + raw_ctx = run_context.get(DIFY_RUN_CONTEXT_KEY) + if raw_ctx is None: + raise ValueError(f"run_context missing required key: {DIFY_RUN_CONTEXT_KEY}") + if isinstance(raw_ctx, DifyRunContext): + return raw_ctx + return DifyRunContext.model_validate(raw_ctx) + + +def apply_dify_debug_email_recipient( + method: DeliveryChannelConfig, + *, + enabled: bool, + actor_id: str | None, +) -> DeliveryChannelConfig: + """Apply the Dify debugger-specific email recipient override outside `graphon`.""" + if not enabled: + return method + if not isinstance(method, EmailDeliveryMethod): + return method + if not method.config.debug_mode: + return method + + if actor_id is None: + debug_recipients = EmailRecipients(include_bound_group=False, items=[]) + else: + debug_recipients = EmailRecipients( + include_bound_group=False, + items=[BoundRecipient(reference_id=actor_id)], + ) + debug_config = method.config.with_recipients(debug_recipients) + return method.model_copy(update={"config": debug_config}) + + +class DifyFileReferenceFactory(FileReferenceFactoryProtocol): + def __init__(self, run_context: Mapping[str, Any] | DifyRunContext) -> None: + self._run_context = resolve_dify_run_context(run_context) + + def build_from_mapping(self, *, mapping: Mapping[str, Any]): + return file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self._run_context.tenant_id, + access_controller=_file_access_controller, + ) + + +class DifyPreparedLLM(PreparedLLMProtocol): + """Workflow-layer adapter that hides the full `ModelInstance` API from `graphon` nodes.""" + + def __init__(self, model_instance: ModelInstance) -> None: + self._model_instance = model_instance + + @property + def provider(self) -> str: + return self._model_instance.provider + + @property + def model_name(self) -> str: + return self._model_instance.model_name + + @property + def parameters(self) -> Mapping[str, Any]: + return self._model_instance.parameters + + @parameters.setter + def parameters(self, value: Mapping[str, Any]) -> None: + self._model_instance.parameters = value + + @property + def stop(self) -> Sequence[str] | None: + return self._model_instance.stop + + def get_model_schema(self) -> AIModelEntity: + model_schema = cast(LargeLanguageModel, self._model_instance.model_type_instance).get_model_schema( + self._model_instance.model_name, + self._model_instance.credentials, + ) + if model_schema is None: + raise ValueError(f"Model schema not found for {self._model_instance.model_name}") + return model_schema + + def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int: + return self._model_instance.get_llm_num_tokens(prompt_messages) + + def invoke_llm( + self, + *, + prompt_messages: Sequence[PromptMessage], + model_parameters: Mapping[str, Any], + tools: Sequence[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: bool, + ) -> LLMResult | Generator[LLMResultChunk, None, None]: + return self._model_instance.invoke_llm( + prompt_messages=list(prompt_messages), + model_parameters=dict(model_parameters), + tools=list(tools or []), + stop=list(stop or []), + stream=stream, + ) + + def invoke_llm_with_structured_output( + self, + *, + prompt_messages: Sequence[PromptMessage], + json_schema: Mapping[str, Any], + model_parameters: Mapping[str, Any], + stop: Sequence[str] | None, + stream: bool, + ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: + return invoke_llm_with_structured_output( + provider=self.provider, + model_schema=self.get_model_schema(), + model_instance=self._model_instance, + prompt_messages=prompt_messages, + json_schema=json_schema, + model_parameters=model_parameters, + stop=list(stop or []), + stream=stream, + ) + + def is_structured_output_parse_error(self, error: Exception) -> bool: + return isinstance(error, OutputParserError) + + +class DifyPromptMessageSerializer(PromptMessageSerializerProtocol): + def serialize( + self, + *, + model_mode: LLMMode, + prompt_messages: Sequence[PromptMessage], + ) -> Any: + return PromptMessageUtil.prompt_messages_to_prompt_for_saving( + model_mode=model_mode, + prompt_messages=prompt_messages, + ) + + +class DifyRetrieverAttachmentLoader(RetrieverAttachmentLoaderProtocol): + """Resolve retriever attachments through Dify persistence and return graph file references.""" + + def __init__(self, *, file_reference_factory: FileReferenceFactoryProtocol) -> None: + self._file_reference_factory = file_reference_factory + + def load(self, *, segment_id: str) -> Sequence[File]: + with Session(db.engine, expire_on_commit=False) as session: + attachments_with_bindings = session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where(SegmentAttachmentBinding.segment_id == segment_id) + ).all() + + return [ + self._file_reference_factory.build_from_mapping( + mapping={ + "id": upload_file.id, + "filename": upload_file.name, + "extension": "." + upload_file.extension, + "mime_type": upload_file.mime_type, + "type": FileType.IMAGE, + "transfer_method": FileTransferMethod.LOCAL_FILE, + "remote_url": upload_file.source_url, + "reference": build_file_reference(record_id=str(upload_file.id)), + "size": upload_file.size, + } + ) + for _, upload_file in attachments_with_bindings + ] + + +class DifyToolFileManager(ToolFileManagerProtocol): + """Workflow adapter that resolves conversation scope outside `graphon`.""" + + _conversation_id_getter: Callable[[], str | None] | None + + def __init__( + self, + run_context: Mapping[str, Any] | DifyRunContext, + *, + conversation_id_getter: Callable[[], str | None] | None = None, + ) -> None: + self._run_context = resolve_dify_run_context(run_context) + self._manager = ToolFileManager() + self._conversation_id_getter = conversation_id_getter + + def create_file_by_raw( + self, + *, + file_binary: bytes, + mimetype: str, + filename: str | None = None, + ) -> Any: + conversation_id = self._conversation_id_getter() if self._conversation_id_getter is not None else None + return self._manager.create_file_by_raw( + user_id=self._run_context.user_id, + tenant_id=self._run_context.tenant_id, + conversation_id=conversation_id, + file_binary=file_binary, + mimetype=mimetype, + filename=filename, + ) + + def get_file_generator_by_tool_file_id(self, tool_file_id: str): + return self._manager.get_file_generator_by_tool_file_id(tool_file_id) + + +@dataclass(frozen=True, slots=True) +class _WorkflowToolRuntimeSpec: + provider_type: CoreToolProviderType + provider_id: str + tool_name: str + tool_configurations: dict[str, Any] + credential_id: str | None = None + + +@dataclass(frozen=True, slots=True) +class _WorkflowToolRuntimeBinding: + """Workflow-private runtime state stored inside the opaque graph handle. + + The binding keeps conversation scope in `core.workflow` while `graphon` + continues to treat the handle as an opaque token. + """ + + tool: Tool + conversation_id: str | None = None + + +class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): + def __init__(self, run_context: Mapping[str, Any] | DifyRunContext) -> None: + self._run_context = resolve_dify_run_context(run_context) + self._file_reference_factory = DifyFileReferenceFactory(self._run_context) + + @property + def file_reference_factory(self) -> FileReferenceFactoryProtocol: + return self._file_reference_factory + + def build_file_reference(self, *, mapping: Mapping[str, Any]): + return self._file_reference_factory.build_from_mapping(mapping=mapping) + + def get_runtime( + self, + *, + node_id: str, + node_data: ToolNodeData, + variable_pool, + ) -> ToolRuntimeHandle: + try: + tool_runtime = ToolManager.get_workflow_tool_runtime( + self._run_context.tenant_id, + self._run_context.app_id, + node_id, + self._build_tool_runtime_spec(node_data), + self._run_context.user_id, + self._run_context.invoke_from, + variable_pool, + ) + except ToolNodeError: + raise + except Exception as exc: + raise ToolRuntimeResolutionError(str(exc)) from exc + + conversation_id = ( + None if variable_pool is None else get_system_text(variable_pool, SystemVariableKey.CONVERSATION_ID) + ) + return ToolRuntimeHandle(raw=_WorkflowToolRuntimeBinding(tool=tool_runtime, conversation_id=conversation_id)) + + def get_runtime_parameters( + self, + *, + tool_runtime: ToolRuntimeHandle, + ) -> Sequence[ToolRuntimeParameter]: + tool = self._tool_from_handle(tool_runtime) + return [ + ToolRuntimeParameter(name=parameter.name, required=parameter.required) + for parameter in (tool.get_merged_runtime_parameters() or []) + ] + + def invoke( + self, + *, + tool_runtime: ToolRuntimeHandle, + tool_parameters: Mapping[str, Any], + workflow_call_depth: int, + provider_name: str, + ) -> Generator[ToolRuntimeMessage, None, None]: + runtime_binding = self._binding_from_handle(tool_runtime) + tool = runtime_binding.tool + callback = DifyWorkflowCallbackHandler() + + try: + messages = ToolEngine.generic_invoke( + tool=tool, + tool_parameters=dict(tool_parameters), + user_id=self._run_context.user_id, + workflow_tool_callback=callback, + workflow_call_depth=workflow_call_depth, + app_id=self._run_context.app_id, + conversation_id=runtime_binding.conversation_id, + ) + except Exception as exc: + raise self._map_invocation_exception(exc, provider_name=provider_name) from exc + + transformed_messages = ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=messages, + user_id=self._run_context.user_id, + tenant_id=self._run_context.tenant_id, + conversation_id=runtime_binding.conversation_id, + ) + + return self._adapt_messages(transformed_messages, provider_name=provider_name) + + def get_usage( + self, + *, + tool_runtime: ToolRuntimeHandle, + ) -> LLMUsage: + latest = getattr(self._binding_from_handle(tool_runtime).tool, "latest_usage", None) + if isinstance(latest, LLMUsage): + return latest + if isinstance(latest, dict): + return LLMUsage.model_validate(latest) + return LLMUsage.empty_usage() + + def resolve_provider_icons( + self, + *, + provider_name: str, + default_icon: str | None = None, + ) -> tuple[str | Mapping[str, str] | None, str | Mapping[str, str] | None]: + icon: str | Mapping[str, str] | None = default_icon + icon_dark: str | Mapping[str, str] | None = None + + manager = PluginInstaller() + plugins = manager.list_plugins(self._run_context.tenant_id) + try: + current_plugin = next(plugin for plugin in plugins if f"{plugin.plugin_id}/{plugin.name}" == provider_name) + icon = current_plugin.declaration.icon + except StopIteration: + pass + + try: + builtin_tool = next( + provider + for provider in BuiltinToolManageService.list_builtin_tools( + self._run_context.user_id, + self._run_context.tenant_id, + ) + if provider.name == provider_name + ) + icon = builtin_tool.icon + icon_dark = builtin_tool.icon_dark + except StopIteration: + pass + + return icon, icon_dark + + @staticmethod + def _tool_from_handle(tool_runtime: ToolRuntimeHandle) -> Tool: + return DifyToolNodeRuntime._binding_from_handle(tool_runtime).tool + + @staticmethod + def _binding_from_handle(tool_runtime: ToolRuntimeHandle) -> _WorkflowToolRuntimeBinding: + if isinstance(tool_runtime.raw, _WorkflowToolRuntimeBinding): + return tool_runtime.raw + return _WorkflowToolRuntimeBinding(tool=cast("Tool", tool_runtime.raw)) + + @staticmethod + def _build_tool_runtime_spec(node_data: ToolNodeData) -> _WorkflowToolRuntimeSpec: + return _WorkflowToolRuntimeSpec( + provider_type=CoreToolProviderType(node_data.provider_type.value), + provider_id=node_data.provider_id, + tool_name=node_data.tool_name, + tool_configurations=dict(node_data.tool_configurations), + credential_id=node_data.credential_id, + ) + + def _adapt_messages( + self, + messages: Generator[CoreToolInvokeMessage, None, None], + *, + provider_name: str, + ) -> Generator[ToolRuntimeMessage, None, None]: + try: + for message in messages: + yield self._convert_message(message) + except Exception as exc: + raise self._map_invocation_exception(exc, provider_name=provider_name) from exc + + def _convert_message(self, message: CoreToolInvokeMessage) -> ToolRuntimeMessage: + graph_message_type = ToolRuntimeMessage.MessageType(message.type.value) + graph_message = self._convert_message_payload(message.message) + graph_meta = message.meta.copy() if message.meta is not None else None + return ToolRuntimeMessage(type=graph_message_type, message=graph_message, meta=graph_meta) + + def _convert_message_payload( + self, + message: CoreToolInvokeMessage.TextMessage + | CoreToolInvokeMessage.JsonMessage + | CoreToolInvokeMessage.BlobChunkMessage + | CoreToolInvokeMessage.BlobMessage + | CoreToolInvokeMessage.LogMessage + | CoreToolInvokeMessage.FileMessage + | CoreToolInvokeMessage.VariableMessage + | CoreToolInvokeMessage.RetrieverResourceMessage + | None, + ) -> ( + ToolRuntimeMessage.TextMessage + | ToolRuntimeMessage.JsonMessage + | ToolRuntimeMessage.BlobChunkMessage + | ToolRuntimeMessage.BlobMessage + | ToolRuntimeMessage.LogMessage + | ToolRuntimeMessage.FileMessage + | ToolRuntimeMessage.VariableMessage + | ToolRuntimeMessage.RetrieverResourceMessage + | None + ): + if message is None: + return None + + from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage + + if isinstance(message, CoreToolInvokeMessage.TextMessage): + return ToolRuntimeMessage.TextMessage(text=message.text) + if isinstance(message, CoreToolInvokeMessage.JsonMessage): + return ToolRuntimeMessage.JsonMessage( + json_object=message.json_object, + suppress_output=message.suppress_output, + ) + if isinstance(message, CoreToolInvokeMessage.BlobMessage): + return ToolRuntimeMessage.BlobMessage(blob=message.blob) + if isinstance(message, CoreToolInvokeMessage.BlobChunkMessage): + return ToolRuntimeMessage.BlobChunkMessage( + id=message.id, + sequence=message.sequence, + total_length=message.total_length, + blob=message.blob, + end=message.end, + ) + if isinstance(message, CoreToolInvokeMessage.FileMessage): + return ToolRuntimeMessage.FileMessage(file_marker=message.file_marker) + if isinstance(message, CoreToolInvokeMessage.VariableMessage): + return ToolRuntimeMessage.VariableMessage( + variable_name=message.variable_name, + variable_value=message.variable_value, + stream=message.stream, + ) + if isinstance(message, CoreToolInvokeMessage.LogMessage): + return ToolRuntimeMessage.LogMessage( + id=message.id, + label=message.label, + parent_id=message.parent_id, + error=message.error, + status=ToolRuntimeMessage.LogMessage.LogStatus(message.status.value), + data=dict(message.data), + metadata=dict(message.metadata), + ) + if isinstance(message, CoreToolInvokeMessage.RetrieverResourceMessage): + retriever_resources = [ + resource.model_dump() if hasattr(resource, "model_dump") else dict(resource) + for resource in message.retriever_resources + ] + return ToolRuntimeMessage.RetrieverResourceMessage( + retriever_resources=retriever_resources, + context=message.context, + ) + + raise TypeError(f"unsupported tool message payload: {type(message).__name__}") + + @staticmethod + def _map_invocation_exception(exc: Exception, *, provider_name: str) -> ToolNodeError: + if isinstance(exc, ToolNodeError): + return exc + if isinstance(exc, PluginInvokeError): + return ToolRuntimeInvocationError(exc.to_user_friendly_error(plugin_name=provider_name)) + if isinstance(exc, PluginDaemonClientSideError): + return ToolRuntimeInvocationError(f"Failed to invoke tool, error: {exc.description}") + if isinstance(exc, ToolInvokeError): + return ToolRuntimeInvocationError(f"Failed to invoke tool {provider_name}: {exc}") + return ToolRuntimeInvocationError(str(exc)) + + +class DifyHumanInputNodeRuntime(HumanInputNodeRuntimeProtocol): + def __init__( + self, + run_context: Mapping[str, Any] | DifyRunContext, + *, + workflow_execution_id_getter: Callable[[], str | None] | None = None, + form_repository: HumanInputFormRepository | None = None, + ) -> None: + self._run_context = resolve_dify_run_context(run_context) + self._workflow_execution_id_getter = workflow_execution_id_getter + self._form_repository = form_repository + + def _invoke_source(self) -> str: + invoke_from = self._run_context.invoke_from + if isinstance(invoke_from, str): + return invoke_from + return str(getattr(invoke_from, "value", invoke_from)) + + def _resolve_delivery_methods(self, *, node_data: HumanInputNodeData) -> Sequence[DeliveryChannelConfig]: + invoke_source = self._invoke_source() + methods = [method for method in parse_human_input_delivery_methods(node_data) if method.enabled] + if invoke_source in {"debugger", "explore"}: + methods = [method for method in methods if method.type != DeliveryMethodType.WEBAPP] + return [ + apply_dify_debug_email_recipient( + method, + enabled=invoke_source == "debugger", + actor_id=self._run_context.user_id, + ) + for method in methods + ] + + def _display_in_ui(self, *, node_data: HumanInputNodeData) -> bool: + if self._invoke_source() == "debugger": + return True + return is_human_input_webapp_enabled(node_data) + + def build_form_repository(self) -> HumanInputFormRepository: + if self._form_repository is not None: + return self._form_repository + + return self._build_form_repository() + + def _build_form_repository(self) -> HumanInputFormRepository: + invoke_source = self._invoke_source() + return HumanInputFormRepositoryImpl( + tenant_id=self._run_context.tenant_id, + app_id=self._run_context.app_id, + workflow_execution_id=self._workflow_execution_id_getter() if self._workflow_execution_id_getter else None, + invoke_source=invoke_source, + submission_actor_id=self._run_context.user_id if invoke_source in {"debugger", "explore"} else None, + ) + + def with_form_repository(self, form_repository: HumanInputFormRepository) -> DifyHumanInputNodeRuntime: + return DifyHumanInputNodeRuntime( + self._run_context, + workflow_execution_id_getter=self._workflow_execution_id_getter, + form_repository=form_repository, + ) + + def get_form(self, *, node_id: str) -> HumanInputFormStateProtocol | None: + repo = self.build_form_repository() + return repo.get_form(node_id) + + def create_form( + self, + *, + node_id: str, + node_data: HumanInputNodeData, + rendered_content: str, + resolved_default_values: Mapping[str, Any], + ) -> HumanInputFormStateProtocol: + repo = self.build_form_repository() + params = FormCreateParams( + workflow_execution_id=self._workflow_execution_id_getter() if self._workflow_execution_id_getter else None, + node_id=node_id, + form_config=node_data, + rendered_content=rendered_content, + delivery_methods=self._resolve_delivery_methods(node_data=node_data), + display_in_ui=self._display_in_ui(node_data=node_data), + resolved_default_values=resolved_default_values, + ) + return repo.create_form(params) + + +def build_dify_llm_file_saver( + *, + run_context: Mapping[str, Any] | DifyRunContext, + http_client: HttpClientProtocol, + conversation_id_getter: Callable[[], str | None] | None = None, +) -> LLMFileSaver: + from graphon.nodes.llm.file_saver import FileSaverImpl + + return FileSaverImpl( + tool_file_manager=DifyToolFileManager(run_context, conversation_id_getter=conversation_id_getter), + file_reference_factory=DifyFileReferenceFactory(run_context), + http_client=http_client, + ) diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 5699ccf404a..bfd5536e4a7 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -3,11 +3,14 @@ from __future__ import annotations from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import BuiltinNodeTypes, SystemVariableKey, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent +from graphon.nodes.base.node import Node +from graphon.nodes.base.variable_template_parser import VariableTemplateParser + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.workflow.system_variables import SystemVariableKey, get_system_text from .entities import AgentNodeData from .exceptions import ( @@ -19,8 +22,8 @@ from .runtime_support import AgentRuntimeSupport from .strategy_protocols import AgentStrategyPresentationProvider, AgentStrategyResolver if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState class AgentNode(Node[AgentNodeData]): @@ -59,7 +62,7 @@ class AgentNode(Node[AgentNodeData]): return "1" def populate_start_event(self, event) -> None: - dify_ctx = self.require_dify_context() + dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) event.extras["agent_strategy"] = { "name": self.node_data.agent_strategy_name, "icon": self._presentation_provider.get_icon( @@ -71,7 +74,7 @@ class AgentNode(Node[AgentNodeData]): def _run(self) -> Generator[NodeEventBase, None, None]: from core.plugin.impl.exc import PluginDaemonClientSideError - dify_ctx = self.require_dify_context() + dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) try: strategy = self._strategy_resolver.resolve( @@ -97,6 +100,7 @@ class AgentNode(Node[AgentNodeData]): node_data=self.node_data, strategy=strategy, tenant_id=dify_ctx.tenant_id, + user_id=dify_ctx.user_id, app_id=dify_ctx.app_id, invoke_from=dify_ctx.invoke_from, ) @@ -106,20 +110,21 @@ class AgentNode(Node[AgentNodeData]): node_data=self.node_data, strategy=strategy, tenant_id=dify_ctx.tenant_id, + user_id=dify_ctx.user_id, app_id=dify_ctx.app_id, invoke_from=dify_ctx.invoke_from, for_log=True, ) credentials = self._runtime_support.build_credentials(parameters=parameters) - conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) + conversation_id = get_system_text(self.graph_runtime_state.variable_pool, SystemVariableKey.CONVERSATION_ID) try: message_stream = strategy.invoke( params=parameters, user_id=dify_ctx.user_id, app_id=dify_ctx.app_id, - conversation_id=conversation_id.text if conversation_id else None, + conversation_id=conversation_id, credentials=credentials, ) except Exception as e: @@ -146,6 +151,7 @@ class AgentNode(Node[AgentNodeData]): parameters_for_log=parameters_for_log, user_id=dify_ctx.user_id, tenant_id=dify_ctx.tenant_id, + conversation_id=conversation_id, node_type=self.node_type, node_id=self._node_id, node_execution_id=self.id, diff --git a/api/core/workflow/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py index 91fed39795b..c52aad150bb 100644 --- a/api/core/workflow/nodes/agent/entities.py +++ b/api/core/workflow/nodes/agent/entities.py @@ -1,12 +1,12 @@ from enum import IntEnum, StrEnum, auto from typing import Any, Literal, Union +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType from pydantic import BaseModel from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.tools.entities.tool_entities import ToolSelector -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType class AgentNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/agent/message_transformer.py b/api/core/workflow/nodes/agent/message_transformer.py index f58a5665f4f..db74590ed76 100644 --- a/api/core/workflow/nodes/agent/message_transformer.py +++ b/api/core/workflow/nodes/agent/message_transformer.py @@ -3,23 +3,24 @@ from __future__ import annotations from collections.abc import Generator, Mapping from typing import Any, cast -from sqlalchemy import select -from sqlalchemy.orm import Session - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.utils.message_transformer import ToolFileMessageTransformer -from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from dify_graph.file import File, FileTransferMethod -from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.node_events import ( +from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type +from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.node_events import ( AgentLogEvent, NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent, ) -from dify_graph.variables.segments import ArrayFileSegment +from graphon.variables.segments import ArrayFileSegment +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.app.file_access import DatabaseFileAccessController +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.utils.message_transformer import ToolFileMessageTransformer from extensions.ext_database import db from factories import file_factory from models import ToolFile @@ -27,6 +28,8 @@ from services.tools.builtin_tools_manage_service import BuiltinToolManageService from .exceptions import AgentNodeError, AgentVariableTypeError, ToolFileNotFoundError +_file_access_controller = DatabaseFileAccessController() + class AgentMessageTransformer: def transform( @@ -37,6 +40,7 @@ class AgentMessageTransformer: parameters_for_log: dict[str, Any], user_id: str, tenant_id: str, + conversation_id: str | None, node_type: NodeType, node_id: str, node_execution_id: str, @@ -47,7 +51,7 @@ class AgentMessageTransformer: messages=messages, user_id=user_id, tenant_id=tenant_id, - conversation_id=None, + conversation_id=conversation_id, ) text = "" @@ -70,10 +74,12 @@ class AgentMessageTransformer: url = message.message.text if message.meta: transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) + tool_file_id = message.meta.get("tool_file_id") else: transfer_method = FileTransferMethod.TOOL_FILE - - tool_file_id = str(url).split("/")[-1].split(".")[0] + tool_file_id = None + if not isinstance(tool_file_id, str) or not tool_file_id: + raise ToolFileNotFoundError("missing tool_file_id metadata") with Session(db.engine) as session: stmt = select(ToolFile).where(ToolFile.id == tool_file_id) @@ -83,20 +89,23 @@ class AgentMessageTransformer: mapping = { "tool_file_id": tool_file_id, - "type": file_factory.get_file_type_by_mime_type(tool_file.mimetype), + "type": get_file_type_by_mime_type(tool_file.mimetype), "transfer_method": transfer_method, "url": url, } file = file_factory.build_from_mapping( mapping=mapping, tenant_id=tenant_id, + access_controller=_file_access_controller, ) files.append(file) elif message.type == ToolInvokeMessage.MessageType.BLOB: assert isinstance(message.message, ToolInvokeMessage.TextMessage) assert message.meta - tool_file_id = message.message.text.split("/")[-1].split(".")[0] + tool_file_id = message.meta.get("tool_file_id") + if not isinstance(tool_file_id, str) or not tool_file_id: + raise ToolFileNotFoundError("missing tool_file_id metadata") with Session(db.engine) as session: stmt = select(ToolFile).where(ToolFile.id == tool_file_id) tool_file = session.scalar(stmt) @@ -111,6 +120,7 @@ class AgentMessageTransformer: file_factory.build_from_mapping( mapping=mapping, tenant_id=tenant_id, + access_controller=_file_access_controller, ) ) elif message.type == ToolInvokeMessage.MessageType.TEXT: diff --git a/api/core/workflow/nodes/agent/runtime_support.py b/api/core/workflow/nodes/agent/runtime_support.py index 2ff7c964b9d..be50edbc4d4 100644 --- a/api/core/workflow/nodes/agent/runtime_support.py +++ b/api/core/workflow/nodes/agent/runtime_support.py @@ -4,6 +4,8 @@ import json from collections.abc import Sequence from typing import Any, cast +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType +from graphon.runtime import VariablePool from packaging.version import Version from pydantic import ValidationError from sqlalchemy import select @@ -12,15 +14,12 @@ from sqlalchemy.orm import Session from core.agent.entities import AgentToolEntity from core.agent.plugin_entities import AgentStrategyParameter from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_manager import ModelInstance, ModelManager +from core.model_manager import ModelInstance from core.plugin.entities.request import InvokeCredentials -from core.provider_manager import ProviderManager +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolProviderType from core.tools.tool_manager import ToolManager -from dify_graph.enums import SystemVariableKey -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType -from dify_graph.runtime import VariablePool -from dify_graph.variables.segments import StringSegment +from core.workflow.system_variables import SystemVariableKey, get_system_text from extensions.ext_database import db from models.model import Conversation @@ -38,6 +37,7 @@ class AgentRuntimeSupport: node_data: AgentNodeData, strategy: ResolvedAgentStrategy, tenant_id: str, + user_id: str, app_id: str, invoke_from: Any, for_log: bool = False, @@ -141,6 +141,7 @@ class AgentRuntimeSupport: tenant_id, app_id, entity, + user_id, invoke_from, runtime_variable_pool, ) @@ -174,7 +175,11 @@ class AgentRuntimeSupport: value = tool_value if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR: value = cast(dict[str, Any], value) - model_instance, model_schema = self.fetch_model(tenant_id=tenant_id, value=value) + model_instance, model_schema = self.fetch_model( + tenant_id=tenant_id, + user_id=user_id, + value=value, + ) history_prompt_messages = [] if node_data.memory: memory = self.fetch_memory( @@ -219,10 +224,9 @@ class AgentRuntimeSupport: app_id: str, model_instance: ModelInstance, ) -> TokenBufferMemory | None: - conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) - if not isinstance(conversation_id_variable, StringSegment): + conversation_id = get_system_text(variable_pool, SystemVariableKey.CONVERSATION_ID) + if conversation_id is None: return None - conversation_id = conversation_id_variable.value with Session(db.engine, expire_on_commit=False) as session: stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id) @@ -232,9 +236,15 @@ class AgentRuntimeSupport: return TokenBufferMemory(conversation=conversation, model_instance=model_instance) - def fetch_model(self, *, tenant_id: str, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]: - provider_manager = ProviderManager() - provider_model_bundle = provider_manager.get_provider_model_bundle( + def fetch_model( + self, + *, + tenant_id: str, + user_id: str, + value: dict[str, Any], + ) -> tuple[ModelInstance, AIModelEntity | None]: + assembly = create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id) + provider_model_bundle = assembly.provider_manager.get_provider_model_bundle( tenant_id=tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM, @@ -246,7 +256,7 @@ class AgentRuntimeSupport: ) provider_name = provider_model_bundle.configuration.provider.provider model_type_instance = provider_model_bundle.model_type_instance - model_instance = ModelManager().get_model_instance( + model_instance = assembly.model_manager.get_model_instance( tenant_id=tenant_id, provider=provider_name, model_type=ModelType(value.get("model_type", "")), diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 44f4a23a5a3..d9247b25932 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -1,22 +1,30 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import ( + BuiltinNodeTypes, + NodeExecutionType, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.node_events import NodeRunResult, StreamCompletedEvent +from graphon.nodes.base.node import Node +from graphon.nodes.base.variable_template_parser import VariableTemplateParser + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceProviderType from core.plugin.impl.exc import PluginDaemonClientSideError -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, SystemVariableKey, WorkflowNodeExecutionMetadataKey -from dify_graph.node_events import NodeRunResult, StreamCompletedEvent -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser +from core.workflow.file_reference import resolve_file_record_id +from core.workflow.system_variables import SystemVariableKey, get_system_segment from .entities import DatasourceNodeData, DatasourceParameter, OnlineDriveDownloadFileParam from .exc import DatasourceNodeError if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState class DatasourceNode(Node[DatasourceNodeData]): @@ -50,15 +58,14 @@ class DatasourceNode(Node[DatasourceNodeData]): """ Run the datasource node """ - - dify_ctx = self.require_dify_context() + dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) node_data = self.node_data variable_pool = self.graph_runtime_state.variable_pool - datasource_type_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE]) + datasource_type_segment = get_system_segment(variable_pool, SystemVariableKey.DATASOURCE_TYPE) if not datasource_type_segment: raise DatasourceNodeError("Datasource type is not set") datasource_type = str(datasource_type_segment.value) if datasource_type_segment.value else None - datasource_info_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO]) + datasource_info_segment = get_system_segment(variable_pool, SystemVariableKey.DATASOURCE_INFO) if not datasource_info_segment: raise DatasourceNodeError("Datasource info is not set") datasource_info_value = datasource_info_segment.value @@ -131,12 +138,14 @@ class DatasourceNode(Node[DatasourceNodeData]): ) ) case DatasourceProviderType.LOCAL_FILE: - related_id = datasource_info.get("related_id") - if not related_id: + file_id = resolve_file_record_id( + datasource_info.get("reference") or datasource_info.get("related_id") + ) + if not file_id: raise DatasourceNodeError("File is not exist") file_info = self.datasource_manager.get_upload_file_by_id( - file_id=related_id, tenant_id=dify_ctx.tenant_id + file_id=file_id, tenant_id=dify_ctx.tenant_id ) variable_pool.add([self._node_id, "file"], file_info) # variable_pool.add([self.node_id, "file"], file_info.to_dict()) diff --git a/api/core/workflow/nodes/datasource/entities.py b/api/core/workflow/nodes/datasource/entities.py index 65864474b08..cad32f8d5bd 100644 --- a/api/core/workflow/nodes/datasource/entities.py +++ b/api/core/workflow/nodes/datasource/entities.py @@ -1,11 +1,10 @@ from typing import Any, Literal, Union +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType from pydantic import BaseModel, field_validator from pydantic_core.core_schema import ValidationInfo -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType - class DatasourceEntity(BaseModel): plugin_id: str diff --git a/api/core/workflow/nodes/datasource/protocols.py b/api/core/workflow/nodes/datasource/protocols.py index c006e0885c1..776e2673172 100644 --- a/api/core/workflow/nodes/datasource/protocols.py +++ b/api/core/workflow/nodes/datasource/protocols.py @@ -1,8 +1,8 @@ from collections.abc import Generator from typing import Any, Protocol -from dify_graph.file import File -from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent +from graphon.file import File +from graphon.node_events import StreamChunkEvent, StreamCompletedEvent from .entities import DatasourceParameter, OnlineDriveDownloadFileParam diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py index 8d2e9bf3cb6..cba6c12dca0 100644 --- a/api/core/workflow/nodes/knowledge_index/entities.py +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -1,12 +1,12 @@ from typing import Literal, Union +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType from pydantic import BaseModel from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType class RerankingModelConfig(BaseModel): diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 4ea9091c5bb..bb72fe38816 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -2,16 +2,17 @@ import logging from collections.abc import Mapping from typing import TYPE_CHECKING, Any +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.base.template import Template + from core.rag.index_processor.index_processor import IndexProcessor from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.summary_index.summary_index import SummaryIndex from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import NodeExecutionType, SystemVariableKey -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.base.template import Template +from core.workflow.system_variables import SystemVariableKey, get_system_segment, get_system_text from .entities import KnowledgeIndexNodeData from .exc import ( @@ -19,8 +20,8 @@ from .exc import ( ) if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState logger = logging.getLogger(__name__) _INVOKE_FROM_DEBUGGER = "debugger" @@ -46,21 +47,20 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): variable_pool = self.graph_runtime_state.variable_pool # get dataset id as string - dataset_id_segment = variable_pool.get(["sys", SystemVariableKey.DATASET_ID]) + dataset_id_segment = get_system_segment(variable_pool, SystemVariableKey.DATASET_ID) if not dataset_id_segment: raise KnowledgeIndexNodeError("Dataset ID is required.") dataset_id: str = dataset_id_segment.value # get document id as string (may be empty when not provided) - document_id_segment = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) + document_id_segment = get_system_segment(variable_pool, SystemVariableKey.DOCUMENT_ID) document_id: str = document_id_segment.value if document_id_segment else "" # extract variables variable = variable_pool.get(node_data.index_chunk_variable_selector) if not variable: raise KnowledgeIndexNodeError("Index chunk variable is required.") - invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) - invoke_from_value = str(invoke_from.value) if invoke_from else None + invoke_from_value = get_system_text(variable_pool, SystemVariableKey.INVOKE_FROM) is_preview = invoke_from_value == _INVOKE_FROM_DEBUGGER chunks = variable.value @@ -87,8 +87,8 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): outputs=outputs.model_dump(exclude_none=True), ) - original_document_id_segment = variable_pool.get(["sys", SystemVariableKey.ORIGINAL_DOCUMENT_ID]) - batch = variable_pool.get(["sys", SystemVariableKey.BATCH]) + original_document_id_segment = get_system_segment(variable_pool, SystemVariableKey.ORIGINAL_DOCUMENT_ID) + batch = get_system_segment(variable_pool, SystemVariableKey.BATCH) if not batch: raise KnowledgeIndexNodeError("Batch is required.") diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index bc5618685a3..b1fa8593efe 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -1,12 +1,11 @@ from collections.abc import Sequence from typing import Literal +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.llm.entities import ModelConfig, VisionConfig from pydantic import BaseModel, Field -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig - class RerankingModelConfig(BaseModel): """ diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 80f59140bea..13624b27b37 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -8,27 +8,30 @@ import logging from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal -from core.app.app_config.entities import DatasetRetrieveConfigEntity -from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict -from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import ( +from graphon.entities import GraphInitParams +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import ( BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base import LLMUsageTrackingMixin -from dify_graph.nodes.base.node import Node -from dify_graph.variables import ( +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.node_events import NodeRunResult +from graphon.nodes.base import LLMUsageTrackingMixin +from graphon.nodes.base.node import Node +from graphon.variables import ( ArrayFileSegment, FileSegment, StringSegment, ) -from dify_graph.variables.segments import ArrayObjectSegment +from graphon.variables.segments import ArrayObjectSegment + +from core.app.app_config.entities import DatasetRetrieveConfigEntity +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval +from core.workflow.file_reference import parse_file_reference from .entities import ( Condition, @@ -42,8 +45,8 @@ from .exc import ( from .retrieval import KnowledgeRetrievalRequest, Source if TYPE_CHECKING: - from dify_graph.file.models import File - from dify_graph.runtime import GraphRuntimeState + from graphon.file import File + from graphon.runtime import GraphRuntimeState logger = logging.getLogger(__name__) @@ -160,7 +163,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD def _fetch_dataset_retriever( self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any] ) -> tuple[list[Source], LLMUsage]: - dify_ctx = self.require_dify_context() + dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) dataset_ids = node_data.dataset_ids query = variables.get("query") attachments = variables.get("attachments") @@ -254,7 +257,13 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD metadata_model_config=node_data.metadata_model_config, metadata_filtering_conditions=resolved_metadata_conditions, metadata_filtering_mode=metadata_filtering_mode, - attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None, + attachment_ids=[ + parsed_reference.record_id + for attachment in attachments + if (parsed_reference := parse_file_reference(attachment.reference)) is not None + ] + if attachments + else None, ) ) diff --git a/api/core/workflow/nodes/knowledge_retrieval/retrieval.py b/api/core/workflow/nodes/knowledge_retrieval/retrieval.py index e1311ab9621..39e2008a2ca 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/retrieval.py +++ b/api/core/workflow/nodes/knowledge_retrieval/retrieval.py @@ -1,10 +1,10 @@ from typing import Any, Literal, Protocol +from graphon.model_runtime.entities import LLMUsage +from graphon.nodes.llm.entities import ModelConfig from pydantic import BaseModel, Field from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict -from dify_graph.model_runtime.entities import LLMUsage -from dify_graph.nodes.llm.entities import ModelConfig from .entities import MetadataFilteringCondition @@ -54,7 +54,7 @@ class KnowledgeRetrievalRequest(BaseModel): tenant_id: str = Field(description="Tenant unique identifier") user_id: str = Field(description="User unique identifier") app_id: str = Field(description="Application unique identifier") - user_from: str = Field(description="Source of the user request (e.g., 'workflow', 'api')") + user_from: str = Field(description="User identity source for audit logging (e.g., 'account', 'end-user')") dataset_ids: list[str] = Field(description="List of dataset IDs to retrieve from") query: str | None = Field(default=None, description="Query text for knowledge retrieval") retrieval_mode: str = Field(description="Retrieval strategy: 'single' or 'multiple'") diff --git a/api/core/workflow/nodes/trigger_plugin/entities.py b/api/core/workflow/nodes/trigger_plugin/entities.py index ea7d20befe8..bf5be2379af 100644 --- a/api/core/workflow/nodes/trigger_plugin/entities.py +++ b/api/core/workflow/nodes/trigger_plugin/entities.py @@ -1,12 +1,12 @@ from collections.abc import Mapping from typing import Any, Literal, Union +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType from pydantic import BaseModel, Field, ValidationInfo, field_validator from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.trigger.entities.entities import EventParameter -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType from .exc import TriggerEventParameterError diff --git a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py index 118c2f2668b..e50de11bb90 100644 --- a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py +++ b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py @@ -1,12 +1,12 @@ from collections.abc import Mapping from typing import Any +from graphon.enums import NodeExecutionType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node + from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import NodeExecutionType, WorkflowNodeExecutionMetadataKey -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID from .entities import TriggerEventNodeData @@ -53,13 +53,11 @@ class TriggerEventNode(Node[TriggerEventNodeData]): "plugin_unique_identifier": self.node_data.plugin_unique_identifier, }, } - node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) - system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict() + node_inputs = dict(self.graph_runtime_state.variable_pool.get_by_prefix(self.id)) + system_inputs = self.graph_runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID) - # TODO: System variables should be directly accessible, no need for special handling - # Set system variables as node outputs. - for var in system_inputs: - node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] + for variable_name, value in system_inputs.items(): + node_inputs[f"{SYSTEM_VARIABLE_NODE_ID}.{variable_name}"] = value outputs = dict(node_inputs) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, diff --git a/api/core/workflow/nodes/trigger_schedule/entities.py b/api/core/workflow/nodes/trigger_schedule/entities.py index 95a25486786..f14ca893c9e 100644 --- a/api/core/workflow/nodes/trigger_schedule/entities.py +++ b/api/core/workflow/nodes/trigger_schedule/entities.py @@ -1,10 +1,10 @@ from typing import Literal, Union +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType from pydantic import BaseModel, Field from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType class TriggerScheduleNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/trigger_schedule/exc.py b/api/core/workflow/nodes/trigger_schedule/exc.py index 336d64d58f3..10962c3de44 100644 --- a/api/core/workflow/nodes/trigger_schedule/exc.py +++ b/api/core/workflow/nodes/trigger_schedule/exc.py @@ -1,4 +1,4 @@ -from dify_graph.entities.exc import BaseNodeError +from graphon.entities.exc import BaseNodeError class ScheduleNodeError(BaseNodeError): diff --git a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py index b9580e6ab15..a9753ab387d 100644 --- a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py +++ b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py @@ -1,11 +1,11 @@ from collections.abc import Mapping +from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node + from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import NodeExecutionType -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID from .entities import TriggerScheduleNodeData @@ -31,13 +31,11 @@ class TriggerScheduleNode(Node[TriggerScheduleNodeData]): } def _run(self) -> NodeRunResult: - node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) - system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict() + node_inputs = dict(self.graph_runtime_state.variable_pool.get_by_prefix(self.id)) + system_inputs = self.graph_runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID) - # TODO: System variables should be directly accessible, no need for special handling - # Set system variables as node outputs. - for var in system_inputs: - node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] + for variable_name, value in system_inputs.items(): + node_inputs[f"{SYSTEM_VARIABLE_NODE_ID}.{variable_name}"] = value outputs = dict(node_inputs) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, diff --git a/api/core/workflow/nodes/trigger_webhook/entities.py b/api/core/workflow/nodes/trigger_webhook/entities.py index 242bf5ef6a3..4d5ad72154b 100644 --- a/api/core/workflow/nodes/trigger_webhook/entities.py +++ b/api/core/workflow/nodes/trigger_webhook/entities.py @@ -1,12 +1,12 @@ from collections.abc import Sequence from enum import StrEnum +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType +from graphon.variables.types import SegmentType from pydantic import BaseModel, Field, field_validator from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType -from dify_graph.variables.types import SegmentType _WEBHOOK_HEADER_ALLOWED_TYPES = frozenset( { diff --git a/api/core/workflow/nodes/trigger_webhook/exc.py b/api/core/workflow/nodes/trigger_webhook/exc.py index 4d87f2a069b..00b0b3baad1 100644 --- a/api/core/workflow/nodes/trigger_webhook/exc.py +++ b/api/core/workflow/nodes/trigger_webhook/exc.py @@ -1,4 +1,4 @@ -from dify_graph.entities.exc import BaseNodeError +from graphon.entities.exc import BaseNodeError class WebhookNodeError(BaseNodeError): diff --git a/api/core/workflow/nodes/trigger_webhook/node.py b/api/core/workflow/nodes/trigger_webhook/node.py index 317844cbdaa..ebaac939345 100644 --- a/api/core/workflow/nodes/trigger_webhook/node.py +++ b/api/core/workflow/nodes/trigger_webhook/node.py @@ -2,16 +2,17 @@ import logging from collections.abc import Mapping from typing import Any +from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus +from graphon.file import FileTransferMethod +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.protocols import FileReferenceFactoryProtocol +from graphon.variables.types import SegmentType +from graphon.variables.variables import FileVariable + from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import NodeExecutionType -from dify_graph.file import FileTransferMethod -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.variables.types import SegmentType -from dify_graph.variables.variables import FileVariable -from factories import file_factory +from core.workflow.file_reference import resolve_file_record_id +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID from factories.variable_factory import build_segment_with_type from .entities import ContentType, WebhookData @@ -23,6 +24,13 @@ class TriggerWebhookNode(Node[WebhookData]): node_type = TRIGGER_WEBHOOK_NODE_TYPE execution_type = NodeExecutionType.ROOT + _file_reference_factory: FileReferenceFactoryProtocol + + def post_init(self) -> None: + from core.workflow.node_runtime import DifyFileReferenceFactory + + self._file_reference_factory = DifyFileReferenceFactory(self.graph_init_params.run_context) + @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: return { @@ -53,16 +61,14 @@ class TriggerWebhookNode(Node[WebhookData]): happens in the trigger controller. """ # Get webhook data from variable pool (injected by Celery task) - webhook_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) + webhook_inputs = dict(self.graph_runtime_state.variable_pool.get_by_prefix(self.id)) # Extract webhook-specific outputs based on node configuration outputs = self._extract_configured_outputs(webhook_inputs) - system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict() + system_inputs = self.graph_runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID) - # TODO: System variables should be directly accessible, no need for special handling - # Set system variables as node outputs. - for var in system_inputs: - outputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] + for variable_name, value in system_inputs.items(): + outputs[f"{SYSTEM_VARIABLE_NODE_ID}.{variable_name}"] = value return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=webhook_inputs, @@ -70,24 +76,20 @@ class TriggerWebhookNode(Node[WebhookData]): ) def generate_file_var(self, param_name: str, file: dict): - dify_ctx = self.require_dify_context() - related_id = file.get("related_id") + file_id = resolve_file_record_id(file.get("reference") or file.get("related_id")) transfer_method_value = file.get("transfer_method") if transfer_method_value: transfer_method = FileTransferMethod.value_of(transfer_method_value) match transfer_method: case FileTransferMethod.LOCAL_FILE | FileTransferMethod.REMOTE_URL: - file["upload_file_id"] = related_id + file["upload_file_id"] = file_id case FileTransferMethod.TOOL_FILE: - file["tool_file_id"] = related_id + file["tool_file_id"] = file_id case FileTransferMethod.DATASOURCE_FILE: - file["datasource_file_id"] = related_id + file["datasource_file_id"] = file_id try: - file_obj = file_factory.build_from_mapping( - mapping=file, - tenant_id=dify_ctx.tenant_id, - ) + file_obj = self._file_reference_factory.build_from_mapping(mapping=file) file_segment = build_segment_with_type(SegmentType.FILE, file_obj) return FileVariable(name=param_name, value=file_segment.value, selector=[self.id, param_name]) except ValueError: diff --git a/api/core/workflow/system_variables.py b/api/core/workflow/system_variables.py new file mode 100644 index 00000000000..9d15a3fcea8 --- /dev/null +++ b/api/core/workflow/system_variables.py @@ -0,0 +1,236 @@ +from __future__ import annotations + +from collections import defaultdict +from collections.abc import Mapping, Sequence +from enum import StrEnum +from typing import Any, Protocol, cast +from uuid import uuid4 + +from graphon.enums import BuiltinNodeTypes +from graphon.variables import build_segment, segment_to_variable +from graphon.variables.segments import Segment +from graphon.variables.variables import RAGPipelineVariableInput, Variable + +from .variable_prefixes import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, + RAG_PIPELINE_VARIABLE_NODE_ID, + SYSTEM_VARIABLE_NODE_ID, +) + + +class SystemVariableKey(StrEnum): + QUERY = "query" + FILES = "files" + CONVERSATION_ID = "conversation_id" + USER_ID = "user_id" + DIALOGUE_COUNT = "dialogue_count" + APP_ID = "app_id" + WORKFLOW_ID = "workflow_id" + WORKFLOW_EXECUTION_ID = "workflow_run_id" + TIMESTAMP = "timestamp" + DOCUMENT_ID = "document_id" + ORIGINAL_DOCUMENT_ID = "original_document_id" + BATCH = "batch" + DATASET_ID = "dataset_id" + DATASOURCE_TYPE = "datasource_type" + DATASOURCE_INFO = "datasource_info" + INVOKE_FROM = "invoke_from" + + +class _VariablePoolReader(Protocol): + def get(self, selector: Sequence[str], /) -> Segment | None: ... + + def get_by_prefix(self, prefix: str, /) -> Mapping[str, object]: ... + + +class _VariablePoolWriter(_VariablePoolReader, Protocol): + def add(self, selector: Sequence[str], value: object, /) -> None: ... + + +class _VariableLoader(Protocol): + def load_variables(self, selectors: list[list[str]]) -> Sequence[object]: ... + + +def system_variable_name(key: str | SystemVariableKey) -> str: + return key.value if isinstance(key, SystemVariableKey) else key + + +def system_variable_selector(key: str | SystemVariableKey) -> tuple[str, str]: + return SYSTEM_VARIABLE_NODE_ID, system_variable_name(key) + + +def _normalize_system_variable_values(values: Mapping[str, Any] | None = None, /, **kwargs: Any) -> dict[str, Any]: + raw_values = dict(values or {}) + raw_values.update(kwargs) + + workflow_execution_id = raw_values.pop("workflow_execution_id", None) + if workflow_execution_id is not None and SystemVariableKey.WORKFLOW_EXECUTION_ID.value not in raw_values: + raw_values[SystemVariableKey.WORKFLOW_EXECUTION_ID.value] = workflow_execution_id + + normalized: dict[str, Any] = {} + for key, value in raw_values.items(): + if value is None: + continue + normalized[system_variable_name(key)] = value + + normalized.setdefault(SystemVariableKey.FILES.value, []) + return normalized + + +def build_system_variables(values: Mapping[str, Any] | None = None, /, **kwargs: Any) -> list[Variable]: + normalized = _normalize_system_variable_values(values, **kwargs) + + return [ + cast( + Variable, + segment_to_variable( + segment=build_segment(value), + selector=system_variable_selector(key), + name=key, + ), + ) + for key, value in normalized.items() + ] + + +def default_system_variables() -> list[Variable]: + return build_system_variables(workflow_run_id=str(uuid4())) + + +def system_variables_to_mapping(system_variables: Sequence[Variable]) -> dict[str, Any]: + return {variable.name: variable.value for variable in system_variables} + + +def _with_selector(variable: Variable, node_id: str) -> Variable: + selector = [node_id, variable.name] + if list(variable.selector) == selector: + return variable + return variable.model_copy(update={"selector": selector}) + + +def build_bootstrap_variables( + *, + system_variables: Sequence[Variable] = (), + environment_variables: Sequence[Variable] = (), + conversation_variables: Sequence[Variable] = (), + rag_pipeline_variables: Sequence[RAGPipelineVariableInput] = (), +) -> list[Variable]: + variables = [ + *(_with_selector(variable, SYSTEM_VARIABLE_NODE_ID) for variable in system_variables), + *(_with_selector(variable, ENVIRONMENT_VARIABLE_NODE_ID) for variable in environment_variables), + *(_with_selector(variable, CONVERSATION_VARIABLE_NODE_ID) for variable in conversation_variables), + ] + + rag_pipeline_variables_map: defaultdict[str, dict[str, Any]] = defaultdict(dict) + for rag_var in rag_pipeline_variables: + node_id = rag_var.variable.belong_to_node_id + key = rag_var.variable.variable + rag_pipeline_variables_map[node_id][key] = rag_var.value + + for node_id, value in rag_pipeline_variables_map.items(): + variables.append( + cast( + Variable, + segment_to_variable( + segment=build_segment(value), + selector=(RAG_PIPELINE_VARIABLE_NODE_ID, node_id), + name=node_id, + ), + ) + ) + + return variables + + +def get_system_segment(variable_pool: _VariablePoolReader, key: str | SystemVariableKey) -> Segment | None: + return variable_pool.get(system_variable_selector(key)) + + +def get_system_value(variable_pool: _VariablePoolReader, key: str | SystemVariableKey) -> Any: + segment = get_system_segment(variable_pool, key) + return None if segment is None else segment.value + + +def get_system_text(variable_pool: _VariablePoolReader, key: str | SystemVariableKey) -> str | None: + segment = get_system_segment(variable_pool, key) + if segment is None: + return None + text = getattr(segment, "text", None) + return text if isinstance(text, str) else None + + +def get_all_system_variables(variable_pool: _VariablePoolReader) -> Mapping[str, object]: + return variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID) + + +_MEMORY_BOOTSTRAP_NODE_TYPES = frozenset( + ( + BuiltinNodeTypes.LLM, + BuiltinNodeTypes.QUESTION_CLASSIFIER, + BuiltinNodeTypes.PARAMETER_EXTRACTOR, + ) +) + + +def get_node_creation_preload_selectors( + *, + node_type: str, + node_data: object, +) -> tuple[tuple[str, str], ...]: + """Return selectors that must exist before node construction begins.""" + + if node_type not in _MEMORY_BOOTSTRAP_NODE_TYPES or getattr(node_data, "memory", None) is None: + return () + + return (system_variable_selector(SystemVariableKey.CONVERSATION_ID),) + + +def preload_node_creation_variables( + *, + variable_loader: _VariableLoader, + variable_pool: _VariablePoolWriter, + selectors: Sequence[Sequence[str]], +) -> None: + """Load constructor-time variables before node or graph creation.""" + + seen_selectors: set[tuple[str, ...]] = set() + selectors_to_load: list[list[str]] = [] + for selector in selectors: + normalized_selector = tuple(selector) + if len(normalized_selector) < 2: + raise ValueError(f"Invalid preload selector: {selector}") + if normalized_selector in seen_selectors: + continue + seen_selectors.add(normalized_selector) + if variable_pool.get(normalized_selector) is None: + selectors_to_load.append(list(normalized_selector)) + + loaded_variables = variable_loader.load_variables(selectors_to_load) + for variable in loaded_variables: + raw_selector = getattr(variable, "selector", ()) + loaded_selector = list(raw_selector) + if len(loaded_selector) < 2: + raise ValueError(f"Invalid loaded variable selector: {raw_selector}") + variable_pool.add(loaded_selector[:2], variable) + + +def inject_default_system_variable_mappings( + *, + node_id: str, + node_type: str, + node_data: object, + variable_mapping: Mapping[str, Sequence[str]], +) -> Mapping[str, Sequence[str]]: + """Add workflow-owned implicit sys mappings that `graphon` should not know about.""" + + if node_type != BuiltinNodeTypes.LLM or getattr(node_data, "memory", None) is None: + return variable_mapping + + query_mapping_key = f"{node_id}.#sys.query#" + if query_mapping_key in variable_mapping: + return variable_mapping + + augmented_mapping = dict(variable_mapping) + augmented_mapping[query_mapping_key] = system_variable_selector(SystemVariableKey.QUERY) + return augmented_mapping diff --git a/api/core/workflow/template_rendering.py b/api/core/workflow/template_rendering.py new file mode 100644 index 00000000000..d51cfadd098 --- /dev/null +++ b/api/core/workflow/template_rendering.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +from graphon.nodes.code.entities import CodeLanguage +from graphon.template_rendering import Jinja2TemplateRenderer, TemplateRenderError + +from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor + + +class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer): + """Sandbox-backed Jinja2 renderer for workflow-owned node composition.""" + + def render_template(self, template: str, variables: Mapping[str, Any]) -> str: + try: + result = CodeExecutor.execute_workflow_code_template( + language=CodeLanguage.JINJA2, + code=template, + inputs=variables, + ) + except Exception as exc: + if isinstance(exc, CodeExecutionError): + raise TemplateRenderError(str(exc)) from exc + raise + + rendered = result.get("result") + if not isinstance(rendered, str): + raise TemplateRenderError("Template render result must be a string.") + return rendered diff --git a/api/core/workflow/variable_pool_initializer.py b/api/core/workflow/variable_pool_initializer.py new file mode 100644 index 00000000000..43523e01b28 --- /dev/null +++ b/api/core/workflow/variable_pool_initializer.py @@ -0,0 +1,15 @@ +from collections.abc import Mapping, Sequence +from typing import Any + +from graphon.runtime import VariablePool +from graphon.variables.variables import Variable + + +def add_variables_to_pool(variable_pool: VariablePool, variables: Sequence[Variable]) -> None: + for variable in variables: + variable_pool.add(variable.selector, variable) + + +def add_node_inputs_to_pool(variable_pool: VariablePool, *, node_id: str, inputs: Mapping[str, Any]) -> None: + for key, value in inputs.items(): + variable_pool.add((node_id, key), value) diff --git a/api/dify_graph/constants.py b/api/core/workflow/variable_prefixes.py similarity index 100% rename from api/dify_graph/constants.py rename to api/core/workflow/variable_prefixes.py diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 2e51a06babc..2346a95d6a8 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -1,36 +1,44 @@ import logging import time from collections.abc import Generator, Mapping, Sequence -from typing import Any, cast +from typing import Any + +from graphon.entities import GraphInitParams +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.errors import WorkflowNodeRunFailedError +from graphon.file import File +from graphon.graph import Graph +from graphon.graph_engine import GraphEngine, GraphEngineConfig +from graphon.graph_engine.command_channels import CommandChannel, InMemoryChannel +from graphon.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer +from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.base.node import Node +from graphon.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from configs import dify_config +from context import capture_current_context from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context +from core.app.file_access import DatabaseFileAccessController from core.app.workflow.layers.llm_quota import LLMQuotaLayer from core.app.workflow.layers.observability import ObservabilityLayer -from core.workflow.node_factory import DifyNodeFactory, resolve_workflow_node_class -from dify_graph.constants import ENVIRONMENT_VARIABLE_NODE_ID -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_config import NodeConfigDictAdapter -from dify_graph.errors import WorkflowNodeRunFailedError -from dify_graph.file.models import File -from dify_graph.graph import Graph -from dify_graph.graph_engine import GraphEngine, GraphEngineConfig -from dify_graph.graph_engine.command_channels import InMemoryChannel -from dify_graph.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_engine.protocols.command_channel import CommandChannel -from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent -from dify_graph.nodes import BuiltinNodeTypes -from dify_graph.nodes.base.node import Node -from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool +from core.workflow.node_factory import DifyNodeFactory, is_start_node_type, resolve_workflow_node_class +from core.workflow.system_variables import ( + default_system_variables, + get_node_creation_preload_selectors, + inject_default_system_variable_mappings, + preload_node_creation_variables, +) +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool +from core.workflow.variable_prefixes import ENVIRONMENT_VARIABLE_NODE_ID from extensions.otel.runtime import is_instrument_flag_enabled from factories import file_factory from models.workflow import Workflow logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() class _WorkflowChildEngineBuilder: @@ -59,16 +67,22 @@ class _WorkflowChildEngineBuilder: *, workflow_id: str, graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: Mapping[str, Any], + parent_graph_runtime_state: GraphRuntimeState, root_node_id: str, - layers: Sequence[object] = (), + variable_pool: VariablePool | None = None, ) -> GraphEngine: + """Build a child engine with a fresh runtime state and only child-safe layers.""" + child_graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool if variable_pool is not None else parent_graph_runtime_state.variable_pool, + start_at=time.perf_counter(), + execution_context=parent_graph_runtime_state.execution_context, + ) node_factory = DifyNodeFactory( graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, + graph_runtime_state=child_graph_runtime_state, ) + graph_config = graph_init_params.graph_config has_root_node = self._has_node_id(graph_config=graph_config, node_id=root_node_id) if has_root_node is False: raise ChildGraphNotFoundError(f"child graph root node '{root_node_id}' not found") @@ -79,17 +93,17 @@ class _WorkflowChildEngineBuilder: root_node_id=root_node_id, ) + command_channel = InMemoryChannel() + config = GraphEngineConfig() child_engine = GraphEngine( workflow_id=workflow_id, graph=child_graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), + graph_runtime_state=child_graph_runtime_state, + command_channel=command_channel, + config=config, child_engine_builder=self, ) child_engine.layer(LLMQuotaLayer()) - for layer in layers: - child_engine.layer(cast(GraphEngineLayer, layer)) return child_engine @@ -136,6 +150,8 @@ class WorkflowEntry: command_channel = InMemoryChannel() self.command_channel = command_channel + execution_context = capture_current_context() + graph_runtime_state.execution_context = execution_context self._child_engine_builder = _WorkflowChildEngineBuilder() self.graph_engine = GraphEngine( workflow_id=workflow_id, @@ -212,6 +228,8 @@ class WorkflowEntry: # Get node type node_type = node_config_data.type + node_version = str(node_config_data.version) + node_cls = resolve_workflow_node_class(node_type=node_type, node_version=node_version) # init graph init params and runtime state graph_init_params = GraphInitParams( @@ -226,15 +244,23 @@ class WorkflowEntry: ), call_depth=0, ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - # init workflow run state - node_factory = DifyNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, + graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + start_at=time.perf_counter(), + execution_context=capture_current_context(), + ) + + if is_start_node_type(node_type): + add_node_inputs_to_pool(variable_pool, node_id=node_id, inputs=user_inputs) + + preload_node_creation_variables( + variable_loader=variable_loader, + variable_pool=variable_pool, + selectors=get_node_creation_preload_selectors( + node_type=node_type, + node_data=node_config_data, + ), ) - node = node_factory.create_node(node_config) - node_cls = type(node) try: # variable selector to variable mapping @@ -243,6 +269,12 @@ class WorkflowEntry: ) except NotImplementedError: variable_mapping = {} + variable_mapping = inject_default_system_variable_mappings( + node_id=node_id, + node_type=node_type, + node_data=node_config_data, + variable_mapping=variable_mapping, + ) # Loading missing variable from draft var here, and set it into # variable_pool. @@ -260,6 +292,13 @@ class WorkflowEntry: tenant_id=workflow.tenant_id, ) + # init workflow run state + node_factory = DifyNodeFactory( + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + node = node_factory.create_node(node_config) + try: generator = cls._traced_node_run(node) except Exception as e: @@ -347,11 +386,8 @@ class WorkflowEntry: raise ValueError(f"Node class not found for node type {node_type}") # init variable pool - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={}, - environment_variables=[], - ) + variable_pool = VariablePool() + add_variables_to_pool(variable_pool, default_system_variables()) # init graph init params and runtime state graph_init_params = GraphInitParams( @@ -366,7 +402,11 @@ class WorkflowEntry: ), call_depth=0, ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + start_at=time.perf_counter(), + execution_context=capture_current_context(), + ) # init workflow run state node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data}) @@ -384,6 +424,12 @@ class WorkflowEntry: ) except NotImplementedError: variable_mapping = {} + variable_mapping = inject_default_system_variable_mappings( + node_id=node_id, + node_type=node_type, + node_data=node_data, + variable_mapping=variable_mapping, + ) cls.mapping_user_inputs_to_variable_pool( variable_mapping=variable_mapping, @@ -477,13 +523,21 @@ class WorkflowEntry: continue if isinstance(input_value, dict) and "type" in input_value and "transfer_method" in input_value: - input_value = file_factory.build_from_mapping(mapping=input_value, tenant_id=tenant_id) + input_value = file_factory.build_from_mapping( + mapping=input_value, + tenant_id=tenant_id, + access_controller=_file_access_controller, + ) if ( isinstance(input_value, list) and all(isinstance(item, dict) for item in input_value) and all("type" in item and "transfer_method" in item for item in input_value) ): - input_value = file_factory.build_from_mappings(mappings=input_value, tenant_id=tenant_id) + input_value = file_factory.build_from_mappings( + mappings=input_value, + tenant_id=tenant_id, + access_controller=_file_access_controller, + ) # append variable and value to variable pool if variable_node_id != ENVIRONMENT_VARIABLE_NODE_ID: diff --git a/api/core/workflow/workflow_run_outputs.py b/api/core/workflow/workflow_run_outputs.py new file mode 100644 index 00000000000..bd89f7c441c --- /dev/null +++ b/api/core/workflow/workflow_run_outputs.py @@ -0,0 +1,18 @@ +from collections.abc import Mapping +from typing import Any + +from graphon.enums import BuiltinNodeTypes, NodeType + + +def project_node_outputs_for_workflow_run( + *, + node_type: NodeType, + inputs: Mapping[str, Any], + outputs: Mapping[str, Any], +) -> dict[str, Any]: + """Project internal node outputs onto the workflow-run public contract.""" + + if node_type == BuiltinNodeTypes.START: + return dict(inputs) + + return dict(outputs) diff --git a/api/dify_graph/README.md b/api/dify_graph/README.md deleted file mode 100644 index 2fc5b8b8904..00000000000 --- a/api/dify_graph/README.md +++ /dev/null @@ -1,135 +0,0 @@ -# Workflow - -## Project Overview - -This is the workflow graph engine module of Dify, implementing a queue-based distributed workflow execution system. The engine handles agentic AI workflows with support for parallel execution, node iteration, conditional logic, and external command control. - -## Architecture - -### Core Components - -The graph engine follows a layered architecture with strict dependency rules: - -1. **Graph Engine** (`graph_engine/`) - Orchestrates workflow execution - - - **Manager** - External control interface for stop/pause/resume commands - - **Worker** - Node execution runtime - - **Command Processing** - Handles control commands (abort, pause, resume) - - **Event Management** - Event propagation and layer notifications - - **Graph Traversal** - Edge processing and skip propagation - - **Response Coordinator** - Path tracking and session management - - **Layers** - Pluggable middleware (debug logging, execution limits) - - **Command Channels** - Communication channels (InMemory, Redis) - -1. **Graph** (`graph/`) - Graph structure and runtime state - - - **Graph Template** - Workflow definition - - **Edge** - Node connections with conditions - - **Runtime State Protocol** - State management interface - -1. **Nodes** (`nodes/`) - Node implementations - - - **Base** - Abstract node classes and variable parsing - - **Specific Nodes** - LLM, Agent, Code, HTTP Request, Iteration, Loop, etc. - -1. **Events** (`node_events/`) - Event system - - - **Base** - Event protocols - - **Node Events** - Node lifecycle events - -1. **Entities** (`entities/`) - Domain models - - - **Variable Pool** - Variable storage - - **Graph Init Params** - Initialization configuration - -## Key Design Patterns - -### Command Channel Pattern - -External workflow control via Redis or in-memory channels: - -```python -# Send stop command to running workflow -channel = RedisChannel(redis_client, f"workflow:{task_id}:commands") -channel.send_command(AbortCommand(reason="User requested")) -``` - -### Layer System - -Extensible middleware for cross-cutting concerns: - -```python -engine = GraphEngine(graph) -engine.layer(DebugLoggingLayer(level="INFO")) -engine.layer(ExecutionLimitsLayer(max_nodes=100)) -``` - -`engine.layer()` binds the read-only runtime state before execution, so layer hooks -can assume `graph_runtime_state` is available. - -### Event-Driven Architecture - -All node executions emit events for monitoring and integration: - -- `NodeRunStartedEvent` - Node execution begins -- `NodeRunSucceededEvent` - Node completes successfully -- `NodeRunFailedEvent` - Node encounters error -- `GraphRunStartedEvent/GraphRunCompletedEvent` - Workflow lifecycle - -### Variable Pool - -Centralized variable storage with namespace isolation: - -```python -# Variables scoped by node_id -pool.add(["node1", "output"], value) -result = pool.get(["node1", "output"]) -``` - -## Import Architecture Rules - -The codebase enforces strict layering via import-linter: - -1. **Workflow Layers** (top to bottom): - - - graph_engine → graph_events → graph → nodes → node_events → entities - -1. **Graph Engine Internal Layers**: - - - orchestration → command_processing → event_management → graph_traversal → domain - -1. **Domain Isolation**: - - - Domain models cannot import from infrastructure layers - -1. **Command Channel Independence**: - - - InMemory and Redis channels must remain independent - -## Common Tasks - -### Adding a New Node Type - -1. Create node class in `nodes//` -1. Inherit from `BaseNode` or appropriate base class -1. Implement `_run()` method -1. Ensure the node module is importable under `nodes//` -1. Add tests in `tests/unit_tests/dify_graph/nodes/` - -### Implementing a Custom Layer - -1. Create class inheriting from `Layer` base -1. Override lifecycle methods: `on_graph_start()`, `on_event()`, `on_graph_end()` -1. Add to engine via `engine.layer()` - -### Debugging Workflow Execution - -Enable debug logging layer: - -```python -debug_layer = DebugLoggingLayer( - level="DEBUG", - include_inputs=True, - include_outputs=True -) -``` diff --git a/api/dify_graph/context/__init__.py b/api/dify_graph/context/__init__.py deleted file mode 100644 index 103f526becb..00000000000 --- a/api/dify_graph/context/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -""" -Execution Context - Context management for workflow execution. - -This package provides Flask-independent context management for workflow -execution in multi-threaded environments. -""" - -from dify_graph.context.execution_context import ( - AppContext, - ContextProviderNotFoundError, - ExecutionContext, - IExecutionContext, - NullAppContext, - capture_current_context, - read_context, - register_context, - register_context_capturer, - reset_context_provider, -) -from dify_graph.context.models import SandboxContext - -__all__ = [ - "AppContext", - "ContextProviderNotFoundError", - "ExecutionContext", - "IExecutionContext", - "NullAppContext", - "SandboxContext", - "capture_current_context", - "read_context", - "register_context", - "register_context_capturer", - "reset_context_provider", -] diff --git a/api/dify_graph/conversation_variable_updater.py b/api/dify_graph/conversation_variable_updater.py deleted file mode 100644 index 17b19f25021..00000000000 --- a/api/dify_graph/conversation_variable_updater.py +++ /dev/null @@ -1,39 +0,0 @@ -import abc -from typing import Protocol - -from dify_graph.variables import VariableBase - - -class ConversationVariableUpdater(Protocol): - """ - ConversationVariableUpdater defines an abstraction for updating conversation variable values. - - It is intended for use by `v1.VariableAssignerNode` and `v2.VariableAssignerNode` when updating - conversation variables. - - Implementations may choose to batch updates. If batching is used, the `flush` method - should be implemented to persist buffered changes, and `update` - should handle buffering accordingly. - - Note: Since implementations may buffer updates, instances of ConversationVariableUpdater - are not thread-safe. Each VariableAssignerNode should create its own instance during execution. - """ - - @abc.abstractmethod - def update(self, conversation_id: str, variable: "VariableBase"): - """ - Updates the value of the specified conversation variable in the underlying storage. - - :param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`. - :param variable: The `VariableBase` instance containing the updated value. - """ - pass - - @abc.abstractmethod - def flush(self): - """ - Flushes all pending updates to the underlying storage system. - - If the implementation does not buffer updates, this method can be a no-op. - """ - pass diff --git a/api/dify_graph/entities/__init__.py b/api/dify_graph/entities/__init__.py deleted file mode 100644 index ef7789c49c8..00000000000 --- a/api/dify_graph/entities/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from .graph_init_params import GraphInitParams -from .workflow_execution import WorkflowExecution -from .workflow_node_execution import WorkflowNodeExecution -from .workflow_start_reason import WorkflowStartReason - -__all__ = [ - "GraphInitParams", - "WorkflowExecution", - "WorkflowNodeExecution", - "WorkflowStartReason", -] diff --git a/api/dify_graph/entities/base_node_data.py b/api/dify_graph/entities/base_node_data.py deleted file mode 100644 index 47b37c9dafd..00000000000 --- a/api/dify_graph/entities/base_node_data.py +++ /dev/null @@ -1,178 +0,0 @@ -from __future__ import annotations - -import json -from abc import ABC -from builtins import type as type_ -from enum import StrEnum -from typing import Any, Union - -from pydantic import BaseModel, ConfigDict, Field, model_validator - -from dify_graph.entities.exc import DefaultValueTypeError -from dify_graph.enums import ErrorStrategy, NodeType - -# Project supports Python 3.11+, where `typing.Union[...]` is valid in `isinstance`. -_NumberType = Union[int, float] - - -class RetryConfig(BaseModel): - """node retry config""" - - max_retries: int = 0 # max retry times - retry_interval: int = 0 # retry interval in milliseconds - retry_enabled: bool = False # whether retry is enabled - - @property - def retry_interval_seconds(self) -> float: - return self.retry_interval / 1000 - - -class DefaultValueType(StrEnum): - STRING = "string" - NUMBER = "number" - OBJECT = "object" - ARRAY_NUMBER = "array[number]" - ARRAY_STRING = "array[string]" - ARRAY_OBJECT = "array[object]" - ARRAY_FILES = "array[file]" - - -class DefaultValue(BaseModel): - value: Any = None - type: DefaultValueType - key: str - - @staticmethod - def _parse_json(value: str): - """Unified JSON parsing handler""" - try: - return json.loads(value) - except json.JSONDecodeError: - raise DefaultValueTypeError(f"Invalid JSON format for value: {value}") - - @staticmethod - def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool: - """Unified array type validation""" - return isinstance(value, list) and all(isinstance(x, element_type) for x in value) - - @staticmethod - def _convert_number(value: str) -> float: - """Unified number conversion handler""" - try: - return float(value) - except ValueError: - raise DefaultValueTypeError(f"Cannot convert to number: {value}") - - @model_validator(mode="after") - def validate_value_type(self) -> DefaultValue: - # Type validation configuration - type_validators: dict[DefaultValueType, dict[str, Any]] = { - DefaultValueType.STRING: { - "type": str, - "converter": lambda x: x, - }, - DefaultValueType.NUMBER: { - "type": _NumberType, - "converter": self._convert_number, - }, - DefaultValueType.OBJECT: { - "type": dict, - "converter": self._parse_json, - }, - DefaultValueType.ARRAY_NUMBER: { - "type": list, - "element_type": _NumberType, - "converter": self._parse_json, - }, - DefaultValueType.ARRAY_STRING: { - "type": list, - "element_type": str, - "converter": self._parse_json, - }, - DefaultValueType.ARRAY_OBJECT: { - "type": list, - "element_type": dict, - "converter": self._parse_json, - }, - } - - validator: dict[str, Any] = type_validators.get(self.type, {}) - if not validator: - if self.type == DefaultValueType.ARRAY_FILES: - # Handle files type - return self - raise DefaultValueTypeError(f"Unsupported type: {self.type}") - - # Handle string input cases - if isinstance(self.value, str) and self.type != DefaultValueType.STRING: - self.value = validator["converter"](self.value) - - # Validate base type - if not isinstance(self.value, validator["type"]): - raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}") - - # Validate array element types - if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]): - raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}") - - return self - - -class BaseNodeData(ABC, BaseModel): - # Raw graph payloads are first validated through `NodeConfigDictAdapter`, where - # `node["data"]` is typed as `BaseNodeData` before the concrete node class is known. - # `type` therefore accepts downstream string node kinds; unknown node implementations - # are rejected later when the node factory resolves the node registry. - # At that boundary, node-specific fields are still "extra" relative to this shared DTO, - # and persisted templates/workflows also carry undeclared compatibility keys such as - # `selected`, `params`, `paramSchemas`, and `datasource_label`. Keep extras permissive - # here until graph parsing becomes discriminated by node type or those legacy payloads - # are normalized. - model_config = ConfigDict(extra="allow") - - type: NodeType - title: str = "" - desc: str | None = None - version: str = "1" - error_strategy: ErrorStrategy | None = None - default_value: list[DefaultValue] | None = None - retry_config: RetryConfig = Field(default_factory=RetryConfig) - - @property - def default_value_dict(self) -> dict[str, Any]: - if self.default_value: - return {item.key: item.value for item in self.default_value} - return {} - - def __getitem__(self, key: str) -> Any: - """ - Dict-style access without calling model_dump() on every lookup. - Prefer using model fields and Pydantic's extra storage. - """ - # First, check declared model fields - if key in self.__class__.model_fields: - return getattr(self, key) - - # Then, check undeclared compatibility fields stored in Pydantic's extra dict. - extras = getattr(self, "__pydantic_extra__", None) - if extras is None: - extras = getattr(self, "model_extra", None) - if extras is not None and key in extras: - return extras[key] - - raise KeyError(key) - - def get(self, key: str, default: Any = None) -> Any: - """ - Dict-style .get() without calling model_dump() on every lookup. - """ - if key in self.__class__.model_fields: - return getattr(self, key) - - extras = getattr(self, "__pydantic_extra__", None) - if extras is None: - extras = getattr(self, "model_extra", None) - if extras is not None and key in extras: - return extras.get(key, default) - - return default diff --git a/api/dify_graph/entities/exc.py b/api/dify_graph/entities/exc.py deleted file mode 100644 index aeecf406403..00000000000 --- a/api/dify_graph/entities/exc.py +++ /dev/null @@ -1,10 +0,0 @@ -class BaseNodeError(ValueError): - """Base class for node errors.""" - - pass - - -class DefaultValueTypeError(BaseNodeError): - """Raised when the default value type is invalid.""" - - pass diff --git a/api/dify_graph/entities/graph_config.py b/api/dify_graph/entities/graph_config.py deleted file mode 100644 index 36f7b94e824..00000000000 --- a/api/dify_graph/entities/graph_config.py +++ /dev/null @@ -1,23 +0,0 @@ -from __future__ import annotations - -import sys - -from pydantic import TypeAdapter, with_config - -from dify_graph.entities.base_node_data import BaseNodeData - -if sys.version_info >= (3, 12): - from typing import TypedDict -else: - from typing_extensions import TypedDict - - -@with_config(extra="allow") -class NodeConfigDict(TypedDict): - id: str - # This is the permissive raw graph boundary. Node factories re-validate `data` - # with the concrete `NodeData` subtype after resolving the node implementation. - data: BaseNodeData - - -NodeConfigDictAdapter = TypeAdapter(NodeConfigDict) diff --git a/api/dify_graph/entities/graph_init_params.py b/api/dify_graph/entities/graph_init_params.py deleted file mode 100644 index f785d58a528..00000000000 --- a/api/dify_graph/entities/graph_init_params.py +++ /dev/null @@ -1,24 +0,0 @@ -from collections.abc import Mapping -from typing import Any - -from pydantic import BaseModel, Field - -DIFY_RUN_CONTEXT_KEY = "_dify" - - -class GraphInitParams(BaseModel): - """GraphInitParams encapsulates the configurations and contextual information - that remain constant throughout a single execution of the graph engine. - - A single execution is defined as follows: as long as the execution has not reached - its conclusion, it is considered one execution. For instance, if a workflow is suspended - and later resumed, it is still regarded as a single execution, not two. - - For the state diagram of workflow execution, refer to `WorkflowExecutionStatus`. - """ - - # init params - workflow_id: str = Field(..., description="workflow id") - graph_config: Mapping[str, Any] = Field(..., description="graph config") - run_context: Mapping[str, Any] = Field(..., description="runtime context") - call_depth: int = Field(..., description="call depth") diff --git a/api/dify_graph/entities/pause_reason.py b/api/dify_graph/entities/pause_reason.py deleted file mode 100644 index 86d8c8ca162..00000000000 --- a/api/dify_graph/entities/pause_reason.py +++ /dev/null @@ -1,50 +0,0 @@ -from collections.abc import Mapping -from enum import StrEnum, auto -from typing import Annotated, Any, Literal, TypeAlias - -from pydantic import BaseModel, Field - -from dify_graph.nodes.human_input.entities import FormInput, UserAction - - -class PauseReasonType(StrEnum): - HUMAN_INPUT_REQUIRED = auto() - SCHEDULED_PAUSE = auto() - - -class HumanInputRequired(BaseModel): - TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED - form_id: str - form_content: str - inputs: list[FormInput] = Field(default_factory=list) - actions: list[UserAction] = Field(default_factory=list) - display_in_ui: bool = False - node_id: str - node_title: str - - # The `resolved_default_values` stores the resolved values of variable defaults. It's a mapping from - # `output_variable_name` to their resolved values. - # - # For example, The form contains a input with output variable name `name` and placeholder type `VARIABLE`, its - # selector is ["start", "name"]. While the HumanInputNode is executed, the correspond value of variable - # `start.name` in variable pool is `John`. Thus, the resolved value of the output variable `name` is `John`. The - # `resolved_default_values` is `{"name": "John"}`. - # - # Only form inputs with default value type `VARIABLE` will be resolved and stored in `resolved_default_values`. - resolved_default_values: Mapping[str, Any] = Field(default_factory=dict) - - # The `form_token` is the token used to submit the form via UI surfaces. It corresponds to - # `HumanInputFormRecipient.access_token`. - # - # This field is `None` if webapp delivery is not set and not - # in orchestrating mode. - form_token: str | None = None - - -class SchedulingPause(BaseModel): - TYPE: Literal[PauseReasonType.SCHEDULED_PAUSE] = PauseReasonType.SCHEDULED_PAUSE - - message: str - - -PauseReason: TypeAlias = Annotated[HumanInputRequired | SchedulingPause, Field(discriminator="TYPE")] diff --git a/api/dify_graph/entities/workflow_execution.py b/api/dify_graph/entities/workflow_execution.py deleted file mode 100644 index 459ac46415d..00000000000 --- a/api/dify_graph/entities/workflow_execution.py +++ /dev/null @@ -1,74 +0,0 @@ -""" -Domain entities for workflow execution. - -Models are independent of the storage mechanism and don't contain -implementation details like tenant_id, app_id, etc. -""" - -from __future__ import annotations - -from collections.abc import Mapping -from datetime import datetime -from typing import Any - -from pydantic import BaseModel, Field - -from dify_graph.enums import WorkflowExecutionStatus, WorkflowType -from libs.datetime_utils import naive_utc_now - - -class WorkflowExecution(BaseModel): - """ - Domain model for workflow execution based on WorkflowRun but without - user, tenant, and app attributes. - """ - - id_: str = Field(...) - workflow_id: str = Field(...) - workflow_version: str = Field(...) - workflow_type: WorkflowType = Field(...) - graph: Mapping[str, Any] = Field(...) - - inputs: Mapping[str, Any] = Field(...) - outputs: Mapping[str, Any] | None = None - - status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING - error_message: str = Field(default="") - total_tokens: int = Field(default=0) - total_steps: int = Field(default=0) - exceptions_count: int = Field(default=0) - - started_at: datetime = Field(...) - finished_at: datetime | None = None - - @property - def elapsed_time(self) -> float: - """ - Calculate elapsed time in seconds. - If workflow is not finished, use current time. - """ - end_time = self.finished_at or naive_utc_now() - return (end_time - self.started_at).total_seconds() - - @classmethod - def new( - cls, - *, - id_: str, - workflow_id: str, - workflow_type: WorkflowType, - workflow_version: str, - graph: Mapping[str, Any], - inputs: Mapping[str, Any], - started_at: datetime, - ) -> WorkflowExecution: - return WorkflowExecution( - id_=id_, - workflow_id=workflow_id, - workflow_type=workflow_type, - workflow_version=workflow_version, - graph=graph, - inputs=inputs, - status=WorkflowExecutionStatus.RUNNING, - started_at=started_at, - ) diff --git a/api/dify_graph/entities/workflow_node_execution.py b/api/dify_graph/entities/workflow_node_execution.py deleted file mode 100644 index bc7e0d02e57..00000000000 --- a/api/dify_graph/entities/workflow_node_execution.py +++ /dev/null @@ -1,147 +0,0 @@ -""" -Domain entities for workflow node execution. - -This module contains the domain model for workflow node execution, which is used -by the core workflow module. These models are independent of the storage mechanism -and don't contain implementation details like tenant_id, app_id, etc. -""" - -from collections.abc import Mapping -from datetime import datetime -from typing import Any - -from pydantic import BaseModel, Field, PrivateAttr - -from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus - - -class WorkflowNodeExecution(BaseModel): - """ - Domain model for workflow node execution. - - This model represents the core business entity of a node execution, - without implementation details like tenant_id, app_id, etc. - - Note: User/context-specific fields (triggered_from, created_by, created_by_role) - have been moved to the repository implementation to keep the domain model clean. - These fields are still accepted in the constructor for backward compatibility, - but they are not stored in the model. - """ - - # --------- Core identification fields --------- - - # Unique identifier for this execution record, used when persisting to storage. - # Value is a UUID string (e.g., '09b3e04c-f9ae-404c-ad82-290b8d7bd382'). - id: str - - # Optional secondary ID for cross-referencing purposes. - # - # NOTE: For referencing the persisted record, use `id` rather than `node_execution_id`. - # While `node_execution_id` may sometimes be a UUID string, this is not guaranteed. - # In most scenarios, `id` should be used as the primary identifier. - node_execution_id: str | None = None - workflow_id: str # ID of the workflow this node belongs to - workflow_execution_id: str | None = None # ID of the specific workflow run (null for single-step debugging) - # --------- Core identification fields ends --------- - - # Execution positioning and flow - index: int # Sequence number for ordering in trace visualization - predecessor_node_id: str | None = None # ID of the node that executed before this one - node_id: str # ID of the node being executed - node_type: NodeType # Type of node (e.g., start, llm, downstream response node) - title: str # Display title of the node - - # Execution data - # The `inputs` and `outputs` fields hold the full content - inputs: Mapping[str, Any] | None = None # Input variables used by this node - process_data: Mapping[str, Any] | None = None # Intermediate processing data - outputs: Mapping[str, Any] | None = None # Output variables produced by this node - - # Execution state - status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING # Current execution status - error: str | None = None # Error message if execution failed - elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds - - # Additional metadata - metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None # Execution metadata (tokens, cost, etc.) - - # Timing information - created_at: datetime # When execution started - finished_at: datetime | None = None # When execution completed - - _truncated_inputs: Mapping[str, Any] | None = PrivateAttr(None) - _truncated_outputs: Mapping[str, Any] | None = PrivateAttr(None) - _truncated_process_data: Mapping[str, Any] | None = PrivateAttr(None) - - def get_truncated_inputs(self) -> Mapping[str, Any] | None: - return self._truncated_inputs - - def get_truncated_outputs(self) -> Mapping[str, Any] | None: - return self._truncated_outputs - - def get_truncated_process_data(self) -> Mapping[str, Any] | None: - return self._truncated_process_data - - def set_truncated_inputs(self, truncated_inputs: Mapping[str, Any] | None): - self._truncated_inputs = truncated_inputs - - def set_truncated_outputs(self, truncated_outputs: Mapping[str, Any] | None): - self._truncated_outputs = truncated_outputs - - def set_truncated_process_data(self, truncated_process_data: Mapping[str, Any] | None): - self._truncated_process_data = truncated_process_data - - def get_response_inputs(self) -> Mapping[str, Any] | None: - inputs = self.get_truncated_inputs() - if inputs: - return inputs - return self.inputs - - @property - def inputs_truncated(self): - return self._truncated_inputs is not None - - @property - def outputs_truncated(self): - return self._truncated_outputs is not None - - @property - def process_data_truncated(self): - return self._truncated_process_data is not None - - def get_response_outputs(self) -> Mapping[str, Any] | None: - outputs = self.get_truncated_outputs() - if outputs is not None: - return outputs - return self.outputs - - def get_response_process_data(self) -> Mapping[str, Any] | None: - process_data = self.get_truncated_process_data() - if process_data is not None: - return process_data - return self.process_data - - def update_from_mapping( - self, - inputs: Mapping[str, Any] | None = None, - process_data: Mapping[str, Any] | None = None, - outputs: Mapping[str, Any] | None = None, - metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None, - ): - """ - Update the model from mappings. - - Args: - inputs: The inputs to update - process_data: The process data to update - outputs: The outputs to update - metadata: The metadata to update - """ - if inputs is not None: - self.inputs = dict(inputs) - if process_data is not None: - self.process_data = dict(process_data) - if outputs is not None: - self.outputs = dict(outputs) - if metadata is not None: - self.metadata = dict(metadata) diff --git a/api/dify_graph/entities/workflow_start_reason.py b/api/dify_graph/entities/workflow_start_reason.py deleted file mode 100644 index df0f75383b0..00000000000 --- a/api/dify_graph/entities/workflow_start_reason.py +++ /dev/null @@ -1,8 +0,0 @@ -from enum import StrEnum - - -class WorkflowStartReason(StrEnum): - """Reason for workflow start events across graph/queue/SSE layers.""" - - INITIAL = "initial" # First start of a workflow run. - RESUMPTION = "resumption" # Start triggered after resuming a paused run. diff --git a/api/dify_graph/enums.py b/api/dify_graph/enums.py deleted file mode 100644 index cfb135cbb0f..00000000000 --- a/api/dify_graph/enums.py +++ /dev/null @@ -1,286 +0,0 @@ -from enum import StrEnum -from typing import ClassVar, TypeAlias - - -class NodeState(StrEnum): - """State of a node or edge during workflow execution.""" - - UNKNOWN = "unknown" - TAKEN = "taken" - SKIPPED = "skipped" - - -class SystemVariableKey(StrEnum): - """ - System Variables. - """ - - QUERY = "query" - FILES = "files" - CONVERSATION_ID = "conversation_id" - USER_ID = "user_id" - DIALOGUE_COUNT = "dialogue_count" - APP_ID = "app_id" - WORKFLOW_ID = "workflow_id" - WORKFLOW_EXECUTION_ID = "workflow_run_id" - TIMESTAMP = "timestamp" - # RAG Pipeline - DOCUMENT_ID = "document_id" - ORIGINAL_DOCUMENT_ID = "original_document_id" - BATCH = "batch" - DATASET_ID = "dataset_id" - DATASOURCE_TYPE = "datasource_type" - DATASOURCE_INFO = "datasource_info" - INVOKE_FROM = "invoke_from" - - -NodeType: TypeAlias = str - - -class BuiltinNodeTypes: - """Built-in node type string constants. - - `node_type` values are plain strings throughout the graph runtime. This namespace - only exposes the built-in values shipped by `dify_graph`; downstream packages can - use additional strings without extending this class. - """ - - START: ClassVar[NodeType] = "start" - END: ClassVar[NodeType] = "end" - ANSWER: ClassVar[NodeType] = "answer" - LLM: ClassVar[NodeType] = "llm" - KNOWLEDGE_RETRIEVAL: ClassVar[NodeType] = "knowledge-retrieval" - IF_ELSE: ClassVar[NodeType] = "if-else" - CODE: ClassVar[NodeType] = "code" - TEMPLATE_TRANSFORM: ClassVar[NodeType] = "template-transform" - QUESTION_CLASSIFIER: ClassVar[NodeType] = "question-classifier" - HTTP_REQUEST: ClassVar[NodeType] = "http-request" - TOOL: ClassVar[NodeType] = "tool" - DATASOURCE: ClassVar[NodeType] = "datasource" - VARIABLE_AGGREGATOR: ClassVar[NodeType] = "variable-aggregator" - LEGACY_VARIABLE_AGGREGATOR: ClassVar[NodeType] = "variable-assigner" - LOOP: ClassVar[NodeType] = "loop" - LOOP_START: ClassVar[NodeType] = "loop-start" - LOOP_END: ClassVar[NodeType] = "loop-end" - ITERATION: ClassVar[NodeType] = "iteration" - ITERATION_START: ClassVar[NodeType] = "iteration-start" - PARAMETER_EXTRACTOR: ClassVar[NodeType] = "parameter-extractor" - VARIABLE_ASSIGNER: ClassVar[NodeType] = "assigner" - DOCUMENT_EXTRACTOR: ClassVar[NodeType] = "document-extractor" - LIST_OPERATOR: ClassVar[NodeType] = "list-operator" - AGENT: ClassVar[NodeType] = "agent" - HUMAN_INPUT: ClassVar[NodeType] = "human-input" - - -BUILT_IN_NODE_TYPES: tuple[NodeType, ...] = ( - BuiltinNodeTypes.START, - BuiltinNodeTypes.END, - BuiltinNodeTypes.ANSWER, - BuiltinNodeTypes.LLM, - BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL, - BuiltinNodeTypes.IF_ELSE, - BuiltinNodeTypes.CODE, - BuiltinNodeTypes.TEMPLATE_TRANSFORM, - BuiltinNodeTypes.QUESTION_CLASSIFIER, - BuiltinNodeTypes.HTTP_REQUEST, - BuiltinNodeTypes.TOOL, - BuiltinNodeTypes.DATASOURCE, - BuiltinNodeTypes.VARIABLE_AGGREGATOR, - BuiltinNodeTypes.LEGACY_VARIABLE_AGGREGATOR, - BuiltinNodeTypes.LOOP, - BuiltinNodeTypes.LOOP_START, - BuiltinNodeTypes.LOOP_END, - BuiltinNodeTypes.ITERATION, - BuiltinNodeTypes.ITERATION_START, - BuiltinNodeTypes.PARAMETER_EXTRACTOR, - BuiltinNodeTypes.VARIABLE_ASSIGNER, - BuiltinNodeTypes.DOCUMENT_EXTRACTOR, - BuiltinNodeTypes.LIST_OPERATOR, - BuiltinNodeTypes.AGENT, - BuiltinNodeTypes.HUMAN_INPUT, -) - - -class NodeExecutionType(StrEnum): - """Node execution type classification.""" - - EXECUTABLE = "executable" # Regular nodes that execute and produce outputs - RESPONSE = "response" # Response nodes that stream outputs (Answer, End) - BRANCH = "branch" # Nodes that can choose different branches (if-else, question-classifier) - CONTAINER = "container" # Container nodes that manage subgraphs (iteration, loop, graph) - ROOT = "root" # Nodes that can serve as execution entry points - - -class ErrorStrategy(StrEnum): - FAIL_BRANCH = "fail-branch" - DEFAULT_VALUE = "default-value" - - -class FailBranchSourceHandle(StrEnum): - FAILED = "fail-branch" - SUCCESS = "success-branch" - - -class WorkflowType(StrEnum): - """ - Workflow Type Enum for domain layer - """ - - WORKFLOW = "workflow" - CHAT = "chat" - RAG_PIPELINE = "rag-pipeline" - - -class WorkflowExecutionStatus(StrEnum): - # State diagram for the workflw status: - # (@) means start, (*) means end - # - # ┌------------------>------------------------->------------------->--------------┐ - # | | - # | ┌-----------------------<--------------------┐ | - # ^ | | | - # | | ^ | - # | V | | - # ┌-----------┐ ┌-----------------------┐ ┌-----------┐ V - # | Scheduled |------->| Running |---------------------->| paused | | - # └-----------┘ └-----------------------┘ └-----------┘ | - # | | | | | | | - # | | | | | | | - # ^ | | | V V | - # | | | | | ┌---------┐ | - # (@) | | | └------------------------>| Stopped |<----┘ - # | | | └---------┘ - # | | | | - # | | V V - # | | ┌-----------┐ | - # | | | Succeeded |------------->--------------┤ - # | | └-----------┘ | - # | V V - # | +--------┐ | - # | | Failed |---------------------->----------------┤ - # | └--------┘ | - # V V - # ┌---------------------┐ | - # | Partially Succeeded |---------------------->-----------------┘--------> (*) - # └---------------------┘ - # - # Mermaid diagram: - # - # --- - # title: State diagram for Workflow run state - # --- - # stateDiagram-v2 - # scheduled: Scheduled - # running: Running - # succeeded: Succeeded - # failed: Failed - # partial_succeeded: Partial Succeeded - # paused: Paused - # stopped: Stopped - # - # [*] --> scheduled: - # scheduled --> running: Start Execution - # running --> paused: Human input required - # paused --> running: human input added - # paused --> stopped: User stops execution - # running --> succeeded: Execution finishes without any error - # running --> failed: Execution finishes with errors - # running --> stopped: User stops execution - # running --> partial_succeeded: some execution occurred and handled during execution - # - # scheduled --> stopped: User stops execution - # - # succeeded --> [*] - # failed --> [*] - # partial_succeeded --> [*] - # stopped --> [*] - - # `SCHEDULED` means that the workflow is scheduled to run, but has not - # started running yet. (maybe due to possible worker saturation.) - # - # This enum value is currently unused. - SCHEDULED = "scheduled" - - # `RUNNING` means the workflow is exeuting. - RUNNING = "running" - - # `SUCCEEDED` means the execution of workflow succeed without any error. - SUCCEEDED = "succeeded" - - # `FAILED` means the execution of workflow failed without some errors. - FAILED = "failed" - - # `STOPPED` means the execution of workflow was stopped, either manually - # by the user, or automatically by the Dify application (E.G. the moderation - # mechanism.) - STOPPED = "stopped" - - # `PARTIAL_SUCCEEDED` indicates that some errors occurred during the workflow - # execution, but they were successfully handled (e.g., by using an error - # strategy such as "fail branch" or "default value"). - PARTIAL_SUCCEEDED = "partial-succeeded" - - # `PAUSED` indicates that the workflow execution is temporarily paused - # (e.g., awaiting human input) and is expected to resume later. - PAUSED = "paused" - - def is_ended(self) -> bool: - return self in _END_STATE - - @classmethod - def ended_values(cls) -> list[str]: - return [status.value for status in _END_STATE] - - -_END_STATE = frozenset( - [ - WorkflowExecutionStatus.SUCCEEDED, - WorkflowExecutionStatus.FAILED, - WorkflowExecutionStatus.PARTIAL_SUCCEEDED, - WorkflowExecutionStatus.STOPPED, - ] -) - - -class WorkflowNodeExecutionMetadataKey(StrEnum): - """ - Node Run Metadata Key. - - Values in this enum are persisted as execution metadata and must stay in sync - with every node that writes `NodeRunResult.metadata`. - """ - - TOTAL_TOKENS = "total_tokens" - TOTAL_PRICE = "total_price" - CURRENCY = "currency" - TOOL_INFO = "tool_info" - AGENT_LOG = "agent_log" - ITERATION_ID = "iteration_id" - ITERATION_INDEX = "iteration_index" - LOOP_ID = "loop_id" - LOOP_INDEX = "loop_index" - PARALLEL_ID = "parallel_id" - PARALLEL_START_NODE_ID = "parallel_start_node_id" - PARENT_PARALLEL_ID = "parent_parallel_id" - PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id" - PARALLEL_MODE_RUN_ID = "parallel_mode_run_id" - ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs - LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs - ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field - LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output - DATASOURCE_INFO = "datasource_info" - TRIGGER_INFO = "trigger_info" - COMPLETED_REASON = "completed_reason" # completed reason for loop node - - -class WorkflowNodeExecutionStatus(StrEnum): - PENDING = "pending" # Node is scheduled but not yet executing - RUNNING = "running" - SUCCEEDED = "succeeded" - FAILED = "failed" - EXCEPTION = "exception" - STOPPED = "stopped" - PAUSED = "paused" - - # Legacy statuses - kept for backward compatibility - RETRY = "retry" # Legacy: replaced by retry mechanism in error handling diff --git a/api/dify_graph/errors.py b/api/dify_graph/errors.py deleted file mode 100644 index 463d17713e4..00000000000 --- a/api/dify_graph/errors.py +++ /dev/null @@ -1,16 +0,0 @@ -from dify_graph.nodes.base.node import Node - - -class WorkflowNodeRunFailedError(Exception): - def __init__(self, node: Node, err_msg: str): - self._node = node - self._error = err_msg - super().__init__(f"Node {node.title} run failed: {err_msg}") - - @property - def node(self) -> Node: - return self._node - - @property - def error(self) -> str: - return self._error diff --git a/api/dify_graph/file/__init__.py b/api/dify_graph/file/__init__.py deleted file mode 100644 index 44749ebec35..00000000000 --- a/api/dify_graph/file/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -from .constants import FILE_MODEL_IDENTITY -from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType -from .models import ( - File, - FileUploadConfig, - ImageConfig, -) - -__all__ = [ - "FILE_MODEL_IDENTITY", - "ArrayFileAttribute", - "File", - "FileAttribute", - "FileBelongsTo", - "FileTransferMethod", - "FileType", - "FileUploadConfig", - "ImageConfig", -] diff --git a/api/dify_graph/file/constants.py b/api/dify_graph/file/constants.py deleted file mode 100644 index 0665ed7e0de..00000000000 --- a/api/dify_graph/file/constants.py +++ /dev/null @@ -1,11 +0,0 @@ -from typing import Any - -# TODO(QuantumGhost): Refactor variable type identification. Instead of directly -# comparing `dify_model_identity` with constants throughout the codebase, extract -# this logic into a dedicated function. This would encapsulate the implementation -# details of how different variable types are identified. -FILE_MODEL_IDENTITY = "__dify__file__" - - -def maybe_file_object(o: Any) -> bool: - return isinstance(o, dict) and o.get("dify_model_identity") == FILE_MODEL_IDENTITY diff --git a/api/dify_graph/file/enums.py b/api/dify_graph/file/enums.py deleted file mode 100644 index 170eb4fc233..00000000000 --- a/api/dify_graph/file/enums.py +++ /dev/null @@ -1,57 +0,0 @@ -from enum import StrEnum - - -class FileType(StrEnum): - IMAGE = "image" - DOCUMENT = "document" - AUDIO = "audio" - VIDEO = "video" - CUSTOM = "custom" - - @staticmethod - def value_of(value): - for member in FileType: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class FileTransferMethod(StrEnum): - REMOTE_URL = "remote_url" - LOCAL_FILE = "local_file" - TOOL_FILE = "tool_file" - DATASOURCE_FILE = "datasource_file" - - @staticmethod - def value_of(value): - for member in FileTransferMethod: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class FileBelongsTo(StrEnum): - USER = "user" - ASSISTANT = "assistant" - - @staticmethod - def value_of(value): - for member in FileBelongsTo: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class FileAttribute(StrEnum): - TYPE = "type" - SIZE = "size" - NAME = "name" - MIME_TYPE = "mime_type" - TRANSFER_METHOD = "transfer_method" - URL = "url" - EXTENSION = "extension" - RELATED_ID = "related_id" - - -class ArrayFileAttribute(StrEnum): - LENGTH = "length" diff --git a/api/dify_graph/file/file_manager.py b/api/dify_graph/file/file_manager.py deleted file mode 100644 index 8d998054db7..00000000000 --- a/api/dify_graph/file/file_manager.py +++ /dev/null @@ -1,143 +0,0 @@ -from __future__ import annotations - -import base64 -from collections.abc import Mapping - -from dify_graph.model_runtime.entities import ( - AudioPromptMessageContent, - DocumentPromptMessageContent, - ImagePromptMessageContent, - TextPromptMessageContent, - VideoPromptMessageContent, -) -from dify_graph.model_runtime.entities.message_entities import PromptMessageContentUnionTypes - -from . import helpers -from .enums import FileAttribute -from .models import File, FileTransferMethod, FileType -from .runtime import get_workflow_file_runtime - - -def get_attr(*, file: File, attr: FileAttribute): - match attr: - case FileAttribute.TYPE: - return file.type.value - case FileAttribute.SIZE: - return file.size - case FileAttribute.NAME: - return file.filename - case FileAttribute.MIME_TYPE: - return file.mime_type - case FileAttribute.TRANSFER_METHOD: - return file.transfer_method.value - case FileAttribute.URL: - return _to_url(file) - case FileAttribute.EXTENSION: - return file.extension - case FileAttribute.RELATED_ID: - return file.related_id - - -def to_prompt_message_content( - f: File, - /, - *, - image_detail_config: ImagePromptMessageContent.DETAIL | None = None, -) -> PromptMessageContentUnionTypes: - """Convert a file to prompt message content.""" - if f.extension is None: - raise ValueError("Missing file extension") - if f.mime_type is None: - raise ValueError("Missing file mime_type") - - prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = { - FileType.IMAGE: ImagePromptMessageContent, - FileType.AUDIO: AudioPromptMessageContent, - FileType.VIDEO: VideoPromptMessageContent, - FileType.DOCUMENT: DocumentPromptMessageContent, - } - - if f.type not in prompt_class_map: - return TextPromptMessageContent(data=f"[Unsupported file type: {f.filename} ({f.type.value})]") - - send_format = get_workflow_file_runtime().multimodal_send_format - params = { - "base64_data": _get_encoded_string(f) if send_format == "base64" else "", - "url": _to_url(f) if send_format == "url" else "", - "format": f.extension.removeprefix("."), - "mime_type": f.mime_type, - "filename": f.filename or "", - } - if f.type == FileType.IMAGE: - params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW - - return prompt_class_map[f.type].model_validate(params) - - -def download(f: File, /) -> bytes: - if f.transfer_method in ( - FileTransferMethod.TOOL_FILE, - FileTransferMethod.LOCAL_FILE, - FileTransferMethod.DATASOURCE_FILE, - ): - return _download_file_content(f.storage_key) - elif f.transfer_method == FileTransferMethod.REMOTE_URL: - if f.remote_url is None: - raise ValueError("Missing file remote_url") - response = get_workflow_file_runtime().http_get(f.remote_url, follow_redirects=True) - response.raise_for_status() - return response.content - raise ValueError(f"unsupported transfer method: {f.transfer_method}") - - -def _download_file_content(path: str, /) -> bytes: - """Download and return a file from storage as bytes.""" - data = get_workflow_file_runtime().storage_load(path, stream=False) - if not isinstance(data, bytes): - raise ValueError(f"file {path} is not a bytes object") - return data - - -def _get_encoded_string(f: File, /) -> str: - match f.transfer_method: - case FileTransferMethod.REMOTE_URL: - if f.remote_url is None: - raise ValueError("Missing file remote_url") - response = get_workflow_file_runtime().http_get(f.remote_url, follow_redirects=True) - response.raise_for_status() - data = response.content - case FileTransferMethod.LOCAL_FILE: - data = _download_file_content(f.storage_key) - case FileTransferMethod.TOOL_FILE: - data = _download_file_content(f.storage_key) - case FileTransferMethod.DATASOURCE_FILE: - data = _download_file_content(f.storage_key) - - return base64.b64encode(data).decode("utf-8") - - -def _to_url(f: File, /): - if f.transfer_method == FileTransferMethod.REMOTE_URL: - if f.remote_url is None: - raise ValueError("Missing file remote_url") - return f.remote_url - elif f.transfer_method == FileTransferMethod.LOCAL_FILE: - if f.related_id is None: - raise ValueError("Missing file related_id") - return f.remote_url or helpers.get_signed_file_url(upload_file_id=f.related_id) - elif f.transfer_method == FileTransferMethod.TOOL_FILE: - if f.related_id is None or f.extension is None: - raise ValueError("Missing file related_id or extension") - return helpers.get_signed_tool_file_url(tool_file_id=f.related_id, extension=f.extension) - else: - raise ValueError(f"Unsupported transfer method: {f.transfer_method}") - - -class FileManager: - """Adapter exposing file manager helpers behind FileManagerProtocol.""" - - def download(self, f: File, /) -> bytes: - return download(f) - - -file_manager = FileManager() diff --git a/api/dify_graph/file/helpers.py b/api/dify_graph/file/helpers.py deleted file mode 100644 index 310cb1310b4..00000000000 --- a/api/dify_graph/file/helpers.py +++ /dev/null @@ -1,92 +0,0 @@ -from __future__ import annotations - -import base64 -import hashlib -import hmac -import os -import time -import urllib.parse - -from .runtime import get_workflow_file_runtime - - -def get_signed_file_url(upload_file_id: str, as_attachment: bool = False, for_external: bool = True) -> str: - runtime = get_workflow_file_runtime() - base_url = runtime.files_url if for_external else (runtime.internal_files_url or runtime.files_url) - url = f"{base_url}/files/{upload_file_id}/file-preview" - - timestamp = str(int(time.time())) - nonce = os.urandom(16).hex() - key = runtime.secret_key.encode() - msg = f"file-preview|{upload_file_id}|{timestamp}|{nonce}" - sign = hmac.new(key, msg.encode(), hashlib.sha256).digest() - encoded_sign = base64.urlsafe_b64encode(sign).decode() - query: dict[str, str] = {"timestamp": timestamp, "nonce": nonce, "sign": encoded_sign} - if as_attachment: - query["as_attachment"] = "true" - query_string = urllib.parse.urlencode(query) - - return f"{url}?{query_string}" - - -def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str, user_id: str) -> str: - runtime = get_workflow_file_runtime() - # Plugin access should use internal URL for Docker network communication. - base_url = runtime.internal_files_url or runtime.files_url - url = f"{base_url}/files/upload/for-plugin" - timestamp = str(int(time.time())) - nonce = os.urandom(16).hex() - key = runtime.secret_key.encode() - msg = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}" - sign = hmac.new(key, msg.encode(), hashlib.sha256).digest() - encoded_sign = base64.urlsafe_b64encode(sign).decode() - return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}&user_id={user_id}&tenant_id={tenant_id}" - - -def get_signed_tool_file_url(tool_file_id: str, extension: str, for_external: bool = True) -> str: - runtime = get_workflow_file_runtime() - return runtime.sign_tool_file(tool_file_id=tool_file_id, extension=extension, for_external=for_external) - - -def verify_plugin_file_signature( - *, filename: str, mimetype: str, tenant_id: str, user_id: str, timestamp: str, nonce: str, sign: str -) -> bool: - runtime = get_workflow_file_runtime() - data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}" - secret_key = runtime.secret_key.encode() - recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() - - if sign != recalculated_encoded_sign: - return False - - current_time = int(time.time()) - return current_time - int(timestamp) <= runtime.files_access_timeout - - -def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: - runtime = get_workflow_file_runtime() - data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" - secret_key = runtime.secret_key.encode() - recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() - - if sign != recalculated_encoded_sign: - return False - - current_time = int(time.time()) - return current_time - int(timestamp) <= runtime.files_access_timeout - - -def verify_file_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: - runtime = get_workflow_file_runtime() - data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}" - secret_key = runtime.secret_key.encode() - recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() - - if sign != recalculated_encoded_sign: - return False - - current_time = int(time.time()) - return current_time - int(timestamp) <= runtime.files_access_timeout diff --git a/api/dify_graph/file/models.py b/api/dify_graph/file/models.py deleted file mode 100644 index dcba00978e0..00000000000 --- a/api/dify_graph/file/models.py +++ /dev/null @@ -1,197 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping, Sequence -from typing import Any -from uuid import UUID, uuid4 - -from pydantic import BaseModel, Field, model_validator - -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent - -from . import helpers -from .constants import FILE_MODEL_IDENTITY -from .enums import FileTransferMethod, FileType - - -def sign_tool_file(*, tool_file_id: str, extension: str, for_external: bool = True) -> str: - """Compatibility shim for tests and legacy callers patching ``models.sign_tool_file``.""" - return helpers.get_signed_tool_file_url( - tool_file_id=tool_file_id, - extension=extension, - for_external=for_external, - ) - - -class ImageConfig(BaseModel): - """ - NOTE: This part of validation is deprecated, but still used in app features "Image Upload". - """ - - number_limits: int = 0 - transfer_methods: Sequence[FileTransferMethod] = Field(default_factory=list) - detail: ImagePromptMessageContent.DETAIL | None = None - - -class FileUploadConfig(BaseModel): - """ - File Upload Entity. - """ - - image_config: ImageConfig | None = None - allowed_file_types: Sequence[FileType] = Field(default_factory=list) - allowed_file_extensions: Sequence[str] = Field(default_factory=list) - allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list) - number_limits: int = 0 - - -class ToolFile(BaseModel): - id: UUID = Field(default_factory=uuid4, description="Unique identifier for the file") - user_id: UUID = Field(..., description="ID of the user who owns this file") - tenant_id: UUID = Field(..., description="ID of the tenant/organization") - conversation_id: UUID | None = Field(None, description="ID of the associated conversation") - file_key: str = Field(..., max_length=255, description="Storage key for the file") - mimetype: str = Field(..., max_length=255, description="MIME type of the file") - original_url: str | None = Field( - None, max_length=2048, description="Original URL if file was fetched from external source" - ) - name: str = Field(default="", max_length=255, description="Display name of the file") - size: int = Field(default=-1, ge=-1, description="File size in bytes (-1 if unknown)") - - class Config: - from_attributes = True # Enable ORM mode for SQLAlchemy compatibility - populate_by_name = True - - -class File(BaseModel): - # NOTE: dify_model_identity is a special identifier used to distinguish between - # new and old data formats during serialization and deserialization. - dify_model_identity: str = FILE_MODEL_IDENTITY - - id: str | None = None # message file id - tenant_id: str - type: FileType - transfer_method: FileTransferMethod - # If `transfer_method` is `FileTransferMethod.remote_url`, the - # `remote_url` attribute must not be `None`. - remote_url: str | None = None # remote url - # If `transfer_method` is `FileTransferMethod.local_file` or - # `FileTransferMethod.tool_file`, the `related_id` attribute must not be `None`. - # - # It should be set to `ToolFile.id` when `transfer_method` is `tool_file`. - related_id: str | None = None - filename: str | None = None - extension: str | None = Field(default=None, description="File extension, should contain dot") - mime_type: str | None = None - size: int = -1 - - # Those properties are private, should not be exposed to the outside. - _storage_key: str - - def __init__( - self, - *, - id: str | None = None, - tenant_id: str, - type: FileType, - transfer_method: FileTransferMethod, - remote_url: str | None = None, - related_id: str | None = None, - filename: str | None = None, - extension: str | None = None, - mime_type: str | None = None, - size: int = -1, - storage_key: str | None = None, - dify_model_identity: str | None = FILE_MODEL_IDENTITY, - url: str | None = None, - # Legacy compatibility fields - explicitly handle known extra fields - tool_file_id: str | None = None, - upload_file_id: str | None = None, - datasource_file_id: str | None = None, - ): - super().__init__( - id=id, - tenant_id=tenant_id, - type=type, - transfer_method=transfer_method, - remote_url=remote_url, - related_id=related_id, - filename=filename, - extension=extension, - mime_type=mime_type, - size=size, - dify_model_identity=dify_model_identity, - url=url, - ) - self._storage_key = str(storage_key) - - def to_dict(self) -> Mapping[str, str | int | None]: - data = self.model_dump(mode="json") - return { - **data, - "url": self.generate_url(), - } - - @property - def markdown(self) -> str: - url = self.generate_url() - if self.type == FileType.IMAGE: - text = f"![{self.filename or ''}]({url})" - else: - text = f"[{self.filename or url}]({url})" - - return text - - def generate_url(self, for_external: bool = True) -> str | None: - if self.transfer_method == FileTransferMethod.REMOTE_URL: - return self.remote_url - elif self.transfer_method == FileTransferMethod.LOCAL_FILE: - if self.related_id is None: - raise ValueError("Missing file related_id") - return helpers.get_signed_file_url(upload_file_id=self.related_id, for_external=for_external) - elif self.transfer_method in [FileTransferMethod.TOOL_FILE, FileTransferMethod.DATASOURCE_FILE]: - assert self.related_id is not None - assert self.extension is not None - return sign_tool_file( - tool_file_id=self.related_id, - extension=self.extension, - for_external=for_external, - ) - return None - - def to_plugin_parameter(self) -> dict[str, Any]: - return { - "dify_model_identity": FILE_MODEL_IDENTITY, - "mime_type": self.mime_type, - "filename": self.filename, - "extension": self.extension, - "size": self.size, - "type": self.type, - "url": self.generate_url(for_external=False), - } - - @model_validator(mode="after") - def validate_after(self) -> File: - match self.transfer_method: - case FileTransferMethod.REMOTE_URL: - if not self.remote_url: - raise ValueError("Missing file url") - if not isinstance(self.remote_url, str) or not self.remote_url.startswith("http"): - raise ValueError("Invalid file url") - case FileTransferMethod.LOCAL_FILE: - if not self.related_id: - raise ValueError("Missing file related_id") - case FileTransferMethod.TOOL_FILE: - if not self.related_id: - raise ValueError("Missing file related_id") - case FileTransferMethod.DATASOURCE_FILE: - if not self.related_id: - raise ValueError("Missing file related_id") - return self - - @property - def storage_key(self) -> str: - return self._storage_key - - @storage_key.setter - def storage_key(self, value: str) -> None: - self._storage_key = value diff --git a/api/dify_graph/file/protocols.py b/api/dify_graph/file/protocols.py deleted file mode 100644 index 24cbb42735a..00000000000 --- a/api/dify_graph/file/protocols.py +++ /dev/null @@ -1,43 +0,0 @@ -from __future__ import annotations - -from collections.abc import Generator -from typing import Protocol - - -class HttpResponseProtocol(Protocol): - """Subset of response behavior needed by workflow file helpers.""" - - @property - def content(self) -> bytes: ... - - def raise_for_status(self) -> object: ... - - -class WorkflowFileRuntimeProtocol(Protocol): - """Runtime dependencies required by ``dify_graph.file``. - - Implementations are expected to be provided by integration layers (for example, - ``core.app.workflow.file_runtime``) so the workflow package avoids importing - application infrastructure modules directly. - """ - - @property - def files_url(self) -> str: ... - - @property - def internal_files_url(self) -> str | None: ... - - @property - def secret_key(self) -> str: ... - - @property - def files_access_timeout(self) -> int: ... - - @property - def multimodal_send_format(self) -> str: ... - - def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol: ... - - def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: ... - - def sign_tool_file(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: ... diff --git a/api/dify_graph/file/runtime.py b/api/dify_graph/file/runtime.py deleted file mode 100644 index 94253e0255a..00000000000 --- a/api/dify_graph/file/runtime.py +++ /dev/null @@ -1,58 +0,0 @@ -from __future__ import annotations - -from collections.abc import Generator -from typing import NoReturn - -from .protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol - - -class WorkflowFileRuntimeNotConfiguredError(RuntimeError): - """Raised when workflow file runtime dependencies were not configured.""" - - -class _UnconfiguredWorkflowFileRuntime(WorkflowFileRuntimeProtocol): - def _raise(self) -> NoReturn: - raise WorkflowFileRuntimeNotConfiguredError( - "workflow file runtime is not configured, call set_workflow_file_runtime(...) first" - ) - - @property - def files_url(self) -> str: - self._raise() - - @property - def internal_files_url(self) -> str | None: - self._raise() - - @property - def secret_key(self) -> str: - self._raise() - - @property - def files_access_timeout(self) -> int: - self._raise() - - @property - def multimodal_send_format(self) -> str: - self._raise() - - def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol: - self._raise() - - def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: - self._raise() - - def sign_tool_file(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: - self._raise() - - -_runtime: WorkflowFileRuntimeProtocol = _UnconfiguredWorkflowFileRuntime() - - -def set_workflow_file_runtime(runtime: WorkflowFileRuntimeProtocol) -> None: - global _runtime - _runtime = runtime - - -def get_workflow_file_runtime() -> WorkflowFileRuntimeProtocol: - return _runtime diff --git a/api/dify_graph/file/tool_file_parser.py b/api/dify_graph/file/tool_file_parser.py deleted file mode 100644 index 2d7a3d43df4..00000000000 --- a/api/dify_graph/file/tool_file_parser.py +++ /dev/null @@ -1,9 +0,0 @@ -from collections.abc import Callable -from typing import Any - -_tool_file_manager_factory: Callable[[], Any] | None = None - - -def set_tool_file_manager_factory(factory: Callable[[], Any]): - global _tool_file_manager_factory - _tool_file_manager_factory = factory diff --git a/api/dify_graph/graph/__init__.py b/api/dify_graph/graph/__init__.py deleted file mode 100644 index 4830ea83d3d..00000000000 --- a/api/dify_graph/graph/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from .edge import Edge -from .graph import Graph, GraphBuilder, NodeFactory -from .graph_template import GraphTemplate - -__all__ = [ - "Edge", - "Graph", - "GraphBuilder", - "GraphTemplate", - "NodeFactory", -] diff --git a/api/dify_graph/graph/edge.py b/api/dify_graph/graph/edge.py deleted file mode 100644 index f4f67ea6be8..00000000000 --- a/api/dify_graph/graph/edge.py +++ /dev/null @@ -1,15 +0,0 @@ -import uuid -from dataclasses import dataclass, field - -from dify_graph.enums import NodeState - - -@dataclass -class Edge: - """Edge connecting two nodes in a workflow graph.""" - - id: str = field(default_factory=lambda: str(uuid.uuid4())) - tail: str = "" # tail node id (source) - head: str = "" # head node id (target) - source_handle: str = "source" # source handle for conditional branching - state: NodeState = field(default=NodeState.UNKNOWN) # edge execution state diff --git a/api/dify_graph/graph/graph.py b/api/dify_graph/graph/graph.py deleted file mode 100644 index 85117583e07..00000000000 --- a/api/dify_graph/graph/graph.py +++ /dev/null @@ -1,439 +0,0 @@ -from __future__ import annotations - -import logging -from collections import defaultdict -from collections.abc import Mapping, Sequence -from typing import Protocol, cast, final - -from pydantic import TypeAdapter - -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeState -from dify_graph.nodes.base.node import Node -from libs.typing import is_str - -from .edge import Edge -from .validation import get_graph_validator - -logger = logging.getLogger(__name__) - -_ListNodeConfigDict = TypeAdapter(list[NodeConfigDict]) - - -class NodeFactory(Protocol): - """ - Protocol for creating Node instances from node data dictionaries. - - This protocol decouples the Graph class from specific node mapping implementations, - allowing for different node creation strategies while maintaining type safety. - """ - - def create_node(self, node_config: NodeConfigDict) -> Node: - """ - Create a Node instance from node configuration data. - - :param node_config: node configuration dictionary containing type and other data - :return: initialized Node instance - :raises ValueError: if node type is unknown or no implementation exists for the resolved version - :raises ValidationError: if node_config does not satisfy NodeConfigDict/BaseNodeData validation - """ - ... - - -@final -class Graph: - """Graph representation with nodes and edges for workflow execution.""" - - def __init__( - self, - *, - nodes: dict[str, Node] | None = None, - edges: dict[str, Edge] | None = None, - in_edges: dict[str, list[str]] | None = None, - out_edges: dict[str, list[str]] | None = None, - root_node: Node, - ): - """ - Initialize Graph instance. - - :param nodes: graph nodes mapping (node id: node object) - :param edges: graph edges mapping (edge id: edge object) - :param in_edges: incoming edges mapping (node id: list of edge ids) - :param out_edges: outgoing edges mapping (node id: list of edge ids) - :param root_node: root node object - """ - self.nodes = nodes or {} - self.edges = edges or {} - self.in_edges = in_edges or {} - self.out_edges = out_edges or {} - self.root_node = root_node - - @classmethod - def _parse_node_configs(cls, node_configs: list[NodeConfigDict]) -> dict[str, NodeConfigDict]: - """ - Parse node configurations and build a mapping of node IDs to configs. - - :param node_configs: list of node configuration dictionaries - :return: mapping of node ID to node config - """ - node_configs_map: dict[str, NodeConfigDict] = {} - - for node_config in node_configs: - node_configs_map[node_config["id"]] = node_config - - return node_configs_map - - @classmethod - def _build_edges( - cls, edge_configs: list[dict[str, object]] - ) -> tuple[dict[str, Edge], dict[str, list[str]], dict[str, list[str]]]: - """ - Build edge objects and mappings from edge configurations. - - :param edge_configs: list of edge configurations - :return: tuple of (edges dict, in_edges dict, out_edges dict) - """ - edges: dict[str, Edge] = {} - in_edges: dict[str, list[str]] = defaultdict(list) - out_edges: dict[str, list[str]] = defaultdict(list) - - edge_counter = 0 - for edge_config in edge_configs: - source = edge_config.get("source") - target = edge_config.get("target") - - if not is_str(source) or not is_str(target): - continue - - # Create edge - edge_id = f"edge_{edge_counter}" - edge_counter += 1 - - source_handle = edge_config.get("sourceHandle", "source") - if not is_str(source_handle): - continue - - edge = Edge( - id=edge_id, - tail=source, - head=target, - source_handle=source_handle, - ) - - edges[edge_id] = edge - out_edges[source].append(edge_id) - in_edges[target].append(edge_id) - - return edges, dict(in_edges), dict(out_edges) - - @classmethod - def _create_node_instances( - cls, - node_configs_map: dict[str, NodeConfigDict], - node_factory: NodeFactory, - ) -> dict[str, Node]: - """ - Create node instances from configurations using the node factory. - - :param node_configs_map: mapping of node ID to node config - :param node_factory: factory for creating node instances - :return: mapping of node ID to node instance - """ - nodes: dict[str, Node] = {} - - for node_id, node_config in node_configs_map.items(): - try: - node_instance = node_factory.create_node(node_config) - except Exception: - logger.exception("Failed to create node instance for node_id %s", node_id) - raise - nodes[node_id] = node_instance - - return nodes - - @classmethod - def new(cls) -> GraphBuilder: - """Create a fluent builder for assembling a graph programmatically.""" - - return GraphBuilder(graph_cls=cls) - - @staticmethod - def _filter_canvas_only_nodes(node_configs: Sequence[Mapping[str, object]]) -> list[dict[str, object]]: - """ - Remove editor-only nodes before `NodeConfigDict` validation. - - Persisted note widgets use a top-level `type == "custom-note"` but leave - `data.type` empty because they are never executable graph nodes. Filter - them while configs are still raw dicts so Pydantic does not validate - their placeholder payloads against `BaseNodeData.type: NodeType`. - """ - filtered_node_configs: list[dict[str, object]] = [] - for node_config in node_configs: - if node_config.get("type", "") == "custom-note": - continue - filtered_node_configs.append(dict(node_config)) - return filtered_node_configs - - @classmethod - def _promote_fail_branch_nodes(cls, nodes: dict[str, Node]) -> None: - """ - Promote nodes configured with FAIL_BRANCH error strategy to branch execution type. - - :param nodes: mapping of node ID to node instance - """ - for node in nodes.values(): - if node.error_strategy == ErrorStrategy.FAIL_BRANCH: - node.execution_type = NodeExecutionType.BRANCH - - @classmethod - def _mark_inactive_root_branches( - cls, - nodes: dict[str, Node], - edges: dict[str, Edge], - in_edges: dict[str, list[str]], - out_edges: dict[str, list[str]], - active_root_id: str, - ) -> None: - """ - Mark nodes and edges from inactive root branches as skipped. - - Algorithm: - 1. Mark inactive root nodes as skipped - 2. For skipped nodes, mark all their outgoing edges as skipped - 3. For each edge marked as skipped, check its target node: - - If ALL incoming edges are skipped, mark the node as skipped - - Otherwise, leave the node state unchanged - - :param nodes: mapping of node ID to node instance - :param edges: mapping of edge ID to edge instance - :param in_edges: mapping of node ID to incoming edge IDs - :param out_edges: mapping of node ID to outgoing edge IDs - :param active_root_id: ID of the active root node - """ - # Find all top-level root nodes (nodes with ROOT execution type and no incoming edges) - top_level_roots: list[str] = [ - node.id for node in nodes.values() if node.execution_type == NodeExecutionType.ROOT - ] - - # If there's only one root or the active root is not a top-level root, no marking needed - if len(top_level_roots) <= 1 or active_root_id not in top_level_roots: - return - - # Mark inactive root nodes as skipped - inactive_roots: list[str] = [root_id for root_id in top_level_roots if root_id != active_root_id] - for root_id in inactive_roots: - if root_id in nodes: - nodes[root_id].state = NodeState.SKIPPED - - # Recursively mark downstream nodes and edges - def mark_downstream(node_id: str) -> None: - """Recursively mark downstream nodes and edges as skipped.""" - if nodes[node_id].state != NodeState.SKIPPED: - return - # If this node is skipped, mark all its outgoing edges as skipped - out_edge_ids = out_edges.get(node_id, []) - for edge_id in out_edge_ids: - edge = edges[edge_id] - edge.state = NodeState.SKIPPED - - # Check the target node of this edge - target_node = nodes[edge.head] - in_edge_ids = in_edges.get(target_node.id, []) - in_edge_states = [edges[eid].state for eid in in_edge_ids] - - # If all incoming edges are skipped, mark the node as skipped - if all(state == NodeState.SKIPPED for state in in_edge_states): - target_node.state = NodeState.SKIPPED - # Recursively process downstream nodes - mark_downstream(target_node.id) - - # Process each inactive root and its downstream nodes - for root_id in inactive_roots: - mark_downstream(root_id) - - @classmethod - def init( - cls, - *, - graph_config: Mapping[str, object], - node_factory: NodeFactory, - root_node_id: str, - skip_validation: bool = False, - ) -> Graph: - """ - Initialize a graph with an explicit execution entry point. - - :param graph_config: graph config containing nodes and edges - :param node_factory: factory for creating node instances from config data - :param root_node_id: active root node id - :return: graph instance - """ - # Parse configs - edge_configs = graph_config.get("edges", []) - node_configs = graph_config.get("nodes", []) - - edge_configs = cast(list[dict[str, object]], edge_configs) - node_configs = cast(list[dict[str, object]], node_configs) - node_configs = cls._filter_canvas_only_nodes(node_configs) - node_configs = _ListNodeConfigDict.validate_python(node_configs) - - if not node_configs: - raise ValueError("Graph must have at least one node") - - # Parse node configurations - node_configs_map = cls._parse_node_configs(node_configs) - - if root_node_id not in node_configs_map: - raise ValueError(f"Root node id {root_node_id} not found in the graph") - - # Build edges - edges, in_edges, out_edges = cls._build_edges(edge_configs) - - # Create node instances - nodes = cls._create_node_instances(node_configs_map, node_factory) - - # Promote fail-branch nodes to branch execution type at graph level - cls._promote_fail_branch_nodes(nodes) - - # Get root node instance - root_node = nodes[root_node_id] - - # Mark inactive root branches as skipped - cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id) - - # Create and return the graph - graph = cls( - nodes=nodes, - edges=edges, - in_edges=in_edges, - out_edges=out_edges, - root_node=root_node, - ) - - if not skip_validation: - # Validate the graph structure using built-in validators - get_graph_validator().validate(graph) - - return graph - - @property - def node_ids(self) -> list[str]: - """ - Get list of node IDs (compatibility property for existing code) - - :return: list of node IDs - """ - return list(self.nodes.keys()) - - def get_outgoing_edges(self, node_id: str) -> list[Edge]: - """ - Get all outgoing edges from a node (V2 method) - - :param node_id: node id - :return: list of outgoing edges - """ - edge_ids = self.out_edges.get(node_id, []) - return [self.edges[eid] for eid in edge_ids if eid in self.edges] - - def get_incoming_edges(self, node_id: str) -> list[Edge]: - """ - Get all incoming edges to a node (V2 method) - - :param node_id: node id - :return: list of incoming edges - """ - edge_ids = self.in_edges.get(node_id, []) - return [self.edges[eid] for eid in edge_ids if eid in self.edges] - - -@final -class GraphBuilder: - """Fluent helper for constructing simple graphs, primarily for tests.""" - - def __init__(self, *, graph_cls: type[Graph]): - self._graph_cls = graph_cls - self._nodes: list[Node] = [] - self._nodes_by_id: dict[str, Node] = {} - self._edges: list[Edge] = [] - self._edge_counter = 0 - - def add_root(self, node: Node) -> GraphBuilder: - """Register the root node. Must be called exactly once.""" - - if self._nodes: - raise ValueError("Root node has already been added") - self._register_node(node) - self._nodes.append(node) - return self - - def add_node( - self, - node: Node, - *, - from_node_id: str | None = None, - source_handle: str = "source", - ) -> GraphBuilder: - """Append a node and connect it from the specified predecessor.""" - - if not self._nodes: - raise ValueError("Root node must be added before adding other nodes") - - predecessor_id = from_node_id or self._nodes[-1].id - if predecessor_id not in self._nodes_by_id: - raise ValueError(f"Predecessor node '{predecessor_id}' not found") - - predecessor = self._nodes_by_id[predecessor_id] - self._register_node(node) - self._nodes.append(node) - - edge_id = f"edge_{self._edge_counter}" - self._edge_counter += 1 - edge = Edge(id=edge_id, tail=predecessor.id, head=node.id, source_handle=source_handle) - self._edges.append(edge) - - return self - - def connect(self, *, tail: str, head: str, source_handle: str = "source") -> GraphBuilder: - """Connect two existing nodes without adding a new node.""" - - if tail not in self._nodes_by_id: - raise ValueError(f"Tail node '{tail}' not found") - if head not in self._nodes_by_id: - raise ValueError(f"Head node '{head}' not found") - - edge_id = f"edge_{self._edge_counter}" - self._edge_counter += 1 - edge = Edge(id=edge_id, tail=tail, head=head, source_handle=source_handle) - self._edges.append(edge) - - return self - - def build(self) -> Graph: - """Materialize the graph instance from the accumulated nodes and edges.""" - - if not self._nodes: - raise ValueError("Cannot build an empty graph") - - nodes = {node.id: node for node in self._nodes} - edges = {edge.id: edge for edge in self._edges} - in_edges: dict[str, list[str]] = defaultdict(list) - out_edges: dict[str, list[str]] = defaultdict(list) - - for edge in self._edges: - out_edges[edge.tail].append(edge.id) - in_edges[edge.head].append(edge.id) - - return self._graph_cls( - nodes=nodes, - edges=edges, - in_edges=dict(in_edges), - out_edges=dict(out_edges), - root_node=self._nodes[0], - ) - - def _register_node(self, node: Node) -> None: - if not node.id: - raise ValueError("Node must have a non-empty id") - if node.id in self._nodes_by_id: - raise ValueError(f"Duplicate node id detected: {node.id}") - self._nodes_by_id[node.id] = node diff --git a/api/dify_graph/graph/graph_template.py b/api/dify_graph/graph/graph_template.py deleted file mode 100644 index 34e2dc19e60..00000000000 --- a/api/dify_graph/graph/graph_template.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import Any - -from pydantic import BaseModel, Field - - -class GraphTemplate(BaseModel): - """ - Graph Template for container nodes and subgraph expansion - - According to GraphEngine V2 spec, GraphTemplate contains: - - nodes: mapping of node definitions - - edges: mapping of edge definitions - - root_ids: list of root node IDs - - output_selectors: list of output selectors for the template - """ - - nodes: dict[str, dict[str, Any]] = Field(default_factory=dict, description="node definitions mapping") - edges: dict[str, dict[str, Any]] = Field(default_factory=dict, description="edge definitions mapping") - root_ids: list[str] = Field(default_factory=list, description="root node IDs") - output_selectors: list[str] = Field(default_factory=list, description="output selectors") diff --git a/api/dify_graph/graph/validation.py b/api/dify_graph/graph/validation.py deleted file mode 100644 index 50d1440b044..00000000000 --- a/api/dify_graph/graph/validation.py +++ /dev/null @@ -1,125 +0,0 @@ -from __future__ import annotations - -from collections.abc import Sequence -from dataclasses import dataclass -from typing import TYPE_CHECKING, Protocol - -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeType - -if TYPE_CHECKING: - from .graph import Graph - - -@dataclass(frozen=True, slots=True) -class GraphValidationIssue: - """Immutable value object describing a single validation issue.""" - - code: str - message: str - node_id: str | None = None - - -class GraphValidationError(ValueError): - """Raised when graph validation fails.""" - - def __init__(self, issues: Sequence[GraphValidationIssue]) -> None: - if not issues: - raise ValueError("GraphValidationError requires at least one issue.") - self.issues: tuple[GraphValidationIssue, ...] = tuple(issues) - message = "; ".join(f"[{issue.code}] {issue.message}" for issue in self.issues) - super().__init__(message) - - -class GraphValidationRule(Protocol): - """Protocol that individual validation rules must satisfy.""" - - def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]: - """Validate the provided graph and return any discovered issues.""" - ... - - -@dataclass(frozen=True, slots=True) -class _EdgeEndpointValidator: - """Ensures all edges reference existing nodes.""" - - missing_node_code: str = "MISSING_NODE" - - def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]: - issues: list[GraphValidationIssue] = [] - for edge in graph.edges.values(): - if edge.tail not in graph.nodes: - issues.append( - GraphValidationIssue( - code=self.missing_node_code, - message=f"Edge {edge.id} references unknown source node '{edge.tail}'.", - node_id=edge.tail, - ) - ) - if edge.head not in graph.nodes: - issues.append( - GraphValidationIssue( - code=self.missing_node_code, - message=f"Edge {edge.id} references unknown target node '{edge.head}'.", - node_id=edge.head, - ) - ) - return issues - - -@dataclass(frozen=True, slots=True) -class _RootNodeValidator: - """Validates root node invariants.""" - - invalid_root_code: str = "INVALID_ROOT" - container_entry_types: tuple[NodeType, ...] = (BuiltinNodeTypes.ITERATION_START, BuiltinNodeTypes.LOOP_START) - - def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]: - root_node = graph.root_node - issues: list[GraphValidationIssue] = [] - if root_node.id not in graph.nodes: - issues.append( - GraphValidationIssue( - code=self.invalid_root_code, - message=f"Root node '{root_node.id}' is missing from the node registry.", - node_id=root_node.id, - ) - ) - return issues - - node_type = root_node.node_type - if root_node.execution_type != NodeExecutionType.ROOT and node_type not in self.container_entry_types: - issues.append( - GraphValidationIssue( - code=self.invalid_root_code, - message=f"Root node '{root_node.id}' must declare execution type 'root'.", - node_id=root_node.id, - ) - ) - return issues - - -@dataclass(frozen=True, slots=True) -class GraphValidator: - """Coordinates execution of graph validation rules.""" - - rules: tuple[GraphValidationRule, ...] - - def validate(self, graph: Graph) -> None: - """Validate the graph against all configured rules.""" - issues: list[GraphValidationIssue] = [] - for rule in self.rules: - issues.extend(rule.validate(graph)) - - if issues: - raise GraphValidationError(issues) - - -_DEFAULT_RULES: tuple[GraphValidationRule, ...] = ( - _EdgeEndpointValidator(), - _RootNodeValidator(), -) - - -def get_graph_validator() -> GraphValidator: - """Construct the validator composed of default rules.""" - return GraphValidator(_DEFAULT_RULES) diff --git a/api/dify_graph/graph_engine/__init__.py b/api/dify_graph/graph_engine/__init__.py deleted file mode 100644 index 0e1c7dd60a7..00000000000 --- a/api/dify_graph/graph_engine/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .config import GraphEngineConfig -from .graph_engine import GraphEngine - -__all__ = ["GraphEngine", "GraphEngineConfig"] diff --git a/api/dify_graph/graph_engine/_engine_utils.py b/api/dify_graph/graph_engine/_engine_utils.py deleted file mode 100644 index 28898268fea..00000000000 --- a/api/dify_graph/graph_engine/_engine_utils.py +++ /dev/null @@ -1,15 +0,0 @@ -import time - - -def get_timestamp() -> float: - """Retrieve a timestamp as a float point numer representing the number of seconds - since the Unix epoch. - - This function is primarily used to measure the execution time of the workflow engine. - Since workflow execution may be paused and resumed on a different machine, - `time.perf_counter` cannot be used as it is inconsistent across machines. - - To address this, the function uses the wall clock as the time source. - However, it assumes that the clocks of all servers are properly synchronized. - """ - return round(time.time()) diff --git a/api/dify_graph/graph_engine/command_channels/README.md b/api/dify_graph/graph_engine/command_channels/README.md deleted file mode 100644 index e35e12054ae..00000000000 --- a/api/dify_graph/graph_engine/command_channels/README.md +++ /dev/null @@ -1,33 +0,0 @@ -# Command Channels - -Channel implementations for external workflow control. - -## Components - -### InMemoryChannel - -Thread-safe in-memory queue for single-process deployments. - -- `fetch_commands()` - Get pending commands -- `send_command()` - Add command to queue - -### RedisChannel - -Redis-based queue for distributed deployments. - -- `fetch_commands()` - Get commands with JSON deserialization -- `send_command()` - Store commands with TTL - -## Usage - -```python -# Local execution -channel = InMemoryChannel() -channel.send_command(AbortCommand(graph_id="workflow-123")) - -# Distributed execution -redis_channel = RedisChannel( - redis_client=redis_client, - channel_key="workflow:123:commands" -) -``` diff --git a/api/dify_graph/graph_engine/command_channels/__init__.py b/api/dify_graph/graph_engine/command_channels/__init__.py deleted file mode 100644 index 863e6032d60..00000000000 --- a/api/dify_graph/graph_engine/command_channels/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Command channel implementations for GraphEngine.""" - -from .in_memory_channel import InMemoryChannel -from .redis_channel import RedisChannel - -__all__ = ["InMemoryChannel", "RedisChannel"] diff --git a/api/dify_graph/graph_engine/command_channels/in_memory_channel.py b/api/dify_graph/graph_engine/command_channels/in_memory_channel.py deleted file mode 100644 index bdaf2367967..00000000000 --- a/api/dify_graph/graph_engine/command_channels/in_memory_channel.py +++ /dev/null @@ -1,53 +0,0 @@ -""" -In-memory implementation of CommandChannel for local/testing scenarios. - -This implementation uses a thread-safe queue for command communication -within a single process. Each instance handles commands for one workflow execution. -""" - -from queue import Queue -from typing import final - -from ..entities.commands import GraphEngineCommand - - -@final -class InMemoryChannel: - """ - In-memory command channel implementation using a thread-safe queue. - - Each instance is dedicated to a single GraphEngine/workflow execution. - Suitable for local development, testing, and single-instance deployments. - """ - - def __init__(self) -> None: - """Initialize the in-memory channel with a single queue.""" - self._queue: Queue[GraphEngineCommand] = Queue() - - def fetch_commands(self) -> list[GraphEngineCommand]: - """ - Fetch all pending commands from the queue. - - Returns: - List of pending commands (drains the queue) - """ - commands: list[GraphEngineCommand] = [] - - # Drain all available commands from the queue - while not self._queue.empty(): - try: - command = self._queue.get_nowait() - commands.append(command) - except Exception: - break - - return commands - - def send_command(self, command: GraphEngineCommand) -> None: - """ - Send a command to this channel's queue. - - Args: - command: The command to send - """ - self._queue.put(command) diff --git a/api/dify_graph/graph_engine/command_channels/redis_channel.py b/api/dify_graph/graph_engine/command_channels/redis_channel.py deleted file mode 100644 index 77cf884c67a..00000000000 --- a/api/dify_graph/graph_engine/command_channels/redis_channel.py +++ /dev/null @@ -1,153 +0,0 @@ -""" -Redis-based implementation of CommandChannel for distributed scenarios. - -This implementation uses Redis lists for command queuing, supporting -multi-instance deployments and cross-server communication. -Each instance uses a unique key for its command queue. -""" - -import json -from contextlib import AbstractContextManager -from typing import Any, Protocol, final - -from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand, UpdateVariablesCommand - - -class RedisPipelineProtocol(Protocol): - """Minimal Redis pipeline contract used by the command channel.""" - - def lrange(self, name: str, start: int, end: int) -> Any: ... - def delete(self, *names: str) -> Any: ... - def execute(self) -> list[Any]: ... - def rpush(self, name: str, *values: str) -> Any: ... - def expire(self, name: str, time: int) -> Any: ... - def set(self, name: str, value: str, ex: int | None = None) -> Any: ... - def get(self, name: str) -> Any: ... - - -class RedisClientProtocol(Protocol): - """Redis client contract required by the command channel.""" - - def pipeline(self) -> AbstractContextManager[RedisPipelineProtocol]: ... - - -@final -class RedisChannel: - """ - Redis-based command channel implementation for distributed systems. - - Each instance uses a unique Redis key for its command queue. - Commands are JSON-serialized for transport. - """ - - def __init__( - self, - redis_client: RedisClientProtocol, - channel_key: str, - command_ttl: int = 3600, - ) -> None: - """ - Initialize the Redis channel. - - Args: - redis_client: Redis client instance - channel_key: Unique key for this channel's command queue - command_ttl: TTL for command keys in seconds (default: 3600) - """ - self._redis = redis_client - self._key = channel_key - self._command_ttl = command_ttl - self._pending_key = f"{channel_key}:pending" - - def fetch_commands(self) -> list[GraphEngineCommand]: - """ - Fetch all pending commands from Redis. - - Returns: - List of pending commands (drains the Redis list) - """ - if not self._has_pending_commands(): - return [] - - commands: list[GraphEngineCommand] = [] - - # Use pipeline for atomic operations - with self._redis.pipeline() as pipe: - # Get all commands and clear the list atomically - pipe.lrange(self._key, 0, -1) - pipe.delete(self._key) - results = pipe.execute() - - # Parse commands from JSON - if results[0]: - for command_json in results[0]: - try: - command_data = json.loads(command_json) - command = self._deserialize_command(command_data) - if command: - commands.append(command) - except (json.JSONDecodeError, ValueError): - # Skip invalid commands - continue - - return commands - - def send_command(self, command: GraphEngineCommand) -> None: - """ - Send a command to Redis. - - Args: - command: The command to send - """ - command_json = json.dumps(command.model_dump()) - - # Push to list and set expiry - with self._redis.pipeline() as pipe: - pipe.rpush(self._key, command_json) - pipe.expire(self._key, self._command_ttl) - pipe.set(self._pending_key, "1", ex=self._command_ttl) - pipe.execute() - - def _deserialize_command(self, data: dict[str, Any]) -> GraphEngineCommand | None: - """ - Deserialize a command from dictionary data. - - Args: - data: Command data dictionary - - Returns: - Deserialized command or None if invalid - """ - command_type_value = data.get("command_type") - if not isinstance(command_type_value, str): - return None - - try: - command_type = CommandType(command_type_value) - - if command_type == CommandType.ABORT: - return AbortCommand.model_validate(data) - if command_type == CommandType.PAUSE: - return PauseCommand.model_validate(data) - if command_type == CommandType.UPDATE_VARIABLES: - return UpdateVariablesCommand.model_validate(data) - - # For other command types, use base class - return GraphEngineCommand.model_validate(data) - - except (ValueError, TypeError): - return None - - def _has_pending_commands(self) -> bool: - """ - Check and consume the pending marker to avoid unnecessary list reads. - - Returns: - True if commands should be fetched from Redis. - """ - with self._redis.pipeline() as pipe: - pipe.get(self._pending_key) - pipe.delete(self._pending_key) - pending_value, _ = pipe.execute() - - return pending_value is not None diff --git a/api/dify_graph/graph_engine/command_processing/__init__.py b/api/dify_graph/graph_engine/command_processing/__init__.py deleted file mode 100644 index 7b4f0dfff79..00000000000 --- a/api/dify_graph/graph_engine/command_processing/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -Command processing subsystem for graph engine. - -This package handles external commands sent to the engine -during execution. -""" - -from .command_handlers import AbortCommandHandler, PauseCommandHandler, UpdateVariablesCommandHandler -from .command_processor import CommandProcessor - -__all__ = [ - "AbortCommandHandler", - "CommandProcessor", - "PauseCommandHandler", - "UpdateVariablesCommandHandler", -] diff --git a/api/dify_graph/graph_engine/command_processing/command_handlers.py b/api/dify_graph/graph_engine/command_processing/command_handlers.py deleted file mode 100644 index eefd0c366b4..00000000000 --- a/api/dify_graph/graph_engine/command_processing/command_handlers.py +++ /dev/null @@ -1,56 +0,0 @@ -import logging -from typing import final - -from typing_extensions import override - -from dify_graph.entities.pause_reason import SchedulingPause -from dify_graph.runtime import VariablePool - -from ..domain.graph_execution import GraphExecution -from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand, UpdateVariablesCommand -from .command_processor import CommandHandler - -logger = logging.getLogger(__name__) - - -@final -class AbortCommandHandler(CommandHandler): - @override - def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: - assert isinstance(command, AbortCommand) - logger.debug("Aborting workflow %s: %s", execution.workflow_id, command.reason) - execution.abort(command.reason or "User requested abort") - - -@final -class PauseCommandHandler(CommandHandler): - @override - def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: - assert isinstance(command, PauseCommand) - logger.debug("Pausing workflow %s: %s", execution.workflow_id, command.reason) - # Convert string reason to PauseReason if needed - reason = command.reason - pause_reason = SchedulingPause(message=reason) - execution.pause(pause_reason) - - -@final -class UpdateVariablesCommandHandler(CommandHandler): - def __init__(self, variable_pool: VariablePool) -> None: - self._variable_pool = variable_pool - - @override - def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: - assert isinstance(command, UpdateVariablesCommand) - for update in command.updates: - try: - variable = update.value - self._variable_pool.add(variable.selector, variable) - logger.debug("Updated variable %s for workflow %s", variable.selector, execution.workflow_id) - except ValueError as exc: - logger.warning( - "Skipping invalid variable selector %s for workflow %s: %s", - getattr(update.value, "selector", None), - execution.workflow_id, - exc, - ) diff --git a/api/dify_graph/graph_engine/command_processing/command_processor.py b/api/dify_graph/graph_engine/command_processing/command_processor.py deleted file mode 100644 index 942c2d77a5a..00000000000 --- a/api/dify_graph/graph_engine/command_processing/command_processor.py +++ /dev/null @@ -1,79 +0,0 @@ -""" -Main command processor for handling external commands. -""" - -import logging -from typing import Protocol, final - -from ..domain.graph_execution import GraphExecution -from ..entities.commands import GraphEngineCommand -from ..protocols.command_channel import CommandChannel - -logger = logging.getLogger(__name__) - - -class CommandHandler(Protocol): - """Protocol for command handlers.""" - - def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: ... - - -@final -class CommandProcessor: - """ - Processes external commands sent to the engine. - - This polls the command channel and dispatches commands to - appropriate handlers. - """ - - def __init__( - self, - command_channel: CommandChannel, - graph_execution: GraphExecution, - ) -> None: - """ - Initialize the command processor. - - Args: - command_channel: Channel for receiving commands - graph_execution: Graph execution aggregate - """ - self._command_channel = command_channel - self._graph_execution = graph_execution - self._handlers: dict[type[GraphEngineCommand], CommandHandler] = {} - - def register_handler(self, command_type: type[GraphEngineCommand], handler: CommandHandler) -> None: - """ - Register a handler for a command type. - - Args: - command_type: Type of command to handle - handler: Handler for the command - """ - self._handlers[command_type] = handler - - def process_commands(self) -> None: - """Check for and process any pending commands.""" - try: - commands = self._command_channel.fetch_commands() - for command in commands: - self._handle_command(command) - except Exception as e: - logger.warning("Error processing commands: %s", e) - - def _handle_command(self, command: GraphEngineCommand) -> None: - """ - Handle a single command. - - Args: - command: The command to handle - """ - handler = self._handlers.get(type(command)) - if handler: - try: - handler.handle(command, self._graph_execution) - except Exception: - logger.exception("Error handling command %s", command.__class__.__name__) - else: - logger.warning("No handler registered for command: %s", command.__class__.__name__) diff --git a/api/dify_graph/graph_engine/config.py b/api/dify_graph/graph_engine/config.py deleted file mode 100644 index d56a69cee03..00000000000 --- a/api/dify_graph/graph_engine/config.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -GraphEngine configuration models. -""" - -from pydantic import BaseModel, ConfigDict - - -class GraphEngineConfig(BaseModel): - """Configuration for GraphEngine worker pool scaling.""" - - model_config = ConfigDict(frozen=True) - - min_workers: int = 1 - max_workers: int = 5 - scale_up_threshold: int = 3 - scale_down_idle_time: float = 5.0 diff --git a/api/dify_graph/graph_engine/domain/__init__.py b/api/dify_graph/graph_engine/domain/__init__.py deleted file mode 100644 index 9e9afe4c219..00000000000 --- a/api/dify_graph/graph_engine/domain/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -Domain models for graph engine. - -This package contains the core domain entities, value objects, and aggregates -that represent the business concepts of workflow graph execution. -""" - -from .graph_execution import GraphExecution -from .node_execution import NodeExecution - -__all__ = [ - "GraphExecution", - "NodeExecution", -] diff --git a/api/dify_graph/graph_engine/domain/graph_execution.py b/api/dify_graph/graph_engine/domain/graph_execution.py deleted file mode 100644 index 0ee4a9f9a7c..00000000000 --- a/api/dify_graph/graph_engine/domain/graph_execution.py +++ /dev/null @@ -1,242 +0,0 @@ -"""GraphExecution aggregate root managing the overall graph execution state.""" - -from __future__ import annotations - -from dataclasses import dataclass, field -from importlib import import_module -from typing import Literal - -from pydantic import BaseModel, Field - -from dify_graph.entities.pause_reason import PauseReason -from dify_graph.enums import NodeState -from dify_graph.runtime.graph_runtime_state import GraphExecutionProtocol - -from .node_execution import NodeExecution - - -class GraphExecutionErrorState(BaseModel): - """Serializable representation of an execution error.""" - - module: str = Field(description="Module containing the exception class") - qualname: str = Field(description="Qualified name of the exception class") - message: str | None = Field(default=None, description="Exception message string") - - -class NodeExecutionState(BaseModel): - """Serializable representation of a node execution entity.""" - - node_id: str - state: NodeState = Field(default=NodeState.UNKNOWN) - retry_count: int = Field(default=0) - execution_id: str | None = Field(default=None) - error: str | None = Field(default=None) - - -class GraphExecutionState(BaseModel): - """Pydantic model describing serialized GraphExecution state.""" - - type: Literal["GraphExecution"] = Field(default="GraphExecution") - version: str = Field(default="1.0") - workflow_id: str - started: bool = Field(default=False) - completed: bool = Field(default=False) - aborted: bool = Field(default=False) - paused: bool = Field(default=False) - pause_reasons: list[PauseReason] = Field(default_factory=list) - error: GraphExecutionErrorState | None = Field(default=None) - exceptions_count: int = Field(default=0) - node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState]) - - -def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None: - """Convert an exception into its serializable representation.""" - - if error is None: - return None - - return GraphExecutionErrorState( - module=error.__class__.__module__, - qualname=error.__class__.__qualname__, - message=str(error), - ) - - -def _resolve_exception_class(module_name: str, qualname: str) -> type[Exception]: - """Locate an exception class from its module and qualified name.""" - - module = import_module(module_name) - attr: object = module - for part in qualname.split("."): - attr = getattr(attr, part) - - if isinstance(attr, type) and issubclass(attr, Exception): - return attr - - raise TypeError(f"{qualname} in {module_name} is not an Exception subclass") - - -def _deserialize_error(state: GraphExecutionErrorState | None) -> Exception | None: - """Reconstruct an exception instance from serialized data.""" - - if state is None: - return None - - try: - exception_class = _resolve_exception_class(state.module, state.qualname) - if state.message is None: - return exception_class() - return exception_class(state.message) - except Exception: - # Fallback to RuntimeError when reconstruction fails - if state.message is None: - return RuntimeError(state.qualname) - return RuntimeError(state.message) - - -@dataclass -class GraphExecution: - """ - Aggregate root for graph execution. - - This manages the overall execution state of a workflow graph, - coordinating between multiple node executions. - """ - - workflow_id: str - started: bool = False - completed: bool = False - aborted: bool = False - paused: bool = False - pause_reasons: list[PauseReason] = field(default_factory=list) - error: Exception | None = None - node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution]) - exceptions_count: int = 0 - - def start(self) -> None: - """Mark the graph execution as started.""" - if self.started: - raise RuntimeError("Graph execution already started") - self.started = True - - def complete(self) -> None: - """Mark the graph execution as completed.""" - if not self.started: - raise RuntimeError("Cannot complete execution that hasn't started") - if self.completed: - raise RuntimeError("Graph execution already completed") - self.completed = True - - def abort(self, reason: str) -> None: - """Abort the graph execution.""" - self.aborted = True - self.error = RuntimeError(f"Aborted: {reason}") - - def pause(self, reason: PauseReason) -> None: - """Pause the graph execution without marking it complete.""" - if self.completed: - raise RuntimeError("Cannot pause execution that has completed") - if self.aborted: - raise RuntimeError("Cannot pause execution that has been aborted") - self.paused = True - self.pause_reasons.append(reason) - - def fail(self, error: Exception) -> None: - """Mark the graph execution as failed.""" - self.error = error - self.completed = True - - def get_or_create_node_execution(self, node_id: str) -> NodeExecution: - """Get or create a node execution entity.""" - if node_id not in self.node_executions: - self.node_executions[node_id] = NodeExecution(node_id=node_id) - return self.node_executions[node_id] - - @property - def is_running(self) -> bool: - """Check if the execution is currently running.""" - return self.started and not self.completed and not self.aborted and not self.paused - - @property - def is_paused(self) -> bool: - """Check if the execution is currently paused.""" - return self.paused - - @property - def has_error(self) -> bool: - """Check if the execution has encountered an error.""" - return self.error is not None - - @property - def error_message(self) -> str | None: - """Get the error message if an error exists.""" - if not self.error: - return None - return str(self.error) - - def dumps(self) -> str: - """Serialize the aggregate state into a JSON string.""" - - node_states = [ - NodeExecutionState( - node_id=node_id, - state=node_execution.state, - retry_count=node_execution.retry_count, - execution_id=node_execution.execution_id, - error=node_execution.error, - ) - for node_id, node_execution in sorted(self.node_executions.items()) - ] - - state = GraphExecutionState( - workflow_id=self.workflow_id, - started=self.started, - completed=self.completed, - aborted=self.aborted, - paused=self.paused, - pause_reasons=self.pause_reasons, - error=_serialize_error(self.error), - exceptions_count=self.exceptions_count, - node_executions=node_states, - ) - - return state.model_dump_json() - - def loads(self, data: str) -> None: - """Restore aggregate state from a serialized JSON string.""" - - state = GraphExecutionState.model_validate_json(data) - - if state.type != "GraphExecution": - raise ValueError(f"Invalid serialized data type: {state.type}") - - if state.version != "1.0": - raise ValueError(f"Unsupported serialized version: {state.version}") - - if self.workflow_id != state.workflow_id: - raise ValueError("Serialized workflow_id does not match aggregate identity") - - self.started = state.started - self.completed = state.completed - self.aborted = state.aborted - self.paused = state.paused - self.pause_reasons = state.pause_reasons - self.error = _deserialize_error(state.error) - self.exceptions_count = state.exceptions_count - self.node_executions = { - item.node_id: NodeExecution( - node_id=item.node_id, - state=item.state, - retry_count=item.retry_count, - execution_id=item.execution_id, - error=item.error, - ) - for item in state.node_executions - } - - def record_node_failure(self) -> None: - """Increment the count of node failures encountered during execution.""" - self.exceptions_count += 1 - - -_: GraphExecutionProtocol = GraphExecution(workflow_id="") diff --git a/api/dify_graph/graph_engine/domain/node_execution.py b/api/dify_graph/graph_engine/domain/node_execution.py deleted file mode 100644 index ae8f9a5e50c..00000000000 --- a/api/dify_graph/graph_engine/domain/node_execution.py +++ /dev/null @@ -1,45 +0,0 @@ -""" -NodeExecution entity representing a node's execution state. -""" - -from dataclasses import dataclass - -from dify_graph.enums import NodeState - - -@dataclass -class NodeExecution: - """ - Entity representing the execution state of a single node. - - This is a mutable entity that tracks the runtime state of a node - during graph execution. - """ - - node_id: str - state: NodeState = NodeState.UNKNOWN - retry_count: int = 0 - execution_id: str | None = None - error: str | None = None - - def mark_started(self, execution_id: str) -> None: - """Mark the node as started with an execution ID.""" - self.state = NodeState.TAKEN - self.execution_id = execution_id - - def mark_taken(self) -> None: - """Mark the node as successfully completed.""" - self.state = NodeState.TAKEN - self.error = None - - def mark_failed(self, error: str) -> None: - """Mark the node as failed with an error.""" - self.error = error - - def mark_skipped(self) -> None: - """Mark the node as skipped.""" - self.state = NodeState.SKIPPED - - def increment_retry(self) -> None: - """Increment the retry count for this node.""" - self.retry_count += 1 diff --git a/api/dify_graph/graph_engine/entities/commands.py b/api/dify_graph/graph_engine/entities/commands.py deleted file mode 100644 index c56845cfc4e..00000000000 --- a/api/dify_graph/graph_engine/entities/commands.py +++ /dev/null @@ -1,56 +0,0 @@ -""" -GraphEngine command entities for external control. - -This module defines command types that can be sent to a running GraphEngine -instance to control its execution flow. -""" - -from collections.abc import Sequence -from enum import StrEnum, auto -from typing import Any - -from pydantic import BaseModel, Field - -from dify_graph.variables.variables import Variable - - -class CommandType(StrEnum): - """Types of commands that can be sent to GraphEngine.""" - - ABORT = auto() - PAUSE = auto() - UPDATE_VARIABLES = auto() - - -class GraphEngineCommand(BaseModel): - """Base class for all GraphEngine commands.""" - - command_type: CommandType = Field(..., description="Type of command") - payload: dict[str, Any] | None = Field(default=None, description="Optional command payload") - - -class AbortCommand(GraphEngineCommand): - """Command to abort a running workflow execution.""" - - command_type: CommandType = Field(default=CommandType.ABORT, description="Type of command") - reason: str | None = Field(default=None, description="Optional reason for abort") - - -class PauseCommand(GraphEngineCommand): - """Command to pause a running workflow execution.""" - - command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command") - reason: str = Field(default="unknown reason", description="reason for pause") - - -class VariableUpdate(BaseModel): - """Represents a single variable update instruction.""" - - value: Variable = Field(description="New variable value") - - -class UpdateVariablesCommand(GraphEngineCommand): - """Command to update a group of variables in the variable pool.""" - - command_type: CommandType = Field(default=CommandType.UPDATE_VARIABLES, description="Type of command") - updates: Sequence[VariableUpdate] = Field(default_factory=list, description="Variable updates") diff --git a/api/dify_graph/graph_engine/error_handler.py b/api/dify_graph/graph_engine/error_handler.py deleted file mode 100644 index e206f215922..00000000000 --- a/api/dify_graph/graph_engine/error_handler.py +++ /dev/null @@ -1,213 +0,0 @@ -""" -Main error handler that coordinates error strategies. -""" - -import logging -import time -from typing import TYPE_CHECKING, final - -from dify_graph.enums import ( - ErrorStrategy as ErrorStrategyEnum, -) -from dify_graph.enums import ( - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from dify_graph.graph import Graph -from dify_graph.graph_events import ( - GraphNodeEventBase, - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunRetryEvent, -) -from dify_graph.node_events import NodeRunResult - -if TYPE_CHECKING: - from .domain import GraphExecution - -logger = logging.getLogger(__name__) - - -@final -class ErrorHandler: - """ - Coordinates error handling strategies for node failures. - - This acts as a facade for the various error strategies, - selecting and applying the appropriate strategy based on - node configuration. - """ - - def __init__(self, graph: Graph, graph_execution: "GraphExecution") -> None: - """ - Initialize the error handler. - - Args: - graph: The workflow graph - graph_execution: The graph execution state - """ - self._graph = graph - self._graph_execution = graph_execution - - def handle_node_failure(self, event: NodeRunFailedEvent) -> GraphNodeEventBase | None: - """ - Handle a node failure event. - - Selects and applies the appropriate error strategy based on - the node's configuration. - - Args: - event: The node failure event - - Returns: - Optional new event to process, or None to abort - """ - node = self._graph.nodes[event.node_id] - # Get retry count from NodeExecution - node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) - retry_count = node_execution.retry_count - - # First check if retry is configured and not exhausted - if node.retry and retry_count < node.retry_config.max_retries: - result = self._handle_retry(event, retry_count) - if result: - # Retry count will be incremented when NodeRunRetryEvent is handled - return result - - # Apply configured error strategy - strategy = node.error_strategy - - match strategy: - case None: - return self._handle_abort(event) - case ErrorStrategyEnum.FAIL_BRANCH: - return self._handle_fail_branch(event) - case ErrorStrategyEnum.DEFAULT_VALUE: - return self._handle_default_value(event) - - def _handle_abort(self, event: NodeRunFailedEvent): - """ - Handle error by aborting execution. - - This is the default strategy when no other strategy is specified. - It stops the entire graph execution when a node fails. - - Args: - event: The failure event - - Returns: - None - signals abortion - """ - logger.error("Node %s failed with ABORT strategy: %s", event.node_id, event.error) - # Return None to signal that execution should stop - - def _handle_retry(self, event: NodeRunFailedEvent, retry_count: int): - """ - Handle error by retrying the node. - - This strategy re-attempts node execution up to a configured - maximum number of retries with configurable intervals. - - Args: - event: The failure event - retry_count: Current retry attempt count - - Returns: - NodeRunRetryEvent if retry should occur, None otherwise - """ - node = self._graph.nodes[event.node_id] - - # Check if we've exceeded max retries - if not node.retry or retry_count >= node.retry_config.max_retries: - return None - - # Wait for retry interval - time.sleep(node.retry_config.retry_interval_seconds) - - # Create retry event - return NodeRunRetryEvent( - id=event.id, - node_title=node.title, - node_id=event.node_id, - node_type=event.node_type, - node_run_result=event.node_run_result, - start_at=event.start_at, - error=event.error, - retry_index=retry_count + 1, - ) - - def _handle_fail_branch(self, event: NodeRunFailedEvent): - """ - Handle error by taking the fail branch. - - This strategy converts failures to exceptions and routes execution - through a designated fail-branch edge. - - Args: - event: The failure event - - Returns: - NodeRunExceptionEvent to continue via fail branch - """ - outputs = { - "error_message": event.node_run_result.error, - "error_type": event.node_run_result.error_type, - } - - return NodeRunExceptionEvent( - id=event.id, - node_id=event.node_id, - node_type=event.node_type, - start_at=event.start_at, - finished_at=event.finished_at, - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.EXCEPTION, - inputs=event.node_run_result.inputs, - process_data=event.node_run_result.process_data, - outputs=outputs, - edge_source_handle="fail-branch", - metadata={ - WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.FAIL_BRANCH, - }, - ), - error=event.error, - ) - - def _handle_default_value(self, event: NodeRunFailedEvent): - """ - Handle error by using default values. - - This strategy allows nodes to fail gracefully by providing - predefined default output values. - - Args: - event: The failure event - - Returns: - NodeRunExceptionEvent with default values - """ - node = self._graph.nodes[event.node_id] - - outputs = { - **node.default_value_dict, - "error_message": event.node_run_result.error, - "error_type": event.node_run_result.error_type, - } - - return NodeRunExceptionEvent( - id=event.id, - node_id=event.node_id, - node_type=event.node_type, - start_at=event.start_at, - finished_at=event.finished_at, - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.EXCEPTION, - inputs=event.node_run_result.inputs, - process_data=event.node_run_result.process_data, - outputs=outputs, - metadata={ - WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.DEFAULT_VALUE, - }, - ), - error=event.error, - ) diff --git a/api/dify_graph/graph_engine/event_management/__init__.py b/api/dify_graph/graph_engine/event_management/__init__.py deleted file mode 100644 index f6c3c0f753f..00000000000 --- a/api/dify_graph/graph_engine/event_management/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -Event management subsystem for graph engine. - -This package handles event routing, collection, and emission for -workflow graph execution events. -""" - -from .event_handlers import EventHandler -from .event_manager import EventManager - -__all__ = [ - "EventHandler", - "EventManager", -] diff --git a/api/dify_graph/graph_engine/event_management/event_handlers.py b/api/dify_graph/graph_engine/event_management/event_handlers.py deleted file mode 100644 index 7f5ad40e0eb..00000000000 --- a/api/dify_graph/graph_engine/event_management/event_handlers.py +++ /dev/null @@ -1,351 +0,0 @@ -""" -Event handler implementations for different event types. -""" - -import logging -from collections.abc import Mapping -from functools import singledispatchmethod -from typing import TYPE_CHECKING, final - -from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeState -from dify_graph.graph import Graph -from dify_graph.graph_events import ( - GraphNodeEventBase, - NodeRunAgentLogEvent, - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunIterationFailedEvent, - NodeRunIterationNextEvent, - NodeRunIterationStartedEvent, - NodeRunIterationSucceededEvent, - NodeRunLoopFailedEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, - NodeRunPauseRequestedEvent, - NodeRunRetrieverResourceEvent, - NodeRunRetryEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.runtime import GraphRuntimeState - -from ..domain.graph_execution import GraphExecution -from ..response_coordinator import ResponseStreamCoordinator - -if TYPE_CHECKING: - from ..error_handler import ErrorHandler - from ..graph_state_manager import GraphStateManager - from ..graph_traversal import EdgeProcessor - from .event_manager import EventManager - -logger = logging.getLogger(__name__) - - -@final -class EventHandler: - """ - Registry of event handlers for different event types. - - This centralizes the business logic for handling specific events, - keeping it separate from the routing and collection infrastructure. - """ - - def __init__( - self, - graph: Graph, - graph_runtime_state: GraphRuntimeState, - graph_execution: GraphExecution, - response_coordinator: ResponseStreamCoordinator, - event_collector: "EventManager", - edge_processor: "EdgeProcessor", - state_manager: "GraphStateManager", - error_handler: "ErrorHandler", - ) -> None: - """ - Initialize the event handler registry. - - Args: - graph: The workflow graph - graph_runtime_state: Runtime state with variable pool - graph_execution: Graph execution aggregate - response_coordinator: Response stream coordinator - event_collector: Event manager for collecting events - edge_processor: Edge processor for edge traversal - state_manager: Unified state manager - error_handler: Error handler - """ - self._graph = graph - self._graph_runtime_state = graph_runtime_state - self._graph_execution = graph_execution - self._response_coordinator = response_coordinator - self._event_collector = event_collector - self._edge_processor = edge_processor - self._state_manager = state_manager - self._error_handler = error_handler - - def dispatch(self, event: GraphNodeEventBase) -> None: - """ - Handle any node event by dispatching to the appropriate handler. - - Args: - event: The event to handle - """ - # Events in loops or iterations are always collected - if event.in_loop_id or event.in_iteration_id: - self._event_collector.collect(event) - return - return self._dispatch(event) - - @singledispatchmethod - def _dispatch(self, event: GraphNodeEventBase) -> None: - self._event_collector.collect(event) - logger.warning("Unhandled event type: %s", type(event).__name__) - - @_dispatch.register(NodeRunIterationStartedEvent) - @_dispatch.register(NodeRunIterationNextEvent) - @_dispatch.register(NodeRunIterationSucceededEvent) - @_dispatch.register(NodeRunIterationFailedEvent) - @_dispatch.register(NodeRunLoopStartedEvent) - @_dispatch.register(NodeRunLoopNextEvent) - @_dispatch.register(NodeRunLoopSucceededEvent) - @_dispatch.register(NodeRunLoopFailedEvent) - @_dispatch.register(NodeRunAgentLogEvent) - @_dispatch.register(NodeRunRetrieverResourceEvent) - def _(self, event: GraphNodeEventBase) -> None: - self._event_collector.collect(event) - - @_dispatch.register - def _(self, event: NodeRunStartedEvent) -> None: - """ - Handle node started event. - - Args: - event: The node started event - """ - # Track execution in domain model - node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) - is_initial_attempt = node_execution.retry_count == 0 - node_execution.mark_started(event.id) - self._graph_runtime_state.increment_node_run_steps() - - # Track in response coordinator for stream ordering - self._response_coordinator.track_node_execution(event.node_id, event.id) - - # Collect the event only for the first attempt; retries remain silent - if is_initial_attempt: - self._event_collector.collect(event) - - @_dispatch.register - def _(self, event: NodeRunStreamChunkEvent) -> None: - """ - Handle stream chunk event with full processing. - - Args: - event: The stream chunk event - """ - # Process with response coordinator - streaming_events = list(self._response_coordinator.intercept_event(event)) - - # Collect all events - for stream_event in streaming_events: - self._event_collector.collect(stream_event) - - @_dispatch.register - def _(self, event: NodeRunSucceededEvent) -> None: - """ - Handle node success by coordinating subsystems. - - This method coordinates between different subsystems to process - node completion, handle edges, and trigger downstream execution. - - Args: - event: The node succeeded event - """ - # Update domain model - node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) - node_execution.mark_taken() - - self._accumulate_node_usage(event.node_run_result.llm_usage) - - # Store outputs in variable pool - self._store_node_outputs(event.node_id, event.node_run_result.outputs) - - # Forward to response coordinator and emit streaming events - streaming_events = self._response_coordinator.intercept_event(event) - for stream_event in streaming_events: - self._event_collector.collect(stream_event) - - # Process edges and get ready nodes - node = self._graph.nodes[event.node_id] - if node.execution_type == NodeExecutionType.BRANCH: - ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion( - event.node_id, event.node_run_result.edge_source_handle - ) - else: - ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id) - - # Collect streaming events from edge processing - for edge_event in edge_streaming_events: - self._event_collector.collect(edge_event) - - # Enqueue ready nodes - if self._graph_execution.is_paused: - for node_id in ready_nodes: - self._graph_runtime_state.register_deferred_node(node_id) - else: - for node_id in ready_nodes: - self._state_manager.enqueue_node(node_id) - self._state_manager.start_execution(node_id) - - # Update execution tracking - self._state_manager.finish_execution(event.node_id) - - # Handle response node outputs - if node.execution_type == NodeExecutionType.RESPONSE: - self._update_response_outputs(event.node_run_result.outputs) - - # Collect the event - self._event_collector.collect(event) - - @_dispatch.register - def _(self, event: NodeRunPauseRequestedEvent) -> None: - """Handle pause requests emitted by nodes.""" - - pause_reason = event.reason - self._graph_execution.pause(pause_reason) - self._state_manager.finish_execution(event.node_id) - if event.node_id in self._graph.nodes: - self._graph.nodes[event.node_id].state = NodeState.UNKNOWN - self._graph_runtime_state.register_paused_node(event.node_id) - self._event_collector.collect(event) - - @_dispatch.register - def _(self, event: NodeRunFailedEvent) -> None: - """ - Handle node failure using error handler. - - Args: - event: The node failed event - """ - # Update domain model - node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) - node_execution.mark_failed(event.error) - self._graph_execution.record_node_failure() - - self._accumulate_node_usage(event.node_run_result.llm_usage) - - result = self._error_handler.handle_node_failure(event) - - if result: - # Process the resulting event (retry, exception, etc.) - self.dispatch(result) - else: - # Abort execution - self._graph_execution.fail(RuntimeError(event.error)) - self._event_collector.collect(event) - self._state_manager.finish_execution(event.node_id) - - @_dispatch.register - def _(self, event: NodeRunExceptionEvent) -> None: - """ - Handle node exception event (fail-branch strategy). - - Args: - event: The node exception event - """ - # Node continues via fail-branch/default-value, treat as completion - node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) - node_execution.mark_taken() - - self._accumulate_node_usage(event.node_run_result.llm_usage) - - # Persist outputs produced by the exception strategy (e.g. default values) - self._store_node_outputs(event.node_id, event.node_run_result.outputs) - - node = self._graph.nodes[event.node_id] - - if node.error_strategy == ErrorStrategy.DEFAULT_VALUE: - ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id) - elif node.error_strategy == ErrorStrategy.FAIL_BRANCH: - ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion( - event.node_id, event.node_run_result.edge_source_handle - ) - else: - raise NotImplementedError(f"Unsupported error strategy: {node.error_strategy}") - - for edge_event in edge_streaming_events: - self._event_collector.collect(edge_event) - - for node_id in ready_nodes: - self._state_manager.enqueue_node(node_id) - self._state_manager.start_execution(node_id) - - # Update response outputs if applicable - if node.execution_type == NodeExecutionType.RESPONSE: - self._update_response_outputs(event.node_run_result.outputs) - - self._state_manager.finish_execution(event.node_id) - - # Collect the exception event for observers - self._event_collector.collect(event) - - @_dispatch.register - def _(self, event: NodeRunRetryEvent) -> None: - """ - Handle node retry event. - - Args: - event: The node retry event - """ - node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) - node_execution.increment_retry() - - # Finish the previous attempt before re-queuing the node - self._state_manager.finish_execution(event.node_id) - - # Emit retry event for observers - self._event_collector.collect(event) - - # Re-queue node for execution - self._state_manager.enqueue_node(event.node_id) - self._state_manager.start_execution(event.node_id) - - def _accumulate_node_usage(self, usage: LLMUsage) -> None: - """Accumulate token usage into the shared runtime state.""" - if usage.total_tokens <= 0: - return - - self._graph_runtime_state.add_tokens(usage.total_tokens) - - current_usage = self._graph_runtime_state.llm_usage - if current_usage.total_tokens == 0: - self._graph_runtime_state.llm_usage = usage - else: - self._graph_runtime_state.llm_usage = current_usage.plus(usage) - - def _store_node_outputs(self, node_id: str, outputs: Mapping[str, object]) -> None: - """ - Store node outputs in the variable pool. - - Args: - event: The node succeeded event containing outputs - """ - for variable_name, variable_value in outputs.items(): - self._graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value) - - def _update_response_outputs(self, outputs: Mapping[str, object]) -> None: - """Update response outputs for response nodes.""" - # TODO: Design a mechanism for nodes to notify the engine about how to update outputs - # in runtime state, rather than allowing nodes to directly access runtime state. - for key, value in outputs.items(): - if key == "answer": - existing = self._graph_runtime_state.get_output("answer", "") - if existing: - self._graph_runtime_state.set_output("answer", f"{existing}{value}") - else: - self._graph_runtime_state.set_output("answer", value) - else: - self._graph_runtime_state.set_output(key, value) diff --git a/api/dify_graph/graph_engine/event_management/event_manager.py b/api/dify_graph/graph_engine/event_management/event_manager.py deleted file mode 100644 index 616f621c3e4..00000000000 --- a/api/dify_graph/graph_engine/event_management/event_manager.py +++ /dev/null @@ -1,186 +0,0 @@ -""" -Unified event manager for collecting and emitting events. -""" - -import logging -import threading -import time -from collections.abc import Generator -from contextlib import contextmanager -from typing import final - -from dify_graph.graph_events import GraphEngineEvent - -from ..layers.base import GraphEngineLayer - -_logger = logging.getLogger(__name__) - - -@final -class ReadWriteLock: - """ - A read-write lock implementation that allows multiple concurrent readers - but only one writer at a time. - """ - - def __init__(self) -> None: - self._read_ready = threading.Condition(threading.RLock()) - self._readers = 0 - - def acquire_read(self) -> None: - """Acquire a read lock.""" - _ = self._read_ready.acquire() - try: - self._readers += 1 - finally: - self._read_ready.release() - - def release_read(self) -> None: - """Release a read lock.""" - _ = self._read_ready.acquire() - try: - self._readers -= 1 - if self._readers == 0: - self._read_ready.notify_all() - finally: - self._read_ready.release() - - def acquire_write(self) -> None: - """Acquire a write lock.""" - _ = self._read_ready.acquire() - while self._readers > 0: - _ = self._read_ready.wait() - - def release_write(self) -> None: - """Release a write lock.""" - self._read_ready.release() - - @contextmanager - def read_lock(self): - """Return a context manager for read locking.""" - self.acquire_read() - try: - yield - finally: - self.release_read() - - @contextmanager - def write_lock(self): - """Return a context manager for write locking.""" - self.acquire_write() - try: - yield - finally: - self.release_write() - - -@final -class EventManager: - """ - Unified event manager that collects, buffers, and emits events. - - This class combines event collection with event emission, providing - thread-safe event management with support for notifying layers and - streaming events to external consumers. - """ - - def __init__(self) -> None: - """Initialize the event manager.""" - self._events: list[GraphEngineEvent] = [] - self._lock = ReadWriteLock() - self._layers: list[GraphEngineLayer] = [] - self._execution_complete = threading.Event() - - def set_layers(self, layers: list[GraphEngineLayer]) -> None: - """ - Set the layers to notify on event collection. - - Args: - layers: List of layers to notify - """ - self._layers = layers - - def notify_layers(self, event: GraphEngineEvent) -> None: - """Notify registered layers about an event without buffering it.""" - self._notify_layers(event) - - def collect(self, event: GraphEngineEvent) -> None: - """ - Thread-safe method to collect an event. - - Args: - event: The event to collect - """ - with self._lock.write_lock(): - self._events.append(event) - - # NOTE: `_notify_layers` is intentionally called outside the critical section - # to minimize lock contention and avoid blocking other readers or writers. - # - # The public `notify_layers` method also does not use a write lock, - # so protecting `_notify_layers` with a lock here is unnecessary. - self._notify_layers(event) - - def _get_new_events(self, start_index: int) -> list[GraphEngineEvent]: - """ - Get new events starting from a specific index. - - Args: - start_index: The index to start from - - Returns: - List of new events - """ - with self._lock.read_lock(): - return list(self._events[start_index:]) - - def _event_count(self) -> int: - """ - Get the current count of collected events. - - Returns: - Number of collected events - """ - with self._lock.read_lock(): - return len(self._events) - - def mark_complete(self) -> None: - """Mark execution as complete to stop the event emission generator.""" - self._execution_complete.set() - - def emit_events(self) -> Generator[GraphEngineEvent, None, None]: - """ - Generator that yields events as they're collected. - - Yields: - GraphEngineEvent instances as they're processed - """ - yielded_count = 0 - - while not self._execution_complete.is_set() or yielded_count < self._event_count(): - # Get new events since last yield - new_events = self._get_new_events(yielded_count) - - # Yield any new events - for event in new_events: - yield event - yielded_count += 1 - - # Small sleep to avoid busy waiting - if not self._execution_complete.is_set() and not new_events: - time.sleep(0.001) - - def _notify_layers(self, event: GraphEngineEvent) -> None: - """ - Notify all layers of an event. - - Layer exceptions are caught and logged to prevent disrupting collection. - - Args: - event: The event to send to layers - """ - for layer in self._layers: - try: - layer.on_event(event) - except Exception: - _logger.exception("Error in layer on_event, layer_type=%s", type(layer)) diff --git a/api/dify_graph/graph_engine/graph_engine.py b/api/dify_graph/graph_engine/graph_engine.py deleted file mode 100644 index ea98a46b063..00000000000 --- a/api/dify_graph/graph_engine/graph_engine.py +++ /dev/null @@ -1,384 +0,0 @@ -""" -QueueBasedGraphEngine - Main orchestrator for queue-based workflow execution. - -This engine uses a modular architecture with separated packages following -Domain-Driven Design principles for improved maintainability and testability. -""" - -from __future__ import annotations - -import logging -import queue -from collections.abc import Generator, Mapping -from typing import TYPE_CHECKING, cast, final - -from dify_graph.context import capture_current_context -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.enums import NodeExecutionType -from dify_graph.graph import Graph -from dify_graph.graph_events import ( - GraphEngineEvent, - GraphNodeEventBase, - GraphRunAbortedEvent, - GraphRunFailedEvent, - GraphRunPartialSucceededEvent, - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, -) -from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper -from dify_graph.runtime.graph_runtime_state import ChildGraphEngineBuilderProtocol - -if TYPE_CHECKING: # pragma: no cover - used only for static analysis - from dify_graph.runtime.graph_runtime_state import GraphProtocol - -from .command_processing import ( - AbortCommandHandler, - CommandProcessor, - PauseCommandHandler, - UpdateVariablesCommandHandler, -) -from .config import GraphEngineConfig -from .entities.commands import AbortCommand, PauseCommand, UpdateVariablesCommand -from .error_handler import ErrorHandler -from .event_management import EventHandler, EventManager -from .graph_state_manager import GraphStateManager -from .graph_traversal import EdgeProcessor, SkipPropagator -from .layers.base import GraphEngineLayer -from .orchestration import Dispatcher, ExecutionCoordinator -from .protocols.command_channel import CommandChannel -from .worker_management import WorkerPool - -if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.graph_engine.domain.graph_execution import GraphExecution - from dify_graph.graph_engine.response_coordinator import ResponseStreamCoordinator - -logger = logging.getLogger(__name__) - - -_DEFAULT_CONFIG = GraphEngineConfig() - - -@final -class GraphEngine: - """ - Queue-based graph execution engine. - - Uses a modular architecture that delegates responsibilities to specialized - subsystems, following Domain-Driven Design and SOLID principles. - """ - - def __init__( - self, - workflow_id: str, - graph: Graph, - graph_runtime_state: GraphRuntimeState, - command_channel: CommandChannel, - config: GraphEngineConfig = _DEFAULT_CONFIG, - child_engine_builder: ChildGraphEngineBuilderProtocol | None = None, - ) -> None: - """Initialize the graph engine with all subsystems and dependencies.""" - - # Bind runtime state to current workflow context - self._graph = graph - self._graph_runtime_state = graph_runtime_state - self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph)) - self._command_channel = command_channel - self._config = config - self._child_engine_builder = child_engine_builder - if child_engine_builder is not None: - self._graph_runtime_state.bind_child_engine_builder(child_engine_builder) - - # Graph execution tracks the overall execution state - self._graph_execution = cast("GraphExecution", self._graph_runtime_state.graph_execution) - self._graph_execution.workflow_id = workflow_id - - # === Execution Queues === - self._ready_queue = self._graph_runtime_state.ready_queue - - # Queue for events generated during execution - self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue() - - # === State Management === - # Unified state manager handles all node state transitions and queue operations - self._state_manager = GraphStateManager(self._graph, self._ready_queue) - - # === Response Coordination === - # Coordinates response streaming from response nodes - self._response_coordinator = cast("ResponseStreamCoordinator", self._graph_runtime_state.response_coordinator) - - # === Event Management === - # Event manager handles both collection and emission of events - self._event_manager = EventManager() - - # === Error Handling === - # Centralized error handler for graph execution errors - self._error_handler = ErrorHandler(self._graph, self._graph_execution) - - # === Graph Traversal Components === - # Propagates skip status through the graph when conditions aren't met - self._skip_propagator = SkipPropagator( - graph=self._graph, - state_manager=self._state_manager, - ) - - # Processes edges to determine next nodes after execution - # Also handles conditional branching and route selection - self._edge_processor = EdgeProcessor( - graph=self._graph, - state_manager=self._state_manager, - response_coordinator=self._response_coordinator, - skip_propagator=self._skip_propagator, - ) - - # === Command Processing === - # Processes external commands (e.g., abort requests) - self._command_processor = CommandProcessor( - command_channel=self._command_channel, - graph_execution=self._graph_execution, - ) - - # Register command handlers - abort_handler = AbortCommandHandler() - self._command_processor.register_handler(AbortCommand, abort_handler) - - pause_handler = PauseCommandHandler() - self._command_processor.register_handler(PauseCommand, pause_handler) - - update_variables_handler = UpdateVariablesCommandHandler(self._graph_runtime_state.variable_pool) - self._command_processor.register_handler(UpdateVariablesCommand, update_variables_handler) - - # === Extensibility === - # Layers allow plugins to extend engine functionality - self._layers: list[GraphEngineLayer] = [] - - # === Worker Pool Setup === - # Capture execution context for worker threads - execution_context = capture_current_context() - - # Create worker pool for parallel node execution - self._worker_pool = WorkerPool( - ready_queue=self._ready_queue, - event_queue=self._event_queue, - graph=self._graph, - layers=self._layers, - execution_context=execution_context, - config=self._config, - ) - - # === Orchestration === - # Coordinates the overall execution lifecycle - self._execution_coordinator = ExecutionCoordinator( - graph_execution=self._graph_execution, - state_manager=self._state_manager, - command_processor=self._command_processor, - worker_pool=self._worker_pool, - ) - - # === Event Handler Registry === - # Central registry for handling all node execution events - self._event_handler_registry = EventHandler( - graph=self._graph, - graph_runtime_state=self._graph_runtime_state, - graph_execution=self._graph_execution, - response_coordinator=self._response_coordinator, - event_collector=self._event_manager, - edge_processor=self._edge_processor, - state_manager=self._state_manager, - error_handler=self._error_handler, - ) - - # Dispatches events and manages execution flow - self._dispatcher = Dispatcher( - event_queue=self._event_queue, - event_handler=self._event_handler_registry, - execution_coordinator=self._execution_coordinator, - event_emitter=self._event_manager, - ) - - # === Validation === - # Ensure all nodes share the same GraphRuntimeState instance - self._validate_graph_state_consistency() - - def _validate_graph_state_consistency(self) -> None: - """Validate that all nodes share the same GraphRuntimeState.""" - expected_state_id = id(self._graph_runtime_state) - for node in self._graph.nodes.values(): - if id(node.graph_runtime_state) != expected_state_id: - raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance") - - def _bind_layer_context( - self, - layer: GraphEngineLayer, - ) -> None: - layer.initialize(ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state), self._command_channel) - - def layer(self, layer: GraphEngineLayer) -> GraphEngine: - """Add a layer for extending functionality.""" - self._layers.append(layer) - self._bind_layer_context(layer) - return self - - def create_child_engine( - self, - *, - workflow_id: str, - graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: dict[str, object] | Mapping[str, object], - root_node_id: str, - layers: list[GraphEngineLayer] | tuple[GraphEngineLayer, ...] = (), - ) -> GraphEngine: - return self._graph_runtime_state.create_child_engine( - workflow_id=workflow_id, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - graph_config=graph_config, - root_node_id=root_node_id, - layers=layers, - ) - - def run(self) -> Generator[GraphEngineEvent, None, None]: - """ - Execute the graph using the modular architecture. - - Returns: - Generator yielding GraphEngineEvent instances - """ - try: - # Initialize layers - self._initialize_layers() - - is_resume = self._graph_execution.started - if not is_resume: - self._graph_execution.start() - else: - self._graph_execution.paused = False - self._graph_execution.pause_reasons = [] - - start_event = GraphRunStartedEvent( - reason=WorkflowStartReason.RESUMPTION if is_resume else WorkflowStartReason.INITIAL, - ) - self._event_manager.notify_layers(start_event) - yield start_event - - # Start subsystems - self._start_execution(resume=is_resume) - - # Yield events as they occur - yield from self._event_manager.emit_events() - - # Handle completion - if self._graph_execution.is_paused: - pause_reasons = self._graph_execution.pause_reasons - assert pause_reasons, "pause_reasons should not be empty when execution is paused." - # Ensure we have a valid PauseReason for the event - paused_event = GraphRunPausedEvent( - reasons=pause_reasons, - outputs=self._graph_runtime_state.outputs, - ) - self._event_manager.notify_layers(paused_event) - yield paused_event - elif self._graph_execution.aborted: - abort_reason = "Workflow execution aborted by user command" - if self._graph_execution.error: - abort_reason = str(self._graph_execution.error) - aborted_event = GraphRunAbortedEvent( - reason=abort_reason, - outputs=self._graph_runtime_state.outputs, - ) - self._event_manager.notify_layers(aborted_event) - yield aborted_event - elif self._graph_execution.has_error: - if self._graph_execution.error: - raise self._graph_execution.error - else: - outputs = self._graph_runtime_state.outputs - exceptions_count = self._graph_execution.exceptions_count - if exceptions_count > 0: - partial_event = GraphRunPartialSucceededEvent( - exceptions_count=exceptions_count, - outputs=outputs, - ) - self._event_manager.notify_layers(partial_event) - yield partial_event - else: - succeeded_event = GraphRunSucceededEvent( - outputs=outputs, - ) - self._event_manager.notify_layers(succeeded_event) - yield succeeded_event - - except Exception as e: - failed_event = GraphRunFailedEvent( - error=str(e), - exceptions_count=self._graph_execution.exceptions_count, - ) - self._event_manager.notify_layers(failed_event) - yield failed_event - raise - - finally: - self._stop_execution() - - def _initialize_layers(self) -> None: - """Initialize layers with context.""" - self._event_manager.set_layers(self._layers) - for layer in self._layers: - try: - layer.on_graph_start() - except Exception: - logger.exception("Layer %s failed on_graph_start", layer.__class__.__name__) - - def _start_execution(self, *, resume: bool = False) -> None: - """Start execution subsystems.""" - paused_nodes: list[str] = [] - deferred_nodes: list[str] = [] - if resume: - paused_nodes = self._graph_runtime_state.consume_paused_nodes() - deferred_nodes = self._graph_runtime_state.consume_deferred_nodes() - - # Start worker pool (it calculates initial workers internally) - self._worker_pool.start() - - # Register response nodes - for node in self._graph.nodes.values(): - if node.execution_type == NodeExecutionType.RESPONSE: - self._response_coordinator.register(node.id) - - if not resume: - # Enqueue root node - root_node = self._graph.root_node - self._state_manager.enqueue_node(root_node.id) - self._state_manager.start_execution(root_node.id) - else: - seen_nodes: set[str] = set() - for node_id in paused_nodes + deferred_nodes: - if node_id in seen_nodes: - continue - seen_nodes.add(node_id) - self._state_manager.enqueue_node(node_id) - self._state_manager.start_execution(node_id) - - # Start dispatcher - self._dispatcher.start() - - def _stop_execution(self) -> None: - """Stop execution subsystems.""" - self._dispatcher.stop() - self._worker_pool.stop() - # Don't mark complete here as the dispatcher already does it - - # Notify layers - for layer in self._layers: - try: - layer.on_graph_end(self._graph_execution.error) - except Exception: - logger.exception("Layer %s failed on_graph_end", layer.__class__.__name__) - - # Public property accessors for attributes that need external access - @property - def graph_runtime_state(self) -> GraphRuntimeState: - """Get the graph runtime state.""" - return self._graph_runtime_state diff --git a/api/dify_graph/graph_engine/graph_state_manager.py b/api/dify_graph/graph_engine/graph_state_manager.py deleted file mode 100644 index 922a9684355..00000000000 --- a/api/dify_graph/graph_engine/graph_state_manager.py +++ /dev/null @@ -1,290 +0,0 @@ -""" -Graph state manager that combines node, edge, and execution tracking. -""" - -import threading -from collections.abc import Sequence -from typing import TypedDict, final - -from dify_graph.enums import NodeState -from dify_graph.graph import Edge, Graph - -from .ready_queue import ReadyQueue - - -class EdgeStateAnalysis(TypedDict): - """Analysis result for edge states.""" - - has_unknown: bool - has_taken: bool - all_skipped: bool - - -@final -class GraphStateManager: - def __init__(self, graph: Graph, ready_queue: ReadyQueue) -> None: - """ - Initialize the state manager. - - Args: - graph: The workflow graph - ready_queue: Queue for nodes ready to execute - """ - self._graph = graph - self._ready_queue = ready_queue - self._lock = threading.RLock() - - # Execution tracking state - self._executing_nodes: set[str] = set() - - # ============= Node State Operations ============= - - def enqueue_node(self, node_id: str) -> None: - """ - Mark a node as TAKEN and add it to the ready queue. - - This combines the state transition and enqueueing operations - that always occur together when preparing a node for execution. - - Args: - node_id: The ID of the node to enqueue - """ - with self._lock: - self._graph.nodes[node_id].state = NodeState.TAKEN - self._ready_queue.put(node_id) - - def mark_node_skipped(self, node_id: str) -> None: - """ - Mark a node as SKIPPED. - - Args: - node_id: The ID of the node to skip - """ - with self._lock: - self._graph.nodes[node_id].state = NodeState.SKIPPED - - def is_node_ready(self, node_id: str) -> bool: - """ - Check if a node is ready to be executed. - - A node is ready when all its incoming edges from taken branches - have been satisfied. - - Args: - node_id: The ID of the node to check - - Returns: - True if the node is ready for execution - """ - with self._lock: - # Get all incoming edges to this node - incoming_edges = self._graph.get_incoming_edges(node_id) - - # If no incoming edges, node is always ready - if not incoming_edges: - return True - - # If any edge is UNKNOWN, node is not ready - if any(edge.state == NodeState.UNKNOWN for edge in incoming_edges): - return False - - # Node is ready if at least one edge is TAKEN - return any(edge.state == NodeState.TAKEN for edge in incoming_edges) - - def get_node_state(self, node_id: str) -> NodeState: - """ - Get the current state of a node. - - Args: - node_id: The ID of the node - - Returns: - The current node state - """ - with self._lock: - return self._graph.nodes[node_id].state - - # ============= Edge State Operations ============= - - def mark_edge_taken(self, edge_id: str) -> None: - """ - Mark an edge as TAKEN. - - Args: - edge_id: The ID of the edge to mark - """ - with self._lock: - self._graph.edges[edge_id].state = NodeState.TAKEN - - def mark_edge_skipped(self, edge_id: str) -> None: - """ - Mark an edge as SKIPPED. - - Args: - edge_id: The ID of the edge to mark - """ - with self._lock: - self._graph.edges[edge_id].state = NodeState.SKIPPED - - def analyze_edge_states(self, edges: list[Edge]) -> EdgeStateAnalysis: - """ - Analyze the states of edges and return summary flags. - - Args: - edges: List of edges to analyze - - Returns: - Analysis result with state flags - """ - with self._lock: - states = {edge.state for edge in edges} - - return EdgeStateAnalysis( - has_unknown=NodeState.UNKNOWN in states, - has_taken=NodeState.TAKEN in states, - all_skipped=states == {NodeState.SKIPPED} if states else True, - ) - - def get_edge_state(self, edge_id: str) -> NodeState: - """ - Get the current state of an edge. - - Args: - edge_id: The ID of the edge - - Returns: - The current edge state - """ - with self._lock: - return self._graph.edges[edge_id].state - - def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]: - """ - Categorize branch edges into selected and unselected. - - Args: - node_id: The ID of the branch node - selected_handle: The handle of the selected edge - - Returns: - A tuple of (selected_edges, unselected_edges) - """ - with self._lock: - outgoing_edges = self._graph.get_outgoing_edges(node_id) - selected_edges: list[Edge] = [] - unselected_edges: list[Edge] = [] - - for edge in outgoing_edges: - if edge.source_handle == selected_handle: - selected_edges.append(edge) - else: - unselected_edges.append(edge) - - return selected_edges, unselected_edges - - # ============= Execution Tracking Operations ============= - - def start_execution(self, node_id: str) -> None: - """ - Mark a node as executing. - - Args: - node_id: The ID of the node starting execution - """ - with self._lock: - self._executing_nodes.add(node_id) - - def finish_execution(self, node_id: str) -> None: - """ - Mark a node as no longer executing. - - Args: - node_id: The ID of the node finishing execution - """ - with self._lock: - self._executing_nodes.discard(node_id) - - def is_executing(self, node_id: str) -> bool: - """ - Check if a node is currently executing. - - Args: - node_id: The ID of the node to check - - Returns: - True if the node is executing - """ - with self._lock: - return node_id in self._executing_nodes - - def get_executing_count(self) -> int: - """ - Get the count of currently executing nodes. - - Returns: - Number of executing nodes - """ - # This count is a best-effort snapshot and can change concurrently. - # Only use it for pause-drain checks where scheduling is already frozen. - with self._lock: - return len(self._executing_nodes) - - def get_executing_nodes(self) -> set[str]: - """ - Get a copy of the set of executing node IDs. - - Returns: - Set of node IDs currently executing - """ - with self._lock: - return self._executing_nodes.copy() - - def clear_executing(self) -> None: - """Clear all executing nodes.""" - with self._lock: - self._executing_nodes.clear() - - # ============= Composite Operations ============= - - def is_execution_complete(self) -> bool: - """ - Check if graph execution is complete. - - Execution is complete when: - - Ready queue is empty - - No nodes are executing - - Returns: - True if execution is complete - """ - with self._lock: - return self._ready_queue.empty() and len(self._executing_nodes) == 0 - - def get_queue_depth(self) -> int: - """ - Get the current depth of the ready queue. - - Returns: - Number of nodes in the ready queue - """ - return self._ready_queue.qsize() - - def get_execution_stats(self) -> dict[str, int]: - """ - Get execution statistics. - - Returns: - Dictionary with execution statistics - """ - with self._lock: - taken_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.TAKEN) - skipped_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.SKIPPED) - unknown_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.UNKNOWN) - - return { - "queue_depth": self._ready_queue.qsize(), - "executing": len(self._executing_nodes), - "taken_nodes": taken_nodes, - "skipped_nodes": skipped_nodes, - "unknown_nodes": unknown_nodes, - } diff --git a/api/dify_graph/graph_engine/graph_traversal/__init__.py b/api/dify_graph/graph_engine/graph_traversal/__init__.py deleted file mode 100644 index d629140d066..00000000000 --- a/api/dify_graph/graph_engine/graph_traversal/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -Graph traversal subsystem for graph engine. - -This package handles graph navigation, edge processing, -and skip propagation logic. -""" - -from .edge_processor import EdgeProcessor -from .skip_propagator import SkipPropagator - -__all__ = [ - "EdgeProcessor", - "SkipPropagator", -] diff --git a/api/dify_graph/graph_engine/graph_traversal/edge_processor.py b/api/dify_graph/graph_engine/graph_traversal/edge_processor.py deleted file mode 100644 index c4625a8ff75..00000000000 --- a/api/dify_graph/graph_engine/graph_traversal/edge_processor.py +++ /dev/null @@ -1,201 +0,0 @@ -""" -Edge processing logic for graph traversal. -""" - -from collections.abc import Sequence -from typing import TYPE_CHECKING, final - -from dify_graph.enums import NodeExecutionType -from dify_graph.graph import Edge, Graph -from dify_graph.graph_events import NodeRunStreamChunkEvent - -from ..graph_state_manager import GraphStateManager -from ..response_coordinator import ResponseStreamCoordinator - -if TYPE_CHECKING: - from .skip_propagator import SkipPropagator - - -@final -class EdgeProcessor: - """ - Processes edges during graph execution. - - This handles marking edges as taken or skipped, notifying - the response coordinator, triggering downstream node execution, - and managing branch node logic. - """ - - def __init__( - self, - graph: Graph, - state_manager: GraphStateManager, - response_coordinator: ResponseStreamCoordinator, - skip_propagator: "SkipPropagator", - ) -> None: - """ - Initialize the edge processor. - - Args: - graph: The workflow graph - state_manager: Unified state manager - response_coordinator: Response stream coordinator - skip_propagator: Propagator for skip states - """ - self._graph = graph - self._state_manager = state_manager - self._response_coordinator = response_coordinator - self._skip_propagator = skip_propagator - - def process_node_success( - self, node_id: str, selected_handle: str | None = None - ) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: - """ - Process edges after a node succeeds. - - Args: - node_id: The ID of the succeeded node - selected_handle: For branch nodes, the selected edge handle - - Returns: - Tuple of (list of downstream node IDs that are now ready, list of streaming events) - """ - node = self._graph.nodes[node_id] - - if node.execution_type == NodeExecutionType.BRANCH: - return self._process_branch_node_edges(node_id, selected_handle) - else: - return self._process_non_branch_node_edges(node_id) - - def _process_non_branch_node_edges(self, node_id: str) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: - """ - Process edges for non-branch nodes (mark all as TAKEN). - - Args: - node_id: The ID of the succeeded node - - Returns: - Tuple of (list of downstream nodes ready for execution, list of streaming events) - """ - ready_nodes: list[str] = [] - all_streaming_events: list[NodeRunStreamChunkEvent] = [] - outgoing_edges = self._graph.get_outgoing_edges(node_id) - - for edge in outgoing_edges: - nodes, events = self._process_taken_edge(edge) - ready_nodes.extend(nodes) - all_streaming_events.extend(events) - - return ready_nodes, all_streaming_events - - def _process_branch_node_edges( - self, node_id: str, selected_handle: str | None - ) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: - """ - Process edges for branch nodes. - - Args: - node_id: The ID of the branch node - selected_handle: The handle of the selected edge - - Returns: - Tuple of (list of downstream nodes ready for execution, list of streaming events) - - Raises: - ValueError: If no edge was selected - """ - if not selected_handle: - raise ValueError(f"Branch node {node_id} did not select any edge") - - ready_nodes: list[str] = [] - all_streaming_events: list[NodeRunStreamChunkEvent] = [] - - # Categorize edges - selected_edges, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle) - - # Process unselected edges first (mark as skipped) - for edge in unselected_edges: - self._process_skipped_edge(edge) - - # Process selected edges - for edge in selected_edges: - nodes, events = self._process_taken_edge(edge) - ready_nodes.extend(nodes) - all_streaming_events.extend(events) - - return ready_nodes, all_streaming_events - - def _process_taken_edge(self, edge: Edge) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: - """ - Mark edge as taken and check downstream node. - - Args: - edge: The edge to process - - Returns: - Tuple of (list containing downstream node ID if it's ready, list of streaming events) - """ - # Mark edge as taken - self._state_manager.mark_edge_taken(edge.id) - - # Notify response coordinator and get streaming events - streaming_events = self._response_coordinator.on_edge_taken(edge.id) - - # Check if downstream node is ready - ready_nodes: list[str] = [] - if self._state_manager.is_node_ready(edge.head): - ready_nodes.append(edge.head) - - return ready_nodes, streaming_events - - def _process_skipped_edge(self, edge: Edge) -> None: - """ - Mark edge as skipped. - - Args: - edge: The edge to skip - """ - self._state_manager.mark_edge_skipped(edge.id) - - def handle_branch_completion( - self, node_id: str, selected_handle: str | None - ) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: - """ - Handle completion of a branch node. - - Args: - node_id: The ID of the branch node - selected_handle: The handle of the selected branch - - Returns: - Tuple of (list of downstream nodes ready for execution, list of streaming events) - - Raises: - ValueError: If no branch was selected - """ - if not selected_handle: - raise ValueError(f"Branch node {node_id} completed without selecting a branch") - - # Categorize edges into selected and unselected - _, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle) - - # Skip all unselected paths - self._skip_propagator.skip_branch_paths(unselected_edges) - - # Process selected edges and get ready nodes and streaming events - return self.process_node_success(node_id, selected_handle) - - def validate_branch_selection(self, node_id: str, selected_handle: str) -> bool: - """ - Validate that a branch selection is valid. - - Args: - node_id: The ID of the branch node - selected_handle: The handle to validate - - Returns: - True if the selection is valid - """ - outgoing_edges = self._graph.get_outgoing_edges(node_id) - valid_handles = {edge.source_handle for edge in outgoing_edges} - return selected_handle in valid_handles diff --git a/api/dify_graph/graph_engine/graph_traversal/skip_propagator.py b/api/dify_graph/graph_engine/graph_traversal/skip_propagator.py deleted file mode 100644 index 76445bccd21..00000000000 --- a/api/dify_graph/graph_engine/graph_traversal/skip_propagator.py +++ /dev/null @@ -1,96 +0,0 @@ -""" -Skip state propagation through the graph. -""" - -from collections.abc import Sequence -from typing import final - -from dify_graph.graph import Edge, Graph - -from ..graph_state_manager import GraphStateManager - - -@final -class SkipPropagator: - """ - Propagates skip states through the graph. - - When a node is skipped, this ensures all downstream nodes - that depend solely on it are also skipped. - """ - - def __init__( - self, - graph: Graph, - state_manager: GraphStateManager, - ) -> None: - """ - Initialize the skip propagator. - - Args: - graph: The workflow graph - state_manager: Unified state manager - """ - self._graph = graph - self._state_manager = state_manager - - def propagate_skip_from_edge(self, edge_id: str) -> None: - """ - Recursively propagate skip state from a skipped edge. - - Rules: - - If a node has any UNKNOWN incoming edges, stop processing - - If all incoming edges are SKIPPED, skip the node and its edges - - If any incoming edge is TAKEN, the node may still execute - - Args: - edge_id: The ID of the skipped edge to start from - """ - downstream_node_id = self._graph.edges[edge_id].head - incoming_edges = self._graph.get_incoming_edges(downstream_node_id) - - # Analyze edge states - edge_states = self._state_manager.analyze_edge_states(incoming_edges) - - # Stop if there are unknown edges (not yet processed) - if edge_states["has_unknown"]: - return - - # If any edge is taken, node may still execute - if edge_states["has_taken"]: - # Enqueue node - self._state_manager.enqueue_node(downstream_node_id) - self._state_manager.start_execution(downstream_node_id) - return - - # All edges are skipped, propagate skip to this node - if edge_states["all_skipped"]: - self._propagate_skip_to_node(downstream_node_id) - - def _propagate_skip_to_node(self, node_id: str) -> None: - """ - Mark a node and all its outgoing edges as skipped. - - Args: - node_id: The ID of the node to skip - """ - # Mark node as skipped - self._state_manager.mark_node_skipped(node_id) - - # Mark all outgoing edges as skipped and propagate - outgoing_edges = self._graph.get_outgoing_edges(node_id) - for edge in outgoing_edges: - self._state_manager.mark_edge_skipped(edge.id) - # Recursively propagate skip - self.propagate_skip_from_edge(edge.id) - - def skip_branch_paths(self, unselected_edges: Sequence[Edge]) -> None: - """ - Skip all paths from unselected branch edges. - - Args: - unselected_edges: List of edges not taken by the branch - """ - for edge in unselected_edges: - self._state_manager.mark_edge_skipped(edge.id) - self.propagate_skip_from_edge(edge.id) diff --git a/api/dify_graph/graph_engine/layers/README.md b/api/dify_graph/graph_engine/layers/README.md deleted file mode 100644 index b0f295037c0..00000000000 --- a/api/dify_graph/graph_engine/layers/README.md +++ /dev/null @@ -1,55 +0,0 @@ -# Layers - -Pluggable middleware for engine extensions. - -## Components - -### Layer (base) - -Abstract base class for layers. - -- `initialize()` - Receive runtime context (runtime state is bound here and always available to hooks) -- `on_graph_start()` - Execution start hook -- `on_event()` - Process all events -- `on_graph_end()` - Execution end hook - -### DebugLoggingLayer - -Comprehensive execution logging. - -- Configurable detail levels -- Tracks execution statistics -- Truncates long values - -## Usage - -```python -debug_layer = DebugLoggingLayer( - level="INFO", - include_outputs=True -) - -engine = GraphEngine(graph) -engine.layer(debug_layer) -engine.run() -``` - -`engine.layer()` binds the read-only runtime state before execution, so -`graph_runtime_state` is always available inside layer hooks. - -## Custom Layers - -```python -class MetricsLayer(Layer): - def on_event(self, event): - if isinstance(event, NodeRunSucceededEvent): - self.metrics[event.node_id] = event.elapsed_time -``` - -## Configuration - -**DebugLoggingLayer Options:** - -- `level` - Log level (INFO, DEBUG, ERROR) -- `include_inputs/outputs` - Log data values -- `max_value_length` - Truncate long values diff --git a/api/dify_graph/graph_engine/layers/__init__.py b/api/dify_graph/graph_engine/layers/__init__.py deleted file mode 100644 index 0a29a529936..00000000000 --- a/api/dify_graph/graph_engine/layers/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -Layer system for GraphEngine extensibility. - -This module provides the layer infrastructure for extending GraphEngine functionality -with middleware-like components that can observe events and interact with execution. -""" - -from .base import GraphEngineLayer -from .debug_logging import DebugLoggingLayer -from .execution_limits import ExecutionLimitsLayer - -__all__ = [ - "DebugLoggingLayer", - "ExecutionLimitsLayer", - "GraphEngineLayer", -] diff --git a/api/dify_graph/graph_engine/layers/base.py b/api/dify_graph/graph_engine/layers/base.py deleted file mode 100644 index 890336c1cac..00000000000 --- a/api/dify_graph/graph_engine/layers/base.py +++ /dev/null @@ -1,128 +0,0 @@ -""" -Base layer class for GraphEngine extensions. - -This module provides the abstract base class for implementing layers that can -intercept and respond to GraphEngine events. -""" - -from abc import ABC, abstractmethod - -from dify_graph.graph_engine.protocols.command_channel import CommandChannel -from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase -from dify_graph.nodes.base.node import Node -from dify_graph.runtime import ReadOnlyGraphRuntimeState - - -class GraphEngineLayerNotInitializedError(Exception): - """Raised when a layer's runtime state is accessed before initialization.""" - - def __init__(self, layer_name: str | None = None) -> None: - name = layer_name or "GraphEngineLayer" - super().__init__(f"{name} runtime state is not initialized. Bind the layer to a GraphEngine before access.") - - -class GraphEngineLayer(ABC): - """ - Abstract base class for GraphEngine layers. - - Layers are middleware-like components that can: - - Observe all events emitted by the GraphEngine - - Access the graph runtime state - - Send commands to control execution - - Subclasses should override the constructor to accept configuration parameters, - then implement the three lifecycle methods. - """ - - def __init__(self) -> None: - """Initialize the layer. Subclasses can override with custom parameters.""" - self._graph_runtime_state: ReadOnlyGraphRuntimeState | None = None - self.command_channel: CommandChannel | None = None - - @property - def graph_runtime_state(self) -> ReadOnlyGraphRuntimeState: - if self._graph_runtime_state is None: - raise GraphEngineLayerNotInitializedError(type(self).__name__) - return self._graph_runtime_state - - def initialize(self, graph_runtime_state: ReadOnlyGraphRuntimeState, command_channel: CommandChannel) -> None: - """ - Initialize the layer with engine dependencies. - - Called by GraphEngine to inject the read-only runtime state and command channel. - This is invoked when the layer is registered with a `GraphEngine` instance. - Implementations should be idempotent. - Args: - graph_runtime_state: Read-only view of the runtime state - command_channel: Channel for sending commands to the engine - """ - self._graph_runtime_state = graph_runtime_state - self.command_channel = command_channel - - @abstractmethod - def on_graph_start(self) -> None: - """ - Called when graph execution starts. - - This is called after the engine has been initialized but before any nodes - are executed. Layers can use this to set up resources or log start information. - """ - pass - - @abstractmethod - def on_event(self, event: GraphEngineEvent) -> None: - """ - Called for every event emitted by the engine. - - This method receives all events generated during graph execution, including: - - Graph lifecycle events (start, success, failure) - - Node execution events (start, success, failure, retry) - - Stream events for response nodes - - Container events (iteration, loop) - - Args: - event: The event emitted by the engine - """ - pass - - @abstractmethod - def on_graph_end(self, error: Exception | None) -> None: - """ - Called when graph execution ends. - - This is called after all nodes have been executed or when execution is - aborted. Layers can use this to clean up resources or log final state. - - Args: - error: The exception that caused execution to fail, or None if successful - """ - pass - - def on_node_run_start(self, node: Node) -> None: - """ - Called immediately before a node begins execution. - - Layers can override to inject behavior (e.g., start spans) prior to node execution. - The node's execution ID is available via `node._node_execution_id` and will be - consistent with all events emitted by this node execution. - - Args: - node: The node instance about to be executed - """ - return - - def on_node_run_end( - self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None - ) -> None: - """ - Called after a node finishes execution. - - The node's execution ID is available via `node._node_execution_id` and matches - the `id` field in all events emitted by this node execution. - - Args: - node: The node instance that just finished execution - error: Exception instance if the node failed, otherwise None - result_event: The final result event from node execution (succeeded/failed/paused), if any - """ - return diff --git a/api/dify_graph/graph_engine/layers/debug_logging.py b/api/dify_graph/graph_engine/layers/debug_logging.py deleted file mode 100644 index 1af2e2db9ed..00000000000 --- a/api/dify_graph/graph_engine/layers/debug_logging.py +++ /dev/null @@ -1,247 +0,0 @@ -""" -Debug logging layer for GraphEngine. - -This module provides a layer that logs all events and state changes during -graph execution for debugging purposes. -""" - -import logging -from collections.abc import Mapping -from typing import Any, final - -from typing_extensions import override - -from dify_graph.graph_events import ( - GraphEngineEvent, - GraphRunAbortedEvent, - GraphRunFailedEvent, - GraphRunPartialSucceededEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunIterationFailedEvent, - NodeRunIterationNextEvent, - NodeRunIterationStartedEvent, - NodeRunIterationSucceededEvent, - NodeRunLoopFailedEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, - NodeRunRetryEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) - -from .base import GraphEngineLayer - - -@final -class DebugLoggingLayer(GraphEngineLayer): - """ - A layer that provides comprehensive logging of GraphEngine execution. - - This layer logs all events with configurable detail levels, helping developers - debug workflow execution and understand the flow of events. - """ - - def __init__( - self, - level: str = "INFO", - include_inputs: bool = False, - include_outputs: bool = True, - include_process_data: bool = False, - logger_name: str = "GraphEngine.Debug", - max_value_length: int = 500, - ) -> None: - """ - Initialize the debug logging layer. - - Args: - level: Logging level (DEBUG, INFO, WARNING, ERROR) - include_inputs: Whether to log node input values - include_outputs: Whether to log node output values - include_process_data: Whether to log node process data - logger_name: Name of the logger to use - max_value_length: Maximum length of logged values (truncated if longer) - """ - super().__init__() - self.level = level - self.include_inputs = include_inputs - self.include_outputs = include_outputs - self.include_process_data = include_process_data - self.max_value_length = max_value_length - - # Set up logger - self.logger = logging.getLogger(logger_name) - log_level = getattr(logging, level.upper(), logging.INFO) - self.logger.setLevel(log_level) - - # Track execution stats - self.node_count = 0 - self.success_count = 0 - self.failure_count = 0 - self.retry_count = 0 - - def _truncate_value(self, value: Any) -> str: - """Truncate long values for logging.""" - str_value = str(value) - if len(str_value) > self.max_value_length: - return str_value[: self.max_value_length] + "... (truncated)" - return str_value - - def _format_dict(self, data: dict[str, Any] | Mapping[str, Any]) -> str: - """Format a dictionary or mapping for logging with truncation.""" - if not data: - return "{}" - - formatted_items: list[str] = [] - for key, value in data.items(): - formatted_value = self._truncate_value(value) - formatted_items.append(f" {key}: {formatted_value}") - - return "{\n" + ",\n".join(formatted_items) + "\n}" - - @override - def on_graph_start(self) -> None: - """Log graph execution start.""" - self.logger.info("=" * 80) - self.logger.info("🚀 GRAPH EXECUTION STARTED") - self.logger.info("=" * 80) - # Log initial state - self.logger.info("Initial State:") - - @override - def on_event(self, event: GraphEngineEvent) -> None: - """Log individual events based on their type.""" - event_class = event.__class__.__name__ - - # Graph-level events - if isinstance(event, GraphRunStartedEvent): - self.logger.debug("Graph run started event") - - elif isinstance(event, GraphRunSucceededEvent): - self.logger.info("✅ Graph run succeeded") - if self.include_outputs and event.outputs: - self.logger.info(" Final outputs: %s", self._format_dict(event.outputs)) - - elif isinstance(event, GraphRunPartialSucceededEvent): - self.logger.warning("⚠️ Graph run partially succeeded") - if event.exceptions_count > 0: - self.logger.warning(" Total exceptions: %s", event.exceptions_count) - if self.include_outputs and event.outputs: - self.logger.info(" Final outputs: %s", self._format_dict(event.outputs)) - - elif isinstance(event, GraphRunFailedEvent): - self.logger.error("❌ Graph run failed: %s", event.error) - if event.exceptions_count > 0: - self.logger.error(" Total exceptions: %s", event.exceptions_count) - - elif isinstance(event, GraphRunAbortedEvent): - self.logger.warning("⚠️ Graph run aborted: %s", event.reason) - if event.outputs: - self.logger.info(" Partial outputs: %s", self._format_dict(event.outputs)) - - # Node-level events - # Retry before Started because Retry subclasses Started; - elif isinstance(event, NodeRunRetryEvent): - self.retry_count += 1 - self.logger.warning("🔄 Node retry: %s (attempt %s)", event.node_id, event.retry_index) - self.logger.warning(" Previous error: %s", event.error) - - elif isinstance(event, NodeRunStartedEvent): - self.node_count += 1 - self.logger.info('▶️ Node started: %s - "%s" (type: %s)', event.node_id, event.node_title, event.node_type) - - if self.include_inputs and event.node_run_result.inputs: - self.logger.debug(" Inputs: %s", self._format_dict(event.node_run_result.inputs)) - - elif isinstance(event, NodeRunSucceededEvent): - self.success_count += 1 - self.logger.info("✅ Node succeeded: %s", event.node_id) - - if self.include_outputs and event.node_run_result.outputs: - self.logger.debug(" Outputs: %s", self._format_dict(event.node_run_result.outputs)) - - if self.include_process_data and event.node_run_result.process_data: - self.logger.debug(" Process data: %s", self._format_dict(event.node_run_result.process_data)) - - elif isinstance(event, NodeRunFailedEvent): - self.failure_count += 1 - self.logger.error("❌ Node failed: %s", event.node_id) - self.logger.error(" Error: %s", event.error) - - if event.node_run_result.error: - self.logger.error(" Details: %s", event.node_run_result.error) - - elif isinstance(event, NodeRunExceptionEvent): - self.logger.warning("⚠️ Node exception handled: %s", event.node_id) - self.logger.warning(" Error: %s", event.error) - - elif isinstance(event, NodeRunStreamChunkEvent): - # Log stream chunks at debug level to avoid spam - final_indicator = " (FINAL)" if event.is_final else "" - self.logger.debug( - "📝 Stream chunk from %s%s: %s", event.node_id, final_indicator, self._truncate_value(event.chunk) - ) - - # Iteration events - elif isinstance(event, NodeRunIterationStartedEvent): - self.logger.info("🔁 Iteration started: %s", event.node_id) - - elif isinstance(event, NodeRunIterationNextEvent): - self.logger.debug(" Iteration next: %s (index: %s)", event.node_id, event.index) - - elif isinstance(event, NodeRunIterationSucceededEvent): - self.logger.info("✅ Iteration succeeded: %s", event.node_id) - if self.include_outputs and event.outputs: - self.logger.debug(" Outputs: %s", self._format_dict(event.outputs)) - - elif isinstance(event, NodeRunIterationFailedEvent): - self.logger.error("❌ Iteration failed: %s", event.node_id) - self.logger.error(" Error: %s", event.error) - - # Loop events - elif isinstance(event, NodeRunLoopStartedEvent): - self.logger.info("🔄 Loop started: %s", event.node_id) - - elif isinstance(event, NodeRunLoopNextEvent): - self.logger.debug(" Loop iteration: %s (index: %s)", event.node_id, event.index) - - elif isinstance(event, NodeRunLoopSucceededEvent): - self.logger.info("✅ Loop succeeded: %s", event.node_id) - if self.include_outputs and event.outputs: - self.logger.debug(" Outputs: %s", self._format_dict(event.outputs)) - - elif isinstance(event, NodeRunLoopFailedEvent): - self.logger.error("❌ Loop failed: %s", event.node_id) - self.logger.error(" Error: %s", event.error) - - else: - # Log unknown events at debug level - self.logger.debug("Event: %s", event_class) - - @override - def on_graph_end(self, error: Exception | None) -> None: - """Log graph execution end with summary statistics.""" - self.logger.info("=" * 80) - - if error: - self.logger.error("🔴 GRAPH EXECUTION FAILED") - self.logger.error(" Error: %s", error) - else: - self.logger.info("🎉 GRAPH EXECUTION COMPLETED SUCCESSFULLY") - - # Log execution statistics - self.logger.info("Execution Statistics:") - self.logger.info(" Total nodes executed: %s", self.node_count) - self.logger.info(" Successful nodes: %s", self.success_count) - self.logger.info(" Failed nodes: %s", self.failure_count) - self.logger.info(" Node retries: %s", self.retry_count) - - # Log final state if available - if self.include_outputs and self.graph_runtime_state.outputs: - self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs)) - - self.logger.info("=" * 80) diff --git a/api/dify_graph/graph_engine/layers/execution_limits.py b/api/dify_graph/graph_engine/layers/execution_limits.py deleted file mode 100644 index 48ba5608d94..00000000000 --- a/api/dify_graph/graph_engine/layers/execution_limits.py +++ /dev/null @@ -1,150 +0,0 @@ -""" -Execution limits layer for GraphEngine. - -This layer monitors workflow execution to enforce limits on: -- Maximum execution steps -- Maximum execution time - -When limits are exceeded, the layer automatically aborts execution. -""" - -import logging -import time -from enum import StrEnum -from typing import final - -from typing_extensions import override - -from dify_graph.graph_engine.entities.commands import AbortCommand, CommandType -from dify_graph.graph_engine.layers import GraphEngineLayer -from dify_graph.graph_events import ( - GraphEngineEvent, - NodeRunStartedEvent, -) -from dify_graph.graph_events.node import NodeRunFailedEvent, NodeRunSucceededEvent - - -class LimitType(StrEnum): - """Types of execution limits that can be exceeded.""" - - STEP_LIMIT = "step_limit" - TIME_LIMIT = "time_limit" - - -@final -class ExecutionLimitsLayer(GraphEngineLayer): - """ - Layer that enforces execution limits for workflows. - - Monitors: - - Step count: Tracks number of node executions - - Time limit: Monitors total execution time - - Automatically aborts execution when limits are exceeded. - """ - - def __init__(self, max_steps: int, max_time: int) -> None: - """ - Initialize the execution limits layer. - - Args: - max_steps: Maximum number of execution steps allowed - max_time: Maximum execution time in seconds allowed - """ - super().__init__() - self.max_steps = max_steps - self.max_time = max_time - - # Runtime tracking - self.start_time: float | None = None - self.step_count = 0 - self.logger = logging.getLogger(__name__) - - # State tracking - self._execution_started = False - self._execution_ended = False - self._abort_sent = False # Track if abort command has been sent - - @override - def on_graph_start(self) -> None: - """Called when graph execution starts.""" - self.start_time = time.time() - self.step_count = 0 - self._execution_started = True - self._execution_ended = False - self._abort_sent = False - - self.logger.debug("Execution limits monitoring started") - - @override - def on_event(self, event: GraphEngineEvent) -> None: - """ - Called for every event emitted by the engine. - - Monitors execution progress and enforces limits. - """ - if not self._execution_started or self._execution_ended or self._abort_sent: - return - - # Track step count for node execution events - if isinstance(event, NodeRunStartedEvent): - self.step_count += 1 - self.logger.debug("Step %d started: %s", self.step_count, event.node_id) - - # Check step limit when node execution completes - if isinstance(event, NodeRunSucceededEvent | NodeRunFailedEvent): - if self._reached_step_limitation(): - self._send_abort_command(LimitType.STEP_LIMIT) - - if self._reached_time_limitation(): - self._send_abort_command(LimitType.TIME_LIMIT) - - @override - def on_graph_end(self, error: Exception | None) -> None: - """Called when graph execution ends.""" - if self._execution_started and not self._execution_ended: - self._execution_ended = True - - if self.start_time: - total_time = time.time() - self.start_time - self.logger.debug("Execution completed: %d steps in %.2f seconds", self.step_count, total_time) - - def _reached_step_limitation(self) -> bool: - """Check if step count limit has been exceeded.""" - return self.step_count > self.max_steps - - def _reached_time_limitation(self) -> bool: - """Check if time limit has been exceeded.""" - return self.start_time is not None and (time.time() - self.start_time) > self.max_time - - def _send_abort_command(self, limit_type: LimitType) -> None: - """ - Send abort command due to limit violation. - - Args: - limit_type: Type of limit exceeded - """ - if not self.command_channel or not self._execution_started or self._execution_ended or self._abort_sent: - return - - # Format detailed reason message - if limit_type == LimitType.STEP_LIMIT: - reason = f"Maximum execution steps exceeded: {self.step_count} > {self.max_steps}" - elif limit_type == LimitType.TIME_LIMIT: - elapsed_time = time.time() - self.start_time if self.start_time else 0 - reason = f"Maximum execution time exceeded: {elapsed_time:.2f}s > {self.max_time}s" - - self.logger.warning("Execution limit exceeded: %s", reason) - - try: - # Send abort command to the engine - abort_command = AbortCommand(command_type=CommandType.ABORT, reason=reason) - self.command_channel.send_command(abort_command) - - # Mark that abort has been sent to prevent duplicate commands - self._abort_sent = True - - self.logger.debug("Abort command sent to engine") - - except Exception: - self.logger.exception("Failed to send abort command") diff --git a/api/dify_graph/graph_engine/manager.py b/api/dify_graph/graph_engine/manager.py deleted file mode 100644 index 955c1490694..00000000000 --- a/api/dify_graph/graph_engine/manager.py +++ /dev/null @@ -1,79 +0,0 @@ -""" -GraphEngine Manager for sending control commands via Redis channel. - -This module provides a simplified interface for controlling workflow executions -using the new Redis command channel, without requiring user permission checks. -Callers must provide a Redis client dependency from outside the workflow package. -""" - -import logging -from collections.abc import Sequence -from typing import final - -from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel, RedisClientProtocol -from dify_graph.graph_engine.entities.commands import ( - AbortCommand, - GraphEngineCommand, - PauseCommand, - UpdateVariablesCommand, - VariableUpdate, -) - -logger = logging.getLogger(__name__) - - -@final -class GraphEngineManager: - """ - Manager for sending control commands to GraphEngine instances. - - This class provides a simple interface for controlling workflow executions - by sending commands through Redis channels, without user validation. - """ - - _redis_client: RedisClientProtocol - - def __init__(self, redis_client: RedisClientProtocol) -> None: - self._redis_client = redis_client - - def send_stop_command(self, task_id: str, reason: str | None = None) -> None: - """ - Send a stop command to a running workflow. - - Args: - task_id: The task ID of the workflow to stop - reason: Optional reason for stopping (defaults to "User requested stop") - """ - abort_command = AbortCommand(reason=reason or "User requested stop") - self._send_command(task_id, abort_command) - - def send_pause_command(self, task_id: str, reason: str | None = None) -> None: - """Send a pause command to a running workflow.""" - - pause_command = PauseCommand(reason=reason or "User requested pause") - self._send_command(task_id, pause_command) - - def send_update_variables_command(self, task_id: str, updates: Sequence[VariableUpdate]) -> None: - """Send a command to update variables in a running workflow.""" - - if not updates: - return - - update_command = UpdateVariablesCommand(updates=updates) - self._send_command(task_id, update_command) - - def _send_command(self, task_id: str, command: GraphEngineCommand) -> None: - """Send a command to the workflow-specific Redis channel.""" - - if not task_id: - return - - channel_key = f"workflow:{task_id}:commands" - channel = RedisChannel(self._redis_client, channel_key) - - try: - channel.send_command(command) - except Exception: - # Silently fail if Redis is unavailable - # The legacy control mechanisms will still work - logger.exception("Failed to send graph engine command %s for task %s", command.__class__.__name__, task_id) diff --git a/api/dify_graph/graph_engine/orchestration/__init__.py b/api/dify_graph/graph_engine/orchestration/__init__.py deleted file mode 100644 index de08e942fb3..00000000000 --- a/api/dify_graph/graph_engine/orchestration/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -Orchestration subsystem for graph engine. - -This package coordinates the overall execution flow between -different subsystems. -""" - -from .dispatcher import Dispatcher -from .execution_coordinator import ExecutionCoordinator - -__all__ = [ - "Dispatcher", - "ExecutionCoordinator", -] diff --git a/api/dify_graph/graph_engine/orchestration/dispatcher.py b/api/dify_graph/graph_engine/orchestration/dispatcher.py deleted file mode 100644 index f8aaf20b2f3..00000000000 --- a/api/dify_graph/graph_engine/orchestration/dispatcher.py +++ /dev/null @@ -1,143 +0,0 @@ -""" -Main dispatcher for processing events from workers. -""" - -import logging -import queue -import threading -import time -from typing import TYPE_CHECKING, final - -from dify_graph.graph_events import ( - GraphNodeEventBase, - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunSucceededEvent, -) - -from ..event_management import EventManager -from .execution_coordinator import ExecutionCoordinator - -if TYPE_CHECKING: - from ..event_management import EventHandler - -logger = logging.getLogger(__name__) - - -@final -class Dispatcher: - """ - Main dispatcher that processes events from the event queue. - - This runs in a separate thread and coordinates event processing - with timeout and completion detection. - """ - - _COMMAND_TRIGGER_EVENTS = ( - NodeRunSucceededEvent, - NodeRunFailedEvent, - NodeRunExceptionEvent, - ) - - def __init__( - self, - event_queue: queue.Queue[GraphNodeEventBase], - event_handler: "EventHandler", - execution_coordinator: ExecutionCoordinator, - event_emitter: EventManager | None = None, - ) -> None: - """ - Initialize the dispatcher. - - Args: - event_queue: Queue of events from workers - event_handler: Event handler registry for processing events - execution_coordinator: Coordinator for execution flow - event_emitter: Optional event manager to signal completion - """ - self._event_queue = event_queue - self._event_handler = event_handler - self._execution_coordinator = execution_coordinator - self._event_emitter = event_emitter - - self._thread: threading.Thread | None = None - self._stop_event = threading.Event() - self._start_time: float | None = None - - def start(self) -> None: - """Start the dispatcher thread.""" - if self._thread and self._thread.is_alive(): - return - - self._stop_event.clear() - self._start_time = time.time() - self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True) - self._thread.start() - - def stop(self) -> None: - """Stop the dispatcher thread.""" - self._stop_event.set() - if self._thread and self._thread.is_alive(): - self._thread.join(timeout=2.0) - - def _dispatcher_loop(self) -> None: - """Main dispatcher loop.""" - try: - self._process_commands() - paused = False - while not self._stop_event.is_set(): - if self._execution_coordinator.aborted or self._execution_coordinator.execution_complete: - break - if self._execution_coordinator.paused: - paused = True - break - - self._execution_coordinator.check_scaling() - try: - event = self._event_queue.get(timeout=0.1) - self._event_handler.dispatch(event) - self._event_queue.task_done() - self._process_commands(event) - except queue.Empty: - time.sleep(0.1) - - self._process_commands() - if paused: - self._drain_events_until_idle() - else: - self._drain_event_queue() - - except Exception as e: - logger.exception("Dispatcher error") - self._execution_coordinator.mark_failed(e) - - finally: - self._execution_coordinator.mark_complete() - # Signal the event emitter that execution is complete - if self._event_emitter: - self._event_emitter.mark_complete() - - def _process_commands(self, event: GraphNodeEventBase | None = None): - if event is None or isinstance(event, self._COMMAND_TRIGGER_EVENTS): - self._execution_coordinator.process_commands() - - def _drain_event_queue(self) -> None: - while True: - try: - event = self._event_queue.get(block=False) - self._event_handler.dispatch(event) - self._event_queue.task_done() - except queue.Empty: - break - - def _drain_events_until_idle(self) -> None: - while not self._stop_event.is_set(): - try: - event = self._event_queue.get(timeout=0.1) - self._event_handler.dispatch(event) - self._event_queue.task_done() - self._process_commands(event) - except queue.Empty: - if not self._execution_coordinator.has_executing_nodes(): - break - self._drain_event_queue() diff --git a/api/dify_graph/graph_engine/orchestration/execution_coordinator.py b/api/dify_graph/graph_engine/orchestration/execution_coordinator.py deleted file mode 100644 index 0f8550eb123..00000000000 --- a/api/dify_graph/graph_engine/orchestration/execution_coordinator.py +++ /dev/null @@ -1,104 +0,0 @@ -""" -Execution coordinator for managing overall workflow execution. -""" - -from typing import final - -from ..command_processing import CommandProcessor -from ..domain import GraphExecution -from ..graph_state_manager import GraphStateManager -from ..worker_management import WorkerPool - - -@final -class ExecutionCoordinator: - """ - Coordinates overall execution flow between subsystems. - - This provides high-level coordination methods used by the - dispatcher to manage execution state. - """ - - def __init__( - self, - graph_execution: GraphExecution, - state_manager: GraphStateManager, - command_processor: CommandProcessor, - worker_pool: WorkerPool, - ) -> None: - """ - Initialize the execution coordinator. - - Args: - graph_execution: Graph execution aggregate - state_manager: Unified state manager - command_processor: Processor for commands - worker_pool: Pool of workers - """ - self._graph_execution = graph_execution - self._state_manager = state_manager - self._command_processor = command_processor - self._worker_pool = worker_pool - - def process_commands(self) -> None: - """Process any pending commands.""" - self._command_processor.process_commands() - - def check_scaling(self) -> None: - """Check and perform worker scaling if needed.""" - self._worker_pool.check_and_scale() - - @property - def execution_complete(self): - return self._state_manager.is_execution_complete() - - @property - def aborted(self): - return self._graph_execution.aborted or self._graph_execution.has_error - - @property - def paused(self) -> bool: - """Expose whether the underlying graph execution is paused.""" - return self._graph_execution.is_paused - - def mark_complete(self) -> None: - """Mark execution as complete.""" - if self._graph_execution.is_paused: - return - if not self._graph_execution.completed: - self._graph_execution.complete() - - def mark_failed(self, error: Exception) -> None: - """ - Mark execution as failed. - - Args: - error: The error that caused failure - """ - self._graph_execution.fail(error) - - def handle_pause_if_needed(self) -> None: - """If the execution has been paused, stop workers immediately.""" - - if not self._graph_execution.is_paused: - return - - self._worker_pool.stop() - self._state_manager.clear_executing() - - def handle_abort_if_needed(self) -> None: - """If the execution has been aborted, stop workers immediately.""" - - if not self._graph_execution.aborted: - return - - self._worker_pool.stop() - self._state_manager.clear_executing() - - def has_executing_nodes(self) -> bool: - """Return True if any nodes are currently marked as executing.""" - # This check is only safe once execution has already paused. - # Before pause, executing state can change concurrently, which makes the result unreliable. - if not self._graph_execution.is_paused: - raise AssertionError("has_executing_nodes should only be called after execution is paused") - return self._state_manager.get_executing_count() > 0 diff --git a/api/dify_graph/graph_engine/protocols/command_channel.py b/api/dify_graph/graph_engine/protocols/command_channel.py deleted file mode 100644 index fabd8634c8b..00000000000 --- a/api/dify_graph/graph_engine/protocols/command_channel.py +++ /dev/null @@ -1,41 +0,0 @@ -""" -CommandChannel protocol for GraphEngine command communication. - -This protocol defines the interface for sending and receiving commands -to/from a GraphEngine instance, supporting both local and distributed scenarios. -""" - -from typing import Protocol - -from ..entities.commands import GraphEngineCommand - - -class CommandChannel(Protocol): - """ - Protocol for bidirectional command communication with GraphEngine. - - Since each GraphEngine instance processes only one workflow execution, - this channel is dedicated to that single execution. - """ - - def fetch_commands(self) -> list[GraphEngineCommand]: - """ - Fetch pending commands for this GraphEngine instance. - - Called by GraphEngine to poll for commands that need to be processed. - - Returns: - List of pending commands (may be empty) - """ - ... - - def send_command(self, command: GraphEngineCommand) -> None: - """ - Send a command to be processed by this GraphEngine instance. - - Called by external systems to send control commands to the running workflow. - - Args: - command: The command to send - """ - ... diff --git a/api/dify_graph/graph_engine/ready_queue/__init__.py b/api/dify_graph/graph_engine/ready_queue/__init__.py deleted file mode 100644 index acba0e961c8..00000000000 --- a/api/dify_graph/graph_engine/ready_queue/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -""" -Ready queue implementations for GraphEngine. - -This package contains the protocol and implementations for managing -the queue of nodes ready for execution. -""" - -from .factory import create_ready_queue_from_state -from .in_memory import InMemoryReadyQueue -from .protocol import ReadyQueue, ReadyQueueState - -__all__ = ["InMemoryReadyQueue", "ReadyQueue", "ReadyQueueState", "create_ready_queue_from_state"] diff --git a/api/dify_graph/graph_engine/ready_queue/factory.py b/api/dify_graph/graph_engine/ready_queue/factory.py deleted file mode 100644 index a9d4f470e53..00000000000 --- a/api/dify_graph/graph_engine/ready_queue/factory.py +++ /dev/null @@ -1,37 +0,0 @@ -""" -Factory for creating ReadyQueue instances from serialized state. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -from .in_memory import InMemoryReadyQueue -from .protocol import ReadyQueueState - -if TYPE_CHECKING: - from .protocol import ReadyQueue - - -def create_ready_queue_from_state(state: ReadyQueueState) -> ReadyQueue: - """ - Create a ReadyQueue instance from a serialized state. - - Args: - state: The serialized queue state (Pydantic model, dict, or JSON string), or None for a new empty queue - - Returns: - A ReadyQueue instance initialized with the given state - - Raises: - ValueError: If the queue type is unknown or version is unsupported - """ - if state.type == "InMemoryReadyQueue": - if state.version != "1.0": - raise ValueError(f"Unsupported InMemoryReadyQueue version: {state.version}") - queue = InMemoryReadyQueue() - # Always pass as JSON string to loads() - queue.loads(state.model_dump_json()) - return queue - else: - raise ValueError(f"Unknown ready queue type: {state.type}") diff --git a/api/dify_graph/graph_engine/ready_queue/in_memory.py b/api/dify_graph/graph_engine/ready_queue/in_memory.py deleted file mode 100644 index f2c265ece09..00000000000 --- a/api/dify_graph/graph_engine/ready_queue/in_memory.py +++ /dev/null @@ -1,140 +0,0 @@ -""" -In-memory implementation of the ReadyQueue protocol. - -This implementation wraps Python's standard queue.Queue and adds -serialization capabilities for state storage. -""" - -import queue -from typing import final - -from .protocol import ReadyQueue, ReadyQueueState - - -@final -class InMemoryReadyQueue(ReadyQueue): - """ - In-memory ready queue implementation with serialization support. - - This implementation uses Python's queue.Queue internally and provides - methods to serialize and restore the queue state. - """ - - def __init__(self, maxsize: int = 0) -> None: - """ - Initialize the in-memory ready queue. - - Args: - maxsize: Maximum size of the queue (0 for unlimited) - """ - self._queue: queue.Queue[str] = queue.Queue(maxsize=maxsize) - - def put(self, item: str) -> None: - """ - Add a node ID to the ready queue. - - Args: - item: The node ID to add to the queue - """ - self._queue.put(item) - - def get(self, timeout: float | None = None) -> str: - """ - Retrieve and remove a node ID from the queue. - - Args: - timeout: Maximum time to wait for an item (None for blocking) - - Returns: - The node ID retrieved from the queue - - Raises: - queue.Empty: If timeout expires and no item is available - """ - if timeout is None: - return self._queue.get(block=True) - return self._queue.get(timeout=timeout) - - def task_done(self) -> None: - """ - Indicate that a previously retrieved task is complete. - - Used by worker threads to signal task completion for - join() synchronization. - """ - self._queue.task_done() - - def empty(self) -> bool: - """ - Check if the queue is empty. - - Returns: - True if the queue has no items, False otherwise - """ - return self._queue.empty() - - def qsize(self) -> int: - """ - Get the approximate size of the queue. - - Returns: - The approximate number of items in the queue - """ - return self._queue.qsize() - - def dumps(self) -> str: - """ - Serialize the queue state to a JSON string for storage. - - Returns: - A JSON string containing the serialized queue state - """ - # Extract all items from the queue without removing them - items: list[str] = [] - temp_items: list[str] = [] - - # Drain the queue temporarily to get all items - while not self._queue.empty(): - try: - item = self._queue.get_nowait() - temp_items.append(item) - items.append(item) - except queue.Empty: - break - - # Put items back in the same order - for item in temp_items: - self._queue.put(item) - - state = ReadyQueueState( - type="InMemoryReadyQueue", - version="1.0", - items=items, - ) - return state.model_dump_json() - - def loads(self, data: str) -> None: - """ - Restore the queue state from a JSON string. - - Args: - data: The JSON string containing the serialized queue state to restore - """ - state = ReadyQueueState.model_validate_json(data) - - if state.type != "InMemoryReadyQueue": - raise ValueError(f"Invalid serialized data type: {state.type}") - - if state.version != "1.0": - raise ValueError(f"Unsupported version: {state.version}") - - # Clear the current queue - while not self._queue.empty(): - try: - self._queue.get_nowait() - except queue.Empty: - break - - # Restore items - for item in state.items: - self._queue.put(item) diff --git a/api/dify_graph/graph_engine/ready_queue/protocol.py b/api/dify_graph/graph_engine/ready_queue/protocol.py deleted file mode 100644 index 97d3ea6dd2c..00000000000 --- a/api/dify_graph/graph_engine/ready_queue/protocol.py +++ /dev/null @@ -1,104 +0,0 @@ -""" -ReadyQueue protocol for GraphEngine node execution queue. - -This protocol defines the interface for managing the queue of nodes ready -for execution, supporting both in-memory and persistent storage scenarios. -""" - -from collections.abc import Sequence -from typing import Protocol - -from pydantic import BaseModel, Field - - -class ReadyQueueState(BaseModel): - """ - Pydantic model for serialized ready queue state. - - This defines the structure of the data returned by dumps() - and expected by loads() for ready queue serialization. - """ - - type: str = Field(description="Queue implementation type (e.g., 'InMemoryReadyQueue')") - version: str = Field(description="Serialization format version") - items: Sequence[str] = Field(default_factory=list, description="List of node IDs in the queue") - - -class ReadyQueue(Protocol): - """ - Protocol for managing nodes ready for execution in GraphEngine. - - This protocol defines the interface that any ready queue implementation - must provide, enabling both in-memory queues and persistent queues - that can be serialized for state storage. - """ - - def put(self, item: str) -> None: - """ - Add a node ID to the ready queue. - - Args: - item: The node ID to add to the queue - """ - ... - - def get(self, timeout: float | None = None) -> str: - """ - Retrieve and remove a node ID from the queue. - - Args: - timeout: Maximum time to wait for an item (None for blocking) - - Returns: - The node ID retrieved from the queue - - Raises: - queue.Empty: If timeout expires and no item is available - """ - ... - - def task_done(self) -> None: - """ - Indicate that a previously retrieved task is complete. - - Used by worker threads to signal task completion for - join() synchronization. - """ - ... - - def empty(self) -> bool: - """ - Check if the queue is empty. - - Returns: - True if the queue has no items, False otherwise - """ - ... - - def qsize(self) -> int: - """ - Get the approximate size of the queue. - - Returns: - The approximate number of items in the queue - """ - ... - - def dumps(self) -> str: - """ - Serialize the queue state to a JSON string for storage. - - Returns: - A JSON string containing the serialized queue state - that can be persisted and later restored - """ - ... - - def loads(self, data: str) -> None: - """ - Restore the queue state from a JSON string. - - Args: - data: The JSON string containing the serialized queue state to restore - """ - ... diff --git a/api/dify_graph/graph_engine/response_coordinator/__init__.py b/api/dify_graph/graph_engine/response_coordinator/__init__.py deleted file mode 100644 index e11d31199c2..00000000000 --- a/api/dify_graph/graph_engine/response_coordinator/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -""" -ResponseStreamCoordinator - Coordinates streaming output from response nodes - -This component manages response streaming sessions and ensures ordered streaming -of responses based on upstream node outputs and constants. -""" - -from .coordinator import ResponseStreamCoordinator - -__all__ = ["ResponseStreamCoordinator"] diff --git a/api/dify_graph/graph_engine/response_coordinator/coordinator.py b/api/dify_graph/graph_engine/response_coordinator/coordinator.py deleted file mode 100644 index 941a8a496b4..00000000000 --- a/api/dify_graph/graph_engine/response_coordinator/coordinator.py +++ /dev/null @@ -1,697 +0,0 @@ -""" -Main ResponseStreamCoordinator implementation. - -This module contains the public ResponseStreamCoordinator class that manages -response streaming sessions and ensures ordered streaming of responses. -""" - -import logging -from collections import deque -from collections.abc import Sequence -from threading import RLock -from typing import Literal, TypeAlias, final -from uuid import uuid4 - -from pydantic import BaseModel, Field - -from dify_graph.enums import NodeExecutionType, NodeState -from dify_graph.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent -from dify_graph.nodes.base.template import TextSegment, VariableSegment -from dify_graph.runtime import VariablePool -from dify_graph.runtime.graph_runtime_state import GraphProtocol - -from .path import Path -from .session import ResponseSession - -logger = logging.getLogger(__name__) - -# Type definitions -NodeID: TypeAlias = str -EdgeID: TypeAlias = str - - -class ResponseSessionState(BaseModel): - """Serializable representation of a response session.""" - - node_id: str - index: int = Field(default=0, ge=0) - - -class StreamBufferState(BaseModel): - """Serializable representation of buffered stream chunks.""" - - selector: tuple[str, ...] - events: list[NodeRunStreamChunkEvent] = Field(default_factory=list) - - -class StreamPositionState(BaseModel): - """Serializable representation for stream read positions.""" - - selector: tuple[str, ...] - position: int = Field(default=0, ge=0) - - -class ResponseStreamCoordinatorState(BaseModel): - """Serialized snapshot of ResponseStreamCoordinator.""" - - type: Literal["ResponseStreamCoordinator"] = Field(default="ResponseStreamCoordinator") - version: str = Field(default="1.0") - response_nodes: Sequence[str] = Field(default_factory=list) - active_session: ResponseSessionState | None = None - waiting_sessions: Sequence[ResponseSessionState] = Field(default_factory=list) - pending_sessions: Sequence[ResponseSessionState] = Field(default_factory=list) - node_execution_ids: dict[str, str] = Field(default_factory=dict) - paths_map: dict[str, list[list[str]]] = Field(default_factory=dict) - stream_buffers: Sequence[StreamBufferState] = Field(default_factory=list) - stream_positions: Sequence[StreamPositionState] = Field(default_factory=list) - closed_streams: Sequence[tuple[str, ...]] = Field(default_factory=list) - - -@final -class ResponseStreamCoordinator: - """ - Manages response streaming sessions without relying on global state. - - Ensures ordered streaming of responses based on upstream node outputs and constants. - """ - - def __init__(self, variable_pool: "VariablePool", graph: GraphProtocol) -> None: - """ - Initialize coordinator with variable pool. - - Args: - variable_pool: VariablePool instance for accessing node variables - graph: Graph instance for looking up node information - """ - self._variable_pool = variable_pool - self._graph = graph - self._active_session: ResponseSession | None = None - self._waiting_sessions: deque[ResponseSession] = deque() - self._lock = RLock() - - # Internal stream management (replacing OutputRegistry) - self._stream_buffers: dict[tuple[str, ...], list[NodeRunStreamChunkEvent]] = {} - self._stream_positions: dict[tuple[str, ...], int] = {} - self._closed_streams: set[tuple[str, ...]] = set() - - # Track response nodes - self._response_nodes: set[NodeID] = set() - - # Store paths for each response node - self._paths_maps: dict[NodeID, list[Path]] = {} - - # Track node execution IDs and types for proper event forwarding - self._node_execution_ids: dict[NodeID, str] = {} # node_id -> execution_id - - # Track response sessions to ensure only one per node - self._response_sessions: dict[NodeID, ResponseSession] = {} # node_id -> session - - def register(self, response_node_id: NodeID) -> None: - with self._lock: - if response_node_id in self._response_nodes: - return - self._response_nodes.add(response_node_id) - - # Build and save paths map for this response node - paths_map = self._build_paths_map(response_node_id) - self._paths_maps[response_node_id] = paths_map - - # Create and store response session for this node - response_node = self._graph.nodes[response_node_id] - session = ResponseSession.from_node(response_node) - self._response_sessions[response_node_id] = session - - def track_node_execution(self, node_id: NodeID, execution_id: str) -> None: - """Track the execution ID for a node when it starts executing. - - Args: - node_id: The ID of the node - execution_id: The execution ID from NodeRunStartedEvent - """ - with self._lock: - self._node_execution_ids[node_id] = execution_id - - def _get_or_create_execution_id(self, node_id: NodeID) -> str: - """Get the execution ID for a node, creating one if it doesn't exist. - - Args: - node_id: The ID of the node - - Returns: - The execution ID for the node - """ - with self._lock: - if node_id not in self._node_execution_ids: - self._node_execution_ids[node_id] = str(uuid4()) - return self._node_execution_ids[node_id] - - def _build_paths_map(self, response_node_id: NodeID) -> list[Path]: - """ - Build a paths map for a response node by finding all paths from root node - to the response node, recording branch edges along each path. - - Args: - response_node_id: ID of the response node to analyze - - Returns: - List of Path objects, where each path contains branch edge IDs - """ - # Get root node ID - root_node_id = self._graph.root_node.id - - # If root is the response node, return empty path - if root_node_id == response_node_id: - return [Path()] - - # Extract variable selectors from the response node's template - response_node = self._graph.nodes[response_node_id] - response_session = ResponseSession.from_node(response_node) - template = response_session.template - - # Collect all variable selectors from the template - variable_selectors: set[tuple[str, ...]] = set() - for segment in template.segments: - if isinstance(segment, VariableSegment): - variable_selectors.add(tuple(segment.selector[:2])) - - # Step 1: Find all complete paths from root to response node - all_complete_paths: list[list[EdgeID]] = [] - - def find_paths( - current_node_id: NodeID, target_node_id: NodeID, current_path: list[EdgeID], visited: set[NodeID] - ) -> None: - """Recursively find all paths from current node to target node.""" - if current_node_id == target_node_id: - # Found a complete path, store it - all_complete_paths.append(current_path.copy()) - return - - # Mark as visited to avoid cycles - visited.add(current_node_id) - - # Explore outgoing edges - outgoing_edges = self._graph.get_outgoing_edges(current_node_id) - for edge in outgoing_edges: - edge_id = edge.id - next_node_id = edge.head - - # Skip if already visited in this path - if next_node_id not in visited: - # Add edge to path and recurse - new_path = current_path + [edge_id] - find_paths(next_node_id, target_node_id, new_path, visited.copy()) - - # Start searching from root node - find_paths(root_node_id, response_node_id, [], set()) - - # Step 2: For each complete path, filter edges based on node blocking behavior - filtered_paths: list[Path] = [] - for path in all_complete_paths: - blocking_edges: list[str] = [] - for edge_id in path: - edge = self._graph.edges[edge_id] - source_node = self._graph.nodes[edge.tail] - - # Check if node is a branch, container, or response node - if source_node.execution_type in { - NodeExecutionType.BRANCH, - NodeExecutionType.CONTAINER, - NodeExecutionType.RESPONSE, - } or source_node.blocks_variable_output(variable_selectors): - blocking_edges.append(edge_id) - - # Keep the path even if it's empty - filtered_paths.append(Path(edges=blocking_edges)) - - return filtered_paths - - def on_edge_taken(self, edge_id: str) -> Sequence[NodeRunStreamChunkEvent]: - """ - Handle when an edge is taken (selected by a branch node). - - This method updates the paths for all response nodes by removing - the taken edge. If any response node has an empty path after removal, - it means the node is now deterministically reachable and should start. - - Args: - edge_id: The ID of the edge that was taken - - Returns: - List of events to emit from starting new sessions - """ - events: list[NodeRunStreamChunkEvent] = [] - - with self._lock: - # Check each response node in order - for response_node_id in self._response_nodes: - if response_node_id not in self._paths_maps: - continue - - paths = self._paths_maps[response_node_id] - has_reachable_path = False - - # Update each path by removing the taken edge - for path in paths: - # Remove the taken edge from this path - path.remove_edge(edge_id) - - # Check if this path is now empty (node is reachable) - if path.is_empty(): - has_reachable_path = True - - # If node is now reachable (has empty path), start/queue session - if has_reachable_path: - # Pass the node_id to the activation method - # The method will handle checking and removing from map - events.extend(self._active_or_queue_session(response_node_id)) - return events - - def _active_or_queue_session(self, node_id: str) -> Sequence[NodeRunStreamChunkEvent]: - """ - Start a session immediately if no active session, otherwise queue it. - Only activates sessions that exist in the _response_sessions map. - - Args: - node_id: The ID of the response node to activate - - Returns: - List of events from flush attempt if session started immediately - """ - events: list[NodeRunStreamChunkEvent] = [] - - # Get the session from our map (only activate if it exists) - session = self._response_sessions.get(node_id) - if not session: - return events - - # Remove from map to ensure it won't be activated again - del self._response_sessions[node_id] - - if self._active_session is None: - self._active_session = session - - # Try to flush immediately - events.extend(self.try_flush()) - else: - # Queue the session if another is active - self._waiting_sessions.append(session) - - return events - - def intercept_event( - self, event: NodeRunStreamChunkEvent | NodeRunSucceededEvent - ) -> Sequence[NodeRunStreamChunkEvent]: - with self._lock: - if isinstance(event, NodeRunStreamChunkEvent): - self._append_stream_chunk(event.selector, event) - if event.is_final: - self._close_stream(event.selector) - return self.try_flush() - else: - # Skip cause we share the same variable pool. - # - # for variable_name, variable_value in event.node_run_result.outputs.items(): - # self._variable_pool.add((event.node_id, variable_name), variable_value) - return self.try_flush() - - def _create_stream_chunk_event( - self, - node_id: str, - execution_id: str, - selector: Sequence[str], - chunk: str, - is_final: bool = False, - ) -> NodeRunStreamChunkEvent: - """Create a stream chunk event with consistent structure. - - For selectors with special prefixes (sys, env, conversation), we use the - active response node's information since these are not actual node IDs. - """ - # Check if this is a special selector that doesn't correspond to a node - if selector and selector[0] not in self._graph.nodes and self._active_session: - # Use the active response node for special selectors - response_node = self._graph.nodes[self._active_session.node_id] - return NodeRunStreamChunkEvent( - id=execution_id, - node_id=response_node.id, - node_type=response_node.node_type, - selector=selector, - chunk=chunk, - is_final=is_final, - ) - - # Standard case: selector refers to an actual node - node = self._graph.nodes[node_id] - return NodeRunStreamChunkEvent( - id=execution_id, - node_id=node.id, - node_type=node.node_type, - selector=selector, - chunk=chunk, - is_final=is_final, - ) - - def _process_variable_segment(self, segment: VariableSegment) -> tuple[Sequence[NodeRunStreamChunkEvent], bool]: - """Process a variable segment. Returns (events, is_complete). - - Handles both regular node selectors and special system selectors (sys, env, conversation). - For special selectors, we attribute the output to the active response node. - """ - events: list[NodeRunStreamChunkEvent] = [] - source_selector_prefix = segment.selector[0] if segment.selector else "" - is_complete = False - - # Determine which node to attribute the output to - # For special selectors (sys, env, conversation), use the active response node - # For regular selectors, use the source node - if self._active_session and source_selector_prefix not in self._graph.nodes: - # Special selector - use active response node - output_node_id = self._active_session.node_id - else: - # Regular node selector - output_node_id = source_selector_prefix - execution_id = self._get_or_create_execution_id(output_node_id) - - # Stream all available chunks - while self._has_unread_stream(segment.selector): - if event := self._pop_stream_chunk(segment.selector): - # For special selectors, we need to update the event to use - # the active response node's information - if self._active_session and source_selector_prefix not in self._graph.nodes: - response_node = self._graph.nodes[self._active_session.node_id] - # Create a new event with the response node's information - # but keep the original selector - updated_event = NodeRunStreamChunkEvent( - id=execution_id, - node_id=response_node.id, - node_type=response_node.node_type, - selector=event.selector, # Keep original selector - chunk=event.chunk, - is_final=event.is_final, - ) - events.append(updated_event) - else: - # Regular node selector - use event as is - events.append(event) - - # Check if this is the last chunk by looking ahead - stream_closed = self._is_stream_closed(segment.selector) - # Check if stream is closed to determine if segment is complete - if stream_closed: - is_complete = True - - elif value := self._variable_pool.get(segment.selector): - # Process scalar value - is_last_segment = bool( - self._active_session and self._active_session.index == len(self._active_session.template.segments) - 1 - ) - events.append( - self._create_stream_chunk_event( - node_id=output_node_id, - execution_id=execution_id, - selector=segment.selector, - chunk=value.markdown, - is_final=is_last_segment, - ) - ) - is_complete = True - - return events, is_complete - - def _process_text_segment(self, segment: TextSegment) -> Sequence[NodeRunStreamChunkEvent]: - """Process a text segment. Returns (events, is_complete).""" - assert self._active_session is not None - current_response_node = self._graph.nodes[self._active_session.node_id] - - # Use get_or_create_execution_id to ensure we have a consistent ID - execution_id = self._get_or_create_execution_id(current_response_node.id) - - is_last_segment = self._active_session.index == len(self._active_session.template.segments) - 1 - event = self._create_stream_chunk_event( - node_id=current_response_node.id, - execution_id=execution_id, - selector=[current_response_node.id, "answer"], # FIXME(-LAN-) - chunk=segment.text, - is_final=is_last_segment, - ) - return [event] - - def try_flush(self) -> list[NodeRunStreamChunkEvent]: - with self._lock: - if not self._active_session: - return [] - - template = self._active_session.template - response_node_id = self._active_session.node_id - - events: list[NodeRunStreamChunkEvent] = [] - - # Process segments sequentially from current index - while self._active_session.index < len(template.segments): - segment = template.segments[self._active_session.index] - - if isinstance(segment, VariableSegment): - # Check if the source node for this variable is skipped - # Only check for actual nodes, not special selectors (sys, env, conversation) - source_selector_prefix = segment.selector[0] if segment.selector else "" - if source_selector_prefix in self._graph.nodes: - source_node = self._graph.nodes[source_selector_prefix] - - if source_node.state == NodeState.SKIPPED: - # Skip this variable segment if the source node is skipped - self._active_session.index += 1 - continue - - segment_events, is_complete = self._process_variable_segment(segment) - events.extend(segment_events) - - # Only advance index if this variable segment is complete - if is_complete: - self._active_session.index += 1 - else: - # Wait for more data - break - - else: - segment_events = self._process_text_segment(segment) - events.extend(segment_events) - self._active_session.index += 1 - - if self._active_session.is_complete(): - # End current session and get events from starting next session - next_session_events = self.end_session(response_node_id) - events.extend(next_session_events) - - return events - - def end_session(self, node_id: str) -> list[NodeRunStreamChunkEvent]: - """ - End the active session for a response node. - Automatically starts the next waiting session if available. - - Args: - node_id: ID of the response node ending its session - - Returns: - List of events from starting the next session - """ - with self._lock: - events: list[NodeRunStreamChunkEvent] = [] - - if self._active_session and self._active_session.node_id == node_id: - self._active_session = None - - # Try to start next waiting session - if self._waiting_sessions: - next_session = self._waiting_sessions.popleft() - self._active_session = next_session - - # Immediately try to flush any available segments - events = self.try_flush() - - return events - - # ============= Internal Stream Management Methods ============= - - def _append_stream_chunk(self, selector: Sequence[str], event: NodeRunStreamChunkEvent) -> None: - """ - Append a stream chunk to the internal buffer. - - Args: - selector: List of strings identifying the stream location - event: The NodeRunStreamChunkEvent to append - - Raises: - ValueError: If the stream is already closed - """ - key = tuple(selector) - - if key in self._closed_streams: - raise ValueError(f"Stream {'.'.join(selector)} is already closed") - - if key not in self._stream_buffers: - self._stream_buffers[key] = [] - self._stream_positions[key] = 0 - - self._stream_buffers[key].append(event) - - def _pop_stream_chunk(self, selector: Sequence[str]) -> NodeRunStreamChunkEvent | None: - """ - Pop the next unread stream chunk from the buffer. - - Args: - selector: List of strings identifying the stream location - - Returns: - The next event, or None if no unread events available - """ - key = tuple(selector) - - if key not in self._stream_buffers: - return None - - position = self._stream_positions.get(key, 0) - buffer = self._stream_buffers[key] - - if position >= len(buffer): - return None - - event = buffer[position] - self._stream_positions[key] = position + 1 - return event - - def _has_unread_stream(self, selector: Sequence[str]) -> bool: - """ - Check if the stream has unread events. - - Args: - selector: List of strings identifying the stream location - - Returns: - True if there are unread events, False otherwise - """ - key = tuple(selector) - - if key not in self._stream_buffers: - return False - - position = self._stream_positions.get(key, 0) - return position < len(self._stream_buffers[key]) - - def _close_stream(self, selector: Sequence[str]) -> None: - """ - Mark a stream as closed (no more chunks can be appended). - - Args: - selector: List of strings identifying the stream location - """ - key = tuple(selector) - self._closed_streams.add(key) - - def _is_stream_closed(self, selector: Sequence[str]) -> bool: - """ - Check if a stream is closed. - - Args: - selector: List of strings identifying the stream location - - Returns: - True if the stream is closed, False otherwise - """ - key = tuple(selector) - return key in self._closed_streams - - def _serialize_session(self, session: ResponseSession | None) -> ResponseSessionState | None: - """Convert an in-memory session into its serializable form.""" - - if session is None: - return None - return ResponseSessionState(node_id=session.node_id, index=session.index) - - def _session_from_state(self, session_state: ResponseSessionState) -> ResponseSession: - """Rebuild a response session from serialized data.""" - - node = self._graph.nodes.get(session_state.node_id) - if node is None: - raise ValueError(f"Unknown response node '{session_state.node_id}' in serialized state") - - session = ResponseSession.from_node(node) - session.index = session_state.index - return session - - def dumps(self) -> str: - """Serialize coordinator state to JSON.""" - - with self._lock: - state = ResponseStreamCoordinatorState( - response_nodes=sorted(self._response_nodes), - active_session=self._serialize_session(self._active_session), - waiting_sessions=[ - session_state - for session in list(self._waiting_sessions) - if (session_state := self._serialize_session(session)) is not None - ], - pending_sessions=[ - session_state - for _, session in sorted(self._response_sessions.items()) - if (session_state := self._serialize_session(session)) is not None - ], - node_execution_ids=dict(sorted(self._node_execution_ids.items())), - paths_map={ - node_id: [path.edges.copy() for path in paths] - for node_id, paths in sorted(self._paths_maps.items()) - }, - stream_buffers=[ - StreamBufferState( - selector=selector, - events=[event.model_copy(deep=True) for event in events], - ) - for selector, events in sorted(self._stream_buffers.items()) - ], - stream_positions=[ - StreamPositionState(selector=selector, position=position) - for selector, position in sorted(self._stream_positions.items()) - ], - closed_streams=sorted(self._closed_streams), - ) - return state.model_dump_json() - - def loads(self, data: str) -> None: - """Restore coordinator state from JSON.""" - - state = ResponseStreamCoordinatorState.model_validate_json(data) - - if state.type != "ResponseStreamCoordinator": - raise ValueError(f"Invalid serialized data type: {state.type}") - - if state.version != "1.0": - raise ValueError(f"Unsupported serialized version: {state.version}") - - with self._lock: - self._response_nodes = set(state.response_nodes) - self._paths_maps = { - node_id: [Path(edges=list(path_edges)) for path_edges in paths] - for node_id, paths in state.paths_map.items() - } - self._node_execution_ids = dict(state.node_execution_ids) - - self._stream_buffers = { - tuple(buffer.selector): [event.model_copy(deep=True) for event in buffer.events] - for buffer in state.stream_buffers - } - self._stream_positions = { - tuple(position.selector): position.position for position in state.stream_positions - } - for selector in self._stream_buffers: - self._stream_positions.setdefault(selector, 0) - - self._closed_streams = {tuple(selector) for selector in state.closed_streams} - - self._waiting_sessions = deque( - self._session_from_state(session_state) for session_state in state.waiting_sessions - ) - self._response_sessions = { - session_state.node_id: self._session_from_state(session_state) - for session_state in state.pending_sessions - } - self._active_session = self._session_from_state(state.active_session) if state.active_session else None diff --git a/api/dify_graph/graph_engine/response_coordinator/path.py b/api/dify_graph/graph_engine/response_coordinator/path.py deleted file mode 100644 index 50f2f4eb217..00000000000 --- a/api/dify_graph/graph_engine/response_coordinator/path.py +++ /dev/null @@ -1,35 +0,0 @@ -""" -Internal path representation for response coordinator. - -This module contains the private Path class used internally by ResponseStreamCoordinator -to track execution paths to response nodes. -""" - -from dataclasses import dataclass, field -from typing import TypeAlias - -EdgeID: TypeAlias = str - - -@dataclass -class Path: - """ - Represents a path of branch edges that must be taken to reach a response node. - - Note: This is an internal class not exposed in the public API. - """ - - edges: list[EdgeID] = field(default_factory=list[EdgeID]) - - def contains_edge(self, edge_id: EdgeID) -> bool: - """Check if this path contains the given edge.""" - return edge_id in self.edges - - def remove_edge(self, edge_id: EdgeID) -> None: - """Remove the given edge from this path in place.""" - if self.contains_edge(edge_id): - self.edges.remove(edge_id) - - def is_empty(self) -> bool: - """Check if the path has no edges (node is reachable).""" - return len(self.edges) == 0 diff --git a/api/dify_graph/graph_engine/response_coordinator/session.py b/api/dify_graph/graph_engine/response_coordinator/session.py deleted file mode 100644 index 11a9f5dac56..00000000000 --- a/api/dify_graph/graph_engine/response_coordinator/session.py +++ /dev/null @@ -1,66 +0,0 @@ -""" -Internal response session management for response coordinator. - -This module contains the private ResponseSession class used internally -by ResponseStreamCoordinator to manage streaming sessions. -""" - -from __future__ import annotations - -from dataclasses import dataclass -from typing import Protocol, cast - -from dify_graph.nodes.base.template import Template -from dify_graph.runtime.graph_runtime_state import NodeProtocol - - -class _ResponseSessionNodeProtocol(NodeProtocol, Protocol): - """Structural contract required from nodes that can open a response session.""" - - def get_streaming_template(self) -> Template: ... - - -@dataclass -class ResponseSession: - """ - Represents an active response streaming session. - - Note: This is an internal class not exposed in the public API. - """ - - node_id: str - template: Template # Template object from the response node - index: int = 0 # Current position in the template segments - - @classmethod - def from_node(cls, node: NodeProtocol) -> ResponseSession: - """ - Create a ResponseSession from a response-capable node. - - The parameter is typed as `NodeProtocol` because the graph is exposed behind a protocol at the runtime layer. - At runtime this must be a node that implements `get_streaming_template()`. The coordinator decides which - graph nodes should be treated as response-capable before they reach this factory. - - Args: - node: Node from the materialized workflow graph. - - Returns: - ResponseSession configured with the node's streaming template - - Raises: - TypeError: If node does not implement the response-session streaming contract. - """ - response_node = cast(_ResponseSessionNodeProtocol, node) - try: - template = response_node.get_streaming_template() - except AttributeError as exc: - raise TypeError("ResponseSession.from_node requires get_streaming_template() on response nodes") from exc - - return cls( - node_id=node.id, - template=template, - ) - - def is_complete(self) -> bool: - """Check if all segments in the template have been processed.""" - return self.index >= len(self.template.segments) diff --git a/api/dify_graph/graph_engine/worker.py b/api/dify_graph/graph_engine/worker.py deleted file mode 100644 index 988c20d72a2..00000000000 --- a/api/dify_graph/graph_engine/worker.py +++ /dev/null @@ -1,205 +0,0 @@ -""" -Worker - Thread implementation for queue-based node execution - -Workers pull node IDs from the ready_queue, execute nodes, and push events -to the event_queue for the dispatcher to process. -""" - -import queue -import threading -import time -from collections.abc import Sequence -from datetime import datetime -from typing import TYPE_CHECKING, final - -from typing_extensions import override - -from dify_graph.context import IExecutionContext -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunStartedEvent, is_node_result_event -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from libs.datetime_utils import naive_utc_now - -from .ready_queue import ReadyQueue - -if TYPE_CHECKING: - pass - - -@final -class Worker(threading.Thread): - """ - Worker thread that executes nodes from the ready queue. - - Workers continuously pull node IDs from the ready_queue, execute the - corresponding nodes, and push the resulting events to the event_queue - for the dispatcher to process. - """ - - def __init__( - self, - ready_queue: ReadyQueue, - event_queue: queue.Queue[GraphNodeEventBase], - graph: Graph, - layers: Sequence[GraphEngineLayer], - worker_id: int = 0, - execution_context: IExecutionContext | None = None, - ) -> None: - """ - Initialize worker thread. - - Args: - ready_queue: Ready queue containing node IDs ready for execution - event_queue: Queue for pushing execution events - graph: Graph containing nodes to execute - layers: Graph engine layers for node execution hooks - worker_id: Unique identifier for this worker - execution_context: Optional execution context for context preservation - """ - super().__init__(name=f"GraphWorker-{worker_id}", daemon=True) - self._ready_queue = ready_queue - self._event_queue = event_queue - self._graph = graph - self._worker_id = worker_id - self._execution_context = execution_context - self._stop_event = threading.Event() - self._layers = layers if layers is not None else [] - self._last_task_time = time.time() - self._current_node_started_at: datetime | None = None - - def stop(self) -> None: - """Signal the worker to stop processing.""" - self._stop_event.set() - - @property - def is_idle(self) -> bool: - """Check if the worker is currently idle.""" - # Worker is idle if it hasn't processed a task recently (within 0.2 seconds) - return (time.time() - self._last_task_time) > 0.2 - - @property - def idle_duration(self) -> float: - """Get the duration in seconds since the worker last processed a task.""" - return time.time() - self._last_task_time - - @property - def worker_id(self) -> int: - """Get the worker's ID.""" - return self._worker_id - - @override - def run(self) -> None: - """ - Main worker loop. - - Continuously pulls node IDs from ready_queue, executes them, - and pushes events to event_queue until stopped. - """ - while not self._stop_event.is_set(): - # Try to get a node ID from the ready queue (with timeout) - try: - node_id = self._ready_queue.get(timeout=0.1) - except queue.Empty: - continue - - self._last_task_time = time.time() - node = self._graph.nodes[node_id] - try: - self._current_node_started_at = None - self._execute_node(node) - self._ready_queue.task_done() - except Exception as e: - self._event_queue.put( - self._build_fallback_failure_event(node, e, started_at=self._current_node_started_at) - ) - finally: - self._current_node_started_at = None - - def _execute_node(self, node: Node) -> None: - """ - Execute a single node and handle its events. - - Args: - node: The node instance to execute - """ - node.ensure_execution_id() - - error: Exception | None = None - result_event: GraphNodeEventBase | None = None - - # Execute the node with preserved context if execution context is provided - if self._execution_context is not None: - with self._execution_context: - self._invoke_node_run_start_hooks(node) - try: - node_events = node.run() - for event in node_events: - if isinstance(event, NodeRunStartedEvent) and event.id == node.execution_id: - self._current_node_started_at = event.start_at - self._event_queue.put(event) - if is_node_result_event(event): - result_event = event - except Exception as exc: - error = exc - raise - finally: - self._invoke_node_run_end_hooks(node, error, result_event) - else: - self._invoke_node_run_start_hooks(node) - try: - node_events = node.run() - for event in node_events: - if isinstance(event, NodeRunStartedEvent) and event.id == node.execution_id: - self._current_node_started_at = event.start_at - self._event_queue.put(event) - if is_node_result_event(event): - result_event = event - except Exception as exc: - error = exc - raise - finally: - self._invoke_node_run_end_hooks(node, error, result_event) - - def _invoke_node_run_start_hooks(self, node: Node) -> None: - """Invoke on_node_run_start hooks for all layers.""" - for layer in self._layers: - try: - layer.on_node_run_start(node) - except Exception: - # Silently ignore layer errors to prevent disrupting node execution - continue - - def _invoke_node_run_end_hooks( - self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None - ) -> None: - """Invoke on_node_run_end hooks for all layers.""" - for layer in self._layers: - try: - layer.on_node_run_end(node, error, result_event) - except Exception: - # Silently ignore layer errors to prevent disrupting node execution - continue - - def _build_fallback_failure_event( - self, node: Node, error: Exception, *, started_at: datetime | None = None - ) -> NodeRunFailedEvent: - """Build a failed event when worker-level execution aborts before a node emits its own result event.""" - failure_time = naive_utc_now() - error_message = str(error) - return NodeRunFailedEvent( - id=node.execution_id, - node_id=node.id, - node_type=node.node_type, - in_iteration_id=None, - error=error_message, - start_at=started_at or failure_time, - finished_at=failure_time, - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=error_message, - error_type=type(error).__name__, - ), - ) diff --git a/api/dify_graph/graph_engine/worker_management/__init__.py b/api/dify_graph/graph_engine/worker_management/__init__.py deleted file mode 100644 index 03de1f6daa7..00000000000 --- a/api/dify_graph/graph_engine/worker_management/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -""" -Worker management subsystem for graph engine. - -This package manages the worker pool, including creation, -scaling, and activity tracking. -""" - -from .worker_pool import WorkerPool - -__all__ = [ - "WorkerPool", -] diff --git a/api/dify_graph/graph_engine/worker_management/worker_pool.py b/api/dify_graph/graph_engine/worker_management/worker_pool.py deleted file mode 100644 index cc930877836..00000000000 --- a/api/dify_graph/graph_engine/worker_management/worker_pool.py +++ /dev/null @@ -1,277 +0,0 @@ -""" -Simple worker pool that consolidates functionality. - -This is a simpler implementation that merges WorkerPool, ActivityTracker, -DynamicScaler, and WorkerFactory into a single class. -""" - -import logging -import queue -import threading -from typing import final - -from dify_graph.context import IExecutionContext -from dify_graph.graph import Graph -from dify_graph.graph_events import GraphNodeEventBase - -from ..config import GraphEngineConfig -from ..layers.base import GraphEngineLayer -from ..ready_queue import ReadyQueue -from ..worker import Worker - -logger = logging.getLogger(__name__) - - -@final -class WorkerPool: - """ - Simple worker pool with integrated management. - - This class consolidates all worker management functionality into - a single, simpler implementation without excessive abstraction. - """ - - def __init__( - self, - ready_queue: ReadyQueue, - event_queue: queue.Queue[GraphNodeEventBase], - graph: Graph, - layers: list[GraphEngineLayer], - config: GraphEngineConfig, - execution_context: IExecutionContext | None = None, - ) -> None: - """ - Initialize the simple worker pool. - - Args: - ready_queue: Ready queue for nodes ready for execution - event_queue: Queue for worker events - graph: The workflow graph - layers: Graph engine layers for node execution hooks - config: GraphEngine worker pool configuration - execution_context: Optional execution context for context preservation - """ - self._ready_queue = ready_queue - self._event_queue = event_queue - self._graph = graph - self._execution_context = execution_context - self._layers = layers - self._config = config - - # Worker management - self._workers: list[Worker] = [] - self._worker_counter = 0 - self._lock = threading.RLock() - self._running = False - - # No longer tracking worker states with callbacks to avoid lock contention - - def start(self, initial_count: int | None = None) -> None: - """ - Start the worker pool. - - Args: - initial_count: Number of workers to start with (auto-calculated if None) - """ - with self._lock: - if self._running: - return - - self._running = True - - # Calculate initial worker count - if initial_count is None: - node_count = len(self._graph.nodes) - if node_count < 10: - initial_count = self._config.min_workers - elif node_count < 50: - initial_count = min(self._config.min_workers + 1, self._config.max_workers) - else: - initial_count = min(self._config.min_workers + 2, self._config.max_workers) - - logger.debug( - "Starting worker pool: %d workers (nodes=%d, min=%d, max=%d)", - initial_count, - node_count, - self._config.min_workers, - self._config.max_workers, - ) - - # Create initial workers - for _ in range(initial_count): - self._create_worker() - - def stop(self) -> None: - """Stop all workers in the pool.""" - with self._lock: - self._running = False - worker_count = len(self._workers) - - if worker_count > 0: - logger.debug("Stopping worker pool: %d workers", worker_count) - - # Stop all workers - for worker in self._workers: - worker.stop() - - # Wait for workers to finish - for worker in self._workers: - if worker.is_alive(): - worker.join(timeout=2.0) - - self._workers.clear() - - def _create_worker(self) -> None: - """Create and start a new worker.""" - worker_id = self._worker_counter - self._worker_counter += 1 - - worker = Worker( - ready_queue=self._ready_queue, - event_queue=self._event_queue, - graph=self._graph, - layers=self._layers, - worker_id=worker_id, - execution_context=self._execution_context, - ) - - worker.start() - self._workers.append(worker) - - def _remove_worker(self, worker: Worker, worker_id: int) -> None: - """Remove a specific worker from the pool.""" - # Stop the worker - worker.stop() - - # Wait for it to finish - if worker.is_alive(): - worker.join(timeout=2.0) - - # Remove from list - if worker in self._workers: - self._workers.remove(worker) - - def _try_scale_up(self, queue_depth: int, current_count: int) -> bool: - """ - Try to scale up workers if needed. - - Args: - queue_depth: Current queue depth - current_count: Current number of workers - - Returns: - True if scaled up, False otherwise - """ - if queue_depth > self._config.scale_up_threshold and current_count < self._config.max_workers: - old_count = current_count - self._create_worker() - - logger.debug( - "Scaled up workers: %d -> %d (queue_depth=%d exceeded threshold=%d)", - old_count, - len(self._workers), - queue_depth, - self._config.scale_up_threshold, - ) - return True - return False - - def _try_scale_down(self, queue_depth: int, current_count: int, active_count: int, idle_count: int) -> bool: - """ - Try to scale down workers if we have excess capacity. - - Args: - queue_depth: Current queue depth - current_count: Current number of workers - active_count: Number of active workers - idle_count: Number of idle workers - - Returns: - True if scaled down, False otherwise - """ - # Skip if we're at minimum or have no idle workers - if current_count <= self._config.min_workers or idle_count == 0: - return False - - # Check if we have excess capacity - has_excess_capacity = ( - queue_depth <= active_count # Active workers can handle current queue - or idle_count > active_count # More idle than active workers - or (queue_depth == 0 and idle_count > 0) # No work and have idle workers - ) - - if not has_excess_capacity: - return False - - # Find and remove idle workers that have been idle long enough - workers_to_remove: list[tuple[Worker, int]] = [] - - for worker in self._workers: - # Check if worker is idle and has exceeded idle time threshold - if worker.is_idle and worker.idle_duration >= self._config.scale_down_idle_time: - # Don't remove if it would leave us unable to handle the queue - remaining_workers = current_count - len(workers_to_remove) - 1 - if remaining_workers >= self._config.min_workers and remaining_workers >= max(1, queue_depth // 2): - workers_to_remove.append((worker, worker.worker_id)) - # Only remove one worker per check to avoid aggressive scaling - break - - # Remove idle workers if any found - if workers_to_remove: - old_count = current_count - for worker, worker_id in workers_to_remove: - self._remove_worker(worker, worker_id) - - logger.debug( - "Scaled down workers: %d -> %d (removed %d idle workers after %.1fs, " - "queue_depth=%d, active=%d, idle=%d)", - old_count, - len(self._workers), - len(workers_to_remove), - self._config.scale_down_idle_time, - queue_depth, - active_count, - idle_count - len(workers_to_remove), - ) - return True - - return False - - def check_and_scale(self) -> None: - """Check and perform scaling if needed.""" - with self._lock: - if not self._running: - return - - current_count = len(self._workers) - queue_depth = self._ready_queue.qsize() - - # Count active vs idle workers by querying their state directly - idle_count = sum(1 for worker in self._workers if worker.is_idle) - active_count = current_count - idle_count - - # Try to scale up if queue is backing up - self._try_scale_up(queue_depth, current_count) - - # Try to scale down if we have excess capacity - self._try_scale_down(queue_depth, current_count, active_count, idle_count) - - def get_worker_count(self) -> int: - """Get current number of workers.""" - with self._lock: - return len(self._workers) - - def get_status(self) -> dict[str, int]: - """ - Get pool status information. - - Returns: - Dictionary with status information - """ - with self._lock: - return { - "total_workers": len(self._workers), - "queue_depth": self._ready_queue.qsize(), - "min_workers": self._config.min_workers, - "max_workers": self._config.max_workers, - } diff --git a/api/dify_graph/graph_events/__init__.py b/api/dify_graph/graph_events/__init__.py deleted file mode 100644 index 56ea6420925..00000000000 --- a/api/dify_graph/graph_events/__init__.py +++ /dev/null @@ -1,82 +0,0 @@ -# Agent events -from .agent import NodeRunAgentLogEvent - -# Base events -from .base import ( - BaseGraphEvent, - GraphEngineEvent, - GraphNodeEventBase, -) - -# Graph events -from .graph import ( - GraphRunAbortedEvent, - GraphRunFailedEvent, - GraphRunPartialSucceededEvent, - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, -) - -# Iteration events -from .iteration import ( - NodeRunIterationFailedEvent, - NodeRunIterationNextEvent, - NodeRunIterationStartedEvent, - NodeRunIterationSucceededEvent, -) - -# Loop events -from .loop import ( - NodeRunLoopFailedEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, -) - -# Node events -from .node import ( - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunHumanInputFormFilledEvent, - NodeRunHumanInputFormTimeoutEvent, - NodeRunPauseRequestedEvent, - NodeRunRetrieverResourceEvent, - NodeRunRetryEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - is_node_result_event, -) - -__all__ = [ - "BaseGraphEvent", - "GraphEngineEvent", - "GraphNodeEventBase", - "GraphRunAbortedEvent", - "GraphRunFailedEvent", - "GraphRunPartialSucceededEvent", - "GraphRunPausedEvent", - "GraphRunStartedEvent", - "GraphRunSucceededEvent", - "NodeRunAgentLogEvent", - "NodeRunExceptionEvent", - "NodeRunFailedEvent", - "NodeRunHumanInputFormFilledEvent", - "NodeRunHumanInputFormTimeoutEvent", - "NodeRunIterationFailedEvent", - "NodeRunIterationNextEvent", - "NodeRunIterationStartedEvent", - "NodeRunIterationSucceededEvent", - "NodeRunLoopFailedEvent", - "NodeRunLoopNextEvent", - "NodeRunLoopStartedEvent", - "NodeRunLoopSucceededEvent", - "NodeRunPauseRequestedEvent", - "NodeRunRetrieverResourceEvent", - "NodeRunRetryEvent", - "NodeRunStartedEvent", - "NodeRunStreamChunkEvent", - "NodeRunSucceededEvent", - "is_node_result_event", -] diff --git a/api/dify_graph/graph_events/agent.py b/api/dify_graph/graph_events/agent.py deleted file mode 100644 index 759fe3a71c7..00000000000 --- a/api/dify_graph/graph_events/agent.py +++ /dev/null @@ -1,17 +0,0 @@ -from collections.abc import Mapping -from typing import Any - -from pydantic import Field - -from .base import GraphAgentNodeEventBase - - -class NodeRunAgentLogEvent(GraphAgentNodeEventBase): - message_id: str = Field(..., description="message id") - label: str = Field(..., description="label") - node_execution_id: str = Field(..., description="node execution id") - parent_id: str | None = Field(..., description="parent id") - error: str | None = Field(..., description="error") - status: str = Field(..., description="status") - data: Mapping[str, Any] = Field(..., description="data") - metadata: Mapping[str, object] = Field(default_factory=dict) diff --git a/api/dify_graph/graph_events/base.py b/api/dify_graph/graph_events/base.py deleted file mode 100644 index 4560cf50854..00000000000 --- a/api/dify_graph/graph_events/base.py +++ /dev/null @@ -1,31 +0,0 @@ -from pydantic import BaseModel, Field - -from dify_graph.enums import NodeType -from dify_graph.node_events import NodeRunResult - - -class GraphEngineEvent(BaseModel): - pass - - -class BaseGraphEvent(GraphEngineEvent): - pass - - -class GraphNodeEventBase(GraphEngineEvent): - id: str = Field(..., description="node execution id") - node_id: str - node_type: NodeType - - in_iteration_id: str | None = None - """iteration id if node is in iteration""" - in_loop_id: str | None = None - """loop id if node is in loop""" - - # The version of the node, or "1" if not specified. - node_version: str = "1" - node_run_result: NodeRunResult = Field(default_factory=NodeRunResult) - - -class GraphAgentNodeEventBase(GraphNodeEventBase): - pass diff --git a/api/dify_graph/graph_events/graph.py b/api/dify_graph/graph_events/graph.py deleted file mode 100644 index f4aaba64d66..00000000000 --- a/api/dify_graph/graph_events/graph.py +++ /dev/null @@ -1,57 +0,0 @@ -from pydantic import Field - -from dify_graph.entities.pause_reason import PauseReason -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.graph_events import BaseGraphEvent - - -class GraphRunStartedEvent(BaseGraphEvent): - # Reason is emitted for workflow start events and is always set. - reason: WorkflowStartReason = Field( - default=WorkflowStartReason.INITIAL, - description="reason for workflow start", - ) - - -class GraphRunSucceededEvent(BaseGraphEvent): - """Event emitted when a run completes successfully with final outputs.""" - - outputs: dict[str, object] = Field( - default_factory=dict, - description="Final workflow outputs keyed by output selector.", - ) - - -class GraphRunFailedEvent(BaseGraphEvent): - error: str = Field(..., description="failed reason") - exceptions_count: int = Field(description="exception count", default=0) - - -class GraphRunPartialSucceededEvent(BaseGraphEvent): - """Event emitted when a run finishes with partial success and failures.""" - - exceptions_count: int = Field(..., description="exception count") - outputs: dict[str, object] = Field( - default_factory=dict, - description="Outputs that were materialised before failures occurred.", - ) - - -class GraphRunAbortedEvent(BaseGraphEvent): - """Event emitted when a graph run is aborted by user command.""" - - reason: str | None = Field(default=None, description="reason for abort") - outputs: dict[str, object] = Field( - default_factory=dict, - description="Outputs produced before the abort was requested.", - ) - - -class GraphRunPausedEvent(BaseGraphEvent): - """Event emitted when a graph run is paused by user command.""" - - reasons: list[PauseReason] = Field(description="reason for pause", default_factory=list) - outputs: dict[str, object] = Field( - default_factory=dict, - description="Outputs available to the client while the run is paused.", - ) diff --git a/api/dify_graph/graph_events/human_input.py b/api/dify_graph/graph_events/human_input.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/dify_graph/graph_events/iteration.py b/api/dify_graph/graph_events/iteration.py deleted file mode 100644 index 28627395fd8..00000000000 --- a/api/dify_graph/graph_events/iteration.py +++ /dev/null @@ -1,40 +0,0 @@ -from collections.abc import Mapping -from datetime import datetime -from typing import Any - -from pydantic import Field - -from .base import GraphNodeEventBase - - -class NodeRunIterationStartedEvent(GraphNodeEventBase): - node_title: str - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - predecessor_node_id: str | None = None - - -class NodeRunIterationNextEvent(GraphNodeEventBase): - node_title: str - index: int = Field(..., description="index") - pre_iteration_output: Any = None - - -class NodeRunIterationSucceededEvent(GraphNodeEventBase): - node_title: str - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - - -class NodeRunIterationFailedEvent(GraphNodeEventBase): - node_title: str - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - error: str = Field(..., description="failed reason") diff --git a/api/dify_graph/graph_events/loop.py b/api/dify_graph/graph_events/loop.py deleted file mode 100644 index 7cdc5427e2b..00000000000 --- a/api/dify_graph/graph_events/loop.py +++ /dev/null @@ -1,40 +0,0 @@ -from collections.abc import Mapping -from datetime import datetime -from typing import Any - -from pydantic import Field - -from .base import GraphNodeEventBase - - -class NodeRunLoopStartedEvent(GraphNodeEventBase): - node_title: str - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - predecessor_node_id: str | None = None - - -class NodeRunLoopNextEvent(GraphNodeEventBase): - node_title: str - index: int = Field(..., description="index") - pre_loop_output: Any = None - - -class NodeRunLoopSucceededEvent(GraphNodeEventBase): - node_title: str - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - - -class NodeRunLoopFailedEvent(GraphNodeEventBase): - node_title: str - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - error: str = Field(..., description="failed reason") diff --git a/api/dify_graph/graph_events/node.py b/api/dify_graph/graph_events/node.py deleted file mode 100644 index df19d6c03b1..00000000000 --- a/api/dify_graph/graph_events/node.py +++ /dev/null @@ -1,99 +0,0 @@ -from collections.abc import Sequence -from datetime import datetime - -from pydantic import Field - -from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from dify_graph.entities.pause_reason import PauseReason - -from .base import GraphNodeEventBase - - -class NodeRunStartedEvent(GraphNodeEventBase): - node_title: str - predecessor_node_id: str | None = None - start_at: datetime = Field(..., description="node start time") - extras: dict[str, object] = Field(default_factory=dict) - - # FIXME(-LAN-): only for ToolNode - provider_type: str = "" - provider_id: str = "" - - -class NodeRunStreamChunkEvent(GraphNodeEventBase): - # Spec-compliant fields - selector: Sequence[str] = Field( - ..., description="selector identifying the output location (e.g., ['nodeA', 'text'])" - ) - chunk: str = Field(..., description="the actual chunk content") - is_final: bool = Field(default=False, description="indicates if this is the last chunk") - - -class NodeRunRetrieverResourceEvent(GraphNodeEventBase): - retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources") - context: str = Field(..., description="context") - - -class NodeRunSucceededEvent(GraphNodeEventBase): - start_at: datetime = Field(..., description="node start time") - finished_at: datetime | None = Field(default=None, description="node finish time") - - -class NodeRunFailedEvent(GraphNodeEventBase): - error: str = Field(..., description="error") - start_at: datetime = Field(..., description="node start time") - finished_at: datetime | None = Field(default=None, description="node finish time") - - -class NodeRunExceptionEvent(GraphNodeEventBase): - error: str = Field(..., description="error") - start_at: datetime = Field(..., description="node start time") - finished_at: datetime | None = Field(default=None, description="node finish time") - - -class NodeRunRetryEvent(NodeRunStartedEvent): - error: str = Field(..., description="error") - retry_index: int = Field(..., description="which retry attempt is about to be performed") - - -class NodeRunHumanInputFormFilledEvent(GraphNodeEventBase): - """Emitted when a HumanInput form is submitted and before the node finishes.""" - - node_title: str = Field(..., description="HumanInput node title") - rendered_content: str = Field(..., description="Markdown content rendered with user inputs.") - action_id: str = Field(..., description="User action identifier chosen in the form.") - action_text: str = Field(..., description="Display text of the chosen action button.") - - -class NodeRunHumanInputFormTimeoutEvent(GraphNodeEventBase): - """Emitted when a HumanInput form times out.""" - - node_title: str = Field(..., description="HumanInput node title") - expiration_time: datetime = Field(..., description="Form expiration time") - - -class NodeRunPauseRequestedEvent(GraphNodeEventBase): - reason: PauseReason = Field(..., description="pause reason") - - -def is_node_result_event(event: GraphNodeEventBase) -> bool: - """ - Check if an event is a final result event from node execution. - - A result event indicates the completion of a node execution and contains - runtime information such as inputs, outputs, or error details. - - Args: - event: The event to check - - Returns: - True if the event is a node result event (succeeded/failed/paused), False otherwise - """ - return isinstance( - event, - ( - NodeRunSucceededEvent, - NodeRunFailedEvent, - NodeRunPauseRequestedEvent, - ), - ) diff --git a/api/dify_graph/model_runtime/README.md b/api/dify_graph/model_runtime/README.md deleted file mode 100644 index b9d2c552105..00000000000 --- a/api/dify_graph/model_runtime/README.md +++ /dev/null @@ -1,51 +0,0 @@ -# Model Runtime - -This module provides the interface for invoking and authenticating various models, and offers Dify a unified information and credentials form rule for model providers. - -- On one hand, it decouples models from upstream and downstream processes, facilitating horizontal expansion for developers, -- On the other hand, it allows for direct display of providers and models in the frontend interface by simply defining them in the backend, eliminating the need to modify frontend logic. - -## Features - -- Supports capability invocation for 6 types of models - - - `LLM` - LLM text completion, dialogue, pre-computed tokens capability - - `Text Embedding Model` - Text Embedding, pre-computed tokens capability - - `Rerank Model` - Segment Rerank capability - - `Speech-to-text Model` - Speech to text capability - - `Text-to-speech Model` - Text to speech capability - - `Moderation` - Moderation capability - -- Model provider display - - Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc. - -- Selectable model list display - - After configuring provider/model credentials, the dropdown (application orchestration interface/default model) allows viewing of the available LLM list. Greyed out items represent predefined model lists from providers without configured credentials, facilitating user review of supported models. - - In addition, this list also returns configurable parameter information and rules for LLM. These parameters are all defined in the backend, allowing different settings for various parameters supported by different models. - -- Provider/model credential authentication - - The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface. - -## Structure - -Model Runtime is divided into three layers: - -- The outermost layer is the factory method - - It provides methods for obtaining all providers, all model lists, getting provider instances, and authenticating provider/model credentials. - -- The second layer is the provider layer - - It provides the current provider's model list, model instance obtaining, provider credential authentication, and provider configuration rule information, **allowing horizontal expansion** to support different providers. - -- The bottom layer is the model layer - - It offers direct invocation of various model types, predefined model configuration information, getting predefined/remote model lists, model credential authentication methods. Different models provide additional special methods, like LLM's pre-computed tokens method, cost information obtaining method, etc., **allowing horizontal expansion** for different models under the same provider (within supported model types). - -## Documentation - -For detailed documentation on how to add new providers or models, please refer to the [Dify documentation](https://docs.dify.ai/). diff --git a/api/dify_graph/model_runtime/README_CN.md b/api/dify_graph/model_runtime/README_CN.md deleted file mode 100644 index 0a8b56b3fed..00000000000 --- a/api/dify_graph/model_runtime/README_CN.md +++ /dev/null @@ -1,64 +0,0 @@ -# Model Runtime - -该模块提供了各模型的调用、鉴权接口,并为 Dify 提供了统一的模型供应商的信息和凭据表单规则。 - -- 一方面将模型和上下游解耦,方便开发者对模型横向扩展, -- 另一方面提供了只需在后端定义供应商和模型,即可在前端页面直接展示,无需修改前端逻辑。 - -## 功能介绍 - -- 支持 6 种模型类型的能力调用 - - - `LLM` - LLM 文本补全、对话,预计算 tokens 能力 - - `Text Embedding Model` - 文本 Embedding,预计算 tokens 能力 - - `Rerank Model` - 分段 Rerank 能力 - - `Speech-to-text Model` - 语音转文本能力 - - `Text-to-speech Model` - 文本转语音能力 - - `Moderation` - Moderation 能力 - -- 模型供应商展示 - - 展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等。 - -- 可选择的模型列表展示 - - 配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。 - - 除此之外,该列表还返回了 LLM 可配置的参数信息和规则。这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数。 - -- 供应商/模型凭据鉴权 - - 供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权。 - -## 结构 - -Model Runtime 分三层: - -- 最外层为工厂方法 - - 提供获取所有供应商、所有模型列表、获取供应商实例、供应商/模型凭据鉴权方法。 - -- 第二层为供应商层 - - 提供获取当前供应商模型列表、获取模型实例、供应商凭据鉴权、供应商配置规则信息,**可横向扩展**以支持不同的供应商。 - - 对于供应商/模型凭据,有两种情况 - - - 如 OpenAI 这类中心化供应商,需要定义如**api_key**这类的鉴权凭据 - - 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据。当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。 - - 当配置好凭据后,就可以通过 DifyRuntime 的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。 - -- 最底层为模型层 - - 提供各种模型类型的直接调用、预定义模型配置信息、获取预定义/远程模型列表、模型凭据鉴权方法,不同模型额外提供了特殊方法,如 LLM 提供预计算 tokens 方法、获取费用信息方法等,**可横向扩展**同供应商下不同的模型(支持的模型类型下)。 - - 在这里我们需要先区分模型参数与模型凭据。 - - - 模型参数 (**在本层定义**):这是一类经常需要变动,随时调整的参数,如 LLM 的 **max_tokens**、**temperature** 等,这些参数是由用户在前端页面上进行调整的,因此需要在后端定义参数的规则,以便前端页面进行展示和调整。在 DifyRuntime 中,他们的参数名一般为**model_parameters: dict[str, any]**。 - - - 模型凭据 (**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在 DifyRuntime 中,他们的参数名一般为**credentials: dict[str, any]**,Provider 层的 credentials 会直接被传递到这一层,不需要再单独定义。 - -## 文档 - -有关如何添加新供应商或模型的详细文档,请参阅 [Dify 文档](https://docs.dify.ai/)。 diff --git a/api/dify_graph/model_runtime/callbacks/base_callback.py b/api/dify_graph/model_runtime/callbacks/base_callback.py deleted file mode 100644 index 20faf3d6cdf..00000000000 --- a/api/dify_graph/model_runtime/callbacks/base_callback.py +++ /dev/null @@ -1,151 +0,0 @@ -from abc import ABC, abstractmethod -from collections.abc import Sequence - -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel - -_TEXT_COLOR_MAPPING = { - "blue": "36;1", - "yellow": "33;1", - "pink": "38;5;200", - "green": "32;1", - "red": "31;1", -} - - -class Callback(ABC): - """ - Base class for callbacks. - Only for LLM. - """ - - raise_error: bool = False - - @abstractmethod - def on_before_invoke( - self, - llm_instance: AIModel, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - ): - """ - Before invoke callback - - :param llm_instance: LLM instance - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - """ - raise NotImplementedError() - - @abstractmethod - def on_new_chunk( - self, - llm_instance: AIModel, - chunk: LLMResultChunk, - model: str, - credentials: dict, - prompt_messages: Sequence[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - ): - """ - On new chunk callback - - :param llm_instance: LLM instance - :param chunk: chunk - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - """ - raise NotImplementedError() - - @abstractmethod - def on_after_invoke( - self, - llm_instance: AIModel, - result: LLMResult, - model: str, - credentials: dict, - prompt_messages: Sequence[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - ): - """ - After invoke callback - - :param llm_instance: LLM instance - :param result: result - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - """ - raise NotImplementedError() - - @abstractmethod - def on_invoke_error( - self, - llm_instance: AIModel, - ex: Exception, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - ): - """ - Invoke error callback - - :param llm_instance: LLM instance - :param ex: exception - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - """ - raise NotImplementedError() - - def print_text(self, text: str, color: str | None = None, end: str = ""): - """Print text with highlighting and no end characters.""" - text_to_print = self._get_colored_text(text, color) if color else text - print(text_to_print, end=end) - - def _get_colored_text(self, text: str, color: str) -> str: - """Get colored text.""" - color_str = _TEXT_COLOR_MAPPING[color] - return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" diff --git a/api/dify_graph/model_runtime/callbacks/logging_callback.py b/api/dify_graph/model_runtime/callbacks/logging_callback.py deleted file mode 100644 index 49b9ab27ebd..00000000000 --- a/api/dify_graph/model_runtime/callbacks/logging_callback.py +++ /dev/null @@ -1,170 +0,0 @@ -import json -import logging -import sys -from collections.abc import Sequence -from typing import cast - -from dify_graph.model_runtime.callbacks.base_callback import Callback -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel - -logger = logging.getLogger(__name__) - - -class LoggingCallback(Callback): - def on_before_invoke( - self, - llm_instance: AIModel, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - ): - """ - Before invoke callback - - :param llm_instance: LLM instance - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - """ - self.print_text("\n[on_llm_before_invoke]\n", color="blue") - self.print_text(f"Model: {model}\n", color="blue") - self.print_text("Parameters:\n", color="blue") - for key, value in model_parameters.items(): - self.print_text(f"\t{key}: {value}\n", color="blue") - - if stop: - self.print_text(f"\tstop: {stop}\n", color="blue") - - if tools: - self.print_text("\tTools:\n", color="blue") - for tool in tools: - self.print_text(f"\t\t{tool.name}\n", color="blue") - - self.print_text(f"Stream: {stream}\n", color="blue") - - if user: - self.print_text(f"User: {user}\n", color="blue") - - self.print_text("Prompt messages:\n", color="blue") - for prompt_message in prompt_messages: - if prompt_message.name: - self.print_text(f"\tname: {prompt_message.name}\n", color="blue") - - self.print_text(f"\trole: {prompt_message.role.value}\n", color="blue") - self.print_text(f"\tcontent: {prompt_message.content}\n", color="blue") - - if stream: - self.print_text("\n[on_llm_new_chunk]") - - def on_new_chunk( - self, - llm_instance: AIModel, - chunk: LLMResultChunk, - model: str, - credentials: dict, - prompt_messages: Sequence[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - ): - """ - On new chunk callback - - :param llm_instance: LLM instance - :param chunk: chunk - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - """ - sys.stdout.write(cast(str, chunk.delta.message.content)) - sys.stdout.flush() - - def on_after_invoke( - self, - llm_instance: AIModel, - result: LLMResult, - model: str, - credentials: dict, - prompt_messages: Sequence[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - ): - """ - After invoke callback - - :param llm_instance: LLM instance - :param result: result - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - """ - self.print_text("\n[on_llm_after_invoke]\n", color="yellow") - self.print_text(f"Content: {result.message.content}\n", color="yellow") - - if result.message.tool_calls: - self.print_text("Tool calls:\n", color="yellow") - for tool_call in result.message.tool_calls: - self.print_text(f"\t{tool_call.id}\n", color="yellow") - self.print_text(f"\t{tool_call.function.name}\n", color="yellow") - self.print_text(f"\t{json.dumps(tool_call.function.arguments)}\n", color="yellow") - - self.print_text(f"Model: {result.model}\n", color="yellow") - self.print_text(f"Usage: {result.usage}\n", color="yellow") - self.print_text(f"System Fingerprint: {result.system_fingerprint}\n", color="yellow") - - def on_invoke_error( - self, - llm_instance: AIModel, - ex: Exception, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - ): - """ - Invoke error callback - - :param llm_instance: LLM instance - :param ex: exception - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - """ - self.print_text("\n[on_llm_invoke_error]\n", color="red") - logger.exception(ex) diff --git a/api/dify_graph/model_runtime/entities/__init__.py b/api/dify_graph/model_runtime/entities/__init__.py deleted file mode 100644 index a24e437d48e..00000000000 --- a/api/dify_graph/model_runtime/entities/__init__.py +++ /dev/null @@ -1,43 +0,0 @@ -from .llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from .message_entities import ( - AssistantPromptMessage, - AudioPromptMessageContent, - DocumentPromptMessageContent, - ImagePromptMessageContent, - MultiModalPromptMessageContent, - PromptMessage, - PromptMessageContent, - PromptMessageContentType, - PromptMessageRole, - PromptMessageTool, - SystemPromptMessage, - TextPromptMessageContent, - ToolPromptMessage, - UserPromptMessage, - VideoPromptMessageContent, -) -from .model_entities import ModelPropertyKey - -__all__ = [ - "AssistantPromptMessage", - "AudioPromptMessageContent", - "DocumentPromptMessageContent", - "ImagePromptMessageContent", - "LLMMode", - "LLMResult", - "LLMResultChunk", - "LLMResultChunkDelta", - "LLMUsage", - "ModelPropertyKey", - "MultiModalPromptMessageContent", - "PromptMessage", - "PromptMessageContent", - "PromptMessageContentType", - "PromptMessageRole", - "PromptMessageTool", - "SystemPromptMessage", - "TextPromptMessageContent", - "ToolPromptMessage", - "UserPromptMessage", - "VideoPromptMessageContent", -] diff --git a/api/dify_graph/model_runtime/entities/common_entities.py b/api/dify_graph/model_runtime/entities/common_entities.py deleted file mode 100644 index b673efae228..00000000000 --- a/api/dify_graph/model_runtime/entities/common_entities.py +++ /dev/null @@ -1,16 +0,0 @@ -from pydantic import BaseModel, model_validator - - -class I18nObject(BaseModel): - """ - Model class for i18n object. - """ - - zh_Hans: str | None = None - en_US: str - - @model_validator(mode="after") - def _(self): - if not self.zh_Hans: - self.zh_Hans = self.en_US - return self diff --git a/api/dify_graph/model_runtime/entities/defaults.py b/api/dify_graph/model_runtime/entities/defaults.py deleted file mode 100644 index 53b732e5c6c..00000000000 --- a/api/dify_graph/model_runtime/entities/defaults.py +++ /dev/null @@ -1,130 +0,0 @@ -from dify_graph.model_runtime.entities.model_entities import DefaultParameterName - -PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = { - DefaultParameterName.TEMPERATURE: { - "label": { - "en_US": "Temperature", - "zh_Hans": "温度", - }, - "type": "float", - "help": { - "en_US": "Controls randomness. Lower temperature results in less random completions." - " As the temperature approaches zero, the model will become deterministic and repetitive." - " Higher temperature results in more random completions.", - "zh_Hans": "温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。" - "较高的温度会导致更多的随机完成。", - }, - "required": False, - "default": 0.0, - "min": 0.0, - "max": 1.0, - "precision": 2, - }, - DefaultParameterName.TOP_P: { - "label": { - "en_US": "Top P", - "zh_Hans": "Top P", - }, - "type": "float", - "help": { - "en_US": "Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options" - " are considered.", - "zh_Hans": "通过核心采样控制多样性:0.5 表示考虑了一半的所有可能性加权选项。", - }, - "required": False, - "default": 1.0, - "min": 0.0, - "max": 1.0, - "precision": 2, - }, - DefaultParameterName.TOP_K: { - "label": { - "en_US": "Top K", - "zh_Hans": "Top K", - }, - "type": "int", - "help": { - "en_US": "Limits the number of tokens to consider for each step by keeping only the k most likely tokens.", - "zh_Hans": "通过只保留每一步中最可能的 k 个标记来限制要考虑的标记数量。", - }, - "required": False, - "default": 50, - "min": 1, - "max": 100, - "precision": 0, - }, - DefaultParameterName.PRESENCE_PENALTY: { - "label": { - "en_US": "Presence Penalty", - "zh_Hans": "存在惩罚", - }, - "type": "float", - "help": { - "en_US": "Applies a penalty to the log-probability of tokens already in the text.", - "zh_Hans": "对文本中已有的标记的对数概率施加惩罚。", - }, - "required": False, - "default": 0.0, - "min": 0.0, - "max": 1.0, - "precision": 2, - }, - DefaultParameterName.FREQUENCY_PENALTY: { - "label": { - "en_US": "Frequency Penalty", - "zh_Hans": "频率惩罚", - }, - "type": "float", - "help": { - "en_US": "Applies a penalty to the log-probability of tokens that appear in the text.", - "zh_Hans": "对文本中出现的标记的对数概率施加惩罚。", - }, - "required": False, - "default": 0.0, - "min": 0.0, - "max": 1.0, - "precision": 2, - }, - DefaultParameterName.MAX_TOKENS: { - "label": { - "en_US": "Max Tokens", - "zh_Hans": "最大 Token 数", - }, - "type": "int", - "help": { - "en_US": "Specifies the upper limit on the length of generated results." - " If the generated results are truncated, you can increase this parameter.", - "zh_Hans": "指定生成结果长度的上限。如果生成结果截断,可以调大该参数。", - }, - "required": False, - "default": 64, - "min": 1, - "max": 2048, - "precision": 0, - }, - DefaultParameterName.RESPONSE_FORMAT: { - "label": { - "en_US": "Response Format", - "zh_Hans": "回复格式", - }, - "type": "string", - "help": { - "en_US": "Set a response format, ensure the output from llm is a valid code block as possible," - " such as JSON, XML, etc.", - "zh_Hans": "设置一个返回格式,确保 llm 的输出尽可能是有效的代码块,如 JSON、XML 等", - }, - "required": False, - "options": ["JSON", "XML"], - }, - DefaultParameterName.JSON_SCHEMA: { - "label": { - "en_US": "JSON Schema", - }, - "type": "text", - "help": { - "en_US": "Set a response json schema will ensure LLM to adhere it.", - "zh_Hans": "设置返回的 json schema,llm 将按照它返回", - }, - "required": False, - }, -} diff --git a/api/dify_graph/model_runtime/entities/llm_entities.py b/api/dify_graph/model_runtime/entities/llm_entities.py deleted file mode 100644 index eec682a2ae5..00000000000 --- a/api/dify_graph/model_runtime/entities/llm_entities.py +++ /dev/null @@ -1,219 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping, Sequence -from decimal import Decimal -from enum import StrEnum -from typing import Any, TypedDict, Union - -from pydantic import BaseModel, Field - -from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage -from dify_graph.model_runtime.entities.model_entities import ModelUsage, PriceInfo - - -class LLMMode(StrEnum): - """ - Enum class for large language model mode. - """ - - COMPLETION = "completion" - CHAT = "chat" - - -class LLMUsageMetadata(TypedDict, total=False): - """ - TypedDict for LLM usage metadata. - All fields are optional. - """ - - prompt_tokens: int - completion_tokens: int - total_tokens: int - prompt_unit_price: Union[float, str] - completion_unit_price: Union[float, str] - total_price: Union[float, str] - currency: str - prompt_price_unit: Union[float, str] - completion_price_unit: Union[float, str] - prompt_price: Union[float, str] - completion_price: Union[float, str] - latency: float - time_to_first_token: float - time_to_generate: float - - -class LLMUsage(ModelUsage): - """ - Model class for llm usage. - """ - - prompt_tokens: int - prompt_unit_price: Decimal - prompt_price_unit: Decimal - prompt_price: Decimal - completion_tokens: int - completion_unit_price: Decimal - completion_price_unit: Decimal - completion_price: Decimal - total_tokens: int - total_price: Decimal - currency: str - latency: float - time_to_first_token: float | None = None - time_to_generate: float | None = None - - @classmethod - def empty_usage(cls): - return cls( - prompt_tokens=0, - prompt_unit_price=Decimal("0.0"), - prompt_price_unit=Decimal("0.0"), - prompt_price=Decimal("0.0"), - completion_tokens=0, - completion_unit_price=Decimal("0.0"), - completion_price_unit=Decimal("0.0"), - completion_price=Decimal("0.0"), - total_tokens=0, - total_price=Decimal("0.0"), - currency="USD", - latency=0.0, - time_to_first_token=None, - time_to_generate=None, - ) - - @classmethod - def from_metadata(cls, metadata: LLMUsageMetadata) -> LLMUsage: - """ - Create LLMUsage instance from metadata dictionary with default values. - - Args: - metadata: TypedDict containing usage metadata - - Returns: - LLMUsage instance with values from metadata or defaults - """ - prompt_tokens = metadata.get("prompt_tokens", 0) - completion_tokens = metadata.get("completion_tokens", 0) - total_tokens = metadata.get("total_tokens", 0) - - # If total_tokens is not provided but prompt and completion tokens are, - # calculate total_tokens - if total_tokens == 0 and (prompt_tokens > 0 or completion_tokens > 0): - total_tokens = prompt_tokens + completion_tokens - - return cls( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens, - prompt_unit_price=Decimal(str(metadata.get("prompt_unit_price", 0))), - completion_unit_price=Decimal(str(metadata.get("completion_unit_price", 0))), - total_price=Decimal(str(metadata.get("total_price", 0))), - currency=metadata.get("currency", "USD"), - prompt_price_unit=Decimal(str(metadata.get("prompt_price_unit", 0))), - completion_price_unit=Decimal(str(metadata.get("completion_price_unit", 0))), - prompt_price=Decimal(str(metadata.get("prompt_price", 0))), - completion_price=Decimal(str(metadata.get("completion_price", 0))), - latency=metadata.get("latency", 0.0), - time_to_first_token=metadata.get("time_to_first_token"), - time_to_generate=metadata.get("time_to_generate"), - ) - - def plus(self, other: LLMUsage) -> LLMUsage: - """ - Add two LLMUsage instances together. - - :param other: Another LLMUsage instance to add - :return: A new LLMUsage instance with summed values - """ - if self.total_tokens == 0: - return other - else: - return LLMUsage( - prompt_tokens=self.prompt_tokens + other.prompt_tokens, - prompt_unit_price=other.prompt_unit_price, - prompt_price_unit=other.prompt_price_unit, - prompt_price=self.prompt_price + other.prompt_price, - completion_tokens=self.completion_tokens + other.completion_tokens, - completion_unit_price=other.completion_unit_price, - completion_price_unit=other.completion_price_unit, - completion_price=self.completion_price + other.completion_price, - total_tokens=self.total_tokens + other.total_tokens, - total_price=self.total_price + other.total_price, - currency=other.currency, - latency=self.latency + other.latency, - time_to_first_token=other.time_to_first_token, - time_to_generate=other.time_to_generate, - ) - - def __add__(self, other: LLMUsage) -> LLMUsage: - """ - Overload the + operator to add two LLMUsage instances. - - :param other: Another LLMUsage instance to add - :return: A new LLMUsage instance with summed values - """ - return self.plus(other) - - -class LLMResult(BaseModel): - """ - Model class for llm result. - """ - - id: str | None = None - model: str - prompt_messages: Sequence[PromptMessage] = Field(default_factory=list) - message: AssistantPromptMessage - usage: LLMUsage - system_fingerprint: str | None = None - reasoning_content: str | None = None - - -class LLMStructuredOutput(BaseModel): - """ - Model class for llm structured output. - """ - - structured_output: Mapping[str, Any] | None = None - - -class LLMResultWithStructuredOutput(LLMResult, LLMStructuredOutput): - """ - Model class for llm result with structured output. - """ - - -class LLMResultChunkDelta(BaseModel): - """ - Model class for llm result chunk delta. - """ - - index: int - message: AssistantPromptMessage - usage: LLMUsage | None = None - finish_reason: str | None = None - - -class LLMResultChunk(BaseModel): - """ - Model class for llm result chunk. - """ - - model: str - prompt_messages: Sequence[PromptMessage] = Field(default_factory=list) - system_fingerprint: str | None = None - delta: LLMResultChunkDelta - - -class LLMResultChunkWithStructuredOutput(LLMResultChunk, LLMStructuredOutput): - """ - Model class for llm result chunk with structured output. - """ - - -class NumTokensResult(PriceInfo): - """ - Model class for number of tokens result. - """ - - tokens: int diff --git a/api/dify_graph/model_runtime/entities/message_entities.py b/api/dify_graph/model_runtime/entities/message_entities.py deleted file mode 100644 index 402bfdc6065..00000000000 --- a/api/dify_graph/model_runtime/entities/message_entities.py +++ /dev/null @@ -1,279 +0,0 @@ -from __future__ import annotations - -from abc import ABC -from collections.abc import Mapping, Sequence -from enum import StrEnum, auto -from typing import Annotated, Any, Literal, Union - -from pydantic import BaseModel, Field, field_serializer, field_validator - - -class PromptMessageRole(StrEnum): - """ - Enum class for prompt message. - """ - - SYSTEM = auto() - USER = auto() - ASSISTANT = auto() - TOOL = auto() - - @classmethod - def value_of(cls, value: str) -> PromptMessageRole: - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f"invalid prompt message type value {value}") - - -class PromptMessageTool(BaseModel): - """ - Model class for prompt message tool. - """ - - name: str - description: str - parameters: dict - - -class PromptMessageFunction(BaseModel): - """ - Model class for prompt message function. - """ - - type: str = "function" - function: PromptMessageTool - - -class PromptMessageContentType(StrEnum): - """ - Enum class for prompt message content type. - """ - - TEXT = auto() - IMAGE = auto() - AUDIO = auto() - VIDEO = auto() - DOCUMENT = auto() - - -class PromptMessageContent(ABC, BaseModel): - """ - Model class for prompt message content. - """ - - type: PromptMessageContentType - - -class TextPromptMessageContent(PromptMessageContent): - """ - Model class for text prompt message content. - """ - - type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT # type: ignore - data: str - - -class MultiModalPromptMessageContent(PromptMessageContent): - """ - Model class for multi-modal prompt message content. - """ - - format: str = Field(default=..., description="the format of multi-modal file") - base64_data: str = Field(default="", description="the base64 data of multi-modal file") - url: str = Field(default="", description="the url of multi-modal file") - mime_type: str = Field(default=..., description="the mime type of multi-modal file") - filename: str = Field(default="", description="the filename of multi-modal file") - - @property - def data(self): - return self.url or f"data:{self.mime_type};base64,{self.base64_data}" - - -class VideoPromptMessageContent(MultiModalPromptMessageContent): - type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO # type: ignore - - -class AudioPromptMessageContent(MultiModalPromptMessageContent): - type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO # type: ignore - - -class ImagePromptMessageContent(MultiModalPromptMessageContent): - """ - Model class for image prompt message content. - """ - - class DETAIL(StrEnum): - LOW = auto() - HIGH = auto() - - type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE # type: ignore - detail: DETAIL = DETAIL.LOW - - -class DocumentPromptMessageContent(MultiModalPromptMessageContent): - type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT # type: ignore - - -PromptMessageContentUnionTypes = Annotated[ - Union[ - TextPromptMessageContent, - ImagePromptMessageContent, - DocumentPromptMessageContent, - AudioPromptMessageContent, - VideoPromptMessageContent, - ], - Field(discriminator="type"), -] - - -CONTENT_TYPE_MAPPING: Mapping[PromptMessageContentType, type[PromptMessageContent]] = { - PromptMessageContentType.TEXT: TextPromptMessageContent, - PromptMessageContentType.IMAGE: ImagePromptMessageContent, - PromptMessageContentType.AUDIO: AudioPromptMessageContent, - PromptMessageContentType.VIDEO: VideoPromptMessageContent, - PromptMessageContentType.DOCUMENT: DocumentPromptMessageContent, -} - - -class PromptMessage(ABC, BaseModel): - """ - Model class for prompt message. - """ - - role: PromptMessageRole - content: str | list[PromptMessageContentUnionTypes] | None = None - name: str | None = None - - def is_empty(self) -> bool: - """ - Check if prompt message is empty. - - :return: True if prompt message is empty, False otherwise - """ - return not self.content - - def get_text_content(self) -> str: - """ - Get text content from prompt message. - - :return: Text content as string, empty string if no text content - """ - if isinstance(self.content, str): - return self.content - elif isinstance(self.content, list): - text_parts = [] - for item in self.content: - if isinstance(item, TextPromptMessageContent): - text_parts.append(item.data) - return "".join(text_parts) - else: - return "" - - @field_validator("content", mode="before") - @classmethod - def validate_content(cls, v): - if isinstance(v, list): - prompts = [] - for prompt in v: - if isinstance(prompt, PromptMessageContent): - if not isinstance(prompt, TextPromptMessageContent | MultiModalPromptMessageContent): - prompt = CONTENT_TYPE_MAPPING[prompt.type].model_validate(prompt.model_dump()) - elif isinstance(prompt, dict): - prompt = CONTENT_TYPE_MAPPING[prompt["type"]].model_validate(prompt) - else: - raise ValueError(f"invalid prompt message {prompt}") - prompts.append(prompt) - return prompts - return v - - @field_serializer("content") - def serialize_content( - self, content: Union[str, Sequence[PromptMessageContent]] | None - ) -> str | list[dict[str, Any] | PromptMessageContent] | Sequence[PromptMessageContent] | None: - if content is None or isinstance(content, str): - return content - if isinstance(content, list): - return [item.model_dump() if hasattr(item, "model_dump") else item for item in content] - return content - - -class UserPromptMessage(PromptMessage): - """ - Model class for user prompt message. - """ - - role: PromptMessageRole = PromptMessageRole.USER - - -class AssistantPromptMessage(PromptMessage): - """ - Model class for assistant prompt message. - """ - - class ToolCall(BaseModel): - """ - Model class for assistant prompt message tool call. - """ - - class ToolCallFunction(BaseModel): - """ - Model class for assistant prompt message tool call function. - """ - - name: str - arguments: str - - id: str - type: str - function: ToolCallFunction - - @field_validator("id", mode="before") - @classmethod - def transform_id_to_str(cls, value) -> str: - if not isinstance(value, str): - return str(value) - else: - return value - - role: PromptMessageRole = PromptMessageRole.ASSISTANT - tool_calls: list[ToolCall] = [] - - def is_empty(self) -> bool: - """ - Check if prompt message is empty. - - :return: True if prompt message is empty, False otherwise - """ - return super().is_empty() and not self.tool_calls - - -class SystemPromptMessage(PromptMessage): - """ - Model class for system prompt message. - """ - - role: PromptMessageRole = PromptMessageRole.SYSTEM - - -class ToolPromptMessage(PromptMessage): - """ - Model class for tool prompt message. - """ - - role: PromptMessageRole = PromptMessageRole.TOOL - tool_call_id: str - - def is_empty(self) -> bool: - """ - Check if prompt message is empty. - - :return: True if prompt message is empty, False otherwise - """ - return super().is_empty() and not self.tool_call_id diff --git a/api/dify_graph/model_runtime/entities/model_entities.py b/api/dify_graph/model_runtime/entities/model_entities.py deleted file mode 100644 index fbcde6740a4..00000000000 --- a/api/dify_graph/model_runtime/entities/model_entities.py +++ /dev/null @@ -1,242 +0,0 @@ -from __future__ import annotations - -from decimal import Decimal -from enum import StrEnum, auto -from typing import Any - -from pydantic import BaseModel, ConfigDict, model_validator - -from dify_graph.model_runtime.entities.common_entities import I18nObject - - -class ModelType(StrEnum): - """ - Enum class for model type. - """ - - LLM = auto() - TEXT_EMBEDDING = "text-embedding" - RERANK = auto() - SPEECH2TEXT = auto() - MODERATION = auto() - TTS = auto() - - @classmethod - def value_of(cls, origin_model_type: str) -> ModelType: - """ - Get model type from origin model type. - - :return: model type - """ - if origin_model_type in {"text-generation", cls.LLM}: - return cls.LLM - elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING}: - return cls.TEXT_EMBEDDING - elif origin_model_type in {"reranking", cls.RERANK}: - return cls.RERANK - elif origin_model_type in {"speech2text", cls.SPEECH2TEXT}: - return cls.SPEECH2TEXT - elif origin_model_type in {"tts", cls.TTS}: - return cls.TTS - elif origin_model_type == cls.MODERATION: - return cls.MODERATION - else: - raise ValueError(f"invalid origin model type {origin_model_type}") - - def to_origin_model_type(self) -> str: - """ - Get origin model type from model type. - - :return: origin model type - """ - if self == self.LLM: - return "text-generation" - elif self == self.TEXT_EMBEDDING: - return "embeddings" - elif self == self.RERANK: - return "reranking" - elif self == self.SPEECH2TEXT: - return "speech2text" - elif self == self.TTS: - return "tts" - elif self == self.MODERATION: - return "moderation" - else: - raise ValueError(f"invalid model type {self}") - - -class FetchFrom(StrEnum): - """ - Enum class for fetch from. - """ - - PREDEFINED_MODEL = "predefined-model" - CUSTOMIZABLE_MODEL = "customizable-model" - - -class ModelFeature(StrEnum): - """ - Enum class for llm feature. - """ - - TOOL_CALL = "tool-call" - MULTI_TOOL_CALL = "multi-tool-call" - AGENT_THOUGHT = "agent-thought" - VISION = auto() - STREAM_TOOL_CALL = "stream-tool-call" - DOCUMENT = auto() - VIDEO = auto() - AUDIO = auto() - STRUCTURED_OUTPUT = "structured-output" - - -class DefaultParameterName(StrEnum): - """ - Enum class for parameter template variable. - """ - - TEMPERATURE = auto() - TOP_P = auto() - TOP_K = auto() - PRESENCE_PENALTY = auto() - FREQUENCY_PENALTY = auto() - MAX_TOKENS = auto() - RESPONSE_FORMAT = auto() - JSON_SCHEMA = auto() - - @classmethod - def value_of(cls, value: Any) -> DefaultParameterName: - """ - Get parameter name from value. - - :param value: parameter value - :return: parameter name - """ - for name in cls: - if name.value == value: - return name - raise ValueError(f"invalid parameter name {value}") - - -class ParameterType(StrEnum): - """ - Enum class for parameter type. - """ - - FLOAT = auto() - INT = auto() - STRING = auto() - BOOLEAN = auto() - TEXT = auto() - - -class ModelPropertyKey(StrEnum): - """ - Enum class for model property key. - """ - - MODE = auto() - CONTEXT_SIZE = auto() - MAX_CHUNKS = auto() - FILE_UPLOAD_LIMIT = auto() - SUPPORTED_FILE_EXTENSIONS = auto() - MAX_CHARACTERS_PER_CHUNK = auto() - DEFAULT_VOICE = auto() - VOICES = auto() - WORD_LIMIT = auto() - AUDIO_TYPE = auto() - MAX_WORKERS = auto() - - -class ProviderModel(BaseModel): - """ - Model class for provider model. - """ - - model: str - label: I18nObject - model_type: ModelType - features: list[ModelFeature] | None = None - fetch_from: FetchFrom - model_properties: dict[ModelPropertyKey, Any] - deprecated: bool = False - model_config = ConfigDict(protected_namespaces=()) - - @property - def support_structure_output(self) -> bool: - return self.features is not None and ModelFeature.STRUCTURED_OUTPUT in self.features - - -class ParameterRule(BaseModel): - """ - Model class for parameter rule. - """ - - name: str - use_template: str | None = None - label: I18nObject - type: ParameterType - help: I18nObject | None = None - required: bool = False - default: Any | None = None - min: float | None = None - max: float | None = None - precision: int | None = None - options: list[str] = [] - - -class PriceConfig(BaseModel): - """ - Model class for pricing info. - """ - - input: Decimal - output: Decimal | None = None - unit: Decimal - currency: str - - -class AIModelEntity(ProviderModel): - """ - Model class for AI model. - """ - - parameter_rules: list[ParameterRule] = [] - pricing: PriceConfig | None = None - - @model_validator(mode="after") - def validate_model(self): - supported_schema_keys = ["json_schema"] - schema_key = next((rule.name for rule in self.parameter_rules if rule.name in supported_schema_keys), None) - if not schema_key: - return self - if self.features is None: - self.features = [ModelFeature.STRUCTURED_OUTPUT] - else: - if ModelFeature.STRUCTURED_OUTPUT not in self.features: - self.features.append(ModelFeature.STRUCTURED_OUTPUT) - return self - - -class ModelUsage(BaseModel): - pass - - -class PriceType(StrEnum): - """ - Enum class for price type. - """ - - INPUT = auto() - OUTPUT = auto() - - -class PriceInfo(BaseModel): - """ - Model class for price info. - """ - - unit_price: Decimal - unit: Decimal - total_amount: Decimal - currency: str diff --git a/api/dify_graph/model_runtime/entities/provider_entities.py b/api/dify_graph/model_runtime/entities/provider_entities.py deleted file mode 100644 index 97a99ea7ceb..00000000000 --- a/api/dify_graph/model_runtime/entities/provider_entities.py +++ /dev/null @@ -1,169 +0,0 @@ -from collections.abc import Sequence -from enum import StrEnum, auto - -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator - -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType - - -class ConfigurateMethod(StrEnum): - """ - Enum class for configurate method of provider model. - """ - - PREDEFINED_MODEL = "predefined-model" - CUSTOMIZABLE_MODEL = "customizable-model" - - -class FormType(StrEnum): - """ - Enum class for form type. - """ - - TEXT_INPUT = "text-input" - SECRET_INPUT = "secret-input" - SELECT = auto() - RADIO = auto() - SWITCH = auto() - - -class FormShowOnObject(BaseModel): - """ - Model class for form show on. - """ - - variable: str - value: str - - -class FormOption(BaseModel): - """ - Model class for form option. - """ - - label: I18nObject - value: str - show_on: list[FormShowOnObject] = [] - - @model_validator(mode="after") - def _(self): - if not self.label: - self.label = I18nObject(en_US=self.value) - return self - - -class CredentialFormSchema(BaseModel): - """ - Model class for credential form schema. - """ - - variable: str - label: I18nObject - type: FormType - required: bool = True - default: str | None = None - options: list[FormOption] | None = None - placeholder: I18nObject | None = None - max_length: int = 0 - show_on: list[FormShowOnObject] = [] - - -class ProviderCredentialSchema(BaseModel): - """ - Model class for provider credential schema. - """ - - credential_form_schemas: list[CredentialFormSchema] - - -class FieldModelSchema(BaseModel): - label: I18nObject - placeholder: I18nObject | None = None - - -class ModelCredentialSchema(BaseModel): - """ - Model class for model credential schema. - """ - - model: FieldModelSchema - credential_form_schemas: list[CredentialFormSchema] - - -class SimpleProviderEntity(BaseModel): - """ - Simple model class for provider. - """ - - provider: str - label: I18nObject - icon_small: I18nObject | None = None - icon_small_dark: I18nObject | None = None - supported_model_types: Sequence[ModelType] - models: list[AIModelEntity] = [] - - -class ProviderHelpEntity(BaseModel): - """ - Model class for provider help. - """ - - title: I18nObject - url: I18nObject - - -class ProviderEntity(BaseModel): - """ - Model class for provider. - """ - - provider: str - label: I18nObject - description: I18nObject | None = None - icon_small: I18nObject | None = None - icon_small_dark: I18nObject | None = None - background: str | None = None - help: ProviderHelpEntity | None = None - supported_model_types: Sequence[ModelType] - configurate_methods: list[ConfigurateMethod] - models: list[AIModelEntity] = Field(default_factory=list) - provider_credential_schema: ProviderCredentialSchema | None = None - model_credential_schema: ModelCredentialSchema | None = None - - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - - # position from plugin _position.yaml - position: dict[str, list[str]] | None = {} - - @field_validator("models", mode="before") - @classmethod - def validate_models(cls, v): - # returns EmptyList if v is empty - if not v: - return [] - return v - - def to_simple_provider(self) -> SimpleProviderEntity: - """ - Convert to simple provider. - - :return: simple provider - """ - return SimpleProviderEntity( - provider=self.provider, - label=self.label, - icon_small=self.icon_small, - supported_model_types=self.supported_model_types, - models=self.models, - ) - - -class ProviderConfig(BaseModel): - """ - Model class for provider config. - """ - - provider: str - credentials: dict diff --git a/api/dify_graph/model_runtime/entities/rerank_entities.py b/api/dify_graph/model_runtime/entities/rerank_entities.py deleted file mode 100644 index 99709e1bcd2..00000000000 --- a/api/dify_graph/model_runtime/entities/rerank_entities.py +++ /dev/null @@ -1,20 +0,0 @@ -from pydantic import BaseModel - - -class RerankDocument(BaseModel): - """ - Model class for rerank document. - """ - - index: int - text: str - score: float - - -class RerankResult(BaseModel): - """ - Model class for rerank result. - """ - - model: str - docs: list[RerankDocument] diff --git a/api/dify_graph/model_runtime/entities/text_embedding_entities.py b/api/dify_graph/model_runtime/entities/text_embedding_entities.py deleted file mode 100644 index a0210c169d5..00000000000 --- a/api/dify_graph/model_runtime/entities/text_embedding_entities.py +++ /dev/null @@ -1,39 +0,0 @@ -from decimal import Decimal - -from pydantic import BaseModel - -from dify_graph.model_runtime.entities.model_entities import ModelUsage - - -class EmbeddingUsage(ModelUsage): - """ - Model class for embedding usage. - """ - - tokens: int - total_tokens: int - unit_price: Decimal - price_unit: Decimal - total_price: Decimal - currency: str - latency: float - - -class EmbeddingResult(BaseModel): - """ - Model class for text embedding result. - """ - - model: str - embeddings: list[list[float]] - usage: EmbeddingUsage - - -class FileEmbeddingResult(BaseModel): - """ - Model class for file embedding result. - """ - - model: str - embeddings: list[list[float]] - usage: EmbeddingUsage diff --git a/api/dify_graph/model_runtime/errors/invoke.py b/api/dify_graph/model_runtime/errors/invoke.py deleted file mode 100644 index 1a57078b988..00000000000 --- a/api/dify_graph/model_runtime/errors/invoke.py +++ /dev/null @@ -1,41 +0,0 @@ -class InvokeError(ValueError): - """Base class for all LLM exceptions.""" - - description: str | None = None - - def __init__(self, description: str | None = None): - if description is not None: - self.description = description - - def __str__(self): - return self.description or self.__class__.__name__ - - -class InvokeConnectionError(InvokeError): - """Raised when the Invoke returns connection error.""" - - description = "Connection Error" - - -class InvokeServerUnavailableError(InvokeError): - """Raised when the Invoke returns server unavailable error.""" - - description = "Server Unavailable Error" - - -class InvokeRateLimitError(InvokeError): - """Raised when the Invoke returns rate limit error.""" - - description = "Rate Limit Error" - - -class InvokeAuthorizationError(InvokeError): - """Raised when the Invoke returns authorization error.""" - - description = "Incorrect model credentials provided, please check and try again. " - - -class InvokeBadRequestError(InvokeError): - """Raised when the Invoke returns bad request.""" - - description = "Bad Request Error" diff --git a/api/dify_graph/model_runtime/errors/validate.py b/api/dify_graph/model_runtime/errors/validate.py deleted file mode 100644 index 16bebcc67db..00000000000 --- a/api/dify_graph/model_runtime/errors/validate.py +++ /dev/null @@ -1,6 +0,0 @@ -class CredentialsValidateFailedError(ValueError): - """ - Credentials validate failed error - """ - - pass diff --git a/api/dify_graph/model_runtime/memory/__init__.py b/api/dify_graph/model_runtime/memory/__init__.py deleted file mode 100644 index 2d954486c30..00000000000 --- a/api/dify_graph/model_runtime/memory/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .prompt_message_memory import DEFAULT_MEMORY_MAX_TOKEN_LIMIT, PromptMessageMemory - -__all__ = ["DEFAULT_MEMORY_MAX_TOKEN_LIMIT", "PromptMessageMemory"] diff --git a/api/dify_graph/model_runtime/memory/prompt_message_memory.py b/api/dify_graph/model_runtime/memory/prompt_message_memory.py deleted file mode 100644 index a76a7faf71b..00000000000 --- a/api/dify_graph/model_runtime/memory/prompt_message_memory.py +++ /dev/null @@ -1,18 +0,0 @@ -from __future__ import annotations - -from collections.abc import Sequence -from typing import Protocol - -from dify_graph.model_runtime.entities import PromptMessage - -DEFAULT_MEMORY_MAX_TOKEN_LIMIT = 2000 - - -class PromptMessageMemory(Protocol): - """Port for loading memory as prompt messages.""" - - def get_history_prompt_messages( - self, max_token_limit: int = DEFAULT_MEMORY_MAX_TOKEN_LIMIT, message_limit: int | None = None - ) -> Sequence[PromptMessage]: - """Return historical prompt messages constrained by token/message limits.""" - ... diff --git a/api/dify_graph/model_runtime/model_providers/__base/ai_model.py b/api/dify_graph/model_runtime/model_providers/__base/ai_model.py deleted file mode 100644 index ac7ae9925b5..00000000000 --- a/api/dify_graph/model_runtime/model_providers/__base/ai_model.py +++ /dev/null @@ -1,286 +0,0 @@ -import decimal -import hashlib -import logging - -from pydantic import BaseModel, ConfigDict, Field, ValidationError -from redis import RedisError - -from configs import dify_config -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE -from dify_graph.model_runtime.entities.model_entities import ( - AIModelEntity, - DefaultParameterName, - ModelType, - PriceConfig, - PriceInfo, - PriceType, -) -from dify_graph.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from extensions.ext_redis import redis_client - -logger = logging.getLogger(__name__) - - -class AIModel(BaseModel): - """ - Base class for all models. - """ - - tenant_id: str = Field(description="Tenant ID") - model_type: ModelType = Field(description="Model type") - plugin_id: str = Field(description="Plugin ID") - provider_name: str = Field(description="Provider") - plugin_model_provider: PluginModelProviderEntity = Field(description="Plugin model provider") - started_at: float = Field(description="Invoke start time", default=0) - - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - - @property - def _invoke_error_mapping(self) -> dict[type[Exception], list[type[Exception]]]: - """ - Map model invoke error to unified error - The key is the error type thrown to the caller - The value is the error type thrown by the model, - which needs to be converted into a unified error type for the caller. - - :return: Invoke error mapping - """ - from core.plugin.entities.plugin_daemon import PluginDaemonInnerError - - return { - InvokeConnectionError: [InvokeConnectionError], - InvokeServerUnavailableError: [InvokeServerUnavailableError], - InvokeRateLimitError: [InvokeRateLimitError], - InvokeAuthorizationError: [InvokeAuthorizationError], - InvokeBadRequestError: [InvokeBadRequestError], - PluginDaemonInnerError: [PluginDaemonInnerError], - ValueError: [ValueError], - } - - def _transform_invoke_error(self, error: Exception) -> Exception: - """ - Transform invoke error to unified error - - :param error: model invoke error - :return: unified error - """ - for invoke_error, model_errors in self._invoke_error_mapping.items(): - if isinstance(error, tuple(model_errors)): - if invoke_error == InvokeAuthorizationError: - return InvokeAuthorizationError( - description=( - f"[{self.provider_name}] Incorrect model credentials provided, please check and try again." - ) - ) - elif isinstance(invoke_error, InvokeError): - return InvokeError(description=f"[{self.provider_name}] {invoke_error.description}, {str(error)}") - else: - return error - - return InvokeError(description=f"[{self.provider_name}] Error: {str(error)}") - - def get_price(self, model: str, credentials: dict, price_type: PriceType, tokens: int) -> PriceInfo: - """ - Get price for given model and tokens - - :param model: model name - :param credentials: model credentials - :param price_type: price type - :param tokens: number of tokens - :return: price info - """ - # get model schema - model_schema = self.get_model_schema(model, credentials) - - # get price info from predefined model schema - price_config: PriceConfig | None = None - if model_schema and model_schema.pricing: - price_config = model_schema.pricing - - # get unit price - unit_price = None - if price_config: - if price_type == PriceType.INPUT: - unit_price = price_config.input - elif price_type == PriceType.OUTPUT and price_config.output is not None: - unit_price = price_config.output - - if unit_price is None: - return PriceInfo( - unit_price=decimal.Decimal("0.0"), - unit=decimal.Decimal("0.0"), - total_amount=decimal.Decimal("0.0"), - currency="USD", - ) - - # calculate total amount - if not price_config: - raise ValueError(f"Price config not found for model {model}") - total_amount = tokens * unit_price * price_config.unit - total_amount = total_amount.quantize(decimal.Decimal("0.0000001"), rounding=decimal.ROUND_HALF_UP) - - return PriceInfo( - unit_price=unit_price, - unit=price_config.unit, - total_amount=total_amount, - currency=price_config.currency, - ) - - def get_model_schema(self, model: str, credentials: dict | None = None) -> AIModelEntity | None: - """ - Get model schema by model name and credentials - - :param model: model name - :param credentials: model credentials - :return: model schema - """ - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}" - sorted_credentials = sorted(credentials.items()) if credentials else [] - cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials]) - - cached_schema_json = None - try: - cached_schema_json = redis_client.get(cache_key) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to read plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - if cached_schema_json: - try: - return AIModelEntity.model_validate_json(cached_schema_json) - except ValidationError: - logger.warning( - "Failed to validate cached plugin model schema for model %s", - model, - exc_info=True, - ) - try: - redis_client.delete(cache_key) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to delete invalid plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - - schema = plugin_model_manager.get_model_schema( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model_type=self.model_type.value, - model=model, - credentials=credentials or {}, - ) - - if schema: - try: - redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json()) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to write plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - - return schema - - def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> AIModelEntity | None: - """ - Get customizable model schema from credentials - - :param model: model name - :param credentials: model credentials - :return: model schema - """ - - # get customizable model schema - schema = self.get_customizable_model_schema(model, credentials) - if not schema: - return None - - # fill in the template - new_parameter_rules = [] - for parameter_rule in schema.parameter_rules: - if parameter_rule.use_template: - try: - default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template) - default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name) - if not parameter_rule.max and "max" in default_parameter_rule: - parameter_rule.max = default_parameter_rule["max"] - if not parameter_rule.min and "min" in default_parameter_rule: - parameter_rule.min = default_parameter_rule["min"] - if not parameter_rule.default and "default" in default_parameter_rule: - parameter_rule.default = default_parameter_rule["default"] - if not parameter_rule.precision and "precision" in default_parameter_rule: - parameter_rule.precision = default_parameter_rule["precision"] - if not parameter_rule.required and "required" in default_parameter_rule: - parameter_rule.required = default_parameter_rule["required"] - if not parameter_rule.help and "help" in default_parameter_rule: - parameter_rule.help = I18nObject( - en_US=default_parameter_rule["help"]["en_US"], - ) - if ( - parameter_rule.help - and not parameter_rule.help.en_US - and ("help" in default_parameter_rule and "en_US" in default_parameter_rule["help"]) - ): - parameter_rule.help.en_US = default_parameter_rule["help"]["en_US"] - if ( - parameter_rule.help - and not parameter_rule.help.zh_Hans - and ("help" in default_parameter_rule and "zh_Hans" in default_parameter_rule["help"]) - ): - parameter_rule.help.zh_Hans = default_parameter_rule["help"].get( - "zh_Hans", default_parameter_rule["help"]["en_US"] - ) - except ValueError: - pass - - new_parameter_rules.append(parameter_rule) - - schema.parameter_rules = new_parameter_rules - - return schema - - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: - """ - Get customizable model schema - - :param model: model name - :param credentials: model credentials - :return: model schema - """ - return None - - def _get_default_parameter_rule_variable_map(self, name: DefaultParameterName): - """ - Get default parameter rule for given name - - :param name: parameter name - :return: parameter rule - """ - default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(name) - - if not default_parameter_rule: - raise Exception(f"Invalid model parameter rule name {name}") - - return default_parameter_rule diff --git a/api/dify_graph/model_runtime/model_providers/__base/large_language_model.py b/api/dify_graph/model_runtime/model_providers/__base/large_language_model.py deleted file mode 100644 index bf864ca227d..00000000000 --- a/api/dify_graph/model_runtime/model_providers/__base/large_language_model.py +++ /dev/null @@ -1,668 +0,0 @@ -import logging -import time -import uuid -from collections.abc import Callable, Generator, Iterator, Sequence -from typing import Union - -from pydantic import ConfigDict - -from configs import dify_config -from dify_graph.model_runtime.callbacks.base_callback import Callback -from dify_graph.model_runtime.callbacks.logging_callback import LoggingCallback -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage -from dify_graph.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessage, - PromptMessageContentUnionTypes, - PromptMessageTool, - TextPromptMessageContent, -) -from dify_graph.model_runtime.entities.model_entities import ( - ModelType, - PriceType, -) -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel - -logger = logging.getLogger(__name__) - - -def _gen_tool_call_id() -> str: - return f"chatcmpl-tool-{str(uuid.uuid4().hex)}" - - -def _run_callbacks(callbacks: Sequence[Callback] | None, *, event: str, invoke: Callable[[Callback], None]) -> None: - if not callbacks: - return - - for callback in callbacks: - try: - invoke(callback) - except Exception as e: - if callback.raise_error: - raise - logger.warning("Callback %s %s failed with error %s", callback.__class__.__name__, event, e) - - -def _get_or_create_tool_call( - existing_tools_calls: list[AssistantPromptMessage.ToolCall], - tool_call_id: str, -) -> AssistantPromptMessage.ToolCall: - """ - Get or create a tool call by ID. - - If `tool_call_id` is empty, returns the most recently created tool call. - """ - if not tool_call_id: - if not existing_tools_calls: - raise ValueError("tool_call_id is empty but no existing tool call is available to apply the delta") - return existing_tools_calls[-1] - - tool_call = next((tool_call for tool_call in existing_tools_calls if tool_call.id == tool_call_id), None) - if tool_call is None: - tool_call = AssistantPromptMessage.ToolCall( - id=tool_call_id, - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""), - ) - existing_tools_calls.append(tool_call) - - return tool_call - - -def _merge_tool_call_delta( - tool_call: AssistantPromptMessage.ToolCall, - delta: AssistantPromptMessage.ToolCall, -) -> None: - if delta.id: - tool_call.id = delta.id - if delta.type: - tool_call.type = delta.type - if delta.function.name: - tool_call.function.name = delta.function.name - if delta.function.arguments: - tool_call.function.arguments += delta.function.arguments - - -def _build_llm_result_from_chunks( - model: str, - prompt_messages: Sequence[PromptMessage], - chunks: Iterator[LLMResultChunk], -) -> LLMResult: - """ - Build a single `LLMResult` by accumulating all returned chunks. - - Some models only support streaming output (e.g. Qwen3 open-source edition) - and the plugin side may still implement the response via a chunked stream, - so all chunks must be consumed and concatenated into a single ``LLMResult``. - - The ``usage`` is taken from the last chunk that carries it, which is the - typical convention for streaming responses (the final chunk contains the - aggregated token counts). - """ - content = "" - content_list: list[PromptMessageContentUnionTypes] = [] - usage = LLMUsage.empty_usage() - system_fingerprint: str | None = None - tools_calls: list[AssistantPromptMessage.ToolCall] = [] - - try: - for chunk in chunks: - if isinstance(chunk.delta.message.content, str): - content += chunk.delta.message.content - elif isinstance(chunk.delta.message.content, list): - content_list.extend(chunk.delta.message.content) - - if chunk.delta.message.tool_calls: - _increase_tool_call(chunk.delta.message.tool_calls, tools_calls) - - if chunk.delta.usage: - usage = chunk.delta.usage - if chunk.system_fingerprint: - system_fingerprint = chunk.system_fingerprint - except Exception: - logger.exception("Error while consuming non-stream plugin chunk iterator.") - raise - finally: - # Drain any remaining chunks to release underlying streaming resources (e.g. HTTP connections). - close = getattr(chunks, "close", None) - if callable(close): - close() - - return LLMResult( - model=model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=content or content_list, - tool_calls=tools_calls, - ), - usage=usage, - system_fingerprint=system_fingerprint, - ) - - -def _invoke_llm_via_plugin( - *, - tenant_id: str, - user_id: str, - plugin_id: str, - provider: str, - model: str, - credentials: dict, - model_parameters: dict, - prompt_messages: Sequence[PromptMessage], - tools: list[PromptMessageTool] | None, - stop: Sequence[str] | None, - stream: bool, -) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_llm( - tenant_id=tenant_id, - user_id=user_id, - plugin_id=plugin_id, - provider=provider, - model=model, - credentials=credentials, - model_parameters=model_parameters, - prompt_messages=list(prompt_messages), - tools=tools, - stop=list(stop) if stop else None, - stream=stream, - ) - - -def _normalize_non_stream_plugin_result( - model: str, - prompt_messages: Sequence[PromptMessage], - result: Union[LLMResult, Iterator[LLMResultChunk]], -) -> LLMResult: - if isinstance(result, LLMResult): - return result - return _build_llm_result_from_chunks(model=model, prompt_messages=prompt_messages, chunks=result) - - -def _increase_tool_call( - new_tool_calls: list[AssistantPromptMessage.ToolCall], existing_tools_calls: list[AssistantPromptMessage.ToolCall] -): - """ - Merge incremental tool call updates into existing tool calls. - - :param new_tool_calls: List of new tool call deltas to be merged. - :param existing_tools_calls: List of existing tool calls to be modified IN-PLACE. - """ - - for new_tool_call in new_tool_calls: - # generate ID for tool calls with function name but no ID to track them - if new_tool_call.function.name and not new_tool_call.id: - new_tool_call.id = _gen_tool_call_id() - - tool_call = _get_or_create_tool_call(existing_tools_calls, new_tool_call.id) - _merge_tool_call_delta(tool_call, new_tool_call) - - -class LargeLanguageModel(AIModel): - """ - Model class for large language model. - """ - - model_type: ModelType = ModelType.LLM - - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - - def invoke( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict | None = None, - tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, - stream: bool = True, - user: str | None = None, - callbacks: list[Callback] | None = None, - ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: - """ - Invoke large language model - - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - :param callbacks: callbacks - :return: full response or stream response chunk generator result - """ - # validate and filter model parameters - if model_parameters is None: - model_parameters = {} - - self.started_at = time.perf_counter() - - callbacks = callbacks or [] - - if dify_config.DEBUG: - callbacks.append(LoggingCallback()) - - # trigger before invoke callbacks - self._trigger_before_invoke_callbacks( - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - callbacks=callbacks, - ) - - result: Union[LLMResult, Generator[LLMResultChunk, None, None]] - - try: - result = _invoke_llm_via_plugin( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model=model, - credentials=credentials, - model_parameters=model_parameters, - prompt_messages=prompt_messages, - tools=tools, - stop=stop, - stream=stream, - ) - - if not stream: - result = _normalize_non_stream_plugin_result( - model=model, prompt_messages=prompt_messages, result=result - ) - except Exception as e: - self._trigger_invoke_error_callbacks( - model=model, - ex=e, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - callbacks=callbacks, - ) - - # TODO - raise self._transform_invoke_error(e) - - if stream and not isinstance(result, LLMResult): - return self._invoke_result_generator( - model=model, - result=result, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - callbacks=callbacks, - ) - elif isinstance(result, LLMResult): - self._trigger_after_invoke_callbacks( - model=model, - result=result, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - callbacks=callbacks, - ) - # Following https://github.com/langgenius/dify/issues/17799, - # we removed the prompt_messages from the chunk on the plugin daemon side. - # To ensure compatibility, we add the prompt_messages back here. - result.prompt_messages = prompt_messages - return result - raise NotImplementedError("unsupported invoke result type", type(result)) - - def _invoke_result_generator( - self, - model: str, - result: Generator[LLMResultChunk, None, None], - credentials: dict, - prompt_messages: Sequence[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - callbacks: list[Callback] | None = None, - ) -> Generator[LLMResultChunk, None, None]: - """ - Invoke result generator - - :param result: result generator - :return: result generator - """ - callbacks = callbacks or [] - message_content: list[PromptMessageContentUnionTypes] = [] - usage = None - system_fingerprint = None - real_model = model - - def _update_message_content(content: str | list[PromptMessageContentUnionTypes] | None): - if not content: - return - if isinstance(content, list): - message_content.extend(content) - return - if isinstance(content, str): - message_content.append(TextPromptMessageContent(data=content)) - return - - try: - for chunk in result: - # Following https://github.com/langgenius/dify/issues/17799, - # we removed the prompt_messages from the chunk on the plugin daemon side. - # To ensure compatibility, we add the prompt_messages back here. - chunk.prompt_messages = prompt_messages - yield chunk - - self._trigger_new_chunk_callbacks( - chunk=chunk, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - callbacks=callbacks, - ) - - _update_message_content(chunk.delta.message.content) - - real_model = chunk.model - if chunk.delta.usage: - usage = chunk.delta.usage - - if chunk.system_fingerprint: - system_fingerprint = chunk.system_fingerprint - except Exception as e: - raise self._transform_invoke_error(e) - - assistant_message = AssistantPromptMessage(content=message_content) - self._trigger_after_invoke_callbacks( - model=model, - result=LLMResult( - model=real_model, - prompt_messages=prompt_messages, - message=assistant_message, - usage=usage or LLMUsage.empty_usage(), - system_fingerprint=system_fingerprint, - ), - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - callbacks=callbacks, - ) - - def get_num_tokens( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None, - ) -> int: - """ - Get number of tokens for given prompt messages - - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param tools: tools for tool calling - :return: - """ - if dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.get_llm_num_tokens( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model_type=self.model_type.value, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - tools=tools, - ) - return 0 - - def calc_response_usage( - self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int - ) -> LLMUsage: - """ - Calculate response usage - - :param model: model name - :param credentials: model credentials - :param prompt_tokens: prompt tokens - :param completion_tokens: completion tokens - :return: usage - """ - # get prompt price info - prompt_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=prompt_tokens, - ) - - # get completion price info - completion_price_info = self.get_price( - model=model, credentials=credentials, price_type=PriceType.OUTPUT, tokens=completion_tokens - ) - - # transform usage - usage = LLMUsage( - prompt_tokens=prompt_tokens, - prompt_unit_price=prompt_price_info.unit_price, - prompt_price_unit=prompt_price_info.unit, - prompt_price=prompt_price_info.total_amount, - completion_tokens=completion_tokens, - completion_unit_price=completion_price_info.unit_price, - completion_price_unit=completion_price_info.unit, - completion_price=completion_price_info.total_amount, - total_tokens=prompt_tokens + completion_tokens, - total_price=prompt_price_info.total_amount + completion_price_info.total_amount, - currency=prompt_price_info.currency, - latency=time.perf_counter() - self.started_at, - ) - - return usage - - def _trigger_before_invoke_callbacks( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - callbacks: list[Callback] | None = None, - ): - """ - Trigger before invoke callbacks - - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - :param callbacks: callbacks - """ - _run_callbacks( - callbacks, - event="on_before_invoke", - invoke=lambda callback: callback.on_before_invoke( - llm_instance=self, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - ), - ) - - def _trigger_new_chunk_callbacks( - self, - chunk: LLMResultChunk, - model: str, - credentials: dict, - prompt_messages: Sequence[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - callbacks: list[Callback] | None = None, - ): - """ - Trigger new chunk callbacks - - :param chunk: chunk - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - """ - _run_callbacks( - callbacks, - event="on_new_chunk", - invoke=lambda callback: callback.on_new_chunk( - llm_instance=self, - chunk=chunk, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - ), - ) - - def _trigger_after_invoke_callbacks( - self, - model: str, - result: LLMResult, - credentials: dict, - prompt_messages: Sequence[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - callbacks: list[Callback] | None = None, - ): - """ - Trigger after invoke callbacks - - :param model: model name - :param result: result - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - :param callbacks: callbacks - """ - _run_callbacks( - callbacks, - event="on_after_invoke", - invoke=lambda callback: callback.on_after_invoke( - llm_instance=self, - result=result, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - ), - ) - - def _trigger_invoke_error_callbacks( - self, - model: str, - ex: Exception, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - callbacks: list[Callback] | None = None, - ): - """ - Trigger invoke error callbacks - - :param model: model name - :param ex: exception - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - :param callbacks: callbacks - """ - _run_callbacks( - callbacks, - event="on_invoke_error", - invoke=lambda callback: callback.on_invoke_error( - llm_instance=self, - ex=ex, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - ), - ) diff --git a/api/dify_graph/model_runtime/model_providers/__base/moderation_model.py b/api/dify_graph/model_runtime/model_providers/__base/moderation_model.py deleted file mode 100644 index 5fa3d1634b3..00000000000 --- a/api/dify_graph/model_runtime/model_providers/__base/moderation_model.py +++ /dev/null @@ -1,45 +0,0 @@ -import time - -from pydantic import ConfigDict - -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel - - -class ModerationModel(AIModel): - """ - Model class for moderation model. - """ - - model_type: ModelType = ModelType.MODERATION - - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - - def invoke(self, model: str, credentials: dict, text: str, user: str | None = None) -> bool: - """ - Invoke moderation model - - :param model: model name - :param credentials: model credentials - :param text: text to moderate - :param user: unique user id - :return: false if text is safe, true otherwise - """ - self.started_at = time.perf_counter() - - try: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_moderation( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model=model, - credentials=credentials, - text=text, - ) - except Exception as e: - raise self._transform_invoke_error(e) diff --git a/api/dify_graph/model_runtime/model_providers/__base/rerank_model.py b/api/dify_graph/model_runtime/model_providers/__base/rerank_model.py deleted file mode 100644 index 5da2b84b951..00000000000 --- a/api/dify_graph/model_runtime/model_providers/__base/rerank_model.py +++ /dev/null @@ -1,92 +0,0 @@ -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.entities.rerank_entities import RerankResult -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel - - -class RerankModel(AIModel): - """ - Base Model class for rerank model. - """ - - model_type: ModelType = ModelType.RERANK - - def invoke( - self, - model: str, - credentials: dict, - query: str, - docs: list[str], - score_threshold: float | None = None, - top_n: int | None = None, - user: str | None = None, - ) -> RerankResult: - """ - Invoke rerank model - - :param model: model name - :param credentials: model credentials - :param query: search query - :param docs: docs for reranking - :param score_threshold: score threshold - :param top_n: top n - :param user: unique user id - :return: rerank result - """ - try: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_rerank( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model=model, - credentials=credentials, - query=query, - docs=docs, - score_threshold=score_threshold, - top_n=top_n, - ) - except Exception as e: - raise self._transform_invoke_error(e) - - def invoke_multimodal_rerank( - self, - model: str, - credentials: dict, - query: dict, - docs: list[dict], - score_threshold: float | None = None, - top_n: int | None = None, - user: str | None = None, - ) -> RerankResult: - """ - Invoke multimodal rerank model - :param model: model name - :param credentials: model credentials - :param query: search query - :param docs: docs for reranking - :param score_threshold: score threshold - :param top_n: top n - :param user: unique user id - :return: rerank result - """ - try: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_multimodal_rerank( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model=model, - credentials=credentials, - query=query, - docs=docs, - score_threshold=score_threshold, - top_n=top_n, - ) - except Exception as e: - raise self._transform_invoke_error(e) diff --git a/api/dify_graph/model_runtime/model_providers/__base/speech2text_model.py b/api/dify_graph/model_runtime/model_providers/__base/speech2text_model.py deleted file mode 100644 index e69069a85df..00000000000 --- a/api/dify_graph/model_runtime/model_providers/__base/speech2text_model.py +++ /dev/null @@ -1,43 +0,0 @@ -from typing import IO - -from pydantic import ConfigDict - -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel - - -class Speech2TextModel(AIModel): - """ - Model class for speech2text model. - """ - - model_type: ModelType = ModelType.SPEECH2TEXT - - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - - def invoke(self, model: str, credentials: dict, file: IO[bytes], user: str | None = None) -> str: - """ - Invoke speech to text model - - :param model: model name - :param credentials: model credentials - :param file: audio file - :param user: unique user id - :return: text for given audio file - """ - try: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_speech_to_text( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model=model, - credentials=credentials, - file=file, - ) - except Exception as e: - raise self._transform_invoke_error(e) diff --git a/api/dify_graph/model_runtime/model_providers/__base/text_embedding_model.py b/api/dify_graph/model_runtime/model_providers/__base/text_embedding_model.py deleted file mode 100644 index 3438da2ada6..00000000000 --- a/api/dify_graph/model_runtime/model_providers/__base/text_embedding_model.py +++ /dev/null @@ -1,121 +0,0 @@ -from pydantic import ConfigDict - -from core.entities.embedding_type import EmbeddingInputType -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel - - -class TextEmbeddingModel(AIModel): - """ - Model class for text embedding model. - """ - - model_type: ModelType = ModelType.TEXT_EMBEDDING - - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - - def invoke( - self, - model: str, - credentials: dict, - texts: list[str] | None = None, - multimodel_documents: list[dict] | None = None, - user: str | None = None, - input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, - ) -> EmbeddingResult: - """ - Invoke text embedding model - - :param model: model name - :param credentials: model credentials - :param texts: texts to embed - :param files: files to embed - :param user: unique user id - :param input_type: input type - :return: embeddings result - """ - from core.plugin.impl.model import PluginModelClient - - try: - plugin_model_manager = PluginModelClient() - if texts: - return plugin_model_manager.invoke_text_embedding( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model=model, - credentials=credentials, - texts=texts, - input_type=input_type, - ) - if multimodel_documents: - return plugin_model_manager.invoke_multimodal_embedding( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model=model, - credentials=credentials, - documents=multimodel_documents, - input_type=input_type, - ) - raise ValueError("No texts or files provided") - except Exception as e: - raise self._transform_invoke_error(e) - - def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> list[int]: - """ - Get number of tokens for given prompt messages - - :param model: model name - :param credentials: model credentials - :param texts: texts to embed - :return: - """ - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.get_text_embedding_num_tokens( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model=model, - credentials=credentials, - texts=texts, - ) - - def _get_context_size(self, model: str, credentials: dict) -> int: - """ - Get context size for given embedding model - - :param model: model name - :param credentials: model credentials - :return: context size - """ - model_schema = self.get_model_schema(model, credentials) - - if model_schema and ModelPropertyKey.CONTEXT_SIZE in model_schema.model_properties: - content_size: int = model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE] - return content_size - - return 1000 - - def _get_max_chunks(self, model: str, credentials: dict) -> int: - """ - Get max chunks for given embedding model - - :param model: model name - :param credentials: model credentials - :return: max chunks - """ - model_schema = self.get_model_schema(model, credentials) - - if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties: - max_chunks: int = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] - return max_chunks - - return 1 diff --git a/api/dify_graph/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py b/api/dify_graph/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py deleted file mode 100644 index 3967acf07ba..00000000000 --- a/api/dify_graph/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py +++ /dev/null @@ -1,53 +0,0 @@ -import logging -from threading import Lock -from typing import Any - -logger = logging.getLogger(__name__) - -_tokenizer: Any | None = None -_lock = Lock() - - -class GPT2Tokenizer: - @staticmethod - def _get_num_tokens_by_gpt2(text: str) -> int: - """ - use gpt2 tokenizer to get num tokens - """ - _tokenizer = GPT2Tokenizer.get_encoder() - tokens = _tokenizer.encode(text) # type: ignore - return len(tokens) - - @staticmethod - def get_num_tokens(text: str) -> int: - # Because this process needs more cpu resource, we turn this back before we find a better way to handle it. - # - # future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text) - # result = future.result() - # return cast(int, result) - return GPT2Tokenizer._get_num_tokens_by_gpt2(text) - - @staticmethod - def get_encoder(): - global _tokenizer, _lock - if _tokenizer is not None: - return _tokenizer - with _lock: - if _tokenizer is None: - # Try to use tiktoken to get the tokenizer because it is faster - # - try: - import tiktoken - - _tokenizer = tiktoken.get_encoding("gpt2") - except Exception: - from os.path import abspath, dirname, join - - from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer - - base_path = abspath(__file__) - gpt2_tokenizer_path = join(dirname(base_path), "gpt2") - _tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path) - logger.info("Fallback to Transformers' GPT-2 tokenizer from tiktoken") - - return _tokenizer diff --git a/api/dify_graph/model_runtime/model_providers/__base/tts_model.py b/api/dify_graph/model_runtime/model_providers/__base/tts_model.py deleted file mode 100644 index 0656529f22f..00000000000 --- a/api/dify_graph/model_runtime/model_providers/__base/tts_model.py +++ /dev/null @@ -1,79 +0,0 @@ -import logging -from collections.abc import Iterable - -from pydantic import ConfigDict - -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel - -logger = logging.getLogger(__name__) - - -class TTSModel(AIModel): - """ - Model class for TTS model. - """ - - model_type: ModelType = ModelType.TTS - - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - - def invoke( - self, - model: str, - tenant_id: str, - credentials: dict, - content_text: str, - voice: str, - user: str | None = None, - ) -> Iterable[bytes]: - """ - Invoke large language model - - :param model: model name - :param tenant_id: user tenant id - :param credentials: model credentials - :param voice: model timbre - :param content_text: text content to be translated - :param user: unique user id - :return: translated audio file - """ - try: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_tts( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model=model, - credentials=credentials, - content_text=content_text, - voice=voice, - ) - except Exception as e: - raise self._transform_invoke_error(e) - - def get_tts_model_voices(self, model: str, credentials: dict, language: str | None = None): - """ - Retrieves the list of voices supported by a given text-to-speech (TTS) model. - - :param language: The language for which the voices are requested. - :param model: The name of the TTS model. - :param credentials: The credentials required to access the TTS model. - :return: A list of voices supported by the TTS model. - """ - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.get_tts_model_voices( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model=model, - credentials=credentials, - language=language, - ) diff --git a/api/dify_graph/model_runtime/model_providers/_position.yaml b/api/dify_graph/model_runtime/model_providers/_position.yaml deleted file mode 100644 index fb02de3a67c..00000000000 --- a/api/dify_graph/model_runtime/model_providers/_position.yaml +++ /dev/null @@ -1,43 +0,0 @@ -- openai -- deepseek -- anthropic -- azure_openai -- google -- vertex_ai -- nvidia -- nvidia_nim -- cohere -- upstage -- bedrock -- togetherai -- openrouter -- ollama -- mistralai -- groq -- replicate -- huggingface_hub -- xinference -- triton_inference_server -- zhipuai -- baichuan -- spark -- minimax -- tongyi -- wenxin -- moonshot -- tencent -- jina -- chatglm -- yi -- openllm -- localai -- volcengine_maas -- openai_api_compatible -- hunyuan -- siliconflow -- perfxcloud -- zhinao -- fireworks -- mixedbread -- nomic -- voyage diff --git a/api/dify_graph/model_runtime/model_providers/model_provider_factory.py b/api/dify_graph/model_runtime/model_providers/model_provider_factory.py deleted file mode 100644 index de0677a3481..00000000000 --- a/api/dify_graph/model_runtime/model_providers/model_provider_factory.py +++ /dev/null @@ -1,387 +0,0 @@ -from __future__ import annotations - -import hashlib -import logging -from collections.abc import Sequence -from threading import Lock - -from pydantic import ValidationError -from redis import RedisError - -import contexts -from configs import dify_config -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType -from dify_graph.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel -from dify_graph.model_runtime.model_providers.__base.rerank_model import RerankModel -from dify_graph.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from dify_graph.model_runtime.model_providers.__base.tts_model import TTSModel -from dify_graph.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator -from dify_graph.model_runtime.schema_validators.provider_credential_schema_validator import ( - ProviderCredentialSchemaValidator, -) -from extensions.ext_redis import redis_client -from models.provider_ids import ModelProviderID - -logger = logging.getLogger(__name__) - - -class ModelProviderFactory: - def __init__(self, tenant_id: str): - from core.plugin.impl.model import PluginModelClient - - self.tenant_id = tenant_id - self.plugin_model_manager = PluginModelClient() - - def get_providers(self) -> Sequence[ProviderEntity]: - """ - Get all providers - :return: list of providers - """ - # FIXME(-LAN-): Removed position map sorting since providers are fetched from plugin server - # The plugin server should return providers in the desired order - plugin_providers = self.get_plugin_model_providers() - return [provider.declaration for provider in plugin_providers] - - def get_plugin_model_providers(self) -> Sequence[PluginModelProviderEntity]: - """ - Get all plugin model providers - :return: list of plugin model providers - """ - # check if context is set - try: - contexts.plugin_model_providers.get() - except LookupError: - contexts.plugin_model_providers.set(None) - contexts.plugin_model_providers_lock.set(Lock()) - - with contexts.plugin_model_providers_lock.get(): - plugin_model_providers = contexts.plugin_model_providers.get() - if plugin_model_providers is not None: - return plugin_model_providers - - plugin_model_providers = [] - contexts.plugin_model_providers.set(plugin_model_providers) - - # Fetch plugin model providers - plugin_providers = self.plugin_model_manager.fetch_model_providers(self.tenant_id) - - for provider in plugin_providers: - provider.declaration.provider = provider.plugin_id + "/" + provider.declaration.provider - plugin_model_providers.append(provider) - - return plugin_model_providers - - def get_provider_schema(self, provider: str) -> ProviderEntity: - """ - Get provider schema - :param provider: provider name - :return: provider schema - """ - plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider) - return plugin_model_provider_entity.declaration - - def get_plugin_model_provider(self, provider: str) -> PluginModelProviderEntity: - """ - Get plugin model provider - :param provider: provider name - :return: provider schema - """ - if "/" not in provider: - provider = str(ModelProviderID(provider)) - - # fetch plugin model providers - plugin_model_provider_entities = self.get_plugin_model_providers() - - # get the provider - plugin_model_provider_entity = next( - (p for p in plugin_model_provider_entities if p.declaration.provider == provider), - None, - ) - - if not plugin_model_provider_entity: - raise ValueError(f"Invalid provider: {provider}") - - return plugin_model_provider_entity - - def provider_credentials_validate(self, *, provider: str, credentials: dict): - """ - Validate provider credentials - - :param provider: provider name - :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. - :return: - """ - # fetch plugin model provider - plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider) - - # get provider_credential_schema and validate credentials according to the rules - provider_credential_schema = plugin_model_provider_entity.declaration.provider_credential_schema - if not provider_credential_schema: - raise ValueError(f"Provider {provider} does not have provider_credential_schema") - - # validate provider credential schema - validator = ProviderCredentialSchemaValidator(provider_credential_schema) - filtered_credentials = validator.validate_and_filter(credentials) - - # validate the credentials, raise exception if validation failed - self.plugin_model_manager.validate_provider_credentials( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=plugin_model_provider_entity.plugin_id, - provider=plugin_model_provider_entity.provider, - credentials=filtered_credentials, - ) - - return filtered_credentials - - def model_credentials_validate(self, *, provider: str, model_type: ModelType, model: str, credentials: dict): - """ - Validate model credentials - - :param provider: provider name - :param model_type: model type - :param model: model name - :param credentials: model credentials, credentials form defined in `model_credential_schema`. - :return: - """ - # fetch plugin model provider - plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider) - - # get model_credential_schema and validate credentials according to the rules - model_credential_schema = plugin_model_provider_entity.declaration.model_credential_schema - if not model_credential_schema: - raise ValueError(f"Provider {provider} does not have model_credential_schema") - - # validate model credential schema - validator = ModelCredentialSchemaValidator(model_type, model_credential_schema) - filtered_credentials = validator.validate_and_filter(credentials) - - # call validate_credentials method of model type to validate credentials, raise exception if validation failed - self.plugin_model_manager.validate_model_credentials( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=plugin_model_provider_entity.plugin_id, - provider=plugin_model_provider_entity.provider, - model_type=model_type.value, - model=model, - credentials=filtered_credentials, - ) - - return filtered_credentials - - def get_model_schema( - self, *, provider: str, model_type: ModelType, model: str, credentials: dict | None - ) -> AIModelEntity | None: - """ - Get model schema - """ - plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider) - cache_key = f"{self.tenant_id}:{plugin_id}:{provider_name}:{model_type.value}:{model}" - sorted_credentials = sorted(credentials.items()) if credentials else [] - cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials]) - - cached_schema_json = None - try: - cached_schema_json = redis_client.get(cache_key) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to read plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - if cached_schema_json: - try: - return AIModelEntity.model_validate_json(cached_schema_json) - except ValidationError: - logger.warning( - "Failed to validate cached plugin model schema for model %s", - model, - exc_info=True, - ) - try: - redis_client.delete(cache_key) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to delete invalid plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - - schema = self.plugin_model_manager.get_model_schema( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=plugin_id, - provider=provider_name, - model_type=model_type.value, - model=model, - credentials=credentials or {}, - ) - - if schema: - try: - redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json()) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to write plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - - return schema - - def get_models( - self, - *, - provider: str | None = None, - model_type: ModelType | None = None, - provider_configs: list[ProviderConfig] | None = None, - ) -> list[SimpleProviderEntity]: - """ - Get all models for given model type - - :param provider: provider name - :param model_type: model type - :param provider_configs: list of provider configs - :return: list of models - """ - provider_configs = provider_configs or [] - - # scan all providers - plugin_model_provider_entities = self.get_plugin_model_providers() - - # traverse all model_provider_extensions - providers = [] - for plugin_model_provider_entity in plugin_model_provider_entities: - # filter by provider if provider is present - if provider and plugin_model_provider_entity.declaration.provider != provider: - continue - - # get provider schema - provider_schema = plugin_model_provider_entity.declaration - - model_types = provider_schema.supported_model_types - if model_type: - if model_type not in model_types: - continue - - model_types = [model_type] - - all_model_type_models = [] - for model_schema in provider_schema.models: - if model_schema.model_type != model_type: - continue - - all_model_type_models.append(model_schema) - - simple_provider_schema = provider_schema.to_simple_provider() - if model_type: - simple_provider_schema.models = all_model_type_models - - providers.append(simple_provider_schema) - - return providers - - def get_model_type_instance(self, provider: str, model_type: ModelType) -> AIModel: - """ - Get model type instance by provider name and model type - :param provider: provider name - :param model_type: model type - :return: model type instance - """ - plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider) - init_params = { - "tenant_id": self.tenant_id, - "plugin_id": plugin_id, - "provider_name": provider_name, - "plugin_model_provider": self.get_plugin_model_provider(provider), - } - - if model_type == ModelType.LLM: - return LargeLanguageModel.model_validate(init_params) - elif model_type == ModelType.TEXT_EMBEDDING: - return TextEmbeddingModel.model_validate(init_params) - elif model_type == ModelType.RERANK: - return RerankModel.model_validate(init_params) - elif model_type == ModelType.SPEECH2TEXT: - return Speech2TextModel.model_validate(init_params) - elif model_type == ModelType.MODERATION: - return ModerationModel.model_validate(init_params) - elif model_type == ModelType.TTS: - return TTSModel.model_validate(init_params) - - raise ValueError(f"Unsupported model type: {model_type}") - - def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: - """ - Get provider icon - :param provider: provider name - :param icon_type: icon type (icon_small or icon_small_dark) - :param lang: language (zh_Hans or en_US) - :return: provider icon - """ - # get the provider schema - provider_schema = self.get_provider_schema(provider) - - if icon_type.lower() == "icon_small": - if not provider_schema.icon_small: - raise ValueError(f"Provider {provider} does not have small icon.") - - if lang.lower() == "zh_hans": - file_name = provider_schema.icon_small.zh_Hans - else: - file_name = provider_schema.icon_small.en_US - elif icon_type.lower() == "icon_small_dark": - if not provider_schema.icon_small_dark: - raise ValueError(f"Provider {provider} does not have small dark icon.") - - if lang.lower() == "zh_hans": - file_name = provider_schema.icon_small_dark.zh_Hans - else: - file_name = provider_schema.icon_small_dark.en_US - else: - raise ValueError(f"Unsupported icon type: {icon_type}.") - - if not file_name: - raise ValueError(f"Provider {provider} does not have icon.") - - image_mime_types = { - "jpg": "image/jpeg", - "jpeg": "image/jpeg", - "png": "image/png", - "gif": "image/gif", - "bmp": "image/bmp", - "tiff": "image/tiff", - "tif": "image/tiff", - "webp": "image/webp", - "svg": "image/svg+xml", - "ico": "image/vnd.microsoft.icon", - "heif": "image/heif", - "heic": "image/heic", - } - - extension = file_name.split(".")[-1] - mime_type = image_mime_types.get(extension, "image/png") - - # get icon bytes from plugin asset manager - from core.plugin.impl.asset import PluginAssetManager - - plugin_asset_manager = PluginAssetManager() - return plugin_asset_manager.fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type - - def get_plugin_id_and_provider_name_from_provider(self, provider: str) -> tuple[str, str]: - """ - Get plugin id and provider name from provider name - :param provider: provider name - :return: plugin id and provider name - """ - - provider_id = ModelProviderID(provider) - return provider_id.plugin_id, provider_id.provider_name diff --git a/api/dify_graph/model_runtime/schema_validators/common_validator.py b/api/dify_graph/model_runtime/schema_validators/common_validator.py deleted file mode 100644 index 04cdb8e4f78..00000000000 --- a/api/dify_graph/model_runtime/schema_validators/common_validator.py +++ /dev/null @@ -1,92 +0,0 @@ -from typing import Union, cast - -from dify_graph.model_runtime.entities.provider_entities import CredentialFormSchema, FormType - - -class CommonValidator: - def _validate_and_filter_credential_form_schemas( - self, credential_form_schemas: list[CredentialFormSchema], credentials: dict - ): - need_validate_credential_form_schema_map = {} - for credential_form_schema in credential_form_schemas: - if not credential_form_schema.show_on: - need_validate_credential_form_schema_map[credential_form_schema.variable] = credential_form_schema - continue - - all_show_on_match = True - for show_on_object in credential_form_schema.show_on: - if show_on_object.variable not in credentials: - all_show_on_match = False - break - - if credentials[show_on_object.variable] != show_on_object.value: - all_show_on_match = False - break - - if all_show_on_match: - need_validate_credential_form_schema_map[credential_form_schema.variable] = credential_form_schema - - # Iterate over the remaining credential_form_schemas, verify each credential_form_schema - validated_credentials = {} - for credential_form_schema in need_validate_credential_form_schema_map.values(): - # add the value of the credential_form_schema corresponding to it to validated_credentials - result = self._validate_credential_form_schema(credential_form_schema, credentials) - if result: - validated_credentials[credential_form_schema.variable] = result - - return validated_credentials - - def _validate_credential_form_schema( - self, credential_form_schema: CredentialFormSchema, credentials: dict - ) -> Union[str, bool, None]: - """ - Validate credential form schema - - :param credential_form_schema: credential form schema - :param credentials: credentials - :return: validated credential form schema value - """ - # If the variable does not exist in credentials - value: Union[str, bool, None] = None - if credential_form_schema.variable not in credentials or not credentials[credential_form_schema.variable]: - # If required is True, an exception is thrown - if credential_form_schema.required: - raise ValueError(f"Variable {credential_form_schema.variable} is required") - else: - # Get the value of default - if credential_form_schema.default: - # If it exists, add it to validated_credentials - return credential_form_schema.default - else: - # If default does not exist, skip - return None - - # Get the value corresponding to the variable from credentials - value = cast(str, credentials[credential_form_schema.variable]) - - # If max_length=0, no validation is performed - if credential_form_schema.max_length: - if len(value) > credential_form_schema.max_length: - raise ValueError( - f"Variable {credential_form_schema.variable} length should not be" - f" greater than {credential_form_schema.max_length}" - ) - - # check the type of value - if not isinstance(value, str): - raise ValueError(f"Variable {credential_form_schema.variable} should be string") - - if credential_form_schema.type in {FormType.SELECT, FormType.RADIO}: - # If the value is in options, no validation is performed - if credential_form_schema.options: - if value not in [option.value for option in credential_form_schema.options]: - raise ValueError(f"Variable {credential_form_schema.variable} is not in options") - - if credential_form_schema.type == FormType.SWITCH: - # If the value is not in ['true', 'false'], an exception is thrown - if value.lower() not in {"true", "false"}: - raise ValueError(f"Variable {credential_form_schema.variable} should be true or false") - - value = value.lower() == "true" - - return value diff --git a/api/dify_graph/model_runtime/schema_validators/model_credential_schema_validator.py b/api/dify_graph/model_runtime/schema_validators/model_credential_schema_validator.py deleted file mode 100644 index a97796e98f8..00000000000 --- a/api/dify_graph/model_runtime/schema_validators/model_credential_schema_validator.py +++ /dev/null @@ -1,27 +0,0 @@ -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.entities.provider_entities import ModelCredentialSchema -from dify_graph.model_runtime.schema_validators.common_validator import CommonValidator - - -class ModelCredentialSchemaValidator(CommonValidator): - def __init__(self, model_type: ModelType, model_credential_schema: ModelCredentialSchema): - self.model_type = model_type - self.model_credential_schema = model_credential_schema - - def validate_and_filter(self, credentials: dict): - """ - Validate model credentials - - :param credentials: model credentials - :return: filtered credentials - """ - - if self.model_credential_schema is None: - raise ValueError("Model credential schema is None") - - # get the credential_form_schemas in provider_credential_schema - credential_form_schemas = self.model_credential_schema.credential_form_schemas - - credentials["__model_type"] = self.model_type.value - - return self._validate_and_filter_credential_form_schemas(credential_form_schemas, credentials) diff --git a/api/dify_graph/model_runtime/schema_validators/provider_credential_schema_validator.py b/api/dify_graph/model_runtime/schema_validators/provider_credential_schema_validator.py deleted file mode 100644 index 2fed75a76cf..00000000000 --- a/api/dify_graph/model_runtime/schema_validators/provider_credential_schema_validator.py +++ /dev/null @@ -1,19 +0,0 @@ -from dify_graph.model_runtime.entities.provider_entities import ProviderCredentialSchema -from dify_graph.model_runtime.schema_validators.common_validator import CommonValidator - - -class ProviderCredentialSchemaValidator(CommonValidator): - def __init__(self, provider_credential_schema: ProviderCredentialSchema): - self.provider_credential_schema = provider_credential_schema - - def validate_and_filter(self, credentials: dict): - """ - Validate provider credentials - - :param credentials: provider credentials - :return: validated provider credentials - """ - # get the credential_form_schemas in provider_credential_schema - credential_form_schemas = self.provider_credential_schema.credential_form_schemas - - return self._validate_and_filter_credential_form_schemas(credential_form_schemas, credentials) diff --git a/api/dify_graph/model_runtime/utils/encoders.py b/api/dify_graph/model_runtime/utils/encoders.py deleted file mode 100644 index c85152463e3..00000000000 --- a/api/dify_graph/model_runtime/utils/encoders.py +++ /dev/null @@ -1,216 +0,0 @@ -import dataclasses -import datetime -from collections import defaultdict, deque -from collections.abc import Callable -from decimal import Decimal -from enum import Enum -from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network -from pathlib import Path, PurePath -from re import Pattern -from types import GeneratorType -from typing import Any, Literal, Union -from uuid import UUID - -from pydantic import BaseModel -from pydantic.networks import AnyUrl, NameEmail -from pydantic.types import SecretBytes, SecretStr -from pydantic_core import Url -from pydantic_extra_types.color import Color - - -def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any: - return model.model_dump(mode=mode, **kwargs) - - -# Taken from Pydantic v1 as is -def isoformat(o: Union[datetime.date, datetime.time]) -> str: - return o.isoformat() - - -# Taken from Pydantic v1 as is -# TODO: pv2 should this return strings instead? -def decimal_encoder(dec_value: Decimal) -> Union[int, float]: - """ - Encodes a Decimal as int of there's no exponent, otherwise float - - This is useful when we use ConstrainedDecimal to represent Numeric(x,0) - where a integer (but not int typed) is used. Encoding this as a float - results in failed round-tripping between encode and parse. - Our Id type is a prime example of this. - - >>> decimal_encoder(Decimal("1.0")) - 1.0 - - >>> decimal_encoder(Decimal("1")) - 1 - """ - if dec_value.as_tuple().exponent >= 0: # type: ignore[operator] - return int(dec_value) - else: - return float(dec_value) - - -ENCODERS_BY_TYPE: dict[type[Any], Callable[[Any], Any]] = { - bytes: lambda o: o.decode(), - Color: str, - datetime.date: isoformat, - datetime.datetime: isoformat, - datetime.time: isoformat, - datetime.timedelta: lambda td: td.total_seconds(), - Decimal: decimal_encoder, - Enum: lambda o: o.value, - frozenset: list, - deque: list, - GeneratorType: list, - IPv4Address: str, - IPv4Interface: str, - IPv4Network: str, - IPv6Address: str, - IPv6Interface: str, - IPv6Network: str, - NameEmail: str, - Path: str, - Pattern: lambda o: o.pattern, - SecretBytes: str, - SecretStr: str, - set: list, - UUID: str, - Url: str, - AnyUrl: str, -} - - -def generate_encoders_by_class_tuples( - type_encoder_map: dict[Any, Callable[[Any], Any]], -) -> dict[Callable[[Any], Any], tuple[Any, ...]]: - encoders_by_class_tuples: dict[Callable[[Any], Any], tuple[Any, ...]] = defaultdict(tuple) - for type_, encoder in type_encoder_map.items(): - encoders_by_class_tuples[encoder] += (type_,) - return encoders_by_class_tuples - - -encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE) - - -def jsonable_encoder( - obj: Any, - by_alias: bool = True, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - custom_encoder: dict[Any, Callable[[Any], Any]] | None = None, - sqlalchemy_safe: bool = True, -) -> Any: - custom_encoder = custom_encoder or {} - if custom_encoder: - if type(obj) in custom_encoder: - return custom_encoder[type(obj)](obj) - else: - for encoder_type, encoder_instance in custom_encoder.items(): - if isinstance(obj, encoder_type): - return encoder_instance(obj) - if isinstance(obj, BaseModel): - obj_dict = _model_dump( - obj, - mode="json", - include=None, - exclude=None, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_none=exclude_none, - exclude_defaults=exclude_defaults, - ) - if "__root__" in obj_dict: - obj_dict = obj_dict["__root__"] - return jsonable_encoder( - obj_dict, - exclude_none=exclude_none, - exclude_defaults=exclude_defaults, - sqlalchemy_safe=sqlalchemy_safe, - ) - if dataclasses.is_dataclass(obj): - # Ensure obj is a dataclass instance, not a dataclass type - if not isinstance(obj, type): - obj_dict = dataclasses.asdict(obj) - return jsonable_encoder( - obj_dict, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, - ) - if isinstance(obj, Enum): - return obj.value - if isinstance(obj, PurePath): - return str(obj) - if isinstance(obj, str | int | float | type(None)): - return obj - if isinstance(obj, Decimal): - return format(obj, "f") - if isinstance(obj, dict): - encoded_dict = {} - for key, value in obj.items(): - if (not sqlalchemy_safe or (not isinstance(key, str)) or (not key.startswith("_sa"))) and ( - value is not None or not exclude_none - ): - encoded_key = jsonable_encoder( - key, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, - ) - encoded_value = jsonable_encoder( - value, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, - ) - encoded_dict[encoded_key] = encoded_value - return encoded_dict - if isinstance(obj, list | set | frozenset | GeneratorType | tuple | deque): - encoded_list = [] - for item in obj: - encoded_list.append( - jsonable_encoder( - item, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, - ) - ) - return encoded_list - - if type(obj) in ENCODERS_BY_TYPE: - return ENCODERS_BY_TYPE[type(obj)](obj) - for encoder, classes_tuple in encoders_by_class_tuples.items(): - if isinstance(obj, classes_tuple): - return encoder(obj) - - try: - data = dict(obj) # type: ignore - except Exception as e: - errors: list[Exception] = [] - errors.append(e) - try: - data = vars(obj) # type: ignore - except Exception as e: - errors.append(e) - raise ValueError(str(errors)) from e - return jsonable_encoder( - data, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, - ) diff --git a/api/dify_graph/node_events/__init__.py b/api/dify_graph/node_events/__init__.py deleted file mode 100644 index a9bef8f9a2d..00000000000 --- a/api/dify_graph/node_events/__init__.py +++ /dev/null @@ -1,46 +0,0 @@ -from .agent import AgentLogEvent -from .base import NodeEventBase, NodeRunResult -from .iteration import ( - IterationFailedEvent, - IterationNextEvent, - IterationStartedEvent, - IterationSucceededEvent, -) -from .loop import ( - LoopFailedEvent, - LoopNextEvent, - LoopStartedEvent, - LoopSucceededEvent, -) -from .node import ( - HumanInputFormFilledEvent, - HumanInputFormTimeoutEvent, - ModelInvokeCompletedEvent, - PauseRequestedEvent, - RunRetrieverResourceEvent, - RunRetryEvent, - StreamChunkEvent, - StreamCompletedEvent, -) - -__all__ = [ - "AgentLogEvent", - "HumanInputFormFilledEvent", - "HumanInputFormTimeoutEvent", - "IterationFailedEvent", - "IterationNextEvent", - "IterationStartedEvent", - "IterationSucceededEvent", - "LoopFailedEvent", - "LoopNextEvent", - "LoopStartedEvent", - "LoopSucceededEvent", - "ModelInvokeCompletedEvent", - "NodeEventBase", - "NodeRunResult", - "PauseRequestedEvent", - "RunRetrieverResourceEvent", - "RunRetryEvent", - "StreamChunkEvent", - "StreamCompletedEvent", -] diff --git a/api/dify_graph/node_events/agent.py b/api/dify_graph/node_events/agent.py deleted file mode 100644 index bf295ec7742..00000000000 --- a/api/dify_graph/node_events/agent.py +++ /dev/null @@ -1,18 +0,0 @@ -from collections.abc import Mapping -from typing import Any - -from pydantic import Field - -from .base import NodeEventBase - - -class AgentLogEvent(NodeEventBase): - message_id: str = Field(..., description="id") - label: str = Field(..., description="label") - node_execution_id: str = Field(..., description="node execution id") - parent_id: str | None = Field(..., description="parent id") - error: str | None = Field(..., description="error") - status: str = Field(..., description="status") - data: Mapping[str, Any] = Field(..., description="data") - metadata: Mapping[str, Any] = Field(default_factory=dict, description="metadata") - node_id: str = Field(..., description="node id") diff --git a/api/dify_graph/node_events/base.py b/api/dify_graph/node_events/base.py deleted file mode 100644 index 2f6259ae7d8..00000000000 --- a/api/dify_graph/node_events/base.py +++ /dev/null @@ -1,40 +0,0 @@ -from collections.abc import Mapping -from typing import Any - -from pydantic import BaseModel, Field - -from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from dify_graph.model_runtime.entities.llm_entities import LLMUsage - - -class NodeEventBase(BaseModel): - """Base class for all node events""" - - pass - - -def _default_metadata(): - v: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {} - return v - - -class NodeRunResult(BaseModel): - """ - Node Run Result. - """ - - status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.PENDING - - inputs: Mapping[str, Any] = Field(default_factory=dict) - process_data: Mapping[str, Any] = Field(default_factory=dict) - outputs: Mapping[str, Any] = Field(default_factory=dict) - metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = Field(default_factory=_default_metadata) - llm_usage: LLMUsage = Field(default_factory=LLMUsage.empty_usage) - - edge_source_handle: str = "source" # source handle id of node with multiple branches - - error: str = "" - error_type: str = "" - - # single step node run retry - retry_index: int = 0 diff --git a/api/dify_graph/node_events/iteration.py b/api/dify_graph/node_events/iteration.py deleted file mode 100644 index 744ddea628b..00000000000 --- a/api/dify_graph/node_events/iteration.py +++ /dev/null @@ -1,36 +0,0 @@ -from collections.abc import Mapping -from datetime import datetime -from typing import Any - -from pydantic import Field - -from .base import NodeEventBase - - -class IterationStartedEvent(NodeEventBase): - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - predecessor_node_id: str | None = None - - -class IterationNextEvent(NodeEventBase): - index: int = Field(..., description="index") - pre_iteration_output: Any = None - - -class IterationSucceededEvent(NodeEventBase): - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - - -class IterationFailedEvent(NodeEventBase): - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - error: str = Field(..., description="failed reason") diff --git a/api/dify_graph/node_events/loop.py b/api/dify_graph/node_events/loop.py deleted file mode 100644 index 3ae230f9f66..00000000000 --- a/api/dify_graph/node_events/loop.py +++ /dev/null @@ -1,36 +0,0 @@ -from collections.abc import Mapping -from datetime import datetime -from typing import Any - -from pydantic import Field - -from .base import NodeEventBase - - -class LoopStartedEvent(NodeEventBase): - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - predecessor_node_id: str | None = None - - -class LoopNextEvent(NodeEventBase): - index: int = Field(..., description="index") - pre_loop_output: Any = None - - -class LoopSucceededEvent(NodeEventBase): - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - - -class LoopFailedEvent(NodeEventBase): - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - error: str = Field(..., description="failed reason") diff --git a/api/dify_graph/node_events/node.py b/api/dify_graph/node_events/node.py deleted file mode 100644 index 2e3973b8fa7..00000000000 --- a/api/dify_graph/node_events/node.py +++ /dev/null @@ -1,65 +0,0 @@ -from collections.abc import Mapping, Sequence -from datetime import datetime -from typing import Any - -from pydantic import Field - -from dify_graph.entities.pause_reason import PauseReason -from dify_graph.file import File -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.node_events import NodeRunResult - -from .base import NodeEventBase - - -class RunRetrieverResourceEvent(NodeEventBase): - retriever_resources: Sequence[Mapping[str, Any]] = Field(..., description="retriever resources") - context: str = Field(..., description="context") - context_files: list[File] | None = Field(default=None, description="context files") - - -class ModelInvokeCompletedEvent(NodeEventBase): - text: str - usage: LLMUsage - finish_reason: str | None = None - reasoning_content: str | None = None - structured_output: dict | None = None - - -class RunRetryEvent(NodeEventBase): - error: str = Field(..., description="error") - retry_index: int = Field(..., description="Retry attempt number") - start_at: datetime = Field(..., description="Retry start time") - - -class StreamChunkEvent(NodeEventBase): - # Spec-compliant fields - selector: Sequence[str] = Field( - ..., description="selector identifying the output location (e.g., ['nodeA', 'text'])" - ) - chunk: str = Field(..., description="the actual chunk content") - is_final: bool = Field(default=False, description="indicates if this is the last chunk") - - -class StreamCompletedEvent(NodeEventBase): - node_run_result: NodeRunResult = Field(..., description="run result") - - -class PauseRequestedEvent(NodeEventBase): - reason: PauseReason = Field(..., description="pause reason") - - -class HumanInputFormFilledEvent(NodeEventBase): - """Event emitted when a human input form is submitted.""" - - node_title: str - rendered_content: str - action_id: str - action_text: str - - -class HumanInputFormTimeoutEvent(NodeEventBase): - """Event emitted when a human input form times out.""" - - node_title: str - expiration_time: datetime diff --git a/api/dify_graph/nodes/__init__.py b/api/dify_graph/nodes/__init__.py deleted file mode 100644 index 0223149bb84..00000000000 --- a/api/dify_graph/nodes/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from dify_graph.enums import BuiltinNodeTypes - -__all__ = ["BuiltinNodeTypes"] diff --git a/api/dify_graph/nodes/answer/answer_node.py b/api/dify_graph/nodes/answer/answer_node.py deleted file mode 100644 index 4286e1a4920..00000000000 --- a/api/dify_graph/nodes/answer/answer_node.py +++ /dev/null @@ -1,70 +0,0 @@ -from collections.abc import Mapping, Sequence -from typing import Any - -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.answer.entities import AnswerNodeData -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.base.template import Template -from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser -from dify_graph.variables import ArrayFileSegment, FileSegment, Segment - - -class AnswerNode(Node[AnswerNodeData]): - node_type = BuiltinNodeTypes.ANSWER - execution_type = NodeExecutionType.RESPONSE - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - segments = self.graph_runtime_state.variable_pool.convert_template(self.node_data.answer) - files = self._extract_files_from_segments(segments.value) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"answer": segments.markdown, "files": ArrayFileSegment(value=files)}, - ) - - def _extract_files_from_segments(self, segments: Sequence[Segment]): - """Extract all files from segments containing FileSegment or ArrayFileSegment instances. - - FileSegment contains a single file, while ArrayFileSegment contains multiple files. - This method flattens all files into a single list. - """ - files = [] - for segment in segments: - if isinstance(segment, FileSegment): - # Single file - wrap in list for consistency - files.append(segment.value) - elif isinstance(segment, ArrayFileSegment): - # Multiple files - extend the list - files.extend(segment.value) - return files - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: AnswerNodeData, - ) -> Mapping[str, Sequence[str]]: - _ = graph_config # Explicitly mark as unused - variable_template_parser = VariableTemplateParser(template=node_data.answer) - variable_selectors = variable_template_parser.extract_variable_selectors() - - variable_mapping = {} - for variable_selector in variable_selectors: - variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector - - return variable_mapping - - def get_streaming_template(self) -> Template: - """ - Get the template for streaming. - - Returns: - Template instance for this Answer node - """ - return Template.from_answer_template(self.node_data.answer) diff --git a/api/dify_graph/nodes/answer/entities.py b/api/dify_graph/nodes/answer/entities.py deleted file mode 100644 index cd82df1ac45..00000000000 --- a/api/dify_graph/nodes/answer/entities.py +++ /dev/null @@ -1,67 +0,0 @@ -from collections.abc import Sequence -from enum import StrEnum, auto - -from pydantic import BaseModel, Field - -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType - - -class AnswerNodeData(BaseNodeData): - """ - Answer Node Data. - """ - - type: NodeType = BuiltinNodeTypes.ANSWER - answer: str = Field(..., description="answer template string") - - -class GenerateRouteChunk(BaseModel): - """ - Generate Route Chunk. - """ - - class ChunkType(StrEnum): - VAR = auto() - TEXT = auto() - - type: ChunkType = Field(..., description="generate route chunk type") - - -class VarGenerateRouteChunk(GenerateRouteChunk): - """ - Var Generate Route Chunk. - """ - - type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.VAR - """generate route chunk type""" - value_selector: Sequence[str] = Field(..., description="value selector") - - -class TextGenerateRouteChunk(GenerateRouteChunk): - """ - Text Generate Route Chunk. - """ - - type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.TEXT - """generate route chunk type""" - text: str = Field(..., description="text") - - -class AnswerNodeDoubleLink(BaseModel): - node_id: str = Field(..., description="node id") - source_node_ids: list[str] = Field(..., description="source node ids") - target_node_ids: list[str] = Field(..., description="target node ids") - - -class AnswerStreamGenerateRoute(BaseModel): - """ - AnswerStreamGenerateRoute entity - """ - - answer_dependencies: dict[str, list[str]] = Field( - ..., description="answer dependencies (answer node id -> dependent answer node ids)" - ) - answer_generate_route: dict[str, list[GenerateRouteChunk]] = Field( - ..., description="answer generate route (answer node id -> generate route chunks)" - ) diff --git a/api/dify_graph/nodes/base/__init__.py b/api/dify_graph/nodes/base/__init__.py deleted file mode 100644 index 036e25895d2..00000000000 --- a/api/dify_graph/nodes/base/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState -from .usage_tracking_mixin import LLMUsageTrackingMixin - -__all__ = [ - "BaseIterationNodeData", - "BaseIterationState", - "BaseLoopNodeData", - "BaseLoopState", - "LLMUsageTrackingMixin", -] diff --git a/api/dify_graph/nodes/base/entities.py b/api/dify_graph/nodes/base/entities.py deleted file mode 100644 index 4f8b2682e1a..00000000000 --- a/api/dify_graph/nodes/base/entities.py +++ /dev/null @@ -1,87 +0,0 @@ -from __future__ import annotations - -from collections.abc import Sequence -from enum import StrEnum -from typing import Any - -from pydantic import BaseModel, field_validator - -from dify_graph.entities.base_node_data import BaseNodeData - - -class VariableSelector(BaseModel): - """ - Variable Selector. - """ - - variable: str - value_selector: Sequence[str] - - -class OutputVariableType(StrEnum): - STRING = "string" - NUMBER = "number" - INTEGER = "integer" - SECRET = "secret" - BOOLEAN = "boolean" - OBJECT = "object" - FILE = "file" - ARRAY = "array" - ARRAY_STRING = "array[string]" - ARRAY_NUMBER = "array[number]" - ARRAY_OBJECT = "array[object]" - ARRAY_BOOLEAN = "array[boolean]" - ARRAY_FILE = "array[file]" - ANY = "any" - ARRAY_ANY = "array[any]" - - -class OutputVariableEntity(BaseModel): - """ - Output Variable Entity. - """ - - variable: str - value_type: OutputVariableType = OutputVariableType.ANY - value_selector: Sequence[str] - - @field_validator("value_type", mode="before") - @classmethod - def normalize_value_type(cls, v: Any) -> Any: - """ - Normalize value_type to handle case-insensitive array types. - Converts 'Array[...]' to 'array[...]' for backward compatibility. - """ - if isinstance(v, str) and v.startswith("Array["): - return v.lower() - return v - - -class BaseIterationNodeData(BaseNodeData): - start_node_id: str | None = None - - -class BaseIterationState(BaseModel): - iteration_node_id: str - index: int - inputs: dict - - class MetaData(BaseModel): - pass - - metadata: MetaData - - -class BaseLoopNodeData(BaseNodeData): - start_node_id: str | None = None - - -class BaseLoopState(BaseModel): - loop_node_id: str - index: int - inputs: dict - - class MetaData(BaseModel): - pass - - metadata: MetaData diff --git a/api/dify_graph/nodes/base/node.py b/api/dify_graph/nodes/base/node.py deleted file mode 100644 index 56b46a58941..00000000000 --- a/api/dify_graph/nodes/base/node.py +++ /dev/null @@ -1,808 +0,0 @@ -from __future__ import annotations - -import logging -import operator -from abc import abstractmethod -from collections.abc import Generator, Mapping, Sequence -from functools import singledispatchmethod -from types import MappingProxyType -from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, get_args, get_origin -from uuid import uuid4 - -from dify_graph.entities import GraphInitParams -from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import ( - ErrorStrategy, - NodeExecutionType, - NodeState, - NodeType, - WorkflowNodeExecutionStatus, -) -from dify_graph.graph_events import ( - GraphNodeEventBase, - NodeRunAgentLogEvent, - NodeRunFailedEvent, - NodeRunHumanInputFormFilledEvent, - NodeRunHumanInputFormTimeoutEvent, - NodeRunIterationFailedEvent, - NodeRunIterationNextEvent, - NodeRunIterationStartedEvent, - NodeRunIterationSucceededEvent, - NodeRunLoopFailedEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, - NodeRunPauseRequestedEvent, - NodeRunRetrieverResourceEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) -from dify_graph.node_events import ( - AgentLogEvent, - HumanInputFormFilledEvent, - HumanInputFormTimeoutEvent, - IterationFailedEvent, - IterationNextEvent, - IterationStartedEvent, - IterationSucceededEvent, - LoopFailedEvent, - LoopNextEvent, - LoopStartedEvent, - LoopSucceededEvent, - NodeEventBase, - NodeRunResult, - PauseRequestedEvent, - RunRetrieverResourceEvent, - StreamChunkEvent, - StreamCompletedEvent, -) -from dify_graph.runtime import GraphRuntimeState -from libs.datetime_utils import naive_utc_now - -NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData) -_MISSING_RUN_CONTEXT_VALUE = object() - -logger = logging.getLogger(__name__) - - -class DifyRunContextProtocol(Protocol): - tenant_id: str - app_id: str - user_id: str - user_from: Any - invoke_from: Any - - -class _MappingDifyRunContext: - def __init__(self, mapping: Mapping[str, Any]) -> None: - self.tenant_id = str(mapping["tenant_id"]) - self.app_id = str(mapping["app_id"]) - self.user_id = str(mapping["user_id"]) - self.user_from = mapping["user_from"] - self.invoke_from = mapping["invoke_from"] - - -class Node(Generic[NodeDataT]): - """BaseNode serves as the foundational class for all node implementations. - - Nodes are allowed to maintain transient states (e.g., `LLMNode` uses the `_file_output` - attribute to track files generated by the LLM). However, these states are not persisted - when the workflow is suspended or resumed. If a node needs its state to be preserved - across workflow suspension and resumption, it should include the relevant state data - in its output. - """ - - node_type: ClassVar[NodeType] - execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE - _node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData - - def __init_subclass__(cls, **kwargs: Any) -> None: - """ - Automatically extract and validate the node data type from the generic parameter. - - When a subclass is defined as `class MyNode(Node[MyNodeData])`, this method: - 1. Inspects `__orig_bases__` to find the `Node[T]` parameterization - 2. Extracts `T` (e.g., `MyNodeData`) from the generic argument - 3. Validates that `T` is a proper `BaseNodeData` subclass - 4. Stores it in `_node_data_type` for automatic hydration in `__init__` - - This eliminates the need for subclasses to manually implement boilerplate - accessor methods like `_get_title()`, `_get_error_strategy()`, etc. - - How it works: - :: - - class CodeNode(Node[CodeNodeData]): - │ │ - │ └─────────────────────────────────┐ - │ │ - ▼ ▼ - ┌─────────────────────────────┐ ┌─────────────────────────────────┐ - │ __orig_bases__ = ( │ │ CodeNodeData(BaseNodeData) │ - │ Node[CodeNodeData], │ │ title: str │ - │ ) │ │ desc: str | None │ - └──────────────┬──────────────┘ │ ... │ - │ └─────────────────────────────────┘ - ▼ ▲ - ┌─────────────────────────────┐ │ - │ get_origin(base) -> Node │ │ - │ get_args(base) -> ( │ │ - │ CodeNodeData, │ ──────────────────────┘ - │ ) │ - └──────────────┬──────────────┘ - │ - ▼ - ┌─────────────────────────────┐ - │ Validate: │ - │ - Is it a type? │ - │ - Is it a BaseNodeData │ - │ subclass? │ - └──────────────┬──────────────┘ - │ - ▼ - ┌─────────────────────────────┐ - │ cls._node_data_type = │ - │ CodeNodeData │ - └─────────────────────────────┘ - - Later, in __init__: - :: - - config["data"] ──► _node_data_type.model_validate(..., from_attributes=True) - │ - ▼ - CodeNodeData instance - (stored in self._node_data) - - Example: - class CodeNode(Node[CodeNodeData]): # CodeNodeData is auto-extracted - node_type = BuiltinNodeTypes.CODE - # No need to implement _get_title, _get_error_strategy, etc. - """ - super().__init_subclass__(**kwargs) - - if cls is Node: - return - - node_data_type = cls._extract_node_data_type_from_generic() - - if node_data_type is None: - raise TypeError(f"{cls.__name__} must inherit from Node[T] with a BaseNodeData subtype") - - cls._node_data_type = node_data_type - - # Skip base class itself - if cls is Node: - return - # Only register production node implementations defined under the - # canonical workflow namespaces. - # This prevents test helper subclasses from polluting the global registry and - # accidentally overriding real node types (e.g., a test Answer node). - module_name = getattr(cls, "__module__", "") - # Only register concrete subclasses that define node_type and version() - node_type = cls.node_type - version = cls.version() - bucket = Node._registry.setdefault(node_type, {}) - if module_name.startswith(("dify_graph.nodes.", "core.workflow.nodes.")): - # Production node definitions take precedence and may override - bucket[version] = cls # type: ignore[index] - else: - # External/test subclasses may register but must not override production - bucket.setdefault(version, cls) # type: ignore[index] - # Maintain a "latest" pointer preferring numeric versions; fallback to lexicographic - version_keys = [v for v in bucket if v != "latest"] - numeric_pairs: list[tuple[str, int]] = [] - for v in version_keys: - numeric_pairs.append((v, int(v))) - if numeric_pairs: - latest_key = max(numeric_pairs, key=operator.itemgetter(1))[0] - else: - latest_key = max(version_keys) if version_keys else version - bucket["latest"] = bucket[latest_key] - Node._registry_version += 1 - - @classmethod - def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None: - """ - Extract the node data type from the generic parameter `Node[T]`. - - Inspects `__orig_bases__` to find the `Node[T]` parameterization and extracts `T`. - - Returns: - The extracted BaseNodeData subtype, or None if not found. - - Raises: - TypeError: If the generic argument is invalid (not exactly one argument, - or not a BaseNodeData subtype). - """ - # __orig_bases__ contains the original generic bases before type erasure. - # For `class CodeNode(Node[CodeNodeData])`, this would be `(Node[CodeNodeData],)`. - for base in getattr(cls, "__orig_bases__", ()): # type: ignore[attr-defined] - origin = get_origin(base) # Returns `Node` for `Node[CodeNodeData]` - if origin is Node: - args = get_args(base) # Returns `(CodeNodeData,)` for `Node[CodeNodeData]` - if len(args) != 1: - raise TypeError(f"{cls.__name__} must specify exactly one node data generic argument") - - candidate = args[0] - if not isinstance(candidate, type) or not issubclass(candidate, BaseNodeData): - raise TypeError(f"{cls.__name__} must parameterize Node with a BaseNodeData subtype") - - return candidate - - return None - - # Global registry populated via __init_subclass__ - _registry: ClassVar[dict[NodeType, dict[str, type[Node]]]] = {} - _registry_version: ClassVar[int] = 0 - - @classmethod - def get_registry_version(cls) -> int: - return cls._registry_version - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - ) -> None: - self._graph_init_params = graph_init_params - self._run_context = MappingProxyType(dict(graph_init_params.run_context)) - self.id = id - self.workflow_id = graph_init_params.workflow_id - self.graph_config = graph_init_params.graph_config - self.workflow_call_depth = graph_init_params.call_depth - self.graph_runtime_state = graph_runtime_state - self.state: NodeState = NodeState.UNKNOWN # node execution state - - node_id = config["id"] - - self._node_id = node_id - self._node_execution_id: str = "" - self._start_at = naive_utc_now() - - self._node_data = self.validate_node_data(config["data"]) - - self.post_init() - - @classmethod - def validate_node_data(cls, node_data: BaseNodeData) -> NodeDataT: - """Validate shared graph node payloads against the subclass-declared NodeData model.""" - return cast(NodeDataT, cls._node_data_type.model_validate(node_data, from_attributes=True)) - - def init_node_data(self, data: BaseNodeData | Mapping[str, Any]) -> None: - """Hydrate `_node_data` for legacy callers that bypass `__init__`.""" - self._node_data = self.validate_node_data(cast(BaseNodeData, data)) - - def post_init(self) -> None: - """Optional hook for subclasses requiring extra initialization.""" - return - - @property - def graph_init_params(self) -> GraphInitParams: - return self._graph_init_params - - @property - def run_context(self) -> Mapping[str, Any]: - return self._run_context - - def get_run_context_value(self, key: str, default: Any = None) -> Any: - return self._run_context.get(key, default) - - def require_run_context_value(self, key: str) -> Any: - value = self.get_run_context_value(key, _MISSING_RUN_CONTEXT_VALUE) - if value is _MISSING_RUN_CONTEXT_VALUE: - raise ValueError(f"run_context missing required key: {key}") - return value - - def require_dify_context(self) -> DifyRunContextProtocol: - raw_ctx = self.require_run_context_value(DIFY_RUN_CONTEXT_KEY) - if raw_ctx is None: - raise ValueError(f"run_context missing required key: {DIFY_RUN_CONTEXT_KEY}") - - if isinstance(raw_ctx, Mapping): - missing_keys = [ - key for key in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from") if key not in raw_ctx - ] - if missing_keys: - raise ValueError(f"dify context missing required keys: {', '.join(missing_keys)}") - return _MappingDifyRunContext(raw_ctx) - - for attr in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from"): - if not hasattr(raw_ctx, attr): - raise TypeError(f"invalid dify context object, missing attribute: {attr}") - - return cast(DifyRunContextProtocol, raw_ctx) - - @property - def execution_id(self) -> str: - return self._node_execution_id - - def ensure_execution_id(self) -> str: - if self._node_execution_id: - return self._node_execution_id - - resumed_execution_id = self._restore_execution_id_from_runtime_state() - if resumed_execution_id: - self._node_execution_id = resumed_execution_id - return self._node_execution_id - - self._node_execution_id = str(uuid4()) - return self._node_execution_id - - def _restore_execution_id_from_runtime_state(self) -> str | None: - graph_execution = self.graph_runtime_state.graph_execution - try: - node_executions = graph_execution.node_executions - except AttributeError: - return None - if not isinstance(node_executions, dict): - return None - node_execution = node_executions.get(self._node_id) - if node_execution is None: - return None - execution_id = node_execution.execution_id - if not execution_id: - return None - return str(execution_id) - - @abstractmethod - def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]: - """ - Run node - :return: - """ - raise NotImplementedError - - def populate_start_event(self, event: NodeRunStartedEvent) -> None: - """Allow subclasses to enrich the started event without cross-node imports in the base class.""" - _ = event - - def run(self) -> Generator[GraphNodeEventBase, None, None]: - execution_id = self.ensure_execution_id() - self._start_at = naive_utc_now() - - # Create and push start event with required fields - start_event = NodeRunStartedEvent( - id=execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.title, - in_iteration_id=None, - start_at=self._start_at, - ) - try: - self.populate_start_event(start_event) - except Exception: - logger.warning("Failed to populate start event for node %s", self._node_id, exc_info=True) - yield start_event - - try: - result = self._run() - - # Handle NodeRunResult - if isinstance(result, NodeRunResult): - yield self._convert_node_run_result_to_graph_node_event(result) - return - - # Handle event stream - for event in result: - # NOTE: this is necessary because iteration and loop nodes yield GraphNodeEventBase - if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance] - yield self._dispatch(event) - elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance] - event.id = self.execution_id - yield event - else: - yield event - except Exception as e: - logger.exception("Node %s failed to run", self._node_id) - result = NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - error_type="WorkflowNodeError", - ) - finished_at = naive_utc_now() - yield NodeRunFailedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - start_at=self._start_at, - finished_at=finished_at, - node_run_result=result, - error=str(e), - ) - - @classmethod - def extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - config: NodeConfigDict, - ) -> Mapping[str, Sequence[str]]: - """Extracts references variable selectors from node configuration. - - The `config` parameter represents the configuration for a specific node type and corresponds - to the `data` field in the node definition object. - - The returned mapping has the following structure: - - {'1747829548239.#1747829667553.result#': ['1747829667553', 'result']} - - For loop and iteration nodes, the mapping may look like this: - - { - "1748332301644.input_selector": ["1748332363630", "result"], - "1748332325079.1748332325079.#sys.workflow_id#": ["sys", "workflow_id"], - } - - where `1748332301644` is the ID of the loop / iteration node, - and `1748332325079` is the ID of the node inside the loop or iteration node. - - Here, the key consists of two parts: the current node ID (provided as the `node_id` - parameter to `_extract_variable_selector_to_variable_mapping`) and the variable selector, - enclosed in `#` symbols. These two parts are separated by a dot (`.`). - - The value is a list of string representing the variable selector, where the first element is the node ID - of the referenced variable, and the second element is the variable name within that node. - - The meaning of the above response is: - - The node with ID `1747829548239` references the variable `result` from the node with - ID `1747829667553`. For example, if `1747829548239` is a LLM node, its prompt may contain a - reference to the `result` output variable of node `1747829667553`. - - :param graph_config: graph config - :param config: node config - :return: - """ - node_id = config["id"] - node_data = cls.validate_node_data(config["data"]) - data = cls._extract_variable_selector_to_variable_mapping( - graph_config=graph_config, - node_id=node_id, - node_data=node_data, - ) - return data - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: NodeDataT, - ) -> Mapping[str, Sequence[str]]: - return {} - - def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: - """ - Check if this node blocks the output of specific variables. - - This method is used to determine if a node must complete execution before - the specified variables can be used in streaming output. - - :param variable_selectors: Set of variable selectors, each as a tuple (e.g., ('conversation', 'str')) - :return: True if this node blocks output of any of the specified variables, False otherwise - """ - return False - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - return {} - - @classmethod - @abstractmethod - def version(cls) -> str: - """`node_version` returns the version of current node type.""" - # NOTE(QuantumGhost): Node versions must remain unique per `NodeType` so - # registry lookups can resolve numeric versions and `latest`. - raise NotImplementedError("subclasses of BaseNode must implement `version` method.") - - @classmethod - def get_node_type_classes_mapping(cls) -> Mapping[NodeType, Mapping[str, type[Node]]]: - """Return a read-only view of the currently registered node classes. - - This accessor intentionally performs no imports. The embedding layer that - owns bootstrap (for example `core.workflow.node_factory`) must import any - extension node packages before calling it so their subclasses register via - `__init_subclass__`. - """ - return {node_type: MappingProxyType(version_map) for node_type, version_map in cls._registry.items()} - - @property - def retry(self) -> bool: - return False - - def _get_error_strategy(self) -> ErrorStrategy | None: - """Get the error strategy for this node.""" - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - """Get the retry configuration for this node.""" - return self._node_data.retry_config - - def _get_title(self) -> str: - """Get the node title.""" - return self._node_data.title - - def _get_description(self) -> str | None: - """Get the node description.""" - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - """Get the default values dictionary for this node.""" - return self._node_data.default_value_dict - - # Public interface properties that delegate to abstract methods - @property - def error_strategy(self) -> ErrorStrategy | None: - """Get the error strategy for this node.""" - return self._get_error_strategy() - - @property - def retry_config(self) -> RetryConfig: - """Get the retry configuration for this node.""" - return self._get_retry_config() - - @property - def title(self) -> str: - """Get the node title.""" - return self._get_title() - - @property - def description(self) -> str | None: - """Get the node description.""" - return self._get_description() - - @property - def default_value_dict(self) -> dict[str, Any]: - """Get the default values dictionary for this node.""" - return self._get_default_value_dict() - - @property - def node_data(self) -> NodeDataT: - """Typed access to this node's configuration data.""" - return self._node_data - - def _convert_node_run_result_to_graph_node_event(self, result: NodeRunResult) -> GraphNodeEventBase: - finished_at = naive_utc_now() - match result.status: - case WorkflowNodeExecutionStatus.FAILED: - return NodeRunFailedEvent( - id=self.execution_id, - node_id=self.id, - node_type=self.node_type, - start_at=self._start_at, - finished_at=finished_at, - node_run_result=result, - error=result.error, - ) - case WorkflowNodeExecutionStatus.SUCCEEDED: - return NodeRunSucceededEvent( - id=self.execution_id, - node_id=self.id, - node_type=self.node_type, - start_at=self._start_at, - finished_at=finished_at, - node_run_result=result, - ) - case _: - raise Exception(f"result status {result.status} not supported") - - @singledispatchmethod - def _dispatch(self, event: NodeEventBase) -> GraphNodeEventBase: - raise NotImplementedError(f"Node {self._node_id} does not support event type {type(event)}") - - @_dispatch.register - def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent: - return NodeRunStreamChunkEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - selector=event.selector, - chunk=event.chunk, - is_final=event.is_final, - ) - - @_dispatch.register - def _(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent: - finished_at = naive_utc_now() - match event.node_run_result.status: - case WorkflowNodeExecutionStatus.SUCCEEDED: - return NodeRunSucceededEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - start_at=self._start_at, - finished_at=finished_at, - node_run_result=event.node_run_result, - ) - case WorkflowNodeExecutionStatus.FAILED: - return NodeRunFailedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - start_at=self._start_at, - finished_at=finished_at, - node_run_result=event.node_run_result, - error=event.node_run_result.error, - ) - case _: - raise NotImplementedError( - f"Node {self._node_id} does not support status {event.node_run_result.status}" - ) - - @_dispatch.register - def _(self, event: PauseRequestedEvent) -> NodeRunPauseRequestedEvent: - return NodeRunPauseRequestedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.PAUSED), - reason=event.reason, - ) - - @_dispatch.register - def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent: - return NodeRunAgentLogEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - message_id=event.message_id, - label=event.label, - node_execution_id=event.node_execution_id, - parent_id=event.parent_id, - error=event.error, - status=event.status, - data=event.data, - metadata=event.metadata, - ) - - @_dispatch.register - def _(self, event: HumanInputFormFilledEvent): - return NodeRunHumanInputFormFilledEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=event.node_title, - rendered_content=event.rendered_content, - action_id=event.action_id, - action_text=event.action_text, - ) - - @_dispatch.register - def _(self, event: HumanInputFormTimeoutEvent): - return NodeRunHumanInputFormTimeoutEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=event.node_title, - expiration_time=event.expiration_time, - ) - - @_dispatch.register - def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent: - return NodeRunLoopStartedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - start_at=event.start_at, - inputs=event.inputs, - metadata=event.metadata, - predecessor_node_id=event.predecessor_node_id, - ) - - @_dispatch.register - def _(self, event: LoopNextEvent) -> NodeRunLoopNextEvent: - return NodeRunLoopNextEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - index=event.index, - pre_loop_output=event.pre_loop_output, - ) - - @_dispatch.register - def _(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent: - return NodeRunLoopSucceededEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - start_at=event.start_at, - inputs=event.inputs, - outputs=event.outputs, - metadata=event.metadata, - steps=event.steps, - ) - - @_dispatch.register - def _(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent: - return NodeRunLoopFailedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - start_at=event.start_at, - inputs=event.inputs, - outputs=event.outputs, - metadata=event.metadata, - steps=event.steps, - error=event.error, - ) - - @_dispatch.register - def _(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent: - return NodeRunIterationStartedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - start_at=event.start_at, - inputs=event.inputs, - metadata=event.metadata, - predecessor_node_id=event.predecessor_node_id, - ) - - @_dispatch.register - def _(self, event: IterationNextEvent) -> NodeRunIterationNextEvent: - return NodeRunIterationNextEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - index=event.index, - pre_iteration_output=event.pre_iteration_output, - ) - - @_dispatch.register - def _(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent: - return NodeRunIterationSucceededEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - start_at=event.start_at, - inputs=event.inputs, - outputs=event.outputs, - metadata=event.metadata, - steps=event.steps, - ) - - @_dispatch.register - def _(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent: - return NodeRunIterationFailedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - start_at=event.start_at, - inputs=event.inputs, - outputs=event.outputs, - metadata=event.metadata, - steps=event.steps, - error=event.error, - ) - - @_dispatch.register - def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent: - from core.rag.entities.citation_metadata import RetrievalSourceMetadata - - retriever_resources = [ - RetrievalSourceMetadata.model_validate(resource) for resource in event.retriever_resources - ] - return NodeRunRetrieverResourceEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - retriever_resources=retriever_resources, - context=event.context, - node_version=self.version(), - ) diff --git a/api/dify_graph/nodes/base/template.py b/api/dify_graph/nodes/base/template.py deleted file mode 100644 index 5976e808e3e..00000000000 --- a/api/dify_graph/nodes/base/template.py +++ /dev/null @@ -1,150 +0,0 @@ -"""Template structures for Response nodes (Answer and End). - -This module provides a unified template structure for both Answer and End nodes, -similar to SegmentGroup but focused on template representation without values. -""" - -from __future__ import annotations - -from abc import ABC, abstractmethod -from collections.abc import Sequence -from dataclasses import dataclass -from typing import Any, Union - -from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser - - -@dataclass(frozen=True) -class TemplateSegment(ABC): - """Base class for template segments.""" - - @abstractmethod - def __str__(self) -> str: - """String representation of the segment.""" - pass - - -@dataclass(frozen=True) -class TextSegment(TemplateSegment): - """A text segment in a template.""" - - text: str - - def __str__(self) -> str: - return self.text - - -@dataclass(frozen=True) -class VariableSegment(TemplateSegment): - """A variable reference segment in a template.""" - - selector: Sequence[str] - variable_name: str | None = None # Optional variable name for End nodes - - def __str__(self) -> str: - return "{{#" + ".".join(self.selector) + "#}}" - - -# Type alias for segments -TemplateSegmentUnion = Union[TextSegment, VariableSegment] - - -@dataclass(frozen=True) -class Template: - """Unified template structure for Response nodes. - - Similar to SegmentGroup, but represents the template structure - without variable values - only marking variable selectors. - """ - - segments: list[TemplateSegmentUnion] - - @classmethod - def from_answer_template(cls, template_str: str) -> Template: - """Create a Template from an Answer node template string. - - Example: - "Hello, {{#node1.name#}}" -> [TextSegment("Hello, "), VariableSegment(["node1", "name"])] - - Args: - template_str: The answer template string - - Returns: - Template instance - """ - parser = VariableTemplateParser(template_str) - segments: list[TemplateSegmentUnion] = [] - - # Extract variable selectors to find all variables - variable_selectors = parser.extract_variable_selectors() - var_map = {var.variable: var.value_selector for var in variable_selectors} - - # Parse template to get ordered segments - # We need to split the template by variable placeholders while preserving order - import re - - # Create a regex pattern that matches variable placeholders - pattern = r"\{\{(#[a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}" - - # Split template while keeping the delimiters (variable placeholders) - parts = re.split(pattern, template_str) - - for i, part in enumerate(parts): - if not part: - continue - - # Check if this part is a variable reference (odd indices after split) - if i % 2 == 1: # Odd indices are variable keys - # Remove the # symbols from the variable key - var_key = part - if var_key in var_map: - segments.append(VariableSegment(selector=list(var_map[var_key]))) - else: - # This shouldn't happen with valid templates - segments.append(TextSegment(text="{{" + part + "}}")) - else: - # Even indices are text segments - segments.append(TextSegment(text=part)) - - return cls(segments=segments) - - @classmethod - def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> Template: - """Create a Template from an End node outputs configuration. - - End nodes are treated as templates of concatenated variables with newlines. - - Example: - [{"variable": "text", "value_selector": ["node1", "text"]}, - {"variable": "result", "value_selector": ["node2", "result"]}] - -> - [VariableSegment(["node1", "text"]), - TextSegment("\n"), - VariableSegment(["node2", "result"])] - - Args: - outputs_config: List of output configurations with variable and value_selector - - Returns: - Template instance - """ - segments: list[TemplateSegmentUnion] = [] - - for i, output in enumerate(outputs_config): - if i > 0: - # Add newline separator between variables - segments.append(TextSegment(text="\n")) - - value_selector = output.get("value_selector", []) - variable_name = output.get("variable", "") - if value_selector: - segments.append(VariableSegment(selector=list(value_selector), variable_name=variable_name)) - - if len(segments) > 0 and isinstance(segments[-1], TextSegment): - segments = segments[:-1] - - return cls(segments=segments) - - def __str__(self) -> str: - """String representation of the template.""" - return "".join(str(segment) for segment in self.segments) diff --git a/api/dify_graph/nodes/base/usage_tracking_mixin.py b/api/dify_graph/nodes/base/usage_tracking_mixin.py deleted file mode 100644 index bd49419fd34..00000000000 --- a/api/dify_graph/nodes/base/usage_tracking_mixin.py +++ /dev/null @@ -1,28 +0,0 @@ -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.runtime import GraphRuntimeState - - -class LLMUsageTrackingMixin: - """Provides shared helpers for merging and recording LLM usage within workflow nodes.""" - - graph_runtime_state: GraphRuntimeState - - @staticmethod - def _merge_usage(current: LLMUsage, new_usage: LLMUsage | None) -> LLMUsage: - """Return a combined usage snapshot, preserving zero-value inputs.""" - if new_usage is None or new_usage.total_tokens <= 0: - return current - if current.total_tokens == 0: - return new_usage - return current.plus(new_usage) - - def _accumulate_usage(self, usage: LLMUsage) -> None: - """Push usage into the graph runtime accumulator for downstream reporting.""" - if usage.total_tokens <= 0: - return - - current_usage = self.graph_runtime_state.llm_usage - if current_usage.total_tokens == 0: - self.graph_runtime_state.llm_usage = usage.model_copy() - else: - self.graph_runtime_state.llm_usage = current_usage.plus(usage) diff --git a/api/dify_graph/nodes/base/variable_template_parser.py b/api/dify_graph/nodes/base/variable_template_parser.py deleted file mode 100644 index de5e619e8c4..00000000000 --- a/api/dify_graph/nodes/base/variable_template_parser.py +++ /dev/null @@ -1,130 +0,0 @@ -import re -from collections.abc import Mapping, Sequence -from typing import Any - -from .entities import VariableSelector - -REGEX = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}") - -SELECTOR_PATTERN = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}") - - -def extract_selectors_from_template(template: str, /) -> Sequence[VariableSelector]: - parts = SELECTOR_PATTERN.split(template) - selectors = [] - for part in filter(lambda x: x, parts): - if "." in part and part[0] == "#" and part[-1] == "#": - selectors.append(VariableSelector(variable=f"{part}", value_selector=part[1:-1].split("."))) - return selectors - - -class VariableTemplateParser: - """ - !NOTE: Consider to use the new `segments` module instead of this class. - - A class for parsing and manipulating template variables in a string. - - Rules: - - 1. Template variables must be enclosed in `{{}}`. - 2. The template variable Key can only be: #node_id.var1.var2#. - 3. The template variable Key cannot contain new lines or spaces, and must comply with rule 2. - - Example usage: - - template = "Hello, {{#node_id.query.name#}}! Your age is {{#node_id.query.age#}}." - parser = VariableTemplateParser(template) - - # Extract template variable keys - variable_keys = parser.extract() - print(variable_keys) - # Output: ['#node_id.query.name#', '#node_id.query.age#'] - - # Extract variable selectors - variable_selectors = parser.extract_variable_selectors() - print(variable_selectors) - # Output: [VariableSelector(variable='#node_id.query.name#', value_selector=['node_id', 'query', 'name']), - # VariableSelector(variable='#node_id.query.age#', value_selector=['node_id', 'query', 'age'])] - - # Format the template string - inputs = {'#node_id.query.name#': 'John', '#node_id.query.age#': 25}} - formatted_string = parser.format(inputs) - print(formatted_string) - # Output: "Hello, John! Your age is 25." - """ - - def __init__(self, template: str): - self.template = template - self.variable_keys = self.extract() - - def extract(self): - """ - Extracts all the template variable keys from the template string. - - Returns: - A list of template variable keys. - """ - # Regular expression to match the template rules - matches = re.findall(REGEX, self.template) - - first_group_matches = [match[0] for match in matches] - - return list(set(first_group_matches)) - - def extract_variable_selectors(self) -> list[VariableSelector]: - """ - Extracts the variable selectors from the template variable keys. - - Returns: - A list of VariableSelector objects representing the variable selectors. - """ - variable_selectors = [] - for variable_key in self.variable_keys: - remove_hash = variable_key.replace("#", "") - split_result = remove_hash.split(".") - if len(split_result) < 2: - continue - - variable_selectors.append(VariableSelector(variable=variable_key, value_selector=split_result)) - - return variable_selectors - - def format(self, inputs: Mapping[str, Any]) -> str: - """ - Formats the template string by replacing the template variables with their corresponding values. - - Args: - inputs: A dictionary containing the values for the template variables. - - Returns: - The formatted string with template variables replaced by their values. - """ - - def replacer(match): - key = match.group(1) - value = inputs.get(key, match.group(0)) # return original matched string if key not found - - if value is None: - value = "" - # convert the value to string - if isinstance(value, list | dict | bool | int | float): - value = str(value) - - # remove template variables if required - return VariableTemplateParser.remove_template_variables(value) - - prompt = re.sub(REGEX, replacer, self.template) - return re.sub(r"<\|.*?\|>", "", prompt) - - @classmethod - def remove_template_variables(cls, text: str): - """ - Removes the template variables from the given text. - - Args: - text: The text from which to remove the template variables. - - Returns: - The text with template variables removed. - """ - return re.sub(REGEX, r"{\1}", text) diff --git a/api/dify_graph/nodes/code/__init__.py b/api/dify_graph/nodes/code/__init__.py deleted file mode 100644 index 8c6dcc7fccb..00000000000 --- a/api/dify_graph/nodes/code/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .code_node import CodeNode - -__all__ = ["CodeNode"] diff --git a/api/dify_graph/nodes/code/code_node.py b/api/dify_graph/nodes/code/code_node.py deleted file mode 100644 index 82d5fced620..00000000000 --- a/api/dify_graph/nodes/code/code_node.py +++ /dev/null @@ -1,493 +0,0 @@ -from collections.abc import Mapping, Sequence -from decimal import Decimal -from textwrap import dedent -from typing import TYPE_CHECKING, Any, Protocol, cast - -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.code.entities import CodeLanguage, CodeNodeData -from dify_graph.nodes.code.limits import CodeNodeLimits -from dify_graph.variables.segments import ArrayFileSegment -from dify_graph.variables.types import SegmentType - -from .exc import ( - CodeNodeError, - DepthLimitError, - OutputValidationError, -) - -if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState - - -class WorkflowCodeExecutor(Protocol): - def execute( - self, - *, - language: CodeLanguage, - code: str, - inputs: Mapping[str, Any], - ) -> Mapping[str, Any]: ... - - def is_execution_error(self, error: Exception) -> bool: ... - - -def _build_default_config(*, language: CodeLanguage, code: str) -> Mapping[str, object]: - return { - "type": "code", - "config": { - "variables": [ - {"variable": "arg1", "value_selector": []}, - {"variable": "arg2", "value_selector": []}, - ], - "code_language": language, - "code": code, - "outputs": {"result": {"type": "string", "children": None}}, - }, - } - - -_DEFAULT_CODE_BY_LANGUAGE: Mapping[CodeLanguage, str] = { - CodeLanguage.PYTHON3: dedent( - """ - def main(arg1: str, arg2: str): - return { - "result": arg1 + arg2, - } - """ - ), - CodeLanguage.JAVASCRIPT: dedent( - """ - function main({arg1, arg2}) { - return { - result: arg1 + arg2 - } - } - """ - ), -} - - -class CodeNode(Node[CodeNodeData]): - node_type = BuiltinNodeTypes.CODE - _limits: CodeNodeLimits - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - *, - code_executor: WorkflowCodeExecutor, - code_limits: CodeNodeLimits, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - self._code_executor: WorkflowCodeExecutor = code_executor - self._limits = code_limits - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - """ - Get default config of node. - :param filters: filter by node config parameters. - :return: - """ - code_language = CodeLanguage.PYTHON3 - if filters: - code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3)) - - default_code = _DEFAULT_CODE_BY_LANGUAGE.get(code_language) - if default_code is None: - raise CodeNodeError(f"Unsupported code language: {code_language}") - return _build_default_config(language=code_language, code=default_code) - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - # Get code language - code_language = self.node_data.code_language - code = self.node_data.code - - # Get variables - variables = {} - for variable_selector in self.node_data.variables: - variable_name = variable_selector.variable - variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) - if isinstance(variable, ArrayFileSegment): - variables[variable_name] = [v.to_dict() for v in variable.value] if variable.value else None - else: - variables[variable_name] = variable.to_object() if variable else None - # Run code - try: - result = self._code_executor.execute( - language=code_language, - code=code, - inputs=variables, - ) - - # Transform result - result = self._transform_result(result=result, output_schema=self.node_data.outputs) - except CodeNodeError as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__ - ) - except Exception as e: - if not self._code_executor.is_execution_error(e): - raise - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__ - ) - - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result) - - def _check_string(self, value: str | None, variable: str) -> str | None: - """ - Check string - :param value: value - :param variable: variable - :return: - """ - if value is None: - return None - - if len(value) > self._limits.max_string_length: - raise OutputValidationError( - f"The length of output variable `{variable}` must be" - f" less than {self._limits.max_string_length} characters" - ) - - return value.replace("\x00", "") - - def _check_boolean(self, value: bool | None, variable: str) -> bool | None: - if value is None: - return None - - return value - - def _check_number(self, value: int | float | None, variable: str) -> int | float | None: - """ - Check number - :param value: value - :param variable: variable - :return: - """ - if value is None: - return None - - if value > self._limits.max_number or value < self._limits.min_number: - raise OutputValidationError( - f"Output variable `{variable}` is out of range," - f" it must be between {self._limits.min_number} and {self._limits.max_number}." - ) - - if isinstance(value, float): - decimal_value = Decimal(str(value)).normalize() - precision = -decimal_value.as_tuple().exponent if decimal_value.as_tuple().exponent < 0 else 0 # type: ignore[operator] - # raise error if precision is too high - if precision > self._limits.max_precision: - raise OutputValidationError( - f"Output variable `{variable}` has too high precision," - f" it must be less than {self._limits.max_precision} digits." - ) - - return value - - def _transform_result( - self, - result: Mapping[str, Any], - output_schema: dict[str, CodeNodeData.Output] | None, - prefix: str = "", - depth: int = 1, - ): - # TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes. - # Note that `_transform_result` may produce lists containing `None` values, - # which don't conform to the type requirements of `Array*Segment` classes. - if depth > self._limits.max_depth: - raise DepthLimitError(f"Depth limit {self._limits.max_depth} reached, object too deep.") - - transformed_result: dict[str, Any] = {} - if output_schema is None: - # validate output thought instance type - for output_name, output_value in result.items(): - if isinstance(output_value, dict): - self._transform_result( - result=output_value, - output_schema=None, - prefix=f"{prefix}.{output_name}" if prefix else output_name, - depth=depth + 1, - ) - elif isinstance(output_value, bool): - self._check_boolean(output_value, variable=f"{prefix}.{output_name}" if prefix else output_name) - elif isinstance(output_value, int | float): - self._check_number( - value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name - ) - elif isinstance(output_value, str): - self._check_string( - value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name - ) - elif isinstance(output_value, list): - first_element = output_value[0] if len(output_value) > 0 else None - if first_element is not None: - if isinstance(first_element, int | float) and all( - value is None or isinstance(value, int | float) for value in output_value - ): - for i, value in enumerate(output_value): - self._check_number( - value=value, - variable=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]", - ) - elif isinstance(first_element, str) and all( - value is None or isinstance(value, str) for value in output_value - ): - for i, value in enumerate(output_value): - self._check_string( - value=value, - variable=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]", - ) - elif ( - isinstance(first_element, dict) - and all(value is None or isinstance(value, dict) for value in output_value) - or isinstance(first_element, list) - and all(value is None or isinstance(value, list) for value in output_value) - ): - for i, value in enumerate(output_value): - if value is not None: - self._transform_result( - result=value, - output_schema=None, - prefix=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]", - depth=depth + 1, - ) - else: - raise OutputValidationError( - f"Output {prefix}.{output_name} is not a valid array." - f" make sure all elements are of the same type." - ) - elif output_value is None: - pass - else: - raise OutputValidationError(f"Output {prefix}.{output_name} is not a valid type.") - - return result - - parameters_validated = {} - for output_name, output_config in output_schema.items(): - dot = "." if prefix else "" - if output_name not in result: - raise OutputValidationError(f"Output {prefix}{dot}{output_name} is missing.") - - if output_config.type == SegmentType.OBJECT: - # check if output is object - if not isinstance(result.get(output_name), dict): - if result[output_name] is None: - transformed_result[output_name] = None - else: - raise OutputValidationError( - f"Output {prefix}{dot}{output_name} is not an object," - f" got {type(result.get(output_name))} instead." - ) - else: - transformed_result[output_name] = self._transform_result( - result=result[output_name], - output_schema=output_config.children, - prefix=f"{prefix}.{output_name}", - depth=depth + 1, - ) - elif output_config.type == SegmentType.NUMBER: - # check if number available - value = result.get(output_name) - if value is not None and not isinstance(value, (int, float)): - raise OutputValidationError( - f"Output {prefix}{dot}{output_name} is not a number," - f" got {type(result.get(output_name))} instead." - ) - checked = self._check_number(value=value, variable=f"{prefix}{dot}{output_name}") - # If the output is a boolean and the output schema specifies a NUMBER type, - # convert the boolean value to an integer. - # - # This ensures compatibility with existing workflows that may use - # `True` and `False` as values for NUMBER type outputs. - transformed_result[output_name] = self._convert_boolean_to_int(checked) - - elif output_config.type == SegmentType.STRING: - # check if string available - value = result.get(output_name) - if value is not None and not isinstance(value, str): - raise OutputValidationError( - f"Output {prefix}{dot}{output_name} must be a string, got {type(value).__name__} instead" - ) - transformed_result[output_name] = self._check_string( - value=value, - variable=f"{prefix}{dot}{output_name}", - ) - elif output_config.type == SegmentType.BOOLEAN: - transformed_result[output_name] = self._check_boolean( - value=result[output_name], - variable=f"{prefix}{dot}{output_name}", - ) - elif output_config.type == SegmentType.ARRAY_NUMBER: - # check if array of number available - value = result[output_name] - if not isinstance(value, list): - if value is None: - transformed_result[output_name] = None - else: - raise OutputValidationError( - f"Output {prefix}{dot}{output_name} is not an array, got {type(value)} instead." - ) - else: - if len(value) > self._limits.max_number_array_length: - raise OutputValidationError( - f"The length of output variable `{prefix}{dot}{output_name}` must be" - f" less than {self._limits.max_number_array_length} elements." - ) - - for i, inner_value in enumerate(value): - if not isinstance(inner_value, (int, float)): - raise OutputValidationError( - f"The element at index {i} of output variable `{prefix}{dot}{output_name}` must be" - f" a number." - ) - _ = self._check_number(value=inner_value, variable=f"{prefix}{dot}{output_name}[{i}]") - transformed_result[output_name] = [ - # If the element is a boolean and the output schema specifies a `array[number]` type, - # convert the boolean value to an integer. - # - # This ensures compatibility with existing workflows that may use - # `True` and `False` as values for NUMBER type outputs. - self._convert_boolean_to_int(v) - for v in value - ] - elif output_config.type == SegmentType.ARRAY_STRING: - # check if array of string available - if not isinstance(result[output_name], list): - if result[output_name] is None: - transformed_result[output_name] = None - else: - raise OutputValidationError( - f"Output {prefix}{dot}{output_name} is not an array," - f" got {type(result.get(output_name))} instead." - ) - else: - if len(result[output_name]) > self._limits.max_string_array_length: - raise OutputValidationError( - f"The length of output variable `{prefix}{dot}{output_name}` must be" - f" less than {self._limits.max_string_array_length} elements." - ) - - transformed_result[output_name] = [ - self._check_string(value=value, variable=f"{prefix}{dot}{output_name}[{i}]") - for i, value in enumerate(result[output_name]) - ] - elif output_config.type == SegmentType.ARRAY_OBJECT: - # check if array of object available - if not isinstance(result[output_name], list): - if result[output_name] is None: - transformed_result[output_name] = None - else: - raise OutputValidationError( - f"Output {prefix}{dot}{output_name} is not an array," - f" got {type(result.get(output_name))} instead." - ) - else: - if len(result[output_name]) > self._limits.max_object_array_length: - raise OutputValidationError( - f"The length of output variable `{prefix}{dot}{output_name}` must be" - f" less than {self._limits.max_object_array_length} elements." - ) - - for i, value in enumerate(result[output_name]): - if not isinstance(value, dict): - if value is None: - pass - else: - raise OutputValidationError( - f"Output {prefix}{dot}{output_name}[{i}] is not an object," - f" got {type(value)} instead at index {i}." - ) - - transformed_result[output_name] = [ - None - if value is None - else self._transform_result( - result=value, - output_schema=output_config.children, - prefix=f"{prefix}{dot}{output_name}[{i}]", - depth=depth + 1, - ) - for i, value in enumerate(result[output_name]) - ] - elif output_config.type == SegmentType.ARRAY_BOOLEAN: - # check if array of object available - value = result[output_name] - if not isinstance(value, list): - if value is None: - transformed_result[output_name] = None - else: - raise OutputValidationError( - f"Output {prefix}{dot}{output_name} is not an array," - f" got {type(result.get(output_name))} instead." - ) - else: - for i, inner_value in enumerate(value): - if inner_value is not None and not isinstance(inner_value, bool): - raise OutputValidationError( - f"Output {prefix}{dot}{output_name}[{i}] is not a boolean," - f" got {type(inner_value)} instead." - ) - _ = self._check_boolean(value=inner_value, variable=f"{prefix}{dot}{output_name}[{i}]") - transformed_result[output_name] = value - - else: - raise OutputValidationError(f"Output type {output_config.type} is not supported.") - - parameters_validated[output_name] = True - - # check if all output parameters are validated - if len(parameters_validated) != len(result): - raise CodeNodeError("Not all output parameters are validated.") - - return transformed_result - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: CodeNodeData, - ) -> Mapping[str, Sequence[str]]: - _ = graph_config # Explicitly mark as unused - return { - node_id + "." + variable_selector.variable: variable_selector.value_selector - for variable_selector in node_data.variables - } - - @property - def retry(self) -> bool: - return self.node_data.retry_config.retry_enabled - - @staticmethod - def _convert_boolean_to_int(value: bool | int | float | None) -> int | float | None: - """This function convert boolean to integers when the output schema specifies a NUMBER type. - - This ensures compatibility with existing workflows that may use - `True` and `False` as values for NUMBER type outputs. - """ - if value is None: - return None - if isinstance(value, bool): - return int(value) - return value diff --git a/api/dify_graph/nodes/code/entities.py b/api/dify_graph/nodes/code/entities.py deleted file mode 100644 index 55b4ee48623..00000000000 --- a/api/dify_graph/nodes/code/entities.py +++ /dev/null @@ -1,57 +0,0 @@ -from enum import StrEnum -from typing import Annotated, Literal - -from pydantic import AfterValidator, BaseModel - -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base.entities import VariableSelector -from dify_graph.variables.types import SegmentType - - -class CodeLanguage(StrEnum): - PYTHON3 = "python3" - JINJA2 = "jinja2" - JAVASCRIPT = "javascript" - - -_ALLOWED_OUTPUT_FROM_CODE = frozenset( - [ - SegmentType.STRING, - SegmentType.NUMBER, - SegmentType.OBJECT, - SegmentType.BOOLEAN, - SegmentType.ARRAY_STRING, - SegmentType.ARRAY_NUMBER, - SegmentType.ARRAY_OBJECT, - SegmentType.ARRAY_BOOLEAN, - ] -) - - -def _validate_type(segment_type: SegmentType) -> SegmentType: - if segment_type not in _ALLOWED_OUTPUT_FROM_CODE: - raise ValueError(f"invalid type for code output, expected {_ALLOWED_OUTPUT_FROM_CODE}, actual {segment_type}") - return segment_type - - -class CodeNodeData(BaseNodeData): - """ - Code Node Data. - """ - - type: NodeType = BuiltinNodeTypes.CODE - - class Output(BaseModel): - type: Annotated[SegmentType, AfterValidator(_validate_type)] - children: dict[str, "CodeNodeData.Output"] | None = None - - class Dependency(BaseModel): - name: str - version: str - - variables: list[VariableSelector] - code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT] - code: str - outputs: dict[str, Output] - dependencies: list[Dependency] | None = None diff --git a/api/dify_graph/nodes/code/exc.py b/api/dify_graph/nodes/code/exc.py deleted file mode 100644 index d6334fd554c..00000000000 --- a/api/dify_graph/nodes/code/exc.py +++ /dev/null @@ -1,16 +0,0 @@ -class CodeNodeError(ValueError): - """Base class for code node errors.""" - - pass - - -class OutputValidationError(CodeNodeError): - """Raised when there is an output validation error.""" - - pass - - -class DepthLimitError(CodeNodeError): - """Raised when the depth limit is reached.""" - - pass diff --git a/api/dify_graph/nodes/code/limits.py b/api/dify_graph/nodes/code/limits.py deleted file mode 100644 index a6b9e9e68ee..00000000000 --- a/api/dify_graph/nodes/code/limits.py +++ /dev/null @@ -1,13 +0,0 @@ -from dataclasses import dataclass - - -@dataclass(frozen=True) -class CodeNodeLimits: - max_string_length: int - max_number: int | float - min_number: int | float - max_precision: int - max_depth: int - max_number_array_length: int - max_string_array_length: int - max_object_array_length: int diff --git a/api/dify_graph/nodes/document_extractor/__init__.py b/api/dify_graph/nodes/document_extractor/__init__.py deleted file mode 100644 index 9922e3949da..00000000000 --- a/api/dify_graph/nodes/document_extractor/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .entities import DocumentExtractorNodeData, UnstructuredApiConfig -from .node import DocumentExtractorNode - -__all__ = ["DocumentExtractorNode", "DocumentExtractorNodeData", "UnstructuredApiConfig"] diff --git a/api/dify_graph/nodes/document_extractor/entities.py b/api/dify_graph/nodes/document_extractor/entities.py deleted file mode 100644 index 1110cc2710f..00000000000 --- a/api/dify_graph/nodes/document_extractor/entities.py +++ /dev/null @@ -1,16 +0,0 @@ -from collections.abc import Sequence -from dataclasses import dataclass - -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType - - -class DocumentExtractorNodeData(BaseNodeData): - type: NodeType = BuiltinNodeTypes.DOCUMENT_EXTRACTOR - variable_selector: Sequence[str] - - -@dataclass(frozen=True) -class UnstructuredApiConfig: - api_url: str | None = None - api_key: str = "" diff --git a/api/dify_graph/nodes/document_extractor/exc.py b/api/dify_graph/nodes/document_extractor/exc.py deleted file mode 100644 index 5caf00ebc5f..00000000000 --- a/api/dify_graph/nodes/document_extractor/exc.py +++ /dev/null @@ -1,14 +0,0 @@ -class DocumentExtractorError(ValueError): - """Base exception for errors related to the DocumentExtractorNode.""" - - -class FileDownloadError(DocumentExtractorError): - """Exception raised when there's an error downloading a file.""" - - -class UnsupportedFileTypeError(DocumentExtractorError): - """Exception raised when trying to extract text from an unsupported file type.""" - - -class TextExtractionError(DocumentExtractorError): - """Exception raised when there's an error during text extraction from a file.""" diff --git a/api/dify_graph/nodes/document_extractor/node.py b/api/dify_graph/nodes/document_extractor/node.py deleted file mode 100644 index 27196f1aca9..00000000000 --- a/api/dify_graph/nodes/document_extractor/node.py +++ /dev/null @@ -1,782 +0,0 @@ -import csv -import io -import json -import logging -import os -import tempfile -import zipfile -from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any - -import charset_normalizer -import docx -import pandas as pd -import pypandoc -import pypdfium2 -import webvtt -import yaml -from docx.document import Document -from docx.oxml.table import CT_Tbl -from docx.oxml.text.paragraph import CT_P -from docx.table import Table -from docx.text.paragraph import Paragraph - -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.file import File, FileTransferMethod, file_manager -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.protocols import HttpClientProtocol -from dify_graph.variables import ArrayFileSegment -from dify_graph.variables.segments import ArrayStringSegment, FileSegment - -from .entities import DocumentExtractorNodeData, UnstructuredApiConfig -from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError - -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState - - -class DocumentExtractorNode(Node[DocumentExtractorNodeData]): - """ - Extracts text content from various file types. - Supports plain text, PDF, and DOC/DOCX files. - """ - - node_type = BuiltinNodeTypes.DOCUMENT_EXTRACTOR - - @classmethod - def version(cls) -> str: - return "1" - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - *, - unstructured_api_config: UnstructuredApiConfig | None = None, - http_client: HttpClientProtocol, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - self._unstructured_api_config = unstructured_api_config or UnstructuredApiConfig() - self._http_client = http_client - - def _run(self): - variable_selector = self.node_data.variable_selector - variable = self.graph_runtime_state.variable_pool.get(variable_selector) - - if variable is None: - error_message = f"File variable not found for selector: {variable_selector}" - return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=error_message) - if variable.value and not isinstance(variable, ArrayFileSegment | FileSegment): - error_message = f"Variable {variable_selector} is not an ArrayFileSegment" - return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=error_message) - - value = variable.value - inputs = {"variable_selector": variable_selector} - if isinstance(value, list): - value = list(filter(lambda x: x, value)) - process_data = {"documents": value if isinstance(value, list) else [value]} - - if not value: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs={"text": ArrayStringSegment(value=[])}, - ) - - try: - if isinstance(value, list): - extracted_text_list = [ - _extract_text_from_file( - self._http_client, file, unstructured_api_config=self._unstructured_api_config - ) - for file in value - ] - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs={"text": ArrayStringSegment(value=extracted_text_list)}, - ) - elif isinstance(value, File): - extracted_text = _extract_text_from_file( - self._http_client, value, unstructured_api_config=self._unstructured_api_config - ) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs={"text": extracted_text}, - ) - else: - raise DocumentExtractorError(f"Unsupported variable type: {type(value)}") - except DocumentExtractorError as e: - logger.warning(e, exc_info=True) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - inputs=inputs, - process_data=process_data, - ) - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: DocumentExtractorNodeData, - ) -> Mapping[str, Sequence[str]]: - _ = graph_config # Explicitly mark as unused - return {node_id + ".files": node_data.variable_selector} - - -def _extract_text_by_mime_type( - *, - file_content: bytes, - mime_type: str, - unstructured_api_config: UnstructuredApiConfig, -) -> str: - """Extract text from a file based on its MIME type.""" - match mime_type: - case "text/plain" | "text/html" | "text/htm" | "text/markdown" | "text/xml": - return _extract_text_from_plain_text(file_content) - case "application/pdf": - return _extract_text_from_pdf(file_content) - case "application/msword": - return _extract_text_from_doc(file_content, unstructured_api_config=unstructured_api_config) - case "application/vnd.openxmlformats-officedocument.wordprocessingml.document": - return _extract_text_from_docx(file_content) - case "text/csv": - return _extract_text_from_csv(file_content) - case "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" | "application/vnd.ms-excel": - return _extract_text_from_excel(file_content) - case "application/vnd.ms-powerpoint": - return _extract_text_from_ppt(file_content, unstructured_api_config=unstructured_api_config) - case "application/vnd.openxmlformats-officedocument.presentationml.presentation": - return _extract_text_from_pptx(file_content, unstructured_api_config=unstructured_api_config) - case "application/epub+zip": - return _extract_text_from_epub(file_content, unstructured_api_config=unstructured_api_config) - case "message/rfc822": - return _extract_text_from_eml(file_content) - case "application/vnd.ms-outlook": - return _extract_text_from_msg(file_content) - case "application/json": - return _extract_text_from_json(file_content) - case "application/x-yaml" | "text/yaml": - return _extract_text_from_yaml(file_content) - case "text/vtt": - return _extract_text_from_vtt(file_content) - case "text/properties": - return _extract_text_from_properties(file_content) - case _: - raise UnsupportedFileTypeError(f"Unsupported MIME type: {mime_type}") - - -def _extract_text_by_file_extension( - *, - file_content: bytes, - file_extension: str, - unstructured_api_config: UnstructuredApiConfig, -) -> str: - """Extract text from a file based on its file extension.""" - match file_extension: - case ( - ".txt" - | ".markdown" - | ".md" - | ".mdx" - | ".html" - | ".htm" - | ".xml" - | ".c" - | ".h" - | ".cpp" - | ".hpp" - | ".cc" - | ".cxx" - | ".c++" - | ".py" - | ".js" - | ".ts" - | ".jsx" - | ".tsx" - | ".java" - | ".php" - | ".rb" - | ".go" - | ".rs" - | ".swift" - | ".kt" - | ".scala" - | ".sh" - | ".bash" - | ".bat" - | ".ps1" - | ".sql" - | ".r" - | ".m" - | ".pl" - | ".lua" - | ".vim" - | ".asm" - | ".s" - | ".css" - | ".scss" - | ".less" - | ".sass" - | ".ini" - | ".cfg" - | ".conf" - | ".toml" - | ".env" - | ".log" - | ".vtt" - ): - return _extract_text_from_plain_text(file_content) - case ".json": - return _extract_text_from_json(file_content) - case ".yaml" | ".yml": - return _extract_text_from_yaml(file_content) - case ".pdf": - return _extract_text_from_pdf(file_content) - case ".doc": - return _extract_text_from_doc(file_content, unstructured_api_config=unstructured_api_config) - case ".docx": - return _extract_text_from_docx(file_content) - case ".csv": - return _extract_text_from_csv(file_content) - case ".xls" | ".xlsx": - return _extract_text_from_excel(file_content) - case ".ppt": - return _extract_text_from_ppt(file_content, unstructured_api_config=unstructured_api_config) - case ".pptx": - return _extract_text_from_pptx(file_content, unstructured_api_config=unstructured_api_config) - case ".epub": - return _extract_text_from_epub(file_content, unstructured_api_config=unstructured_api_config) - case ".eml": - return _extract_text_from_eml(file_content) - case ".msg": - return _extract_text_from_msg(file_content) - case ".properties": - return _extract_text_from_properties(file_content) - case _: - raise UnsupportedFileTypeError(f"Unsupported Extension Type: {file_extension}") - - -def _extract_text_from_plain_text(file_content: bytes) -> str: - try: - # Detect encoding using charset_normalizer - result = charset_normalizer.from_bytes(file_content, cp_isolation=["utf_8", "latin_1", "cp1252"]).best() - if result: - encoding = result.encoding - else: - encoding = "utf-8" - - # Fallback to utf-8 if detection fails - if not encoding: - encoding = "utf-8" - - return file_content.decode(encoding, errors="ignore") - except (UnicodeDecodeError, LookupError) as e: - # If decoding fails, try with utf-8 as last resort - try: - return file_content.decode("utf-8", errors="ignore") - except UnicodeDecodeError: - raise TextExtractionError(f"Failed to decode plain text file: {e}") from e - - -def _extract_text_from_json(file_content: bytes) -> str: - try: - # Detect encoding using charset_normalizer - result = charset_normalizer.from_bytes(file_content).best() - if result: - encoding = result.encoding - else: - encoding = "utf-8" - - # Fallback to utf-8 if detection fails - if not encoding: - encoding = "utf-8" - - json_data = json.loads(file_content.decode(encoding, errors="ignore")) - return json.dumps(json_data, indent=2, ensure_ascii=False) - except (UnicodeDecodeError, LookupError, json.JSONDecodeError) as e: - # If decoding fails, try with utf-8 as last resort - try: - json_data = json.loads(file_content.decode("utf-8", errors="ignore")) - return json.dumps(json_data, indent=2, ensure_ascii=False) - except (UnicodeDecodeError, json.JSONDecodeError): - raise TextExtractionError(f"Failed to decode or parse JSON file: {e}") from e - - -def _extract_text_from_yaml(file_content: bytes) -> str: - """Extract the content from yaml file""" - try: - # Detect encoding using charset_normalizer - result = charset_normalizer.from_bytes(file_content).best() - if result: - encoding = result.encoding - else: - encoding = "utf-8" - - # Fallback to utf-8 if detection fails - if not encoding: - encoding = "utf-8" - - yaml_data = yaml.safe_load_all(file_content.decode(encoding, errors="ignore")) - return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False) - except (UnicodeDecodeError, LookupError, yaml.YAMLError) as e: - # If decoding fails, try with utf-8 as last resort - try: - yaml_data = yaml.safe_load_all(file_content.decode("utf-8", errors="ignore")) - return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False) - except (UnicodeDecodeError, yaml.YAMLError): - raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e - - -def _extract_text_from_pdf(file_content: bytes) -> str: - try: - pdf_file = io.BytesIO(file_content) - pdf_document = pypdfium2.PdfDocument(pdf_file, autoclose=True) - text = "" - for page in pdf_document: - text_page = page.get_textpage() - text += text_page.get_text_range() - text_page.close() - page.close() - return text - except Exception as e: - raise TextExtractionError(f"Failed to extract text from PDF: {str(e)}") from e - - -def _extract_text_from_doc(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str: - """ - Extract text from a DOC file. - """ - from unstructured.partition.api import partition_via_api - - if not unstructured_api_config.api_url: - raise TextExtractionError("Unstructured API URL is not configured for DOC file processing.") - api_key = unstructured_api_config.api_key or "" - - try: - with tempfile.NamedTemporaryFile(suffix=".doc", delete=False) as temp_file: - temp_file.write(file_content) - temp_file.flush() - with open(temp_file.name, "rb") as file: - elements = partition_via_api( - file=file, - metadata_filename=temp_file.name, - api_url=unstructured_api_config.api_url, - api_key=api_key, - ) - os.unlink(temp_file.name) - return "\n".join([getattr(element, "text", "") for element in elements]) - except Exception as e: - raise TextExtractionError(f"Failed to extract text from DOC: {str(e)}") from e - - -def parser_docx_part(block, doc: Document, content_items, i): - if isinstance(block, CT_P): - content_items.append((i, "paragraph", Paragraph(block, doc))) - elif isinstance(block, CT_Tbl): - content_items.append((i, "table", Table(block, doc))) - - -def _normalize_docx_zip(file_content: bytes) -> bytes: - """ - Some DOCX files (e.g. exported by Evernote on Windows) are malformed: - ZIP entry names use backslash (\\) as path separator instead of the forward - slash (/) required by both the ZIP spec and OOXML. On Linux/Mac the entry - "word\\document.xml" is never found when python-docx looks for - "word/document.xml", which triggers a KeyError about a missing relationship. - - This function rewrites the ZIP in-memory, normalizing all entry names to - use forward slashes without touching any actual document content. - """ - try: - with zipfile.ZipFile(io.BytesIO(file_content), "r") as zin: - out_buf = io.BytesIO() - with zipfile.ZipFile(out_buf, "w", compression=zipfile.ZIP_DEFLATED) as zout: - for item in zin.infolist(): - data = zin.read(item.filename) - # Normalize backslash path separators to forward slash - item.filename = item.filename.replace("\\", "/") - zout.writestr(item, data) - return out_buf.getvalue() - except zipfile.BadZipFile: - # Not a valid zip — return as-is and let python-docx report the real error - return file_content - - -def _extract_text_from_docx(file_content: bytes) -> str: - """ - Extract text from a DOCX file. - For now support only paragraph and table add more if needed - """ - try: - doc_file = io.BytesIO(file_content) - try: - doc = docx.Document(doc_file) - except Exception as e: - logger.warning("Failed to parse DOCX, attempting to normalize ZIP entry paths: %s", e) - # Some DOCX files exported by tools like Evernote on Windows use - # backslash path separators in ZIP entries and/or single-quoted XML - # attributes, both of which break python-docx on Linux. Normalize and retry. - file_content = _normalize_docx_zip(file_content) - doc = docx.Document(io.BytesIO(file_content)) - text = [] - - # Keep track of paragraph and table positions - content_items: list[tuple[int, str, Table | Paragraph]] = [] - - it = iter(doc.element.body) - part = next(it, None) - i = 0 - while part is not None: - parser_docx_part(part, doc, content_items, i) - i = i + 1 - part = next(it, None) - - # Process sorted content - for _, item_type, item in content_items: - if item_type == "paragraph": - if isinstance(item, Table): - continue - text.append(item.text) - elif item_type == "table": - # Process tables - if not isinstance(item, Table): - continue - try: - # Check if any cell in the table has text - has_content = False - for row in item.rows: - if any(cell.text.strip() for cell in row.cells): - has_content = True - break - - if has_content: - cell_texts = [cell.text.replace("\n", "
") for cell in item.rows[0].cells] - markdown_table = f"| {' | '.join(cell_texts)} |\n" - markdown_table += f"| {' | '.join(['---'] * len(item.rows[0].cells))} |\n" - - for row in item.rows[1:]: - # Replace newlines with
in each cell - row_cells = [cell.text.replace("\n", "
") for cell in row.cells] - markdown_table += "| " + " | ".join(row_cells) + " |\n" - - text.append(markdown_table) - except Exception as e: - logger.warning("Failed to extract table from DOC: %s", e) - continue - - return "\n".join(text) - - except Exception as e: - raise TextExtractionError(f"Failed to extract text from DOCX: {str(e)}") from e - - -def _download_file_content(http_client: HttpClientProtocol, file: File) -> bytes: - """Download the content of a file based on its transfer method.""" - try: - if file.transfer_method == FileTransferMethod.REMOTE_URL: - if file.remote_url is None: - raise FileDownloadError("Missing URL for remote file") - response = http_client.get(file.remote_url) - response.raise_for_status() - return response.content - else: - return file_manager.download(file) - except Exception as e: - raise FileDownloadError(f"Error downloading file: {str(e)}") from e - - -def _extract_text_from_file( - http_client: HttpClientProtocol, file: File, *, unstructured_api_config: UnstructuredApiConfig -) -> str: - file_content = _download_file_content(http_client, file) - if file.extension: - extracted_text = _extract_text_by_file_extension( - file_content=file_content, - file_extension=file.extension, - unstructured_api_config=unstructured_api_config, - ) - elif file.mime_type: - extracted_text = _extract_text_by_mime_type( - file_content=file_content, - mime_type=file.mime_type, - unstructured_api_config=unstructured_api_config, - ) - else: - raise UnsupportedFileTypeError("Unable to determine file type: MIME type or file extension is missing") - return extracted_text - - -def _extract_text_from_csv(file_content: bytes) -> str: - try: - # Detect encoding using charset_normalizer - result = charset_normalizer.from_bytes(file_content).best() - if result: - encoding = result.encoding - else: - encoding = "utf-8" - - # Fallback to utf-8 if detection fails - if not encoding: - encoding = "utf-8" - - try: - csv_file = io.StringIO(file_content.decode(encoding, errors="ignore")) - except (UnicodeDecodeError, LookupError): - # If decoding fails, try with utf-8 as last resort - csv_file = io.StringIO(file_content.decode("utf-8", errors="ignore")) - - csv_reader = csv.reader(csv_file) - rows = list(csv_reader) - - if not rows: - return "" - - # Combine multi-line text in the header row - header_row = [cell.replace("\n", " ").replace("\r", "") for cell in rows[0]] - - # Create Markdown table - markdown_table = "| " + " | ".join(header_row) + " |\n" - markdown_table += "| " + " | ".join(["-" * len(col) for col in rows[0]]) + " |\n" - - # Process each data row and combine multi-line text in each cell - for row in rows[1:]: - processed_row = [cell.replace("\n", " ").replace("\r", "") for cell in row] - markdown_table += "| " + " | ".join(processed_row) + " |\n" - - return markdown_table - except Exception as e: - raise TextExtractionError(f"Failed to extract text from CSV: {str(e)}") from e - - -def _extract_text_from_excel(file_content: bytes) -> str: - """Extract text from an Excel file using pandas.""" - - def _construct_markdown_table(df: pd.DataFrame) -> str: - """Manually construct a Markdown table from a DataFrame.""" - # Construct the header row - header_row = "| " + " | ".join(df.columns) + " |" - - # Construct the separator row - separator_row = "| " + " | ".join(["-" * len(col) for col in df.columns]) + " |" - - # Construct the data rows - data_rows = [] - for _, row in df.iterrows(): - data_row = "| " + " | ".join(map(str, row)) + " |" - data_rows.append(data_row) - - # Combine all rows into a single string - markdown_table = "\n".join([header_row, separator_row] + data_rows) - return markdown_table - - try: - excel_file = pd.ExcelFile(io.BytesIO(file_content)) - markdown_table = "" - for sheet_name in excel_file.sheet_names: - try: - df = excel_file.parse(sheet_name=sheet_name) - df.dropna(how="all", inplace=True) - - # Combine multi-line text in each cell into a single line - df = df.map(lambda x: " ".join(str(x).splitlines()) if isinstance(x, str) else x) - - # Combine multi-line text in column names into a single line - df.columns = pd.Index([" ".join(str(col).splitlines()) for col in df.columns]) - - # Manually construct the Markdown table - markdown_table += _construct_markdown_table(df) + "\n\n" - except Exception: - continue - return markdown_table - except Exception as e: - raise TextExtractionError(f"Failed to extract text from Excel file: {str(e)}") from e - - -def _extract_text_from_ppt(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str: - from unstructured.partition.api import partition_via_api - from unstructured.partition.ppt import partition_ppt - - api_key = unstructured_api_config.api_key or "" - - try: - if unstructured_api_config.api_url: - with tempfile.NamedTemporaryFile(suffix=".ppt", delete=False) as temp_file: - temp_file.write(file_content) - temp_file.flush() - with open(temp_file.name, "rb") as file: - elements = partition_via_api( - file=file, - metadata_filename=temp_file.name, - api_url=unstructured_api_config.api_url, - api_key=api_key, - ) - os.unlink(temp_file.name) - else: - with io.BytesIO(file_content) as file: - elements = partition_ppt(file=file) - return "\n".join([getattr(element, "text", "") for element in elements]) - - except Exception as e: - raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e - - -def _extract_text_from_pptx(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str: - from unstructured.partition.api import partition_via_api - from unstructured.partition.pptx import partition_pptx - - api_key = unstructured_api_config.api_key or "" - - try: - if unstructured_api_config.api_url: - with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file: - temp_file.write(file_content) - temp_file.flush() - with open(temp_file.name, "rb") as file: - elements = partition_via_api( - file=file, - metadata_filename=temp_file.name, - api_url=unstructured_api_config.api_url, - api_key=api_key, - ) - os.unlink(temp_file.name) - else: - with io.BytesIO(file_content) as file: - elements = partition_pptx(file=file) - return "\n".join([getattr(element, "text", "") for element in elements]) - except Exception as e: - raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e - - -def _extract_text_from_epub(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str: - from unstructured.partition.api import partition_via_api - from unstructured.partition.epub import partition_epub - - api_key = unstructured_api_config.api_key or "" - - try: - if unstructured_api_config.api_url: - with tempfile.NamedTemporaryFile(suffix=".epub", delete=False) as temp_file: - temp_file.write(file_content) - temp_file.flush() - with open(temp_file.name, "rb") as file: - elements = partition_via_api( - file=file, - metadata_filename=temp_file.name, - api_url=unstructured_api_config.api_url, - api_key=api_key, - ) - os.unlink(temp_file.name) - else: - pypandoc.download_pandoc() - with io.BytesIO(file_content) as file: - elements = partition_epub(file=file) - return "\n".join([str(element) for element in elements]) - except Exception as e: - raise TextExtractionError(f"Failed to extract text from EPUB: {str(e)}") from e - - -def _extract_text_from_eml(file_content: bytes) -> str: - from unstructured.partition.email import partition_email - - try: - with io.BytesIO(file_content) as file: - elements = partition_email(file=file) - return "\n".join([str(element) for element in elements]) - except Exception as e: - raise TextExtractionError(f"Failed to extract text from EML: {str(e)}") from e - - -def _extract_text_from_msg(file_content: bytes) -> str: - from unstructured.partition.msg import partition_msg - - try: - with io.BytesIO(file_content) as file: - elements = partition_msg(file=file) - return "\n".join([str(element) for element in elements]) - except Exception as e: - raise TextExtractionError(f"Failed to extract text from MSG: {str(e)}") from e - - -def _extract_text_from_vtt(vtt_bytes: bytes) -> str: - text = _extract_text_from_plain_text(vtt_bytes) - - # remove bom - text = text.lstrip("\ufeff") - - raw_results = [] - for caption in webvtt.from_string(text): - raw_results.append((caption.voice, caption.text)) - - # Merge consecutive utterances by the same speaker - merged_results = [] - if raw_results: - current_speaker, current_text = raw_results[0] - - for i in range(1, len(raw_results)): - spk, txt = raw_results[i] - if spk is None: - merged_results.append((None, current_text)) - continue - - if spk == current_speaker: - # If it is the same speaker, merge the utterances (joined by space) - current_text += " " + txt - else: - # If the speaker changes, register the utterance so far and move on - merged_results.append((current_speaker, current_text)) - current_speaker, current_text = spk, txt - - # Add the last element - merged_results.append((current_speaker, current_text)) - else: - merged_results = raw_results - - # Return the result in the specified format: Speaker "text" style - formatted = [f'{spk or ""} "{txt}"' for spk, txt in merged_results] - return "\n".join(formatted) - - -def _extract_text_from_properties(file_content: bytes) -> str: - try: - text = _extract_text_from_plain_text(file_content) - lines = text.splitlines() - result = [] - for line in lines: - line = line.strip() - # Preserve comments and empty lines - if not line or line.startswith("#") or line.startswith("!"): - result.append(line) - continue - - if "=" in line: - key, value = line.split("=", 1) - elif ":" in line: - key, value = line.split(":", 1) - else: - key, value = line, "" - - result.append(f"{key.strip()}: {value.strip()}") - - return "\n".join(result) - except Exception as e: - raise TextExtractionError(f"Failed to extract text from properties file: {str(e)}") from e diff --git a/api/dify_graph/nodes/end/end_node.py b/api/dify_graph/nodes/end/end_node.py deleted file mode 100644 index 1f5cfab22bc..00000000000 --- a/api/dify_graph/nodes/end/end_node.py +++ /dev/null @@ -1,47 +0,0 @@ -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.base.template import Template -from dify_graph.nodes.end.entities import EndNodeData - - -class EndNode(Node[EndNodeData]): - node_type = BuiltinNodeTypes.END - execution_type = NodeExecutionType.RESPONSE - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - """ - Run node - collect all outputs at once. - - This method runs after streaming is complete (if streaming was enabled). - It collects all output variables and returns them. - """ - output_variables = self.node_data.outputs - - outputs = {} - for variable_selector in output_variables: - variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) - value = variable.to_object() if variable is not None else None - outputs[variable_selector.variable] = value - - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=outputs, - outputs=outputs, - ) - - def get_streaming_template(self) -> Template: - """ - Get the template for streaming. - - Returns: - Template instance for this End node - """ - outputs_config = [ - {"variable": output.variable, "value_selector": output.value_selector} for output in self.node_data.outputs - ] - return Template.from_end_outputs(outputs_config) diff --git a/api/dify_graph/nodes/end/entities.py b/api/dify_graph/nodes/end/entities.py deleted file mode 100644 index be7f0c8de8a..00000000000 --- a/api/dify_graph/nodes/end/entities.py +++ /dev/null @@ -1,27 +0,0 @@ -from pydantic import BaseModel, Field - -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base.entities import OutputVariableEntity - - -class EndNodeData(BaseNodeData): - """ - END Node Data. - """ - - type: NodeType = BuiltinNodeTypes.END - outputs: list[OutputVariableEntity] - - -class EndStreamParam(BaseModel): - """ - EndStreamParam entity - """ - - end_dependencies: dict[str, list[str]] = Field( - ..., description="end dependencies (end node id -> dependent node ids)" - ) - end_stream_variable_selector_mapping: dict[str, list[list[str]]] = Field( - ..., description="end stream variable selector mapping (end node id -> stream variable selectors)" - ) diff --git a/api/dify_graph/nodes/http_request/__init__.py b/api/dify_graph/nodes/http_request/__init__.py deleted file mode 100644 index b29099db230..00000000000 --- a/api/dify_graph/nodes/http_request/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -from .config import build_http_request_config, resolve_http_request_config -from .entities import ( - HTTP_REQUEST_CONFIG_FILTER_KEY, - BodyData, - HttpRequestNodeAuthorization, - HttpRequestNodeBody, - HttpRequestNodeConfig, - HttpRequestNodeData, -) -from .node import HttpRequestNode - -__all__ = [ - "HTTP_REQUEST_CONFIG_FILTER_KEY", - "BodyData", - "HttpRequestNode", - "HttpRequestNodeAuthorization", - "HttpRequestNodeBody", - "HttpRequestNodeConfig", - "HttpRequestNodeData", - "build_http_request_config", - "resolve_http_request_config", -] diff --git a/api/dify_graph/nodes/http_request/config.py b/api/dify_graph/nodes/http_request/config.py deleted file mode 100644 index 53bf6c7ae4c..00000000000 --- a/api/dify_graph/nodes/http_request/config.py +++ /dev/null @@ -1,33 +0,0 @@ -from collections.abc import Mapping - -from .entities import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNodeConfig - - -def build_http_request_config( - *, - max_connect_timeout: int = 10, - max_read_timeout: int = 600, - max_write_timeout: int = 600, - max_binary_size: int = 10 * 1024 * 1024, - max_text_size: int = 1 * 1024 * 1024, - ssl_verify: bool = True, - ssrf_default_max_retries: int = 3, -) -> HttpRequestNodeConfig: - return HttpRequestNodeConfig( - max_connect_timeout=max_connect_timeout, - max_read_timeout=max_read_timeout, - max_write_timeout=max_write_timeout, - max_binary_size=max_binary_size, - max_text_size=max_text_size, - ssl_verify=ssl_verify, - ssrf_default_max_retries=ssrf_default_max_retries, - ) - - -def resolve_http_request_config(filters: Mapping[str, object] | None) -> HttpRequestNodeConfig: - if not filters: - raise ValueError("http_request_config is required to build HTTP request default config") - config = filters.get(HTTP_REQUEST_CONFIG_FILTER_KEY) - if not isinstance(config, HttpRequestNodeConfig): - raise ValueError("http_request_config must be an HttpRequestNodeConfig instance") - return config diff --git a/api/dify_graph/nodes/http_request/entities.py b/api/dify_graph/nodes/http_request/entities.py deleted file mode 100644 index f594d58ae64..00000000000 --- a/api/dify_graph/nodes/http_request/entities.py +++ /dev/null @@ -1,241 +0,0 @@ -import mimetypes -from collections.abc import Sequence -from dataclasses import dataclass -from email.message import Message -from typing import Any, Literal - -import charset_normalizer -import httpx -from pydantic import BaseModel, Field, ValidationInfo, field_validator - -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType - -HTTP_REQUEST_CONFIG_FILTER_KEY = "http_request_config" - - -class HttpRequestNodeAuthorizationConfig(BaseModel): - type: Literal["basic", "bearer", "custom"] - api_key: str - header: str = "" - - -class HttpRequestNodeAuthorization(BaseModel): - type: Literal["no-auth", "api-key"] - config: HttpRequestNodeAuthorizationConfig | None = None - - @field_validator("config", mode="before") - @classmethod - def check_config(cls, v: HttpRequestNodeAuthorizationConfig, values: ValidationInfo): - """ - Check config, if type is no-auth, config should be None, otherwise it should be a dict. - """ - if values.data["type"] == "no-auth": - return None - else: - if not v or not isinstance(v, dict): - raise ValueError("config should be a dict") - - return v - - -class BodyData(BaseModel): - key: str = "" - type: Literal["file", "text"] - value: str = "" - file: Sequence[str] = Field(default_factory=list) - - -class HttpRequestNodeBody(BaseModel): - type: Literal["none", "form-data", "x-www-form-urlencoded", "raw-text", "json", "binary"] - data: Sequence[BodyData] = Field(default_factory=list) - - @field_validator("data", mode="before") - @classmethod - def check_data(cls, v: Any): - """For compatibility, if body is not set, return empty list.""" - if not v: - return [] - if isinstance(v, str): - return [BodyData(key="", type="text", value=v)] - return v - - -class HttpRequestNodeTimeout(BaseModel): - connect: int | None = None - read: int | None = None - write: int | None = None - - -@dataclass(frozen=True, slots=True) -class HttpRequestNodeConfig: - max_connect_timeout: int - max_read_timeout: int - max_write_timeout: int - max_binary_size: int - max_text_size: int - ssl_verify: bool - ssrf_default_max_retries: int - - def default_timeout(self) -> "HttpRequestNodeTimeout": - return HttpRequestNodeTimeout( - connect=self.max_connect_timeout, - read=self.max_read_timeout, - write=self.max_write_timeout, - ) - - -class HttpRequestNodeData(BaseNodeData): - """ - Code Node Data. - """ - - type: NodeType = BuiltinNodeTypes.HTTP_REQUEST - method: Literal[ - "get", - "post", - "put", - "patch", - "delete", - "head", - "options", - "GET", - "POST", - "PUT", - "PATCH", - "DELETE", - "HEAD", - "OPTIONS", - ] - url: str - authorization: HttpRequestNodeAuthorization - headers: str - params: str - body: HttpRequestNodeBody | None = None - timeout: HttpRequestNodeTimeout | None = None - ssl_verify: bool | None = None - - -class Response: - headers: dict[str, str] - response: httpx.Response - _cached_text: str | None - - def __init__(self, response: httpx.Response): - self.response = response - self.headers = dict(response.headers) - self._cached_text = None - - @property - def is_file(self): - """ - Determine if the response contains a file by checking: - 1. Content-Disposition header (RFC 6266) - 2. Content characteristics - 3. MIME type analysis - """ - content_type = self.content_type.split(";")[0].strip().lower() - parsed_content_disposition = self.parsed_content_disposition - - # Check if it's explicitly marked as an attachment - if parsed_content_disposition: - disp_type = parsed_content_disposition.get_content_disposition() # Returns 'attachment', 'inline', or None - filename = parsed_content_disposition.get_filename() # Returns filename if present, None otherwise - if disp_type == "attachment" or filename is not None: - return True - - # For 'text/' types, only 'csv' should be downloaded as file - if content_type.startswith("text/") and "csv" not in content_type: - return False - - # For application types, try to detect if it's a text-based format - if content_type.startswith("application/"): - # Common text-based application types - if any( - text_type in content_type - for text_type in ("json", "xml", "javascript", "x-www-form-urlencoded", "yaml", "graphql") - ): - return False - - # Try to detect if content is text-based by sampling first few bytes - try: - # Sample first 1024 bytes for text detection - content_sample = self.response.content[:1024] - content_sample.decode("utf-8") - # If we can decode as UTF-8 and find common text patterns, likely not a file - text_markers = (b"{", b"[", b"<", b"function", b"var ", b"const ", b"let ") - if any(marker in content_sample for marker in text_markers): - return False - except UnicodeDecodeError: - # If we can't decode as UTF-8, likely a binary file - return True - - # For other types, use MIME type analysis - main_type, _ = mimetypes.guess_type("dummy" + (mimetypes.guess_extension(content_type) or "")) - if main_type: - return main_type.split("/")[0] in ("application", "image", "audio", "video") - - # For unknown types, check if it's a media type - return any(media_type in content_type for media_type in ("image/", "audio/", "video/")) - - @property - def content_type(self) -> str: - return self.headers.get("content-type", "") - - @property - def text(self) -> str: - """ - Get response text with robust encoding detection. - - Uses charset_normalizer for better encoding detection than httpx's default, - which helps handle Chinese and other non-ASCII characters properly. - """ - # Check cache first - if hasattr(self, "_cached_text") and self._cached_text is not None: - return self._cached_text - - # Try charset_normalizer for robust encoding detection first - detected_encoding = charset_normalizer.from_bytes(self.response.content).best() - if detected_encoding and detected_encoding.encoding: - try: - text = self.response.content.decode(detected_encoding.encoding) - self._cached_text = text - return text - except (UnicodeDecodeError, TypeError, LookupError): - # Fallback to httpx's encoding detection if charset_normalizer fails - pass - - # Fallback to httpx's built-in encoding detection - text = self.response.text - self._cached_text = text - return text - - @property - def content(self) -> bytes: - return self.response.content - - @property - def status_code(self) -> int: - return self.response.status_code - - @property - def size(self) -> int: - return len(self.content) - - @property - def readable_size(self) -> str: - if self.size < 1024: - return f"{self.size} bytes" - elif self.size < 1024 * 1024: - return f"{(self.size / 1024):.2f} KB" - else: - return f"{(self.size / 1024 / 1024):.2f} MB" - - @property - def parsed_content_disposition(self) -> Message | None: - content_disposition = self.headers.get("content-disposition", "") - if content_disposition: - msg = Message() - msg["content-disposition"] = content_disposition - return msg - return None diff --git a/api/dify_graph/nodes/http_request/exc.py b/api/dify_graph/nodes/http_request/exc.py deleted file mode 100644 index 46613c9e861..00000000000 --- a/api/dify_graph/nodes/http_request/exc.py +++ /dev/null @@ -1,26 +0,0 @@ -class HttpRequestNodeError(ValueError): - """Custom error for HTTP request node.""" - - -class AuthorizationConfigError(HttpRequestNodeError): - """Raised when authorization config is missing or invalid.""" - - -class FileFetchError(HttpRequestNodeError): - """Raised when a file cannot be fetched.""" - - -class InvalidHttpMethodError(HttpRequestNodeError): - """Raised when an invalid HTTP method is used.""" - - -class ResponseSizeError(HttpRequestNodeError): - """Raised when the response size exceeds the allowed threshold.""" - - -class RequestBodyError(HttpRequestNodeError): - """Raised when the request body is invalid.""" - - -class InvalidURLError(HttpRequestNodeError): - """Raised when the URL is invalid.""" diff --git a/api/dify_graph/nodes/http_request/executor.py b/api/dify_graph/nodes/http_request/executor.py deleted file mode 100644 index 892b0fc688a..00000000000 --- a/api/dify_graph/nodes/http_request/executor.py +++ /dev/null @@ -1,488 +0,0 @@ -import base64 -import json -import secrets -import string -from collections.abc import Callable, Mapping -from copy import deepcopy -from typing import Any, Literal -from urllib.parse import urlencode, urlparse - -import httpx -from json_repair import repair_json - -from dify_graph.file.enums import FileTransferMethod -from dify_graph.runtime import VariablePool -from dify_graph.variables.segments import ArrayFileSegment, FileSegment - -from ..protocols import FileManagerProtocol, HttpClientProtocol -from .entities import ( - HttpRequestNodeAuthorization, - HttpRequestNodeConfig, - HttpRequestNodeData, - HttpRequestNodeTimeout, - Response, -) -from .exc import ( - AuthorizationConfigError, - FileFetchError, - HttpRequestNodeError, - InvalidHttpMethodError, - InvalidURLError, - RequestBodyError, - ResponseSizeError, -) - -BODY_TYPE_TO_CONTENT_TYPE = { - "json": "application/json", - "x-www-form-urlencoded": "application/x-www-form-urlencoded", - "form-data": "multipart/form-data", - "raw-text": "text/plain", -} - - -class Executor: - method: Literal[ - "get", - "head", - "post", - "put", - "delete", - "patch", - "options", - "GET", - "POST", - "PUT", - "PATCH", - "DELETE", - "HEAD", - "OPTIONS", - ] - url: str - params: list[tuple[str, str]] | None - content: str | bytes | None - data: Mapping[str, Any] | None - files: list[tuple[str, tuple[str | None, bytes, str]]] | None - json: Any - headers: dict[str, str] - auth: HttpRequestNodeAuthorization - timeout: HttpRequestNodeTimeout - max_retries: int - - boundary: str - - def __init__( - self, - *, - node_data: HttpRequestNodeData, - timeout: HttpRequestNodeTimeout, - variable_pool: VariablePool, - http_request_config: HttpRequestNodeConfig, - max_retries: int | None = None, - ssl_verify: bool | None = None, - http_client: HttpClientProtocol, - file_manager: FileManagerProtocol, - ): - self._http_request_config = http_request_config - # If authorization API key is present, convert the API key using the variable pool - if node_data.authorization.type == "api-key": - if node_data.authorization.config is None: - raise AuthorizationConfigError("authorization config is required") - node_data.authorization.config.api_key = variable_pool.convert_template( - node_data.authorization.config.api_key - ).text - # Validate that API key is not empty after template conversion - if not node_data.authorization.config.api_key or not node_data.authorization.config.api_key.strip(): - raise AuthorizationConfigError( - "API key is required for authorization but was empty. Please provide a valid API key." - ) - - self.url = node_data.url - self.method = node_data.method - self.auth = node_data.authorization - self.timeout = timeout - self.ssl_verify = ssl_verify if ssl_verify is not None else node_data.ssl_verify - if self.ssl_verify is None: - self.ssl_verify = self._http_request_config.ssl_verify - if not isinstance(self.ssl_verify, bool): - raise ValueError("ssl_verify must be a boolean") - self.params = None - self.headers = {} - self.content = None - self.files = None - self.data = None - self.json = None - self.max_retries = ( - max_retries if max_retries is not None else self._http_request_config.ssrf_default_max_retries - ) - self._http_client = http_client - self._file_manager = file_manager - - # init template - self.variable_pool = variable_pool - self.node_data = node_data - self._initialize() - - def _initialize(self): - self._init_url() - self._init_params() - self._init_headers() - self._init_body() - - def _init_url(self): - self.url = self.variable_pool.convert_template(self.node_data.url).text - - # check if url is a valid URL - if not self.url: - raise InvalidURLError("url is required") - if not self.url.startswith(("http://", "https://")): - raise InvalidURLError("url should start with http:// or https://") - - def _init_params(self): - """ - Almost same as _init_headers(), difference: - 1. response a list tuple to support same key, like 'aa=1&aa=2' - 2. param value may have '\n', we need to splitlines then extract the variable value. - """ - result = [] - for line in self.node_data.params.splitlines(): - if not (line := line.strip()): - continue - - key, *value = line.split(":", 1) - if not (key := key.strip()): - continue - - value_str = value[0].strip() if value else "" - result.append( - (self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value_str).text) - ) - - if result: - self.params = result - - def _init_headers(self): - """ - Convert the header string of frontend to a dictionary. - - Each line in the header string represents a key-value pair. - Keys and values are separated by ':'. - Empty values are allowed. - - Examples: - 'aa:bb\n cc:dd' -> {'aa': 'bb', 'cc': 'dd'} - 'aa:\n cc:dd\n' -> {'aa': '', 'cc': 'dd'} - 'aa\n cc : dd' -> {'aa': '', 'cc': 'dd'} - - """ - headers = self.variable_pool.convert_template(self.node_data.headers).text - self.headers = { - key.strip(): (value[0].strip() if value else "") - for line in headers.splitlines() - if line.strip() - for key, *value in [line.split(":", 1)] - } - - def _init_body(self): - body = self.node_data.body - if body is not None: - data = body.data - match body.type: - case "none": - self.content = "" - case "raw-text": - if len(data) != 1: - raise RequestBodyError("raw-text body type should have exactly one item") - self.content = self.variable_pool.convert_template(data[0].value).text - case "json": - if len(data) != 1: - raise RequestBodyError("json body type should have exactly one item") - json_string = self.variable_pool.convert_template(data[0].value).text - try: - repaired = repair_json(json_string) - json_object = json.loads(repaired, strict=False) - except json.JSONDecodeError as e: - raise RequestBodyError(f"Failed to parse JSON: {json_string}") from e - self.json = json_object - # self.json = self._parse_object_contains_variables(json_object) - case "binary": - if len(data) != 1: - raise RequestBodyError("binary body type should have exactly one item") - file_selector = data[0].file - file_variable = self.variable_pool.get_file(file_selector) - if file_variable is None: - raise FileFetchError(f"cannot fetch file with selector {file_selector}") - file = file_variable.value - self.content = self._file_manager.download(file) - case "x-www-form-urlencoded": - form_data = { - self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template( - item.value - ).text - for item in data - } - self.data = form_data - case "form-data": - form_data = { - self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template( - item.value - ).text - for item in filter(lambda item: item.type == "text", data) - } - file_selectors = { - self.variable_pool.convert_template(item.key).text: item.file - for item in filter(lambda item: item.type == "file", data) - } - - # get files from file_selectors, add support for array file variables - files_list = [] - for key, selector in file_selectors.items(): - segment = self.variable_pool.get(selector) - if isinstance(segment, FileSegment): - files_list.append((key, [segment.value])) - elif isinstance(segment, ArrayFileSegment): - files_list.append((key, list(segment.value))) - - # get files from file_manager - files: dict[str, list[tuple[str | None, bytes, str]]] = {} - for key, files_in_segment in files_list: - for file in files_in_segment: - if file.related_id is not None or ( - file.transfer_method == FileTransferMethod.REMOTE_URL and file.remote_url is not None - ): - file_tuple = ( - file.filename, - self._file_manager.download(file), - file.mime_type or "application/octet-stream", - ) - if key not in files: - files[key] = [] - files[key].append(file_tuple) - - # convert files to list for httpx request - # If there are no actual files, we still need to force httpx to use `multipart/form-data`. - # This is achieved by inserting a harmless placeholder file that will be ignored by the server. - if not files: - self.files = [("__multipart_placeholder__", ("", b"", "application/octet-stream"))] - if files: - self.files = [] - for key, file_tuples in files.items(): - for file_tuple in file_tuples: - self.files.append((key, file_tuple)) - - self.data = form_data - - def _assembling_headers(self) -> dict[str, Any]: - authorization = deepcopy(self.auth) - headers = deepcopy(self.headers) or {} - if self.auth.type == "api-key": - if self.auth.config is None: - raise AuthorizationConfigError("self.authorization config is required") - if authorization.config is None: - raise AuthorizationConfigError("authorization config is required") - - if not authorization.config.header: - authorization.config.header = "Authorization" - - if self.auth.config.type == "bearer" and authorization.config.api_key: - headers[authorization.config.header] = f"Bearer {authorization.config.api_key}" - elif self.auth.config.type == "basic" and authorization.config.api_key: - credentials = authorization.config.api_key - if ":" in credentials: - encoded_credentials = base64.b64encode(credentials.encode("utf-8")).decode("utf-8") - else: - encoded_credentials = credentials - headers[authorization.config.header] = f"Basic {encoded_credentials}" - elif self.auth.config.type == "custom": - if authorization.config.header and authorization.config.api_key: - headers[authorization.config.header] = authorization.config.api_key - - # Handle Content-Type for multipart/form-data requests - # Fix for issue #23829: Missing boundary when using multipart/form-data - body = self.node_data.body - if body and body.type == "form-data": - # For multipart/form-data with files (including placeholder files), - # remove any manually set Content-Type header to let httpx handle - # For multipart/form-data, if any files are present (including placeholder files), - # we must remove any manually set Content-Type header. This is because httpx needs to - # automatically set the Content-Type and boundary for multipart encoding whenever files - # are included, even if they are placeholders, to avoid boundary issues and ensure correct - # file upload behaviour. Manually setting Content-Type can cause httpx to fail to set the - # boundary, resulting in invalid requests. - if self.files: - # Remove Content-Type if it was manually set to avoid boundary issues - headers = {k: v for k, v in headers.items() if k.lower() != "content-type"} - else: - # No files at all, set Content-Type manually - if "content-type" not in (k.lower() for k in headers): - headers["Content-Type"] = "multipart/form-data" - elif body and body.type in BODY_TYPE_TO_CONTENT_TYPE: - # Set Content-Type for other body types - if "content-type" not in (k.lower() for k in headers): - headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type] - - return headers - - def _validate_and_parse_response(self, response: httpx.Response) -> Response: - executor_response = Response(response) - - threshold_size = ( - self._http_request_config.max_binary_size - if executor_response.is_file - else self._http_request_config.max_text_size - ) - if executor_response.size > threshold_size: - raise ResponseSizeError( - f"{'File' if executor_response.is_file else 'Text'} size is too large," - f" max size is {threshold_size / 1024 / 1024:.2f} MB," - f" but current size is {executor_response.readable_size}." - ) - - return executor_response - - def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response: - """ - do http request depending on api bundle - """ - _METHOD_MAP: dict[str, Callable[..., httpx.Response]] = { - "get": self._http_client.get, - "head": self._http_client.head, - "post": self._http_client.post, - "put": self._http_client.put, - "delete": self._http_client.delete, - "patch": self._http_client.patch, - } - method_lc = self.method.lower() - if method_lc not in _METHOD_MAP: - raise InvalidHttpMethodError(f"Invalid http method {self.method}") - - request_args: dict[str, Any] = { - "data": self.data, - "files": self.files, - "json": self.json, - "content": self.content, - "headers": headers, - "params": self.params, - "timeout": (self.timeout.connect, self.timeout.read, self.timeout.write), - "ssl_verify": self.ssl_verify, - "follow_redirects": True, - } - # request_args = {k: v for k, v in request_args.items() if v is not None} - try: - response = _METHOD_MAP[method_lc]( - url=self.url, - **request_args, - max_retries=self.max_retries, - ) - except self._http_client.max_retries_exceeded_error as e: - raise HttpRequestNodeError(f"Reached maximum retries for URL {self.url}") from e - except self._http_client.request_error as e: - raise HttpRequestNodeError(str(e)) from e - return response - - def invoke(self) -> Response: - # assemble headers - headers = self._assembling_headers() - # do http request - response = self._do_http_request(headers) - # validate response - return self._validate_and_parse_response(response) - - def to_log(self): - url_parts = urlparse(self.url) - path = url_parts.path or "/" - - # Add query parameters - if self.params: - query_string = urlencode(self.params) - path += f"?{query_string}" - elif url_parts.query: - path += f"?{url_parts.query}" - - raw = f"{self.method.upper()} {path} HTTP/1.1\r\n" - raw += f"Host: {url_parts.netloc}\r\n" - - headers = self._assembling_headers() - body = self.node_data.body - boundary = f"----WebKitFormBoundary{_generate_random_string(16)}" - if body: - if "content-type" not in (k.lower() for k in self.headers) and body.type in BODY_TYPE_TO_CONTENT_TYPE: - headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type] - if body.type == "form-data": - headers["Content-Type"] = f"multipart/form-data; boundary={boundary}" - for k, v in headers.items(): - if self.auth.type == "api-key": - authorization_header = "Authorization" - if self.auth.config and self.auth.config.header: - authorization_header = self.auth.config.header - if k.lower() == authorization_header.lower(): - raw += f"{k}: {'*' * len(v)}\r\n" - continue - raw += f"{k}: {v}\r\n" - - body_string = "" - # Only log actual files if present. - # '__multipart_placeholder__' is inserted to force multipart encoding but is not a real file. - # This prevents logging meaningless placeholder entries. - if self.files and not all(f[0] == "__multipart_placeholder__" for f in self.files): - for file_entry in self.files: - # file_entry should be (key, (filename, content, mime_type)), but handle edge cases - if len(file_entry) != 2 or len(file_entry[1]) < 2: - continue # skip malformed entries - key = file_entry[0] - content = file_entry[1][1] - body_string += f"--{boundary}\r\n" - body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n' - # decode content safely - # Do not decode binary content; use a placeholder with file metadata instead. - # Includes filename, size, and MIME type for better logging context. - body_string += ( - f"\r\n" - ) - body_string += f"--{boundary}--\r\n" - elif self.node_data.body: - if self.content: - # If content is bytes, do not decode it; show a placeholder with size. - # Provides content size information for binary data without exposing the raw bytes. - if isinstance(self.content, bytes): - body_string = f"" - else: - body_string = self.content - elif self.data and self.node_data.body.type == "x-www-form-urlencoded": - body_string = urlencode(self.data) - elif self.data and self.node_data.body.type == "form-data": - for key, value in self.data.items(): - body_string += f"--{boundary}\r\n" - body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n' - body_string += f"{value}\r\n" - body_string += f"--{boundary}--\r\n" - elif self.json: - body_string = json.dumps(self.json) - elif self.node_data.body.type == "raw-text": - if len(self.node_data.body.data) != 1: - raise RequestBodyError("raw-text body type should have exactly one item") - body_string = self.node_data.body.data[0].value - if body_string: - raw += f"Content-Length: {len(body_string)}\r\n" - raw += "\r\n" # Empty line between headers and body - raw += body_string - - return raw - - -def _generate_random_string(n: int) -> str: - """ - Generate a random string of lowercase ASCII letters. - - Args: - n (int): The length of the random string to generate. - - Returns: - str: A random string of lowercase ASCII letters with length n. - - Example: - >>> _generate_random_string(5) - 'abcde' - """ - return "".join(secrets.choice(string.ascii_lowercase) for _ in range(n)) diff --git a/api/dify_graph/nodes/http_request/node.py b/api/dify_graph/nodes/http_request/node.py deleted file mode 100644 index 486ae241ee0..00000000000 --- a/api/dify_graph/nodes/http_request/node.py +++ /dev/null @@ -1,258 +0,0 @@ -import logging -import mimetypes -from collections.abc import Callable, Mapping, Sequence -from typing import TYPE_CHECKING, Any - -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.file import File, FileTransferMethod -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base import variable_template_parser -from dify_graph.nodes.base.entities import VariableSelector -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.http_request.executor import Executor -from dify_graph.nodes.protocols import FileManagerProtocol, HttpClientProtocol, ToolFileManagerProtocol -from dify_graph.variables.segments import ArrayFileSegment -from factories import file_factory - -from .config import build_http_request_config, resolve_http_request_config -from .entities import ( - HTTP_REQUEST_CONFIG_FILTER_KEY, - HttpRequestNodeConfig, - HttpRequestNodeData, - HttpRequestNodeTimeout, - Response, -) -from .exc import HttpRequestNodeError, RequestBodyError - -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState - - -class HttpRequestNode(Node[HttpRequestNodeData]): - node_type = BuiltinNodeTypes.HTTP_REQUEST - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - *, - http_request_config: HttpRequestNodeConfig, - http_client: HttpClientProtocol, - tool_file_manager_factory: Callable[[], ToolFileManagerProtocol], - file_manager: FileManagerProtocol, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - self._http_request_config = http_request_config - self._http_client = http_client - self._tool_file_manager_factory = tool_file_manager_factory - self._file_manager = file_manager - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - if not filters or HTTP_REQUEST_CONFIG_FILTER_KEY not in filters: - http_request_config = build_http_request_config() - else: - http_request_config = resolve_http_request_config(filters) - default_timeout = http_request_config.default_timeout() - return { - "type": "http-request", - "config": { - "method": "get", - "authorization": { - "type": "no-auth", - }, - "body": {"type": "none"}, - "timeout": { - **default_timeout.model_dump(), - "max_connect_timeout": http_request_config.max_connect_timeout, - "max_read_timeout": http_request_config.max_read_timeout, - "max_write_timeout": http_request_config.max_write_timeout, - }, - "ssl_verify": http_request_config.ssl_verify, - }, - "retry_config": { - "max_retries": http_request_config.ssrf_default_max_retries, - "retry_interval": 0.5 * (2**2), - "retry_enabled": True, - }, - } - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - process_data = {} - try: - http_executor = Executor( - node_data=self.node_data, - timeout=self._get_request_timeout(self.node_data), - variable_pool=self.graph_runtime_state.variable_pool, - http_request_config=self._http_request_config, - ssl_verify=self.node_data.ssl_verify, - http_client=self._http_client, - file_manager=self._file_manager, - ) - process_data["request"] = http_executor.to_log() - - response = http_executor.invoke() - files = self.extract_files(url=http_executor.url, response=response) - if not response.response.is_success and (self.error_strategy or self.retry): - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - outputs={ - "status_code": response.status_code, - "body": response.text if not files.value else "", - "headers": response.headers, - "files": files, - }, - process_data={ - "request": http_executor.to_log(), - }, - error=f"Request failed with status code {response.status_code}", - error_type="HTTPResponseCodeError", - ) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={ - "status_code": response.status_code, - "body": response.text if not files.value else "", - "headers": response.headers, - "files": files, - }, - process_data={ - "request": http_executor.to_log(), - }, - ) - except HttpRequestNodeError as e: - logger.warning("http request node %s failed to run: %s", self._node_id, e) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - process_data=process_data, - error_type=type(e).__name__, - ) - - def _get_request_timeout(self, node_data: HttpRequestNodeData) -> HttpRequestNodeTimeout: - default_timeout = self._http_request_config.default_timeout() - timeout = node_data.timeout - if timeout is None: - return default_timeout - - return HttpRequestNodeTimeout( - connect=timeout.connect or default_timeout.connect, - read=timeout.read or default_timeout.read, - write=timeout.write or default_timeout.write, - ) - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: HttpRequestNodeData, - ) -> Mapping[str, Sequence[str]]: - selectors: list[VariableSelector] = [] - selectors += variable_template_parser.extract_selectors_from_template(node_data.url) - selectors += variable_template_parser.extract_selectors_from_template(node_data.headers) - selectors += variable_template_parser.extract_selectors_from_template(node_data.params) - if node_data.body: - body_type = node_data.body.type - data = node_data.body.data - match body_type: - case "none": - pass - case "binary": - if len(data) != 1: - raise RequestBodyError("invalid body data, should have only one item") - selector = data[0].file - selectors.append(VariableSelector(variable="#" + ".".join(selector) + "#", value_selector=selector)) - case "json" | "raw-text": - if len(data) != 1: - raise RequestBodyError("invalid body data, should have only one item") - selectors += variable_template_parser.extract_selectors_from_template(data[0].key) - selectors += variable_template_parser.extract_selectors_from_template(data[0].value) - case "x-www-form-urlencoded": - for item in data: - selectors += variable_template_parser.extract_selectors_from_template(item.key) - selectors += variable_template_parser.extract_selectors_from_template(item.value) - case "form-data": - for item in data: - selectors += variable_template_parser.extract_selectors_from_template(item.key) - if item.type == "text": - selectors += variable_template_parser.extract_selectors_from_template(item.value) - elif item.type == "file": - selectors.append( - VariableSelector(variable="#" + ".".join(item.file) + "#", value_selector=item.file) - ) - - mapping = {} - for selector_iter in selectors: - mapping[node_id + "." + selector_iter.variable] = selector_iter.value_selector - - return mapping - - def extract_files(self, url: str, response: Response) -> ArrayFileSegment: - """ - Extract files from response by checking both Content-Type header and URL - """ - dify_ctx = self.require_dify_context() - files: list[File] = [] - is_file = response.is_file - content_type = response.content_type - content = response.content - parsed_content_disposition = response.parsed_content_disposition - content_disposition_type = None - - if not is_file: - return ArrayFileSegment(value=[]) - - if parsed_content_disposition: - content_disposition_filename = parsed_content_disposition.get_filename() - if content_disposition_filename: - # If filename is available from content-disposition, use it to guess the content type - content_disposition_type = mimetypes.guess_type(content_disposition_filename)[0] - - # Guess file extension from URL or Content-Type header - filename = url.split("?")[0].split("/")[-1] or "" - mime_type = ( - content_disposition_type or content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream" - ) - tool_file_manager = self._tool_file_manager_factory() - - tool_file = tool_file_manager.create_file_by_raw( - user_id=dify_ctx.user_id, - tenant_id=dify_ctx.tenant_id, - conversation_id=None, - file_binary=content, - mimetype=mime_type, - ) - - mapping = { - "tool_file_id": tool_file.id, - "transfer_method": FileTransferMethod.TOOL_FILE, - } - file = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=dify_ctx.tenant_id, - ) - files.append(file) - - return ArrayFileSegment(value=files) - - @property - def retry(self) -> bool: - return self.node_data.retry_config.retry_enabled diff --git a/api/dify_graph/nodes/human_input/__init__.py b/api/dify_graph/nodes/human_input/__init__.py deleted file mode 100644 index 17896045779..00000000000 --- a/api/dify_graph/nodes/human_input/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -Human Input node implementation. -""" diff --git a/api/dify_graph/nodes/human_input/entities.py b/api/dify_graph/nodes/human_input/entities.py deleted file mode 100644 index 2a33b4a0a8c..00000000000 --- a/api/dify_graph/nodes/human_input/entities.py +++ /dev/null @@ -1,424 +0,0 @@ -""" -Human Input node entities. -""" - -import re -import uuid -from collections.abc import Mapping, Sequence -from datetime import datetime, timedelta -from typing import Annotated, Any, ClassVar, Literal, Self - -import bleach -import markdown -from pydantic import BaseModel, Field, field_validator, model_validator - -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser -from dify_graph.runtime import VariablePool -from dify_graph.variables.consts import SELECTORS_LENGTH - -from .enums import ButtonStyle, DeliveryMethodType, EmailRecipientType, FormInputType, PlaceholderType, TimeoutUnit - -_OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$output\.(?P[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}") - - -class _WebAppDeliveryConfig(BaseModel): - """Configuration for webapp delivery method.""" - - pass # Empty for webapp delivery - - -class MemberRecipient(BaseModel): - """Member recipient for email delivery.""" - - type: Literal[EmailRecipientType.MEMBER] = EmailRecipientType.MEMBER - user_id: str - - -class ExternalRecipient(BaseModel): - """External recipient for email delivery.""" - - type: Literal[EmailRecipientType.EXTERNAL] = EmailRecipientType.EXTERNAL - email: str - - -EmailRecipient = Annotated[MemberRecipient | ExternalRecipient, Field(discriminator="type")] - - -class EmailRecipients(BaseModel): - """Email recipients configuration.""" - - # When true, recipients are the union of all workspace members and external items. - # Member items are ignored because they are already covered by the workspace scope. - # De-duplication is applied by email, with member recipients taking precedence. - whole_workspace: bool = False - items: list[EmailRecipient] = Field(default_factory=list) - - -class EmailDeliveryConfig(BaseModel): - """Configuration for email delivery method.""" - - URL_PLACEHOLDER: ClassVar[str] = "{{#url#}}" - _SUBJECT_NEWLINE_PATTERN: ClassVar[re.Pattern[str]] = re.compile(r"[\r\n]+") - _ALLOWED_HTML_TAGS: ClassVar[list[str]] = [ - "a", - "blockquote", - "br", - "code", - "em", - "h1", - "h2", - "h3", - "h4", - "h5", - "h6", - "hr", - "li", - "ol", - "p", - "pre", - "strong", - "table", - "tbody", - "td", - "th", - "thead", - "tr", - "ul", - ] - _ALLOWED_HTML_ATTRIBUTES: ClassVar[dict[str, list[str]]] = { - "a": ["href", "title"], - "td": ["align"], - "th": ["align"], - } - _ALLOWED_PROTOCOLS: ClassVar[list[str]] = ["http", "https", "mailto"] - - recipients: EmailRecipients - - # the subject of email - subject: str - - # Body is the content of email.It may contain the speical placeholder `{{#url#}}`, which - # represent the url to submit the form. - # - # It may also reference the output variable of the previous node with the syntax - # `{{#.#}}`. - body: str - debug_mode: bool = False - - def with_debug_recipient(self, user_id: str | None) -> "EmailDeliveryConfig": - if user_id is None: - debug_recipients = EmailRecipients(whole_workspace=False, items=[]) - return self.model_copy(update={"recipients": debug_recipients}) - debug_recipients = EmailRecipients(whole_workspace=False, items=[MemberRecipient(user_id=user_id)]) - return self.model_copy(update={"recipients": debug_recipients}) - - @classmethod - def replace_url_placeholder(cls, body: str, url: str | None) -> str: - """Replace the url placeholder with provided value.""" - return body.replace(cls.URL_PLACEHOLDER, url or "") - - @classmethod - def render_body_template( - cls, - *, - body: str, - url: str | None, - variable_pool: VariablePool | None = None, - ) -> str: - """Render email body by replacing placeholders with runtime values.""" - templated_body = cls.replace_url_placeholder(body, url) - if variable_pool is None: - return templated_body - return variable_pool.convert_template(templated_body).text - - @classmethod - def render_markdown_body(cls, body: str) -> str: - """Render markdown to safe HTML for email delivery.""" - sanitized_markdown = bleach.clean( - body, - tags=[], - attributes={}, - strip=True, - strip_comments=True, - ) - rendered_html = markdown.markdown( - sanitized_markdown, - extensions=["nl2br", "tables"], - extension_configs={"tables": {"use_align_attribute": True}}, - ) - return bleach.clean( - rendered_html, - tags=cls._ALLOWED_HTML_TAGS, - attributes=cls._ALLOWED_HTML_ATTRIBUTES, - protocols=cls._ALLOWED_PROTOCOLS, - strip=True, - strip_comments=True, - ) - - @classmethod - def sanitize_subject(cls, subject: str) -> str: - """Sanitize email subject to plain text and prevent CRLF injection.""" - sanitized_subject = bleach.clean( - subject, - tags=[], - attributes={}, - strip=True, - strip_comments=True, - ) - sanitized_subject = cls._SUBJECT_NEWLINE_PATTERN.sub(" ", sanitized_subject) - return " ".join(sanitized_subject.split()) - - -class _DeliveryMethodBase(BaseModel): - """Base delivery method configuration.""" - - enabled: bool = True - id: uuid.UUID = Field(default_factory=uuid.uuid4) - - def extract_variable_selectors(self) -> Sequence[Sequence[str]]: - return () - - -class WebAppDeliveryMethod(_DeliveryMethodBase): - """Webapp delivery method configuration.""" - - type: Literal[DeliveryMethodType.WEBAPP] = DeliveryMethodType.WEBAPP - # The config field is not used currently. - config: _WebAppDeliveryConfig = Field(default_factory=_WebAppDeliveryConfig) - - -class EmailDeliveryMethod(_DeliveryMethodBase): - """Email delivery method configuration.""" - - type: Literal[DeliveryMethodType.EMAIL] = DeliveryMethodType.EMAIL - config: EmailDeliveryConfig - - def extract_variable_selectors(self) -> Sequence[Sequence[str]]: - variable_template_parser = VariableTemplateParser(template=self.config.body) - selectors: list[Sequence[str]] = [] - for variable_selector in variable_template_parser.extract_variable_selectors(): - value_selector = list(variable_selector.value_selector) - if len(value_selector) < SELECTORS_LENGTH: - continue - selectors.append(value_selector[:SELECTORS_LENGTH]) - return selectors - - -DeliveryChannelConfig = Annotated[WebAppDeliveryMethod | EmailDeliveryMethod, Field(discriminator="type")] - - -def apply_debug_email_recipient( - method: DeliveryChannelConfig, - *, - enabled: bool, - user_id: str | None, -) -> DeliveryChannelConfig: - if not enabled: - return method - if not isinstance(method, EmailDeliveryMethod): - return method - if not method.config.debug_mode: - return method - debug_config = method.config.with_debug_recipient(user_id) - return method.model_copy(update={"config": debug_config}) - - -class FormInputDefault(BaseModel): - """Default configuration for form inputs.""" - - # NOTE: Ideally, a discriminated union would be used to model - # FormInputDefault. However, the UI requires preserving the previous - # value when switching between `VARIABLE` and `CONSTANT` types. This - # necessitates retaining all fields, making a discriminated union unsuitable. - - type: PlaceholderType - - # The selector of default variable, used when `type` is `VARIABLE`. - selector: Sequence[str] = Field(default_factory=tuple) # - - # The value of the default, used when `type` is `CONSTANT`. - # TODO: How should we express JSON values? - value: str = "" - - @model_validator(mode="after") - def _validate_selector(self) -> Self: - if self.type == PlaceholderType.CONSTANT: - return self - if len(self.selector) < SELECTORS_LENGTH: - raise ValueError(f"the length of selector should be at least {SELECTORS_LENGTH}, selector={self.selector}") - return self - - -class FormInput(BaseModel): - """Form input definition.""" - - type: FormInputType - output_variable_name: str - default: FormInputDefault | None = None - - -_IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") - - -class UserAction(BaseModel): - """User action configuration.""" - - # id is the identifier for this action. - # It also serves as the identifiers of output handle. - # - # The id must be a valid identifier (satisfy the _IDENTIFIER_PATTERN above.) - id: str = Field(max_length=20) - title: str = Field(max_length=20) - button_style: ButtonStyle = ButtonStyle.DEFAULT - - @field_validator("id") - @classmethod - def _validate_id(cls, value: str) -> str: - if not _IDENTIFIER_PATTERN.match(value): - raise ValueError( - f"'{value}' is not a valid identifier. It must start with a letter or underscore, " - f"and contain only letters, numbers, or underscores." - ) - return value - - -class HumanInputNodeData(BaseNodeData): - """Human Input node data.""" - - type: NodeType = BuiltinNodeTypes.HUMAN_INPUT - delivery_methods: list[DeliveryChannelConfig] = Field(default_factory=list) - form_content: str = "" - inputs: list[FormInput] = Field(default_factory=list) - user_actions: list[UserAction] = Field(default_factory=list) - timeout: int = 36 - timeout_unit: TimeoutUnit = TimeoutUnit.HOUR - - @field_validator("inputs") - @classmethod - def _validate_inputs(cls, inputs: list[FormInput]) -> list[FormInput]: - seen_names: set[str] = set() - for form_input in inputs: - name = form_input.output_variable_name - if name in seen_names: - raise ValueError(f"duplicated output_variable_name '{name}' in inputs") - seen_names.add(name) - return inputs - - @field_validator("user_actions") - @classmethod - def _validate_user_actions(cls, user_actions: list[UserAction]) -> list[UserAction]: - seen_ids: set[str] = set() - for action in user_actions: - action_id = action.id - if action_id in seen_ids: - raise ValueError(f"duplicated user action id '{action_id}'") - seen_ids.add(action_id) - return user_actions - - def is_webapp_enabled(self) -> bool: - for dm in self.delivery_methods: - if not dm.enabled: - continue - if dm.type == DeliveryMethodType.WEBAPP: - return True - return False - - def expiration_time(self, start_time: datetime) -> datetime: - if self.timeout_unit == TimeoutUnit.HOUR: - return start_time + timedelta(hours=self.timeout) - elif self.timeout_unit == TimeoutUnit.DAY: - return start_time + timedelta(days=self.timeout) - else: - raise AssertionError("unknown timeout unit.") - - def outputs_field_names(self) -> Sequence[str]: - field_names = [] - for match in _OUTPUT_VARIABLE_PATTERN.finditer(self.form_content): - field_names.append(match.group("field_name")) - return field_names - - def extract_variable_selector_to_variable_mapping(self, node_id: str) -> Mapping[str, Sequence[str]]: - variable_mappings: dict[str, Sequence[str]] = {} - - def _add_variable_selectors(selectors: Sequence[Sequence[str]]) -> None: - for selector in selectors: - if len(selector) < SELECTORS_LENGTH: - continue - qualified_variable_mapping_key = f"{node_id}.#{'.'.join(selector[:SELECTORS_LENGTH])}#" - variable_mappings[qualified_variable_mapping_key] = list(selector[:SELECTORS_LENGTH]) - - form_template_parser = VariableTemplateParser(template=self.form_content) - _add_variable_selectors( - [selector.value_selector for selector in form_template_parser.extract_variable_selectors()] - ) - for delivery_method in self.delivery_methods: - if not delivery_method.enabled: - continue - _add_variable_selectors(delivery_method.extract_variable_selectors()) - - for input in self.inputs: - default_value = input.default - if default_value is None: - continue - if default_value.type == PlaceholderType.CONSTANT: - continue - default_value_key = ".".join(default_value.selector) - qualified_variable_mapping_key = f"{node_id}.#{default_value_key}#" - variable_mappings[qualified_variable_mapping_key] = default_value.selector - - return variable_mappings - - def find_action_text(self, action_id: str) -> str: - """ - Resolve action display text by id. - """ - for action in self.user_actions: - if action.id == action_id: - return action.title - return action_id - - -class FormDefinition(BaseModel): - form_content: str - inputs: list[FormInput] = Field(default_factory=list) - user_actions: list[UserAction] = Field(default_factory=list) - rendered_content: str - expiration_time: datetime - - # this is used to store the resolved default values - default_values: dict[str, Any] = Field(default_factory=dict) - - # node_title records the title of the HumanInput node. - node_title: str | None = None - - # display_in_ui controls whether the form should be displayed in UI surfaces. - display_in_ui: bool | None = None - - -class HumanInputSubmissionValidationError(ValueError): - pass - - -def validate_human_input_submission( - *, - inputs: Sequence[FormInput], - user_actions: Sequence[UserAction], - selected_action_id: str, - form_data: Mapping[str, Any], -) -> None: - available_actions = {action.id for action in user_actions} - if selected_action_id not in available_actions: - raise HumanInputSubmissionValidationError(f"Invalid action: {selected_action_id}") - - provided_inputs = set(form_data.keys()) - missing_inputs = [ - form_input.output_variable_name - for form_input in inputs - if form_input.output_variable_name not in provided_inputs - ] - - if missing_inputs: - missing_list = ", ".join(missing_inputs) - raise HumanInputSubmissionValidationError(f"Missing required inputs: {missing_list}") diff --git a/api/dify_graph/nodes/human_input/enums.py b/api/dify_graph/nodes/human_input/enums.py deleted file mode 100644 index da85728828f..00000000000 --- a/api/dify_graph/nodes/human_input/enums.py +++ /dev/null @@ -1,72 +0,0 @@ -import enum - - -class HumanInputFormStatus(enum.StrEnum): - """Status of a human input form.""" - - # Awaiting submission from any recipient. Forms stay in this state until - # submitted or a timeout rule applies. - WAITING = enum.auto() - # Global timeout reached. The workflow run is stopped and will not resume. - # This is distinct from node-level timeout. - EXPIRED = enum.auto() - # Submitted by a recipient; form data is available and execution resumes - # along the selected action edge. - SUBMITTED = enum.auto() - # Node-level timeout reached. The human input node should emit a timeout - # event and the workflow should resume along the timeout edge. - TIMEOUT = enum.auto() - - -class HumanInputFormKind(enum.StrEnum): - """Kind of a human input form.""" - - RUNTIME = enum.auto() # Form created during workflow execution. - DELIVERY_TEST = enum.auto() # Form created for delivery tests. - - -class DeliveryMethodType(enum.StrEnum): - """Delivery method types for human input forms.""" - - # WEBAPP controls whether the form is delivered to the web app. It not only controls - # the standalone web app, but also controls the installed apps in the console. - WEBAPP = enum.auto() - - EMAIL = enum.auto() - - -class ButtonStyle(enum.StrEnum): - """Button styles for user actions.""" - - PRIMARY = enum.auto() - DEFAULT = enum.auto() - ACCENT = enum.auto() - GHOST = enum.auto() - - -class TimeoutUnit(enum.StrEnum): - """Timeout unit for form expiration.""" - - HOUR = enum.auto() - DAY = enum.auto() - - -class FormInputType(enum.StrEnum): - """Form input types.""" - - TEXT_INPUT = enum.auto() - PARAGRAPH = enum.auto() - - -class PlaceholderType(enum.StrEnum): - """Default value types for form inputs.""" - - VARIABLE = enum.auto() - CONSTANT = enum.auto() - - -class EmailRecipientType(enum.StrEnum): - """Email recipient types.""" - - MEMBER = enum.auto() - EXTERNAL = enum.auto() diff --git a/api/dify_graph/nodes/human_input/human_input_node.py b/api/dify_graph/nodes/human_input/human_input_node.py deleted file mode 100644 index 794e33d92e0..00000000000 --- a/api/dify_graph/nodes/human_input/human_input_node.py +++ /dev/null @@ -1,361 +0,0 @@ -import json -import logging -from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any - -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus -from dify_graph.node_events import ( - HumanInputFormFilledEvent, - HumanInputFormTimeoutEvent, - NodeRunResult, - PauseRequestedEvent, -) -from dify_graph.node_events.base import NodeEventBase -from dify_graph.node_events.node import StreamCompletedEvent -from dify_graph.nodes.base.node import Node -from dify_graph.repositories.human_input_form_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter -from libs.datetime_utils import naive_utc_now - -from .entities import DeliveryChannelConfig, HumanInputNodeData, apply_debug_email_recipient -from .enums import DeliveryMethodType, HumanInputFormStatus, PlaceholderType - -if TYPE_CHECKING: - from dify_graph.entities.graph_init_params import GraphInitParams - from dify_graph.runtime.graph_runtime_state import GraphRuntimeState - - -_SELECTED_BRANCH_KEY = "selected_branch" -_INVOKE_FROM_DEBUGGER = "debugger" -_INVOKE_FROM_EXPLORE = "explore" - - -logger = logging.getLogger(__name__) - - -class HumanInputNode(Node[HumanInputNodeData]): - node_type = BuiltinNodeTypes.HUMAN_INPUT - execution_type = NodeExecutionType.BRANCH - - _BRANCH_SELECTION_KEYS: tuple[str, ...] = ( - "edge_source_handle", - "edgeSourceHandle", - "source_handle", - _SELECTED_BRANCH_KEY, - "selectedBranch", - "branch", - "branch_id", - "branchId", - "handle", - ) - - _node_data: HumanInputNodeData - _form_repository: HumanInputFormRepository - _OUTPUT_FIELD_ACTION_ID = "__action_id" - _OUTPUT_FIELD_RENDERED_CONTENT = "__rendered_content" - _TIMEOUT_HANDLE = _TIMEOUT_ACTION_ID = "__timeout" - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - form_repository: HumanInputFormRepository, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - self._form_repository = form_repository - - @classmethod - def version(cls) -> str: - return "1" - - def _resolve_branch_selection(self) -> str | None: - """Determine the branch handle selected by human input if available.""" - - variable_pool = self.graph_runtime_state.variable_pool - - for key in self._BRANCH_SELECTION_KEYS: - handle = self._extract_branch_handle(variable_pool.get((self.id, key))) - if handle: - return handle - - default_values = self.node_data.default_value_dict - for key in self._BRANCH_SELECTION_KEYS: - handle = self._normalize_branch_value(default_values.get(key)) - if handle: - return handle - - return None - - @staticmethod - def _extract_branch_handle(segment: Any) -> str | None: - if segment is None: - return None - - candidate = getattr(segment, "to_object", None) - raw_value = candidate() if callable(candidate) else getattr(segment, "value", None) - if raw_value is None: - return None - - return HumanInputNode._normalize_branch_value(raw_value) - - @staticmethod - def _normalize_branch_value(value: Any) -> str | None: - if value is None: - return None - - if isinstance(value, str): - stripped = value.strip() - return stripped or None - - if isinstance(value, Mapping): - for key in ("handle", "edge_source_handle", "edgeSourceHandle", "branch", "id", "value"): - candidate = value.get(key) - if isinstance(candidate, str) and candidate: - return candidate - - return None - - @property - def _workflow_execution_id(self) -> str: - workflow_exec_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id - assert workflow_exec_id is not None - return workflow_exec_id - - def _form_to_pause_event(self, form_entity: HumanInputFormEntity): - required_event = self._human_input_required_event(form_entity) - pause_requested_event = PauseRequestedEvent(reason=required_event) - return pause_requested_event - - def resolve_default_values(self) -> Mapping[str, Any]: - variable_pool = self.graph_runtime_state.variable_pool - resolved_defaults: dict[str, Any] = {} - for input in self._node_data.inputs: - if (default_value := input.default) is None: - continue - if default_value.type == PlaceholderType.CONSTANT: - continue - resolved_value = variable_pool.get(default_value.selector) - if resolved_value is None: - # TODO: How should we handle this? - continue - resolved_defaults[input.output_variable_name] = ( - WorkflowRuntimeTypeConverter().value_to_json_encodable_recursive(resolved_value.value) - ) - - return resolved_defaults - - def _should_require_console_recipient(self) -> bool: - invoke_from = self._invoke_from_value() - if invoke_from == _INVOKE_FROM_DEBUGGER: - return True - if invoke_from == _INVOKE_FROM_EXPLORE: - return self._node_data.is_webapp_enabled() - return False - - def _display_in_ui(self) -> bool: - if self._invoke_from_value() == _INVOKE_FROM_DEBUGGER: - return True - return self._node_data.is_webapp_enabled() - - def _effective_delivery_methods(self) -> Sequence[DeliveryChannelConfig]: - dify_ctx = self.require_dify_context() - invoke_from = self._invoke_from_value() - enabled_methods = [method for method in self._node_data.delivery_methods if method.enabled] - if invoke_from in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE}: - enabled_methods = [method for method in enabled_methods if method.type != DeliveryMethodType.WEBAPP] - return [ - apply_debug_email_recipient( - method, - enabled=invoke_from == _INVOKE_FROM_DEBUGGER, - user_id=dify_ctx.user_id, - ) - for method in enabled_methods - ] - - def _invoke_from_value(self) -> str: - invoke_from = self.require_dify_context().invoke_from - if isinstance(invoke_from, str): - return invoke_from - return str(getattr(invoke_from, "value", invoke_from)) - - def _human_input_required_event(self, form_entity: HumanInputFormEntity) -> HumanInputRequired: - node_data = self._node_data - resolved_default_values = self.resolve_default_values() - display_in_ui = self._display_in_ui() - form_token = form_entity.web_app_token - if display_in_ui and form_token is None: - raise AssertionError("Form token should be available for UI execution.") - return HumanInputRequired( - form_id=form_entity.id, - form_content=form_entity.rendered_content, - inputs=node_data.inputs, - actions=node_data.user_actions, - display_in_ui=display_in_ui, - node_id=self.id, - node_title=node_data.title, - form_token=form_token, - resolved_default_values=resolved_default_values, - ) - - def _run(self) -> Generator[NodeEventBase, None, None]: - """ - Execute the human input node. - - This method will: - 1. Generate a unique form ID - 2. Create form content with variable substitution - 3. Create form in database - 4. Send form via configured delivery methods - 5. Suspend workflow execution - 6. Wait for form submission to resume - """ - repo = self._form_repository - form = repo.get_form(self._workflow_execution_id, self.id) - dify_ctx = self.require_dify_context() - if form is None: - display_in_ui = self._display_in_ui() - params = FormCreateParams( - app_id=dify_ctx.app_id, - workflow_execution_id=self._workflow_execution_id, - node_id=self.id, - form_config=self._node_data, - rendered_content=self.render_form_content_before_submission(), - delivery_methods=self._effective_delivery_methods(), - display_in_ui=display_in_ui, - resolved_default_values=self.resolve_default_values(), - console_recipient_required=self._should_require_console_recipient(), - console_creator_account_id=( - dify_ctx.user_id - if self._invoke_from_value() in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE} - else None - ), - backstage_recipient_required=True, - ) - form_entity = self._form_repository.create_form(params) - # Create human input required event - - logger.info( - "Human Input node suspended workflow for form. workflow_run_id=%s, node_id=%s, form_id=%s", - self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id, - self.id, - form_entity.id, - ) - yield self._form_to_pause_event(form_entity) - return - - if ( - form.status in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED} - or form.expiration_time <= naive_utc_now() - ): - yield HumanInputFormTimeoutEvent( - node_title=self._node_data.title, - expiration_time=form.expiration_time, - ) - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={self._OUTPUT_FIELD_ACTION_ID: ""}, - edge_source_handle=self._TIMEOUT_HANDLE, - ) - ) - return - - if not form.submitted: - yield self._form_to_pause_event(form) - return - - selected_action_id = form.selected_action_id - if selected_action_id is None: - raise AssertionError(f"selected_action_id should not be None when form submitted, form_id={form.id}") - submitted_data = form.submitted_data or {} - outputs: dict[str, Any] = dict(submitted_data) - outputs[self._OUTPUT_FIELD_ACTION_ID] = selected_action_id - rendered_content = self.render_form_content_with_outputs( - form.rendered_content, - outputs, - self._node_data.outputs_field_names(), - ) - outputs[self._OUTPUT_FIELD_RENDERED_CONTENT] = rendered_content - - action_text = self._node_data.find_action_text(selected_action_id) - - yield HumanInputFormFilledEvent( - node_title=self._node_data.title, - rendered_content=rendered_content, - action_id=selected_action_id, - action_text=action_text, - ) - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs=outputs, - edge_source_handle=selected_action_id, - ) - ) - - def render_form_content_before_submission(self) -> str: - """ - Process form content by substituting variables. - - This method should: - 1. Parse the form_content markdown - 2. Substitute {{#node_name.var_name#}} with actual values - 3. Keep {{#$output.field_name#}} placeholders for form inputs - """ - rendered_form_content = self.graph_runtime_state.variable_pool.convert_template( - self._node_data.form_content, - ) - return rendered_form_content.markdown - - @staticmethod - def render_form_content_with_outputs( - form_content: str, - outputs: Mapping[str, Any], - field_names: Sequence[str], - ) -> str: - """ - Replace {{#$output.xxx#}} placeholders with submitted values. - """ - rendered_content = form_content - for field_name in field_names: - placeholder = "{{#$output." + field_name + "#}}" - value = outputs.get(field_name) - if value is None: - replacement = "" - elif isinstance(value, (dict, list)): - replacement = json.dumps(value, ensure_ascii=False) - else: - replacement = str(value) - rendered_content = rendered_content.replace(placeholder, replacement) - return rendered_content - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: HumanInputNodeData, - ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selectors referenced in form content and input default values. - - This method should parse: - 1. Variables referenced in form_content ({{#node_name.var_name#}}) - 2. Variables referenced in input default values - """ - return node_data.extract_variable_selector_to_variable_mapping(node_id) diff --git a/api/dify_graph/nodes/if_else/__init__.py b/api/dify_graph/nodes/if_else/__init__.py deleted file mode 100644 index afa0e8112c5..00000000000 --- a/api/dify_graph/nodes/if_else/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .if_else_node import IfElseNode - -__all__ = ["IfElseNode"] diff --git a/api/dify_graph/nodes/if_else/entities.py b/api/dify_graph/nodes/if_else/entities.py deleted file mode 100644 index ff09f3c023a..00000000000 --- a/api/dify_graph/nodes/if_else/entities.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Literal - -from pydantic import BaseModel, Field - -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.utils.condition.entities import Condition - - -class IfElseNodeData(BaseNodeData): - """ - If Else Node Data. - """ - - type: NodeType = BuiltinNodeTypes.IF_ELSE - - class Case(BaseModel): - """ - Case entity representing a single logical condition group - """ - - case_id: str - logical_operator: Literal["and", "or"] - conditions: list[Condition] - - logical_operator: Literal["and", "or"] | None = "and" - conditions: list[Condition] | None = Field(default=None, deprecated=True) - - cases: list[Case] | None = None diff --git a/api/dify_graph/nodes/if_else/if_else_node.py b/api/dify_graph/nodes/if_else/if_else_node.py deleted file mode 100644 index 7c0370e48c6..00000000000 --- a/api/dify_graph/nodes/if_else/if_else_node.py +++ /dev/null @@ -1,124 +0,0 @@ -from collections.abc import Mapping, Sequence -from typing import Any, Literal - -from typing_extensions import deprecated - -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.if_else.entities import IfElseNodeData -from dify_graph.runtime import VariablePool -from dify_graph.utils.condition.entities import Condition -from dify_graph.utils.condition.processor import ConditionProcessor - - -class IfElseNode(Node[IfElseNodeData]): - node_type = BuiltinNodeTypes.IF_ELSE - execution_type = NodeExecutionType.BRANCH - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - """ - Run node - :return: - """ - node_inputs: dict[str, Sequence[Mapping[str, Any]]] = {"conditions": []} - - process_data: dict[str, list] = {"condition_results": []} - - input_conditions: Sequence[Mapping[str, Any]] = [] - final_result = False - selected_case_id = "false" - condition_processor = ConditionProcessor() - try: - # Check if the new cases structure is used - if self.node_data.cases: - for case in self.node_data.cases: - input_conditions, group_result, final_result = condition_processor.process_conditions( - variable_pool=self.graph_runtime_state.variable_pool, - conditions=case.conditions, - operator=case.logical_operator, - ) - - process_data["condition_results"].append( - { - "group": case.model_dump(), - "results": group_result, - "final_result": final_result, - } - ) - - # Break if a case passes (logical short-circuit) - if final_result: - selected_case_id = case.case_id # Capture the ID of the passing case - break - - else: - # TODO: Update database then remove this - # Fallback to old structure if cases are not defined - input_conditions, group_result, final_result = _should_not_use_old_function( # pyright: ignore [reportDeprecated] - condition_processor=condition_processor, - variable_pool=self.graph_runtime_state.variable_pool, - conditions=self.node_data.conditions or [], - operator=self.node_data.logical_operator or "and", - ) - - selected_case_id = "true" if final_result else "false" - - process_data["condition_results"].append( - {"group": "default", "results": group_result, "final_result": final_result} - ) - - node_inputs["conditions"] = input_conditions - - except Exception as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, inputs=node_inputs, process_data=process_data, error=str(e) - ) - - outputs = {"result": final_result, "selected_case_id": selected_case_id} - - data = NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=node_inputs, - process_data=process_data, - edge_source_handle=selected_case_id or "false", # Use case ID or 'default' - outputs=outputs, - ) - - return data - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: IfElseNodeData, - ) -> Mapping[str, Sequence[str]]: - var_mapping: dict[str, list[str]] = {} - _ = graph_config # Explicitly mark as unused - for case in node_data.cases or []: - for condition in case.conditions: - key = f"{node_id}.#{'.'.join(condition.variable_selector)}#" - var_mapping[key] = condition.variable_selector - - return var_mapping - - -@deprecated("This function is deprecated. You should use the new cases structure.") -def _should_not_use_old_function( - *, - condition_processor: ConditionProcessor, - variable_pool: VariablePool, - conditions: list[Condition], - operator: Literal["and", "or"], -): - return condition_processor.process_conditions( - variable_pool=variable_pool, - conditions=conditions, - operator=operator, - ) diff --git a/api/dify_graph/nodes/iteration/__init__.py b/api/dify_graph/nodes/iteration/__init__.py deleted file mode 100644 index 5bb87aaffa9..00000000000 --- a/api/dify_graph/nodes/iteration/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .entities import IterationNodeData -from .iteration_node import IterationNode -from .iteration_start_node import IterationStartNode - -__all__ = ["IterationNode", "IterationNodeData", "IterationStartNode"] diff --git a/api/dify_graph/nodes/iteration/entities.py b/api/dify_graph/nodes/iteration/entities.py deleted file mode 100644 index 58fd112b121..00000000000 --- a/api/dify_graph/nodes/iteration/entities.py +++ /dev/null @@ -1,67 +0,0 @@ -from enum import StrEnum -from typing import Any - -from pydantic import Field - -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base import BaseIterationNodeData, BaseIterationState - - -class ErrorHandleMode(StrEnum): - TERMINATED = "terminated" - CONTINUE_ON_ERROR = "continue-on-error" - REMOVE_ABNORMAL_OUTPUT = "remove-abnormal-output" - - -class IterationNodeData(BaseIterationNodeData): - """ - Iteration Node Data. - """ - - type: NodeType = BuiltinNodeTypes.ITERATION - parent_loop_id: str | None = None # redundant field, not used currently - iterator_selector: list[str] # variable selector - output_selector: list[str] # output selector - is_parallel: bool = False # open the parallel mode or not - parallel_nums: int = 10 # the numbers of parallel - error_handle_mode: ErrorHandleMode = ErrorHandleMode.TERMINATED # how to handle the error - flatten_output: bool = True # whether to flatten the output array if all elements are lists - - -class IterationStartNodeData(BaseNodeData): - """ - Iteration Start Node Data. - """ - - type: NodeType = BuiltinNodeTypes.ITERATION_START - - -class IterationState(BaseIterationState): - """ - Iteration State. - """ - - outputs: list[Any] = Field(default_factory=list) - current_output: Any = None - - class MetaData(BaseIterationState.MetaData): - """ - Data. - """ - - iterator_length: int - - def get_last_output(self) -> Any: - """ - Get last output. - """ - if self.outputs: - return self.outputs[-1] - return None - - def get_current_output(self) -> Any: - """ - Get current output. - """ - return self.current_output diff --git a/api/dify_graph/nodes/iteration/exc.py b/api/dify_graph/nodes/iteration/exc.py deleted file mode 100644 index d9947e09bc1..00000000000 --- a/api/dify_graph/nodes/iteration/exc.py +++ /dev/null @@ -1,22 +0,0 @@ -class IterationNodeError(ValueError): - """Base class for iteration node errors.""" - - -class IteratorVariableNotFoundError(IterationNodeError): - """Raised when the iterator variable is not found.""" - - -class InvalidIteratorValueError(IterationNodeError): - """Raised when the iterator value is invalid.""" - - -class StartNodeIdNotFoundError(IterationNodeError): - """Raised when the start node ID is not found.""" - - -class IterationGraphNotFoundError(IterationNodeError): - """Raised when the iteration graph is not found.""" - - -class IterationIndexNotFoundError(IterationNodeError): - """Raised when the iteration index is not found.""" diff --git a/api/dify_graph/nodes/iteration/iteration_node.py b/api/dify_graph/nodes/iteration/iteration_node.py deleted file mode 100644 index 033ec8672fc..00000000000 --- a/api/dify_graph/nodes/iteration/iteration_node.py +++ /dev/null @@ -1,626 +0,0 @@ -import logging -from collections.abc import Generator, Mapping, Sequence -from concurrent.futures import Future, ThreadPoolExecutor, as_completed -from datetime import UTC, datetime -from typing import TYPE_CHECKING, Any, NewType, cast - -from typing_extensions import TypeIs - -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID -from dify_graph.entities.graph_config import NodeConfigDictAdapter -from dify_graph.enums import ( - BuiltinNodeTypes, - NodeExecutionType, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from dify_graph.graph_events import ( - GraphNodeEventBase, - GraphRunFailedEvent, - GraphRunPartialSucceededEvent, - GraphRunSucceededEvent, -) -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.node_events import ( - IterationFailedEvent, - IterationNextEvent, - IterationStartedEvent, - IterationSucceededEvent, - NodeEventBase, - NodeRunResult, - StreamCompletedEvent, -) -from dify_graph.nodes.base import LLMUsageTrackingMixin -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData -from dify_graph.runtime import VariablePool -from dify_graph.variables import IntegerVariable, NoneSegment -from dify_graph.variables.segments import ArrayAnySegment, ArraySegment -from dify_graph.variables.variables import Variable -from libs.datetime_utils import naive_utc_now - -from .exc import ( - InvalidIteratorValueError, - IterationGraphNotFoundError, - IterationIndexNotFoundError, - IterationNodeError, - IteratorVariableNotFoundError, - StartNodeIdNotFoundError, -) - -if TYPE_CHECKING: - from dify_graph.context import IExecutionContext - from dify_graph.graph_engine import GraphEngine - -logger = logging.getLogger(__name__) - -EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment) - - -class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): - """ - Iteration Node. - """ - - node_type = BuiltinNodeTypes.ITERATION - execution_type = NodeExecutionType.CONTAINER - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - return { - "type": "iteration", - "config": { - "is_parallel": False, - "parallel_nums": 10, - "error_handle_mode": ErrorHandleMode.TERMINATED, - "flatten_output": True, - }, - } - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: # type: ignore - variable = self._get_iterator_variable() - - if self._is_empty_iteration(variable): - yield from self._handle_empty_iteration(variable) - return - - iterator_list_value = self._validate_and_get_iterator_list(variable) - inputs = {"iterator_selector": iterator_list_value} - - self._validate_start_node() - - started_at = naive_utc_now() - iter_run_map: dict[str, float] = {} - outputs: list[object] = [] - usage_accumulator = [LLMUsage.empty_usage()] - - yield IterationStartedEvent( - start_at=started_at, - inputs=inputs, - metadata={"iteration_length": len(iterator_list_value)}, - ) - - try: - yield from self._execute_iterations( - iterator_list_value=iterator_list_value, - outputs=outputs, - iter_run_map=iter_run_map, - usage_accumulator=usage_accumulator, - ) - - self._accumulate_usage(usage_accumulator[0]) - yield from self._handle_iteration_success( - started_at=started_at, - inputs=inputs, - outputs=outputs, - iterator_list_value=iterator_list_value, - iter_run_map=iter_run_map, - usage=usage_accumulator[0], - ) - except IterationNodeError as e: - self._accumulate_usage(usage_accumulator[0]) - yield from self._handle_iteration_failure( - started_at=started_at, - inputs=inputs, - outputs=outputs, - iterator_list_value=iterator_list_value, - iter_run_map=iter_run_map, - usage=usage_accumulator[0], - error=e, - ) - - def _get_iterator_variable(self) -> ArraySegment | NoneSegment: - variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector) - - if not variable: - raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found") - - if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment): - raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.") - - return variable - - def _is_empty_iteration(self, variable: ArraySegment | NoneSegment) -> TypeIs[NoneSegment | EmptyArraySegment]: - return isinstance(variable, NoneSegment) or len(variable.value) == 0 - - def _handle_empty_iteration(self, variable: ArraySegment | NoneSegment) -> Generator[NodeEventBase, None, None]: - # Try our best to preserve the type information. - if isinstance(variable, ArraySegment): - output = variable.model_copy(update={"value": []}) - else: - output = ArrayAnySegment(value=[]) - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - # TODO(QuantumGhost): is it possible to compute the type of `output` - # from graph definition? - outputs={"output": output}, - ) - ) - - def _validate_and_get_iterator_list(self, variable: ArraySegment) -> Sequence[object]: - iterator_list_value = variable.to_object() - - if not isinstance(iterator_list_value, list): - raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.") - - return cast(list[object], iterator_list_value) - - def _validate_start_node(self) -> None: - if not self.node_data.start_node_id: - raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found") - - def _execute_iterations( - self, - iterator_list_value: Sequence[object], - outputs: list[object], - iter_run_map: dict[str, float], - usage_accumulator: list[LLMUsage], - ) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: - if self.node_data.is_parallel: - # Parallel mode execution - yield from self._execute_parallel_iterations( - iterator_list_value=iterator_list_value, - outputs=outputs, - iter_run_map=iter_run_map, - usage_accumulator=usage_accumulator, - ) - else: - # Sequential mode execution - for index, item in enumerate(iterator_list_value): - iter_start_at = datetime.now(UTC).replace(tzinfo=None) - yield IterationNextEvent(index=index) - - graph_engine = self._create_graph_engine(index, item) - - # Run the iteration - yield from self._run_single_iter( - variable_pool=graph_engine.graph_runtime_state.variable_pool, - outputs=outputs, - graph_engine=graph_engine, - ) - - # Sync conversation variables after each iteration completes - self._sync_conversation_variables_from_snapshot( - self._extract_conversation_variable_snapshot( - variable_pool=graph_engine.graph_runtime_state.variable_pool - ) - ) - - # Accumulate usage from this iteration - usage_accumulator[0] = self._merge_usage( - usage_accumulator[0], graph_engine.graph_runtime_state.llm_usage - ) - iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() - - def _execute_parallel_iterations( - self, - iterator_list_value: Sequence[object], - outputs: list[object], - iter_run_map: dict[str, float], - usage_accumulator: list[LLMUsage], - ) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: - # Initialize outputs list with None values to maintain order - outputs.extend([None] * len(iterator_list_value)) - - # Determine the number of parallel workers - max_workers = min(self.node_data.parallel_nums, len(iterator_list_value)) - - with ThreadPoolExecutor(max_workers=max_workers) as executor: - # Submit all iteration tasks - future_to_index: dict[ - Future[ - tuple[ - float, - list[GraphNodeEventBase], - object | None, - dict[str, Variable], - LLMUsage, - ] - ], - int, - ] = {} - for index, item in enumerate(iterator_list_value): - yield IterationNextEvent(index=index) - future = executor.submit( - self._execute_single_iteration_parallel, - index=index, - item=item, - execution_context=self._capture_execution_context(), - ) - future_to_index[future] = index - - # Process completed iterations as they finish - for future in as_completed(future_to_index): - index = future_to_index[future] - try: - result = future.result() - ( - iteration_duration, - events, - output_value, - conversation_snapshot, - iteration_usage, - ) = result - - # Update outputs at the correct index - outputs[index] = output_value - - # Yield all events from this iteration - yield from events - - # The worker computes duration before we replay buffered events here, - # so slow downstream consumers don't inflate per-iteration timing. - iter_run_map[str(index)] = iteration_duration - - usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage) - - # Sync conversation variables after iteration completion - self._sync_conversation_variables_from_snapshot(conversation_snapshot) - - except Exception as e: - # Handle errors based on error_handle_mode - match self.node_data.error_handle_mode: - case ErrorHandleMode.TERMINATED: - # Cancel remaining futures and re-raise - for f in future_to_index: - if f != future: - f.cancel() - raise IterationNodeError(str(e)) - case ErrorHandleMode.CONTINUE_ON_ERROR: - outputs[index] = None - case ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: - outputs[index] = None # Will be filtered later - - # Remove None values if in REMOVE_ABNORMAL_OUTPUT mode - if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: - outputs[:] = [output for output in outputs if output is not None] - - def _execute_single_iteration_parallel( - self, - index: int, - item: object, - execution_context: "IExecutionContext", - ) -> tuple[float, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]: - """Execute a single iteration in parallel mode and return results.""" - with execution_context: - iter_start_at = datetime.now(UTC).replace(tzinfo=None) - events: list[GraphNodeEventBase] = [] - outputs_temp: list[object] = [] - - graph_engine = self._create_graph_engine(index, item) - - # Collect events instead of yielding them directly - for event in self._run_single_iter( - variable_pool=graph_engine.graph_runtime_state.variable_pool, - outputs=outputs_temp, - graph_engine=graph_engine, - ): - events.append(event) - - # Get the output value from the temporary outputs list - output_value = outputs_temp[0] if outputs_temp else None - conversation_snapshot = self._extract_conversation_variable_snapshot( - variable_pool=graph_engine.graph_runtime_state.variable_pool - ) - iteration_duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() - - return ( - iteration_duration, - events, - output_value, - conversation_snapshot, - graph_engine.graph_runtime_state.llm_usage, - ) - - def _capture_execution_context(self) -> "IExecutionContext": - """Capture current execution context for parallel iterations.""" - from dify_graph.context import capture_current_context - - return capture_current_context() - - def _handle_iteration_success( - self, - started_at: datetime, - inputs: dict[str, Sequence[object]], - outputs: list[object], - iterator_list_value: Sequence[object], - iter_run_map: dict[str, float], - *, - usage: LLMUsage, - ) -> Generator[NodeEventBase, None, None]: - # Flatten the list of lists if all outputs are lists - flattened_outputs = self._flatten_outputs_if_needed(outputs) - - yield IterationSucceededEvent( - start_at=started_at, - inputs=inputs, - outputs={"output": flattened_outputs}, - steps=len(iterator_list_value), - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, - }, - ) - - # Yield final success event - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"output": flattened_outputs}, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - }, - llm_usage=usage, - ) - ) - - def _flatten_outputs_if_needed(self, outputs: list[object]) -> list[object]: - """ - Flatten the outputs list if all elements are lists. - This maintains backward compatibility with version 1.8.1 behavior. - - If flatten_output is False, returns outputs as-is (nested structure). - If flatten_output is True (default), flattens the list if all elements are lists. - """ - # If flatten_output is disabled, return outputs as-is - if not self.node_data.flatten_output: - return outputs - - if not outputs: - return outputs - - # Check if all non-None outputs are lists - non_none_outputs: list[object] = [output for output in outputs if output is not None] - if not non_none_outputs: - return outputs - - if all(isinstance(output, list) for output in non_none_outputs): - # Flatten the list of lists - flattened: list[Any] = [] - for output in outputs: - if isinstance(output, list): - flattened.extend(output) - elif output is not None: - # This shouldn't happen based on our check, but handle it gracefully - flattened.append(output) - return flattened - - return outputs - - def _handle_iteration_failure( - self, - started_at: datetime, - inputs: dict[str, Sequence[object]], - outputs: list[object], - iterator_list_value: Sequence[object], - iter_run_map: dict[str, float], - *, - usage: LLMUsage, - error: IterationNodeError, - ) -> Generator[NodeEventBase, None, None]: - # Flatten the list of lists if all outputs are lists (even in failure case) - flattened_outputs = self._flatten_outputs_if_needed(outputs) - - yield IterationFailedEvent( - start_at=started_at, - inputs=inputs, - outputs={"output": flattened_outputs}, - steps=len(iterator_list_value), - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, - }, - error=str(error), - ) - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(error), - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - }, - llm_usage=usage, - ) - ) - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: IterationNodeData, - ) -> Mapping[str, Sequence[str]]: - variable_mapping: dict[str, Sequence[str]] = { - f"{node_id}.input_selector": node_data.iterator_selector, - } - iteration_node_ids = set() - - # Find all nodes that belong to this loop - nodes = graph_config.get("nodes", []) - for node in nodes: - node_config_data = node.get("data", {}) - if node_config_data.get("iteration_id") == node_id: - in_iteration_node_id = node.get("id") - if in_iteration_node_id: - iteration_node_ids.add(in_iteration_node_id) - - # Get node configs from graph_config instead of non-existent node_id_config_mapping - node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node} - for sub_node_id, sub_node_config in node_configs.items(): - if sub_node_config.get("data", {}).get("iteration_id") != node_id: - continue - - # variable selector to variable mapping - try: - typed_sub_node_config = NodeConfigDictAdapter.validate_python(sub_node_config) - node_type = typed_sub_node_config["data"].type - node_mapping = Node.get_node_type_classes_mapping() - if node_type not in node_mapping: - continue - node_version = str(typed_sub_node_config["data"].version) - node_cls = node_mapping[node_type][node_version] - - sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( - graph_config=graph_config, config=typed_sub_node_config - ) - sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping) - except NotImplementedError: - sub_node_variable_mapping = {} - - # remove iteration variables - sub_node_variable_mapping = { - sub_node_id + "." + key: value - for key, value in sub_node_variable_mapping.items() - if value[0] != node_id - } - - variable_mapping.update(sub_node_variable_mapping) - - # remove variable out from iteration - variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in iteration_node_ids} - - return variable_mapping - - def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, Variable]: - conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {}) - return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()} - - def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, Variable]) -> None: - parent_pool = self.graph_runtime_state.variable_pool - parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {}) - - current_keys = set(parent_conversations.keys()) - snapshot_keys = set(snapshot.keys()) - - for removed_key in current_keys - snapshot_keys: - parent_pool.remove((CONVERSATION_VARIABLE_NODE_ID, removed_key)) - - for name, variable in snapshot.items(): - parent_pool.add((CONVERSATION_VARIABLE_NODE_ID, name), variable) - - def _append_iteration_info_to_event( - self, - event: GraphNodeEventBase, - iter_run_index: int, - ): - event.in_iteration_id = self._node_id - iter_metadata = { - WorkflowNodeExecutionMetadataKey.ITERATION_ID: self._node_id, - WorkflowNodeExecutionMetadataKey.ITERATION_INDEX: iter_run_index, - } - - current_metadata = event.node_run_result.metadata - if WorkflowNodeExecutionMetadataKey.ITERATION_ID not in current_metadata: - event.node_run_result.metadata = {**current_metadata, **iter_metadata} - - def _run_single_iter( - self, - *, - variable_pool: VariablePool, - outputs: list[object], - graph_engine: "GraphEngine", - ) -> Generator[GraphNodeEventBase, None, None]: - rst = graph_engine.run() - # get current iteration index - index_variable = variable_pool.get([self._node_id, "index"]) - if not isinstance(index_variable, IntegerVariable): - raise IterationIndexNotFoundError(f"iteration {self._node_id} current index not found") - current_index = index_variable.value - for event in rst: - if isinstance(event, GraphNodeEventBase) and event.node_type == BuiltinNodeTypes.ITERATION_START: - continue - - if isinstance(event, GraphNodeEventBase): - self._append_iteration_info_to_event(event=event, iter_run_index=current_index) - yield event - elif isinstance(event, (GraphRunSucceededEvent, GraphRunPartialSucceededEvent)): - result = variable_pool.get(self.node_data.output_selector) - if result is None: - outputs.append(None) - else: - outputs.append(result.to_object()) - return - elif isinstance(event, GraphRunFailedEvent): - match self.node_data.error_handle_mode: - case ErrorHandleMode.TERMINATED: - raise IterationNodeError(event.error) - case ErrorHandleMode.CONTINUE_ON_ERROR: - outputs.append(None) - return - case ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: - return - - def _create_graph_engine(self, index: int, item: object): - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState - - # Create GraphInitParams for child graph execution. - graph_init_params = GraphInitParams( - workflow_id=self.workflow_id, - graph_config=self.graph_config, - run_context=self.run_context, - call_depth=self.workflow_call_depth, - ) - # Create a deep copy of the variable pool for each iteration - variable_pool_copy = self.graph_runtime_state.variable_pool.model_copy(deep=True) - - # append iteration variable (item, index) to variable pool - variable_pool_copy.add([self._node_id, "index"], index) - variable_pool_copy.add([self._node_id, "item"], item) - - # Create a new GraphRuntimeState for this iteration - graph_runtime_state_copy = GraphRuntimeState( - variable_pool=variable_pool_copy, - start_at=self.graph_runtime_state.start_at, - total_tokens=0, - node_run_steps=0, - ) - root_node_id = self.node_data.start_node_id - if root_node_id is None: - raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found") - - try: - return self.graph_runtime_state.create_child_engine( - workflow_id=self.workflow_id, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state_copy, - graph_config=self.graph_config, - root_node_id=root_node_id, - ) - except ChildGraphNotFoundError as exc: - raise IterationGraphNotFoundError("iteration graph not found") from exc diff --git a/api/dify_graph/nodes/iteration/iteration_start_node.py b/api/dify_graph/nodes/iteration/iteration_start_node.py deleted file mode 100644 index a8ecf3d83bc..00000000000 --- a/api/dify_graph/nodes/iteration/iteration_start_node.py +++ /dev/null @@ -1,22 +0,0 @@ -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.iteration.entities import IterationStartNodeData - - -class IterationStartNode(Node[IterationStartNodeData]): - """ - Iteration Start Node. - """ - - node_type = BuiltinNodeTypes.ITERATION_START - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - """ - Run the node. - """ - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED) diff --git a/api/dify_graph/nodes/list_operator/__init__.py b/api/dify_graph/nodes/list_operator/__init__.py deleted file mode 100644 index 1877586ef41..00000000000 --- a/api/dify_graph/nodes/list_operator/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .node import ListOperatorNode - -__all__ = ["ListOperatorNode"] diff --git a/api/dify_graph/nodes/list_operator/entities.py b/api/dify_graph/nodes/list_operator/entities.py deleted file mode 100644 index 41b3a40b786..00000000000 --- a/api/dify_graph/nodes/list_operator/entities.py +++ /dev/null @@ -1,71 +0,0 @@ -from collections.abc import Sequence -from enum import StrEnum - -from pydantic import BaseModel, Field - -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType - - -class FilterOperator(StrEnum): - # string conditions - CONTAINS = "contains" - START_WITH = "start with" - END_WITH = "end with" - IS = "is" - IN = "in" - EMPTY = "empty" - NOT_CONTAINS = "not contains" - IS_NOT = "is not" - NOT_IN = "not in" - NOT_EMPTY = "not empty" - # number conditions - EQUAL = "=" - NOT_EQUAL = "≠" - LESS_THAN = "<" - GREATER_THAN = ">" - GREATER_THAN_OR_EQUAL = "≥" - LESS_THAN_OR_EQUAL = "≤" - - -class Order(StrEnum): - ASC = "asc" - DESC = "desc" - - -class FilterCondition(BaseModel): - key: str = "" - comparison_operator: FilterOperator = FilterOperator.CONTAINS - # the value is bool if the filter operator is comparing with - # a boolean constant. - value: str | Sequence[str] | bool = "" - - -class FilterBy(BaseModel): - enabled: bool = False - conditions: Sequence[FilterCondition] = Field(default_factory=list) - - -class OrderByConfig(BaseModel): - enabled: bool = False - key: str = "" - value: Order = Order.ASC - - -class Limit(BaseModel): - enabled: bool = False - size: int = -1 - - -class ExtractConfig(BaseModel): - enabled: bool = False - serial: str = "1" - - -class ListOperatorNodeData(BaseNodeData): - type: NodeType = BuiltinNodeTypes.LIST_OPERATOR - variable: Sequence[str] = Field(default_factory=list) - filter_by: FilterBy - order_by: OrderByConfig - limit: Limit - extract_by: ExtractConfig = Field(default_factory=ExtractConfig) diff --git a/api/dify_graph/nodes/list_operator/exc.py b/api/dify_graph/nodes/list_operator/exc.py deleted file mode 100644 index f88aa0be29c..00000000000 --- a/api/dify_graph/nodes/list_operator/exc.py +++ /dev/null @@ -1,16 +0,0 @@ -class ListOperatorError(ValueError): - """Base class for all ListOperator errors.""" - - pass - - -class InvalidFilterValueError(ListOperatorError): - pass - - -class InvalidKeyError(ListOperatorError): - pass - - -class InvalidConditionError(ListOperatorError): - pass diff --git a/api/dify_graph/nodes/list_operator/node.py b/api/dify_graph/nodes/list_operator/node.py deleted file mode 100644 index dc8b8904f7f..00000000000 --- a/api/dify_graph/nodes/list_operator/node.py +++ /dev/null @@ -1,345 +0,0 @@ -from collections.abc import Callable, Sequence -from typing import Any, TypeAlias, TypeVar - -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.file import File -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment -from dify_graph.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment - -from .entities import FilterOperator, ListOperatorNodeData, Order -from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError - -_SUPPORTED_TYPES_TUPLE = ( - ArrayFileSegment, - ArrayNumberSegment, - ArrayStringSegment, - ArrayBooleanSegment, -) -_SUPPORTED_TYPES_ALIAS: TypeAlias = ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment | ArrayBooleanSegment - - -_T = TypeVar("_T") - - -def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]: - """Returns the negation of a given filter function. If the original filter - returns `True` for a value, the negated filter will return `False`, and vice versa. - """ - - def wrapper(value: _T) -> bool: - return not filter_(value) - - return wrapper - - -class ListOperatorNode(Node[ListOperatorNodeData]): - node_type = BuiltinNodeTypes.LIST_OPERATOR - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self): - inputs: dict[str, Sequence[object]] = {} - process_data: dict[str, Sequence[object]] = {} - outputs: dict[str, Any] = {} - - variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable) - if variable is None: - error_message = f"Variable not found for selector: {self.node_data.variable}" - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs - ) - if not variable.value: - inputs = {"variable": []} - process_data = {"variable": []} - if isinstance(variable, ArraySegment): - result = variable.model_copy(update={"value": []}) - else: - result = ArrayAnySegment(value=[]) - outputs = {"result": result, "first_record": None, "last_record": None} - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs=outputs, - ) - if not isinstance(variable, _SUPPORTED_TYPES_TUPLE): - error_message = f"Variable {self.node_data.variable} is not an array type, actual type: {type(variable)}" - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs - ) - - if isinstance(variable, ArrayFileSegment): - inputs = {"variable": [item.to_dict() for item in variable.value]} - process_data["variable"] = [item.to_dict() for item in variable.value] - else: - inputs = {"variable": variable.value} - process_data["variable"] = variable.value - - try: - # Filter - if self.node_data.filter_by.enabled: - variable = self._apply_filter(variable) - - # Extract - if self.node_data.extract_by.enabled: - variable = self._extract_slice(variable) - - # Order - if self.node_data.order_by.enabled: - variable = self._apply_order(variable) - - # Slice - if self.node_data.limit.enabled: - variable = self._apply_slice(variable) - - outputs = { - "result": variable, - "first_record": variable.value[0] if variable.value else None, - "last_record": variable.value[-1] if variable.value else None, - } - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs=outputs, - ) - except ListOperatorError as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - inputs=inputs, - process_data=process_data, - outputs=outputs, - ) - - def _apply_filter(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS: - filter_func: Callable[[Any], bool] - result: list[Any] = [] - for condition in self.node_data.filter_by.conditions: - if isinstance(variable, ArrayStringSegment): - if not isinstance(condition.value, str): - raise InvalidFilterValueError(f"Invalid filter value: {condition.value}") - value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text - filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value) - result = list(filter(filter_func, variable.value)) - variable = variable.model_copy(update={"value": result}) - elif isinstance(variable, ArrayNumberSegment): - if not isinstance(condition.value, str): - raise InvalidFilterValueError(f"Invalid filter value: {condition.value}") - value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text - filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value)) - result = list(filter(filter_func, variable.value)) - variable = variable.model_copy(update={"value": result}) - elif isinstance(variable, ArrayFileSegment): - if isinstance(condition.value, str): - value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text - elif isinstance(condition.value, bool): - raise ValueError(f"File filter expects a string value, got {type(condition.value)}") - else: - value = condition.value - filter_func = _get_file_filter_func( - key=condition.key, - condition=condition.comparison_operator, - value=value, - ) - result = list(filter(filter_func, variable.value)) - variable = variable.model_copy(update={"value": result}) - else: - if not isinstance(condition.value, bool): - raise ValueError(f"Boolean filter expects a boolean value, got {type(condition.value)}") - filter_func = _get_boolean_filter_func(condition=condition.comparison_operator, value=condition.value) - result = list(filter(filter_func, variable.value)) - variable = variable.model_copy(update={"value": result}) - return variable - - def _apply_order(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS: - if isinstance(variable, (ArrayStringSegment, ArrayNumberSegment, ArrayBooleanSegment)): - result = sorted(variable.value, reverse=self.node_data.order_by.value == Order.DESC) - variable = variable.model_copy(update={"value": result}) - else: - result = _order_file( - order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value - ) - variable = variable.model_copy(update={"value": result}) - - return variable - - def _apply_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS: - result = variable.value[: self.node_data.limit.size] - return variable.model_copy(update={"value": result}) - - def _extract_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS: - value = int(self.graph_runtime_state.variable_pool.convert_template(self.node_data.extract_by.serial).text) - if value < 1: - raise ValueError(f"Invalid serial index: must be >= 1, got {value}") - if value > len(variable.value): - raise InvalidKeyError(f"Invalid serial index: must be <= {len(variable.value)}, got {value}") - value -= 1 - result = variable.value[value] - return variable.model_copy(update={"value": [result]}) - - -def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]: - match key: - case "size": - return lambda x: x.size - case _: - raise InvalidKeyError(f"Invalid key: {key}") - - -def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]: - match key: - case "name": - return lambda x: x.filename or "" - case "type": - return lambda x: str(x.type) - case "extension": - return lambda x: x.extension or "" - case "mime_type": - return lambda x: x.mime_type or "" - case "transfer_method": - return lambda x: str(x.transfer_method) - case "url": - return lambda x: x.remote_url or "" - case "related_id": - return lambda x: x.related_id or "" - case _: - raise InvalidKeyError(f"Invalid key: {key}") - - -def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bool]: - match condition: - case "contains": - return _contains(value) - case "start with": - return _startswith(value) - case "end with": - return _endswith(value) - case "is": - return _is(value) - case "in": - return _in(value) - case "empty": - return lambda x: x == "" - case "not contains": - return _negation(_contains(value)) - case "is not": - return _negation(_is(value)) - case "not in": - return _negation(_in(value)) - case "not empty": - return lambda x: x != "" - case _: - raise InvalidConditionError(f"Invalid condition: {condition}") - - -def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callable[[str], bool]: - match condition: - case "in": - return _in(value) - case "not in": - return _negation(_in(value)) - case _: - raise InvalidConditionError(f"Invalid condition: {condition}") - - -def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[int | float], bool]: - match condition: - case "=": - return _eq(value) - case "≠": - return _ne(value) - case "<": - return _lt(value) - case "≤": - return _le(value) - case ">": - return _gt(value) - case "≥": - return _ge(value) - case _: - raise InvalidConditionError(f"Invalid condition: {condition}") - - -def _get_boolean_filter_func(*, condition: FilterOperator, value: bool) -> Callable[[bool], bool]: - match condition: - case FilterOperator.IS: - return _is(value) - case FilterOperator.IS_NOT: - return _negation(_is(value)) - case _: - raise InvalidConditionError(f"Invalid condition: {condition}") - - -def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]: - if key in {"name", "extension", "mime_type", "url", "related_id"} and isinstance(value, str): - extract_func = _get_file_extract_string_func(key=key) - return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x)) - if key in {"type", "transfer_method"}: - extract_func = _get_file_extract_string_func(key=key) - return lambda x: _get_sequence_filter_func(condition=condition, value=value)(extract_func(x)) - elif key == "size" and isinstance(value, str): - extract_number = _get_file_extract_number_func(key=key) - return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_number(x)) - else: - raise InvalidKeyError(f"Invalid key: {key}") - - -def _contains(value: str) -> Callable[[str], bool]: - return lambda x: value in x - - -def _startswith(value: str) -> Callable[[str], bool]: - return lambda x: x.startswith(value) - - -def _endswith(value: str) -> Callable[[str], bool]: - return lambda x: x.endswith(value) - - -def _is(value: _T) -> Callable[[_T], bool]: - return lambda x: x == value - - -def _in(value: str | Sequence[str]) -> Callable[[str], bool]: - return lambda x: x in value - - -def _eq(value: int | float) -> Callable[[int | float], bool]: - return lambda x: x == value - - -def _ne(value: int | float) -> Callable[[int | float], bool]: - return lambda x: x != value - - -def _lt(value: int | float) -> Callable[[int | float], bool]: - return lambda x: x < value - - -def _le(value: int | float) -> Callable[[int | float], bool]: - return lambda x: x <= value - - -def _gt(value: int | float) -> Callable[[int | float], bool]: - return lambda x: x > value - - -def _ge(value: int | float) -> Callable[[int | float], bool]: - return lambda x: x >= value - - -def _order_file(*, order: Order, order_by: str = "", array: Sequence[File]): - extract_func: Callable[[File], Any] - if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url", "related_id"}: - extract_func = _get_file_extract_string_func(key=order_by) - return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC) - elif order_by == "size": - extract_func = _get_file_extract_number_func(key=order_by) - return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC) - else: - raise InvalidKeyError(f"Invalid order key: {order_by}") diff --git a/api/dify_graph/nodes/llm/__init__.py b/api/dify_graph/nodes/llm/__init__.py deleted file mode 100644 index f7bc713f631..00000000000 --- a/api/dify_graph/nodes/llm/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -from .entities import ( - LLMNodeChatModelMessage, - LLMNodeCompletionModelPromptTemplate, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from .node import LLMNode - -__all__ = [ - "LLMNode", - "LLMNodeChatModelMessage", - "LLMNodeCompletionModelPromptTemplate", - "LLMNodeData", - "ModelConfig", - "VisionConfig", -] diff --git a/api/dify_graph/nodes/llm/entities.py b/api/dify_graph/nodes/llm/entities.py deleted file mode 100644 index 6ca01a21da3..00000000000 --- a/api/dify_graph/nodes/llm/entities.py +++ /dev/null @@ -1,100 +0,0 @@ -from collections.abc import Mapping, Sequence -from typing import Any, Literal - -from pydantic import BaseModel, Field, field_validator - -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode -from dify_graph.nodes.base.entities import VariableSelector - - -class ModelConfig(BaseModel): - provider: str - name: str - mode: LLMMode - completion_params: dict[str, Any] = Field(default_factory=dict) - - -class ContextConfig(BaseModel): - enabled: bool - variable_selector: list[str] | None = None - - -class VisionConfigOptions(BaseModel): - variable_selector: Sequence[str] = Field(default_factory=lambda: ["sys", "files"]) - detail: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.HIGH - - -class VisionConfig(BaseModel): - enabled: bool = False - configs: VisionConfigOptions = Field(default_factory=VisionConfigOptions) - - @field_validator("configs", mode="before") - @classmethod - def convert_none_configs(cls, v: Any): - if v is None: - return VisionConfigOptions() - return v - - -class PromptConfig(BaseModel): - jinja2_variables: Sequence[VariableSelector] = Field(default_factory=list) - - @field_validator("jinja2_variables", mode="before") - @classmethod - def convert_none_jinja2_variables(cls, v: Any): - if v is None: - return [] - return v - - -class LLMNodeChatModelMessage(ChatModelMessage): - text: str = "" - jinja2_text: str | None = None - - -class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate): - jinja2_text: str | None = None - - -class LLMNodeData(BaseNodeData): - type: NodeType = BuiltinNodeTypes.LLM - model: ModelConfig - prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate - prompt_config: PromptConfig = Field(default_factory=PromptConfig) - memory: MemoryConfig | None = None - context: ContextConfig - vision: VisionConfig = Field(default_factory=VisionConfig) - structured_output: Mapping[str, Any] | None = None - # We used 'structured_output_enabled' in the past, but it's not a good name. - structured_output_switch_on: bool = Field(False, alias="structured_output_enabled") - reasoning_format: Literal["separated", "tagged"] = Field( - # Keep tagged as default for backward compatibility - default="tagged", - description=( - """ - Strategy for handling model reasoning output. - - separated: Return clean text (without tags) + reasoning_content field. - Recommended for new workflows. Enables safe downstream parsing and - workflow variable access: {{#node_id.reasoning_content#}} - - tagged : Return original text (with tags) + reasoning_content field. - Maintains full backward compatibility while still providing reasoning_content - for workflow automation. Frontend thinking panels work as before. - """ - ), - ) - - @field_validator("prompt_config", mode="before") - @classmethod - def convert_none_prompt_config(cls, v: Any): - if v is None: - return PromptConfig() - return v - - @property - def structured_output_enabled(self) -> bool: - return self.structured_output_switch_on and self.structured_output is not None diff --git a/api/dify_graph/nodes/llm/exc.py b/api/dify_graph/nodes/llm/exc.py deleted file mode 100644 index 4d160952963..00000000000 --- a/api/dify_graph/nodes/llm/exc.py +++ /dev/null @@ -1,45 +0,0 @@ -class LLMNodeError(ValueError): - """Base class for LLM Node errors.""" - - -class VariableNotFoundError(LLMNodeError): - """Raised when a required variable is not found.""" - - -class InvalidContextStructureError(LLMNodeError): - """Raised when the context structure is invalid.""" - - -class InvalidVariableTypeError(LLMNodeError): - """Raised when the variable type is invalid.""" - - -class ModelNotExistError(LLMNodeError): - """Raised when the specified model does not exist.""" - - -class LLMModeRequiredError(LLMNodeError): - """Raised when LLM mode is required but not provided.""" - - -class NoPromptFoundError(LLMNodeError): - """Raised when no prompt is found in the LLM configuration.""" - - -class TemplateTypeNotSupportError(LLMNodeError): - def __init__(self, *, type_name: str): - super().__init__(f"Prompt type {type_name} is not supported.") - - -class MemoryRolePrefixRequiredError(LLMNodeError): - """Raised when memory role prefix is required for completion model.""" - - -class FileTypeNotSupportError(LLMNodeError): - def __init__(self, *, type_name: str): - super().__init__(f"{type_name} type is not supported by this model") - - -class UnsupportedPromptContentTypeError(LLMNodeError): - def __init__(self, *, type_name: str): - super().__init__(f"Prompt content type {type_name} is not supported.") diff --git a/api/dify_graph/nodes/llm/file_saver.py b/api/dify_graph/nodes/llm/file_saver.py deleted file mode 100644 index 50e52a3b6f2..00000000000 --- a/api/dify_graph/nodes/llm/file_saver.py +++ /dev/null @@ -1,144 +0,0 @@ -import mimetypes -import typing as tp - -from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE -from core.tools.signature import sign_tool_file -from core.tools.tool_file_manager import ToolFileManager -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.nodes.protocols import HttpClientProtocol - - -class LLMFileSaver(tp.Protocol): - """LLMFileSaver is responsible for save multimodal output returned by - LLM. - """ - - def save_binary_string( - self, - data: bytes, - mime_type: str, - file_type: FileType, - extension_override: str | None = None, - ) -> File: - """save_binary_string saves the inline file data returned by LLM. - - Currently (2025-04-30), only some of Google Gemini models will return - multimodal output as inline data. - - :param data: the contents of the file - :param mime_type: the media type of the file, specified by rfc6838 - (https://datatracker.ietf.org/doc/html/rfc6838) - :param file_type: The file type of the inline file. - :param extension_override: Override the auto-detected file extension while saving this file. - - The default value is `None`, which means do not override the file extension and guessing it - from the `mime_type` attribute while saving the file. - - Setting it to values other than `None` means override the file's extension, and - will bypass the extension guessing saving the file. - - Specially, setting it to empty string (`""`) will leave the file extension empty. - - When it is not `None` or empty string (`""`), it should be a string beginning with a - dot (`.`). For example, `.py` and `.tar.gz` are both valid values, while `py` - and `tar.gz` are not. - """ - raise NotImplementedError() - - def save_remote_url(self, url: str, file_type: FileType) -> File: - """save_remote_url saves the file from a remote url returned by LLM. - - Currently (2025-04-30), no model returns multimodel output as a url. - - :param url: the url of the file. - :param file_type: the file type of the file, check `FileType` enum for reference. - """ - raise NotImplementedError() - - -class FileSaverImpl(LLMFileSaver): - _tenant_id: str - _user_id: str - - def __init__(self, user_id: str, tenant_id: str, http_client: HttpClientProtocol): - self._user_id = user_id - self._tenant_id = tenant_id - self._http_client = http_client - - def _get_tool_file_manager(self): - return ToolFileManager() - - def save_remote_url(self, url: str, file_type: FileType) -> File: - http_response = self._http_client.get(url) - http_response.raise_for_status() - data = http_response.content - mime_type_from_header = http_response.headers.get("Content-Type") - mime_type, extension = _extract_content_type_and_extension(url, mime_type_from_header) - return self.save_binary_string(data, mime_type, file_type, extension_override=extension) - - def save_binary_string( - self, - data: bytes, - mime_type: str, - file_type: FileType, - extension_override: str | None = None, - ) -> File: - tool_file_manager = self._get_tool_file_manager() - tool_file = tool_file_manager.create_file_by_raw( - user_id=self._user_id, - tenant_id=self._tenant_id, - # TODO(QuantumGhost): what is conversation id? - conversation_id=None, - file_binary=data, - mimetype=mime_type, - ) - extension_override = _validate_extension_override(extension_override) - extension = _get_extension(mime_type, extension_override) - url = sign_tool_file(tool_file.id, extension) - - return File( - tenant_id=self._tenant_id, - type=file_type, - transfer_method=FileTransferMethod.TOOL_FILE, - filename=tool_file.name, - extension=extension, - mime_type=mime_type, - size=len(data), - related_id=tool_file.id, - url=url, - storage_key=tool_file.file_key, - ) - - -def _get_extension(mime_type: str, extension_override: str | None = None) -> str: - """get_extension return the extension of file. - - If the `extension_override` parameter is set, this function should honor it and - return its value. - """ - if extension_override is not None: - return extension_override - return mimetypes.guess_extension(mime_type) or DEFAULT_EXTENSION - - -def _extract_content_type_and_extension(url: str, content_type_header: str | None) -> tuple[str, str]: - """_extract_content_type_and_extension tries to - guess content type of file from url and `Content-Type` header in response. - """ - if content_type_header: - extension = mimetypes.guess_extension(content_type_header) or DEFAULT_EXTENSION - return content_type_header, extension - content_type = mimetypes.guess_type(url)[0] or DEFAULT_MIME_TYPE - extension = mimetypes.guess_extension(content_type) or DEFAULT_EXTENSION - return content_type, extension - - -def _validate_extension_override(extension_override: str | None) -> str | None: - # `extension_override` is allow to be `None or `""`. - if extension_override is None: - return None - if extension_override == "": - return "" - if not extension_override.startswith("."): - raise ValueError("extension_override should start with '.' if not None or empty.", extension_override) - return extension_override diff --git a/api/dify_graph/nodes/llm/llm_utils.py b/api/dify_graph/nodes/llm/llm_utils.py deleted file mode 100644 index 2be391a4240..00000000000 --- a/api/dify_graph/nodes/llm/llm_utils.py +++ /dev/null @@ -1,477 +0,0 @@ -from __future__ import annotations - -from collections.abc import Sequence -from typing import Any, cast - -from core.model_manager import ModelInstance -from dify_graph.file import FileType, file_manager -from dify_graph.file.models import File -from dify_graph.model_runtime.entities import ( - ImagePromptMessageContent, - PromptMessage, - PromptMessageContentType, - PromptMessageRole, - TextPromptMessageContent, -) -from dify_graph.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessageContentUnionTypes, - SystemPromptMessage, - UserPromptMessage, -) -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelPropertyKey -from dify_graph.model_runtime.memory import PromptMessageMemory -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from dify_graph.nodes.base.entities import VariableSelector -from dify_graph.runtime import VariablePool -from dify_graph.variables import ArrayFileSegment, FileSegment -from dify_graph.variables.segments import ArrayAnySegment, NoneSegment - -from .entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, MemoryConfig -from .exc import ( - InvalidVariableTypeError, - MemoryRolePrefixRequiredError, - NoPromptFoundError, - TemplateTypeNotSupportError, -) -from .protocols import TemplateRenderer - - -def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity: - model_schema = cast(LargeLanguageModel, model_instance.model_type_instance).get_model_schema( - model_instance.model_name, - dict(model_instance.credentials), - ) - if not model_schema: - raise ValueError(f"Model schema not found for {model_instance.model_name}") - return model_schema - - -def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence[File]: - variable = variable_pool.get(selector) - if variable is None: - return [] - elif isinstance(variable, FileSegment): - return [variable.value] - elif isinstance(variable, ArrayFileSegment): - return variable.value - elif isinstance(variable, NoneSegment | ArrayAnySegment): - return [] - raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}") - - -def convert_history_messages_to_text( - *, - history_messages: Sequence[PromptMessage], - human_prefix: str, - ai_prefix: str, -) -> str: - string_messages: list[str] = [] - for message in history_messages: - if message.role == PromptMessageRole.USER: - role = human_prefix - elif message.role == PromptMessageRole.ASSISTANT: - role = ai_prefix - else: - continue - - if isinstance(message.content, list): - content_parts = [] - for content in message.content: - if isinstance(content, TextPromptMessageContent): - content_parts.append(content.data) - elif isinstance(content, ImagePromptMessageContent): - content_parts.append("[image]") - - inner_msg = "\n".join(content_parts) - string_messages.append(f"{role}: {inner_msg}") - else: - string_messages.append(f"{role}: {message.content}") - - return "\n".join(string_messages) - - -def fetch_memory_text( - *, - memory: PromptMessageMemory, - max_token_limit: int, - message_limit: int | None = None, - human_prefix: str = "Human", - ai_prefix: str = "Assistant", -) -> str: - history_messages = memory.get_history_prompt_messages( - max_token_limit=max_token_limit, - message_limit=message_limit, - ) - return convert_history_messages_to_text( - history_messages=history_messages, - human_prefix=human_prefix, - ai_prefix=ai_prefix, - ) - - -def fetch_prompt_messages( - *, - sys_query: str | None = None, - sys_files: Sequence[File], - context: str | None = None, - memory: PromptMessageMemory | None = None, - model_instance: ModelInstance, - prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, - stop: Sequence[str] | None = None, - memory_config: MemoryConfig | None = None, - vision_enabled: bool = False, - vision_detail: ImagePromptMessageContent.DETAIL, - variable_pool: VariablePool, - jinja2_variables: Sequence[VariableSelector], - context_files: list[File] | None = None, - template_renderer: TemplateRenderer | None = None, -) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: - prompt_messages: list[PromptMessage] = [] - model_schema = fetch_model_schema(model_instance=model_instance) - - if isinstance(prompt_template, list): - prompt_messages.extend( - handle_list_messages( - messages=prompt_template, - context=context, - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - vision_detail_config=vision_detail, - template_renderer=template_renderer, - ) - ) - - prompt_messages.extend( - handle_memory_chat_mode( - memory=memory, - memory_config=memory_config, - model_instance=model_instance, - ) - ) - - if sys_query: - prompt_messages.extend( - handle_list_messages( - messages=[ - LLMNodeChatModelMessage( - text=sys_query, - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context="", - jinja2_variables=[], - variable_pool=variable_pool, - vision_detail_config=vision_detail, - template_renderer=template_renderer, - ) - ) - elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): - prompt_messages.extend( - handle_completion_template( - template=prompt_template, - context=context, - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - template_renderer=template_renderer, - ) - ) - - memory_text = handle_memory_completion_mode( - memory=memory, - memory_config=memory_config, - model_instance=model_instance, - ) - prompt_content = prompt_messages[0].content - if isinstance(prompt_content, str): - prompt_content = str(prompt_content) - if "#histories#" in prompt_content: - prompt_content = prompt_content.replace("#histories#", memory_text) - else: - prompt_content = memory_text + "\n" + prompt_content - prompt_messages[0].content = prompt_content - elif isinstance(prompt_content, list): - for content_item in prompt_content: - if isinstance(content_item, TextPromptMessageContent): - if "#histories#" in content_item.data: - content_item.data = content_item.data.replace("#histories#", memory_text) - else: - content_item.data = memory_text + "\n" + content_item.data - else: - raise ValueError("Invalid prompt content type") - - if sys_query: - if isinstance(prompt_content, str): - prompt_messages[0].content = str(prompt_messages[0].content).replace("#sys.query#", sys_query) - elif isinstance(prompt_content, list): - for content_item in prompt_content: - if isinstance(content_item, TextPromptMessageContent): - content_item.data = sys_query + "\n" + content_item.data - else: - raise ValueError("Invalid prompt content type") - else: - raise TemplateTypeNotSupportError(type_name=str(type(prompt_template))) - - _append_file_prompts( - prompt_messages=prompt_messages, - files=sys_files, - vision_enabled=vision_enabled, - vision_detail=vision_detail, - ) - _append_file_prompts( - prompt_messages=prompt_messages, - files=context_files or [], - vision_enabled=vision_enabled, - vision_detail=vision_detail, - ) - - filtered_prompt_messages: list[PromptMessage] = [] - for prompt_message in prompt_messages: - if isinstance(prompt_message.content, list): - prompt_message_content: list[PromptMessageContentUnionTypes] = [] - for content_item in prompt_message.content: - if not model_schema.features: - if content_item.type == PromptMessageContentType.TEXT: - prompt_message_content.append(content_item) - continue - - if ( - ( - content_item.type == PromptMessageContentType.IMAGE - and ModelFeature.VISION not in model_schema.features - ) - or ( - content_item.type == PromptMessageContentType.DOCUMENT - and ModelFeature.DOCUMENT not in model_schema.features - ) - or ( - content_item.type == PromptMessageContentType.VIDEO - and ModelFeature.VIDEO not in model_schema.features - ) - or ( - content_item.type == PromptMessageContentType.AUDIO - and ModelFeature.AUDIO not in model_schema.features - ) - ): - continue - prompt_message_content.append(content_item) - if not prompt_message_content: - continue - if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT: - prompt_message.content = prompt_message_content[0].data - else: - prompt_message.content = prompt_message_content - filtered_prompt_messages.append(prompt_message) - elif not prompt_message.is_empty(): - filtered_prompt_messages.append(prompt_message) - - if len(filtered_prompt_messages) == 0: - raise NoPromptFoundError( - "No prompt found in the LLM configuration. Please ensure a prompt is properly configured before proceeding." - ) - - return filtered_prompt_messages, stop - - -def handle_list_messages( - *, - messages: Sequence[LLMNodeChatModelMessage], - context: str | None, - jinja2_variables: Sequence[VariableSelector], - variable_pool: VariablePool, - vision_detail_config: ImagePromptMessageContent.DETAIL, - template_renderer: TemplateRenderer | None = None, -) -> Sequence[PromptMessage]: - prompt_messages: list[PromptMessage] = [] - for message in messages: - if message.edition_type == "jinja2": - result_text = render_jinja2_message( - template=message.jinja2_text or "", - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - template_renderer=template_renderer, - ) - prompt_messages.append( - combine_message_content_with_role( - contents=[TextPromptMessageContent(data=result_text)], - role=message.role, - ) - ) - continue - - template = message.text.replace("{#context#}", context) if context else message.text - segment_group = variable_pool.convert_template(template) - file_contents: list[PromptMessageContentUnionTypes] = [] - for segment in segment_group.value: - if isinstance(segment, ArrayFileSegment): - for file in segment.value: - if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: - file_contents.append( - file_manager.to_prompt_message_content(file, image_detail_config=vision_detail_config) - ) - elif isinstance(segment, FileSegment): - file = segment.value - if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: - file_contents.append( - file_manager.to_prompt_message_content(file, image_detail_config=vision_detail_config) - ) - - if segment_group.text: - prompt_messages.append( - combine_message_content_with_role( - contents=[TextPromptMessageContent(data=segment_group.text)], - role=message.role, - ) - ) - if file_contents: - prompt_messages.append(combine_message_content_with_role(contents=file_contents, role=message.role)) - - return prompt_messages - - -def render_jinja2_message( - *, - template: str, - jinja2_variables: Sequence[VariableSelector], - variable_pool: VariablePool, - template_renderer: TemplateRenderer | None = None, -) -> str: - if not template: - return "" - if template_renderer is None: - raise ValueError("template_renderer is required for jinja2 prompt rendering") - - jinja2_inputs: dict[str, Any] = {} - for jinja2_variable in jinja2_variables: - variable = variable_pool.get(jinja2_variable.value_selector) - jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else "" - return template_renderer.render_jinja2(template=template, inputs=jinja2_inputs) - - -def handle_completion_template( - *, - template: LLMNodeCompletionModelPromptTemplate, - context: str | None, - jinja2_variables: Sequence[VariableSelector], - variable_pool: VariablePool, - template_renderer: TemplateRenderer | None = None, -) -> Sequence[PromptMessage]: - if template.edition_type == "jinja2": - result_text = render_jinja2_message( - template=template.jinja2_text or "", - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - template_renderer=template_renderer, - ) - else: - template_text = template.text.replace("{#context#}", context) if context else template.text - result_text = variable_pool.convert_template(template_text).text - return [ - combine_message_content_with_role( - contents=[TextPromptMessageContent(data=result_text)], - role=PromptMessageRole.USER, - ) - ] - - -def combine_message_content_with_role( - *, - contents: str | list[PromptMessageContentUnionTypes] | None = None, - role: PromptMessageRole, -) -> PromptMessage: - match role: - case PromptMessageRole.USER: - return UserPromptMessage(content=contents) - case PromptMessageRole.ASSISTANT: - return AssistantPromptMessage(content=contents) - case PromptMessageRole.SYSTEM: - return SystemPromptMessage(content=contents) - case _: - raise NotImplementedError(f"Role {role} is not supported") - - -def calculate_rest_token(*, prompt_messages: list[PromptMessage], model_instance: ModelInstance) -> int: - rest_tokens = 2000 - runtime_model_schema = fetch_model_schema(model_instance=model_instance) - runtime_model_parameters = model_instance.parameters - - model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) - if model_context_tokens: - curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) - - max_tokens = 0 - for parameter_rule in runtime_model_schema.parameter_rules: - if parameter_rule.name == "max_tokens" or ( - parameter_rule.use_template and parameter_rule.use_template == "max_tokens" - ): - max_tokens = ( - runtime_model_parameters.get(parameter_rule.name) - or runtime_model_parameters.get(str(parameter_rule.use_template)) - or 0 - ) - - rest_tokens = model_context_tokens - max_tokens - curr_message_tokens - rest_tokens = max(rest_tokens, 0) - - return rest_tokens - - -def handle_memory_chat_mode( - *, - memory: PromptMessageMemory | None, - memory_config: MemoryConfig | None, - model_instance: ModelInstance, -) -> Sequence[PromptMessage]: - if not memory or not memory_config: - return [] - rest_tokens = calculate_rest_token(prompt_messages=[], model_instance=model_instance) - return memory.get_history_prompt_messages( - max_token_limit=rest_tokens, - message_limit=memory_config.window.size if memory_config.window.enabled else None, - ) - - -def handle_memory_completion_mode( - *, - memory: PromptMessageMemory | None, - memory_config: MemoryConfig | None, - model_instance: ModelInstance, -) -> str: - if not memory or not memory_config: - return "" - - rest_tokens = calculate_rest_token(prompt_messages=[], model_instance=model_instance) - if not memory_config.role_prefix: - raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") - - return fetch_memory_text( - memory=memory, - max_token_limit=rest_tokens, - message_limit=memory_config.window.size if memory_config.window.enabled else None, - human_prefix=memory_config.role_prefix.user, - ai_prefix=memory_config.role_prefix.assistant, - ) - - -def _append_file_prompts( - *, - prompt_messages: list[PromptMessage], - files: Sequence[File], - vision_enabled: bool, - vision_detail: ImagePromptMessageContent.DETAIL, -) -> None: - if not vision_enabled or not files: - return - - file_prompts = [file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) for file in files] - if ( - prompt_messages - and isinstance(prompt_messages[-1], UserPromptMessage) - and isinstance(prompt_messages[-1].content, list) - ): - existing_contents = prompt_messages[-1].content - assert isinstance(existing_contents, list) - prompt_messages[-1] = UserPromptMessage(content=file_prompts + existing_contents) - else: - prompt_messages.append(UserPromptMessage(content=file_prompts)) diff --git a/api/dify_graph/nodes/llm/node.py b/api/dify_graph/nodes/llm/node.py deleted file mode 100644 index 5ed90ed7e36..00000000000 --- a/api/dify_graph/nodes/llm/node.py +++ /dev/null @@ -1,1031 +0,0 @@ -from __future__ import annotations - -import base64 -import io -import json -import logging -import re -import time -from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Literal - -from sqlalchemy import select - -from core.llm_generator.output_parser.errors import OutputParserError -from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output -from core.model_manager import ModelInstance -from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig -from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.tools.signature import sign_upload_file -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import ( - BuiltinNodeTypes, - NodeType, - SystemVariableKey, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.model_runtime.entities import ( - ImagePromptMessageContent, - PromptMessage, - TextPromptMessageContent, -) -from dify_graph.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkWithStructuredOutput, - LLMResultWithStructuredOutput, - LLMStructuredOutput, - LLMUsage, -) -from dify_graph.model_runtime.entities.message_entities import PromptMessageContentUnionTypes -from dify_graph.model_runtime.memory import PromptMessageMemory -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.node_events import ( - ModelInvokeCompletedEvent, - NodeEventBase, - NodeRunResult, - RunRetrieverResourceEvent, - StreamChunkEvent, - StreamCompletedEvent, -) -from dify_graph.nodes.base.entities import VariableSelector -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer -from dify_graph.nodes.protocols import HttpClientProtocol -from dify_graph.runtime import VariablePool -from dify_graph.variables import ( - ArrayFileSegment, - ArraySegment, - NoneSegment, - ObjectSegment, - StringSegment, -) -from extensions.ext_database import db -from models.dataset import SegmentAttachmentBinding -from models.model import UploadFile - -from . import llm_utils -from .entities import ( - LLMNodeChatModelMessage, - LLMNodeCompletionModelPromptTemplate, - LLMNodeData, -) -from .exc import ( - InvalidContextStructureError, - InvalidVariableTypeError, - LLMNodeError, - VariableNotFoundError, -) -from .file_saver import FileSaverImpl, LLMFileSaver - -if TYPE_CHECKING: - from dify_graph.file.models import File - from dify_graph.runtime import GraphRuntimeState - -logger = logging.getLogger(__name__) - - -class LLMNode(Node[LLMNodeData]): - node_type = BuiltinNodeTypes.LLM - - # Compiled regex for extracting blocks (with compatibility for attributes) - _THINK_PATTERN = re.compile(r"]*>(.*?)", re.IGNORECASE | re.DOTALL) - - # Instance attributes specific to LLMNode. - # Output variable for file - _file_outputs: list[File] - - _llm_file_saver: LLMFileSaver - _credentials_provider: CredentialsProvider - _model_factory: ModelFactory - _model_instance: ModelInstance - _memory: PromptMessageMemory | None - _template_renderer: TemplateRenderer - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - *, - credentials_provider: CredentialsProvider, - model_factory: ModelFactory, - model_instance: ModelInstance, - http_client: HttpClientProtocol, - template_renderer: TemplateRenderer, - memory: PromptMessageMemory | None = None, - llm_file_saver: LLMFileSaver | None = None, - ): - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - # LLM file outputs, used for MultiModal outputs. - self._file_outputs = [] - - self._credentials_provider = credentials_provider - self._model_factory = model_factory - self._model_instance = model_instance - self._memory = memory - self._template_renderer = template_renderer - - if llm_file_saver is None: - dify_ctx = self.require_dify_context() - llm_file_saver = FileSaverImpl( - user_id=dify_ctx.user_id, - tenant_id=dify_ctx.tenant_id, - http_client=http_client, - ) - self._llm_file_saver = llm_file_saver - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> Generator: - node_inputs: dict[str, Any] = {} - process_data: dict[str, Any] = {} - result_text = "" - clean_text = "" - usage = LLMUsage.empty_usage() - finish_reason = None - reasoning_content = None - variable_pool = self.graph_runtime_state.variable_pool - - try: - # init messages template - self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template) - - # fetch variables and fetch values from variable pool - inputs = self._fetch_inputs(node_data=self.node_data) - - # fetch jinja2 inputs - jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data) - - # merge inputs - inputs.update(jinja_inputs) - - # fetch files - files = ( - llm_utils.fetch_files( - variable_pool=variable_pool, - selector=self.node_data.vision.configs.variable_selector, - ) - if self.node_data.vision.enabled - else [] - ) - - if files: - node_inputs["#files#"] = [file.to_dict() for file in files] - - # fetch context value - generator = self._fetch_context(node_data=self.node_data) - context = None - context_files: list[File] = [] - for event in generator: - context = event.context - context_files = event.context_files or [] - yield event - if context: - node_inputs["#context#"] = context - - if context_files: - node_inputs["#context_files#"] = [file.model_dump() for file in context_files] - - # fetch model config - model_instance = self._model_instance - model_name = model_instance.model_name - model_provider = model_instance.provider - model_stop = model_instance.stop - - memory = self._memory - - query: str | None = None - if self.node_data.memory: - query = self.node_data.memory.query_prompt_template - if not query and ( - query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) - ): - query = query_variable.text - - prompt_messages, stop = LLMNode.fetch_prompt_messages( - sys_query=query, - sys_files=files, - context=context, - memory=memory, - model_instance=model_instance, - stop=model_stop, - prompt_template=self.node_data.prompt_template, - memory_config=self.node_data.memory, - vision_enabled=self.node_data.vision.enabled, - vision_detail=self.node_data.vision.configs.detail, - variable_pool=variable_pool, - jinja2_variables=self.node_data.prompt_config.jinja2_variables, - context_files=context_files, - template_renderer=self._template_renderer, - ) - - # handle invoke result - generator = LLMNode.invoke_llm( - model_instance=model_instance, - prompt_messages=prompt_messages, - stop=stop, - user_id=self.require_dify_context().user_id, - structured_output_enabled=self.node_data.structured_output_enabled, - structured_output=self.node_data.structured_output, - file_saver=self._llm_file_saver, - file_outputs=self._file_outputs, - node_id=self._node_id, - node_type=self.node_type, - reasoning_format=self.node_data.reasoning_format, - ) - - structured_output: LLMStructuredOutput | None = None - - for event in generator: - if isinstance(event, StreamChunkEvent): - yield event - elif isinstance(event, ModelInvokeCompletedEvent): - # Raw text - result_text = event.text - usage = event.usage - finish_reason = event.finish_reason - reasoning_content = event.reasoning_content or "" - - # For downstream nodes, determine clean text based on reasoning_format - if self.node_data.reasoning_format == "tagged": - # Keep tags for backward compatibility - clean_text = result_text - else: - # Extract clean text from tags - clean_text, _ = LLMNode._split_reasoning(result_text, self.node_data.reasoning_format) - - # Process structured output if available from the event. - structured_output = ( - LLMStructuredOutput(structured_output=event.structured_output) - if event.structured_output - else None - ) - - break - elif isinstance(event, LLMStructuredOutput): - structured_output = event - - process_data = { - "model_mode": self.node_data.model.mode, - "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( - model_mode=self.node_data.model.mode, prompt_messages=prompt_messages - ), - "usage": jsonable_encoder(usage), - "finish_reason": finish_reason, - "model_provider": model_provider, - "model_name": model_name, - } - - outputs = { - "text": clean_text, - "reasoning_content": reasoning_content, - "usage": jsonable_encoder(usage), - "finish_reason": finish_reason, - } - if structured_output: - outputs["structured_output"] = structured_output.structured_output - if self._file_outputs: - outputs["files"] = ArrayFileSegment(value=self._file_outputs) - - # Send final chunk event to indicate streaming is complete - yield StreamChunkEvent( - selector=[self._node_id, "text"], - chunk="", - is_final=True, - ) - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=node_inputs, - process_data=process_data, - outputs=outputs, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - }, - llm_usage=usage, - ) - ) - except ValueError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - inputs=node_inputs, - process_data=process_data, - error_type=type(e).__name__, - llm_usage=usage, - ) - ) - except Exception as e: - logger.exception("error while executing llm node") - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - inputs=node_inputs, - process_data=process_data, - error_type=type(e).__name__, - llm_usage=usage, - ) - ) - - @staticmethod - def invoke_llm( - *, - model_instance: ModelInstance, - prompt_messages: Sequence[PromptMessage], - stop: Sequence[str] | None = None, - user_id: str, - structured_output_enabled: bool, - structured_output: Mapping[str, Any] | None = None, - file_saver: LLMFileSaver, - file_outputs: list[File], - node_id: str, - node_type: NodeType, - reasoning_format: Literal["separated", "tagged"] = "tagged", - ) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]: - model_parameters = model_instance.parameters - invoke_model_parameters = dict(model_parameters) - - model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) - - if structured_output_enabled: - output_schema = LLMNode.fetch_structured_output_schema( - structured_output=structured_output or {}, - ) - request_start_time = time.perf_counter() - - invoke_result = invoke_llm_with_structured_output( - provider=model_instance.provider, - model_schema=model_schema, - model_instance=model_instance, - prompt_messages=prompt_messages, - json_schema=output_schema, - model_parameters=invoke_model_parameters, - stop=list(stop or []), - stream=True, - user=user_id, - ) - else: - request_start_time = time.perf_counter() - - invoke_result = model_instance.invoke_llm( - prompt_messages=list(prompt_messages), - model_parameters=invoke_model_parameters, - stop=list(stop or []), - stream=True, - user=user_id, - ) - - return LLMNode.handle_invoke_result( - invoke_result=invoke_result, - file_saver=file_saver, - file_outputs=file_outputs, - node_id=node_id, - node_type=node_type, - reasoning_format=reasoning_format, - request_start_time=request_start_time, - ) - - @staticmethod - def handle_invoke_result( - *, - invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None], - file_saver: LLMFileSaver, - file_outputs: list[File], - node_id: str, - node_type: NodeType, - reasoning_format: Literal["separated", "tagged"] = "tagged", - request_start_time: float | None = None, - ) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]: - # For blocking mode - if isinstance(invoke_result, LLMResult): - duration = None - if request_start_time is not None: - duration = time.perf_counter() - request_start_time - invoke_result.usage.latency = round(duration, 3) - event = LLMNode.handle_blocking_result( - invoke_result=invoke_result, - saver=file_saver, - file_outputs=file_outputs, - reasoning_format=reasoning_format, - request_latency=duration, - ) - yield event - return - - # For streaming mode - model = "" - prompt_messages: list[PromptMessage] = [] - - usage = LLMUsage.empty_usage() - finish_reason = None - full_text_buffer = io.StringIO() - - # Initialize streaming metrics tracking - start_time = request_start_time if request_start_time is not None else time.perf_counter() - first_token_time = None - has_content = False - - collected_structured_output = None # Collect structured_output from streaming chunks - # Consume the invoke result and handle generator exception - try: - for result in invoke_result: - if isinstance(result, LLMResultChunkWithStructuredOutput): - # Collect structured_output from the chunk - if result.structured_output is not None: - collected_structured_output = dict(result.structured_output) - yield result - if isinstance(result, LLMResultChunk): - contents = result.delta.message.content - for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown( - contents=contents, - file_saver=file_saver, - file_outputs=file_outputs, - ): - # Detect first token for TTFT calculation - if text_part and not has_content: - first_token_time = time.perf_counter() - has_content = True - - full_text_buffer.write(text_part) - yield StreamChunkEvent( - selector=[node_id, "text"], - chunk=text_part, - is_final=False, - ) - - # Update the whole metadata - if not model and result.model: - model = result.model - if len(prompt_messages) == 0: - # TODO(QuantumGhost): it seems that this update has no visable effect. - # What's the purpose of the line below? - prompt_messages = list(result.prompt_messages) - if usage.prompt_tokens == 0 and result.delta.usage: - usage = result.delta.usage - if finish_reason is None and result.delta.finish_reason: - finish_reason = result.delta.finish_reason - except OutputParserError as e: - raise LLMNodeError(f"Failed to parse structured output: {e}") - - # Extract reasoning content from tags in the main text - full_text = full_text_buffer.getvalue() - - if reasoning_format == "tagged": - # Keep tags in text for backward compatibility - clean_text = full_text - reasoning_content = "" - else: - # Extract clean text and reasoning from tags - clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format) - - # Calculate streaming metrics - end_time = time.perf_counter() - total_duration = end_time - start_time - usage.latency = round(total_duration, 3) - if has_content and first_token_time: - gen_ai_server_time_to_first_token = first_token_time - start_time - llm_streaming_time_to_generate = end_time - first_token_time - usage.time_to_first_token = round(gen_ai_server_time_to_first_token, 3) - usage.time_to_generate = round(llm_streaming_time_to_generate, 3) - - yield ModelInvokeCompletedEvent( - # Use clean_text for separated mode, full_text for tagged mode - text=clean_text if reasoning_format == "separated" else full_text, - usage=usage, - finish_reason=finish_reason, - # Reasoning content for workflow variables and downstream nodes - reasoning_content=reasoning_content, - # Pass structured output if collected from streaming chunks - structured_output=collected_structured_output, - ) - - @staticmethod - def _image_file_to_markdown(file: File, /): - text_chunk = f"![]({file.generate_url()})" - return text_chunk - - @classmethod - def _split_reasoning( - cls, text: str, reasoning_format: Literal["separated", "tagged"] = "tagged" - ) -> tuple[str, str]: - """ - Split reasoning content from text based on reasoning_format strategy. - - Args: - text: Full text that may contain blocks - reasoning_format: Strategy for handling reasoning content - - "separated": Remove tags and return clean text + reasoning_content field - - "tagged": Keep tags in text, return empty reasoning_content - - Returns: - tuple of (clean_text, reasoning_content) - """ - - if reasoning_format == "tagged": - return text, "" - - # Find all ... blocks (case-insensitive) - matches = cls._THINK_PATTERN.findall(text) - - # Extract reasoning content from all blocks - reasoning_content = "\n".join(match.strip() for match in matches) if matches else "" - - # Remove all ... blocks from original text - clean_text = cls._THINK_PATTERN.sub("", text) - - # Clean up extra whitespace - clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip() - - # Separated mode: always return clean text and reasoning_content - return clean_text, reasoning_content or "" - - def _transform_chat_messages( - self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, / - ) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: - if isinstance(messages, LLMNodeCompletionModelPromptTemplate): - if messages.edition_type == "jinja2" and messages.jinja2_text: - messages.text = messages.jinja2_text - - return messages - - for message in messages: - if message.edition_type == "jinja2" and message.jinja2_text: - message.text = message.jinja2_text - - return messages - - def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]: - variables: dict[str, Any] = {} - - if not node_data.prompt_config: - return variables - - for variable_selector in node_data.prompt_config.jinja2_variables or []: - variable_name = variable_selector.variable - variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) - if variable is None: - raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") - - def parse_dict(input_dict: Mapping[str, Any]) -> str: - """ - Parse dict into string - """ - # check if it's a context structure - if "metadata" in input_dict and "_source" in input_dict["metadata"] and "content" in input_dict: - return str(input_dict["content"]) - - # else, parse the dict - try: - return json.dumps(input_dict, ensure_ascii=False) - except Exception: - return str(input_dict) - - if isinstance(variable, ArraySegment): - result = "" - for item in variable.value: - if isinstance(item, dict): - result += parse_dict(item) - else: - result += str(item) - result += "\n" - value = result.strip() - elif isinstance(variable, ObjectSegment): - value = parse_dict(variable.value) - else: - value = variable.text - - variables[variable_name] = value - - return variables - - def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, Any]: - inputs = {} - prompt_template = node_data.prompt_template - - variable_selectors = [] - if isinstance(prompt_template, list): - for prompt in prompt_template: - variable_template_parser = VariableTemplateParser(template=prompt.text) - variable_selectors.extend(variable_template_parser.extract_variable_selectors()) - elif isinstance(prompt_template, CompletionModelPromptTemplate): - variable_template_parser = VariableTemplateParser(template=prompt_template.text) - variable_selectors = variable_template_parser.extract_variable_selectors() - - for variable_selector in variable_selectors: - variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) - if variable is None: - raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") - if isinstance(variable, NoneSegment): - inputs[variable_selector.variable] = "" - inputs[variable_selector.variable] = variable.to_object() - - memory = node_data.memory - if memory and memory.query_prompt_template: - query_variable_selectors = VariableTemplateParser( - template=memory.query_prompt_template - ).extract_variable_selectors() - for variable_selector in query_variable_selectors: - variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) - if variable is None: - raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") - if isinstance(variable, NoneSegment): - continue - inputs[variable_selector.variable] = variable.to_object() - - return inputs - - def _fetch_context(self, node_data: LLMNodeData): - if not node_data.context.enabled: - return - - if not node_data.context.variable_selector: - return - - context_value_variable = self.graph_runtime_state.variable_pool.get(node_data.context.variable_selector) - if context_value_variable: - if isinstance(context_value_variable, StringSegment): - yield RunRetrieverResourceEvent( - retriever_resources=[], context=context_value_variable.value, context_files=[] - ) - elif isinstance(context_value_variable, ArraySegment): - context_str = "" - original_retriever_resource: list[dict[str, Any]] = [] - context_files: list[File] = [] - for item in context_value_variable.value: - if isinstance(item, str): - context_str += item + "\n" - else: - if "content" not in item: - raise InvalidContextStructureError(f"Invalid context structure: {item}") - - if item.get("summary"): - context_str += item["summary"] + "\n" - context_str += item["content"] + "\n" - - retriever_resource = self._convert_to_original_retriever_resource(item) - if retriever_resource: - original_retriever_resource.append(retriever_resource) - segment_id = retriever_resource.get("segment_id") - if not segment_id: - continue - attachments_with_bindings = db.session.execute( - select(SegmentAttachmentBinding, UploadFile) - .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) - .where( - SegmentAttachmentBinding.segment_id == segment_id, - ) - ).all() - if attachments_with_bindings: - for _, upload_file in attachments_with_bindings: - attachment_info = File( - id=upload_file.id, - filename=upload_file.name, - extension="." + upload_file.extension, - mime_type=upload_file.mime_type, - tenant_id=self.require_dify_context().tenant_id, - type=FileType.IMAGE, - transfer_method=FileTransferMethod.LOCAL_FILE, - remote_url=upload_file.source_url, - related_id=upload_file.id, - size=upload_file.size, - storage_key=upload_file.key, - url=sign_upload_file(upload_file.id, upload_file.extension), - ) - context_files.append(attachment_info) - yield RunRetrieverResourceEvent( - retriever_resources=original_retriever_resource, - context=context_str.strip(), - context_files=context_files, - ) - - def _convert_to_original_retriever_resource(self, context_dict: dict) -> dict[str, Any] | None: - if ( - "metadata" in context_dict - and "_source" in context_dict["metadata"] - and context_dict["metadata"]["_source"] == "knowledge" - ): - metadata = context_dict.get("metadata", {}) - - return { - "position": metadata.get("position"), - "dataset_id": metadata.get("dataset_id"), - "dataset_name": metadata.get("dataset_name"), - "document_id": metadata.get("document_id"), - "document_name": metadata.get("document_name"), - "data_source_type": metadata.get("data_source_type"), - "segment_id": metadata.get("segment_id"), - "retriever_from": metadata.get("retriever_from"), - "score": metadata.get("score"), - "hit_count": metadata.get("segment_hit_count"), - "word_count": metadata.get("segment_word_count"), - "segment_position": metadata.get("segment_position"), - "index_node_hash": metadata.get("segment_index_node_hash"), - "content": context_dict.get("content"), - "page": metadata.get("page"), - "doc_metadata": metadata.get("doc_metadata"), - "files": context_dict.get("files"), - "summary": context_dict.get("summary"), - } - - return None - - @staticmethod - def fetch_prompt_messages( - *, - sys_query: str | None = None, - sys_files: Sequence[File], - context: str | None = None, - memory: PromptMessageMemory | None = None, - model_instance: ModelInstance, - prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, - stop: Sequence[str] | None = None, - memory_config: MemoryConfig | None = None, - vision_enabled: bool = False, - vision_detail: ImagePromptMessageContent.DETAIL, - variable_pool: VariablePool, - jinja2_variables: Sequence[VariableSelector], - context_files: list[File] | None = None, - template_renderer: TemplateRenderer | None = None, - ) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: - return llm_utils.fetch_prompt_messages( - sys_query=sys_query, - sys_files=sys_files, - context=context, - memory=memory, - model_instance=model_instance, - prompt_template=prompt_template, - stop=stop, - memory_config=memory_config, - vision_enabled=vision_enabled, - vision_detail=vision_detail, - variable_pool=variable_pool, - jinja2_variables=jinja2_variables, - context_files=context_files, - template_renderer=template_renderer, - ) - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: LLMNodeData, - ) -> Mapping[str, Sequence[str]]: - # graph_config is not used in this node type - _ = graph_config # Explicitly mark as unused - prompt_template = node_data.prompt_template - variable_selectors = [] - if isinstance(prompt_template, list): - for prompt in prompt_template: - if prompt.edition_type != "jinja2": - variable_template_parser = VariableTemplateParser(template=prompt.text) - variable_selectors.extend(variable_template_parser.extract_variable_selectors()) - elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): - if prompt_template.edition_type != "jinja2": - variable_template_parser = VariableTemplateParser(template=prompt_template.text) - variable_selectors = variable_template_parser.extract_variable_selectors() - else: - raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}") - - variable_mapping: dict[str, Any] = {} - for variable_selector in variable_selectors: - variable_mapping[variable_selector.variable] = variable_selector.value_selector - - memory = node_data.memory - if memory and memory.query_prompt_template: - query_variable_selectors = VariableTemplateParser( - template=memory.query_prompt_template - ).extract_variable_selectors() - for variable_selector in query_variable_selectors: - variable_mapping[variable_selector.variable] = variable_selector.value_selector - - if node_data.context.enabled: - variable_mapping["#context#"] = node_data.context.variable_selector - - if node_data.vision.enabled: - variable_mapping["#files#"] = node_data.vision.configs.variable_selector - - if node_data.memory: - variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY] - - if node_data.prompt_config: - enable_jinja = False - - if isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): - if prompt_template.edition_type == "jinja2": - enable_jinja = True - else: - for prompt in prompt_template: - if prompt.edition_type == "jinja2": - enable_jinja = True - break - - if enable_jinja: - for variable_selector in node_data.prompt_config.jinja2_variables or []: - variable_mapping[variable_selector.variable] = variable_selector.value_selector - - variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} - - return variable_mapping - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - return { - "type": "llm", - "config": { - "prompt_templates": { - "chat_model": { - "prompts": [ - {"role": "system", "text": "You are a helpful AI assistant.", "edition_type": "basic"} - ] - }, - "completion_model": { - "conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"}, - "prompt": { - "text": "Here are the chat histories between human and assistant, inside " - " XML tags.\n\n\n{{" - "#histories#}}\n\n\n\nHuman: {{#sys.query#}}\n\nAssistant:", - "edition_type": "basic", - }, - "stop": ["Human:"], - }, - } - }, - } - - @staticmethod - def handle_list_messages( - *, - messages: Sequence[LLMNodeChatModelMessage], - context: str | None, - jinja2_variables: Sequence[VariableSelector], - variable_pool: VariablePool, - vision_detail_config: ImagePromptMessageContent.DETAIL, - template_renderer: TemplateRenderer | None = None, - ) -> Sequence[PromptMessage]: - return llm_utils.handle_list_messages( - messages=messages, - context=context, - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - vision_detail_config=vision_detail_config, - template_renderer=template_renderer, - ) - - @staticmethod - def handle_blocking_result( - *, - invoke_result: LLMResult | LLMResultWithStructuredOutput, - saver: LLMFileSaver, - file_outputs: list[File], - reasoning_format: Literal["separated", "tagged"] = "tagged", - request_latency: float | None = None, - ) -> ModelInvokeCompletedEvent: - buffer = io.StringIO() - for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown( - contents=invoke_result.message.content, - file_saver=saver, - file_outputs=file_outputs, - ): - buffer.write(text_part) - - # Extract reasoning content from tags in the main text - full_text = buffer.getvalue() - - if reasoning_format == "tagged": - # Keep tags in text for backward compatibility - clean_text = full_text - reasoning_content = "" - else: - # Extract clean text and reasoning from tags - clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format) - - event = ModelInvokeCompletedEvent( - # Use clean_text for separated mode, full_text for tagged mode - text=clean_text if reasoning_format == "separated" else full_text, - usage=invoke_result.usage, - finish_reason=None, - # Reasoning content for workflow variables and downstream nodes - reasoning_content=reasoning_content, - # Pass structured output if enabled - structured_output=getattr(invoke_result, "structured_output", None), - ) - if request_latency is not None: - event.usage.latency = round(request_latency, 3) - return event - - @staticmethod - def save_multimodal_image_output( - *, - content: ImagePromptMessageContent, - file_saver: LLMFileSaver, - ) -> File: - """_save_multimodal_output saves multi-modal contents generated by LLM plugins. - - There are two kinds of multimodal outputs: - - - Inlined data encoded in base64, which would be saved to storage directly. - - Remote files referenced by an url, which would be downloaded and then saved to storage. - - Currently, only image files are supported. - """ - if content.url != "": - saved_file = file_saver.save_remote_url(content.url, FileType.IMAGE) - else: - saved_file = file_saver.save_binary_string( - data=base64.b64decode(content.base64_data), - mime_type=content.mime_type, - file_type=FileType.IMAGE, - ) - return saved_file - - @staticmethod - def fetch_structured_output_schema( - *, - structured_output: Mapping[str, Any], - ) -> dict[str, Any]: - """ - Fetch the structured output schema from the node data. - - Returns: - dict[str, Any]: The structured output schema - """ - if not structured_output: - raise LLMNodeError("Please provide a valid structured output schema") - structured_output_schema = json.dumps(structured_output.get("schema", {}), ensure_ascii=False) - if not structured_output_schema: - raise LLMNodeError("Please provide a valid structured output schema") - - try: - schema = json.loads(structured_output_schema) - if not isinstance(schema, dict): - raise LLMNodeError("structured_output_schema must be a JSON object") - return schema - except json.JSONDecodeError: - raise LLMNodeError("structured_output_schema is not valid JSON format") - - @staticmethod - def _save_multimodal_output_and_convert_result_to_markdown( - *, - contents: str | list[PromptMessageContentUnionTypes] | None, - file_saver: LLMFileSaver, - file_outputs: list[File], - ) -> Generator[str, None, None]: - """Convert intermediate prompt messages into strings and yield them to the caller. - - If the messages contain non-textual content (e.g., multimedia like images or videos), - it will be saved separately, and the corresponding Markdown representation will - be yielded to the caller. - """ - - # NOTE(QuantumGhost): This function should yield results to the caller immediately - # whenever new content or partial content is available. Avoid any intermediate buffering - # of results. Additionally, do not yield empty strings; instead, yield from an empty list - # if necessary. - if contents is None: - yield from [] - return - if isinstance(contents, str): - yield contents - else: - for item in contents: - if isinstance(item, TextPromptMessageContent): - yield item.data - elif isinstance(item, ImagePromptMessageContent): - file = LLMNode.save_multimodal_image_output( - content=item, - file_saver=file_saver, - ) - file_outputs.append(file) - yield LLMNode._image_file_to_markdown(file) - else: - logger.warning("unknown item type encountered, type=%s", type(item)) - yield str(item) - - @property - def retry(self) -> bool: - return self.node_data.retry_config.retry_enabled - - @property - def model_instance(self) -> ModelInstance: - return self._model_instance diff --git a/api/dify_graph/nodes/llm/protocols.py b/api/dify_graph/nodes/llm/protocols.py deleted file mode 100644 index 9e95d341c93..00000000000 --- a/api/dify_graph/nodes/llm/protocols.py +++ /dev/null @@ -1,30 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping -from typing import Any, Protocol - -from core.model_manager import ModelInstance - - -class CredentialsProvider(Protocol): - """Port for loading runtime credentials for a provider/model pair.""" - - def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: - """Return credentials for the target provider/model or raise a domain error.""" - ... - - -class ModelFactory(Protocol): - """Port for creating initialized LLM model instances for execution.""" - - def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance: - """Create a model instance that is ready for schema lookup and invocation.""" - ... - - -class TemplateRenderer(Protocol): - """Port for rendering prompt templates used by LLM-compatible nodes.""" - - def render_jinja2(self, *, template: str, inputs: Mapping[str, Any]) -> str: - """Render the given Jinja2 template into plain text.""" - ... diff --git a/api/dify_graph/nodes/loop/__init__.py b/api/dify_graph/nodes/loop/__init__.py deleted file mode 100644 index 9fe695607b9..00000000000 --- a/api/dify_graph/nodes/loop/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .entities import LoopNodeData -from .loop_end_node import LoopEndNode -from .loop_node import LoopNode -from .loop_start_node import LoopStartNode - -__all__ = ["LoopEndNode", "LoopNode", "LoopNodeData", "LoopStartNode"] diff --git a/api/dify_graph/nodes/loop/entities.py b/api/dify_graph/nodes/loop/entities.py deleted file mode 100644 index f0bfad5a0f5..00000000000 --- a/api/dify_graph/nodes/loop/entities.py +++ /dev/null @@ -1,107 +0,0 @@ -from enum import StrEnum -from typing import Annotated, Any, Literal - -from pydantic import AfterValidator, BaseModel, Field, field_validator - -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState -from dify_graph.utils.condition.entities import Condition -from dify_graph.variables.types import SegmentType - -_VALID_VAR_TYPE = frozenset( - [ - SegmentType.STRING, - SegmentType.NUMBER, - SegmentType.OBJECT, - SegmentType.BOOLEAN, - SegmentType.ARRAY_STRING, - SegmentType.ARRAY_NUMBER, - SegmentType.ARRAY_OBJECT, - SegmentType.ARRAY_BOOLEAN, - ] -) - - -def _is_valid_var_type(seg_type: SegmentType) -> SegmentType: - if seg_type not in _VALID_VAR_TYPE: - raise ValueError(...) - return seg_type - - -class LoopVariableData(BaseModel): - """ - Loop Variable Data. - """ - - label: str - var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)] - value_type: Literal["variable", "constant"] - value: Any | list[str] | None = None - - -class LoopNodeData(BaseLoopNodeData): - type: NodeType = BuiltinNodeTypes.LOOP - loop_count: int # Maximum number of loops - break_conditions: list[Condition] # Conditions to break the loop - logical_operator: Literal["and", "or"] - loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData]) - outputs: dict[str, Any] = Field(default_factory=dict) - - @field_validator("outputs", mode="before") - @classmethod - def validate_outputs(cls, v): - if v is None: - return {} - return v - - -class LoopStartNodeData(BaseNodeData): - """ - Loop Start Node Data. - """ - - type: NodeType = BuiltinNodeTypes.LOOP_START - - -class LoopEndNodeData(BaseNodeData): - """ - Loop End Node Data. - """ - - type: NodeType = BuiltinNodeTypes.LOOP_END - - -class LoopState(BaseLoopState): - """ - Loop State. - """ - - outputs: list[Any] = Field(default_factory=list) - current_output: Any = None - - class MetaData(BaseLoopState.MetaData): - """ - Data. - """ - - loop_length: int - - def get_last_output(self) -> Any: - """ - Get last output. - """ - if self.outputs: - return self.outputs[-1] - return None - - def get_current_output(self) -> Any: - """ - Get current output. - """ - return self.current_output - - -class LoopCompletedReason(StrEnum): - LOOP_BREAK = "loop_break" - LOOP_COMPLETED = "loop_completed" diff --git a/api/dify_graph/nodes/loop/loop_end_node.py b/api/dify_graph/nodes/loop/loop_end_node.py deleted file mode 100644 index 0287708fb36..00000000000 --- a/api/dify_graph/nodes/loop/loop_end_node.py +++ /dev/null @@ -1,22 +0,0 @@ -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.loop.entities import LoopEndNodeData - - -class LoopEndNode(Node[LoopEndNodeData]): - """ - Loop End Node. - """ - - node_type = BuiltinNodeTypes.LOOP_END - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - """ - Run the node. - """ - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED) diff --git a/api/dify_graph/nodes/loop/loop_node.py b/api/dify_graph/nodes/loop/loop_node.py deleted file mode 100644 index 3c546ffa234..00000000000 --- a/api/dify_graph/nodes/loop/loop_node.py +++ /dev/null @@ -1,435 +0,0 @@ -import contextlib -import json -import logging -from collections.abc import Callable, Generator, Mapping, Sequence -from datetime import datetime -from typing import TYPE_CHECKING, Any, Literal, cast - -from dify_graph.entities.graph_config import NodeConfigDictAdapter -from dify_graph.enums import ( - BuiltinNodeTypes, - NodeExecutionType, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from dify_graph.graph_events import ( - GraphNodeEventBase, - GraphRunFailedEvent, - NodeRunSucceededEvent, -) -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.node_events import ( - LoopFailedEvent, - LoopNextEvent, - LoopStartedEvent, - LoopSucceededEvent, - NodeEventBase, - NodeRunResult, - StreamCompletedEvent, -) -from dify_graph.nodes.base import LLMUsageTrackingMixin -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData -from dify_graph.utils.condition.processor import ConditionProcessor -from dify_graph.variables import Segment, SegmentType -from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable -from libs.datetime_utils import naive_utc_now - -if TYPE_CHECKING: - from dify_graph.graph_engine import GraphEngine - -logger = logging.getLogger(__name__) - - -class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): - """ - Loop Node. - """ - - node_type = BuiltinNodeTypes.LOOP - execution_type = NodeExecutionType.CONTAINER - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> Generator: - """Run the node.""" - # Get inputs - loop_count = self.node_data.loop_count - break_conditions = self.node_data.break_conditions - logical_operator = self.node_data.logical_operator - - inputs = {"loop_count": loop_count} - - if not self.node_data.start_node_id: - raise ValueError(f"field start_node_id in loop {self._node_id} not found") - - root_node_id = self.node_data.start_node_id - - # Initialize loop variables in the original variable pool - loop_variable_selectors = {} - if self.node_data.loop_variables: - value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = { - "constant": lambda var: self._get_segment_for_constant(var.var_type, var.value), - "variable": lambda var: ( - self.graph_runtime_state.variable_pool.get(var.value) if isinstance(var.value, list) else None - ), - } - for loop_variable in self.node_data.loop_variables: - if loop_variable.value_type not in value_processor: - raise ValueError( - f"Invalid value type '{loop_variable.value_type}' for loop variable {loop_variable.label}" - ) - - processed_segment = value_processor[loop_variable.value_type](loop_variable) - if not processed_segment: - raise ValueError(f"Invalid value for loop variable {loop_variable.label}") - variable_selector = [self._node_id, loop_variable.label] - variable = segment_to_variable(segment=processed_segment, selector=variable_selector) - self.graph_runtime_state.variable_pool.add(variable_selector, variable.value) - loop_variable_selectors[loop_variable.label] = variable_selector - inputs[loop_variable.label] = processed_segment.value - - start_at = naive_utc_now() - condition_processor = ConditionProcessor() - - loop_duration_map: dict[str, float] = {} - single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output - loop_usage = LLMUsage.empty_usage() - loop_node_ids = self._extract_loop_node_ids_from_config(self.graph_config, self._node_id) - - # Start Loop event - yield LoopStartedEvent( - start_at=start_at, - inputs=inputs, - metadata={"loop_length": loop_count}, - ) - - try: - reach_break_condition = False - if break_conditions: - with contextlib.suppress(ValueError): - _, _, reach_break_condition = condition_processor.process_conditions( - variable_pool=self.graph_runtime_state.variable_pool, - conditions=break_conditions, - operator=logical_operator, - ) - - if reach_break_condition: - loop_count = 0 - - for i in range(loop_count): - # Clear stale variables from previous loop iterations to avoid streaming old values - self._clear_loop_subgraph_variables(loop_node_ids) - graph_engine = self._create_graph_engine(start_at=start_at, root_node_id=root_node_id) - - loop_start_time = naive_utc_now() - reach_break_node = yield from self._run_single_loop(graph_engine=graph_engine, current_index=i) - # Track loop duration - loop_duration_map[str(i)] = (naive_utc_now() - loop_start_time).total_seconds() - - # Accumulate outputs from the sub-graph's response nodes - for key, value in graph_engine.graph_runtime_state.outputs.items(): - if key == "answer": - # Concatenate answer outputs with newline - existing_answer = self.graph_runtime_state.get_output("answer", "") - if existing_answer: - self.graph_runtime_state.set_output("answer", f"{existing_answer}{value}") - else: - self.graph_runtime_state.set_output("answer", value) - else: - # For other outputs, just update - self.graph_runtime_state.set_output(key, value) - - # Accumulate usage from the sub-graph execution - loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage) - - # Collect loop variable values after iteration - single_loop_variable = {} - for key, selector in loop_variable_selectors.items(): - segment = self.graph_runtime_state.variable_pool.get(selector) - single_loop_variable[key] = segment.value if segment else None - - single_loop_variable_map[str(i)] = single_loop_variable - - if reach_break_node: - break - - if break_conditions: - _, _, reach_break_condition = condition_processor.process_conditions( - variable_pool=self.graph_runtime_state.variable_pool, - conditions=break_conditions, - operator=logical_operator, - ) - if reach_break_condition: - break - - yield LoopNextEvent( - index=i + 1, - pre_loop_output=self.node_data.outputs, - ) - - self._accumulate_usage(loop_usage) - # Loop completed successfully - yield LoopSucceededEvent( - start_at=start_at, - inputs=inputs, - outputs=self.node_data.outputs, - steps=loop_count, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency, - WorkflowNodeExecutionMetadataKey.COMPLETED_REASON: ( - LoopCompletedReason.LOOP_BREAK - if reach_break_condition - else LoopCompletedReason.LOOP_COMPLETED.value - ), - WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, - WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, - }, - ) - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency, - WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, - WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, - }, - outputs=self.node_data.outputs, - inputs=inputs, - llm_usage=loop_usage, - ) - ) - - except Exception as e: - self._accumulate_usage(loop_usage) - yield LoopFailedEvent( - start_at=start_at, - inputs=inputs, - steps=loop_count, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency, - "completed_reason": "error", - WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, - WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, - }, - error=str(e), - ) - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency, - WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, - WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, - }, - llm_usage=loop_usage, - ) - ) - - def _run_single_loop( - self, - *, - graph_engine: "GraphEngine", - current_index: int, - ) -> Generator[NodeEventBase | GraphNodeEventBase, None, bool]: - reach_break_node = False - for event in graph_engine.run(): - if isinstance(event, GraphNodeEventBase): - self._append_loop_info_to_event(event=event, loop_run_index=current_index) - - if isinstance(event, GraphNodeEventBase) and event.node_type == BuiltinNodeTypes.LOOP_START: - continue - if isinstance(event, GraphNodeEventBase): - yield event - if isinstance(event, NodeRunSucceededEvent) and event.node_type == BuiltinNodeTypes.LOOP_END: - reach_break_node = True - if isinstance(event, GraphRunFailedEvent): - raise Exception(event.error) - - for loop_var in self.node_data.loop_variables or []: - key, sel = loop_var.label, [self._node_id, loop_var.label] - segment = self.graph_runtime_state.variable_pool.get(sel) - self.node_data.outputs[key] = segment.value if segment else None - self.node_data.outputs["loop_round"] = current_index + 1 - - return reach_break_node - - def _append_loop_info_to_event( - self, - event: GraphNodeEventBase, - loop_run_index: int, - ): - event.in_loop_id = self._node_id - loop_metadata = { - WorkflowNodeExecutionMetadataKey.LOOP_ID: self._node_id, - WorkflowNodeExecutionMetadataKey.LOOP_INDEX: loop_run_index, - } - - current_metadata = event.node_run_result.metadata - if WorkflowNodeExecutionMetadataKey.LOOP_ID not in current_metadata: - event.node_run_result.metadata = {**current_metadata, **loop_metadata} - - def _clear_loop_subgraph_variables(self, loop_node_ids: set[str]) -> None: - """ - Remove variables produced by loop sub-graph nodes from previous iterations. - - Keeping stale variables causes a freshly created response coordinator in the - next iteration to fall back to outdated values when no stream chunks exist. - """ - variable_pool = self.graph_runtime_state.variable_pool - for node_id in loop_node_ids: - variable_pool.remove([node_id]) - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: LoopNodeData, - ) -> Mapping[str, Sequence[str]]: - variable_mapping = {} - - # Extract loop node IDs statically from graph_config - - loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id) - - # Get node configs from graph_config - node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node} - for sub_node_id, sub_node_config in node_configs.items(): - if sub_node_config.get("data", {}).get("loop_id") != node_id: - continue - - # variable selector to variable mapping - try: - typed_sub_node_config = NodeConfigDictAdapter.validate_python(sub_node_config) - node_type = typed_sub_node_config["data"].type - node_mapping = Node.get_node_type_classes_mapping() - if node_type not in node_mapping: - continue - node_version = str(typed_sub_node_config["data"].version) - node_cls = node_mapping[node_type][node_version] - - sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( - graph_config=graph_config, config=typed_sub_node_config - ) - sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping) - except NotImplementedError: - sub_node_variable_mapping = {} - - # remove loop variables - sub_node_variable_mapping = { - sub_node_id + "." + key: value - for key, value in sub_node_variable_mapping.items() - if value[0] != node_id - } - - variable_mapping.update(sub_node_variable_mapping) - - for loop_variable in node_data.loop_variables or []: - if loop_variable.value_type == "variable": - assert loop_variable.value is not None, "Loop variable value must be provided for variable type" - # add loop variable to variable mapping - selector = loop_variable.value - variable_mapping[f"{node_id}.{loop_variable.label}"] = selector - - # remove variable out from loop - variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in loop_node_ids} - - return variable_mapping - - @classmethod - def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, Any], loop_node_id: str) -> set[str]: - """ - Extract node IDs that belong to a specific loop from graph configuration. - - This method statically analyzes the graph configuration to find all nodes - that are part of the specified loop, without creating actual node instances. - - :param graph_config: the complete graph configuration - :param loop_node_id: the ID of the loop node - :return: set of node IDs that belong to the loop - """ - loop_node_ids = set() - - # Find all nodes that belong to this loop - nodes = graph_config.get("nodes", []) - for node in nodes: - node_data = node.get("data", {}) - if node_data.get("loop_id") == loop_node_id: - node_id = node.get("id") - if node_id: - loop_node_ids.add(node_id) - - return loop_node_ids - - @staticmethod - def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment: - """Get the appropriate segment type for a constant value.""" - # TODO: Refactor for maintainability: - # 1. Ensure type handling logic stays synchronized with _VALID_VAR_TYPE (entities.py) - # 2. Consider moving this method to LoopVariableData class for better encapsulation - if not var_type.is_array_type() or var_type == SegmentType.ARRAY_BOOLEAN: - value = original_value - elif var_type in [ - SegmentType.ARRAY_NUMBER, - SegmentType.ARRAY_OBJECT, - SegmentType.ARRAY_STRING, - ]: - if original_value and isinstance(original_value, str): - value = json.loads(original_value) - else: - logger.warning("unexpected value for LoopNode, value_type=%s, value=%s", original_value, var_type) - value = [] - else: - raise AssertionError("this statement should be unreachable.") - try: - return build_segment_with_type(var_type, value=value) - except TypeMismatchError as type_exc: - # Attempt to parse the value as a JSON-encoded string, if applicable. - if not isinstance(original_value, str): - raise - try: - value = json.loads(original_value) - except ValueError: - raise type_exc - return build_segment_with_type(var_type, value) - - def _create_graph_engine(self, start_at: datetime, root_node_id: str): - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState - - # Create GraphInitParams for child graph execution. - graph_init_params = GraphInitParams( - workflow_id=self.workflow_id, - graph_config=self.graph_config, - run_context=self.run_context, - call_depth=self.workflow_call_depth, - ) - - # Create a new GraphRuntimeState for this iteration - graph_runtime_state_copy = GraphRuntimeState( - variable_pool=self.graph_runtime_state.variable_pool, - start_at=start_at.timestamp(), - ) - - return self.graph_runtime_state.create_child_engine( - workflow_id=self.workflow_id, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state_copy, - graph_config=self.graph_config, - root_node_id=root_node_id, - ) diff --git a/api/dify_graph/nodes/loop/loop_start_node.py b/api/dify_graph/nodes/loop/loop_start_node.py deleted file mode 100644 index e171b4df2fd..00000000000 --- a/api/dify_graph/nodes/loop/loop_start_node.py +++ /dev/null @@ -1,22 +0,0 @@ -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.loop.entities import LoopStartNodeData - - -class LoopStartNode(Node[LoopStartNodeData]): - """ - Loop Start Node. - """ - - node_type = BuiltinNodeTypes.LOOP_START - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - """ - Run the node. - """ - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED) diff --git a/api/dify_graph/nodes/parameter_extractor/__init__.py b/api/dify_graph/nodes/parameter_extractor/__init__.py deleted file mode 100644 index bdbf19a7d36..00000000000 --- a/api/dify_graph/nodes/parameter_extractor/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .parameter_extractor_node import ParameterExtractorNode - -__all__ = ["ParameterExtractorNode"] diff --git a/api/dify_graph/nodes/parameter_extractor/entities.py b/api/dify_graph/nodes/parameter_extractor/entities.py deleted file mode 100644 index 2fb042c16c3..00000000000 --- a/api/dify_graph/nodes/parameter_extractor/entities.py +++ /dev/null @@ -1,131 +0,0 @@ -from typing import Annotated, Any, Literal - -from pydantic import ( - BaseModel, - BeforeValidator, - Field, - field_validator, -) - -from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig -from dify_graph.variables.types import SegmentType - -_OLD_BOOL_TYPE_NAME = "bool" -_OLD_SELECT_TYPE_NAME = "select" - -_VALID_PARAMETER_TYPES = frozenset( - [ - SegmentType.STRING, # "string", - SegmentType.NUMBER, # "number", - SegmentType.BOOLEAN, - SegmentType.ARRAY_STRING, - SegmentType.ARRAY_NUMBER, - SegmentType.ARRAY_OBJECT, - SegmentType.ARRAY_BOOLEAN, - _OLD_BOOL_TYPE_NAME, # old boolean type used by Parameter Extractor node - _OLD_SELECT_TYPE_NAME, # string type with enumeration choices. - ] -) - - -def _validate_type(parameter_type: str) -> SegmentType: - if parameter_type not in _VALID_PARAMETER_TYPES: - raise ValueError(f"type {parameter_type} is not allowd to use in Parameter Extractor node.") - - if parameter_type == _OLD_BOOL_TYPE_NAME: - return SegmentType.BOOLEAN - elif parameter_type == _OLD_SELECT_TYPE_NAME: - return SegmentType.STRING - return SegmentType(parameter_type) - - -class ParameterConfig(BaseModel): - """ - Parameter Config. - """ - - name: str - type: Annotated[SegmentType, BeforeValidator(_validate_type)] - options: list[str] | None = None - description: str - required: bool - - @field_validator("name", mode="before") - @classmethod - def validate_name(cls, value) -> str: - if not value: - raise ValueError("Parameter name is required") - if value in {"__reason", "__is_success"}: - raise ValueError("Invalid parameter name, __reason and __is_success are reserved") - return str(value) - - def is_array_type(self) -> bool: - return self.type.is_array_type() - - def element_type(self) -> SegmentType: - """Return the element type of the parameter. - - Raises a ValueError if the parameter's type is not an array type. - """ - element_type = self.type.element_type() - # At this point, self.type is guaranteed to be one of `ARRAY_STRING`, - # `ARRAY_NUMBER`, `ARRAY_OBJECT`, or `ARRAY_BOOLEAN`. - # - # See: _VALID_PARAMETER_TYPES for reference. - assert element_type is not None, f"the element type should not be None, {self.type=}" - return element_type - - -class ParameterExtractorNodeData(BaseNodeData): - """ - Parameter Extractor Node Data. - """ - - type: NodeType = BuiltinNodeTypes.PARAMETER_EXTRACTOR - model: ModelConfig - query: list[str] - parameters: list[ParameterConfig] - instruction: str | None = None - memory: MemoryConfig | None = None - reasoning_mode: Literal["function_call", "prompt"] - vision: VisionConfig = Field(default_factory=VisionConfig) - - @field_validator("reasoning_mode", mode="before") - @classmethod - def set_reasoning_mode(cls, v) -> str: - return v or "function_call" - - def get_parameter_json_schema(self): - """ - Get parameter json schema. - - :return: parameter json schema - """ - parameters: dict[str, Any] = {"type": "object", "properties": {}, "required": []} - - for parameter in self.parameters: - parameter_schema: dict[str, Any] = {"description": parameter.description} - - if parameter.type == SegmentType.STRING: - parameter_schema["type"] = "string" - elif parameter.type.is_array_type(): - parameter_schema["type"] = "array" - element_type = parameter.type.element_type() - if element_type is None: - raise AssertionError("element type should not be None.") - parameter_schema["items"] = {"type": element_type.value} - else: - parameter_schema["type"] = parameter.type - - if parameter.options: - parameter_schema["enum"] = parameter.options - - parameters["properties"][parameter.name] = parameter_schema - - if parameter.required: - parameters["required"].append(parameter.name) - - return parameters diff --git a/api/dify_graph/nodes/parameter_extractor/exc.py b/api/dify_graph/nodes/parameter_extractor/exc.py deleted file mode 100644 index c25b809d1cf..00000000000 --- a/api/dify_graph/nodes/parameter_extractor/exc.py +++ /dev/null @@ -1,75 +0,0 @@ -from typing import Any - -from dify_graph.variables.types import SegmentType - - -class ParameterExtractorNodeError(ValueError): - """Base error for ParameterExtractorNode.""" - - -class InvalidModelTypeError(ParameterExtractorNodeError): - """Raised when the model is not a Large Language Model.""" - - -class ModelSchemaNotFoundError(ParameterExtractorNodeError): - """Raised when the model schema is not found.""" - - -class InvalidInvokeResultError(ParameterExtractorNodeError): - """Raised when the invoke result is invalid.""" - - -class InvalidTextContentTypeError(ParameterExtractorNodeError): - """Raised when the text content type is invalid.""" - - -class InvalidNumberOfParametersError(ParameterExtractorNodeError): - """Raised when the number of parameters is invalid.""" - - -class RequiredParameterMissingError(ParameterExtractorNodeError): - """Raised when a required parameter is missing.""" - - -class InvalidSelectValueError(ParameterExtractorNodeError): - """Raised when a select value is invalid.""" - - -class InvalidNumberValueError(ParameterExtractorNodeError): - """Raised when a number value is invalid.""" - - -class InvalidBoolValueError(ParameterExtractorNodeError): - """Raised when a bool value is invalid.""" - - -class InvalidStringValueError(ParameterExtractorNodeError): - """Raised when a string value is invalid.""" - - -class InvalidArrayValueError(ParameterExtractorNodeError): - """Raised when an array value is invalid.""" - - -class InvalidModelModeError(ParameterExtractorNodeError): - """Raised when the model mode is invalid.""" - - -class InvalidValueTypeError(ParameterExtractorNodeError): - def __init__( - self, - /, - parameter_name: str, - expected_type: SegmentType, - actual_type: SegmentType | None, - value: Any, - ): - message = ( - f"Invalid value for parameter {parameter_name}, expected segment type: {expected_type}, " - f"actual_type: {actual_type}, python_type: {type(value)}, value: {value}" - ) - super().__init__(message) - self.parameter_name = parameter_name - self.expected_type = expected_type - self.actual_type = actual_type - self.value = value diff --git a/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py b/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py deleted file mode 100644 index 3913a276971..00000000000 --- a/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py +++ /dev/null @@ -1,853 +0,0 @@ -import contextlib -import json -import logging -import uuid -from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any, cast - -from core.model_manager import ModelInstance -from core.prompt.advanced_prompt_transform import AdvancedPromptTransform -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate -from core.prompt.simple_prompt_transform import ModelMode -from core.prompt.utils.prompt_message_util import PromptMessageUtil -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import ( - BuiltinNodeTypes, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from dify_graph.file import File -from dify_graph.model_runtime.entities import ImagePromptMessageContent -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessage, - PromptMessageRole, - PromptMessageTool, - ToolPromptMessage, - UserPromptMessage, -) -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey -from dify_graph.model_runtime.memory import PromptMessageMemory -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base import variable_template_parser -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.llm import llm_utils -from dify_graph.runtime import VariablePool -from dify_graph.variables.types import ArrayValidation, SegmentType -from factories.variable_factory import build_segment_with_type - -from .entities import ParameterExtractorNodeData -from .exc import ( - InvalidModelModeError, - InvalidModelTypeError, - InvalidNumberOfParametersError, - InvalidSelectValueError, - InvalidTextContentTypeError, - InvalidValueTypeError, - ModelSchemaNotFoundError, - ParameterExtractorNodeError, - RequiredParameterMissingError, -) -from .prompts import ( - CHAT_EXAMPLE, - CHAT_GENERATE_JSON_PROMPT, - CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE, - COMPLETION_GENERATE_JSON_PROMPT, - FUNCTION_CALLING_EXTRACTOR_EXAMPLE, - FUNCTION_CALLING_EXTRACTOR_NAME, - FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT, - FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE, -) - -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory - from dify_graph.runtime import GraphRuntimeState - - -def extract_json(text): - """ - From a given JSON started from '{' or '[' extract the complete JSON object. - """ - stack = [] - for i, c in enumerate(text): - if c in {"{", "["}: - stack.append(c) - elif c in {"}", "]"}: - # check if stack is empty - if not stack: - return text[:i] - # check if the last element in stack is matching - if (c == "}" and stack[-1] == "{") or (c == "]" and stack[-1] == "["): - stack.pop() - if not stack: - return text[: i + 1] - else: - return text[:i] - return None - - -class ParameterExtractorNode(Node[ParameterExtractorNodeData]): - """ - Parameter Extractor Node. - """ - - node_type = BuiltinNodeTypes.PARAMETER_EXTRACTOR - - _model_instance: ModelInstance - _credentials_provider: "CredentialsProvider" - _model_factory: "ModelFactory" - _memory: PromptMessageMemory | None - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - *, - credentials_provider: "CredentialsProvider", - model_factory: "ModelFactory", - model_instance: ModelInstance, - memory: PromptMessageMemory | None = None, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - self._credentials_provider = credentials_provider - self._model_factory = model_factory - self._model_instance = model_instance - self._memory = memory - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - return { - "model": { - "prompt_templates": { - "completion_model": { - "conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"}, - "stop": ["Human:"], - } - } - } - } - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self): - """ - Run the node. - """ - node_data = self.node_data - variable = self.graph_runtime_state.variable_pool.get(node_data.query) - query = variable.text if variable else "" - - variable_pool = self.graph_runtime_state.variable_pool - - files = ( - llm_utils.fetch_files( - variable_pool=variable_pool, - selector=node_data.vision.configs.variable_selector, - ) - if node_data.vision.enabled - else [] - ) - - model_instance = self._model_instance - if not isinstance(model_instance.model_type_instance, LargeLanguageModel): - raise InvalidModelTypeError("Model is not a Large Language Model") - - try: - model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) - except ValueError as exc: - raise ModelSchemaNotFoundError("Model schema not found") from exc - memory = self._memory - - if ( - set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} - and node_data.reasoning_mode == "function_call" - ): - # use function call - prompt_messages, prompt_message_tools = self._generate_function_call_prompt( - node_data=node_data, - query=query, - variable_pool=self.graph_runtime_state.variable_pool, - model_instance=model_instance, - memory=memory, - files=files, - vision_detail=node_data.vision.configs.detail, - ) - else: - # use prompt engineering - prompt_messages = self._generate_prompt_engineering_prompt( - data=node_data, - query=query, - variable_pool=self.graph_runtime_state.variable_pool, - model_instance=model_instance, - memory=memory, - files=files, - vision_detail=node_data.vision.configs.detail, - ) - - prompt_message_tools = [] - - inputs = { - "query": query, - "files": [f.to_dict() for f in files], - "parameters": jsonable_encoder(node_data.parameters), - "instruction": jsonable_encoder(node_data.instruction), - } - - process_data = { - "model_mode": node_data.model.mode, - "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( - model_mode=node_data.model.mode, prompt_messages=prompt_messages - ), - "usage": None, - "function": {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]), - "tool_call": None, - "model_provider": model_instance.provider, - "model_name": model_instance.model_name, - } - - try: - text, usage, tool_call = self._invoke( - model_instance=model_instance, - prompt_messages=prompt_messages, - tools=prompt_message_tools, - stop=model_instance.stop, - ) - process_data["usage"] = jsonable_encoder(usage) - process_data["tool_call"] = jsonable_encoder(tool_call) - process_data["llm_text"] = text - except ParameterExtractorNodeError as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=inputs, - process_data=process_data, - outputs={"__is_success": 0, "__reason": str(e)}, - error=str(e), - metadata={}, - ) - except Exception as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=inputs, - process_data=process_data, - outputs={"__is_success": 0, "__reason": "Failed to invoke model", "__error": str(e)}, - error=str(e), - metadata={}, - ) - - error = None - - if tool_call: - result = self._extract_json_from_tool_call(tool_call) - else: - result = self._extract_complete_json_response(text) - if not result: - result = self._generate_default_result(node_data) - error = "Failed to extract result from function call or text response, using empty result." - - try: - result = self._validate_result(data=node_data, result=result or {}) - except ParameterExtractorNodeError as e: - error = str(e) - - # transform result into standard format - result = self._transform_result(data=node_data, result=result or {}) - - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs={ - "__is_success": 1 if not error else 0, - "__reason": error, - "__usage": jsonable_encoder(usage), - **result, - }, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - }, - llm_usage=usage, - ) - - def _invoke( - self, - model_instance: ModelInstance, - prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - stop: Sequence[str], - ) -> tuple[str, LLMUsage, AssistantPromptMessage.ToolCall | None]: - invoke_result = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=dict(model_instance.parameters), - tools=tools, - stop=list(stop), - stream=False, - user=self.require_dify_context().user_id, - ) - - # handle invoke result - - text = invoke_result.message.get_text_content() - if not isinstance(text, str): - raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.") - - usage = invoke_result.usage - tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None - - return text, usage, tool_call - - def _generate_function_call_prompt( - self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_instance: ModelInstance, - memory: PromptMessageMemory | None, - files: Sequence[File], - vision_detail: ImagePromptMessageContent.DETAIL | None = None, - ) -> tuple[list[PromptMessage], list[PromptMessageTool]]: - """ - Generate function call prompt. - """ - query = FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE.format( - content=query, structure=json.dumps(node_data.get_parameter_json_schema()) - ) - - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - rest_token = self._calculate_rest_token( - node_data=node_data, - query=query, - variable_pool=variable_pool, - model_instance=model_instance, - context="", - ) - prompt_template = self._get_function_calling_prompt_template( - node_data, query, variable_pool, memory, rest_token - ) - prompt_messages = prompt_transform.get_prompt( - prompt_template=prompt_template, - inputs={}, - query="", - files=files, - context="", - memory_config=node_data.memory, - memory=None, - model_instance=model_instance, - image_detail_config=vision_detail, - ) - - # find last user message - last_user_message_idx = -1 - for i, prompt_message in enumerate(prompt_messages): - if prompt_message.role == PromptMessageRole.USER: - last_user_message_idx = i - - # add function call messages before last user message - example_messages = [] - for example in FUNCTION_CALLING_EXTRACTOR_EXAMPLE: - id = uuid.uuid4().hex - example_messages.extend( - [ - UserPromptMessage(content=example["user"]["query"]), - AssistantPromptMessage( - content=example["assistant"]["text"], - tool_calls=[ - AssistantPromptMessage.ToolCall( - id=id, - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=example["assistant"]["function_call"]["name"], - arguments=json.dumps(example["assistant"]["function_call"]["parameters"]), - ), - ) - ], - ), - ToolPromptMessage( - content="Great! You have called the function with the correct parameters.", tool_call_id=id - ), - AssistantPromptMessage( - content="I have extracted the parameters, let's move on.", - ), - ] - ) - - prompt_messages = ( - prompt_messages[:last_user_message_idx] + example_messages + prompt_messages[last_user_message_idx:] - ) - - # generate tool - tool = PromptMessageTool( - name=FUNCTION_CALLING_EXTRACTOR_NAME, - description="Extract parameters from the natural language text", - parameters=node_data.get_parameter_json_schema(), - ) - - return prompt_messages, [tool] - - def _generate_prompt_engineering_prompt( - self, - data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_instance: ModelInstance, - memory: PromptMessageMemory | None, - files: Sequence[File], - vision_detail: ImagePromptMessageContent.DETAIL | None = None, - ) -> list[PromptMessage]: - """ - Generate prompt engineering prompt. - """ - model_mode = ModelMode(data.model.mode) - - if model_mode == ModelMode.COMPLETION: - return self._generate_prompt_engineering_completion_prompt( - node_data=data, - query=query, - variable_pool=variable_pool, - model_instance=model_instance, - memory=memory, - files=files, - vision_detail=vision_detail, - ) - elif model_mode == ModelMode.CHAT: - return self._generate_prompt_engineering_chat_prompt( - node_data=data, - query=query, - variable_pool=variable_pool, - model_instance=model_instance, - memory=memory, - files=files, - vision_detail=vision_detail, - ) - else: - raise InvalidModelModeError(f"Invalid model mode: {model_mode}") - - def _generate_prompt_engineering_completion_prompt( - self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_instance: ModelInstance, - memory: PromptMessageMemory | None, - files: Sequence[File], - vision_detail: ImagePromptMessageContent.DETAIL | None = None, - ) -> list[PromptMessage]: - """ - Generate completion prompt. - """ - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - rest_token = self._calculate_rest_token( - node_data=node_data, - query=query, - variable_pool=variable_pool, - model_instance=model_instance, - context="", - ) - prompt_template = self._get_prompt_engineering_prompt_template( - node_data=node_data, query=query, variable_pool=variable_pool, memory=memory, max_token_limit=rest_token - ) - prompt_messages = prompt_transform.get_prompt( - prompt_template=prompt_template, - inputs={"structure": json.dumps(node_data.get_parameter_json_schema())}, - query="", - files=files, - context="", - memory_config=node_data.memory, - # AdvancedPromptTransform is still typed against TokenBufferMemory. - memory=cast(Any, memory), - model_instance=model_instance, - image_detail_config=vision_detail, - ) - - return prompt_messages - - def _generate_prompt_engineering_chat_prompt( - self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_instance: ModelInstance, - memory: PromptMessageMemory | None, - files: Sequence[File], - vision_detail: ImagePromptMessageContent.DETAIL | None = None, - ) -> list[PromptMessage]: - """ - Generate chat prompt. - """ - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - rest_token = self._calculate_rest_token( - node_data=node_data, - query=query, - variable_pool=variable_pool, - model_instance=model_instance, - context="", - ) - prompt_template = self._get_prompt_engineering_prompt_template( - node_data=node_data, - query=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( - structure=json.dumps(node_data.get_parameter_json_schema()), text=query - ), - variable_pool=variable_pool, - memory=memory, - max_token_limit=rest_token, - ) - - prompt_messages = prompt_transform.get_prompt( - prompt_template=prompt_template, - inputs={}, - query="", - files=files, - context="", - memory_config=node_data.memory, - memory=None, - model_instance=model_instance, - image_detail_config=vision_detail, - ) - - # find last user message - last_user_message_idx = -1 - for i, prompt_message in enumerate(prompt_messages): - if prompt_message.role == PromptMessageRole.USER: - last_user_message_idx = i - - # add example messages before last user message - example_messages = [] - for example in CHAT_EXAMPLE: - example_messages.extend( - [ - UserPromptMessage( - content=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( - structure=json.dumps(example["user"]["json"]), - text=example["user"]["query"], - ) - ), - AssistantPromptMessage( - content=json.dumps(example["assistant"]["json"]), - ), - ] - ) - - prompt_messages = ( - prompt_messages[:last_user_message_idx] + example_messages + prompt_messages[last_user_message_idx:] - ) - - return prompt_messages - - def _validate_result(self, data: ParameterExtractorNodeData, result: dict): - if len(data.parameters) != len(result): - raise InvalidNumberOfParametersError("Invalid number of parameters") - - for parameter in data.parameters: - if parameter.required and parameter.name not in result: - raise RequiredParameterMissingError(f"Parameter {parameter.name} is required") - - param_value = result.get(parameter.name) - if not parameter.type.is_valid(param_value, array_validation=ArrayValidation.ALL): - inferred_type = SegmentType.infer_segment_type(param_value) - raise InvalidValueTypeError( - parameter_name=parameter.name, - expected_type=parameter.type, - actual_type=inferred_type, - value=param_value, - ) - if parameter.type == SegmentType.STRING and parameter.options: - if param_value not in parameter.options: - raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}") - return result - - @staticmethod - def _transform_number(value: int | float | str | bool) -> int | float | None: - """ - Attempts to transform the input into an integer or float. - - Returns: - int or float: The transformed number if the conversion is successful. - None: If the transformation fails. - - Note: - Boolean values `True` and `False` are converted to integers `1` and `0`, respectively. - This behavior ensures compatibility with existing workflows that may use boolean types as integers. - """ - if isinstance(value, bool): - return int(value) - elif isinstance(value, (int, float)): - return value - elif isinstance(value, str): - if "." in value: - try: - return float(value) - except ValueError: - return None - else: - try: - return int(value) - except ValueError: - return None - else: - return None - - def _transform_result(self, data: ParameterExtractorNodeData, result: dict): - """ - Transform result into standard format. - """ - transformed_result: dict[str, Any] = {} - for parameter in data.parameters: - if parameter.name in result: - param_value = result[parameter.name] - # transform value - if parameter.type == SegmentType.NUMBER: - transformed = self._transform_number(param_value) - if transformed is not None: - transformed_result[parameter.name] = transformed - elif parameter.type == SegmentType.BOOLEAN: - if isinstance(result[parameter.name], (bool, int)): - transformed_result[parameter.name] = bool(result[parameter.name]) - # elif isinstance(result[parameter.name], str): - # if result[parameter.name].lower() in ["true", "false"]: - # transformed_result[parameter.name] = bool(result[parameter.name].lower() == "true") - elif parameter.type == SegmentType.STRING: - if isinstance(param_value, str): - transformed_result[parameter.name] = param_value - elif parameter.is_array_type(): - if isinstance(param_value, list): - nested_type = parameter.element_type() - assert nested_type is not None - segment_value = build_segment_with_type(segment_type=SegmentType(parameter.type), value=[]) - transformed_result[parameter.name] = segment_value - for item in param_value: - if nested_type == SegmentType.NUMBER: - transformed = self._transform_number(item) - if transformed is not None: - segment_value.value.append(transformed) - elif nested_type == SegmentType.STRING: - if isinstance(item, str): - segment_value.value.append(item) - elif nested_type == SegmentType.OBJECT: - if isinstance(item, dict): - segment_value.value.append(item) - elif nested_type == SegmentType.BOOLEAN: - if isinstance(item, bool): - segment_value.value.append(item) - - if parameter.name not in transformed_result: - if parameter.type.is_array_type(): - transformed_result[parameter.name] = build_segment_with_type( - segment_type=SegmentType(parameter.type), value=[] - ) - elif parameter.type in (SegmentType.STRING, SegmentType.SECRET): - transformed_result[parameter.name] = "" - elif parameter.type == SegmentType.NUMBER: - transformed_result[parameter.name] = 0 - elif parameter.type == SegmentType.BOOLEAN: - transformed_result[parameter.name] = False - else: - raise AssertionError("this statement should be unreachable.") - - return transformed_result - - def _extract_complete_json_response(self, result: str) -> dict | None: - """ - Extract complete json response. - """ - - # extract json from the text - for idx in range(len(result)): - if result[idx] == "{" or result[idx] == "[": - json_str = extract_json(result[idx:]) - if json_str: - with contextlib.suppress(Exception): - return cast(dict, json.loads(json_str)) - logger.info("extra error: %s", result) - return None - - def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict | None: - """ - Extract json from tool call. - """ - if not tool_call or not tool_call.function.arguments: - return None - - result = tool_call.function.arguments - # extract json from the arguments - for idx in range(len(result)): - if result[idx] == "{" or result[idx] == "[": - json_str = extract_json(result[idx:]) - if json_str: - with contextlib.suppress(Exception): - return cast(dict, json.loads(json_str)) - - logger.info("extra error: %s", result) - return None - - def _generate_default_result(self, data: ParameterExtractorNodeData): - """ - Generate default result. - """ - result: dict[str, Any] = {} - for parameter in data.parameters: - if parameter.type == "number": - result[parameter.name] = 0 - elif parameter.type == "boolean": - result[parameter.name] = False - elif parameter.type in {"string", "select"}: - result[parameter.name] = "" - - return result - - def _get_function_calling_prompt_template( - self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - memory: PromptMessageMemory | None, - max_token_limit: int = 2000, - ) -> list[ChatModelMessage]: - model_mode = ModelMode(node_data.model.mode) - input_text = query - memory_str = "" - instruction = variable_pool.convert_template(node_data.instruction or "").text - - if memory and node_data.memory and node_data.memory.window: - memory_str = llm_utils.fetch_memory_text( - memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size - ) - if model_mode == ModelMode.CHAT: - system_prompt_messages = ChatModelMessage( - role=PromptMessageRole.SYSTEM, - text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction), - ) - user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text) - return [system_prompt_messages, user_prompt_message] - else: - raise InvalidModelModeError(f"Model mode {model_mode} not support.") - - def _get_prompt_engineering_prompt_template( - self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - memory: PromptMessageMemory | None, - max_token_limit: int = 2000, - ): - model_mode = ModelMode(node_data.model.mode) - input_text = query - memory_str = "" - instruction = variable_pool.convert_template(node_data.instruction or "").text - - if memory and node_data.memory and node_data.memory.window: - memory_str = llm_utils.fetch_memory_text( - memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size - ) - if model_mode == ModelMode.CHAT: - system_prompt_messages = ChatModelMessage( - role=PromptMessageRole.SYSTEM, - text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str, instructions=instruction), - ) - user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text) - return [system_prompt_messages, user_prompt_message] - elif model_mode == ModelMode.COMPLETION: - return CompletionModelPromptTemplate( - text=COMPLETION_GENERATE_JSON_PROMPT.format( - histories=memory_str, text=input_text, instruction=instruction - ) - .replace("{γγγ", "") - .replace("}γγγ", "") - ) - else: - raise InvalidModelModeError(f"Model mode {model_mode} not support.") - - def _calculate_rest_token( - self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_instance: ModelInstance, - context: str | None, - ) -> int: - try: - model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) - except ValueError as exc: - raise ModelSchemaNotFoundError("Model schema not found") from exc - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - - if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}: - prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000) - else: - prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000) - - prompt_messages = prompt_transform.get_prompt( - prompt_template=prompt_template, - inputs={}, - query="", - files=[], - context=context, - memory_config=node_data.memory, - memory=None, - model_instance=model_instance, - ) - rest_tokens = 2000 - - model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) - if model_context_tokens: - model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance) - curr_message_tokens = ( - model_type_instance.get_num_tokens( - model_instance.model_name, model_instance.credentials, prompt_messages - ) - + 1000 - ) # add 1000 to ensure tool call messages - - max_tokens = 0 - for parameter_rule in model_schema.parameter_rules: - if parameter_rule.name == "max_tokens" or ( - parameter_rule.use_template and parameter_rule.use_template == "max_tokens" - ): - max_tokens = ( - model_instance.parameters.get(parameter_rule.name) - or model_instance.parameters.get(parameter_rule.use_template or "") - ) or 0 - - rest_tokens = model_context_tokens - max_tokens - curr_message_tokens - rest_tokens = max(rest_tokens, 0) - - return rest_tokens - - @property - def model_instance(self) -> ModelInstance: - return self._model_instance - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: ParameterExtractorNodeData, - ) -> Mapping[str, Sequence[str]]: - _ = graph_config # Explicitly mark as unused - variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query} - - if node_data.instruction: - selectors = variable_template_parser.extract_selectors_from_template(node_data.instruction) - for selector in selectors: - variable_mapping[selector.variable] = selector.value_selector - - variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} - - return variable_mapping diff --git a/api/dify_graph/nodes/parameter_extractor/prompts.py b/api/dify_graph/nodes/parameter_extractor/prompts.py deleted file mode 100644 index 1b29be4418d..00000000000 --- a/api/dify_graph/nodes/parameter_extractor/prompts.py +++ /dev/null @@ -1,184 +0,0 @@ -from typing import Any - -FUNCTION_CALLING_EXTRACTOR_NAME = "extract_parameters" - -FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT = f"""You are a helpful assistant tasked with extracting structured information based on specific criteria provided. Follow the guidelines below to ensure consistency and accuracy. -### Task -Always call the `{FUNCTION_CALLING_EXTRACTOR_NAME}` function with the correct parameters. Ensure that the information extraction is contextual and aligns with the provided criteria. -### Memory -Here is the chat history between the human and assistant, provided within tags: - -\x7bhistories\x7d - -### Instructions: -Some additional information is provided below. Always adhere to these instructions as closely as possible: - -\x7binstruction\x7d - -Steps: -1. Review the chat history provided within the tags. -2. Extract the relevant information based on the criteria given, output multiple values if there is multiple relevant information that match the criteria in the given text. -3. Generate a well-formatted output using the defined functions and arguments. -4. Use the `extract_parameter` function to create structured outputs with appropriate parameters. -5. Do not include any XML tags in your output. -### Example -To illustrate, if the task involves extracting a user's name and their request, your function call might look like this: Ensure your output follows a similar structure to examples. -### Final Output -Produce well-formatted function calls in json without XML tags, as shown in the example. -""" # noqa: E501 - -FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE = f"""extract structured information from context inside XML tags by calling the function {FUNCTION_CALLING_EXTRACTOR_NAME} with the correct parameters with structure inside XML tags. - -\x7bcontent\x7d - - - -\x7bstructure\x7d - -""" # noqa: E501 - -FUNCTION_CALLING_EXTRACTOR_EXAMPLE: list[dict[str, Any]] = [ - { - "user": { - "query": "What is the weather today in SF?", - "function": { - "name": FUNCTION_CALLING_EXTRACTOR_NAME, - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The location to get the weather information", - "required": True, - }, - }, - "required": ["location"], - }, - }, - }, - "assistant": { - "text": "I need always call the function with the correct parameters." - " in this case, I need to call the function with the location parameter.", - "function_call": {"name": FUNCTION_CALLING_EXTRACTOR_NAME, "parameters": {"location": "San Francisco"}}, - }, - }, - { - "user": { - "query": "I want to eat some apple pie.", - "function": { - "name": FUNCTION_CALLING_EXTRACTOR_NAME, - "parameters": { - "type": "object", - "properties": {"food": {"type": "string", "description": "The food to eat", "required": True}}, - "required": ["food"], - }, - }, - }, - "assistant": { - "text": "I need always call the function with the correct parameters." - " in this case, I need to call the function with the food parameter.", - "function_call": {"name": FUNCTION_CALLING_EXTRACTOR_NAME, "parameters": {"food": "apple pie"}}, - }, - }, -] - -COMPLETION_GENERATE_JSON_PROMPT = """### Instructions: -Some extra information are provided below, I should always follow the instructions as possible as I can. - -{instruction} - - -### Extract parameter Workflow -I need to extract the following information from the input text. The tag specifies the 'type', 'description' and 'required' of the information to be extracted. - -{{ structure }} - - -Step 1: Carefully read the input and understand the structure of the expected output. -Step 2: Extract relevant parameters from the provided text based on the name and description of object. -Step 3: Structure the extracted parameters to JSON object as specified in . -Step 4: Ensure that the JSON object is properly formatted and valid. The output should not contain any XML tags. Only the JSON object should be outputted. - -### Memory -Here are the chat histories between human and assistant, inside XML tags. - -{histories} - - -### Structure -Here is the structure of the expected output, I should always follow the output structure. -{{γγγ - 'properties1': 'relevant text extracted from input', - 'properties2': 'relevant text extracted from input', -}}γγγ - -### Input Text -Inside XML tags, there is a text that I should extract parameters and convert to a JSON object. - -{text} - - -### Answer -I should always output a valid JSON object. Output nothing other than the JSON object. -```JSON -""" # noqa: E501 - -CHAT_GENERATE_JSON_PROMPT = """You should always follow the instructions and output a valid JSON object. -The structure of the JSON object you can found in the instructions. - -### Memory -Here are the chat histories between human and assistant, inside XML tags. - -{histories} - - -### Instructions: -Some extra information are provided below, you should always follow the instructions as possible as you can. - -{instructions} - -""" - -CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE = """### Structure -Here is the structure of the JSON object, you should always follow the structure. - -{structure} - - -### Text to be converted to JSON -Inside XML tags, there is a text that you should convert to a JSON object. - -{text} - -""" - -CHAT_EXAMPLE = [ - { - "user": { - "query": "What is the weather today in SF?", - "json": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The location to get the weather information", - "required": True, - } - }, - "required": ["location"], - }, - }, - "assistant": {"text": "I need to output a valid JSON object.", "json": {"location": "San Francisco"}}, - }, - { - "user": { - "query": "I want to eat some apple pie.", - "json": { - "type": "object", - "properties": {"food": {"type": "string", "description": "The food to eat", "required": True}}, - "required": ["food"], - }, - }, - "assistant": {"text": "I need to output a valid JSON object.", "json": {"food": "apple pie"}}, - }, -] diff --git a/api/dify_graph/nodes/protocols.py b/api/dify_graph/nodes/protocols.py deleted file mode 100644 index 62d3bcdca1e..00000000000 --- a/api/dify_graph/nodes/protocols.py +++ /dev/null @@ -1,46 +0,0 @@ -from collections.abc import Generator -from typing import Any, Protocol - -import httpx - -from dify_graph.file import File -from dify_graph.file.models import ToolFile - - -class HttpClientProtocol(Protocol): - @property - def max_retries_exceeded_error(self) -> type[Exception]: ... - - @property - def request_error(self) -> type[Exception]: ... - - def get(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - - def head(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - - def post(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - - def put(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - - def delete(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - - def patch(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - - -class FileManagerProtocol(Protocol): - def download(self, f: File, /) -> bytes: ... - - -class ToolFileManagerProtocol(Protocol): - def create_file_by_raw( - self, - *, - user_id: str, - tenant_id: str, - conversation_id: str | None, - file_binary: bytes, - mimetype: str, - filename: str | None = None, - ) -> Any: ... - - def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, ToolFile | None]: ... diff --git a/api/dify_graph/nodes/question_classifier/__init__.py b/api/dify_graph/nodes/question_classifier/__init__.py deleted file mode 100644 index 4d06b6bea36..00000000000 --- a/api/dify_graph/nodes/question_classifier/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .entities import QuestionClassifierNodeData -from .question_classifier_node import QuestionClassifierNode - -__all__ = ["QuestionClassifierNode", "QuestionClassifierNodeData"] diff --git a/api/dify_graph/nodes/question_classifier/entities.py b/api/dify_graph/nodes/question_classifier/entities.py deleted file mode 100644 index 0c1601d4393..00000000000 --- a/api/dify_graph/nodes/question_classifier/entities.py +++ /dev/null @@ -1,30 +0,0 @@ -from pydantic import BaseModel, Field - -from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.llm import ModelConfig, VisionConfig - - -class ClassConfig(BaseModel): - id: str - name: str - - -class QuestionClassifierNodeData(BaseNodeData): - type: NodeType = BuiltinNodeTypes.QUESTION_CLASSIFIER - query_variable_selector: list[str] - model: ModelConfig - classes: list[ClassConfig] - instruction: str | None = None - memory: MemoryConfig | None = None - vision: VisionConfig = Field(default_factory=VisionConfig) - - @property - def structured_output_enabled(self) -> bool: - # NOTE(QuantumGhost): Temporary workaround for issue #20725 - # (https://github.com/langgenius/dify/issues/20725). - # - # The proper fix would be to make `QuestionClassifierNode` inherit - # from `BaseNode` instead of `LLMNode`. - return False diff --git a/api/dify_graph/nodes/question_classifier/exc.py b/api/dify_graph/nodes/question_classifier/exc.py deleted file mode 100644 index 2c6354e2a70..00000000000 --- a/api/dify_graph/nodes/question_classifier/exc.py +++ /dev/null @@ -1,6 +0,0 @@ -class QuestionClassifierNodeError(ValueError): - """Base class for QuestionClassifierNode errors.""" - - -class InvalidModelTypeError(QuestionClassifierNodeError): - """Raised when the model is not a Large Language Model.""" diff --git a/api/dify_graph/nodes/question_classifier/question_classifier_node.py b/api/dify_graph/nodes/question_classifier/question_classifier_node.py deleted file mode 100644 index 59d0a2a4d80..00000000000 --- a/api/dify_graph/nodes/question_classifier/question_classifier_node.py +++ /dev/null @@ -1,395 +0,0 @@ -import json -import re -from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any - -from core.model_manager import ModelInstance -from core.prompt.simple_prompt_transform import ModelMode -from core.prompt.utils.prompt_message_util import PromptMessageUtil -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import ( - BuiltinNodeTypes, - NodeExecutionType, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from dify_graph.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole -from dify_graph.model_runtime.memory import PromptMessageMemory -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.node_events import ModelInvokeCompletedEvent, NodeRunResult -from dify_graph.nodes.base.entities import VariableSelector -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser -from dify_graph.nodes.llm import ( - LLMNode, - LLMNodeChatModelMessage, - LLMNodeCompletionModelPromptTemplate, - llm_utils, -) -from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer -from dify_graph.nodes.protocols import HttpClientProtocol -from libs.json_in_md_parser import parse_and_check_json_markdown - -from .entities import QuestionClassifierNodeData -from .exc import InvalidModelTypeError -from .template_prompts import ( - QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1, - QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2, - QUESTION_CLASSIFIER_COMPLETION_PROMPT, - QUESTION_CLASSIFIER_SYSTEM_PROMPT, - QUESTION_CLASSIFIER_USER_PROMPT_1, - QUESTION_CLASSIFIER_USER_PROMPT_2, - QUESTION_CLASSIFIER_USER_PROMPT_3, -) - -if TYPE_CHECKING: - from dify_graph.file.models import File - from dify_graph.runtime import GraphRuntimeState - - -class QuestionClassifierNode(Node[QuestionClassifierNodeData]): - node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER - execution_type = NodeExecutionType.BRANCH - - _file_outputs: list["File"] - _llm_file_saver: LLMFileSaver - _credentials_provider: "CredentialsProvider" - _model_factory: "ModelFactory" - _model_instance: ModelInstance - _memory: PromptMessageMemory | None - _template_renderer: TemplateRenderer - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - *, - credentials_provider: "CredentialsProvider", - model_factory: "ModelFactory", - model_instance: ModelInstance, - http_client: HttpClientProtocol, - template_renderer: TemplateRenderer, - memory: PromptMessageMemory | None = None, - llm_file_saver: LLMFileSaver | None = None, - ): - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - # LLM file outputs, used for MultiModal outputs. - self._file_outputs = [] - - self._credentials_provider = credentials_provider - self._model_factory = model_factory - self._model_instance = model_instance - self._memory = memory - self._template_renderer = template_renderer - - if llm_file_saver is None: - dify_ctx = self.require_dify_context() - llm_file_saver = FileSaverImpl( - user_id=dify_ctx.user_id, - tenant_id=dify_ctx.tenant_id, - http_client=http_client, - ) - self._llm_file_saver = llm_file_saver - - @classmethod - def version(cls): - return "1" - - def _run(self): - node_data = self.node_data - variable_pool = self.graph_runtime_state.variable_pool - - # extract variables - variable = variable_pool.get(node_data.query_variable_selector) if node_data.query_variable_selector else None - query = variable.value if variable else None - variables = {"query": query} - # fetch model instance - model_instance = self._model_instance - memory = self._memory - # fetch instruction - node_data.instruction = node_data.instruction or "" - node_data.instruction = variable_pool.convert_template(node_data.instruction).text - - files = ( - llm_utils.fetch_files( - variable_pool=variable_pool, - selector=node_data.vision.configs.variable_selector, - ) - if node_data.vision.enabled - else [] - ) - - # fetch prompt messages - rest_token = self._calculate_rest_token( - node_data=node_data, - query=query or "", - model_instance=model_instance, - context="", - ) - prompt_template = self._get_prompt_template( - node_data=node_data, - query=query or "", - memory=memory, - max_token_limit=rest_token, - ) - # Some models (e.g. Gemma, Mistral) force roles alternation (user/assistant/user/assistant...). - # If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt, - # two consecutive user prompts will be generated, causing model's error. - # To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end. - prompt_messages, stop = llm_utils.fetch_prompt_messages( - prompt_template=prompt_template, - sys_query="", - memory=memory, - model_instance=model_instance, - stop=model_instance.stop, - sys_files=files, - vision_enabled=node_data.vision.enabled, - vision_detail=node_data.vision.configs.detail, - variable_pool=variable_pool, - jinja2_variables=[], - template_renderer=self._template_renderer, - ) - - result_text = "" - usage = LLMUsage.empty_usage() - finish_reason = None - - try: - # handle invoke result - generator = LLMNode.invoke_llm( - model_instance=model_instance, - prompt_messages=prompt_messages, - stop=stop, - user_id=self.require_dify_context().user_id, - structured_output_enabled=False, - structured_output=None, - file_saver=self._llm_file_saver, - file_outputs=self._file_outputs, - node_id=self._node_id, - node_type=self.node_type, - ) - - for event in generator: - if isinstance(event, ModelInvokeCompletedEvent): - result_text = event.text - usage = event.usage - finish_reason = event.finish_reason - break - - rendered_classes = [ - c.model_copy(update={"name": variable_pool.convert_template(c.name).text}) for c in node_data.classes - ] - - category_name = rendered_classes[0].name - category_id = rendered_classes[0].id - if "" in result_text: - result_text = re.sub(r"]*>[\s\S]*?", "", result_text, flags=re.IGNORECASE) - result_text_json = parse_and_check_json_markdown(result_text, []) - # result_text_json = json.loads(result_text.strip('```JSON\n')) - if "category_name" in result_text_json and "category_id" in result_text_json: - category_id_result = result_text_json["category_id"] - classes = rendered_classes - classes_map = {class_.id: class_.name for class_ in classes} - category_ids = [_class.id for _class in classes] - if category_id_result in category_ids: - category_name = classes_map[category_id_result] - category_id = category_id_result - process_data = { - "model_mode": node_data.model.mode, - "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( - model_mode=node_data.model.mode, prompt_messages=prompt_messages - ), - "usage": jsonable_encoder(usage), - "finish_reason": finish_reason, - "model_provider": model_instance.provider, - "model_name": model_instance.model_name, - } - outputs = { - "class_name": category_name, - "class_id": category_id, - "usage": jsonable_encoder(usage), - } - - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=variables, - process_data=process_data, - outputs=outputs, - edge_source_handle=category_id, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - }, - llm_usage=usage, - ) - except ValueError as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=variables, - error=str(e), - error_type=type(e).__name__, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - }, - llm_usage=usage, - ) - - @property - def model_instance(self) -> ModelInstance: - return self._model_instance - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: QuestionClassifierNodeData, - ) -> Mapping[str, Sequence[str]]: - # graph_config is not used in this node type - variable_mapping = {"query": node_data.query_variable_selector} - variable_selectors: list[VariableSelector] = [] - if node_data.instruction: - variable_template_parser = VariableTemplateParser(template=node_data.instruction) - variable_selectors.extend(variable_template_parser.extract_variable_selectors()) - for variable_selector in variable_selectors: - variable_mapping[variable_selector.variable] = list(variable_selector.value_selector) - - variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} - - return variable_mapping - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - """ - Get default config of node. - :param filters: filter by node config parameters (not used in this implementation). - :return: - """ - # filters parameter is not used in this node type - return {"type": "question-classifier", "config": {"instructions": ""}} - - def _calculate_rest_token( - self, - node_data: QuestionClassifierNodeData, - query: str, - model_instance: ModelInstance, - context: str | None, - ) -> int: - model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) - - prompt_template = self._get_prompt_template(node_data, query, None, 2000) - prompt_messages, _ = llm_utils.fetch_prompt_messages( - prompt_template=prompt_template, - sys_query="", - sys_files=[], - context=context, - memory=None, - model_instance=model_instance, - stop=model_instance.stop, - memory_config=node_data.memory, - vision_enabled=False, - vision_detail=node_data.vision.configs.detail, - variable_pool=self.graph_runtime_state.variable_pool, - jinja2_variables=[], - template_renderer=self._template_renderer, - ) - rest_tokens = 2000 - - model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) - if model_context_tokens: - curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) - - max_tokens = 0 - for parameter_rule in model_schema.parameter_rules: - if parameter_rule.name == "max_tokens" or ( - parameter_rule.use_template and parameter_rule.use_template == "max_tokens" - ): - max_tokens = ( - model_instance.parameters.get(parameter_rule.name) - or model_instance.parameters.get(parameter_rule.use_template or "") - ) or 0 - - rest_tokens = model_context_tokens - max_tokens - curr_message_tokens - rest_tokens = max(rest_tokens, 0) - - return rest_tokens - - def _get_prompt_template( - self, - node_data: QuestionClassifierNodeData, - query: str, - memory: PromptMessageMemory | None, - max_token_limit: int = 2000, - ): - model_mode = ModelMode(node_data.model.mode) - classes = node_data.classes - categories = [] - for class_ in classes: - category = {"category_id": class_.id, "category_name": class_.name} - categories.append(category) - instruction = node_data.instruction or "" - input_text = query - memory_str = "" - if memory: - memory_str = llm_utils.fetch_memory_text( - memory=memory, - max_token_limit=max_token_limit, - message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None, - ) - prompt_messages: list[LLMNodeChatModelMessage] = [] - if model_mode == ModelMode.CHAT: - system_prompt_messages = LLMNodeChatModelMessage( - role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str) - ) - prompt_messages.append(system_prompt_messages) - user_prompt_message_1 = LLMNodeChatModelMessage( - role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_1 - ) - prompt_messages.append(user_prompt_message_1) - assistant_prompt_message_1 = LLMNodeChatModelMessage( - role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 - ) - prompt_messages.append(assistant_prompt_message_1) - user_prompt_message_2 = LLMNodeChatModelMessage( - role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2 - ) - prompt_messages.append(user_prompt_message_2) - assistant_prompt_message_2 = LLMNodeChatModelMessage( - role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 - ) - prompt_messages.append(assistant_prompt_message_2) - user_prompt_message_3 = LLMNodeChatModelMessage( - role=PromptMessageRole.USER, - text=QUESTION_CLASSIFIER_USER_PROMPT_3.format( - input_text=input_text, - categories=json.dumps(categories, ensure_ascii=False), - classification_instructions=instruction, - ), - ) - prompt_messages.append(user_prompt_message_3) - return prompt_messages - elif model_mode == ModelMode.COMPLETION: - return LLMNodeCompletionModelPromptTemplate( - text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format( - histories=memory_str, - input_text=input_text, - categories=json.dumps(categories, ensure_ascii=False), - classification_instructions=instruction, - ) - ) - - else: - raise InvalidModelTypeError(f"Model mode {model_mode} not support.") diff --git a/api/dify_graph/nodes/question_classifier/template_prompts.py b/api/dify_graph/nodes/question_classifier/template_prompts.py deleted file mode 100644 index a615c323836..00000000000 --- a/api/dify_graph/nodes/question_classifier/template_prompts.py +++ /dev/null @@ -1,76 +0,0 @@ -QUESTION_CLASSIFIER_SYSTEM_PROMPT = """ -### Job Description', -You are a text classification engine that analyzes text data and assigns categories based on user input or automatically determined categories. -### Task -Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output. Additionally, you need to extract the key words from the text that are related to the classification. -### Format -The input text is in the variable input_text. Categories are specified as a category list with two filed category_id and category_name in the variable categories. Classification instructions may be included to improve the classification accuracy. -### Constraint -DO NOT include anything other than the JSON array in your response. -### Memory -Here are the chat histories between human and assistant, inside XML tags. - -{histories} - -""" # noqa: E501 - -QUESTION_CLASSIFIER_USER_PROMPT_1 = """ - {"input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."], - "categories": [{"category_id":"f5660049-284f-41a7-b301-fd24176a711c","category_name":"Customer Service"},{"category_id":"8d007d06-f2c9-4be5-8ff6-cd4381c13c60","category_name":"Satisfaction"},{"category_id":"5fbbbb18-9843-466d-9b8e-b9bfbb9482c8","category_name":"Sales"},{"category_id":"23623c75-7184-4a2e-8226-466c2e4631e4","category_name":"Product"}], - "classification_instructions": ["classify the text based on the feedback provided by customer"]} -""" # noqa: E501 - -QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 = """ -```json - {"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"], - "category_id": "f5660049-284f-41a7-b301-fd24176a711c", - "category_name": "Customer Service"} -``` -""" - -QUESTION_CLASSIFIER_USER_PROMPT_2 = """ - {"input_text": ["bad service, slow to bring the food"], - "categories": [{"category_id":"80fb86a0-4454-4bf5-924c-f253fdd83c02","category_name":"Food Quality"},{"category_id":"f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name":"Experience"},{"category_id":"cc771f63-74e7-4c61-882e-3eda9d8ba5d7","category_name":"Price"}], - "classification_instructions": []} -""" # noqa: E501 - -QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 = """ -```json - {"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"], - "category_id": "f6ff5bc3-aca0-4e4a-8627-e760d0aca78f", - "category_name": "Experience"} -``` -""" - -QUESTION_CLASSIFIER_USER_PROMPT_3 = """ - {{"input_text": ["{input_text}"], - "categories": {categories}, - "classification_instructions": ["{classification_instructions}"]}} -""" - -QUESTION_CLASSIFIER_COMPLETION_PROMPT = """ -### Job Description -You are a text classification engine that analyzes text data and assigns categories based on user input or automatically determined categories. -### Task -Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output. Additionally, you need to extract the key words from the text that are related to the classification. -### Format -The input text is in the variable input_text. Categories are specified as a category list with two filed category_id and category_name in the variable categories. Classification instructions may be included to improve the classification accuracy. -### Constraint -DO NOT include anything other than the JSON array in your response. -### Example -Here is the chat example between human and assistant, inside XML tags. - -User:{{"input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."], "categories": [{{"category_id":"f5660049-284f-41a7-b301-fd24176a711c","category_name":"Customer Service"}},{{"category_id":"8d007d06-f2c9-4be5-8ff6-cd4381c13c60","category_name":"Satisfaction"}},{{"category_id":"5fbbbb18-9843-466d-9b8e-b9bfbb9482c8","category_name":"Sales"}},{{"category_id":"23623c75-7184-4a2e-8226-466c2e4631e4","category_name":"Product"}}], "classification_instructions": ["classify the text based on the feedback provided by customer"]}} -Assistant:{{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],"category_id": "f5660049-284f-41a7-b301-fd24176a711c","category_name": "Customer Service"}} -User:{{"input_text": ["bad service, slow to bring the food"], "categories": [{{"category_id":"80fb86a0-4454-4bf5-924c-f253fdd83c02","category_name":"Food Quality"}},{{"category_id":"f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name":"Experience"}},{{"category_id":"cc771f63-74e7-4c61-882e-3eda9d8ba5d7","category_name":"Price"}}], "classification_instructions": []}} -Assistant:{{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],"category_id": "f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name": "Experience"}} - -### Memory -Here are the chat histories between human and assistant, inside XML tags. - -{histories} - -### User Input -{{"input_text" : ["{input_text}"], "categories" : {categories},"classification_instruction" : ["{classification_instructions}"]}} -### Assistant Output -""" # noqa: E501 diff --git a/api/dify_graph/nodes/start/__init__.py b/api/dify_graph/nodes/start/__init__.py deleted file mode 100644 index 54117804231..00000000000 --- a/api/dify_graph/nodes/start/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .start_node import StartNode - -__all__ = ["StartNode"] diff --git a/api/dify_graph/nodes/start/entities.py b/api/dify_graph/nodes/start/entities.py deleted file mode 100644 index 92ebd1a2ec5..00000000000 --- a/api/dify_graph/nodes/start/entities.py +++ /dev/null @@ -1,16 +0,0 @@ -from collections.abc import Sequence - -from pydantic import Field - -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.variables.input_entities import VariableEntity - - -class StartNodeData(BaseNodeData): - """ - Start Node Data - """ - - type: NodeType = BuiltinNodeTypes.START - variables: Sequence[VariableEntity] = Field(default_factory=list) diff --git a/api/dify_graph/nodes/start/start_node.py b/api/dify_graph/nodes/start/start_node.py deleted file mode 100644 index 5e6055ea345..00000000000 --- a/api/dify_graph/nodes/start/start_node.py +++ /dev/null @@ -1,63 +0,0 @@ -from typing import Any - -from jsonschema import Draft7Validator, ValidationError - -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.variables.input_entities import VariableEntityType - - -class StartNode(Node[StartNodeData]): - node_type = BuiltinNodeTypes.START - execution_type = NodeExecutionType.ROOT - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) - self._validate_and_normalize_json_object_inputs(node_inputs) - system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict() - - # TODO: System variables should be directly accessible, no need for special handling - # Set system variables as node outputs. - for var in system_inputs: - node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] - outputs = dict(node_inputs) - - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=outputs) - - def _validate_and_normalize_json_object_inputs(self, node_inputs: dict[str, Any]) -> None: - for variable in self.node_data.variables: - if variable.type != VariableEntityType.JSON_OBJECT: - continue - - key = variable.variable - value = node_inputs.get(key) - - if value is None and variable.required: - raise ValueError(f"{key} is required in input form") - - # If no value provided, skip further processing for this key - if not value: - continue - - if not isinstance(value, dict): - raise ValueError(f"JSON object for '{key}' must be an object") - - # Overwrite with normalized dict to ensure downstream consistency - node_inputs[key] = value - - # If schema exists, then validate against it - schema = variable.json_schema - if not schema: - continue - - try: - Draft7Validator(schema).validate(value) - except ValidationError as e: - raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}") diff --git a/api/dify_graph/nodes/template_transform/__init__.py b/api/dify_graph/nodes/template_transform/__init__.py deleted file mode 100644 index 43863b9d59a..00000000000 --- a/api/dify_graph/nodes/template_transform/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .template_transform_node import TemplateTransformNode - -__all__ = ["TemplateTransformNode"] diff --git a/api/dify_graph/nodes/template_transform/entities.py b/api/dify_graph/nodes/template_transform/entities.py deleted file mode 100644 index ac292399587..00000000000 --- a/api/dify_graph/nodes/template_transform/entities.py +++ /dev/null @@ -1,13 +0,0 @@ -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base.entities import VariableSelector - - -class TemplateTransformNodeData(BaseNodeData): - """ - Template Transform Node Data. - """ - - type: NodeType = BuiltinNodeTypes.TEMPLATE_TRANSFORM - variables: list[VariableSelector] - template: str diff --git a/api/dify_graph/nodes/template_transform/template_renderer.py b/api/dify_graph/nodes/template_transform/template_renderer.py deleted file mode 100644 index 9b679d4497d..00000000000 --- a/api/dify_graph/nodes/template_transform/template_renderer.py +++ /dev/null @@ -1,41 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping -from typing import Any, Protocol - -from dify_graph.nodes.code.code_node import WorkflowCodeExecutor -from dify_graph.nodes.code.entities import CodeLanguage - - -class TemplateRenderError(ValueError): - """Raised when rendering a Jinja2 template fails.""" - - -class Jinja2TemplateRenderer(Protocol): - """Render Jinja2 templates for template transform nodes.""" - - def render_template(self, template: str, variables: Mapping[str, Any]) -> str: - """Render a Jinja2 template with provided variables.""" - raise NotImplementedError - - -class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer): - """Adapter that renders Jinja2 templates via CodeExecutor.""" - - _code_executor: WorkflowCodeExecutor - - def __init__(self, code_executor: WorkflowCodeExecutor) -> None: - self._code_executor = code_executor - - def render_template(self, template: str, variables: Mapping[str, Any]) -> str: - try: - result = self._code_executor.execute(language=CodeLanguage.JINJA2, code=template, inputs=variables) - except Exception as exc: - if self._code_executor.is_execution_error(exc): - raise TemplateRenderError(str(exc)) from exc - raise - - rendered = result.get("result") - if not isinstance(rendered, str): - raise TemplateRenderError("Template render result must be a string.") - return rendered diff --git a/api/dify_graph/nodes/template_transform/template_transform_node.py b/api/dify_graph/nodes/template_transform/template_transform_node.py deleted file mode 100644 index dc6fce2b0aa..00000000000 --- a/api/dify_graph/nodes/template_transform/template_transform_node.py +++ /dev/null @@ -1,95 +0,0 @@ -from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any - -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.template_transform.entities import TemplateTransformNodeData -from dify_graph.nodes.template_transform.template_renderer import ( - Jinja2TemplateRenderer, - TemplateRenderError, -) - -if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState - -DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH = 400_000 - - -class TemplateTransformNode(Node[TemplateTransformNodeData]): - node_type = BuiltinNodeTypes.TEMPLATE_TRANSFORM - _template_renderer: Jinja2TemplateRenderer - _max_output_length: int - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - *, - template_renderer: Jinja2TemplateRenderer, - max_output_length: int | None = None, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - self._template_renderer = template_renderer - - if max_output_length is not None and max_output_length <= 0: - raise ValueError("max_output_length must be a positive integer") - self._max_output_length = max_output_length or DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - """ - Get default config of node. - :param filters: filter by node config parameters. - :return: - """ - return { - "type": "template-transform", - "config": {"variables": [{"variable": "arg1", "value_selector": []}], "template": "{{ arg1 }}"}, - } - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - # Get variables - variables: dict[str, Any] = {} - for variable_selector in self.node_data.variables: - variable_name = variable_selector.variable - value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) - variables[variable_name] = value.to_object() if value else None - # Run code - try: - rendered = self._template_renderer.render_template(self.node_data.template, variables) - except TemplateRenderError as e: - return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e)) - - if len(rendered) > self._max_output_length: - return NodeRunResult( - inputs=variables, - status=WorkflowNodeExecutionStatus.FAILED, - error=f"Output length exceeds {self._max_output_length} characters", - ) - - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": rendered} - ) - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData - ) -> Mapping[str, Sequence[str]]: - return { - node_id + "." + variable_selector.variable: variable_selector.value_selector - for variable_selector in node_data.variables - } diff --git a/api/dify_graph/nodes/tool/__init__.py b/api/dify_graph/nodes/tool/__init__.py deleted file mode 100644 index f4982e655d1..00000000000 --- a/api/dify_graph/nodes/tool/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .tool_node import ToolNode - -__all__ = ["ToolNode"] diff --git a/api/dify_graph/nodes/tool/entities.py b/api/dify_graph/nodes/tool/entities.py deleted file mode 100644 index b041ee66fda..00000000000 --- a/api/dify_graph/nodes/tool/entities.py +++ /dev/null @@ -1,87 +0,0 @@ -from typing import Any, Literal, Union - -from pydantic import BaseModel, field_validator -from pydantic_core.core_schema import ValidationInfo - -from core.tools.entities.tool_entities import ToolProviderType -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType - - -class ToolEntity(BaseModel): - provider_id: str - provider_type: ToolProviderType - provider_name: str # redundancy - tool_name: str - tool_label: str # redundancy - tool_configurations: dict[str, Any] - credential_id: str | None = None - plugin_unique_identifier: str | None = None # redundancy - - @field_validator("tool_configurations", mode="before") - @classmethod - def validate_tool_configurations(cls, value, values: ValidationInfo): - if not isinstance(value, dict): - raise ValueError("tool_configurations must be a dictionary") - - for key in values.data.get("tool_configurations", {}): - value = values.data.get("tool_configurations", {}).get(key) - if not isinstance(value, str | int | float | bool): - raise ValueError(f"{key} must be a string") - - return value - - -class ToolNodeData(BaseNodeData, ToolEntity): - type: NodeType = BuiltinNodeTypes.TOOL - - class ToolInput(BaseModel): - # TODO: check this type - value: Union[Any, list[str]] - type: Literal["mixed", "variable", "constant"] - - @field_validator("type", mode="before") - @classmethod - def check_type(cls, value, validation_info: ValidationInfo): - typ = value - value = validation_info.data.get("value") - - if value is None: - return typ - - if typ == "mixed" and not isinstance(value, str): - raise ValueError("value must be a string") - elif typ == "variable": - if not isinstance(value, list): - raise ValueError("value must be a list") - for val in value: - if not isinstance(val, str): - raise ValueError("value must be a list of strings") - elif typ == "constant" and not isinstance(value, (allowed_types := (str, int, float, bool, dict, list))): - raise ValueError(f"value must be one of: {', '.join(t.__name__ for t in allowed_types)}") - return typ - - tool_parameters: dict[str, ToolInput] - # The version of the tool parameter. - # If this value is None, it indicates this is a previous version - # and requires using the legacy parameter parsing rules. - tool_node_version: str | None = None - - @field_validator("tool_parameters", mode="before") - @classmethod - def filter_none_tool_inputs(cls, value): - if not isinstance(value, dict): - return value - - return { - key: tool_input - for key, tool_input in value.items() - if tool_input is not None and cls._has_valid_value(tool_input) - } - - @staticmethod - def _has_valid_value(tool_input): - """Check if the value is valid""" - if isinstance(tool_input, dict): - return tool_input.get("value") is not None - return getattr(tool_input, "value", None) is not None diff --git a/api/dify_graph/nodes/tool/exc.py b/api/dify_graph/nodes/tool/exc.py deleted file mode 100644 index 7212e8bfc07..00000000000 --- a/api/dify_graph/nodes/tool/exc.py +++ /dev/null @@ -1,16 +0,0 @@ -class ToolNodeError(ValueError): - """Base exception for tool node errors.""" - - pass - - -class ToolParameterError(ToolNodeError): - """Exception raised for errors in tool parameters.""" - - pass - - -class ToolFileError(ToolNodeError): - """Exception raised for errors related to tool files.""" - - pass diff --git a/api/dify_graph/nodes/tool/tool_node.py b/api/dify_graph/nodes/tool/tool_node.py deleted file mode 100644 index 598f0da92ef..00000000000 --- a/api/dify_graph/nodes/tool/tool_node.py +++ /dev/null @@ -1,524 +0,0 @@ -from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any - -from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler -from core.tools.__base.tool import Tool -from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter -from core.tools.errors import ToolInvokeError -from core.tools.tool_engine import ToolEngine -from core.tools.utils.message_transformer import ToolFileMessageTransformer -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import ( - BuiltinNodeTypes, - SystemVariableKey, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from dify_graph.file import File, FileTransferMethod -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser -from dify_graph.nodes.protocols import ToolFileManagerProtocol -from dify_graph.variables.segments import ArrayAnySegment, ArrayFileSegment -from dify_graph.variables.variables import ArrayAnyVariable -from factories import file_factory -from services.tools.builtin_tools_manage_service import BuiltinToolManageService - -from .entities import ToolNodeData -from .exc import ( - ToolFileError, - ToolNodeError, - ToolParameterError, -) - -if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool - - -class ToolNode(Node[ToolNodeData]): - """ - Tool Node - """ - - node_type = BuiltinNodeTypes.TOOL - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - *, - tool_file_manager_factory: ToolFileManagerProtocol, - ): - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - self._tool_file_manager_factory = tool_file_manager_factory - - @classmethod - def version(cls) -> str: - return "1" - - def populate_start_event(self, event) -> None: - event.provider_id = self.node_data.provider_id - event.provider_type = self.node_data.provider_type - - def _run(self) -> Generator[NodeEventBase, None, None]: - """ - Run the tool node - """ - from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError - - dify_ctx = self.require_dify_context() - - # fetch tool icon - tool_info = { - "provider_type": self.node_data.provider_type.value, - "provider_id": self.node_data.provider_id, - "plugin_unique_identifier": self.node_data.plugin_unique_identifier, - } - - # get tool runtime - try: - from core.tools.tool_manager import ToolManager - - # This is an issue that caused problems before. - # Logically, we shouldn't use the node_data.version field for judgment - # But for backward compatibility with historical data - # this version field judgment is still preserved here. - variable_pool: VariablePool | None = None - if self.node_data.version != "1" or self.node_data.tool_node_version is not None: - variable_pool = self.graph_runtime_state.variable_pool - tool_runtime = ToolManager.get_workflow_tool_runtime( - dify_ctx.tenant_id, - dify_ctx.app_id, - self._node_id, - self.node_data, - dify_ctx.invoke_from, - variable_pool, - ) - except ToolNodeError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs={}, - metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, - error=f"Failed to get tool runtime: {str(e)}", - error_type=type(e).__name__, - ) - ) - return - - # get parameters - tool_parameters = tool_runtime.get_merged_runtime_parameters() or [] - parameters = self._generate_parameters( - tool_parameters=tool_parameters, - variable_pool=self.graph_runtime_state.variable_pool, - node_data=self.node_data, - ) - parameters_for_log = self._generate_parameters( - tool_parameters=tool_parameters, - variable_pool=self.graph_runtime_state.variable_pool, - node_data=self.node_data, - for_log=True, - ) - # get conversation id - conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) - - try: - message_stream = ToolEngine.generic_invoke( - tool=tool_runtime, - tool_parameters=parameters, - user_id=dify_ctx.user_id, - workflow_tool_callback=DifyWorkflowCallbackHandler(), - workflow_call_depth=self.workflow_call_depth, - app_id=dify_ctx.app_id, - conversation_id=conversation_id.text if conversation_id else None, - ) - except ToolNodeError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters_for_log, - metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, - error=f"Failed to invoke tool: {str(e)}", - error_type=type(e).__name__, - ) - ) - return - - try: - # convert tool messages - _ = yield from self._transform_message( - messages=message_stream, - tool_info=tool_info, - parameters_for_log=parameters_for_log, - user_id=dify_ctx.user_id, - tenant_id=dify_ctx.tenant_id, - node_id=self._node_id, - tool_runtime=tool_runtime, - ) - except ToolInvokeError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters_for_log, - metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, - error=f"Failed to invoke tool {self.node_data.provider_name}: {str(e)}", - error_type=type(e).__name__, - ) - ) - except PluginInvokeError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters_for_log, - metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, - error=e.to_user_friendly_error(plugin_name=self.node_data.provider_name), - error_type=type(e).__name__, - ) - ) - except PluginDaemonClientSideError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters_for_log, - metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, - error=f"Failed to invoke tool, error: {e.description}", - error_type=type(e).__name__, - ) - ) - - def _generate_parameters( - self, - *, - tool_parameters: Sequence[ToolParameter], - variable_pool: "VariablePool", - node_data: ToolNodeData, - for_log: bool = False, - ) -> dict[str, Any]: - """ - Generate parameters based on the given tool parameters, variable pool, and node data. - - Args: - tool_parameters (Sequence[ToolParameter]): The list of tool parameters. - variable_pool (VariablePool): The variable pool containing the variables. - node_data (ToolNodeData): The data associated with the tool node. - - Returns: - Mapping[str, Any]: A dictionary containing the generated parameters. - - """ - tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters} - - result: dict[str, Any] = {} - for parameter_name in node_data.tool_parameters: - parameter = tool_parameters_dictionary.get(parameter_name) - if not parameter: - result[parameter_name] = None - continue - tool_input = node_data.tool_parameters[parameter_name] - if tool_input.type == "variable": - variable = variable_pool.get(tool_input.value) - if variable is None: - if parameter.required: - raise ToolParameterError(f"Variable {tool_input.value} does not exist") - continue - parameter_value = variable.value - elif tool_input.type in {"mixed", "constant"}: - segment_group = variable_pool.convert_template(str(tool_input.value)) - parameter_value = segment_group.log if for_log else segment_group.text - else: - raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'") - result[parameter_name] = parameter_value - - return result - - def _fetch_files(self, variable_pool: "VariablePool") -> list[File]: - variable = variable_pool.get(["sys", SystemVariableKey.FILES]) - assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) - return list(variable.value) if variable else [] - - def _transform_message( - self, - messages: Generator[ToolInvokeMessage, None, None], - tool_info: Mapping[str, Any], - parameters_for_log: dict[str, Any], - user_id: str, - tenant_id: str, - node_id: str, - tool_runtime: Tool, - ) -> Generator[NodeEventBase, None, LLMUsage]: - """ - Convert ToolInvokeMessages into tuple[plain_text, files] - """ - # transform message and handle file storage - from core.plugin.impl.plugin import PluginInstaller - - message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages( - messages=messages, - user_id=user_id, - tenant_id=tenant_id, - conversation_id=None, - ) - - text = "" - files: list[File] = [] - json: list[dict | list] = [] - - variables: dict[str, Any] = {} - - for message in message_stream: - if message.type in { - ToolInvokeMessage.MessageType.IMAGE_LINK, - ToolInvokeMessage.MessageType.BINARY_LINK, - ToolInvokeMessage.MessageType.IMAGE, - }: - assert isinstance(message.message, ToolInvokeMessage.TextMessage) - - url = message.message.text - if message.meta: - transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) - else: - transfer_method = FileTransferMethod.TOOL_FILE - - tool_file_id = str(url).split("/")[-1].split(".")[0] - - _, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id) - if not tool_file: - raise ToolFileError(f"tool file {tool_file_id} not found") - - mapping = { - "tool_file_id": tool_file_id, - "type": file_factory.get_file_type_by_mime_type(tool_file.mimetype), - "transfer_method": transfer_method, - "url": url, - } - file = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=tenant_id, - ) - files.append(file) - elif message.type == ToolInvokeMessage.MessageType.BLOB: - # get tool file id - assert isinstance(message.message, ToolInvokeMessage.TextMessage) - assert message.meta - - tool_file_id = message.message.text.split("/")[-1].split(".")[0] - _, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id) - if not tool_file: - raise ToolFileError(f"tool file {tool_file_id} not exists") - - mapping = { - "tool_file_id": tool_file_id, - "transfer_method": FileTransferMethod.TOOL_FILE, - } - - files.append( - file_factory.build_from_mapping( - mapping=mapping, - tenant_id=tenant_id, - ) - ) - elif message.type == ToolInvokeMessage.MessageType.TEXT: - assert isinstance(message.message, ToolInvokeMessage.TextMessage) - text += message.message.text - yield StreamChunkEvent( - selector=[node_id, "text"], - chunk=message.message.text, - is_final=False, - ) - elif message.type == ToolInvokeMessage.MessageType.JSON: - assert isinstance(message.message, ToolInvokeMessage.JsonMessage) - # JSON message handling for tool node - if message.message.json_object: - json.append(message.message.json_object) - elif message.type == ToolInvokeMessage.MessageType.LINK: - assert isinstance(message.message, ToolInvokeMessage.TextMessage) - - # Check if this LINK message is a file link - file_obj = (message.meta or {}).get("file") - if isinstance(file_obj, File): - files.append(file_obj) - stream_text = f"File: {message.message.text}\n" - else: - stream_text = f"Link: {message.message.text}\n" - - text += stream_text - yield StreamChunkEvent( - selector=[node_id, "text"], - chunk=stream_text, - is_final=False, - ) - elif message.type == ToolInvokeMessage.MessageType.VARIABLE: - assert isinstance(message.message, ToolInvokeMessage.VariableMessage) - variable_name = message.message.variable_name - variable_value = message.message.variable_value - if message.message.stream: - if not isinstance(variable_value, str): - raise ToolNodeError("When 'stream' is True, 'variable_value' must be a string.") - if variable_name not in variables: - variables[variable_name] = "" - variables[variable_name] += variable_value - - yield StreamChunkEvent( - selector=[node_id, variable_name], - chunk=variable_value, - is_final=False, - ) - else: - variables[variable_name] = variable_value - elif message.type == ToolInvokeMessage.MessageType.FILE: - assert message.meta is not None - assert isinstance(message.meta, dict) - # Validate that meta contains a 'file' key - if "file" not in message.meta: - raise ToolNodeError("File message is missing 'file' key in meta") - - # Validate that the file is an instance of File - if not isinstance(message.meta["file"], File): - raise ToolNodeError(f"Expected File object but got {type(message.meta['file']).__name__}") - files.append(message.meta["file"]) - elif message.type == ToolInvokeMessage.MessageType.LOG: - assert isinstance(message.message, ToolInvokeMessage.LogMessage) - if message.message.metadata: - icon = tool_info.get("icon", "") - dict_metadata = dict(message.message.metadata) - if dict_metadata.get("provider"): - manager = PluginInstaller() - plugins = manager.list_plugins(tenant_id) - try: - current_plugin = next( - plugin - for plugin in plugins - if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"] - ) - icon = current_plugin.declaration.icon - except StopIteration: - pass - icon_dark = None - try: - builtin_tool = next( - provider - for provider in BuiltinToolManageService.list_builtin_tools( - user_id, - tenant_id, - ) - if provider.name == dict_metadata["provider"] - ) - icon = builtin_tool.icon - icon_dark = builtin_tool.icon_dark - except StopIteration: - pass - - dict_metadata["icon"] = icon - dict_metadata["icon_dark"] = icon_dark - message.message.metadata = dict_metadata - - # Add agent_logs to outputs['json'] to ensure frontend can access thinking process - json_output: list[dict[str, Any] | list[Any]] = [] - - # Step 2: normalize JSON into {"data": [...]}.change json to list[dict] - if json: - json_output.extend(json) - else: - json_output.append({"data": []}) - - # Send final chunk events for all streamed outputs - # Final chunk for text stream - yield StreamChunkEvent( - selector=[self._node_id, "text"], - chunk="", - is_final=True, - ) - - # Final chunks for any streamed variables - for var_name in variables: - yield StreamChunkEvent( - selector=[self._node_id, var_name], - chunk="", - is_final=True, - ) - - usage = self._extract_tool_usage(tool_runtime) - - metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = { - WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, - } - if isinstance(usage.total_tokens, int) and usage.total_tokens > 0: - metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens - metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price - metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables}, - metadata=metadata, - inputs=parameters_for_log, - llm_usage=usage, - ) - ) - - return usage - - @staticmethod - def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage: - # Avoid importing WorkflowTool at module import time; rely on duck typing - # Some runtimes expose `latest_usage`; mocks may synthesize arbitrary attributes. - latest = getattr(tool_runtime, "latest_usage", None) - # Normalize into a concrete LLMUsage. MagicMock returns truthy attribute objects - # for any name, so we must type-check here. - if isinstance(latest, LLMUsage): - return latest - if isinstance(latest, dict): - # Allow dict payloads from external runtimes - return LLMUsage.model_validate(latest) - # Fallback to empty usage when attribute is missing or not a valid payload - return LLMUsage.empty_usage() - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: ToolNodeData, - ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ - _ = graph_config # Explicitly mark as unused - typed_node_data = node_data - result = {} - for parameter_name in typed_node_data.tool_parameters: - input = typed_node_data.tool_parameters[parameter_name] - match input.type: - case "mixed": - assert isinstance(input.value, str) - selectors = VariableTemplateParser(input.value).extract_variable_selectors() - for selector in selectors: - result[selector.variable] = selector.value_selector - case "variable": - selector_key = ".".join(input.value) - result[f"#{selector_key}#"] = input.value - case "constant": - pass - - result = {node_id + "." + key: value for key, value in result.items()} - - return result - - @property - def retry(self) -> bool: - return self.node_data.retry_config.retry_enabled diff --git a/api/dify_graph/nodes/variable_aggregator/__init__.py b/api/dify_graph/nodes/variable_aggregator/__init__.py deleted file mode 100644 index 0b6bf2a5b62..00000000000 --- a/api/dify_graph/nodes/variable_aggregator/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .variable_aggregator_node import VariableAggregatorNode - -__all__ = ["VariableAggregatorNode"] diff --git a/api/dify_graph/nodes/variable_aggregator/entities.py b/api/dify_graph/nodes/variable_aggregator/entities.py deleted file mode 100644 index 4779ebd9a9f..00000000000 --- a/api/dify_graph/nodes/variable_aggregator/entities.py +++ /dev/null @@ -1,35 +0,0 @@ -from pydantic import BaseModel - -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.variables.types import SegmentType - - -class AdvancedSettings(BaseModel): - """ - Advanced setting. - """ - - group_enabled: bool - - class Group(BaseModel): - """ - Group. - """ - - output_type: SegmentType - variables: list[list[str]] - group_name: str - - groups: list[Group] - - -class VariableAggregatorNodeData(BaseNodeData): - """ - Variable Aggregator Node Data. - """ - - type: NodeType = BuiltinNodeTypes.VARIABLE_AGGREGATOR - output_type: str - variables: list[list[str]] - advanced_settings: AdvancedSettings | None = None diff --git a/api/dify_graph/nodes/variable_aggregator/variable_aggregator_node.py b/api/dify_graph/nodes/variable_aggregator/variable_aggregator_node.py deleted file mode 100644 index 7d26de62322..00000000000 --- a/api/dify_graph/nodes/variable_aggregator/variable_aggregator_node.py +++ /dev/null @@ -1,40 +0,0 @@ -from collections.abc import Mapping - -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.variable_aggregator.entities import VariableAggregatorNodeData -from dify_graph.variables.segments import Segment - - -class VariableAggregatorNode(Node[VariableAggregatorNodeData]): - node_type = BuiltinNodeTypes.VARIABLE_AGGREGATOR - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - # Get variables - outputs: dict[str, Segment | Mapping[str, Segment]] = {} - inputs = {} - - if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled: - for selector in self.node_data.variables: - variable = self.graph_runtime_state.variable_pool.get(selector) - if variable is not None: - outputs = {"output": variable} - - inputs = {".".join(selector[1:]): variable.to_object()} - break - else: - for group in self.node_data.advanced_settings.groups: - for selector in group.variables: - variable = self.graph_runtime_state.variable_pool.get(selector) - - if variable is not None: - outputs[group.group_name] = {"output": variable} - inputs[".".join(selector[1:])] = variable.to_object() - break - - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs=outputs, inputs=inputs) diff --git a/api/dify_graph/nodes/variable_assigner/__init__.py b/api/dify_graph/nodes/variable_assigner/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/dify_graph/nodes/variable_assigner/common/__init__.py b/api/dify_graph/nodes/variable_assigner/common/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/dify_graph/nodes/variable_assigner/common/exc.py b/api/dify_graph/nodes/variable_assigner/common/exc.py deleted file mode 100644 index f8dbedc2901..00000000000 --- a/api/dify_graph/nodes/variable_assigner/common/exc.py +++ /dev/null @@ -1,4 +0,0 @@ -class VariableOperatorNodeError(ValueError): - """Base error type, don't use directly.""" - - pass diff --git a/api/dify_graph/nodes/variable_assigner/common/helpers.py b/api/dify_graph/nodes/variable_assigner/common/helpers.py deleted file mode 100644 index f0b22904a93..00000000000 --- a/api/dify_graph/nodes/variable_assigner/common/helpers.py +++ /dev/null @@ -1,55 +0,0 @@ -from collections.abc import Mapping, MutableMapping, Sequence -from typing import Any, TypeVar - -from pydantic import BaseModel - -from dify_graph.variables import Segment -from dify_graph.variables.consts import SELECTORS_LENGTH -from dify_graph.variables.types import SegmentType - -# Use double underscore (`__`) prefix for internal variables -# to minimize risk of collision with user-defined variable names. -_UPDATED_VARIABLES_KEY = "__updated_variables" - - -class UpdatedVariable(BaseModel): - name: str - selector: Sequence[str] - value_type: SegmentType - new_value: Any = None - - -_T = TypeVar("_T", bound=MutableMapping[str, Any]) - - -def variable_to_processed_data(selector: Sequence[str], seg: Segment) -> UpdatedVariable: - if len(selector) < SELECTORS_LENGTH: - raise Exception("selector too short") - _, var_name = selector[:2] - return UpdatedVariable( - name=var_name, - selector=list(selector[:2]), - value_type=seg.value_type, - new_value=seg.value, - ) - - -def set_updated_variables(m: _T, updates: Sequence[UpdatedVariable]) -> _T: - m[_UPDATED_VARIABLES_KEY] = updates - return m - - -def get_updated_variables(m: Mapping[str, Any]) -> Sequence[UpdatedVariable] | None: - updated_values = m.get(_UPDATED_VARIABLES_KEY, None) - if updated_values is None: - return None - result = [] - for items in updated_values: - if isinstance(items, UpdatedVariable): - result.append(items) - elif isinstance(items, dict): - items = UpdatedVariable.model_validate(items) - result.append(items) - else: - raise TypeError(f"Invalid updated variable: {items}, type={type(items)}") - return result diff --git a/api/dify_graph/nodes/variable_assigner/v1/__init__.py b/api/dify_graph/nodes/variable_assigner/v1/__init__.py deleted file mode 100644 index 7eb1428e503..00000000000 --- a/api/dify_graph/nodes/variable_assigner/v1/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .node import VariableAssignerNode - -__all__ = ["VariableAssignerNode"] diff --git a/api/dify_graph/nodes/variable_assigner/v1/node.py b/api/dify_graph/nodes/variable_assigner/v1/node.py deleted file mode 100644 index f9b261b191e..00000000000 --- a/api/dify_graph/nodes/variable_assigner/v1/node.py +++ /dev/null @@ -1,109 +0,0 @@ -from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any - -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.variable_assigner.common import helpers as common_helpers -from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from dify_graph.variables import SegmentType, VariableBase - -from .node_data import VariableAssignerData, WriteMode - -if TYPE_CHECKING: - from dify_graph.runtime import GraphRuntimeState - - -class VariableAssignerNode(Node[VariableAssignerData]): - node_type = BuiltinNodeTypes.VARIABLE_ASSIGNER - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - ): - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: - """ - Check if this Variable Assigner node blocks the output of specific variables. - - Returns True if this node updates any of the requested conversation variables. - """ - assigned_selector = tuple(self.node_data.assigned_variable_selector) - return assigned_selector in variable_selectors - - @classmethod - def version(cls) -> str: - return "1" - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: VariableAssignerData, - ) -> Mapping[str, Sequence[str]]: - mapping = {} - assigned_variable_node_id = node_data.assigned_variable_selector[0] - if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID: - selector_key = ".".join(node_data.assigned_variable_selector) - key = f"{node_id}.#{selector_key}#" - mapping[key] = node_data.assigned_variable_selector - - selector_key = ".".join(node_data.input_variable_selector) - key = f"{node_id}.#{selector_key}#" - mapping[key] = node_data.input_variable_selector - return mapping - - def _run(self) -> NodeRunResult: - assigned_variable_selector = self.node_data.assigned_variable_selector - # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject - original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector) - if not isinstance(original_variable, VariableBase): - raise VariableOperatorNodeError("assigned variable not found") - - match self.node_data.write_mode: - case WriteMode.OVER_WRITE: - income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector) - if not income_value: - raise VariableOperatorNodeError("input value not found") - updated_variable = original_variable.model_copy(update={"value": income_value.value}) - - case WriteMode.APPEND: - income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector) - if not income_value: - raise VariableOperatorNodeError("input value not found") - updated_value = original_variable.value + [income_value.value] - updated_variable = original_variable.model_copy(update={"value": updated_value}) - - case WriteMode.CLEAR: - income_value = SegmentType.get_zero_value(original_variable.value_type) - updated_variable = original_variable.model_copy(update={"value": income_value.to_object()}) - - # Over write the variable. - self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable) - - updated_variables = [common_helpers.variable_to_processed_data(assigned_variable_selector, updated_variable)] - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs={ - "value": income_value.to_object(), - }, - # NOTE(QuantumGhost): although only one variable is updated in `v1.VariableAssignerNode`, - # we still set `output_variables` as a list to ensure the schema of output is - # compatible with `v2.VariableAssignerNode`. - process_data=common_helpers.set_updated_variables({}, updated_variables), - outputs={}, - ) diff --git a/api/dify_graph/nodes/variable_assigner/v1/node_data.py b/api/dify_graph/nodes/variable_assigner/v1/node_data.py deleted file mode 100644 index 57acb29535a..00000000000 --- a/api/dify_graph/nodes/variable_assigner/v1/node_data.py +++ /dev/null @@ -1,18 +0,0 @@ -from collections.abc import Sequence -from enum import StrEnum - -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType - - -class WriteMode(StrEnum): - OVER_WRITE = "over-write" - APPEND = "append" - CLEAR = "clear" - - -class VariableAssignerData(BaseNodeData): - type: NodeType = BuiltinNodeTypes.VARIABLE_ASSIGNER - assigned_variable_selector: Sequence[str] - write_mode: WriteMode - input_variable_selector: Sequence[str] diff --git a/api/dify_graph/nodes/variable_assigner/v2/__init__.py b/api/dify_graph/nodes/variable_assigner/v2/__init__.py deleted file mode 100644 index 7eb1428e503..00000000000 --- a/api/dify_graph/nodes/variable_assigner/v2/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .node import VariableAssignerNode - -__all__ = ["VariableAssignerNode"] diff --git a/api/dify_graph/nodes/variable_assigner/v2/entities.py b/api/dify_graph/nodes/variable_assigner/v2/entities.py deleted file mode 100644 index 2b2bbe85deb..00000000000 --- a/api/dify_graph/nodes/variable_assigner/v2/entities.py +++ /dev/null @@ -1,28 +0,0 @@ -from collections.abc import Sequence -from typing import Any - -from pydantic import BaseModel, Field - -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType - -from .enums import InputType, Operation - - -class VariableOperationItem(BaseModel): - variable_selector: Sequence[str] - input_type: InputType - operation: Operation - # NOTE(QuantumGhost): The `value` field serves multiple purposes depending on context: - # - # 1. For CONSTANT input_type: Contains the literal value to be used in the operation. - # 2. For VARIABLE input_type: Initially contains the selector of the source variable. - # 3. During the variable updating procedure: The `value` field is reassigned to hold - # the resolved actual value that will be applied to the target variable. - value: Any = None - - -class VariableAssignerNodeData(BaseNodeData): - type: NodeType = BuiltinNodeTypes.VARIABLE_ASSIGNER - version: str = "2" - items: Sequence[VariableOperationItem] = Field(default_factory=list) diff --git a/api/dify_graph/nodes/variable_assigner/v2/enums.py b/api/dify_graph/nodes/variable_assigner/v2/enums.py deleted file mode 100644 index 291b1208d46..00000000000 --- a/api/dify_graph/nodes/variable_assigner/v2/enums.py +++ /dev/null @@ -1,20 +0,0 @@ -from enum import StrEnum - - -class Operation(StrEnum): - OVER_WRITE = "over-write" - CLEAR = "clear" - APPEND = "append" - EXTEND = "extend" - SET = "set" - ADD = "+=" - SUBTRACT = "-=" - MULTIPLY = "*=" - DIVIDE = "/=" - REMOVE_FIRST = "remove-first" - REMOVE_LAST = "remove-last" - - -class InputType(StrEnum): - VARIABLE = "variable" - CONSTANT = "constant" diff --git a/api/dify_graph/nodes/variable_assigner/v2/exc.py b/api/dify_graph/nodes/variable_assigner/v2/exc.py deleted file mode 100644 index c50aab86687..00000000000 --- a/api/dify_graph/nodes/variable_assigner/v2/exc.py +++ /dev/null @@ -1,36 +0,0 @@ -from collections.abc import Sequence -from typing import Any - -from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError - -from .enums import InputType, Operation - - -class OperationNotSupportedError(VariableOperatorNodeError): - def __init__(self, *, operation: Operation, variable_type: str): - super().__init__(f"Operation {operation} is not supported for type {variable_type}") - - -class InputTypeNotSupportedError(VariableOperatorNodeError): - def __init__(self, *, input_type: InputType, operation: Operation): - super().__init__(f"Input type {input_type} is not supported for operation {operation}") - - -class VariableNotFoundError(VariableOperatorNodeError): - def __init__(self, *, variable_selector: Sequence[str]): - super().__init__(f"Variable {variable_selector} not found") - - -class InvalidInputValueError(VariableOperatorNodeError): - def __init__(self, *, value: Any): - super().__init__(f"Invalid input value {value}") - - -class ConversationIDNotFoundError(VariableOperatorNodeError): - def __init__(self): - super().__init__("conversation_id not found") - - -class InvalidDataError(VariableOperatorNodeError): - def __init__(self, message: str): - super().__init__(message) diff --git a/api/dify_graph/nodes/variable_assigner/v2/helpers.py b/api/dify_graph/nodes/variable_assigner/v2/helpers.py deleted file mode 100644 index 38c69cbe3c0..00000000000 --- a/api/dify_graph/nodes/variable_assigner/v2/helpers.py +++ /dev/null @@ -1,98 +0,0 @@ -from typing import Any - -from dify_graph.variables import SegmentType - -from .enums import Operation - - -def is_operation_supported(*, variable_type: SegmentType, operation: Operation): - match operation: - case Operation.OVER_WRITE | Operation.CLEAR: - return True - case Operation.SET: - return variable_type in { - SegmentType.OBJECT, - SegmentType.STRING, - SegmentType.NUMBER, - SegmentType.INTEGER, - SegmentType.FLOAT, - SegmentType.BOOLEAN, - } - case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE: - # Only number variable can be added, subtracted, multiplied or divided - return variable_type in {SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT} - case Operation.APPEND | Operation.EXTEND | Operation.REMOVE_FIRST | Operation.REMOVE_LAST: - # Only array variable can be appended or extended - # Only array variable can have elements removed - return variable_type.is_array_type() - - -def is_variable_input_supported(*, operation: Operation): - if operation in {Operation.SET, Operation.ADD, Operation.SUBTRACT, Operation.MULTIPLY, Operation.DIVIDE}: - return False - return True - - -def is_constant_input_supported(*, variable_type: SegmentType, operation: Operation): - match variable_type: - case SegmentType.STRING | SegmentType.OBJECT | SegmentType.BOOLEAN: - return operation in {Operation.OVER_WRITE, Operation.SET} - case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: - return operation in { - Operation.OVER_WRITE, - Operation.SET, - Operation.ADD, - Operation.SUBTRACT, - Operation.MULTIPLY, - Operation.DIVIDE, - } - case _: - return False - - -def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, value: Any): - if operation in {Operation.CLEAR, Operation.REMOVE_FIRST, Operation.REMOVE_LAST}: - return True - match variable_type: - case SegmentType.STRING: - return isinstance(value, str) - - case SegmentType.BOOLEAN: - return isinstance(value, bool) - - case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: - if not isinstance(value, int | float): - return False - if operation == Operation.DIVIDE and value == 0: - return False - return True - - case SegmentType.OBJECT: - return isinstance(value, dict) - - # Array & Append - case SegmentType.ARRAY_ANY if operation == Operation.APPEND: - return isinstance(value, str | float | int | dict) - case SegmentType.ARRAY_STRING if operation == Operation.APPEND: - return isinstance(value, str) - case SegmentType.ARRAY_NUMBER if operation == Operation.APPEND: - return isinstance(value, int | float) - case SegmentType.ARRAY_OBJECT if operation == Operation.APPEND: - return isinstance(value, dict) - case SegmentType.ARRAY_BOOLEAN if operation == Operation.APPEND: - return isinstance(value, bool) - - # Array & Extend / Overwrite - case SegmentType.ARRAY_ANY if operation in {Operation.EXTEND, Operation.OVER_WRITE}: - return isinstance(value, list) and all(isinstance(item, str | float | int | dict) for item in value) - case SegmentType.ARRAY_STRING if operation in {Operation.EXTEND, Operation.OVER_WRITE}: - return isinstance(value, list) and all(isinstance(item, str) for item in value) - case SegmentType.ARRAY_NUMBER if operation in {Operation.EXTEND, Operation.OVER_WRITE}: - return isinstance(value, list) and all(isinstance(item, int | float) for item in value) - case SegmentType.ARRAY_OBJECT if operation in {Operation.EXTEND, Operation.OVER_WRITE}: - return isinstance(value, list) and all(isinstance(item, dict) for item in value) - case SegmentType.ARRAY_BOOLEAN if operation in {Operation.EXTEND, Operation.OVER_WRITE}: - return isinstance(value, list) and all(isinstance(item, bool) for item in value) - - case _: - return False diff --git a/api/dify_graph/nodes/variable_assigner/v2/node.py b/api/dify_graph/nodes/variable_assigner/v2/node.py deleted file mode 100644 index f04a6b3b804..00000000000 --- a/api/dify_graph/nodes/variable_assigner/v2/node.py +++ /dev/null @@ -1,246 +0,0 @@ -import json -from collections.abc import Mapping, MutableMapping, Sequence -from typing import TYPE_CHECKING, Any - -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.variable_assigner.common import helpers as common_helpers -from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from dify_graph.variables import SegmentType, VariableBase -from dify_graph.variables.consts import SELECTORS_LENGTH - -from . import helpers -from .entities import VariableAssignerNodeData, VariableOperationItem -from .enums import InputType, Operation -from .exc import ( - InputTypeNotSupportedError, - InvalidDataError, - InvalidInputValueError, - OperationNotSupportedError, - VariableNotFoundError, -) - -if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState - - -def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): - selector_node_id = item.variable_selector[0] - if selector_node_id != CONVERSATION_VARIABLE_NODE_ID: - return - selector_str = ".".join(item.variable_selector) - key = f"{node_id}.#{selector_str}#" - mapping[key] = item.variable_selector - - -def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): - # Keep this in sync with the logic in _run methods... - if item.input_type != InputType.VARIABLE: - return - selector = item.value - if not isinstance(selector, list): - raise InvalidDataError(f"selector is not a list, {node_id=}, {item=}") - if len(selector) < SELECTORS_LENGTH: - raise InvalidDataError(f"selector too short, {node_id=}, {item=}") - selector_str = ".".join(selector) - key = f"{node_id}.#{selector_str}#" - mapping[key] = selector - - -class VariableAssignerNode(Node[VariableAssignerNodeData]): - node_type = BuiltinNodeTypes.VARIABLE_ASSIGNER - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - ): - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: - """ - Check if this Variable Assigner node blocks the output of specific variables. - - Returns True if this node updates any of the requested conversation variables. - """ - # Check each item in this Variable Assigner node - for item in self.node_data.items: - # Convert the item's variable_selector to tuple for comparison - item_selector_tuple = tuple(item.variable_selector) - - # Check if this item updates any of the requested variables - if item_selector_tuple in variable_selectors: - return True - - return False - - @classmethod - def version(cls) -> str: - return "2" - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: VariableAssignerNodeData, - ) -> Mapping[str, Sequence[str]]: - var_mapping: dict[str, Sequence[str]] = {} - for item in node_data.items: - _target_mapping_from_item(var_mapping, node_id, item) - _source_mapping_from_item(var_mapping, node_id, item) - return var_mapping - - def _run(self) -> NodeRunResult: - inputs = self.node_data.model_dump() - process_data: dict[str, Any] = {} - # NOTE: This node has no outputs - updated_variable_selectors: list[Sequence[str]] = [] - - try: - for item in self.node_data.items: - variable = self.graph_runtime_state.variable_pool.get(item.variable_selector) - - # ==================== Validation Part - - # Check if variable exists - if not isinstance(variable, VariableBase): - raise VariableNotFoundError(variable_selector=item.variable_selector) - - # Check if operation is supported - if not helpers.is_operation_supported(variable_type=variable.value_type, operation=item.operation): - raise OperationNotSupportedError(operation=item.operation, variable_type=variable.value_type) - - # Check if variable input is supported - if item.input_type == InputType.VARIABLE and not helpers.is_variable_input_supported( - operation=item.operation - ): - raise InputTypeNotSupportedError(input_type=InputType.VARIABLE, operation=item.operation) - - # Check if constant input is supported - if item.input_type == InputType.CONSTANT and not helpers.is_constant_input_supported( - variable_type=variable.value_type, operation=item.operation - ): - raise InputTypeNotSupportedError(input_type=InputType.CONSTANT, operation=item.operation) - - # Get value from variable pool - if ( - item.input_type == InputType.VARIABLE - and item.operation not in {Operation.CLEAR, Operation.REMOVE_FIRST, Operation.REMOVE_LAST} - and item.value is not None - ): - value = self.graph_runtime_state.variable_pool.get(item.value) - if value is None: - raise VariableNotFoundError(variable_selector=item.value) - # Skip if value is NoneSegment - if value.value_type == SegmentType.NONE: - continue - item.value = value.value - - # If set string / bytes / bytearray to object, try convert string to object. - if ( - item.operation == Operation.SET - and variable.value_type == SegmentType.OBJECT - and isinstance(item.value, str | bytes | bytearray) - ): - try: - item.value = json.loads(item.value) - except json.JSONDecodeError: - raise InvalidInputValueError(value=item.value) - - # Check if input value is valid - if not helpers.is_input_value_valid( - variable_type=variable.value_type, operation=item.operation, value=item.value - ): - raise InvalidInputValueError(value=item.value) - - # ==================== Execution Part - - updated_value = self._handle_item( - variable=variable, - operation=item.operation, - value=item.value, - ) - variable = variable.model_copy(update={"value": updated_value}) - self.graph_runtime_state.variable_pool.add(variable.selector, variable) - updated_variable_selectors.append(variable.selector) - except VariableOperatorNodeError as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=inputs, - process_data=process_data, - error=str(e), - ) - - # The `updated_variable_selectors` is a list contains list[str] which not hashable, - # remove the duplicated items first. - updated_variable_selectors = list(set(map(tuple, updated_variable_selectors))) - - for selector in updated_variable_selectors: - variable = self.graph_runtime_state.variable_pool.get(selector) - if not isinstance(variable, VariableBase): - raise VariableNotFoundError(variable_selector=selector) - process_data[variable.name] = variable.value - - updated_variables = [ - common_helpers.variable_to_processed_data(selector, seg) - for selector in updated_variable_selectors - if (seg := self.graph_runtime_state.variable_pool.get(selector)) is not None - ] - - process_data = common_helpers.set_updated_variables(process_data, updated_variables) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs={}, - ) - - def _handle_item( - self, - *, - variable: VariableBase, - operation: Operation, - value: Any, - ): - match operation: - case Operation.OVER_WRITE: - return value - case Operation.CLEAR: - return SegmentType.get_zero_value(variable.value_type).to_object() - case Operation.APPEND: - return variable.value + [value] - case Operation.EXTEND: - return variable.value + value - case Operation.SET: - return value - case Operation.ADD: - return variable.value + value - case Operation.SUBTRACT: - return variable.value - value - case Operation.MULTIPLY: - return variable.value * value - case Operation.DIVIDE: - return variable.value / value - case Operation.REMOVE_FIRST: - # If array is empty, do nothing - if not variable.value: - return variable.value - return variable.value[1:] - case Operation.REMOVE_LAST: - # If array is empty, do nothing - if not variable.value: - return variable.value - return variable.value[:-1] diff --git a/api/dify_graph/repositories/__init__.py b/api/dify_graph/repositories/__init__.py deleted file mode 100644 index ef70eb09cc4..00000000000 --- a/api/dify_graph/repositories/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -Repository interfaces for data access. - -This package contains repository interfaces that define the contract -for accessing and manipulating data, regardless of the underlying -storage mechanism. -""" - -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository - -__all__ = [ - "OrderConfig", - "WorkflowNodeExecutionRepository", -] diff --git a/api/dify_graph/repositories/human_input_form_repository.py b/api/dify_graph/repositories/human_input_form_repository.py deleted file mode 100644 index 88966831cba..00000000000 --- a/api/dify_graph/repositories/human_input_form_repository.py +++ /dev/null @@ -1,152 +0,0 @@ -import abc -import dataclasses -from collections.abc import Mapping, Sequence -from datetime import datetime -from typing import Any, Protocol - -from dify_graph.nodes.human_input.entities import DeliveryChannelConfig, HumanInputNodeData -from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus - - -class HumanInputError(Exception): - pass - - -class FormNotFoundError(HumanInputError): - pass - - -@dataclasses.dataclass -class FormCreateParams: - # app_id is the identifier for the app that the form belongs to. - # It is a string with uuid format. - app_id: str - # None when creating a delivery test form; set for runtime forms. - workflow_execution_id: str | None - - # node_id is the identifier for a specific - # node in the graph. - # - # TODO: for node inside loop / iteration, this would - # cause problems, as a single node may be executed multiple times. - node_id: str - - form_config: HumanInputNodeData - rendered_content: str - # Delivery methods already filtered by runtime context (invoke_from). - delivery_methods: Sequence[DeliveryChannelConfig] - # UI display flag computed by runtime context. - display_in_ui: bool - - # resolved_default_values saves the values for defaults with - # type = VARIABLE. - # - # For type = CONSTANT, the value is not stored inside `resolved_default_values` - resolved_default_values: Mapping[str, Any] - form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME - - # Force creating a console-only recipient for submission in Console. - console_recipient_required: bool = False - console_creator_account_id: str | None = None - # Force creating a backstage recipient for submission in Console. - backstage_recipient_required: bool = False - - -class HumanInputFormEntity(abc.ABC): - @property - @abc.abstractmethod - def id(self) -> str: - """id returns the identifer of the form.""" - pass - - @property - @abc.abstractmethod - def web_app_token(self) -> str | None: - """web_app_token returns the token for submission inside webapp. - - For console/debug execution, this may point to the console submission token - if the form is configured to require console delivery. - """ - - # TODO: what if the users are allowed to add multiple - # webapp delivery? - pass - - @property - @abc.abstractmethod - def recipients(self) -> list["HumanInputFormRecipientEntity"]: ... - - @property - @abc.abstractmethod - def rendered_content(self) -> str: - """Rendered markdown content associated with the form.""" - ... - - @property - @abc.abstractmethod - def selected_action_id(self) -> str | None: - """Identifier of the selected user action if the form has been submitted.""" - ... - - @property - @abc.abstractmethod - def submitted_data(self) -> Mapping[str, Any] | None: - """Submitted form data if available.""" - ... - - @property - @abc.abstractmethod - def submitted(self) -> bool: - """Whether the form has been submitted.""" - ... - - @property - @abc.abstractmethod - def status(self) -> HumanInputFormStatus: - """Current status of the form.""" - ... - - @property - @abc.abstractmethod - def expiration_time(self) -> datetime: - """When the form expires.""" - ... - - -class HumanInputFormRecipientEntity(abc.ABC): - @property - @abc.abstractmethod - def id(self) -> str: - """id returns the identifer of this recipient.""" - ... - - @property - @abc.abstractmethod - def token(self) -> str: - """token returns a random string used to submit form""" - ... - - -class HumanInputFormRepository(Protocol): - """ - Repository interface for HumanInputForm. - - This interface defines the contract for accessing and manipulating - HumanInputForm data, regardless of the underlying storage mechanism. - - Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id), - and other implementation details should be handled at the implementation level, not in - the core interface. This keeps the core domain model clean and independent of specific - application domains or deployment scenarios. - """ - - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: - """Get the form created for a given human input node in a workflow execution. Returns - `None` if the form has not been created yet.""" - ... - - def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: - """ - Create a human input form from form definition. - """ - ... diff --git a/api/dify_graph/repositories/workflow_execution_repository.py b/api/dify_graph/repositories/workflow_execution_repository.py deleted file mode 100644 index ef83f076492..00000000000 --- a/api/dify_graph/repositories/workflow_execution_repository.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import Protocol - -from dify_graph.entities import WorkflowExecution - - -class WorkflowExecutionRepository(Protocol): - """ - Repository interface for WorkflowExecution. - - This interface defines the contract for accessing and manipulating - WorkflowExecution data, regardless of the underlying storage mechanism. - - Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id), - and other implementation details should be handled at the implementation level, not in - the core interface. This keeps the core domain model clean and independent of specific - application domains or deployment scenarios. - """ - - def save(self, execution: WorkflowExecution): - """ - Save or update a WorkflowExecution instance. - - This method handles both creating new records and updating existing ones. - The implementation should determine whether to create or update based on - the execution's ID or other identifying fields. - - Args: - execution: The WorkflowExecution instance to save or update - """ - ... diff --git a/api/dify_graph/repositories/workflow_node_execution_repository.py b/api/dify_graph/repositories/workflow_node_execution_repository.py deleted file mode 100644 index e6c1c3e497b..00000000000 --- a/api/dify_graph/repositories/workflow_node_execution_repository.py +++ /dev/null @@ -1,73 +0,0 @@ -from collections.abc import Sequence -from dataclasses import dataclass -from typing import Literal, Protocol - -from dify_graph.entities import WorkflowNodeExecution - - -@dataclass -class OrderConfig: - """Configuration for ordering NodeExecution instances.""" - - order_by: list[str] - order_direction: Literal["asc", "desc"] | None = None - - -class WorkflowNodeExecutionRepository(Protocol): - """ - Repository interface for NodeExecution. - - This interface defines the contract for accessing and manipulating - NodeExecution data, regardless of the underlying storage mechanism. - - Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id), - and trigger sources (triggered_from) should be handled at the implementation level, not in - the core interface. This keeps the core domain model clean and independent of specific - application domains or deployment scenarios. - """ - - def save(self, execution: WorkflowNodeExecution): - """ - Save or update a NodeExecution instance. - - This method saves all data on the `WorkflowNodeExecution` object, except for `inputs`, `process_data`, - and `outputs`. Its primary purpose is to persist the status and various metadata, such as execution time - and execution-related details. - - It's main purpose is to save the status and various metadata (execution time, execution metadata etc.) - - This method handles both creating new records and updating existing ones. - The implementation should determine whether to create or update based on - the execution's ID or other identifying fields. - - Args: - execution: The NodeExecution instance to save or update - """ - ... - - def save_execution_data(self, execution: WorkflowNodeExecution): - """Save or update the inputs, process_data, or outputs associated with a specific - node_execution record. - - If any of the inputs, process_data, or outputs are None, those fields will not be updated. - """ - ... - - def get_by_workflow_run( - self, - workflow_run_id: str, - order_config: OrderConfig | None = None, - ) -> Sequence[WorkflowNodeExecution]: - """ - Retrieve all NodeExecution instances for a specific workflow run. - - Args: - workflow_run_id: The workflow run ID - order_config: Optional configuration for ordering results - order_config.order_by: List of fields to order by (e.g., ["index", "created_at"]) - order_config.order_direction: Direction to order ("asc" or "desc") - - Returns: - A list of NodeExecution instances - """ - ... diff --git a/api/dify_graph/runtime/__init__.py b/api/dify_graph/runtime/__init__.py deleted file mode 100644 index adca07e59a7..00000000000 --- a/api/dify_graph/runtime/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -from .graph_runtime_state import ( - ChildEngineBuilderNotConfiguredError, - ChildEngineError, - ChildGraphNotFoundError, - GraphRuntimeState, -) -from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool -from .read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper -from .variable_pool import VariablePool, VariableValue - -__all__ = [ - "ChildEngineBuilderNotConfiguredError", - "ChildEngineError", - "ChildGraphNotFoundError", - "GraphRuntimeState", - "ReadOnlyGraphRuntimeState", - "ReadOnlyGraphRuntimeStateWrapper", - "ReadOnlyVariablePool", - "ReadOnlyVariablePoolWrapper", - "VariablePool", - "VariableValue", -] diff --git a/api/dify_graph/runtime/graph_runtime_state.py b/api/dify_graph/runtime/graph_runtime_state.py deleted file mode 100644 index 41acc6db35f..00000000000 --- a/api/dify_graph/runtime/graph_runtime_state.py +++ /dev/null @@ -1,683 +0,0 @@ -from __future__ import annotations - -import importlib -import json -from collections.abc import Mapping, Sequence -from copy import deepcopy -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, ClassVar, Protocol - -from pydantic import BaseModel, Field -from pydantic.json import pydantic_encoder - -from dify_graph.enums import NodeExecutionType, NodeState, NodeType -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.runtime.variable_pool import VariablePool - -if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.entities.pause_reason import PauseReason - - -class ReadyQueueProtocol(Protocol): - """Structural interface required from ready queue implementations.""" - - def put(self, item: str) -> None: - """Enqueue the identifier of a node that is ready to run.""" - ... - - def get(self, timeout: float | None = None) -> str: - """Return the next node identifier, blocking until available or timeout expires.""" - ... - - def task_done(self) -> None: - """Signal that the most recently dequeued node has completed processing.""" - ... - - def empty(self) -> bool: - """Return True when the queue contains no pending nodes.""" - ... - - def qsize(self) -> int: - """Approximate the number of pending nodes awaiting execution.""" - ... - - def dumps(self) -> str: - """Serialize the queue contents for persistence.""" - ... - - def loads(self, data: str) -> None: - """Restore the queue contents from a serialized payload.""" - ... - - -class GraphExecutionProtocol(Protocol): - """Structural interface for graph execution aggregate. - - Defines the minimal set of attributes and methods required from a GraphExecution entity - for runtime orchestration and state management. - """ - - workflow_id: str - started: bool - completed: bool - aborted: bool - error: Exception | None - exceptions_count: int - pause_reasons: list[PauseReason] - - def start(self) -> None: - """Transition execution into the running state.""" - ... - - def complete(self) -> None: - """Mark execution as successfully completed.""" - ... - - def abort(self, reason: str) -> None: - """Abort execution in response to an external stop request.""" - ... - - def fail(self, error: Exception) -> None: - """Record an unrecoverable error and end execution.""" - ... - - def dumps(self) -> str: - """Serialize execution state into a JSON payload.""" - ... - - def loads(self, data: str) -> None: - """Restore execution state from a previously serialized payload.""" - ... - - -class ResponseStreamCoordinatorProtocol(Protocol): - """Structural interface for response stream coordinator.""" - - def register(self, response_node_id: str) -> None: - """Register a response node so its outputs can be streamed.""" - ... - - def loads(self, data: str) -> None: - """Restore coordinator state from a serialized payload.""" - ... - - def dumps(self) -> str: - """Serialize coordinator state for persistence.""" - ... - - -class NodeProtocol(Protocol): - """Structural interface for graph nodes.""" - - id: str - state: NodeState - execution_type: NodeExecutionType - node_type: ClassVar[NodeType] - - def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: ... - - -class EdgeProtocol(Protocol): - id: str - state: NodeState - tail: str - head: str - source_handle: str - - -class GraphProtocol(Protocol): - """Structural interface required from graph instances attached to the runtime state.""" - - nodes: Mapping[str, NodeProtocol] - edges: Mapping[str, EdgeProtocol] - root_node: NodeProtocol - - def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ... - - -class ChildGraphEngineBuilderProtocol(Protocol): - def build_child_engine( - self, - *, - workflow_id: str, - graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: Mapping[str, Any], - root_node_id: str, - layers: Sequence[object] = (), - ) -> Any: ... - - -class ChildEngineError(ValueError): - """Base error type for child-engine creation failures.""" - - -class ChildEngineBuilderNotConfiguredError(ChildEngineError): - """Raised when child-engine creation is requested without a bound builder.""" - - -class ChildGraphNotFoundError(ChildEngineError): - """Raised when the requested child graph entry point cannot be resolved.""" - - -class _GraphStateSnapshot(BaseModel): - """Serializable graph state snapshot for node/edge states.""" - - nodes: dict[str, NodeState] = Field(default_factory=dict) - edges: dict[str, NodeState] = Field(default_factory=dict) - - -@dataclass(slots=True) -class _GraphRuntimeStateSnapshot: - """Immutable view of a serialized runtime state snapshot.""" - - start_at: float - total_tokens: int - node_run_steps: int - llm_usage: LLMUsage - outputs: dict[str, Any] - variable_pool: VariablePool - has_variable_pool: bool - ready_queue_dump: str | None - graph_execution_dump: str | None - response_coordinator_dump: str | None - paused_nodes: tuple[str, ...] - deferred_nodes: tuple[str, ...] - graph_node_states: dict[str, NodeState] - graph_edge_states: dict[str, NodeState] - - -class GraphRuntimeState: - """Mutable runtime state shared across graph execution components. - - `GraphRuntimeState` encapsulates the runtime state of workflow execution, - including scheduling details, variable values, and timing information. - - Values that are initialized prior to workflow execution and remain constant - throughout the execution should be part of `GraphInitParams` instead. - """ - - def __init__( - self, - *, - variable_pool: VariablePool, - start_at: float, - total_tokens: int = 0, - llm_usage: LLMUsage | None = None, - outputs: dict[str, object] | None = None, - node_run_steps: int = 0, - ready_queue: ReadyQueueProtocol | None = None, - graph_execution: GraphExecutionProtocol | None = None, - response_coordinator: ResponseStreamCoordinatorProtocol | None = None, - graph: GraphProtocol | None = None, - ) -> None: - self._variable_pool = variable_pool - self._start_at = start_at - - if total_tokens < 0: - raise ValueError("total_tokens must be non-negative") - self._total_tokens = total_tokens - - self._llm_usage = (llm_usage or LLMUsage.empty_usage()).model_copy() - self._outputs = deepcopy(outputs) if outputs is not None else {} - - if node_run_steps < 0: - raise ValueError("node_run_steps must be non-negative") - self._node_run_steps = node_run_steps - - self._graph: GraphProtocol | None = None - - self._ready_queue = ready_queue - self._graph_execution = graph_execution - self._response_coordinator = response_coordinator - self._pending_response_coordinator_dump: str | None = None - self._pending_graph_execution_workflow_id: str | None = None - self._paused_nodes: set[str] = set() - self._deferred_nodes: set[str] = set() - self._child_engine_builder: ChildGraphEngineBuilderProtocol | None = None - - # Node and edges states needed to be restored into - # graph object. - # - # These two fields are non-None only when resuming from a snapshot. - # Once the graph is attached, these two fields will be set to None. - self._pending_graph_node_states: dict[str, NodeState] | None = None - self._pending_graph_edge_states: dict[str, NodeState] | None = None - - if graph is not None: - self.attach_graph(graph) - - # ------------------------------------------------------------------ - # Context binding helpers - # ------------------------------------------------------------------ - def attach_graph(self, graph: GraphProtocol) -> None: - """Attach the materialized graph to the runtime state.""" - if self._graph is not None and self._graph is not graph: - raise ValueError("GraphRuntimeState already attached to a different graph instance") - - self._graph = graph - - if self._response_coordinator is None: - self._response_coordinator = self._build_response_coordinator(graph) - - if self._pending_response_coordinator_dump is not None and self._response_coordinator is not None: - self._response_coordinator.loads(self._pending_response_coordinator_dump) - self._pending_response_coordinator_dump = None - self._apply_pending_graph_state() - - def configure(self, *, graph: GraphProtocol | None = None) -> None: - """Ensure core collaborators are initialized with the provided context.""" - if graph is not None: - self.attach_graph(graph) - - # Ensure collaborators are instantiated - _ = self.ready_queue - _ = self.graph_execution - if self._graph is not None: - _ = self.response_coordinator - - def bind_child_engine_builder(self, builder: ChildGraphEngineBuilderProtocol) -> None: - self._child_engine_builder = builder - - def create_child_engine( - self, - *, - workflow_id: str, - graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: Mapping[str, Any], - root_node_id: str, - layers: Sequence[object] = (), - ) -> Any: - if self._child_engine_builder is None: - raise ChildEngineBuilderNotConfiguredError("Child engine builder is not configured.") - - return self._child_engine_builder.build_child_engine( - workflow_id=workflow_id, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - graph_config=graph_config, - root_node_id=root_node_id, - layers=layers, - ) - - # ------------------------------------------------------------------ - # Primary collaborators - # ------------------------------------------------------------------ - @property - def variable_pool(self) -> VariablePool: - return self._variable_pool - - @property - def ready_queue(self) -> ReadyQueueProtocol: - if self._ready_queue is None: - self._ready_queue = self._build_ready_queue() - return self._ready_queue - - @property - def graph_execution(self) -> GraphExecutionProtocol: - if self._graph_execution is None: - self._graph_execution = self._build_graph_execution() - return self._graph_execution - - @property - def response_coordinator(self) -> ResponseStreamCoordinatorProtocol: - if self._response_coordinator is None: - if self._graph is None: - raise ValueError("Graph must be attached before accessing response coordinator") - self._response_coordinator = self._build_response_coordinator(self._graph) - return self._response_coordinator - - # ------------------------------------------------------------------ - # Scalar state - # ------------------------------------------------------------------ - @property - def start_at(self) -> float: - return self._start_at - - @start_at.setter - def start_at(self, value: float) -> None: - self._start_at = value - - @property - def total_tokens(self) -> int: - return self._total_tokens - - @total_tokens.setter - def total_tokens(self, value: int) -> None: - if value < 0: - raise ValueError("total_tokens must be non-negative") - self._total_tokens = value - - @property - def llm_usage(self) -> LLMUsage: - return self._llm_usage.model_copy() - - @llm_usage.setter - def llm_usage(self, value: LLMUsage) -> None: - self._llm_usage = value.model_copy() - - @property - def outputs(self) -> dict[str, Any]: - return deepcopy(self._outputs) - - @outputs.setter - def outputs(self, value: dict[str, Any]) -> None: - self._outputs = deepcopy(value) - - def set_output(self, key: str, value: object) -> None: - self._outputs[key] = deepcopy(value) - - def get_output(self, key: str, default: object = None) -> object: - return deepcopy(self._outputs.get(key, default)) - - def update_outputs(self, updates: dict[str, object]) -> None: - for key, value in updates.items(): - self._outputs[key] = deepcopy(value) - - @property - def node_run_steps(self) -> int: - return self._node_run_steps - - @node_run_steps.setter - def node_run_steps(self, value: int) -> None: - if value < 0: - raise ValueError("node_run_steps must be non-negative") - self._node_run_steps = value - - def increment_node_run_steps(self) -> None: - self._node_run_steps += 1 - - def add_tokens(self, tokens: int) -> None: - if tokens < 0: - raise ValueError("tokens must be non-negative") - self._total_tokens += tokens - - # ------------------------------------------------------------------ - # Serialization - # ------------------------------------------------------------------ - def dumps(self) -> str: - """Serialize runtime state into a JSON string.""" - - snapshot: dict[str, Any] = { - "version": "1.0", - "start_at": self._start_at, - "total_tokens": self._total_tokens, - "node_run_steps": self._node_run_steps, - "llm_usage": self._llm_usage.model_dump(mode="json"), - "outputs": self.outputs, - "variable_pool": self.variable_pool.model_dump(mode="json"), - "ready_queue": self.ready_queue.dumps(), - "graph_execution": self.graph_execution.dumps(), - "paused_nodes": list(self._paused_nodes), - "deferred_nodes": list(self._deferred_nodes), - } - - graph_state = self._snapshot_graph_state() - if graph_state is not None: - snapshot["graph_state"] = graph_state - - if self._response_coordinator is not None and self._graph is not None: - snapshot["response_coordinator"] = self._response_coordinator.dumps() - - return json.dumps(snapshot, default=pydantic_encoder) - - @classmethod - def from_snapshot(cls, data: str | Mapping[str, Any]) -> GraphRuntimeState: - """Restore runtime state from a serialized snapshot.""" - - snapshot = cls._parse_snapshot_payload(data) - - state = cls( - variable_pool=snapshot.variable_pool, - start_at=snapshot.start_at, - total_tokens=snapshot.total_tokens, - llm_usage=snapshot.llm_usage, - outputs=snapshot.outputs, - node_run_steps=snapshot.node_run_steps, - ) - state._apply_snapshot(snapshot) - return state - - def loads(self, data: str | Mapping[str, Any]) -> None: - """Restore runtime state from a serialized snapshot (legacy API).""" - - snapshot = self._parse_snapshot_payload(data) - self._apply_snapshot(snapshot) - - def register_paused_node(self, node_id: str) -> None: - """Record a node that should resume when execution is continued.""" - - self._paused_nodes.add(node_id) - - def get_paused_nodes(self) -> list[str]: - """Retrieve the list of paused nodes without mutating internal state.""" - - return list(self._paused_nodes) - - def consume_paused_nodes(self) -> list[str]: - """Retrieve and clear the list of paused nodes awaiting resume.""" - - nodes = list(self._paused_nodes) - self._paused_nodes.clear() - return nodes - - def register_deferred_node(self, node_id: str) -> None: - """Record a node that became ready during pause and should resume later.""" - - self._deferred_nodes.add(node_id) - - def get_deferred_nodes(self) -> list[str]: - """Retrieve deferred nodes without mutating internal state.""" - - return list(self._deferred_nodes) - - def consume_deferred_nodes(self) -> list[str]: - """Retrieve and clear deferred nodes awaiting resume.""" - - nodes = list(self._deferred_nodes) - self._deferred_nodes.clear() - return nodes - - # ------------------------------------------------------------------ - # Builders - # ------------------------------------------------------------------ - def _build_ready_queue(self) -> ReadyQueueProtocol: - # Import lazily to avoid breaching architecture boundaries enforced by import-linter. - module = importlib.import_module("dify_graph.graph_engine.ready_queue") - in_memory_cls = module.InMemoryReadyQueue - return in_memory_cls() - - def _build_graph_execution(self) -> GraphExecutionProtocol: - # Lazily import to keep the runtime domain decoupled from graph_engine modules. - module = importlib.import_module("dify_graph.graph_engine.domain.graph_execution") - graph_execution_cls = module.GraphExecution - workflow_id = self._pending_graph_execution_workflow_id or "" - self._pending_graph_execution_workflow_id = None - return graph_execution_cls(workflow_id=workflow_id) # type: ignore[invalid-return-type] - - def _build_response_coordinator(self, graph: GraphProtocol) -> ResponseStreamCoordinatorProtocol: - # Lazily import to keep the runtime domain decoupled from graph_engine modules. - module = importlib.import_module("dify_graph.graph_engine.response_coordinator") - coordinator_cls = module.ResponseStreamCoordinator - return coordinator_cls(variable_pool=self.variable_pool, graph=graph) - - # ------------------------------------------------------------------ - # Snapshot helpers - # ------------------------------------------------------------------ - @classmethod - def _parse_snapshot_payload(cls, data: str | Mapping[str, Any]) -> _GraphRuntimeStateSnapshot: - payload: dict[str, Any] - if isinstance(data, str): - payload = json.loads(data) - else: - payload = dict(data) - - version = payload.get("version") - if version != "1.0": - raise ValueError(f"Unsupported GraphRuntimeState snapshot version: {version}") - - start_at = float(payload.get("start_at", 0.0)) - - total_tokens = int(payload.get("total_tokens", 0)) - if total_tokens < 0: - raise ValueError("total_tokens must be non-negative") - - node_run_steps = int(payload.get("node_run_steps", 0)) - if node_run_steps < 0: - raise ValueError("node_run_steps must be non-negative") - - llm_usage_payload = payload.get("llm_usage", {}) - llm_usage = LLMUsage.model_validate(llm_usage_payload) - - outputs_payload = deepcopy(payload.get("outputs", {})) - - variable_pool_payload = payload.get("variable_pool") - has_variable_pool = variable_pool_payload is not None - variable_pool = VariablePool.model_validate(variable_pool_payload) if has_variable_pool else VariablePool() - - ready_queue_payload = payload.get("ready_queue") - graph_execution_payload = payload.get("graph_execution") - response_payload = payload.get("response_coordinator") - paused_nodes_payload = payload.get("paused_nodes", []) - deferred_nodes_payload = payload.get("deferred_nodes", []) - graph_state_payload = payload.get("graph_state", {}) or {} - graph_node_states = _coerce_graph_state_map(graph_state_payload, "nodes") - graph_edge_states = _coerce_graph_state_map(graph_state_payload, "edges") - - return _GraphRuntimeStateSnapshot( - start_at=start_at, - total_tokens=total_tokens, - node_run_steps=node_run_steps, - llm_usage=llm_usage, - outputs=outputs_payload, - variable_pool=variable_pool, - has_variable_pool=has_variable_pool, - ready_queue_dump=ready_queue_payload, - graph_execution_dump=graph_execution_payload, - response_coordinator_dump=response_payload, - paused_nodes=tuple(map(str, paused_nodes_payload)), - deferred_nodes=tuple(map(str, deferred_nodes_payload)), - graph_node_states=graph_node_states, - graph_edge_states=graph_edge_states, - ) - - def _apply_snapshot(self, snapshot: _GraphRuntimeStateSnapshot) -> None: - self._start_at = snapshot.start_at - self._total_tokens = snapshot.total_tokens - self._node_run_steps = snapshot.node_run_steps - self._llm_usage = snapshot.llm_usage.model_copy() - self._outputs = deepcopy(snapshot.outputs) - if snapshot.has_variable_pool or self._variable_pool is None: - self._variable_pool = snapshot.variable_pool - - self._restore_ready_queue(snapshot.ready_queue_dump) - self._restore_graph_execution(snapshot.graph_execution_dump) - self._restore_response_coordinator(snapshot.response_coordinator_dump) - self._paused_nodes = set(snapshot.paused_nodes) - self._deferred_nodes = set(snapshot.deferred_nodes) - self._pending_graph_node_states = snapshot.graph_node_states or None - self._pending_graph_edge_states = snapshot.graph_edge_states or None - self._apply_pending_graph_state() - - def _restore_ready_queue(self, payload: str | None) -> None: - if payload is not None: - self._ready_queue = self._build_ready_queue() - self._ready_queue.loads(payload) - else: - self._ready_queue = None - - def _restore_graph_execution(self, payload: str | None) -> None: - self._graph_execution = None - self._pending_graph_execution_workflow_id = None - - if payload is None: - return - - try: - execution_payload = json.loads(payload) - self._pending_graph_execution_workflow_id = execution_payload.get("workflow_id") - except (json.JSONDecodeError, TypeError, AttributeError): - self._pending_graph_execution_workflow_id = None - - self.graph_execution.loads(payload) - - def _restore_response_coordinator(self, payload: str | None) -> None: - if payload is None: - self._pending_response_coordinator_dump = None - self._response_coordinator = None - return - - if self._graph is not None: - self.response_coordinator.loads(payload) - self._pending_response_coordinator_dump = None - return - - self._pending_response_coordinator_dump = payload - self._response_coordinator = None - - def _snapshot_graph_state(self) -> _GraphStateSnapshot: - graph = self._graph - if graph is None: - if self._pending_graph_node_states is None and self._pending_graph_edge_states is None: - return _GraphStateSnapshot() - return _GraphStateSnapshot( - nodes=self._pending_graph_node_states or {}, - edges=self._pending_graph_edge_states or {}, - ) - - nodes = graph.nodes - edges = graph.edges - if not isinstance(nodes, Mapping) or not isinstance(edges, Mapping): - return _GraphStateSnapshot() - - node_states = {} - for node_id, node in nodes.items(): - if not isinstance(node_id, str): - continue - node_states[node_id] = node.state - - edge_states = {} - for edge_id, edge in edges.items(): - if not isinstance(edge_id, str): - continue - edge_states[edge_id] = edge.state - - return _GraphStateSnapshot(nodes=node_states, edges=edge_states) - - def _apply_pending_graph_state(self) -> None: - if self._graph is None: - return - if self._pending_graph_node_states: - for node_id, state in self._pending_graph_node_states.items(): - node = self._graph.nodes.get(node_id) - if node is None: - continue - node.state = state - if self._pending_graph_edge_states: - for edge_id, state in self._pending_graph_edge_states.items(): - edge = self._graph.edges.get(edge_id) - if edge is None: - continue - edge.state = state - - self._pending_graph_node_states = None - self._pending_graph_edge_states = None - - -def _coerce_graph_state_map(payload: Any, key: str) -> dict[str, NodeState]: - if not isinstance(payload, Mapping): - return {} - raw_map = payload.get(key, {}) - if not isinstance(raw_map, Mapping): - return {} - result: dict[str, NodeState] = {} - for node_id, raw_state in raw_map.items(): - if not isinstance(node_id, str): - continue - try: - result[node_id] = NodeState(str(raw_state)) - except ValueError: - continue - return result diff --git a/api/dify_graph/runtime/graph_runtime_state_protocol.py b/api/dify_graph/runtime/graph_runtime_state_protocol.py deleted file mode 100644 index 7e55ece3f14..00000000000 --- a/api/dify_graph/runtime/graph_runtime_state_protocol.py +++ /dev/null @@ -1,83 +0,0 @@ -from collections.abc import Mapping, Sequence -from typing import Any, Protocol - -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.system_variable import SystemVariableReadOnlyView -from dify_graph.variables.segments import Segment - - -class ReadOnlyVariablePool(Protocol): - """Read-only interface for VariablePool.""" - - def get(self, selector: Sequence[str], /) -> Segment | None: - """Get a variable value (read-only).""" - ... - - def get_all_by_node(self, node_id: str) -> Mapping[str, object]: - """Get all variables for a node (read-only).""" - ... - - def get_by_prefix(self, prefix: str) -> Mapping[str, object]: - """Get all variables stored under a given node prefix (read-only).""" - ... - - -class ReadOnlyGraphRuntimeState(Protocol): - """ - Read-only view of GraphRuntimeState for layers. - - This protocol defines a read-only interface that prevents layers from - modifying the graph runtime state while still allowing observation. - All methods return defensive copies to ensure immutability. - """ - - @property - def system_variable(self) -> SystemVariableReadOnlyView: ... - - @property - def variable_pool(self) -> ReadOnlyVariablePool: - """Get read-only access to the variable pool.""" - ... - - @property - def start_at(self) -> float: - """Get the start time (read-only).""" - ... - - @property - def total_tokens(self) -> int: - """Get the total tokens count (read-only).""" - ... - - @property - def llm_usage(self) -> LLMUsage: - """Get a copy of LLM usage info (read-only).""" - ... - - @property - def outputs(self) -> dict[str, Any]: - """Get a defensive copy of outputs (read-only).""" - ... - - @property - def node_run_steps(self) -> int: - """Get the node run steps count (read-only).""" - ... - - @property - def ready_queue_size(self) -> int: - """Get the number of nodes currently in the ready queue.""" - ... - - @property - def exceptions_count(self) -> int: - """Get the number of node execution exceptions recorded.""" - ... - - def get_output(self, key: str, default: Any = None) -> Any: - """Get a single output value (returns a copy).""" - ... - - def dumps(self) -> str: - """Serialize the runtime state into a JSON snapshot (read-only).""" - ... diff --git a/api/dify_graph/runtime/read_only_wrappers.py b/api/dify_graph/runtime/read_only_wrappers.py deleted file mode 100644 index ca06d88c3d4..00000000000 --- a/api/dify_graph/runtime/read_only_wrappers.py +++ /dev/null @@ -1,87 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping, Sequence -from copy import deepcopy -from typing import Any - -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.system_variable import SystemVariableReadOnlyView -from dify_graph.variables.segments import Segment - -from .graph_runtime_state import GraphRuntimeState -from .variable_pool import VariablePool - - -class ReadOnlyVariablePoolWrapper: - """Provide defensive, read-only access to ``VariablePool``.""" - - def __init__(self, variable_pool: VariablePool) -> None: - self._variable_pool = variable_pool - - def get(self, selector: Sequence[str], /) -> Segment | None: - """Return a copy of a variable value if present.""" - value = self._variable_pool.get(selector) - return deepcopy(value) if value is not None else None - - def get_all_by_node(self, node_id: str) -> Mapping[str, object]: - """Return a copy of all variables for the specified node.""" - variables: dict[str, object] = {} - if node_id in self._variable_pool.variable_dictionary: - for key, variable in self._variable_pool.variable_dictionary[node_id].items(): - variables[key] = deepcopy(variable.value) - return variables - - def get_by_prefix(self, prefix: str) -> Mapping[str, object]: - """Return a copy of all variables stored under the given prefix.""" - return self._variable_pool.get_by_prefix(prefix) - - -class ReadOnlyGraphRuntimeStateWrapper: - """Expose a defensive, read-only view of ``GraphRuntimeState``.""" - - def __init__(self, state: GraphRuntimeState) -> None: - self._state = state - self._variable_pool_wrapper = ReadOnlyVariablePoolWrapper(state.variable_pool) - - @property - def system_variable(self) -> SystemVariableReadOnlyView: - return self._state.variable_pool.system_variables.as_view() - - @property - def variable_pool(self) -> ReadOnlyVariablePoolWrapper: - return self._variable_pool_wrapper - - @property - def start_at(self) -> float: - return self._state.start_at - - @property - def total_tokens(self) -> int: - return self._state.total_tokens - - @property - def llm_usage(self) -> LLMUsage: - return self._state.llm_usage.model_copy() - - @property - def outputs(self) -> dict[str, Any]: - return deepcopy(self._state.outputs) - - @property - def node_run_steps(self) -> int: - return self._state.node_run_steps - - @property - def ready_queue_size(self) -> int: - return self._state.ready_queue.qsize() - - @property - def exceptions_count(self) -> int: - return self._state.graph_execution.exceptions_count - - def get_output(self, key: str, default: Any = None) -> Any: - return self._state.get_output(key, default) - - def dumps(self) -> str: - """Serialize the underlying runtime state for external persistence.""" - return self._state.dumps() diff --git a/api/dify_graph/runtime/variable_pool.py b/api/dify_graph/runtime/variable_pool.py deleted file mode 100644 index e3ef6a2897c..00000000000 --- a/api/dify_graph/runtime/variable_pool.py +++ /dev/null @@ -1,280 +0,0 @@ -from __future__ import annotations - -import re -from collections import defaultdict -from collections.abc import Mapping, Sequence -from copy import deepcopy -from typing import Annotated, Any, Union, cast - -from pydantic import BaseModel, Field - -from dify_graph.constants import ( - CONVERSATION_VARIABLE_NODE_ID, - ENVIRONMENT_VARIABLE_NODE_ID, - RAG_PIPELINE_VARIABLE_NODE_ID, - SYSTEM_VARIABLE_NODE_ID, -) -from dify_graph.file import File, FileAttribute, file_manager -from dify_graph.system_variable import SystemVariable -from dify_graph.variables import Segment, SegmentGroup, VariableBase -from dify_graph.variables.consts import SELECTORS_LENGTH -from dify_graph.variables.segments import FileSegment, ObjectSegment -from dify_graph.variables.variables import RAGPipelineVariableInput, Variable -from factories import variable_factory - -VariableValue = Union[str, int, float, dict[str, object], list[object], File] - -VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}") - - -class VariablePool(BaseModel): - # Variable dictionary is a dictionary for looking up variables by their selector. - # The first element of the selector is the node id, it's the first-level key in the dictionary. - # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the - # elements of the selector except the first one. - variable_dictionary: defaultdict[str, Annotated[dict[str, Variable], Field(default_factory=dict)]] = Field( - description="Variables mapping", - default=defaultdict(dict), - ) - - # The `user_inputs` is used only when constructing the inputs for the `StartNode`. It's not used elsewhere. - user_inputs: Mapping[str, Any] = Field( - description="User inputs", - default_factory=dict, - ) - system_variables: SystemVariable = Field( - description="System variables", - default_factory=SystemVariable.default, - ) - environment_variables: Sequence[Variable] = Field( - description="Environment variables.", - default_factory=list[Variable], - ) - conversation_variables: Sequence[Variable] = Field( - description="Conversation variables.", - default_factory=list[Variable], - ) - rag_pipeline_variables: list[RAGPipelineVariableInput] = Field( - description="RAG pipeline variables.", - default_factory=list, - ) - - def model_post_init(self, context: Any, /): - # Create a mapping from field names to SystemVariableKey enum values - self._add_system_variables(self.system_variables) - # Add environment variables to the variable pool - for var in self.environment_variables: - self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var) - # Add conversation variables to the variable pool. When restoring from a serialized - # snapshot, `variable_dictionary` already carries the latest runtime values. - # In that case, keep existing entries instead of overwriting them with the - # bootstrap list. - for var in self.conversation_variables: - selector = (CONVERSATION_VARIABLE_NODE_ID, var.name) - if self._has(selector): - continue - self.add(selector, var) - # Add rag pipeline variables to the variable pool - if self.rag_pipeline_variables: - rag_pipeline_variables_map: defaultdict[Any, dict[Any, Any]] = defaultdict(dict) - for rag_var in self.rag_pipeline_variables: - node_id = rag_var.variable.belong_to_node_id - key = rag_var.variable.variable - value = rag_var.value - rag_pipeline_variables_map[node_id][key] = value - for key, value in rag_pipeline_variables_map.items(): - self.add((RAG_PIPELINE_VARIABLE_NODE_ID, key), value) - - def add(self, selector: Sequence[str], value: Any, /): - """ - Add a variable to the variable pool. - - This method accepts a selector path and a value, converting the value - to a Variable object if necessary before storing it in the pool. - - Args: - selector: A two-element sequence containing [node_id, variable_name]. - The selector must have exactly 2 elements to be valid. - value: The value to store. Can be a Variable, Segment, or any value - that can be converted to a Segment (str, int, float, dict, list, File). - - Raises: - ValueError: If selector length is not exactly 2 elements. - - Note: - While non-Segment values are currently accepted and automatically - converted, it's recommended to pass Segment or Variable objects directly. - """ - if len(selector) != SELECTORS_LENGTH: - raise ValueError( - f"Invalid selector: expected {SELECTORS_LENGTH} elements (node_id, variable_name), " - f"got {len(selector)} elements" - ) - - if isinstance(value, VariableBase): - variable = value - elif isinstance(value, Segment): - variable = variable_factory.segment_to_variable(segment=value, selector=selector) - else: - segment = variable_factory.build_segment(value) - variable = variable_factory.segment_to_variable(segment=segment, selector=selector) - - node_id, name = self._selector_to_keys(selector) - # Based on the definition of `Variable`, - # `VariableBase` instances can be safely used as `Variable` since they are compatible. - self.variable_dictionary[node_id][name] = cast(Variable, variable) - - @classmethod - def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]: - return selector[0], selector[1] - - def _has(self, selector: Sequence[str]) -> bool: - node_id, name = self._selector_to_keys(selector) - if node_id not in self.variable_dictionary: - return False - if name not in self.variable_dictionary[node_id]: - return False - return True - - def get(self, selector: Sequence[str], /) -> Segment | None: - """ - Retrieve a variable's value from the pool as a Segment. - - This method supports both simple selectors [node_id, variable_name] and - extended selectors that include attribute access for FileSegment and - ObjectSegment types. - - Args: - selector: A sequence with at least 2 elements: - - [node_id, variable_name]: Returns the full segment - - [node_id, variable_name, attr, ...]: Returns a nested value - from FileSegment (e.g., 'url', 'name') or ObjectSegment - - Returns: - The Segment associated with the selector, or None if not found. - Returns None if selector has fewer than 2 elements. - - Raises: - ValueError: If attempting to access an invalid FileAttribute. - """ - if len(selector) < SELECTORS_LENGTH: - return None - - node_id, name = self._selector_to_keys(selector) - node_map = self.variable_dictionary.get(node_id) - if node_map is None: - return None - - segment: Segment | None = node_map.get(name) - - if segment is None: - return None - - if len(selector) == 2: - return segment - - if isinstance(segment, FileSegment): - attr = selector[2] - # Python support `attr in FileAttribute` after 3.12 - if attr not in {item.value for item in FileAttribute}: - return None - attr = FileAttribute(attr) - attr_value = file_manager.get_attr(file=segment.value, attr=attr) - return variable_factory.build_segment(attr_value) - - # Navigate through nested attributes - result: Any = segment - for attr in selector[2:]: - result = self._extract_value(result) - result = self._get_nested_attribute(result, attr) - if result is None: - return None - - # Return result as Segment - return result if isinstance(result, Segment) else variable_factory.build_segment(result) - - def _extract_value(self, obj: Any): - """Extract the actual value from an ObjectSegment.""" - return obj.value if isinstance(obj, ObjectSegment) else obj - - def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str) -> Segment | None: - """ - Get a nested attribute from a dictionary-like object. - - Args: - obj: The dictionary-like object to search. - attr: The key to look up. - - Returns: - Segment | None: - The corresponding Segment built from the attribute value if the key exists, - otherwise None. - """ - if not isinstance(obj, dict) or attr not in obj: - return None - return variable_factory.build_segment(obj.get(attr)) - - def remove(self, selector: Sequence[str], /): - """ - Remove variables from the variable pool based on the given selector. - - Args: - selector (Sequence[str]): A sequence of strings representing the selector. - - Returns: - None - """ - if not selector: - return - if len(selector) == 1: - self.variable_dictionary[selector[0]] = {} - return - key, hash_key = self._selector_to_keys(selector) - self.variable_dictionary[key].pop(hash_key, None) - - def convert_template(self, template: str, /): - parts = VARIABLE_PATTERN.split(template) - segments: list[Segment] = [] - for part in filter(lambda x: x, parts): - if "." in part and (variable := self.get(part.split("."))): - segments.append(variable) - else: - segments.append(variable_factory.build_segment(part)) - return SegmentGroup(value=segments) - - def get_file(self, selector: Sequence[str], /) -> FileSegment | None: - segment = self.get(selector) - if isinstance(segment, FileSegment): - return segment - return None - - def get_by_prefix(self, prefix: str, /) -> Mapping[str, object]: - """Return a copy of all variables stored under the given node prefix.""" - - nodes = self.variable_dictionary.get(prefix) - if not nodes: - return {} - - result: dict[str, object] = {} - for key, variable in nodes.items(): - value = variable.value - result[key] = deepcopy(value) - - return result - - def _add_system_variables(self, system_variable: SystemVariable): - sys_var_mapping = system_variable.to_dict() - for key, value in sys_var_mapping.items(): - if value is None: - continue - selector = (SYSTEM_VARIABLE_NODE_ID, key) - # If the system variable already exists, do not add it again. - # This ensures that we can keep the id of the system variables intact. - if self._has(selector): - continue - self.add(selector, value) - - @classmethod - def empty(cls) -> VariablePool: - """Create an empty variable pool.""" - return cls(system_variables=SystemVariable.default()) diff --git a/api/dify_graph/system_variable.py b/api/dify_graph/system_variable.py deleted file mode 100644 index cc5deda892c..00000000000 --- a/api/dify_graph/system_variable.py +++ /dev/null @@ -1,217 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping, Sequence -from types import MappingProxyType -from typing import Any -from uuid import uuid4 - -from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator - -from dify_graph.enums import SystemVariableKey -from dify_graph.file.models import File - - -class SystemVariable(BaseModel): - """A model for managing system variables. - - Fields with a value of `None` are treated as absent and will not be included - in the variable pool. - """ - - model_config = ConfigDict( - extra="forbid", - serialize_by_alias=True, - validate_by_alias=True, - ) - - user_id: str | None = None - - # Ideally, `app_id` and `workflow_id` should be required and not `None`. - # However, there are scenarios in the codebase where these fields are not set. - # To maintain compatibility, they are marked as optional here. - app_id: str | None = None - workflow_id: str | None = None - - timestamp: int | None = None - - files: Sequence[File] = Field(default_factory=list) - - # NOTE: The `workflow_execution_id` field was previously named `workflow_run_id`. - # To maintain compatibility with existing workflows, it must be serialized - # as `workflow_run_id` in dictionaries or JSON objects, and also referenced - # as `workflow_run_id` in the variable pool. - workflow_execution_id: str | None = Field( - validation_alias=AliasChoices("workflow_execution_id", "workflow_run_id"), - serialization_alias="workflow_run_id", - default=None, - ) - # Chatflow related fields. - query: str | None = None - conversation_id: str | None = None - dialogue_count: int | None = None - document_id: str | None = None - original_document_id: str | None = None - dataset_id: str | None = None - batch: str | None = None - datasource_type: str | None = None - datasource_info: Mapping[str, Any] | None = None - invoke_from: str | None = None - - @model_validator(mode="before") - @classmethod - def validate_json_fields(cls, data): - if isinstance(data, dict): - # For JSON validation, only allow workflow_run_id - if "workflow_execution_id" in data and "workflow_run_id" not in data: - # This is likely from direct instantiation, allow it - return data - elif "workflow_execution_id" in data and "workflow_run_id" in data: - # Both present, remove workflow_execution_id - data = data.copy() - data.pop("workflow_execution_id") - return data - return data - - @classmethod - def default(cls) -> SystemVariable: - return cls(workflow_execution_id=str(uuid4())) - - def to_dict(self) -> dict[SystemVariableKey, Any]: - # NOTE: This method is provided for compatibility with legacy code. - # New code should use the `SystemVariable` object directly instead of converting - # it to a dictionary, as this conversion results in the loss of type information - # for each key, making static analysis more difficult. - - d: dict[SystemVariableKey, Any] = { - SystemVariableKey.FILES: self.files, - } - if self.user_id is not None: - d[SystemVariableKey.USER_ID] = self.user_id - if self.app_id is not None: - d[SystemVariableKey.APP_ID] = self.app_id - if self.workflow_id is not None: - d[SystemVariableKey.WORKFLOW_ID] = self.workflow_id - if self.workflow_execution_id is not None: - d[SystemVariableKey.WORKFLOW_EXECUTION_ID] = self.workflow_execution_id - if self.query is not None: - d[SystemVariableKey.QUERY] = self.query - if self.conversation_id is not None: - d[SystemVariableKey.CONVERSATION_ID] = self.conversation_id - if self.dialogue_count is not None: - d[SystemVariableKey.DIALOGUE_COUNT] = self.dialogue_count - if self.document_id is not None: - d[SystemVariableKey.DOCUMENT_ID] = self.document_id - if self.original_document_id is not None: - d[SystemVariableKey.ORIGINAL_DOCUMENT_ID] = self.original_document_id - if self.dataset_id is not None: - d[SystemVariableKey.DATASET_ID] = self.dataset_id - if self.batch is not None: - d[SystemVariableKey.BATCH] = self.batch - if self.datasource_type is not None: - d[SystemVariableKey.DATASOURCE_TYPE] = self.datasource_type - if self.datasource_info is not None: - d[SystemVariableKey.DATASOURCE_INFO] = self.datasource_info - if self.invoke_from is not None: - d[SystemVariableKey.INVOKE_FROM] = self.invoke_from - if self.timestamp is not None: - d[SystemVariableKey.TIMESTAMP] = self.timestamp - return d - - def as_view(self) -> SystemVariableReadOnlyView: - return SystemVariableReadOnlyView(self) - - -class SystemVariableReadOnlyView: - """ - A read-only view of a SystemVariable that implements the ReadOnlySystemVariable protocol. - - This class wraps a SystemVariable instance and provides read-only access to all its fields. - It always reads the latest data from the wrapped instance and prevents any write operations. - """ - - def __init__(self, system_variable: SystemVariable) -> None: - """ - Initialize the read-only view with a SystemVariable instance. - - Args: - system_variable: The SystemVariable instance to wrap - """ - self._system_variable = system_variable - - @property - def user_id(self) -> str | None: - return self._system_variable.user_id - - @property - def app_id(self) -> str | None: - return self._system_variable.app_id - - @property - def workflow_id(self) -> str | None: - return self._system_variable.workflow_id - - @property - def workflow_execution_id(self) -> str | None: - return self._system_variable.workflow_execution_id - - @property - def query(self) -> str | None: - return self._system_variable.query - - @property - def conversation_id(self) -> str | None: - return self._system_variable.conversation_id - - @property - def dialogue_count(self) -> int | None: - return self._system_variable.dialogue_count - - @property - def document_id(self) -> str | None: - return self._system_variable.document_id - - @property - def original_document_id(self) -> str | None: - return self._system_variable.original_document_id - - @property - def dataset_id(self) -> str | None: - return self._system_variable.dataset_id - - @property - def batch(self) -> str | None: - return self._system_variable.batch - - @property - def datasource_type(self) -> str | None: - return self._system_variable.datasource_type - - @property - def invoke_from(self) -> str | None: - return self._system_variable.invoke_from - - @property - def files(self) -> Sequence[File]: - """ - Get a copy of the files from the wrapped SystemVariable. - - Returns: - A defensive copy of the files sequence to prevent modification - """ - return tuple(self._system_variable.files) # Convert to immutable tuple - - @property - def datasource_info(self) -> Mapping[str, Any] | None: - """ - Get a copy of the datasource info from the wrapped SystemVariable. - - Returns: - A view of the datasource info mapping to prevent modification - """ - if self._system_variable.datasource_info is None: - return None - return MappingProxyType(self._system_variable.datasource_info) - - def __repr__(self) -> str: - """Return a string representation of the read-only view.""" - return f"SystemVariableReadOnlyView(system_variable={self._system_variable!r})" diff --git a/api/dify_graph/utils/__init__.py b/api/dify_graph/utils/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/dify_graph/utils/condition/__init__.py b/api/dify_graph/utils/condition/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/dify_graph/utils/condition/entities.py b/api/dify_graph/utils/condition/entities.py deleted file mode 100644 index 77a214571a1..00000000000 --- a/api/dify_graph/utils/condition/entities.py +++ /dev/null @@ -1,49 +0,0 @@ -from collections.abc import Sequence -from typing import Literal - -from pydantic import BaseModel, Field - -SupportedComparisonOperator = Literal[ - # for string or array - "contains", - "not contains", - "start with", - "end with", - "is", - "is not", - "empty", - "not empty", - "in", - "not in", - "all of", - # for number - "=", - "≠", - ">", - "<", - "≥", - "≤", - "null", - "not null", - # for file - "exists", - "not exists", -] - - -class SubCondition(BaseModel): - key: str - comparison_operator: SupportedComparisonOperator - value: str | Sequence[str] | None = None - - -class SubVariableCondition(BaseModel): - logical_operator: Literal["and", "or"] - conditions: list[SubCondition] = Field(default_factory=list) - - -class Condition(BaseModel): - variable_selector: list[str] - comparison_operator: SupportedComparisonOperator - value: str | Sequence[str] | bool | None = None - sub_variable_condition: SubVariableCondition | None = None diff --git a/api/dify_graph/utils/condition/processor.py b/api/dify_graph/utils/condition/processor.py deleted file mode 100644 index dea72d96c2c..00000000000 --- a/api/dify_graph/utils/condition/processor.py +++ /dev/null @@ -1,504 +0,0 @@ -import json -from collections.abc import Mapping, Sequence -from typing import Literal, NamedTuple - -from dify_graph.file import FileAttribute, file_manager -from dify_graph.runtime import VariablePool -from dify_graph.variables import ArrayFileSegment -from dify_graph.variables.segments import ArrayBooleanSegment, BooleanSegment - -from .entities import Condition, SubCondition, SupportedComparisonOperator - - -def _convert_to_bool(value: object) -> bool: - if isinstance(value, int): - return bool(value) - - if isinstance(value, str): - loaded = json.loads(value) - if isinstance(loaded, (int, bool)): - return bool(loaded) - - raise TypeError(f"unexpected value: type={type(value)}, value={value}") - - -class ConditionCheckResult(NamedTuple): - inputs: Sequence[Mapping[str, object]] - group_results: Sequence[bool] - final_result: bool - - -class ConditionProcessor: - def process_conditions( - self, - *, - variable_pool: VariablePool, - conditions: Sequence[Condition], - operator: Literal["and", "or"], - ) -> ConditionCheckResult: - input_conditions: list[Mapping[str, object]] = [] - group_results: list[bool] = [] - - for condition in conditions: - variable = variable_pool.get(condition.variable_selector) - if variable is None: - raise ValueError(f"Variable {condition.variable_selector} not found") - - if isinstance(variable, ArrayFileSegment) and condition.comparison_operator in { - "contains", - "not contains", - "all of", - }: - # check sub conditions - if not condition.sub_variable_condition: - raise ValueError("Sub variable is required") - result = _process_sub_conditions( - variable=variable, - sub_conditions=condition.sub_variable_condition.conditions, - operator=condition.sub_variable_condition.logical_operator, - ) - elif condition.comparison_operator in { - "exists", - "not exists", - }: - result = _evaluate_condition( - value=variable.value, - operator=condition.comparison_operator, - expected=None, - ) - else: - actual_value = variable.value if variable else None - expected_value: str | Sequence[str] | bool | list[bool] | None = condition.value - if isinstance(expected_value, str): - expected_value = variable_pool.convert_template(expected_value).text - # Here we need to explicit convet the input string to boolean. - if isinstance(variable, (BooleanSegment, ArrayBooleanSegment)) and expected_value is not None: - # The following two lines is for compatibility with existing workflows. - if isinstance(expected_value, list): - expected_value = [_convert_to_bool(i) for i in expected_value] - else: - expected_value = _convert_to_bool(expected_value) - input_conditions.append( - { - "actual_value": actual_value, - "expected_value": expected_value, - "comparison_operator": condition.comparison_operator, - } - ) - result = _evaluate_condition( - value=actual_value, - operator=condition.comparison_operator, - expected=expected_value, - ) - group_results.append(result) - # Implemented short-circuit evaluation for logical conditions - if (operator == "and" and not result) or (operator == "or" and result): - final_result = result - return ConditionCheckResult(input_conditions, group_results, final_result) - - final_result = all(group_results) if operator == "and" else any(group_results) - return ConditionCheckResult(input_conditions, group_results, final_result) - - -def _evaluate_condition( - *, - operator: SupportedComparisonOperator, - value: object, - expected: str | Sequence[str] | bool | Sequence[bool] | None, -) -> bool: - match operator: - case "contains": - return _assert_contains(value=value, expected=expected) - case "not contains": - return _assert_not_contains(value=value, expected=expected) - case "start with": - return _assert_start_with(value=value, expected=expected) - case "end with": - return _assert_end_with(value=value, expected=expected) - case "is": - return _assert_is(value=value, expected=expected) - case "is not": - return _assert_is_not(value=value, expected=expected) - case "empty": - return _assert_empty(value=value) - case "not empty": - return _assert_not_empty(value=value) - case "=": - return _assert_equal(value=value, expected=expected) - case "≠": - return _assert_not_equal(value=value, expected=expected) - case ">": - return _assert_greater_than(value=value, expected=expected) - case "<": - return _assert_less_than(value=value, expected=expected) - case "≥": - return _assert_greater_than_or_equal(value=value, expected=expected) - case "≤": - return _assert_less_than_or_equal(value=value, expected=expected) - case "null": - return _assert_null(value=value) - case "not null": - return _assert_not_null(value=value) - case "in": - return _assert_in(value=value, expected=expected) - case "not in": - return _assert_not_in(value=value, expected=expected) - case "all of" if isinstance(expected, list): - # Type narrowing: at this point expected is a list, could be list[str] or list[bool] - if all(isinstance(item, str) for item in expected): - # Create a new typed list to satisfy type checker - str_list: list[str] = [item for item in expected if isinstance(item, str)] - return _assert_all_of(value=value, expected=str_list) - elif all(isinstance(item, bool) for item in expected): - # Create a new typed list to satisfy type checker - bool_list: list[bool] = [item for item in expected if isinstance(item, bool)] - return _assert_all_of_bool(value=value, expected=bool_list) - else: - raise ValueError("all of operator expects homogeneous list of strings or booleans") - case "exists": - return _assert_exists(value=value) - case "not exists": - return _assert_not_exists(value=value) - case _: - raise ValueError(f"Unsupported operator: {operator}") - - -def _assert_contains(*, value: object, expected: object) -> bool: - if not value: - return False - - if not isinstance(value, (str, list)): - raise ValueError("Invalid actual value type: string or array") - - # Type checking ensures value is str or list at this point - if isinstance(value, str): - if not isinstance(expected, str): - expected = str(expected) - if expected not in value: - return False - else: # value is list - if expected not in value: - return False - return True - - -def _assert_not_contains(*, value: object, expected: object) -> bool: - if not value: - return True - - if not isinstance(value, (str, list)): - raise ValueError("Invalid actual value type: string or array") - - # Type checking ensures value is str or list at this point - if isinstance(value, str): - if not isinstance(expected, str): - expected = str(expected) - if expected in value: - return False - else: # value is list - if expected in value: - return False - return True - - -def _assert_start_with(*, value: object, expected: object) -> bool: - if not value: - return False - - if not isinstance(value, str): - raise ValueError("Invalid actual value type: string") - - if not isinstance(expected, str): - raise ValueError("Expected value must be a string for startswith") - if not value.startswith(expected): - return False - return True - - -def _assert_end_with(*, value: object, expected: object) -> bool: - if not value: - return False - - if not isinstance(value, str): - raise ValueError("Invalid actual value type: string") - - if not isinstance(expected, str): - raise ValueError("Expected value must be a string for endswith") - if not value.endswith(expected): - return False - return True - - -def _assert_is(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (str, bool)): - raise ValueError("Invalid actual value type: string or boolean") - - if value != expected: - return False - return True - - -def _assert_is_not(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (str, bool)): - raise ValueError("Invalid actual value type: string or boolean") - - if value == expected: - return False - return True - - -def _assert_empty(*, value: object) -> bool: - if not value: - return True - return False - - -def _assert_not_empty(*, value: object) -> bool: - if value: - return True - return False - - -def _normalize_numeric_values(value: int | float, expected: object) -> tuple[int | float, int | float]: - """ - Normalize value and expected to compatible numeric types for comparison. - - Args: - value: The actual numeric value (int or float) - expected: The expected value (int, float, or str) - - Returns: - A tuple of (normalized_value, normalized_expected) with compatible types - - Raises: - ValueError: If expected cannot be converted to a number - """ - if not isinstance(expected, (int, float, str)): - raise ValueError(f"Cannot convert {type(expected)} to number") - - # Convert expected to appropriate numeric type - if isinstance(expected, str): - # Try to convert to float first to handle decimal strings - try: - expected_float = float(expected) - except ValueError as e: - raise ValueError(f"Cannot convert '{expected}' to number") from e - - # If value is int and expected is a whole number, keep as int comparison - if isinstance(value, int) and expected_float.is_integer(): - return value, int(expected_float) - else: - # Otherwise convert value to float for comparison - return float(value) if isinstance(value, int) else value, expected_float - elif isinstance(expected, float): - # If expected is already float, convert int value to float - return float(value) if isinstance(value, int) else value, expected - else: - # expected is int - return value, expected - - -def _assert_equal(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (int, float, bool)): - raise ValueError("Invalid actual value type: number or boolean") - - # Handle boolean comparison - if isinstance(value, bool): - if not isinstance(expected, (bool, int, str)): - raise ValueError(f"Cannot convert {type(expected)} to bool") - expected = bool(expected) - elif isinstance(value, int): - if not isinstance(expected, (int, float, str)): - raise ValueError(f"Cannot convert {type(expected)} to int") - expected = int(expected) - else: - if not isinstance(expected, (int, float, str)): - raise ValueError(f"Cannot convert {type(expected)} to float") - expected = float(expected) - - if value != expected: - return False - return True - - -def _assert_not_equal(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (int, float, bool)): - raise ValueError("Invalid actual value type: number or boolean") - - # Handle boolean comparison - if isinstance(value, bool): - if not isinstance(expected, (bool, int, str)): - raise ValueError(f"Cannot convert {type(expected)} to bool") - expected = bool(expected) - elif isinstance(value, int): - if not isinstance(expected, (int, float, str)): - raise ValueError(f"Cannot convert {type(expected)} to int") - expected = int(expected) - else: - if not isinstance(expected, (int, float, str)): - raise ValueError(f"Cannot convert {type(expected)} to float") - expected = float(expected) - - if value == expected: - return False - return True - - -def _assert_greater_than(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (int, float)): - raise ValueError("Invalid actual value type: number") - - value, expected = _normalize_numeric_values(value, expected) - return value > expected - - -def _assert_less_than(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (int, float)): - raise ValueError("Invalid actual value type: number") - - value, expected = _normalize_numeric_values(value, expected) - return value < expected - - -def _assert_greater_than_or_equal(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (int, float)): - raise ValueError("Invalid actual value type: number") - - value, expected = _normalize_numeric_values(value, expected) - return value >= expected - - -def _assert_less_than_or_equal(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (int, float)): - raise ValueError("Invalid actual value type: number") - - value, expected = _normalize_numeric_values(value, expected) - return value <= expected - - -def _assert_null(*, value: object) -> bool: - if value is None: - return True - return False - - -def _assert_not_null(*, value: object) -> bool: - if value is not None: - return True - return False - - -def _assert_in(*, value: object, expected: object) -> bool: - if not value: - return False - - if not isinstance(expected, list): - raise ValueError("Invalid expected value type: array") - - if value not in expected: - return False - return True - - -def _assert_not_in(*, value: object, expected: object) -> bool: - if not value: - return True - - if not isinstance(expected, list): - raise ValueError("Invalid expected value type: array") - - if value in expected: - return False - return True - - -def _assert_all_of(*, value: object, expected: Sequence[str]) -> bool: - if not value: - return False - - # Ensure value is a container that supports 'in' operator - if not isinstance(value, (list, tuple, set, str)): - return False - - return all(item in value for item in expected) - - -def _assert_all_of_bool(*, value: object, expected: Sequence[bool]) -> bool: - if not value: - return False - - # Ensure value is a container that supports 'in' operator - if not isinstance(value, (list, tuple, set)): - return False - - return all(item in value for item in expected) - - -def _assert_exists(*, value: object) -> bool: - return value is not None - - -def _assert_not_exists(*, value: object) -> bool: - return value is None - - -def _process_sub_conditions( - variable: ArrayFileSegment, - sub_conditions: Sequence[SubCondition], - operator: Literal["and", "or"], -) -> bool: - files = variable.value - group_results: list[bool] = [] - for condition in sub_conditions: - key = FileAttribute(condition.key) - values = [file_manager.get_attr(file=file, attr=key) for file in files] - expected_value = condition.value - if key == FileAttribute.EXTENSION: - if not isinstance(expected_value, str): - raise TypeError("Expected value must be a string when key is FileAttribute.EXTENSION") - if expected_value and not expected_value.startswith("."): - expected_value = "." + expected_value - - normalized_values: list[object] = [] - for value in values: - if value and isinstance(value, str): - if not value.startswith("."): - value = "." + value - normalized_values.append(value) - values = normalized_values - sub_group_results: list[bool] = [ - _evaluate_condition( - value=value, - operator=condition.comparison_operator, - expected=expected_value, - ) - for value in values - ] - # Determine the result based on the presence of "not" in the comparison operator - result = all(sub_group_results) if "not" in condition.comparison_operator else any(sub_group_results) - group_results.append(result) - return all(group_results) if operator == "and" else any(group_results) diff --git a/api/dify_graph/variable_loader.py b/api/dify_graph/variable_loader.py deleted file mode 100644 index d263450334e..00000000000 --- a/api/dify_graph/variable_loader.py +++ /dev/null @@ -1,83 +0,0 @@ -import abc -from collections.abc import Mapping, Sequence -from typing import Any, Protocol - -from dify_graph.runtime import VariablePool -from dify_graph.variables import VariableBase -from dify_graph.variables.consts import SELECTORS_LENGTH - - -class VariableLoader(Protocol): - """Interface for loading variables based on selectors. - - A `VariableLoader` is responsible for retrieving additional variables required during the execution - of a single node, which are not provided as user inputs. - - NOTE(QuantumGhost): Typically, all variables loaded by a `VariableLoader` should belong to the same - application and share the same `app_id`. However, this interface does not enforce that constraint, - and the `app_id` parameter is intentionally omitted from `load_variables` to achieve separation of - concern and allow for flexible implementations. - - Implementations of `VariableLoader` should almost always have an `app_id` parameter in - their constructor. - - TODO(QuantumGhost): this is a temporally workaround. If we can move the creation of node instance into - `WorkflowService.single_step_run`, we may get rid of this interface. - """ - - @abc.abstractmethod - def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]: - """Load variables based on the provided selectors. If the selectors are empty, - this method should return an empty list. - - The order of the returned variables is not guaranteed. If the caller wants to ensure - a specific order, they should sort the returned list themselves. - - :param: selectors: a list of string list, each inner list should have at least two elements: - - the first element is the node ID, - - the second element is the variable name. - :return: a list of VariableBase objects that match the provided selectors. - """ - pass - - -class _DummyVariableLoader(VariableLoader): - """A dummy implementation of VariableLoader that does not load any variables. - Serves as a placeholder when no variable loading is needed. - """ - - def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]: - return [] - - -DUMMY_VARIABLE_LOADER = _DummyVariableLoader() - - -def load_into_variable_pool( - variable_loader: VariableLoader, - variable_pool: VariablePool, - variable_mapping: Mapping[str, Sequence[str]], - user_inputs: Mapping[str, Any], -): - # Loading missing variable from draft var here, and set it into - # variable_pool. - variables_to_load: list[list[str]] = [] - for key, selector in variable_mapping.items(): - # NOTE(QuantumGhost): this logic needs to be in sync with - # `WorkflowEntry.mapping_user_inputs_to_variable_pool`. - node_variable_list = key.split(".") - if len(node_variable_list) < 2: - raise ValueError(f"Invalid variable key: {key}. It should have at least two elements.") - if key in user_inputs: - continue - node_variable_key = ".".join(node_variable_list[1:]) - if node_variable_key in user_inputs: - continue - if variable_pool.get(selector) is None: - variables_to_load.append(list(selector)) - loaded = variable_loader.load_variables(variables_to_load) - for var in loaded: - assert len(var.selector) >= SELECTORS_LENGTH, f"Invalid variable {var}" - # Add variable directly to the pool - # The variable pool expects 2-element selectors [node_id, variable_name] - variable_pool.add([var.selector[0], var.selector[1]], var) diff --git a/api/dify_graph/variables/__init__.py b/api/dify_graph/variables/__init__.py deleted file mode 100644 index be3fc8d97a6..00000000000 --- a/api/dify_graph/variables/__init__.py +++ /dev/null @@ -1,70 +0,0 @@ -from .input_entities import VariableEntity, VariableEntityType -from .segment_group import SegmentGroup -from .segments import ( - ArrayAnySegment, - ArrayFileSegment, - ArrayNumberSegment, - ArrayObjectSegment, - ArraySegment, - ArrayStringSegment, - FileSegment, - FloatSegment, - IntegerSegment, - NoneSegment, - ObjectSegment, - Segment, - StringSegment, -) -from .types import SegmentType -from .variables import ( - ArrayAnyVariable, - ArrayFileVariable, - ArrayNumberVariable, - ArrayObjectVariable, - ArrayStringVariable, - ArrayVariable, - FileVariable, - FloatVariable, - IntegerVariable, - NoneVariable, - ObjectVariable, - SecretVariable, - StringVariable, - Variable, - VariableBase, -) - -__all__ = [ - "ArrayAnySegment", - "ArrayAnyVariable", - "ArrayFileSegment", - "ArrayFileVariable", - "ArrayNumberSegment", - "ArrayNumberVariable", - "ArrayObjectSegment", - "ArrayObjectVariable", - "ArraySegment", - "ArrayStringSegment", - "ArrayStringVariable", - "ArrayVariable", - "FileSegment", - "FileVariable", - "FloatSegment", - "FloatVariable", - "IntegerSegment", - "IntegerVariable", - "NoneSegment", - "NoneVariable", - "ObjectSegment", - "ObjectVariable", - "SecretVariable", - "Segment", - "SegmentGroup", - "SegmentType", - "StringSegment", - "StringVariable", - "Variable", - "VariableBase", - "VariableEntity", - "VariableEntityType", -] diff --git a/api/dify_graph/variables/consts.py b/api/dify_graph/variables/consts.py deleted file mode 100644 index 8f3f78f740f..00000000000 --- a/api/dify_graph/variables/consts.py +++ /dev/null @@ -1,7 +0,0 @@ -# The minimal selector length for valid variables. -# -# The first element of the selector is the node id, and the second element is the variable name. -# -# If the selector length is more than 2, the remaining parts are the keys / indexes paths used -# to extract part of the variable value. -SELECTORS_LENGTH = 2 diff --git a/api/dify_graph/variables/exc.py b/api/dify_graph/variables/exc.py deleted file mode 100644 index 5cf67c3bacc..00000000000 --- a/api/dify_graph/variables/exc.py +++ /dev/null @@ -1,2 +0,0 @@ -class VariableError(ValueError): - pass diff --git a/api/dify_graph/variables/input_entities.py b/api/dify_graph/variables/input_entities.py deleted file mode 100644 index e6a68ea3594..00000000000 --- a/api/dify_graph/variables/input_entities.py +++ /dev/null @@ -1,62 +0,0 @@ -from collections.abc import Sequence -from enum import StrEnum -from typing import Any - -from jsonschema import Draft7Validator, SchemaError -from pydantic import BaseModel, Field, field_validator - -from dify_graph.file import FileTransferMethod, FileType - - -class VariableEntityType(StrEnum): - TEXT_INPUT = "text-input" - SELECT = "select" - PARAGRAPH = "paragraph" - NUMBER = "number" - EXTERNAL_DATA_TOOL = "external_data_tool" - FILE = "file" - FILE_LIST = "file-list" - CHECKBOX = "checkbox" - JSON_OBJECT = "json_object" - - -class VariableEntity(BaseModel): - """ - Shared variable entity used by workflow runtime and app configuration. - """ - - # `variable` records the name of the variable in user inputs. - variable: str - label: str - description: str = "" - type: VariableEntityType - required: bool = False - hide: bool = False - default: Any = None - max_length: int | None = None - options: Sequence[str] = Field(default_factory=list) - allowed_file_types: Sequence[FileType] | None = Field(default_factory=list) - allowed_file_extensions: Sequence[str] | None = Field(default_factory=list) - allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list) - json_schema: dict[str, Any] | None = Field(default=None) - - @field_validator("description", mode="before") - @classmethod - def convert_none_description(cls, value: Any) -> str: - return value or "" - - @field_validator("options", mode="before") - @classmethod - def convert_none_options(cls, value: Any) -> Sequence[str]: - return value or [] - - @field_validator("json_schema") - @classmethod - def validate_json_schema(cls, schema: dict[str, Any] | None) -> dict[str, Any] | None: - if schema is None: - return None - try: - Draft7Validator.check_schema(schema) - except SchemaError as error: - raise ValueError(f"Invalid JSON schema: {error.message}") - return schema diff --git a/api/dify_graph/variables/segment_group.py b/api/dify_graph/variables/segment_group.py deleted file mode 100644 index b363255b2ca..00000000000 --- a/api/dify_graph/variables/segment_group.py +++ /dev/null @@ -1,22 +0,0 @@ -from .segments import Segment -from .types import SegmentType - - -class SegmentGroup(Segment): - value_type: SegmentType = SegmentType.GROUP - value: list[Segment] - - @property - def text(self): - return "".join([segment.text for segment in self.value]) - - @property - def log(self): - return "".join([segment.log for segment in self.value]) - - @property - def markdown(self): - return "".join([segment.markdown for segment in self.value]) - - def to_object(self): - return [segment.to_object() for segment in self.value] diff --git a/api/dify_graph/variables/segments.py b/api/dify_graph/variables/segments.py deleted file mode 100644 index bdb213ed48f..00000000000 --- a/api/dify_graph/variables/segments.py +++ /dev/null @@ -1,253 +0,0 @@ -import json -import sys -from collections.abc import Mapping, Sequence -from typing import Annotated, Any, TypeAlias - -from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator - -from dify_graph.file import File - -from .types import SegmentType - - -class Segment(BaseModel): - """Segment is runtime type used during the execution of workflow. - - Note: this class is abstract, you should use subclasses of this class instead. - """ - - model_config = ConfigDict(frozen=True) - - value_type: SegmentType - value: Any - - @field_validator("value_type") - @classmethod - def validate_value_type(cls, value): - """ - This validator checks if the provided value is equal to the default value of the 'value_type' field. - If the value is different, a ValueError is raised. - """ - if value != cls.model_fields["value_type"].default: - raise ValueError("Cannot modify 'value_type'") - return value - - @property - def text(self) -> str: - return str(self.value) - - @property - def log(self) -> str: - return str(self.value) - - @property - def markdown(self) -> str: - return str(self.value) - - @property - def size(self) -> int: - """ - Return the size of the value in bytes. - """ - return sys.getsizeof(self.value) - - def to_object(self): - return self.value - - -class NoneSegment(Segment): - value_type: SegmentType = SegmentType.NONE - value: None = None - - @property - def text(self) -> str: - return "" - - @property - def log(self) -> str: - return "" - - @property - def markdown(self) -> str: - return "" - - -class StringSegment(Segment): - value_type: SegmentType = SegmentType.STRING - value: str - - -class FloatSegment(Segment): - value_type: SegmentType = SegmentType.FLOAT - value: float - # NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems. - # The following tests cannot pass. - # - # def test_float_segment_and_nan(): - # nan = float("nan") - # assert nan != nan - # - # f1 = FloatSegment(value=float("nan")) - # f2 = FloatSegment(value=float("nan")) - # assert f1 != f2 - # - # f3 = FloatSegment(value=nan) - # f4 = FloatSegment(value=nan) - # assert f3 != f4 - - -class IntegerSegment(Segment): - value_type: SegmentType = SegmentType.INTEGER - value: int - - -class ObjectSegment(Segment): - value_type: SegmentType = SegmentType.OBJECT - value: Mapping[str, Any] - - @property - def text(self) -> str: - return json.dumps(self.model_dump()["value"], ensure_ascii=False) - - @property - def log(self) -> str: - return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2) - - @property - def markdown(self) -> str: - return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2) - - -class ArraySegment(Segment): - @property - def text(self) -> str: - # Return empty string for empty arrays instead of "[]" - if not self.value: - return "" - return super().text - - @property - def markdown(self) -> str: - items = [] - for item in self.value: - items.append(f"- {item}") - return "\n".join(items) - - -class FileSegment(Segment): - value_type: SegmentType = SegmentType.FILE - value: File - - @property - def markdown(self) -> str: - return self.value.markdown - - @property - def log(self) -> str: - return "" - - @property - def text(self) -> str: - return "" - - -class BooleanSegment(Segment): - value_type: SegmentType = SegmentType.BOOLEAN - value: bool - - -class ArrayAnySegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_ANY - value: Sequence[Any] - - -class ArrayStringSegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_STRING - value: Sequence[str] - - @property - def text(self) -> str: - # Return empty string for empty arrays instead of "[]" - if not self.value: - return "" - return json.dumps(self.value, ensure_ascii=False) - - -class ArrayNumberSegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_NUMBER - value: Sequence[float | int] - - -class ArrayObjectSegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_OBJECT - value: Sequence[Mapping[str, Any]] - - -class ArrayFileSegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_FILE - value: Sequence[File] - - @property - def markdown(self) -> str: - items = [] - for item in self.value: - items.append(item.markdown) - return "\n".join(items) - - @property - def log(self) -> str: - return "" - - @property - def text(self) -> str: - return "" - - -class ArrayBooleanSegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_BOOLEAN - value: Sequence[bool] - - -def get_segment_discriminator(v: Any) -> SegmentType | None: - if isinstance(v, Segment): - return v.value_type - elif isinstance(v, dict): - value_type = v.get("value_type") - if value_type is None: - return None - try: - seg_type = SegmentType(value_type) - except ValueError: - return None - return seg_type - else: - # return None if the discriminator value isn't found - return None - - -# The `SegmentUnion`` type is used to enable serialization and deserialization with Pydantic. -# Use `Segment` for type hinting when serialization is not required. -# -# Note: -# - All variants in `SegmentUnion` must inherit from the `Segment` class. -# - The union must include all non-abstract subclasses of `Segment`, except: -# - `SegmentGroup`, which is not added to the variable pool. -# - `VariableBase` and its subclasses, which are handled by `Variable`. -SegmentUnion: TypeAlias = Annotated[ - ( - Annotated[NoneSegment, Tag(SegmentType.NONE)] - | Annotated[StringSegment, Tag(SegmentType.STRING)] - | Annotated[FloatSegment, Tag(SegmentType.FLOAT)] - | Annotated[IntegerSegment, Tag(SegmentType.INTEGER)] - | Annotated[ObjectSegment, Tag(SegmentType.OBJECT)] - | Annotated[FileSegment, Tag(SegmentType.FILE)] - | Annotated[BooleanSegment, Tag(SegmentType.BOOLEAN)] - | Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)] - | Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)] - | Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)] - | Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)] - | Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)] - | Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)] - ), - Discriminator(get_segment_discriminator), -] diff --git a/api/dify_graph/variables/types.py b/api/dify_graph/variables/types.py deleted file mode 100644 index 53bf495a270..00000000000 --- a/api/dify_graph/variables/types.py +++ /dev/null @@ -1,273 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping -from enum import StrEnum -from typing import TYPE_CHECKING, Any - -from dify_graph.file.models import File - -if TYPE_CHECKING: - from dify_graph.variables.segments import Segment - - -class ArrayValidation(StrEnum): - """Strategy for validating array elements. - - Note: - The `NONE` and `FIRST` strategies are primarily for compatibility purposes. - Avoid using them in new code whenever possible. - """ - - # Skip element validation (only check array container) - NONE = "none" - - # Validate the first element (if array is non-empty) - FIRST = "first" - - # Validate all elements in the array. - ALL = "all" - - -class SegmentType(StrEnum): - NUMBER = "number" - INTEGER = "integer" - FLOAT = "float" - STRING = "string" - OBJECT = "object" - SECRET = "secret" - - FILE = "file" - BOOLEAN = "boolean" - - ARRAY_ANY = "array[any]" - ARRAY_STRING = "array[string]" - ARRAY_NUMBER = "array[number]" - ARRAY_OBJECT = "array[object]" - ARRAY_FILE = "array[file]" - ARRAY_BOOLEAN = "array[boolean]" - - NONE = "none" - - GROUP = "group" - - def is_array_type(self) -> bool: - return self in _ARRAY_TYPES - - @classmethod - def infer_segment_type(cls, value: Any) -> SegmentType | None: - """ - Attempt to infer the `SegmentType` based on the Python type of the `value` parameter. - - Returns `None` if no appropriate `SegmentType` can be determined for the given `value`. - For example, this may occur if the input is a generic Python object of type `object`. - """ - - if isinstance(value, list): - elem_types: set[SegmentType] = set() - for i in value: - segment_type = cls.infer_segment_type(i) - if segment_type is None: - return None - - elem_types.add(segment_type) - - if len(elem_types) != 1: - if elem_types.issubset(_NUMERICAL_TYPES): - return SegmentType.ARRAY_NUMBER - return SegmentType.ARRAY_ANY - elif all(i.is_array_type() for i in elem_types): - return SegmentType.ARRAY_ANY - match elem_types.pop(): - case SegmentType.STRING: - return SegmentType.ARRAY_STRING - case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: - return SegmentType.ARRAY_NUMBER - case SegmentType.OBJECT: - return SegmentType.ARRAY_OBJECT - case SegmentType.FILE: - return SegmentType.ARRAY_FILE - case SegmentType.NONE: - return SegmentType.ARRAY_ANY - case SegmentType.BOOLEAN: - return SegmentType.ARRAY_BOOLEAN - case _: - # This should be unreachable. - raise ValueError(f"not supported value {value}") - if value is None: - return SegmentType.NONE - # Important: The check for `bool` must precede the check for `int`, - # as `bool` is a subclass of `int` in Python's type hierarchy. - elif isinstance(value, bool): - return SegmentType.BOOLEAN - elif isinstance(value, int): - return SegmentType.INTEGER - elif isinstance(value, float): - return SegmentType.FLOAT - elif isinstance(value, str): - return SegmentType.STRING - elif isinstance(value, dict): - return SegmentType.OBJECT - elif isinstance(value, File): - return SegmentType.FILE - else: - return None - - def _validate_array(self, value: Any, array_validation: ArrayValidation) -> bool: - if not isinstance(value, list): - return False - # Skip element validation if array is empty - if len(value) == 0: - return True - if self == SegmentType.ARRAY_ANY: - return True - element_type = _ARRAY_ELEMENT_TYPES_MAPPING[self] - - if array_validation == ArrayValidation.NONE: - return True - elif array_validation == ArrayValidation.FIRST: - return element_type.is_valid(value[0]) - else: - return all(element_type.is_valid(i, array_validation=ArrayValidation.NONE) for i in value) - - def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.ALL) -> bool: - """ - Check if a value matches the segment type. - Users of `SegmentType` should call this method, instead of using - `isinstance` manually. - - Args: - value: The value to validate - array_validation: Validation strategy for array types (ignored for non-array types) - - Returns: - True if the value matches the type under the given validation strategy - """ - if self.is_array_type(): - return self._validate_array(value, array_validation) - # Important: The check for `bool` must precede the check for `int`, - # as `bool` is a subclass of `int` in Python's type hierarchy. - elif self == SegmentType.BOOLEAN: - return isinstance(value, bool) - elif self in [SegmentType.INTEGER, SegmentType.FLOAT, SegmentType.NUMBER]: - return isinstance(value, (int, float)) - elif self == SegmentType.STRING: - return isinstance(value, str) - elif self == SegmentType.OBJECT: - return isinstance(value, dict) - elif self == SegmentType.SECRET: - return isinstance(value, str) - elif self == SegmentType.FILE: - return isinstance(value, File) - elif self == SegmentType.NONE: - return value is None - elif self == SegmentType.GROUP: - from .segment_group import SegmentGroup - from .segments import Segment - - if isinstance(value, SegmentGroup): - return all(isinstance(item, Segment) for item in value.value) - - if isinstance(value, list): - return all(isinstance(item, Segment) for item in value) - - return False - else: - raise AssertionError("this statement should be unreachable.") - - @staticmethod - def cast_value(value: Any, type_: SegmentType): - # Cast Python's `bool` type to `int` when the runtime type requires - # an integer or number. - # - # This ensures compatibility with existing workflows that may use `bool` as - # `int`, since in Python's type system, `bool` is a subtype of `int`. - # - # This function exists solely to maintain compatibility with existing workflows. - # It should not be used to compromise the integrity of the runtime type system. - # No additional casting rules should be introduced to this function. - - if type_ in ( - SegmentType.INTEGER, - SegmentType.NUMBER, - ) and isinstance(value, bool): - return int(value) - if type_ == SegmentType.ARRAY_NUMBER and all(isinstance(i, bool) for i in value): - return [int(i) for i in value] - return value - - def exposed_type(self) -> SegmentType: - """Returns the type exposed to the frontend. - - The frontend treats `INTEGER` and `FLOAT` as `NUMBER`, so these are returned as `NUMBER` here. - """ - if self in (SegmentType.INTEGER, SegmentType.FLOAT): - return SegmentType.NUMBER - return self - - def element_type(self) -> SegmentType | None: - """Return the element type of the current segment type, or `None` if the element type is undefined. - - Raises: - ValueError: If the current segment type is not an array type. - - Note: - For certain array types, such as `SegmentType.ARRAY_ANY`, their element types are not defined - by the runtime system. In such cases, this method will return `None`. - """ - if not self.is_array_type(): - raise ValueError(f"element_type is only supported by array type, got {self}") - return _ARRAY_ELEMENT_TYPES_MAPPING.get(self) - - @staticmethod - def get_zero_value(t: SegmentType) -> Segment: - # Lazy import to avoid circular dependency - from factories import variable_factory - - match t: - case ( - SegmentType.ARRAY_OBJECT - | SegmentType.ARRAY_ANY - | SegmentType.ARRAY_STRING - | SegmentType.ARRAY_NUMBER - | SegmentType.ARRAY_BOOLEAN - ): - return variable_factory.build_segment_with_type(t, []) - case SegmentType.OBJECT: - return variable_factory.build_segment({}) - case SegmentType.STRING: - return variable_factory.build_segment("") - case SegmentType.INTEGER: - return variable_factory.build_segment(0) - case SegmentType.FLOAT: - return variable_factory.build_segment(0.0) - case SegmentType.NUMBER: - return variable_factory.build_segment(0) - case SegmentType.BOOLEAN: - return variable_factory.build_segment(False) - case _: - raise ValueError(f"unsupported variable type: {t}") - - -_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = { - # ARRAY_ANY does not have corresponding element type. - SegmentType.ARRAY_STRING: SegmentType.STRING, - SegmentType.ARRAY_NUMBER: SegmentType.NUMBER, - SegmentType.ARRAY_OBJECT: SegmentType.OBJECT, - SegmentType.ARRAY_FILE: SegmentType.FILE, - SegmentType.ARRAY_BOOLEAN: SegmentType.BOOLEAN, -} - -_ARRAY_TYPES = frozenset( - list(_ARRAY_ELEMENT_TYPES_MAPPING.keys()) - + [ - SegmentType.ARRAY_ANY, - ] -) - -_NUMERICAL_TYPES = frozenset( - [ - SegmentType.NUMBER, - SegmentType.INTEGER, - SegmentType.FLOAT, - ] -) diff --git a/api/dify_graph/variables/utils.py b/api/dify_graph/variables/utils.py deleted file mode 100644 index 8e738f8fd5f..00000000000 --- a/api/dify_graph/variables/utils.py +++ /dev/null @@ -1,33 +0,0 @@ -from collections.abc import Iterable, Sequence -from typing import Any - -import orjson - -from .segment_group import SegmentGroup -from .segments import ArrayFileSegment, FileSegment, Segment - - -def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]: - selectors = [node_id, name] - if paths: - selectors.extend(paths) - return selectors - - -def segment_orjson_default(o: Any): - """Default function for orjson serialization of Segment types""" - if isinstance(o, ArrayFileSegment): - return [v.model_dump() for v in o.value] - elif isinstance(o, FileSegment): - return o.value.model_dump() - elif isinstance(o, SegmentGroup): - return [segment_orjson_default(seg) for seg in o.value] - elif isinstance(o, Segment): - return o.value - raise TypeError(f"Object of type {type(o).__name__} is not JSON serializable") - - -def dumps_with_segments(obj: Any, ensure_ascii: bool = False) -> str: - """JSON dumps with segment support using orjson""" - option = orjson.OPT_NON_STR_KEYS - return orjson.dumps(obj, default=segment_orjson_default, option=option).decode("utf-8") diff --git a/api/dify_graph/variables/variables.py b/api/dify_graph/variables/variables.py deleted file mode 100644 index af866283dac..00000000000 --- a/api/dify_graph/variables/variables.py +++ /dev/null @@ -1,172 +0,0 @@ -from collections.abc import Sequence -from typing import Annotated, Any, TypeAlias -from uuid import uuid4 - -from pydantic import BaseModel, Discriminator, Field, Tag - -from .segments import ( - ArrayAnySegment, - ArrayBooleanSegment, - ArrayFileSegment, - ArrayNumberSegment, - ArrayObjectSegment, - ArraySegment, - ArrayStringSegment, - BooleanSegment, - FileSegment, - FloatSegment, - IntegerSegment, - NoneSegment, - ObjectSegment, - Segment, - StringSegment, - get_segment_discriminator, -) -from .types import SegmentType - - -def _obfuscated_token(token: str) -> str: - if not token: - return token - if len(token) <= 8: - return "*" * 20 - return token[:6] + "*" * 12 + token[-2:] - - -class VariableBase(Segment): - """ - A variable is a segment that has a name. - - It is mainly used to store segments and their selector in VariablePool. - - Note: this class is abstract, you should use subclasses of this class instead. - """ - - id: str = Field( - default_factory=lambda: str(uuid4()), - description="Unique identity for variable.", - ) - name: str - description: str = Field(default="", description="Description of the variable.") - selector: Sequence[str] = Field(default_factory=list) - - -class StringVariable(StringSegment, VariableBase): - pass - - -class FloatVariable(FloatSegment, VariableBase): - pass - - -class IntegerVariable(IntegerSegment, VariableBase): - pass - - -class ObjectVariable(ObjectSegment, VariableBase): - pass - - -class ArrayVariable(ArraySegment, VariableBase): - pass - - -class ArrayAnyVariable(ArrayAnySegment, ArrayVariable): - pass - - -class ArrayStringVariable(ArrayStringSegment, ArrayVariable): - pass - - -class ArrayNumberVariable(ArrayNumberSegment, ArrayVariable): - pass - - -class ArrayObjectVariable(ArrayObjectSegment, ArrayVariable): - pass - - -class SecretVariable(StringVariable): - value_type: SegmentType = SegmentType.SECRET - - @property - def log(self) -> str: - return _obfuscated_token(self.value) - - -class NoneVariable(NoneSegment, VariableBase): - value_type: SegmentType = SegmentType.NONE - value: None = None - - -class FileVariable(FileSegment, VariableBase): - pass - - -class BooleanVariable(BooleanSegment, VariableBase): - pass - - -class ArrayFileVariable(ArrayFileSegment, ArrayVariable): - pass - - -class ArrayBooleanVariable(ArrayBooleanSegment, ArrayVariable): - pass - - -class RAGPipelineVariable(BaseModel): - belong_to_node_id: str = Field(description="belong to which node id, shared means public") - type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list") - label: str = Field(description="label") - description: str | None = Field(description="description", default="") - variable: str = Field(description="variable key", default="") - max_length: int | None = Field( - description="max length, applicable to text-input, paragraph, and file-list", default=0 - ) - default_value: Any = Field(description="default value", default="") - placeholder: str | None = Field(description="placeholder", default="") - unit: str | None = Field(description="unit, applicable to Number", default="") - tooltips: str | None = Field(description="helpful text", default="") - allowed_file_types: list[str] | None = Field( - description="image, document, audio, video, custom.", default_factory=list - ) - allowed_file_extensions: list[str] | None = Field(description="e.g. ['.jpg', '.mp3']", default_factory=list) - allowed_file_upload_methods: list[str] | None = Field( - description="remote_url, local_file, tool_file.", default_factory=list - ) - required: bool = Field(description="optional, default false", default=False) - options: list[str] | None = Field(default_factory=list) - - -class RAGPipelineVariableInput(BaseModel): - variable: RAGPipelineVariable - value: Any - - -# The `Variable` type is used to enable serialization and deserialization with Pydantic. -# Use `VariableBase` for type hinting when serialization is not required. -# -# Note: -# - All variants in `Variable` must inherit from the `VariableBase` class. -# - The union must include all non-abstract subclasses of `VariableBase`. -Variable: TypeAlias = Annotated[ - ( - Annotated[NoneVariable, Tag(SegmentType.NONE)] - | Annotated[StringVariable, Tag(SegmentType.STRING)] - | Annotated[FloatVariable, Tag(SegmentType.FLOAT)] - | Annotated[IntegerVariable, Tag(SegmentType.INTEGER)] - | Annotated[ObjectVariable, Tag(SegmentType.OBJECT)] - | Annotated[FileVariable, Tag(SegmentType.FILE)] - | Annotated[BooleanVariable, Tag(SegmentType.BOOLEAN)] - | Annotated[ArrayAnyVariable, Tag(SegmentType.ARRAY_ANY)] - | Annotated[ArrayStringVariable, Tag(SegmentType.ARRAY_STRING)] - | Annotated[ArrayNumberVariable, Tag(SegmentType.ARRAY_NUMBER)] - | Annotated[ArrayObjectVariable, Tag(SegmentType.ARRAY_OBJECT)] - | Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)] - | Annotated[ArrayBooleanVariable, Tag(SegmentType.ARRAY_BOOLEAN)] - | Annotated[SecretVariable, Tag(SegmentType.SECRET)] - ), - Discriminator(get_segment_discriminator), -] diff --git a/api/dify_graph/workflow_type_encoder.py b/api/dify_graph/workflow_type_encoder.py deleted file mode 100644 index 3dd846b3cb9..00000000000 --- a/api/dify_graph/workflow_type_encoder.py +++ /dev/null @@ -1,49 +0,0 @@ -from collections.abc import Mapping -from decimal import Decimal -from typing import Any, overload - -from pydantic import BaseModel - -from dify_graph.file.models import File -from dify_graph.variables import Segment - - -class WorkflowRuntimeTypeConverter: - @overload - def to_json_encodable(self, value: Mapping[str, Any]) -> Mapping[str, Any]: ... - @overload - def to_json_encodable(self, value: None) -> None: ... - - def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None: - """Convert runtime values to JSON-serializable structures.""" - - result = self.value_to_json_encodable_recursive(value) - if isinstance(result, Mapping) or result is None: - return result - return {} - - def value_to_json_encodable_recursive(self, value: Any): - if value is None: - return value - if isinstance(value, (bool, int, str, float)): - return value - if isinstance(value, Decimal): - # Convert Decimal to float for JSON serialization - return float(value) - if isinstance(value, Segment): - return self.value_to_json_encodable_recursive(value.value) - if isinstance(value, File): - return value.to_dict() - if isinstance(value, BaseModel): - return value.model_dump(mode="json") - if isinstance(value, dict): - res = {} - for k, v in value.items(): - res[k] = self.value_to_json_encodable_recursive(v) - return res - if isinstance(value, list): - res_list = [] - for item in value: - res_list.append(self.value_to_json_encodable_recursive(item)) - return res_list - return value diff --git a/api/dify_graph/__init__.py b/api/enterprise/__init__.py similarity index 100% rename from api/dify_graph/__init__.py rename to api/enterprise/__init__.py diff --git a/api/enterprise/telemetry/DATA_DICTIONARY.md b/api/enterprise/telemetry/DATA_DICTIONARY.md new file mode 100644 index 00000000000..60d482cd1c8 --- /dev/null +++ b/api/enterprise/telemetry/DATA_DICTIONARY.md @@ -0,0 +1,525 @@ +# Dify Enterprise Telemetry Data Dictionary + +Quick reference for all telemetry signals emitted by Dify Enterprise. For configuration and architecture details, see [README.md](./README.md). + +## Resource Attributes + +Attached to every signal (Span, Metric, Log). + +| Attribute | Type | Example | +|-----------|------|---------| +| `service.name` | string | `dify` | +| `host.name` | string | `dify-api-7f8b` | + +## Traces (Spans) + +### `dify.workflow.run` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.trace_id` | string | Business trace ID (Workflow Run ID) | +| `dify.tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.workflow.id` | string | Workflow definition ID | +| `dify.workflow.run_id` | string | Unique ID for this run | +| `dify.workflow.status` | string | `succeeded`, `failed`, `stopped`, etc. | +| `dify.workflow.error` | string | Error message if failed | +| `dify.workflow.elapsed_time` | float | Total execution time (seconds) | +| `dify.invoke_from` | string | `api`, `webapp`, `debug` | +| `dify.conversation.id` | string | Conversation ID (optional) | +| `dify.message.id` | string | Message ID (optional) | +| `dify.invoked_by` | string | User ID who triggered the run | +| `gen_ai.usage.total_tokens` | int | Total tokens across all nodes (optional) | +| `gen_ai.user.id` | string | End-user identifier (optional) | +| `dify.parent.trace_id` | string | Parent workflow trace ID (optional) | +| `dify.parent.workflow.run_id` | string | Parent workflow run ID (optional) | +| `dify.parent.node.execution_id` | string | Parent node execution ID (optional) | +| `dify.parent.app.id` | string | Parent app ID (optional) | + +### `dify.node.execution` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.trace_id` | string | Business trace ID | +| `dify.tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.workflow.id` | string | Workflow definition ID | +| `dify.workflow.run_id` | string | Workflow Run ID | +| `dify.message.id` | string | Message ID (optional) | +| `dify.conversation.id` | string | Conversation ID (optional) | +| `dify.node.execution_id` | string | Unique node execution ID | +| `dify.node.id` | string | Node ID in workflow graph | +| `dify.node.type` | string | Node type (see appendix) | +| `dify.node.title` | string | Display title | +| `dify.node.status` | string | `succeeded`, `failed` | +| `dify.node.error` | string | Error message if failed | +| `dify.node.elapsed_time` | float | Execution time (seconds) | +| `dify.node.index` | int | Execution order index | +| `dify.node.predecessor_node_id` | string | Triggering node ID | +| `dify.node.iteration_id` | string | Iteration ID (optional) | +| `dify.node.loop_id` | string | Loop ID (optional) | +| `dify.node.parallel_id` | string | Parallel branch ID (optional) | +| `dify.node.invoked_by` | string | User ID who triggered execution | +| `gen_ai.usage.input_tokens` | int | Prompt tokens (LLM nodes only) | +| `gen_ai.usage.output_tokens` | int | Completion tokens (LLM nodes only) | +| `gen_ai.usage.total_tokens` | int | Total tokens (LLM nodes only) | +| `gen_ai.request.model` | string | LLM model name (LLM nodes only) | +| `gen_ai.provider.name` | string | LLM provider name (LLM nodes only) | +| `gen_ai.user.id` | string | End-user identifier (optional) | + +### `dify.node.execution.draft` + +Same attributes as `dify.node.execution`. Emitted during Preview/Debug runs. + +## Counters + +All counters are cumulative and emitted at 100% accuracy. + +### Token Counters + +| Metric | Unit | Description | +|--------|------|-------------| +| `dify.tokens.total` | `{token}` | Total tokens consumed | +| `dify.tokens.input` | `{token}` | Input (prompt) tokens | +| `dify.tokens.output` | `{token}` | Output (completion) tokens | + +**Labels:** + +- `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `node_type` (if node_execution) + +⚠️ **Warning:** `dify.tokens.total` at workflow level includes all node tokens. Filter by `operation_type` to avoid double-counting. + +#### Token Hierarchy & Query Patterns + +Token metrics are emitted at multiple layers. Understanding the hierarchy prevents double-counting: + +``` +App-level total +├── workflow ← sum of all node_execution tokens (DO NOT add both) +│ └── node_execution ← per-node breakdown +├── message ← independent (non-workflow chat apps only) +├── rule_generate ← independent helper LLM call +├── code_generate ← independent helper LLM call +├── structured_output ← independent helper LLM call +└── instruction_modify← independent helper LLM call +``` + +**Key rule:** `workflow` tokens already include all `node_execution` tokens. Never sum both. + +**Available labels on token metrics:** `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `node_type`. +App name is only available on span attributes (`dify.app.name`), not metric labels — use `app_id` for metric queries. + +**Common queries** (PromQL): + +```promql +# ── Totals ────────────────────────────────────────────────── +# App-level total (exclude node_execution to avoid double-counting) +sum by (app_id) (dify_tokens_total{operation_type!="node_execution"}) + +# Single app total +sum (dify_tokens_total{app_id="", operation_type!="node_execution"}) + +# Per-tenant totals +sum by (tenant_id) (dify_tokens_total{operation_type!="node_execution"}) + +# ── Drill-down ────────────────────────────────────────────── +# Workflow-level tokens for an app +sum (dify_tokens_total{app_id="", operation_type="workflow"}) + +# Node-level breakdown within an app +sum by (node_type) (dify_tokens_total{app_id="", operation_type="node_execution"}) + +# Model breakdown for an app +sum by (model_provider, model_name) (dify_tokens_total{app_id=""}) + +# Input vs output per model +sum by (model_name) (dify_tokens_input_total{app_id=""}) +sum by (model_name) (dify_tokens_output_total{app_id=""}) + +# ── Rates ─────────────────────────────────────────────────── +# Token consumption rate (per hour) +sum(rate(dify_tokens_total{operation_type!="node_execution"}[1h])) + +# Per-app consumption rate +sum by (app_id) (rate(dify_tokens_total{operation_type!="node_execution"}[1h])) +``` + +**Finding `app_id` from app name** (trace query — Tempo / Jaeger): + +``` +{ resource.dify.app.name = "My Chatbot" } | select(resource.dify.app.id) +``` + +### Request Counters + +| Metric | Unit | Description | +|--------|------|-------------| +| `dify.requests.total` | `{request}` | Total operations count | + +**Labels by type:** + +| `type` | Additional Labels | +|--------|-------------------| +| `workflow` | `tenant_id`, `app_id`, `status`, `invoke_from` | +| `node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `status` | +| `draft_node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `status` | +| `message` | `tenant_id`, `app_id`, `model_provider`, `model_name`, `status`, `invoke_from` | +| `tool` | `tenant_id`, `app_id`, `tool_name` | +| `moderation` | `tenant_id`, `app_id` | +| `suggested_question` | `tenant_id`, `app_id`, `model_provider`, `model_name` | +| `dataset_retrieval` | `tenant_id`, `app_id` | +| `generate_name` | `tenant_id`, `app_id` | +| `prompt_generation` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `status` | + +### Error Counters + +| Metric | Unit | Description | +|--------|------|-------------| +| `dify.errors.total` | `{error}` | Total failed operations | + +**Labels by type:** + +| `type` | Additional Labels | +|--------|-------------------| +| `workflow` | `tenant_id`, `app_id` | +| `node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name` | +| `draft_node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name` | +| `message` | `tenant_id`, `app_id`, `model_provider`, `model_name` | +| `tool` | `tenant_id`, `app_id`, `tool_name` | +| `prompt_generation` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name` | + +### Other Counters + +| Metric | Unit | Labels | +|--------|------|--------| +| `dify.feedback.total` | `{feedback}` | `tenant_id`, `app_id`, `rating` | +| `dify.dataset.retrievals.total` | `{retrieval}` | `tenant_id`, `app_id`, `dataset_id`, `embedding_model_provider`, `embedding_model`, `rerank_model_provider`, `rerank_model` | +| `dify.app.created.total` | `{app}` | `tenant_id`, `app_id`, `mode` | +| `dify.app.updated.total` | `{app}` | `tenant_id`, `app_id` | +| `dify.app.deleted.total` | `{app}` | `tenant_id`, `app_id` | + +## Histograms + +| Metric | Unit | Labels | +|--------|------|--------| +| `dify.workflow.duration` | `s` | `tenant_id`, `app_id`, `status` | +| `dify.node.duration` | `s` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `plugin_name` | +| `dify.message.duration` | `s` | `tenant_id`, `app_id`, `model_provider`, `model_name` | +| `dify.message.time_to_first_token` | `s` | `tenant_id`, `app_id`, `model_provider`, `model_name` | +| `dify.tool.duration` | `s` | `tenant_id`, `app_id`, `tool_name` | +| `dify.prompt_generation.duration` | `s` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name` | + +## Structured Logs + +### Span Companion Logs + +Logs that accompany spans. Signal type: `span_detail` + +#### `dify.workflow.run` Companion Log + +**Common attributes:** All span attributes (see Traces section) plus: + +| Additional Attribute | Type | Always Present | Description | +|---------------------|------|----------------|-------------| +| `dify.app.name` | string | No | Application display name | +| `dify.workspace.name` | string | No | Workspace display name | +| `dify.workflow.version` | string | Yes | Workflow definition version | +| `dify.workflow.inputs` | string/JSON | Yes | Input parameters (content-gated) | +| `dify.workflow.outputs` | string/JSON | Yes | Output results (content-gated) | +| `dify.workflow.query` | string | No | User query text (content-gated) | + +**Event attributes:** + +- `dify.event.name`: `"dify.workflow.run"` +- `dify.event.signal`: `"span_detail"` +- `trace_id`, `span_id`, `tenant_id`, `user_id` + +#### `dify.node.execution` and `dify.node.execution.draft` Companion Logs + +**Common attributes:** All span attributes (see Traces section) plus: + +| Additional Attribute | Type | Always Present | Description | +|---------------------|------|----------------|-------------| +| `dify.app.name` | string | No | Application display name | +| `dify.workspace.name` | string | No | Workspace display name | +| `dify.invoke_from` | string | No | Invocation source | +| `gen_ai.tool.name` | string | No | Tool name (tool nodes only) | +| `dify.node.total_price` | float | No | Cost (LLM nodes only) | +| `dify.node.currency` | string | No | Currency code (LLM nodes only) | +| `dify.node.iteration_index` | int | No | Iteration index (iteration nodes) | +| `dify.node.loop_index` | int | No | Loop index (loop nodes) | +| `dify.plugin.name` | string | No | Plugin name (tool/knowledge nodes) | +| `dify.credential.name` | string | No | Credential name (plugin nodes) | +| `dify.credential.id` | string | No | Credential ID (plugin nodes) | +| `dify.dataset.ids` | JSON array | No | Dataset IDs (knowledge nodes) | +| `dify.dataset.names` | JSON array | No | Dataset names (knowledge nodes) | +| `dify.node.inputs` | string/JSON | Yes | Node inputs (content-gated) | +| `dify.node.outputs` | string/JSON | Yes | Node outputs (content-gated) | +| `dify.node.process_data` | string/JSON | No | Processing data (content-gated) | + +**Event attributes:** + +- `dify.event.name`: `"dify.node.execution"` or `"dify.node.execution.draft"` +- `dify.event.signal`: `"span_detail"` +- `trace_id`, `span_id`, `tenant_id`, `user_id` + +### Standalone Logs + +Logs without structural spans. Signal type: `metric_only` + +#### `dify.message.run` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.message.run"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID (32-char hex) | +| `span_id` | string | OTEL span ID (16-char hex) | +| `tenant_id` | string | Tenant identifier | +| `user_id` | string | User identifier (optional) | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.conversation.id` | string | Conversation ID (optional) | +| `dify.workflow.run_id` | string | Workflow run ID (optional) | +| `dify.invoke_from` | string | `service-api`, `web-app`, `debugger`, `explore` | +| `gen_ai.provider.name` | string | LLM provider | +| `gen_ai.request.model` | string | LLM model | +| `gen_ai.usage.input_tokens` | int | Input tokens | +| `gen_ai.usage.output_tokens` | int | Output tokens | +| `gen_ai.usage.total_tokens` | int | Total tokens | +| `dify.message.status` | string | `succeeded`, `failed` | +| `dify.message.error` | string | Error message (if failed) | +| `dify.message.duration` | float | Duration (seconds) | +| `dify.message.time_to_first_token` | float | TTFT (seconds) | +| `dify.message.inputs` | string/JSON | Inputs (content-gated) | +| `dify.message.outputs` | string/JSON | Outputs (content-gated) | + +#### `dify.tool.execution` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.tool.execution"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.tool.name` | string | Tool name | +| `dify.tool.duration` | float | Duration (seconds) | +| `dify.tool.status` | string | `succeeded`, `failed` | +| `dify.tool.error` | string | Error message (if failed) | +| `dify.tool.inputs` | string/JSON | Inputs (content-gated) | +| `dify.tool.outputs` | string/JSON | Outputs (content-gated) | +| `dify.tool.parameters` | string/JSON | Parameters (content-gated) | +| `dify.tool.config` | string/JSON | Configuration (content-gated) | + +#### `dify.moderation.check` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.moderation.check"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.moderation.type` | string | `input`, `output` | +| `dify.moderation.action` | string | `pass`, `block`, `flag` | +| `dify.moderation.flagged` | boolean | Whether flagged | +| `dify.moderation.categories` | JSON array | Flagged categories | +| `dify.moderation.query` | string | Content (content-gated) | + +#### `dify.suggested_question.generation` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.suggested_question.generation"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.suggested_question.count` | int | Number of questions | +| `dify.suggested_question.duration` | float | Duration (seconds) | +| `dify.suggested_question.status` | string | `succeeded`, `failed` | +| `dify.suggested_question.error` | string | Error message (if failed) | +| `dify.suggested_question.questions` | JSON array | Questions (content-gated) | + +#### `dify.dataset.retrieval` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.dataset.retrieval"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.dataset.id` | string | Dataset identifier | +| `dify.dataset.name` | string | Dataset name | +| `dify.dataset.embedding_providers` | JSON array | Embedding model providers (one per dataset) | +| `dify.dataset.embedding_models` | JSON array | Embedding models (one per dataset) | +| `dify.retrieval.rerank_provider` | string | Rerank model provider | +| `dify.retrieval.rerank_model` | string | Rerank model name | +| `dify.retrieval.query` | string | Search query (content-gated) | +| `dify.retrieval.document_count` | int | Documents retrieved | +| `dify.retrieval.duration` | float | Duration (seconds) | +| `dify.retrieval.status` | string | `succeeded`, `failed` | +| `dify.retrieval.error` | string | Error message (if failed) | +| `dify.dataset.documents` | JSON array | Documents (content-gated) | + +#### `dify.generate_name.execution` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.generate_name.execution"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.conversation.id` | string | Conversation identifier | +| `dify.generate_name.duration` | float | Duration (seconds) | +| `dify.generate_name.status` | string | `succeeded`, `failed` | +| `dify.generate_name.error` | string | Error message (if failed) | +| `dify.generate_name.inputs` | string/JSON | Inputs (content-gated) | +| `dify.generate_name.outputs` | string | Generated name (content-gated) | + +#### `dify.prompt_generation.execution` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.prompt_generation.execution"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.prompt_generation.operation_type` | string | Operation type (see appendix) | +| `gen_ai.provider.name` | string | LLM provider | +| `gen_ai.request.model` | string | LLM model | +| `gen_ai.usage.input_tokens` | int | Input tokens | +| `gen_ai.usage.output_tokens` | int | Output tokens | +| `gen_ai.usage.total_tokens` | int | Total tokens | +| `dify.prompt_generation.duration` | float | Duration (seconds) | +| `dify.prompt_generation.status` | string | `succeeded`, `failed` | +| `dify.prompt_generation.error` | string | Error message (if failed) | +| `dify.prompt_generation.instruction` | string | Instruction (content-gated) | +| `dify.prompt_generation.output` | string/JSON | Output (content-gated) | + +#### `dify.app.created` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.app.created"` | +| `dify.event.signal` | string | `"metric_only"` | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.app.mode` | string | `chat`, `completion`, `agent-chat`, `workflow` | +| `dify.app.created_at` | string | Timestamp (ISO 8601) | + +#### `dify.app.updated` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.app.updated"` | +| `dify.event.signal` | string | `"metric_only"` | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.app.updated_at` | string | Timestamp (ISO 8601) | + +#### `dify.app.deleted` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.app.deleted"` | +| `dify.event.signal` | string | `"metric_only"` | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.app.deleted_at` | string | Timestamp (ISO 8601) | + +#### `dify.feedback.created` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.feedback.created"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.feedback.rating` | string | `like`, `dislike`, `null` | +| `dify.feedback.content` | string | Feedback text (content-gated) | +| `dify.feedback.created_at` | string | Timestamp (ISO 8601) | + +#### `dify.telemetry.rehydration_failed` + +Diagnostic event for telemetry system health monitoring. + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.telemetry.rehydration_failed"` | +| `dify.event.signal` | string | `"metric_only"` | +| `tenant_id` | string | Tenant identifier | +| `dify.telemetry.error` | string | Error message | +| `dify.telemetry.payload_type` | string | Payload type (see appendix) | +| `dify.telemetry.correlation_id` | string | Correlation ID | + +## Content-Gated Attributes + +When `ENTERPRISE_INCLUDE_CONTENT=false`, these attributes are replaced with reference strings (`ref:{id_type}={uuid}`). + +| Attribute | Signal | +|-----------|--------| +| `dify.workflow.inputs` | `dify.workflow.run` | +| `dify.workflow.outputs` | `dify.workflow.run` | +| `dify.workflow.query` | `dify.workflow.run` | +| `dify.node.inputs` | `dify.node.execution` | +| `dify.node.outputs` | `dify.node.execution` | +| `dify.node.process_data` | `dify.node.execution` | +| `dify.message.inputs` | `dify.message.run` | +| `dify.message.outputs` | `dify.message.run` | +| `dify.tool.inputs` | `dify.tool.execution` | +| `dify.tool.outputs` | `dify.tool.execution` | +| `dify.tool.parameters` | `dify.tool.execution` | +| `dify.tool.config` | `dify.tool.execution` | +| `dify.moderation.query` | `dify.moderation.check` | +| `dify.suggested_question.questions` | `dify.suggested_question.generation` | +| `dify.retrieval.query` | `dify.dataset.retrieval` | +| `dify.dataset.documents` | `dify.dataset.retrieval` | +| `dify.generate_name.inputs` | `dify.generate_name.execution` | +| `dify.generate_name.outputs` | `dify.generate_name.execution` | +| `dify.prompt_generation.instruction` | `dify.prompt_generation.execution` | +| `dify.prompt_generation.output` | `dify.prompt_generation.execution` | +| `dify.feedback.content` | `dify.feedback.created` | + +## Appendix + +### Operation Types + +- `workflow`, `node_execution`, `message`, `rule_generate`, `code_generate`, `structured_output`, `instruction_modify` + +### Node Types + +- `start`, `end`, `answer`, `llm`, `knowledge-retrieval`, `knowledge-index`, `if-else`, `code`, `template-transform`, `question-classifier`, `http-request`, `tool`, `datasource`, `variable-aggregator`, `loop`, `iteration`, `parameter-extractor`, `assigner`, `document-extractor`, `list-operator`, `agent`, `trigger-webhook`, `trigger-schedule`, `trigger-plugin`, `human-input` + +### Workflow Statuses + +- `running`, `succeeded`, `failed`, `stopped`, `partial-succeeded`, `paused` + +### Payload Types + +- `workflow`, `node`, `message`, `tool`, `moderation`, `suggested_question`, `dataset_retrieval`, `generate_name`, `prompt_generation`, `app`, `feedback` + +### Null Value Behavior + +**Spans:** Attributes with `null` values are omitted. + +**Logs:** Attributes with `null` values appear as `null` in JSON. + +**Content-Gated:** Replaced with reference strings, not set to `null`. diff --git a/api/enterprise/telemetry/README.md b/api/enterprise/telemetry/README.md new file mode 100644 index 00000000000..e43c0b1ea29 --- /dev/null +++ b/api/enterprise/telemetry/README.md @@ -0,0 +1,121 @@ +# Dify Enterprise Telemetry + +This document provides an overview of the Dify Enterprise OpenTelemetry (OTEL) exporter and how to configure it for integration with observability stacks like Prometheus, Grafana, Jaeger, or Honeycomb. + +## Overview + +Dify Enterprise uses a "slim span + rich companion log" architecture to provide high-fidelity observability without overwhelming trace storage. + +- **Traces (Spans)**: Capture the structure, identity, and timing of high-level operations (Workflows and Nodes). +- **Structured Logs**: Provide deep context (inputs, outputs, metadata) for every event, correlated to spans via `trace_id` and `span_id`. +- **Metrics**: Provide 100% accurate counters and histograms for usage, performance, and error tracking. + +### Signal Architecture + +```mermaid +graph TD + A[Workflow Run] -->|Span| B(dify.workflow.run) + A -->|Log| C(dify.workflow.run detail) + B ---|trace_id| C + + D[Node Execution] -->|Span| E(dify.node.execution) + D -->|Log| F(dify.node.execution detail) + E ---|span_id| F + + G[Message/Tool/etc] -->|Log| H(dify.* event) + G -->|Metric| I(dify.* counter/histogram) +``` + +## Configuration + +The Enterprise OTEL exporter is configured via environment variables. + +| Variable | Description | Default | +|----------|-------------|---------| +| `ENTERPRISE_ENABLED` | Master switch for all enterprise features. | `false` | +| `ENTERPRISE_TELEMETRY_ENABLED` | Master switch for enterprise telemetry. | `false` | +| `ENTERPRISE_OTLP_ENDPOINT` | OTLP collector endpoint (e.g., `http://otel-collector:4318`). | - | +| `ENTERPRISE_OTLP_HEADERS` | Custom headers for OTLP requests (e.g., `x-scope-orgid=tenant1`). | - | +| `ENTERPRISE_OTLP_PROTOCOL` | OTLP transport protocol (`http` or `grpc`). | `http` | +| `ENTERPRISE_OTLP_API_KEY` | Bearer token for authentication. | - | +| `ENTERPRISE_INCLUDE_CONTENT` | Whether to include sensitive content (inputs/outputs) in logs. | `false` | +| `ENTERPRISE_SERVICE_NAME` | Service name reported to OTEL. | `dify` | +| `ENTERPRISE_OTEL_SAMPLING_RATE` | Sampling rate for traces (0.0 to 1.0). Metrics are always 100%. | `1.0` | + +## Correlation Model + +Dify uses deterministic ID generation to ensure signals are correlated across different services and asynchronous tasks. + +### ID Generation Rules + +- `trace_id`: Derived from the correlation ID (workflow_run_id or node_execution_id for drafts) using `int(UUID(correlation_id))` +- `span_id`: Derived from the source ID using the lower 64 bits of `UUID(source_id)` + +### Scenario A: Simple Workflow + +A single workflow run with multiple nodes. All spans and logs share the same `trace_id` (derived from `workflow_run_id`). + +``` +trace_id = UUID(workflow_run_id) +├── [root span] dify.workflow.run (span_id = hash(workflow_run_id)) +│ ├── [child] dify.node.execution - "Start" (span_id = hash(node_exec_id_1)) +│ ├── [child] dify.node.execution - "LLM" (span_id = hash(node_exec_id_2)) +│ └── [child] dify.node.execution - "End" (span_id = hash(node_exec_id_3)) +``` + +### Scenario B: Nested Sub-Workflow + +A workflow calling another workflow via a Tool or Sub-workflow node. The child workflow's spans are linked to the parent via `parent_span_id`. Both workflows share the same trace_id. + +``` +trace_id = UUID(outer_workflow_run_id) ← shared across both workflows +├── [root] dify.workflow.run (outer) (span_id = hash(outer_workflow_run_id)) +│ ├── dify.node.execution - "Start Node" +│ ├── dify.node.execution - "Tool Node" (triggers sub-workflow) +│ │ └── [child] dify.workflow.run (inner) (span_id = hash(inner_workflow_run_id)) +│ │ ├── dify.node.execution - "Inner Start" +│ │ └── dify.node.execution - "Inner End" +│ └── dify.node.execution - "End Node" +``` + +**Key attributes for nested workflows:** + +- Inner workflow's `dify.parent.trace_id` = outer `workflow_run_id` +- Inner workflow's `dify.parent.node.execution_id` = tool node's `execution_id` +- Inner workflow's `dify.parent.workflow.run_id` = outer `workflow_run_id` +- Inner workflow's `dify.parent.app.id` = outer `app_id` + +### Scenario C: Draft Node Execution + +A single node run in isolation (debugger/preview mode). It creates its own trace where the node span is the root. + +``` +trace_id = UUID(node_execution_id) ← own trace, NOT part of any workflow +└── dify.node.execution.draft (span_id = hash(node_execution_id)) +``` + +**Key difference:** Draft executions use `node_execution_id` as the correlation_id, so they are NOT children of any workflow trace. + +## Content Gating + +When `ENTERPRISE_INCLUDE_CONTENT` is set to `false`, sensitive content attributes (inputs, outputs, queries) are replaced with reference strings (e.g., `ref:workflow_run_id=...`) to prevent data leakage to the OTEL collector. + +**Reference String Format:** + +``` +ref:{id_type}={uuid} +``` + +**Examples:** + +``` +ref:workflow_run_id=550e8400-e29b-41d4-a716-446655440000 +ref:node_execution_id=660e8400-e29b-41d4-a716-446655440001 +ref:message_id=770e8400-e29b-41d4-a716-446655440002 +``` + +To retrieve actual content when gating is enabled, query the Dify database using the provided UUID. + +## Reference + +For a complete list of telemetry signals, attributes, and data structures, see [DATA_DICTIONARY.md](./DATA_DICTIONARY.md). diff --git a/api/dify_graph/graph_engine/entities/__init__.py b/api/enterprise/telemetry/__init__.py similarity index 100% rename from api/dify_graph/graph_engine/entities/__init__.py rename to api/enterprise/telemetry/__init__.py diff --git a/api/enterprise/telemetry/contracts.py b/api/enterprise/telemetry/contracts.py new file mode 100644 index 00000000000..91398cb8cb1 --- /dev/null +++ b/api/enterprise/telemetry/contracts.py @@ -0,0 +1,73 @@ +"""Telemetry gateway contracts and data structures. + +This module defines the envelope format for telemetry events and the routing +configuration that determines how each event type is processed. +""" + +from __future__ import annotations + +from enum import StrEnum +from typing import Any + +from pydantic import BaseModel, ConfigDict + + +class TelemetryCase(StrEnum): + """Enumeration of all known telemetry event cases.""" + + WORKFLOW_RUN = "workflow_run" + NODE_EXECUTION = "node_execution" + DRAFT_NODE_EXECUTION = "draft_node_execution" + MESSAGE_RUN = "message_run" + TOOL_EXECUTION = "tool_execution" + MODERATION_CHECK = "moderation_check" + SUGGESTED_QUESTION = "suggested_question" + DATASET_RETRIEVAL = "dataset_retrieval" + GENERATE_NAME = "generate_name" + PROMPT_GENERATION = "prompt_generation" + APP_CREATED = "app_created" + APP_UPDATED = "app_updated" + APP_DELETED = "app_deleted" + FEEDBACK_CREATED = "feedback_created" + + +class SignalType(StrEnum): + """Signal routing type for telemetry cases.""" + + TRACE = "trace" + METRIC_LOG = "metric_log" + + +class CaseRoute(BaseModel): + """Routing configuration for a telemetry case. + + Attributes: + signal_type: The type of signal (trace or metric_log). + ce_eligible: Whether this case is eligible for community edition tracing. + """ + + signal_type: SignalType + ce_eligible: bool + + +class TelemetryEnvelope(BaseModel): + """Envelope for telemetry events. + + Attributes: + case: The telemetry case type. + tenant_id: The tenant identifier. + event_id: Unique event identifier for deduplication. + payload: The main event payload (inline for small payloads, + empty when offloaded to storage via ``payload_ref``). + metadata: Optional metadata dictionary. When the gateway + offloads a large payload to object storage, this contains + ``{"payload_ref": ""}``. + """ + + model_config = ConfigDict(extra="forbid", use_enum_values=False) + + case: TelemetryCase + tenant_id: str + event_id: str + payload: dict[str, Any] + metadata: dict[str, Any] | None = None diff --git a/api/enterprise/telemetry/draft_trace.py b/api/enterprise/telemetry/draft_trace.py new file mode 100644 index 00000000000..5a8d0ee6f49 --- /dev/null +++ b/api/enterprise/telemetry/draft_trace.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +from graphon.enums import WorkflowNodeExecutionMetadataKey + +from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName +from core.telemetry import emit as telemetry_emit +from models.workflow import WorkflowNodeExecutionModel + + +def enqueue_draft_node_execution_trace( + *, + execution: WorkflowNodeExecutionModel, + outputs: Mapping[str, Any] | None, + workflow_execution_id: str | None, + user_id: str, +) -> None: + node_data = _build_node_execution_data( + execution=execution, + outputs=outputs, + workflow_execution_id=workflow_execution_id, + ) + telemetry_emit( + TelemetryEvent( + name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id=execution.tenant_id, + user_id=user_id, + app_id=execution.app_id, + ), + payload={"node_execution_data": node_data}, + ) + ) + + +def _build_node_execution_data( + *, + execution: WorkflowNodeExecutionModel, + outputs: Mapping[str, Any] | None, + workflow_execution_id: str | None, +) -> dict[str, Any]: + metadata = execution.execution_metadata_dict + node_outputs = outputs if outputs is not None else execution.outputs_dict + execution_id = workflow_execution_id or execution.workflow_run_id or execution.id + process_data = execution.process_data_dict or {} + + # Extract token breakdown from outputs.usage (set by LLM node) + usage: Mapping[str, Any] = {} + if isinstance(node_outputs, Mapping): + raw_usage = node_outputs.get("usage") + if isinstance(raw_usage, Mapping): + usage = raw_usage + + return { + "workflow_id": execution.workflow_id, + "workflow_execution_id": execution_id, + "tenant_id": execution.tenant_id, + "app_id": execution.app_id, + "node_execution_id": execution.id, + "node_id": execution.node_id, + "node_type": execution.node_type, + "title": execution.title, + "status": execution.status, + "error": execution.error, + "elapsed_time": execution.elapsed_time, + "index": execution.index, + "predecessor_node_id": execution.predecessor_node_id, + "created_at": execution.created_at, + "finished_at": execution.finished_at, + "total_tokens": metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS, 0), + "total_price": metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_PRICE, 0.0), + "currency": metadata.get(WorkflowNodeExecutionMetadataKey.CURRENCY), + "model_provider": process_data.get("model_provider"), + "model_name": process_data.get("model_name"), + "prompt_tokens": usage.get("prompt_tokens"), + "completion_tokens": usage.get("completion_tokens"), + "tool_name": (metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO) or {}).get("tool_name") + if isinstance(metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO), dict) + else None, + "iteration_id": metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_ID), + "iteration_index": metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_INDEX), + "loop_id": metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_ID), + "loop_index": metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_INDEX), + "parallel_id": metadata.get(WorkflowNodeExecutionMetadataKey.PARALLEL_ID), + "node_inputs": execution.inputs_dict, + "node_outputs": node_outputs, + "process_data": execution.process_data_dict, + } diff --git a/api/enterprise/telemetry/enterprise_trace.py b/api/enterprise/telemetry/enterprise_trace.py new file mode 100644 index 00000000000..fc17d9d93ee --- /dev/null +++ b/api/enterprise/telemetry/enterprise_trace.py @@ -0,0 +1,966 @@ +"""Enterprise trace handler — duck-typed, NOT a BaseTraceInstance subclass. + +Invoked directly in the Celery task, not through OpsTraceManager dispatch. +Only requires a matching ``trace(trace_info)`` method signature. + +Signal strategy: +- **Traces (spans)**: workflow run, node execution, draft node execution only. +- **Metrics + structured logs**: all other event types. + +Token metric labels (unified structure): +All token metrics (dify.tokens.input, dify.tokens.output, dify.tokens.total) use the +same label set for consistent filtering and aggregation: +- tenant_id: Tenant identifier +- app_id: Application identifier +- operation_type: Source of token usage (workflow | node_execution | message | rule_generate | etc.) +- model_provider: LLM provider name (empty string if not applicable) +- model_name: LLM model name (empty string if not applicable) +- node_type: Workflow node type (empty string if not node_execution) + +This unified structure allows filtering by operation_type to separate: +- Workflow-level aggregates (operation_type=workflow) +- Individual node executions (operation_type=node_execution) +- Direct message calls (operation_type=message) +- Prompt generation operations (operation_type=rule_generate, code_generate, etc.) + +Without this, tokens are double-counted when querying totals (workflow totals include +node totals, since workflow.total_tokens is the sum of all node tokens). +""" + +from __future__ import annotations + +import json +import logging +from typing import Any, cast + +from opentelemetry.util.types import AttributeValue + +from core.ops.entities.trace_entity import ( + BaseTraceInfo, + DatasetRetrievalTraceInfo, + DraftNodeExecutionTrace, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + OperationType, + PromptGenerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + WorkflowNodeTraceInfo, + WorkflowTraceInfo, +) +from enterprise.telemetry.entities import ( + EnterpriseTelemetryCounter, + EnterpriseTelemetryEvent, + EnterpriseTelemetryHistogram, + EnterpriseTelemetrySpan, + TokenMetricLabels, +) +from enterprise.telemetry.telemetry_log import emit_metric_only_event, emit_telemetry_log + +logger = logging.getLogger(__name__) + + +class EnterpriseOtelTrace: + """Duck-typed enterprise trace handler. + + ``*_trace`` methods emit spans (workflow/node only) or structured logs + (all other events), plus metrics at 100 % accuracy. + """ + + def __init__(self) -> None: + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if exporter is None: + raise RuntimeError("EnterpriseOtelTrace instantiated but exporter is not initialized") + self._exporter = exporter + + def trace(self, trace_info: BaseTraceInfo) -> None: + if isinstance(trace_info, WorkflowTraceInfo): + self._workflow_trace(trace_info) + elif isinstance(trace_info, MessageTraceInfo): + self._message_trace(trace_info) + elif isinstance(trace_info, ToolTraceInfo): + self._tool_trace(trace_info) + elif isinstance(trace_info, DraftNodeExecutionTrace): + self._draft_node_execution_trace(trace_info) + elif isinstance(trace_info, WorkflowNodeTraceInfo): + self._node_execution_trace(trace_info) + elif isinstance(trace_info, ModerationTraceInfo): + self._moderation_trace(trace_info) + elif isinstance(trace_info, SuggestedQuestionTraceInfo): + self._suggested_question_trace(trace_info) + elif isinstance(trace_info, DatasetRetrievalTraceInfo): + self._dataset_retrieval_trace(trace_info) + elif isinstance(trace_info, GenerateNameTraceInfo): + self._generate_name_trace(trace_info) + elif isinstance(trace_info, PromptGenerationTraceInfo): + self._prompt_generation_trace(trace_info) + else: + raise AssertionError("this statment should be unreachable") + + def _common_attrs(self, trace_info: BaseTraceInfo) -> dict[str, Any]: + metadata = self._metadata(trace_info) + tenant_id, app_id, user_id = self._context_ids(trace_info, metadata) + return { + "dify.trace_id": trace_info.resolved_trace_id, + "dify.tenant_id": tenant_id, + "dify.app_id": app_id, + "dify.app.name": metadata.get("app_name"), + "dify.workspace.name": metadata.get("workspace_name"), + "gen_ai.user.id": user_id, + "dify.message.id": trace_info.message_id, + } + + def _metadata(self, trace_info: BaseTraceInfo) -> dict[str, Any]: + return trace_info.metadata + + def _context_ids( + self, + trace_info: BaseTraceInfo, + metadata: dict[str, Any], + ) -> tuple[str | None, str | None, str | None]: + tenant_id = getattr(trace_info, "tenant_id", None) or metadata.get("tenant_id") + app_id = getattr(trace_info, "app_id", None) or metadata.get("app_id") + user_id = getattr(trace_info, "user_id", None) or metadata.get("user_id") + return tenant_id, app_id, user_id + + def _labels(self, **values: AttributeValue) -> dict[str, AttributeValue]: + return dict(values) + + def _safe_payload_value(self, value: Any) -> str | dict[str, Any] | list[object] | None: + if isinstance(value, str): + return value + if isinstance(value, dict): + return cast(dict[str, Any], value) + if isinstance(value, list): + items: list[object] = [] + for item in cast(list[object], value): + items.append(item) + return items + return None + + def _content_or_ref(self, value: Any, ref: str) -> Any: + if self._exporter.include_content: + return self._maybe_json(value) + return ref + + def _maybe_json(self, value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + try: + return json.dumps(value, default=str) + except (TypeError, ValueError): + return str(value) + + # ------------------------------------------------------------------ + # SPAN-emitting handlers (workflow, node execution, draft node) + # ------------------------------------------------------------------ + + def _workflow_trace(self, info: WorkflowTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + # -- Span attrs: identity + structure + status + timing + gen_ai scalars -- + span_attrs: dict[str, Any] = { + "dify.trace_id": info.resolved_trace_id, + "dify.tenant_id": tenant_id, + "dify.app_id": app_id, + "dify.workflow.id": info.workflow_id, + "dify.workflow.run_id": info.workflow_run_id, + "dify.workflow.status": info.workflow_run_status, + "dify.workflow.error": info.error, + "dify.workflow.elapsed_time": info.workflow_run_elapsed_time, + "dify.invoke_from": metadata.get("triggered_from"), + "dify.conversation.id": info.conversation_id, + "dify.message.id": info.message_id, + "dify.invoked_by": info.invoked_by, + "gen_ai.usage.total_tokens": info.total_tokens, + "gen_ai.user.id": user_id, + } + + trace_correlation_override, parent_span_id_source = info.resolved_parent_context + + parent_ctx = metadata.get("parent_trace_context") + if isinstance(parent_ctx, dict): + parent_ctx_dict = cast(dict[str, Any], parent_ctx) + span_attrs["dify.parent.trace_id"] = parent_ctx_dict.get("trace_id") + span_attrs["dify.parent.node.execution_id"] = parent_ctx_dict.get("parent_node_execution_id") + span_attrs["dify.parent.workflow.run_id"] = parent_ctx_dict.get("parent_workflow_run_id") + span_attrs["dify.parent.app.id"] = parent_ctx_dict.get("parent_app_id") + + self._exporter.export_span( + EnterpriseTelemetrySpan.WORKFLOW_RUN, + span_attrs, + correlation_id=info.workflow_run_id, + span_id_source=info.workflow_run_id, + start_time=info.start_time, + end_time=info.end_time, + trace_correlation_override=trace_correlation_override, + parent_span_id_source=parent_span_id_source, + ) + + # -- Companion log: ALL attrs (span + detail) for full picture -- + log_attrs: dict[str, Any] = {**span_attrs} + log_attrs.update( + { + "dify.app.name": metadata.get("app_name"), + "dify.workspace.name": metadata.get("workspace_name"), + "gen_ai.user.id": user_id, + "gen_ai.usage.total_tokens": info.total_tokens, + "dify.workflow.version": info.workflow_run_version, + } + ) + + ref = f"ref:workflow_run_id={info.workflow_run_id}" + log_attrs["dify.workflow.inputs"] = self._content_or_ref(info.workflow_run_inputs, ref) + log_attrs["dify.workflow.outputs"] = self._content_or_ref(info.workflow_run_outputs, ref) + log_attrs["dify.workflow.query"] = self._content_or_ref(info.query, ref) + + emit_telemetry_log( + event_name=EnterpriseTelemetryEvent.WORKFLOW_RUN, + attributes=log_attrs, + signal="span_detail", + trace_id_source=info.workflow_run_id, + span_id_source=info.workflow_run_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + # -- Metrics -- + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + ) + token_labels = TokenMetricLabels( + tenant_id=tenant_id or "", + app_id=app_id or "", + operation_type=OperationType.WORKFLOW, + model_provider="", + model_name="", + node_type="", + ).to_dict() + self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels) + if info.prompt_tokens is not None and info.prompt_tokens > 0: + self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels) + if info.completion_tokens is not None and info.completion_tokens > 0: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels + ) + invoke_from = metadata.get("triggered_from", "") + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="workflow", + status=info.workflow_run_status, + invoke_from=invoke_from, + ), + ) + # Prefer wall-clock timestamps over the elapsed_time field: elapsed_time defaults + # to 0 in the DB and can be stale if the Celery write races with the trace task. + # start_time = workflow_run.created_at, end_time = workflow_run.finished_at. + if info.start_time and info.end_time: + workflow_duration = (info.end_time - info.start_time).total_seconds() + elif info.workflow_run_elapsed_time: + workflow_duration = float(info.workflow_run_elapsed_time) + else: + workflow_duration = 0.0 + self._exporter.record_histogram( + EnterpriseTelemetryHistogram.WORKFLOW_DURATION, + workflow_duration, + self._labels( + **labels, + status=info.workflow_run_status, + ), + ) + + if info.error: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.ERRORS, + 1, + self._labels( + **labels, + type="workflow", + ), + ) + + def _node_execution_trace(self, info: WorkflowNodeTraceInfo) -> None: + self._emit_node_execution_trace(info, EnterpriseTelemetrySpan.NODE_EXECUTION, "node") + + def _draft_node_execution_trace(self, info: DraftNodeExecutionTrace) -> None: + self._emit_node_execution_trace( + info, + EnterpriseTelemetrySpan.DRAFT_NODE_EXECUTION, + "draft_node", + correlation_id_override=info.node_execution_id, + trace_correlation_override_param=info.workflow_run_id, + ) + + def _emit_node_execution_trace( + self, + info: WorkflowNodeTraceInfo, + span_name: EnterpriseTelemetrySpan, + request_type: str, + correlation_id_override: str | None = None, + trace_correlation_override_param: str | None = None, + ) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + # -- Span attrs: identity + structure + status + timing + gen_ai scalars -- + span_attrs: dict[str, Any] = { + "dify.trace_id": info.resolved_trace_id, + "dify.tenant_id": tenant_id, + "dify.app_id": app_id, + "dify.workflow.id": info.workflow_id, + "dify.workflow.run_id": info.workflow_run_id, + "dify.message.id": info.message_id, + "dify.conversation.id": metadata.get("conversation_id"), + "dify.node.execution_id": info.node_execution_id, + "dify.node.id": info.node_id, + "dify.node.type": info.node_type, + "dify.node.title": info.title, + "dify.node.status": info.status, + "dify.node.error": info.error, + "dify.node.elapsed_time": info.elapsed_time, + "dify.node.index": info.index, + "dify.node.predecessor_node_id": info.predecessor_node_id, + "dify.node.iteration_id": info.iteration_id, + "dify.node.loop_id": info.loop_id, + "dify.node.parallel_id": info.parallel_id, + "dify.node.invoked_by": info.invoked_by, + "gen_ai.usage.input_tokens": info.prompt_tokens, + "gen_ai.usage.output_tokens": info.completion_tokens, + "gen_ai.usage.total_tokens": info.total_tokens, + "gen_ai.request.model": info.model_name, + "gen_ai.provider.name": info.model_provider, + "gen_ai.user.id": user_id, + } + + resolved_override, _ = info.resolved_parent_context + trace_correlation_override = trace_correlation_override_param or resolved_override + + effective_correlation_id = correlation_id_override or info.workflow_run_id + self._exporter.export_span( + span_name, + span_attrs, + correlation_id=effective_correlation_id, + span_id_source=info.node_execution_id, + start_time=info.start_time, + end_time=info.end_time, + trace_correlation_override=trace_correlation_override, + ) + + # -- Companion log: ALL attrs (span + detail) -- + log_attrs: dict[str, Any] = {**span_attrs} + log_attrs.update( + { + "dify.app.name": metadata.get("app_name"), + "dify.workspace.name": metadata.get("workspace_name"), + "dify.invoke_from": metadata.get("invoke_from"), + "gen_ai.user.id": user_id, + "gen_ai.usage.total_tokens": info.total_tokens, + "dify.node.total_price": info.total_price, + "dify.node.currency": info.currency, + "gen_ai.provider.name": info.model_provider, + "gen_ai.request.model": info.model_name, + "gen_ai.tool.name": info.tool_name, + "dify.node.iteration_index": info.iteration_index, + "dify.node.loop_index": info.loop_index, + "dify.plugin.name": metadata.get("plugin_name"), + "dify.credential.name": metadata.get("credential_name"), + "dify.credential.id": metadata.get("credential_id"), + "dify.dataset.ids": self._maybe_json(metadata.get("dataset_ids")), + "dify.dataset.names": self._maybe_json(metadata.get("dataset_names")), + } + ) + + ref = f"ref:node_execution_id={info.node_execution_id}" + log_attrs["dify.node.inputs"] = self._content_or_ref(info.node_inputs, ref) + log_attrs["dify.node.outputs"] = self._content_or_ref(info.node_outputs, ref) + log_attrs["dify.node.process_data"] = self._content_or_ref(info.process_data, ref) + + emit_telemetry_log( + event_name=span_name.value, + attributes=log_attrs, + signal="span_detail", + trace_id_source=info.workflow_run_id, + span_id_source=info.node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + # -- Metrics -- + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + node_type=info.node_type, + model_provider=info.model_provider or "", + ) + if info.total_tokens: + token_labels = TokenMetricLabels( + tenant_id=tenant_id or "", + app_id=app_id or "", + operation_type=OperationType.NODE_EXECUTION, + model_provider=info.model_provider or "", + model_name=info.model_name or "", + node_type=info.node_type, + ).to_dict() + self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels) + if info.prompt_tokens is not None and info.prompt_tokens > 0: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels + ) + if info.completion_tokens is not None and info.completion_tokens > 0: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type=request_type, + status=info.status, + model_name=info.model_name or "", + ), + ) + duration_labels = dict(labels) + duration_labels["model_name"] = info.model_name or "" + plugin_name = metadata.get("plugin_name") + if plugin_name and info.node_type in {"tool", "knowledge-retrieval"}: + duration_labels["plugin_name"] = plugin_name + self._exporter.record_histogram(EnterpriseTelemetryHistogram.NODE_DURATION, info.elapsed_time, duration_labels) + + if info.error: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.ERRORS, + 1, + self._labels( + **labels, + type=request_type, + model_name=info.model_name or "", + ), + ) + + # ------------------------------------------------------------------ + # METRIC-ONLY handlers (structured log + counters/histograms) + # ------------------------------------------------------------------ + + def _message_trace(self, info: MessageTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs.update( + { + "dify.invoke_from": metadata.get("from_source"), + "dify.conversation.id": metadata.get("conversation_id"), + "dify.conversation.mode": info.conversation_mode, + "gen_ai.provider.name": metadata.get("ls_provider"), + "gen_ai.request.model": metadata.get("ls_model_name"), + "gen_ai.usage.input_tokens": info.message_tokens, + "gen_ai.usage.output_tokens": info.answer_tokens, + "gen_ai.usage.total_tokens": info.total_tokens, + "dify.message.status": metadata.get("status"), + "dify.message.error": info.error, + "dify.message.from_source": metadata.get("from_source"), + "dify.message.from_end_user_id": metadata.get("from_end_user_id"), + "dify.message.from_account_id": metadata.get("from_account_id"), + "dify.streaming": info.is_streaming_request, + "dify.message.time_to_first_token": info.gen_ai_server_time_to_first_token, + "dify.message.streaming_duration": info.llm_streaming_time_to_generate, + "dify.workflow.run_id": metadata.get("workflow_run_id"), + } + ) + + if info.start_time and info.end_time: + attrs["dify.message.duration"] = (info.end_time - info.start_time).total_seconds() + + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + ref = f"ref:message_id={info.message_id}" + inputs = self._safe_payload_value(info.inputs) + outputs = self._safe_payload_value(info.outputs) + attrs["dify.message.inputs"] = self._content_or_ref(inputs, ref) + attrs["dify.message.outputs"] = self._content_or_ref(outputs, ref) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.MESSAGE_RUN, + attributes=attrs, + trace_id_source=metadata.get("workflow_run_id") or (str(info.message_id) if info.message_id else None), + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + model_provider=metadata.get("ls_provider") or "", + model_name=metadata.get("ls_model_name") or "", + ) + token_labels = TokenMetricLabels( + tenant_id=tenant_id or "", + app_id=app_id or "", + operation_type=OperationType.MESSAGE, + model_provider=metadata.get("ls_provider") or "", + model_name=metadata.get("ls_model_name") or "", + node_type="", + ).to_dict() + self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels) + if info.message_tokens > 0: + self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.message_tokens, token_labels) + if info.answer_tokens > 0: + self._exporter.increment_counter(EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.answer_tokens, token_labels) + invoke_from = metadata.get("from_source", "") + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="message", + status=metadata.get("status", ""), + invoke_from=invoke_from, + ), + ) + + if info.start_time and info.end_time: + duration = (info.end_time - info.start_time).total_seconds() + self._exporter.record_histogram(EnterpriseTelemetryHistogram.MESSAGE_DURATION, duration, labels) + + if info.gen_ai_server_time_to_first_token is not None: + self._exporter.record_histogram( + EnterpriseTelemetryHistogram.MESSAGE_TTFT, info.gen_ai_server_time_to_first_token, labels + ) + + if info.error: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.ERRORS, + 1, + self._labels( + **labels, + type="message", + ), + ) + + def _tool_trace(self, info: ToolTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs.update( + { + "dify.tool.name": info.tool_name, + "dify.tool.duration": float(info.time_cost), + "dify.tool.status": "failed" if info.error else "succeeded", + "dify.tool.error": info.error, + "dify.workflow.run_id": metadata.get("workflow_run_id"), + } + ) + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + ref = f"ref:message_id={info.message_id}" + attrs["dify.tool.inputs"] = self._content_or_ref(info.tool_inputs, ref) + attrs["dify.tool.outputs"] = self._content_or_ref(info.tool_outputs, ref) + attrs["dify.tool.parameters"] = self._content_or_ref(info.tool_parameters, ref) + attrs["dify.tool.config"] = self._content_or_ref(info.tool_config, ref) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.TOOL_EXECUTION, + attributes=attrs, + trace_id_source=info.resolved_trace_id, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + tool_name=info.tool_name, + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="tool", + ), + ) + self._exporter.record_histogram(EnterpriseTelemetryHistogram.TOOL_DURATION, float(info.time_cost), labels) + + if info.error: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.ERRORS, + 1, + self._labels( + **labels, + type="tool", + ), + ) + + def _moderation_trace(self, info: ModerationTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs.update( + { + "dify.moderation.flagged": info.flagged, + "dify.moderation.action": info.action, + "dify.moderation.preset_response": info.preset_response, + "dify.moderation.type": metadata.get("moderation_type", "input"), + "dify.moderation.categories": self._maybe_json(metadata.get("moderation_categories", [])), + "dify.workflow.run_id": metadata.get("workflow_run_id"), + } + ) + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + attrs["dify.moderation.query"] = self._content_or_ref( + info.query, + f"ref:message_id={info.message_id}", + ) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.MODERATION_CHECK, + attributes=attrs, + trace_id_source=info.resolved_trace_id, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="moderation", + ), + ) + + def _suggested_question_trace(self, info: SuggestedQuestionTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + duration: float | None = None + if info.start_time is not None and info.end_time is not None: + duration = (info.end_time - info.start_time).total_seconds() + error = info.error or (info.metadata.get("error") if info.metadata else None) + status = "failed" if error else (info.status or "succeeded") + attrs.update( + { + "gen_ai.usage.total_tokens": info.total_tokens, + "dify.suggested_question.status": status, + "dify.suggested_question.error": error, + "dify.suggested_question.duration": duration, + "gen_ai.provider.name": info.model_provider, + "gen_ai.request.model": info.model_id, + "dify.suggested_question.count": len(info.suggested_question), + "dify.workflow.run_id": metadata.get("workflow_run_id"), + } + ) + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + attrs["dify.suggested_question.questions"] = self._content_or_ref( + info.suggested_question, + f"ref:message_id={info.message_id}", + ) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.SUGGESTED_QUESTION_GENERATION, + attributes=attrs, + trace_id_source=info.resolved_trace_id, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="suggested_question", + model_provider=info.model_provider or "", + model_name=info.model_id or "", + ), + ) + + def _dataset_retrieval_trace(self, info: DatasetRetrievalTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs["dify.retrieval.error"] = info.error + attrs["dify.retrieval.status"] = "failed" if info.error else "succeeded" + if info.start_time and info.end_time: + attrs["dify.retrieval.duration"] = (info.end_time - info.start_time).total_seconds() + attrs["dify.workflow.run_id"] = metadata.get("workflow_run_id") + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + docs: list[dict[str, Any]] = [] + documents_any: Any = info.documents + documents_list: list[Any] = cast(list[Any], documents_any) if isinstance(documents_any, list) else [] + for entry in documents_list: + if isinstance(entry, dict): + entry_dict: dict[str, Any] = cast(dict[str, Any], entry) + docs.append(entry_dict) + dataset_ids: list[str] = [] + dataset_names: list[str] = [] + structured_docs: list[dict[str, Any]] = [] + for doc in docs: + meta_raw = doc.get("metadata") + meta: dict[str, Any] = cast(dict[str, Any], meta_raw) if isinstance(meta_raw, dict) else {} + did = meta.get("dataset_id") + dname = meta.get("dataset_name") + if did and did not in dataset_ids: + dataset_ids.append(did) + if dname and dname not in dataset_names: + dataset_names.append(dname) + structured_docs.append( + { + "dataset_id": did, + "document_id": meta.get("document_id"), + "segment_id": meta.get("segment_id"), + "score": meta.get("score"), + } + ) + + attrs["dify.dataset.id"] = self._maybe_json(dataset_ids) + attrs["dify.dataset.name"] = self._maybe_json(dataset_names) + attrs["dify.retrieval.document_count"] = len(docs) + + embedding_models_raw: Any = metadata.get("embedding_models") + embedding_models: dict[str, Any] = ( + cast(dict[str, Any], embedding_models_raw) if isinstance(embedding_models_raw, dict) else {} + ) + if embedding_models: + providers: list[str] = [] + models: list[str] = [] + for ds_info in embedding_models.values(): + if isinstance(ds_info, dict): + ds_info_dict: dict[str, Any] = cast(dict[str, Any], ds_info) + p = ds_info_dict.get("embedding_model_provider", "") + m = ds_info_dict.get("embedding_model", "") + if p and p not in providers: + providers.append(p) + if m and m not in models: + models.append(m) + attrs["dify.dataset.embedding_providers"] = self._maybe_json(providers) + attrs["dify.dataset.embedding_models"] = self._maybe_json(models) + + # Add rerank model to logs + rerank_provider = metadata.get("rerank_model_provider", "") + rerank_model = metadata.get("rerank_model_name", "") + if rerank_provider or rerank_model: + attrs["dify.retrieval.rerank_provider"] = rerank_provider + attrs["dify.retrieval.rerank_model"] = rerank_model + + ref = f"ref:message_id={info.message_id}" + retrieval_inputs = self._safe_payload_value(info.inputs) + attrs["dify.retrieval.query"] = self._content_or_ref(retrieval_inputs, ref) + attrs["dify.dataset.documents"] = self._content_or_ref(structured_docs, ref) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.DATASET_RETRIEVAL, + attributes=attrs, + trace_id_source=metadata.get("workflow_run_id") or (str(info.message_id) if info.message_id else None), + span_id_source=node_execution_id or (str(info.message_id) if info.message_id else None), + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="dataset_retrieval", + ), + ) + + for did in dataset_ids: + # Get embedding model for this specific dataset + ds_embedding_info = embedding_models.get(did, {}) + embedding_provider = ds_embedding_info.get("embedding_model_provider", "") + embedding_model = ds_embedding_info.get("embedding_model", "") + + # Get rerank model (same for all datasets in this retrieval) + rerank_provider = metadata.get("rerank_model_provider", "") + rerank_model = metadata.get("rerank_model_name", "") + + self._exporter.increment_counter( + EnterpriseTelemetryCounter.DATASET_RETRIEVALS, + 1, + self._labels( + **labels, + dataset_id=did, + embedding_model_provider=embedding_provider, + embedding_model=embedding_model, + rerank_model_provider=rerank_provider, + rerank_model=rerank_model, + ), + ) + + def _generate_name_trace(self, info: GenerateNameTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs["dify.conversation.id"] = info.conversation_id + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + duration: float | None = None + if info.start_time is not None and info.end_time is not None: + duration = (info.end_time - info.start_time).total_seconds() + error: str | None = metadata.get("error") if metadata else None + status = "failed" if error else "succeeded" + attrs["dify.generate_name.duration"] = duration + attrs["dify.generate_name.status"] = status + attrs["dify.generate_name.error"] = error + + ref = f"ref:conversation_id={info.conversation_id}" + inputs = self._safe_payload_value(info.inputs) + outputs = self._safe_payload_value(info.outputs) + attrs["dify.generate_name.inputs"] = self._content_or_ref(inputs, ref) + attrs["dify.generate_name.outputs"] = self._content_or_ref(outputs, ref) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.GENERATE_NAME_EXECUTION, + attributes=attrs, + trace_id_source=info.resolved_trace_id, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="generate_name", + ), + ) + + def _prompt_generation_trace(self, info: PromptGenerationTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = { + "dify.trace_id": info.resolved_trace_id, + "dify.tenant_id": tenant_id, + "gen_ai.user.id": user_id, + "dify.app_id": app_id or "", + "dify.app.name": metadata.get("app_name"), + "dify.workspace.name": metadata.get("workspace_name"), + "dify.prompt_generation.operation_type": info.operation_type, + "gen_ai.provider.name": info.model_provider, + "gen_ai.request.model": info.model_name, + "gen_ai.usage.input_tokens": info.prompt_tokens, + "gen_ai.usage.output_tokens": info.completion_tokens, + "gen_ai.usage.total_tokens": info.total_tokens, + "dify.prompt_generation.duration": info.latency, + "dify.prompt_generation.status": "failed" if info.error else "succeeded", + "dify.prompt_generation.error": info.error, + } + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + if info.total_price is not None: + attrs["dify.prompt_generation.total_price"] = info.total_price + attrs["dify.prompt_generation.currency"] = info.currency + + ref = f"ref:trace_id={info.trace_id}" + outputs = self._safe_payload_value(info.outputs) + attrs["dify.prompt_generation.instruction"] = self._content_or_ref(info.instruction, ref) + attrs["dify.prompt_generation.output"] = self._content_or_ref(outputs, ref) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.PROMPT_GENERATION_EXECUTION, + attributes=attrs, + trace_id_source=info.resolved_trace_id, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + token_labels = TokenMetricLabels( + tenant_id=tenant_id or "", + app_id=app_id or "", + operation_type=info.operation_type, + model_provider=info.model_provider, + model_name=info.model_name, + node_type="", + ).to_dict() + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + operation_type=info.operation_type, + model_provider=info.model_provider, + model_name=info.model_name, + ) + + self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels) + if info.prompt_tokens > 0: + self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels) + if info.completion_tokens > 0: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels + ) + + prompt_status = "failed" if info.error else "succeeded" + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="prompt_generation", + status=prompt_status, + ), + ) + + self._exporter.record_histogram( + EnterpriseTelemetryHistogram.PROMPT_GENERATION_DURATION, + info.latency, + labels, + ) + + if info.error: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.ERRORS, + 1, + self._labels( + **labels, + type="prompt_generation", + ), + ) diff --git a/api/enterprise/telemetry/entities/__init__.py b/api/enterprise/telemetry/entities/__init__.py new file mode 100644 index 00000000000..4a9bd3dbf80 --- /dev/null +++ b/api/enterprise/telemetry/entities/__init__.py @@ -0,0 +1,121 @@ +from enum import StrEnum +from typing import cast + +from opentelemetry.util.types import AttributeValue +from pydantic import BaseModel, ConfigDict + + +class EnterpriseTelemetrySpan(StrEnum): + WORKFLOW_RUN = "dify.workflow.run" + NODE_EXECUTION = "dify.node.execution" + DRAFT_NODE_EXECUTION = "dify.node.execution.draft" + + +class EnterpriseTelemetryEvent(StrEnum): + """Event names for enterprise telemetry logs.""" + + APP_CREATED = "dify.app.created" + APP_UPDATED = "dify.app.updated" + APP_DELETED = "dify.app.deleted" + FEEDBACK_CREATED = "dify.feedback.created" + WORKFLOW_RUN = "dify.workflow.run" + MESSAGE_RUN = "dify.message.run" + TOOL_EXECUTION = "dify.tool.execution" + MODERATION_CHECK = "dify.moderation.check" + SUGGESTED_QUESTION_GENERATION = "dify.suggested_question.generation" + DATASET_RETRIEVAL = "dify.dataset.retrieval" + GENERATE_NAME_EXECUTION = "dify.generate_name.execution" + PROMPT_GENERATION_EXECUTION = "dify.prompt_generation.execution" + REHYDRATION_FAILED = "dify.telemetry.rehydration_failed" + + +class EnterpriseTelemetryCounter(StrEnum): + TOKENS = "tokens" + INPUT_TOKENS = "input_tokens" + OUTPUT_TOKENS = "output_tokens" + REQUESTS = "requests" + ERRORS = "errors" + FEEDBACK = "feedback" + DATASET_RETRIEVALS = "dataset_retrievals" + APP_CREATED = "app_created" + APP_UPDATED = "app_updated" + APP_DELETED = "app_deleted" + + +class EnterpriseTelemetryHistogram(StrEnum): + WORKFLOW_DURATION = "workflow_duration" + NODE_DURATION = "node_duration" + MESSAGE_DURATION = "message_duration" + MESSAGE_TTFT = "message_ttft" + TOOL_DURATION = "tool_duration" + PROMPT_GENERATION_DURATION = "prompt_generation_duration" + + +class TokenMetricLabels(BaseModel): + """Unified label structure for all dify.token.* metrics. + + All token counters (dify.tokens.input, dify.tokens.output, dify.tokens.total) MUST + use this exact label set to ensure consistent filtering and aggregation across + different operation types. + + Attributes: + tenant_id: Tenant identifier. + app_id: Application identifier. + operation_type: Source of token usage (workflow | node_execution | message | + rule_generate | code_generate | structured_output | instruction_modify). + model_provider: LLM provider name. Empty string if not applicable (e.g., workflow-level). + model_name: LLM model name. Empty string if not applicable (e.g., workflow-level). + node_type: Workflow node type. Empty string unless operation_type=node_execution. + + Usage: + labels = TokenMetricLabels( + tenant_id="tenant-123", + app_id="app-456", + operation_type=OperationType.WORKFLOW, + model_provider="", + model_name="", + node_type="", + ) + exporter.increment_counter( + EnterpriseTelemetryCounter.INPUT_TOKENS, + 100, + labels.to_dict() + ) + + Design rationale: + Without this unified structure, tokens get double-counted when querying totals + because workflow.total_tokens is already the sum of all node tokens. The + operation_type label allows filtering to separate workflow-level aggregates from + node-level detail, while keeping the same label cardinality for consistent queries. + """ + + tenant_id: str + app_id: str + operation_type: str + model_provider: str + model_name: str + node_type: str + + model_config = ConfigDict(extra="forbid", frozen=True) + + def to_dict(self) -> dict[str, AttributeValue]: + return cast( + dict[str, AttributeValue], + { + "tenant_id": self.tenant_id, + "app_id": self.app_id, + "operation_type": self.operation_type, + "model_provider": self.model_provider, + "model_name": self.model_name, + "node_type": self.node_type, + }, + ) + + +__all__ = [ + "EnterpriseTelemetryCounter", + "EnterpriseTelemetryEvent", + "EnterpriseTelemetryHistogram", + "EnterpriseTelemetrySpan", + "TokenMetricLabels", +] diff --git a/api/enterprise/telemetry/event_handlers.py b/api/enterprise/telemetry/event_handlers.py new file mode 100644 index 00000000000..d8b4208c697 --- /dev/null +++ b/api/enterprise/telemetry/event_handlers.py @@ -0,0 +1,72 @@ +"""Blinker signal handlers for enterprise telemetry. + +Registered at import time via ``@signal.connect`` decorators. +Import must happen during ``ext_enterprise_telemetry.init_app()`` to +ensure handlers fire. Each handler delegates to ``core.telemetry.gateway`` +which handles routing, EE-gating, and dispatch. + +All handlers are best-effort: exceptions are caught and logged so that +telemetry failures never break user-facing operations. +""" + +from __future__ import annotations + +import logging + +from events.app_event import app_was_created, app_was_deleted, app_was_updated + +logger = logging.getLogger(__name__) + +__all__ = [ + "_handle_app_created", + "_handle_app_deleted", + "_handle_app_updated", +] + + +@app_was_created.connect +def _handle_app_created(sender: object, **kwargs: object) -> None: + try: + from core.telemetry.gateway import emit as gateway_emit + from enterprise.telemetry.contracts import TelemetryCase + + gateway_emit( + case=TelemetryCase.APP_CREATED, + context={"tenant_id": str(getattr(sender, "tenant_id", "") or "")}, + payload={ + "app_id": getattr(sender, "id", None), + "mode": getattr(sender, "mode", None), + }, + ) + except Exception: + logger.warning("Failed to emit app_created telemetry", exc_info=True) + + +@app_was_updated.connect +def _handle_app_updated(sender: object, **kwargs: object) -> None: + try: + from core.telemetry.gateway import emit as gateway_emit + from enterprise.telemetry.contracts import TelemetryCase + + gateway_emit( + case=TelemetryCase.APP_UPDATED, + context={"tenant_id": str(getattr(sender, "tenant_id", "") or "")}, + payload={"app_id": getattr(sender, "id", None)}, + ) + except Exception: + logger.warning("Failed to emit app_updated telemetry", exc_info=True) + + +@app_was_deleted.connect +def _handle_app_deleted(sender: object, **kwargs: object) -> None: + try: + from core.telemetry.gateway import emit as gateway_emit + from enterprise.telemetry.contracts import TelemetryCase + + gateway_emit( + case=TelemetryCase.APP_DELETED, + context={"tenant_id": str(getattr(sender, "tenant_id", "") or "")}, + payload={"app_id": getattr(sender, "id", None)}, + ) + except Exception: + logger.warning("Failed to emit app_deleted telemetry", exc_info=True) diff --git a/api/enterprise/telemetry/exporter.py b/api/enterprise/telemetry/exporter.py new file mode 100644 index 00000000000..80959514f28 --- /dev/null +++ b/api/enterprise/telemetry/exporter.py @@ -0,0 +1,286 @@ +"""Enterprise OTEL exporter — shared by EnterpriseOtelTrace, event handlers, and direct instrumentation. + +Uses dedicated TracerProvider and MeterProvider instances (configurable sampling, +independent from ext_otel.py infrastructure). + +Initialized once during Flask extension init (single-threaded via ext_enterprise_telemetry.py). +Accessed via ``ext_enterprise_telemetry.get_enterprise_exporter()`` from any thread/process. +""" + +import logging +import socket +import uuid +from datetime import UTC, datetime +from typing import Any, cast + +from opentelemetry import trace +from opentelemetry.baggage import get_all +from opentelemetry.baggage.propagation import W3CBaggagePropagator +from opentelemetry.context import Context +from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter as GRPCMetricExporter +from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter +from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as HTTPMetricExporter +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPSpanExporter +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor +from opentelemetry.sdk.trace.sampling import ParentBasedTraceIdRatio +from opentelemetry.semconv._incubating.attributes.host_attributes import ( # type: ignore[import-untyped] + HOST_NAME, +) +from opentelemetry.semconv.attributes import service_attributes +from opentelemetry.trace import SpanContext, TraceFlags +from opentelemetry.util.types import Attributes, AttributeValue + +from configs import dify_config +from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryHistogram +from enterprise.telemetry.id_generator import ( + CorrelationIdGenerator, + compute_deterministic_span_id, + set_correlation_id, + set_span_id_source, +) + +logger = logging.getLogger(__name__) + + +def is_enterprise_telemetry_enabled() -> bool: + return bool(dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED) + + +def _parse_otlp_headers(raw: str) -> dict[str, str]: + ctx = W3CBaggagePropagator().extract({"baggage": raw}) + return {k: v for k, v in get_all(ctx).items() if isinstance(v, str)} + + +def _datetime_to_ns(dt: datetime) -> int: + """Convert a datetime to nanoseconds since epoch (OTEL convention).""" + # Ensure we always interpret naive datetimes as UTC instead of local time. + if dt.tzinfo is None: + dt = dt.replace(tzinfo=UTC) + else: + dt = dt.astimezone(UTC) + return int(dt.timestamp() * 1_000_000_000) + + +class _ExporterFactory: + def __init__(self, protocol: str, endpoint: str, headers: dict[str, str], insecure: bool): + self._protocol = protocol + self._endpoint = endpoint + self._headers = headers + self._grpc_headers = tuple(headers.items()) if headers else None + self._http_headers = headers or None + self._insecure = insecure + + def create_trace_exporter(self) -> HTTPSpanExporter | GRPCSpanExporter: + if self._protocol == "grpc": + return GRPCSpanExporter( + endpoint=self._endpoint or None, + headers=self._grpc_headers, + insecure=self._insecure, + ) + trace_endpoint = f"{self._endpoint}/v1/traces" if self._endpoint else "" + return HTTPSpanExporter(endpoint=trace_endpoint or None, headers=self._http_headers) + + def create_metric_exporter(self) -> HTTPMetricExporter | GRPCMetricExporter: + if self._protocol == "grpc": + return GRPCMetricExporter( + endpoint=self._endpoint or None, + headers=self._grpc_headers, + insecure=self._insecure, + ) + metric_endpoint = f"{self._endpoint}/v1/metrics" if self._endpoint else "" + return HTTPMetricExporter(endpoint=metric_endpoint or None, headers=self._http_headers) + + +class EnterpriseExporter: + """Shared OTEL exporter for all enterprise telemetry. + + ``export_span`` creates spans with optional real timestamps, deterministic + span/trace IDs, and cross-workflow parent linking. + ``increment_counter`` / ``record_histogram`` emit OTEL metrics at 100% accuracy. + """ + + def __init__(self, config: object) -> None: + endpoint: str = getattr(config, "ENTERPRISE_OTLP_ENDPOINT", "") + headers_raw: str = getattr(config, "ENTERPRISE_OTLP_HEADERS", "") + protocol: str = (getattr(config, "ENTERPRISE_OTLP_PROTOCOL", "http") or "http").lower() + service_name: str = getattr(config, "ENTERPRISE_SERVICE_NAME", "dify") + sampling_rate: float = getattr(config, "ENTERPRISE_OTEL_SAMPLING_RATE", 1.0) + self.include_content: bool = getattr(config, "ENTERPRISE_INCLUDE_CONTENT", True) + api_key: str = getattr(config, "ENTERPRISE_OTLP_API_KEY", "") + + # Auto-detect TLS: https:// uses secure, everything else is insecure + insecure = not endpoint.startswith("https://") + + resource = Resource( + attributes={ + service_attributes.SERVICE_NAME: service_name, + HOST_NAME: socket.gethostname(), + } + ) + sampler = ParentBasedTraceIdRatio(sampling_rate) + id_generator = CorrelationIdGenerator() + self._tracer_provider = TracerProvider(resource=resource, sampler=sampler, id_generator=id_generator) + + headers = _parse_otlp_headers(headers_raw) + if api_key: + if "authorization" in headers: + logger.warning( + "ENTERPRISE_OTLP_API_KEY is set but ENTERPRISE_OTLP_HEADERS also contains " + "'authorization'; the API key will take precedence." + ) + headers["authorization"] = f"Bearer {api_key}" + factory = _ExporterFactory(protocol, endpoint, headers, insecure=insecure) + + trace_exporter = factory.create_trace_exporter() + self._tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter)) + self._tracer = self._tracer_provider.get_tracer("dify.enterprise") + + metric_exporter = factory.create_metric_exporter() + self._meter_provider = MeterProvider( + resource=resource, + metric_readers=[PeriodicExportingMetricReader(metric_exporter)], + ) + meter = self._meter_provider.get_meter("dify.enterprise") + self._counters = { + EnterpriseTelemetryCounter.TOKENS: meter.create_counter("dify.tokens.total", unit="{token}"), + EnterpriseTelemetryCounter.INPUT_TOKENS: meter.create_counter("dify.tokens.input", unit="{token}"), + EnterpriseTelemetryCounter.OUTPUT_TOKENS: meter.create_counter("dify.tokens.output", unit="{token}"), + EnterpriseTelemetryCounter.REQUESTS: meter.create_counter("dify.requests.total", unit="{request}"), + EnterpriseTelemetryCounter.ERRORS: meter.create_counter("dify.errors.total", unit="{error}"), + EnterpriseTelemetryCounter.FEEDBACK: meter.create_counter("dify.feedback.total", unit="{feedback}"), + EnterpriseTelemetryCounter.DATASET_RETRIEVALS: meter.create_counter( + "dify.dataset.retrievals.total", unit="{retrieval}" + ), + EnterpriseTelemetryCounter.APP_CREATED: meter.create_counter("dify.app.created.total", unit="{app}"), + EnterpriseTelemetryCounter.APP_UPDATED: meter.create_counter("dify.app.updated.total", unit="{app}"), + EnterpriseTelemetryCounter.APP_DELETED: meter.create_counter("dify.app.deleted.total", unit="{app}"), + } + self._histograms = { + EnterpriseTelemetryHistogram.WORKFLOW_DURATION: meter.create_histogram("dify.workflow.duration", unit="s"), + EnterpriseTelemetryHistogram.NODE_DURATION: meter.create_histogram("dify.node.duration", unit="s"), + EnterpriseTelemetryHistogram.MESSAGE_DURATION: meter.create_histogram("dify.message.duration", unit="s"), + EnterpriseTelemetryHistogram.MESSAGE_TTFT: meter.create_histogram( + "dify.message.time_to_first_token", unit="s" + ), + EnterpriseTelemetryHistogram.TOOL_DURATION: meter.create_histogram("dify.tool.duration", unit="s"), + EnterpriseTelemetryHistogram.PROMPT_GENERATION_DURATION: meter.create_histogram( + "dify.prompt_generation.duration", unit="s" + ), + } + + def export_span( + self, + name: str, + attributes: dict[str, Any], + correlation_id: str | None = None, + span_id_source: str | None = None, + start_time: datetime | None = None, + end_time: datetime | None = None, + trace_correlation_override: str | None = None, + parent_span_id_source: str | None = None, + ) -> None: + """Export an OTEL span with optional deterministic IDs and real timestamps. + + Args: + name: Span operation name. + attributes: Span attributes dict. + correlation_id: Source for trace_id derivation (groups spans in one trace). + span_id_source: Source for deterministic span_id (e.g. workflow_run_id or node_execution_id). + start_time: Real span start time. When None, uses current time. + end_time: Real span end time. When None, span ends immediately. + trace_correlation_override: Override trace_id source (for cross-workflow linking). + When set, trace_id is derived from this instead of ``correlation_id``. + parent_span_id_source: Override parent span_id source (for cross-workflow linking). + When set, parent span_id is derived from this value. When None and + ``correlation_id`` is set, parent is the workflow root span. + """ + effective_trace_correlation = trace_correlation_override or correlation_id + set_correlation_id(effective_trace_correlation) + set_span_id_source(span_id_source) + + try: + parent_context: Context | None = None + # A span is the "root" of its correlation group when span_id_source == correlation_id + # (i.e. a workflow root span). All other spans are children. + if parent_span_id_source: + # Cross-workflow linking: parent is an explicit span (e.g. tool node in outer workflow) + parent_span_id = compute_deterministic_span_id(parent_span_id_source) + try: + parent_trace_id = int(uuid.UUID(effective_trace_correlation)) if effective_trace_correlation else 0 + except (ValueError, AttributeError): + logger.warning( + "Invalid trace correlation UUID for cross-workflow link: %s, span=%s", + effective_trace_correlation, + name, + ) + parent_trace_id = 0 + if parent_trace_id: + parent_span_context = SpanContext( + trace_id=parent_trace_id, + span_id=parent_span_id, + is_remote=True, + trace_flags=TraceFlags(TraceFlags.SAMPLED), + ) + parent_context = trace.set_span_in_context(trace.NonRecordingSpan(parent_span_context)) + elif correlation_id and correlation_id != span_id_source: + # Child span: parent is the correlation-group root (workflow root span) + parent_span_id = compute_deterministic_span_id(correlation_id) + try: + parent_trace_id = int(uuid.UUID(effective_trace_correlation or correlation_id)) + except (ValueError, AttributeError): + logger.warning( + "Invalid trace correlation UUID for child span link: %s, span=%s", + effective_trace_correlation or correlation_id, + name, + ) + parent_trace_id = 0 + if parent_trace_id: + parent_span_context = SpanContext( + trace_id=parent_trace_id, + span_id=parent_span_id, + is_remote=True, + trace_flags=TraceFlags(TraceFlags.SAMPLED), + ) + parent_context = trace.set_span_in_context(trace.NonRecordingSpan(parent_span_context)) + + span_start_time = _datetime_to_ns(start_time) if start_time is not None else None + span_end_on_exit = end_time is None + + with self._tracer.start_as_current_span( + name, + context=parent_context, + start_time=span_start_time, + end_on_exit=span_end_on_exit, + ) as span: + for key, value in attributes.items(): + if value is not None: + span.set_attribute(key, value) + if end_time is not None: + span.end(end_time=_datetime_to_ns(end_time)) + except Exception: + logger.exception("Failed to export span %s", name) + finally: + set_correlation_id(None) + set_span_id_source(None) + + def increment_counter( + self, name: EnterpriseTelemetryCounter, value: int, labels: dict[str, AttributeValue] + ) -> None: + counter = self._counters.get(name) + if counter: + counter.add(value, cast(Attributes, labels)) + + def record_histogram( + self, name: EnterpriseTelemetryHistogram, value: float, labels: dict[str, AttributeValue] + ) -> None: + histogram = self._histograms.get(name) + if histogram: + histogram.record(value, cast(Attributes, labels)) + + def shutdown(self) -> None: + self._tracer_provider.shutdown() + self._meter_provider.shutdown() diff --git a/api/enterprise/telemetry/id_generator.py b/api/enterprise/telemetry/id_generator.py new file mode 100644 index 00000000000..f3e5d6d0d66 --- /dev/null +++ b/api/enterprise/telemetry/id_generator.py @@ -0,0 +1,75 @@ +"""Custom OTEL ID Generator for correlation-based trace/span ID derivation. + +Uses contextvars for thread-safe correlation_id -> trace_id mapping. +When a span_id_source is set, the span_id is derived deterministically +from that value, enabling any span to reference another as parent +without depending on span creation order. +""" + +import random +import uuid +from contextvars import ContextVar + +from opentelemetry.sdk.trace.id_generator import IdGenerator + +_correlation_id_context: ContextVar[str | None] = ContextVar("correlation_id", default=None) +_span_id_source_context: ContextVar[str | None] = ContextVar("span_id_source", default=None) + + +def set_correlation_id(correlation_id: str | None) -> None: + _correlation_id_context.set(correlation_id) + + +def get_correlation_id() -> str | None: + return _correlation_id_context.get() + + +def set_span_id_source(source_id: str | None) -> None: + """Set the source for deterministic span_id generation. + + When set, ``generate_span_id()`` derives the span_id from this value + (lower 64 bits of the UUID). Pass the ``workflow_run_id`` for workflow + root spans or ``node_execution_id`` for node spans. + """ + _span_id_source_context.set(source_id) + + +def compute_deterministic_span_id(source_id: str) -> int: + """Derive a deterministic span_id from any UUID string. + + Uses the lower 64 bits of the UUID, guaranteeing non-zero output + (OTEL requires span_id != 0). + """ + span_id = uuid.UUID(source_id).int & ((1 << 64) - 1) + return span_id if span_id != 0 else 1 + + +class CorrelationIdGenerator(IdGenerator): + """ID generator that derives trace_id and optionally span_id from context. + + - trace_id: always derived from correlation_id (groups all spans in one trace) + - span_id: derived from span_id_source when set (enables deterministic + parent-child linking), otherwise random + """ + + def generate_trace_id(self) -> int: + correlation_id = _correlation_id_context.get() + if correlation_id: + try: + return uuid.UUID(correlation_id).int + except (ValueError, AttributeError): + pass + return random.getrandbits(128) + + def generate_span_id(self) -> int: + source = _span_id_source_context.get() + if source: + try: + return compute_deterministic_span_id(source) + except (ValueError, AttributeError): + pass + + span_id = random.getrandbits(64) + while span_id == 0: + span_id = random.getrandbits(64) + return span_id diff --git a/api/enterprise/telemetry/metric_handler.py b/api/enterprise/telemetry/metric_handler.py new file mode 100644 index 00000000000..ffd9a7e2b58 --- /dev/null +++ b/api/enterprise/telemetry/metric_handler.py @@ -0,0 +1,421 @@ +"""Enterprise metric/log event handler. + +This module processes metric and log telemetry events after they've been +dequeued from the enterprise_telemetry Celery queue. It handles case routing, +idempotency checking, and payload rehydration. +""" + +from __future__ import annotations + +import json +import logging +from datetime import UTC, datetime +from typing import Any + +from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope +from extensions.ext_redis import redis_client +from extensions.ext_storage import storage + +logger = logging.getLogger(__name__) + + +class EnterpriseMetricHandler: + """Handler for enterprise metric and log telemetry events. + + Processes envelopes from the enterprise_telemetry queue, routing each + case to the appropriate handler method. Implements idempotency checking + and payload rehydration with fallback. + """ + + def _increment_diagnostic_counter(self, counter_name: str, labels: dict[str, str] | None = None) -> None: + """Increment a diagnostic counter for operational monitoring. + + Args: + counter_name: Name of the counter (e.g., 'processed_total', 'deduped_total'). + labels: Optional labels for the counter. + """ + try: + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if not exporter: + return + + full_counter_name = f"enterprise_telemetry.handler.{counter_name}" + logger.debug( + "Diagnostic counter: %s, labels=%s", + full_counter_name, + labels or {}, + ) + except Exception: + logger.debug("Failed to increment diagnostic counter: %s", counter_name, exc_info=True) + + def handle(self, envelope: TelemetryEnvelope) -> None: + """Main entry point for processing telemetry envelopes. + + Args: + envelope: The telemetry envelope to process. + """ + # Check for duplicate events + if self._is_duplicate(envelope): + logger.debug( + "Skipping duplicate event: tenant_id=%s, event_id=%s", + envelope.tenant_id, + envelope.event_id, + ) + self._increment_diagnostic_counter("deduped_total") + return + + # Route to appropriate handler based on case + case = envelope.case + if case == TelemetryCase.APP_CREATED: + self._on_app_created(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "app_created"}) + elif case == TelemetryCase.APP_UPDATED: + self._on_app_updated(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "app_updated"}) + elif case == TelemetryCase.APP_DELETED: + self._on_app_deleted(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "app_deleted"}) + elif case == TelemetryCase.FEEDBACK_CREATED: + self._on_feedback_created(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "feedback_created"}) + elif case == TelemetryCase.MESSAGE_RUN: + self._on_message_run(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "message_run"}) + elif case == TelemetryCase.TOOL_EXECUTION: + self._on_tool_execution(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "tool_execution"}) + elif case == TelemetryCase.MODERATION_CHECK: + self._on_moderation_check(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "moderation_check"}) + elif case == TelemetryCase.SUGGESTED_QUESTION: + self._on_suggested_question(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "suggested_question"}) + elif case == TelemetryCase.DATASET_RETRIEVAL: + self._on_dataset_retrieval(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "dataset_retrieval"}) + elif case == TelemetryCase.GENERATE_NAME: + self._on_generate_name(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "generate_name"}) + elif case == TelemetryCase.PROMPT_GENERATION: + self._on_prompt_generation(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "prompt_generation"}) + else: + logger.warning( + "Unknown telemetry case: %s (tenant_id=%s, event_id=%s)", + case, + envelope.tenant_id, + envelope.event_id, + ) + + def _is_duplicate(self, envelope: TelemetryEnvelope) -> bool: + """Check if this event has already been processed. + + Uses Redis with TTL for deduplication. Returns True if duplicate, + False if first time seeing this event. + + Args: + envelope: The telemetry envelope to check. + + Returns: + True if this event_id has been seen before, False otherwise. + """ + dedup_key = f"telemetry:dedup:{envelope.tenant_id}:{envelope.event_id}" + + try: + # Atomic set-if-not-exists with 1h TTL + # Returns True if key was set (first time), None if already exists (duplicate) + was_set = redis_client.set(dedup_key, b"1", nx=True, ex=3600) + return was_set is None + except Exception: + # Fail open: if Redis is unavailable, process the event + # (prefer occasional duplicate over lost data) + logger.warning( + "Redis unavailable for deduplication check, processing event anyway: %s", + envelope.event_id, + exc_info=True, + ) + return False + + def _rehydrate(self, envelope: TelemetryEnvelope) -> dict[str, Any]: + """Rehydrate payload from storage reference or inline data. + + If the envelope payload is empty and metadata contains a + ``payload_ref``, the full payload is loaded from object storage + (where the gateway wrote it as JSON). When both the inline + payload and storage resolution fail, a degraded-event marker + is emitted so the gap is observable. + + Args: + envelope: The telemetry envelope containing payload data. + + Returns: + The rehydrated payload dictionary, or ``{}`` on total failure. + """ + payload = envelope.payload + + # Resolve from object storage when the gateway offloaded a large payload. + if not payload and envelope.metadata: + payload_ref = envelope.metadata.get("payload_ref") + if payload_ref: + try: + payload_bytes = storage.load(payload_ref) + payload = json.loads(payload_bytes.decode("utf-8")) + logger.debug("Loaded payload from storage: key=%s", payload_ref) + except Exception: + logger.warning( + "Failed to load payload from storage: key=%s, event_id=%s", + payload_ref, + envelope.event_id, + exc_info=True, + ) + + if not payload: + # Storage resolution failed or no data available — emit degraded event. + logger.error( + "Payload rehydration failed for event_id=%s, tenant_id=%s, case=%s", + envelope.event_id, + envelope.tenant_id, + envelope.case, + ) + from enterprise.telemetry.entities import EnterpriseTelemetryEvent + from enterprise.telemetry.telemetry_log import emit_metric_only_event + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.REHYDRATION_FAILED, + attributes={ + "tenant_id": envelope.tenant_id, + "dify.telemetry.error": f"Payload rehydration failed for event_id={envelope.event_id}", + "dify.telemetry.payload_type": envelope.case, + "dify.telemetry.correlation_id": envelope.event_id, + }, + tenant_id=envelope.tenant_id, + ) + self._increment_diagnostic_counter("rehydration_failed_total") + return {} + + return payload + + # Stub methods for each metric/log case + # These will be implemented in later tasks with actual emission logic + + def _on_app_created(self, envelope: TelemetryEnvelope) -> None: + """Handle app created event.""" + from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent + from enterprise.telemetry.telemetry_log import emit_metric_only_event + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if not exporter: + logger.debug("No exporter available for APP_CREATED: event_id=%s", envelope.event_id) + return + + payload = self._rehydrate(envelope) + if not payload: + return + + attrs = { + "dify.app_id": payload.get("app_id"), + "dify.tenant_id": envelope.tenant_id, + "dify.event.id": envelope.event_id, + "dify.app.mode": payload.get("mode"), + "dify.app.created_at": datetime.now(UTC).isoformat(), + } + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.APP_CREATED, + attributes=attrs, + tenant_id=envelope.tenant_id, + ) + exporter.increment_counter( + EnterpriseTelemetryCounter.APP_CREATED, + 1, + { + "tenant_id": envelope.tenant_id, + "app_id": str(payload.get("app_id", "")), + "mode": str(payload.get("mode", "")), + }, + ) + + def _on_app_updated(self, envelope: TelemetryEnvelope) -> None: + """Handle app updated event.""" + from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent + from enterprise.telemetry.telemetry_log import emit_metric_only_event + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if not exporter: + logger.debug("No exporter available for APP_UPDATED: event_id=%s", envelope.event_id) + return + + payload = self._rehydrate(envelope) + if not payload: + return + + attrs = { + "dify.app_id": payload.get("app_id"), + "dify.tenant_id": envelope.tenant_id, + "dify.event.id": envelope.event_id, + "dify.app.updated_at": datetime.now(UTC).isoformat(), + } + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.APP_UPDATED, + attributes=attrs, + tenant_id=envelope.tenant_id, + ) + exporter.increment_counter( + EnterpriseTelemetryCounter.APP_UPDATED, + 1, + { + "tenant_id": envelope.tenant_id, + "app_id": str(payload.get("app_id", "")), + }, + ) + + def _on_app_deleted(self, envelope: TelemetryEnvelope) -> None: + """Handle app deleted event.""" + from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent + from enterprise.telemetry.telemetry_log import emit_metric_only_event + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if not exporter: + logger.debug("No exporter available for APP_DELETED: event_id=%s", envelope.event_id) + return + + payload = self._rehydrate(envelope) + if not payload: + return + + attrs = { + "dify.app_id": payload.get("app_id"), + "dify.tenant_id": envelope.tenant_id, + "dify.event.id": envelope.event_id, + "dify.app.deleted_at": datetime.now(UTC).isoformat(), + } + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.APP_DELETED, + attributes=attrs, + tenant_id=envelope.tenant_id, + ) + exporter.increment_counter( + EnterpriseTelemetryCounter.APP_DELETED, + 1, + { + "tenant_id": envelope.tenant_id, + "app_id": str(payload.get("app_id", "")), + }, + ) + + def _on_feedback_created(self, envelope: TelemetryEnvelope) -> None: + """Handle feedback created event.""" + from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent + from enterprise.telemetry.telemetry_log import emit_metric_only_event + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if not exporter: + logger.debug("No exporter available for FEEDBACK_CREATED: event_id=%s", envelope.event_id) + return + + payload = self._rehydrate(envelope) + if not payload: + return + + include_content = exporter.include_content + attrs: dict = { + "dify.message.id": payload.get("message_id"), + "dify.tenant_id": envelope.tenant_id, + "dify.event.id": envelope.event_id, + "dify.app_id": payload.get("app_id"), + "dify.conversation.id": payload.get("conversation_id"), + "gen_ai.user.id": payload.get("from_end_user_id") or payload.get("from_account_id"), + "dify.feedback.rating": payload.get("rating"), + "dify.feedback.from_source": payload.get("from_source"), + "dify.feedback.created_at": datetime.now(UTC).isoformat(), + } + if include_content: + attrs["dify.feedback.content"] = payload.get("content") + + user_id = payload.get("from_end_user_id") or payload.get("from_account_id") + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.FEEDBACK_CREATED, + attributes=attrs, + tenant_id=envelope.tenant_id, + user_id=str(user_id or ""), + ) + exporter.increment_counter( + EnterpriseTelemetryCounter.FEEDBACK, + 1, + { + "tenant_id": envelope.tenant_id, + "app_id": str(payload.get("app_id", "")), + "rating": str(payload.get("rating", "")), + }, + ) + + def _on_message_run(self, envelope: TelemetryEnvelope) -> None: + """Handle message run event. + + Intentionally a no-op: metrics and structured logs for message runs are + emitted directly by EnterpriseOtelTrace._message_trace at trace time, + not through the metric handler queue path. + """ + logger.debug("Processing MESSAGE_RUN: event_id=%s", envelope.event_id) + + def _on_tool_execution(self, envelope: TelemetryEnvelope) -> None: + """Handle tool execution event. + + Intentionally a no-op: metrics and structured logs for tool executions + are emitted directly by EnterpriseOtelTrace._tool_trace at trace time, + not through the metric handler queue path. + """ + logger.debug("Processing TOOL_EXECUTION: event_id=%s", envelope.event_id) + + def _on_moderation_check(self, envelope: TelemetryEnvelope) -> None: + """Handle moderation check event. + + Intentionally a no-op: metrics and structured logs for moderation checks + are emitted directly by EnterpriseOtelTrace._moderation_trace at trace time, + not through the metric handler queue path. + """ + logger.debug("Processing MODERATION_CHECK: event_id=%s", envelope.event_id) + + def _on_suggested_question(self, envelope: TelemetryEnvelope) -> None: + """Handle suggested question event. + + Intentionally a no-op: metrics and structured logs for suggested questions + are emitted directly by EnterpriseOtelTrace._suggested_question_trace at + trace time, not through the metric handler queue path. + """ + logger.debug("Processing SUGGESTED_QUESTION: event_id=%s", envelope.event_id) + + def _on_dataset_retrieval(self, envelope: TelemetryEnvelope) -> None: + """Handle dataset retrieval event. + + Intentionally a no-op: metrics and structured logs for dataset retrievals + are emitted directly by EnterpriseOtelTrace._dataset_retrieval_trace at + trace time, not through the metric handler queue path. + """ + logger.debug("Processing DATASET_RETRIEVAL: event_id=%s", envelope.event_id) + + def _on_generate_name(self, envelope: TelemetryEnvelope) -> None: + """Handle generate name event. + + Intentionally a no-op: metrics and structured logs for generate name + operations are emitted directly by EnterpriseOtelTrace._generate_name_trace + at trace time, not through the metric handler queue path. + """ + logger.debug("Processing GENERATE_NAME: event_id=%s", envelope.event_id) + + def _on_prompt_generation(self, envelope: TelemetryEnvelope) -> None: + """Handle prompt generation event. + + Intentionally a no-op: metrics and structured logs for prompt generation + operations are emitted directly by EnterpriseOtelTrace._prompt_generation_trace + at trace time, not through the metric handler queue path. + """ + logger.debug("Processing PROMPT_GENERATION: event_id=%s", envelope.event_id) diff --git a/api/enterprise/telemetry/telemetry_log.py b/api/enterprise/telemetry/telemetry_log.py new file mode 100644 index 00000000000..8cce4a9fcdd --- /dev/null +++ b/api/enterprise/telemetry/telemetry_log.py @@ -0,0 +1,122 @@ +"""Structured-log emitter for enterprise telemetry events. + +Emits structured JSON log lines correlated with OTEL traces via trace_id. +Picked up by ``StructuredJSONFormatter`` → stdout/Loki/Elastic. +""" + +from __future__ import annotations + +import logging +import uuid +from functools import lru_cache +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from enterprise.telemetry.entities import EnterpriseTelemetryEvent + +logger = logging.getLogger("dify.telemetry") + + +@lru_cache(maxsize=4096) +def compute_trace_id_hex(uuid_str: str | None) -> str: + """Convert a business UUID string to a 32-hex OTEL-compatible trace_id. + + Returns empty string when *uuid_str* is ``None`` or invalid. + """ + if not uuid_str: + return "" + normalized = uuid_str.strip().lower() + if len(normalized) == 32 and all(ch in "0123456789abcdef" for ch in normalized): + return normalized + try: + return f"{uuid.UUID(normalized).int:032x}" + except (ValueError, AttributeError): + return "" + + +@lru_cache(maxsize=4096) +def compute_span_id_hex(uuid_str: str | None) -> str: + if not uuid_str: + return "" + normalized = uuid_str.strip().lower() + if len(normalized) == 16 and all(ch in "0123456789abcdef" for ch in normalized): + return normalized + try: + from enterprise.telemetry.id_generator import compute_deterministic_span_id + + return f"{compute_deterministic_span_id(normalized):016x}" + except (ValueError, AttributeError): + return "" + + +def emit_telemetry_log( + *, + event_name: str | EnterpriseTelemetryEvent, + attributes: dict[str, Any], + signal: str = "metric_only", + trace_id_source: str | None = None, + span_id_source: str | None = None, + tenant_id: str | None = None, + user_id: str | None = None, +) -> None: + """Emit a structured log line for a telemetry event. + + Parameters + ---------- + event_name: + Canonical event name, e.g. ``"dify.workflow.run"``. + attributes: + All event-specific attributes (already built by the caller). + signal: + ``"metric_only"`` for events with no span, ``"span_detail"`` + for detail logs accompanying a slim span. + trace_id_source: + A UUID string (e.g. ``workflow_run_id``) used to derive a 32-hex + trace_id for cross-signal correlation. + tenant_id: + Tenant identifier (for the ``IdentityContextFilter``). + user_id: + User identifier (for the ``IdentityContextFilter``). + """ + if not logger.isEnabledFor(logging.INFO): + return + attrs = { + "dify.event.name": event_name, + "dify.event.signal": signal, + **attributes, + } + + extra: dict[str, Any] = {"attributes": attrs} + + trace_id_hex = compute_trace_id_hex(trace_id_source) + if trace_id_hex: + extra["trace_id"] = trace_id_hex + span_id_hex = compute_span_id_hex(span_id_source) + if span_id_hex: + extra["span_id"] = span_id_hex + if tenant_id: + extra["tenant_id"] = tenant_id + if user_id: + extra["user_id"] = user_id + + logger.info("telemetry.%s", signal, extra=extra) + + +def emit_metric_only_event( + *, + event_name: str | EnterpriseTelemetryEvent, + attributes: dict[str, Any], + trace_id_source: str | None = None, + span_id_source: str | None = None, + tenant_id: str | None = None, + user_id: str | None = None, +) -> None: + emit_telemetry_log( + event_name=event_name, + attributes=attributes, + signal="metric_only", + trace_id_source=trace_id_source, + span_id_source=span_id_source, + tenant_id=tenant_id, + user_id=user_id, + ) diff --git a/api/events/app_event.py b/api/events/app_event.py index f2ce71bbbb3..2fba0028f91 100644 --- a/api/events/app_event.py +++ b/api/events/app_event.py @@ -11,3 +11,9 @@ app_published_workflow_was_updated = signal("app-published-workflow-was-updated" # sender: app, kwargs: synced_draft_workflow app_draft_workflow_was_synced = signal("app-draft-workflow-was-synced") + +# sender: app +app_was_updated = signal("app-was-updated") + +# sender: app +app_was_deleted = signal("app-was-deleted") diff --git a/api/events/event_handlers/create_site_record_when_app_created.py b/api/events/event_handlers/create_site_record_when_app_created.py index 5e7caf8cbed..84be592b1a9 100644 --- a/api/events/event_handlers/create_site_record_when_app_created.py +++ b/api/events/event_handlers/create_site_record_when_app_created.py @@ -1,5 +1,6 @@ from events.app_event import app_was_created from extensions.ext_database import db +from models.enums import CustomizeTokenStrategy from models.model import Site @@ -16,7 +17,7 @@ def handle(sender, **kwargs): icon=app.icon, icon_background=app.icon_background, default_language=account.interface_language, - customize_token_strategy="not_allow", + customize_token_strategy=CustomizeTokenStrategy.NOT_ALLOW, code=Site.generate_code(16), created_by=app.created_by, updated_by=app.updated_by, diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index c43e99f0f4b..7bd8e88231a 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -1,9 +1,11 @@ import logging +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.tool.entities import ToolEntity + +from core.tools.entities.tool_entities import ToolProviderType from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager -from dify_graph.nodes import BuiltinNodeTypes -from dify_graph.nodes.tool.entities import ToolEntity from events.app_event import app_draft_workflow_was_synced logger = logging.getLogger(__name__) @@ -19,8 +21,9 @@ def handle(sender, **kwargs): if node_data.get("data", {}).get("type") == BuiltinNodeTypes.TOOL: try: tool_entity = ToolEntity.model_validate(node_data["data"]) + provider_type = ToolProviderType(tool_entity.provider_type.value) tool_runtime = ToolManager.get_tool_runtime( - provider_type=tool_entity.provider_type, + provider_type=provider_type, provider_id=tool_entity.provider_id, tool_name=tool_entity.tool_name, tenant_id=app.tenant_id, @@ -30,7 +33,7 @@ def handle(sender, **kwargs): tenant_id=app.tenant_id, tool_runtime=tool_runtime, provider_name=tool_entity.provider_name, - provider_type=tool_entity.provider_type, + provider_type=provider_type, identity_id=f"WORKFLOW.{app.id}.{node_data.get('id')}", ) manager.delete_tool_parameters_cache() diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index 20852b818e3..86b5b2bbf05 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -1,9 +1,9 @@ from typing import cast +from graphon.nodes import BuiltinNodeTypes from sqlalchemy import delete, select from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData -from dify_graph.nodes import BuiltinNodeTypes from events.app_event import app_published_workflow_was_updated from extensions.ext_database import db from models.dataset import AppDatasetJoin diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 7b6a73af527..4eed34436a4 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -2,7 +2,7 @@ import ssl from datetime import timedelta from typing import Any -import pytz +import pytz # type: ignore[import-untyped] from celery import Celery, Task from celery.schedules import crontab @@ -204,6 +204,8 @@ def init_app(app: DifyApp) -> Celery: "schedule": timedelta(minutes=dify_config.API_TOKEN_LAST_USED_UPDATE_INTERVAL), } + if dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED: + imports.append("tasks.enterprise_telemetry_task") celery_app.conf.update(beat_schedule=beat_schedule, imports=imports) return celery_app diff --git a/api/extensions/ext_enterprise_telemetry.py b/api/extensions/ext_enterprise_telemetry.py new file mode 100644 index 00000000000..b3cfa01aee6 --- /dev/null +++ b/api/extensions/ext_enterprise_telemetry.py @@ -0,0 +1,50 @@ +"""Flask extension for enterprise telemetry lifecycle management. + +Initializes the EnterpriseExporter singleton during ``create_app()`` +(single-threaded), registers blinker event handlers, and hooks atexit +for graceful shutdown. + +Skipped entirely when either ``ENTERPRISE_ENABLED`` or ``ENTERPRISE_TELEMETRY_ENABLED`` +is false (``is_enabled()`` gate). +""" + +from __future__ import annotations + +import atexit +import logging +from typing import TYPE_CHECKING + +from configs import dify_config + +if TYPE_CHECKING: + from dify_app import DifyApp + from enterprise.telemetry.exporter import EnterpriseExporter + +logger = logging.getLogger(__name__) + +_exporter: EnterpriseExporter | None = None + + +def is_enabled() -> bool: + return bool(dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED) + + +def init_app(app: DifyApp) -> None: + global _exporter + + if not is_enabled(): + return + + from enterprise.telemetry.exporter import EnterpriseExporter + + _exporter = EnterpriseExporter(dify_config) + atexit.register(_exporter.shutdown) + + # Import to trigger @signal.connect decorator registration + import enterprise.telemetry.event_handlers # noqa: F401 # type: ignore[reportUnusedImport] + + logger.info("Enterprise telemetry initialized") + + +def get_enterprise_exporter() -> EnterpriseExporter | None: + return _exporter diff --git a/api/extensions/ext_otel.py b/api/extensions/ext_otel.py index a5baa21018d..63edbe93e79 100644 --- a/api/extensions/ext_otel.py +++ b/api/extensions/ext_otel.py @@ -78,16 +78,24 @@ def init_app(app: DifyApp): protocol = (dify_config.OTEL_EXPORTER_OTLP_PROTOCOL or "").lower() if dify_config.OTEL_EXPORTER_TYPE == "otlp": if protocol == "grpc": + # Auto-detect TLS: https:// uses secure, everything else is insecure + endpoint = dify_config.OTLP_BASE_ENDPOINT + insecure = not endpoint.startswith("https://") + + # Header field names must consist of lowercase letters, check RFC7540 + grpc_headers = ( + (("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),) if dify_config.OTLP_API_KEY else () + ) + exporter = GRPCSpanExporter( - endpoint=dify_config.OTLP_BASE_ENDPOINT, - # Header field names must consist of lowercase letters, check RFC7540 - headers=(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),), - insecure=True, + endpoint=endpoint, + headers=grpc_headers, + insecure=insecure, ) metric_exporter = GRPCMetricExporter( - endpoint=dify_config.OTLP_BASE_ENDPOINT, - headers=(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),), - insecure=True, + endpoint=endpoint, + headers=grpc_headers, + insecure=insecure, ) else: headers = {"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"} if dify_config.OTLP_API_KEY else None diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py index 9a34acb0c19..5cc58f27c4d 100644 --- a/api/extensions/ext_sentry.py +++ b/api/extensions/ext_sentry.py @@ -5,17 +5,25 @@ from dify_app import DifyApp def init_app(app: DifyApp): if dify_config.SENTRY_DSN: import sentry_sdk - from langfuse import parse_error + from graphon.model_runtime.errors.invoke import InvokeRateLimitError from sentry_sdk.integrations.celery import CeleryIntegration from sentry_sdk.integrations.flask import FlaskIntegration from werkzeug.exceptions import HTTPException - from dify_graph.model_runtime.errors.invoke import InvokeRateLimitError + try: + from langfuse._utils import parse_error + + _langfuse_error_response = parse_error.defaultErrorResponse + except (ImportError, AttributeError): + _langfuse_error_response = ( + "Unexpected error occurred. Please check your request" + " and contact support: https://langfuse.com/support." + ) def before_send(event, hint): if "exc_info" in hint: _, exc_value, _ = hint["exc_info"] - if parse_error.defaultErrorResponse in str(exc_value): + if _langfuse_error_response in str(exc_value): return None return event @@ -28,7 +36,7 @@ def init_app(app: DifyApp): ValueError, FileNotFoundError, InvokeRateLimitError, - parse_error.defaultErrorResponse, + _langfuse_error_response, ], traces_sample_rate=dify_config.SENTRY_TRACES_SAMPLE_RATE, profiles_sample_rate=dify_config.SENTRY_PROFILES_SAMPLE_RATE, diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py index a94d75ec760..db599c5d495 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py @@ -11,9 +11,9 @@ from collections.abc import Sequence from datetime import datetime from typing import Any +from graphon.enums import WorkflowNodeExecutionStatus from sqlalchemy.orm import sessionmaker -from dify_graph.enums import WorkflowNodeExecutionStatus from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value @@ -60,7 +60,7 @@ def _dict_to_workflow_node_execution_model(data: dict[str, Any]) -> WorkflowNode model.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN model.node_id = data.get("node_id") or "" model.node_type = data.get("node_type") or "" - model.status = data.get("status") or "running" # Default status if missing + model.status = WorkflowNodeExecutionStatus(data.get("status") or "running") model.title = data.get("title") or "" created_by_role_val = data.get("created_by_role") try: diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py index bdfc81bd1c1..3c83ab4f84e 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py @@ -20,9 +20,9 @@ from collections.abc import Sequence from datetime import datetime from typing import Any, cast +from graphon.enums import WorkflowExecutionStatus from sqlalchemy.orm import sessionmaker -from dify_graph.enums import WorkflowExecutionStatus from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string diff --git a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py index c58aa6adbb4..f71b2fa1df9 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py @@ -4,13 +4,13 @@ import os import time from typing import Union +from graphon.entities import WorkflowExecution +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker +from core.repositories.factory import WorkflowExecutionRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository -from dify_graph.entities import WorkflowExecution -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from extensions.logstore.aliyun_logstore import AliyunLogStore from libs.helper import extract_tenant_id from models import ( diff --git a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py index d84c0bc432d..b7254366817 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py @@ -13,15 +13,15 @@ from collections.abc import Sequence from datetime import datetime from typing import Any, Union +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter +from core.repositories.factory import OrderConfig, WorkflowNodeExecutionRepository from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier @@ -304,35 +304,39 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): logger.exception("Failed to dual-write node execution data to SQL database: id=%s", execution.id) # Don't raise - LogStore write succeeded, SQL is just a backup - def get_by_workflow_run( + def get_by_workflow_execution( self, - workflow_run_id: str, + workflow_execution_id: str, order_config: OrderConfig | None = None, ) -> Sequence[WorkflowNodeExecution]: """ - Retrieve all NodeExecution instances for a specific workflow run. + Retrieve all node executions for a workflow execution. Uses LogStore SQL query with window function to get the latest version of each node execution. This ensures we only get the most recent version of each node execution record. Args: - workflow_run_id: The workflow run ID + workflow_execution_id: The workflow execution identifier order_config: Optional configuration for ordering results order_config.order_by: List of fields to order by (e.g., ["index", "created_at"]) order_config.order_direction: Direction to order ("asc" or "desc") Returns: - A list of NodeExecution instances + A list of workflow node execution instances Note: This method uses ROW_NUMBER() window function partitioned by node_execution_id to get the latest version (highest log_version) of each node execution. """ - logger.debug("get_by_workflow_run: workflow_run_id=%s, order_config=%s", workflow_run_id, order_config) + logger.debug( + "get_by_workflow_execution: workflow_execution_id=%s, order_config=%s", + workflow_execution_id, + order_config, + ) # Build SQL query with deduplication using window function # ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC) # ensures we get the latest version of each node execution # Escape parameters to prevent SQL injection - escaped_workflow_run_id = escape_identifier(workflow_run_id) + escaped_workflow_execution_id = escape_identifier(workflow_execution_id) escaped_tenant_id = escape_identifier(self._tenant_id) # Build ORDER BY clause for outer query @@ -360,7 +364,7 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): SELECT * FROM ( SELECT *, ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC) AS rn FROM {AliyunLogStore.workflow_node_execution_logstore} - WHERE workflow_run_id='{escaped_workflow_run_id}' + WHERE workflow_run_id='{escaped_workflow_execution_id}' AND tenant_id='{escaped_tenant_id}' {app_id_filter} ) t @@ -391,5 +395,8 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): return executions except Exception: - logger.exception("Failed to retrieve node executions from LogStore: workflow_run_id=%s", workflow_run_id) + logger.exception( + "Failed to retrieve node executions from LogStore: workflow_execution_id=%s", + workflow_execution_id, + ) raise diff --git a/api/extensions/otel/parser/__init__.py b/api/extensions/otel/parser/__init__.py index 164db7c2756..c671e8b4096 100644 --- a/api/extensions/otel/parser/__init__.py +++ b/api/extensions/otel/parser/__init__.py @@ -5,7 +5,7 @@ This module provides parsers that extract node-specific metadata and set OpenTelemetry span attributes according to semantic conventions. """ -from extensions.otel.parser.base import DefaultNodeOTelParser, NodeOTelParser, safe_json_dumps +from extensions.otel.parser.base import DefaultNodeOTelParser, NodeOTelParser, safe_json_dumps, should_include_content from extensions.otel.parser.llm import LLMNodeOTelParser from extensions.otel.parser.retrieval import RetrievalNodeOTelParser from extensions.otel.parser.tool import ToolNodeOTelParser @@ -17,4 +17,5 @@ __all__ = [ "RetrievalNodeOTelParser", "ToolNodeOTelParser", "safe_json_dumps", + "should_include_content", ] diff --git a/api/extensions/otel/parser/base.py b/api/extensions/otel/parser/base.py index 544ef3fe18a..23d324f9ead 100644 --- a/api/extensions/otel/parser/base.py +++ b/api/extensions/otel/parser/base.py @@ -1,22 +1,38 @@ """ Base parser interface and utilities for OpenTelemetry node parsers. + +Content gating: ``should_include_content()`` controls whether content-bearing +span attributes (inputs, outputs, prompts, completions, documents) are written. +Gate is only active in EE (``ENTERPRISE_ENABLED=True``) when +``ENTERPRISE_INCLUDE_CONTENT=False``; CE behaviour is unchanged. """ import json from typing import Any, Protocol +from graphon.enums import BuiltinNodeTypes +from graphon.file import File +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node +from graphon.variables import Segment from opentelemetry.trace import Span from opentelemetry.trace.status import Status, StatusCode from pydantic import BaseModel -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.file.models import File -from dify_graph.graph_events import GraphNodeEventBase -from dify_graph.nodes.base.node import Node -from dify_graph.variables import Segment +from configs import dify_config from extensions.otel.semconv.gen_ai import ChainAttributes, GenAIAttributes +def should_include_content() -> bool: + """Return True if content should be written to spans. + + CE (ENTERPRISE_ENABLED=False): always True — no behaviour change. + """ + if not dify_config.ENTERPRISE_ENABLED: + return True + return dify_config.ENTERPRISE_INCLUDE_CONTENT + + def safe_json_dumps(obj: Any, ensure_ascii: bool = False) -> str: """ Safely serialize objects to JSON, handling non-serializable types. @@ -101,10 +117,11 @@ class DefaultNodeOTelParser: # Extract inputs and outputs from result_event if result_event and result_event.node_run_result: node_run_result = result_event.node_run_result - if node_run_result.inputs: - span.set_attribute(ChainAttributes.INPUT_VALUE, safe_json_dumps(node_run_result.inputs)) - if node_run_result.outputs: - span.set_attribute(ChainAttributes.OUTPUT_VALUE, safe_json_dumps(node_run_result.outputs)) + if should_include_content(): + if node_run_result.inputs: + span.set_attribute(ChainAttributes.INPUT_VALUE, safe_json_dumps(node_run_result.inputs)) + if node_run_result.outputs: + span.set_attribute(ChainAttributes.OUTPUT_VALUE, safe_json_dumps(node_run_result.outputs)) if error: span.record_exception(error) diff --git a/api/extensions/otel/parser/llm.py b/api/extensions/otel/parser/llm.py index 3da9a9e97d1..335c5cc29e2 100644 --- a/api/extensions/otel/parser/llm.py +++ b/api/extensions/otel/parser/llm.py @@ -6,10 +6,10 @@ import logging from collections.abc import Mapping from typing import Any +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node from opentelemetry.trace import Span -from dify_graph.graph_events import GraphNodeEventBase -from dify_graph.nodes.base.node import Node from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import LLMAttributes diff --git a/api/extensions/otel/parser/retrieval.py b/api/extensions/otel/parser/retrieval.py index dd658b250bb..6df5f62c155 100644 --- a/api/extensions/otel/parser/retrieval.py +++ b/api/extensions/otel/parser/retrieval.py @@ -6,11 +6,11 @@ import logging from collections.abc import Sequence from typing import Any +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node +from graphon.variables import Segment from opentelemetry.trace import Span -from dify_graph.graph_events import GraphNodeEventBase -from dify_graph.nodes.base.node import Node -from dify_graph.variables import Segment from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import RetrieverAttributes diff --git a/api/extensions/otel/parser/tool.py b/api/extensions/otel/parser/tool.py index f4e6a18b4dd..b9fdd9e1caa 100644 --- a/api/extensions/otel/parser/tool.py +++ b/api/extensions/otel/parser/tool.py @@ -2,12 +2,12 @@ Parser for tool nodes that captures tool-specific metadata. """ +from graphon.enums import WorkflowNodeExecutionMetadataKey +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node +from graphon.nodes.tool.entities import ToolNodeData from opentelemetry.trace import Span -from dify_graph.enums import WorkflowNodeExecutionMetadataKey -from dify_graph.graph_events import GraphNodeEventBase -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.tool.entities import ToolNodeData from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import ToolAttributes diff --git a/api/extensions/otel/semconv/dify.py b/api/extensions/otel/semconv/dify.py index a20b9b358df..301ddd11aaa 100644 --- a/api/extensions/otel/semconv/dify.py +++ b/api/extensions/otel/semconv/dify.py @@ -21,3 +21,15 @@ class DifySpanAttributes: INVOKE_FROM = "dify.invoke_from" """Invocation source, e.g. SERVICE_API, WEB_APP, DEBUGGER.""" + + INVOKED_BY = "dify.invoked_by" + """Invoked by, e.g. end_user, account, user.""" + + USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens" + """Number of input tokens (prompt tokens) used.""" + + USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens" + """Number of output tokens (completion tokens) generated.""" + + USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens" + """Total number of tokens used.""" diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py deleted file mode 100644 index cb07ba58ae6..00000000000 --- a/api/factories/file_factory.py +++ /dev/null @@ -1,618 +0,0 @@ -import logging -import mimetypes -import os -import re -import urllib.parse -import uuid -from collections.abc import Callable, Mapping, Sequence -from typing import Any - -import httpx -from sqlalchemy import select -from sqlalchemy.orm import Session -from werkzeug.http import parse_options_header - -from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS -from core.helper import ssrf_proxy -from dify_graph.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig, helpers -from extensions.ext_database import db -from models import MessageFile, ToolFile, UploadFile - -logger = logging.getLogger(__name__) - - -def build_from_message_files( - *, - message_files: Sequence["MessageFile"], - tenant_id: str, - config: FileUploadConfig | None = None, -) -> Sequence[File]: - results = [ - build_from_message_file(message_file=file, tenant_id=tenant_id, config=config) - for file in message_files - if file.belongs_to != FileBelongsTo.ASSISTANT - ] - return results - - -def build_from_message_file( - *, - message_file: "MessageFile", - tenant_id: str, - config: FileUploadConfig | None, -): - mapping = { - "transfer_method": message_file.transfer_method, - "url": message_file.url, - "type": message_file.type, - } - - # Only include id if it exists (message_file has been committed to DB) - if message_file.id: - mapping["id"] = message_file.id - - # Set the correct ID field based on transfer method - if message_file.transfer_method == FileTransferMethod.TOOL_FILE: - mapping["tool_file_id"] = message_file.upload_file_id - else: - mapping["upload_file_id"] = message_file.upload_file_id - - return build_from_mapping( - mapping=mapping, - tenant_id=tenant_id, - config=config, - ) - - -def build_from_mapping( - *, - mapping: Mapping[str, Any], - tenant_id: str, - config: FileUploadConfig | None = None, - strict_type_validation: bool = False, -) -> File: - transfer_method_value = mapping.get("transfer_method") - if not transfer_method_value: - raise ValueError("transfer_method is required in file mapping") - transfer_method = FileTransferMethod.value_of(transfer_method_value) - - build_functions: dict[FileTransferMethod, Callable] = { - FileTransferMethod.LOCAL_FILE: _build_from_local_file, - FileTransferMethod.REMOTE_URL: _build_from_remote_url, - FileTransferMethod.TOOL_FILE: _build_from_tool_file, - FileTransferMethod.DATASOURCE_FILE: _build_from_datasource_file, - } - - build_func = build_functions.get(transfer_method) - if not build_func: - raise ValueError(f"Invalid file transfer method: {transfer_method}") - - file: File = build_func( - mapping=mapping, - tenant_id=tenant_id, - transfer_method=transfer_method, - strict_type_validation=strict_type_validation, - ) - - if config and not _is_file_valid_with_config( - input_file_type=mapping.get("type", FileType.CUSTOM), - file_extension=file.extension or "", - file_transfer_method=file.transfer_method, - config=config, - ): - raise ValueError(f"File validation failed for file: {file.filename}") - - return file - - -def build_from_mappings( - *, - mappings: Sequence[Mapping[str, Any]], - config: FileUploadConfig | None = None, - tenant_id: str, - strict_type_validation: bool = False, -) -> Sequence[File]: - # TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query. - # Implement batch processing to reduce database load when handling multiple files. - # Filter out None/empty mappings to avoid errors - def is_valid_mapping(m: Mapping[str, Any]) -> bool: - if not m or not m.get("transfer_method"): - return False - # For REMOTE_URL transfer method, ensure url or remote_url is provided and not None - transfer_method = m.get("transfer_method") - if transfer_method == FileTransferMethod.REMOTE_URL: - url = m.get("url") or m.get("remote_url") - if not url: - return False - return True - - valid_mappings = [m for m in mappings if is_valid_mapping(m)] - files = [ - build_from_mapping( - mapping=mapping, - tenant_id=tenant_id, - config=config, - strict_type_validation=strict_type_validation, - ) - for mapping in valid_mappings - ] - - if ( - config - # If image config is set. - and config.image_config - # And the number of image files exceeds the maximum limit - and sum(1 for _ in (filter(lambda x: x.type == FileType.IMAGE, files))) > config.image_config.number_limits - ): - raise ValueError(f"Number of image files exceeds the maximum limit {config.image_config.number_limits}") - if config and config.number_limits and len(files) > config.number_limits: - raise ValueError(f"Number of files exceeds the maximum limit {config.number_limits}") - - return files - - -def _build_from_local_file( - *, - mapping: Mapping[str, Any], - tenant_id: str, - transfer_method: FileTransferMethod, - strict_type_validation: bool = False, -) -> File: - upload_file_id = mapping.get("upload_file_id") - if not upload_file_id: - raise ValueError("Invalid upload file id") - # check if upload_file_id is a valid uuid - try: - uuid.UUID(upload_file_id) - except ValueError: - raise ValueError("Invalid upload file id format") - stmt = select(UploadFile).where( - UploadFile.id == upload_file_id, - UploadFile.tenant_id == tenant_id, - ) - - row = db.session.scalar(stmt) - if row is None: - raise ValueError("Invalid upload file") - - detected_file_type = _standardize_file_type(extension="." + row.extension, mime_type=row.mime_type) - specified_type = mapping.get("type", "custom") - - if strict_type_validation and detected_file_type.value != specified_type: - raise ValueError("Detected file type does not match the specified type. Please verify the file.") - - if specified_type and specified_type != "custom": - file_type = FileType(specified_type) - else: - file_type = detected_file_type - - return File( - id=mapping.get("id"), - filename=row.name, - extension="." + row.extension, - mime_type=row.mime_type, - tenant_id=tenant_id, - type=file_type, - transfer_method=transfer_method, - remote_url=row.source_url, - related_id=mapping.get("upload_file_id"), - size=row.size, - storage_key=row.key, - ) - - -def _build_from_remote_url( - *, - mapping: Mapping[str, Any], - tenant_id: str, - transfer_method: FileTransferMethod, - strict_type_validation: bool = False, -) -> File: - upload_file_id = mapping.get("upload_file_id") - if upload_file_id: - try: - uuid.UUID(upload_file_id) - except ValueError: - raise ValueError("Invalid upload file id format") - stmt = select(UploadFile).where( - UploadFile.id == upload_file_id, - UploadFile.tenant_id == tenant_id, - ) - - upload_file = db.session.scalar(stmt) - if upload_file is None: - raise ValueError("Invalid upload file") - - detected_file_type = _standardize_file_type( - extension="." + upload_file.extension, mime_type=upload_file.mime_type - ) - - specified_type = mapping.get("type") - - if strict_type_validation and specified_type and detected_file_type.value != specified_type: - raise ValueError("Detected file type does not match the specified type. Please verify the file.") - - if specified_type and specified_type != "custom": - file_type = FileType(specified_type) - else: - file_type = detected_file_type - - return File( - id=mapping.get("id"), - filename=upload_file.name, - extension="." + upload_file.extension, - mime_type=upload_file.mime_type, - tenant_id=tenant_id, - type=file_type, - transfer_method=transfer_method, - remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)), - related_id=mapping.get("upload_file_id"), - size=upload_file.size, - storage_key=upload_file.key, - ) - url = mapping.get("url") or mapping.get("remote_url") - if not url: - raise ValueError("Invalid file url") - - mime_type, filename, file_size = _get_remote_file_info(url) - extension = mimetypes.guess_extension(mime_type) or ("." + filename.split(".")[-1] if "." in filename else ".bin") - - detected_file_type = _standardize_file_type(extension=extension, mime_type=mime_type) - specified_type = mapping.get("type") - - if strict_type_validation and specified_type and detected_file_type.value != specified_type: - raise ValueError("Detected file type does not match the specified type. Please verify the file.") - - if specified_type and specified_type != "custom": - file_type = FileType(specified_type) - else: - file_type = detected_file_type - - return File( - id=mapping.get("id"), - filename=filename, - tenant_id=tenant_id, - type=file_type, - transfer_method=transfer_method, - remote_url=url, - mime_type=mime_type, - extension=extension, - size=file_size, - storage_key="", - ) - - -def _extract_filename(url_path: str, content_disposition: str | None) -> str | None: - filename: str | None = None - # Try to extract from Content-Disposition header first - if content_disposition: - # Manually extract filename* parameter since parse_options_header doesn't support it - filename_star_match = re.search(r"filename\*=([^;]+)", content_disposition) - if filename_star_match: - raw_star = filename_star_match.group(1).strip() - # Remove trailing quotes if present - raw_star = raw_star.removesuffix('"') - # format: charset'lang'value - try: - parts = raw_star.split("'", 2) - charset = (parts[0] or "utf-8").lower() if len(parts) >= 1 else "utf-8" - value = parts[2] if len(parts) == 3 else parts[-1] - filename = urllib.parse.unquote(value, encoding=charset, errors="replace") - except Exception: - # Fallback: try to extract value after the last single quote - if "''" in raw_star: - filename = urllib.parse.unquote(raw_star.split("''")[-1]) - else: - filename = urllib.parse.unquote(raw_star) - - if not filename: - # Fallback to regular filename parameter - _, params = parse_options_header(content_disposition) - raw = params.get("filename") - if raw: - # Strip surrounding quotes and percent-decode if present - if len(raw) >= 2 and raw[0] == raw[-1] == '"': - raw = raw[1:-1] - filename = urllib.parse.unquote(raw) - # Fallback to URL path if no filename from header - if not filename: - candidate = os.path.basename(url_path) - filename = urllib.parse.unquote(candidate) if candidate else None - # Defense-in-depth: ensure basename only - if filename: - filename = os.path.basename(filename) - # Return None if filename is empty or only whitespace - if not filename or not filename.strip(): - filename = None - return filename or None - - -def _guess_mime_type(filename: str) -> str: - """Guess MIME type from filename, returning empty string if None.""" - guessed_mime, _ = mimetypes.guess_type(filename) - return guessed_mime or "" - - -def _get_remote_file_info(url: str): - file_size = -1 - parsed_url = urllib.parse.urlparse(url) - url_path = parsed_url.path - filename = os.path.basename(url_path) - - # Initialize mime_type from filename as fallback - mime_type = _guess_mime_type(filename) - - resp = ssrf_proxy.head(url, follow_redirects=True) - if resp.status_code == httpx.codes.OK: - content_disposition = resp.headers.get("Content-Disposition") - extracted_filename = _extract_filename(url_path, content_disposition) - if extracted_filename: - filename = extracted_filename - mime_type = _guess_mime_type(filename) - file_size = int(resp.headers.get("Content-Length", file_size)) - # Fallback to Content-Type header if mime_type is still empty - if not mime_type: - mime_type = resp.headers.get("Content-Type", "").split(";")[0].strip() - - if not filename: - extension = mimetypes.guess_extension(mime_type) or ".bin" - filename = f"{uuid.uuid4().hex}{extension}" - if not mime_type: - mime_type = _guess_mime_type(filename) - - return mime_type, filename, file_size - - -def _build_from_tool_file( - *, - mapping: Mapping[str, Any], - tenant_id: str, - transfer_method: FileTransferMethod, - strict_type_validation: bool = False, -) -> File: - # Backward/interop compatibility: allow tool_file_id to come from related_id or URL - tool_file_id = mapping.get("tool_file_id") - - if not tool_file_id: - raise ValueError(f"ToolFile {tool_file_id} not found") - tool_file = db.session.scalar( - select(ToolFile).where( - ToolFile.id == tool_file_id, - ToolFile.tenant_id == tenant_id, - ) - ) - - if tool_file is None: - raise ValueError(f"ToolFile {tool_file_id} not found") - - extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" - - detected_file_type = _standardize_file_type(extension=extension, mime_type=tool_file.mimetype) - - specified_type = mapping.get("type") - - if strict_type_validation and specified_type and detected_file_type.value != specified_type: - raise ValueError("Detected file type does not match the specified type. Please verify the file.") - - if specified_type and specified_type != "custom": - file_type = FileType(specified_type) - else: - file_type = detected_file_type - - return File( - id=mapping.get("id"), - tenant_id=tenant_id, - filename=tool_file.name, - type=file_type, - transfer_method=transfer_method, - remote_url=tool_file.original_url, - related_id=tool_file.id, - extension=extension, - mime_type=tool_file.mimetype, - size=tool_file.size, - storage_key=tool_file.file_key, - ) - - -def _build_from_datasource_file( - *, - mapping: Mapping[str, Any], - tenant_id: str, - transfer_method: FileTransferMethod, - strict_type_validation: bool = False, -) -> File: - datasource_file_id = mapping.get("datasource_file_id") - if not datasource_file_id: - raise ValueError(f"DatasourceFile {datasource_file_id} not found") - datasource_file = db.session.scalar( - select(UploadFile).where( - UploadFile.id == datasource_file_id, - UploadFile.tenant_id == tenant_id, - ) - ) - - if datasource_file is None: - raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found") - - extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin" - - detected_file_type = _standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type) - - specified_type = mapping.get("type") - - if strict_type_validation and specified_type and detected_file_type.value != specified_type: - raise ValueError("Detected file type does not match the specified type. Please verify the file.") - - if specified_type and specified_type != "custom": - file_type = FileType(specified_type) - else: - file_type = detected_file_type - - return File( - id=mapping.get("datasource_file_id"), - tenant_id=tenant_id, - filename=datasource_file.name, - type=file_type, - transfer_method=FileTransferMethod.TOOL_FILE, - remote_url=datasource_file.source_url, - related_id=datasource_file.id, - extension=extension, - mime_type=datasource_file.mime_type, - size=datasource_file.size, - storage_key=datasource_file.key, - url=datasource_file.source_url, - ) - - -def _is_file_valid_with_config( - *, - input_file_type: str, - file_extension: str, - file_transfer_method: FileTransferMethod, - config: FileUploadConfig, -) -> bool: - # FIXME(QIN2DIM): Always allow tool files (files generated by the assistant/model) - # These are internally generated and should bypass user upload restrictions - if file_transfer_method == FileTransferMethod.TOOL_FILE: - return True - - if ( - config.allowed_file_types - and input_file_type not in config.allowed_file_types - and input_file_type != FileType.CUSTOM - ): - return False - - if ( - input_file_type == FileType.CUSTOM - and config.allowed_file_extensions is not None - and file_extension not in config.allowed_file_extensions - ): - return False - - if input_file_type == FileType.IMAGE: - if ( - config.image_config - and config.image_config.transfer_methods - and file_transfer_method not in config.image_config.transfer_methods - ): - return False - elif config.allowed_file_upload_methods and file_transfer_method not in config.allowed_file_upload_methods: - return False - - return True - - -def _standardize_file_type(*, extension: str = "", mime_type: str = "") -> FileType: - """ - Infer the possible actual type of the file based on the extension and mime_type - """ - guessed_type = None - if extension: - guessed_type = _get_file_type_by_extension(extension) - if guessed_type is None and mime_type: - guessed_type = _get_file_type_by_mimetype(mime_type) - return guessed_type or FileType.CUSTOM - - -def _get_file_type_by_extension(extension: str) -> FileType | None: - extension = extension.lstrip(".") - if extension in IMAGE_EXTENSIONS: - return FileType.IMAGE - elif extension in VIDEO_EXTENSIONS: - return FileType.VIDEO - elif extension in AUDIO_EXTENSIONS: - return FileType.AUDIO - elif extension in DOCUMENT_EXTENSIONS: - return FileType.DOCUMENT - return None - - -def _get_file_type_by_mimetype(mime_type: str) -> FileType | None: - if "image" in mime_type: - file_type = FileType.IMAGE - elif "video" in mime_type: - file_type = FileType.VIDEO - elif "audio" in mime_type: - file_type = FileType.AUDIO - elif "text" in mime_type or "pdf" in mime_type: - file_type = FileType.DOCUMENT - else: - file_type = FileType.CUSTOM - return file_type - - -def get_file_type_by_mime_type(mime_type: str) -> FileType: - return _get_file_type_by_mimetype(mime_type) or FileType.CUSTOM - - -class StorageKeyLoader: - """FileKeyLoader load the storage key from database for a list of files. - This loader is batched, the database query count is constant regardless of the input size. - """ - - def __init__(self, session: Session, tenant_id: str): - self._session = session - self._tenant_id = tenant_id - - def _load_upload_files(self, upload_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, UploadFile]: - stmt = select(UploadFile).where( - UploadFile.id.in_(upload_file_ids), - UploadFile.tenant_id == self._tenant_id, - ) - - return {uuid.UUID(i.id): i for i in self._session.scalars(stmt)} - - def _load_tool_files(self, tool_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, ToolFile]: - stmt = select(ToolFile).where( - ToolFile.id.in_(tool_file_ids), - ToolFile.tenant_id == self._tenant_id, - ) - return {uuid.UUID(i.id): i for i in self._session.scalars(stmt)} - - def load_storage_keys(self, files: Sequence[File]): - """Loads storage keys for a sequence of files by retrieving the corresponding - `UploadFile` or `ToolFile` records from the database based on their transfer method. - - This method doesn't modify the input sequence structure but updates the `_storage_key` - property of each file object by extracting the relevant key from its database record. - - Performance note: This is a batched operation where database query count remains constant - regardless of input size. However, for optimal performance, input sequences should contain - fewer than 1000 files. For larger collections, split into smaller batches and process each - batch separately. - """ - - upload_file_ids: list[uuid.UUID] = [] - tool_file_ids: list[uuid.UUID] = [] - for file in files: - related_model_id = file.related_id - if file.related_id is None: - raise ValueError("file id should not be None.") - if file.tenant_id != self._tenant_id: - err_msg = ( - f"invalid file, expected tenant_id={self._tenant_id}, " - f"got tenant_id={file.tenant_id}, file_id={file.id}, related_model_id={related_model_id}" - ) - raise ValueError(err_msg) - model_id = uuid.UUID(related_model_id) - - if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL): - upload_file_ids.append(model_id) - elif file.transfer_method == FileTransferMethod.TOOL_FILE: - tool_file_ids.append(model_id) - - tool_files = self._load_tool_files(tool_file_ids) - upload_files = self._load_upload_files(upload_file_ids) - for file in files: - model_id = uuid.UUID(file.related_id) - if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL): - upload_file_row = upload_files.get(model_id) - if upload_file_row is None: - raise ValueError(f"Upload file not found for id: {model_id}") - file.storage_key = upload_file_row.key - elif file.transfer_method == FileTransferMethod.TOOL_FILE: - tool_file_row = tool_files.get(model_id) - if tool_file_row is None: - raise ValueError(f"Tool file not found for id: {model_id}") - file.storage_key = tool_file_row.file_key diff --git a/api/factories/file_factory/__init__.py b/api/factories/file_factory/__init__.py new file mode 100644 index 00000000000..ae0cd972ec5 --- /dev/null +++ b/api/factories/file_factory/__init__.py @@ -0,0 +1,18 @@ +"""Workflow file factory package. + +This package normalizes workflow-layer file payloads into graph-layer ``File`` +values. It keeps tenancy and ownership checks in the application layer and +exports the workflow-facing file builders for callers. +""" + +from .builders import build_from_mapping, build_from_mappings +from .message_files import build_from_message_file, build_from_message_files +from .storage_keys import StorageKeyLoader + +__all__ = [ + "StorageKeyLoader", + "build_from_mapping", + "build_from_mappings", + "build_from_message_file", + "build_from_message_files", +] diff --git a/api/factories/file_factory/builders.py b/api/factories/file_factory/builders.py new file mode 100644 index 00000000000..7516d18c8e0 --- /dev/null +++ b/api/factories/file_factory/builders.py @@ -0,0 +1,328 @@ +"""Core builders for workflow file mappings.""" + +from __future__ import annotations + +import mimetypes +import uuid +from collections.abc import Mapping, Sequence +from typing import Any + +from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig, helpers, standardize_file_type +from sqlalchemy import select + +from core.app.file_access import FileAccessControllerProtocol +from core.workflow.file_reference import build_file_reference +from extensions.ext_database import db +from models import ToolFile, UploadFile + +from .common import resolve_mapping_file_id +from .remote import get_remote_file_info +from .validation import is_file_valid_with_config + + +def build_from_mapping( + *, + mapping: Mapping[str, Any], + tenant_id: str, + config: FileUploadConfig | None = None, + strict_type_validation: bool = False, + access_controller: FileAccessControllerProtocol, +) -> File: + transfer_method_value = mapping.get("transfer_method") + if not transfer_method_value: + raise ValueError("transfer_method is required in file mapping") + + transfer_method = FileTransferMethod.value_of(transfer_method_value) + build_func = _get_build_function(transfer_method) + file = build_func( + mapping=mapping, + tenant_id=tenant_id, + transfer_method=transfer_method, + strict_type_validation=strict_type_validation, + access_controller=access_controller, + ) + + if config and not is_file_valid_with_config( + input_file_type=mapping.get("type", FileType.CUSTOM), + file_extension=file.extension or "", + file_transfer_method=file.transfer_method, + config=config, + ): + raise ValueError(f"File validation failed for file: {file.filename}") + + return file + + +def build_from_mappings( + *, + mappings: Sequence[Mapping[str, Any]], + config: FileUploadConfig | None = None, + tenant_id: str, + strict_type_validation: bool = False, + access_controller: FileAccessControllerProtocol, +) -> Sequence[File]: + # TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query. + # Implement batch processing to reduce database load when handling multiple files. + valid_mappings = [mapping for mapping in mappings if _is_valid_mapping(mapping)] + files = [ + build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + config=config, + strict_type_validation=strict_type_validation, + access_controller=access_controller, + ) + for mapping in valid_mappings + ] + + if ( + config + and config.image_config + and sum(1 for file in files if file.type == FileType.IMAGE) > config.image_config.number_limits + ): + raise ValueError(f"Number of image files exceeds the maximum limit {config.image_config.number_limits}") + if config and config.number_limits and len(files) > config.number_limits: + raise ValueError(f"Number of files exceeds the maximum limit {config.number_limits}") + + return files + + +def _get_build_function(transfer_method: FileTransferMethod): + build_functions = { + FileTransferMethod.LOCAL_FILE: _build_from_local_file, + FileTransferMethod.REMOTE_URL: _build_from_remote_url, + FileTransferMethod.TOOL_FILE: _build_from_tool_file, + FileTransferMethod.DATASOURCE_FILE: _build_from_datasource_file, + } + build_func = build_functions.get(transfer_method) + if build_func is None: + raise ValueError(f"Invalid file transfer method: {transfer_method}") + return build_func + + +def _resolve_file_type( + *, + detected_file_type: FileType, + specified_type: str | None, + strict_type_validation: bool, +) -> FileType: + if strict_type_validation and specified_type and detected_file_type.value != specified_type: + raise ValueError("Detected file type does not match the specified type. Please verify the file.") + + if specified_type and specified_type != "custom": + return FileType(specified_type) + return detected_file_type + + +def _build_from_local_file( + *, + mapping: Mapping[str, Any], + tenant_id: str, + transfer_method: FileTransferMethod, + strict_type_validation: bool = False, + access_controller: FileAccessControllerProtocol, +) -> File: + upload_file_id = resolve_mapping_file_id(mapping, "upload_file_id") + if not upload_file_id: + raise ValueError("Invalid upload file id") + + try: + uuid.UUID(upload_file_id) + except ValueError as exc: + raise ValueError("Invalid upload file id format") from exc + + stmt = select(UploadFile).where( + UploadFile.id == upload_file_id, + UploadFile.tenant_id == tenant_id, + ) + row = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) + if row is None: + raise ValueError("Invalid upload file") + + detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type", "custom"), + strict_type_validation=strict_type_validation, + ) + + return File( + id=mapping.get("id"), + filename=row.name, + extension="." + row.extension, + mime_type=row.mime_type, + type=file_type, + transfer_method=transfer_method, + remote_url=row.source_url, + reference=build_file_reference(record_id=str(row.id)), + size=row.size, + storage_key=row.key, + ) + + +def _build_from_remote_url( + *, + mapping: Mapping[str, Any], + tenant_id: str, + transfer_method: FileTransferMethod, + strict_type_validation: bool = False, + access_controller: FileAccessControllerProtocol, +) -> File: + upload_file_id = resolve_mapping_file_id(mapping, "upload_file_id") + if upload_file_id: + try: + uuid.UUID(upload_file_id) + except ValueError as exc: + raise ValueError("Invalid upload file id format") from exc + + stmt = select(UploadFile).where( + UploadFile.id == upload_file_id, + UploadFile.tenant_id == tenant_id, + ) + upload_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) + if upload_file is None: + raise ValueError("Invalid upload file") + + detected_file_type = standardize_file_type( + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + ) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) + + return File( + id=mapping.get("id"), + filename=upload_file.name, + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + type=file_type, + transfer_method=transfer_method, + remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)), + reference=build_file_reference(record_id=str(upload_file.id)), + size=upload_file.size, + storage_key=upload_file.key, + ) + + url = mapping.get("url") or mapping.get("remote_url") + if not url: + raise ValueError("Invalid file url") + + mime_type, filename, file_size = get_remote_file_info(url) + extension = mimetypes.guess_extension(mime_type) or ("." + filename.split(".")[-1] if "." in filename else ".bin") + detected_file_type = standardize_file_type(extension=extension, mime_type=mime_type) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) + + return File( + id=mapping.get("id"), + filename=filename, + type=file_type, + transfer_method=transfer_method, + remote_url=url, + mime_type=mime_type, + extension=extension, + size=file_size, + ) + + +def _build_from_tool_file( + *, + mapping: Mapping[str, Any], + tenant_id: str, + transfer_method: FileTransferMethod, + strict_type_validation: bool = False, + access_controller: FileAccessControllerProtocol, +) -> File: + tool_file_id = resolve_mapping_file_id(mapping, "tool_file_id") + if not tool_file_id: + raise ValueError(f"ToolFile {tool_file_id} not found") + + stmt = select(ToolFile).where( + ToolFile.id == tool_file_id, + ToolFile.tenant_id == tenant_id, + ) + tool_file = db.session.scalar(access_controller.apply_tool_file_filters(stmt)) + if tool_file is None: + raise ValueError(f"ToolFile {tool_file_id} not found") + + extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" + detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) + + return File( + id=mapping.get("id"), + filename=tool_file.name, + type=file_type, + transfer_method=transfer_method, + remote_url=tool_file.original_url, + reference=build_file_reference(record_id=str(tool_file.id)), + extension=extension, + mime_type=tool_file.mimetype, + size=tool_file.size, + storage_key=tool_file.file_key, + ) + + +def _build_from_datasource_file( + *, + mapping: Mapping[str, Any], + tenant_id: str, + transfer_method: FileTransferMethod, + strict_type_validation: bool = False, + access_controller: FileAccessControllerProtocol, +) -> File: + datasource_file_id = resolve_mapping_file_id(mapping, "datasource_file_id") + if not datasource_file_id: + raise ValueError(f"DatasourceFile {datasource_file_id} not found") + + stmt = select(UploadFile).where( + UploadFile.id == datasource_file_id, + UploadFile.tenant_id == tenant_id, + ) + datasource_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) + if datasource_file is None: + raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found") + + extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin" + detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) + + return File( + id=mapping.get("datasource_file_id"), + filename=datasource_file.name, + type=file_type, + transfer_method=FileTransferMethod.TOOL_FILE, + remote_url=datasource_file.source_url, + reference=build_file_reference(record_id=str(datasource_file.id)), + extension=extension, + mime_type=datasource_file.mime_type, + size=datasource_file.size, + storage_key=datasource_file.key, + url=datasource_file.source_url, + ) + + +def _is_valid_mapping(mapping: Mapping[str, Any]) -> bool: + if not mapping or not mapping.get("transfer_method"): + return False + + if mapping.get("transfer_method") == FileTransferMethod.REMOTE_URL: + url = mapping.get("url") or mapping.get("remote_url") + if not url: + return False + + return True diff --git a/api/factories/file_factory/common.py b/api/factories/file_factory/common.py new file mode 100644 index 00000000000..2e1c95ab3fe --- /dev/null +++ b/api/factories/file_factory/common.py @@ -0,0 +1,27 @@ +"""Shared helpers for workflow file factory modules.""" + +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +from core.workflow.file_reference import resolve_file_record_id + + +def resolve_mapping_file_id(mapping: Mapping[str, Any], *keys: str) -> str | None: + """Resolve historical file identifiers from persisted mapping payloads. + + Workflow and model payloads can outlive file schema changes. Older rows may + still carry concrete identifiers in legacy fields such as ``related_id``, + while newer payloads use opaque references. Keep this compatibility lookup in + the factory layer so historical data remains readable without reintroducing + storage details into graph-layer ``File`` values. + """ + + for key in (*keys, "reference", "related_id"): + raw_value = mapping.get(key) + if isinstance(raw_value, str) and raw_value: + resolved_value = resolve_file_record_id(raw_value) + if resolved_value: + return resolved_value + return None diff --git a/api/factories/file_factory/message_files.py b/api/factories/file_factory/message_files.py new file mode 100644 index 00000000000..5582b85c956 --- /dev/null +++ b/api/factories/file_factory/message_files.py @@ -0,0 +1,60 @@ +"""Adapters from persisted message files to graph-layer file values.""" + +from __future__ import annotations + +from collections.abc import Sequence + +from graphon.file import File, FileBelongsTo, FileTransferMethod, FileUploadConfig + +from core.app.file_access import FileAccessControllerProtocol +from models import MessageFile + +from .builders import build_from_mapping + + +def build_from_message_files( + *, + message_files: Sequence[MessageFile], + tenant_id: str, + config: FileUploadConfig | None = None, + access_controller: FileAccessControllerProtocol, +) -> Sequence[File]: + return [ + build_from_message_file( + message_file=message_file, + tenant_id=tenant_id, + config=config, + access_controller=access_controller, + ) + for message_file in message_files + if message_file.belongs_to != FileBelongsTo.ASSISTANT + ] + + +def build_from_message_file( + *, + message_file: MessageFile, + tenant_id: str, + config: FileUploadConfig | None, + access_controller: FileAccessControllerProtocol, +) -> File: + mapping = { + "transfer_method": message_file.transfer_method, + "url": message_file.url, + "type": message_file.type, + } + + if message_file.id: + mapping["id"] = message_file.id + + if message_file.transfer_method == FileTransferMethod.TOOL_FILE: + mapping["tool_file_id"] = message_file.upload_file_id + else: + mapping["upload_file_id"] = message_file.upload_file_id + + return build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + config=config, + access_controller=access_controller, + ) diff --git a/api/factories/file_factory/remote.py b/api/factories/file_factory/remote.py new file mode 100644 index 00000000000..e5a71860077 --- /dev/null +++ b/api/factories/file_factory/remote.py @@ -0,0 +1,91 @@ +"""Remote file metadata helpers used by workflow file normalization. + +These helpers are part of the ``factories.file_factory`` package surface +because both workflow builders and tests rely on the same RFC5987 filename +parsing and HEAD-response normalization rules. +""" + +from __future__ import annotations + +import mimetypes +import os +import re +import urllib.parse +import uuid + +import httpx +from werkzeug.http import parse_options_header + +from core.helper import ssrf_proxy + + +def extract_filename(url_path: str, content_disposition: str | None) -> str | None: + """Extract a safe filename from Content-Disposition or the request URL path.""" + filename: str | None = None + if content_disposition: + filename_star_match = re.search(r"filename\*=([^;]+)", content_disposition) + if filename_star_match: + raw_star = filename_star_match.group(1).strip() + raw_star = raw_star.removesuffix('"') + try: + parts = raw_star.split("'", 2) + charset = (parts[0] or "utf-8").lower() if len(parts) >= 1 else "utf-8" + value = parts[2] if len(parts) == 3 else parts[-1] + filename = urllib.parse.unquote(value, encoding=charset, errors="replace") + except Exception: + if "''" in raw_star: + filename = urllib.parse.unquote(raw_star.split("''")[-1]) + else: + filename = urllib.parse.unquote(raw_star) + + if not filename: + _, params = parse_options_header(content_disposition) + raw = params.get("filename") + if raw: + if len(raw) >= 2 and raw[0] == raw[-1] == '"': + raw = raw[1:-1] + filename = urllib.parse.unquote(raw) + + if not filename: + candidate = os.path.basename(url_path) + filename = urllib.parse.unquote(candidate) if candidate else None + + if filename: + filename = os.path.basename(filename) + if not filename or not filename.strip(): + filename = None + + return filename or None + + +def _guess_mime_type(filename: str) -> str: + guessed_mime, _ = mimetypes.guess_type(filename) + return guessed_mime or "" + + +def get_remote_file_info(url: str) -> tuple[str, str, int]: + """Resolve remote file metadata with SSRF-safe HEAD probing.""" + file_size = -1 + parsed_url = urllib.parse.urlparse(url) + url_path = parsed_url.path + filename = os.path.basename(url_path) + mime_type = _guess_mime_type(filename) + + resp = ssrf_proxy.head(url, follow_redirects=True) + if resp.status_code == httpx.codes.OK: + content_disposition = resp.headers.get("Content-Disposition") + extracted_filename = extract_filename(url_path, content_disposition) + if extracted_filename: + filename = extracted_filename + mime_type = _guess_mime_type(filename) + file_size = int(resp.headers.get("Content-Length", file_size)) + if not mime_type: + mime_type = resp.headers.get("Content-Type", "").split(";")[0].strip() + + if not filename: + extension = mimetypes.guess_extension(mime_type) or ".bin" + filename = f"{uuid.uuid4().hex}{extension}" + if not mime_type: + mime_type = _guess_mime_type(filename) + + return mime_type, filename, file_size diff --git a/api/factories/file_factory/storage_keys.py b/api/factories/file_factory/storage_keys.py new file mode 100644 index 00000000000..db3a7f30159 --- /dev/null +++ b/api/factories/file_factory/storage_keys.py @@ -0,0 +1,106 @@ +"""Batched storage-key hydration for workflow files.""" + +from __future__ import annotations + +import uuid +from collections.abc import Mapping, Sequence + +from graphon.file import File, FileTransferMethod +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.app.file_access import FileAccessControllerProtocol +from core.workflow.file_reference import build_file_reference, parse_file_reference +from models import ToolFile, UploadFile + + +class StorageKeyLoader: + """Load storage keys for files with a constant number of database queries.""" + + _session: Session + _tenant_id: str + _access_controller: FileAccessControllerProtocol + + def __init__( + self, + session: Session, + tenant_id: str, + access_controller: FileAccessControllerProtocol, + ) -> None: + self._session = session + self._tenant_id = tenant_id + self._access_controller = access_controller + + def _load_upload_files(self, upload_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, UploadFile]: + stmt = select(UploadFile).where( + UploadFile.id.in_(upload_file_ids), + UploadFile.tenant_id == self._tenant_id, + ) + scoped_stmt = self._access_controller.apply_upload_file_filters(stmt) + return {uuid.UUID(upload_file.id): upload_file for upload_file in self._session.scalars(scoped_stmt)} + + def _load_tool_files(self, tool_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, ToolFile]: + stmt = select(ToolFile).where( + ToolFile.id.in_(tool_file_ids), + ToolFile.tenant_id == self._tenant_id, + ) + scoped_stmt = self._access_controller.apply_tool_file_filters(stmt) + return {uuid.UUID(tool_file.id): tool_file for tool_file in self._session.scalars(scoped_stmt)} + + def load_storage_keys(self, files: Sequence[File]) -> None: + """Hydrate storage keys by loading their backing file rows in batches. + + The sequence shape is preserved. Each file is updated in place with a + canonical record reference and storage key loaded from an authorized + database row. Tenant scoping is enforced by this loader's context + rather than by embedding tenant identity or storage paths inside + graph-layer ``File`` values. + + For best performance, prefer batches smaller than 1000 files. + """ + + upload_file_ids: list[uuid.UUID] = [] + tool_file_ids: list[uuid.UUID] = [] + for file in files: + parsed_reference = parse_file_reference(file.reference) + if parsed_reference is None: + raise ValueError("file id should not be None.") + + model_id = uuid.UUID(parsed_reference.record_id) + if file.transfer_method in ( + FileTransferMethod.LOCAL_FILE, + FileTransferMethod.REMOTE_URL, + FileTransferMethod.DATASOURCE_FILE, + ): + upload_file_ids.append(model_id) + elif file.transfer_method == FileTransferMethod.TOOL_FILE: + tool_file_ids.append(model_id) + + tool_files = self._load_tool_files(tool_file_ids) + upload_files = self._load_upload_files(upload_file_ids) + for file in files: + parsed_reference = parse_file_reference(file.reference) + if parsed_reference is None: + raise ValueError("file id should not be None.") + + model_id = uuid.UUID(parsed_reference.record_id) + if file.transfer_method in ( + FileTransferMethod.LOCAL_FILE, + FileTransferMethod.REMOTE_URL, + FileTransferMethod.DATASOURCE_FILE, + ): + upload_file_row = upload_files.get(model_id) + if upload_file_row is None: + raise ValueError(f"Upload file not found for id: {model_id}") + file.reference = build_file_reference( + record_id=str(upload_file_row.id), + ) + file.storage_key = upload_file_row.key + elif file.transfer_method == FileTransferMethod.TOOL_FILE: + tool_file_row = tool_files.get(model_id) + if tool_file_row is None: + raise ValueError(f"Tool file not found for id: {model_id}") + file.reference = build_file_reference( + record_id=str(tool_file_row.id), + ) + file.storage_key = tool_file_row.file_key diff --git a/api/factories/file_factory/validation.py b/api/factories/file_factory/validation.py new file mode 100644 index 00000000000..4c4f6150e40 --- /dev/null +++ b/api/factories/file_factory/validation.py @@ -0,0 +1,44 @@ +"""Validation helpers for workflow file inputs.""" + +from __future__ import annotations + +from graphon.file import FileTransferMethod, FileType, FileUploadConfig + + +def is_file_valid_with_config( + *, + input_file_type: str, + file_extension: str, + file_transfer_method: FileTransferMethod, + config: FileUploadConfig, +) -> bool: + # FIXME(QIN2DIM): Always allow tool files (files generated by the assistant/model) + # These are internally generated and should bypass user upload restrictions + if file_transfer_method == FileTransferMethod.TOOL_FILE: + return True + + if ( + config.allowed_file_types + and input_file_type not in config.allowed_file_types + and input_file_type != FileType.CUSTOM + ): + return False + + if ( + input_file_type == FileType.CUSTOM + and config.allowed_file_extensions is not None + and file_extension not in config.allowed_file_extensions + ): + return False + + if input_file_type == FileType.IMAGE: + if ( + config.image_config + and config.image_config.transfer_methods + and file_transfer_method not in config.image_config.transfer_methods + ): + return False + elif config.allowed_file_upload_methods and file_transfer_method not in config.allowed_file_upload_methods: + return False + + return True diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 14a56bf4a27..57205b5739f 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -1,75 +1,52 @@ +"""Compatibility factory for non-graph variable bootstrapping. + +Graph runtime segment/variable conversions live under `graphon.variables`. +This module keeps the application-layer mapping helpers and re-exports the +shared conversion functions for legacy callers and tests. +""" + from collections.abc import Mapping, Sequence from typing import Any, cast -from uuid import uuid4 -from configs import dify_config -from dify_graph.constants import ( - CONVERSATION_VARIABLE_NODE_ID, - ENVIRONMENT_VARIABLE_NODE_ID, +from graphon.variables.exc import VariableError +from graphon.variables.factory import ( + TypeMismatchError, + UnsupportedSegmentTypeError, + build_segment, + build_segment_with_type, + segment_to_variable, ) -from dify_graph.file import File -from dify_graph.variables.exc import VariableError -from dify_graph.variables.segments import ( - ArrayAnySegment, - ArrayBooleanSegment, - ArrayFileSegment, - ArrayNumberSegment, - ArrayObjectSegment, - ArraySegment, - ArrayStringSegment, - BooleanSegment, - FileSegment, - FloatSegment, - IntegerSegment, - NoneSegment, - ObjectSegment, - Segment, - StringSegment, -) -from dify_graph.variables.types import SegmentType -from dify_graph.variables.variables import ( - ArrayAnyVariable, +from graphon.variables.types import SegmentType +from graphon.variables.variables import ( ArrayBooleanVariable, - ArrayFileVariable, ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, BooleanVariable, - FileVariable, FloatVariable, IntegerVariable, - NoneVariable, ObjectVariable, SecretVariable, StringVariable, VariableBase, ) +from configs import dify_config +from core.workflow.variable_prefixes import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, +) -class UnsupportedSegmentTypeError(Exception): - pass - - -class TypeMismatchError(Exception): - pass - - -# Define the constant -SEGMENT_TO_VARIABLE_MAP: Mapping[type[Segment], type[VariableBase]] = { - ArrayAnySegment: ArrayAnyVariable, - ArrayBooleanSegment: ArrayBooleanVariable, - ArrayFileSegment: ArrayFileVariable, - ArrayNumberSegment: ArrayNumberVariable, - ArrayObjectSegment: ArrayObjectVariable, - ArrayStringSegment: ArrayStringVariable, - BooleanSegment: BooleanVariable, - FileSegment: FileVariable, - FloatSegment: FloatVariable, - IntegerSegment: IntegerVariable, - NoneSegment: NoneVariable, - ObjectSegment: ObjectVariable, - StringSegment: StringVariable, -} +__all__ = [ + "TypeMismatchError", + "UnsupportedSegmentTypeError", + "build_conversation_variable_from_mapping", + "build_environment_variable_from_mapping", + "build_pipeline_variable_from_mapping", + "build_segment", + "build_segment_with_type", + "segment_to_variable", +] def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase: @@ -135,172 +112,3 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen if not result.selector: result = result.model_copy(update={"selector": selector}) return cast(VariableBase, result) - - -def build_segment(value: Any, /) -> Segment: - # NOTE: If you have runtime type information available, consider using the `build_segment_with_type` - # below - if value is None: - return NoneSegment() - if isinstance(value, Segment): - return value - if isinstance(value, str): - return StringSegment(value=value) - if isinstance(value, bool): - return BooleanSegment(value=value) - if isinstance(value, int): - return IntegerSegment(value=value) - if isinstance(value, float): - return FloatSegment(value=value) - if isinstance(value, dict): - return ObjectSegment(value=value) - if isinstance(value, File): - return FileSegment(value=value) - if isinstance(value, list): - items = [build_segment(item) for item in value] - types = {item.value_type for item in items} - if all(isinstance(item, ArraySegment) for item in items): - return ArrayAnySegment(value=value) - elif len(types) != 1: - if types.issubset({SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}): - return ArrayNumberSegment(value=value) - return ArrayAnySegment(value=value) - - match types.pop(): - case SegmentType.STRING: - return ArrayStringSegment(value=value) - case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: - return ArrayNumberSegment(value=value) - case SegmentType.BOOLEAN: - return ArrayBooleanSegment(value=value) - case SegmentType.OBJECT: - return ArrayObjectSegment(value=value) - case SegmentType.FILE: - return ArrayFileSegment(value=value) - case SegmentType.NONE: - return ArrayAnySegment(value=value) - case _: - # This should be unreachable. - raise ValueError(f"not supported value {value}") - raise ValueError(f"not supported value {value}") - - -_segment_factory: Mapping[SegmentType, type[Segment]] = { - SegmentType.NONE: NoneSegment, - SegmentType.STRING: StringSegment, - SegmentType.INTEGER: IntegerSegment, - SegmentType.FLOAT: FloatSegment, - SegmentType.FILE: FileSegment, - SegmentType.BOOLEAN: BooleanSegment, - SegmentType.OBJECT: ObjectSegment, - # Array types - SegmentType.ARRAY_ANY: ArrayAnySegment, - SegmentType.ARRAY_STRING: ArrayStringSegment, - SegmentType.ARRAY_NUMBER: ArrayNumberSegment, - SegmentType.ARRAY_OBJECT: ArrayObjectSegment, - SegmentType.ARRAY_FILE: ArrayFileSegment, - SegmentType.ARRAY_BOOLEAN: ArrayBooleanSegment, -} - - -def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: - """ - Build a segment with explicit type checking. - - This function creates a segment from a value while enforcing type compatibility - with the specified segment_type. It provides stricter type validation compared - to the standard build_segment function. - - Args: - segment_type: The expected SegmentType for the resulting segment - value: The value to be converted into a segment - - Returns: - Segment: A segment instance of the appropriate type - - Raises: - TypeMismatchError: If the value type doesn't match the expected segment_type - - Special Cases: - - For empty list [] values, if segment_type is array[*], returns the corresponding array type - - Type validation is performed before segment creation - - Examples: - >>> build_segment_with_type(SegmentType.STRING, "hello") - StringSegment(value="hello") - - >>> build_segment_with_type(SegmentType.ARRAY_STRING, []) - ArrayStringSegment(value=[]) - - >>> build_segment_with_type(SegmentType.STRING, 123) - # Raises TypeMismatchError - """ - # Handle None values - if value is None: - if segment_type == SegmentType.NONE: - return NoneSegment() - else: - raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got None") - - # Handle empty list special case for array types - if isinstance(value, list) and len(value) == 0: - if segment_type == SegmentType.ARRAY_ANY: - return ArrayAnySegment(value=value) - elif segment_type == SegmentType.ARRAY_STRING: - return ArrayStringSegment(value=value) - elif segment_type == SegmentType.ARRAY_BOOLEAN: - return ArrayBooleanSegment(value=value) - elif segment_type == SegmentType.ARRAY_NUMBER: - return ArrayNumberSegment(value=value) - elif segment_type == SegmentType.ARRAY_OBJECT: - return ArrayObjectSegment(value=value) - elif segment_type == SegmentType.ARRAY_FILE: - return ArrayFileSegment(value=value) - else: - raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got empty list") - - inferred_type = SegmentType.infer_segment_type(value) - # Type compatibility checking - if inferred_type is None: - raise TypeMismatchError( - f"Type mismatch: expected {segment_type}, but got python object, type={type(value)}, value={value}" - ) - if inferred_type == segment_type: - segment_class = _segment_factory[segment_type] - return segment_class(value_type=segment_type, value=value) - elif segment_type == SegmentType.NUMBER and inferred_type in ( - SegmentType.INTEGER, - SegmentType.FLOAT, - ): - segment_class = _segment_factory[inferred_type] - return segment_class(value_type=inferred_type, value=value) - else: - raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}") - - -def segment_to_variable( - *, - segment: Segment, - selector: Sequence[str], - id: str | None = None, - name: str | None = None, - description: str = "", -) -> VariableBase: - if isinstance(segment, VariableBase): - return segment - name = name or selector[-1] - id = id or str(uuid4()) - - segment_type = type(segment) - if segment_type not in SEGMENT_TO_VARIABLE_MAP: - raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}") - - variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type] - return variable_class( - id=id, - name=name, - description=description, - value_type=segment.value_type, - value=segment.value, - selector=list(selector), - ) diff --git a/api/fields/_value_type_serializer.py b/api/fields/_value_type_serializer.py index ac7c5376fb7..b5acbbbcb4d 100644 --- a/api/fields/_value_type_serializer.py +++ b/api/fields/_value_type_serializer.py @@ -1,7 +1,7 @@ from typing import TypedDict -from dify_graph.variables.segments import Segment -from dify_graph.variables.types import SegmentType +from graphon.variables.segments import Segment +from graphon.variables.types import SegmentType class _VarTypedDict(TypedDict, total=False): diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index a5c7ddbb110..30d02aeedc2 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -3,10 +3,9 @@ from __future__ import annotations from datetime import datetime from typing import Any, TypeAlias +from graphon.file import File from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator -from dify_graph.file import File - JSONValue: TypeAlias = Any @@ -311,7 +310,9 @@ def to_timestamp(value: datetime | None) -> int | None: def format_files_contained(value: JSONValue) -> JSONValue: if isinstance(value, File): - return value.model_dump() + # Response payloads must preserve legacy file keys like `related_id`/`url` + # while still exposing the new graph-layer `reference` field. + return value.to_dict() if isinstance(value, dict): return {k: format_files_contained(v) for k, v in value.items()} if isinstance(value, list): diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index 7ee628726b4..b8daa5af303 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -3,10 +3,9 @@ from __future__ import annotations from datetime import datetime from flask_restx import fields +from graphon.file import helpers as file_helpers from pydantic import BaseModel, ConfigDict, computed_field, field_validator -from dify_graph.file import helpers as file_helpers - simple_account_fields = { "id": fields.String, "name": fields.String, diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index 428f92ed337..d982c31aeeb 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -4,10 +4,10 @@ from datetime import datetime from typing import TypeAlias from uuid import uuid4 +from graphon.file import File from pydantic import BaseModel, ConfigDict, Field, field_validator from core.entities.execution_extra_content import ExecutionExtraContentDomainModel -from dify_graph.file import File from fields.conversation_fields import AgentThought, JSONValue, MessageFile JSONValueType: TypeAlias = JSONValue @@ -133,7 +133,9 @@ def to_timestamp(value: datetime | None) -> int | None: def format_files_contained(value: JSONValueType) -> JSONValueType: if isinstance(value, File): - return value.model_dump() + # Response payloads must preserve legacy file keys like `related_id`/`url` + # while still exposing the new graph-layer `reference` field. + return value.to_dict() if isinstance(value, dict): return {k: format_files_contained(v) for k, v in value.items()} if isinstance(value, list): diff --git a/api/fields/raws.py b/api/fields/raws.py index 318dedc25cf..4c65cdab7af 100644 --- a/api/fields/raws.py +++ b/api/fields/raws.py @@ -1,6 +1,5 @@ from flask_restx import fields - -from dify_graph.file import File +from graphon.file import File class FilesContainedField(fields.Raw): diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 7ce21396879..b0b6cc0b483 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,7 +1,7 @@ from flask_restx import fields +from graphon.variables import SecretVariable, SegmentType, VariableBase from core.helper import encrypter -from dify_graph.variables import SecretVariable, SegmentType, VariableBase from fields.member_fields import simple_account_fields from libs.helper import TimestampField diff --git a/api/libs/broadcast_channel/channel.py b/api/libs/broadcast_channel/channel.py index d4cb3e9971f..8eeac372325 100644 --- a/api/libs/broadcast_channel/channel.py +++ b/api/libs/broadcast_channel/channel.py @@ -125,7 +125,8 @@ class BroadcastChannel(Protocol): a specific topic, all subscription should receive the published message. There are no restriction for the persistence of messages. Once a subscription is created, it - should receive all subsequent messages published. + should receive all subsequent messages published. However, a subscription should not receive + any message published before the subscription is established. `BroadcastChannel` implementations must be thread-safe and support concurrent use by multiple threads. """ diff --git a/api/libs/broadcast_channel/redis/streams_channel.py b/api/libs/broadcast_channel/redis/streams_channel.py index d6ec5504ca8..983f785027a 100644 --- a/api/libs/broadcast_channel/redis/streams_channel.py +++ b/api/libs/broadcast_channel/redis/streams_channel.py @@ -63,21 +63,45 @@ class _StreamsSubscription(Subscription): def __init__(self, client: Redis | RedisCluster, key: str): self._client = client self._key = key - self._closed = threading.Event() - self._last_id = "0-0" + self._queue: queue.Queue[object] = queue.Queue() - self._start_lock = threading.Lock() + + # The `_lock` lock is used to + # + # 1. protect the _listener attribute + # 2. prevent repeated releases of underlying resoueces. (The _closed flag.) + # + # INVARIANT: the implementation must hold the lock while + # reading and writing the _listener / `_closed` attribute. + self._lock = threading.Lock() + self._closed: bool = False + # self._closed = threading.Event() self._listener: threading.Thread | None = None def _listen(self) -> None: - try: - while not self._closed.is_set(): - streams = self._client.xread({self._key: self._last_id}, block=1000, count=100) + """The `_listen` method handles the message retrieval loop. It requires a dedicated thread + and is not intended for direct invocation. + The thread is started by `_start_if_needed`. + """ + + # since this method runs in a dedicated thread, acquiring `_lock` inside this method won't cause + # deadlock. + + # Setting initial last id to `$` to signal redis that we only want new messages. + # + # ref: https://redis.io/docs/latest/commands/xread/#the-special--id + last_id = "$" + try: + while True: + with self._lock: + if self._closed: + break + streams = self._client.xread({self._key: last_id}, block=1000, count=100) if not streams: continue - for _key, entries in streams: + for _, entries in streams: for entry_id, fields in entries: data = None if isinstance(fields, dict): @@ -89,37 +113,48 @@ class _StreamsSubscription(Subscription): data_bytes = bytes(data) if data_bytes is not None: self._queue.put_nowait(data_bytes) - self._last_id = entry_id + last_id = entry_id finally: self._queue.put_nowait(self._SENTINEL) - self._listener = None + with self._lock: + self._listener = None + self._closed = True def _start_if_needed(self) -> None: + """This method must be called with `_lock` held.""" if self._listener is not None: return # Ensure only one listener thread is created under concurrent calls - with self._start_lock: - if self._listener is not None or self._closed.is_set(): - return - self._listener = threading.Thread( - target=self._listen, - name=f"redis-streams-sub-{self._key}", - daemon=True, - ) - self._listener.start() + if self._listener is not None or self._closed: + return + self._listener = threading.Thread( + target=self._listen, + name=f"redis-streams-sub-{self._key}", + daemon=True, + ) + self._listener.start() def __iter__(self) -> Iterator[bytes]: # Iterator delegates to receive with timeout; stops on closure. - self._start_if_needed() - while not self._closed.is_set(): - item = self.receive(timeout=1) + with self._lock: + self._start_if_needed() + + while True: + with self._lock: + if self._closed: + return + try: + item = self.receive(timeout=1) + except SubscriptionClosedError: + return if item is not None: yield item def receive(self, timeout: float | None = 0.1) -> bytes | None: - if self._closed.is_set(): - raise SubscriptionClosedError("The Redis streams subscription is closed") - self._start_if_needed() + with self._lock: + if self._closed: + raise SubscriptionClosedError("The Redis streams subscription is closed") + self._start_if_needed() try: if timeout is None: @@ -129,29 +164,33 @@ class _StreamsSubscription(Subscription): except queue.Empty: return None - if item is self._SENTINEL or self._closed.is_set(): + if item is self._SENTINEL: raise SubscriptionClosedError("The Redis streams subscription is closed") assert isinstance(item, (bytes, bytearray)), "Unexpected item type in stream queue" return bytes(item) def close(self) -> None: - if self._closed.is_set(): - return - self._closed.set() - listener = self._listener - if listener is not None: + with self._lock: + if self._closed: + return + self._closed = True + listener = self._listener + if listener is not None: + self._listener = None + # We close the listener outside of the with block to avoid holding the + # lock for a long time. + if listener is not None and listener.is_alive(): listener.join(timeout=2.0) if listener.is_alive(): logger.warning( "Streams subscription listener for key %s did not stop within timeout; keeping reference.", self._key, ) - else: - self._listener = None # Context manager helpers def __enter__(self) -> Self: - self._start_if_needed() + with self._lock: + self._start_if_needed() return self def __exit__(self, exc_type, exc_value, traceback) -> bool | None: diff --git a/api/libs/datetime_utils.py b/api/libs/datetime_utils.py index c08578981b4..e0a6ec2cacd 100644 --- a/api/libs/datetime_utils.py +++ b/api/libs/datetime_utils.py @@ -2,7 +2,7 @@ import abc import datetime from typing import Protocol -import pytz +import pytz # type: ignore[import-untyped] class _NowFunction(Protocol): diff --git a/api/libs/helper.py b/api/libs/helper.py index e7572cc025a..a7b3da77ff8 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -16,13 +16,13 @@ from zoneinfo import available_timezones from flask import Response, stream_with_context from flask_restx import fields +from graphon.file import helpers as file_helpers +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel from pydantic.functional_validators import AfterValidator from configs import dify_config from core.app.features.rate_limiting.rate_limit import RateLimitGenerator -from dify_graph.file import helpers as file_helpers -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_redis import redis_client if TYPE_CHECKING: @@ -174,6 +174,18 @@ def normalize_uuid(value: str | UUID) -> str: raise ValueError("must be a valid UUID") from exc +def parse_uuid_str_or_none(value: str | None) -> str | None: + """ + Return None for missing/empty UUID-like values. + + Keep non-empty values unchanged to avoid changing behavior in paths that + currently pass placeholder IDs in tests/mocks. + """ + if value is None or not str(value).strip(): + return None + return str(value) + + UUIDStrOrEmpty = Annotated[str, AfterValidator(normalize_uuid)] diff --git a/api/libs/login.py b/api/libs/login.py index bd5cb5f30d1..dce332b01d1 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -18,15 +18,23 @@ if TYPE_CHECKING: from models.model import EndUser +def _resolve_current_user() -> EndUser | Account | None: + """ + Resolve the current user proxy to its underlying user object. + This keeps unit tests working when they patch `current_user` directly + instead of bootstrapping a full Flask-Login manager. + """ + user_proxy = current_user + get_current_object = getattr(user_proxy, "_get_current_object", None) + return get_current_object() if callable(get_current_object) else user_proxy # type: ignore + + def current_account_with_tenant(): """ Resolve the underlying account for the current user proxy and ensure tenant context exists. Allows tests to supply plain Account mocks without the LocalProxy helper. """ - user_proxy = current_user - - get_current_object = getattr(user_proxy, "_get_current_object", None) - user = get_current_object() if callable(get_current_object) else user_proxy # type: ignore + user = _resolve_current_user() if not isinstance(user, Account): raise ValueError("current_user must be an Account instance") @@ -79,9 +87,10 @@ def login_required(func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue] if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: return current_app.ensure_sync(func)(*args, **kwargs) - user = _get_user() + user = _resolve_current_user() if user is None or not user.is_authenticated: return current_app.login_manager.unauthorized() # type: ignore + g._login_user = user # we put csrf validation here for less conflicts # TODO: maybe find a better place for it. check_csrf_token(request, user.id) diff --git a/api/libs/oauth.py b/api/libs/oauth.py index efce13f6f1a..76e741301cd 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -1,16 +1,19 @@ +import logging import sys import urllib.parse from dataclasses import dataclass from typing import NotRequired import httpx -from pydantic import TypeAdapter +from pydantic import TypeAdapter, ValidationError if sys.version_info >= (3, 12): from typing import TypedDict else: from typing_extensions import TypedDict +logger = logging.getLogger(__name__) + JsonObject = dict[str, object] JsonObjectList = list[JsonObject] @@ -25,13 +28,14 @@ class AccessTokenResponse(TypedDict, total=False): class GitHubEmailRecord(TypedDict, total=False): email: str primary: bool + verified: bool class GitHubRawUserInfo(TypedDict): id: int | str login: str - name: NotRequired[str] - email: NotRequired[str] + name: NotRequired[str | None] + email: NotRequired[str | None] class GoogleRawUserInfo(TypedDict): @@ -127,18 +131,52 @@ class GitHubOAuth(OAuth): response.raise_for_status() user_info = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(_json_object(response)) - email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers) - email_info = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response)) - primary_email = next((email for email in email_info if email.get("primary") is True), None) + # Only call the /user/emails endpoint when the profile email is absent, + # i.e. the user has "Keep my email addresses private" enabled. + resolved_email = user_info.get("email") or "" + if not resolved_email: + resolved_email = self._get_email_from_emails_endpoint(headers) - return {**user_info, "email": primary_email.get("email", "") if primary_email else ""} + return {**user_info, "email": resolved_email} + + @staticmethod + def _get_email_from_emails_endpoint(headers: dict[str, str]) -> str: + """Fetch the best available email from GitHub's /user/emails endpoint. + + Prefers the primary email, then falls back to any verified email. + Returns an empty string when no usable email is found. + """ + try: + email_response = httpx.get(GitHubOAuth._EMAIL_INFO_URL, headers=headers) + email_response.raise_for_status() + email_records = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response)) + except (httpx.HTTPStatusError, ValidationError): + logger.warning("Failed to retrieve email from GitHub /user/emails endpoint", exc_info=True) + return "" + + primary = next((r for r in email_records if r.get("primary") is True), None) + if primary: + return primary.get("email", "") + + # No primary email; try any verified email as a fallback. + verified = next((r for r in email_records if r.get("verified") is True), None) + if verified: + return verified.get("email", "") + + return "" def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo: payload = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(raw_info) - email = payload.get("email") + email = payload.get("email") or "" if not email: - email = f"{payload['id']}+{payload['login']}@users.noreply.github.com" - return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name", "")), email=email) + # When no email is available from the profile or /user/emails endpoint, + # fall back to GitHub's noreply address so sign-in can still proceed. + # Use only the numeric ID (not the login) so the address stays stable + # even if the user renames their GitHub account. + github_id = payload["id"] + email = f"{github_id}@users.noreply.github.com" + logger.info("GitHub user %s has no public email; using noreply address", payload["login"]) + return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name") or ""), email=email) class GoogleOAuth(OAuth): diff --git a/api/libs/schedule_utils.py b/api/libs/schedule_utils.py index 1ab5f499e99..b80a5ea7227 100644 --- a/api/libs/schedule_utils.py +++ b/api/libs/schedule_utils.py @@ -1,6 +1,6 @@ from datetime import UTC, datetime -import pytz +import pytz # type: ignore[import-untyped] from croniter import croniter diff --git a/api/models/dataset.py b/api/models/dataset.py index d0163e69848..e323ccfd7f6 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -20,7 +20,7 @@ from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.constant.query_type import QueryType from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.signature import sign_upload_file @@ -43,7 +43,9 @@ from .enums import ( IndexingStatus, ProcessRuleMode, SegmentStatus, + SegmentType, SummaryStatus, + TidbAuthBindingStatus, ) from .model import App, Tag, TagBinding, UploadFile from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adjusted_json_index @@ -135,7 +137,7 @@ class Dataset(Base): default=DatasetPermissionEnum.ONLY_ME, ) data_source_type = mapped_column(EnumText(DataSourceType, length=255)) - indexing_technique: Mapped[str | None] = mapped_column(String(255)) + indexing_technique: Mapped[IndexTechniqueType | None] = mapped_column(EnumText(IndexTechniqueType, length=255)) index_struct = mapped_column(LongText, nullable=True) created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @@ -494,7 +496,9 @@ class Document(Base): ) doc_type = mapped_column(EnumText(DocumentDocType, length=40), nullable=True) doc_metadata = mapped_column(AdjustedJSON, nullable=True) - doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'")) + doc_form: Mapped[IndexStructureType] = mapped_column( + EnumText(IndexStructureType, length=255), nullable=False, server_default=sa.text("'text_model'") + ) doc_language = mapped_column(String(255), nullable=True) need_summary: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) @@ -998,7 +1002,9 @@ class ChildChunk(Base): # indexing fields index_node_id = mapped_column(String(255), nullable=True) index_node_hash = mapped_column(String(255), nullable=True) - type = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'")) + type: Mapped[SegmentType] = mapped_column( + EnumText(SegmentType, length=255), nullable=False, server_default=sa.text("'automatic'") + ) created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) @@ -1239,7 +1245,9 @@ class TidbAuthBinding(TypeBase): cluster_id: Mapped[str] = mapped_column(String(255), nullable=False) cluster_name: Mapped[str] = mapped_column(String(255), nullable=False) active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) - status: Mapped[str] = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'CREATING'")) + status: Mapped[TidbAuthBindingStatus] = mapped_column( + EnumText(TidbAuthBindingStatus, length=255), nullable=False, server_default=sa.text("'CREATING'") + ) account: Mapped[str] = mapped_column(String(255), nullable=False) password: Mapped[str] = mapped_column(String(255), nullable=False) created_at: Mapped[datetime] = mapped_column( diff --git a/api/models/enums.py b/api/models/enums.py index 4849099d303..bf2e927f002 100644 --- a/api/models/enums.py +++ b/api/models/enums.py @@ -158,6 +158,15 @@ class FeedbackFromSource(StrEnum): ADMIN = "admin" +class CustomizeTokenStrategy(StrEnum): + """Site token customization strategy""" + + MUST = "must" + ALLOW = "allow" + NOT_ALLOW = "not_allow" + UUID = "uuid" + + class FeedbackRating(StrEnum): """MessageFeedback rating""" @@ -222,6 +231,13 @@ class DatasetMetadataType(StrEnum): TIME = "time" +class SegmentType(StrEnum): + """Document segment type""" + + AUTOMATIC = "automatic" + CUSTOMIZED = "customized" + + class SegmentStatus(StrEnum): """Document segment status""" @@ -307,6 +323,13 @@ class MessageChainType(StrEnum): SYSTEM = "system" +class PromptType(StrEnum): + """Prompt configuration type""" + + SIMPLE = "simple" + ADVANCED = "advanced" + + class ProviderQuotaType(StrEnum): PAID = "paid" """hosted paid quota""" @@ -323,3 +346,10 @@ class ProviderQuotaType(StrEnum): if member.value == value: return member raise ValueError(f"No matching enum found for value '{value}'") + + +class ApiTokenType(StrEnum): + """API Token type""" + + APP = "app" + DATASET = "dataset" diff --git a/api/models/execution_extra_content.py b/api/models/execution_extra_content.py index d0bd34efeca..b2d09a77327 100644 --- a/api/models/execution_extra_content.py +++ b/api/models/execution_extra_content.py @@ -66,8 +66,8 @@ class HumanInputContent(ExecutionExtraContent): form_id: Mapped[str] = mapped_column(StringUUID, nullable=True) @classmethod - def new(cls, form_id: str, message_id: str | None) -> "HumanInputContent": - return cls(form_id=form_id, message_id=message_id) + def new(cls, *, workflow_run_id: str, form_id: str, message_id: str | None) -> "HumanInputContent": + return cls(workflow_run_id=workflow_run_id, form_id=form_id, message_id=message_id) form: Mapped["HumanInputForm"] = relationship( "HumanInputForm", diff --git a/api/models/human_input.py b/api/models/human_input.py index 48e7fbb9eaf..79c5d62f6a8 100644 --- a/api/models/human_input.py +++ b/api/models/human_input.py @@ -3,14 +3,11 @@ from enum import StrEnum from typing import Annotated, Literal, Self, final import sqlalchemy as sa +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from pydantic import BaseModel, Field from sqlalchemy.orm import Mapped, mapped_column, relationship -from dify_graph.nodes.human_input.enums import ( - DeliveryMethodType, - HumanInputFormKind, - HumanInputFormStatus, -) +from core.workflow.human_input_compat import DeliveryMethodType from libs.helper import generate_string from .base import Base, DefaultFieldsMixin diff --git a/api/models/model.py b/api/models/model.py index 3bd68d1d956..066d2acdce0 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -3,16 +3,20 @@ from __future__ import annotations import json import re import uuid -from collections.abc import Mapping, Sequence +from collections.abc import Callable, Mapping, Sequence from datetime import datetime from decimal import Decimal from enum import StrEnum, auto +from functools import lru_cache from typing import TYPE_CHECKING, Any, Literal, NotRequired, cast from uuid import uuid4 import sqlalchemy as sa from flask import request from flask_login import UserMixin # type: ignore[import-untyped] +from graphon.enums import WorkflowExecutionStatus +from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType +from graphon.file import helpers as file_helpers from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text from sqlalchemy.orm import Mapped, Session, mapped_column from typing_extensions import TypedDict @@ -20,27 +24,32 @@ from typing_extensions import TypedDict from configs import dify_config from constants import DEFAULT_FILE_NUMBER_LIMITS from core.tools.signature import sign_tool_file -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod -from dify_graph.file import helpers as file_helpers from extensions.storage.storage_type import StorageType from libs.helper import generate_string # type: ignore[import-not-found] from libs.uuid_utils import uuidv7 +from models.utils.file_input_compat import build_file_from_input_mapping from .account import Account, Tenant from .base import Base, TypeBase, gen_uuidv4_string from .engine import db from .enums import ( + ApiTokenType, AppMCPServerStatus, AppStatus, BannerStatus, + ConversationFromSource, ConversationStatus, CreatorUserRole, + CustomizeTokenStrategy, FeedbackFromSource, FeedbackRating, + InvokeFrom, MessageChainType, MessageFileBelongsTo, MessageStatus, + PromptType, + ProviderQuotaType, + TagType, ) from .provider_ids import GenericProviderID from .types import EnumText, LongText, StringUUID @@ -52,6 +61,32 @@ if TYPE_CHECKING: # --- TypedDict definitions for structured dict return types --- +@lru_cache(maxsize=1) +def _get_file_access_controller(): + from core.app.file_access import DatabaseFileAccessController + + return DatabaseFileAccessController() + + +def _resolve_app_tenant_id(app_id: str) -> str: + resolved_tenant_id = db.session.scalar(select(App.tenant_id).where(App.id == app_id)) + if not resolved_tenant_id: + raise ValueError(f"Unable to resolve tenant_id for app {app_id}") + return resolved_tenant_id + + +def _build_app_tenant_resolver(app_id: str, owner_tenant_id: str | None = None) -> Callable[[], str]: + resolved_tenant_id = owner_tenant_id + + def resolve_owner_tenant_id() -> str: + nonlocal resolved_tenant_id + if resolved_tenant_id is None: + resolved_tenant_id = _resolve_app_tenant_id(app_id) + return resolved_tenant_id + + return resolve_owner_tenant_id + + class EnabledConfig(TypedDict): enabled: bool @@ -584,7 +619,9 @@ class AppModelConfig(TypeBase): __tablename__ = "app_model_configs" __table_args__ = (sa.PrimaryKeyConstraint("id", name="app_model_config_pkey"), sa.Index("app_app_id_idx", "app_id")) - id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False) + id: Mapped[str] = mapped_column( + StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False + ) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) provider: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) model_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) @@ -614,8 +651,11 @@ class AppModelConfig(TypeBase): agent_mode: Mapped[str | None] = mapped_column(LongText, default=None) sensitive_word_avoidance: Mapped[str | None] = mapped_column(LongText, default=None) retriever_resource: Mapped[str | None] = mapped_column(LongText, default=None) - prompt_type: Mapped[str] = mapped_column( - String(255), nullable=False, server_default=sa.text("'simple'"), default="simple" + prompt_type: Mapped[PromptType] = mapped_column( + EnumText(PromptType, length=255), + nullable=False, + server_default=sa.text("'simple'"), + default=PromptType.SIMPLE, ) chat_prompt_config: Mapped[str | None] = mapped_column(LongText, default=None) completion_prompt_config: Mapped[str | None] = mapped_column(LongText, default=None) @@ -767,7 +807,7 @@ class AppModelConfig(TypeBase): "dataset_query_variable": self.dataset_query_variable, "pre_prompt": self.pre_prompt, "agent_mode": self.agent_mode_dict, - "prompt_type": self.prompt_type, + "prompt_type": self.prompt_type.value if isinstance(self.prompt_type, PromptType) else self.prompt_type, "chat_prompt_config": self.chat_prompt_config_dict, "completion_prompt_config": self.completion_prompt_config_dict, "dataset_configs": self.dataset_configs_dict, @@ -811,7 +851,7 @@ class AppModelConfig(TypeBase): self.retriever_resource = ( json.dumps(model_config.get("retriever_resource")) if model_config.get("retriever_resource") else None ) - self.prompt_type = model_config.get("prompt_type", "simple") + self.prompt_type = PromptType(model_config.get("prompt_type", "simple")) self.chat_prompt_config = ( json.dumps(model_config.get("chat_prompt_config")) if model_config.get("chat_prompt_config") else None ) @@ -933,7 +973,9 @@ class AccountTrialAppRecord(Base): class ExporleBanner(TypeBase): __tablename__ = "exporle_banners" __table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),) - id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv4_string, init=False) + id: Mapped[str] = mapped_column( + StringUUID, insert_default=gen_uuidv4_string, default_factory=gen_uuidv4_string, init=False + ) content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False) link: Mapped[str] = mapped_column(String(255), nullable=False) sort: Mapped[int] = mapped_column(sa.Integer, nullable=False) @@ -1022,10 +1064,12 @@ class Conversation(Base): # # Its value corresponds to the members of `InvokeFrom`. # (api/core/app/entities/app_invoke_entities.py) - invoke_from = mapped_column(String(255), nullable=True) + invoke_from: Mapped[InvokeFrom | None] = mapped_column(EnumText(InvokeFrom, length=255), nullable=True) # ref: ConversationSource. - from_source: Mapped[str] = mapped_column(String(255), nullable=False) + from_source: Mapped[ConversationFromSource] = mapped_column( + EnumText(ConversationFromSource, length=255), nullable=False + ) from_end_user_id = mapped_column(StringUUID) from_account_id = mapped_column(StringUUID) read_at = mapped_column(sa.DateTime) @@ -1046,23 +1090,26 @@ class Conversation(Base): @property def inputs(self) -> dict[str, Any]: inputs = self._inputs.copy() + # Compatibility bridge: stored input payloads may come from before or after the + # graph-layer file refactor. Newer rows may omit `tenant_id`, so keep tenant + # resolution at the SQLAlchemy model boundary instead of pushing ownership back + # into `graphon.file.File`. + tenant_resolver = _build_app_tenant_resolver( + app_id=self.app_id, + owner_tenant_id=cast(str | None, getattr(self, "_owner_tenant_id", None)), + ) # Convert file mapping to File object for key, value in inputs.items(): - # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. - from factories import file_factory - if ( isinstance(value, dict) and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY ): value_dict = cast(dict[str, Any], value) - if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: - value_dict["tool_file_id"] = value_dict["related_id"] - elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: - value_dict["upload_file_id"] = value_dict["related_id"] - tenant_id = cast(str, value_dict.get("tenant_id", "")) - inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id) + inputs[key] = build_file_from_input_mapping( + file_mapping=value_dict, + tenant_resolver=tenant_resolver, + ) elif isinstance(value, list): value_list = cast(list[Any], value) if all( @@ -1075,15 +1122,12 @@ class Conversation(Base): if not isinstance(item, dict): continue item_dict = cast(dict[str, Any], item) - if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: - item_dict["tool_file_id"] = item_dict["related_id"] - elif item_dict["transfer_method"] in [ - FileTransferMethod.LOCAL_FILE, - FileTransferMethod.REMOTE_URL, - ]: - item_dict["upload_file_id"] = item_dict["related_id"] - tenant_id = cast(str, item_dict.get("tenant_id", "")) - file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id)) + file_list.append( + build_file_from_input_mapping( + file_mapping=item_dict, + tenant_resolver=tenant_resolver, + ) + ) inputs[key] = file_list return inputs @@ -1374,8 +1418,10 @@ class Message(Base): ) error: Mapped[str | None] = mapped_column(LongText) message_metadata: Mapped[str | None] = mapped_column(LongText) - invoke_from: Mapped[str | None] = mapped_column(String(255), nullable=True) - from_source: Mapped[str] = mapped_column(String(255), nullable=False) + invoke_from: Mapped[InvokeFrom | None] = mapped_column(EnumText(InvokeFrom, length=255), nullable=True) + from_source: Mapped[ConversationFromSource] = mapped_column( + EnumText(ConversationFromSource, length=255), nullable=False + ) from_end_user_id: Mapped[str | None] = mapped_column(StringUUID) from_account_id: Mapped[str | None] = mapped_column(StringUUID) created_at: Mapped[datetime] = mapped_column(sa.DateTime, server_default=func.current_timestamp()) @@ -1389,21 +1435,23 @@ class Message(Base): @property def inputs(self) -> dict[str, Any]: inputs = self._inputs.copy() + # Compatibility bridge: message inputs are persisted as JSON and must remain + # readable across file payload shape changes. Do not assume `tenant_id` + # is serialized into each file mapping going forward. + tenant_resolver = _build_app_tenant_resolver( + app_id=self.app_id, + owner_tenant_id=cast(str | None, getattr(self, "_owner_tenant_id", None)), + ) for key, value in inputs.items(): - # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. - from factories import file_factory - if ( isinstance(value, dict) and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY ): value_dict = cast(dict[str, Any], value) - if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: - value_dict["tool_file_id"] = value_dict["related_id"] - elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: - value_dict["upload_file_id"] = value_dict["related_id"] - tenant_id = cast(str, value_dict.get("tenant_id", "")) - inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id) + inputs[key] = build_file_from_input_mapping( + file_mapping=value_dict, + tenant_resolver=tenant_resolver, + ) elif isinstance(value, list): value_list = cast(list[Any], value) if all( @@ -1416,15 +1464,12 @@ class Message(Base): if not isinstance(item, dict): continue item_dict = cast(dict[str, Any], item) - if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: - item_dict["tool_file_id"] = item_dict["related_id"] - elif item_dict["transfer_method"] in [ - FileTransferMethod.LOCAL_FILE, - FileTransferMethod.REMOTE_URL, - ]: - item_dict["upload_file_id"] = item_dict["related_id"] - tenant_id = cast(str, item_dict.get("tenant_id", "")) - file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id)) + file_list.append( + build_file_from_input_mapping( + file_mapping=item_dict, + tenant_resolver=tenant_resolver, + ) + ) inputs[key] = file_list return inputs @@ -1599,6 +1644,7 @@ class Message(Base): "upload_file_id": message_file.upload_file_id, }, tenant_id=current_app.tenant_id, + access_controller=_get_file_access_controller(), ) elif message_file.transfer_method == FileTransferMethod.REMOTE_URL: if message_file.url is None: @@ -1612,6 +1658,7 @@ class Message(Base): "url": message_file.url, }, tenant_id=current_app.tenant_id, + access_controller=_get_file_access_controller(), ) elif message_file.transfer_method == FileTransferMethod.TOOL_FILE: if message_file.upload_file_id is None: @@ -1626,6 +1673,7 @@ class Message(Base): file = file_factory.build_from_mapping( mapping=mapping, tenant_id=current_app.tenant_id, + access_controller=_get_file_access_controller(), ) else: raise ValueError( @@ -1776,7 +1824,7 @@ class MessageFile(TypeBase): StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False ) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - type: Mapped[str] = mapped_column(String(255), nullable=False) + type: Mapped[FileType] = mapped_column(EnumText(FileType, length=255), nullable=False) transfer_method: Mapped[FileTransferMethod] = mapped_column( EnumText(FileTransferMethod, length=255), nullable=False ) @@ -1838,7 +1886,9 @@ class AppAnnotationHitHistory(TypeBase): sa.Index("app_annotation_hit_histories_message_idx", "message_id"), ) - id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False) + id: Mapped[str] = mapped_column( + StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False + ) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) annotation_id: Mapped[str] = mapped_column(StringUUID, nullable=False) source: Mapped[str] = mapped_column(LongText, nullable=False) @@ -2039,7 +2089,9 @@ class Site(Base): use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) _custom_disclaimer: Mapped[str] = mapped_column("custom_disclaimer", LongText, default="") customize_domain = mapped_column(String(255)) - customize_token_strategy: Mapped[str] = mapped_column(String(255), nullable=False) + customize_token_strategy: Mapped[CustomizeTokenStrategy] = mapped_column( + EnumText(CustomizeTokenStrategy, length=255), nullable=False + ) prompt_public: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) status: Mapped[AppStatus] = mapped_column( EnumText(AppStatus, length=255), nullable=False, server_default=sa.text("'normal'"), default=AppStatus.NORMAL @@ -2088,7 +2140,7 @@ class ApiToken(Base): # bug: this uses setattr so idk the field. id = mapped_column(StringUUID, default=lambda: str(uuid4())) app_id = mapped_column(StringUUID, nullable=True) tenant_id = mapped_column(StringUUID, nullable=True) - type = mapped_column(String(16), nullable=False) + type: Mapped[ApiTokenType] = mapped_column(EnumText(ApiTokenType, length=16), nullable=False) token: Mapped[str] = mapped_column(String(255), nullable=False) last_used_at = mapped_column(sa.DateTime, nullable=True) created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -2398,7 +2450,7 @@ class Tag(TypeBase): StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False ) tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) - type: Mapped[str] = mapped_column(String(16), nullable=False) + type: Mapped[TagType] = mapped_column(EnumText(TagType, length=16), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column( @@ -2483,7 +2535,9 @@ class TenantCreditPool(TypeBase): StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False ) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - pool_type: Mapped[str] = mapped_column(String(40), nullable=False, default="trial", server_default="trial") + pool_type: Mapped[ProviderQuotaType] = mapped_column( + EnumText(ProviderQuotaType, length=40), nullable=False, default=ProviderQuotaType.TRIAL, server_default="trial" + ) quota_limit: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) quota_used: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) created_at: Mapped[datetime] = mapped_column( diff --git a/api/models/tools.py b/api/models/tools.py index c09f054e7da..d8731fb8a8a 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -11,14 +11,19 @@ from deprecated import deprecated from sqlalchemy import ForeignKey, String, func, select from sqlalchemy.orm import Mapped, mapped_column +from core.plugin.entities.plugin_daemon import CredentialType from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle -from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration +from core.tools.entities.tool_entities import ( + ApiProviderSchemaType, + ToolProviderType, + WorkflowToolParameterConfiguration, +) from .base import TypeBase from .engine import db from .model import Account, App, Tenant -from .types import LongText, StringUUID +from .types import EnumText, LongText, StringUUID if TYPE_CHECKING: from core.entities.mcp_provider import MCPProviderEntity @@ -105,8 +110,11 @@ class BuiltinToolProvider(TypeBase): ) is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"), default=False) # credential type, e.g., "api-key", "oauth2" - credential_type: Mapped[str] = mapped_column( - String(32), nullable=False, server_default=sa.text("'api-key'"), default="api-key" + credential_type: Mapped[CredentialType] = mapped_column( + EnumText(CredentialType, length=32), + nullable=False, + server_default=sa.text("'api-key'"), + default=CredentialType.API_KEY, ) expires_at: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, server_default=sa.text("-1"), default=-1) @@ -141,7 +149,9 @@ class ApiToolProvider(TypeBase): icon: Mapped[str] = mapped_column(String(255), nullable=False) # original schema schema: Mapped[str] = mapped_column(LongText, nullable=False) - schema_type_str: Mapped[str] = mapped_column(String(40), nullable=False) + schema_type_str: Mapped[ApiProviderSchemaType] = mapped_column( + EnumText(ApiProviderSchemaType, length=40), nullable=False + ) # who created this tool user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # tenant id @@ -208,7 +218,7 @@ class ToolLabelBinding(TypeBase): # tool id tool_id: Mapped[str] = mapped_column(String(64), nullable=False) # tool type - tool_type: Mapped[str] = mapped_column(String(40), nullable=False) + tool_type: Mapped[ToolProviderType] = mapped_column(EnumText(ToolProviderType, length=40), nullable=False) # label name label_name: Mapped[str] = mapped_column(String(40), nullable=False) @@ -386,7 +396,7 @@ class ToolModelInvoke(TypeBase): # provider provider: Mapped[str] = mapped_column(String(255), nullable=False) # type - tool_type: Mapped[str] = mapped_column(String(40), nullable=False) + tool_type: Mapped[ToolProviderType] = mapped_column(EnumText(ToolProviderType, length=40), nullable=False) # tool name tool_name: Mapped[str] = mapped_column(String(128), nullable=False) # invoke parameters diff --git a/api/models/trigger.py b/api/models/trigger.py index 627b854060c..5233a6e2711 100644 --- a/api/models/trigger.py +++ b/api/models/trigger.py @@ -102,7 +102,9 @@ class TriggerSubscription(TypeBase): credentials: Mapped[TriggerCredentials] = mapped_column( sa.JSON, nullable=False, comment="Subscription credentials JSON" ) - credential_type: Mapped[str] = mapped_column(String(50), nullable=False, comment="oauth or api_key") + credential_type: Mapped[CredentialType] = mapped_column( + EnumText(CredentialType, length=50), nullable=False, comment="oauth or api_key" + ) credential_expires_at: Mapped[int] = mapped_column( Integer, default=-1, comment="OAuth token expiration timestamp, -1 for never" ) @@ -144,7 +146,7 @@ class TriggerSubscription(TypeBase): endpoint=generate_plugin_trigger_endpoint_url(self.endpoint_id), parameters=self.parameters, properties=self.properties, - credential_type=CredentialType(self.credential_type), + credential_type=self.credential_type, credentials=self.credentials, workflows_in_use=-1, ) diff --git a/api/models/utils/__init__.py b/api/models/utils/__init__.py new file mode 100644 index 00000000000..b390b8106b0 --- /dev/null +++ b/api/models/utils/__init__.py @@ -0,0 +1,3 @@ +from .file_input_compat import build_file_from_input_mapping + +__all__ = ["build_file_from_input_mapping"] diff --git a/api/models/utils/file_input_compat.py b/api/models/utils/file_input_compat.py new file mode 100644 index 00000000000..f71583c1cde --- /dev/null +++ b/api/models/utils/file_input_compat.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +from collections.abc import Callable, Mapping +from functools import lru_cache +from typing import Any + +from graphon.file import File, FileTransferMethod + +from core.workflow.file_reference import parse_file_reference + + +@lru_cache(maxsize=1) +def _get_file_access_controller(): + from core.app.file_access import DatabaseFileAccessController + + return DatabaseFileAccessController() + + +def resolve_file_record_id(file_mapping: Mapping[str, Any]) -> str | None: + reference = file_mapping.get("reference") + if isinstance(reference, str) and reference: + parsed_reference = parse_file_reference(reference) + if parsed_reference is not None: + return parsed_reference.record_id + + related_id = file_mapping.get("related_id") + if isinstance(related_id, str) and related_id: + parsed_reference = parse_file_reference(related_id) + if parsed_reference is not None: + return parsed_reference.record_id + + return None + + +def resolve_file_mapping_tenant_id( + *, + file_mapping: Mapping[str, Any], + tenant_resolver: Callable[[], str], +) -> str: + tenant_id = file_mapping.get("tenant_id") + if isinstance(tenant_id, str) and tenant_id: + return tenant_id + + return tenant_resolver() + + +def build_file_from_stored_mapping( + *, + file_mapping: Mapping[str, Any], + tenant_id: str, +) -> File: + """ + Canonicalize a persisted file payload against the current tenant context. + + Stored JSON rows can outlive file schema changes, so rebuild storage-backed + files through the workflow factory instead of trusting serialized metadata. + Pure external ``REMOTE_URL`` payloads without a backing upload row are + passed through because there is no server-owned record to rebind. + """ + + # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. + from factories import file_factory + + mapping = dict(file_mapping) + mapping.pop("tenant_id", None) + record_id = resolve_file_record_id(mapping) + transfer_method = FileTransferMethod.value_of(mapping["transfer_method"]) + + if transfer_method == FileTransferMethod.TOOL_FILE and record_id: + mapping["tool_file_id"] = record_id + elif transfer_method in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL] and record_id: + mapping["upload_file_id"] = record_id + elif transfer_method == FileTransferMethod.DATASOURCE_FILE and record_id: + mapping["datasource_file_id"] = record_id + + if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None: + remote_url = mapping.get("remote_url") + if not isinstance(remote_url, str) or not remote_url: + url = mapping.get("url") + if isinstance(url, str) and url: + mapping["remote_url"] = url + return File.model_validate(mapping) + + return file_factory.build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + access_controller=_get_file_access_controller(), + ) + + +def build_file_from_input_mapping( + *, + file_mapping: Mapping[str, Any], + tenant_resolver: Callable[[], str], +) -> File: + """ + Rehydrate persisted model input payloads into graph `File` objects. + + This compatibility layer exists because model JSON rows can outlive file payload + schema changes. Legacy rows may carry `related_id` and `tenant_id`, while newer + rows may only carry `reference`. Keep ownership resolution here, at the model + boundary, instead of pushing tenant data back into `graphon.file.File`. + """ + + transfer_method = FileTransferMethod.value_of(file_mapping["transfer_method"]) + record_id = resolve_file_record_id(file_mapping) + if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None: + return build_file_from_stored_mapping( + file_mapping=file_mapping, + tenant_id="", + ) + + tenant_id = resolve_file_mapping_tenant_id(file_mapping=file_mapping, tenant_resolver=tenant_resolver) + return build_file_from_stored_mapping( + file_mapping=file_mapping, + tenant_id=tenant_id, + ) diff --git a/api/models/workflow.py b/api/models/workflow.py index e7b20d0e659..f8868cb73cc 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,3 +1,4 @@ +import copy import json import logging from collections.abc import Generator, Mapping, Sequence @@ -7,6 +8,19 @@ from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union, cast from uuid import uuid4 import sqlalchemy as sa +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause +from graphon.enums import ( + BuiltinNodeTypes, + NodeType, + WorkflowExecutionStatus, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.file import File +from graphon.file.constants import maybe_file_object +from graphon.variables import utils as variable_utils +from graphon.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable from sqlalchemy import ( DateTime, Index, @@ -23,17 +37,11 @@ from sqlalchemy.orm import Mapped, mapped_column from typing_extensions import deprecated from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE -from dify_graph.constants import ( +from core.workflow.human_input_compat import normalize_node_config_for_graph +from core.workflow.variable_prefixes import ( CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID, ) -from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from dify_graph.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause -from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey -from dify_graph.file.constants import maybe_file_object -from dify_graph.file.models import File -from dify_graph.variables import utils as variable_utils -from dify_graph.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable from extensions.ext_storage import Storage from factories.variable_factory import TypeMismatchError, build_segment_with_type from libs.datetime_utils import naive_utc_now @@ -45,9 +53,10 @@ if TYPE_CHECKING: from .model import AppMode, UploadFile +from graphon.variables import SecretVariable, Segment, SegmentType, VariableBase + from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE from core.helper import encrypter -from dify_graph.variables import SecretVariable, Segment, SegmentType, VariableBase from factories import variable_factory from libs import helper @@ -56,6 +65,7 @@ from .base import Base, DefaultFieldsMixin, TypeBase from .engine import db from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom from .types import EnumText, LongText, StringUUID +from .utils.file_input_compat import build_file_from_stored_mapping logger = logging.getLogger(__name__) @@ -63,6 +73,15 @@ SerializedWorkflowValue = dict[str, Any] SerializedWorkflowVariables = dict[str, SerializedWorkflowValue] +def _resolve_workflow_app_tenant_id(app_id: str) -> str: + from .model import App + + tenant_id = db.session.scalar(select(App.tenant_id).where(App.id == app_id)) + if not tenant_id: + raise ValueError(f"Unable to resolve tenant_id for app {app_id}") + return tenant_id + + class WorkflowContentDict(TypedDict): graph: Mapping[str, Any] features: dict[str, Any] @@ -272,7 +291,7 @@ class Workflow(Base): # bug node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes)) except StopIteration: raise NodeNotFoundError(node_id) - return NodeConfigDictAdapter.validate_python(node_config) + return NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config)) @staticmethod def get_node_type_from_node_config(node_config: NodeConfigDict) -> NodeType: @@ -302,26 +321,40 @@ class Workflow(Base): # bug def features(self) -> str: """ Convert old features structure to new features structure. + + This property avoids rewriting the underlying JSON when normalization + produces no effective change, to prevent marking the row dirty on read. """ if not self._features: return self._features - features = json.loads(self._features) - if features.get("file_upload", {}).get("image", {}).get("enabled", False): - image_enabled = True - image_number_limits = int(features["file_upload"]["image"].get("number_limits", DEFAULT_FILE_NUMBER_LIMITS)) - image_transfer_methods = features["file_upload"]["image"].get( - "transfer_methods", ["remote_url", "local_file"] - ) - features["file_upload"]["enabled"] = image_enabled - features["file_upload"]["number_limits"] = image_number_limits - features["file_upload"]["allowed_file_upload_methods"] = image_transfer_methods - features["file_upload"]["allowed_file_types"] = features["file_upload"].get("allowed_file_types", ["image"]) - features["file_upload"]["allowed_file_extensions"] = features["file_upload"].get( - "allowed_file_extensions", [] - ) - del features["file_upload"]["image"] - self._features = json.dumps(features) + # Parse once and deep-copy before normalization to detect in-place changes. + original_dict = self._decode_features_payload(self._features) + if original_dict is None: + return self._features + + # Fast-path: if the legacy file_upload.image.enabled shape is absent, skip + # deep-copy and normalization entirely and return the stored JSON. + file_upload_payload = original_dict.get("file_upload") + if not isinstance(file_upload_payload, dict): + return self._features + file_upload = cast(dict[str, Any], file_upload_payload) + + image_payload = file_upload.get("image") + if not isinstance(image_payload, dict): + return self._features + image = cast(dict[str, Any], image_payload) + if "enabled" not in image: + return self._features + + normalized_dict = self._normalize_features_payload(copy.deepcopy(original_dict)) + + if normalized_dict == original_dict: + # No effective change; return stored JSON unchanged. + return self._features + + # Normalization changed the payload: persist the normalized JSON. + self._features = json.dumps(normalized_dict) return self._features @features.setter @@ -332,6 +365,44 @@ class Workflow(Base): # bug def features_dict(self) -> dict[str, Any]: return json.loads(self.features) if self.features else {} + @property + def serialized_features(self) -> str: + """Return the stored features JSON without triggering compatibility rewrites.""" + return self._features + + @property + def normalized_features_dict(self) -> dict[str, Any]: + """Decode features with legacy normalization without mutating the model state.""" + if not self._features: + return {} + + features = self._decode_features_payload(self._features) + return self._normalize_features_payload(features) if features is not None else {} + + @staticmethod + def _decode_features_payload(features: str) -> dict[str, Any] | None: + """Decode workflow features JSON when it contains an object payload.""" + payload = json.loads(features) + return cast(dict[str, Any], payload) if isinstance(payload, dict) else None + + @staticmethod + def _normalize_features_payload(features: dict[str, Any]) -> dict[str, Any]: + if features.get("file_upload", {}).get("image", {}).get("enabled", False): + image_number_limits = int(features["file_upload"]["image"].get("number_limits", DEFAULT_FILE_NUMBER_LIMITS)) + image_transfer_methods = features["file_upload"]["image"].get( + "transfer_methods", ["remote_url", "local_file"] + ) + features["file_upload"]["enabled"] = True + features["file_upload"]["number_limits"] = image_number_limits + features["file_upload"]["allowed_file_upload_methods"] = image_transfer_methods + features["file_upload"]["allowed_file_types"] = features["file_upload"].get("allowed_file_types", ["image"]) + features["file_upload"]["allowed_file_extensions"] = features["file_upload"].get( + "allowed_file_extensions", [] + ) + del features["file_upload"]["image"] + + return features + def walk_nodes( self, specific_node_type: NodeType | None = None ) -> Generator[tuple[str, Mapping[str, Any]], None, None]: @@ -365,7 +436,7 @@ class Workflow(Base): # bug "selected": false, } - For specific node type, refer to `dify_graph.nodes` + For specific node type, refer to `graphon.nodes` """ graph_dict = self.graph_dict if "nodes" not in graph_dict: @@ -517,6 +588,31 @@ class Workflow(Base): # bug ) self._environment_variables = environment_variables_json + @staticmethod + def normalize_environment_variable_mappings( + mappings: Sequence[Mapping[str, Any]], + ) -> list[dict[str, Any]]: + """Convert masked secret placeholders into the draft hidden sentinel. + + Regular draft sync requests should preserve existing secrets without shipping + plaintext values back from the client. The dedicated restore endpoint now + copies published secrets server-side, so draft sync only needs to normalize + the UI mask into `HIDDEN_VALUE`. + """ + masked_secret_value = encrypter.full_mask_token() + normalized_mappings: list[dict[str, Any]] = [] + + for mapping in mappings: + normalized_mapping = dict(mapping) + if ( + normalized_mapping.get("value_type") == SegmentType.SECRET.value + and normalized_mapping.get("value") == masked_secret_value + ): + normalized_mapping["value"] = HIDDEN_VALUE + normalized_mappings.append(normalized_mapping) + + return normalized_mappings + def to_dict(self, *, include_secret: bool = False) -> WorkflowContentDict: environment_variables = list(self.environment_variables) environment_variables = [ @@ -564,6 +660,12 @@ class Workflow(Base): # bug ensure_ascii=False, ) + def copy_serialized_variable_storage_from(self, source_workflow: "Workflow") -> None: + """Copy stored variable JSON directly for same-tenant restore flows.""" + self._environment_variables = source_workflow._environment_variables + self._conversation_variables = source_workflow._conversation_variables + self._rag_pipeline_variables = source_workflow._rag_pipeline_variables + @staticmethod def version_from_datetime(d: datetime) -> str: return str(d) @@ -846,7 +948,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo inputs: Mapped[str | None] = mapped_column(LongText) process_data: Mapped[str | None] = mapped_column(LongText) outputs: Mapped[str | None] = mapped_column(LongText) - status: Mapped[str] = mapped_column(String(255)) + status: Mapped[WorkflowNodeExecutionStatus] = mapped_column(EnumText(WorkflowNodeExecutionStatus, length=255)) error: Mapped[str | None] = mapped_column(LongText) elapsed_time: Mapped[float] = mapped_column(sa.Float, server_default=sa.text("0")) execution_metadata: Mapped[str | None] = mapped_column(LongText) @@ -1137,7 +1239,9 @@ class WorkflowAppLog(TypeBase): app_id: Mapped[str] = mapped_column(StringUUID) workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False) workflow_run_id: Mapped[str] = mapped_column(StringUUID) - created_from: Mapped[str] = mapped_column(String(255), nullable=False) + created_from: Mapped[WorkflowAppLogCreatedFrom] = mapped_column( + EnumText(WorkflowAppLogCreatedFrom, length=255), nullable=False + ) created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column( @@ -1217,10 +1321,14 @@ class WorkflowArchiveLog(TypeBase): log_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) log_created_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) - log_created_from: Mapped[str | None] = mapped_column(String(255), nullable=True) + log_created_from: Mapped[WorkflowAppLogCreatedFrom | None] = mapped_column( + EnumText(WorkflowAppLogCreatedFrom, length=255), nullable=True + ) run_version: Mapped[str] = mapped_column(String(255), nullable=False) - run_status: Mapped[str] = mapped_column(String(255), nullable=False) + run_status: Mapped[WorkflowExecutionStatus] = mapped_column( + EnumText(WorkflowExecutionStatus, length=255), nullable=False + ) run_triggered_from: Mapped[WorkflowRunTriggeredFrom] = mapped_column( EnumText(WorkflowRunTriggeredFrom, length=255), nullable=False ) @@ -1359,8 +1467,6 @@ class WorkflowDraftVariable(Base): # From `VARIABLE_PATTERN`, we may conclude that the length of a top level variable is less than # 80 chars. - # - # ref: api/dify_graph/entities/variable_pool.py:18 name: Mapped[str] = mapped_column(sa.String(255), nullable=False) description: Mapped[str] = mapped_column( sa.String(255), @@ -1475,10 +1581,9 @@ class WorkflowDraftVariable(Base): def _loads_value(self) -> Segment: value = json.loads(self.value) - return self.build_segment_with_type(self.value_type, value) + return self.build_segment_from_serialized_value(self.value_type, value) - @staticmethod - def rebuild_file_types(value: Any): + def _rebuild_file_types(self, value: Any): # NOTE(QuantumGhost): Temporary workaround for structured data handling. # By this point, `output` has been converted to dict by # `WorkflowEntry.handle_special_values`, so we need to @@ -1492,13 +1597,72 @@ class WorkflowDraftVariable(Base): if isinstance(value, dict): if not maybe_file_object(value): return cast(Any, value) - return File.model_validate(value) + tenant_id = _resolve_workflow_app_tenant_id(self.app_id) + return build_file_from_stored_mapping( + file_mapping=cast(dict[str, Any], value), + tenant_id=tenant_id, + ) elif isinstance(value, list) and value: value_list = cast(list[Any], value) first: Any = value_list[0] if not maybe_file_object(first): return cast(Any, value) - file_list: list[File] = [File.model_validate(cast(dict[str, Any], i)) for i in value_list] + tenant_id = _resolve_workflow_app_tenant_id(self.app_id) + file_list: list[File] = [] + for item in value_list: + file_list.append( + build_file_from_stored_mapping( + file_mapping=cast(dict[str, Any], item), + tenant_id=tenant_id, + ) + ) + return cast(Any, file_list) + else: + return cast(Any, value) + + def build_segment_from_serialized_value(self, segment_type: SegmentType, value: Any) -> Segment: + # Persisted draft variable rows may contain historical file payloads. + # Rebuild them through the file factory so tenant ownership, signed URLs, + # and storage-backed metadata come from canonical records instead of the + # serialized JSON blob. + if segment_type == SegmentType.FILE: + if isinstance(value, File): + return build_segment_with_type(segment_type, value) + elif isinstance(value, dict): + file = self._rebuild_file_types(value) + return build_segment_with_type(segment_type, file) + else: + raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}") + if segment_type == SegmentType.ARRAY_FILE: + if not isinstance(value, list): + raise TypeMismatchError(f"expected list for ArrayFileSegment, got {type(value)}") + file_list = self._rebuild_file_types(value) + return build_segment_with_type(segment_type=segment_type, value=file_list) + + return build_segment_with_type(segment_type=segment_type, value=value) + + @staticmethod + def rebuild_file_types(value: Any): + # Keep the class-level fallback for callers that only need lightweight + # structural reconstruction. Persisted draft-variable payloads should go + # through `build_segment_from_serialized_value()` so file metadata is + # rebuilt from canonical storage records. + if isinstance(value, dict): + if not maybe_file_object(value): + return cast(Any, value) + normalized_file = dict(value) + normalized_file.pop("tenant_id", None) + return File.model_validate(normalized_file) + elif isinstance(value, list) and value: + value_list = cast(list[Any], value) + first: Any = value_list[0] + if not maybe_file_object(first): + return cast(Any, value) + file_list: list[File] = [] + for item in value_list: + normalized_file = dict(cast(dict[str, Any], item)) + normalized_file.pop("tenant_id", None) + file_list.append(File.model_validate(normalized_file)) return cast(Any, file_list) else: return cast(Any, value) diff --git a/api/pyproject.toml b/api/pyproject.toml index f824fe7c236..a09b474bf5d 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dify-api" -version = "1.13.2" +version = "1.13.3" requires-python = ">=3.11,<3.13" dependencies = [ @@ -8,7 +8,7 @@ dependencies = [ "arize-phoenix-otel~=0.15.0", "azure-identity==1.25.3", "beautifulsoup4==4.14.3", - "boto3==1.42.68", + "boto3==1.42.78", "bs4~=0.0.1", "cachetools~=5.3.0", "celery~=5.6.2", @@ -23,47 +23,46 @@ dependencies = [ "gevent~=25.9.1", "gmpy2~=2.3.0", "google-api-core>=2.19.1", - "google-api-python-client==2.192.0", + "google-api-python-client==2.193.0", "google-auth>=2.47.0", "google-auth-httplib2==0.3.0", "google-cloud-aiplatform>=1.123.0", "googleapis-common-protos>=1.65.0", - "gunicorn~=25.1.0", + "graphon>=0.1.2", + "gunicorn~=25.3.0", "httpx[socks]~=0.28.0", "jieba==0.42.1", "json-repair>=0.55.1", - "jsonschema>=4.25.1", - "langfuse~=2.51.3", + "langfuse>=3.0.0,<5.0.0", "langsmith~=0.7.16", "markdown~=3.10.2", "mlflow-skinny>=3.0.0", "numpy~=1.26.4", "openpyxl~=3.1.5", "opik~=1.10.37", - "litellm==1.82.2", # Pinned to avoid madoka dependency issue - "opentelemetry-api==1.28.0", - "opentelemetry-distro==0.49b0", - "opentelemetry-exporter-otlp==1.28.0", - "opentelemetry-exporter-otlp-proto-common==1.28.0", - "opentelemetry-exporter-otlp-proto-grpc==1.28.0", - "opentelemetry-exporter-otlp-proto-http==1.28.0", - "opentelemetry-instrumentation==0.49b0", - "opentelemetry-instrumentation-celery==0.49b0", - "opentelemetry-instrumentation-flask==0.49b0", - "opentelemetry-instrumentation-httpx==0.49b0", - "opentelemetry-instrumentation-redis==0.49b0", - "opentelemetry-instrumentation-sqlalchemy==0.49b0", + "litellm==1.82.6", # Pinned to avoid madoka dependency issue + "opentelemetry-api==1.40.0", + "opentelemetry-distro==0.61b0", + "opentelemetry-exporter-otlp==1.40.0", + "opentelemetry-exporter-otlp-proto-common==1.40.0", + "opentelemetry-exporter-otlp-proto-grpc==1.40.0", + "opentelemetry-exporter-otlp-proto-http==1.40.0", + "opentelemetry-instrumentation==0.61b0", + "opentelemetry-instrumentation-celery==0.61b0", + "opentelemetry-instrumentation-flask==0.61b0", + "opentelemetry-instrumentation-httpx==0.61b0", + "opentelemetry-instrumentation-redis==0.61b0", + "opentelemetry-instrumentation-sqlalchemy==0.61b0", "opentelemetry-propagator-b3==1.40.0", - "opentelemetry-proto==1.28.0", - "opentelemetry-sdk==1.28.0", - "opentelemetry-semantic-conventions==0.49b0", - "opentelemetry-util-http==0.49b0", + "opentelemetry-proto==1.40.0", + "opentelemetry-sdk==1.40.0", + "opentelemetry-semantic-conventions==0.61b0", + "opentelemetry-util-http==0.61b0", "pandas[excel,output-formatting,performance]~=3.0.1", "psycogreen~=1.0.2", "psycopg2-binary~=2.9.6", "pycryptodome==3.23.0", "pydantic~=2.12.5", - "pydantic-extra-types~=2.11.0", "pydantic-settings~=2.13.1", "pyjwt~=2.12.0", "pypdfium2==5.6.0", @@ -71,16 +70,16 @@ dependencies = [ "python-dotenv==1.2.2", "pyyaml~=6.0.1", "readabilipy~=0.3.0", - "redis[hiredis]~=7.3.0", - "resend~=2.23.0", - "sentry-sdk[flask]~=2.54.0", + "redis[hiredis]~=7.4.0", + "resend~=2.26.0", + "sentry-sdk[flask]~=2.55.0", "sqlalchemy~=2.0.29", - "starlette==0.52.1", + "starlette==1.0.0", "tiktoken~=0.12.0", "transformers~=5.3.0", "unstructured[docx,epub,md,ppt,pptx]~=0.21.5", + "pypandoc~=1.13", "yarl~=1.23.0", - "webvtt-py~=0.5.1", "sseclient-py~=1.9.0", "httpx-sse~=0.4.0", "sendgrid~=6.12.3", @@ -91,7 +90,7 @@ dependencies = [ "apscheduler>=3.11.0", "weave>=0.52.16", "fastopenapi[flask]>=0.7.0", - "bleach~=6.2.0", + "bleach~=6.3.0", ] # Before adding new dependency, consider place it in # alphabet order (a-z) and suitable group. @@ -118,7 +117,7 @@ dev = [ "ruff~=0.15.5", "pytest~=9.0.2", "pytest-benchmark~=5.2.3", - "pytest-cov~=7.0.0", + "pytest-cov~=7.1.0", "pytest-env~=1.6.0", "pytest-mock~=3.15.1", "testcontainers~=4.14.1", @@ -129,7 +128,6 @@ dev = [ "types-defusedxml~=0.7.0", "types-deprecated~=1.3.1", "types-docutils~=0.22.3", - "types-jsonschema~=4.26.0", "types-flask-cors~=6.0.0", "types-flask-migrate~=4.1.0", "types-gevent~=25.9.0", @@ -149,7 +147,7 @@ dev = [ "types-python-dateutil~=2.9.0", "types-pywin32~=311.0.0", "types-pyyaml~=6.0.12", - "types-regex~=2026.2.28", + "types-regex~=2026.3.32", "types-shapely~=2.1.0", "types-simplejson>=3.20.0", "types-six>=1.17.0", @@ -173,7 +171,7 @@ dev = [ "sseclient-py>=1.8.0", "pytest-timeout>=2.4.0", "pytest-xdist>=3.8.0", - "pyrefly>=0.55.0", + "pyrefly>=0.57.1", ] ############################################################ @@ -202,10 +200,10 @@ tools = ["cloudscraper~=1.2.71", "nltk~=3.9.1"] # Required by vector store clients ############################################################ vdb = [ - "alibabacloud_gpdb20160503~=3.8.0", + "alibabacloud_gpdb20160503~=5.1.0", "alibabacloud_tea_openapi~=0.4.3", "chromadb==0.5.20", - "clickhouse-connect~=0.14.1", + "clickhouse-connect~=0.15.0", "clickzetta-connector-python>=0.8.102", "couchbase~=4.5.0", "elasticsearch==8.14.0", @@ -218,38 +216,18 @@ vdb = [ "pyobvector~=0.2.17", "qdrant-client==1.9.0", "intersystems-irispython>=5.1.0", - "tablestore==6.4.1", - "tcvectordb~=2.0.0", + "tablestore==6.4.2", + "tcvectordb~=2.1.0", "tidb-vector==0.0.15", "upstash-vector==0.8.0", "volcengine-compat~=1.0.0", "weaviate-client==4.20.4", - "xinference-client~=2.3.1", + "xinference-client~=2.4.0", "mo-vector~=0.1.13", "mysql-connector-python>=9.3.0", "holo-search-sdk>=0.4.1", ] -[tool.mypy] - -[[tool.mypy.overrides]] -# targeted ignores for current type-check errors -# TODO(QuantumGhost): suppress type errors in HITL related code. -# fix the type error later -module = [ - "configs.middleware.cache.redis_pubsub_config", - "extensions.ext_redis", - "tasks.workflow_execution_tasks", - "dify_graph.nodes.base.node", - "services.human_input_delivery_test_service", - "core.app.apps.advanced_chat.app_generator", - "controllers.console.human_input_form", - "controllers.console.app.workflow_run", - "repositories.sqlalchemy_api_workflow_node_execution_repository", - "extensions.logstore.repositories.logstore_api_workflow_run_repository", -] -ignore_errors = true - [tool.pyrefly] project-includes = ["."] project-excludes = [".venv", "migrations/"] diff --git a/api/pyrefly-local-excludes.txt b/api/pyrefly-local-excludes.txt index ad3c1e83895..43f604c2de9 100644 --- a/api/pyrefly-local-excludes.txt +++ b/api/pyrefly-local-excludes.txt @@ -109,34 +109,16 @@ core/trigger/debug/event_selectors.py core/trigger/entities/entities.py core/trigger/provider.py core/workflow/workflow_entry.py -dify_graph/entities/workflow_execution.py -dify_graph/file/file_manager.py -dify_graph/graph_engine/error_handler.py -dify_graph/graph_engine/layers/execution_limits.py -dify_graph/nodes/agent/agent_node.py -dify_graph/nodes/base/node.py -dify_graph/nodes/code/code_node.py -dify_graph/nodes/datasource/datasource_node.py -dify_graph/nodes/document_extractor/node.py -dify_graph/nodes/human_input/human_input_node.py -dify_graph/nodes/if_else/if_else_node.py -dify_graph/nodes/iteration/iteration_node.py -dify_graph/nodes/knowledge_index/knowledge_index_node.py +enterprise/telemetry/contracts.py +enterprise/telemetry/draft_trace.py +enterprise/telemetry/enterprise_trace.py +enterprise/telemetry/entities/__init__.py +enterprise/telemetry/event_handlers.py +enterprise/telemetry/exporter.py +enterprise/telemetry/id_generator.py +enterprise/telemetry/metric_handler.py +enterprise/telemetry/telemetry_log.py core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py -dify_graph/nodes/list_operator/node.py -dify_graph/nodes/llm/node.py -dify_graph/nodes/loop/loop_node.py -dify_graph/nodes/parameter_extractor/parameter_extractor_node.py -dify_graph/nodes/question_classifier/question_classifier_node.py -dify_graph/nodes/start/start_node.py -dify_graph/nodes/template_transform/template_transform_node.py -dify_graph/nodes/tool/tool_node.py -dify_graph/nodes/trigger_plugin/trigger_event_node.py -dify_graph/nodes/trigger_schedule/trigger_schedule_node.py -dify_graph/nodes/trigger_webhook/node.py -dify_graph/nodes/variable_aggregator/variable_aggregator_node.py -dify_graph/nodes/variable_assigner/v1/node.py -dify_graph/nodes/variable_assigner/v2/node.py extensions/logstore/repositories/logstore_api_workflow_run_repository.py extensions/otel/instrumentation.py extensions/otel/runtime.py diff --git a/api/repositories/api_workflow_node_execution_repository.py b/api/repositories/api_workflow_node_execution_repository.py index 2fa065bcc8d..3595ea33f07 100644 --- a/api/repositories/api_workflow_node_execution_repository.py +++ b/api/repositories/api_workflow_node_execution_repository.py @@ -16,7 +16,7 @@ from typing import Protocol from sqlalchemy.orm import Session -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.repositories.factory import WorkflowNodeExecutionRepository from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index a96c4acb31c..1a2a539c802 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -38,11 +38,11 @@ from collections.abc import Callable, Sequence from datetime import datetime from typing import Protocol +from graphon.entities.pause_reason import PauseReason +from graphon.enums import WorkflowType from sqlalchemy.orm import Session -from dify_graph.entities.pause_reason import PauseReason -from dify_graph.enums import WorkflowType -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.repositories.factory import WorkflowExecutionRepository from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun diff --git a/api/repositories/entities/workflow_pause.py b/api/repositories/entities/workflow_pause.py index be28b7e6137..03ce574dca2 100644 --- a/api/repositories/entities/workflow_pause.py +++ b/api/repositories/entities/workflow_pause.py @@ -10,7 +10,7 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from datetime import datetime -from dify_graph.entities.pause_reason import PauseReason +from graphon.entities.pause_reason import PauseReason class WorkflowPauseEntity(ABC): diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py index 77e40fc6fc6..d5c6a203b14 100644 --- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -10,11 +10,11 @@ from collections.abc import Sequence from datetime import datetime from typing import Protocol, cast +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from sqlalchemy import asc, delete, desc, func, select from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, sessionmaker -from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload from repositories.api_workflow_node_execution_repository import ( DifyAPIWorkflowNodeExecutionRepository, diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index fdd3e123e49..413936b542b 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -28,14 +28,14 @@ from decimal import Decimal from typing import Any, cast import sqlalchemy as sa +from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause +from graphon.enums import WorkflowExecutionStatus, WorkflowType +from graphon.nodes.human_input.entities import FormDefinition from pydantic import ValidationError from sqlalchemy import and_, delete, func, null, or_, select, tuple_ from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, selectinload, sessionmaker -from dify_graph.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause -from dify_graph.enums import WorkflowExecutionStatus, WorkflowType -from dify_graph.nodes.human_input.entities import FormDefinition from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from libs.helper import convert_datetime_to_date @@ -43,7 +43,7 @@ from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.time_parser import get_time_threshold from libs.uuid_utils import uuidv7 from models.enums import WorkflowRunTriggeredFrom -from models.human_input import HumanInputForm, HumanInputFormRecipient, RecipientType +from models.human_input import HumanInputForm from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.entities.workflow_pause import WorkflowPauseEntity @@ -61,25 +61,13 @@ class _WorkflowRunError(Exception): pass -def _select_recipient_token( - recipients: Sequence[HumanInputFormRecipient], - recipient_type: RecipientType, -) -> str | None: - for recipient in recipients: - if recipient.recipient_type == recipient_type and recipient.access_token: - return recipient.access_token - return None - - def _build_human_input_required_reason( reason_model: WorkflowPauseReason, form_model: HumanInputForm | None, - recipients: Sequence[HumanInputFormRecipient], ) -> HumanInputRequired: form_content = "" inputs = [] actions = [] - display_in_ui = False resolved_default_values: dict[str, Any] = {} node_title = "Human Input" form_id = reason_model.form_id @@ -99,25 +87,16 @@ def _build_human_input_required_reason( form_content = definition.form_content inputs = list(definition.inputs) actions = list(definition.user_actions) - display_in_ui = bool(definition.display_in_ui) resolved_default_values = dict(definition.default_values) node_title = definition.node_title or node_title - form_token = ( - _select_recipient_token(recipients, RecipientType.BACKSTAGE) - or _select_recipient_token(recipients, RecipientType.CONSOLE) - or _select_recipient_token(recipients, RecipientType.STANDALONE_WEB_APP) - ) - return HumanInputRequired( form_id=form_id, form_content=form_content, inputs=inputs, actions=actions, - display_in_ui=display_in_ui, node_id=node_id, node_title=node_title, - form_token=form_token, resolved_default_values=resolved_default_values, ) @@ -823,22 +802,16 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED and reason.form_id ] form_models: dict[str, HumanInputForm] = {} - recipient_models_by_form: dict[str, list[HumanInputFormRecipient]] = {} if form_ids: form_stmt = select(HumanInputForm).where(HumanInputForm.id.in_(form_ids)) for form in session.scalars(form_stmt).all(): form_models[form.id] = form - recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids)) - for recipient in session.scalars(recipient_stmt).all(): - recipient_models_by_form.setdefault(recipient.form_id, []).append(recipient) - pause_reasons: list[PauseReason] = [] for reason in pause_reason_models: if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED: form_model = form_models.get(reason.form_id) - recipients = recipient_models_by_form.get(reason.form_id, []) - pause_reasons.append(_build_human_input_required_reason(reason, form_model, recipients)) + pause_reasons.append(_build_human_input_required_reason(reason, form_model)) else: pause_reasons.append(reason.to_entity()) return pause_reasons diff --git a/api/repositories/sqlalchemy_execution_extra_content_repository.py b/api/repositories/sqlalchemy_execution_extra_content_repository.py index 508db22eb03..feba5f7eb65 100644 --- a/api/repositories/sqlalchemy_execution_extra_content_repository.py +++ b/api/repositories/sqlalchemy_execution_extra_content_repository.py @@ -7,6 +7,9 @@ from collections import defaultdict from collections.abc import Sequence from typing import Any +from graphon.nodes.human_input.entities import FormDefinition +from graphon.nodes.human_input.enums import HumanInputFormStatus +from graphon.nodes.human_input.human_input_node import HumanInputNode from sqlalchemy import select from sqlalchemy.orm import Session, selectinload, sessionmaker @@ -18,9 +21,6 @@ from core.entities.execution_extra_content import ( from core.entities.execution_extra_content import ( HumanInputContent as HumanInputContentDomainModel, ) -from dify_graph.nodes.human_input.entities import FormDefinition -from dify_graph.nodes.human_input.enums import HumanInputFormStatus -from dify_graph.nodes.human_input.human_input_node import HumanInputNode from models.execution_extra_content import ( ExecutionExtraContent as ExecutionExtraContentModel, ) diff --git a/api/schedule/create_tidb_serverless_task.py b/api/schedule/create_tidb_serverless_task.py index 8b9d973d6dc..6ceb3ef856a 100644 --- a/api/schedule/create_tidb_serverless_task.py +++ b/api/schedule/create_tidb_serverless_task.py @@ -8,6 +8,7 @@ from configs import dify_config from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService from extensions.ext_database import db from models.dataset import TidbAuthBinding +from models.enums import TidbAuthBindingStatus @app.celery.task(queue="dataset") @@ -57,7 +58,7 @@ def create_clusters(batch_size): account=new_cluster["account"], password=new_cluster["password"], active=False, - status="CREATING", + status=TidbAuthBindingStatus.CREATING, ) db.session.add(tidb_auth_binding) db.session.commit() diff --git a/api/schedule/update_tidb_serverless_status_task.py b/api/schedule/update_tidb_serverless_status_task.py index 1befa0e8b5e..10003b1b975 100644 --- a/api/schedule/update_tidb_serverless_status_task.py +++ b/api/schedule/update_tidb_serverless_status_task.py @@ -9,6 +9,7 @@ from configs import dify_config from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService from extensions.ext_database import db from models.dataset import TidbAuthBinding +from models.enums import TidbAuthBindingStatus @app.celery.task(queue="dataset") @@ -18,7 +19,10 @@ def update_tidb_serverless_status_task(): try: # check the number of idle tidb serverless tidb_serverless_list = db.session.scalars( - select(TidbAuthBinding).where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING") + select(TidbAuthBinding).where( + TidbAuthBinding.active == False, + TidbAuthBinding.status == TidbAuthBindingStatus.CREATING, + ) ).all() if len(tidb_serverless_list) == 0: return diff --git a/api/services/account_service.py b/api/services/account_service.py index bd520f54cfc..cc8ef08857a 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -7,9 +7,19 @@ from datetime import UTC, datetime, timedelta from hashlib import sha256 from typing import Any, cast -from pydantic import BaseModel +from pydantic import BaseModel, TypeAdapter from sqlalchemy import func, select from sqlalchemy.orm import Session +from typing_extensions import TypedDict + + +class InvitationData(TypedDict): + account_id: str + email: str + workspace_id: str + + +_invitation_adapter: TypeAdapter[InvitationData] = TypeAdapter(InvitationData) from werkzeug.exceptions import Unauthorized from configs import dify_config @@ -1571,7 +1581,7 @@ class RegisterService: @classmethod def get_invitation_by_token( cls, token: str, workspace_id: str | None = None, email: str | None = None - ) -> dict[str, str] | None: + ) -> InvitationData | None: if workspace_id is not None and email is not None: email_hash = sha256(email.encode()).hexdigest() cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}" @@ -1590,7 +1600,7 @@ class RegisterService: if not data: return None - invitation: dict = json.loads(data) + invitation = _invitation_adapter.validate_json(data) return invitation @classmethod diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 68cb3438caa..dd73e103746 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -11,6 +11,12 @@ from uuid import uuid4 import yaml from Crypto.Cipher import AES from Crypto.Util.Padding import pad, unpad +from graphon.enums import BuiltinNodeTypes +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes.llm.entities import LLMNodeData +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData +from graphon.nodes.tool.entities import ToolNodeData from packaging import version from packaging.version import parse as parse_version from pydantic import BaseModel, Field @@ -27,12 +33,6 @@ from core.trigger.constants import ( ) from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from core.workflow.nodes.trigger_schedule.trigger_schedule_node import TriggerScheduleNode -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.nodes.llm.entities import LLMNodeData -from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData -from dify_graph.nodes.tool.entities import ToolNodeData from events.app_event import app_model_config_was_updated, app_was_created from extensions.ext_redis import redis_client from factories import variable_factory diff --git a/api/services/app_service.py b/api/services/app_service.py index c5d1479a202..e9aeb6c43d0 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -4,6 +4,8 @@ from typing import Any, TypedDict, cast import sqlalchemy as sa from flask_sqlalchemy.pagination import Pagination +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from configs import dify_config from constants.model_template import default_app_templates @@ -12,9 +14,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from events.app_event import app_was_created +from events.app_event import app_was_created, app_was_deleted, app_was_updated from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from libs.login import current_user @@ -92,7 +92,7 @@ class AppService: default_model_config = default_model_config.copy() if default_model_config else None if default_model_config and "model" in default_model_config: # get model provider - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=account.current_tenant_id or "") # get default model instance try: @@ -124,11 +124,19 @@ class AppService: "completion_params": {}, } else: - provider, model = model_manager.get_default_provider_model_name( - tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM - ) - default_model_config["model"]["provider"] = provider - default_model_config["model"]["name"] = model + try: + provider, model = model_manager.get_default_provider_model_name( + tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM + ) + except Exception: + logger.exception("Get default provider model failed, tenant_id: %s", tenant_id) + provider = default_model_config["model"].get("provider") + model = default_model_config["model"].get("name") + + if provider: + default_model_config["model"]["provider"] = provider + if model: + default_model_config["model"]["name"] = model default_model_dict = default_model_config["model"] default_model_config["model"] = json.dumps(default_model_dict) @@ -197,6 +205,7 @@ class AppService: tenant_id=current_user.current_tenant_id, app_id=app.id, agent_tool=agent_tool_entity, + user_id=current_user.id, ) manager = ToolParameterConfigurationManager( tenant_id=current_user.current_tenant_id, @@ -241,7 +250,7 @@ class AppService: class ArgsDict(TypedDict): name: str description: str - icon_type: str + icon_type: IconType | str | None icon: str icon_background: str use_icon_as_answer_icon: bool @@ -257,7 +266,13 @@ class AppService: assert current_user is not None app.name = args["name"] app.description = args["description"] - app.icon_type = IconType(args["icon_type"]) if args["icon_type"] else None + icon_type = args.get("icon_type") + if icon_type is None: + resolved_icon_type = app.icon_type + else: + resolved_icon_type = IconType(icon_type) + + app.icon_type = resolved_icon_type app.icon = args["icon"] app.icon_background = args["icon_background"] app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False) @@ -266,6 +281,8 @@ class AppService: app.updated_at = naive_utc_now() db.session.commit() + app_was_updated.send(app) + return app def update_app_name(self, app: App, name: str) -> App: @@ -281,6 +298,8 @@ class AppService: app.updated_at = naive_utc_now() db.session.commit() + app_was_updated.send(app) + return app def update_app_icon(self, app: App, icon: str, icon_background: str) -> App: @@ -298,6 +317,8 @@ class AppService: app.updated_at = naive_utc_now() db.session.commit() + app_was_updated.send(app) + return app def update_app_site_status(self, app: App, enable_site: bool) -> App: @@ -315,6 +336,8 @@ class AppService: app.updated_at = naive_utc_now() db.session.commit() + app_was_updated.send(app) + return app def update_app_api_status(self, app: App, enable_api: bool) -> App: @@ -333,6 +356,8 @@ class AppService: app.updated_at = naive_utc_now() db.session.commit() + app_was_updated.send(app) + return app def delete_app(self, app: App): @@ -340,6 +365,8 @@ class AppService: Delete app :param app: App instance """ + app_was_deleted.send(app) + db.session.delete(app) db.session.commit() diff --git a/api/services/app_task_service.py b/api/services/app_task_service.py index d5562300441..0842e9d3e7f 100644 --- a/api/services/app_task_service.py +++ b/api/services/app_task_service.py @@ -5,9 +5,10 @@ like stopping tasks, handling both legacy Redis flag mechanism and new GraphEngine command channel mechanism. """ +from graphon.graph_engine.manager import GraphEngineManager + from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.graph_engine.manager import GraphEngineManager from extensions.ext_redis import redis_client from models.model import AppMode diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 1794ea9947b..90e72d5f34f 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -5,11 +5,11 @@ from collections.abc import Generator from typing import cast from flask import Response, stream_with_context +from graphon.model_runtime.entities.model_entities import ModelType from werkzeug.datastructures import FileStorage from constants import AUDIO_EXTENSIONS from core.model_manager import ModelManager -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from models.enums import MessageStatus from models.model import App, AppMode, Message @@ -61,7 +61,7 @@ class AudioService: message = f"Audio size larger than {FILE_SIZE} mb" raise AudioTooLargeServiceError(message) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=app_model.tenant_id, user_id=end_user) model_instance = model_manager.get_default_model_instance( tenant_id=app_model.tenant_id, model_type=ModelType.SPEECH2TEXT ) @@ -71,7 +71,7 @@ class AudioService: buffer = io.BytesIO(file_content) buffer.name = "temp.mp3" - return {"text": model_instance.invoke_speech2text(file=buffer, user=end_user)} + return {"text": model_instance.invoke_speech2text(file=buffer)} @classmethod def transcript_tts( @@ -109,7 +109,7 @@ class AudioService: voice = cast(str | None, text_to_speech_dict.get("voice")) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=app_model.tenant_id, user_id=end_user) model_instance = model_manager.get_default_model_instance( tenant_id=app_model.tenant_id, model_type=ModelType.TTS ) @@ -123,9 +123,7 @@ class AudioService: else: raise ValueError("Sorry, no voice available.") - return model_instance.invoke_tts( - content_text=text_content.strip(), user=end_user, tenant_id=app_model.tenant_id, voice=voice - ) + return model_instance.invoke_tts(content_text=text_content.strip(), voice=voice) except Exception as e: raise e @@ -155,7 +153,7 @@ class AudioService: @classmethod def transcript_tts_voices(cls, tenant_id: str, language: str): - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.TTS) if model_instance is None: raise ProviderNotSupportTextToSpeechServiceError() diff --git a/api/services/auth/api_key_auth_base.py b/api/services/auth/api_key_auth_base.py index dd74a8f1b53..2e1b723e82e 100644 --- a/api/services/auth/api_key_auth_base.py +++ b/api/services/auth/api_key_auth_base.py @@ -1,8 +1,16 @@ from abc import ABC, abstractmethod +from typing import Any + +from typing_extensions import TypedDict + + +class AuthCredentials(TypedDict): + auth_type: str + config: dict[str, Any] class ApiKeyAuthBase(ABC): - def __init__(self, credentials: dict): + def __init__(self, credentials: AuthCredentials): self.credentials = credentials @abstractmethod diff --git a/api/services/auth/api_key_auth_factory.py b/api/services/auth/api_key_auth_factory.py index 7ae31b07688..6e183b70e3a 100644 --- a/api/services/auth/api_key_auth_factory.py +++ b/api/services/auth/api_key_auth_factory.py @@ -1,9 +1,9 @@ -from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials from services.auth.auth_type import AuthType class ApiKeyAuthFactory: - def __init__(self, provider: str, credentials: dict): + def __init__(self, provider: str, credentials: AuthCredentials): auth_factory = self.get_apikey_auth_factory(provider) self.auth = auth_factory(credentials) diff --git a/api/services/auth/api_key_auth_service.py b/api/services/auth/api_key_auth_service.py index 56aaf407eeb..3282dcfb113 100644 --- a/api/services/auth/api_key_auth_service.py +++ b/api/services/auth/api_key_auth_service.py @@ -35,15 +35,13 @@ class ApiKeyAuthService: @staticmethod def get_auth_credentials(tenant_id: str, category: str, provider: str): - data_source_api_key_bindings = ( - db.session.query(DataSourceApiKeyAuthBinding) - .where( + data_source_api_key_bindings = db.session.scalar( + select(DataSourceApiKeyAuthBinding).where( DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.category == category, DataSourceApiKeyAuthBinding.provider == provider, DataSourceApiKeyAuthBinding.disabled.is_(False), ) - .first() ) if not data_source_api_key_bindings: return None @@ -54,10 +52,11 @@ class ApiKeyAuthService: @staticmethod def delete_provider_auth(tenant_id: str, binding_id: str): - data_source_api_key_binding = ( - db.session.query(DataSourceApiKeyAuthBinding) - .where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id) - .first() + data_source_api_key_binding = db.session.scalar( + select(DataSourceApiKeyAuthBinding).where( + DataSourceApiKeyAuthBinding.tenant_id == tenant_id, + DataSourceApiKeyAuthBinding.id == binding_id, + ) ) if data_source_api_key_binding: db.session.delete(data_source_api_key_binding) diff --git a/api/services/auth/firecrawl/firecrawl.py b/api/services/auth/firecrawl/firecrawl.py index b0027069317..c9e5610aead 100644 --- a/api/services/auth/firecrawl/firecrawl.py +++ b/api/services/auth/firecrawl/firecrawl.py @@ -2,11 +2,11 @@ import json import httpx -from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials class FirecrawlAuth(ApiKeyAuthBase): - def __init__(self, credentials: dict): + def __init__(self, credentials: AuthCredentials): super().__init__(credentials) auth_type = credentials.get("auth_type") if auth_type != "bearer": diff --git a/api/services/auth/jina.py b/api/services/auth/jina.py index afaed28ac98..e5e2319ce13 100644 --- a/api/services/auth/jina.py +++ b/api/services/auth/jina.py @@ -2,11 +2,11 @@ import json import httpx -from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials class JinaAuth(ApiKeyAuthBase): - def __init__(self, credentials: dict): + def __init__(self, credentials: AuthCredentials): super().__init__(credentials) auth_type = credentials.get("auth_type") if auth_type != "bearer": diff --git a/api/services/auth/jina/jina.py b/api/services/auth/jina/jina.py index afaed28ac98..e5e2319ce13 100644 --- a/api/services/auth/jina/jina.py +++ b/api/services/auth/jina/jina.py @@ -2,11 +2,11 @@ import json import httpx -from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials class JinaAuth(ApiKeyAuthBase): - def __init__(self, credentials: dict): + def __init__(self, credentials: AuthCredentials): super().__init__(credentials) auth_type = credentials.get("auth_type") if auth_type != "bearer": diff --git a/api/services/auth/watercrawl/watercrawl.py b/api/services/auth/watercrawl/watercrawl.py index b2d28a83d14..cbdc908690f 100644 --- a/api/services/auth/watercrawl/watercrawl.py +++ b/api/services/auth/watercrawl/watercrawl.py @@ -3,11 +3,11 @@ from urllib.parse import urljoin import httpx -from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials class WatercrawlAuth(ApiKeyAuthBase): - def __init__(self, credentials: dict): + def __init__(self, credentials: AuthCredentials): super().__init__(credentials) auth_type = credentials.get("auth_type") if auth_type != "x-api-key": diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 5ab47c799ad..70d4ce1ee6b 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -335,7 +335,11 @@ class BillingService: # Redis returns bytes, decode to string and parse JSON json_str = cached_value.decode("utf-8") if isinstance(cached_value, bytes) else cached_value plan_dict = json.loads(json_str) + # NOTE (hj24): New billing versions may return timestamp as str, and validate_python + # in non-strict mode will coerce it to the expected int type. + # To preserve compatibility, always keep non-strict mode here and avoid strict mode. subscription_plan = subscription_adapter.validate_python(plan_dict) + # NOTE END tenant_plans[tenant_id] = subscription_plan except Exception: logger.exception( diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index 0e0eab00ad1..1c128524ad4 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -6,11 +6,11 @@ from concurrent.futures import ThreadPoolExecutor import click from flask import Flask, current_app +from graphon.model_runtime.utils.encoders import jsonable_encoder from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from configs import dify_config -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from enums.cloud_plan import CloudPlan from extensions.ext_database import db from extensions.ext_storage import storage diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 566c27c0f3c..ba1e7bb8266 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -3,6 +3,7 @@ import logging from collections.abc import Callable, Sequence from typing import Any, Union +from graphon.variables.types import SegmentType from sqlalchemy import asc, desc, func, or_, select from sqlalchemy.orm import Session @@ -10,7 +11,6 @@ from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.db.session_factory import session_factory from core.llm_generator.llm_generator import LLMGenerator -from dify_graph.variables.types import SegmentType from extensions.ext_database import db from factories import variable_factory from libs.datetime_utils import naive_utc_now diff --git a/api/services/conversation_variable_updater.py b/api/services/conversation_variable_updater.py index f00e3fe01e4..95a8951951c 100644 --- a/api/services/conversation_variable_updater.py +++ b/api/services/conversation_variable_updater.py @@ -1,7 +1,7 @@ +from graphon.variables.variables import VariableBase from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker -from dify_graph.variables.variables import VariableBase from models import ConversationVariable diff --git a/api/services/credit_pool_service.py b/api/services/credit_pool_service.py index 1954602571f..2894826935d 100644 --- a/api/services/credit_pool_service.py +++ b/api/services/credit_pool_service.py @@ -7,6 +7,7 @@ from configs import dify_config from core.errors.error import QuotaExceededError from extensions.ext_database import db from models import TenantCreditPool +from models.enums import ProviderQuotaType logger = logging.getLogger(__name__) @@ -16,7 +17,10 @@ class CreditPoolService: def create_default_pool(cls, tenant_id: str) -> TenantCreditPool: """create default credit pool for new tenant""" credit_pool = TenantCreditPool( - tenant_id=tenant_id, quota_limit=dify_config.HOSTED_POOL_CREDITS, quota_used=0, pool_type="trial" + tenant_id=tenant_id, + quota_limit=dify_config.HOSTED_POOL_CREDITS, + quota_used=0, + pool_type=ProviderQuotaType.TRIAL, ) db.session.add(credit_pool) db.session.commit() diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index cdab90a3dce..83363125c38 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -10,6 +10,9 @@ from collections.abc import Sequence from typing import Any, Literal, cast import sqlalchemy as sa +from graphon.file import helpers as file_helpers +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType +from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from redis.exceptions import LockNotOwnedError from sqlalchemy import exists, func, select from sqlalchemy.orm import Session @@ -21,11 +24,8 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.helper.name_generator import generate_incremental_name from core.model_manager import ModelManager from core.rag.index_processor.constant.built_in_field import BuiltInField -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod -from dify_graph.file import helpers as file_helpers -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType -from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from enums.cloud_plan import CloudPlan from events.dataset_event import dataset_was_deleted from events.document_event import document_was_deleted @@ -58,6 +58,7 @@ from models.enums import ( IndexingStatus, ProcessRuleMode, SegmentStatus, + SegmentType, ) from models.model import UploadFile from models.provider_ids import ModelProviderID @@ -227,8 +228,8 @@ class DatasetService: if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first(): raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.") embedding_model = None - if indexing_technique == "high_quality": - model_manager = ModelManager() + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) if embedding_model_provider and embedding_model_name: # check if embedding model setting is valid DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model_name) @@ -253,7 +254,10 @@ class DatasetService: retrieval_model.reranking_model.reranking_provider_name, retrieval_model.reranking_model.reranking_model_name, ) - dataset = Dataset(name=name, indexing_technique=indexing_technique) + dataset = Dataset( + name=name, + indexing_technique=IndexTechniqueType(indexing_technique) if indexing_technique else None, + ) # dataset = Dataset(name=name, provider=provider, config=config) dataset.description = description dataset.created_by = account.id @@ -348,9 +352,9 @@ class DatasetService: @staticmethod def check_dataset_model_setting(dataset): - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) model_manager.get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, @@ -367,7 +371,7 @@ class DatasetService: @staticmethod def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model: str): try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_manager.get_model_instance( tenant_id=tenant_id, provider=embedding_model_provider, @@ -384,7 +388,7 @@ class DatasetService: @staticmethod def check_is_multimodal_model(tenant_id: str, model_provider: str, model: str): try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, provider=model_provider, @@ -405,7 +409,7 @@ class DatasetService: @staticmethod def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str): try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_manager.get_model_instance( tenant_id=tenant_id, provider=reranking_model_provider, @@ -716,13 +720,13 @@ class DatasetService: if "indexing_technique" not in data: return None if dataset.indexing_technique != data["indexing_technique"]: - if data["indexing_technique"] == "economy": + if data["indexing_technique"] == IndexTechniqueType.ECONOMY: # Remove embedding model configuration for economy mode filtered_data["embedding_model"] = None filtered_data["embedding_model_provider"] = None filtered_data["collection_binding_id"] = None return "remove" - elif data["indexing_technique"] == "high_quality": + elif data["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY: # Configure embedding model for high quality mode DatasetService._configure_embedding_model_for_high_quality(data, filtered_data) return "add" @@ -742,7 +746,7 @@ class DatasetService: """ # assert isinstance(current_user, Account) and current_user.current_tenant_id is not None try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None embedding_model = model_manager.get_model_instance( @@ -860,7 +864,7 @@ class DatasetService: """ # assert isinstance(current_user, Account) and current_user.current_tenant_id is not None - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) try: assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None @@ -952,9 +956,9 @@ class DatasetService: dataset = session.merge(dataset) if not has_published: dataset.chunk_structure = knowledge_configuration.chunk_structure - dataset.indexing_technique = knowledge_configuration.indexing_technique - if knowledge_configuration.indexing_technique == "high_quality": - model_manager = ModelManager() + dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique) + if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, # ignore type error provider=knowledge_configuration.embedding_model_provider or "", @@ -975,7 +979,7 @@ class DatasetService: embedding_model_name, ) dataset.collection_binding_id = dataset_collection_binding.id - elif knowledge_configuration.indexing_technique == "economy": + elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY: dataset.keyword_number = knowledge_configuration.keyword_number else: raise ValueError("Invalid index method") @@ -990,13 +994,13 @@ class DatasetService: action = None if dataset.indexing_technique != knowledge_configuration.indexing_technique: # if update indexing_technique - if knowledge_configuration.indexing_technique == "economy": + if knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY: raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.") - elif knowledge_configuration.indexing_technique == "high_quality": + elif knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: action = "add" # get embedding model setting try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=knowledge_configuration.embedding_model_provider, @@ -1017,7 +1021,7 @@ class DatasetService: ) dataset.is_multimodal = is_multimodal dataset.collection_binding_id = dataset_collection_binding.id - dataset.indexing_technique = knowledge_configuration.indexing_technique + dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique) except LLMBadRequestError: raise ValueError( "No Embedding Model available. Please configure a valid provider " @@ -1028,7 +1032,7 @@ class DatasetService: else: # add default plugin id to both setting sets, to make sure the plugin model provider is consistent # Skip embedding model checks if not provided in the update request - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: skip_embedding_update = False try: # Handle existing model provider @@ -1049,7 +1053,7 @@ class DatasetService: or knowledge_configuration.embedding_model != dataset.embedding_model ): action = "update" - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = None try: embedding_model = model_manager.get_model_instance( @@ -1088,7 +1092,7 @@ class DatasetService: ) except ProviderTokenNotInitError as ex: raise ValueError(ex.description) - elif dataset.indexing_technique == "economy": + elif dataset.indexing_technique == IndexTechniqueType.ECONOMY: if dataset.keyword_number != knowledge_configuration.keyword_number: dataset.keyword_number = knowledge_configuration.keyword_number dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() @@ -1439,7 +1443,7 @@ class DocumentService: .filter( Document.id.in_(document_id_list), Document.dataset_id == dataset_id, - Document.doc_form != "qa_model", # Skip qa_model documents + Document.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents ) .update({Document.need_summary: need_summary}, synchronize_session=False) ) @@ -1906,9 +1910,9 @@ class DocumentService: if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: raise ValueError("Indexing technique is invalid") - dataset.indexing_technique = knowledge_config.indexing_technique - if knowledge_config.indexing_technique == "high_quality": - model_manager = ModelManager() + dataset.indexing_technique = IndexTechniqueType(knowledge_config.indexing_technique) + if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY: + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) if knowledge_config.embedding_model and knowledge_config.embedding_model_provider: dataset_embedding_model = knowledge_config.embedding_model dataset_embedding_model_provider = knowledge_config.embedding_model_provider @@ -2039,7 +2043,7 @@ class DocumentService: document.dataset_process_rule_id = dataset_process_rule.id document.updated_at = naive_utc_now() document.created_from = created_from - document.doc_form = knowledge_config.doc_form + document.doc_form = IndexStructureType(knowledge_config.doc_form) document.doc_language = knowledge_config.doc_language document.data_source_info = json.dumps(data_source_info) document.batch = batch @@ -2220,7 +2224,7 @@ class DocumentService: # dataset.indexing_technique = knowledge_config.indexing_technique # if knowledge_config.indexing_technique == "high_quality": - # model_manager = ModelManager() + # model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) # if knowledge_config.embedding_model and knowledge_config.embedding_model_provider: # dataset_embedding_model = knowledge_config.embedding_model # dataset_embedding_model_provider = knowledge_config.embedding_model_provider @@ -2639,7 +2643,7 @@ class DocumentService: document.splitting_completed_at = None document.updated_at = naive_utc_now() document.created_from = created_from - document.doc_form = document_data.doc_form + document.doc_form = IndexStructureType(document_data.doc_form) db.session.add(document) db.session.commit() # update document segment @@ -2688,7 +2692,7 @@ class DocumentService: dataset_collection_binding_id = None retrieval_model = None - if knowledge_config.indexing_technique == "high_quality": + if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY: assert knowledge_config.embedding_model_provider assert knowledge_config.embedding_model dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( @@ -2711,7 +2715,7 @@ class DocumentService: tenant_id=tenant_id, name="", data_source_type=knowledge_config.data_source.info_list.data_source_type, - indexing_technique=knowledge_config.indexing_technique, + indexing_technique=IndexTechniqueType(knowledge_config.indexing_technique), created_by=account.id, embedding_model=knowledge_config.embedding_model, embedding_model_provider=knowledge_config.embedding_model_provider, @@ -3100,7 +3104,7 @@ class DocumentService: class SegmentService: @classmethod def segment_create_args_validate(cls, args: dict, document: Document): - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: if "answer" not in args or not args["answer"]: raise ValueError("Answer is required") if not args["answer"].strip(): @@ -3124,8 +3128,8 @@ class SegmentService: doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == "high_quality": - model_manager = ModelManager() + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, @@ -3157,7 +3161,7 @@ class SegmentService: completed_at=naive_utc_now(), created_by=current_user.id, ) - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: segment_document.word_count += len(args["answer"]) segment_document.answer = args["answer"] @@ -3207,8 +3211,8 @@ class SegmentService: try: with redis_client.lock(lock_name, timeout=600): embedding_model = None - if dataset.indexing_technique == "high_quality": - model_manager = ModelManager() + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, @@ -3229,9 +3233,9 @@ class SegmentService: doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == "high_quality" and embedding_model: + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY and embedding_model: # calc embedding use tokens - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: tokens = embedding_model.get_text_embedding_num_tokens( texts=[content + segment_item["answer"]] )[0] @@ -3254,7 +3258,7 @@ class SegmentService: completed_at=naive_utc_now(), created_by=current_user.id, ) - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: segment_document.answer = segment_item["answer"] segment_document.word_count += len(segment_item["answer"]) increment_word_count += segment_document.word_count @@ -3321,7 +3325,7 @@ class SegmentService: content = args.content or segment.content if segment.content == content: segment.word_count = len(content) - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: segment.answer = args.answer segment.word_count += len(args.answer) if args.answer else 0 word_count_change = segment.word_count - word_count_change @@ -3344,9 +3348,9 @@ class SegmentService: if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: # regenerate child chunks # get embedding model instance - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) if dataset.embedding_model_provider: embedding_model_instance = model_manager.get_model_instance( @@ -3381,7 +3385,7 @@ class SegmentService: # When user manually provides summary, allow saving even if summary_index_setting doesn't exist # summary_index_setting is only needed for LLM generation, not for manual summary vectorization # Vectorization uses dataset.embedding_model, which doesn't require summary_index_setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # Query existing summary from database from models.dataset import DocumentSegmentSummary @@ -3408,8 +3412,8 @@ class SegmentService: else: segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == "high_quality": - model_manager = ModelManager() + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, @@ -3418,7 +3422,7 @@ class SegmentService: ) # calc embedding use tokens - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: segment.answer = args.answer tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0] # type: ignore else: @@ -3435,7 +3439,7 @@ class SegmentService: segment.enabled = True segment.disabled_at = None segment.disabled_by = None - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: segment.answer = args.answer segment.word_count += len(args.answer) if args.answer else 0 word_count_change = segment.word_count - word_count_change @@ -3448,9 +3452,9 @@ class SegmentService: db.session.commit() if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: # get embedding model instance - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) if dataset.embedding_model_provider: embedding_model_instance = model_manager.get_model_instance( @@ -3480,7 +3484,7 @@ class SegmentService: # update segment vector index VectorService.update_segment_vector(args.keywords, segment, dataset) # Handle summary index when content changed - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: from models.dataset import DocumentSegmentSummary existing_summary = ( @@ -3786,7 +3790,7 @@ class SegmentService: child_chunk.word_count = len(child_chunk.content) child_chunk.updated_by = current_user.id child_chunk.updated_at = naive_utc_now() - child_chunk.type = "customized" + child_chunk.type = SegmentType.CUSTOMIZED update_child_chunks.append(child_chunk) else: new_child_chunks_args.append(child_chunk_update_args) @@ -3845,7 +3849,7 @@ class SegmentService: child_chunk.word_count = len(content) child_chunk.updated_by = current_user.id child_chunk.updated_at = naive_utc_now() - child_chunk.type = "customized" + child_chunk.type = SegmentType.CUSTOMIZED db.session.add(child_chunk) VectorService.update_child_chunk_vector([], [child_chunk], [], dataset) db.session.commit() diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index f3b2adb965d..06f83a18f7a 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -3,6 +3,7 @@ import time from collections.abc import Mapping from typing import Any +from graphon.model_runtime.entities.provider_entities import FormType from sqlalchemy.orm import Session from configs import dify_config @@ -14,7 +15,6 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.datasource import PluginDatasourceManager from core.plugin.impl.oauth import OAuthHandler from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter -from dify_graph.model_runtime.entities.provider_entities import FormType from extensions.ext_database import db from extensions.ext_redis import redis_client from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index 9dd595f5160..a944ef6acdd 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -1,6 +1,15 @@ from collections.abc import Sequence from enum import StrEnum +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + ModelCredentialSchema, + ProviderCredentialSchema, + ProviderHelpEntity, + SimpleProviderEntity, +) from pydantic import BaseModel, ConfigDict, model_validator from configs import dify_config @@ -15,15 +24,6 @@ from core.entities.provider_entities import ( QuotaConfiguration, UnaddedModelConfiguration, ) -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - ModelCredentialSchema, - ProviderCredentialSchema, - ProviderHelpEntity, - SimpleProviderEntity, -) from models.provider import ProviderType diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 4cf42b7f44a..64852c222f3 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -4,12 +4,12 @@ from typing import Any, Union, cast from urllib.parse import urlparse import httpx +from graphon.nodes.http_request.exc import InvalidHttpMethodError from sqlalchemy import select from constants import HIDDEN_VALUE from core.helper import ssrf_proxy from core.rag.entities.metadata_entities import MetadataCondition -from dify_graph.nodes.http_request.exc import InvalidHttpMethodError from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.dataset import ( diff --git a/api/services/file_service.py b/api/services/file_service.py index a7060f3b928..50a326d8138 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -8,6 +8,7 @@ from tempfile import NamedTemporaryFile from typing import Literal, Union from zipfile import ZIP_DEFLATED, ZipFile +from graphon.file import helpers as file_helpers from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker from werkzeug.exceptions import NotFound @@ -20,7 +21,6 @@ from constants import ( VIDEO_EXTENSIONS, ) from core.rag.extractor.extract_processor import ExtractProcessor -from dify_graph.file import helpers as file_helpers from extensions.ext_database import db from extensions.ext_storage import storage from extensions.storage.storage_type import StorageType diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 9993d24c70d..82e0b0f8b1f 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -3,13 +3,14 @@ import logging import time from typing import Any +from graphon.model_runtime.entities import LLMMode + from core.app.app_config.entities import ModelConfig from core.rag.datasource.retrieval_service import RetrievalService from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod -from dify_graph.model_runtime.entities import LLMMode from extensions.ext_database import db from models import Account from models.dataset import Dataset, DatasetQuery diff --git a/api/services/human_input_delivery_test_service.py b/api/services/human_input_delivery_test_service.py index 229e6608da7..77576fa4c0d 100644 --- a/api/services/human_input_delivery_test_service.py +++ b/api/services/human_input_delivery_test_service.py @@ -4,18 +4,18 @@ from dataclasses import dataclass, field from enum import StrEnum from typing import Protocol +from graphon.runtime import VariablePool from sqlalchemy import Engine, select from sqlalchemy.orm import sessionmaker from configs import dify_config -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( DeliveryChannelConfig, EmailDeliveryConfig, EmailDeliveryMethod, ExternalRecipient, MemberRecipient, ) -from dify_graph.runtime import VariablePool from extensions.ext_database import db from extensions.ext_mail import mail from libs.email_template_renderer import render_email_template @@ -177,21 +177,21 @@ class EmailDeliveryTestHandler: def _resolve_recipients(self, *, tenant_id: str, method: EmailDeliveryMethod) -> list[str]: recipients = method.config.recipients emails: list[str] = [] - member_user_ids: list[str] = [] + bound_reference_ids: list[str] = [] for recipient in recipients.items: if isinstance(recipient, MemberRecipient): - member_user_ids.append(recipient.user_id) + bound_reference_ids.append(recipient.reference_id) elif isinstance(recipient, ExternalRecipient): if recipient.email: emails.append(recipient.email) - if recipients.whole_workspace: - member_user_ids = [] + if recipients.include_bound_group: + bound_reference_ids = [] member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=None) emails.extend(member_emails.values()) - elif member_user_ids: - member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=member_user_ids) - for user_id in member_user_ids: + elif bound_reference_ids: + member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=bound_reference_ids) + for user_id in bound_reference_ids: email = member_emails.get(user_id) if email: emails.append(email) @@ -220,7 +220,7 @@ class EmailDeliveryTestHandler: stmt = stmt.where(Account.id.in_(unique_ids)) with self._session_factory() as session: - rows = session.execute(stmt).all() + rows = session.execute(stmt).tuples().all() return dict(rows) @staticmethod diff --git a/api/services/human_input_service.py b/api/services/human_input_service.py index 2e74c509632..02a6620fc74 100644 --- a/api/services/human_input_service.py +++ b/api/services/human_input_service.py @@ -3,6 +3,12 @@ from collections.abc import Mapping from datetime import datetime, timedelta from typing import Any +from graphon.nodes.human_input.entities import ( + FormDefinition, + HumanInputSubmissionValidationError, + validate_human_input_submission, +) +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker @@ -11,12 +17,6 @@ from core.repositories.human_input_repository import ( HumanInputFormRecord, HumanInputFormSubmissionRepository, ) -from dify_graph.nodes.human_input.entities import ( - FormDefinition, - HumanInputSubmissionValidationError, - validate_human_input_submission, -) -from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import ensure_naive_utc, naive_utc_now from libs.exception import BaseHTTPException from models.human_input import RecipientType diff --git a/api/services/knowledge_service.py b/api/services/knowledge_service.py index 02fe1d19bc4..3d6fdb08a38 100644 --- a/api/services/knowledge_service.py +++ b/api/services/knowledge_service.py @@ -1,12 +1,20 @@ import boto3 +from pydantic import BaseModel, Field from configs import dify_config +class BedrockRetrievalSetting(BaseModel): + """Retrieval settings for Amazon Bedrock knowledge base queries.""" + + top_k: int | None = Field(default=None, description="Maximum number of results to retrieve") + score_threshold: float = Field(default=0.0, description="Minimum relevance score threshold") + + class ExternalDatasetTestService: # this service is only for internal testing @staticmethod - def knowledge_retrieval(retrieval_setting: dict, query: str, knowledge_id: str): + def knowledge_retrieval(retrieval_setting: BedrockRetrievalSetting, query: str, knowledge_id: str): # get bedrock client client = boto3.client( "bedrock-agent-runtime", @@ -20,7 +28,7 @@ class ExternalDatasetTestService: knowledgeBaseId=knowledge_id, retrievalConfiguration={ "vectorSearchConfiguration": { - "numberOfResults": retrieval_setting.get("top_k"), + "numberOfResults": retrieval_setting.top_k, "overrideSearchType": "HYBRID", } }, @@ -33,7 +41,7 @@ class ExternalDatasetTestService: retrieval_results = response.get("retrievalResults") for retrieval_result in retrieval_results: # filter out results with score less than threshold - if retrieval_result.get("score") < retrieval_setting.get("score_threshold", 0.0): + if retrieval_result.get("score") < retrieval_setting.score_threshold: continue result = { "metadata": retrieval_result.get("metadata"), diff --git a/api/services/message_service.py b/api/services/message_service.py index fc87802f510..a04f9cbe012 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -1,7 +1,8 @@ -import json from collections.abc import Sequence from typing import Union +from graphon.model_runtime.entities.model_entities import ModelType +from pydantic import TypeAdapter from sqlalchemy.orm import sessionmaker from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager @@ -12,12 +13,11 @@ from core.model_manager import ModelManager from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account from models.enums import FeedbackFromSource, FeedbackRating -from models.model import App, AppMode, AppModelConfig, EndUser, Message, MessageFeedback +from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, EndUser, Message, MessageFeedback from repositories.execution_extra_content_repository import ExecutionExtraContentRepository from repositories.sqlalchemy_execution_extra_content_repository import ( SQLAlchemyExecutionExtraContentRepository, @@ -31,6 +31,8 @@ from services.errors.message import ( ) from services.workflow_service import WorkflowService +_app_model_config_adapter: TypeAdapter[AppModelConfigDict] = TypeAdapter(AppModelConfigDict) + def _create_execution_extra_content_repository() -> ExecutionExtraContentRepository: session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) @@ -255,7 +257,7 @@ class MessageService: app_model=app_model, conversation_id=message.conversation_id, user=user ) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=app_model.tenant_id) if app_model.mode == AppMode.ADVANCED_CHAT: workflow_service = WorkflowService() @@ -286,7 +288,9 @@ class MessageService: .first() ) else: - conversation_override_model_configs = json.loads(conversation.override_model_configs) + conversation_override_model_configs = _app_model_config_adapter.validate_json( + conversation.override_model_configs + ) app_model_config = AppModelConfig( app_id=app_model.id, ) diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index bf3b6db3edc..25de411e434 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -1,8 +1,13 @@ import json import logging -from json import JSONDecodeError -from typing import Union +from typing import Any, Union +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ModelCredentialSchema, + ProviderCredentialSchema, +) +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from sqlalchemy import or_, select from constants import HIDDEN_VALUE @@ -10,13 +15,8 @@ from core.entities.provider_configuration import ProviderConfiguration from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.model_manager import LBModelManager +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager from core.provider_manager import ProviderManager -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.entities.provider_entities import ( - ModelCredentialSchema, - ProviderCredentialSchema, -) -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.enums import CredentialSourceType @@ -26,8 +26,9 @@ logger = logging.getLogger(__name__) class ModelLoadBalancingService: - def __init__(self): - self.provider_manager = ProviderManager() + @staticmethod + def _get_provider_manager(tenant_id: str) -> ProviderManager: + return create_plugin_provider_manager(tenant_id=tenant_id) def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str): """ @@ -40,7 +41,7 @@ class ModelLoadBalancingService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -61,7 +62,7 @@ class ModelLoadBalancingService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -83,7 +84,7 @@ class ModelLoadBalancingService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -166,10 +167,10 @@ class ModelLoadBalancingService: try: if load_balancing_config.encrypted_config: - credentials: dict[str, object] = json.loads(load_balancing_config.encrypted_config) + credentials: dict[str, Any] = json.loads(load_balancing_config.encrypted_config) else: credentials = {} - except JSONDecodeError: + except (json.JSONDecodeError, ValueError): credentials = {} # Get provider credential secret variables @@ -222,8 +223,8 @@ class ModelLoadBalancingService: :param config_id: load balancing config id :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_manager = create_plugin_provider_manager(tenant_id=tenant_id) + provider_configurations = provider_manager.get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -254,7 +255,7 @@ class ModelLoadBalancingService: credentials = json.loads(load_balancing_model_config.encrypted_config) else: credentials = {} - except JSONDecodeError: + except (json.JSONDecodeError, ValueError): credentials = {} # Get credential form schemas from model credential schema or provider credential schema @@ -310,7 +311,7 @@ class ModelLoadBalancingService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -495,8 +496,8 @@ class ModelLoadBalancingService: :param config_id: load balancing config id :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + assembly = create_plugin_model_assembly(tenant_id=tenant_id) + provider_configurations = assembly.provider_manager.get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -532,6 +533,7 @@ class ModelLoadBalancingService: model=model, credentials=credentials, load_balancing_model_config=load_balancing_model_config, + model_provider_factory=assembly.model_provider_factory, ) def _custom_credentials_validate( @@ -542,6 +544,7 @@ class ModelLoadBalancingService: model: str, credentials: dict, load_balancing_model_config: LoadBalancingModelConfig | None = None, + model_provider_factory: ModelProviderFactory | None = None, validate: bool = True, ): """ @@ -552,6 +555,7 @@ class ModelLoadBalancingService: :param model: model name :param credentials: credentials :param load_balancing_model_config: load balancing model config + :param model_provider_factory: model provider factory sharing the active runtime :param validate: validate credentials :return: """ @@ -570,7 +574,7 @@ class ModelLoadBalancingService: original_credentials = json.loads(load_balancing_model_config.encrypted_config) else: original_credentials = {} - except JSONDecodeError: + except (json.JSONDecodeError, ValueError): original_credentials = {} # encrypt credentials @@ -581,7 +585,8 @@ class ModelLoadBalancingService: credentials[key] = encrypter.decrypt_token(tenant_id, original_credentials[key]) if validate: - model_provider_factory = ModelProviderFactory(tenant_id) + if model_provider_factory is None: + model_provider_factory = provider_configuration.get_model_provider_factory() if isinstance(credential_schemas, ModelCredentialSchema): credentials = model_provider_factory.model_credentials_validate( provider=provider_configuration.provider.provider, diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 0ddd6b9b1af..3f37c9b176d 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -1,9 +1,10 @@ import logging +from graphon.model_runtime.entities.model_entities import ModelType, ParameterRule + from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity +from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory, create_plugin_provider_manager from core.provider_manager import ProviderManager -from dify_graph.model_runtime.entities.model_entities import ModelType, ParameterRule -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from models.provider import ProviderType from services.entities.model_provider_entities import ( CustomConfigurationResponse, @@ -25,8 +26,9 @@ class ModelProviderService: Model Provider Service """ - def __init__(self): - self.provider_manager = ProviderManager() + @staticmethod + def _get_provider_manager(tenant_id: str) -> ProviderManager: + return create_plugin_provider_manager(tenant_id=tenant_id) def _get_provider_configuration(self, tenant_id: str, provider: str): """ @@ -43,7 +45,7 @@ class ModelProviderService: ProviderNotFoundError: If provider doesn't exist """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) provider_configuration = provider_configurations.get(provider) if not provider_configuration: @@ -60,7 +62,7 @@ class ModelProviderService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) provider_responses = [] for provider_configuration in provider_configurations.values(): @@ -138,7 +140,7 @@ class ModelProviderService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider available models return [ @@ -146,6 +148,26 @@ class ModelProviderService: for model in provider_configurations.get_models(provider=provider) ] + def get_provider_available_credentials(self, tenant_id: str, provider: str): + return self._get_provider_manager(tenant_id).get_provider_available_credentials( + tenant_id=tenant_id, + provider_name=provider, + ) + + def get_provider_model_available_credentials( + self, + tenant_id: str, + provider: str, + model_type: str, + model: str, + ): + return self._get_provider_manager(tenant_id).get_provider_model_available_credentials( + tenant_id=tenant_id, + provider_name=provider, + model_type=model_type, + model_name=model, + ) + def get_provider_credential(self, tenant_id: str, provider: str, credential_id: str | None = None) -> dict | None: """ get provider credentials. @@ -391,7 +413,7 @@ class ModelProviderService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider available models models = provider_configurations.get_models(model_type=ModelType.value_of(model_type), only_active=True) @@ -476,7 +498,9 @@ class ModelProviderService: model_type_enum = ModelType.value_of(model_type) try: - result = self.provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type_enum) + result = self._get_provider_manager(tenant_id).get_default_model( + tenant_id=tenant_id, model_type=model_type_enum + ) return ( DefaultModelResponse( model=result.model, @@ -507,7 +531,7 @@ class ModelProviderService: :return: """ model_type_enum = ModelType.value_of(model_type) - self.provider_manager.update_default_model_record( + self._get_provider_manager(tenant_id).update_default_model_record( tenant_id=tenant_id, model_type=model_type_enum, provider=provider, model=model ) @@ -523,7 +547,7 @@ class ModelProviderService: :param lang: language (zh_Hans or en_US) :return: """ - model_provider_factory = ModelProviderFactory(tenant_id) + model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id) byte_data, mime_type = model_provider_factory.get_provider_icon(provider, icon_type, lang) return byte_data, mime_type diff --git a/api/services/plugin/plugin_migration.py b/api/services/plugin/plugin_migration.py index df5fa3e233b..1562d4e6966 100644 --- a/api/services/plugin/plugin_migration.py +++ b/api/services/plugin/plugin_migration.py @@ -12,7 +12,9 @@ import click import sqlalchemy as sa import tqdm from flask import Flask, current_app +from pydantic import TypeAdapter from sqlalchemy.orm import Session +from typing_extensions import TypedDict from core.agent.entities import AgentToolEntity from core.helper import marketplace @@ -33,6 +35,14 @@ logger = logging.getLogger(__name__) excluded_providers = ["time", "audio", "code", "webscraper"] +class _TenantPluginRecord(TypedDict): + tenant_id: str + plugins: list[str] + + +_tenant_plugin_adapter: TypeAdapter[_TenantPluginRecord] = TypeAdapter(_TenantPluginRecord) + + class PluginMigration: @classmethod def extract_plugins(cls, filepath: str, workers: int): @@ -308,9 +318,8 @@ class PluginMigration: logger.info("Extracting unique plugins from %s", extracted_plugins) with open(extracted_plugins) as f: for line in f: - data = json.loads(line) - new_plugin_ids = data.get("plugins", []) - for plugin_id in new_plugin_ids: + data = _tenant_plugin_adapter.validate_json(line) + for plugin_id in data["plugins"]: if plugin_id not in plugin_ids: plugin_ids.append(plugin_id) @@ -381,21 +390,23 @@ class PluginMigration: Read line by line, and install plugins for each tenant. """ for line in f: - data = json.loads(line) - tenant_id = data.get("tenant_id") - plugin_ids = data.get("plugins", []) - current_not_installed = { - "tenant_id": tenant_id, - "plugin_not_exist": [], - } + data = _tenant_plugin_adapter.validate_json(line) + tenant_id = data["tenant_id"] + plugin_ids = data["plugins"] + plugin_not_exist: list[str] = [] # get plugin unique identifier for plugin_id in plugin_ids: unique_identifier = plugins.get(plugin_id) if unique_identifier: - current_not_installed["plugin_not_exist"].append(plugin_id) + plugin_not_exist.append(plugin_id) - if current_not_installed["plugin_not_exist"]: - not_installed.append(current_not_installed) + if plugin_not_exist: + not_installed.append( + { + "tenant_id": tenant_id, + "plugin_not_exist": plugin_not_exist, + } + ) thread_pool.submit(install, tenant_id, plugin_ids) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index f3aedafac9f..bcf5973d7b2 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -9,6 +9,15 @@ from typing import Any, Union, cast from uuid import uuid4 from flask_login import current_user +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes, ErrorStrategy, NodeType, WorkflowNodeExecutionStatus +from graphon.errors import WorkflowNodeRunFailedError +from graphon.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config +from graphon.runtime import VariablePool +from graphon.variables.variables import Variable, VariableBase from sqlalchemy import func, select from sqlalchemy.orm import Session, sessionmaker @@ -34,25 +43,19 @@ from core.rag.entities.event import ( DatasourceErrorEvent, DatasourceProcessingEvent, ) -from core.repositories.factory import DifyCoreRepositoryFactory +from core.repositories.factory import DifyCoreRepositoryFactory, OrderConfig from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.workflow.node_factory import LATEST_VERSION, get_node_type_classes_mapping -from core.workflow.workflow_entry import WorkflowEntry -from dify_graph.entities.workflow_node_execution import ( - WorkflowNodeExecution, - WorkflowNodeExecutionStatus, +from core.workflow.system_variables import ( + SystemVariableKey, + build_bootstrap_variables, + build_system_variables, + default_system_variables, + get_system_segment, ) -from dify_graph.enums import BuiltinNodeTypes, ErrorStrategy, NodeType, SystemVariableKey -from dify_graph.errors import WorkflowNodeRunFailedError -from dify_graph.graph_events import NodeRunFailedEvent, NodeRunSucceededEvent -from dify_graph.graph_events.base import GraphNodeEventBase -from dify_graph.node_events.base import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig -from dify_graph.runtime import VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables.variables import VariableBase +from core.workflow.variable_pool_initializer import add_variables_to_pool +from core.workflow.workflow_entry import WorkflowEntry +from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account @@ -79,14 +82,21 @@ from services.entities.knowledge_entities.rag_pipeline_entities import ( KnowledgeConfiguration, PipelineTemplateInfoEntity, ) -from services.errors.app import WorkflowHashNotEqualError +from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory from services.tools.builtin_tools_manage_service import BuiltinToolManageService from services.workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader +from services.workflow_restore import apply_published_workflow_snapshot_to_draft logger = logging.getLogger(__name__) +def _build_seeded_variable_pool(variables: Sequence[Variable]) -> VariablePool: + variable_pool = VariablePool() + add_variables_to_pool(variable_pool, variables) + return variable_pool + + class RagPipelineService: def __init__(self, session_maker: sessionmaker | None = None): """Initialize RagPipelineService with repository dependencies.""" @@ -234,6 +244,21 @@ class RagPipelineService: return workflow + def get_published_workflow_by_id(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None: + """Fetch a published workflow snapshot by ID for restore operations.""" + workflow = ( + db.session.query(Workflow) + .where( + Workflow.tenant_id == pipeline.tenant_id, + Workflow.app_id == pipeline.id, + Workflow.id == workflow_id, + ) + .first() + ) + if workflow and workflow.version == Workflow.VERSION_DRAFT: + raise IsDraftWorkflowError("source workflow must be published") + return workflow + def get_all_published_workflow( self, *, @@ -327,6 +352,42 @@ class RagPipelineService: # return draft workflow return workflow + def restore_published_workflow_to_draft( + self, + *, + pipeline: Pipeline, + workflow_id: str, + account: Account, + ) -> Workflow: + """Restore a published pipeline workflow snapshot into the draft workflow. + + Pipelines reuse the shared draft-restore field copy helper, but still own + the pipeline-specific flush/link step that wires a newly created draft + back onto ``pipeline.workflow_id``. + """ + source_workflow = self.get_published_workflow_by_id(pipeline=pipeline, workflow_id=workflow_id) + if not source_workflow: + raise WorkflowNotFoundError("Workflow not found.") + + draft_workflow = self.get_draft_workflow(pipeline=pipeline) + draft_workflow, is_new_draft = apply_published_workflow_snapshot_to_draft( + tenant_id=pipeline.tenant_id, + app_id=pipeline.id, + source_workflow=source_workflow, + draft_workflow=draft_workflow, + account=account, + updated_at_factory=lambda: datetime.now(UTC).replace(tzinfo=None), + ) + + if is_new_draft: + db.session.add(draft_workflow) + db.session.flush() + pipeline.workflow_id = draft_workflow.id + + db.session.commit() + + return draft_workflow + def publish_workflow( self, *, @@ -469,13 +530,7 @@ class RagPipelineService: node_id=node_id, user_inputs=user_inputs, user_id=account.id, - variable_pool=VariablePool( - system_variables=SystemVariable.default(), - user_inputs=user_inputs, - environment_variables=[], - conversation_variables=[], - rag_pipeline_variables=[], - ), + variable_pool=_build_seeded_variable_pool(default_system_variables()), variable_loader=DraftVarLoader( engine=db.engine, app_id=pipeline.id, @@ -519,6 +574,13 @@ class RagPipelineService: outputs=workflow_node_execution.outputs, ) session.commit() + if workflow_node_execution_db_model is not None: + enqueue_draft_node_execution_trace( + execution=workflow_node_execution_db_model, + outputs=workflow_node_execution.outputs, + workflow_execution_id=None, + user_id=account.id, + ) return workflow_node_execution_db_model def run_datasource_workflow_node( @@ -907,10 +969,10 @@ class RagPipelineService: workflow_node_execution.error = error # update document status variable_pool = node_instance.graph_runtime_state.variable_pool - invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) + invoke_from = get_system_segment(variable_pool, SystemVariableKey.INVOKE_FROM) if invoke_from: if invoke_from.value == InvokeFrom.PUBLISHED_PIPELINE: - document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) + document_id = get_system_segment(variable_pool, SystemVariableKey.DOCUMENT_ID) if document_id: document = db.session.query(Document).where(Document.id == document_id.value).first() if document: @@ -1224,7 +1286,7 @@ class RagPipelineService: else: enclosing_node_id = None - system_inputs = SystemVariable( + system_inputs = build_system_variables( datasource_type=args.get("datasource_type", "online_document"), datasource_info=args.get("datasource_info", {}), ) @@ -1235,12 +1297,11 @@ class RagPipelineService: node_id=node_id, user_inputs={}, user_id=current_user.id, - variable_pool=VariablePool( - system_variables=system_inputs, - user_inputs={}, - environment_variables=[], - conversation_variables=[], - rag_pipeline_variables=[], + variable_pool=_build_seeded_variable_pool( + build_bootstrap_variables( + system_variables=system_inputs, + rag_pipeline_variables=(), + ) ), variable_loader=DraftVarLoader( engine=db.engine, @@ -1282,6 +1343,12 @@ class RagPipelineService: outputs=workflow_node_execution.outputs, ) session.commit() + enqueue_draft_node_execution_trace( + execution=workflow_node_execution_db_model, + outputs=workflow_node_execution.outputs, + workflow_execution_id=None, + user_id=current_user.id, + ) return workflow_node_execution_db_model def get_recommended_plugins(self, type: str) -> dict: diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index deb59da8d34..04156713f4f 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -14,6 +14,12 @@ import yaml # type: ignore from Crypto.Cipher import AES from Crypto.Util.Padding import pad, unpad from flask_login import current_user +from graphon.enums import BuiltinNodeTypes +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes.llm.entities import LLMNodeData +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData +from graphon.nodes.tool.entities import ToolNodeData from packaging import version from pydantic import BaseModel, Field from sqlalchemy import select @@ -22,15 +28,10 @@ from sqlalchemy.orm import Session from core.helper import ssrf_proxy from core.helper.name_generator import generate_incremental_name from core.plugin.entities.plugin import PluginDependency +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.workflow.nodes.datasource.entities import DatasourceNodeData from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.nodes.llm.entities import LLMNodeData -from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData -from dify_graph.nodes.tool.entities import ToolNodeData from extensions.ext_redis import redis_client from factories import variable_factory from models import Account @@ -311,13 +312,13 @@ class RagPipelineDslService: "icon_background": icon_background, "icon_url": icon_url, }, - indexing_technique=knowledge_configuration.indexing_technique, + indexing_technique=IndexTechniqueType(knowledge_configuration.indexing_technique), created_by=account.id, retrieval_model=knowledge_configuration.retrieval_model.model_dump(), runtime_mode=DatasetRuntimeMode.RAG_PIPELINE, chunk_structure=knowledge_configuration.chunk_structure, ) - if knowledge_configuration.indexing_technique == "high_quality": + if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: dataset_collection_binding = ( self._session.query(DatasetCollectionBinding) .where( @@ -343,7 +344,7 @@ class RagPipelineDslService: dataset.collection_binding_id = dataset_collection_binding_id dataset.embedding_model = knowledge_configuration.embedding_model dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider - elif knowledge_configuration.indexing_technique == "economy": + elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY: dataset.keyword_number = knowledge_configuration.keyword_number # Update summary_index_setting if provided if knowledge_configuration.summary_index_setting is not None: @@ -443,18 +444,18 @@ class RagPipelineDslService: "icon_background": icon_background, "icon_url": icon_url, }, - indexing_technique=knowledge_configuration.indexing_technique, + indexing_technique=IndexTechniqueType(knowledge_configuration.indexing_technique), created_by=account.id, retrieval_model=knowledge_configuration.retrieval_model.model_dump(), runtime_mode=DatasetRuntimeMode.RAG_PIPELINE, chunk_structure=knowledge_configuration.chunk_structure, ) else: - dataset.indexing_technique = knowledge_configuration.indexing_technique + dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique) dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE dataset.chunk_structure = knowledge_configuration.chunk_structure - if knowledge_configuration.indexing_technique == "high_quality": + if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: dataset_collection_binding = ( self._session.query(DatasetCollectionBinding) .where( @@ -480,7 +481,7 @@ class RagPipelineDslService: dataset.collection_binding_id = dataset_collection_binding_id dataset.embedding_model = knowledge_configuration.embedding_model dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider - elif knowledge_configuration.indexing_technique == "economy": + elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY: dataset.keyword_number = knowledge_configuration.keyword_number # Update summary_index_setting if provided if knowledge_configuration.summary_index_setting is not None: @@ -772,7 +773,7 @@ class RagPipelineDslService: ) case _ if typ == KNOWLEDGE_INDEX_NODE_TYPE: knowledge_index_entity = KnowledgeConfiguration.model_validate(node["data"]) - if knowledge_index_entity.indexing_technique == "high_quality": + if knowledge_index_entity.indexing_technique == IndexTechniqueType.HIGH_QUALITY: if knowledge_index_entity.embedding_model_provider: dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency( diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index 1d0aafd5fd2..215a8c85285 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -9,6 +9,7 @@ from flask_login import current_user from constants import DOCUMENT_EXTENSIONS from core.plugin.impl.plugin import PluginInstaller +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from factories import variable_factory @@ -79,9 +80,9 @@ class RagPipelineTransformService: pipeline = self._create_pipeline(pipeline_yaml) # save chunk structure to dataset - if doc_form == "hierarchical_model": + if doc_form == IndexStructureType.PARENT_CHILD_INDEX: dataset.chunk_structure = "hierarchical_model" - elif doc_form == "text_model": + elif doc_form == IndexStructureType.PARAGRAPH_INDEX: dataset.chunk_structure = "text_model" else: raise ValueError("Unsupported doc form") @@ -101,38 +102,38 @@ class RagPipelineTransformService: def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: str | None): pipeline_yaml = {} - if doc_form == "text_model": + if doc_form == IndexStructureType.PARAGRAPH_INDEX: match datasource_type: case DataSourceType.UPLOAD_FILE: - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: # get graph from transform.file-general-high-quality.yml with open(f"{Path(__file__).parent}/transform/file-general-high-quality.yml") as f: pipeline_yaml = yaml.safe_load(f) - if indexing_technique == "economy": + if indexing_technique == IndexTechniqueType.ECONOMY: # get graph from transform.file-general-economy.yml with open(f"{Path(__file__).parent}/transform/file-general-economy.yml") as f: pipeline_yaml = yaml.safe_load(f) case DataSourceType.NOTION_IMPORT: - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: # get graph from transform.notion-general-high-quality.yml with open(f"{Path(__file__).parent}/transform/notion-general-high-quality.yml") as f: pipeline_yaml = yaml.safe_load(f) - if indexing_technique == "economy": + if indexing_technique == IndexTechniqueType.ECONOMY: # get graph from transform.notion-general-economy.yml with open(f"{Path(__file__).parent}/transform/notion-general-economy.yml") as f: pipeline_yaml = yaml.safe_load(f) case DataSourceType.WEBSITE_CRAWL: - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: # get graph from transform.website-crawl-general-high-quality.yml with open(f"{Path(__file__).parent}/transform/website-crawl-general-high-quality.yml") as f: pipeline_yaml = yaml.safe_load(f) - if indexing_technique == "economy": + if indexing_technique == IndexTechniqueType.ECONOMY: # get graph from transform.website-crawl-general-economy.yml with open(f"{Path(__file__).parent}/transform/website-crawl-general-economy.yml") as f: pipeline_yaml = yaml.safe_load(f) case _: raise ValueError("Unsupported datasource type") - elif doc_form == "hierarchical_model": + elif doc_form == IndexStructureType.PARENT_CHILD_INDEX: match datasource_type: case DataSourceType.UPLOAD_FILE: # get graph from transform.file-parentchild.yml @@ -169,11 +170,11 @@ class RagPipelineTransformService: ): knowledge_configuration_dict = node.get("data", {}) - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: knowledge_configuration.embedding_model = dataset.embedding_model knowledge_configuration.embedding_model_provider = dataset.embedding_model_provider if retrieval_model: - if indexing_technique == "economy": + if indexing_technique == IndexTechniqueType.ECONOMY: retrieval_model.search_method = RetrievalMethod.KEYWORD_SEARCH knowledge_configuration.retrieval_model = retrieval_model else: diff --git a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py index 00a21448005..2c1f99a3bc9 100644 --- a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py +++ b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py @@ -27,11 +27,11 @@ from dataclasses import dataclass, field from typing import Any import click +from graphon.enums import WorkflowType from sqlalchemy import inspect from sqlalchemy.orm import Session, sessionmaker from configs import dify_config -from dify_graph.enums import WorkflowType from enums.cloud_plan import CloudPlan from extensions.ext_database import db from libs.archive_storage import ( diff --git a/api/services/retention/workflow_run/restore_archived_workflow_run.py b/api/services/retention/workflow_run/restore_archived_workflow_run.py index 64dad7ba52a..c8362738ee8 100644 --- a/api/services/retention/workflow_run/restore_archived_workflow_run.py +++ b/api/services/retention/workflow_run/restore_archived_workflow_run.py @@ -6,7 +6,6 @@ back to the database. """ import io -import json import logging import time import zipfile @@ -17,8 +16,23 @@ from datetime import datetime from typing import Any, cast import click +from pydantic import TypeAdapter from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.engine import CursorResult +from typing_extensions import TypedDict + + +class _TableInfo(TypedDict, total=False): + row_count: int + + +class ArchiveManifest(TypedDict, total=False): + tables: dict[str, _TableInfo] + schema_version: str + + +_manifest_adapter: TypeAdapter[ArchiveManifest] = TypeAdapter(ArchiveManifest) + from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker from extensions.ext_database import db @@ -239,12 +253,12 @@ class WorkflowRunRestore: return self.workflow_run_repo @staticmethod - def _load_manifest_from_zip(archive: zipfile.ZipFile) -> dict[str, Any]: + def _load_manifest_from_zip(archive: zipfile.ZipFile) -> ArchiveManifest: try: data = archive.read("manifest.json") except KeyError as e: raise ValueError("manifest.json missing from archive bundle") from e - return json.loads(data.decode("utf-8")) + return _manifest_adapter.validate_json(data) def _restore_table_records( self, @@ -332,7 +346,7 @@ class WorkflowRunRestore: return result - def _get_schema_version(self, manifest: dict[str, Any]) -> str: + def _get_schema_version(self, manifest: ArchiveManifest) -> str: schema_version = manifest.get("schema_version") if not schema_version: logger.warning("Manifest missing schema_version; defaulting to 1.0") diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py index 943dfc972bd..12053377e24 100644 --- a/api/services/summary_index_service.py +++ b/api/services/summary_index_service.py @@ -6,16 +6,17 @@ import uuid from datetime import UTC, datetime from typing import Any +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm import Session from core.db.session_factory import session_factory from core.model_manager import ModelManager from core.rag.datasource.vdb.vector_factory import Vector from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.models.document import Document -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.model_runtime.entities.model_entities import ModelType from libs import helper from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary from models.dataset import Document as DatasetDocument @@ -140,7 +141,7 @@ class SummaryIndexService: session: Optional SQLAlchemy session. If provided, uses this session instead of creating a new one. If not provided, creates a new session and commits automatically. """ - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: logger.warning( "Summary vectorization skipped for dataset %s: indexing_technique is not high_quality", dataset.id, @@ -191,7 +192,7 @@ class SummaryIndexService: # Calculate embedding tokens for summary (for logging and statistics) embedding_tokens = 0 try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, @@ -200,7 +201,8 @@ class SummaryIndexService: ) if embedding_model: tokens_list = embedding_model.get_text_embedding_num_tokens([summary_content]) - embedding_tokens = tokens_list[0] if tokens_list else 0 + raw_embedding_tokens = tokens_list[0] if tokens_list else 0 + embedding_tokens = raw_embedding_tokens if isinstance(raw_embedding_tokens, int) else 0 except Exception as e: logger.warning("Failed to calculate embedding tokens for summary: %s", str(e)) @@ -724,7 +726,7 @@ class SummaryIndexService: List of created DocumentSegmentSummary instances """ # Only generate summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: logger.info( "Skipping summary generation for dataset %s: indexing_technique is %s, not 'high_quality'", dataset.id, @@ -851,7 +853,7 @@ class SummaryIndexService: ) # Remove from vector database (but keep records) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id] if summary_node_ids: try: @@ -889,7 +891,7 @@ class SummaryIndexService: segment_ids: List of segment IDs to enable summaries for. If None, enable all. """ # Only enable summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: return with session_factory.create_session() as session: @@ -981,7 +983,7 @@ class SummaryIndexService: return # Delete from vector database - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id] if summary_node_ids: vector = Vector(dataset) @@ -1012,7 +1014,7 @@ class SummaryIndexService: Updated DocumentSegmentSummary instance, or None if indexing technique is not high_quality """ # Only update summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: return None # When user manually provides summary, allow saving even if summary_index_setting doesn't exist diff --git a/api/services/tag_service.py b/api/services/tag_service.py index bd3585acf44..70bf7f16f24 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound from extensions.ext_database import db from models.dataset import Dataset +from models.enums import TagType from models.model import App, Tag, TagBinding @@ -83,7 +84,7 @@ class TagService: raise ValueError("Tag name already exists") tag = Tag( name=args["name"], - type=args["type"], + type=TagType(args["type"]), created_by=current_user.id, tenant_id=current_user.current_tenant_id, ) diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index dc883f0daa8..2a56bc0c71e 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -1,10 +1,11 @@ import json import logging -from collections.abc import Mapping from typing import Any, cast +from graphon.model_runtime.utils.encoders import jsonable_encoder from httpx import get from sqlalchemy import select +from typing_extensions import TypedDict from core.entities.provider_entities import ProviderConfig from core.tools.__base.tool_runtime import ToolRuntime @@ -20,7 +21,6 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager from core.tools.utils.encryption import create_tool_provider_encrypter from core.tools.utils.parser import ApiBasedToolSchemaParser -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from models.tools import ApiToolProvider from services.tools.tools_transform_service import ToolTransformService @@ -28,9 +28,16 @@ from services.tools.tools_transform_service import ToolTransformService logger = logging.getLogger(__name__) +class ApiSchemaParseResult(TypedDict): + schema_type: str + parameters_schema: list[dict[str, Any]] + credentials_schema: list[dict[str, Any]] + warning: dict[str, str] + + class ApiToolManageService: @staticmethod - def parser_api_schema(schema: str) -> Mapping[str, Any]: + def parser_api_schema(schema: str) -> ApiSchemaParseResult: """ parse api schema to tool bundle """ @@ -71,7 +78,7 @@ class ApiToolManageService: ] return cast( - Mapping, + ApiSchemaParseResult, jsonable_encoder( { "schema_type": schema_type, diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 6797a67ddef..8e3c36e0998 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -275,7 +275,7 @@ class BuiltinToolManageService: user_id=user_id, provider=provider, encrypted_credentials=json.dumps(encrypter.encrypt(credentials)), - credential_type=api_type.value, + credential_type=api_type, name=name, expires_at=expires_at if expires_at is not None else -1, ) @@ -314,7 +314,7 @@ class BuiltinToolManageService: .filter_by( tenant_id=tenant_id, provider=provider, - credential_type=credential_type.value, + credential_type=credential_type, ) .order_by(BuiltinToolProvider.created_at.desc()) .all() diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py index 0be106f5977..deb26438a8f 100644 --- a/api/services/tools/mcp_tools_manage_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -18,6 +18,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache from core.mcp.auth.auth_flow import auth from core.mcp.auth_client import MCPClientWithAuthRetry from core.mcp.error import MCPAuthError, MCPError +from core.mcp.types import Tool as MCPTool from core.tools.entities.api_entities import ToolProviderApiEntity from core.tools.utils.encryption import ProviderConfigEncrypter from models.tools import MCPToolProvider @@ -681,7 +682,7 @@ class MCPToolManageService: raise ValueError(f"Failed to re-connect MCP server: {e}") from e def _build_tool_provider_response( - self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list + self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list[MCPTool] ) -> ToolProviderApiEntity: """Build API response for tool provider.""" user = db_provider.load_user() @@ -703,7 +704,7 @@ class MCPToolManageService: raise ValueError(f"MCP tool {server_url} already exists") if "unique_mcp_provider_server_identifier" in error_msg: raise ValueError(f"MCP tool {server_identifier} already exists") - raise + raise error def _is_valid_url(self, url: str) -> bool: """Validate URL format.""" diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index b6e5367023c..7cd61e3162a 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -3,7 +3,7 @@ import logging from collections.abc import Mapping from typing import Any, Union -from pydantic import ValidationError +from pydantic import TypeAdapter, ValidationError from yarl import URL from configs import dify_config @@ -31,6 +31,8 @@ from services.plugin.plugin_service import PluginService logger = logging.getLogger(__name__) +_mcp_tools_adapter: TypeAdapter[list[MCPTool]] = TypeAdapter(list[MCPTool]) + class ToolTransformService: _MCP_SCHEMA_TYPE_RESOLUTION_MAX_DEPTH = 10 @@ -53,7 +55,7 @@ class ToolTransformService: if isinstance(icon, str): return json.loads(icon) return icon - except Exception: + except (json.JSONDecodeError, ValueError): return {"background": "#252525", "content": "\ud83d\ude01"} elif provider_type == ToolProviderType.MCP: return icon @@ -247,8 +249,8 @@ class ToolTransformService: response = provider_entity.to_api_response(user_name=user_name, include_sensitive=include_sensitive) try: - mcp_tools = [MCPTool(**tool) for tool in json.loads(db_provider.tools)] - except (ValidationError, json.JSONDecodeError): + mcp_tools = _mcp_tools_adapter.validate_json(db_provider.tools) + except (ValidationError, ValueError): mcp_tools = [] # Add additional fields specific to the transform response["id"] = db_provider.server_identifier if not for_list else db_provider.id @@ -423,7 +425,7 @@ class ToolTransformService: id=provider.id, name=provider.name, provider=provider.provider, - credential_type=CredentialType.of(provider.credential_type), + credential_type=provider.credential_type, is_default=provider.is_default, credentials=credentials, ) diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 101b2fe5a25..fb6b5bea24d 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -2,6 +2,7 @@ import json import logging from datetime import datetime +from graphon.model_runtime.utils.encoders import jsonable_encoder from sqlalchemy import or_, select from sqlalchemy.orm import Session @@ -12,7 +13,6 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from models.model import App from models.tools import WorkflowToolProvider diff --git a/api/services/trigger/schedule_service.py b/api/services/trigger/schedule_service.py index 7e9d010d2f3..25e80770b83 100644 --- a/api/services/trigger/schedule_service.py +++ b/api/services/trigger/schedule_service.py @@ -2,6 +2,7 @@ import json import logging from datetime import datetime +from graphon.entities.graph_config import NodeConfigDict from sqlalchemy import select from sqlalchemy.orm import Session @@ -13,7 +14,6 @@ from core.workflow.nodes.trigger_schedule.entities import ( VisualConfig, ) from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError, ScheduleNotFoundError -from dify_graph.entities.graph_config import NodeConfigDict from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h from models.account import Account, TenantAccountJoin from models.trigger import WorkflowSchedulePlan diff --git a/api/services/trigger/trigger_provider_service.py b/api/services/trigger/trigger_provider_service.py index 688993c7987..008d8bdb8af 100644 --- a/api/services/trigger/trigger_provider_service.py +++ b/api/services/trigger/trigger_provider_service.py @@ -198,7 +198,7 @@ class TriggerProviderService: credentials=dict(credential_encrypter.encrypt(dict(credentials))) if credential_encrypter else {}, - credential_type=credential_type.value, + credential_type=credential_type, credential_expires_at=credential_expires_at, expires_at=expires_at, ) diff --git a/api/services/trigger/trigger_service.py b/api/services/trigger/trigger_service.py index 24bbeda3293..d72c0416092 100644 --- a/api/services/trigger/trigger_service.py +++ b/api/services/trigger/trigger_service.py @@ -5,6 +5,7 @@ from collections.abc import Mapping from typing import Any from flask import Request, Response +from graphon.entities.graph_config import NodeConfigDict from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.orm import Session @@ -18,7 +19,6 @@ from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager from core.trigger.utils.encryption import create_trigger_provider_encrypter_for_subscription from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData -from dify_graph.entities.graph_config import NodeConfigDict from extensions.ext_database import db from extensions.ext_redis import redis_client from models.model import App diff --git a/api/services/trigger/trigger_subscription_builder_service.py b/api/services/trigger/trigger_subscription_builder_service.py index 37f852da3eb..889717df727 100644 --- a/api/services/trigger/trigger_subscription_builder_service.py +++ b/api/services/trigger/trigger_subscription_builder_service.py @@ -1,4 +1,3 @@ -import json import logging import uuid from collections.abc import Mapping @@ -7,6 +6,7 @@ from datetime import datetime from typing import Any from flask import Request, Response +from pydantic import TypeAdapter from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.entities.request import TriggerDispatchResponse @@ -29,6 +29,8 @@ from services.trigger.trigger_provider_service import TriggerProviderService logger = logging.getLogger(__name__) +_request_logs_adapter: TypeAdapter[list[RequestLog]] = TypeAdapter(list[RequestLog]) + class TriggerSubscriptionBuilderService: """Service for managing trigger providers and credentials""" @@ -398,7 +400,7 @@ class TriggerSubscriptionBuilderService: cache_key = cls.encode_cache_key(endpoint_id) subscription_cache = redis_client.get(cache_key) if subscription_cache: - return SubscriptionBuilder.model_validate(json.loads(subscription_cache)) + return SubscriptionBuilder.model_validate_json(subscription_cache) return None @@ -423,12 +425,16 @@ class TriggerSubscriptionBuilderService: ) key = f"trigger:subscription:builder:logs:{endpoint_id}" - logs = json.loads(redis_client.get(key) or "[]") - logs.append(log.model_dump(mode="json")) + logs = _request_logs_adapter.validate_json(redis_client.get(key) or b"[]") + logs.append(log) # Keep last N logs logs = logs[-cls.__VALIDATION_REQUEST_CACHE_COUNT__ :] - redis_client.setex(key, cls.__VALIDATION_REQUEST_CACHE_EXPIRE_SECONDS__, json.dumps(logs, default=str)) + redis_client.setex( + key, + cls.__VALIDATION_REQUEST_CACHE_EXPIRE_SECONDS__, + _request_logs_adapter.dump_json(logs), + ) @classmethod def list_logs(cls, endpoint_id: str) -> list[RequestLog]: @@ -437,7 +443,7 @@ class TriggerSubscriptionBuilderService: logs_json = redis_client.get(key) if not logs_json: return [] - return [RequestLog.model_validate(log) for log in json.loads(logs_json)] + return _request_logs_adapter.validate_json(logs_json) @classmethod def process_builder_validation_endpoint(cls, endpoint_id: str, request: Request) -> Response | None: diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index 3c1a4cc7472..c03275497d7 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -7,6 +7,9 @@ from typing import Any import orjson from flask import request +from graphon.entities.graph_config import NodeConfigDict +from graphon.file import FileTransferMethod +from graphon.variables.types import ArrayValidation, SegmentType from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.orm import Session @@ -15,6 +18,7 @@ from werkzeug.exceptions import RequestEntityTooLarge from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.file_access import DatabaseFileAccessController from core.tools.tool_file_manager import ToolFileManager from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE from core.workflow.nodes.trigger_webhook.entities import ( @@ -23,9 +27,6 @@ from core.workflow.nodes.trigger_webhook.entities import ( WebhookData, WebhookParameter, ) -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.file.models import FileTransferMethod -from dify_graph.variables.types import ArrayValidation, SegmentType from enums.quota_type import QuotaType from extensions.ext_database import db from extensions.ext_redis import redis_client @@ -46,6 +47,7 @@ except ImportError: magic = None # type: ignore[assignment] logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() class WebhookService: @@ -422,6 +424,7 @@ class WebhookService: return file_factory.build_from_mapping( mapping=mapping, tenant_id=webhook_trigger.tenant_id, + access_controller=_file_access_controller, ) @classmethod diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py index 60dc1dedb81..62916cc2c93 100644 --- a/api/services/variable_truncator.py +++ b/api/services/variable_truncator.py @@ -5,10 +5,9 @@ from abc import ABC, abstractmethod from collections.abc import Mapping from typing import Any, Generic, TypeAlias, TypeVar, overload -from configs import dify_config -from dify_graph.file.models import File -from dify_graph.nodes.variable_assigner.common.helpers import UpdatedVariable -from dify_graph.variables.segments import ( +from graphon.file import File +from graphon.nodes.variable_assigner.common.helpers import UpdatedVariable +from graphon.variables.segments import ( ArrayFileSegment, ArraySegment, BooleanSegment, @@ -20,7 +19,9 @@ from dify_graph.variables.segments import ( Segment, StringSegment, ) -from dify_graph.variables.utils import dumps_with_segments +from graphon.variables.utils import dumps_with_segments + +from configs import dify_config _MAX_DEPTH = 100 diff --git a/api/services/vector_service.py b/api/services/vector_service.py index b66fdd7a20e..3f78b823a63 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -1,14 +1,15 @@ import logging +from graphon.model_runtime.entities.model_entities import ModelType + from core.model_manager import ModelInstance, ModelManager from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector from core.rag.index_processor.constant.doc_type import DocType -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import AttachmentDocument, Document -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from models import UploadFile from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding @@ -45,9 +46,9 @@ class VectorService: if not processing_rule: raise ValueError("No processing rule found.") # get embedding model instance - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) if dataset.embedding_model_provider: embedding_model_instance = model_manager.get_model_instance( @@ -112,7 +113,7 @@ class VectorService: "dataset_id": segment.dataset_id, }, ) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # update vector index vector = Vector(dataset=dataset) vector.delete_by_ids([segment.index_node_id]) @@ -197,7 +198,7 @@ class VectorService: "dataset_id": child_segment.dataset_id, }, ) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # save vector index vector = Vector(dataset=dataset) vector.add_texts([child_document], duplicate_check=True) @@ -237,7 +238,7 @@ class VectorService: delete_node_ids.append(update_child_chunk.index_node_id) for delete_child_chunk in delete_child_chunks: delete_node_ids.append(delete_child_chunk.index_node_id) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # update vector index vector = Vector(dataset=dataset) if delete_node_ids: @@ -252,7 +253,7 @@ class VectorService: @classmethod def update_multimodel_vector(cls, segment: DocumentSegment, attachment_ids: list[str], dataset: Dataset): - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: return attachments = segment.attachments diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 006483fe979..31367f72fab 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -1,5 +1,12 @@ import json -from typing import Any, TypedDict +from typing import Any + +from graphon.file import FileUploadConfig +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes import BuiltinNodeTypes +from graphon.variables.input_entities import VariableEntity +from typing_extensions import TypedDict from core.app.app_config.entities import ( DatasetEntity, @@ -15,11 +22,6 @@ from core.app.apps.completion.app_config_manager import CompletionAppConfigManag from core.helper import encrypter from core.prompt.simple_prompt_transform import SimplePromptTransform from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from dify_graph.file.models import FileUploadConfig -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.nodes import BuiltinNodeTypes -from dify_graph.variables.input_entities import VariableEntity from events.app_event import app_was_created from extensions.ext_database import db from models import Account @@ -34,6 +36,17 @@ class _NodeType(TypedDict): data: dict[str, Any] +class _EdgeType(TypedDict): + id: str + source: str + target: str + + +class WorkflowGraph(TypedDict): + nodes: list[_NodeType] + edges: list[_EdgeType] + + class WorkflowConverter: """ App Convert to Workflow Mode @@ -107,7 +120,7 @@ class WorkflowConverter: app_config = self._convert_to_app_config(app_model=app_model, app_model_config=app_model_config) # init workflow graph - graph: dict[str, Any] = {"nodes": [], "edges": []} + graph: WorkflowGraph = {"nodes": [], "edges": []} # Convert list: # - variables -> start @@ -385,7 +398,7 @@ class WorkflowConverter: self, original_app_mode: AppMode, new_app_mode: AppMode, - graph: dict, + graph: WorkflowGraph, model_config: ModelConfigEntity, prompt_template: PromptTemplateEntity, file_upload: FileUploadConfig | None = None, @@ -595,7 +608,7 @@ class WorkflowConverter: "data": {"title": "ANSWER", "type": BuiltinNodeTypes.ANSWER, "answer": "{{#llm.text#}}"}, } - def _create_edge(self, source: str, target: str): + def _create_edge(self, source: str, target: str) -> _EdgeType: """ Create Edge :param source: source node id @@ -604,7 +617,7 @@ class WorkflowConverter: """ return {"id": f"{source}-{target}", "source": source, "target": target} - def _append_node(self, graph: dict[str, Any], node: _NodeType): + def _append_node(self, graph: WorkflowGraph, node: _NodeType): """ Append Node to Graph diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index 7147fe1eab1..bf178e8a44d 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -3,10 +3,11 @@ import uuid from datetime import datetime from typing import Any +from graphon.enums import WorkflowExecutionStatus from sqlalchemy import and_, func, or_, select from sqlalchemy.orm import Session +from typing_extensions import TypedDict -from dify_graph.enums import WorkflowExecutionStatus from models import Account, App, EndUser, TenantAccountJoin, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun from models.enums import AppTriggerType, CreatorUserRole from models.trigger import WorkflowTriggerLog @@ -14,6 +15,10 @@ from services.plugin.plugin_service import PluginService from services.workflow.entities import TriggerMetadata +class LogViewDetails(TypedDict): + trigger_metadata: dict[str, Any] | None + + # Since the workflow_app_log table has exceeded 100 million records, we use an additional details field to extend it class LogView: """Lightweight wrapper for WorkflowAppLog with computed details. @@ -22,12 +27,12 @@ class LogView: - Proxies all other attributes to the underlying `WorkflowAppLog` """ - def __init__(self, log: WorkflowAppLog, details: dict | None): + def __init__(self, log: WorkflowAppLog, details: LogViewDetails | None): self.log = log self.details_ = details @property - def details(self) -> dict | None: + def details(self) -> LogViewDetails | None: return self.details_ def __getattr__(self, name): diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index fb1a3f30c01..98e338a2d4a 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -6,6 +6,19 @@ from concurrent.futures import ThreadPoolExecutor from enum import StrEnum from typing import Any, ClassVar +from graphon.enums import NodeType +from graphon.file import File +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.variable_assigner.common.helpers import get_updated_variables +from graphon.variable_loader import VariableLoader +from graphon.variables import Segment, StringSegment, VariableBase +from graphon.variables.consts import SELECTORS_LENGTH +from graphon.variables.segments import ( + ArrayFileSegment, + FileSegment, +) +from graphon.variables.types import SegmentType +from graphon.variables.utils import dumps_with_segments from sqlalchemy import Engine, orm, select from sqlalchemy.dialects.mysql import insert as mysql_insert from sqlalchemy.dialects.postgresql import insert as pg_insert @@ -14,28 +27,23 @@ from sqlalchemy.sql.expression import and_, or_ from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.file_access import DatabaseFileAccessController from core.trigger.constants import is_trigger_node_type -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from dify_graph.enums import NodeType, SystemVariableKey -from dify_graph.file.models import File -from dify_graph.nodes import BuiltinNodeTypes -from dify_graph.nodes.variable_assigner.common.helpers import get_updated_variables -from dify_graph.variable_loader import VariableLoader -from dify_graph.variables import Segment, StringSegment, VariableBase -from dify_graph.variables.consts import SELECTORS_LENGTH -from dify_graph.variables.segments import ( - ArrayFileSegment, - FileSegment, +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, + RAG_PIPELINE_VARIABLE_NODE_ID, + SYSTEM_VARIABLE_NODE_ID, ) -from dify_graph.variables.types import SegmentType -from dify_graph.variables.utils import dumps_with_segments from extensions.ext_storage import storage from factories.file_factory import StorageKeyLoader from factories.variable_factory import build_segment, segment_to_variable from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models import Account, App, Conversation -from models.enums import DraftVariableType +from models.enums import ConversationFromSource, DraftVariableType +from models.utils.file_input_compat import build_file_from_stored_mapping from models.workflow import Workflow, WorkflowDraftVariable, WorkflowDraftVariableFile, is_system_variable_editable from repositories.factory import DifyAPIRepositoryFactory from services.file_service import FileService @@ -71,7 +79,7 @@ class UpdateNotSupportedError(WorkflowDraftVariableError): class DraftVarLoader(VariableLoader): # This implements the VariableLoader interface for loading draft variables. # - # ref: dify_graph.variable_loader.VariableLoader + # ref: graphon.variable_loader.VariableLoader # Database engine used for loading variables. _engine: Engine @@ -120,7 +128,11 @@ class DraftVarLoader(VariableLoader): elif isinstance(value, ArrayFileSegment): files.extend(value.value) with Session(bind=self._engine) as session: - storage_key_loader = StorageKeyLoader(session, tenant_id=self._tenant_id) + storage_key_loader = StorageKeyLoader( + session, + tenant_id=self._tenant_id, + access_controller=DatabaseFileAccessController(), + ) storage_key_loader.load_storage_keys(files) offloaded_draft_vars = [] @@ -174,7 +186,7 @@ class DraftVarLoader(VariableLoader): return (draft_var.node_id, draft_var.name), variable deserialized = json.loads(content) - segment = WorkflowDraftVariable.build_segment_with_type(variable_file.value_type, deserialized) + segment = draft_var.build_segment_from_serialized_value(variable_file.value_type, deserialized) variable = segment_to_variable( segment=segment, selector=draft_var.get_selector(), @@ -601,7 +613,7 @@ class WorkflowDraftVariableService: system_instruction_tokens=0, status="normal", invoke_from=InvokeFrom.DEBUGGER, - from_source="console", + from_source=ConversationFromSource.CONSOLE, from_end_user_id=None, from_account_id=account_id, ) @@ -838,6 +850,12 @@ class DraftVariableSaver: self._user = user self._enclosing_node_id = enclosing_node_id + def _resolve_app_tenant_id(self) -> str: + tenant_id = self._session.scalar(select(App.tenant_id).where(App.id == self._app_id)) + if not tenant_id: + raise ValueError(f"Unable to resolve tenant_id for app {self._app_id}") + return tenant_id + def _create_dummy_output_variable(self): return WorkflowDraftVariable.new_node_variable( app_id=self._app_id, @@ -892,27 +910,18 @@ class DraftVariableSaver: for name, value in output.items(): value_seg = _build_segment_for_serialized_values(value) node_id, name = self._normalize_variable_for_start_node(name) - # If node_id is not `sys`, it means that the variable is a user-defined input field - # in `Start` node. - if node_id != SYSTEM_VARIABLE_NODE_ID: - draft_vars.append( - WorkflowDraftVariable.new_node_variable( - app_id=self._app_id, - user_id=self._user.id, - node_id=self._node_id, - name=name, - node_execution_id=self._node_execution_id, - value=value_seg, - visible=True, - editable=True, - ) - ) - has_non_sys_variables = True - else: + if node_id == SYSTEM_VARIABLE_NODE_ID: if name == SystemVariableKey.FILES: # Here we know the type of variable must be `array[file]`, we - # just build files from the value. - files = [File.model_validate(v) for v in value] + # just rebuild files from the serialized payload. + tenant_id = self._resolve_app_tenant_id() + files = [ + build_file_from_stored_mapping( + file_mapping=v, + tenant_id=tenant_id, + ) + for v in value + ] if files: value_seg = WorkflowDraftVariable.build_segment_with_type(SegmentType.ARRAY_FILE, files) else: @@ -928,15 +937,47 @@ class DraftVariableSaver: editable=self._should_variable_be_editable(node_id, name), ) ) + elif node_id == CONVERSATION_VARIABLE_NODE_ID: + draft_vars.append( + WorkflowDraftVariable.new_conversation_variable( + app_id=self._app_id, + user_id=self._user.id, + name=name, + value=value_seg, + ) + ) + has_non_sys_variables = True + else: + draft_vars.append( + WorkflowDraftVariable.new_node_variable( + app_id=self._app_id, + user_id=self._user.id, + node_id=node_id, + name=name, + node_execution_id=self._node_execution_id, + value=value_seg, + visible=self._should_variable_be_visible(node_id, self._node_type, name), + editable=self._should_variable_be_editable(node_id, name), + ) + ) + has_non_sys_variables = True if not has_non_sys_variables: draft_vars.append(self._create_dummy_output_variable()) return draft_vars def _normalize_variable_for_start_node(self, name: str) -> tuple[str, str]: - if not name.startswith(f"{SYSTEM_VARIABLE_NODE_ID}."): - return self._node_id, name - _, name_ = name.split(".", maxsplit=1) - return SYSTEM_VARIABLE_NODE_ID, name_ + for reserved_node_id in ( + SYSTEM_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, + CONVERSATION_VARIABLE_NODE_ID, + RAG_PIPELINE_VARIABLE_NODE_ID, + ): + prefix = f"{reserved_node_id}." + if name.startswith(prefix): + _, name_ = name.split(".", maxsplit=1) + return reserved_node_id, name_ + + return self._node_id, name def _build_variables_from_mapping(self, output: Mapping[str, Any]) -> list[WorkflowDraftVariable]: draft_vars = [] diff --git a/api/services/workflow_event_snapshot_service.py b/api/services/workflow_event_snapshot_service.py index 8f323ebb8b2..601e9261fc6 100644 --- a/api/services/workflow_event_snapshot_service.py +++ b/api/services/workflow_event_snapshot_service.py @@ -9,6 +9,10 @@ from collections.abc import Generator, Mapping, Sequence from dataclasses import dataclass from typing import Any +from graphon.entities import WorkflowStartReason +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import desc, select from sqlalchemy.orm import Session, sessionmaker @@ -22,10 +26,6 @@ from core.app.entities.task_entities import ( WorkflowStartStreamResponse, ) from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext -from dify_graph.entities import WorkflowStartReason -from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus -from dify_graph.runtime import GraphRuntimeState -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from models.model import AppMode, Message from models.workflow import WorkflowNodeExecutionTriggeredFrom, WorkflowRun from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot diff --git a/api/services/workflow_restore.py b/api/services/workflow_restore.py new file mode 100644 index 00000000000..083235d228c --- /dev/null +++ b/api/services/workflow_restore.py @@ -0,0 +1,58 @@ +"""Shared helpers for restoring published workflow snapshots into drafts. + +Both app workflows and RAG pipeline workflows restore the same workflow fields +from a published snapshot into a draft. Keeping that field-copy logic in one +place prevents the two restore paths from drifting when we add or adjust draft +state in the future. Restore stays within a tenant, so we can safely reuse the +serialized workflow storage blobs without decrypting and re-encrypting secrets. +""" + +from collections.abc import Callable +from datetime import datetime + +from models import Account +from models.workflow import Workflow, WorkflowType + +UpdatedAtFactory = Callable[[], datetime] + + +def apply_published_workflow_snapshot_to_draft( + *, + tenant_id: str, + app_id: str, + source_workflow: Workflow, + draft_workflow: Workflow | None, + account: Account, + updated_at_factory: UpdatedAtFactory, +) -> tuple[Workflow, bool]: + """Copy a published workflow snapshot into a draft workflow record. + + The caller remains responsible for source lookup, validation, flushing, and + post-commit side effects. This helper only centralizes the shared draft + creation/update semantics used by both restore entry points. Features are + copied from the stored JSON payload so restore does not normalize and dirty + the published source row before the caller commits. + """ + if not draft_workflow: + workflow_type = ( + source_workflow.type.value if isinstance(source_workflow.type, WorkflowType) else source_workflow.type + ) + draft_workflow = Workflow( + tenant_id=tenant_id, + app_id=app_id, + type=workflow_type, + version=Workflow.VERSION_DRAFT, + graph=source_workflow.graph, + features=source_workflow.serialized_features, + created_by=account.id, + ) + draft_workflow.copy_serialized_variable_storage_from(source_workflow) + return draft_workflow, True + + draft_workflow.graph = source_workflow.graph + draft_workflow.features = source_workflow.serialized_features + draft_workflow.updated_by = account.id + draft_workflow.updated_at = updated_at_factory() + draft_workflow.copy_serialized_variable_storage_from(source_workflow) + + return draft_workflow, False diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index e13cdd5f278..3b3ee6dd92e 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -5,6 +5,31 @@ import uuid from collections.abc import Callable, Generator, Mapping, Sequence from typing import Any, cast +from graphon.entities import GraphInitParams, WorkflowNodeExecution +from graphon.entities.graph_config import NodeConfigDict +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import ( + ErrorStrategy, + NodeType, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.errors import WorkflowNodeRunFailedError +from graphon.file import File +from graphon.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent +from graphon.node_events import NodeRunResult +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.base.node import Node +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config +from graphon.nodes.human_input.entities import HumanInputNodeData, validate_human_input_submission +from graphon.nodes.human_input.enums import HumanInputFormKind +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.nodes.start.entities import StartNodeData +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import load_into_variable_pool +from graphon.variables import VariableBase +from graphon.variables.input_entities import VariableEntityType +from graphon.variables.variables import Variable from sqlalchemy import exists, select from sqlalchemy.orm import Session, sessionmaker @@ -12,43 +37,22 @@ from configs import dify_config from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context +from core.app.file_access import DatabaseFileAccessController +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager from core.repositories import DifyCoreRepositoryFactory -from core.repositories.human_input_repository import HumanInputFormRepositoryImpl +from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl from core.trigger.constants import is_trigger_node_type -from core.workflow.node_factory import LATEST_VERSION, get_node_type_classes_mapping, is_start_node_type -from core.workflow.workflow_entry import WorkflowEntry -from dify_graph.entities import GraphInitParams, WorkflowNodeExecution -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import ( - ErrorStrategy, - NodeType, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from dify_graph.errors import WorkflowNodeRunFailedError -from dify_graph.file import File -from dify_graph.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes import BuiltinNodeTypes -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( DeliveryChannelConfig, - HumanInputNodeData, - apply_debug_email_recipient, - validate_human_input_submission, + normalize_human_input_node_data_for_graph, + parse_human_input_delivery_methods, ) -from dify_graph.nodes.human_input.enums import HumanInputFormKind -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.repositories.human_input_form_repository import FormCreateParams -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variable_loader import load_into_variable_pool -from dify_graph.variables import VariableBase -from dify_graph.variables.input_entities import VariableEntityType -from dify_graph.variables.variables import Variable +from core.workflow.node_factory import LATEST_VERSION, get_node_type_classes_mapping, is_start_node_type +from core.workflow.node_runtime import DifyHumanInputNodeRuntime, apply_dify_debug_email_recipient +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables, default_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool +from core.workflow.workflow_entry import WorkflowEntry +from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace from enums.cloud_plan import CloudPlan from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db @@ -63,7 +67,12 @@ from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeEx from repositories.factory import DifyAPIRepositoryFactory from services.billing_service import BillingService from services.enterprise.plugin_manager_service import PluginCredentialType -from services.errors.app import IsDraftWorkflowError, TriggerNodeLimitExceededError, WorkflowHashNotEqualError +from services.errors.app import ( + IsDraftWorkflowError, + TriggerNodeLimitExceededError, + WorkflowHashNotEqualError, + WorkflowNotFoundError, +) from services.workflow.workflow_converter import WorkflowConverter from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError @@ -75,6 +84,9 @@ from .human_input_delivery_test_service import ( HumanInputDeliveryTestService, ) from .workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader, WorkflowDraftVariableService +from .workflow_restore import apply_published_workflow_snapshot_to_draft + +_file_access_controller = DatabaseFileAccessController() class WorkflowService: @@ -279,6 +291,43 @@ class WorkflowService: # return draft workflow return workflow + def restore_published_workflow_to_draft( + self, + *, + app_model: App, + workflow_id: str, + account: Account, + ) -> Workflow: + """Restore a published workflow snapshot into the draft workflow. + + Secret environment variables are copied server-side from the selected + published workflow so the normal draft sync flow stays stateless. + """ + source_workflow = self.get_published_workflow_by_id(app_model=app_model, workflow_id=workflow_id) + if not source_workflow: + raise WorkflowNotFoundError("Workflow not found.") + + self.validate_features_structure(app_model=app_model, features=source_workflow.normalized_features_dict) + self.validate_graph_structure(graph=source_workflow.graph_dict) + + draft_workflow = self.get_draft_workflow(app_model=app_model) + draft_workflow, is_new_draft = apply_published_workflow_snapshot_to_draft( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + source_workflow=source_workflow, + draft_workflow=draft_workflow, + account=account, + updated_at_factory=naive_utc_now, + ) + + if is_new_draft: + db.session.add(draft_workflow) + + db.session.commit() + app_draft_workflow_was_synced.send(app_model, synced_draft_workflow=draft_workflow) + + return draft_workflow + def publish_workflow( self, *, @@ -443,13 +492,15 @@ class WorkflowService: :raises ValueError: If the model configuration is invalid or credentials fail policy checks """ try: - from core.model_manager import ModelManager - from core.provider_manager import ProviderManager - from dify_graph.model_runtime.entities.model_entities import ModelType + from graphon.model_runtime.entities.model_entities import ModelType + + # Model instance resolution and provider status lookup must reuse the + # same request-scoped runtime so validation does not silently split + # provider discovery and credential reads across different caches. + assembly = create_plugin_model_assembly(tenant_id=tenant_id) # Get model instance to validate provider+model combination - model_manager = ModelManager() - model_manager.get_model_instance( + assembly.model_manager.get_model_instance( tenant_id=tenant_id, provider=provider, model_type=ModelType.LLM, model=model_name ) @@ -458,8 +509,7 @@ class WorkflowService: # If it fails, an exception will be raised # Additionally, check the model status to ensure it's ACTIVE - provider_manager = ProviderManager() - provider_configurations = provider_manager.get_configurations(tenant_id) + provider_configurations = assembly.provider_manager.get_configurations(tenant_id) models = provider_configurations.get_models(provider=provider, model_type=ModelType.LLM) target_model = None @@ -564,11 +614,10 @@ class WorkflowService: :return: True if load balancing is enabled, False otherwise """ try: - from core.provider_manager import ProviderManager - from dify_graph.model_runtime.entities.model_entities import ModelType + from graphon.model_runtime.entities.model_entities import ModelType # Get provider configurations - provider_manager = ProviderManager() + provider_manager = create_plugin_provider_manager(tenant_id=tenant_id) provider_configurations = provider_manager.get_configurations(tenant_id) provider_configuration = provider_configurations.get(provider) @@ -722,6 +771,7 @@ class WorkflowService: user_id=account.id, user_inputs=user_inputs, workflow=draft_workflow, + node_id=node_id, # NOTE(QuantumGhost): We rely on `DraftVarLoader` to load conversation variables. conversation_variables=[], node_type=node_type, @@ -729,11 +779,13 @@ class WorkflowService: ) else: - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs=user_inputs, - environment_variables=draft_workflow.environment_variables, - conversation_variables=[], + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=default_system_variables(), + environment_variables=draft_workflow.environment_variables, + ), ) variable_loader = DraftVarLoader( @@ -798,6 +850,13 @@ class WorkflowService: draft_var_saver.save(process_data=node_execution.process_data, outputs=outputs) session.commit() + enqueue_draft_node_execution_trace( + execution=workflow_node_execution, + outputs=outputs, + workflow_execution_id=None, + user_id=account.id, + ) + return workflow_node_execution def get_human_input_form_preview( @@ -852,7 +911,6 @@ class WorkflowService: node_id=node_id, node_title=node.title, resolved_default_values=resolved_default_values, - form_token=None, ) return human_input_required.model_dump(mode="json") @@ -952,17 +1010,20 @@ class WorkflowService: if node_type != BuiltinNodeTypes.HUMAN_INPUT: raise ValueError("Node type must be human-input.") - node_data = HumanInputNodeData.model_validate(node_config["data"], from_attributes=True) + node_data = HumanInputNodeData.model_validate( + normalize_human_input_node_data_for_graph(node_config["data"]), + from_attributes=True, + ) delivery_method = self._resolve_human_input_delivery_method( node_data=node_data, delivery_method_id=delivery_method_id, ) if delivery_method is None: raise ValueError("Delivery method not found.") - delivery_method = apply_debug_email_recipient( + delivery_method = apply_dify_debug_email_recipient( delivery_method, enabled=True, - user_id=account.id, + actor_id=account.id, ) variable_pool = self._build_human_input_variable_pool( @@ -1012,7 +1073,7 @@ class WorkflowService: node_data: HumanInputNodeData, delivery_method_id: str, ) -> DeliveryChannelConfig | None: - for method in node_data.delivery_methods: + for method in parse_human_input_delivery_methods(node_data): if str(method.id) == delivery_method_id: return method return None @@ -1027,9 +1088,8 @@ class WorkflowService: rendered_content: str, resolved_default_values: Mapping[str, Any], ) -> tuple[str, list[DeliveryTestEmailRecipient]]: - repo = HumanInputFormRepositoryImpl(tenant_id=app_model.tenant_id) + repo = HumanInputFormRepositoryImpl(tenant_id=app_model.tenant_id, app_id=app_model.id) params = FormCreateParams( - app_id=app_model.id, workflow_execution_id=None, node_id=node_id, form_config=node_data, @@ -1058,7 +1118,7 @@ class WorkflowService: continue try: payload = json.loads(recipient.recipient_payload) - except Exception: + except (json.JSONDecodeError, ValueError): logger.exception("Failed to parse human input recipient payload for delivery test.") continue email = payload.get("email") @@ -1095,7 +1155,7 @@ class WorkflowService: config=node_config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, - form_repository=HumanInputFormRepositoryImpl(tenant_id=workflow.tenant_id), + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) return node @@ -1112,11 +1172,13 @@ class WorkflowService: draft_var_srv = WorkflowDraftVariableService(session) draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user_id) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={}, - environment_variables=workflow.environment_variables, - conversation_variables=[], + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=default_system_variables(), + environment_variables=workflow.environment_variables, + ), ) variable_loader = DraftVarLoader( @@ -1376,10 +1438,10 @@ class WorkflowService: Raises: ValueError: If the node data format is invalid """ - from dify_graph.nodes.human_input.entities import HumanInputNodeData + from graphon.nodes.human_input.entities import HumanInputNodeData try: - HumanInputNodeData.model_validate(node_data) + HumanInputNodeData.model_validate(normalize_human_input_node_data_for_graph(node_data)) except Exception as e: raise ValueError(f"Invalid HumanInput node data: {str(e)}") @@ -1468,38 +1530,48 @@ def _setup_variable_pool( user_id: str, user_inputs: Mapping[str, Any], workflow: Workflow, + node_id: str, node_type: NodeType, conversation_id: str, conversation_variables: list[VariableBase], ): # Only inject system variables for START node type. if is_start_node_type(node_type): - system_variable = SystemVariable( - user_id=user_id, - app_id=workflow.app_id, - timestamp=int(naive_utc_now().timestamp()), - workflow_id=workflow.id, - files=files or [], - workflow_execution_id=str(uuid.uuid4()), - ) + system_variable_values: dict[str, Any] = { + "user_id": user_id, + "app_id": workflow.app_id, + "timestamp": int(naive_utc_now().timestamp()), + "workflow_id": workflow.id, + "files": files or [], + "workflow_execution_id": str(uuid.uuid4()), + } - # Only add chatflow-specific variables for non-workflow types + # Only add chatflow-specific variables for non-workflow types. if workflow.type != WorkflowType.WORKFLOW: - system_variable.query = query - system_variable.conversation_id = conversation_id - system_variable.dialogue_count = 1 + system_variable_values.update( + { + "query": query, + "conversation_id": conversation_id, + "dialogue_count": 1, + } + ) + + system_variable = build_system_variables(system_variable_values) else: - system_variable = SystemVariable.default() + system_variable = default_system_variables() # init variable pool - variable_pool = VariablePool( - system_variables=system_variable, - user_inputs=user_inputs, - environment_variables=workflow.environment_variables, - # Based on the definition of `Variable`, - # `VariableBase` instances can be safely used as `Variable` since they are compatible. - conversation_variables=cast(list[Variable], conversation_variables), # + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_variable, + environment_variables=workflow.environment_variables, + conversation_variables=cast(list[Variable], conversation_variables), + ), ) + if is_start_node_type(node_type): + add_node_inputs_to_pool(variable_pool, node_id=node_id, inputs=user_inputs) return variable_pool @@ -1524,7 +1596,7 @@ def _rebuild_single_file(tenant_id: str, value: Any, variable_entity_type: Varia if variable_entity_type == VariableEntityType.FILE: if not isinstance(value, dict): raise ValueError(f"expected dict for file object, got {type(value)}") - return build_from_mapping(mapping=value, tenant_id=tenant_id) + return build_from_mapping(mapping=value, tenant_id=tenant_id, access_controller=_file_access_controller) elif variable_entity_type == VariableEntityType.FILE_LIST: if not isinstance(value, list): raise ValueError(f"expected list for file list object, got {type(value)}") @@ -1532,6 +1604,6 @@ def _rebuild_single_file(tenant_id: str, value: Any, variable_entity_type: Varia return [] if not isinstance(value[0], dict): raise ValueError(f"expected dict for first element in the file list, got {type(value)}") - return build_from_mappings(mappings=value, tenant_id=tenant_id) + return build_from_mappings(mappings=value, tenant_id=tenant_id, access_controller=_file_access_controller) else: raise Exception("unreachable") diff --git a/api/tasks/annotation/add_annotation_to_index_task.py b/api/tasks/annotation/add_annotation_to_index_task.py index a9a8b892c2a..dafa36cc343 100644 --- a/api/tasks/annotation/add_annotation_to_index_task.py +++ b/api/tasks/annotation/add_annotation_to_index_task.py @@ -5,6 +5,7 @@ import click from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -36,7 +37,7 @@ def add_annotation_to_index_task( dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, collection_binding_id=dataset_collection_binding.id, diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py index fc6bf034542..c734e1321ba 100644 --- a/api/tasks/annotation/batch_import_annotations_task.py +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound from core.db.session_factory import session_factory from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset @@ -67,7 +68,7 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, collection_binding_id=dataset_collection_binding.id, diff --git a/api/tasks/annotation/delete_annotation_index_task.py b/api/tasks/annotation/delete_annotation_index_task.py index 432732af957..c9aa8fadb78 100644 --- a/api/tasks/annotation/delete_annotation_index_task.py +++ b/api/tasks/annotation/delete_annotation_index_task.py @@ -5,6 +5,7 @@ import click from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -26,7 +27,7 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, collection_binding_id=dataset_collection_binding.id, ) diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py index 7b5cd46b007..41cf7ccbf61 100644 --- a/api/tasks/annotation/disable_annotation_reply_task.py +++ b/api/tasks/annotation/disable_annotation_reply_task.py @@ -7,6 +7,7 @@ from sqlalchemy import exists, select from core.db.session_factory import session_factory from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_redis import redis_client from models.dataset import Dataset from models.model import App, AppAnnotationSetting, MessageAnnotation @@ -44,7 +45,7 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str): dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, collection_binding_id=app_annotation_setting.collection_binding_id, ) diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index 1fe43c3d624..2c07fe0f31f 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -7,6 +7,7 @@ from sqlalchemy import select from core.db.session_factory import session_factory from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now @@ -64,7 +65,7 @@ def enable_annotation_reply_task( old_dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=old_dataset_collection_binding.provider_name, embedding_model=old_dataset_collection_binding.model_name, collection_binding_id=old_dataset_collection_binding.id, @@ -93,7 +94,7 @@ def enable_annotation_reply_task( dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=embedding_provider_name, embedding_model=embedding_model_name, collection_binding_id=dataset_collection_binding.id, diff --git a/api/tasks/annotation/update_annotation_to_index_task.py b/api/tasks/annotation/update_annotation_to_index_task.py index 6ff34c0e745..f41da1d373e 100644 --- a/api/tasks/annotation/update_annotation_to_index_task.py +++ b/api/tasks/annotation/update_annotation_to_index_task.py @@ -5,6 +5,7 @@ import click from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -37,7 +38,7 @@ def update_annotation_to_index_task( dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, collection_binding_id=dataset_collection_binding.id, diff --git a/api/tasks/app_generate/workflow_execute_task.py b/api/tasks/app_generate/workflow_execute_task.py index 174aa50343b..489467651d0 100644 --- a/api/tasks/app_generate/workflow_execute_task.py +++ b/api/tasks/app_generate/workflow_execute_task.py @@ -7,6 +7,7 @@ from typing import Annotated, Any, TypeAlias, Union from celery import shared_task from flask import current_app, json +from graphon.runtime import GraphRuntimeState from pydantic import BaseModel, Discriminator, Field, Tag from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker @@ -21,7 +22,6 @@ from core.app.entities.app_invoke_entities import ( ) from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext from core.repositories import DifyCoreRepositoryFactory -from dify_graph.runtime import GraphRuntimeState from extensions.ext_database import db from libs.flask_utils import set_login_user from models.account import Account @@ -239,13 +239,18 @@ def _resolve_user_for_run(session: Session, workflow_run: WorkflowRun) -> Accoun def _publish_streaming_response( - response_stream: Generator[str | Mapping[str, Any], None, None], workflow_run_id: str, app_mode: AppMode + response_stream: Generator[str | Mapping[str, Any] | BaseModel, None, None], + workflow_run_id: str, + app_mode: AppMode, ) -> None: topic = MessageBasedAppGenerator.get_response_topic(app_mode, workflow_run_id) for event in response_stream: try: - payload = json.dumps(event) - except TypeError: + if isinstance(event, BaseModel): + payload = json.dumps(event.model_dump(mode="json"), ensure_ascii=False) + else: + payload = json.dumps(event, ensure_ascii=False, default=str) + except (TypeError, ValueError): logger.exception("error while encoding event") continue diff --git a/api/tasks/async_workflow_tasks.py b/api/tasks/async_workflow_tasks.py index d247cf5cf71..0a73c912798 100644 --- a/api/tasks/async_workflow_tasks.py +++ b/api/tasks/async_workflow_tasks.py @@ -10,6 +10,7 @@ from datetime import UTC, datetime from typing import Any from celery import shared_task +from graphon.runtime import GraphRuntimeState from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -21,7 +22,6 @@ from core.app.layers.timeslice_layer import TimeSliceLayer from core.app.layers.trigger_post_layer import TriggerPostLayer from core.db.session_factory import session_factory from core.repositories import DifyCoreRepositoryFactory -from dify_graph.runtime import GraphRuntimeState from extensions.ext_database import db from models.account import Account from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom, WorkflowTriggerStatus diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 49dee009194..20335d9b9f9 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -7,11 +7,12 @@ from pathlib import Path import click import pandas as pd from celery import shared_task +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import func from core.db.session_factory import session_factory from core.model_manager import ModelManager -from dify_graph.model_runtime.entities.model_entities import ModelType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from extensions.ext_storage import storage from libs import helper @@ -109,7 +110,7 @@ def batch_create_segment_to_index_task( df = pd.read_csv(file_path) content = [] for _, row in df.iterrows(): - if document_config["doc_form"] == "qa_model": + if document_config["doc_form"] == IndexStructureType.QA_INDEX: data = {"content": row.iloc[0], "answer": row.iloc[1]} else: data = {"content": row.iloc[0]} @@ -119,8 +120,8 @@ def batch_create_segment_to_index_task( document_segments = [] embedding_model = None - if dataset_config["indexing_technique"] == "high_quality": - model_manager = ModelManager() + if dataset_config["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY: + model_manager = ModelManager.for_tenant(tenant_id=dataset_config["tenant_id"]) embedding_model = model_manager.get_model_instance( tenant_id=dataset_config["tenant_id"], provider=dataset_config["embedding_model_provider"], @@ -159,7 +160,7 @@ def batch_create_segment_to_index_task( status="completed", completed_at=naive_utc_now(), ) - if document_config["doc_form"] == "qa_model": + if document_config["doc_form"] == IndexStructureType.QA_INDEX: segment_document.answer = segment["answer"] segment_document.word_count += len(segment["answer"]) word_count_change += segment_document.word_count diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index e05d63426c6..23a80fa1065 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -10,6 +10,7 @@ from configs import dify_config from core.db.session_factory import session_factory from core.entities.document_task import DocumentTask from core.indexing_runner import DocumentIsPausedError, IndexingRunner +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from libs.datetime_utils import naive_utc_now @@ -126,7 +127,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): logger.warning("Dataset %s not found after indexing", dataset_id) return - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: summary_index_setting = dataset.summary_index_setting if summary_index_setting and summary_index_setting.get("enable"): # expire all session to get latest document's indexing status @@ -150,7 +151,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): ) if ( document.indexing_status == IndexingStatus.COMPLETED - and document.doc_form != "qa_model" + and document.doc_form != IndexStructureType.QA_INDEX and document.need_summary is True ): try: diff --git a/api/tasks/enterprise_telemetry_task.py b/api/tasks/enterprise_telemetry_task.py new file mode 100644 index 00000000000..7d5ea7c0a5a --- /dev/null +++ b/api/tasks/enterprise_telemetry_task.py @@ -0,0 +1,52 @@ +"""Celery worker for enterprise metric/log telemetry events. + +This module defines the Celery task that processes telemetry envelopes +from the enterprise_telemetry queue. It deserializes envelopes and +dispatches them to the EnterpriseMetricHandler. +""" + +import json +import logging + +from celery import shared_task + +from enterprise.telemetry.contracts import TelemetryEnvelope +from enterprise.telemetry.metric_handler import EnterpriseMetricHandler + +logger = logging.getLogger(__name__) + + +@shared_task(queue="enterprise_telemetry") +def process_enterprise_telemetry(envelope_json: str) -> None: + """Process enterprise metric/log telemetry envelope. + + This task is enqueued by the TelemetryGateway for metric/log-only + events. It deserializes the envelope and dispatches to the handler. + + Best-effort processing: logs errors but never raises, to avoid + failing user requests due to telemetry issues. + + Args: + envelope_json: JSON-serialized TelemetryEnvelope. + """ + try: + # Deserialize envelope + envelope_dict = json.loads(envelope_json) + envelope = TelemetryEnvelope.model_validate(envelope_dict) + + # Process through handler + handler = EnterpriseMetricHandler() + handler.handle(envelope) + + logger.debug( + "Successfully processed telemetry envelope: tenant_id=%s, event_id=%s, case=%s", + envelope.tenant_id, + envelope.event_id, + envelope.case, + ) + except Exception: + # Best-effort: log and drop on error, never fail user request + logger.warning( + "Failed to process enterprise telemetry envelope, dropping event", + exc_info=True, + ) diff --git a/api/tasks/generate_summary_index_task.py b/api/tasks/generate_summary_index_task.py index 6493833edc1..e3d82d28516 100644 --- a/api/tasks/generate_summary_index_task.py +++ b/api/tasks/generate_summary_index_task.py @@ -7,6 +7,7 @@ import click from celery import shared_task from core.db.session_factory import session_factory +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument from services.summary_index_service import SummaryIndexService @@ -59,7 +60,7 @@ def generate_summary_index_task(dataset_id: str, document_id: str, segment_ids: return # Only generate summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: logger.info( click.style( f"Skipping summary generation for dataset {dataset_id}: " diff --git a/api/tasks/human_input_timeout_tasks.py b/api/tasks/human_input_timeout_tasks.py index dd3b6a45308..ca73b4d3745 100644 --- a/api/tasks/human_input_timeout_tasks.py +++ b/api/tasks/human_input_timeout_tasks.py @@ -2,13 +2,13 @@ import logging from datetime import timedelta from celery import shared_task +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from sqlalchemy import or_, select from sqlalchemy.orm import sessionmaker from configs import dify_config from core.repositories.human_input_repository import HumanInputFormSubmissionRepository -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from extensions.ext_database import db from extensions.ext_storage import storage from libs.datetime_utils import ensure_naive_utc, naive_utc_now diff --git a/api/tasks/mail_human_input_delivery_task.py b/api/tasks/mail_human_input_delivery_task.py index d2417833590..a316eec7b95 100644 --- a/api/tasks/mail_human_input_delivery_task.py +++ b/api/tasks/mail_human_input_delivery_task.py @@ -6,13 +6,13 @@ from typing import Any import click from celery import shared_task +from graphon.runtime import GraphRuntimeState, VariablePool from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext -from dify_graph.nodes.human_input.entities import EmailDeliveryConfig, EmailDeliveryMethod -from dify_graph.runtime import GraphRuntimeState, VariablePool +from core.workflow.human_input_compat import EmailDeliveryConfig, EmailDeliveryMethod from extensions.ext_database import db from extensions.ext_mail import mail from models.human_input import ( diff --git a/api/tasks/ops_trace_task.py b/api/tasks/ops_trace_task.py index 72e3b42ca7a..c95b8db0784 100644 --- a/api/tasks/ops_trace_task.py +++ b/api/tasks/ops_trace_task.py @@ -39,17 +39,36 @@ def process_trace_tasks(file_info): trace_info["documents"] = [Document.model_validate(doc) for doc in trace_info["documents"]] try: + trace_type = trace_info_info_map.get(trace_info_type) + if trace_type: + trace_info = trace_type(**trace_info) + + from extensions.ext_enterprise_telemetry import is_enabled as is_ee_telemetry_enabled + + if is_ee_telemetry_enabled(): + from enterprise.telemetry.enterprise_trace import EnterpriseOtelTrace + + try: + EnterpriseOtelTrace().trace(trace_info) + except Exception: + logger.exception("Enterprise trace failed for app_id: %s", app_id) + if trace_instance: with current_app.app_context(): - trace_type = trace_info_info_map.get(trace_info_type) - if trace_type: - trace_info = trace_type(**trace_info) trace_instance.trace(trace_info) + logger.info("Processing trace tasks success, app_id: %s", app_id) except Exception as e: - logger.info("error:\n\n\n%s\n\n\n\n", e) + logger.exception("Processing trace tasks failed, app_id: %s", app_id) failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}" redis_client.incr(failed_key) - logger.info("Processing trace tasks failed, app_id: %s", app_id) finally: - storage.delete(file_path) + try: + storage.delete(file_path) + except Exception as e: + logger.warning( + "Failed to delete trace file %s for app_id %s: %s", + file_path, + app_id, + e, + ) diff --git a/api/tasks/regenerate_summary_index_task.py b/api/tasks/regenerate_summary_index_task.py index 39c2f4103e4..6f490ab7eab 100644 --- a/api/tasks/regenerate_summary_index_task.py +++ b/api/tasks/regenerate_summary_index_task.py @@ -9,6 +9,7 @@ from celery import shared_task from sqlalchemy import or_, select from core.db.session_factory import session_factory +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary from models.dataset import Document as DatasetDocument from services.summary_index_service import SummaryIndexService @@ -52,7 +53,7 @@ def regenerate_summary_index_task( return # Only regenerate summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: logger.info( click.style( f"Skipping summary regeneration for dataset {dataset_id}: " @@ -106,7 +107,7 @@ def regenerate_summary_index_task( ), DatasetDocument.enabled == True, # Document must be enabled DatasetDocument.archived == False, # Document must not be archived - DatasetDocument.doc_form != "qa_model", # Skip qa_model documents + DatasetDocument.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents ) .order_by(DocumentSegment.document_id.asc(), DocumentSegment.position.asc()) .all() @@ -209,7 +210,7 @@ def regenerate_summary_index_task( for dataset_document in dataset_documents: # Skip qa_model documents - if dataset_document.doc_form == "qa_model": + if dataset_document.doc_form == IndexStructureType.QA_INDEX: continue try: diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py index 75ae1f63169..56626e372ea 100644 --- a/api/tasks/trigger_processing_tasks.py +++ b/api/tasks/trigger_processing_tasks.py @@ -12,6 +12,7 @@ from datetime import UTC, datetime from typing import Any from celery import shared_task +from graphon.enums import WorkflowExecutionStatus from sqlalchemy import func, select from sqlalchemy.orm import Session @@ -27,7 +28,6 @@ from core.trigger.entities.entities import TriggerProviderEntity from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData -from dify_graph.enums import WorkflowExecutionStatus from enums.quota_type import QuotaType, unlimited from models.enums import ( AppTriggerType, @@ -179,7 +179,7 @@ def _record_trigger_failure_log( app_id=workflow.app_id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from=WorkflowAppLogCreatedFrom.SERVICE_API.value, + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=created_by_role, created_by=created_by, ) diff --git a/api/tasks/workflow_execution_tasks.py b/api/tasks/workflow_execution_tasks.py index f41118e5925..0c7f74c180a 100644 --- a/api/tasks/workflow_execution_tasks.py +++ b/api/tasks/workflow_execution_tasks.py @@ -9,11 +9,11 @@ import json import logging from celery import shared_task +from graphon.entities import WorkflowExecution +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import select from core.db.session_factory import session_factory -from dify_graph.entities.workflow_execution import WorkflowExecution -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import CreatorUserRole, WorkflowRun from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tasks/workflow_node_execution_tasks.py b/api/tasks/workflow_node_execution_tasks.py index 466ef6c8588..f25ebe3bae4 100644 --- a/api/tasks/workflow_node_execution_tasks.py +++ b/api/tasks/workflow_node_execution_tasks.py @@ -9,13 +9,13 @@ import json import logging from celery import shared_task +from graphon.entities.workflow_node_execution import ( + WorkflowNodeExecution, +) +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import select from core.db.session_factory import session_factory -from dify_graph.entities.workflow_node_execution import ( - WorkflowNodeExecution, -) -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import CreatorUserRole, WorkflowNodeExecutionModel from models.workflow import WorkflowNodeExecutionTriggeredFrom @@ -125,7 +125,7 @@ def _create_node_execution_from_domain( else: node_execution.execution_metadata = "{}" - node_execution.status = execution.status.value + node_execution.status = execution.status node_execution.error = execution.error node_execution.elapsed_time = execution.elapsed_time node_execution.created_by_role = creator_user_role @@ -159,7 +159,7 @@ def _update_node_execution_from_domain(node_execution: WorkflowNodeExecutionMode node_execution.execution_metadata = "{}" # Update other fields - node_execution.status = execution.status.value + node_execution.status = execution.status node_execution.error = execution.error node_execution.elapsed_time = execution.elapsed_time node_execution.finished_at = execution.finished_at diff --git a/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py b/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py index afb6938baa9..d10e5ed13ce 100644 --- a/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py +++ b/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py @@ -13,6 +13,7 @@ from controllers.console.app import wraps from libs.datetime_utils import naive_utc_now from models import App, Tenant from models.account import Account, TenantAccountJoin, TenantAccountRole +from models.enums import ConversationFromSource from models.model import AppMode from services.app_generate_service import AppGenerateService @@ -154,7 +155,7 @@ class TestChatMessageApiPermissions: re_sign_file_url_answer="", answer_tokens=0, provider_response_latency=0.0, - from_source="console", + from_source=ConversationFromSource.CONSOLE, from_end_user_id=None, from_account_id=mock_account.id, feedbacks=[], diff --git a/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py b/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py index 4fdbb7d9f3b..91245e879e4 100644 --- a/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py +++ b/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py @@ -1,8 +1,9 @@ from collections.abc import Generator +from graphon.node_events import StreamCompletedEvent + from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceMessage -from dify_graph.node_events import StreamCompletedEvent def _gen_var_stream() -> Generator[DatasourceMessage, None, None]: diff --git a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py index 3e79792b5bc..3fdea109762 100644 --- a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py +++ b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py @@ -1,7 +1,8 @@ +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult, StreamCompletedEvent + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from core.workflow.nodes.datasource.datasource_node import DatasourceNode -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult, StreamCompletedEvent class _Seg: diff --git a/api/tests/integration_tests/factories/test_storage_key_loader.py b/api/tests/integration_tests/factories/test_storage_key_loader.py index db4bbc1ca16..c1bb8e12453 100644 --- a/api/tests/integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/integration_tests/factories/test_storage_key_loader.py @@ -4,9 +4,10 @@ from unittest.mock import patch from uuid import uuid4 import pytest +from graphon.file import File, FileTransferMethod, FileType from sqlalchemy.orm import Session -from dify_graph.file import File, FileTransferMethod, FileType +from core.app.file_access import DatabaseFileAccessController from extensions.ext_database import db from extensions.storage.storage_type import StorageType from factories.file_factory import StorageKeyLoader @@ -35,7 +36,11 @@ class TestStorageKeyLoader(unittest.TestCase): self.test_tool_files = [] # Create StorageKeyLoader instance - self.loader = StorageKeyLoader(self.session, self.tenant_id) + self.loader = StorageKeyLoader( + self.session, + self.tenant_id, + access_controller=DatabaseFileAccessController(), + ) def tearDown(self): """Clean up test data after each test method.""" @@ -192,19 +197,16 @@ class TestStorageKeyLoader(unittest.TestCase): # Should not raise any exceptions self.loader.load_storage_keys([]) - def test_load_storage_keys_tenant_mismatch(self): - """Test tenant_id validation.""" - # Create file with different tenant_id + def test_load_storage_keys_ignores_legacy_file_tenant_id(self): + """Legacy file tenant_id should not override the loader tenant scope.""" upload_file = self._create_upload_file() file = self._create_file( related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4()) ) - # Should raise ValueError for tenant mismatch - with pytest.raises(ValueError) as context: - self.loader.load_storage_keys([file]) + self.loader.load_storage_keys([file]) - assert "invalid file, expected tenant_id" in str(context.value) + assert file._storage_key == upload_file.key def test_load_storage_keys_missing_file_id(self): """Test with None file.related_id.""" @@ -313,7 +315,7 @@ class TestStorageKeyLoader(unittest.TestCase): with pytest.raises(ValueError) as context: self.loader.load_storage_keys([file_other]) - assert "invalid file, expected tenant_id" in str(context.value) + assert "Upload file not found for id:" in str(context.value) # Current tenant's file should still work self.loader.load_storage_keys([file_current]) @@ -337,7 +339,7 @@ class TestStorageKeyLoader(unittest.TestCase): with pytest.raises(ValueError) as context: self.loader.load_storage_keys([file_current, file_other]) - assert "invalid file, expected tenant_id" in str(context.value) + assert "Upload file not found for id:" in str(context.value) def test_load_storage_keys_duplicate_file_ids(self): """Test handling of duplicate file IDs in the batch.""" @@ -364,6 +366,10 @@ class TestStorageKeyLoader(unittest.TestCase): # Create loader with different session (same underlying connection) with Session(bind=db.engine) as other_session: - other_loader = StorageKeyLoader(other_session, self.tenant_id) + other_loader = StorageKeyLoader( + other_session, + self.tenant_id, + access_controller=DatabaseFileAccessController(), + ) with pytest.raises(ValueError): other_loader.load_storage_keys([file]) diff --git a/api/tests/integration_tests/libs/test_api_token_cache_integration.py b/api/tests/integration_tests/libs/test_api_token_cache_integration.py index 1d7b835fd2d..a942690cbdf 100644 --- a/api/tests/integration_tests/libs/test_api_token_cache_integration.py +++ b/api/tests/integration_tests/libs/test_api_token_cache_integration.py @@ -13,6 +13,7 @@ from unittest.mock import patch import pytest from extensions.ext_redis import redis_client +from models.enums import ApiTokenType from models.model import ApiToken from services.api_token_service import ApiTokenCache, CachedApiToken @@ -279,7 +280,7 @@ class TestEndToEndCacheFlow: test_token = ApiToken() test_token.id = "test-e2e-id" test_token.token = test_token_value - test_token.type = test_scope + test_token.type = ApiTokenType.APP test_token.app_id = "test-app" test_token.tenant_id = "test-tenant" test_token.last_used_at = None diff --git a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py index 4e184c93fde..ce04a158a82 100644 --- a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py +++ b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py @@ -4,27 +4,27 @@ from collections.abc import Generator, Sequence from decimal import Decimal from json import dumps -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity -from core.plugin.impl.model import PluginModelClient - # import monkeypatch -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.llm_entities import ( +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.llm_entities import ( LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage, ) -from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, PromptMessageTool -from dify_graph.model_runtime.entities.model_entities import ( +from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import ( AIModelEntity, FetchFrom, ModelFeature, ModelPropertyKey, ModelType, ) -from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity +from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity + +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from core.plugin.impl.model import PluginModelClient class MockModelClass(PluginModelClient): diff --git a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py index 9d3a8696917..5c6636f31ec 100644 --- a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py @@ -3,14 +3,14 @@ import unittest import uuid import pytest +from graphon.nodes import BuiltinNodeTypes +from graphon.variables.segments import StringSegment +from graphon.variables.types import SegmentType +from graphon.variables.variables import StringVariable from sqlalchemy import delete from sqlalchemy.orm import Session -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from dify_graph.nodes import BuiltinNodeTypes -from dify_graph.variables.segments import StringSegment -from dify_graph.variables.types import SegmentType -from dify_graph.variables.variables import StringVariable +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from extensions.ext_database import db from extensions.ext_storage import storage from extensions.storage.storage_type import StorageType diff --git a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py index bc83c6cc12a..38dc8bbb281 100644 --- a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -2,10 +2,10 @@ import uuid from unittest.mock import patch import pytest +from graphon.variables.segments import StringSegment from sqlalchemy import delete from core.db.session_factory import session_factory -from dify_graph.variables.segments import StringSegment from extensions.storage.storage_type import StorageType from models import Tenant from models.enums import CreatorUserRole @@ -192,7 +192,8 @@ class TestDeleteDraftVariablesWithOffloadIntegration: @pytest.fixture def setup_offload_test_data(self, app_and_tenant): tenant, app = app_and_tenant - from dify_graph.variables.types import SegmentType + from graphon.variables.types import SegmentType + from libs.datetime_utils import naive_utc_now with session_factory.create_session() as session: @@ -423,7 +424,8 @@ class TestDeleteDraftVariablesSessionCommit: @pytest.fixture def setup_offload_test_data(self, app_and_tenant): """Create test data with offload files for session commit tests.""" - from dify_graph.variables.types import SegmentType + from graphon.variables.types import SegmentType + from libs.datetime_utils import naive_utc_now tenant, app = app_and_tenant diff --git a/api/tests/integration_tests/workflow/nodes/__mock/model.py b/api/tests/integration_tests/workflow/nodes/__mock/model.py index 5b0f86fed11..c0143faa853 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/model.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/model.py @@ -1,11 +1,12 @@ from unittest.mock import MagicMock +from graphon.model_runtime.entities.model_entities import ModelType + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration from core.model_manager import ModelInstance -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory from models.provider import ProviderType @@ -15,7 +16,7 @@ def get_mocked_fetch_model_config( mode: str, credentials: dict, ): - model_provider_factory = ModelProviderFactory(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b") + model_provider_factory = create_plugin_model_provider_factory(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b") model_type_instance = model_provider_factory.get_model_type_instance(provider, ModelType.LLM) provider_model_bundle = ProviderModelBundle( configuration=ProviderConfiguration( diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index e3a2b6b8664..ce0c8bf8ca8 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -2,17 +2,17 @@ import time import uuid import pytest +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.node_events import NodeRunResult +from graphon.nodes.code.code_node import CodeNode +from graphon.nodes.code.limits import CodeNodeLimits +from graphon.runtime import GraphRuntimeState, VariablePool from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.code.code_node import CodeNode -from dify_graph.nodes.code.limits import CodeNodeLimits -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock from tests.workflow_test_utils import build_test_graph_init_params @@ -44,7 +44,7 @@ def init_code_node(code_config: dict): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -172,7 +172,7 @@ def test_execute_code_output_validator(setup_code_executor_mock): result = node._run() assert isinstance(result, NodeRunResult) assert result.status == WorkflowNodeExecutionStatus.FAILED - assert result.error == "Output result must be a string, got int instead" + assert result.error == "Output result must be a string, got int instead." def test_execute_code_output_validator_depth(): diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index f885f69e552..ce18486fafc 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -3,18 +3,19 @@ import uuid from urllib.parse import urlencode import pytest +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file.file_manager import file_manager +from graphon.graph import Graph +from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig +from graphon.runtime import GraphRuntimeState, VariablePool from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager from core.workflow.node_factory import DifyNodeFactory -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.file.file_manager import file_manager -from dify_graph.graph import Graph -from dify_graph.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.node_runtime import DifyFileReferenceFactory +from core.workflow.system_variables import build_system_variables from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock from tests.workflow_test_utils import build_test_graph_init_params @@ -54,7 +55,7 @@ def init_http_node(config: dict): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -81,6 +82,7 @@ def init_http_node(config: dict): http_client=ssrf_proxy, tool_file_manager_factory=ToolFileManager, file_manager=file_manager, + file_reference_factory=DifyFileReferenceFactory(init_params.run_context), ) return node @@ -189,20 +191,21 @@ def test_custom_authorization_header(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): """Test: In custom authentication mode, when the api_key is empty, AuthorizationConfigError should be raised.""" - from dify_graph.enums import BuiltinNodeTypes - from dify_graph.nodes.http_request.entities import ( + from graphon.enums import BuiltinNodeTypes + from graphon.nodes.http_request.entities import ( HttpRequestNodeAuthorization, HttpRequestNodeData, HttpRequestNodeTimeout, ) - from dify_graph.nodes.http_request.exc import AuthorizationConfigError - from dify_graph.nodes.http_request.executor import Executor - from dify_graph.runtime import VariablePool - from dify_graph.system_variable import SystemVariable + from graphon.nodes.http_request.exc import AuthorizationConfigError + from graphon.nodes.http_request.executor import Executor + from graphon.runtime import VariablePool + + from core.workflow.system_variables import build_system_variables # Create variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="test", files=[]), + system_variables=build_system_variables(user_id="test", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -700,7 +703,7 @@ def test_nested_object_variable_selector(setup_http_mock): # Create independent variable pool for this test only variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -728,6 +731,7 @@ def test_nested_object_variable_selector(setup_http_mock): http_client=ssrf_proxy, tool_file_manager_factory=ToolFileManager, file_manager=file_manager, + file_reference_factory=DifyFileReferenceFactory(init_params.run_context), ) result = node._run() diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index d628348f1eb..f0f3fcead19 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -4,16 +4,19 @@ import uuid from collections.abc import Generator from unittest.mock import MagicMock, patch +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.node_events import StreamCompletedEvent +from graphon.nodes.llm.file_saver import LLMFileSaver +from graphon.nodes.llm.node import LLMNode +from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory +from graphon.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol +from graphon.nodes.protocols import HttpClientProtocol +from graphon.runtime import GraphRuntimeState, VariablePool + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.llm_generator.output_parser.structured_output import _parse_structured_output from core.model_manager import ModelInstance -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.node_events import StreamCompletedEvent -from dify_graph.nodes.llm.node import LLMNode -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer -from dify_graph.nodes.protocols import HttpClientProtocol -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables from extensions.ext_database import db from tests.workflow_test_utils import build_test_graph_init_params @@ -51,7 +54,7 @@ def init_llm_node(config: dict) -> LLMNode: # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="aaa", app_id=app_id, workflow_id=workflow_id, @@ -66,6 +69,11 @@ def init_llm_node(config: dict) -> LLMNode: variable_pool.add(["abc", "output"], "sunny") graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + prompt_message_serializer = MagicMock(spec=PromptMessageSerializerProtocol) + prompt_message_serializer.serialize.side_effect = lambda *, model_mode, prompt_messages: [ + message.model_dump(mode="json") for message in prompt_messages + ] + llm_file_saver = MagicMock(spec=LLMFileSaver) node = LLMNode( id=str(uuid.uuid4()), @@ -75,7 +83,8 @@ def init_llm_node(config: dict) -> LLMNode: credentials_provider=MagicMock(spec=CredentialsProvider), model_factory=MagicMock(spec=ModelFactory), model_instance=MagicMock(spec=ModelInstance), - template_renderer=MagicMock(spec=TemplateRenderer), + llm_file_saver=llm_file_saver, + prompt_message_serializer=prompt_message_serializer, http_client=MagicMock(spec=HttpClientProtocol), ) @@ -115,8 +124,8 @@ def test_execute_llm(): from decimal import Decimal from unittest.mock import MagicMock - from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage - from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage + from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage + from graphon.model_runtime.entities.message_entities import AssistantPromptMessage # Create mock model instance mock_model_instance = MagicMock(spec=ModelInstance) @@ -159,8 +168,8 @@ def test_execute_llm(): return mock_model_instance # Mock fetch_prompt_messages to avoid database calls - def mock_fetch_prompt_messages_1(*_args, **_kwargs): - from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage + def mock_fetch_prompt_messages_1(**_kwargs): + from graphon.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage return [ SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."), @@ -231,8 +240,8 @@ def test_execute_llm_with_jinja2(): from decimal import Decimal from unittest.mock import MagicMock - from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage - from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage + from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage + from graphon.model_runtime.entities.message_entities import AssistantPromptMessage # Create mock model instance mock_model_instance = MagicMock(spec=ModelInstance) @@ -276,7 +285,7 @@ def test_execute_llm_with_jinja2(): # Mock fetch_prompt_messages to avoid database calls def mock_fetch_prompt_messages_2(**_kwargs): - from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage + from graphon.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage return [ SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."), diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 62d9af01963..3bf44df349c 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -3,14 +3,16 @@ import time import uuid from unittest.mock import MagicMock +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.model_runtime.entities import AssistantPromptMessage, UserPromptMessage +from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory +from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from graphon.runtime import GraphRuntimeState, VariablePool + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.model_manager import ModelInstance -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.model_runtime.entities import AssistantPromptMessage, UserPromptMessage -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory -from dify_graph.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.node_runtime import DifyPromptMessageSerializer +from core.workflow.system_variables import build_system_variables from extensions.ext_database import db from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance from tests.workflow_test_utils import build_test_graph_init_params @@ -56,7 +58,7 @@ def init_parameter_extractor_node(config: dict, memory=None): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="aaa", files=[], query="what's the weather in SF", conversation_id="abababa" ), user_inputs={}, @@ -77,6 +79,7 @@ def init_parameter_extractor_node(config: dict, memory=None): model_factory=MagicMock(spec=ModelFactory), model_instance=MagicMock(spec=ModelInstance), memory=memory, + prompt_message_serializer=DifyPromptMessageSerializer(), ) return node diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 7bb4f905c31..2d728569bee 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -1,14 +1,15 @@ import time import uuid +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.template_rendering import TemplateRenderError + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.nodes.template_transform.template_renderer import TemplateRenderError -from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params @@ -66,7 +67,7 @@ def test_execute_template_transform(): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -90,7 +91,7 @@ def test_execute_template_transform(): config=config, graph_init_params=init_params, graph_runtime_state=graph_runtime_state, - template_renderer=_SimpleJinja2Renderer(), + jinja2_template_renderer=_SimpleJinja2Renderer(), ) # execute node diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index a6717ada316..750ced7075e 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -2,16 +2,18 @@ import time import uuid from unittest.mock import MagicMock, patch +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.node_events import StreamCompletedEvent +from graphon.nodes.protocols import ToolFileManagerProtocol +from graphon.nodes.tool.tool_node import ToolNode +from graphon.runtime import GraphRuntimeState, VariablePool + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.tools.utils.configuration import ToolParameterConfigurationManager from core.workflow.node_factory import DifyNodeFactory -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.node_events import StreamCompletedEvent -from dify_graph.nodes.protocols import ToolFileManagerProtocol -from dify_graph.nodes.tool.tool_node import ToolNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.node_runtime import DifyToolNodeRuntime +from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params @@ -40,7 +42,7 @@ def init_tool_node(config: dict): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -64,11 +66,12 @@ def init_tool_node(config: dict): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, tool_file_manager_factory=tool_file_manager_factory, + runtime=DifyToolNodeRuntime(init_params.run_context), ) return node -def test_tool_variable_invoke(): +def test_tool_variable_invoke(monkeypatch): node = init_tool_node( config={ "id": "1", @@ -103,7 +106,7 @@ def test_tool_variable_invoke(): assert item.node_run_result.outputs.get("text") is not None -def test_tool_mixed_invoke(): +def test_tool_mixed_invoke(monkeypatch): node = init_tool_node( config={ "id": "1", diff --git a/api/tests/test_containers_integration_tests/conftest.py b/api/tests/test_containers_integration_tests/conftest.py index 0bdd3bdc471..be8a1c6aab0 100644 --- a/api/tests/test_containers_integration_tests/conftest.py +++ b/api/tests/test_containers_integration_tests/conftest.py @@ -32,6 +32,10 @@ from extensions.ext_database import db # Configure logging for test containers logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) +_TEST_SANDBOX_IMAGE = os.getenv("TEST_SANDBOX_IMAGE", "langgenius/dify-sandbox:0.2.12") + +DEFAULT_SANDBOX_TEST_IMAGE = "langgenius/dify-sandbox:0.2.14" +SANDBOX_TEST_IMAGE_ENV = "DIFY_SANDBOX_TEST_IMAGE" class _CloserProtocol(Protocol): @@ -163,10 +167,11 @@ class DifyTestContainers: wait_for_logs(self.redis, "Ready to accept connections", timeout=30) logger.info("Redis container is ready and accepting connections") - # Start Dify Sandbox container for code execution environment - # Dify Sandbox provides a secure environment for executing user code + # Start Dify Sandbox container for code execution environment. + # Default to the production-pinned image while allowing local overrides for debugging. logger.info("Initializing Dify Sandbox container...") - self.dify_sandbox = DockerContainer(image="langgenius/dify-sandbox:latest").with_network(self.network) + sandbox_image = os.getenv(SANDBOX_TEST_IMAGE_ENV, DEFAULT_SANDBOX_TEST_IMAGE) + self.dify_sandbox = DockerContainer(image=sandbox_image).with_network(self.network) self.dify_sandbox.with_exposed_ports(8194) self.dify_sandbox.env = { "API_KEY": "test_api_key", @@ -176,7 +181,12 @@ class DifyTestContainers: sandbox_port = self.dify_sandbox.get_exposed_port(8194) os.environ["CODE_EXECUTION_ENDPOINT"] = f"http://{sandbox_host}:{sandbox_port}" os.environ["CODE_EXECUTION_API_KEY"] = "test_api_key" - logger.info("Dify Sandbox container started successfully - Host: %s, Port: %s", sandbox_host, sandbox_port) + logger.info( + "Dify Sandbox container started successfully - Image: %s Host: %s, Port: %s", + sandbox_image, + sandbox_host, + sandbox_port, + ) # Wait for Dify Sandbox to be ready logger.info("Waiting for Dify Sandbox to be ready to accept connections...") @@ -186,7 +196,7 @@ class DifyTestContainers: # Start Dify Plugin Daemon container for plugin management # Dify Plugin Daemon provides plugin lifecycle management and execution logger.info("Initializing Dify Plugin Daemon container...") - self.dify_plugin_daemon = DockerContainer(image="langgenius/dify-plugin-daemon:0.5.4-local").with_network( + self.dify_plugin_daemon = DockerContainer(image="langgenius/dify-plugin-daemon:0.5.3-local").with_network( self.network ) self.dify_plugin_daemon.with_exposed_ports(5002) diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py index 6f2e008d443..5cc458fe2ef 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py @@ -4,16 +4,16 @@ import json import uuid from flask.testing import FlaskClient +from graphon.enums import WorkflowExecutionStatus from sqlalchemy.orm import Session from configs import dify_config from constants import HEADER_NAME_CSRF_TOKEN -from dify_graph.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from libs.token import _real_cookie_name, generate_csrf_token from models import Account, DifySetup, Tenant, TenantAccountJoin from models.account import AccountStatus, TenantAccountRole -from models.enums import CreatorUserRole +from models.enums import ConversationFromSource, CreatorUserRole from models.model import App, AppMode, Conversation, Message from models.workflow import WorkflowRun from services.account_service import AccountService @@ -75,7 +75,7 @@ def _create_conversation(db_session: Session, app_id: str, account_id: str) -> C inputs={}, status="normal", mode=AppMode.CHAT, - from_source=CreatorUserRole.ACCOUNT, + from_source=ConversationFromSource.CONSOLE, from_account_id=account_id, ) db_session.add(conversation) @@ -124,7 +124,7 @@ def _create_message( answer_price_unit=0.001, currency="USD", status="normal", - from_source=CreatorUserRole.ACCOUNT, + from_source=ConversationFromSource.CONSOLE, from_account_id=account_id, workflow_run_id=workflow_run_id, inputs={"query": "Hello"}, diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_message.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_message.py new file mode 100644 index 00000000000..6b51ec98bcd --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_message.py @@ -0,0 +1,342 @@ +"""Authenticated controller integration tests for console message APIs.""" + +from datetime import timedelta +from decimal import Decimal +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from flask.testing import FlaskClient +from sqlalchemy import select +from sqlalchemy.orm import Session + +from controllers.console.app.message import ChatMessagesQuery, FeedbackExportQuery, MessageFeedbackPayload +from controllers.console.app.message import attach_message_extra_contents as _attach_message_extra_contents +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from libs.datetime_utils import naive_utc_now +from models.enums import ConversationFromSource, FeedbackRating +from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback +from services.errors.conversation import ConversationNotExistsError +from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, + create_console_app, +) + + +def _create_conversation(db_session: Session, app_id: str, account_id: str, mode: AppMode) -> Conversation: + conversation = Conversation( + app_id=app_id, + app_model_config_id=None, + model_provider=None, + model_id="", + override_model_configs=None, + mode=mode, + name="Test Conversation", + inputs={}, + introduction="", + system_instruction="", + system_instruction_tokens=0, + status="normal", + from_source=ConversationFromSource.CONSOLE, + from_account_id=account_id, + ) + db_session.add(conversation) + db_session.commit() + return conversation + + +def _create_message( + db_session: Session, + app_id: str, + conversation_id: str, + account_id: str, + *, + created_at_offset_seconds: int = 0, +) -> Message: + created_at = naive_utc_now() + timedelta(seconds=created_at_offset_seconds) + message = Message( + app_id=app_id, + model_provider=None, + model_id="", + override_model_configs=None, + conversation_id=conversation_id, + inputs={}, + query="Hello", + message={"type": "text", "content": "Hello"}, + message_tokens=1, + message_unit_price=Decimal("0.0001"), + message_price_unit=Decimal("0.001"), + answer="Hi there", + answer_tokens=1, + answer_unit_price=Decimal("0.0001"), + answer_price_unit=Decimal("0.001"), + parent_message_id=None, + provider_response_latency=0, + total_price=Decimal("0.0002"), + currency="USD", + from_source=ConversationFromSource.CONSOLE, + from_account_id=account_id, + created_at=created_at, + updated_at=created_at, + app_mode=AppMode.CHAT, + ) + db_session.add(message) + db_session.commit() + return message + + +class TestMessageValidators: + def test_chat_messages_query_validators(self) -> None: + assert ChatMessagesQuery.empty_to_none("") is None + assert ChatMessagesQuery.empty_to_none("val") == "val" + assert ChatMessagesQuery.validate_uuid(None) is None + assert ( + ChatMessagesQuery.validate_uuid("123e4567-e89b-12d3-a456-426614174000") + == "123e4567-e89b-12d3-a456-426614174000" + ) + + def test_message_feedback_validators(self) -> None: + assert ( + MessageFeedbackPayload.validate_message_id("123e4567-e89b-12d3-a456-426614174000") + == "123e4567-e89b-12d3-a456-426614174000" + ) + + def test_feedback_export_validators(self) -> None: + assert FeedbackExportQuery.parse_bool(None) is None + assert FeedbackExportQuery.parse_bool(True) is True + assert FeedbackExportQuery.parse_bool("1") is True + assert FeedbackExportQuery.parse_bool("0") is False + assert FeedbackExportQuery.parse_bool("off") is False + + with pytest.raises(ValueError): + FeedbackExportQuery.parse_bool("invalid") + + +def test_chat_message_list_not_found( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/chat-messages", + query_string={"conversation_id": str(uuid4())}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 404 + payload = response.get_json() + assert payload is not None + assert payload["code"] == "not_found" + + +def test_chat_message_list_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode) + _create_message(db_session_with_containers, app.id, conversation.id, account.id, created_at_offset_seconds=0) + second = _create_message( + db_session_with_containers, + app.id, + conversation.id, + account.id, + created_at_offset_seconds=1, + ) + + with patch( + "controllers.console.app.message.attach_message_extra_contents", + side_effect=_attach_message_extra_contents, + ): + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/chat-messages", + query_string={"conversation_id": conversation.id, "limit": 1}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert payload["limit"] == 1 + assert payload["has_more"] is True + assert len(payload["data"]) == 1 + assert payload["data"][0]["id"] == second.id + + +def test_message_feedback_not_found( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + + response = test_client_with_containers.post( + f"/console/api/apps/{app.id}/feedbacks", + json={"message_id": str(uuid4()), "rating": "like"}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 404 + payload = response.get_json() + assert payload is not None + assert payload["code"] == "not_found" + + +def test_message_feedback_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode) + message = _create_message(db_session_with_containers, app.id, conversation.id, account.id) + + response = test_client_with_containers.post( + f"/console/api/apps/{app.id}/feedbacks", + json={"message_id": message.id, "rating": "like"}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"result": "success"} + + feedback = db_session_with_containers.scalar( + select(MessageFeedback).where(MessageFeedback.message_id == message.id) + ) + assert feedback is not None + assert feedback.rating == FeedbackRating.LIKE + assert feedback.from_account_id == account.id + + +def test_message_annotation_count( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode) + message = _create_message(db_session_with_containers, app.id, conversation.id, account.id) + db_session_with_containers.add( + MessageAnnotation( + app_id=app.id, + conversation_id=conversation.id, + message_id=message.id, + question="Q", + content="A", + account_id=account.id, + ) + ) + db_session_with_containers.commit() + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/annotations/count", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"count": 1} + + +def test_message_suggested_questions_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + message_id = str(uuid4()) + + with patch( + "controllers.console.app.message.MessageService.get_suggested_questions_after_answer", + return_value=["q1", "q2"], + ): + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/chat-messages/{message_id}/suggested-questions", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"data": ["q1", "q2"]} + + +@pytest.mark.parametrize( + ("exc", "expected_status", "expected_code"), + [ + (MessageNotExistsError(), 404, "not_found"), + (ConversationNotExistsError(), 404, "not_found"), + (ProviderTokenNotInitError(), 400, "provider_not_initialize"), + (QuotaExceededError(), 400, "provider_quota_exceeded"), + (ModelCurrentlyNotSupportError(), 400, "model_currently_not_support"), + (SuggestedQuestionsAfterAnswerDisabledError(), 403, "app_suggested_questions_after_answer_disabled"), + (Exception(), 500, "internal_server_error"), + ], +) +def test_message_suggested_questions_errors( + exc: Exception, + expected_status: int, + expected_code: str, + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + message_id = str(uuid4()) + + with patch( + "controllers.console.app.message.MessageService.get_suggested_questions_after_answer", + side_effect=exc, + ): + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/chat-messages/{message_id}/suggested-questions", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == expected_status + payload = response.get_json() + assert payload is not None + assert payload["code"] == expected_code + + +def test_message_feedback_export_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + + with patch("services.feedback_service.FeedbackService.export_feedbacks", return_value={"exported": True}): + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/feedbacks/export", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"exported": True} + + +def test_message_api_get_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode) + message = _create_message(db_session_with_containers, app.id, conversation.id, account.id) + + with patch( + "controllers.console.app.message.attach_message_extra_contents", + side_effect=_attach_message_extra_contents, + ): + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/messages/{message.id}", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert payload["id"] == message.id diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_statistic.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_statistic.py new file mode 100644 index 00000000000..963cfe53e55 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_statistic.py @@ -0,0 +1,334 @@ +"""Controller integration tests for console statistic routes.""" + +from datetime import timedelta +from decimal import Decimal +from unittest.mock import patch +from uuid import uuid4 + +from flask.testing import FlaskClient +from sqlalchemy.orm import Session + +from core.app.entities.app_invoke_entities import InvokeFrom +from libs.datetime_utils import naive_utc_now +from models.enums import ConversationFromSource, FeedbackFromSource, FeedbackRating +from models.model import AppMode, Conversation, Message, MessageFeedback +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, + create_console_app, +) + + +def _create_conversation( + db_session: Session, + app_id: str, + account_id: str, + *, + mode: AppMode, + created_at_offset_days: int = 0, +) -> Conversation: + created_at = naive_utc_now() + timedelta(days=created_at_offset_days) + conversation = Conversation( + app_id=app_id, + app_model_config_id=None, + model_provider=None, + model_id="", + override_model_configs=None, + mode=mode, + name="Stats Conversation", + inputs={}, + introduction="", + system_instruction="", + system_instruction_tokens=0, + status="normal", + from_source=ConversationFromSource.CONSOLE, + from_account_id=account_id, + created_at=created_at, + updated_at=created_at, + ) + db_session.add(conversation) + db_session.commit() + return conversation + + +def _create_message( + db_session: Session, + app_id: str, + conversation_id: str, + *, + from_account_id: str | None, + from_end_user_id: str | None = None, + message_tokens: int = 1, + answer_tokens: int = 1, + total_price: Decimal = Decimal("0.01"), + provider_response_latency: float = 1.0, + created_at_offset_days: int = 0, +) -> Message: + created_at = naive_utc_now() + timedelta(days=created_at_offset_days) + message = Message( + app_id=app_id, + model_provider=None, + model_id="", + override_model_configs=None, + conversation_id=conversation_id, + inputs={}, + query="Hello", + message={"type": "text", "content": "Hello"}, + message_tokens=message_tokens, + message_unit_price=Decimal("0.001"), + message_price_unit=Decimal("0.001"), + answer="Hi there", + answer_tokens=answer_tokens, + answer_unit_price=Decimal("0.001"), + answer_price_unit=Decimal("0.001"), + parent_message_id=None, + provider_response_latency=provider_response_latency, + total_price=total_price, + currency="USD", + invoke_from=InvokeFrom.EXPLORE, + from_source=ConversationFromSource.CONSOLE, + from_end_user_id=from_end_user_id, + from_account_id=from_account_id, + created_at=created_at, + updated_at=created_at, + app_mode=AppMode.CHAT, + ) + db_session.add(message) + db_session.commit() + return message + + +def _create_like_feedback( + db_session: Session, + app_id: str, + conversation_id: str, + message_id: str, + account_id: str, +) -> None: + db_session.add( + MessageFeedback( + app_id=app_id, + conversation_id=conversation_id, + message_id=message_id, + rating=FeedbackRating.LIKE, + from_source=FeedbackFromSource.ADMIN, + from_account_id=account_id, + ) + ) + db_session.commit() + + +def test_daily_message_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/daily-messages", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["message_count"] == 1 + + +def test_daily_conversation_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/daily-conversations", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["conversation_count"] == 1 + + +def test_daily_terminals_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message( + db_session_with_containers, + app.id, + conversation.id, + from_account_id=None, + from_end_user_id=str(uuid4()), + ) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/daily-end-users", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["terminal_count"] == 1 + + +def test_daily_token_cost_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message( + db_session_with_containers, + app.id, + conversation.id, + from_account_id=account.id, + message_tokens=40, + answer_tokens=60, + total_price=Decimal("0.02"), + ) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/token-costs", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload["data"][0]["token_count"] == 100 + assert Decimal(payload["data"][0]["total_price"]) == Decimal("0.02") + + +def test_average_session_interaction_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/average-session-interactions", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["interactions"] == 2.0 + + +def test_user_satisfaction_rate_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + first = _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + for _ in range(9): + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + _create_like_feedback(db_session_with_containers, app.id, conversation.id, first.id, account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/user-satisfaction-rate", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["rate"] == 100.0 + + +def test_average_response_time_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.COMPLETION) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message( + db_session_with_containers, + app.id, + conversation.id, + from_account_id=account.id, + provider_response_latency=1.234, + ) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/average-response-time", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["latency"] == 1234.0 + + +def test_tokens_per_second_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message( + db_session_with_containers, + app.id, + conversation.id, + from_account_id=account.id, + answer_tokens=31, + provider_response_latency=2.0, + ) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/tokens-per-second", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["tps"] == 15.5 + + +def test_invalid_time_range( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + + with patch("controllers.console.app.statistic.parse_time_range", side_effect=ValueError("Invalid time")): + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/daily-messages?start=invalid&end=invalid", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "Invalid time" + + +def test_time_range_params_passed( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + import datetime + + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + start = datetime.datetime.now() + end = datetime.datetime.now() + + with patch("controllers.console.app.statistic.parse_time_range", return_value=(start, end)) as mock_parse: + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/daily-messages?start=something&end=something", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + mock_parse.assert_called_once_with("something", "something", "UTC") diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py new file mode 100644 index 00000000000..8ddf867370e --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py @@ -0,0 +1,415 @@ +"""Authenticated controller integration tests for workflow draft variable APIs.""" + +import uuid + +from flask.testing import FlaskClient +from graphon.variables.segments import StringSegment +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID +from factories.variable_factory import segment_to_variable +from models import Workflow +from models.model import AppMode +from models.workflow import WorkflowDraftVariable +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, + create_console_app, +) + + +def _create_draft_workflow( + db_session: Session, + app_id: str, + tenant_id: str, + account_id: str, + *, + environment_variables: list | None = None, + conversation_variables: list | None = None, +) -> Workflow: + workflow = Workflow.new( + tenant_id=tenant_id, + app_id=app_id, + type="workflow", + version=Workflow.VERSION_DRAFT, + graph='{"nodes": [], "edges": []}', + features="{}", + created_by=account_id, + environment_variables=environment_variables or [], + conversation_variables=conversation_variables or [], + rag_pipeline_variables=[], + ) + db_session.add(workflow) + db_session.commit() + return workflow + + +def _create_node_variable( + db_session: Session, + app_id: str, + user_id: str, + *, + node_id: str = "node_1", + name: str = "test_var", +) -> WorkflowDraftVariable: + variable = WorkflowDraftVariable.new_node_variable( + app_id=app_id, + user_id=user_id, + node_id=node_id, + name=name, + value=StringSegment(value="test_value"), + node_execution_id=str(uuid.uuid4()), + visible=True, + editable=True, + ) + db_session.add(variable) + db_session.commit() + return variable + + +def _create_system_variable( + db_session: Session, app_id: str, user_id: str, name: str = "query" +) -> WorkflowDraftVariable: + variable = WorkflowDraftVariable.new_sys_variable( + app_id=app_id, + user_id=user_id, + name=name, + value=StringSegment(value="system-value"), + node_execution_id=str(uuid.uuid4()), + editable=True, + ) + db_session.add(variable) + db_session.commit() + return variable + + +def _build_environment_variable(name: str, value: str): + return segment_to_variable( + segment=StringSegment(value=value), + selector=[ENVIRONMENT_VARIABLE_NODE_ID, name], + name=name, + description=f"Environment variable {name}", + ) + + +def _build_conversation_variable(name: str, value: str): + return segment_to_variable( + segment=StringSegment(value=value), + selector=[CONVERSATION_VARIABLE_NODE_ID, name], + name=name, + description=f"Conversation variable {name}", + ) + + +def test_workflow_variable_collection_get_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/variables?page=1&limit=20", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"items": [], "total": 0} + + +def test_workflow_variable_collection_get_not_exist( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 404 + payload = response.get_json() + assert payload is not None + assert payload["code"] == "draft_workflow_not_exist" + + +def test_workflow_variable_collection_delete( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_node_variable(db_session_with_containers, app.id, account.id) + _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_2", name="other_var") + + response = test_client_with_containers.delete( + f"/console/api/apps/{app.id}/workflows/draft/variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 204 + remaining = db_session_with_containers.scalars( + select(WorkflowDraftVariable).where( + WorkflowDraftVariable.app_id == app.id, + WorkflowDraftVariable.user_id == account.id, + ) + ).all() + assert remaining == [] + + +def test_node_variable_collection_get_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + node_variable = _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_123") + _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_456", name="other") + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/nodes/node_123/variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert [item["id"] for item in payload["items"]] == [node_variable.id] + + +def test_node_variable_collection_get_invalid_node_id( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/nodes/sys/variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 400 + payload = response.get_json() + assert payload is not None + assert payload["code"] == "invalid_param" + + +def test_node_variable_collection_delete( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + target = _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_123") + untouched = _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_456") + target_id = target.id + untouched_id = untouched.id + + response = test_client_with_containers.delete( + f"/console/api/apps/{app.id}/workflows/draft/nodes/node_123/variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 204 + assert ( + db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == target_id)) + is None + ) + assert ( + db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == untouched_id)) + is not None + ) + + +def test_variable_api_get_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id) + variable = _create_node_variable(db_session_with_containers, app.id, account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert payload["id"] == variable.id + assert payload["name"] == "test_var" + + +def test_variable_api_get_not_found( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/variables/{uuid.uuid4()}", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 404 + payload = response.get_json() + assert payload is not None + assert payload["code"] == "not_found" + + +def test_variable_api_patch_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id) + variable = _create_node_variable(db_session_with_containers, app.id, account.id) + + response = test_client_with_containers.patch( + f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}", + headers=authenticate_console_client(test_client_with_containers, account), + json={"name": "renamed_var"}, + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert payload["id"] == variable.id + assert payload["name"] == "renamed_var" + + refreshed = db_session_with_containers.scalar( + select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == variable.id) + ) + assert refreshed is not None + assert refreshed.name == "renamed_var" + + +def test_variable_api_delete_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id) + variable = _create_node_variable(db_session_with_containers, app.id, account.id) + + response = test_client_with_containers.delete( + f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 204 + assert ( + db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == variable.id)) + is None + ) + + +def test_variable_reset_api_put_success_returns_no_content_without_execution( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id) + variable = _create_node_variable(db_session_with_containers, app.id, account.id) + + response = test_client_with_containers.put( + f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}/reset", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 204 + assert ( + db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == variable.id)) + is None + ) + + +def test_conversation_variable_collection_get( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow( + db_session_with_containers, + app.id, + tenant.id, + account.id, + conversation_variables=[_build_conversation_variable("session_name", "Alice")], + ) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/conversation-variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert [item["name"] for item in payload["items"]] == ["session_name"] + + created = db_session_with_containers.scalars( + select(WorkflowDraftVariable).where( + WorkflowDraftVariable.app_id == app.id, + WorkflowDraftVariable.user_id == account.id, + WorkflowDraftVariable.node_id == CONVERSATION_VARIABLE_NODE_ID, + ) + ).all() + assert len(created) == 1 + + +def test_system_variable_collection_get( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + variable = _create_system_variable(db_session_with_containers, app.id, account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/system-variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert [item["id"] for item in payload["items"]] == [variable.id] + + +def test_environment_variable_collection_get( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow( + db_session_with_containers, + app.id, + tenant.id, + account.id, + environment_variables=[_build_environment_variable("api_key", "secret-value")], + ) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/environment-variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert payload["items"][0]["name"] == "api_key" + assert payload["items"][0]["value"] == "secret-value" diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py new file mode 100644 index 00000000000..00309c25d65 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py @@ -0,0 +1,131 @@ +"""Controller integration tests for API key data source auth routes.""" + +import json +from unittest.mock import patch + +from flask.testing import FlaskClient +from sqlalchemy import select +from sqlalchemy.orm import Session + +from models.source import DataSourceApiKeyAuthBinding +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, +) + + +def test_get_api_key_auth_data_source( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + binding = DataSourceApiKeyAuthBinding( + tenant_id=tenant.id, + category="api_key", + provider="custom_provider", + credentials=json.dumps({"auth_type": "api_key", "config": {"api_key": "encrypted"}}), + disabled=False, + ) + db_session_with_containers.add(binding) + db_session_with_containers.commit() + + response = test_client_with_containers.get( + "/console/api/api-key-auth/data-source", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert len(payload["sources"]) == 1 + assert payload["sources"][0]["provider"] == "custom_provider" + + +def test_get_api_key_auth_data_source_empty( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + + response = test_client_with_containers.get( + "/console/api/api-key-auth/data-source", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"sources": []} + + +def test_create_binding_successful( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + + with ( + patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"), + patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth"), + ): + response = test_client_with_containers.post( + "/console/api/api-key-auth/data-source/binding", + json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"result": "success"} + + +def test_create_binding_failure( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + + with ( + patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"), + patch( + "controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth", + side_effect=ValueError("Invalid structure"), + ), + ): + response = test_client_with_containers.post( + "/console/api/api-key-auth/data-source/binding", + json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 500 + payload = response.get_json() + assert payload is not None + assert payload["code"] == "auth_failed" + assert payload["message"] == "Invalid structure" + + +def test_delete_binding_successful( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + binding = DataSourceApiKeyAuthBinding( + tenant_id=tenant.id, + category="api_key", + provider="custom_provider", + credentials=json.dumps({"auth_type": "api_key", "config": {"api_key": "encrypted"}}), + disabled=False, + ) + db_session_with_containers.add(binding) + db_session_with_containers.commit() + + response = test_client_with_containers.delete( + f"/console/api/api-key-auth/data-source/{binding.id}", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 204 + assert ( + db_session_with_containers.scalar( + select(DataSourceApiKeyAuthBinding).where(DataSourceApiKeyAuthBinding.id == binding.id) + ) + is None + ) diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_oauth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_oauth.py new file mode 100644 index 00000000000..81b54232615 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_oauth.py @@ -0,0 +1,120 @@ +"""Controller integration tests for console OAuth data source routes.""" + +from unittest.mock import MagicMock, patch + +from flask.testing import FlaskClient +from sqlalchemy.orm import Session + +from models.source import DataSourceOauthBinding +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, +) + + +def test_get_oauth_url_successful( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + provider = MagicMock() + provider.get_authorization_url.return_value = "http://oauth.provider/auth" + + with ( + patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": provider}), + patch("controllers.console.auth.data_source_oauth.dify_config.NOTION_INTEGRATION_TYPE", None), + ): + response = test_client_with_containers.get( + "/console/api/oauth/data-source/notion", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert tenant.id == account.current_tenant_id + assert response.status_code == 200 + assert response.get_json() == {"data": "http://oauth.provider/auth"} + provider.get_authorization_url.assert_called_once() + + +def test_get_oauth_url_invalid_provider( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + + with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}): + response = test_client_with_containers.get( + "/console/api/oauth/data-source/unknown_provider", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 400 + assert response.get_json() == {"error": "Invalid provider"} + + +def test_oauth_callback_successful(test_client_with_containers: FlaskClient) -> None: + with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}): + response = test_client_with_containers.get("/console/api/oauth/data-source/callback/notion?code=mock_code") + + assert response.status_code == 302 + assert "code=mock_code" in response.location + + +def test_oauth_callback_missing_code(test_client_with_containers: FlaskClient) -> None: + with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}): + response = test_client_with_containers.get("/console/api/oauth/data-source/callback/notion") + + assert response.status_code == 302 + assert "error=Access%20denied" in response.location + + +def test_oauth_callback_invalid_provider(test_client_with_containers: FlaskClient) -> None: + with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}): + response = test_client_with_containers.get("/console/api/oauth/data-source/callback/invalid?code=mock_code") + + assert response.status_code == 400 + assert response.get_json() == {"error": "Invalid provider"} + + +def test_get_binding_successful(test_client_with_containers: FlaskClient) -> None: + provider = MagicMock() + with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": provider}): + response = test_client_with_containers.get("/console/api/oauth/data-source/binding/notion?code=auth_code_123") + + assert response.status_code == 200 + assert response.get_json() == {"result": "success"} + provider.get_access_token.assert_called_once_with("auth_code_123") + + +def test_get_binding_missing_code(test_client_with_containers: FlaskClient) -> None: + with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}): + response = test_client_with_containers.get("/console/api/oauth/data-source/binding/notion?code=") + + assert response.status_code == 400 + assert response.get_json() == {"error": "Invalid code"} + + +def test_sync_successful( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + binding = DataSourceOauthBinding( + tenant_id=tenant.id, + access_token="test-access-token", + provider="notion", + source_info={"workspace_name": "Workspace", "workspace_icon": None, "workspace_id": tenant.id, "pages": []}, + disabled=False, + ) + db_session_with_containers.add(binding) + db_session_with_containers.commit() + + provider = MagicMock() + with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": provider}): + response = test_client_with_containers.get( + f"/console/api/oauth/data-source/notion/{binding.id}/sync", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"result": "success"} + provider.sync_data_source.assert_called_once_with(binding.id) diff --git a/api/tests/unit_tests/controllers/console/auth/test_email_register.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py similarity index 82% rename from api/tests/unit_tests/controllers/console/auth/test_email_register.py rename to api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py index 724c80f18c7..879c337319c 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_email_register.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py @@ -1,8 +1,11 @@ +"""Testcontainers integration tests for email register controller endpoints.""" + +from __future__ import annotations + from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest -from flask import Flask from controllers.console.auth.email_register import ( EmailRegisterCheckApi, @@ -13,14 +16,11 @@ from services.account_service import AccountService @pytest.fixture -def app(): - flask_app = Flask(__name__) - flask_app.config["TESTING"] = True - return flask_app +def app(flask_app_with_containers): + return flask_app_with_containers class TestEmailRegisterSendEmailApi: - @patch("controllers.console.auth.email_register.Session") @patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.email_register.AccountService.send_email_register_email") @patch("controllers.console.auth.email_register.BillingService.is_email_in_freeze") @@ -33,20 +33,15 @@ class TestEmailRegisterSendEmailApi: mock_is_freeze, mock_send_mail, mock_get_account, - mock_session_cls, app, ): mock_send_mail.return_value = "token-123" mock_is_freeze.return_value = False mock_account = MagicMock() - - mock_session = MagicMock() - mock_session_cls.return_value.__enter__.return_value = mock_session mock_get_account.return_value = mock_account feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) with ( - patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")), patch("controllers.console.auth.email_register.dify_config", SimpleNamespace(BILLING_ENABLED=True)), patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags), @@ -61,7 +56,6 @@ class TestEmailRegisterSendEmailApi: assert response == {"result": "success", "data": "token-123"} mock_is_freeze.assert_called_once_with("invitee@example.com") mock_send_mail.assert_called_once_with(email="invitee@example.com", account=mock_account, language="en-US") - mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session) mock_extract_ip.assert_called_once() mock_is_email_send_ip_limit.assert_called_once_with("127.0.0.1") @@ -89,7 +83,6 @@ class TestEmailRegisterCheckApi: feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) with ( - patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")), patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags), ): @@ -114,7 +107,6 @@ class TestEmailRegisterResetApi: @patch("controllers.console.auth.email_register.AccountService.reset_login_error_rate_limit") @patch("controllers.console.auth.email_register.AccountService.login") @patch("controllers.console.auth.email_register.EmailRegisterResetApi._create_new_account") - @patch("controllers.console.auth.email_register.Session") @patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token") @patch("controllers.console.auth.email_register.AccountService.get_email_register_data") @@ -125,7 +117,6 @@ class TestEmailRegisterResetApi: mock_get_data, mock_revoke_token, mock_get_account, - mock_session_cls, mock_create_account, mock_login, mock_reset_login_rate, @@ -136,14 +127,10 @@ class TestEmailRegisterResetApi: token_pair = MagicMock() token_pair.model_dump.return_value = {"access_token": "a", "refresh_token": "r"} mock_login.return_value = token_pair - - mock_session = MagicMock() - mock_session_cls.return_value.__enter__.return_value = mock_session mock_get_account.return_value = None feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) with ( - patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")), patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags), ): @@ -159,19 +146,19 @@ class TestEmailRegisterResetApi: mock_reset_login_rate.assert_called_once_with("invitee@example.com") mock_revoke_token.assert_called_once_with("token-123") mock_extract_ip.assert_called_once() - mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session) -def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(): +def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(): + """Test that case fallback tries lowercase when exact match fails.""" mock_session = MagicMock() - first_query = MagicMock() - first_query.scalar_one_or_none.return_value = None + first_result = MagicMock() + first_result.scalar_one_or_none.return_value = None expected_account = MagicMock() - second_query = MagicMock() - second_query.scalar_one_or_none.return_value = expected_account - mock_session.execute.side_effect = [first_query, second_query] + second_result = MagicMock() + second_result.scalar_one_or_none.return_value = expected_account + mock_session.execute.side_effect = [first_result, second_result] - account = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) + result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) - assert account is expected_account + assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/unit_tests/controllers/console/auth/test_forgot_password.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py similarity index 82% rename from api/tests/unit_tests/controllers/console/auth/test_forgot_password.py rename to api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py index 8403777dc9f..7b7393dade9 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_forgot_password.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py @@ -1,8 +1,11 @@ +"""Testcontainers integration tests for forgot password controller endpoints.""" + +from __future__ import annotations + from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest -from flask import Flask from controllers.console.auth.forgot_password import ( ForgotPasswordCheckApi, @@ -13,14 +16,11 @@ from services.account_service import AccountService @pytest.fixture -def app(): - flask_app = Flask(__name__) - flask_app.config["TESTING"] = True - return flask_app +def app(flask_app_with_containers): + return flask_app_with_containers class TestForgotPasswordSendEmailApi: - @patch("controllers.console.auth.forgot_password.Session") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email") @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit", return_value=False) @@ -31,19 +31,15 @@ class TestForgotPasswordSendEmailApi: mock_is_ip_limit, mock_send_email, mock_get_account, - mock_session_cls, app, ): mock_account = MagicMock() mock_get_account.return_value = mock_account mock_send_email.return_value = "token-123" - mock_session = MagicMock() - mock_session_cls.return_value.__enter__.return_value = mock_session wraps_features = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) controller_features = SimpleNamespace(is_allow_register=True) with ( - patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")), patch( "controllers.console.auth.forgot_password.FeatureService.get_system_features", return_value=controller_features, @@ -59,7 +55,6 @@ class TestForgotPasswordSendEmailApi: response = ForgotPasswordSendEmailApi().post() assert response == {"result": "success", "data": "token-123"} - mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) mock_send_email.assert_called_once_with( account=mock_account, email="user@example.com", @@ -117,7 +112,6 @@ class TestForgotPasswordCheckApi: class TestForgotPasswordResetApi: @patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account") - @patch("controllers.console.auth.forgot_password.Session") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @@ -126,7 +120,6 @@ class TestForgotPasswordResetApi: mock_get_reset_data, mock_revoke_token, mock_get_account, - mock_session_cls, mock_update_account, app, ): @@ -134,12 +127,8 @@ class TestForgotPasswordResetApi: mock_account = MagicMock() mock_get_account.return_value = mock_account - mock_session = MagicMock() - mock_session_cls.return_value.__enter__.return_value = mock_session - wraps_features = SimpleNamespace(enable_email_password_login=True) with ( - patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")), patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features), ): @@ -157,20 +146,22 @@ class TestForgotPasswordResetApi: assert response == {"result": "success"} mock_get_reset_data.assert_called_once_with("token-123") mock_revoke_token.assert_called_once_with("token-123") - mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) mock_update_account.assert_called_once() -def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(): +def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(): + """Test that case fallback tries lowercase when exact match fails.""" + from unittest.mock import MagicMock + mock_session = MagicMock() - first_query = MagicMock() - first_query.scalar_one_or_none.return_value = None + first_result = MagicMock() + first_result.scalar_one_or_none.return_value = None expected_account = MagicMock() - second_query = MagicMock() - second_query.scalar_one_or_none.return_value = expected_account - mock_session.execute.side_effect = [first_query, second_query] + second_result = MagicMock() + second_result.scalar_one_or_none.return_value = expected_account + mock_session.execute.side_effect = [first_result, second_result] - account = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session) + result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session) - assert account is expected_account + assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py similarity index 92% rename from api/tests/unit_tests/controllers/console/auth/test_oauth.py rename to api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py index 6345c2ab23a..a2f1328579f 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py @@ -1,7 +1,10 @@ +"""Testcontainers integration tests for OAuth controller endpoints.""" + +from __future__ import annotations + from unittest.mock import MagicMock, patch import pytest -from flask import Flask from controllers.console.auth.oauth import ( OAuthCallback, @@ -18,10 +21,8 @@ from services.errors.account import AccountRegisterError class TestGetOAuthProviders: @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers @pytest.mark.parametrize( ("github_config", "google_config", "expected_github", "expected_google"), @@ -64,10 +65,8 @@ class TestOAuthLogin: return OAuthLogin() @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers @pytest.fixture def mock_oauth_provider(self): @@ -131,10 +130,8 @@ class TestOAuthCallback: return OAuthCallback() @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers @pytest.fixture def oauth_setup(self): @@ -190,15 +187,8 @@ class TestOAuthCallback: (KeyError("Missing key"), "OAuth process failed"), ], ) - @patch("controllers.console.auth.oauth.db") @patch("controllers.console.auth.oauth.get_oauth_providers") - def test_should_handle_oauth_exceptions( - self, mock_get_providers, mock_db, resource, app, exception, expected_error - ): - # Mock database session - mock_db.session = MagicMock() - mock_db.session.rollback = MagicMock() - + def test_should_handle_oauth_exceptions(self, mock_get_providers, resource, app, exception, expected_error): # Import the real requests module to create a proper exception import httpx @@ -258,7 +248,6 @@ class TestOAuthCallback: ) @patch("controllers.console.auth.oauth.AccountService") @patch("controllers.console.auth.oauth.TenantService") - @patch("controllers.console.auth.oauth.db") @patch("controllers.console.auth.oauth.dify_config") @patch("controllers.console.auth.oauth.get_oauth_providers") @patch("controllers.console.auth.oauth._generate_account") @@ -269,7 +258,6 @@ class TestOAuthCallback: mock_generate_account, mock_get_providers, mock_config, - mock_db, mock_tenant_service, mock_account_service, resource, @@ -278,10 +266,6 @@ class TestOAuthCallback: account_status, expected_redirect, ): - # Mock database session - mock_db.session = MagicMock() - mock_db.session.rollback = MagicMock() - mock_db.session.commit = MagicMock() mock_config.CONSOLE_WEB_URL = "http://localhost:3000" mock_get_providers.return_value = {"github": oauth_setup["provider"]} @@ -306,14 +290,12 @@ class TestOAuthCallback: @patch("controllers.console.auth.oauth.dify_config") @patch("controllers.console.auth.oauth.get_oauth_providers") @patch("controllers.console.auth.oauth._generate_account") - @patch("controllers.console.auth.oauth.db") @patch("controllers.console.auth.oauth.TenantService") @patch("controllers.console.auth.oauth.AccountService") def test_should_activate_pending_account( self, mock_account_service, mock_tenant_service, - mock_db, mock_generate_account, mock_get_providers, mock_config, @@ -338,12 +320,10 @@ class TestOAuthCallback: assert mock_account.status == AccountStatus.ACTIVE assert mock_account.initialized_at is not None - mock_db.session.commit.assert_called_once() @patch("controllers.console.auth.oauth.dify_config") @patch("controllers.console.auth.oauth.get_oauth_providers") @patch("controllers.console.auth.oauth._generate_account") - @patch("controllers.console.auth.oauth.db") @patch("controllers.console.auth.oauth.TenantService") @patch("controllers.console.auth.oauth.AccountService") @patch("controllers.console.auth.oauth.redirect") @@ -352,7 +332,6 @@ class TestOAuthCallback: mock_redirect, mock_account_service, mock_tenant_service, - mock_db, mock_generate_account, mock_get_providers, mock_config, @@ -414,6 +393,10 @@ class TestOAuthCallback: class TestAccountGeneration: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + @pytest.fixture def user_info(self): return OAuthUserInfo(id="123", name="Test User", email="test@example.com") @@ -425,15 +408,10 @@ class TestAccountGeneration: return account @patch("controllers.console.auth.oauth.AccountService.get_account_by_email_with_case_fallback") - @patch("controllers.console.auth.oauth.Session") @patch("controllers.console.auth.oauth.Account") - @patch("controllers.console.auth.oauth.db") def test_should_get_account_by_openid_or_email( - self, mock_db, mock_account_model, mock_session, mock_get_account, user_info, mock_account + self, mock_account_model, mock_get_account, flask_req_ctx_with_containers, user_info, mock_account ): - # Mock db.engine for Session creation - mock_db.engine = MagicMock() - # Test OpenID found mock_account_model.get_by_openid.return_value = mock_account result = _get_account_by_openid_or_email("github", user_info) @@ -443,15 +421,14 @@ class TestAccountGeneration: # Test fallback to email lookup mock_account_model.get_by_openid.return_value = None - mock_session_instance = MagicMock() - mock_session.return_value.__enter__.return_value = mock_session_instance mock_get_account.return_value = mock_account result = _get_account_by_openid_or_email("github", user_info) assert result == mock_account - mock_get_account.assert_called_once_with(user_info.email, session=mock_session_instance) + mock_get_account.assert_called_once() - def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(self): + def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(self): + """Test that case fallback tries lowercase when exact match fails.""" mock_session = MagicMock() first_result = MagicMock() first_result.scalar_one_or_none.return_value = None @@ -462,7 +439,7 @@ class TestAccountGeneration: result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) - assert result == expected_account + assert result is expected_account assert mock_session.execute.call_count == 2 @pytest.mark.parametrize( @@ -478,10 +455,8 @@ class TestAccountGeneration: @patch("controllers.console.auth.oauth.RegisterService") @patch("controllers.console.auth.oauth.AccountService") @patch("controllers.console.auth.oauth.TenantService") - @patch("controllers.console.auth.oauth.db") def test_should_handle_account_generation_scenarios( self, - mock_db, mock_tenant_service, mock_account_service, mock_register_service, @@ -519,10 +494,8 @@ class TestAccountGeneration: @patch("controllers.console.auth.oauth.RegisterService") @patch("controllers.console.auth.oauth.AccountService") @patch("controllers.console.auth.oauth.TenantService") - @patch("controllers.console.auth.oauth.db") def test_should_register_with_lowercase_email( self, - mock_db, mock_tenant_service, mock_account_service, mock_register_service, diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth_server.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth_server.py new file mode 100644 index 00000000000..2ef27133d8b --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth_server.py @@ -0,0 +1,365 @@ +"""Controller integration tests for console OAuth server routes.""" + +from unittest.mock import patch + +from flask.testing import FlaskClient +from sqlalchemy.orm import Session + +from models.model import OAuthProviderApp +from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, + ensure_dify_setup, +) + + +def _build_oauth_provider_app() -> OAuthProviderApp: + return OAuthProviderApp( + app_icon="icon_url", + client_id="test_client_id", + client_secret="test_secret", + app_label={"en-US": "Test App"}, + redirect_uris=["http://localhost/callback"], + scope="read,write", + ) + + +def test_oauth_provider_successful_post( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider", + json={"client_id": "test_client_id", "redirect_uri": "http://localhost/callback"}, + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert payload["app_icon"] == "icon_url" + assert payload["app_label"] == {"en-US": "Test App"} + assert payload["scope"] == "read,write" + + +def test_oauth_provider_invalid_redirect_uri( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider", + json={"client_id": "test_client_id", "redirect_uri": "http://invalid/callback"}, + ) + + assert response.status_code == 400 + payload = response.get_json() + assert payload is not None + assert "redirect_uri is invalid" in payload["message"] + + +def test_oauth_provider_invalid_client_id( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + response = test_client_with_containers.post( + "/console/api/oauth/provider", + json={"client_id": "test_invalid_client_id", "redirect_uri": "http://localhost/callback"}, + ) + + assert response.status_code == 404 + payload = response.get_json() + assert payload is not None + assert "client_id is invalid" in payload["message"] + + +def test_oauth_authorize_successful( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + + with ( + patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ), + patch( + "controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_authorization_code", + return_value="auth_code_123", + ) as mock_sign, + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/authorize", + json={"client_id": "test_client_id"}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"code": "auth_code_123"} + mock_sign.assert_called_once_with("test_client_id", account.id) + + +def test_oauth_token_authorization_code_grant( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with ( + patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ), + patch( + "controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token", + return_value=("access_123", "refresh_123"), + ), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={ + "client_id": "test_client_id", + "grant_type": "authorization_code", + "code": "auth_code", + "client_secret": "test_secret", + "redirect_uri": "http://localhost/callback", + }, + ) + + assert response.status_code == 200 + assert response.get_json() == { + "access_token": "access_123", + "token_type": "Bearer", + "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, + "refresh_token": "refresh_123", + } + + +def test_oauth_token_authorization_code_grant_missing_code( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={ + "client_id": "test_client_id", + "grant_type": "authorization_code", + "client_secret": "test_secret", + "redirect_uri": "http://localhost/callback", + }, + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "code is required" + + +def test_oauth_token_authorization_code_grant_invalid_secret( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={ + "client_id": "test_client_id", + "grant_type": "authorization_code", + "code": "auth_code", + "client_secret": "invalid_secret", + "redirect_uri": "http://localhost/callback", + }, + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "client_secret is invalid" + + +def test_oauth_token_authorization_code_grant_invalid_redirect_uri( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={ + "client_id": "test_client_id", + "grant_type": "authorization_code", + "code": "auth_code", + "client_secret": "test_secret", + "redirect_uri": "http://invalid/callback", + }, + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "redirect_uri is invalid" + + +def test_oauth_token_refresh_token_grant( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with ( + patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ), + patch( + "controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token", + return_value=("new_access", "new_refresh"), + ), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={"client_id": "test_client_id", "grant_type": "refresh_token", "refresh_token": "refresh_123"}, + ) + + assert response.status_code == 200 + assert response.get_json() == { + "access_token": "new_access", + "token_type": "Bearer", + "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, + "refresh_token": "new_refresh", + } + + +def test_oauth_token_refresh_token_grant_missing_token( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={"client_id": "test_client_id", "grant_type": "refresh_token"}, + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "refresh_token is required" + + +def test_oauth_token_invalid_grant_type( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={"client_id": "test_client_id", "grant_type": "invalid_grant"}, + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "invalid grant_type" + + +def test_oauth_account_successful_retrieval( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + account.avatar = "avatar_url" + db_session_with_containers.commit() + + with ( + patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ), + patch( + "controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token", + return_value=account, + ), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/account", + json={"client_id": "test_client_id"}, + headers={"Authorization": "Bearer valid_access_token"}, + ) + + assert response.status_code == 200 + assert response.get_json() == { + "name": "Test User", + "email": account.email, + "avatar": "avatar_url", + "interface_language": "en-US", + "timezone": "UTC", + } + + +def test_oauth_account_missing_authorization_header( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/account", + json={"client_id": "test_client_id"}, + ) + + assert response.status_code == 401 + assert response.get_json() == {"error": "Authorization header is required"} + + +def test_oauth_account_invalid_authorization_header_format( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/account", + json={"client_id": "test_client_id"}, + headers={"Authorization": "InvalidFormat"}, + ) + + assert response.status_code == 401 + assert response.get_json() == {"error": "Invalid Authorization header format"} diff --git a/api/tests/unit_tests/controllers/console/auth/test_password_reset.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py similarity index 81% rename from api/tests/unit_tests/controllers/console/auth/test_password_reset.py rename to api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py index 9488cf528ed..8f9db287e31 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_password_reset.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py @@ -1,17 +1,10 @@ -""" -Test suite for password reset authentication flows. +"""Testcontainers integration tests for password reset authentication flows.""" -This module tests the password reset mechanism including: -- Password reset email sending -- Verification code validation -- Password reset with token -- Rate limiting and security checks -""" +from __future__ import annotations from unittest.mock import MagicMock, patch import pytest -from flask import Flask from controllers.console.auth.error import ( EmailCodeError, @@ -28,31 +21,12 @@ from controllers.console.auth.forgot_password import ( from controllers.console.error import AccountNotFound, EmailSendIpLimitError -@pytest.fixture(autouse=True) -def _mock_forgot_password_session(): - with patch("controllers.console.auth.forgot_password.Session") as mock_session_cls: - mock_session = MagicMock() - mock_session_cls.return_value.__enter__.return_value = mock_session - mock_session_cls.return_value.__exit__.return_value = None - yield mock_session - - -@pytest.fixture(autouse=True) -def _mock_forgot_password_db(): - with patch("controllers.console.auth.forgot_password.db") as mock_db: - mock_db.engine = MagicMock() - yield mock_db - - class TestForgotPasswordSendEmailApi: """Test cases for sending password reset emails.""" @pytest.fixture - def app(self): - """Create Flask test application.""" - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers @pytest.fixture def mock_account(self): @@ -62,7 +36,6 @@ class TestForgotPasswordSendEmailApi: account.name = "Test User" return account - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email") @@ -73,20 +46,10 @@ class TestForgotPasswordSendEmailApi: mock_send_email, mock_get_account, mock_is_ip_limit, - mock_wraps_db, app, mock_account, ): - """ - Test successful password reset email sending. - - Verifies that: - - Email is sent to valid account - - Reset token is generated and returned - - IP rate limiting is checked - """ # Arrange - mock_wraps_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = False mock_get_account.return_value = mock_account mock_send_email.return_value = "reset_token_123" @@ -104,9 +67,8 @@ class TestForgotPasswordSendEmailApi: assert response["data"] == "reset_token_123" mock_send_email.assert_called_once() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit") - def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, mock_db, app): + def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, app): """ Test password reset email blocked by IP rate limit. @@ -115,7 +77,6 @@ class TestForgotPasswordSendEmailApi: - No email is sent when rate limited """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = True # Act & Assert @@ -133,7 +94,6 @@ class TestForgotPasswordSendEmailApi: (None, "en-US"), # Defaults to en-US when not provided ], ) - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email") @@ -144,7 +104,6 @@ class TestForgotPasswordSendEmailApi: mock_send_email, mock_get_account, mock_is_ip_limit, - mock_wraps_db, app, mock_account, language_input, @@ -158,7 +117,6 @@ class TestForgotPasswordSendEmailApi: - Unsupported languages default to en-US """ # Arrange - mock_wraps_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = False mock_get_account.return_value = mock_account mock_send_email.return_value = "token" @@ -180,13 +138,9 @@ class TestForgotPasswordCheckApi: """Test cases for verifying password reset codes.""" @pytest.fixture - def app(self): - """Create Flask test application.""" - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @@ -199,7 +153,6 @@ class TestForgotPasswordCheckApi: mock_revoke_token, mock_get_data, mock_is_rate_limit, - mock_db, app, ): """ @@ -212,7 +165,6 @@ class TestForgotPasswordCheckApi: - Rate limit is reset on success """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "test@example.com", "code": "123456"} mock_generate_token.return_value = (None, "new_token") @@ -236,7 +188,6 @@ class TestForgotPasswordCheckApi: ) mock_reset_rate_limit.assert_called_once_with("test@example.com") - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @@ -249,10 +200,8 @@ class TestForgotPasswordCheckApi: mock_revoke_token, mock_get_data, mock_is_rate_limit, - mock_db, app, ): - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "User@Example.com", "code": "999888"} mock_generate_token.return_value = (None, "fresh-token") @@ -271,9 +220,8 @@ class TestForgotPasswordCheckApi: mock_revoke_token.assert_called_once_with("upper_token") mock_reset_rate_limit.assert_called_once_with("user@example.com") - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") - def test_verify_code_rate_limited(self, mock_is_rate_limit, mock_db, app): + def test_verify_code_rate_limited(self, mock_is_rate_limit, app): """ Test code verification blocked by rate limit. @@ -282,7 +230,6 @@ class TestForgotPasswordCheckApi: - Prevents brute force attacks on verification codes """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = True # Act & Assert @@ -295,10 +242,9 @@ class TestForgotPasswordCheckApi: with pytest.raises(EmailPasswordResetLimitError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, mock_db, app): + def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, app): """ Test code verification with invalid token. @@ -306,7 +252,6 @@ class TestForgotPasswordCheckApi: - InvalidTokenError is raised for invalid/expired tokens """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_data.return_value = None @@ -320,10 +265,9 @@ class TestForgotPasswordCheckApi: with pytest.raises(InvalidTokenError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, mock_db, app): + def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, app): """ Test code verification with mismatched email. @@ -332,7 +276,6 @@ class TestForgotPasswordCheckApi: - Prevents token abuse """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "original@example.com", "code": "123456"} @@ -346,11 +289,10 @@ class TestForgotPasswordCheckApi: with pytest.raises(InvalidEmailError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.add_forgot_password_error_rate_limit") - def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, mock_db, app): + def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, app): """ Test code verification with incorrect code. @@ -359,7 +301,6 @@ class TestForgotPasswordCheckApi: - Rate limit counter is incremented """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "test@example.com", "code": "123456"} @@ -380,11 +321,8 @@ class TestForgotPasswordResetApi: """Test cases for resetting password with verified token.""" @pytest.fixture - def app(self): - """Create Flask test application.""" - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers @pytest.fixture def mock_account(self): @@ -394,7 +332,6 @@ class TestForgotPasswordResetApi: account.name = "Test User" return account - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @@ -405,7 +342,6 @@ class TestForgotPasswordResetApi: mock_get_account, mock_revoke_token, mock_get_data, - mock_wraps_db, app, mock_account, ): @@ -418,7 +354,6 @@ class TestForgotPasswordResetApi: - Success response is returned """ # Arrange - mock_wraps_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"} mock_get_account.return_value = mock_account mock_get_tenants.return_value = [MagicMock()] @@ -436,9 +371,8 @@ class TestForgotPasswordResetApi: assert response["result"] == "success" mock_revoke_token.assert_called_once_with("valid_token") - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_reset_password_mismatch(self, mock_get_data, mock_db, app): + def test_reset_password_mismatch(self, mock_get_data, app): """ Test password reset with mismatched passwords. @@ -447,7 +381,6 @@ class TestForgotPasswordResetApi: - No password update occurs """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"} # Act & Assert @@ -460,9 +393,8 @@ class TestForgotPasswordResetApi: with pytest.raises(PasswordMismatchError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_reset_password_invalid_token(self, mock_get_data, mock_db, app): + def test_reset_password_invalid_token(self, mock_get_data, app): """ Test password reset with invalid token. @@ -470,7 +402,6 @@ class TestForgotPasswordResetApi: - InvalidTokenError is raised for invalid/expired tokens """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = None # Act & Assert @@ -483,9 +414,8 @@ class TestForgotPasswordResetApi: with pytest.raises(InvalidTokenError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_reset_password_wrong_phase(self, mock_get_data, mock_db, app): + def test_reset_password_wrong_phase(self, mock_get_data, app): """ Test password reset with token not in reset phase. @@ -494,7 +424,6 @@ class TestForgotPasswordResetApi: - Prevents use of verification-phase tokens for reset """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "phase": "verify"} # Act & Assert @@ -507,13 +436,10 @@ class TestForgotPasswordResetApi: with pytest.raises(InvalidTokenError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") - def test_reset_password_account_not_found( - self, mock_get_account, mock_revoke_token, mock_get_data, mock_wraps_db, app - ): + def test_reset_password_account_not_found(self, mock_get_account, mock_revoke_token, mock_get_data, app): """ Test password reset for non-existent account. @@ -521,7 +447,6 @@ class TestForgotPasswordResetApi: - AccountNotFound is raised when account doesn't exist """ # Arrange - mock_wraps_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "nonexistent@example.com", "phase": "reset"} mock_get_account.return_value = None diff --git a/api/tests/test_containers_integration_tests/controllers/console/helpers.py b/api/tests/test_containers_integration_tests/controllers/console/helpers.py new file mode 100644 index 00000000000..9e2084f3939 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/helpers.py @@ -0,0 +1,85 @@ +"""Shared helpers for authenticated console controller integration tests.""" + +import uuid + +from flask.testing import FlaskClient +from sqlalchemy import select +from sqlalchemy.orm import Session + +from configs import dify_config +from constants import HEADER_NAME_CSRF_TOKEN +from libs.datetime_utils import naive_utc_now +from libs.token import _real_cookie_name, generate_csrf_token +from models import Account, DifySetup, Tenant, TenantAccountJoin +from models.account import AccountStatus, TenantAccountRole +from models.model import App, AppMode +from services.account_service import AccountService + + +def ensure_dify_setup(db_session: Session) -> None: + """Create a setup marker once so setup-protected console routes can be exercised.""" + if db_session.scalar(select(DifySetup).limit(1)) is not None: + return + + db_session.add(DifySetup(version=dify_config.project.version)) + db_session.commit() + + +def create_console_account_and_tenant(db_session: Session) -> tuple[Account, Tenant]: + """Create an initialized owner account with a current tenant.""" + account = Account( + email=f"test-{uuid.uuid4()}@example.com", + name="Test User", + interface_language="en-US", + status=AccountStatus.ACTIVE, + ) + account.initialized_at = naive_utc_now() + db_session.add(account) + db_session.commit() + + tenant = Tenant(name="Test Tenant", status="normal") + db_session.add(tenant) + db_session.commit() + + db_session.add( + TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + ) + db_session.commit() + + account.set_tenant_id(tenant.id) + account.timezone = "UTC" + db_session.commit() + + ensure_dify_setup(db_session) + return account, tenant + + +def create_console_app(db_session: Session, tenant_id: str, account_id: str, mode: AppMode) -> App: + """Create a minimal app row that can be loaded by get_app_model.""" + app = App( + tenant_id=tenant_id, + name="Test App", + mode=mode, + enable_site=True, + enable_api=True, + created_by=account_id, + ) + db_session.add(app) + db_session.commit() + return app + + +def authenticate_console_client(test_client: FlaskClient, account: Account) -> dict[str, str]: + """Attach console auth cookies/headers for endpoints guarded by login_required.""" + access_token = AccountService.get_account_jwt_token(account) + csrf_token = generate_csrf_token(account.id) + test_client.set_cookie(_real_cookie_name("csrf_token"), csrf_token, domain="localhost") + return { + "Authorization": f"Bearer {access_token}", + HEADER_NAME_CSRF_TOKEN: csrf_token, + } diff --git a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py index 96fb7ea2935..2b4c1b59abf 100644 --- a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py @@ -22,6 +22,13 @@ import uuid from time import time import pytest +from graphon.entities.pause_reason import SchedulingPause +from graphon.enums import WorkflowExecutionStatus +from graphon.graph_engine.entities.commands import GraphEngineCommand +from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError +from graphon.graph_events import GraphRunPausedEvent +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool from sqlalchemy import Engine, delete, select from sqlalchemy.orm import Session @@ -31,16 +38,7 @@ from core.app.layers.pause_state_persist_layer import ( PauseStatePersistenceLayer, WorkflowResumptionContext, ) -from dify_graph.entities.pause_reason import SchedulingPause -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.graph_engine.entities.commands import GraphEngineCommand -from dify_graph.graph_engine.layers.base import GraphEngineLayerNotInitializedError -from dify_graph.graph_events.graph import GraphRunPausedEvent -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.runtime.graph_runtime_state import GraphRuntimeState -from dify_graph.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState -from dify_graph.runtime.read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper -from dify_graph.runtime.variable_pool import SystemVariable, VariablePool +from core.workflow.system_variables import build_system_variables from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from models import Account @@ -212,7 +210,7 @@ class TestPauseStatePersistenceLayerTestContainers: execution_id = workflow_run_id or getattr(self, "test_workflow_run_id", None) or str(uuid.uuid4()) # Create variable pool - variable_pool = VariablePool(system_variables=SystemVariable(workflow_execution_id=execution_id)) + variable_pool = VariablePool(system_variables=build_system_variables(workflow_execution_id=execution_id)) if variables: for (node_id, var_key), value in variables.items(): variable_pool.add([node_id, var_key], value) @@ -544,7 +542,7 @@ class TestPauseStatePersistenceLayerTestContainers: layer.initialize(graph_runtime_state, command_channel) # Import other event types - from dify_graph.graph_events.graph import ( + from graphon.graph_events import ( GraphRunFailedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, diff --git a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py index 781e297fa4b..00d7496a409 100644 --- a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py +++ b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py @@ -4,6 +4,7 @@ from unittest.mock import patch import pytest from faker import Faker +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest from models.dataset import Dataset, Document @@ -38,7 +39,7 @@ class TestGetAvailableDatasetsIntegration: provider="dify", data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) db_session_with_containers.add(dataset) db_session_with_containers.flush() @@ -55,7 +56,7 @@ class TestGetAvailableDatasetsIntegration: name=f"Document {i}", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -112,7 +113,7 @@ class TestGetAvailableDatasetsIntegration: created_from=DocumentCreatedFrom.WEB, name=f"Archived Document {i}", created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=True, # Archived @@ -165,7 +166,7 @@ class TestGetAvailableDatasetsIntegration: created_from=DocumentCreatedFrom.WEB, name=f"Disabled Document {i}", created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_status=IndexingStatus.COMPLETED, enabled=False, # Disabled archived=False, @@ -218,7 +219,7 @@ class TestGetAvailableDatasetsIntegration: created_from=DocumentCreatedFrom.WEB, name=f"Document {status}", created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_status=status, # Not completed enabled=True, archived=False, @@ -336,7 +337,7 @@ class TestGetAvailableDatasetsIntegration: created_from=DocumentCreatedFrom.WEB, name=f"Document for {dataset.name}", created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, @@ -416,7 +417,7 @@ class TestGetAvailableDatasetsIntegration: created_from=DocumentCreatedFrom.WEB, name=f"Document {i}", created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, @@ -459,7 +460,7 @@ class TestKnowledgeRetrievalIntegration: provider="dify", data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) db_session_with_containers.add(dataset) @@ -476,7 +477,7 @@ class TestKnowledgeRetrievalIntegration: indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py index 9d0fad4b12a..13caad799eb 100644 --- a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py @@ -4,23 +4,20 @@ from __future__ import annotations from uuid import uuid4 +from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData, UserAction from sqlalchemy import Engine, select from sqlalchemy.orm import Session -from core.repositories.human_input_repository import HumanInputFormRepositoryImpl -from dify_graph.nodes.human_input.entities import ( +from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl +from core.workflow.human_input_compat import ( DeliveryChannelConfig, EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - FormDefinition, - HumanInputNodeData, MemberRecipient, - UserAction, WebAppDeliveryMethod, ) -from dify_graph.repositories.human_input_form_repository import FormCreateParams from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.human_input import ( EmailExternalRecipientPayload, @@ -68,7 +65,6 @@ def _build_form_params(delivery_methods: list[DeliveryChannelConfig]) -> FormCre user_actions=[UserAction(id="approve", title="Approve")], ) return FormCreateParams( - app_id=str(uuid4()), workflow_execution_id=str(uuid4()), node_id="human-input-node", form_config=form_config, @@ -84,7 +80,7 @@ def _build_email_delivery( ) -> EmailDeliveryMethod: return EmailDeliveryMethod( config=EmailDeliveryConfig( - recipients=EmailRecipients(whole_workspace=whole_workspace, items=recipients), + recipients=EmailRecipients(include_bound_group=whole_workspace, items=recipients), subject="Approval Needed", body="Please review", ) @@ -100,7 +96,7 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["member1@example.com", "member2@example.com"], ) - repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id, app_id=str(uuid4())) params = _build_form_params( delivery_methods=[_build_email_delivery(whole_workspace=True, recipients=[])], ) @@ -129,13 +125,13 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["primary@example.com", "secondary@example.com"], ) - repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id, app_id=str(uuid4())) params = _build_form_params( delivery_methods=[ _build_email_delivery( whole_workspace=False, recipients=[ - MemberRecipient(user_id=members[0].id), + MemberRecipient(reference_id=members[0].id), ExternalRecipient(email="external@example.com"), ], ) @@ -173,10 +169,9 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["prefill@example.com"], ) - repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id, app_id=str(uuid4())) resolved_values = {"greeting": "Hello!"} params = FormCreateParams( - app_id=str(uuid4()), workflow_execution_id=str(uuid4()), node_id="human-input-node", form_config=HumanInputNodeData( @@ -210,9 +205,8 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["ui@example.com"], ) - repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id, app_id=str(uuid4())) params = FormCreateParams( - app_id=str(uuid4()), workflow_execution_id=str(uuid4()), node_id="human-input-node", form_config=HumanInputNodeData( diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py index 9733735df36..0a9b476afc5 100644 --- a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py +++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py @@ -4,28 +4,29 @@ from datetime import timedelta from unittest.mock import MagicMock import pytest +from graphon.enums import WorkflowType +from graphon.graph import Graph +from graphon.graph_engine import GraphEngine +from graphon.graph_engine.command_channels import InMemoryChannel +from graphon.nodes.end.end_node import EndNode +from graphon.nodes.end.entities import EndNodeData +from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction +from graphon.nodes.human_input.enums import HumanInputFormStatus +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.nodes.start.entities import StartNodeData +from graphon.nodes.start.start_node import StartNode +from graphon.runtime import GraphRuntimeState, VariablePool from sqlalchemy import delete, select from sqlalchemy.orm import Session from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.workflow.layers import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.repositories.human_input_repository import HumanInputFormEntity, HumanInputFormRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository -from dify_graph.enums import WorkflowType -from dify_graph.graph import Graph -from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from dify_graph.graph_engine.graph_engine import GraphEngine -from dify_graph.nodes.end.end_node import EndNode -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction -from dify_graph.nodes.human_input.enums import HumanInputFormStatus -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from libs.datetime_utils import naive_utc_now from models import Account from models.account import Tenant, TenantAccountJoin, TenantAccountRole @@ -39,7 +40,7 @@ def _mock_form_repository_without_submission() -> HumanInputFormRepository: repo = MagicMock(spec=HumanInputFormRepository) form_entity = MagicMock(spec=HumanInputFormEntity) form_entity.id = "test-form-id" - form_entity.web_app_token = "test-form-token" + form_entity.submission_token = "test-form-token" form_entity.recipients = [] form_entity.rendered_content = "rendered" form_entity.submitted = False @@ -52,7 +53,7 @@ def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepos repo = MagicMock(spec=HumanInputFormRepository) form_entity = MagicMock(spec=HumanInputFormEntity) form_entity.id = "test-form-id" - form_entity.web_app_token = "test-form-token" + form_entity.submission_token = "test-form-token" form_entity.recipients = [] form_entity.rendered_content = "rendered" form_entity.submitted = True @@ -66,7 +67,7 @@ def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepos def _build_runtime_state(workflow_execution_id: str, app_id: str, workflow_id: str, user_id: str) -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( workflow_execution_id=workflow_execution_id, app_id=app_id, workflow_id=workflow_id, @@ -120,6 +121,7 @@ def _build_graph( graph_init_params=params, graph_runtime_state=runtime_state, form_repository=form_repository, + runtime=DifyHumanInputNodeRuntime(params.run_context), ) end_data = EndNodeData( diff --git a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py index 8e70fc0bb00..cc72dc1cf39 100644 --- a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py @@ -4,9 +4,10 @@ from unittest.mock import patch from uuid import uuid4 import pytest +from graphon.file import File, FileTransferMethod, FileType from sqlalchemy.orm import Session -from dify_graph.file import File, FileTransferMethod, FileType +from core.app.file_access import DatabaseFileAccessController from extensions.ext_database import db from extensions.storage.storage_type import StorageType from factories.file_factory import StorageKeyLoader @@ -35,7 +36,11 @@ class TestStorageKeyLoader(unittest.TestCase): self.test_tool_files = [] # Create StorageKeyLoader instance - self.loader = StorageKeyLoader(self.session, self.tenant_id) + self.loader = StorageKeyLoader( + self.session, + self.tenant_id, + access_controller=DatabaseFileAccessController(), + ) def tearDown(self): """Clean up test data after each test method.""" @@ -193,19 +198,16 @@ class TestStorageKeyLoader(unittest.TestCase): # Should not raise any exceptions self.loader.load_storage_keys([]) - def test_load_storage_keys_tenant_mismatch(self): - """Test tenant_id validation.""" - # Create file with different tenant_id + def test_load_storage_keys_ignores_legacy_file_tenant_id(self): + """Legacy file tenant_id should not override the loader tenant scope.""" upload_file = self._create_upload_file() file = self._create_file( related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4()) ) - # Should raise ValueError for tenant mismatch - with pytest.raises(ValueError) as context: - self.loader.load_storage_keys([file]) + self.loader.load_storage_keys([file]) - assert "invalid file, expected tenant_id" in str(context.value) + assert file._storage_key == upload_file.key def test_load_storage_keys_missing_file_id(self): """Test with None file.related_id.""" @@ -314,7 +316,7 @@ class TestStorageKeyLoader(unittest.TestCase): with pytest.raises(ValueError) as context: self.loader.load_storage_keys([file_other]) - assert "invalid file, expected tenant_id" in str(context.value) + assert "Upload file not found for id:" in str(context.value) # Current tenant's file should still work self.loader.load_storage_keys([file_current]) @@ -338,7 +340,7 @@ class TestStorageKeyLoader(unittest.TestCase): with pytest.raises(ValueError) as context: self.loader.load_storage_keys([file_current, file_other]) - assert "invalid file, expected tenant_id" in str(context.value) + assert "Upload file not found for id:" in str(context.value) def test_load_storage_keys_duplicate_file_ids(self): """Test handling of duplicate file IDs in the batch.""" @@ -365,6 +367,10 @@ class TestStorageKeyLoader(unittest.TestCase): # Create loader with different session (same underlying connection) with Session(bind=db.engine) as other_session: - other_loader = StorageKeyLoader(other_session, self.tenant_id) + other_loader = StorageKeyLoader( + other_session, + self.tenant_id, + access_controller=DatabaseFileAccessController(), + ) with pytest.raises(ValueError): other_loader.load_storage_keys([file]) diff --git a/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py b/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py index 573f84cb0bc..b745aed1417 100644 --- a/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py +++ b/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py @@ -1,12 +1,15 @@ from __future__ import annotations from dataclasses import dataclass -from datetime import datetime, timedelta +from datetime import timedelta from decimal import Decimal from uuid import uuid4 -from dify_graph.nodes.human_input.entities import FormDefinition, UserAction +from graphon.nodes.human_input.entities import FormDefinition, UserAction + +from libs.datetime_utils import naive_utc_now from models.account import Account, Tenant, TenantAccountJoin +from models.enums import ConversationFromSource, InvokeFrom from models.execution_extra_content import HumanInputContent from models.human_input import HumanInputForm, HumanInputFormStatus from models.model import App, Conversation, Message @@ -78,8 +81,8 @@ def create_human_input_message_fixture(db_session) -> HumanInputMessageFixture: introduction="", system_instruction="", status="normal", - invoke_from="console", - from_source="console", + invoke_from=InvokeFrom.EXPLORE, + from_source=ConversationFromSource.CONSOLE, from_account_id=account.id, from_end_user_id=None, ) @@ -101,7 +104,7 @@ def create_human_input_message_fixture(db_session) -> HumanInputMessageFixture: answer_unit_price=Decimal("0.001"), provider_response_latency=0.5, currency="USD", - from_source="console", + from_source=ConversationFromSource.CONSOLE, from_account_id=account.id, workflow_run_id=workflow_run_id, ) @@ -116,7 +119,7 @@ def create_human_input_message_fixture(db_session) -> HumanInputMessageFixture: inputs=[], user_actions=[UserAction(id=action_id, title=action_text)], rendered_content="Rendered block", - expiration_time=datetime.utcnow() + timedelta(days=1), + expiration_time=naive_utc_now() + timedelta(days=1), node_title=node_title, display_in_ui=True, ) @@ -128,7 +131,7 @@ def create_human_input_message_fixture(db_session) -> HumanInputMessageFixture: form_definition=form_definition.model_dump_json(), rendered_content="Rendered block", status=HumanInputFormStatus.SUBMITTED, - expiration_time=datetime.utcnow() + timedelta(days=1), + expiration_time=naive_utc_now() + timedelta(days=1), selected_action_id=action_id, ) db_session.add(form) diff --git a/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_streams_channel.py b/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_streams_channel.py new file mode 100644 index 00000000000..a79208f649d --- /dev/null +++ b/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_streams_channel.py @@ -0,0 +1,227 @@ +""" +Integration tests for Redis Streams broadcast channel implementation using TestContainers. + +This suite focuses on the semantics that differ from Redis Pub/Sub: +- Every active subscription should receive each newly published message. +- Each subscription should only observe messages published after its listener starts. +""" + +import threading +import time +import uuid +from collections.abc import Iterator +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +import redis +from testcontainers.redis import RedisContainer + +from libs.broadcast_channel.channel import BroadcastChannel, Subscription, Topic +from libs.broadcast_channel.exc import SubscriptionClosedError +from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel + + +class TestRedisStreamsBroadcastChannelIntegration: + """Integration tests for Redis Streams broadcast channel with a real Redis instance.""" + + @pytest.fixture(scope="class") + def redis_container(self) -> Iterator[RedisContainer]: + """Create a Redis container for integration testing.""" + with RedisContainer(image="redis:6-alpine") as container: + yield container + + @pytest.fixture(scope="class") + def redis_client(self, redis_container: RedisContainer) -> redis.Redis: + """Create a Redis client connected to the test container.""" + host = redis_container.get_container_host_ip() + port = redis_container.get_exposed_port(6379) + return redis.Redis(host=host, port=port, decode_responses=False) + + @pytest.fixture + def broadcast_channel(self, redis_client: redis.Redis) -> BroadcastChannel: + """Create a StreamsBroadcastChannel instance with a real Redis client.""" + return StreamsBroadcastChannel(redis_client) + + @classmethod + def _get_test_topic_name(cls) -> str: + return f"test_streams_topic_{uuid.uuid4()}" + + @staticmethod + def _start_subscription(subscription: Subscription) -> None: + """Start the background listener and confirm the subscription queue is empty.""" + assert subscription.receive(timeout=0.05) is None + + @staticmethod + def _receive_message(subscription: Subscription, *, timeout_seconds: float = 2.0) -> bytes: + """Poll until a message is received or the timeout expires.""" + deadline = time.monotonic() + timeout_seconds + while time.monotonic() < deadline: + message = subscription.receive(timeout=0.1) + if message is not None: + return message + pytest.fail("Timed out waiting for a message") + + def test_close_an_active_subscription_should_stop_iteration(self, broadcast_channel: BroadcastChannel) -> None: + """Closing an active subscription should terminate the iterator cleanly.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + subscription = topic.subscribe() + consuming_event = threading.Event() + + def consume() -> list[bytes]: + messages: list[bytes] = [] + consuming_event.set() + for message in subscription: + messages.append(message) + return messages + + with ThreadPoolExecutor(max_workers=1) as executor: + consumer_future = executor.submit(consume) + assert consuming_event.wait(timeout=1.0) + subscription.close() + assert consumer_future.result(timeout=2.0) == [] + + def test_end_to_end_messaging(self, broadcast_channel: BroadcastChannel) -> None: + """A producer should publish a message that a live subscription can consume.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + producer = topic.as_producer() + subscription = topic.subscribe() + message = b"hello streams" + + try: + self._start_subscription(subscription) + producer.publish(message) + + assert self._receive_message(subscription) == message + assert subscription.receive(timeout=0.1) is None + finally: + subscription.close() + + def test_multiple_subscriptions_each_receive_each_new_message(self, broadcast_channel: BroadcastChannel) -> None: + """Each active subscription should receive the same newly published message.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + subscriptions = [topic.subscribe() for _ in range(3)] + new_message = b"message-visible-to-every-subscriber" + + try: + for subscription in subscriptions: + self._start_subscription(subscription) + + topic.publish(new_message) + + for subscription in subscriptions: + assert self._receive_message(subscription) == new_message + assert subscription.receive(timeout=0.1) is None + finally: + for subscription in subscriptions: + subscription.close() + + def test_each_subscription_only_receives_messages_published_after_it_starts( + self, + broadcast_channel: BroadcastChannel, + ) -> None: + """A late subscription should not replay messages that existed before its listener started.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + first_subscription = topic.subscribe() + second_subscription = topic.subscribe() + message_before_any_subscription = b"before-any-subscription" + message_after_first_subscription = b"after-first-subscription" + message_after_second_subscription = b"after-second-subscription" + + try: + topic.publish(message_before_any_subscription) + + self._start_subscription(first_subscription) + topic.publish(message_after_first_subscription) + + assert self._receive_message(first_subscription) == message_after_first_subscription + assert first_subscription.receive(timeout=0.1) is None + + self._start_subscription(second_subscription) + topic.publish(message_after_second_subscription) + + assert self._receive_message(first_subscription) == message_after_second_subscription + assert self._receive_message(second_subscription) == message_after_second_subscription + assert first_subscription.receive(timeout=0.1) is None + assert second_subscription.receive(timeout=0.1) is None + finally: + first_subscription.close() + second_subscription.close() + + def test_topic_isolation(self, broadcast_channel: BroadcastChannel) -> None: + """Messages from different topics should remain isolated.""" + topic1 = broadcast_channel.topic(self._get_test_topic_name()) + topic2 = broadcast_channel.topic(self._get_test_topic_name()) + message1 = b"message-for-topic-1" + message2 = b"message-for-topic-2" + + def consume_single_message(topic: Topic) -> bytes: + subscription = topic.subscribe() + try: + self._start_subscription(subscription) + return self._receive_message(subscription) + finally: + subscription.close() + + with ThreadPoolExecutor(max_workers=3) as executor: + consumer1_future = executor.submit(consume_single_message, topic1) + consumer2_future = executor.submit(consume_single_message, topic2) + time.sleep(0.1) + topic1.publish(message1) + topic2.publish(message2) + + assert consumer1_future.result(timeout=5.0) == message1 + assert consumer2_future.result(timeout=5.0) == message2 + + def test_concurrent_producers_publish_all_messages(self, broadcast_channel: BroadcastChannel) -> None: + """Concurrent producers should not lose messages for a live subscription.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + subscription = topic.subscribe() + producer_count = 4 + messages_per_producer = 4 + expected_total = producer_count * messages_per_producer + consumer_ready = threading.Event() + + def produce_messages(producer_idx: int) -> set[bytes]: + producer = topic.as_producer() + produced: set[bytes] = set() + for message_idx in range(messages_per_producer): + payload = f"producer-{producer_idx}-message-{message_idx}".encode() + produced.add(payload) + producer.publish(payload) + time.sleep(0.001) + return produced + + def consume_messages() -> set[bytes]: + received: set[bytes] = set() + try: + self._start_subscription(subscription) + consumer_ready.set() + while len(received) < expected_total: + message = subscription.receive(timeout=0.2) + if message is not None: + received.add(message) + return received + finally: + subscription.close() + + with ThreadPoolExecutor(max_workers=producer_count + 1) as executor: + consumer_future = executor.submit(consume_messages) + assert consumer_ready.wait(timeout=2.0) + + producer_futures = [executor.submit(produce_messages, idx) for idx in range(producer_count)] + expected_messages: set[bytes] = set() + for future in as_completed(producer_futures, timeout=10.0): + expected_messages.update(future.result()) + + assert consumer_future.result(timeout=10.0) == expected_messages + + def test_receive_raises_subscription_closed_after_close(self, broadcast_channel: BroadcastChannel) -> None: + """Calling receive on a closed subscription should raise SubscriptionClosedError.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + subscription = topic.subscribe() + + self._start_subscription(subscription) + subscription.close() + + with pytest.raises(SubscriptionClosedError): + subscription.receive(timeout=0.1) diff --git a/api/tests/test_containers_integration_tests/repositories/test_execution_extra_content_repository.py b/api/tests/test_containers_integration_tests/repositories/test_execution_extra_content_repository.py deleted file mode 100644 index c9058626d10..00000000000 --- a/api/tests/test_containers_integration_tests/repositories/test_execution_extra_content_repository.py +++ /dev/null @@ -1,27 +0,0 @@ -from __future__ import annotations - -from sqlalchemy.orm import sessionmaker - -from extensions.ext_database import db -from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository -from tests.test_containers_integration_tests.helpers.execution_extra_content import ( - create_human_input_message_fixture, -) - - -def test_get_by_message_ids_returns_human_input_content(db_session_with_containers): - fixture = create_human_input_message_fixture(db_session_with_containers) - repository = SQLAlchemyExecutionExtraContentRepository( - session_maker=sessionmaker(bind=db.engine, expire_on_commit=False) - ) - - results = repository.get_by_message_ids([fixture.message.id]) - - assert len(results) == 1 - assert len(results[0]) == 1 - content = results[0][0] - assert content.submitted is True - assert content.form_submission_data is not None - assert content.form_submission_data.action_id == fixture.action_id - assert content.form_submission_data.action_text == fixture.action_text - assert content.form_submission_data.rendered_content == fixture.form.rendered_content diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py index 458862b0ece..a68b3a08c7a 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py @@ -5,10 +5,10 @@ from __future__ import annotations from datetime import timedelta from uuid import uuid4 +from graphon.enums import WorkflowNodeExecutionStatus from sqlalchemy import Engine, delete from sqlalchemy.orm import Session, sessionmaker -from dify_graph.enums import WorkflowNodeExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index 76e586e65f8..d28cfda1598 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -8,19 +8,28 @@ from unittest.mock import Mock from uuid import uuid4 import pytest +from graphon.entities import WorkflowExecution +from graphon.entities.pause_reason import HumanInputRequired, PauseReasonType +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.entities import FormDefinition, FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType, HumanInputFormStatus from sqlalchemy import Engine, delete, select from sqlalchemy.orm import Session, sessionmaker -from dify_graph.entities import WorkflowExecution -from dify_graph.entities.pause_reason import PauseReasonType -from dify_graph.enums import WorkflowExecutionStatus from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom -from models.workflow import WorkflowAppLog, WorkflowPause, WorkflowPauseReason, WorkflowRun +from models.human_input import ( + HumanInputDelivery, + HumanInputForm, + HumanInputFormRecipient, +) +from models.workflow import WorkflowAppLog, WorkflowAppLogCreatedFrom, WorkflowPause, WorkflowPauseReason, WorkflowRun from repositories.entities.workflow_pause import WorkflowPauseEntity from repositories.sqlalchemy_api_workflow_run_repository import ( DifyAPISQLAlchemyWorkflowRunRepository, + _build_human_input_required_reason, + _PrivateWorkflowPauseEntity, _WorkflowRunError, ) @@ -90,6 +99,19 @@ def _cleanup_scope_data(session: Session, scope: _TestScope) -> None: WorkflowRun.app_id == scope.app_id, ) ) + + form_ids_subquery = select(HumanInputForm.id).where( + HumanInputForm.tenant_id == scope.tenant_id, + HumanInputForm.app_id == scope.app_id, + ) + session.execute(delete(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids_subquery))) + session.execute(delete(HumanInputDelivery).where(HumanInputDelivery.form_id.in_(form_ids_subquery))) + session.execute( + delete(HumanInputForm).where( + HumanInputForm.tenant_id == scope.tenant_id, + HumanInputForm.app_id == scope.app_id, + ) + ) session.commit() for state_key in scope.state_keys: @@ -193,7 +215,7 @@ class TestDeleteRunsWithRelated: app_id=test_scope.app_id, workflow_id=test_scope.workflow_id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=test_scope.user_id, ) @@ -253,7 +275,7 @@ class TestCountRunsWithRelated: app_id=test_scope.app_id, workflow_id=test_scope.workflow_id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=test_scope.user_id, ) @@ -504,3 +526,180 @@ class TestDeleteWorkflowPause: with pytest.raises(_WorkflowRunError, match="WorkflowPause not found"): repository.delete_workflow_pause(pause_entity=pause_entity) + + +class TestPrivateWorkflowPauseEntity: + """Integration tests for _PrivateWorkflowPauseEntity using real DB models.""" + + def test_properties( + self, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Entity properties delegate to the persisted WorkflowPause model.""" + + workflow_run = _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.RUNNING, + ) + pause = WorkflowPause( + id=str(uuid4()), + workflow_id=test_scope.workflow_id, + workflow_run_id=workflow_run.id, + state_object_key=f"workflow-state-{uuid4()}.json", + ) + db_session_with_containers.add(pause) + db_session_with_containers.commit() + db_session_with_containers.refresh(pause) + test_scope.state_keys.add(pause.state_object_key) + + entity = _PrivateWorkflowPauseEntity(pause_model=pause, reason_models=[], human_input_form=[]) + + assert entity.id == pause.id + assert entity.workflow_execution_id == workflow_run.id + assert entity.resumed_at is None + + def test_get_state( + self, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """get_state loads state data from storage using the persisted state_object_key.""" + + workflow_run = _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.RUNNING, + ) + state_key = f"workflow-state-{uuid4()}.json" + pause = WorkflowPause( + id=str(uuid4()), + workflow_id=test_scope.workflow_id, + workflow_run_id=workflow_run.id, + state_object_key=state_key, + ) + db_session_with_containers.add(pause) + db_session_with_containers.commit() + db_session_with_containers.refresh(pause) + test_scope.state_keys.add(state_key) + + expected_state = b'{"test": "state"}' + storage.save(state_key, expected_state) + + entity = _PrivateWorkflowPauseEntity(pause_model=pause, reason_models=[], human_input_form=[]) + result = entity.get_state() + + assert result == expected_state + + def test_get_state_caching( + self, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """get_state caches the result so storage is only accessed once.""" + + workflow_run = _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.RUNNING, + ) + state_key = f"workflow-state-{uuid4()}.json" + pause = WorkflowPause( + id=str(uuid4()), + workflow_id=test_scope.workflow_id, + workflow_run_id=workflow_run.id, + state_object_key=state_key, + ) + db_session_with_containers.add(pause) + db_session_with_containers.commit() + db_session_with_containers.refresh(pause) + test_scope.state_keys.add(state_key) + + expected_state = b'{"test": "state"}' + storage.save(state_key, expected_state) + + entity = _PrivateWorkflowPauseEntity(pause_model=pause, reason_models=[], human_input_form=[]) + result1 = entity.get_state() + # Delete from storage to prove second call uses cache + storage.delete(state_key) + test_scope.state_keys.discard(state_key) + result2 = entity.get_state() + + assert result1 == expected_state + assert result2 == expected_state + + +class TestBuildHumanInputRequiredReason: + """Integration tests for _build_human_input_required_reason using real DB models.""" + + def test_builds_reason_from_form_definition( + self, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Build the graph pause reason from the stored form definition.""" + + expiration_time = naive_utc_now() + form_definition = FormDefinition( + form_content="content", + inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")], + user_actions=[UserAction(id="approve", title="Approve")], + rendered_content="rendered", + expiration_time=expiration_time, + default_values={"name": "Alice"}, + node_title="Ask Name", + display_in_ui=True, + ) + + form_model = HumanInputForm( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + workflow_run_id=str(uuid4()), + node_id="node-1", + form_definition=form_definition.model_dump_json(), + rendered_content="rendered", + status=HumanInputFormStatus.WAITING, + expiration_time=expiration_time, + ) + db_session_with_containers.add(form_model) + db_session_with_containers.flush() + + # Create a pause so the reason has a valid pause_id + workflow_run = _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.RUNNING, + ) + pause = WorkflowPause( + id=str(uuid4()), + workflow_id=test_scope.workflow_id, + workflow_run_id=workflow_run.id, + state_object_key=f"workflow-state-{uuid4()}.json", + ) + db_session_with_containers.add(pause) + db_session_with_containers.flush() + test_scope.state_keys.add(pause.state_object_key) + + reason_model = WorkflowPauseReason( + pause_id=pause.id, + type_=PauseReasonType.HUMAN_INPUT_REQUIRED, + form_id=form_model.id, + node_id="node-1", + message="", + ) + db_session_with_containers.add(reason_model) + db_session_with_containers.commit() + + # Refresh to ensure we have DB-round-tripped objects + db_session_with_containers.refresh(form_model) + db_session_with_containers.refresh(reason_model) + + reason = _build_human_input_required_reason(reason_model, form_model) + + assert isinstance(reason, HumanInputRequired) + assert reason.node_title == "Ask Name" + assert reason.form_content == "content" + assert reason.inputs[0].output_variable_name == "name" + assert reason.actions[0].id == "approve" + assert reason.resolved_default_values == {"name": "Alice"} diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py new file mode 100644 index 00000000000..7f44eb6ca37 --- /dev/null +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py @@ -0,0 +1,408 @@ +"""Integration tests for SQLAlchemyExecutionExtraContentRepository using Testcontainers. + +Part of #32454 — replaces the mock-based unit tests with real database interactions. +""" + +from __future__ import annotations + +from collections.abc import Generator +from dataclasses import dataclass +from datetime import timedelta +from decimal import Decimal +from uuid import uuid4 + +import pytest +from graphon.nodes.human_input.entities import FormDefinition, UserAction +from graphon.nodes.human_input.enums import HumanInputFormStatus +from sqlalchemy import Engine, delete, select +from sqlalchemy.orm import Session, sessionmaker + +from libs.datetime_utils import naive_utc_now +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.enums import ConversationFromSource, InvokeFrom +from models.execution_extra_content import ExecutionExtraContent, HumanInputContent +from models.human_input import ( + ConsoleRecipientPayload, + HumanInputDelivery, + HumanInputForm, + HumanInputFormRecipient, + RecipientType, +) +from models.model import App, Conversation, Message +from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository + + +@dataclass +class _TestScope: + """Per-test data scope used to isolate DB rows. + + IDs are populated after flushing the base entities to the database. + """ + + tenant_id: str = "" + app_id: str = "" + user_id: str = "" + + +def _cleanup_scope_data(session: Session, scope: _TestScope) -> None: + """Remove test-created DB rows for a test scope.""" + form_ids_subquery = select(HumanInputForm.id).where( + HumanInputForm.tenant_id == scope.tenant_id, + ) + session.execute(delete(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids_subquery))) + session.execute(delete(HumanInputDelivery).where(HumanInputDelivery.form_id.in_(form_ids_subquery))) + session.execute( + delete(ExecutionExtraContent).where( + ExecutionExtraContent.workflow_run_id.in_( + select(HumanInputForm.workflow_run_id).where(HumanInputForm.tenant_id == scope.tenant_id) + ) + ) + ) + session.execute(delete(HumanInputForm).where(HumanInputForm.tenant_id == scope.tenant_id)) + session.execute(delete(Message).where(Message.app_id == scope.app_id)) + session.execute(delete(Conversation).where(Conversation.app_id == scope.app_id)) + session.execute(delete(App).where(App.id == scope.app_id)) + session.execute(delete(TenantAccountJoin).where(TenantAccountJoin.tenant_id == scope.tenant_id)) + session.execute(delete(Account).where(Account.id == scope.user_id)) + session.execute(delete(Tenant).where(Tenant.id == scope.tenant_id)) + session.commit() + + +def _seed_base_entities(session: Session, scope: _TestScope) -> None: + """Create the base tenant, account, and app needed by tests.""" + tenant = Tenant(name="Test Tenant") + session.add(tenant) + session.flush() + scope.tenant_id = tenant.id + + account = Account( + name="Test Account", + email=f"test_{uuid4()}@example.com", + password="hashed-password", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) + session.add(account) + session.flush() + scope.user_id = account.id + + tenant_join = TenantAccountJoin( + tenant_id=scope.tenant_id, + account_id=scope.user_id, + role=TenantAccountRole.OWNER, + current=True, + ) + session.add(tenant_join) + + app = App( + tenant_id=scope.tenant_id, + name="Test App", + description="", + mode="chat", + icon_type="emoji", + icon="bot", + icon_background="#FFFFFF", + enable_site=False, + enable_api=True, + api_rpm=100, + api_rph=100, + is_demo=False, + is_public=False, + is_universal=False, + created_by=scope.user_id, + updated_by=scope.user_id, + ) + session.add(app) + session.flush() + scope.app_id = app.id + + +def _create_conversation(session: Session, scope: _TestScope) -> Conversation: + conversation = Conversation( + app_id=scope.app_id, + mode="chat", + name="Test Conversation", + summary="", + introduction="", + system_instruction="", + status="normal", + invoke_from=InvokeFrom.EXPLORE, + from_source=ConversationFromSource.CONSOLE, + from_account_id=scope.user_id, + from_end_user_id=None, + ) + conversation.inputs = {} + session.add(conversation) + session.flush() + return conversation + + +def _create_message( + session: Session, + scope: _TestScope, + conversation_id: str, + workflow_run_id: str, +) -> Message: + message = Message( + app_id=scope.app_id, + conversation_id=conversation_id, + inputs={}, + query="test query", + message={"messages": []}, + answer="test answer", + message_tokens=50, + message_unit_price=Decimal("0.001"), + answer_tokens=80, + answer_unit_price=Decimal("0.001"), + provider_response_latency=0.5, + currency="USD", + from_source=ConversationFromSource.CONSOLE, + from_account_id=scope.user_id, + workflow_run_id=workflow_run_id, + ) + session.add(message) + session.flush() + return message + + +def _create_submitted_form( + session: Session, + scope: _TestScope, + *, + workflow_run_id: str, + action_id: str = "approve", + action_title: str = "Approve", + node_title: str = "Approval", +) -> HumanInputForm: + expiration_time = naive_utc_now() + timedelta(days=1) + form_definition = FormDefinition( + form_content="content", + inputs=[], + user_actions=[UserAction(id=action_id, title=action_title)], + rendered_content="rendered", + expiration_time=expiration_time, + node_title=node_title, + display_in_ui=True, + ) + form = HumanInputForm( + tenant_id=scope.tenant_id, + app_id=scope.app_id, + workflow_run_id=workflow_run_id, + node_id="node-id", + form_definition=form_definition.model_dump_json(), + rendered_content=f"Rendered {action_title}", + status=HumanInputFormStatus.SUBMITTED, + expiration_time=expiration_time, + selected_action_id=action_id, + ) + session.add(form) + session.flush() + return form + + +def _create_waiting_form( + session: Session, + scope: _TestScope, + *, + workflow_run_id: str, + default_values: dict | None = None, +) -> HumanInputForm: + expiration_time = naive_utc_now() + timedelta(days=1) + form_definition = FormDefinition( + form_content="content", + inputs=[], + user_actions=[UserAction(id="approve", title="Approve")], + rendered_content="rendered", + expiration_time=expiration_time, + default_values=default_values or {"name": "John"}, + node_title="Approval", + display_in_ui=True, + ) + form = HumanInputForm( + tenant_id=scope.tenant_id, + app_id=scope.app_id, + workflow_run_id=workflow_run_id, + node_id="node-id", + form_definition=form_definition.model_dump_json(), + rendered_content="Rendered block", + status=HumanInputFormStatus.WAITING, + expiration_time=expiration_time, + ) + session.add(form) + session.flush() + return form + + +def _create_human_input_content( + session: Session, + *, + workflow_run_id: str, + message_id: str, + form_id: str, +) -> HumanInputContent: + content = HumanInputContent.new( + workflow_run_id=workflow_run_id, + message_id=message_id, + form_id=form_id, + ) + session.add(content) + return content + + +def _create_recipient( + session: Session, + *, + form_id: str, + delivery_id: str, + recipient_type: RecipientType = RecipientType.CONSOLE, + access_token: str = "token-1", +) -> HumanInputFormRecipient: + payload = ConsoleRecipientPayload(account_id=None) + recipient = HumanInputFormRecipient( + form_id=form_id, + delivery_id=delivery_id, + recipient_type=recipient_type, + recipient_payload=payload.model_dump_json(), + access_token=access_token, + ) + session.add(recipient) + return recipient + + +def _create_delivery(session: Session, *, form_id: str) -> HumanInputDelivery: + from core.workflow.human_input_compat import DeliveryMethodType + from models.human_input import ConsoleDeliveryPayload + + delivery = HumanInputDelivery( + form_id=form_id, + delivery_method_type=DeliveryMethodType.WEBAPP, + channel_payload=ConsoleDeliveryPayload().model_dump_json(), + ) + session.add(delivery) + session.flush() + return delivery + + +@pytest.fixture +def repository(db_session_with_containers: Session) -> SQLAlchemyExecutionExtraContentRepository: + """Build a repository backed by the testcontainers database engine.""" + engine = db_session_with_containers.get_bind() + assert isinstance(engine, Engine) + return SQLAlchemyExecutionExtraContentRepository(sessionmaker(bind=engine, expire_on_commit=False)) + + +@pytest.fixture +def test_scope(db_session_with_containers: Session) -> Generator[_TestScope]: + """Provide an isolated scope and clean related data after each test.""" + scope = _TestScope() + _seed_base_entities(db_session_with_containers, scope) + db_session_with_containers.commit() + yield scope + _cleanup_scope_data(db_session_with_containers, scope) + + +class TestGetByMessageIds: + """Tests for SQLAlchemyExecutionExtraContentRepository.get_by_message_ids.""" + + def test_groups_contents_by_message( + self, + db_session_with_containers: Session, + repository: SQLAlchemyExecutionExtraContentRepository, + test_scope: _TestScope, + ) -> None: + """Submitted forms are correctly mapped and grouped by message ID.""" + workflow_run_id = str(uuid4()) + conversation = _create_conversation(db_session_with_containers, test_scope) + msg1 = _create_message(db_session_with_containers, test_scope, conversation.id, workflow_run_id) + msg2 = _create_message(db_session_with_containers, test_scope, conversation.id, workflow_run_id) + + form = _create_submitted_form( + db_session_with_containers, + test_scope, + workflow_run_id=workflow_run_id, + action_id="approve", + action_title="Approve", + ) + _create_human_input_content( + db_session_with_containers, + workflow_run_id=workflow_run_id, + message_id=msg1.id, + form_id=form.id, + ) + db_session_with_containers.commit() + + result = repository.get_by_message_ids([msg1.id, msg2.id]) + + assert len(result) == 2 + # msg1 has one submitted content + assert len(result[0]) == 1 + content = result[0][0] + assert content.submitted is True + assert content.workflow_run_id == workflow_run_id + assert content.form_submission_data is not None + assert content.form_submission_data.action_id == "approve" + assert content.form_submission_data.action_text == "Approve" + assert content.form_submission_data.rendered_content == "Rendered Approve" + assert content.form_submission_data.node_id == "node-id" + assert content.form_submission_data.node_title == "Approval" + # msg2 has no content + assert result[1] == [] + + def test_returns_unsubmitted_form_definition( + self, + db_session_with_containers: Session, + repository: SQLAlchemyExecutionExtraContentRepository, + test_scope: _TestScope, + ) -> None: + """Waiting forms return full form_definition with resolved token and defaults.""" + workflow_run_id = str(uuid4()) + conversation = _create_conversation(db_session_with_containers, test_scope) + msg = _create_message(db_session_with_containers, test_scope, conversation.id, workflow_run_id) + + form = _create_waiting_form( + db_session_with_containers, + test_scope, + workflow_run_id=workflow_run_id, + default_values={"name": "John"}, + ) + delivery = _create_delivery(db_session_with_containers, form_id=form.id) + _create_recipient( + db_session_with_containers, + form_id=form.id, + delivery_id=delivery.id, + access_token="token-1", + ) + _create_human_input_content( + db_session_with_containers, + workflow_run_id=workflow_run_id, + message_id=msg.id, + form_id=form.id, + ) + db_session_with_containers.commit() + + result = repository.get_by_message_ids([msg.id]) + + assert len(result) == 1 + assert len(result[0]) == 1 + domain_content = result[0][0] + assert domain_content.submitted is False + assert domain_content.workflow_run_id == workflow_run_id + assert domain_content.form_definition is not None + form_def = domain_content.form_definition + assert form_def.form_id == form.id + assert form_def.node_id == "node-id" + assert form_def.node_title == "Approval" + assert form_def.form_content == "Rendered block" + assert form_def.display_in_ui is True + assert form_def.form_token == "token-1" + assert form_def.resolved_default_values == {"name": "John"} + assert form_def.expiration_time == int(form.expiration_time.timestamp()) + + def test_empty_message_ids_returns_empty_list( + self, + repository: SQLAlchemyExecutionExtraContentRepository, + ) -> None: + """Passing no message IDs returns an empty list without hitting the DB.""" + result = repository.get_by_message_ids([]) + assert result == [] diff --git a/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py new file mode 100644 index 00000000000..c5e9201ee37 --- /dev/null +++ b/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py @@ -0,0 +1,391 @@ +"""Integration tests for get_paginated_workflow_runs and get_workflow_runs_count using testcontainers.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import timedelta +from uuid import uuid4 + +import pytest +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus +from sqlalchemy import Engine, delete +from sqlalchemy import exc as sa_exc +from sqlalchemy.orm import Session, sessionmaker + +from libs.datetime_utils import naive_utc_now +from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom +from models.workflow import WorkflowRun, WorkflowType +from repositories.sqlalchemy_api_workflow_run_repository import DifyAPISQLAlchemyWorkflowRunRepository + + +class _TestWorkflowRunRepository(DifyAPISQLAlchemyWorkflowRunRepository): + """Concrete repository for tests where save() is not under test.""" + + def save(self, execution: WorkflowExecution) -> None: + return None + + +@dataclass +class _TestScope: + """Per-test data scope used to isolate DB rows.""" + + tenant_id: str = field(default_factory=lambda: str(uuid4())) + app_id: str = field(default_factory=lambda: str(uuid4())) + workflow_id: str = field(default_factory=lambda: str(uuid4())) + user_id: str = field(default_factory=lambda: str(uuid4())) + + +def _create_workflow_run( + session: Session, + scope: _TestScope, + *, + status: WorkflowExecutionStatus, + triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING, + created_at_offset: timedelta | None = None, +) -> WorkflowRun: + """Create and persist a workflow run bound to the current test scope.""" + now = naive_utc_now() + workflow_run = WorkflowRun( + id=str(uuid4()), + tenant_id=scope.tenant_id, + app_id=scope.app_id, + workflow_id=scope.workflow_id, + type=WorkflowType.WORKFLOW, + triggered_from=triggered_from, + version="draft", + graph="{}", + inputs="{}", + status=status, + created_by_role=CreatorUserRole.ACCOUNT, + created_by=scope.user_id, + created_at=now + created_at_offset if created_at_offset is not None else now, + ) + session.add(workflow_run) + session.commit() + return workflow_run + + +def _cleanup_scope_data(session: Session, scope: _TestScope) -> None: + """Remove test-created DB rows for a test scope.""" + session.execute( + delete(WorkflowRun).where( + WorkflowRun.tenant_id == scope.tenant_id, + WorkflowRun.app_id == scope.app_id, + ) + ) + session.commit() + + +@pytest.fixture +def repository(db_session_with_containers: Session) -> DifyAPISQLAlchemyWorkflowRunRepository: + """Build a repository backed by the testcontainers database engine.""" + engine = db_session_with_containers.get_bind() + assert isinstance(engine, Engine) + return _TestWorkflowRunRepository(session_maker=sessionmaker(bind=engine, expire_on_commit=False)) + + +@pytest.fixture +def test_scope(db_session_with_containers: Session) -> _TestScope: + """Provide an isolated scope and clean related data after each test.""" + scope = _TestScope() + yield scope + _cleanup_scope_data(db_session_with_containers, scope) + + +class TestGetPaginatedWorkflowRuns: + """Integration tests for get_paginated_workflow_runs.""" + + def test_returns_runs_without_status_filter( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Return all runs for the given tenant/app when no status filter is applied.""" + for status in ( + WorkflowExecutionStatus.SUCCEEDED, + WorkflowExecutionStatus.FAILED, + WorkflowExecutionStatus.RUNNING, + ): + _create_workflow_run(db_session_with_containers, test_scope, status=status) + + result = repository.get_paginated_workflow_runs( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + limit=20, + last_id=None, + status=None, + ) + + assert len(result.data) == 3 + assert result.limit == 20 + assert result.has_more is False + + def test_filters_by_status( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Return only runs matching the requested status.""" + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED) + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED) + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.FAILED) + + result = repository.get_paginated_workflow_runs( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + limit=20, + last_id=None, + status="succeeded", + ) + + assert len(result.data) == 2 + assert all(run.status == WorkflowExecutionStatus.SUCCEEDED for run in result.data) + + def test_pagination_has_more( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Return has_more=True when more records exist beyond the limit.""" + for i in range(5): + _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.SUCCEEDED, + created_at_offset=timedelta(seconds=i), + ) + + result = repository.get_paginated_workflow_runs( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + limit=3, + last_id=None, + status=None, + ) + + assert len(result.data) == 3 + assert result.has_more is True + + def test_cursor_based_pagination( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Cursor-based pagination returns the next page of results.""" + for i in range(5): + _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.SUCCEEDED, + created_at_offset=timedelta(seconds=i), + ) + + # First page + page1 = repository.get_paginated_workflow_runs( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + limit=3, + last_id=None, + status=None, + ) + assert len(page1.data) == 3 + assert page1.has_more is True + + # Second page using cursor + page2 = repository.get_paginated_workflow_runs( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + limit=3, + last_id=page1.data[-1].id, + status=None, + ) + assert len(page2.data) == 2 + assert page2.has_more is False + + # No overlap between pages + page1_ids = {r.id for r in page1.data} + page2_ids = {r.id for r in page2.data} + assert page1_ids.isdisjoint(page2_ids) + + def test_invalid_last_id_raises( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + test_scope: _TestScope, + ) -> None: + """Raise ValueError when last_id refers to a non-existent run.""" + with pytest.raises(ValueError, match="Last workflow run not exists"): + repository.get_paginated_workflow_runs( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + limit=20, + last_id=str(uuid4()), + status=None, + ) + + def test_tenant_isolation( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Runs from other tenants are not returned.""" + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED) + + other_scope = _TestScope(app_id=test_scope.app_id) + try: + _create_workflow_run(db_session_with_containers, other_scope, status=WorkflowExecutionStatus.SUCCEEDED) + + result = repository.get_paginated_workflow_runs( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + limit=20, + last_id=None, + status=None, + ) + + assert len(result.data) == 1 + assert result.data[0].tenant_id == test_scope.tenant_id + finally: + _cleanup_scope_data(db_session_with_containers, other_scope) + + +class TestGetWorkflowRunsCount: + """Integration tests for get_workflow_runs_count.""" + + def test_count_without_status_filter( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Count all runs grouped by status when no status filter is applied.""" + for _ in range(3): + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED) + for _ in range(2): + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.FAILED) + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.RUNNING) + + result = repository.get_workflow_runs_count( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + status=None, + ) + + assert result["total"] == 6 + assert result["succeeded"] == 3 + assert result["failed"] == 2 + assert result["running"] == 1 + assert result["stopped"] == 0 + assert result["partial-succeeded"] == 0 + + def test_count_with_status_filter( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Count only runs matching the requested status.""" + for _ in range(3): + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED) + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.FAILED) + + result = repository.get_workflow_runs_count( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + status="succeeded", + ) + + assert result["total"] == 3 + assert result["succeeded"] == 3 + assert result["failed"] == 0 + + def test_count_with_invalid_status_raises( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Invalid status raises StatementError because the column uses an enum type.""" + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED) + + with pytest.raises(sa_exc.StatementError) as exc_info: + repository.get_workflow_runs_count( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + status="invalid_status", + ) + assert isinstance(exc_info.value.orig, ValueError) + + def test_count_with_time_range( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Time range filter excludes runs created outside the window.""" + # Recent run (within 1 day) + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED) + # Old run (8 days ago) + _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.SUCCEEDED, + created_at_offset=timedelta(days=-8), + ) + + result = repository.get_workflow_runs_count( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + status=None, + time_range="7d", + ) + + assert result["total"] == 1 + assert result["succeeded"] == 1 + + def test_count_with_status_and_time_range( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Both status and time_range filters apply together.""" + # Recent succeeded + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED) + # Recent failed + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.FAILED) + # Old succeeded (outside time range) + _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.SUCCEEDED, + created_at_offset=timedelta(days=-8), + ) + + result = repository.get_workflow_runs_count( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + status="succeeded", + time_range="7d", + ) + + assert result["total"] == 1 + assert result["succeeded"] == 1 + assert result["failed"] == 0 diff --git a/api/dify_graph/model_runtime/__init__.py b/api/tests/test_containers_integration_tests/services/auth/__init__.py similarity index 100% rename from api/dify_graph/model_runtime/__init__.py rename to api/tests/test_containers_integration_tests/services/auth/__init__.py diff --git a/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py b/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py new file mode 100644 index 00000000000..177fb95ff3a --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py @@ -0,0 +1,289 @@ +from __future__ import annotations + +import json +from unittest.mock import Mock, patch +from uuid import uuid4 + +import pytest + +from models.source import DataSourceApiKeyAuthBinding +from services.auth.api_key_auth_service import ApiKeyAuthService + + +class TestApiKeyAuthService: + @pytest.fixture + def tenant_id(self) -> str: + return str(uuid4()) + + @pytest.fixture + def category(self) -> str: + return "search" + + @pytest.fixture + def provider(self) -> str: + return "google" + + @pytest.fixture + def mock_credentials(self) -> dict: + return {"auth_type": "api_key", "config": {"api_key": "test_secret_key_123"}} + + @pytest.fixture + def mock_args(self, category, provider, mock_credentials) -> dict: + return {"category": category, "provider": provider, "credentials": mock_credentials} + + def _create_binding(self, db_session, *, tenant_id, category, provider, credentials=None, disabled=False): + binding = DataSourceApiKeyAuthBinding( + tenant_id=tenant_id, + category=category, + provider=provider, + credentials=json.dumps(credentials, ensure_ascii=False) if credentials else None, + disabled=disabled, + ) + db_session.add(binding) + db_session.commit() + return binding + + def test_get_provider_auth_list_success( + self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + ): + self._create_binding(db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider) + db_session_with_containers.expire_all() + + result = ApiKeyAuthService.get_provider_auth_list(tenant_id) + + assert len(result) >= 1 + tenant_results = [r for r in result if r.tenant_id == tenant_id] + assert len(tenant_results) == 1 + assert tenant_results[0].provider == provider + + def test_get_provider_auth_list_empty(self, flask_app_with_containers, db_session_with_containers, tenant_id): + result = ApiKeyAuthService.get_provider_auth_list(tenant_id) + + tenant_results = [r for r in result if r.tenant_id == tenant_id] + assert tenant_results == [] + + def test_get_provider_auth_list_filters_disabled( + self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + ): + self._create_binding( + db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider, disabled=True + ) + db_session_with_containers.expire_all() + + result = ApiKeyAuthService.get_provider_auth_list(tenant_id) + + tenant_results = [r for r in result if r.tenant_id == tenant_id] + assert tenant_results == [] + + @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") + @patch("services.auth.api_key_auth_service.encrypter") + def test_create_provider_auth_success( + self, mock_encrypter, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args + ): + mock_auth_instance = Mock() + mock_auth_instance.validate_credentials.return_value = True + mock_factory.return_value = mock_auth_instance + mock_encrypter.encrypt_token.return_value = "encrypted_test_key_123" + + ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + + mock_factory.assert_called_once() + mock_auth_instance.validate_credentials.assert_called_once() + mock_encrypter.encrypt_token.assert_called_once_with(tenant_id, "test_secret_key_123") + + db_session_with_containers.expire_all() + bindings = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(tenant_id=tenant_id).all() + assert len(bindings) == 1 + + @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") + def test_create_provider_auth_validation_failed( + self, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args + ): + mock_auth_instance = Mock() + mock_auth_instance.validate_credentials.return_value = False + mock_factory.return_value = mock_auth_instance + + ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + + db_session_with_containers.expire_all() + bindings = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(tenant_id=tenant_id).all() + assert len(bindings) == 0 + + @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") + @patch("services.auth.api_key_auth_service.encrypter") + def test_create_provider_auth_encrypts_api_key( + self, mock_encrypter, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args + ): + mock_auth_instance = Mock() + mock_auth_instance.validate_credentials.return_value = True + mock_factory.return_value = mock_auth_instance + mock_encrypter.encrypt_token.return_value = "encrypted_test_key_123" + + original_key = mock_args["credentials"]["config"]["api_key"] + + ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + + assert mock_args["credentials"]["config"]["api_key"] == "encrypted_test_key_123" + assert mock_args["credentials"]["config"]["api_key"] != original_key + mock_encrypter.encrypt_token.assert_called_once_with(tenant_id, original_key) + + def test_get_auth_credentials_success( + self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider, mock_credentials + ): + self._create_binding( + db_session_with_containers, + tenant_id=tenant_id, + category=category, + provider=provider, + credentials=mock_credentials, + ) + db_session_with_containers.expire_all() + + result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider) + + assert result == mock_credentials + + def test_get_auth_credentials_not_found( + self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + ): + result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider) + + assert result is None + + def test_get_auth_credentials_json_parsing( + self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + ): + special_credentials = {"auth_type": "api_key", "config": {"api_key": "key_with_中文_and_special_chars_!@#$%"}} + self._create_binding( + db_session_with_containers, + tenant_id=tenant_id, + category=category, + provider=provider, + credentials=special_credentials, + ) + db_session_with_containers.expire_all() + + result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider) + + assert result == special_credentials + assert result["config"]["api_key"] == "key_with_中文_and_special_chars_!@#$%" + + def test_delete_provider_auth_success( + self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + ): + binding = self._create_binding( + db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider + ) + binding_id = binding.id + db_session_with_containers.expire_all() + + ApiKeyAuthService.delete_provider_auth(tenant_id, binding_id) + + db_session_with_containers.expire_all() + remaining = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(id=binding_id).first() + assert remaining is None + + def test_delete_provider_auth_not_found(self, flask_app_with_containers, db_session_with_containers, tenant_id): + # Should not raise when binding not found + ApiKeyAuthService.delete_provider_auth(tenant_id, str(uuid4())) + + def test_validate_api_key_auth_args_success(self, mock_args): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_missing_category(self, mock_args): + del mock_args["category"] + with pytest.raises(ValueError, match="category is required"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_empty_category(self, mock_args): + mock_args["category"] = "" + with pytest.raises(ValueError, match="category is required"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_missing_provider(self, mock_args): + del mock_args["provider"] + with pytest.raises(ValueError, match="provider is required"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_empty_provider(self, mock_args): + mock_args["provider"] = "" + with pytest.raises(ValueError, match="provider is required"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_missing_credentials(self, mock_args): + del mock_args["credentials"] + with pytest.raises(ValueError, match="credentials is required"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_empty_credentials(self, mock_args): + mock_args["credentials"] = None + with pytest.raises(ValueError, match="credentials is required"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_invalid_credentials_type(self, mock_args): + mock_args["credentials"] = "not_a_dict" + with pytest.raises(ValueError, match="credentials must be a dictionary"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_missing_auth_type(self, mock_args): + del mock_args["credentials"]["auth_type"] + with pytest.raises(ValueError, match="auth_type is required"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_empty_auth_type(self, mock_args): + mock_args["credentials"]["auth_type"] = "" + with pytest.raises(ValueError, match="auth_type is required"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + @pytest.mark.parametrize( + "malicious_input", + [ + "", + "'; DROP TABLE users; --", + "../../../etc/passwd", + "\\x00\\x00", + "A" * 10000, + ], + ) + def test_validate_api_key_auth_args_malicious_input(self, malicious_input, mock_args): + mock_args["category"] = malicious_input + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") + @patch("services.auth.api_key_auth_service.encrypter") + def test_create_provider_auth_database_error_handling( + self, mock_encrypter, mock_factory, flask_app_with_containers, tenant_id, mock_args + ): + mock_auth_instance = Mock() + mock_auth_instance.validate_credentials.return_value = True + mock_factory.return_value = mock_auth_instance + mock_encrypter.encrypt_token.return_value = "encrypted_key" + + with patch("services.auth.api_key_auth_service.db.session") as mock_session: + mock_session.commit.side_effect = Exception("Database error") + with pytest.raises(Exception, match="Database error"): + ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + + @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") + def test_create_provider_auth_factory_exception(self, mock_factory, tenant_id, mock_args): + mock_factory.side_effect = Exception("Factory error") + with pytest.raises(Exception, match="Factory error"): + ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + + @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") + @patch("services.auth.api_key_auth_service.encrypter") + def test_create_provider_auth_encryption_exception(self, mock_encrypter, mock_factory, tenant_id, mock_args): + mock_auth_instance = Mock() + mock_auth_instance.validate_credentials.return_value = True + mock_factory.return_value = mock_auth_instance + mock_encrypter.encrypt_token.side_effect = Exception("Encryption error") + with pytest.raises(Exception, match="Encryption error"): + ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + + def test_validate_api_key_auth_args_none_input(self): + with pytest.raises(TypeError): + ApiKeyAuthService.validate_api_key_auth_args(None) + + def test_validate_api_key_auth_args_dict_credentials_with_list_auth_type(self, mock_args): + mock_args["credentials"]["auth_type"] = ["api_key"] + ApiKeyAuthService.validate_api_key_auth_args(mock_args) diff --git a/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py b/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py new file mode 100644 index 00000000000..dc4c0fda1d4 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py @@ -0,0 +1,264 @@ +""" +API Key Authentication System Integration Tests +""" + +from __future__ import annotations + +from concurrent.futures import ThreadPoolExecutor +from unittest.mock import Mock, patch +from uuid import uuid4 + +import httpx +import pytest + +from models.source import DataSourceApiKeyAuthBinding +from services.auth.api_key_auth_factory import ApiKeyAuthFactory +from services.auth.api_key_auth_service import ApiKeyAuthService +from services.auth.auth_type import AuthType + + +class TestAuthIntegration: + @pytest.fixture + def tenant_id_1(self) -> str: + return str(uuid4()) + + @pytest.fixture + def tenant_id_2(self) -> str: + return str(uuid4()) + + @pytest.fixture + def category(self) -> str: + return "search" + + @pytest.fixture + def firecrawl_credentials(self) -> dict: + return {"auth_type": "bearer", "config": {"api_key": "fc_test_key_123"}} + + @pytest.fixture + def jina_credentials(self) -> dict: + return {"auth_type": "bearer", "config": {"api_key": "jina_test_key_456"}} + + @patch("services.auth.firecrawl.firecrawl.httpx.post") + @patch("services.auth.api_key_auth_service.encrypter.encrypt_token") + def test_end_to_end_auth_flow( + self, + mock_encrypt, + mock_http, + flask_app_with_containers, + db_session_with_containers, + tenant_id_1, + category, + firecrawl_credentials, + ): + mock_http.return_value = self._create_success_response() + mock_encrypt.return_value = "encrypted_fc_test_key_123" + + args = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials} + ApiKeyAuthService.create_provider_auth(tenant_id_1, args) + + mock_http.assert_called_once() + call_args = mock_http.call_args + assert "https://api.firecrawl.dev/v1/crawl" in call_args[0][0] + assert call_args[1]["headers"]["Authorization"] == "Bearer fc_test_key_123" + + mock_encrypt.assert_called_once_with(tenant_id_1, "fc_test_key_123") + + db_session_with_containers.expire_all() + bindings = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(tenant_id=tenant_id_1).all() + assert len(bindings) == 1 + assert bindings[0].provider == AuthType.FIRECRAWL + + @patch("services.auth.firecrawl.firecrawl.httpx.post") + def test_cross_component_integration(self, mock_http, firecrawl_credentials): + mock_http.return_value = self._create_success_response() + factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, firecrawl_credentials) + result = factory.validate_credentials() + + assert result is True + mock_http.assert_called_once() + + @patch("services.auth.api_key_auth_service.encrypter.encrypt_token") + @patch("services.auth.firecrawl.firecrawl.httpx.post") + @patch("services.auth.jina.jina.httpx.post") + def test_multi_tenant_isolation( + self, + mock_jina_http, + mock_fc_http, + mock_encrypt, + flask_app_with_containers, + db_session_with_containers, + tenant_id_1, + tenant_id_2, + category, + firecrawl_credentials, + jina_credentials, + ): + mock_fc_http.return_value = self._create_success_response() + mock_jina_http.return_value = self._create_success_response() + mock_encrypt.return_value = "encrypted_key" + + args1 = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials} + ApiKeyAuthService.create_provider_auth(tenant_id_1, args1) + + args2 = {"category": category, "provider": AuthType.JINA, "credentials": jina_credentials} + ApiKeyAuthService.create_provider_auth(tenant_id_2, args2) + + db_session_with_containers.expire_all() + + result1 = ApiKeyAuthService.get_provider_auth_list(tenant_id_1) + result2 = ApiKeyAuthService.get_provider_auth_list(tenant_id_2) + + assert len(result1) == 1 + assert result1[0].tenant_id == tenant_id_1 + assert len(result2) == 1 + assert result2[0].tenant_id == tenant_id_2 + + def test_cross_tenant_access_prevention( + self, flask_app_with_containers, db_session_with_containers, tenant_id_2, category + ): + result = ApiKeyAuthService.get_auth_credentials(tenant_id_2, category, AuthType.FIRECRAWL) + + assert result is None + + def test_sensitive_data_protection(self): + credentials_with_secrets = { + "auth_type": "bearer", + "config": {"api_key": "super_secret_key_do_not_log", "secret": "another_secret"}, + } + + factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, credentials_with_secrets) + factory_str = str(factory) + + assert "super_secret_key_do_not_log" not in factory_str + assert "another_secret" not in factory_str + + @patch("services.auth.firecrawl.firecrawl.httpx.post") + @patch("services.auth.api_key_auth_service.encrypter.encrypt_token", return_value="encrypted_key") + def test_concurrent_creation_safety( + self, + mock_encrypt, + mock_http, + flask_app_with_containers, + db_session_with_containers, + tenant_id_1, + category, + firecrawl_credentials, + ): + app = flask_app_with_containers + mock_http.return_value = self._create_success_response() + + results = [] + exceptions = [] + + def create_auth(): + try: + with app.app_context(): + thread_args = { + "category": category, + "provider": AuthType.FIRECRAWL, + "credentials": {"auth_type": "bearer", "config": {"api_key": "fc_test_key_123"}}, + } + ApiKeyAuthService.create_provider_auth(tenant_id_1, thread_args) + results.append("success") + except Exception as e: + exceptions.append(e) + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(create_auth) for _ in range(5)] + for future in futures: + future.result() + + assert len(results) == 5 + assert len(exceptions) == 0 + + @pytest.mark.parametrize( + "invalid_input", + [ + None, + {}, + {"auth_type": "bearer"}, + {"auth_type": "bearer", "config": {}}, + ], + ) + def test_invalid_input_boundary(self, invalid_input): + with pytest.raises((ValueError, KeyError, TypeError, AttributeError)): + ApiKeyAuthFactory(AuthType.FIRECRAWL, invalid_input) + + @patch("services.auth.firecrawl.firecrawl.httpx.post") + def test_http_error_handling(self, mock_http, firecrawl_credentials): + mock_response = Mock() + mock_response.status_code = 401 + mock_response.text = '{"error": "Unauthorized"}' + mock_response.raise_for_status.side_effect = httpx.HTTPError("Unauthorized") + mock_http.return_value = mock_response + + factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, firecrawl_credentials) + with pytest.raises((httpx.HTTPError, Exception)): + factory.validate_credentials() + + @patch("services.auth.firecrawl.firecrawl.httpx.post") + def test_network_failure_recovery( + self, + mock_http, + flask_app_with_containers, + db_session_with_containers, + tenant_id_1, + category, + firecrawl_credentials, + ): + mock_http.side_effect = httpx.RequestError("Network timeout") + + args = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials} + + with pytest.raises(httpx.RequestError): + ApiKeyAuthService.create_provider_auth(tenant_id_1, args) + + db_session_with_containers.expire_all() + bindings = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(tenant_id=tenant_id_1).all() + assert len(bindings) == 0 + + @pytest.mark.parametrize( + ("provider", "credentials"), + [ + (AuthType.FIRECRAWL, {"auth_type": "bearer", "config": {"api_key": "fc_key"}}), + (AuthType.JINA, {"auth_type": "bearer", "config": {"api_key": "jina_key"}}), + (AuthType.WATERCRAWL, {"auth_type": "x-api-key", "config": {"api_key": "wc_key"}}), + ], + ) + def test_all_providers_factory_creation(self, provider, credentials): + auth_class = ApiKeyAuthFactory.get_apikey_auth_factory(provider) + assert auth_class is not None + + factory = ApiKeyAuthFactory(provider, credentials) + assert factory.auth is not None + + @patch("services.auth.api_key_auth_service.encrypter.encrypt_token") + @patch("services.auth.firecrawl.firecrawl.httpx.post") + def test_get_auth_credentials_returns_stored_credentials( + self, + mock_http, + mock_encrypt, + flask_app_with_containers, + db_session_with_containers, + tenant_id_1, + category, + firecrawl_credentials, + ): + mock_http.return_value = self._create_success_response() + mock_encrypt.return_value = "encrypted_key" + + args = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials} + ApiKeyAuthService.create_provider_auth(tenant_id_1, args) + + db_session_with_containers.expire_all() + + result = ApiKeyAuthService.get_auth_credentials(tenant_id_1, category, AuthType.FIRECRAWL) + assert result is not None + assert result["config"]["api_key"] == "encrypted_key" + + def _create_success_response(self, status_code=200): + mock_response = Mock() + mock_response.status_code = status_code + mock_response.json.return_value = {"status": "success"} + mock_response.raise_for_status.return_value = None + return mock_response diff --git a/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py b/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py index 6b35f867d75..02c3d1a80e3 100644 --- a/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py +++ b/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py @@ -13,6 +13,7 @@ import pytest from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import AppDatasetJoin, Dataset, DatasetPermissionEnum from models.enums import DataSourceType @@ -74,7 +75,7 @@ class DatasetUpdateDeleteTestDataFactory: name=name, description="Test description", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=created_by, permission=permission, provider="vendor", diff --git a/api/tests/test_containers_integration_tests/services/document_service_status.py b/api/tests/test_containers_integration_tests/services/document_service_status.py index f995ac7bef6..42d587b7f77 100644 --- a/api/tests/test_containers_integration_tests/services/document_service_status.py +++ b/api/tests/test_containers_integration_tests/services/document_service_status.py @@ -13,6 +13,7 @@ from uuid import uuid4 import pytest +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.storage.storage_type import StorageType from models import Account from models.dataset import Dataset, Document @@ -91,7 +92,7 @@ class DocumentStatusTestDataFactory: name=name, created_from=DocumentCreatedFrom.WEB, created_by=created_by, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) document.id = document_id document.indexing_status = indexing_status diff --git a/api/dify_graph/model_runtime/callbacks/__init__.py b/api/tests/test_containers_integration_tests/services/enterprise/__init__.py similarity index 100% rename from api/dify_graph/model_runtime/callbacks/__init__.py rename to api/tests/test_containers_integration_tests/services/enterprise/__init__.py diff --git a/api/tests/test_containers_integration_tests/services/enterprise/test_account_deletion_sync.py b/api/tests/test_containers_integration_tests/services/enterprise/test_account_deletion_sync.py new file mode 100644 index 00000000000..4e8255d8ed9 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/enterprise/test_account_deletion_sync.py @@ -0,0 +1,200 @@ +"""Integration tests for account deletion synchronization. + +Verifies enterprise account deletion sync functionality including +Redis queuing, error handling, and community vs enterprise behavior. +""" + +from __future__ import annotations + +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from redis import RedisError + +from extensions.ext_redis import redis_client +from models.account import TenantAccountJoin +from services.enterprise.account_deletion_sync import ( + _queue_task, + sync_account_deletion, + sync_workspace_member_removal, +) + + +class TestQueueTask: + def test_queue_task_success(self): + workspace_id = str(uuid4()) + member_id = str(uuid4()) + + result = _queue_task(workspace_id=workspace_id, member_id=member_id, source="test_source") + + assert result is True + + import json + + raw = redis_client.rpop("enterprise:member:sync:queue") + assert raw is not None + task_data = json.loads(raw) + assert task_data["workspace_id"] == workspace_id + assert task_data["member_id"] == member_id + assert task_data["source"] == "test_source" + assert task_data["type"] == "sync_member_deletion_from_workspace" + assert task_data["retry_count"] == 0 + assert "task_id" in task_data + assert "created_at" in task_data + + def test_queue_task_redis_error(self, caplog): + with patch("services.enterprise.account_deletion_sync.redis_client") as mock_redis: + mock_redis.lpush.side_effect = RedisError("Connection failed") + + result = _queue_task(workspace_id="ws-123", member_id="member-456", source="test_source") + + assert result is False + assert "Failed to queue account deletion sync" in caplog.text + + def test_queue_task_type_error(self, caplog): + with patch("services.enterprise.account_deletion_sync.redis_client") as mock_redis: + mock_redis.lpush.side_effect = TypeError("Cannot serialize") + + result = _queue_task(workspace_id="ws-123", member_id="member-456", source="test_source") + + assert result is False + assert "Failed to queue account deletion sync" in caplog.text + + +class TestSyncWorkspaceMemberRemoval: + @pytest.fixture + def mock_queue_task(self): + with patch("services.enterprise.account_deletion_sync._queue_task") as mock_queue: + mock_queue.return_value = True + yield mock_queue + + def test_sync_workspace_member_removal_enterprise_enabled(self, mock_queue_task): + workspace_id = str(uuid4()) + member_id = str(uuid4()) + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + result = sync_workspace_member_removal(workspace_id=workspace_id, member_id=member_id, source="removed") + + assert result is True + mock_queue_task.assert_called_once_with(workspace_id=workspace_id, member_id=member_id, source="removed") + + def test_sync_workspace_member_removal_enterprise_disabled(self, mock_queue_task): + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = False + + result = sync_workspace_member_removal( + workspace_id=str(uuid4()), member_id=str(uuid4()), source="test_source" + ) + + assert result is True + mock_queue_task.assert_not_called() + + def test_sync_workspace_member_removal_queue_failure(self, mock_queue_task): + mock_queue_task.return_value = False + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + result = sync_workspace_member_removal( + workspace_id=str(uuid4()), member_id=str(uuid4()), source="test_source" + ) + + assert result is False + + +class TestSyncAccountDeletion: + @pytest.fixture + def mock_queue_task(self): + with patch("services.enterprise.account_deletion_sync._queue_task") as mock_queue: + mock_queue.return_value = True + yield mock_queue + + def test_sync_account_deletion_enterprise_disabled(self, mock_queue_task): + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = False + + result = sync_account_deletion(account_id=str(uuid4()), source="account_deleted") + + assert result is True + mock_queue_task.assert_not_called() + + def test_sync_account_deletion_multiple_workspaces( + self, flask_app_with_containers, db_session_with_containers, mock_queue_task + ): + account_id = str(uuid4()) + tenant_ids = [str(uuid4()) for _ in range(3)] + + for tenant_id in tenant_ids: + join = TenantAccountJoin(tenant_id=tenant_id, account_id=account_id) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + result = sync_account_deletion(account_id=account_id, source="account_deleted") + + assert result is True + assert mock_queue_task.call_count == 3 + + queued_workspace_ids = {call.kwargs["workspace_id"] for call in mock_queue_task.call_args_list} + assert queued_workspace_ids == set(tenant_ids) + + def test_sync_account_deletion_no_workspaces( + self, flask_app_with_containers, db_session_with_containers, mock_queue_task + ): + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + result = sync_account_deletion(account_id=str(uuid4()), source="account_deleted") + + assert result is True + mock_queue_task.assert_not_called() + + def test_sync_account_deletion_partial_failure( + self, flask_app_with_containers, db_session_with_containers, mock_queue_task + ): + account_id = str(uuid4()) + tenant_ids = [str(uuid4()) for _ in range(3)] + fail_tenant = tenant_ids[1] + + for tenant_id in tenant_ids: + join = TenantAccountJoin(tenant_id=tenant_id, account_id=account_id) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + def queue_side_effect(workspace_id, member_id, source): + return workspace_id != fail_tenant + + mock_queue_task.side_effect = queue_side_effect + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + result = sync_account_deletion(account_id=account_id, source="account_deleted") + + assert result is False + assert mock_queue_task.call_count == 3 + + def test_sync_account_deletion_all_failures( + self, flask_app_with_containers, db_session_with_containers, mock_queue_task + ): + account_id = str(uuid4()) + tenant_id = str(uuid4()) + + join = TenantAccountJoin(tenant_id=tenant_id, account_id=account_id) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + mock_queue_task.return_value = False + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + result = sync_account_deletion(account_id=account_id, source="account_deleted") + + assert result is False + mock_queue_task.assert_called_once() diff --git a/api/dify_graph/model_runtime/errors/__init__.py b/api/tests/test_containers_integration_tests/services/plugin/__init__.py similarity index 100% rename from api/dify_graph/model_runtime/errors/__init__.py rename to api/tests/test_containers_integration_tests/services/plugin/__init__.py diff --git a/api/tests/unit_tests/services/plugin/test_plugin_parameter_service.py b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_parameter_service.py similarity index 78% rename from api/tests/unit_tests/services/plugin/test_plugin_parameter_service.py rename to api/tests/test_containers_integration_tests/services/plugin/test_plugin_parameter_service.py index bfa9fe976b7..ce9f10e207d 100644 --- a/api/tests/unit_tests/services/plugin/test_plugin_parameter_service.py +++ b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_parameter_service.py @@ -6,10 +6,14 @@ HIDDEN_VALUE replacement, and error handling for missing records. from __future__ import annotations +import json from unittest.mock import MagicMock, patch +from uuid import uuid4 import pytest +from core.plugin.entities.plugin_daemon import CredentialType +from models.tools import BuiltinToolProvider from services.plugin.plugin_parameter_service import PluginParameterService @@ -39,67 +43,73 @@ class TestGetDynamicSelectOptionsTool: @patch("services.plugin.plugin_parameter_service.DynamicSelectClient") @patch("services.plugin.plugin_parameter_service.create_tool_provider_encrypter") - @patch("services.plugin.plugin_parameter_service.db") @patch("services.plugin.plugin_parameter_service.ToolManager") - def test_fetches_credentials_with_credential_id(self, mock_tool_mgr, mock_db, mock_encrypter_fn, mock_client_cls): + def test_fetches_credentials_with_credential_id( + self, + mock_tool_mgr, + mock_encrypter_fn, + mock_client_cls, + flask_app_with_containers, + db_session_with_containers, + ): + tenant_id = str(uuid4()) provider_ctrl = MagicMock() provider_ctrl.need_credentials = True mock_tool_mgr.get_builtin_provider.return_value = provider_ctrl encrypter = MagicMock() encrypter.decrypt.return_value = {"api_key": "decrypted"} mock_encrypter_fn.return_value = (encrypter, None) + mock_client_cls.return_value.fetch_dynamic_select_options.return_value.options = ["opt1"] - # Mock the Session/query chain - db_record = MagicMock() - db_record.credentials = {"api_key": "encrypted"} - db_record.credential_type = "api_key" + db_record = BuiltinToolProvider( + tenant_id=tenant_id, + user_id=str(uuid4()), + provider="google", + name="API KEY 1", + encrypted_credentials=json.dumps({"api_key": "encrypted"}), + credential_type=CredentialType.API_KEY, + ) + db_session_with_containers.add(db_record) + db_session_with_containers.commit() - with patch("services.plugin.plugin_parameter_service.Session") as mock_session_cls: - mock_session = MagicMock() - mock_session_cls.return_value.__enter__ = MagicMock(return_value=mock_session) - mock_session_cls.return_value.__exit__ = MagicMock(return_value=False) - mock_session.query.return_value.where.return_value.first.return_value = db_record - mock_client_cls.return_value.fetch_dynamic_select_options.return_value.options = ["opt1"] - - result = PluginParameterService.get_dynamic_select_options( - tenant_id="t1", - user_id="u1", - plugin_id="p1", - provider="google", - action="search", - parameter="engine", - credential_id="cred-1", - provider_type="tool", - ) + result = PluginParameterService.get_dynamic_select_options( + tenant_id=tenant_id, + user_id="u1", + plugin_id="p1", + provider="google", + action="search", + parameter="engine", + credential_id=db_record.id, + provider_type="tool", + ) assert result == ["opt1"] @patch("services.plugin.plugin_parameter_service.create_tool_provider_encrypter") - @patch("services.plugin.plugin_parameter_service.db") @patch("services.plugin.plugin_parameter_service.ToolManager") - def test_raises_when_tool_provider_not_found(self, mock_tool_mgr, mock_db, mock_encrypter_fn): + def test_raises_when_tool_provider_not_found( + self, + mock_tool_mgr, + mock_encrypter_fn, + flask_app_with_containers, + db_session_with_containers, + ): provider_ctrl = MagicMock() provider_ctrl.need_credentials = True mock_tool_mgr.get_builtin_provider.return_value = provider_ctrl mock_encrypter_fn.return_value = (MagicMock(), None) - with patch("services.plugin.plugin_parameter_service.Session") as mock_session_cls: - mock_session = MagicMock() - mock_session_cls.return_value.__enter__ = MagicMock(return_value=mock_session) - mock_session_cls.return_value.__exit__ = MagicMock(return_value=False) - mock_session.query.return_value.where.return_value.order_by.return_value.first.return_value = None - - with pytest.raises(ValueError, match="not found"): - PluginParameterService.get_dynamic_select_options( - tenant_id="t1", - user_id="u1", - plugin_id="p1", - provider="google", - action="search", - parameter="engine", - credential_id=None, - provider_type="tool", - ) + with pytest.raises(ValueError, match="not found"): + PluginParameterService.get_dynamic_select_options( + tenant_id=str(uuid4()), + user_id="u1", + plugin_id="p1", + provider="google", + action="search", + parameter="engine", + credential_id=None, + provider_type="tool", + ) class TestGetDynamicSelectOptionsTrigger: diff --git a/api/tests/unit_tests/services/plugin/test_plugin_service.py b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_service.py similarity index 78% rename from api/tests/unit_tests/services/plugin/test_plugin_service.py rename to api/tests/test_containers_integration_tests/services/plugin/test_plugin_service.py index 09b9ab498be..0cdae572fba 100644 --- a/api/tests/unit_tests/services/plugin/test_plugin_service.py +++ b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_service.py @@ -8,15 +8,27 @@ verification, marketplace upgrade flows, and uninstall with credential cleanup. from __future__ import annotations from unittest.mock import MagicMock, patch +from uuid import uuid4 import pytest +from sqlalchemy import select from core.plugin.entities.plugin import PluginInstallationSource from core.plugin.entities.plugin_daemon import PluginVerification +from models.provider import Provider, ProviderCredential, TenantPreferredModelProvider from services.errors.plugin import PluginInstallationForbiddenError from services.feature_service import PluginInstallationScope from services.plugin.plugin_service import PluginService -from tests.unit_tests.services.plugin.conftest import make_features + + +def _make_features( + restrict_to_marketplace: bool = False, + scope: PluginInstallationScope = PluginInstallationScope.ALL, +) -> MagicMock: + features = MagicMock() + features.plugin_installation_permission.restrict_to_marketplace_only = restrict_to_marketplace + features.plugin_installation_permission.plugin_installation_scope = scope + return features class TestFetchLatestPluginVersion: @@ -80,14 +92,14 @@ class TestFetchLatestPluginVersion: class TestCheckMarketplaceOnlyPermission: @patch("services.plugin.plugin_service.FeatureService") def test_raises_when_restricted(self, mock_fs): - mock_fs.get_system_features.return_value = make_features(restrict_to_marketplace=True) + mock_fs.get_system_features.return_value = _make_features(restrict_to_marketplace=True) with pytest.raises(PluginInstallationForbiddenError): PluginService._check_marketplace_only_permission() @patch("services.plugin.plugin_service.FeatureService") def test_passes_when_not_restricted(self, mock_fs): - mock_fs.get_system_features.return_value = make_features(restrict_to_marketplace=False) + mock_fs.get_system_features.return_value = _make_features(restrict_to_marketplace=False) PluginService._check_marketplace_only_permission() # should not raise @@ -95,7 +107,7 @@ class TestCheckMarketplaceOnlyPermission: class TestCheckPluginInstallationScope: @patch("services.plugin.plugin_service.FeatureService") def test_official_only_allows_langgenius(self, mock_fs): - mock_fs.get_system_features.return_value = make_features(scope=PluginInstallationScope.OFFICIAL_ONLY) + mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.OFFICIAL_ONLY) verification = MagicMock() verification.authorized_category = PluginVerification.AuthorizedCategory.Langgenius @@ -103,14 +115,14 @@ class TestCheckPluginInstallationScope: @patch("services.plugin.plugin_service.FeatureService") def test_official_only_rejects_third_party(self, mock_fs): - mock_fs.get_system_features.return_value = make_features(scope=PluginInstallationScope.OFFICIAL_ONLY) + mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.OFFICIAL_ONLY) with pytest.raises(PluginInstallationForbiddenError): PluginService._check_plugin_installation_scope(None) @patch("services.plugin.plugin_service.FeatureService") def test_official_and_partners_allows_partner(self, mock_fs): - mock_fs.get_system_features.return_value = make_features( + mock_fs.get_system_features.return_value = _make_features( scope=PluginInstallationScope.OFFICIAL_AND_SPECIFIC_PARTNERS ) verification = MagicMock() @@ -120,7 +132,7 @@ class TestCheckPluginInstallationScope: @patch("services.plugin.plugin_service.FeatureService") def test_official_and_partners_rejects_none(self, mock_fs): - mock_fs.get_system_features.return_value = make_features( + mock_fs.get_system_features.return_value = _make_features( scope=PluginInstallationScope.OFFICIAL_AND_SPECIFIC_PARTNERS ) @@ -129,7 +141,7 @@ class TestCheckPluginInstallationScope: @patch("services.plugin.plugin_service.FeatureService") def test_none_scope_always_raises(self, mock_fs): - mock_fs.get_system_features.return_value = make_features(scope=PluginInstallationScope.NONE) + mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.NONE) verification = MagicMock() verification.authorized_category = PluginVerification.AuthorizedCategory.Langgenius @@ -138,7 +150,7 @@ class TestCheckPluginInstallationScope: @patch("services.plugin.plugin_service.FeatureService") def test_all_scope_passes_any(self, mock_fs): - mock_fs.get_system_features.return_value = make_features(scope=PluginInstallationScope.ALL) + mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.ALL) PluginService._check_plugin_installation_scope(None) # should not raise @@ -209,9 +221,9 @@ class TestUpgradePluginWithMarketplace: @patch("services.plugin.plugin_service.dify_config") def test_skips_download_when_already_installed(self, mock_config, mock_installer_cls, mock_fs, mock_marketplace): mock_config.MARKETPLACE_ENABLED = True - mock_fs.get_system_features.return_value = make_features() + mock_fs.get_system_features.return_value = _make_features() installer = mock_installer_cls.return_value - installer.fetch_plugin_manifest.return_value = MagicMock() # no exception = already installed + installer.fetch_plugin_manifest.return_value = MagicMock() installer.upgrade_plugin.return_value = MagicMock() PluginService.upgrade_plugin_with_marketplace("t1", "old-uid", "new-uid") @@ -225,7 +237,7 @@ class TestUpgradePluginWithMarketplace: @patch("services.plugin.plugin_service.dify_config") def test_downloads_when_not_installed(self, mock_config, mock_installer_cls, mock_fs, mock_download): mock_config.MARKETPLACE_ENABLED = True - mock_fs.get_system_features.return_value = make_features() + mock_fs.get_system_features.return_value = _make_features() installer = mock_installer_cls.return_value installer.fetch_plugin_manifest.side_effect = RuntimeError("not found") mock_download.return_value = b"pkg-bytes" @@ -244,7 +256,7 @@ class TestUpgradePluginWithGithub: @patch("services.plugin.plugin_service.FeatureService") @patch("services.plugin.plugin_service.PluginInstaller") def test_checks_marketplace_permission_and_delegates(self, mock_installer_cls, mock_fs): - mock_fs.get_system_features.return_value = make_features() + mock_fs.get_system_features.return_value = _make_features() installer = mock_installer_cls.return_value installer.upgrade_plugin.return_value = MagicMock() @@ -259,7 +271,7 @@ class TestUploadPkg: @patch("services.plugin.plugin_service.FeatureService") @patch("services.plugin.plugin_service.PluginInstaller") def test_runs_permission_and_scope_checks(self, mock_installer_cls, mock_fs): - mock_fs.get_system_features.return_value = make_features() + mock_fs.get_system_features.return_value = _make_features() upload_resp = MagicMock() upload_resp.verification = None mock_installer_cls.return_value.upload_pkg.return_value = upload_resp @@ -283,7 +295,7 @@ class TestInstallFromMarketplacePkg: @patch("services.plugin.plugin_service.dify_config") def test_downloads_when_not_cached(self, mock_config, mock_installer_cls, mock_fs, mock_download): mock_config.MARKETPLACE_ENABLED = True - mock_fs.get_system_features.return_value = make_features() + mock_fs.get_system_features.return_value = _make_features() installer = mock_installer_cls.return_value installer.fetch_plugin_manifest.side_effect = RuntimeError("not found") mock_download.return_value = b"pkg" @@ -298,14 +310,14 @@ class TestInstallFromMarketplacePkg: assert result == "task-id" installer.install_from_identifiers.assert_called_once() call_args = installer.install_from_identifiers.call_args[0] - assert call_args[1] == ["resolved-uid"] # uses response uid, not input + assert call_args[1] == ["resolved-uid"] @patch("services.plugin.plugin_service.FeatureService") @patch("services.plugin.plugin_service.PluginInstaller") @patch("services.plugin.plugin_service.dify_config") def test_uses_cached_when_already_downloaded(self, mock_config, mock_installer_cls, mock_fs): mock_config.MARKETPLACE_ENABLED = True - mock_fs.get_system_features.return_value = make_features() + mock_fs.get_system_features.return_value = _make_features() installer = mock_installer_cls.return_value installer.fetch_plugin_manifest.return_value = MagicMock() decode_resp = MagicMock() @@ -317,7 +329,7 @@ class TestInstallFromMarketplacePkg: installer.install_from_identifiers.assert_called_once() call_args = installer.install_from_identifiers.call_args[0] - assert call_args[1] == ["uid-1"] # uses original uid + assert call_args[1] == ["uid-1"] class TestUninstall: @@ -332,26 +344,70 @@ class TestUninstall: assert result is True installer.uninstall.assert_called_once_with("t1", "install-1") - @patch("services.plugin.plugin_service.db") @patch("services.plugin.plugin_service.PluginInstaller") - def test_cleans_credentials_when_plugin_found(self, mock_installer_cls, mock_db): + def test_cleans_credentials_when_plugin_found( + self, mock_installer_cls, flask_app_with_containers, db_session_with_containers + ): + tenant_id = str(uuid4()) + plugin_id = "org/myplugin" + provider_name = f"{plugin_id}/model-provider" + + credential = ProviderCredential( + tenant_id=tenant_id, + provider_name=provider_name, + credential_name="default", + encrypted_config="{}", + ) + db_session_with_containers.add(credential) + db_session_with_containers.flush() + credential_id = credential.id + + provider = Provider( + tenant_id=tenant_id, + provider_name=provider_name, + credential_id=credential_id, + ) + db_session_with_containers.add(provider) + db_session_with_containers.flush() + provider_id = provider.id + + pref = TenantPreferredModelProvider( + tenant_id=tenant_id, + provider_name=provider_name, + preferred_provider_type="custom", + ) + db_session_with_containers.add(pref) + db_session_with_containers.commit() + plugin = MagicMock() plugin.installation_id = "install-1" - plugin.plugin_id = "org/myplugin" + plugin.plugin_id = plugin_id installer = mock_installer_cls.return_value installer.list_plugins.return_value = [plugin] installer.uninstall.return_value = True - # Mock Session context manager - mock_session = MagicMock() - mock_db.engine = MagicMock() - mock_session.scalars.return_value.all.return_value = [] # no credentials found - - with patch("services.plugin.plugin_service.Session") as mock_session_cls: - mock_session_cls.return_value.__enter__ = MagicMock(return_value=mock_session) - mock_session_cls.return_value.__exit__ = MagicMock(return_value=False) - - result = PluginService.uninstall("t1", "install-1") + with patch("services.plugin.plugin_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = False + result = PluginService.uninstall(tenant_id, "install-1") assert result is True installer.uninstall.assert_called_once() + + db_session_with_containers.expire_all() + + remaining_creds = db_session_with_containers.scalars( + select(ProviderCredential).where(ProviderCredential.id == credential_id) + ).all() + assert len(remaining_creds) == 0 + + updated_provider = db_session_with_containers.get(Provider, provider_id) + assert updated_provider is not None + assert updated_provider.credential_id is None + + remaining_prefs = db_session_with_containers.scalars( + select(TenantPreferredModelProvider).where( + TenantPreferredModelProvider.tenant_id == tenant_id, + TenantPreferredModelProvider.provider_name == provider_name, + ) + ).all() + assert len(remaining_prefs) == 0 diff --git a/api/dify_graph/model_runtime/model_providers/__base/__init__.py b/api/tests/test_containers_integration_tests/services/recommend_app/__init__.py similarity index 100% rename from api/dify_graph/model_runtime/model_providers/__base/__init__.py rename to api/tests/test_containers_integration_tests/services/recommend_app/__init__.py diff --git a/api/tests/test_containers_integration_tests/services/recommend_app/test_database_retrieval.py b/api/tests/test_containers_integration_tests/services/recommend_app/test_database_retrieval.py new file mode 100644 index 00000000000..2b842629a72 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/recommend_app/test_database_retrieval.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +from unittest.mock import patch +from uuid import uuid4 + +from models.model import App, RecommendedApp, Site +from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval +from services.recommend_app.recommend_app_type import RecommendAppType + + +def _create_app(db_session, *, tenant_id: str, is_public: bool = True) -> App: + app = App( + tenant_id=tenant_id, + name=f"app-{uuid4()}", + mode="chat", + enable_site=True, + enable_api=True, + is_public=is_public, + ) + app.id = str(uuid4()) + db_session.add(app) + db_session.commit() + return app + + +def _create_site(db_session, *, app_id: str) -> Site: + site = Site( + app_id=app_id, + title=f"site-{uuid4()}", + default_language="en-US", + customize_token_strategy="not_allow", + description="desc", + copyright="copy", + privacy_policy="pp", + custom_disclaimer="cd", + ) + site.id = str(uuid4()) + db_session.add(site) + db_session.commit() + return site + + +def _create_recommended_app( + db_session, + *, + app_id: str, + category: str = "chat", + language: str = "en-US", + is_listed: bool = True, + position: int = 1, +) -> RecommendedApp: + rec = RecommendedApp( + app_id=app_id, + description={"en-US": "test"}, + copyright="copy", + privacy_policy="pp", + category=category, + language=language, + is_listed=is_listed, + position=position, + ) + rec.id = str(uuid4()) + db_session.add(rec) + db_session.commit() + return rec + + +class TestDatabaseRecommendAppRetrieval: + def test_get_type(self): + assert DatabaseRecommendAppRetrieval().get_type() == RecommendAppType.DATABASE + + def test_get_recommended_apps_delegates(self): + with patch.object( + DatabaseRecommendAppRetrieval, + "fetch_recommended_apps_from_db", + return_value={"recommended_apps": [], "categories": []}, + ) as mock_fetch: + result = DatabaseRecommendAppRetrieval().get_recommended_apps_and_categories("en-US") + mock_fetch.assert_called_once_with("en-US") + assert result == {"recommended_apps": [], "categories": []} + + def test_get_recommend_app_detail_delegates(self): + with patch.object( + DatabaseRecommendAppRetrieval, + "fetch_recommended_app_detail_from_db", + return_value={"id": "app-1"}, + ) as mock_fetch: + result = DatabaseRecommendAppRetrieval().get_recommend_app_detail("app-1") + mock_fetch.assert_called_once_with("app-1") + assert result == {"id": "app-1"} + + +class TestFetchRecommendedAppsFromDb: + def test_returns_apps_and_sorted_categories(self, flask_app_with_containers, db_session_with_containers): + tenant_id = str(uuid4()) + app1 = _create_app(db_session_with_containers, tenant_id=tenant_id) + _create_site(db_session_with_containers, app_id=app1.id) + _create_recommended_app(db_session_with_containers, app_id=app1.id, category="writing") + + app2 = _create_app(db_session_with_containers, tenant_id=tenant_id) + _create_site(db_session_with_containers, app_id=app2.id) + _create_recommended_app(db_session_with_containers, app_id=app2.id, category="assistant") + + db_session_with_containers.expire_all() + + result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US") + + app_ids = {r["app_id"] for r in result["recommended_apps"]} + assert app1.id in app_ids + assert app2.id in app_ids + assert "assistant" in result["categories"] + assert "writing" in result["categories"] + + def test_falls_back_to_default_language_when_empty(self, flask_app_with_containers, db_session_with_containers): + tenant_id = str(uuid4()) + app1 = _create_app(db_session_with_containers, tenant_id=tenant_id) + _create_site(db_session_with_containers, app_id=app1.id) + _create_recommended_app(db_session_with_containers, app_id=app1.id, language="en-US") + + db_session_with_containers.expire_all() + + result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("fr-FR") + + app_ids = {r["app_id"] for r in result["recommended_apps"]} + assert app1.id in app_ids + + def test_skips_non_public_apps(self, flask_app_with_containers, db_session_with_containers): + tenant_id = str(uuid4()) + app1 = _create_app(db_session_with_containers, tenant_id=tenant_id, is_public=False) + _create_site(db_session_with_containers, app_id=app1.id) + _create_recommended_app(db_session_with_containers, app_id=app1.id) + + db_session_with_containers.expire_all() + + result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US") + + app_ids = {r["app_id"] for r in result["recommended_apps"]} + assert app1.id not in app_ids + + def test_skips_apps_without_site(self, flask_app_with_containers, db_session_with_containers): + tenant_id = str(uuid4()) + app1 = _create_app(db_session_with_containers, tenant_id=tenant_id) + _create_recommended_app(db_session_with_containers, app_id=app1.id) + + db_session_with_containers.expire_all() + + result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US") + + app_ids = {r["app_id"] for r in result["recommended_apps"]} + assert app1.id not in app_ids + + +class TestFetchRecommendedAppDetailFromDb: + def test_returns_none_when_not_listed(self, flask_app_with_containers, db_session_with_containers): + result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db(str(uuid4())) + + assert result is None + + def test_returns_none_when_app_not_public(self, flask_app_with_containers, db_session_with_containers): + tenant_id = str(uuid4()) + app1 = _create_app(db_session_with_containers, tenant_id=tenant_id, is_public=False) + _create_recommended_app(db_session_with_containers, app_id=app1.id) + + db_session_with_containers.expire_all() + + result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db(app1.id) + + assert result is None + + @patch("services.recommend_app.database.database_retrieval.AppDslService") + def test_returns_detail_on_success(self, mock_dsl, flask_app_with_containers, db_session_with_containers): + tenant_id = str(uuid4()) + app1 = _create_app(db_session_with_containers, tenant_id=tenant_id) + _create_site(db_session_with_containers, app_id=app1.id) + _create_recommended_app(db_session_with_containers, app_id=app1.id) + mock_dsl.export_dsl.return_value = "exported_yaml" + + db_session_with_containers.expire_all() + + result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db(app1.id) + + assert result is not None + assert result["id"] == app1.id + assert result["export_data"] == "exported_yaml" diff --git a/api/tests/test_containers_integration_tests/services/test_agent_service.py b/api/tests/test_containers_integration_tests/services/test_agent_service.py index ee34b65831b..4f3c0e42000 100644 --- a/api/tests/test_containers_integration_tests/services/test_agent_service.py +++ b/api/tests/test_containers_integration_tests/services/test_agent_service.py @@ -7,7 +7,7 @@ from sqlalchemy.orm import Session from core.plugin.impl.exc import PluginDaemonClientSideError from models import Account -from models.enums import MessageFileBelongsTo +from models.enums import ConversationFromSource, MessageFileBelongsTo from models.model import AppModelConfig, Conversation, EndUser, Message, MessageAgentThought from services.account_service import AccountService, TenantService from services.agent_service import AgentService @@ -28,7 +28,7 @@ class TestAgentService: patch("services.agent_service.current_user", create_autospec(Account, instance=True)) as mock_current_user, patch("services.app_service.FeatureService", autospec=True) as mock_feature_service, patch("services.app_service.EnterpriseService", autospec=True) as mock_enterprise_service, - patch("services.app_service.ModelManager", autospec=True) as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant", autospec=True) as mock_model_manager, patch("services.account_service.FeatureService", autospec=True) as mock_account_feature_service, ): # Setup default mock returns for agent service @@ -165,7 +165,7 @@ class TestAgentService: inputs={}, status="normal", mode="chat", - from_source="api", + from_source=ConversationFromSource.API, ) db_session_with_containers.add(conversation) db_session_with_containers.commit() @@ -204,7 +204,7 @@ class TestAgentService: answer_unit_price=0.001, provider_response_latency=1.5, currency="USD", - from_source="api", + from_source=ConversationFromSource.API, ) db_session_with_containers.add(message) db_session_with_containers.commit() @@ -406,7 +406,7 @@ class TestAgentService: inputs={}, status="normal", mode="chat", - from_source="api", + from_source=ConversationFromSource.API, ) db_session_with_containers.add(conversation) db_session_with_containers.commit() @@ -445,7 +445,7 @@ class TestAgentService: answer_unit_price=0.001, provider_response_latency=1.5, currency="USD", - from_source="api", + from_source=ConversationFromSource.API, ) db_session_with_containers.add(message) db_session_with_containers.commit() @@ -478,7 +478,7 @@ class TestAgentService: inputs={}, status="normal", mode="chat", - from_source="api", + from_source=ConversationFromSource.API, ) db_session_with_containers.add(conversation) db_session_with_containers.commit() @@ -517,7 +517,7 @@ class TestAgentService: answer_unit_price=0.001, provider_response_latency=1.5, currency="USD", - from_source="api", + from_source=ConversationFromSource.API, ) db_session_with_containers.add(message) db_session_with_containers.commit() @@ -624,7 +624,7 @@ class TestAgentService: inputs={}, status="normal", mode="chat", - from_source="api", + from_source=ConversationFromSource.API, app_model_config_id=None, # Explicitly set to None ) db_session_with_containers.add(conversation) @@ -647,7 +647,7 @@ class TestAgentService: answer_unit_price=0.001, provider_response_latency=1.5, currency="USD", - from_source="api", + from_source=ConversationFromSource.API, ) db_session_with_containers.add(message) db_session_with_containers.commit() @@ -841,7 +841,8 @@ class TestAgentService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) - from dify_graph.file import FileTransferMethod, FileType + from graphon.file import FileTransferMethod, FileType + from models.enums import CreatorUserRole # Add files to message diff --git a/api/tests/test_containers_integration_tests/services/test_annotation_service.py b/api/tests/test_containers_integration_tests/services/test_annotation_service.py index a260d823a27..95fc73f45a4 100644 --- a/api/tests/test_containers_integration_tests/services/test_annotation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_annotation_service.py @@ -6,6 +6,7 @@ from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound from models import Account +from models.enums import ConversationFromSource, InvokeFrom from models.model import MessageAnnotation from services.annotation_service import AppAnnotationService from services.app_service import AppService @@ -136,8 +137,8 @@ class TestAnnotationService: system_instruction="", system_instruction_tokens=0, status="normal", - invoke_from="console", - from_source="console", + invoke_from=InvokeFrom.EXPLORE, + from_source=ConversationFromSource.CONSOLE, from_end_user_id=None, from_account_id=account.id, ) @@ -174,8 +175,8 @@ class TestAnnotationService: provider_response_latency=0, total_price=0, currency="USD", - invoke_from="console", - from_source="console", + invoke_from=InvokeFrom.EXPLORE, + from_source=ConversationFromSource.CONSOLE, from_end_user_id=None, from_account_id=account.id, ) @@ -721,7 +722,7 @@ class TestAnnotationService: query=f"Query {i}: {fake.sentence()}", user_id=account.id, message_id=fake.uuid4(), - from_source="console", + from_source=ConversationFromSource.CONSOLE, score=0.8 + (i * 0.1), ) @@ -772,7 +773,7 @@ class TestAnnotationService: query=query, user_id=account.id, message_id=message_id, - from_source="console", + from_source=ConversationFromSource.CONSOLE, score=score, ) diff --git a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py index 7ce7357b41a..b8e022503fd 100644 --- a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py +++ b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py @@ -525,3 +525,147 @@ class TestAPIBasedExtensionService: # Try to get extension with wrong tenant ID with pytest.raises(ValueError, match="API based extension is not found"): APIBasedExtensionService.get_with_tenant_id(tenant2.id, created_extension.id) + + def test_save_extension_api_key_exactly_four_chars_rejected( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """API key with exactly 4 characters should be rejected (boundary).""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + assert tenant is not None + + extension_data = APIBasedExtension( + tenant_id=tenant.id, + name=fake.company(), + api_endpoint=f"https://{fake.domain_name()}/api", + api_key="1234", + ) + + with pytest.raises(ValueError, match="api_key must be at least 5 characters"): + APIBasedExtensionService.save(extension_data) + + def test_save_extension_api_key_exactly_five_chars_accepted( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """API key with exactly 5 characters should be accepted (boundary).""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + assert tenant is not None + + extension_data = APIBasedExtension( + tenant_id=tenant.id, + name=fake.company(), + api_endpoint=f"https://{fake.domain_name()}/api", + api_key="12345", + ) + + saved = APIBasedExtensionService.save(extension_data) + assert saved.id is not None + + def test_save_extension_requestor_constructor_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Exception raised by requestor constructor is wrapped in ValueError.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + assert tenant is not None + + mock_external_service_dependencies["requestor"].side_effect = RuntimeError("bad config") + + extension_data = APIBasedExtension( + tenant_id=tenant.id, + name=fake.company(), + api_endpoint=f"https://{fake.domain_name()}/api", + api_key=fake.password(length=20), + ) + + with pytest.raises(ValueError, match="connection error: bad config"): + APIBasedExtensionService.save(extension_data) + + def test_save_extension_network_exception( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Network exceptions during ping are wrapped in ValueError.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + assert tenant is not None + + mock_external_service_dependencies["requestor_instance"].request.side_effect = ConnectionError( + "network failure" + ) + + extension_data = APIBasedExtension( + tenant_id=tenant.id, + name=fake.company(), + api_endpoint=f"https://{fake.domain_name()}/api", + api_key=fake.password(length=20), + ) + + with pytest.raises(ValueError, match="connection error: network failure"): + APIBasedExtensionService.save(extension_data) + + def test_save_extension_update_duplicate_name_rejected( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Updating an existing extension to use another extension's name should fail.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + assert tenant is not None + + ext1 = APIBasedExtensionService.save( + APIBasedExtension( + tenant_id=tenant.id, + name="Extension Alpha", + api_endpoint=f"https://{fake.domain_name()}/api", + api_key=fake.password(length=20), + ) + ) + ext2 = APIBasedExtensionService.save( + APIBasedExtension( + tenant_id=tenant.id, + name="Extension Beta", + api_endpoint=f"https://{fake.domain_name()}/api", + api_key=fake.password(length=20), + ) + ) + + # Try to rename ext2 to ext1's name + ext2.name = "Extension Alpha" + with pytest.raises(ValueError, match="name must be unique, it is already existed"): + APIBasedExtensionService.save(ext2) + + def test_get_all_returns_empty_for_different_tenant( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Extensions from one tenant should not be visible to another.""" + fake = Faker() + _, tenant1 = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + _, tenant2 = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + assert tenant1 is not None + + APIBasedExtensionService.save( + APIBasedExtension( + tenant_id=tenant1.id, + name=fake.company(), + api_endpoint=f"https://{fake.domain_name()}/api", + api_key=fake.password(length=20), + ) + ) + + assert tenant2 is not None + result = APIBasedExtensionService.get_all_by_tenant_id(tenant2.id) + assert result == [] diff --git a/api/tests/unit_tests/services/test_api_token_service.py b/api/tests/test_containers_integration_tests/services/test_api_token_service.py similarity index 71% rename from api/tests/unit_tests/services/test_api_token_service.py rename to api/tests/test_containers_integration_tests/services/test_api_token_service.py index ad4de93b25e..a2028d3ed3d 100644 --- a/api/tests/unit_tests/services/test_api_token_service.py +++ b/api/tests/test_containers_integration_tests/services/test_api_token_service.py @@ -1,80 +1,63 @@ +from __future__ import annotations + from datetime import datetime -from types import SimpleNamespace from unittest.mock import MagicMock, patch +from uuid import uuid4 import pytest from werkzeug.exceptions import Unauthorized import services.api_token_service as api_token_service_module +from models.model import ApiToken from services.api_token_service import ApiTokenCache, CachedApiToken -@pytest.fixture -def mock_db_session(): - """Fixture providing common DB session mocking for query_token_from_db tests.""" - fake_engine = MagicMock() - - session = MagicMock() - session_context = MagicMock() - session_context.__enter__.return_value = session - session_context.__exit__.return_value = None - - with ( - patch.object(api_token_service_module, "db", new=SimpleNamespace(engine=fake_engine)), - patch.object(api_token_service_module, "Session", return_value=session_context) as mock_session_class, - patch.object(api_token_service_module.ApiTokenCache, "set") as mock_cache_set, - patch.object(api_token_service_module, "record_token_usage") as mock_record_usage, - ): - yield { - "session": session, - "mock_session_class": mock_session_class, - "mock_cache_set": mock_cache_set, - "mock_record_usage": mock_record_usage, - "fake_engine": fake_engine, - } - - class TestQueryTokenFromDb: - def test_should_return_api_token_and_cache_when_token_exists(self, mock_db_session): - """Test DB lookup success path caches token and records usage.""" - # Arrange - auth_token = "token-123" - scope = "app" - api_token = MagicMock() + def test_should_return_api_token_and_cache_when_token_exists( + self, flask_app_with_containers, db_session_with_containers + ): + tenant_id = str(uuid4()) + app_id = str(uuid4()) + token_value = f"app-test-{uuid4()}" - mock_db_session["session"].scalar.return_value = api_token + api_token = ApiToken() + api_token.id = str(uuid4()) + api_token.app_id = app_id + api_token.tenant_id = tenant_id + api_token.type = "app" + api_token.token = token_value + db_session_with_containers.add(api_token) + db_session_with_containers.commit() - # Act - result = api_token_service_module.query_token_from_db(auth_token, scope) + with ( + patch.object(api_token_service_module.ApiTokenCache, "set") as mock_cache_set, + patch.object(api_token_service_module, "record_token_usage") as mock_record_usage, + ): + result = api_token_service_module.query_token_from_db(token_value, "app") - # Assert - assert result == api_token - mock_db_session["mock_session_class"].assert_called_once_with( - mock_db_session["fake_engine"], expire_on_commit=False - ) - mock_db_session["mock_cache_set"].assert_called_once_with(auth_token, scope, api_token) - mock_db_session["mock_record_usage"].assert_called_once_with(auth_token, scope) + assert result.id == api_token.id + assert result.token == token_value + mock_cache_set.assert_called_once() + mock_record_usage.assert_called_once_with(token_value, "app") - def test_should_cache_null_and_raise_unauthorized_when_token_not_found(self, mock_db_session): - """Test DB lookup miss path caches null marker and raises Unauthorized.""" - # Arrange - auth_token = "missing-token" - scope = "app" + def test_should_cache_null_and_raise_unauthorized_when_token_not_found( + self, flask_app_with_containers, db_session_with_containers + ): + with ( + patch.object(api_token_service_module.ApiTokenCache, "set") as mock_cache_set, + patch.object(api_token_service_module, "record_token_usage") as mock_record_usage, + ): + with pytest.raises(Unauthorized, match="Access token is invalid"): + api_token_service_module.query_token_from_db(f"missing-{uuid4()}", "app") - mock_db_session["session"].scalar.return_value = None - - # Act / Assert - with pytest.raises(Unauthorized, match="Access token is invalid"): - api_token_service_module.query_token_from_db(auth_token, scope) - - mock_db_session["mock_cache_set"].assert_called_once_with(auth_token, scope, None) - mock_db_session["mock_record_usage"].assert_not_called() + mock_cache_set.assert_called_once() + call_args = mock_cache_set.call_args[0] + assert call_args[2] is None # cached None + mock_record_usage.assert_not_called() class TestRecordTokenUsage: def test_should_write_active_key_with_iso_timestamp_and_ttl(self): - """Test record_token_usage writes usage timestamp with one-hour TTL.""" - # Arrange auth_token = "token-123" scope = "dataset" fixed_time = datetime(2026, 2, 24, 12, 0, 0) @@ -84,26 +67,18 @@ class TestRecordTokenUsage: patch.object(api_token_service_module, "naive_utc_now", return_value=fixed_time), patch.object(api_token_service_module, "redis_client") as mock_redis, ): - # Act api_token_service_module.record_token_usage(auth_token, scope) - # Assert mock_redis.set.assert_called_once_with(expected_key, fixed_time.isoformat(), ex=3600) def test_should_not_raise_when_redis_write_fails(self): - """Test record_token_usage swallows Redis errors.""" - # Arrange with patch.object(api_token_service_module, "redis_client") as mock_redis: mock_redis.set.side_effect = Exception("redis unavailable") - - # Act / Assert api_token_service_module.record_token_usage("token-123", "app") class TestFetchTokenWithSingleFlight: def test_should_return_cached_token_when_lock_acquired_and_cache_filled(self): - """Test single-flight returns cache when another request already populated it.""" - # Arrange auth_token = "token-123" scope = "app" cached_token = CachedApiToken( @@ -115,39 +90,26 @@ class TestFetchTokenWithSingleFlight: last_used_at=None, created_at=None, ) - lock = MagicMock() lock.acquire.return_value = True with ( patch.object(api_token_service_module, "redis_client") as mock_redis, - patch.object(api_token_service_module.ApiTokenCache, "get", return_value=cached_token) as mock_cache_get, + patch.object(api_token_service_module.ApiTokenCache, "get", return_value=cached_token), patch.object(api_token_service_module, "query_token_from_db") as mock_query_db, ): mock_redis.lock.return_value = lock - - # Act result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope) - # Assert assert result == cached_token - mock_redis.lock.assert_called_once_with( - f"api_token_query_lock:{scope}:{auth_token}", - timeout=10, - blocking_timeout=5, - ) lock.acquire.assert_called_once_with(blocking=True) lock.release.assert_called_once() - mock_cache_get.assert_called_once_with(auth_token, scope) mock_query_db.assert_not_called() def test_should_query_db_when_lock_acquired_and_cache_missed(self): - """Test single-flight queries DB when cache remains empty after lock acquisition.""" - # Arrange auth_token = "token-123" scope = "app" db_token = MagicMock() - lock = MagicMock() lock.acquire.return_value = True @@ -157,22 +119,16 @@ class TestFetchTokenWithSingleFlight: patch.object(api_token_service_module, "query_token_from_db", return_value=db_token) as mock_query_db, ): mock_redis.lock.return_value = lock - - # Act result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope) - # Assert assert result == db_token mock_query_db.assert_called_once_with(auth_token, scope) lock.release.assert_called_once() def test_should_query_db_directly_when_lock_not_acquired(self): - """Test lock timeout branch falls back to direct DB query.""" - # Arrange auth_token = "token-123" scope = "app" db_token = MagicMock() - lock = MagicMock() lock.acquire.return_value = False @@ -182,19 +138,14 @@ class TestFetchTokenWithSingleFlight: patch.object(api_token_service_module, "query_token_from_db", return_value=db_token) as mock_query_db, ): mock_redis.lock.return_value = lock - - # Act result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope) - # Assert assert result == db_token mock_cache_get.assert_not_called() mock_query_db.assert_called_once_with(auth_token, scope) lock.release.assert_not_called() def test_should_reraise_unauthorized_from_db_query(self): - """Test Unauthorized from DB query is propagated unchanged.""" - # Arrange auth_token = "token-123" scope = "app" lock = MagicMock() @@ -210,20 +161,15 @@ class TestFetchTokenWithSingleFlight: ), ): mock_redis.lock.return_value = lock - - # Act / Assert with pytest.raises(Unauthorized, match="Access token is invalid"): api_token_service_module.fetch_token_with_single_flight(auth_token, scope) lock.release.assert_called_once() def test_should_fallback_to_db_query_when_lock_raises_exception(self): - """Test Redis lock errors fall back to direct DB query.""" - # Arrange auth_token = "token-123" scope = "app" db_token = MagicMock() - lock = MagicMock() lock.acquire.side_effect = RuntimeError("redis lock error") @@ -232,11 +178,8 @@ class TestFetchTokenWithSingleFlight: patch.object(api_token_service_module, "query_token_from_db", return_value=db_token) as mock_query_db, ): mock_redis.lock.return_value = lock - - # Act result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope) - # Assert assert result == db_token mock_query_db.assert_called_once_with(auth_token, scope) @@ -244,8 +187,6 @@ class TestFetchTokenWithSingleFlight: class TestApiTokenCacheTenantBranches: @patch("services.api_token_service.redis_client") def test_delete_with_scope_should_remove_from_tenant_index_when_tenant_found(self, mock_redis): - """Test scoped delete removes cache key and tenant index membership.""" - # Arrange token = "token-123" scope = "app" cache_key = ApiTokenCache._make_cache_key(token, scope) @@ -261,18 +202,14 @@ class TestApiTokenCacheTenantBranches: mock_redis.get.return_value = cached_token.model_dump_json().encode("utf-8") with patch.object(ApiTokenCache, "_remove_from_tenant_index") as mock_remove_index: - # Act result = ApiTokenCache.delete(token, scope) - # Assert assert result is True mock_redis.delete.assert_called_once_with(cache_key) mock_remove_index.assert_called_once_with("tenant-1", cache_key) @patch("services.api_token_service.redis_client") def test_invalidate_by_tenant_should_delete_all_indexed_cache_keys(self, mock_redis): - """Test tenant invalidation deletes indexed cache entries and index key.""" - # Arrange tenant_id = "tenant-1" index_key = ApiTokenCache._make_tenant_index_key(tenant_id) mock_redis.smembers.return_value = { @@ -280,10 +217,8 @@ class TestApiTokenCacheTenantBranches: b"api_token:any:token-2", } - # Act result = ApiTokenCache.invalidate_by_tenant(tenant_id) - # Assert assert result is True mock_redis.smembers.assert_called_once_with(index_key) mock_redis.delete.assert_any_call("api_token:app:token-1") @@ -293,7 +228,6 @@ class TestApiTokenCacheTenantBranches: class TestApiTokenCacheCoreBranches: def test_cached_api_token_repr_should_include_id_and_type(self): - """Test CachedApiToken __repr__ includes key identity fields.""" token = CachedApiToken( id="id-123", app_id="app-123", @@ -303,11 +237,9 @@ class TestApiTokenCacheCoreBranches: last_used_at=None, created_at=None, ) - assert repr(token) == "" def test_serialize_token_should_handle_cached_api_token_instances(self): - """Test serialization path when input is already a CachedApiToken.""" token = CachedApiToken( id="id-123", app_id="app-123", @@ -317,35 +249,25 @@ class TestApiTokenCacheCoreBranches: last_used_at=None, created_at=None, ) - serialized = ApiTokenCache._serialize_token(token) - assert isinstance(serialized, bytes) assert b'"id":"id-123"' in serialized - assert b'"token":"token-123"' in serialized def test_deserialize_token_should_return_none_for_null_markers(self): - """Test null cache marker deserializes to None.""" assert ApiTokenCache._deserialize_token("null") is None assert ApiTokenCache._deserialize_token(b"null") is None def test_deserialize_token_should_return_none_for_invalid_payload(self): - """Test invalid serialized payload returns None.""" assert ApiTokenCache._deserialize_token("not-json") is None @patch("services.api_token_service.redis_client") def test_get_should_return_none_on_cache_miss(self, mock_redis): - """Test cache miss branch in ApiTokenCache.get.""" mock_redis.get.return_value = None - result = ApiTokenCache.get("token-123", "app") - assert result is None - mock_redis.get.assert_called_once_with("api_token:app:token-123") @patch("services.api_token_service.redis_client") def test_get_should_deserialize_cached_payload_on_cache_hit(self, mock_redis): - """Test cache hit branch in ApiTokenCache.get.""" token = CachedApiToken( id="id-123", app_id="app-123", @@ -356,48 +278,34 @@ class TestApiTokenCacheCoreBranches: created_at=None, ) mock_redis.get.return_value = token.model_dump_json().encode("utf-8") - result = ApiTokenCache.get("token-123", "app") - assert isinstance(result, CachedApiToken) assert result.id == "id-123" @patch("services.api_token_service.redis_client") def test_add_to_tenant_index_should_skip_when_tenant_id_missing(self, mock_redis): - """Test tenant index update exits early for missing tenant id.""" ApiTokenCache._add_to_tenant_index(None, "api_token:app:token-123") - mock_redis.sadd.assert_not_called() - mock_redis.expire.assert_not_called() @patch("services.api_token_service.redis_client") def test_add_to_tenant_index_should_swallow_index_update_errors(self, mock_redis): - """Test tenant index update handles Redis write errors gracefully.""" mock_redis.sadd.side_effect = Exception("redis down") - ApiTokenCache._add_to_tenant_index("tenant-123", "api_token:app:token-123") - mock_redis.sadd.assert_called_once() @patch("services.api_token_service.redis_client") def test_remove_from_tenant_index_should_skip_when_tenant_id_missing(self, mock_redis): - """Test tenant index removal exits early for missing tenant id.""" ApiTokenCache._remove_from_tenant_index(None, "api_token:app:token-123") - mock_redis.srem.assert_not_called() @patch("services.api_token_service.redis_client") def test_remove_from_tenant_index_should_swallow_redis_errors(self, mock_redis): - """Test tenant index removal handles Redis errors gracefully.""" mock_redis.srem.side_effect = Exception("redis down") - ApiTokenCache._remove_from_tenant_index("tenant-123", "api_token:app:token-123") - mock_redis.srem.assert_called_once() @patch("services.api_token_service.redis_client") def test_set_should_return_false_when_cache_write_raises_exception(self, mock_redis): - """Test set returns False when Redis setex fails.""" mock_redis.setex.side_effect = Exception("redis write failed") api_token = MagicMock() api_token.id = "id-123" @@ -407,60 +315,41 @@ class TestApiTokenCacheCoreBranches: api_token.token = "token-123" api_token.last_used_at = None api_token.created_at = None - result = ApiTokenCache.set("token-123", "app", api_token) - assert result is False @patch("services.api_token_service.redis_client") def test_delete_without_scope_should_return_false_when_scan_fails(self, mock_redis): - """Test delete(scope=None) returns False when scan_iter raises.""" mock_redis.scan_iter.side_effect = Exception("scan failed") - result = ApiTokenCache.delete("token-123", None) - assert result is False @patch("services.api_token_service.redis_client") def test_delete_with_scope_should_continue_when_tenant_lookup_raises(self, mock_redis): - """Test scoped delete still succeeds when tenant lookup from cache fails.""" token = "token-123" scope = "app" cache_key = ApiTokenCache._make_cache_key(token, scope) mock_redis.get.side_effect = Exception("get failed") - result = ApiTokenCache.delete(token, scope) - assert result is True mock_redis.delete.assert_called_once_with(cache_key) @patch("services.api_token_service.redis_client") def test_delete_with_scope_should_return_false_when_delete_raises(self, mock_redis): - """Test scoped delete returns False when delete operation fails.""" - token = "token-123" - scope = "app" mock_redis.get.return_value = None mock_redis.delete.side_effect = Exception("delete failed") - - result = ApiTokenCache.delete(token, scope) - + result = ApiTokenCache.delete("token-123", "app") assert result is False @patch("services.api_token_service.redis_client") def test_invalidate_by_tenant_should_return_true_when_index_not_found(self, mock_redis): - """Test tenant invalidation returns True when tenant index is empty.""" mock_redis.smembers.return_value = set() - result = ApiTokenCache.invalidate_by_tenant("tenant-123") - assert result is True mock_redis.delete.assert_not_called() @patch("services.api_token_service.redis_client") def test_invalidate_by_tenant_should_return_false_when_redis_raises(self, mock_redis): - """Test tenant invalidation returns False when Redis operation fails.""" mock_redis.smembers.side_effect = Exception("redis failed") - result = ApiTokenCache.invalidate_by_tenant("tenant-123") - assert result is False diff --git a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py index 8a362e1f5e7..33955d5d84b 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py @@ -26,7 +26,7 @@ class TestAppDslService: patch("services.app_dsl_service.redis_client") as mock_redis_client, patch("services.app_dsl_service.app_was_created") as mock_app_was_created, patch("services.app_dsl_service.app_model_config_was_updated") as mock_app_model_config_was_updated, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, ): diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py index d79f80c0094..fa57dd4a6ff 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Session from constants.model_template import default_app_templates from models import Account -from models.model import App, Site +from models.model import App, IconType, Site from services.account_service import AccountService, TenantService from tests.test_containers_integration_tests.helpers import generate_valid_password @@ -23,7 +23,7 @@ class TestAppService: with ( patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.account_service.FeatureService") as mock_account_feature_service, ): # Setup default mock returns for app service @@ -463,6 +463,109 @@ class TestAppService: assert updated_app.tenant_id == app.tenant_id assert updated_app.created_by == app.created_by + def test_update_app_should_preserve_icon_type_when_omitted( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """ + Test update_app keeps the persisted icon_type when the update payload omits it. + """ + fake = Faker() + + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=generate_valid_password(fake), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + from services.app_service import AppService + + app_service = AppService() + app = app_service.create_app( + tenant.id, + { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🎯", + "icon_background": "#45B7D1", + }, + account, + ) + + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.id = account.id + mock_current_user.current_tenant_id = account.current_tenant_id + + with patch("services.app_service.current_user", mock_current_user): + updated_app = app_service.update_app( + app, + { + "name": "Updated App Name", + "description": "Updated app description", + "icon_type": None, + "icon": "🔄", + "icon_background": "#FF8C42", + "use_icon_as_answer_icon": True, + }, + ) + + assert updated_app.icon_type == IconType.EMOJI + + def test_update_app_should_reject_empty_icon_type( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """ + Test update_app rejects an explicit empty icon_type. + """ + fake = Faker() + + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=generate_valid_password(fake), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + from services.app_service import AppService + + app_service = AppService() + app = app_service.create_app( + tenant.id, + { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🎯", + "icon_background": "#45B7D1", + }, + account, + ) + + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.id = account.id + mock_current_user.current_tenant_id = account.current_tenant_id + + with patch("services.app_service.current_user", mock_current_user): + with pytest.raises(ValueError): + app_service.update_app( + app, + { + "name": "Updated App Name", + "description": "Updated app description", + "icon_type": "", + "icon": "🔄", + "icon_background": "#FF8C42", + "use_icon_as_answer_icon": True, + }, + ) + def test_update_app_name_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app name update. @@ -1142,3 +1245,51 @@ class TestAppService: assert paginated_apps is not None assert paginated_apps.total == 1 assert all("50%" in app.name for app in paginated_apps.items) + + def test_get_app_code_by_id_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test get_app_code_by_id raises ValueError when site is missing.""" + from uuid import uuid4 + + from services.app_service import AppService + + with pytest.raises(ValueError, match="not found"): + AppService.get_app_code_by_id(str(uuid4())) + + def test_get_app_id_by_code_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test get_app_id_by_code raises ValueError when code does not exist.""" + from services.app_service import AppService + + with pytest.raises(ValueError, match="not found"): + AppService.get_app_id_by_code("nonexistent-code") + + def test_get_app_meta_returns_empty_when_workflow_missing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test get_app_meta returns empty tool_icons when workflow is None.""" + from types import SimpleNamespace + + from services.app_service import AppService + + app_service = AppService() + workflow_app = SimpleNamespace(mode="workflow", workflow=None) + + meta = app_service.get_app_meta(workflow_app) + assert meta == {"tool_icons": {}} + + def test_get_app_meta_returns_empty_when_model_config_missing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test get_app_meta returns empty tool_icons when app_model_config is None.""" + from types import SimpleNamespace + + from services.app_service import AppService + + app_service = AppService() + chat_app = SimpleNamespace(mode="chat", app_model_config=None) + + meta = app_service.get_app_meta(chat_app) + assert meta == {"tool_icons": {}} diff --git a/api/tests/test_containers_integration_tests/services/test_attachment_service.py b/api/tests/test_containers_integration_tests/services/test_attachment_service.py new file mode 100644 index 00000000000..768a8baee2d --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_attachment_service.py @@ -0,0 +1,80 @@ +"""Testcontainers integration tests for AttachmentService.""" + +import base64 +from datetime import UTC, datetime +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from werkzeug.exceptions import NotFound + +import services.attachment_service as attachment_service_module +from extensions.ext_database import db +from extensions.storage.storage_type import StorageType +from models.enums import CreatorUserRole +from models.model import UploadFile +from services.attachment_service import AttachmentService + + +class TestAttachmentService: + def _create_upload_file(self, db_session_with_containers, *, tenant_id: str | None = None) -> UploadFile: + upload_file = UploadFile( + tenant_id=tenant_id or str(uuid4()), + storage_type=StorageType.OPENDAL, + key=f"upload/{uuid4()}.txt", + name="test-file.txt", + size=100, + extension="txt", + mime_type="text/plain", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=str(uuid4()), + created_at=datetime.now(UTC), + used=False, + ) + db_session_with_containers.add(upload_file) + db_session_with_containers.commit() + return upload_file + + def test_should_initialize_with_sessionmaker(self): + session_factory = sessionmaker() + + service = AttachmentService(session_factory=session_factory) + + assert service._session_maker is session_factory + + def test_should_initialize_with_engine(self): + engine = create_engine("sqlite:///:memory:") + + service = AttachmentService(session_factory=engine) + session = service._session_maker() + try: + assert session.bind == engine + finally: + session.close() + engine.dispose() + + @pytest.mark.parametrize("invalid_session_factory", [None, "not-a-session-factory", 1]) + def test_should_raise_assertion_error_for_invalid_session_factory(self, invalid_session_factory): + with pytest.raises(AssertionError, match="must be a sessionmaker or an Engine."): + AttachmentService(session_factory=invalid_session_factory) + + def test_should_return_base64_when_file_exists(self, db_session_with_containers): + upload_file = self._create_upload_file(db_session_with_containers) + service = AttachmentService(session_factory=sessionmaker(bind=db.engine)) + + with patch.object(attachment_service_module.storage, "load_once", return_value=b"binary-content") as mock_load: + result = service.get_file_base64(upload_file.id) + + assert result == base64.b64encode(b"binary-content").decode() + mock_load.assert_called_once_with(upload_file.key) + + def test_should_raise_not_found_when_file_missing(self, db_session_with_containers): + service = AttachmentService(session_factory=sessionmaker(bind=db.engine)) + + with patch.object(attachment_service_module.storage, "load_once") as mock_load: + with pytest.raises(NotFound, match="File not found"): + service.get_file_base64(str(uuid4())) + + mock_load.assert_not_called() diff --git a/api/tests/test_containers_integration_tests/services/test_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_conversation_service.py index 5f64e6f6744..6180d98b1ef 100644 --- a/api/tests/test_containers_integration_tests/services/test_conversation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_conversation_service.py @@ -10,6 +10,7 @@ from sqlalchemy import select from core.app.entities.app_invoke_entities import InvokeFrom from models.account import Account, Tenant, TenantAccountJoin +from models.enums import ConversationFromSource from models.model import App, Conversation, EndUser, Message, MessageAnnotation from services.annotation_service import AppAnnotationService from services.conversation_service import ConversationService @@ -107,7 +108,7 @@ class ConversationServiceIntegrationTestDataFactory: system_instruction_tokens=0, status="normal", invoke_from=invoke_from.value, - from_source="api" if isinstance(user, EndUser) else "console", + from_source=ConversationFromSource.API if isinstance(user, EndUser) else ConversationFromSource.CONSOLE, from_end_user_id=user.id if isinstance(user, EndUser) else None, from_account_id=user.id if isinstance(user, Account) else None, dialogue_count=0, @@ -154,7 +155,7 @@ class ConversationServiceIntegrationTestDataFactory: currency="USD", status="normal", invoke_from=InvokeFrom.WEB_APP.value, - from_source="api" if isinstance(user, EndUser) else "console", + from_source=ConversationFromSource.API if isinstance(user, EndUser) else ConversationFromSource.CONSOLE, from_end_user_id=user.id if isinstance(user, EndUser) else None, from_account_id=user.id if isinstance(user, Account) else None, ) diff --git a/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py b/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py new file mode 100644 index 00000000000..fb0adbbcc26 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py @@ -0,0 +1,58 @@ +"""Testcontainers integration tests for ConversationVariableUpdater.""" + +from uuid import uuid4 + +import pytest +from graphon.variables import StringVariable +from sqlalchemy.orm import sessionmaker + +from extensions.ext_database import db +from models.workflow import ConversationVariable +from services.conversation_variable_updater import ConversationVariableNotFoundError, ConversationVariableUpdater + + +class TestConversationVariableUpdater: + def _create_conversation_variable( + self, db_session_with_containers, *, conversation_id: str, variable: StringVariable, app_id: str | None = None + ) -> ConversationVariable: + row = ConversationVariable( + id=variable.id, + conversation_id=conversation_id, + app_id=app_id or str(uuid4()), + data=variable.model_dump_json(), + ) + db_session_with_containers.add(row) + db_session_with_containers.commit() + return row + + def test_should_update_conversation_variable_data_and_commit(self, db_session_with_containers): + conversation_id = str(uuid4()) + variable = StringVariable(id=str(uuid4()), name="topic", value="old value") + self._create_conversation_variable( + db_session_with_containers, conversation_id=conversation_id, variable=variable + ) + + updated_variable = StringVariable(id=variable.id, name="topic", value="new value") + updater = ConversationVariableUpdater(sessionmaker(bind=db.engine)) + + updater.update(conversation_id=conversation_id, variable=updated_variable) + + db_session_with_containers.expire_all() + row = db_session_with_containers.get(ConversationVariable, (variable.id, conversation_id)) + assert row is not None + assert row.data == updated_variable.model_dump_json() + + def test_should_raise_not_found_when_variable_missing(self, db_session_with_containers): + conversation_id = str(uuid4()) + variable = StringVariable(id=str(uuid4()), name="topic", value="value") + updater = ConversationVariableUpdater(sessionmaker(bind=db.engine)) + + with pytest.raises(ConversationVariableNotFoundError, match="conversation variable not found in the database"): + updater.update(conversation_id=conversation_id, variable=variable) + + def test_should_do_nothing_when_flush_is_called(self, db_session_with_containers): + updater = ConversationVariableUpdater(sessionmaker(bind=db.engine)) + + result = updater.flush() + + assert result is None diff --git a/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py b/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py new file mode 100644 index 00000000000..0f63d986422 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py @@ -0,0 +1,104 @@ +"""Testcontainers integration tests for CreditPoolService.""" + +from uuid import uuid4 + +import pytest + +from core.errors.error import QuotaExceededError +from models import TenantCreditPool +from models.enums import ProviderQuotaType +from services.credit_pool_service import CreditPoolService + + +class TestCreditPoolService: + def _create_tenant_id(self) -> str: + return str(uuid4()) + + def test_create_default_pool(self, db_session_with_containers): + tenant_id = self._create_tenant_id() + + pool = CreditPoolService.create_default_pool(tenant_id) + + assert isinstance(pool, TenantCreditPool) + assert pool.tenant_id == tenant_id + assert pool.pool_type == ProviderQuotaType.TRIAL + assert pool.quota_used == 0 + assert pool.quota_limit > 0 + + def test_get_pool_returns_pool_when_exists(self, db_session_with_containers): + tenant_id = self._create_tenant_id() + CreditPoolService.create_default_pool(tenant_id) + + result = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type=ProviderQuotaType.TRIAL) + + assert result is not None + assert result.tenant_id == tenant_id + assert result.pool_type == ProviderQuotaType.TRIAL + + def test_get_pool_returns_none_when_not_exists(self, db_session_with_containers): + result = CreditPoolService.get_pool(tenant_id=self._create_tenant_id(), pool_type=ProviderQuotaType.TRIAL) + + assert result is None + + def test_check_credits_available_returns_false_when_no_pool(self, db_session_with_containers): + result = CreditPoolService.check_credits_available(tenant_id=self._create_tenant_id(), credits_required=10) + + assert result is False + + def test_check_credits_available_returns_true_when_sufficient(self, db_session_with_containers): + tenant_id = self._create_tenant_id() + CreditPoolService.create_default_pool(tenant_id) + + result = CreditPoolService.check_credits_available(tenant_id=tenant_id, credits_required=10) + + assert result is True + + def test_check_credits_available_returns_false_when_insufficient(self, db_session_with_containers): + tenant_id = self._create_tenant_id() + pool = CreditPoolService.create_default_pool(tenant_id) + # Exhaust credits + pool.quota_used = pool.quota_limit + db_session_with_containers.commit() + + result = CreditPoolService.check_credits_available(tenant_id=tenant_id, credits_required=1) + + assert result is False + + def test_check_and_deduct_credits_raises_when_no_pool(self, db_session_with_containers): + with pytest.raises(QuotaExceededError, match="Credit pool not found"): + CreditPoolService.check_and_deduct_credits(tenant_id=self._create_tenant_id(), credits_required=10) + + def test_check_and_deduct_credits_raises_when_no_remaining(self, db_session_with_containers): + tenant_id = self._create_tenant_id() + pool = CreditPoolService.create_default_pool(tenant_id) + pool.quota_used = pool.quota_limit + db_session_with_containers.commit() + + with pytest.raises(QuotaExceededError, match="No credits remaining"): + CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=10) + + def test_check_and_deduct_credits_deducts_required_amount(self, db_session_with_containers): + tenant_id = self._create_tenant_id() + CreditPoolService.create_default_pool(tenant_id) + credits_required = 10 + + result = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=credits_required) + + assert result == credits_required + db_session_with_containers.expire_all() + pool = CreditPoolService.get_pool(tenant_id=tenant_id) + assert pool.quota_used == credits_required + + def test_check_and_deduct_credits_caps_at_remaining(self, db_session_with_containers): + tenant_id = self._create_tenant_id() + pool = CreditPoolService.create_default_pool(tenant_id) + remaining = 5 + pool.quota_used = pool.quota_limit - remaining + db_session_with_containers.commit() + + result = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=200) + + assert result == remaining + db_session_with_containers.expire_all() + updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id) + assert updated_pool.quota_used == pool.quota_limit diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py index 975af3d4282..71c8874f79a 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py @@ -9,6 +9,7 @@ from uuid import uuid4 import pytest +from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import ( @@ -69,7 +70,7 @@ class DatasetPermissionTestDataFactory: name=name, description="desc", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=created_by, permission=permission, provider="vendor", @@ -397,6 +398,68 @@ class TestDatasetPermissionServiceClearPartialMemberList: class TestDatasetServiceCheckDatasetPermission: """Verify dataset access checks against persisted partial-member permissions.""" + def test_check_dataset_permission_different_tenant_should_fail(self, db_session_with_containers): + """Test that users from different tenants cannot access dataset.""" + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + other_user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, owner.id, permission=DatasetPermissionEnum.ALL_TEAM + ) + + with pytest.raises(NoPermissionError): + DatasetService.check_dataset_permission(dataset, other_user) + + def test_check_dataset_permission_owner_can_access_any_dataset(self, db_session_with_containers): + """Test that tenant owners can access any dataset regardless of permission level.""" + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + creator, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, tenant=tenant + ) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, creator.id, permission=DatasetPermissionEnum.ONLY_ME + ) + + DatasetService.check_dataset_permission(dataset, owner) + + def test_check_dataset_permission_only_me_creator_can_access(self, db_session_with_containers): + """Test ONLY_ME permission allows only the dataset creator to access.""" + creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, creator.id, permission=DatasetPermissionEnum.ONLY_ME + ) + + DatasetService.check_dataset_permission(dataset, creator) + + def test_check_dataset_permission_only_me_others_cannot_access(self, db_session_with_containers): + """Test ONLY_ME permission denies access to non-creators.""" + creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) + other, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, tenant=tenant + ) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, creator.id, permission=DatasetPermissionEnum.ONLY_ME + ) + + with pytest.raises(NoPermissionError): + DatasetService.check_dataset_permission(dataset, other) + + def test_check_dataset_permission_all_team_allows_access(self, db_session_with_containers): + """Test ALL_TEAM permission allows any team member to access the dataset.""" + creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) + member, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, tenant=tenant + ) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, creator.id, permission=DatasetPermissionEnum.ALL_TEAM + ) + + DatasetService.check_dataset_permission(dataset, member) + def test_check_dataset_permission_partial_members_with_permission_success(self, db_session_with_containers): """ Test that user with explicit permission can access partial_members dataset. @@ -443,6 +506,16 @@ class TestDatasetServiceCheckDatasetPermission: with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"): DatasetService.check_dataset_permission(dataset, user) + def test_check_dataset_permission_partial_team_creator_can_access(self, db_session_with_containers): + """Test PARTIAL_TEAM permission allows creator to access without explicit permission.""" + creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, creator.id, permission=DatasetPermissionEnum.PARTIAL_TEAM + ) + + DatasetService.check_dataset_permission(dataset, creator) + class TestDatasetServiceCheckDatasetOperatorPermission: """Verify operator permission checks against persisted partial-member permissions.""" diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_service.py index ac3d9f9604e..f9bfa570cbb 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service.py @@ -9,10 +9,11 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod -from dify_graph.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings, Pipeline from models.enums import DatasetRuntimeMode, DataSourceType, DocumentCreatedFrom, IndexingStatus @@ -62,7 +63,7 @@ class DatasetServiceIntegrationDataFactory: name: str = "Test Dataset", description: str | None = "Test description", provider: str = "vendor", - indexing_technique: str | None = "high_quality", + indexing_technique: str | None = IndexTechniqueType.HIGH_QUALITY, permission: str = DatasetPermissionEnum.ONLY_ME, retrieval_model: dict | None = None, embedding_model_provider: str | None = None, @@ -106,7 +107,7 @@ class DatasetServiceIntegrationDataFactory: created_from=DocumentCreatedFrom.WEB, created_by=created_by, indexing_status=IndexingStatus.COMPLETED, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) db_session_with_containers.flush() @@ -156,13 +157,13 @@ class TestDatasetServiceCreateDataset: tenant_id=tenant.id, name="Economy Dataset", description=None, - indexing_technique="economy", + indexing_technique=IndexTechniqueType.ECONOMY, account=account, ) # Assert db_session_with_containers.refresh(result) - assert result.indexing_technique == "economy" + assert result.indexing_technique == IndexTechniqueType.ECONOMY assert result.embedding_model_provider is None assert result.embedding_model is None @@ -173,20 +174,20 @@ class TestDatasetServiceCreateDataset: embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model() # Act - with patch("services.dataset_service.ModelManager") as mock_model_manager: + with patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager: mock_model_manager.return_value.get_default_model_instance.return_value = embedding_model result = DatasetService.create_empty_dataset( tenant_id=tenant.id, name="High Quality Dataset", description=None, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, account=account, ) # Assert db_session_with_containers.refresh(result) - assert result.indexing_technique == "high_quality" + assert result.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert result.embedding_model_provider == embedding_model.provider assert result.embedding_model == embedding_model.model_name mock_model_manager.return_value.get_default_model_instance.assert_called_once_with( @@ -263,7 +264,7 @@ class TestDatasetServiceCreateDataset: # Act with ( - patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager, patch("services.dataset_service.DatasetService.check_reranking_model_setting") as mock_check_reranking, ): mock_model_manager.return_value.get_default_model_instance.return_value = embedding_model @@ -272,7 +273,7 @@ class TestDatasetServiceCreateDataset: tenant_id=tenant.id, name="Dataset With Reranking", description=None, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, account=account, retrieval_model=retrieval_model, ) @@ -296,7 +297,7 @@ class TestDatasetServiceCreateDataset: # Act with ( - patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager, patch("services.dataset_service.DatasetService.check_embedding_model_setting") as mock_check_embedding, ): mock_model_manager.return_value.get_model_instance.return_value = embedding_model @@ -305,7 +306,7 @@ class TestDatasetServiceCreateDataset: tenant_id=tenant.id, name="Custom Embedding Dataset", description=None, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, account=account, embedding_model_provider=embedding_provider, embedding_model_name=embedding_model_name, @@ -313,7 +314,7 @@ class TestDatasetServiceCreateDataset: # Assert db_session_with_containers.refresh(result) - assert result.indexing_technique == "high_quality" + assert result.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert result.embedding_model_provider == embedding_provider assert result.embedding_model == embedding_model_name mock_check_embedding.assert_called_once_with(tenant.id, embedding_provider, embedding_model_name) @@ -588,7 +589,7 @@ class TestDatasetServiceUpdateAndDeleteDataset: db_session_with_containers, tenant_id=tenant.id, created_by=account.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, chunk_structure="text_model", ) DatasetServiceIntegrationDataFactory.create_document( @@ -684,14 +685,14 @@ class TestDatasetServiceRetrievalConfiguration: db_session_with_containers, tenant_id=tenant.id, created_by=account.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, retrieval_model={"search_method": "semantic_search", "top_k": 2, "score_threshold": 0.0}, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=str(uuid4()), ) update_data = { - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "retrieval_model": { "search_method": "full_text_search", "top_k": 10, @@ -706,3 +707,104 @@ class TestDatasetServiceRetrievalConfiguration: db_session_with_containers.refresh(dataset) assert result.id == dataset.id assert dataset.retrieval_model == update_data["retrieval_model"] + + +class TestDocumentServicePauseRecoverRetry: + """Tests for pause/recover/retry orchestration using real DB and Redis.""" + + def _create_indexing_document(self, db_session_with_containers, indexing_status="indexing"): + factory = DatasetServiceIntegrationDataFactory + account, tenant = factory.create_account_with_tenant(db_session_with_containers) + dataset = factory.create_dataset(db_session_with_containers, tenant.id, account.id) + doc = factory.create_document(db_session_with_containers, dataset, account.id) + doc.indexing_status = indexing_status + db_session_with_containers.commit() + return doc, account + + def test_pause_document_success(self, db_session_with_containers): + from extensions.ext_redis import redis_client + from services.dataset_service import DocumentService + + doc, account = self._create_indexing_document(db_session_with_containers, indexing_status="indexing") + + with patch("services.dataset_service.current_user") as mock_user: + mock_user.id = account.id + DocumentService.pause_document(doc) + + db_session_with_containers.refresh(doc) + assert doc.is_paused is True + assert doc.paused_by == account.id + assert doc.paused_at is not None + + cache_key = f"document_{doc.id}_is_paused" + assert redis_client.get(cache_key) is not None + redis_client.delete(cache_key) + + def test_pause_document_invalid_status_error(self, db_session_with_containers): + from services.dataset_service import DocumentService + from services.errors.document import DocumentIndexingError + + doc, account = self._create_indexing_document(db_session_with_containers, indexing_status="completed") + + with patch("services.dataset_service.current_user") as mock_user: + mock_user.id = account.id + with pytest.raises(DocumentIndexingError): + DocumentService.pause_document(doc) + + def test_recover_document_success(self, db_session_with_containers): + from extensions.ext_redis import redis_client + from services.dataset_service import DocumentService + + doc, account = self._create_indexing_document(db_session_with_containers, indexing_status="indexing") + + # Pause first + with patch("services.dataset_service.current_user") as mock_user: + mock_user.id = account.id + DocumentService.pause_document(doc) + + # Recover + with patch("services.dataset_service.recover_document_indexing_task") as recover_task: + DocumentService.recover_document(doc) + + db_session_with_containers.refresh(doc) + assert doc.is_paused is False + assert doc.paused_by is None + assert doc.paused_at is None + + cache_key = f"document_{doc.id}_is_paused" + assert redis_client.get(cache_key) is None + recover_task.delay.assert_called_once_with(doc.dataset_id, doc.id) + + def test_retry_document_indexing_success(self, db_session_with_containers): + from extensions.ext_redis import redis_client + from services.dataset_service import DocumentService + + factory = DatasetServiceIntegrationDataFactory + account, tenant = factory.create_account_with_tenant(db_session_with_containers) + dataset = factory.create_dataset(db_session_with_containers, tenant.id, account.id) + doc1 = factory.create_document(db_session_with_containers, dataset, account.id, name="doc1.txt") + doc2 = factory.create_document(db_session_with_containers, dataset, account.id, name="doc2.txt") + doc2.position = 2 + doc1.indexing_status = "error" + doc2.indexing_status = "error" + db_session_with_containers.commit() + + with ( + patch("services.dataset_service.current_user") as mock_user, + patch("services.dataset_service.retry_document_indexing_task") as retry_task, + ): + mock_user.id = account.id + DocumentService.retry_document(dataset.id, [doc1, doc2]) + + db_session_with_containers.refresh(doc1) + db_session_with_containers.refresh(doc2) + assert doc1.indexing_status == "waiting" + assert doc2.indexing_status == "waiting" + + # Verify redis keys were set + assert redis_client.get(f"document_{doc1.id}_is_retried") is not None + assert redis_client.get(f"document_{doc2.id}_is_retried") is not None + retry_task.delay.assert_called_once_with(dataset.id, [doc1.id, doc2.id], account.id) + + # Cleanup + redis_client.delete(f"document_{doc1.id}_is_retried", f"document_{doc2.id}_is_retried") diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py index 7983b1cd937..c1d088755c1 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py @@ -13,6 +13,7 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus from services.dataset_service import DocumentService @@ -79,7 +80,7 @@ class DocumentBatchUpdateIntegrationDataFactory: name=name, created_from=DocumentCreatedFrom.WEB, created_by=created_by or str(uuid4()), - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) document.id = document_id or str(uuid4()) document.enabled = enabled @@ -694,3 +695,19 @@ class TestDatasetServiceBatchUpdateDocumentStatus: patched_dependencies["redis_client"].setex.assert_called_once_with(f"document_{doc1.id}_indexing", 600, 1) patched_dependencies["add_task"].delay.assert_called_once_with(doc1.id) + + def test_batch_update_invalid_action_raises_value_error( + self, db_session_with_containers: Session, patched_dependencies + ): + """Test that an invalid action raises ValueError.""" + factory = DocumentBatchUpdateIntegrationDataFactory + dataset = factory.create_dataset(db_session_with_containers) + doc = factory.create_document(db_session_with_containers, dataset) + user = UserDouble(id=str(uuid4())) + + patched_dependencies["redis_client"].get.return_value = None + + with pytest.raises(ValueError, match="Invalid action"): + DocumentService.batch_update_document_status( + dataset=dataset, document_ids=[doc.id], action="invalid_action", user=user + ) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py new file mode 100644 index 00000000000..c486ff56136 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py @@ -0,0 +1,60 @@ +"""Testcontainers integration tests for DatasetService.create_empty_rag_pipeline_dataset.""" + +from __future__ import annotations + +from unittest.mock import Mock, patch +from uuid import uuid4 + +import pytest + +from models.account import Account, Tenant, TenantAccountJoin +from services.dataset_service import DatasetService +from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity + + +class TestDatasetServiceCreateRagPipelineDataset: + def _create_tenant_and_account(self, db_session_with_containers) -> tuple[Tenant, Account]: + tenant = Tenant(name=f"Tenant {uuid4()}") + db_session_with_containers.add(tenant) + db_session_with_containers.flush() + + account = Account( + name=f"Account {uuid4()}", + email=f"ds_create_{uuid4()}@example.com", + password="hashed", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) + db_session_with_containers.add(account) + db_session_with_containers.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role="owner", + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + return tenant, account + + def _build_entity(self, name: str = "Test Dataset") -> RagPipelineDatasetCreateEntity: + icon_info = IconInfo(icon="\U0001f4d9", icon_background="#FFF4ED", icon_type="emoji") + return RagPipelineDatasetCreateEntity( + name=name, + description="", + icon_info=icon_info, + permission="only_me", + ) + + def test_create_rag_pipeline_dataset_raises_when_current_user_id_is_none(self, db_session_with_containers): + tenant, _ = self._create_tenant_and_account(db_session_with_containers) + + mock_user = Mock(id=None) + with patch("services.dataset_service.current_user", mock_user): + with pytest.raises(ValueError, match="Current user or current user id not found"): + DatasetService.create_empty_rag_pipeline_dataset( + tenant_id=tenant.id, + rag_pipeline_dataset_create_entity=self._build_entity(), + ) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py index ed070527c95..3cac964d89a 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py @@ -3,6 +3,7 @@ from unittest.mock import patch from uuid import uuid4 +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document from models.enums import DataSourceType, DocumentCreatedFrom @@ -78,7 +79,7 @@ class DatasetDeleteIntegrationDataFactory: tenant_id: str, dataset_id: str, created_by: str, - doc_form: str = "text_model", + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, ) -> Document: """Persist a document so dataset.doc_form resolves through the real document path.""" document = Document( @@ -108,7 +109,7 @@ class TestDatasetServiceDeleteDataset: db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, chunk_structure=None, index_struct='{"type": "paragraph"}', collection_binding_id=str(uuid4()), @@ -119,7 +120,7 @@ class TestDatasetServiceDeleteDataset: tenant_id=tenant.id, dataset_id=dataset.id, created_by=owner.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) # Act @@ -207,7 +208,7 @@ class TestDatasetServiceDeleteDataset: db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, chunk_structure=None, index_struct='{"type": "paragraph"}', collection_binding_id=str(uuid4()), diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py index c4b3a57bb2f..87239b2cb33 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py @@ -12,6 +12,7 @@ from uuid import uuid4 from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom @@ -64,7 +65,7 @@ class SegmentServiceTestDataFactory: name=f"Test Dataset {uuid4()}", description="Test description", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=created_by, permission=DatasetPermissionEnum.ONLY_ME, provider="vendor", diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py index 3021d8984d2..2f90d16176a 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py @@ -15,6 +15,7 @@ from uuid import uuid4 from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import ( AppDatasetJoin, @@ -102,7 +103,7 @@ class DatasetRetrievalTestDataFactory: name=name, description="desc", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=created_by, permission=permission, provider="vendor", diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py index fd819482471..a814466e14f 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py @@ -2,9 +2,10 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm import Session -from dify_graph.model_runtime.entities.model_entities import ModelType +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, ExternalKnowledgeBindings from models.enums import DataSourceType @@ -53,7 +54,7 @@ class DatasetUpdateTestDataFactory: provider: str = "vendor", name: str = "old_name", description: str = "old_description", - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, retrieval_model: str = "old_model", permission: str = "only_me", embedding_model_provider: str | None = None, @@ -241,7 +242,7 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=existing_binding_id, @@ -250,7 +251,7 @@ class TestDatasetServiceUpdateDataset: update_data = { "name": "new_name", "description": "new_description", - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "retrieval_model": "new_model", "embedding_model_provider": "openai", "embedding_model": "text-embedding-ada-002", @@ -261,7 +262,7 @@ class TestDatasetServiceUpdateDataset: assert dataset.name == "new_name" assert dataset.description == "new_description" - assert dataset.indexing_technique == "high_quality" + assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert dataset.retrieval_model == "new_model" assert dataset.embedding_model_provider == "openai" assert dataset.embedding_model == "text-embedding-ada-002" @@ -276,7 +277,7 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=existing_binding_id, @@ -285,7 +286,7 @@ class TestDatasetServiceUpdateDataset: update_data = { "name": "new_name", "description": None, - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "retrieval_model": "new_model", "embedding_model_provider": None, "embedding_model": None, @@ -312,14 +313,14 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=existing_binding_id, ) update_data = { - "indexing_technique": "economy", + "indexing_technique": IndexTechniqueType.ECONOMY, "retrieval_model": "new_model", } @@ -328,7 +329,7 @@ class TestDatasetServiceUpdateDataset: mock_task.delay.assert_called_once_with(dataset.id, "remove") db_session_with_containers.refresh(dataset) - assert dataset.indexing_technique == "economy" + assert dataset.indexing_technique == IndexTechniqueType.ECONOMY assert dataset.embedding_model is None assert dataset.embedding_model_provider is None assert dataset.collection_binding_id is None @@ -343,7 +344,7 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="economy", + indexing_technique=IndexTechniqueType.ECONOMY, ) embedding_model = Mock() @@ -354,7 +355,7 @@ class TestDatasetServiceUpdateDataset: binding.id = str(uuid4()) update_data = { - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "embedding_model_provider": "openai", "embedding_model": "text-embedding-ada-002", "retrieval_model": "new_model", @@ -362,7 +363,7 @@ class TestDatasetServiceUpdateDataset: with ( patch("services.dataset_service.current_user", user), - patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager, patch( "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" ) as mock_get_binding, @@ -383,7 +384,7 @@ class TestDatasetServiceUpdateDataset: mock_task.delay.assert_called_once_with(dataset.id, "add") db_session_with_containers.refresh(dataset) - assert dataset.indexing_technique == "high_quality" + assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert dataset.embedding_model == "text-embedding-ada-002" assert dataset.embedding_model_provider == "openai" assert dataset.collection_binding_id == binding.id @@ -403,7 +404,7 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=existing_binding_id, @@ -411,7 +412,7 @@ class TestDatasetServiceUpdateDataset: update_data = { "name": "new_name", - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "retrieval_model": "new_model", } @@ -419,7 +420,7 @@ class TestDatasetServiceUpdateDataset: db_session_with_containers.refresh(dataset) assert dataset.name == "new_name" - assert dataset.indexing_technique == "high_quality" + assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert dataset.embedding_model_provider == "openai" assert dataset.embedding_model == "text-embedding-ada-002" assert dataset.collection_binding_id == existing_binding_id @@ -435,7 +436,7 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=existing_binding_id, @@ -449,7 +450,7 @@ class TestDatasetServiceUpdateDataset: binding.id = str(uuid4()) update_data = { - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "embedding_model_provider": "openai", "embedding_model": "text-embedding-3-small", "retrieval_model": "new_model", @@ -457,7 +458,7 @@ class TestDatasetServiceUpdateDataset: with ( patch("services.dataset_service.current_user", user), - patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager, patch( "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" ) as mock_get_binding, @@ -531,11 +532,11 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="economy", + indexing_technique=IndexTechniqueType.ECONOMY, ) update_data = { - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "embedding_model_provider": "invalid_provider", "embedding_model": "invalid_model", "retrieval_model": "new_model", @@ -543,7 +544,7 @@ class TestDatasetServiceUpdateDataset: with ( patch("services.dataset_service.current_user", user), - patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager, ): mock_model_manager.return_value.get_model_instance.side_effect = Exception("No Embedding Model available") diff --git a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py index 5f86cb2ae94..c8f04e92159 100644 --- a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py +++ b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py @@ -5,9 +5,9 @@ Testcontainers integration tests for archived workflow run deletion service. from datetime import UTC, datetime, timedelta from uuid import uuid4 +from graphon.enums import WorkflowExecutionStatus from sqlalchemy import select -from dify_graph.enums import WorkflowExecutionStatus from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowArchiveLog, WorkflowRun from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion @@ -141,3 +141,73 @@ class TestArchivedWorkflowRunDeletion: db_session_with_containers.expunge_all() deleted_run = db_session_with_containers.get(WorkflowRun, run_id) assert deleted_run is None + + def test_delete_run_dry_run(self, db_session_with_containers): + """Dry run should return success without actually deleting.""" + tenant_id = str(uuid4()) + run = self._create_workflow_run( + db_session_with_containers, + tenant_id=tenant_id, + created_at=datetime.now(UTC), + ) + run_id = run.id + deleter = ArchivedWorkflowRunDeletion(dry_run=True) + + result = deleter._delete_run(run) + + assert result.success is True + assert result.run_id == run_id + # Run should still exist because it's a dry run + db_session_with_containers.expire_all() + assert db_session_with_containers.get(WorkflowRun, run_id) is not None + + def test_delete_run_exception_returns_error(self, db_session_with_containers): + """Exception during deletion should return failure result.""" + from unittest.mock import MagicMock, patch + + tenant_id = str(uuid4()) + run = self._create_workflow_run( + db_session_with_containers, + tenant_id=tenant_id, + created_at=datetime.now(UTC), + ) + deleter = ArchivedWorkflowRunDeletion(dry_run=False) + + with patch.object(deleter, "_get_workflow_run_repo") as mock_get_repo: + mock_repo = MagicMock() + mock_get_repo.return_value = mock_repo + mock_repo.delete_runs_with_related.side_effect = Exception("Database error") + + result = deleter._delete_run(run) + + assert result.success is False + assert result.error == "Database error" + + def test_delete_by_run_id_success(self, db_session_with_containers): + """Successfully delete an archived workflow run by ID.""" + tenant_id = str(uuid4()) + base_time = datetime.now(UTC) + run = self._create_workflow_run( + db_session_with_containers, + tenant_id=tenant_id, + created_at=base_time, + ) + self._create_archive_log(db_session_with_containers, run=run) + run_id = run.id + + deleter = ArchivedWorkflowRunDeletion() + result = deleter.delete_by_run_id(run_id) + + assert result.success is True + db_session_with_containers.expunge_all() + assert db_session_with_containers.get(WorkflowRun, run_id) is None + + def test_get_workflow_run_repo_caches_instance(self, db_session_with_containers): + """_get_workflow_run_repo should return a cached repo on subsequent calls.""" + deleter = ArchivedWorkflowRunDeletion() + + repo1 = deleter._get_workflow_run_repo() + repo2 = deleter._get_workflow_run_repo() + + assert repo1 is repo2 + assert deleter.workflow_run_repo is repo1 diff --git a/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py b/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py index c6aa89c7339..c0047df8107 100644 --- a/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py +++ b/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py @@ -3,6 +3,7 @@ from uuid import uuid4 from sqlalchemy import select +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus from services.dataset_service import DocumentService @@ -42,7 +43,7 @@ def _create_document( name=f"doc-{uuid4()}", created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) document.id = str(uuid4()) document.indexing_status = indexing_status @@ -142,3 +143,11 @@ def test_apply_display_status_filter_returns_same_when_invalid(db_session_with_c rows = db_session_with_containers.scalars(filtered).all() assert {row.id for row in rows} == {doc1.id, doc2.id} + + +def test_normalize_display_status_alias_mapping(): + """Test that normalize_display_status maps aliases correctly.""" + assert DocumentService.normalize_display_status("ACTIVE") == "available" + assert DocumentService.normalize_display_status("enabled") == "available" + assert DocumentService.normalize_display_status("archived") == "archived" + assert DocumentService.normalize_display_status("unknown") is None diff --git a/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py b/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py index bffa520ce69..34532ed7f81 100644 --- a/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py +++ b/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py @@ -7,6 +7,7 @@ from uuid import uuid4 import pytest +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.storage.storage_type import StorageType from models import Account from models.dataset import Dataset, Document @@ -69,7 +70,7 @@ def make_document( name=name, created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) doc.id = document_id doc.indexing_status = "completed" diff --git a/api/tests/test_containers_integration_tests/services/test_end_user_service.py b/api/tests/test_containers_integration_tests/services/test_end_user_service.py index ae811db7689..cafabc939b2 100644 --- a/api/tests/test_containers_integration_tests/services/test_end_user_service.py +++ b/api/tests/test_containers_integration_tests/services/test_end_user_service.py @@ -414,3 +414,144 @@ class TestEndUserServiceGetEndUserById: ) assert result is None + + +class TestEndUserServiceCreateBatch: + """Integration tests for EndUserService.create_end_user_batch.""" + + @pytest.fixture + def factory(self): + return TestEndUserServiceFactory() + + def _create_multiple_apps(self, db_session_with_containers, factory, count: int = 3): + """Create multiple apps under the same tenant.""" + first_app = factory.create_app_and_account(db_session_with_containers) + tenant_id = first_app.tenant_id + apps = [first_app] + for _ in range(count - 1): + app = App( + tenant_id=tenant_id, + name=f"App {uuid4()}", + description="", + mode="chat", + icon_type="emoji", + icon="bot", + icon_background="#FFFFFF", + enable_site=False, + enable_api=True, + api_rpm=100, + api_rph=100, + is_demo=False, + is_public=False, + is_universal=False, + created_by=first_app.created_by, + updated_by=first_app.updated_by, + ) + db_session_with_containers.add(app) + db_session_with_containers.commit() + all_apps = db_session_with_containers.query(App).filter(App.tenant_id == tenant_id).all() + return tenant_id, all_apps + + def test_create_batch_empty_app_ids(self, db_session_with_containers): + result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, tenant_id=str(uuid4()), app_ids=[], user_id="user-1" + ) + assert result == {} + + def test_create_batch_creates_users_for_all_apps(self, db_session_with_containers, factory): + tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=3) + app_ids = [a.id for a in apps] + user_id = f"user-{uuid4()}" + + result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + assert len(result) == 3 + for app_id in app_ids: + assert app_id in result + assert result[app_id].session_id == user_id + assert result[app_id].type == InvokeFrom.SERVICE_API + + def test_create_batch_default_session_id(self, db_session_with_containers, factory): + tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2) + app_ids = [a.id for a in apps] + + result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id="" + ) + + assert len(result) == 2 + for end_user in result.values(): + assert end_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID + assert end_user._is_anonymous is True + + def test_create_batch_deduplicate_app_ids(self, db_session_with_containers, factory): + tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2) + app_ids = [apps[0].id, apps[1].id, apps[0].id, apps[1].id] + user_id = f"user-{uuid4()}" + + result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + assert len(result) == 2 + + def test_create_batch_returns_existing_users(self, db_session_with_containers, factory): + tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2) + app_ids = [a.id for a in apps] + user_id = f"user-{uuid4()}" + + # Create batch first time + first_result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + # Create batch second time — should return existing users + second_result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + assert len(second_result) == 2 + for app_id in app_ids: + assert first_result[app_id].id == second_result[app_id].id + + def test_create_batch_partial_existing_users(self, db_session_with_containers, factory): + tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=3) + user_id = f"user-{uuid4()}" + + # Create for first 2 apps + first_result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, + tenant_id=tenant_id, + app_ids=[apps[0].id, apps[1].id], + user_id=user_id, + ) + + # Create for all 3 apps — should reuse first 2, create 3rd + all_result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, + tenant_id=tenant_id, + app_ids=[a.id for a in apps], + user_id=user_id, + ) + + assert len(all_result) == 3 + assert all_result[apps[0].id].id == first_result[apps[0].id].id + assert all_result[apps[1].id].id == first_result[apps[1].id].id + assert all_result[apps[2].id].session_id == user_id + + @pytest.mark.parametrize( + "invoke_type", + [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP, InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER], + ) + def test_create_batch_all_invoke_types(self, db_session_with_containers, invoke_type, factory): + tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=1) + user_id = f"user-{uuid4()}" + + result = EndUserService.create_end_user_batch( + type=invoke_type, tenant_id=tenant_id, app_ids=[apps[0].id], user_id=user_id + ) + + assert len(result) == 1 + assert result[apps[0].id].type == invoke_type diff --git a/api/tests/test_containers_integration_tests/services/test_file_service_zip_and_lookup.py b/api/tests/test_containers_integration_tests/services/test_file_service_zip_and_lookup.py new file mode 100644 index 00000000000..4e0a726cc72 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_file_service_zip_and_lookup.py @@ -0,0 +1,96 @@ +""" +Testcontainers integration tests for FileService helpers. + +Covers: +- ZIP tempfile building (sanitization + deduplication + content writes) +- tenant-scoped batch lookup behavior (get_upload_files_by_ids) +""" + +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace +from typing import Any +from uuid import uuid4 +from zipfile import ZipFile + +import pytest + +import services.file_service as file_service_module +from extensions.storage.storage_type import StorageType +from models.enums import CreatorUserRole +from models.model import UploadFile +from services.file_service import FileService + + +def _create_upload_file(db_session, *, tenant_id: str, key: str, name: str) -> UploadFile: + upload_file = UploadFile( + tenant_id=tenant_id, + storage_type=StorageType.OPENDAL, + key=key, + name=name, + size=100, + extension="txt", + mime_type="text/plain", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=str(uuid4()), + created_at=datetime.now(UTC), + used=False, + ) + db_session.add(upload_file) + db_session.commit() + return upload_file + + +def test_build_upload_files_zip_tempfile_sanitizes_and_dedupes_names(monkeypatch: pytest.MonkeyPatch) -> None: + """Ensure ZIP entry names are safe and unique while preserving extensions.""" + upload_files: list[Any] = [ + SimpleNamespace(name="a/b.txt", key="k1"), + SimpleNamespace(name="c/b.txt", key="k2"), + SimpleNamespace(name="../b.txt", key="k3"), + ] + + data_by_key: dict[str, list[bytes]] = {"k1": [b"one"], "k2": [b"two"], "k3": [b"three"]} + + def _load(key: str, stream: bool = True) -> list[bytes]: + assert stream is True + return data_by_key[key] + + monkeypatch.setattr(file_service_module.storage, "load", _load) + + with FileService.build_upload_files_zip_tempfile(upload_files=upload_files) as tmp: + with ZipFile(tmp, mode="r") as zf: + assert zf.namelist() == ["b.txt", "b (1).txt", "b (2).txt"] + assert zf.read("b.txt") == b"one" + assert zf.read("b (1).txt") == b"two" + assert zf.read("b (2).txt") == b"three" + + +def test_get_upload_files_by_ids_returns_empty_when_no_ids(db_session_with_containers) -> None: + """Ensure empty input returns an empty mapping without hitting the database.""" + assert FileService.get_upload_files_by_ids(str(uuid4()), []) == {} + + +def test_get_upload_files_by_ids_returns_id_keyed_mapping(db_session_with_containers) -> None: + """Ensure batch lookup returns a dict keyed by stringified UploadFile ids.""" + tenant_id = str(uuid4()) + file1 = _create_upload_file(db_session_with_containers, tenant_id=tenant_id, key="k1", name="file1.txt") + file2 = _create_upload_file(db_session_with_containers, tenant_id=tenant_id, key="k2", name="file2.txt") + + result = FileService.get_upload_files_by_ids(tenant_id, [file1.id, file1.id, file2.id]) + + assert set(result.keys()) == {file1.id, file2.id} + assert result[file1.id].id == file1.id + assert result[file2.id].id == file2.id + + +def test_get_upload_files_by_ids_filters_by_tenant(db_session_with_containers) -> None: + """Ensure files from other tenants are not returned.""" + tenant_a = str(uuid4()) + tenant_b = str(uuid4()) + file_a = _create_upload_file(db_session_with_containers, tenant_id=tenant_a, key="ka", name="a.txt") + _create_upload_file(db_session_with_containers, tenant_id=tenant_b, key="kb", name="b.txt") + + result = FileService.get_upload_files_by_ids(tenant_a, [file_a.id]) + + assert set(result.keys()) == {file_a.id} diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py index 70d05792cee..c46b8fba0bd 100644 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py @@ -3,14 +3,14 @@ import uuid from unittest.mock import MagicMock import pytest +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.human_input.entities import HumanInputNodeData -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - HumanInputNodeData, ) from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.model import App, AppMode @@ -54,7 +54,7 @@ def _create_app_with_draft_workflow(session, *, delivery_method_id: uuid.UUID) - enabled=True, config=EmailDeliveryConfig( recipients=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ExternalRecipient(email="recipient@example.com")], ), subject="Test {{recipient_email}}", diff --git a/api/tests/unit_tests/services/test_human_input_delivery_test_service.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py similarity index 79% rename from api/tests/unit_tests/services/test_human_input_delivery_test_service.py rename to api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py index a23c44b26ea..0f252515f72 100644 --- a/api/tests/unit_tests/services/test_human_input_delivery_test_service.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py @@ -1,18 +1,22 @@ +from __future__ import annotations + from types import SimpleNamespace from unittest.mock import MagicMock, patch +from uuid import uuid4 import pytest +from graphon.runtime import VariablePool from sqlalchemy.engine import Engine from configs import dify_config -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, MemberRecipient, ) -from dify_graph.runtime import VariablePool +from models.account import Account, TenantAccountJoin from services import human_input_delivery_test_service as service_module from services.human_input_delivery_test_service import ( DeliveryTestContext, @@ -28,15 +32,12 @@ from services.human_input_delivery_test_service import ( ) -@pytest.fixture -def mock_db(monkeypatch): - mock_db = MagicMock() - monkeypatch.setattr(service_module, "db", mock_db) - return mock_db - - def _make_valid_email_config(): - return EmailDeliveryConfig(recipients=EmailRecipients(whole_workspace=False, items=[]), subject="Subj", body="Body") + return EmailDeliveryConfig( + recipients=EmailRecipients(include_bound_group=False, items=[]), + subject="Subj", + body="Body", + ) def test_build_form_link(): @@ -87,7 +88,7 @@ class TestDeliveryTestRegistry: with pytest.raises(DeliveryTestUnsupportedError, match="Delivery method does not support test send."): registry.dispatch(context=context, method=method) - def test_default(self, mock_db): + def test_default(self, flask_app_with_containers, db_session_with_containers): registry = DeliveryTestRegistry.default() assert len(registry._handlers) == 1 assert isinstance(registry._handlers[0], EmailDeliveryTestHandler) @@ -246,62 +247,70 @@ class TestEmailDeliveryTestHandler: _, kwargs = mock_mail_send.call_args assert kwargs["subject"] == "Notice BCC:test@example.com" - def test_resolve_recipients(self): + def test_resolve_recipients_external(self): handler = EmailDeliveryTestHandler(session_factory=MagicMock()) - - # Test Case 1: External Recipient method = EmailDeliveryMethod( config=EmailDeliveryConfig( - recipients=EmailRecipients(items=[ExternalRecipient(email="ext@example.com")], whole_workspace=False), + recipients=EmailRecipients( + items=[ExternalRecipient(email="ext@example.com")], include_bound_group=False + ), subject="", body="", ) ) assert handler._resolve_recipients(tenant_id="t1", method=method) == ["ext@example.com"] - # Test Case 2: Member Recipient + def test_resolve_recipients_member(self, flask_app_with_containers, db_session_with_containers): + tenant_id = str(uuid4()) + account = Account(name="Test User", email="member@example.com") + db_session_with_containers.add(account) + db_session_with_containers.commit() + + join = TenantAccountJoin(tenant_id=tenant_id, account_id=account.id) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + from extensions.ext_database import db + + handler = EmailDeliveryTestHandler(session_factory=db.engine) method = EmailDeliveryMethod( config=EmailDeliveryConfig( - recipients=EmailRecipients(items=[MemberRecipient(user_id="u1")], whole_workspace=False), + recipients=EmailRecipients(items=[MemberRecipient(reference_id=account.id)], include_bound_group=False), subject="", body="", ) ) - handler._query_workspace_member_emails = MagicMock(return_value={"u1": "u1@example.com"}) - assert handler._resolve_recipients(tenant_id="t1", method=method) == ["u1@example.com"] + assert handler._resolve_recipients(tenant_id=tenant_id, method=method) == ["member@example.com"] - # Test Case 3: Whole Workspace + def test_resolve_recipients_whole_workspace(self, flask_app_with_containers, db_session_with_containers): + tenant_id = str(uuid4()) + account1 = Account(name="User 1", email=f"u1-{uuid4()}@example.com") + account2 = Account(name="User 2", email=f"u2-{uuid4()}@example.com") + db_session_with_containers.add_all([account1, account2]) + db_session_with_containers.commit() + + for acc in [account1, account2]: + join = TenantAccountJoin(tenant_id=tenant_id, account_id=acc.id) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + from extensions.ext_database import db + + handler = EmailDeliveryTestHandler(session_factory=db.engine) method = EmailDeliveryMethod( - config=EmailDeliveryConfig(recipients=EmailRecipients(items=[], whole_workspace=True), subject="", body="") + config=EmailDeliveryConfig( + recipients=EmailRecipients(items=[], include_bound_group=True), + subject="", + body="", + ) ) - handler._query_workspace_member_emails = MagicMock( - return_value={"u1": "u1@example.com", "u2": "u2@example.com"} - ) - recipients = handler._resolve_recipients(tenant_id="t1", method=method) - assert set(recipients) == {"u1@example.com", "u2@example.com"} + recipients = handler._resolve_recipients(tenant_id=tenant_id, method=method) + assert set(recipients) == {account1.email, account2.email} - def test_query_workspace_member_emails(self): - mock_session = MagicMock() - mock_session_factory = MagicMock(return_value=mock_session) - mock_session.__enter__.return_value = mock_session - - handler = EmailDeliveryTestHandler(session_factory=mock_session_factory) - - # Empty user_ids + def test_query_workspace_member_emails_empty_ids(self): + handler = EmailDeliveryTestHandler(session_factory=MagicMock()) assert handler._query_workspace_member_emails(tenant_id="t1", user_ids=[]) == {} - # user_ids is None (all) - mock_execute = MagicMock() - mock_session.execute.return_value = mock_execute - mock_execute.all.return_value = [("u1", "u1@example.com")] - - result = handler._query_workspace_member_emails(tenant_id="t1", user_ids=None) - assert result == {"u1": "u1@example.com"} - - # user_ids with values - result = handler._query_workspace_member_emails(tenant_id="t1", user_ids=["u1"]) - assert result == {"u1": "u1@example.com"} - def test_build_substitutions(self): context = DeliveryTestContext( tenant_id="t1", @@ -313,7 +322,8 @@ class TestEmailDeliveryTestHandler: recipients=[DeliveryTestEmailRecipient(email="test@example.com", form_token="token123")], ) - subs = EmailDeliveryTestHandler._build_substitutions(context=context, recipient_email="test@example.com") + with patch.object(dify_config, "APP_WEB_URL", "http://example.com"): + subs = EmailDeliveryTestHandler._build_substitutions(context=context, recipient_email="test@example.com") assert subs["node_title"] == "title" assert subs["form_content"] == "content" @@ -322,7 +332,6 @@ class TestEmailDeliveryTestHandler: assert subs["form_token"] == "token123" assert "form/token123" in subs["form_link"] - # Without matching recipient subs_no_match = EmailDeliveryTestHandler._build_substitutions( context=context, recipient_email="other@example.com" ) diff --git a/api/tests/test_containers_integration_tests/services/test_message_export_service.py b/api/tests/test_containers_integration_tests/services/test_message_export_service.py index 805bab9b9d5..00dfe9dda41 100644 --- a/api/tests/test_containers_integration_tests/services/test_message_export_service.py +++ b/api/tests/test_containers_integration_tests/services/test_message_export_service.py @@ -7,7 +7,7 @@ import pytest from sqlalchemy.orm import Session from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole -from models.enums import FeedbackFromSource, FeedbackRating +from models.enums import ConversationFromSource, FeedbackFromSource, FeedbackRating from models.model import ( App, AppAnnotationHitHistory, @@ -94,7 +94,7 @@ class TestAppMessageExportServiceIntegration: name="conv", inputs={"seed": 1}, status="normal", - from_source="api", + from_source=ConversationFromSource.API, from_end_user_id=str(uuid.uuid4()), ) session.add(conversation) @@ -129,7 +129,7 @@ class TestAppMessageExportServiceIntegration: total_price=Decimal("0.003"), currency="USD", message_metadata=message_metadata, - from_source="api", + from_source=ConversationFromSource.API, from_end_user_id=conversation.from_end_user_id, created_at=created_at, ) diff --git a/api/tests/test_containers_integration_tests/services/test_message_service.py b/api/tests/test_containers_integration_tests/services/test_message_service.py index af666a03752..bdf6d9b9517 100644 --- a/api/tests/test_containers_integration_tests/services/test_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_message_service.py @@ -4,7 +4,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session -from models.enums import FeedbackRating +from models.enums import ConversationFromSource, FeedbackRating, InvokeFrom from models.model import MessageFeedback from services.app_service import AppService from services.errors.message import ( @@ -25,7 +25,7 @@ class TestMessageService: """Mock setup for external service dependencies.""" with ( patch("services.account_service.FeatureService") as mock_account_feature_service, - patch("services.message_service.ModelManager") as mock_model_manager, + patch("services.message_service.ModelManager.for_tenant") as mock_model_manager, patch("services.message_service.WorkflowService") as mock_workflow_service, patch("services.message_service.AdvancedChatAppConfigManager") as mock_app_config_manager, patch("services.message_service.LLMGenerator") as mock_llm_generator, @@ -149,8 +149,8 @@ class TestMessageService: system_instruction="", system_instruction_tokens=0, status="normal", - invoke_from="console", - from_source="console", + invoke_from=InvokeFrom.EXPLORE, + from_source=ConversationFromSource.CONSOLE, from_end_user_id=None, from_account_id=account.id, ) @@ -187,8 +187,8 @@ class TestMessageService: provider_response_latency=0, total_price=0, currency="USD", - invoke_from="console", - from_source="console", + invoke_from=InvokeFrom.EXPLORE, + from_source=ConversationFromSource.CONSOLE, from_end_user_id=None, from_account_id=account.id, ) diff --git a/api/tests/test_containers_integration_tests/services/test_message_service_extra_contents.py b/api/tests/test_containers_integration_tests/services/test_message_service_extra_contents.py index 772365ba540..f2cb667204f 100644 --- a/api/tests/test_containers_integration_tests/services/test_message_service_extra_contents.py +++ b/api/tests/test_containers_integration_tests/services/test_message_service_extra_contents.py @@ -4,6 +4,7 @@ from decimal import Decimal import pytest +from models.enums import ConversationFromSource from models.model import Message from services import message_service from tests.test_containers_integration_tests.helpers.execution_extra_content import ( @@ -36,7 +37,7 @@ def test_attach_message_extra_contents_assigns_serialized_payload(db_session_wit total_price=Decimal(0), currency="USD", status="normal", - from_source="console", + from_source=ConversationFromSource.CONSOLE, from_account_id=fixture.account.id, ) db_session_with_containers.add(message_without_extra_content) diff --git a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py index 863f013e195..2340dd2a03d 100644 --- a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py +++ b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py @@ -1,17 +1,27 @@ +from __future__ import annotations + import datetime import json import uuid from decimal import Decimal -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from faker import Faker +from graphon.file import FileType from sqlalchemy.orm import Session from enums.cloud_plan import CloudPlan from extensions.ext_redis import redis_client from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole -from models.enums import DataSourceType, FeedbackFromSource, FeedbackRating, MessageChainType, MessageFileBelongsTo +from models.enums import ( + ConversationFromSource, + DataSourceType, + FeedbackFromSource, + FeedbackRating, + MessageChainType, + MessageFileBelongsTo, +) from models.model import ( App, AppAnnotationHitHistory, @@ -166,7 +176,7 @@ class TestMessagesCleanServiceIntegration: name="Test conversation", inputs={}, status="normal", - from_source=FeedbackFromSource.USER, + from_source=ConversationFromSource.API, from_end_user_id=str(uuid.uuid4()), ) db_session_with_containers.add(conversation) @@ -196,7 +206,7 @@ class TestMessagesCleanServiceIntegration: answer_unit_price=Decimal("0.002"), total_price=Decimal("0.003"), currency="USD", - from_source=FeedbackFromSource.USER, + from_source=ConversationFromSource.API, from_account_id=conversation.from_end_user_id, created_at=created_at, ) @@ -246,7 +256,7 @@ class TestMessagesCleanServiceIntegration: # MessageFile file = MessageFile( message_id=message.id, - type="image", + type=FileType.IMAGE, transfer_method="local_file", url="http://example.com/test.jpg", belongs_to=MessageFileBelongsTo.USER, @@ -1161,3 +1171,66 @@ class TestMessagesCleanServiceIntegration: # Verify all messages were deleted assert db_session_with_containers.query(Message).where(Message.id.in_(msg_ids)).count() == 0 + + def test_from_time_range_validation(self): + """Test that from_time_range raises ValueError for invalid inputs.""" + policy = MagicMock(spec=BillingDisabledPolicy) + now = datetime.datetime.now() + + with pytest.raises(ValueError, match="start_from .* must be less than end_before"): + MessagesCleanService.from_time_range(policy, now, now) + + with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): + MessagesCleanService.from_time_range(policy, now - datetime.timedelta(days=1), now, batch_size=0) + + def test_from_time_range_success(self): + """Test that from_time_range creates a service with correct parameters.""" + policy = MagicMock(spec=BillingDisabledPolicy) + start = datetime.datetime(2024, 1, 1) + end = datetime.datetime(2024, 2, 1) + + service = MessagesCleanService.from_time_range(policy, start, end) + assert service._start_from == start + assert service._end_before == end + + def test_from_days_validation(self): + """Test that from_days raises ValueError for invalid inputs.""" + policy = MagicMock(spec=BillingDisabledPolicy) + + with pytest.raises(ValueError, match="days .* must be greater than or equal to 0"): + MessagesCleanService.from_days(policy, days=-1) + + with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): + MessagesCleanService.from_days(policy, days=30, batch_size=0) + + def test_from_days_success(self): + """Test that from_days creates a service with correct parameters.""" + policy = MagicMock(spec=BillingDisabledPolicy) + + with patch("services.retention.conversation.messages_clean_service.naive_utc_now") as mock_now: + fixed_now = datetime.datetime(2024, 6, 1) + mock_now.return_value = fixed_now + + service = MessagesCleanService.from_days(policy, days=10) + assert service._start_from is None + assert service._end_before == fixed_now - datetime.timedelta(days=10) + + def test_batch_delete_message_relations_empty(self, db_session_with_containers: Session): + """Test that batch_delete_message_relations with empty list does nothing.""" + # Get execute call count before + MessagesCleanService._batch_delete_message_relations(db_session_with_containers, []) + # No exception means success — empty list is a no-op + + def test_run_calls_clean_messages(self): + """Test that run() delegates to _clean_messages_by_time_range.""" + policy = MagicMock(spec=BillingDisabledPolicy) + service = MessagesCleanService( + policy=policy, + end_before=datetime.datetime.now(), + batch_size=10, + ) + with patch.object(service, "_clean_messages_by_time_range") as mock_clean: + mock_clean.return_value = {"total_deleted": 5} + result = service.run() + assert result == {"total_deleted": 5} + mock_clean.assert_called_once() diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py b/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py new file mode 100644 index 00000000000..b55a19eaa9e --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py @@ -0,0 +1,182 @@ +from __future__ import annotations + +from unittest.mock import Mock, patch +from uuid import uuid4 + +import pytest +from sqlalchemy import select + +from models.dataset import Dataset, DatasetMetadataBinding, Document +from models.enums import DataSourceType, DocumentCreatedFrom +from services.entities.knowledge_entities.knowledge_entities import ( + DocumentMetadataOperation, + MetadataDetail, + MetadataOperationData, +) +from services.metadata_service import MetadataService + + +def _create_dataset(db_session, *, tenant_id: str, built_in_field_enabled: bool = False) -> Dataset: + dataset = Dataset( + tenant_id=tenant_id, + name=f"dataset-{uuid4()}", + data_source_type=DataSourceType.UPLOAD_FILE, + created_by=str(uuid4()), + ) + dataset.id = str(uuid4()) + dataset.built_in_field_enabled = built_in_field_enabled + db_session.add(dataset) + db_session.commit() + return dataset + + +def _create_document(db_session, *, dataset_id: str, tenant_id: str, doc_metadata: dict | None = None) -> Document: + document = Document( + tenant_id=tenant_id, + dataset_id=dataset_id, + position=1, + data_source_type=DataSourceType.UPLOAD_FILE, + data_source_info="{}", + batch=f"batch-{uuid4()}", + name=f"doc-{uuid4()}", + created_from=DocumentCreatedFrom.WEB, + created_by=str(uuid4()), + ) + document.id = str(uuid4()) + document.doc_metadata = doc_metadata + db_session.add(document) + db_session.commit() + return document + + +class TestMetadataPartialUpdate: + @pytest.fixture + def tenant_id(self) -> str: + return str(uuid4()) + + @pytest.fixture + def user_id(self) -> str: + return str(uuid4()) + + @pytest.fixture + def mock_current_account(self, user_id, tenant_id): + account = Mock(id=user_id, current_tenant_id=tenant_id) + with patch("services.metadata_service.current_account_with_tenant", return_value=(account, tenant_id)): + yield account + + def test_partial_update_merges_metadata( + self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account + ): + dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) + document = _create_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant_id, + doc_metadata={"existing_key": "existing_value"}, + ) + + meta_id = str(uuid4()) + operation = DocumentMetadataOperation( + document_id=document.id, + metadata_list=[MetadataDetail(id=meta_id, name="new_key", value="new_value")], + partial_update=True, + ) + metadata_args = MetadataOperationData(operation_data=[operation]) + + MetadataService.update_documents_metadata(dataset, metadata_args) + db_session_with_containers.expire_all() + + updated_doc = db_session_with_containers.get(Document, document.id) + assert updated_doc is not None + assert updated_doc.doc_metadata["existing_key"] == "existing_value" + assert updated_doc.doc_metadata["new_key"] == "new_value" + + def test_full_update_replaces_metadata( + self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account + ): + dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) + document = _create_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant_id, + doc_metadata={"existing_key": "existing_value"}, + ) + + meta_id = str(uuid4()) + operation = DocumentMetadataOperation( + document_id=document.id, + metadata_list=[MetadataDetail(id=meta_id, name="new_key", value="new_value")], + partial_update=False, + ) + metadata_args = MetadataOperationData(operation_data=[operation]) + + MetadataService.update_documents_metadata(dataset, metadata_args) + db_session_with_containers.expire_all() + + updated_doc = db_session_with_containers.get(Document, document.id) + assert updated_doc is not None + assert updated_doc.doc_metadata == {"new_key": "new_value"} + assert "existing_key" not in updated_doc.doc_metadata + + def test_partial_update_skips_existing_binding( + self, flask_app_with_containers, db_session_with_containers, tenant_id, user_id, mock_current_account + ): + dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) + document = _create_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant_id, + doc_metadata={"existing_key": "existing_value"}, + ) + + meta_id = str(uuid4()) + existing_binding = DatasetMetadataBinding( + tenant_id=tenant_id, + dataset_id=dataset.id, + document_id=document.id, + metadata_id=meta_id, + created_by=user_id, + ) + db_session_with_containers.add(existing_binding) + db_session_with_containers.commit() + + operation = DocumentMetadataOperation( + document_id=document.id, + metadata_list=[MetadataDetail(id=meta_id, name="existing_key", value="existing_value")], + partial_update=True, + ) + metadata_args = MetadataOperationData(operation_data=[operation]) + + MetadataService.update_documents_metadata(dataset, metadata_args) + db_session_with_containers.expire_all() + + bindings = db_session_with_containers.scalars( + select(DatasetMetadataBinding).where( + DatasetMetadataBinding.document_id == document.id, + DatasetMetadataBinding.metadata_id == meta_id, + ) + ).all() + assert len(bindings) == 1 + + def test_rollback_called_on_commit_failure( + self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account + ): + dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) + document = _create_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant_id, + doc_metadata={"existing_key": "existing_value"}, + ) + + meta_id = str(uuid4()) + operation = DocumentMetadataOperation( + document_id=document.id, + metadata_list=[MetadataDetail(id=meta_id, name="key", value="value")], + partial_update=True, + ) + metadata_args = MetadataOperationData(operation_data=[operation]) + + with patch("services.metadata_service.db.session.commit", side_effect=RuntimeError("database connection lost")): + with pytest.raises(RuntimeError, match="database connection lost"): + MetadataService.update_documents_metadata(dataset, metadata_args) diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_service.py b/api/tests/test_containers_integration_tests/services/test_metadata_service.py index e847329c5b0..8b1349be9a8 100644 --- a/api/tests/test_containers_integration_tests/services/test_metadata_service.py +++ b/api/tests/test_containers_integration_tests/services/test_metadata_service.py @@ -5,6 +5,7 @@ from faker import Faker from sqlalchemy.orm import Session from core.rag.index_processor.constant.built_in_field import BuiltInField +from core.rag.index_processor.constant.index_type import IndexStructureType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding, Document from models.enums import DatasetMetadataType, DataSourceType, DocumentCreatedFrom @@ -139,7 +140,7 @@ class TestMetadataService: name=fake.file_name(), created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", ) diff --git a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py index 989df424991..ca6e7afeabc 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py @@ -18,11 +18,10 @@ class TestModelLoadBalancingService: def mock_external_service_dependencies(self): """Mock setup for external service dependencies.""" with ( - patch("services.model_load_balancing_service.ProviderManager", autospec=True) as mock_provider_manager, - patch("services.model_load_balancing_service.LBModelManager", autospec=True) as mock_lb_model_manager, patch( - "services.model_load_balancing_service.ModelProviderFactory", autospec=True - ) as mock_model_provider_factory, + "services.model_load_balancing_service.create_plugin_provider_manager", autospec=True + ) as mock_provider_manager, + patch("services.model_load_balancing_service.LBModelManager", autospec=True) as mock_lb_model_manager, patch("services.model_load_balancing_service.encrypter", autospec=True) as mock_encrypter, ): # Setup default mock returns @@ -46,9 +45,6 @@ class TestModelLoadBalancingService: # Mock LBModelManager mock_lb_model_manager.get_config_in_cooldown_and_ttl.return_value = (False, 0) - # Mock ModelProviderFactory - mock_model_provider_factory_instance = mock_model_provider_factory.return_value - # Mock credential schemas mock_credential_schema = MagicMock() mock_credential_schema.credential_form_schemas = [] @@ -61,7 +57,6 @@ class TestModelLoadBalancingService: yield { "provider_manager": mock_provider_manager, "lb_model_manager": mock_lb_model_manager, - "model_provider_factory": mock_model_provider_factory, "encrypter": mock_encrypter, "provider_config": mock_provider_config, "provider_model_setting": mock_provider_model_setting, diff --git a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py index 6afc5aa43c7..ba926bf6758 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py @@ -2,10 +2,10 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType from sqlalchemy.orm import Session from core.entities.model_entities import ModelStatus -from dify_graph.model_runtime.entities.model_entities import FetchFrom, ModelType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.provider import Provider, ProviderModel, ProviderModelSetting, ProviderType from services.model_provider_service import ModelProviderService @@ -18,8 +18,12 @@ class TestModelProviderService: def mock_external_service_dependencies(self): """Mock setup for external service dependencies.""" with ( - patch("services.model_provider_service.ProviderManager", autospec=True) as mock_provider_manager, - patch("services.model_provider_service.ModelProviderFactory", autospec=True) as mock_model_provider_factory, + patch( + "services.model_provider_service.create_plugin_provider_manager", autospec=True + ) as mock_provider_manager, + patch( + "services.model_provider_service.create_plugin_model_provider_factory", autospec=True + ) as mock_model_provider_factory, ): # Setup default mock returns mock_provider_manager.return_value.get_configurations.return_value = MagicMock() @@ -401,9 +405,10 @@ class TestModelProviderService: mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value # Create mock models + from graphon.model_runtime.entities.common_entities import I18nObject + from graphon.model_runtime.entities.provider_entities import ProviderEntity + from core.entities.model_entities import ModelWithProviderEntity, SimpleModelProviderEntity - from dify_graph.model_runtime.entities.common_entities import I18nObject - from dify_graph.model_runtime.entities.provider_entities import ProviderEntity # Create real model objects instead of mocks provider_entity_1 = SimpleModelProviderEntity( @@ -639,8 +644,9 @@ class TestModelProviderService: mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value # Create mock default model response + from graphon.model_runtime.entities.common_entities import I18nObject + from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity - from dify_graph.model_runtime.entities.common_entities import I18nObject mock_default_model = DefaultModelEntity( model="gpt-3.5-turbo", diff --git a/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py b/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py new file mode 100644 index 00000000000..c146a5924bc --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py @@ -0,0 +1,174 @@ +"""Testcontainers integration tests for OAuthServerService.""" + +from __future__ import annotations + +import uuid +from typing import cast +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from werkzeug.exceptions import BadRequest + +from models.model import OAuthProviderApp +from services.oauth_server import ( + OAUTH_ACCESS_TOKEN_EXPIRES_IN, + OAUTH_ACCESS_TOKEN_REDIS_KEY, + OAUTH_AUTHORIZATION_CODE_REDIS_KEY, + OAUTH_REFRESH_TOKEN_EXPIRES_IN, + OAUTH_REFRESH_TOKEN_REDIS_KEY, + OAuthGrantType, + OAuthServerService, +) + + +class TestOAuthServerServiceGetProviderApp: + """DB-backed tests for get_oauth_provider_app.""" + + def _create_oauth_provider_app(self, db_session_with_containers, *, client_id: str) -> OAuthProviderApp: + app = OAuthProviderApp( + app_icon="icon.png", + client_id=client_id, + client_secret=str(uuid4()), + app_label={"en-US": "Test OAuth App"}, + redirect_uris=["https://example.com/callback"], + scope="read", + ) + db_session_with_containers.add(app) + db_session_with_containers.commit() + return app + + def test_get_oauth_provider_app_returns_app_when_exists(self, db_session_with_containers): + client_id = f"client-{uuid4()}" + created = self._create_oauth_provider_app(db_session_with_containers, client_id=client_id) + + result = OAuthServerService.get_oauth_provider_app(client_id) + + assert result is not None + assert result.client_id == client_id + assert result.id == created.id + + def test_get_oauth_provider_app_returns_none_when_not_exists(self, db_session_with_containers): + result = OAuthServerService.get_oauth_provider_app(f"nonexistent-{uuid4()}") + + assert result is None + + +class TestOAuthServerServiceTokenOperations: + """Redis-backed tests for token sign/validate operations.""" + + @pytest.fixture + def mock_redis(self): + with patch("services.oauth_server.redis_client") as mock: + yield mock + + def test_sign_authorization_code_stores_and_returns_code(self, mock_redis): + deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000111") + with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid): + code = OAuthServerService.sign_oauth_authorization_code("client-1", "user-1") + + assert code == str(deterministic_uuid) + mock_redis.set.assert_called_once_with( + OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code=code), + "user-1", + ex=600, + ) + + def test_sign_access_token_raises_bad_request_for_invalid_code(self, mock_redis): + mock_redis.get.return_value = None + + with pytest.raises(BadRequest, match="invalid code"): + OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.AUTHORIZATION_CODE, + code="bad-code", + client_id="client-1", + ) + + def test_sign_access_token_issues_tokens_for_valid_code(self, mock_redis): + token_uuids = [ + uuid.UUID("00000000-0000-0000-0000-000000000201"), + uuid.UUID("00000000-0000-0000-0000-000000000202"), + ] + with patch("services.oauth_server.uuid.uuid4", side_effect=token_uuids): + mock_redis.get.return_value = b"user-1" + + access_token, refresh_token = OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.AUTHORIZATION_CODE, + code="code-1", + client_id="client-1", + ) + + assert access_token == str(token_uuids[0]) + assert refresh_token == str(token_uuids[1]) + code_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code="code-1") + mock_redis.delete.assert_called_once_with(code_key) + mock_redis.set.assert_any_call( + OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token), + b"user-1", + ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN, + ) + mock_redis.set.assert_any_call( + OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-1", token=refresh_token), + b"user-1", + ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN, + ) + + def test_sign_access_token_raises_bad_request_for_invalid_refresh_token(self, mock_redis): + mock_redis.get.return_value = None + + with pytest.raises(BadRequest, match="invalid refresh token"): + OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.REFRESH_TOKEN, + refresh_token="stale-token", + client_id="client-1", + ) + + def test_sign_access_token_issues_new_token_for_valid_refresh(self, mock_redis): + deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000301") + with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid): + mock_redis.get.return_value = b"user-1" + + access_token, returned_refresh = OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.REFRESH_TOKEN, + refresh_token="refresh-1", + client_id="client-1", + ) + + assert access_token == str(deterministic_uuid) + assert returned_refresh == "refresh-1" + + def test_sign_access_token_returns_none_for_unknown_grant_type(self, mock_redis): + grant_type = cast(OAuthGrantType, "invalid-grant-type") + + result = OAuthServerService.sign_oauth_access_token(grant_type=grant_type, client_id="client-1") + + assert result is None + + def test_sign_refresh_token_stores_with_expected_expiry(self, mock_redis): + deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000401") + with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid): + refresh_token = OAuthServerService._sign_oauth_refresh_token("client-2", "user-2") + + assert refresh_token == str(deterministic_uuid) + mock_redis.set.assert_called_once_with( + OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-2", token=refresh_token), + "user-2", + ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN, + ) + + def test_validate_access_token_returns_none_when_not_found(self, mock_redis): + mock_redis.get.return_value = None + + result = OAuthServerService.validate_oauth_access_token("client-1", "missing-token") + + assert result is None + + def test_validate_access_token_loads_user_when_exists(self, mock_redis): + mock_redis.get.return_value = b"user-88" + expected_user = MagicMock() + + with patch("services.oauth_server.AccountService.load_user", return_value=expected_user) as mock_load: + result = OAuthServerService.validate_oauth_access_token("client-1", "access-token") + + assert result is expected_user + mock_load.assert_called_once_with("user-88") diff --git a/api/tests/test_containers_integration_tests/services/test_restore_archived_workflow_run.py b/api/tests/test_containers_integration_tests/services/test_restore_archived_workflow_run.py index ba4310e22e5..7036524918b 100644 --- a/api/tests/test_containers_integration_tests/services/test_restore_archived_workflow_run.py +++ b/api/tests/test_containers_integration_tests/services/test_restore_archived_workflow_run.py @@ -2,17 +2,43 @@ Testcontainers integration tests for workflow run restore functionality. """ +from __future__ import annotations + +from datetime import datetime from uuid import uuid4 from sqlalchemy import select -from models.workflow import WorkflowPause +from models.workflow import WorkflowPause, WorkflowRun from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore class TestWorkflowRunRestore: """Tests for the WorkflowRunRestore class.""" + def test_restore_initialization(self): + """Restore service should respect dry_run flag.""" + restore = WorkflowRunRestore(dry_run=True) + + assert restore.dry_run is True + + def test_convert_datetime_fields(self): + """ISO datetime strings should be converted to datetime objects.""" + record = { + "id": "test-id", + "created_at": "2024-01-01T12:00:00", + "finished_at": "2024-01-01T12:05:00", + "name": "test", + } + + restore = WorkflowRunRestore() + result = restore._convert_datetime_fields(record, WorkflowRun) + + assert isinstance(result["created_at"], datetime) + assert result["created_at"].year == 2024 + assert result["created_at"].month == 1 + assert result["name"] == "test" + def test_restore_table_records_returns_rowcount(self, db_session_with_containers): """Restore should return inserted rowcount.""" restore = WorkflowRunRestore() diff --git a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py index dd743d46c2f..70aa813142c 100644 --- a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py @@ -4,6 +4,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from models.enums import ConversationFromSource from models.model import EndUser, Message from models.web import SavedMessage from services.app_service import AppService @@ -19,7 +20,7 @@ class TestSavedMessageService: """Mock setup for external service dependencies.""" with ( patch("services.account_service.FeatureService") as mock_account_feature_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.saved_message_service.MessageService") as mock_message_service, ): # Setup default mock returns @@ -132,11 +133,14 @@ class TestSavedMessageService: # Create a simple conversation first from models.model import Conversation + is_account = hasattr(user, "current_tenant") + from_source = ConversationFromSource.CONSOLE if is_account else ConversationFromSource.API + conversation = Conversation( app_id=app.id, - from_source="account" if hasattr(user, "current_tenant") else "end_user", - from_end_user_id=user.id if not hasattr(user, "current_tenant") else None, - from_account_id=user.id if hasattr(user, "current_tenant") else None, + from_source=from_source, + from_end_user_id=user.id if not is_account else None, + from_account_id=user.id if is_account else None, name=fake.sentence(nb_words=3), inputs={}, status="normal", @@ -150,9 +154,9 @@ class TestSavedMessageService: message = Message( app_id=app.id, conversation_id=conversation.id, - from_source="account" if hasattr(user, "current_tenant") else "end_user", - from_end_user_id=user.id if not hasattr(user, "current_tenant") else None, - from_account_id=user.id if hasattr(user, "current_tenant") else None, + from_source=from_source, + from_end_user_id=user.id if not is_account else None, + from_account_id=user.id if is_account else None, inputs={}, query=fake.sentence(nb_words=5), message=fake.text(max_nb_chars=100), @@ -392,11 +396,6 @@ class TestSavedMessageService: assert "User is required" in str(exc_info.value) - # Verify no database operations were performed - - saved_messages = db_session_with_containers.query(SavedMessage).all() - assert len(saved_messages) == 0 - def test_save_error_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test error handling when saving message with no user. @@ -493,124 +492,140 @@ class TestSavedMessageService: # The message should still exist, only the saved_message should be deleted assert db_session_with_containers.query(Message).where(Message.id == message.id).first() is not None - def test_pagination_by_last_id_error_no_user( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): - """ - Test error handling when no user is provided. - - This test verifies: - - Proper error handling for missing user - - ValueError is raised when user is None - - No database operations are performed - """ - # Arrange: Create test data - fake = Faker() + def test_save_for_end_user(self, db_session_with_containers: Session, mock_external_service_dependencies): + """Test saving a message for an EndUser.""" app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + end_user = self._create_test_end_user(db_session_with_containers, app) + message = self._create_test_message(db_session_with_containers, app, end_user) - # Act & Assert: Verify proper error handling - with pytest.raises(ValueError) as exc_info: - SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=10) + mock_external_service_dependencies["message_service"].get_message.return_value = message - assert "User is required" in str(exc_info.value) + SavedMessageService.save(app_model=app, user=end_user, message_id=message.id) - # Verify no database operations were performed for this specific test - # Note: We don't check total count as other tests may have created data - # Instead, we verify that the error was properly raised - pass - - def test_save_error_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies): - """ - Test error handling when saving message with no user. - - This test verifies: - - Method returns early when user is None - - No database operations are performed - - No exceptions are raised - """ - # Arrange: Create test data - fake = Faker() - app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - message = self._create_test_message(db_session_with_containers, app, account) - - # Act: Execute the method under test with None user - result = SavedMessageService.save(app_model=app, user=None, message_id=message.id) - - # Assert: Verify the expected outcomes - assert result is None - - # Verify no saved message was created - - saved_message = ( + saved = ( db_session_with_containers.query(SavedMessage) - .where( - SavedMessage.app_id == app.id, - SavedMessage.message_id == message.id, - ) + .where(SavedMessage.app_id == app.id, SavedMessage.message_id == message.id) .first() ) + assert saved is not None + assert saved.created_by == end_user.id + assert saved.created_by_role == "end_user" - assert saved_message is None - - def test_delete_success_existing_message( + def test_save_duplicate_is_idempotent( self, db_session_with_containers: Session, mock_external_service_dependencies ): - """ - Test successful deletion of an existing saved message. - - This test verifies: - - Proper deletion of existing saved message - - Correct database state after deletion - - No errors during deletion process - """ - # Arrange: Create test data - fake = Faker() + """Test that saving an already-saved message does not create a duplicate.""" app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) message = self._create_test_message(db_session_with_containers, app, account) - # Create a saved message first - saved_message = SavedMessage( - app_id=app.id, - message_id=message.id, - created_by_role="account", - created_by=account.id, - ) + mock_external_service_dependencies["message_service"].get_message.return_value = message - db_session_with_containers.add(saved_message) + # Save once + SavedMessageService.save(app_model=app, user=account, message_id=message.id) + # Save again + SavedMessageService.save(app_model=app, user=account, message_id=message.id) + + count = ( + db_session_with_containers.query(SavedMessage) + .where(SavedMessage.app_id == app.id, SavedMessage.message_id == message.id) + .count() + ) + assert count == 1 + + def test_delete_without_user_does_nothing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that deleting without a user is a no-op.""" + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + message = self._create_test_message(db_session_with_containers, app, account) + + # Pre-create a saved message + saved = SavedMessage(app_id=app.id, message_id=message.id, created_by_role="account", created_by=account.id) + db_session_with_containers.add(saved) db_session_with_containers.commit() - # Verify saved message exists + SavedMessageService.delete(app_model=app, user=None, message_id=message.id) + + # Should still exist + assert ( + db_session_with_containers.query(SavedMessage) + .where(SavedMessage.app_id == app.id, SavedMessage.message_id == message.id) + .first() + is not None + ) + + def test_delete_non_existent_does_nothing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that deleting a non-existent saved message is a no-op.""" + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Should not raise — use a valid UUID that doesn't exist in DB + from uuid import uuid4 + + SavedMessageService.delete(app_model=app, user=account, message_id=str(uuid4())) + + def test_delete_for_end_user(self, db_session_with_containers: Session, mock_external_service_dependencies): + """Test deleting a saved message for an EndUser.""" + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + end_user = self._create_test_end_user(db_session_with_containers, app) + message = self._create_test_message(db_session_with_containers, app, end_user) + + saved = SavedMessage(app_id=app.id, message_id=message.id, created_by_role="end_user", created_by=end_user.id) + db_session_with_containers.add(saved) + db_session_with_containers.commit() + + SavedMessageService.delete(app_model=app, user=end_user, message_id=message.id) + + assert ( + db_session_with_containers.query(SavedMessage) + .where(SavedMessage.app_id == app.id, SavedMessage.message_id == message.id) + .first() + is None + ) + + def test_delete_only_affects_own_saved_messages( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that delete only removes the requesting user's saved message.""" + app, account1 = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + end_user = self._create_test_end_user(db_session_with_containers, app) + message = self._create_test_message(db_session_with_containers, app, account1) + + # Both users save the same message + saved_account = SavedMessage( + app_id=app.id, message_id=message.id, created_by_role="account", created_by=account1.id + ) + saved_end_user = SavedMessage( + app_id=app.id, message_id=message.id, created_by_role="end_user", created_by=end_user.id + ) + db_session_with_containers.add_all([saved_account, saved_end_user]) + db_session_with_containers.commit() + + # Delete only account1's saved message + SavedMessageService.delete(app_model=app, user=account1, message_id=message.id) + + # Account's saved message should be gone assert ( db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, - SavedMessage.created_by_role == "account", - SavedMessage.created_by == account.id, + SavedMessage.created_by == account1.id, ) .first() - is not None + is None ) - - # Act: Execute the method under test - SavedMessageService.delete(app_model=app, user=account, message_id=message.id) - - # Assert: Verify the expected outcomes - # Check if saved message was deleted from database - deleted_saved_message = ( + # End user's saved message should still exist + assert ( db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, - SavedMessage.created_by_role == "account", - SavedMessage.created_by == account.id, + SavedMessage.created_by == end_user.id, ) .first() + is not None ) - - assert deleted_saved_message is None - - # Verify database state - db_session_with_containers.commit() - # The message should still exist, only the saved_message should be deleted - assert db_session_with_containers.query(Message).where(Message.id == message.id).first() is not None diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index fa6e6515298..f504f355890 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -7,9 +7,10 @@ from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset -from models.enums import DataSourceType +from models.enums import DataSourceType, TagType from models.model import App, Tag, TagBinding from services.tag_service import TagService @@ -102,7 +103,7 @@ class TestTagService: provider="vendor", permission="only_me", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, tenant_id=tenant_id, created_by=mock_external_service_dependencies["current_user"].id, ) @@ -547,7 +548,7 @@ class TestTagService: assert result is not None assert len(result) == 1 assert result[0].name == "python_tag" - assert result[0].type == "app" + assert result[0].type == TagType.APP assert result[0].tenant_id == tenant.id def test_get_tag_by_tag_name_no_matches( @@ -638,7 +639,7 @@ class TestTagService: # Verify all tags are returned for tag in result: - assert tag.type == "app" + assert tag.type == TagType.APP assert tag.tenant_id == tenant.id assert tag.id in [t.id for t in tags] diff --git a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py index 425611744b3..f2307fbd7df 100644 --- a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py @@ -7,6 +7,7 @@ from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom from models import Account +from models.enums import ConversationFromSource from models.model import Conversation, EndUser from models.web import PinnedConversation from services.account_service import AccountService, TenantService @@ -24,7 +25,7 @@ class TestWebConversationService: with ( patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.account_service.FeatureService") as mock_account_feature_service, ): # Setup default mock returns for app service @@ -145,7 +146,7 @@ class TestWebConversationService: system_instruction_tokens=50, status="normal", invoke_from=InvokeFrom.WEB_APP, - from_source="console" if isinstance(user, Account) else "api", + from_source=ConversationFromSource.CONSOLE if isinstance(user, Account) else ConversationFromSource.API, from_end_user_id=user.id if isinstance(user, EndUser) else None, from_account_id=user.id if isinstance(user, Account) else None, dialogue_count=0, diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index 8ab8df2a5a8..749c6fff5bc 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -1,20 +1,24 @@ +from __future__ import annotations + import json import uuid from datetime import UTC, datetime, timedelta +from types import SimpleNamespace from unittest.mock import patch import pytest from faker import Faker +from graphon.enums import WorkflowExecutionStatus from sqlalchemy.orm import Session -from dify_graph.entities.workflow_execution import WorkflowExecutionStatus -from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun -from models.enums import CreatorUserRole +from models import EndUser, Workflow, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun +from models.enums import AppTriggerType, CreatorUserRole, WorkflowRunTriggeredFrom +from models.workflow import WorkflowAppLogCreatedFrom from services.account_service import AccountService, TenantService # Delay import of AppService to avoid circular dependency # from services.app_service import AppService -from services.workflow_app_service import WorkflowAppService +from services.workflow_app_service import LogView, WorkflowAppService from tests.test_containers_integration_tests.helpers import generate_valid_password @@ -27,7 +31,7 @@ class TestWorkflowAppService: with ( patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.account_service.FeatureService") as mock_account_feature_service, ): # Setup default mock returns for app service @@ -221,7 +225,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -357,7 +361,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run_1.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -399,7 +403,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run_2.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -441,7 +445,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run_4.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -521,7 +525,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -627,7 +631,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -732,7 +736,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -860,7 +864,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -902,7 +906,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="web-app", + created_from=WorkflowAppLogCreatedFrom.WEB_APP, created_by_role=CreatorUserRole.END_USER, created_by=end_user.id, ) @@ -1037,7 +1041,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -1125,7 +1129,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -1279,7 +1283,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -1379,7 +1383,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -1481,7 +1485,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -1524,3 +1528,168 @@ class TestWorkflowAppService: # Should not find tenant2's data when searching from tenant1's context assert result_cross_tenant["total"] == 0 + + def test_get_paginate_workflow_app_logs_raises_when_account_filter_email_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + service = WorkflowAppService() + + with pytest.raises(ValueError, match="Account not found: nonexistent@example.com"): + service.get_paginate_workflow_app_logs( + session=db_session_with_containers, + app_model=app, + created_by_account="nonexistent@example.com", + ) + + def test_get_paginate_workflow_app_logs_filters_by_account( + self, db_session_with_containers, mock_external_service_dependencies + ): + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + service = WorkflowAppService() + workflow, workflow_run, _log = self._create_test_workflow_data(db_session_with_containers, app, account) + + result = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, + app_model=app, + created_by_account=account.email, + ) + + assert result["total"] >= 0 + assert isinstance(result["data"], list) + + def test_get_paginate_workflow_archive_logs(self, db_session_with_containers, mock_external_service_dependencies): + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + service = WorkflowAppService() + + end_user = EndUser( + tenant_id=app.tenant_id, + app_id=app.id, + type="browser", + is_anonymous=False, + session_id="session-1", + ) + db_session_with_containers.add(end_user) + db_session_with_containers.commit() + + now = datetime.now(UTC) + archive_defaults = { + "workflow_id": str(uuid.uuid4()), + "run_version": "1.0.0", + "run_status": WorkflowExecutionStatus.SUCCEEDED, + "run_triggered_from": WorkflowRunTriggeredFrom.APP_RUN, + "run_error": None, + "run_elapsed_time": 1.0, + "run_total_tokens": 0, + "run_total_steps": 0, + "run_created_at": now, + "run_finished_at": now, + "run_exceptions_count": 0, + "trigger_metadata": '{"type":"trigger-webhook"}', + "log_created_at": now, + "log_created_from": WorkflowAppLogCreatedFrom.SERVICE_API, + } + archive_account = WorkflowArchiveLog( + tenant_id=app.tenant_id, + app_id=app.id, + workflow_run_id=str(uuid.uuid4()), + log_id=str(uuid.uuid4()), + created_by=account.id, + created_by_role=CreatorUserRole.ACCOUNT, + **archive_defaults, + ) + archive_end_user = WorkflowArchiveLog( + tenant_id=app.tenant_id, + app_id=app.id, + workflow_run_id=str(uuid.uuid4()), + log_id=str(uuid.uuid4()), + created_by=end_user.id, + created_by_role=CreatorUserRole.END_USER, + **archive_defaults, + ) + db_session_with_containers.add_all([archive_account, archive_end_user]) + db_session_with_containers.commit() + + result = service.get_paginate_workflow_archive_logs( + session=db_session_with_containers, + app_model=app, + page=1, + limit=20, + ) + + assert result["total"] == 2 + assert len(result["data"]) == 2 + account_item = next(d for d in result["data"] if d["created_by_account"] is not None) + end_user_item = next(d for d in result["data"] if d["created_by_end_user"] is not None) + assert account_item["created_by_account"].id == account.id + assert end_user_item["created_by_end_user"].id == end_user.id + + +class TestLogView: + def test_details_and_proxy_attributes(self): + log = SimpleNamespace(id="log-1", status="succeeded") + view = LogView(log=log, details={"trigger_metadata": {"type": "plugin"}}) + + assert view.details == {"trigger_metadata": {"type": "plugin"}} + assert view.status == "succeeded" + + +class TestHandleTriggerMetadata: + def test_returns_empty_dict_when_metadata_missing(self): + service = WorkflowAppService() + assert service.handle_trigger_metadata("tenant-1", None) == {} + + def test_enriches_plugin_icons(self): + service = WorkflowAppService() + meta = { + "type": AppTriggerType.TRIGGER_PLUGIN.value, + "icon_filename": "light.png", + "icon_dark_filename": "dark.png", + } + with patch( + "services.workflow_app_service.PluginService.get_plugin_icon_url", + side_effect=["https://cdn/light.png", "https://cdn/dark.png"], + ) as mock_icon: + result = service.handle_trigger_metadata("tenant-1", json.dumps(meta)) + + assert result["icon"] == "https://cdn/light.png" + assert result["icon_dark"] == "https://cdn/dark.png" + assert mock_icon.call_count == 2 + + def test_non_plugin_metadata_without_icon_lookup(self): + service = WorkflowAppService() + meta = {"type": AppTriggerType.TRIGGER_WEBHOOK.value} + with patch("services.workflow_app_service.PluginService.get_plugin_icon_url") as mock_icon: + result = service.handle_trigger_metadata("tenant-1", json.dumps(meta)) + + assert result["type"] == AppTriggerType.TRIGGER_WEBHOOK.value + mock_icon.assert_not_called() + + +class TestSafeJsonLoads: + @pytest.mark.parametrize( + ("value", "expected"), + [ + (None, None), + ("", None), + ('{"k":"v"}', {"k": "v"}), + ("not-json", None), + ({"raw": True}, {"raw": True}), + ], + ) + def test_handles_various_inputs(self, value, expected): + assert WorkflowAppService._safe_json_loads(value) == expected + + +class TestSafeParseUuid: + def test_returns_none_for_short_or_invalid_values(self): + service = WorkflowAppService() + assert service._safe_parse_uuid("short") is None + assert service._safe_parse_uuid("x" * 40) is None + + def test_returns_uuid_for_valid_string(self): + service = WorkflowAppService() + raw = str(uuid.uuid4()) + result = service._safe_parse_uuid(raw) + assert result is not None + assert str(result) == raw diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py index 572cf72fa06..0c281c8c33b 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py @@ -1,9 +1,9 @@ import pytest from faker import Faker +from graphon.variables.segments import StringSegment from sqlalchemy.orm import Session -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from dify_graph.variables.segments import StringSegment +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from models import App, Workflow from models.enums import DraftVariableType from models.workflow import WorkflowDraftVariable @@ -482,7 +482,7 @@ class TestWorkflowDraftVariableService: fake = Faker() app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) workflow = self._create_test_workflow(db_session_with_containers, app, fake=fake) - from dify_graph.variables.variables import StringVariable + from graphon.variables.variables import StringVariable conv_var = StringVariable( id=fake.uuid4(), @@ -734,7 +734,7 @@ class TestWorkflowDraftVariableService: fake = Faker() app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) workflow = self._create_test_workflow(db_session_with_containers, app, fake=fake) - from dify_graph.variables.variables import StringVariable + from graphon.variables.variables import StringVariable conv_var1 = StringVariable( id=fake.uuid4(), diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py index e080d6ef6be..d02a0782810 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py @@ -7,7 +7,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session -from models.enums import CreatorUserRole +from models.enums import ConversationFromSource, CreatorUserRole from models.model import ( Message, ) @@ -27,7 +27,7 @@ class TestWorkflowRunService: with ( patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.account_service.FeatureService") as mock_account_feature_service, ): # Setup default mock returns for app service @@ -165,7 +165,7 @@ class TestWorkflowRunService: inputs={}, status="normal", mode="chat", - from_source=CreatorUserRole.ACCOUNT, + from_source=ConversationFromSource.CONSOLE, from_account_id=account.id, ) db_session_with_containers.add(conversation) @@ -186,7 +186,7 @@ class TestWorkflowRunService: message.answer_price_unit = 0.001 message.currency = "USD" message.status = "normal" - message.from_source = CreatorUserRole.ACCOUNT + message.from_source = ConversationFromSource.CONSOLE message.from_account_id = account.id message.workflow_run_id = workflow_run.id message.inputs = {"input": "test input"} diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_service.py index 056db417502..b5ce8a53de4 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_service.py @@ -555,6 +555,124 @@ class TestWorkflowService: assert len(result_workflows) == 2 assert all(wf.marked_name for wf in result_workflows) + def test_get_all_published_workflow_no_workflow_id(self, db_session_with_containers: Session): + """Test that an app with no workflow_id returns empty results.""" + # Arrange + fake = Faker() + app = self._create_test_app(db_session_with_containers, fake) + app.workflow_id = None + db_session_with_containers.commit() + + workflow_service = WorkflowService() + + # Act + result_workflows, has_more = workflow_service.get_all_published_workflow( + session=db_session_with_containers, app_model=app, page=1, limit=10, user_id=None + ) + + # Assert + assert result_workflows == [] + assert has_more is False + + def test_get_all_published_workflow_basic(self, db_session_with_containers: Session): + """Test basic retrieval of published workflows.""" + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + app = self._create_test_app(db_session_with_containers, fake) + + workflow1 = self._create_test_workflow(db_session_with_containers, app, account, fake) + workflow1.version = "2024.01.01.001" + workflow2 = self._create_test_workflow(db_session_with_containers, app, account, fake) + workflow2.version = "2024.01.02.001" + + app.workflow_id = workflow1.id + db_session_with_containers.commit() + + workflow_service = WorkflowService() + + # Act + result_workflows, has_more = workflow_service.get_all_published_workflow( + session=db_session_with_containers, app_model=app, page=1, limit=10, user_id=None + ) + + # Assert + assert len(result_workflows) == 2 + assert has_more is False + + def test_get_all_published_workflow_combined_filters(self, db_session_with_containers: Session): + """Test combined user_id and named_only filters.""" + # Arrange + fake = Faker() + account1 = self._create_test_account(db_session_with_containers, fake) + account2 = self._create_test_account(db_session_with_containers, fake) + app = self._create_test_app(db_session_with_containers, fake) + + # account1 named + wf1 = self._create_test_workflow(db_session_with_containers, app, account1, fake) + wf1.version = "2024.01.01.001" + wf1.marked_name = "Named by user1" + wf1.created_by = account1.id + + # account1 unnamed + wf2 = self._create_test_workflow(db_session_with_containers, app, account1, fake) + wf2.version = "2024.01.02.001" + wf2.marked_name = "" + wf2.created_by = account1.id + + # account2 named + wf3 = self._create_test_workflow(db_session_with_containers, app, account2, fake) + wf3.version = "2024.01.03.001" + wf3.marked_name = "Named by user2" + wf3.created_by = account2.id + + app.workflow_id = wf1.id + db_session_with_containers.commit() + + workflow_service = WorkflowService() + + # Act - Filter by account1 + named_only + result_workflows, has_more = workflow_service.get_all_published_workflow( + session=db_session_with_containers, + app_model=app, + page=1, + limit=10, + user_id=account1.id, + named_only=True, + ) + + # Assert - Only wf1 matches (account1 + named) + assert len(result_workflows) == 1 + assert result_workflows[0].marked_name == "Named by user1" + assert result_workflows[0].created_by == account1.id + + def test_get_all_published_workflow_empty_result(self, db_session_with_containers: Session): + """Test that querying with no matching workflows returns empty.""" + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + app = self._create_test_app(db_session_with_containers, fake) + + # Create a draft workflow (no version set = draft) + workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) + app.workflow_id = workflow.id + db_session_with_containers.commit() + + workflow_service = WorkflowService() + + # Act - Filter by a user that has no workflows + result_workflows, has_more = workflow_service.get_all_published_workflow( + session=db_session_with_containers, + app_model=app, + page=1, + limit=10, + user_id="00000000-0000-0000-0000-000000000000", + ) + + # Assert + assert result_workflows == [] + assert has_more is False + def test_sync_draft_workflow_create_new(self, db_session_with_containers: Session): """ Test creating a new draft workflow through sync operation. @@ -802,6 +920,81 @@ class TestWorkflowService: with pytest.raises(ValueError, match="No valid workflow found"): workflow_service.publish_workflow(session=db_session_with_containers, app_model=app, account=account) + def test_restore_published_workflow_to_draft_does_not_persist_normalized_source_features( + self, db_session_with_containers: Session + ): + """Restore copies legacy feature JSON into draft without rewriting the source row.""" + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + app = self._create_test_app(db_session_with_containers, fake) + app.mode = AppMode.ADVANCED_CHAT + + legacy_features = { + "file_upload": { + "image": { + "enabled": True, + "number_limits": 6, + "transfer_methods": ["remote_url", "local_file"], + } + }, + "opening_statement": "", + "retriever_resource": {"enabled": True}, + "sensitive_word_avoidance": {"enabled": False}, + "speech_to_text": {"enabled": False}, + "suggested_questions": [], + "suggested_questions_after_answer": {"enabled": False}, + "text_to_speech": {"enabled": False, "language": "", "voice": ""}, + } + published_workflow = Workflow( + id=fake.uuid4(), + tenant_id=app.tenant_id, + app_id=app.id, + type=WorkflowType.WORKFLOW, + version="2026.03.19.001", + graph=json.dumps({"nodes": [], "edges": []}), + features=json.dumps(legacy_features), + created_by=account.id, + updated_by=account.id, + environment_variables=[], + conversation_variables=[], + ) + draft_workflow = Workflow( + id=fake.uuid4(), + tenant_id=app.tenant_id, + app_id=app.id, + type=WorkflowType.WORKFLOW, + version=Workflow.VERSION_DRAFT, + graph=json.dumps({"nodes": [], "edges": []}), + features=json.dumps({}), + created_by=account.id, + updated_by=account.id, + environment_variables=[], + conversation_variables=[], + ) + db_session_with_containers.add(published_workflow) + db_session_with_containers.add(draft_workflow) + db_session_with_containers.commit() + + workflow_service = WorkflowService() + + restored_workflow = workflow_service.restore_published_workflow_to_draft( + app_model=app, + workflow_id=published_workflow.id, + account=account, + ) + + db_session_with_containers.expire_all() + refreshed_published_workflow = ( + db_session_with_containers.query(Workflow).filter_by(id=published_workflow.id).first() + ) + refreshed_draft_workflow = db_session_with_containers.query(Workflow).filter_by(id=draft_workflow.id).first() + + assert restored_workflow.id == draft_workflow.id + assert refreshed_published_workflow is not None + assert refreshed_draft_workflow is not None + assert refreshed_published_workflow.serialized_features == json.dumps(legacy_features) + assert refreshed_draft_workflow.serialized_features == json.dumps(legacy_features) + def test_get_default_block_configs(self, db_session_with_containers: Session): """ Test retrieval of default block configurations for all node types. @@ -1428,10 +1621,10 @@ class TestWorkflowService: import uuid from datetime import datetime - from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus - from dify_graph.graph_events import NodeRunSucceededEvent - from dify_graph.node_events import NodeRunResult - from dify_graph.nodes.base.node import Node + from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus + from graphon.graph_events import NodeRunSucceededEvent + from graphon.node_events import NodeRunResult + from graphon.nodes.base.node import Node # Create mock node mock_node = MagicMock(spec=Node) @@ -1473,12 +1666,12 @@ class TestWorkflowService: # Assert assert result is not None assert result.node_id == node_id - from dify_graph.enums import BuiltinNodeTypes + from graphon.enums import BuiltinNodeTypes assert result.node_type == BuiltinNodeTypes.START # Should match the mock node type assert result.title == "Test Node" # Import the enum for comparison - from dify_graph.enums import WorkflowNodeExecutionStatus + from graphon.enums import WorkflowNodeExecutionStatus assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.inputs is not None @@ -1503,10 +1696,10 @@ class TestWorkflowService: import uuid from datetime import datetime - from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus - from dify_graph.graph_events import NodeRunFailedEvent - from dify_graph.node_events import NodeRunResult - from dify_graph.nodes.base.node import Node + from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus + from graphon.graph_events import NodeRunFailedEvent + from graphon.node_events import NodeRunResult + from graphon.nodes.base.node import Node # Create mock node mock_node = MagicMock(spec=Node) @@ -1548,7 +1741,7 @@ class TestWorkflowService: assert result is not None assert result.node_id == node_id # Import the enum for comparison - from dify_graph.enums import WorkflowNodeExecutionStatus + from graphon.enums import WorkflowNodeExecutionStatus assert result.status == WorkflowNodeExecutionStatus.FAILED assert result.error is not None @@ -1572,10 +1765,10 @@ class TestWorkflowService: import uuid from datetime import datetime - from dify_graph.enums import BuiltinNodeTypes, ErrorStrategy, WorkflowNodeExecutionStatus - from dify_graph.graph_events import NodeRunFailedEvent - from dify_graph.node_events import NodeRunResult - from dify_graph.nodes.base.node import Node + from graphon.enums import BuiltinNodeTypes, ErrorStrategy, WorkflowNodeExecutionStatus + from graphon.graph_events import NodeRunFailedEvent + from graphon.node_events import NodeRunResult + from graphon.nodes.base.node import Node # Create mock node with continue_on_error mock_node = MagicMock(spec=Node) @@ -1618,7 +1811,7 @@ class TestWorkflowService: assert result is not None assert result.node_id == node_id # Import the enum for comparison - from dify_graph.enums import WorkflowNodeExecutionStatus + from graphon.enums import WorkflowNodeExecutionStatus assert result.status == WorkflowNodeExecutionStatus.EXCEPTION # Should be EXCEPTION, not FAILED assert result.outputs is not None diff --git a/api/tests/test_containers_integration_tests/services/test_workspace_service.py b/api/tests/test_containers_integration_tests/services/test_workspace_service.py index 92dec24c7d8..4e89d906f16 100644 --- a/api/tests/test_containers_integration_tests/services/test_workspace_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workspace_service.py @@ -1,4 +1,6 @@ -from unittest.mock import patch +from __future__ import annotations + +from unittest.mock import MagicMock, patch import pytest from faker import Faker @@ -534,3 +536,283 @@ class TestWorkspaceService: # Verify database state db_session_with_containers.refresh(tenant) assert tenant.id is not None + + def test_get_tenant_info_should_raise_assertion_when_join_missing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """TenantAccountJoin must exist; missing join should raise AssertionError.""" + fake = Faker() + account = Account(email=fake.email(), name=fake.name(), interface_language="en-US", status="active") + db_session_with_containers.add(account) + db_session_with_containers.commit() + + tenant = Tenant(name=fake.company(), status="normal", plan="basic") + db_session_with_containers.add(tenant) + db_session_with_containers.commit() + + # No TenantAccountJoin created + with patch("services.workspace_service.current_user", account): + with pytest.raises(AssertionError, match="TenantAccountJoin not found"): + WorkspaceService.get_tenant_info(tenant) + + def test_get_tenant_info_should_set_replace_webapp_logo_to_none_when_flag_absent( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """replace_webapp_logo should be None when custom_config_dict does not have the key.""" + import json + + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + tenant.custom_config = json.dumps({}) + db_session_with_containers.commit() + + mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True + mock_external_service_dependencies["tenant_service"].has_roles.return_value = True + + with patch("services.workspace_service.current_user", account): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert result["custom_config"]["replace_webapp_logo"] is None + + def test_get_tenant_info_should_use_files_url_for_logo_url( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """The logo URL should use dify_config.FILES_URL as the base.""" + import json + + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + tenant.custom_config = json.dumps({"replace_webapp_logo": True}) + db_session_with_containers.commit() + + custom_base = "https://cdn.mycompany.io" + mock_external_service_dependencies["dify_config"].FILES_URL = custom_base + mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True + mock_external_service_dependencies["tenant_service"].has_roles.return_value = True + + with patch("services.workspace_service.current_user", account): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert result["custom_config"]["replace_webapp_logo"].startswith(custom_base) + + def test_get_tenant_info_should_not_include_cloud_fields_in_self_hosted( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """next_credit_reset_date and trial_credits should NOT appear in SELF_HOSTED mode.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + mock_external_service_dependencies["dify_config"].EDITION = "SELF_HOSTED" + mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = False + mock_external_service_dependencies["tenant_service"].has_roles.return_value = False + + with patch("services.workspace_service.current_user", account): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert "next_credit_reset_date" not in result + assert "trial_credits" not in result + assert "trial_credits_used" not in result + + def test_get_tenant_info_cloud_credit_reset_date( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """next_credit_reset_date should be present in CLOUD edition.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + mock_external_service_dependencies["dify_config"].EDITION = "CLOUD" + feature = mock_external_service_dependencies["feature_service"].get_features.return_value + feature.can_replace_logo = False + feature.next_credit_reset_date = "2025-02-01" + feature.billing.subscription.plan = "professional" + mock_external_service_dependencies["tenant_service"].has_roles.return_value = False + + with ( + patch("services.workspace_service.current_user", account), + patch("services.credit_pool_service.CreditPoolService.get_pool", return_value=None), + ): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert result["next_credit_reset_date"] == "2025-02-01" + + def test_get_tenant_info_cloud_paid_pool_not_full( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """trial_credits come from paid pool when plan is not sandbox and pool is not full.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + mock_external_service_dependencies["dify_config"].EDITION = "CLOUD" + feature = mock_external_service_dependencies["feature_service"].get_features.return_value + feature.can_replace_logo = False + feature.next_credit_reset_date = "2025-02-01" + feature.billing.subscription.plan = "professional" + mock_external_service_dependencies["tenant_service"].has_roles.return_value = False + + paid_pool = MagicMock(quota_limit=1000, quota_used=200) + + with ( + patch("services.workspace_service.current_user", account), + patch("services.credit_pool_service.CreditPoolService.get_pool", return_value=paid_pool), + ): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert result["trial_credits"] == 1000 + assert result["trial_credits_used"] == 200 + + def test_get_tenant_info_cloud_paid_pool_unlimited( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """quota_limit == -1 means unlimited; service should use paid pool.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + mock_external_service_dependencies["dify_config"].EDITION = "CLOUD" + feature = mock_external_service_dependencies["feature_service"].get_features.return_value + feature.can_replace_logo = False + feature.next_credit_reset_date = "2025-02-01" + feature.billing.subscription.plan = "professional" + mock_external_service_dependencies["tenant_service"].has_roles.return_value = False + + paid_pool = MagicMock(quota_limit=-1, quota_used=999) + + with ( + patch("services.workspace_service.current_user", account), + patch("services.credit_pool_service.CreditPoolService.get_pool", side_effect=[paid_pool, None]), + ): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert result["trial_credits"] == -1 + assert result["trial_credits_used"] == 999 + + def test_get_tenant_info_cloud_fall_back_to_trial_when_paid_full( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """When paid pool is exhausted, switch to trial pool.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + mock_external_service_dependencies["dify_config"].EDITION = "CLOUD" + feature = mock_external_service_dependencies["feature_service"].get_features.return_value + feature.can_replace_logo = False + feature.next_credit_reset_date = "2025-02-01" + feature.billing.subscription.plan = "professional" + mock_external_service_dependencies["tenant_service"].has_roles.return_value = False + + paid_pool = MagicMock(quota_limit=500, quota_used=500) + trial_pool = MagicMock(quota_limit=100, quota_used=10) + + with ( + patch("services.workspace_service.current_user", account), + patch("services.credit_pool_service.CreditPoolService.get_pool", side_effect=[paid_pool, trial_pool]), + ): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert result["trial_credits"] == 100 + assert result["trial_credits_used"] == 10 + + def test_get_tenant_info_cloud_fall_back_to_trial_when_paid_none( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """When paid_pool is None, fall back to trial pool.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + mock_external_service_dependencies["dify_config"].EDITION = "CLOUD" + feature = mock_external_service_dependencies["feature_service"].get_features.return_value + feature.can_replace_logo = False + feature.next_credit_reset_date = "2025-02-01" + feature.billing.subscription.plan = "professional" + mock_external_service_dependencies["tenant_service"].has_roles.return_value = False + + trial_pool = MagicMock(quota_limit=50, quota_used=5) + + with ( + patch("services.workspace_service.current_user", account), + patch("services.credit_pool_service.CreditPoolService.get_pool", side_effect=[None, trial_pool]), + ): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert result["trial_credits"] == 50 + assert result["trial_credits_used"] == 5 + + def test_get_tenant_info_cloud_sandbox_uses_trial_pool( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """When plan is SANDBOX, skip paid pool and use trial pool.""" + from enums.cloud_plan import CloudPlan + + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + mock_external_service_dependencies["dify_config"].EDITION = "CLOUD" + feature = mock_external_service_dependencies["feature_service"].get_features.return_value + feature.can_replace_logo = False + feature.next_credit_reset_date = "2025-02-01" + feature.billing.subscription.plan = CloudPlan.SANDBOX + mock_external_service_dependencies["tenant_service"].has_roles.return_value = False + + paid_pool = MagicMock(quota_limit=1000, quota_used=0) + trial_pool = MagicMock(quota_limit=200, quota_used=20) + + with ( + patch("services.workspace_service.current_user", account), + patch("services.credit_pool_service.CreditPoolService.get_pool", side_effect=[paid_pool, trial_pool]), + ): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert result["trial_credits"] == 200 + assert result["trial_credits_used"] == 20 + + def test_get_tenant_info_cloud_both_pools_none( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """When both paid and trial pools are absent, trial_credits should not be set.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + mock_external_service_dependencies["dify_config"].EDITION = "CLOUD" + feature = mock_external_service_dependencies["feature_service"].get_features.return_value + feature.can_replace_logo = False + feature.next_credit_reset_date = "2025-02-01" + feature.billing.subscription.plan = "professional" + mock_external_service_dependencies["tenant_service"].has_roles.return_value = False + + with ( + patch("services.workspace_service.current_user", account), + patch("services.credit_pool_service.CreditPoolService.get_pool", side_effect=[None, None]), + ): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert "trial_credits" not in result + assert "trial_credits_used" not in result diff --git a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py index bffdca623a4..d3e765055a0 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py @@ -536,3 +536,151 @@ class TestApiToolManageService: # Verify mock interactions mock_external_service_dependencies["encrypter"].assert_called_once() mock_external_service_dependencies["provider_controller"].from_db.assert_called_once() + + def test_delete_api_tool_provider_success( + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test successful deletion of an API tool provider.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + schema = self._create_test_openapi_schema() + provider_name = fake.unique.word() + + ApiToolManageService.create_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name=provider_name, + icon={"content": "🔧", "background": "#FFF"}, + credentials={"auth_type": "none"}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema=schema, + privacy_policy="", + custom_disclaimer="", + labels=[], + ) + + provider = ( + db_session_with_containers.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name) + .first() + ) + assert provider is not None + + result = ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, provider_name) + + assert result == {"result": "success"} + deleted = ( + db_session_with_containers.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name) + .first() + ) + assert deleted is None + + def test_delete_api_tool_provider_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test deletion raises ValueError when provider not found.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="you have not added provider"): + ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, "nonexistent") + + def test_update_api_tool_provider_not_found( + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test update raises ValueError when original provider not found.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="does not exists"): + ApiToolManageService.update_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name="new-name", + original_provider="nonexistent", + icon={}, + credentials={"auth_type": "none"}, + _schema_type=ApiProviderSchemaType.OPENAPI, + schema=self._create_test_openapi_schema(), + privacy_policy=None, + custom_disclaimer="", + labels=[], + ) + + def test_update_api_tool_provider_missing_auth_type( + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test update raises ValueError when auth_type is missing from credentials.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + schema = self._create_test_openapi_schema() + provider_name = fake.unique.word() + + ApiToolManageService.create_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name=provider_name, + icon={"content": "🔧", "background": "#FFF"}, + credentials={"auth_type": "none"}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema=schema, + privacy_policy="", + custom_disclaimer="", + labels=[], + ) + + with pytest.raises(ValueError, match="auth_type is required"): + ApiToolManageService.update_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name=provider_name, + original_provider=provider_name, + icon={}, + credentials={}, + _schema_type=ApiProviderSchemaType.OPENAPI, + schema=schema, + privacy_policy=None, + custom_disclaimer="", + labels=[], + ) + + def test_list_api_tool_provider_tools_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test listing tools raises ValueError when provider not found.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="you have not added provider"): + ApiToolManageService.list_api_tool_provider_tools(account.id, tenant.id, "nonexistent") + + def test_test_api_tool_preview_invalid_schema_type( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test preview raises ValueError for invalid schema type.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="invalid schema type"): + ApiToolManageService.test_api_tool_preview( + tenant_id=tenant.id, + provider_name="provider-a", + tool_name="tool-a", + credentials={"auth_type": "none"}, + parameters={}, + schema_type="bad-schema-type", + schema="schema", + ) diff --git a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py index 0f38218c51b..2dc50cc7202 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py @@ -1,12 +1,24 @@ +from __future__ import annotations + from unittest.mock import Mock, patch import pytest from faker import Faker from sqlalchemy.orm import Session -from core.tools.entities.api_entities import ToolProviderApiEntity +from core.tools.__base.tool import Tool +from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolProviderType +from core.tools.entities.tool_entities import ( + ApiProviderSchemaType, + ToolDescription, + ToolEntity, + ToolIdentity, + ToolParameter, + ToolProviderEntity, + ToolProviderIdentity, + ToolProviderType, +) from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider from services.plugin.plugin_service import PluginService from services.tools.tools_transform_service import ToolTransformService @@ -52,7 +64,7 @@ class TestToolTransformService: user_id="test_user_id", credentials_str='{"auth_type": "api_key_header", "api_key": "test_key"}', schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, tools_str="[]", ) elif provider_type == "builtin": @@ -659,7 +671,7 @@ class TestToolTransformService: user_id=fake.uuid4(), credentials_str='{"auth_type": "api_key_header", "api_key": "test_key"}', schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, tools_str="[]", ) @@ -695,7 +707,7 @@ class TestToolTransformService: user_id=fake.uuid4(), credentials_str='{"auth_type": "api_key_query", "api_key": "test_key"}', schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, tools_str="[]", ) @@ -731,7 +743,7 @@ class TestToolTransformService: user_id=fake.uuid4(), credentials_str='{"auth_type": "api_key", "api_key": "test_key"}', schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, tools_str="[]", ) @@ -786,3 +798,192 @@ class TestToolTransformService: assert result is not None assert result == mock_controller mock_from_db.assert_called_once_with(provider) + + +def _mock_tool(*, base_params, runtime_params): + """Helper to build a Mock tool with real entity objects. + + Tool is abstract and requires runtime behaviour (fork_tool_runtime, + get_runtime_parameters), so it stays as a Mock. Everything else uses + real Pydantic instances. + """ + entity = ToolEntity( + identity=ToolIdentity( + author="test_author", + name="test_tool", + label=I18nObject(en_US="Test Tool"), + provider="test_provider", + ), + parameters=base_params or [], + description=ToolDescription( + human=I18nObject(en_US="Test description"), + llm="Test description for LLM", + ), + output_schema={}, + ) + mock_tool = Mock(spec=Tool) + mock_tool.entity = entity + mock_tool.get_runtime_parameters.return_value = runtime_params + mock_tool.fork_tool_runtime.return_value = mock_tool + return mock_tool + + +def _param(name, *, form=ToolParameter.ToolParameterForm.FORM, label=None): + return ToolParameter( + name=name, + label=I18nObject(en_US=label or name), + human_description=I18nObject(en_US=name), + type=ToolParameter.ToolParameterType.STRING, + form=form, + ) + + +class TestConvertToolEntityToApiEntity: + """Tests for ToolTransformService.convert_tool_entity_to_api_entity.""" + + def test_parameter_override(self): + base = [_param("param1", label="Base 1"), _param("param2", label="Base 2")] + runtime = [_param("param1", label="Runtime 1")] + tool = _mock_tool(base_params=base, runtime_params=runtime) + + result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None) + + assert isinstance(result, ToolApiEntity) + assert len(result.parameters) == 2 + assert next(p for p in result.parameters if p.name == "param1").label.en_US == "Runtime 1" + assert next(p for p in result.parameters if p.name == "param2").label.en_US == "Base 2" + + def test_additional_runtime_parameters(self): + base = [_param("param1", label="Base 1")] + runtime = [_param("param1", label="Runtime 1"), _param("runtime_only", label="Runtime Only")] + tool = _mock_tool(base_params=base, runtime_params=runtime) + + result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None) + + assert len(result.parameters) == 2 + names = [p.name for p in result.parameters] + assert "param1" in names + assert "runtime_only" in names + + def test_non_form_runtime_parameters_excluded(self): + base = [_param("param1")] + runtime = [ + _param("param1", label="Runtime 1"), + _param("llm_param", form=ToolParameter.ToolParameterForm.LLM), + ] + tool = _mock_tool(base_params=base, runtime_params=runtime) + + result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None) + + assert len(result.parameters) == 1 + assert result.parameters[0].name == "param1" + + def test_empty_parameters(self): + tool = _mock_tool(base_params=[], runtime_params=[]) + + result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None) + + assert isinstance(result, ToolApiEntity) + assert len(result.parameters) == 0 + + def test_none_parameters(self): + tool = _mock_tool(base_params=None, runtime_params=[]) + + result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None) + + assert isinstance(result, ToolApiEntity) + assert len(result.parameters) == 0 + + def test_parameter_order_preserved(self): + base = [_param("p1", label="B1"), _param("p2", label="B2"), _param("p3", label="B3")] + runtime = [_param("p2", label="R2"), _param("p4", label="R4")] + tool = _mock_tool(base_params=base, runtime_params=runtime) + + result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None) + + assert [p.name for p in result.parameters] == ["p1", "p2", "p3", "p4"] + assert result.parameters[1].label.en_US == "R2" + + +class TestWorkflowProviderToUserProvider: + """Tests for ToolTransformService.workflow_provider_to_user_provider.""" + + @staticmethod + def _make_controller(provider_id="provider_123", **identity_overrides): + from core.tools.workflow_as_tool.provider import WorkflowToolProviderController + + defaults = { + "author": "test_author", + "name": "test_workflow_tool", + "description": I18nObject(en_US="Test description"), + "icon": '{"type": "emoji", "content": "🔧"}', + "icon_dark": None, + "label": I18nObject(en_US="Test Workflow Tool"), + } + defaults.update(identity_overrides) + identity = ToolProviderIdentity(**defaults) + entity = ToolProviderEntity(identity=identity) + return WorkflowToolProviderController(entity=entity, provider_id=provider_id) + + def test_with_workflow_app_id(self): + ctrl = self._make_controller() + + result = ToolTransformService.workflow_provider_to_user_provider( + provider_controller=ctrl, + labels=["l1", "l2"], + workflow_app_id="app_123", + ) + + assert isinstance(result, ToolProviderApiEntity) + assert result.id == "provider_123" + assert result.type == ToolProviderType.WORKFLOW + assert result.workflow_app_id == "app_123" + assert result.labels == ["l1", "l2"] + assert result.is_team_authorization is True + + def test_without_workflow_app_id(self): + ctrl = self._make_controller() + + result = ToolTransformService.workflow_provider_to_user_provider( + provider_controller=ctrl, + labels=["l1"], + ) + + assert result.workflow_app_id is None + + def test_workflow_app_id_none_explicit(self): + ctrl = self._make_controller() + + result = ToolTransformService.workflow_provider_to_user_provider( + provider_controller=ctrl, + labels=None, + workflow_app_id=None, + ) + + assert result.workflow_app_id is None + assert result.labels == [] + + def test_preserves_other_fields(self): + ctrl = self._make_controller( + "provider_456", + author="another_author", + name="another_workflow_tool", + description=I18nObject(en_US="Another desc", zh_Hans="Another desc"), + icon='{"type": "emoji", "content": "⚙️"}', + icon_dark='{"type": "emoji", "content": "🔧"}', + label=I18nObject(en_US="Another Tool", zh_Hans="Another Tool"), + ) + + result = ToolTransformService.workflow_provider_to_user_provider( + provider_controller=ctrl, + labels=["automation"], + workflow_app_id="app_456", + ) + + assert result.id == "provider_456" + assert result.author == "another_author" + assert result.name == "another_workflow_tool" + assert result.type == ToolProviderType.WORKFLOW + assert result.workflow_app_id == "app_456" + assert result.is_team_authorization is True + assert result.allow_delete is True diff --git a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py index 34906a4e540..21a19758798 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py @@ -25,7 +25,7 @@ class TestWorkflowToolManageService: with ( patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.account_service.FeatureService") as mock_account_feature_service, patch( "services.tools.workflow_tools_manage_service.WorkflowToolProviderController" @@ -1043,3 +1043,112 @@ class TestWorkflowToolManageService: # After the fix, this should always be 0 # For now, we document that the record may exist, demonstrating the bug # assert tool_count == 0 # Expected after fix + + def test_delete_workflow_tool_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test successful deletion of a workflow tool.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + tool_name = fake.unique.word() + + WorkflowToolManageService.create_workflow_tool( + user_id=account.id, + tenant_id=account.current_tenant.id, + workflow_app_id=app.id, + name=tool_name, + label=fake.word(), + icon={"type": "emoji", "emoji": "🔧"}, + description=fake.text(max_nb_chars=200), + parameters=self._create_test_workflow_tool_parameters(), + ) + + tool = ( + db_session_with_containers.query(WorkflowToolProvider) + .where(WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.name == tool_name) + .first() + ) + assert tool is not None + + result = WorkflowToolManageService.delete_workflow_tool(account.id, account.current_tenant.id, tool.id) + + assert result == {"result": "success"} + deleted = ( + db_session_with_containers.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool.id).first() + ) + assert deleted is None + + def test_list_tenant_workflow_tools_empty( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test listing workflow tools when none exist returns empty list.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + result = WorkflowToolManageService.list_tenant_workflow_tools(account.id, account.current_tenant.id) + + assert result == [] + + def test_get_workflow_tool_by_tool_id_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that get_workflow_tool_by_tool_id raises ValueError when tool not found.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="Tool not found"): + WorkflowToolManageService.get_workflow_tool_by_tool_id(account.id, account.current_tenant.id, fake.uuid4()) + + def test_get_workflow_tool_by_app_id_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that get_workflow_tool_by_app_id raises ValueError when tool not found.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="Tool not found"): + WorkflowToolManageService.get_workflow_tool_by_app_id(account.id, account.current_tenant.id, fake.uuid4()) + + def test_list_single_workflow_tools_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that list_single_workflow_tools raises ValueError when tool not found.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="not found"): + WorkflowToolManageService.list_single_workflow_tools(account.id, account.current_tenant.id, fake.uuid4()) + + def test_create_workflow_tool_with_labels( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that labels are forwarded to ToolLabelManager when provided.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + result = WorkflowToolManageService.create_workflow_tool( + user_id=account.id, + tenant_id=account.current_tenant.id, + workflow_app_id=app.id, + name=fake.unique.word(), + label=fake.word(), + icon={"type": "emoji", "emoji": "🔧"}, + description=fake.text(max_nb_chars=200), + parameters=self._create_test_workflow_tool_parameters(), + labels=["label-1", "label-2"], + ) + + assert result == {"result": "success"} + mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_called_once() diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py index c3fe6a2950d..ce2fd2eeb1a 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py @@ -1,11 +1,19 @@ +from __future__ import annotations + import json -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from faker import Faker +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.message_entities import PromptMessageRole +from graphon.variables.input_entities import VariableEntity, VariableEntityType from sqlalchemy.orm import Session from core.app.app_config.entities import ( + AdvancedChatMessageEntity, + AdvancedChatPromptTemplateEntity, + AdvancedCompletionPromptTemplateEntity, DatasetEntity, DatasetRetrieveConfigEntity, ExternalDataVariableEntity, @@ -13,10 +21,8 @@ from core.app.app_config.entities import ( PromptTemplateEntity, ) from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType from models import Account, Tenant -from models.api_based_extension import APIBasedExtension +from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint from models.model import App, AppMode, AppModelConfig from models.workflow import Workflow from services.workflow.workflow_converter import WorkflowConverter @@ -548,3 +554,198 @@ class TestWorkflowConverter: # Verify single retrieval config is None for multiple strategy assert node["data"]["single_retrieval_config"] is None + + +@pytest.fixture +def default_variables(): + return [ + VariableEntity(variable="text_input", label="text-input", type=VariableEntityType.TEXT_INPUT), + VariableEntity(variable="paragraph", label="paragraph", type=VariableEntityType.PARAGRAPH), + VariableEntity(variable="select", label="select", type=VariableEntityType.SELECT), + ] + + +class TestConvertToHttpRequestNodeVariants: + """Tests for chatbot vs workflow differences in HTTP request node conversion.""" + + @staticmethod + def _setup(app_mode, default_variables): + app_model = App( + tenant_id="tenant_id", + mode=app_mode, + name="test", + icon_type="emoji", + icon="🤖", + icon_background="#FFFFFF", + ) + + ext = APIBasedExtension(tenant_id="tenant_id", name="api-1", api_key="enc", api_endpoint="https://dify.ai") + ext.id = "ext_id" + + converter = WorkflowConverter() + converter._get_api_based_extension = MagicMock(return_value=ext) + + from core.helper import encrypter + + encrypter.decrypt_token = MagicMock(return_value="api_key") + + ext_vars = [ + ExternalDataVariableEntity( + variable="external_variable", type="api", config={"api_based_extension_id": "ext_id"} + ) + ] + nodes, _ = converter._convert_to_http_request_node( + app_model=app_model, + variables=default_variables, + external_data_variables=ext_vars, + ) + return nodes + + def test_chatbot_query_uses_sys_query(self, default_variables): + nodes = self._setup(AppMode.CHAT, default_variables) + + body = json.loads(nodes[0]["data"]["body"]["data"]) + assert body["params"]["query"] == "{{#sys.query#}}" + assert body["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY + assert nodes[1]["data"]["type"] == "code" + + def test_workflow_query_is_empty(self, default_variables): + nodes = self._setup(AppMode.WORKFLOW, default_variables) + + body = json.loads(nodes[0]["data"]["body"]["data"]) + assert body["params"]["query"] == "" + + +class TestConvertToKnowledgeRetrievalNodeVariants: + """Tests for chatbot vs workflow differences in knowledge retrieval node.""" + + @staticmethod + def _dataset_config(query_variable=None): + return DatasetEntity( + dataset_ids=["ds1", "ds2"], + retrieve_config=DatasetRetrieveConfigEntity( + query_variable=query_variable, + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + top_k=5, + score_threshold=0.8, + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"}, + reranking_enabled=True, + ), + ) + + @staticmethod + def _model_config(): + return ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[]) + + def test_chatbot_uses_sys_query(self): + node = WorkflowConverter()._convert_to_knowledge_retrieval_node( + new_app_mode=AppMode.ADVANCED_CHAT, + dataset_config=self._dataset_config(), + model_config=self._model_config(), + ) + assert node["data"]["query_variable_selector"] == ["sys", "query"] + + def test_workflow_uses_start_variable(self): + node = WorkflowConverter()._convert_to_knowledge_retrieval_node( + new_app_mode=AppMode.WORKFLOW, + dataset_config=self._dataset_config(query_variable="query"), + model_config=self._model_config(), + ) + assert node["data"]["query_variable_selector"] == ["start", "query"] + + +class TestConvertToLlmNode: + """Tests for LLM node conversion across model modes and prompt types.""" + + @staticmethod + def _model_config(model, mode): + return ModelConfigEntity( + provider="openai", + model=model, + mode=mode.value, + parameters={}, + stop=[], + ) + + @staticmethod + def _graph(default_variables): + start = WorkflowConverter()._convert_to_start_node(default_variables) + return {"nodes": [start], "edges": []} + + def test_simple_chat_model(self, default_variables): + prompt = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.SIMPLE, + simple_prompt_template="You are helpful {{text_input}}, {{paragraph}}, {{select}}.", + ) + node = WorkflowConverter()._convert_to_llm_node( + original_app_mode=AppMode.CHAT, + new_app_mode=AppMode.ADVANCED_CHAT, + model_config=self._model_config("gpt-4", LLMMode.CHAT), + graph=self._graph(default_variables), + prompt_template=prompt, + ) + assert node["data"]["type"] == "llm" + assert node["data"]["model"]["mode"] == LLMMode.CHAT.value + assert node["data"]["context"]["enabled"] is False + expected = "You are helpful {{#start.text_input#}}, {{#start.paragraph#}}, {{#start.select#}}.\n" + assert node["data"]["prompt_template"][0]["text"] == expected + + def test_simple_completion_model(self, default_variables): + prompt = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.SIMPLE, + simple_prompt_template="You are helpful {{text_input}}, {{paragraph}}, {{select}}.", + ) + node = WorkflowConverter()._convert_to_llm_node( + original_app_mode=AppMode.CHAT, + new_app_mode=AppMode.ADVANCED_CHAT, + model_config=self._model_config("gpt-3.5-turbo-instruct", LLMMode.COMPLETION), + graph=self._graph(default_variables), + prompt_template=prompt, + ) + assert node["data"]["model"]["mode"] == LLMMode.COMPLETION.value + expected = "You are helpful {{#start.text_input#}}, {{#start.paragraph#}}, {{#start.select#}}.\n" + assert node["data"]["prompt_template"]["text"] == expected + + def test_advanced_chat_model(self, default_variables): + prompt = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.ADVANCED, + advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity( + messages=[ + AdvancedChatMessageEntity( + text="You are helpful named {{name}}.\n\nContext:\n{{#context#}}", + role=PromptMessageRole.SYSTEM, + ), + AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER), + AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT), + ] + ), + ) + node = WorkflowConverter()._convert_to_llm_node( + original_app_mode=AppMode.CHAT, + new_app_mode=AppMode.ADVANCED_CHAT, + model_config=self._model_config("gpt-4", LLMMode.CHAT), + graph=self._graph(default_variables), + prompt_template=prompt, + ) + assert isinstance(node["data"]["prompt_template"], list) + assert len(node["data"]["prompt_template"]) == 3 + + def test_advanced_completion_model(self, default_variables): + prompt = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.ADVANCED, + advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity( + prompt="You are helpful named {{name}}.\n\nContext:\n{{#context#}}\n\nHuman: hi\nAssistant: ", + role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity( + user="Human", assistant="Assistant" + ), + ), + ) + node = WorkflowConverter()._convert_to_llm_node( + original_app_mode=AppMode.CHAT, + new_app_mode=AppMode.ADVANCED_CHAT, + model_config=self._model_config("gpt-3.5-turbo-instruct", LLMMode.COMPLETION), + graph=self._graph(default_variables), + prompt_template=prompt, + ) + assert isinstance(node["data"]["prompt_template"], dict) + assert "text" in node["data"]["prompt_template"] diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_deletion.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_deletion.py new file mode 100644 index 00000000000..29e1e240b49 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_deletion.py @@ -0,0 +1,158 @@ +"""Testcontainers integration tests for WorkflowService.delete_workflow.""" + +import json +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session, sessionmaker + +from extensions.ext_database import db +from models.account import Account, Tenant, TenantAccountJoin +from models.model import App +from models.tools import WorkflowToolProvider +from models.workflow import Workflow +from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService + + +class TestWorkflowDeletion: + def _create_tenant_and_account(self, session: Session) -> tuple[Tenant, Account]: + tenant = Tenant(name=f"Tenant {uuid4()}") + session.add(tenant) + session.flush() + + account = Account( + name=f"Account {uuid4()}", + email=f"wf_del_{uuid4()}@example.com", + password="hashed", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) + session.add(account) + session.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role="owner", + current=True, + ) + session.add(join) + session.flush() + return tenant, account + + def _create_app(self, session: Session, *, tenant: Tenant, account: Account, workflow_id: str | None = None) -> App: + app = App( + tenant_id=tenant.id, + name=f"App {uuid4()}", + description="", + mode="workflow", + icon_type="emoji", + icon="bot", + icon_background="#FFFFFF", + enable_site=False, + enable_api=True, + api_rpm=100, + api_rph=100, + is_demo=False, + is_public=False, + is_universal=False, + created_by=account.id, + updated_by=account.id, + workflow_id=workflow_id, + ) + session.add(app) + session.flush() + return app + + def _create_workflow( + self, session: Session, *, tenant: Tenant, app: App, account: Account, version: str = "1.0" + ) -> Workflow: + workflow = Workflow( + id=str(uuid4()), + tenant_id=tenant.id, + app_id=app.id, + type="workflow", + version=version, + graph=json.dumps({"nodes": [], "edges": []}), + _features=json.dumps({}), + created_by=account.id, + updated_by=account.id, + ) + session.add(workflow) + session.flush() + return workflow + + def _create_tool_provider( + self, session: Session, *, tenant: Tenant, app: App, account: Account, version: str + ) -> WorkflowToolProvider: + provider = WorkflowToolProvider( + name=f"tool-{uuid4()}", + label=f"Tool {uuid4()}", + icon="wrench", + app_id=app.id, + version=version, + user_id=account.id, + tenant_id=tenant.id, + description="test tool provider", + ) + session.add(provider) + session.flush() + return provider + + def test_delete_workflow_success(self, db_session_with_containers): + tenant, account = self._create_tenant_and_account(db_session_with_containers) + app = self._create_app(db_session_with_containers, tenant=tenant, account=account) + workflow = self._create_workflow( + db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0" + ) + db_session_with_containers.commit() + workflow_id = workflow.id + + service = WorkflowService(sessionmaker(bind=db.engine)) + result = service.delete_workflow( + session=db_session_with_containers, workflow_id=workflow_id, tenant_id=tenant.id + ) + + assert result is True + db_session_with_containers.expire_all() + assert db_session_with_containers.get(Workflow, workflow_id) is None + + def test_delete_draft_workflow_raises_error(self, db_session_with_containers): + tenant, account = self._create_tenant_and_account(db_session_with_containers) + app = self._create_app(db_session_with_containers, tenant=tenant, account=account) + workflow = self._create_workflow( + db_session_with_containers, tenant=tenant, app=app, account=account, version="draft" + ) + db_session_with_containers.commit() + + service = WorkflowService(sessionmaker(bind=db.engine)) + with pytest.raises(DraftWorkflowDeletionError): + service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id) + + def test_delete_workflow_in_use_by_app_raises_error(self, db_session_with_containers): + tenant, account = self._create_tenant_and_account(db_session_with_containers) + app = self._create_app(db_session_with_containers, tenant=tenant, account=account) + workflow = self._create_workflow( + db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0" + ) + # Point app to this workflow + app.workflow_id = workflow.id + db_session_with_containers.commit() + + service = WorkflowService(sessionmaker(bind=db.engine)) + with pytest.raises(WorkflowInUseError, match="currently in use by app"): + service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id) + + def test_delete_workflow_published_as_tool_raises_error(self, db_session_with_containers): + tenant, account = self._create_tenant_and_account(db_session_with_containers) + app = self._create_app(db_session_with_containers, tenant=tenant, account=account) + workflow = self._create_workflow( + db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0" + ) + self._create_tool_provider(db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0") + db_session_with_containers.commit() + + service = WorkflowService(sessionmaker(bind=db.engine)) + with pytest.raises(WorkflowInUseError, match="published as a tool"): + service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id) diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py index af9e8d0b2cd..7c43bf676b0 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py @@ -1,10 +1,10 @@ from datetime import datetime, timedelta from uuid import uuid4 +from graphon.enums import WorkflowNodeExecutionStatus from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker -from dify_graph.enums import WorkflowNodeExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel diff --git a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py index 94173c34bff..4b04c1accb1 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py @@ -4,7 +4,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetAutoDisableLog, Document, DocumentSegment @@ -81,7 +81,7 @@ class TestAddDocumentToIndexTask: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py index 210d9eb39eb..6cbbe431370 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py @@ -13,6 +13,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.storage.storage_type import StorageType from libs.datetime_utils import naive_utc_now from models import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -152,7 +153,7 @@ class TestBatchCleanDocumentTask: created_from=DocumentCreatedFrom.WEB, created_by=account.id, indexing_status=IndexingStatus.COMPLETED, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) @@ -392,7 +393,12 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Execute the task with non-existent dataset - batch_clean_document_task(document_ids=[document_id], dataset_id=dataset_id, doc_form="text_model", file_ids=[]) + batch_clean_document_task( + document_ids=[document_id], + dataset_id=dataset_id, + doc_form=IndexStructureType.PARAGRAPH_INDEX, + file_ids=[], + ) # Verify that no index processing occurred mock_external_service_dependencies["index_processor"].clean.assert_not_called() @@ -525,7 +531,11 @@ class TestBatchCleanDocumentTask: account = self._create_test_account(db_session_with_containers) # Test different doc_form types - doc_forms = ["text_model", "qa_model", "hierarchical_model"] + doc_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + ] for doc_form in doc_forms: dataset = self._create_test_dataset(db_session_with_containers, account) diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index 202ccb0098f..f9ae33b32ff 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -19,6 +19,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.storage.storage_type import StorageType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -53,7 +54,10 @@ class TestBatchCreateSegmentToIndexTask: """Mock setup for external service dependencies.""" with ( patch("tasks.batch_create_segment_to_index_task.storage", autospec=True) as mock_storage, - patch("tasks.batch_create_segment_to_index_task.ModelManager", autospec=True) as mock_model_manager, + patch( + "tasks.batch_create_segment_to_index_task.ModelManager.for_tenant", + autospec=True, + ) as mock_model_manager, patch("tasks.batch_create_segment_to_index_task.VectorService", autospec=True) as mock_vector_service, ): # Setup default mock returns @@ -141,7 +145,7 @@ class TestBatchCreateSegmentToIndexTask: name=fake.company(), description=fake.text(), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model="text-embedding-ada-002", embedding_model_provider="openai", created_by=account.id, @@ -179,7 +183,7 @@ class TestBatchCreateSegmentToIndexTask: indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=0, ) @@ -221,17 +225,17 @@ class TestBatchCreateSegmentToIndexTask: return upload_file - def _create_test_csv_content(self, content_type="text_model"): + def _create_test_csv_content(self, content_type=IndexStructureType.PARAGRAPH_INDEX): """ Helper method to create test CSV content. Args: - content_type: Type of content to create ("text_model" or "qa_model") + content_type: Type of content to create (IndexStructureType.PARAGRAPH_INDEX or IndexStructureType.QA_INDEX) Returns: str: CSV content as string """ - if content_type == "qa_model": + if content_type == IndexStructureType.QA_INDEX: csv_content = "content,answer\n" csv_content += "This is the first segment content,This is the first answer\n" csv_content += "This is the second segment content,This is the second answer\n" @@ -264,7 +268,7 @@ class TestBatchCreateSegmentToIndexTask: upload_file = self._create_test_upload_file(db_session_with_containers, account, tenant) # Create CSV content - csv_content = self._create_test_csv_content("text_model") + csv_content = self._create_test_csv_content(IndexStructureType.PARAGRAPH_INDEX) # Mock storage to return our CSV content mock_storage = mock_external_service_dependencies["storage"] @@ -451,7 +455,7 @@ class TestBatchCreateSegmentToIndexTask: indexing_status=IndexingStatus.COMPLETED, enabled=False, # Document is disabled archived=False, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=0, ), # Archived document @@ -467,7 +471,7 @@ class TestBatchCreateSegmentToIndexTask: indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=True, # Document is archived - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=0, ), # Document with incomplete indexing @@ -483,7 +487,7 @@ class TestBatchCreateSegmentToIndexTask: indexing_status=IndexingStatus.INDEXING, # Not completed enabled=True, archived=False, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=0, ), ] @@ -655,7 +659,7 @@ class TestBatchCreateSegmentToIndexTask: db_session_with_containers.commit() # Create CSV content - csv_content = self._create_test_csv_content("text_model") + csv_content = self._create_test_csv_content(IndexStructureType.PARAGRAPH_INDEX) # Mock storage to return our CSV content mock_storage = mock_external_service_dependencies["storage"] diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py index 1cd698b870a..1dd37fbc92c 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py @@ -18,6 +18,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.storage.storage_type import StorageType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import ( @@ -153,7 +154,7 @@ class TestCleanDatasetTask: tenant_id=tenant.id, name="test_dataset", description="Test dataset for cleanup testing", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=str(uuid.uuid4()), created_by=account.id, @@ -192,7 +193,7 @@ class TestCleanDatasetTask: indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=100, created_at=datetime.now(), updated_at=datetime.now(), @@ -869,7 +870,7 @@ class TestCleanDatasetTask: tenant_id=tenant.id, name=long_name, description=long_description, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph", "max_length": 10000}', collection_binding_id=str(uuid.uuid4()), created_by=account.id, diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py index a2a190fd69f..926c839c8b6 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py @@ -12,6 +12,7 @@ from unittest.mock import Mock, patch import pytest from faker import Faker +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from services.account_service import AccountService, TenantService @@ -114,7 +115,7 @@ class TestCleanNotionDocumentTask: name=f"Notion Page {i}", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", # Set doc_form to ensure dataset.doc_form works + doc_form=IndexStructureType.PARAGRAPH_INDEX, # Set doc_form to ensure dataset.doc_form works doc_language="en", indexing_status=IndexingStatus.COMPLETED, ) @@ -261,7 +262,7 @@ class TestCleanNotionDocumentTask: # Test different index types # Note: Only testing text_model to avoid dependency on external services - index_types = ["text_model"] + index_types = [IndexStructureType.PARAGRAPH_INDEX] for index_type in index_types: # Create dataset (doc_form will be set via document creation) diff --git a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py index 132f43c3208..9f8e37fc9ec 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py @@ -12,6 +12,7 @@ from uuid import uuid4 import pytest from faker import Faker +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -120,7 +121,7 @@ class TestCreateSegmentToIndexTask: description=fake.text(max_nb_chars=100), tenant_id=tenant_id, data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", created_by=account_id, @@ -141,7 +142,7 @@ class TestCreateSegmentToIndexTask: enabled=True, archived=False, indexing_status=IndexingStatus.COMPLETED, - doc_form="qa_model", + doc_form=IndexStructureType.QA_INDEX, ) db_session_with_containers.add(document) db_session_with_containers.commit() @@ -301,7 +302,7 @@ class TestCreateSegmentToIndexTask: enabled=True, archived=False, indexing_status=IndexingStatus.COMPLETED, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) db_session_with_containers.commit() @@ -552,7 +553,11 @@ class TestCreateSegmentToIndexTask: - Processing completes successfully for different forms """ # Arrange: Test different doc_forms - doc_forms = ["qa_model", "text_model", "web_model"] + doc_forms = [ + IndexStructureType.QA_INDEX, + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.PARAGRAPH_INDEX, + ] for doc_form in doc_forms: # Create fresh test data for each form diff --git a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py index 67f9dc70113..13ea94348a6 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py @@ -8,6 +8,7 @@ import pytest from faker import Faker from core.indexing_runner import DocumentIsPausedError +from core.rag.index_processor.constant.index_type import IndexTechniqueType from enums.cloud_plan import CloudPlan from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document @@ -141,7 +142,7 @@ class TestDatasetIndexingTaskIntegration: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py index e80b37ac1b4..d457b59d588 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py @@ -12,6 +12,7 @@ from unittest.mock import ANY, Mock, patch import pytest from faker import Faker +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from services.account_service import AccountService, TenantService @@ -107,7 +108,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -167,7 +168,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -187,7 +188,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -268,7 +269,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="parent_child_index", + doc_form=IndexStructureType.PARENT_CHILD_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -288,7 +289,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="parent_child_index", + doc_form=IndexStructureType.PARENT_CHILD_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -416,7 +417,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -505,7 +506,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -525,7 +526,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -601,7 +602,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="qa_index", + doc_form=IndexStructureType.QA_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -638,7 +639,7 @@ class TestDealDatasetVectorIndexTask: assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor was initialized with custom index type - mock_index_processor_factory.assert_called_once_with("qa_index") + mock_index_processor_factory.assert_called_once_with(IndexStructureType.QA_INDEX) mock_factory = mock_index_processor_factory.return_value mock_processor = mock_factory.init_index_processor.return_value mock_processor.load.assert_called_once() @@ -677,7 +678,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -714,7 +715,7 @@ class TestDealDatasetVectorIndexTask: assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor was initialized with the document's index type - mock_index_processor_factory.assert_called_once_with("text_model") + mock_index_processor_factory.assert_called_once_with(IndexStructureType.PARAGRAPH_INDEX) mock_factory = mock_index_processor_factory.return_value mock_processor = mock_factory.init_index_processor.return_value mock_processor.load.assert_called_once() @@ -753,7 +754,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -775,7 +776,7 @@ class TestDealDatasetVectorIndexTask: name=f"Test Document {i}", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -856,7 +857,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -876,7 +877,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -953,7 +954,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -973,7 +974,7 @@ class TestDealDatasetVectorIndexTask: name="Enabled Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -992,7 +993,7 @@ class TestDealDatasetVectorIndexTask: name="Disabled Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=False, # This document should be skipped @@ -1074,7 +1075,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -1094,7 +1095,7 @@ class TestDealDatasetVectorIndexTask: name="Active Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -1113,7 +1114,7 @@ class TestDealDatasetVectorIndexTask: name="Archived Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -1195,7 +1196,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -1215,7 +1216,7 @@ class TestDealDatasetVectorIndexTask: name="Completed Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -1234,7 +1235,7 @@ class TestDealDatasetVectorIndexTask: name="Incomplete Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.INDEXING, # This document should be skipped enabled=True, diff --git a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py index 6fc2a53f9cd..8a69707b38b 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py @@ -12,7 +12,7 @@ from unittest.mock import MagicMock, patch from faker import Faker -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models import Account, Dataset, Document, DocumentSegment, Tenant from models.enums import DataSourceType, DocumentCreatedFrom, DocumentDocType, IndexingStatus, SegmentStatus from tasks.delete_segment_from_index_task import delete_segment_from_index_task @@ -108,7 +108,7 @@ class TestDeleteSegmentFromIndexTask: dataset.provider = "vendor" dataset.permission = "only_me" dataset.data_source_type = DataSourceType.UPLOAD_FILE - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.index_struct = '{"type": "paragraph"}' dataset.created_by = account.id dataset.created_at = fake.date_time_this_year() diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py index da42fc7167c..5bdf7d1389c 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py @@ -15,6 +15,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -99,7 +100,7 @@ class TestDisableSegmentFromIndexTask: name=fake.sentence(nb_words=3), description=fake.text(max_nb_chars=200), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -113,7 +114,7 @@ class TestDisableSegmentFromIndexTask: dataset: Dataset, tenant: Tenant, account: Account, - doc_form: str = "text_model", + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, ) -> Document: """ Helper method to create a test document. @@ -476,7 +477,11 @@ class TestDisableSegmentFromIndexTask: - Index processor clean method is called correctly """ # Test different document forms - doc_forms = ["text_model", "qa_model", "table_model"] + doc_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + ] for doc_form in doc_forms: # Arrange: Create test data for each form diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py index 4bc9bb47496..3e9a0c8f7f0 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -11,6 +11,7 @@ from unittest.mock import MagicMock, patch from faker import Faker from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models import Account, Dataset, DocumentSegment from models import Document as DatasetDocument from models.dataset import DatasetProcessRule @@ -102,7 +103,7 @@ class TestDisableSegmentsFromIndexTask: provider="vendor", permission="only_me", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, updated_by=account.id, embedding_model="text-embedding-ada-002", @@ -153,7 +154,7 @@ class TestDisableSegmentsFromIndexTask: document.indexing_status = "completed" document.enabled = True document.archived = False - document.doc_form = "text_model" # Use text_model form for testing + document.doc_form = IndexStructureType.PARAGRAPH_INDEX # Use text_model form for testing document.doc_language = "en" db_session_with_containers.add(document) db_session_with_containers.commit() @@ -500,7 +501,11 @@ class TestDisableSegmentsFromIndexTask: segment_ids = [segment.id for segment in segments] # Test different document forms - doc_forms = ["text_model", "qa_model", "hierarchical_model"] + doc_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + ] for doc_form in doc_forms: # Update document form diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py index 6a17a19a548..d4021143eff 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py @@ -14,6 +14,7 @@ from uuid import uuid4 import pytest from core.indexing_runner import DocumentIsPausedError, IndexingRunner +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus @@ -56,7 +57,7 @@ class DocumentIndexingSyncTaskTestDataFactory: name=f"dataset-{uuid4()}", description="sync test dataset", data_source_type=DataSourceType.NOTION_IMPORT, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=created_by, ) db_session_with_containers.add(dataset) @@ -85,7 +86,7 @@ class DocumentIndexingSyncTaskTestDataFactory: created_by=created_by, indexing_status=indexing_status, enabled=True, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", ) db_session_with_containers.add(document) diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py index 9421b072853..cf1a8666f35 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py @@ -5,6 +5,7 @@ import pytest from faker import Faker from core.entities.document_task import DocumentTask +from core.rag.index_processor.constant.index_type import IndexTechniqueType from enums.cloud_plan import CloudPlan from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document @@ -99,7 +100,7 @@ class TestDocumentIndexingTasks: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -181,7 +182,7 @@ class TestDocumentIndexingTasks: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py index 2fbea1388cd..d94abf2b40a 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus @@ -63,7 +64,7 @@ class TestDocumentIndexingUpdateTask: name=fake.company(), description=fake.text(max_nb_chars=64), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -80,7 +81,7 @@ class TestDocumentIndexingUpdateTask: created_by=account.id, indexing_status=IndexingStatus.WAITING, enabled=True, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py index f1f5a4b1053..6a8e1869580 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py @@ -4,6 +4,7 @@ import pytest from faker import Faker from core.indexing_runner import DocumentIsPausedError +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from enums.cloud_plan import CloudPlan from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -109,7 +110,7 @@ class TestDuplicateDocumentIndexingTasks: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -130,7 +131,7 @@ class TestDuplicateDocumentIndexingTasks: created_by=account.id, indexing_status=IndexingStatus.WAITING, enabled=True, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) documents.append(document) @@ -244,7 +245,7 @@ class TestDuplicateDocumentIndexingTasks: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -265,7 +266,7 @@ class TestDuplicateDocumentIndexingTasks: created_by=account.id, indexing_status=IndexingStatus.WAITING, enabled=True, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) documents.append(document) @@ -524,7 +525,7 @@ class TestDuplicateDocumentIndexingTasks: created_by=dataset.created_by, indexing_status=IndexingStatus.WAITING, enabled=True, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) extra_documents.append(document) diff --git a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py index 54b50016a8a..e2f35067e3d 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py @@ -4,7 +4,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -81,7 +81,7 @@ class TestEnableSegmentsToIndexTask: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py index 0876a39f821..a16f3ff773b 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py @@ -3,22 +3,22 @@ from datetime import UTC, datetime from unittest.mock import patch import pytest +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.entities import HumanInputNodeData +from graphon.runtime import GraphRuntimeState, VariablePool from configs import dify_config from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - HumanInputNodeData, MemberRecipient, ) -from dify_graph.runtime import GraphRuntimeState, VariablePool from extensions.ext_storage import storage from models.account import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom @@ -79,9 +79,9 @@ def _build_form(db_session_with_containers, tenant, account, *, app_id: str, wor delivery_method = EmailDeliveryMethod( config=EmailDeliveryConfig( recipients=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(user_id=account.id), + MemberRecipient(reference_id=account.id), ExternalRecipient(email="external@example.com"), ], ), @@ -96,9 +96,8 @@ def _build_form(db_session_with_containers, tenant, account, *, app_id: str, wor delivery_methods=[delivery_method], ) - repo = HumanInputFormRepositoryImpl(tenant_id=tenant.id) + repo = HumanInputFormRepositoryImpl(tenant_id=tenant.id, app_id=app_id) params = FormCreateParams( - app_id=app_id, workflow_execution_id=workflow_execution_id, node_id="node-1", form_config=node_data, diff --git a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py index 5bded4d6705..96cf9cebf5e 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -2,10 +2,10 @@ import uuid from unittest.mock import ANY, call, patch import pytest +from graphon.variables.segments import StringSegment +from graphon.variables.types import SegmentType from core.db.session_factory import session_factory -from dify_graph.variables.segments import StringSegment -from dify_graph.variables.types import SegmentType from extensions.storage.storage_type import StorageType from libs.datetime_utils import naive_utc_now from models import Tenant diff --git a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py index ca76fa0a4b5..159ab51304a 100644 --- a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py +++ b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py @@ -24,11 +24,11 @@ from dataclasses import dataclass from datetime import timedelta import pytest +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus from sqlalchemy import delete, select from sqlalchemy.orm import Session, selectinload, sessionmaker -from dify_graph.entities import WorkflowExecution -from dify_graph.enums import WorkflowExecutionStatus from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from models import Account diff --git a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py index 4ea8d8c1c73..7539bae6855 100644 --- a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py +++ b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py @@ -10,6 +10,7 @@ from typing import Any import pytest from flask import Flask, Response from flask.testing import FlaskClient +from graphon.enums import BuiltinNodeTypes from sqlalchemy.orm import Session from configs import dify_config @@ -23,7 +24,6 @@ from core.trigger.debug import event_selectors from core.trigger.debug.event_bus import TriggerDebugEventBus from core.trigger.debug.event_selectors import PluginTriggerDebugEventPoller, WebhookTriggerDebugEventPoller from core.trigger.debug.events import PluginTriggerDebugEvent, build_plugin_pool_key -from dify_graph.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models.account import Account, Tenant from models.enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTriggerStatus diff --git a/api/tests/unit_tests/conftest.py b/api/tests/unit_tests/conftest.py index 3f75fd2851f..55873b06a8a 100644 --- a/api/tests/unit_tests/conftest.py +++ b/api/tests/unit_tests/conftest.py @@ -123,27 +123,26 @@ def _configure_session_factory(_unit_test_engine): def setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account): """ - Helper to set up the mock DB query chain for tenant/account authentication. + Helper to set up the mock DB execute chain for tenant/account authentication. - This configures the mock to return (tenant, account) for the join query used - by validate_app_token and validate_dataset_token decorators. + This configures the mock to return (tenant, account) for the + db.session.execute(select(...).join().join().where()).one_or_none() + query used by validate_app_token decorator. Args: mock_db: The mocked db object mock_tenant: Mock tenant object to return mock_account: Mock account object to return """ - query = mock_db.session.query.return_value - join_chain = query.join.return_value.join.return_value - where_chain = join_chain.where.return_value - where_chain.one_or_none.return_value = (mock_tenant, mock_account) + mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_account) def setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta): """ - Helper to set up the mock DB query chain for dataset tenant authentication. + Helper to set up the mock DB execute chain for dataset tenant authentication. - This configures the mock to return (tenant, tenant_account) for the where chain + This configures the mock to return (tenant, tenant_account) for the + db.session.execute(select(...).where().where().where().where()).one_or_none() query used by validate_dataset_token decorator. Args: @@ -151,6 +150,4 @@ def setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta): mock_tenant: Mock tenant object to return mock_ta: Mock tenant account object to return """ - query = mock_db.session.query.return_value - where_chain = query.where.return_value.where.return_value.where.return_value.where.return_value - where_chain.one_or_none.return_value = (mock_tenant, mock_ta) + mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta) diff --git a/api/tests/unit_tests/controllers/console/app/test_app_apis.py b/api/tests/unit_tests/controllers/console/app/test_app_apis.py index 60b8ee96fec..1d1e119fd61 100644 --- a/api/tests/unit_tests/controllers/console/app/test_app_apis.py +++ b/api/tests/unit_tests/controllers/console/app/test_app_apis.py @@ -7,14 +7,19 @@ from __future__ import annotations import uuid from types import SimpleNamespace -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest +from pydantic import ValidationError from werkzeug.exceptions import BadRequest, NotFound +from controllers.console import console_ns from controllers.console.app import ( annotation as annotation_module, ) +from controllers.console.app import ( + app as app_module, +) from controllers.console.app import ( completion as completion_module, ) @@ -203,6 +208,48 @@ class TestCompletionEndpoints: method(app_model=MagicMock(id="app-1")) +class TestAppEndpoints: + """Tests for app endpoints.""" + + def test_app_put_should_preserve_icon_type_when_payload_omits_it(self, app, monkeypatch): + api = app_module.AppApi() + method = _unwrap(api.put) + payload = { + "name": "Updated App", + "description": "Updated description", + "icon": "🤖", + "icon_background": "#FFFFFF", + } + app_service = MagicMock() + app_service.update_app.return_value = SimpleNamespace() + response_model = MagicMock() + response_model.model_dump.return_value = {"id": "app-1"} + + monkeypatch.setattr(app_module, "AppService", lambda: app_service) + monkeypatch.setattr(app_module.AppDetailWithSite, "model_validate", MagicMock(return_value=response_model)) + + with ( + app.test_request_context("/console/api/apps/app-1", method="PUT", json=payload), + patch.object(type(console_ns), "payload", payload), + ): + response = method(app_model=SimpleNamespace(icon_type=app_module.IconType.EMOJI)) + + assert response == {"id": "app-1"} + assert app_service.update_app.call_args.args[1]["icon_type"] is None + + def test_update_app_payload_should_reject_empty_icon_type(self): + with pytest.raises(ValidationError): + app_module.UpdateAppPayload.model_validate( + { + "name": "Updated App", + "description": "Updated description", + "icon_type": "", + "icon": "🤖", + "icon_background": "#FFFFFF", + } + ) + + # ========== OpsTrace Tests ========== class TestOpsTraceEndpoints: """Tests for ops_trace endpoint.""" @@ -281,12 +328,10 @@ class TestSiteEndpoints: method = _unwrap(api.post) site = MagicMock() - query = MagicMock() - query.where.return_value.first.return_value = site monkeypatch.setattr( site_module.db, "session", - MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None), + MagicMock(scalar=lambda *_args, **_kwargs: site, commit=lambda: None), ) monkeypatch.setattr( site_module, @@ -305,12 +350,10 @@ class TestSiteEndpoints: method = _unwrap(api.post) site = MagicMock() - query = MagicMock() - query.where.return_value.first.return_value = site monkeypatch.setattr( site_module.db, "session", - MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None), + MagicMock(scalar=lambda *_args, **_kwargs: site, commit=lambda: None), ) monkeypatch.setattr(site_module.Site, "generate_code", lambda *_args, **_kwargs: "code") monkeypatch.setattr( diff --git a/api/tests/unit_tests/controllers/console/app/test_audio.py b/api/tests/unit_tests/controllers/console/app/test_audio.py index 021e9a07840..c52bc02420e 100644 --- a/api/tests/unit_tests/controllers/console/app/test_audio.py +++ b/api/tests/unit_tests/controllers/console/app/test_audio.py @@ -4,6 +4,7 @@ import io from types import SimpleNamespace import pytest +from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.datastructures import FileStorage from werkzeug.exceptions import InternalServerError @@ -20,7 +21,6 @@ from controllers.console.app.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.audio import ( diff --git a/api/tests/unit_tests/controllers/console/app/test_conversation_api.py b/api/tests/unit_tests/controllers/console/app/test_conversation_api.py index 5db8e5c3321..11b3b3470d6 100644 --- a/api/tests/unit_tests/controllers/console/app/test_conversation_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_conversation_api.py @@ -82,12 +82,8 @@ def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: p def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> None: conversation = SimpleNamespace(id="c1", app_id="app-1") - query = MagicMock() - query.where.return_value = query - query.first.return_value = conversation - session = MagicMock() - session.query.return_value = query + session.scalar.return_value = conversation monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1")) monkeypatch.setattr(conversation_module.db, "session", session) @@ -101,12 +97,8 @@ def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> No def test_get_conversation_missing_raises_not_found(monkeypatch: pytest.MonkeyPatch) -> None: - query = MagicMock() - query.where.return_value = query - query.first.return_value = None - session = MagicMock() - session.query.return_value = query + session.scalar.return_value = None monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1")) monkeypatch.setattr(conversation_module.db, "session", session) diff --git a/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py b/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py index 460da06ecc8..f588ab261d8 100644 --- a/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py +++ b/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py @@ -24,7 +24,7 @@ def test_get_conversation_mark_read_keeps_updated_at_unchanged(): ), patch("controllers.console.app.conversation.db.session", autospec=True) as mock_session, ): - mock_session.query.return_value.where.return_value.first.return_value = conversation + mock_session.scalar.return_value = conversation _get_conversation(app_model, "conversation-id") diff --git a/api/tests/unit_tests/controllers/console/app/test_generator_api.py b/api/tests/unit_tests/controllers/console/app/test_generator_api.py index f83bc18da31..e64c508b822 100644 --- a/api/tests/unit_tests/controllers/console/app/test_generator_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_generator_api.py @@ -73,8 +73,7 @@ def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: None) - monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: None)) with app.test_request_context( "/console/api/instruction-generate", @@ -99,8 +98,7 @@ def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.Monkey monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) app_model = SimpleNamespace(id="app-1") - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) - monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model)) _install_workflow_service(monkeypatch, workflow=None) with app.test_request_context( @@ -126,8 +124,7 @@ def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch) monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) app_model = SimpleNamespace(id="app-1") - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) - monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model)) workflow = SimpleNamespace(graph_dict={"nodes": []}) _install_workflow_service(monkeypatch, workflow=workflow) @@ -155,8 +152,7 @@ def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) -> monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) app_model = SimpleNamespace(id="app-1") - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) - monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model)) workflow = SimpleNamespace( graph_dict={ diff --git a/api/tests/unit_tests/controllers/console/app/test_model_config_api.py b/api/tests/unit_tests/controllers/console/app/test_model_config_api.py index 61d92bb5c76..a0e2edb8cf6 100644 --- a/api/tests/unit_tests/controllers/console/app/test_model_config_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_model_config_api.py @@ -92,10 +92,7 @@ def test_post_encrypts_agent_tool_parameters(app, monkeypatch: pytest.MonkeyPatc ) session = MagicMock() - query = MagicMock() - query.where.return_value = query - query.first.return_value = original_config - session.query.return_value = query + session.get.return_value = original_config monkeypatch.setattr(model_config_module.db, "session", session) monkeypatch.setattr( diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow.py b/api/tests/unit_tests/controllers/console/app/test_workflow.py index f100080eaa3..36076368805 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow.py @@ -5,12 +5,11 @@ from types import SimpleNamespace from unittest.mock import Mock import pytest +from graphon.file import File, FileTransferMethod, FileType from werkzeug.exceptions import HTTPException, NotFound from controllers.console.app import workflow as workflow_module from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync -from dify_graph.file.enums import FileTransferMethod, FileType -from dify_graph.file.models import File def _unwrap(func): @@ -129,6 +128,136 @@ def test_sync_draft_workflow_hash_mismatch(app, monkeypatch: pytest.MonkeyPatch) handler(api, app_model=SimpleNamespace(id="app")) +def test_restore_published_workflow_to_draft_success(app, monkeypatch: pytest.MonkeyPatch) -> None: + workflow = SimpleNamespace( + unique_hash="restored-hash", + updated_at=None, + created_at=datetime(2024, 1, 1), + ) + user = SimpleNamespace(id="account-1") + + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1")) + monkeypatch.setattr( + workflow_module, + "WorkflowService", + lambda: SimpleNamespace(restore_published_workflow_to_draft=lambda **_kwargs: workflow), + ) + + api = workflow_module.DraftWorkflowRestoreApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/apps/app/workflows/published-workflow/restore", + method="POST", + ): + response = handler( + api, + app_model=SimpleNamespace(id="app", tenant_id="tenant-1"), + workflow_id="published-workflow", + ) + + assert response["result"] == "success" + assert response["hash"] == "restored-hash" + + +def test_restore_published_workflow_to_draft_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None: + user = SimpleNamespace(id="account-1") + + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1")) + monkeypatch.setattr( + workflow_module, + "WorkflowService", + lambda: SimpleNamespace( + restore_published_workflow_to_draft=lambda **_kwargs: (_ for _ in ()).throw( + workflow_module.WorkflowNotFoundError("Workflow not found") + ) + ), + ) + + api = workflow_module.DraftWorkflowRestoreApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/apps/app/workflows/published-workflow/restore", + method="POST", + ): + with pytest.raises(NotFound): + handler( + api, + app_model=SimpleNamespace(id="app", tenant_id="tenant-1"), + workflow_id="published-workflow", + ) + + +def test_restore_published_workflow_to_draft_returns_400_for_draft_source(app, monkeypatch: pytest.MonkeyPatch) -> None: + user = SimpleNamespace(id="account-1") + + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1")) + monkeypatch.setattr( + workflow_module, + "WorkflowService", + lambda: SimpleNamespace( + restore_published_workflow_to_draft=lambda **_kwargs: (_ for _ in ()).throw( + workflow_module.IsDraftWorkflowError( + "Cannot use draft workflow version. Workflow ID: draft-workflow. " + "Please use a published workflow version or leave workflow_id empty." + ) + ) + ), + ) + + api = workflow_module.DraftWorkflowRestoreApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/apps/app/workflows/draft-workflow/restore", + method="POST", + ): + with pytest.raises(HTTPException) as exc: + handler( + api, + app_model=SimpleNamespace(id="app", tenant_id="tenant-1"), + workflow_id="draft-workflow", + ) + + assert exc.value.code == 400 + assert exc.value.description == workflow_module.RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE + + +def test_restore_published_workflow_to_draft_returns_400_for_invalid_structure( + app, monkeypatch: pytest.MonkeyPatch +) -> None: + user = SimpleNamespace(id="account-1") + + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1")) + monkeypatch.setattr( + workflow_module, + "WorkflowService", + lambda: SimpleNamespace( + restore_published_workflow_to_draft=lambda **_kwargs: (_ for _ in ()).throw( + ValueError("invalid workflow graph") + ) + ), + ) + + api = workflow_module.DraftWorkflowRestoreApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/apps/app/workflows/published-workflow/restore", + method="POST", + ): + with pytest.raises(HTTPException) as exc: + handler( + api, + app_model=SimpleNamespace(id="app", tenant_id="tenant-1"), + workflow_id="published-workflow", + ) + + assert exc.value.code == 400 + assert exc.value.description == "invalid workflow graph" + + def test_draft_workflow_get_not_found(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( workflow_module, "WorkflowService", lambda: SimpleNamespace(get_draft_workflow=lambda **_k: None) diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py index 83601dc1b98..e11102acb1e 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py @@ -6,14 +6,14 @@ from unittest.mock import Mock import pytest from flask import Flask +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType from controllers.console import wraps as console_wraps from controllers.console.app import workflow_run as workflow_run_module from controllers.web.error import NotFoundError -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.nodes.human_input.entities import FormInput, UserAction -from dify_graph.nodes.human_input.enums import FormInputType from libs import login as login_lib from models.account import Account, AccountStatus, TenantAccountRole from models.workflow import WorkflowRun @@ -67,7 +67,6 @@ def test_pause_details_returns_backstage_input_url(app: Flask, monkeypatch: pyte actions=[UserAction(id="approve", title="Approve")], node_id="node-1", node_title="Ask Name", - form_token="backstage-token", ) pause_entity = _PauseEntity(paused_at=datetime(2024, 1, 1, 12, 0, 0), reasons=[reason]) @@ -78,6 +77,11 @@ def test_pause_details_returns_backstage_input_url(app: Flask, monkeypatch: pyte "create_api_workflow_run_repository", lambda *_, **__: repo, ) + monkeypatch.setattr( + workflow_run_module, + "_load_form_tokens_by_form_id", + lambda _form_ids: {"form-1": "backstage-token"}, + ) with app.test_request_context("/console/api/workflow/run-1/pause-details", method="GET"): response, status = workflow_run_module.ConsoleWorkflowPauseDetailsApi().get(workflow_run_id="run-1") diff --git a/api/tests/unit_tests/controllers/console/app/test_wraps.py b/api/tests/unit_tests/controllers/console/app/test_wraps.py index 7664e492dab..b5f751f5a5a 100644 --- a/api/tests/unit_tests/controllers/console/app/test_wraps.py +++ b/api/tests/unit_tests/controllers/console/app/test_wraps.py @@ -11,10 +11,8 @@ from models.model import AppMode def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None: app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1") - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) - monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1")) - monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(scalar=lambda *_args, **_kwargs: app_model)) @wraps_module.get_app_model def handler(app_model): @@ -25,10 +23,8 @@ def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None: def test_get_app_model_rejects_wrong_mode(monkeypatch: pytest.MonkeyPatch) -> None: app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1") - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) - monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1")) - monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(scalar=lambda *_args, **_kwargs: app_model)) @wraps_module.get_app_model(mode=[AppMode.COMPLETION]) def handler(app_model): diff --git a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py index f34702a2579..740da1f1df1 100644 --- a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py +++ b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py @@ -5,6 +5,7 @@ from unittest.mock import MagicMock, patch import pytest from flask_restx import marshal +from graphon.variables.types import SegmentType from controllers.console.app.workflow_draft_variable import ( _WORKFLOW_DRAFT_VARIABLE_FIELDS, @@ -13,8 +14,7 @@ from controllers.console.app.workflow_draft_variable import ( _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, _serialize_full_content, ) -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from dify_graph.variables.types import SegmentType +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from factories.variable_factory import build_segment from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 @@ -310,13 +310,11 @@ def test_workflow_node_variables_fields(): def test_workflow_file_variable_with_signed_url(): """Test that File type variables include signed URLs in API responses.""" - from dify_graph.file.enums import FileTransferMethod, FileType - from dify_graph.file.models import File + from graphon.file import File, FileTransferMethod, FileType # Create a File object with LOCAL_FILE transfer method (which generates signed URLs) test_file = File( id="test_file_id", - tenant_id="test_tenant_id", type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="test_upload_file_id", @@ -368,13 +366,11 @@ def test_workflow_file_variable_with_signed_url(): def test_workflow_file_variable_remote_url(): """Test that File type variables with REMOTE_URL transfer method return the remote URL.""" - from dify_graph.file.enums import FileTransferMethod, FileType - from dify_graph.file.models import File + from graphon.file import File, FileTransferMethod, FileType # Create a File object with REMOTE_URL transfer method test_file = File( id="test_file_id", - tenant_id="test_tenant_id", type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/test.jpg", diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py index 9014edc39e4..9c9f8da87c1 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from werkzeug.exceptions import Forbidden, NotFound from controllers.console import console_ns @@ -17,7 +18,6 @@ from controllers.console.datasets.rag_pipeline.datasource_auth import ( DatasourceUpdateProviderNameApi, ) from core.plugin.impl.oauth import OAuthHandler -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError from services.datasource_provider_service import DatasourceProviderService from services.plugin.oauth_service import OAuthProxyService diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py index b4c0903f63e..6ef8ccfdbd3 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch import pytest from flask import Response +from graphon.variables.types import SegmentType from controllers.console import console_ns from controllers.console.app.error import DraftWorkflowNotExist @@ -14,8 +15,7 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable impor RagPipelineVariableResetApi, ) from controllers.web.error import InvalidArgumentError, NotFoundError -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID -from dify_graph.variables.types import SegmentType +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID from models.account import Account diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py index 7775cbdd81a..a3c0592d766 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py @@ -2,7 +2,7 @@ from datetime import datetime from unittest.mock import MagicMock, patch import pytest -from werkzeug.exceptions import Forbidden, NotFound +from werkzeug.exceptions import BadRequest, Forbidden, HTTPException, NotFound import services from controllers.console import console_ns @@ -19,13 +19,15 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline_workflow import ( RagPipelineDraftNodeRunApi, RagPipelineDraftRunIterationNodeApi, RagPipelineDraftRunLoopNodeApi, + RagPipelineDraftWorkflowRestoreApi, RagPipelineRecommendedPluginApi, RagPipelineTaskStopApi, RagPipelineTransformApi, RagPipelineWorkflowLastRunApi, ) from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError -from services.errors.app import WorkflowHashNotEqualError +from libs.datetime_utils import naive_utc_now +from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError @@ -116,6 +118,86 @@ class TestDraftWorkflowApi: response, status = method(api, pipeline) assert status == 400 + def test_restore_published_workflow_to_draft_success(self, app): + api = RagPipelineDraftWorkflowRestoreApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock(id="account-1") + workflow = MagicMock(unique_hash="restored-hash", updated_at=None, created_at=datetime(2024, 1, 1)) + + service = MagicMock() + service.restore_published_workflow_to_draft.return_value = workflow + + with ( + app.test_request_context("/", method="POST"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + result = method(api, pipeline, "published-workflow") + + assert result["result"] == "success" + assert result["hash"] == "restored-hash" + + def test_restore_published_workflow_to_draft_not_found(self, app): + api = RagPipelineDraftWorkflowRestoreApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock(id="account-1") + + service = MagicMock() + service.restore_published_workflow_to_draft.side_effect = WorkflowNotFoundError("Workflow not found") + + with ( + app.test_request_context("/", method="POST"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + with pytest.raises(NotFound): + method(api, pipeline, "published-workflow") + + def test_restore_published_workflow_to_draft_returns_400_for_draft_source(self, app): + api = RagPipelineDraftWorkflowRestoreApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock(id="account-1") + + service = MagicMock() + service.restore_published_workflow_to_draft.side_effect = IsDraftWorkflowError( + "source workflow must be published" + ) + + with ( + app.test_request_context("/", method="POST"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + with pytest.raises(HTTPException) as exc: + method(api, pipeline, "draft-workflow") + + assert exc.value.code == 400 + assert exc.value.description == "source workflow must be published" + class TestDraftRunNodes: def test_iteration_node_success(self, app): @@ -291,7 +373,7 @@ class TestPublishedPipelineApis: workflow = MagicMock( id="w1", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) session = MagicMock() @@ -610,6 +692,57 @@ class TestRagPipelineByIdApi: result, status = method(api, pipeline, "w1") assert status == 400 + def test_delete_success(self, app): + api = RagPipelineByIdApi() + method = unwrap(api.delete) + + pipeline = MagicMock(tenant_id="t1", workflow_id="active-workflow", id="pipeline-1") + + workflow_service = MagicMock() + + session = MagicMock() + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = None + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + with ( + app.test_request_context("/", method="DELETE"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session", + return_value=session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.WorkflowService", + return_value=workflow_service, + ), + ): + result = method(api, pipeline, "old-workflow") + + workflow_service.delete_workflow.assert_called_once_with( + session=session, + workflow_id="old-workflow", + tenant_id="t1", + ) + session.commit.assert_called_once() + assert result == (None, 204) + + def test_delete_active_workflow_rejected(self, app): + api = RagPipelineByIdApi() + method = unwrap(api.delete) + + pipeline = MagicMock(tenant_id="t1", workflow_id="active-workflow", id="pipeline-1") + + with app.test_request_context("/", method="DELETE"): + with pytest.raises(BadRequest, match="currently in use by pipeline"): + method(api, pipeline, "active-workflow") + class TestRagPipelineWorkflowLastRunApi: def test_last_run_success(self, app): diff --git a/api/tests/unit_tests/controllers/console/datasets/test_data_source.py b/api/tests/unit_tests/controllers/console/datasets/test_data_source.py index 3060062adff..d841f67f9ba 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_data_source.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_data_source.py @@ -11,6 +11,7 @@ from controllers.console.datasets.data_source import ( DataSourceNotionDocumentSyncApi, DataSourceNotionListApi, ) +from core.rag.index_processor.constant.index_type import IndexStructureType def unwrap(func): @@ -343,7 +344,7 @@ class TestDataSourceNotionApi: } ], "process_rule": {"rules": {}}, - "doc_form": "text_model", + "doc_form": IndexStructureType.PARAGRAPH_INDEX, "doc_language": "English", } diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py index 0ee76e504be..8555900f4ec 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py @@ -28,6 +28,7 @@ from controllers.console.datasets.datasets import ( from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.provider_manager import ProviderManager +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.storage.storage_type import StorageType from models.enums import CreatorUserRole from models.model import ApiToken, UploadFile @@ -416,7 +417,7 @@ class TestDatasetApiGet: "check_dataset_permission", return_value=None, ), - patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock, + patch("controllers.console.datasets.datasets.create_plugin_provider_manager") as provider_manager_mock, ): # embedding models exist → embedding_available stays True provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = [] @@ -520,7 +521,7 @@ class TestDatasetApiGet: "check_dataset_permission", return_value=None, ), - patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock, + patch("controllers.console.datasets.datasets.create_plugin_provider_manager") as provider_manager_mock, ): # embedding model NOT configured provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = [] @@ -579,7 +580,7 @@ class TestDatasetApiGet: "get_dataset_partial_member_list", return_value=partial_members, ), - patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock, + patch("controllers.console.datasets.datasets.create_plugin_provider_manager") as provider_manager_mock, ): provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = [] @@ -1146,7 +1147,7 @@ class TestDatasetIndexingEstimateApi: }, "process_rule": {"chunk_size": 100}, "indexing_technique": "high_quality", - "doc_form": "text_model", + "doc_form": IndexStructureType.PARAGRAPH_INDEX, "doc_language": "English", "dataset_id": None, } @@ -1475,8 +1476,8 @@ class TestDatasetIndexingStatusApi: return_value=MagicMock(all=lambda: [document]), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 3)), + "controllers.console.datasets.datasets.db.session.scalar", + return_value=3, ), ): response, status = method(api, "dataset-1") @@ -1525,13 +1526,6 @@ class TestDatasetIndexingStatusApi: document.error = None document.stopped_at = None - # First count = completed segments, second = total segments - query_mock = MagicMock() - query_mock.where.side_effect = [ - MagicMock(count=lambda: 2), - MagicMock(count=lambda: 5), - ] - with ( app.test_request_context("/"), patch( @@ -1543,8 +1537,8 @@ class TestDatasetIndexingStatusApi: return_value=MagicMock(all=lambda: [document]), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=query_mock, + "controllers.console.datasets.datasets.db.session.scalar", + side_effect=[2, 5], ), ): response, status = method(api, "dataset-1") @@ -1590,8 +1584,8 @@ class TestDatasetApiKeyApi: return_value=(MagicMock(), "tenant-1"), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 3)), + "controllers.console.datasets.datasets.db.session.scalar", + return_value=3, ), patch( "controllers.console.datasets.datasets.ApiToken.generate_api_key", @@ -1624,8 +1618,8 @@ class TestDatasetApiKeyApi: return_value=(MagicMock(), "tenant-1"), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 10)), + "controllers.console.datasets.datasets.db.session.scalar", + return_value=10, ), ): with pytest.raises(BadRequest) as exc_info: @@ -1652,8 +1646,8 @@ class TestDatasetApiDeleteApi: return_value=(MagicMock(), "tenant-1"), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(first=lambda: mock_key)), + "controllers.console.datasets.datasets.db.session.scalar", + return_value=mock_key, ), patch( "controllers.console.datasets.datasets.db.session.commit", @@ -1680,8 +1674,8 @@ class TestDatasetApiDeleteApi: return_value=(MagicMock(), "tenant-1"), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(first=lambda: None)), + "controllers.console.datasets.datasets.db.session.scalar", + return_value=None, ), ): with pytest.raises(NotFound): diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py index f23dd5b44aa..ce2278de4f8 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py @@ -30,6 +30,7 @@ from controllers.console.datasets.error import ( InvalidActionError, InvalidMetadataError, ) +from core.rag.index_processor.constant.index_type import IndexStructureType from models.enums import DataSourceType, IndexingStatus @@ -66,7 +67,7 @@ def document(): indexing_status=IndexingStatus.INDEXING, data_source_type=DataSourceType.UPLOAD_FILE, data_source_info_dict={"upload_file_id": "file-1"}, - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, archived=False, is_paused=False, dataset_process_rule=None, @@ -139,8 +140,8 @@ class TestDatasetDocumentListApi: return_value=pagination, ), patch( - "controllers.console.datasets.datasets_document.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(count=count_mock)), + "controllers.console.datasets.datasets_document.db.session.scalar", + return_value=2, ), patch( "controllers.console.datasets.datasets_document.DocumentService.enrich_documents_with_summary_index_status", @@ -699,10 +700,8 @@ class TestDocumentPipelineExecutionLogApi: return_value=MagicMock(), ), patch( - "controllers.console.datasets.datasets_document.db.session.query", - return_value=MagicMock( - filter_by=lambda **k: MagicMock(order_by=lambda *a: MagicMock(first=lambda: log)) - ), + "controllers.console.datasets.datasets_document.db.session.scalar", + return_value=log, ), ): response, status = method(api, "ds-1", "doc-1") @@ -765,8 +764,8 @@ class TestDocumentGenerateSummaryApi: summary_index_setting={"enable": True}, ) - doc1 = MagicMock(id="doc-1", doc_form="qa_model") - doc2 = MagicMock(id="doc-2", doc_form="text") + doc1 = MagicMock(id="doc-1", doc_form=IndexStructureType.QA_INDEX) + doc2 = MagicMock(id="doc-2", doc_form=IndexStructureType.PARAGRAPH_INDEX) payload = {"document_list": ["doc-1", "doc-2"]} @@ -822,19 +821,16 @@ class TestDocumentIndexingEstimateApi: data_source_type=DataSourceType.UPLOAD_FILE, data_source_info_dict={"upload_file_id": "file-1"}, tenant_id="tenant-1", - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, dataset_process_rule=None, ) - query_mock = MagicMock() - query_mock.where.return_value.first.return_value = None - with ( app.test_request_context("/"), patch.object(api, "get_document", return_value=document), patch( - "controllers.console.datasets.datasets_document.db.session.query", - return_value=query_mock, + "controllers.console.datasets.datasets_document.db.session.scalar", + return_value=None, ), ): with pytest.raises(NotFound): @@ -849,7 +845,7 @@ class TestDocumentIndexingEstimateApi: data_source_type=DataSourceType.UPLOAD_FILE, data_source_info_dict={"upload_file_id": "file-1"}, tenant_id="tenant-1", - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, dataset_process_rule=None, ) @@ -862,10 +858,8 @@ class TestDocumentIndexingEstimateApi: app.test_request_context("/"), patch.object(api, "get_document", return_value=document), patch( - "controllers.console.datasets.datasets_document.db.session.query", - return_value=MagicMock( - where=MagicMock(return_value=MagicMock(first=MagicMock(return_value=upload_file))) - ), + "controllers.console.datasets.datasets_document.db.session.scalar", + return_value=upload_file, ), patch( "controllers.console.datasets.datasets_document.ExtractSetting", @@ -973,7 +967,7 @@ class TestDocumentBatchIndexingEstimateApi: "mode": "single", "only_main_content": True, }, - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) with ( @@ -1001,7 +995,7 @@ class TestDocumentBatchIndexingEstimateApi: "notion_page_id": "p1", "type": "page", }, - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) with ( @@ -1024,7 +1018,7 @@ class TestDocumentBatchIndexingEstimateApi: indexing_status=IndexingStatus.INDEXING, data_source_type="unknown", data_source_info_dict={}, - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) with app.test_request_context("/"), patch.object(api, "get_batch_documents", return_value=[document]): @@ -1238,12 +1232,8 @@ class TestDocumentPermissionCases: return_value=None, ), patch( - "controllers.console.datasets.datasets_document.db.session.query", - return_value=MagicMock( - where=lambda *a: MagicMock( - order_by=lambda *b: MagicMock(limit=lambda n: MagicMock(one_or_none=lambda: process_rule)) - ) - ), + "controllers.console.datasets.datasets_document.db.session.scalar", + return_value=process_rule, ), ): result = method(api) @@ -1353,7 +1343,7 @@ class TestDocumentIndexingEdgeCases: data_source_type=DataSourceType.UPLOAD_FILE, data_source_info_dict={"upload_file_id": "file-1"}, tenant_id="tenant-1", - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, dataset_process_rule=None, ) @@ -1363,8 +1353,8 @@ class TestDocumentIndexingEdgeCases: app.test_request_context("/"), patch.object(api, "get_document", return_value=document), patch( - "controllers.console.datasets.datasets_document.db.session.query", - return_value=MagicMock(where=lambda *a: MagicMock(first=lambda: upload_file)), + "controllers.console.datasets.datasets_document.db.session.scalar", + return_value=upload_file, ), patch( "controllers.console.datasets.datasets_document.ExtractSetting", diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py index e67e4daad93..693b06e95bc 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py @@ -1,4 +1,3 @@ -from datetime import datetime from types import SimpleNamespace from unittest.mock import MagicMock, patch @@ -24,6 +23,8 @@ from controllers.console.datasets.error import ( InvalidActionError, ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError +from core.rag.index_processor.constant.index_type import IndexStructureType +from libs.datetime_utils import naive_utc_now from models.dataset import ChildChunk, DocumentSegment from models.model import UploadFile @@ -53,8 +54,8 @@ def _segment(): disabled_by=None, status="normal", created_by="u1", - created_at=datetime.utcnow(), - updated_at=datetime.utcnow(), + created_at=naive_utc_now(), + updated_at=naive_utc_now(), updated_by="u1", indexing_at=None, completed_at=None, @@ -366,7 +367,7 @@ class TestDatasetDocumentSegmentAddApi: dataset.indexing_technique = "economy" document = MagicMock() - document.doc_form = "text" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX segment = MagicMock() segment.id = "seg-1" @@ -505,7 +506,7 @@ class TestDatasetDocumentSegmentUpdateApi: dataset.indexing_technique = "economy" document = MagicMock() - document.doc_form = "text" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX segment = MagicMock() @@ -525,8 +526,8 @@ class TestDatasetDocumentSegmentUpdateApi: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=segment, ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", @@ -620,8 +621,8 @@ class TestDatasetDocumentSegmentBatchImportApi: return_value=MagicMock(), ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=upload_file, ), patch( "controllers.console.datasets.datasets_segments.redis_client.setnx", @@ -705,8 +706,8 @@ class TestDatasetDocumentSegmentBatchImportApi: return_value=MagicMock(), ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: None)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=None, ), ): with pytest.raises(NotFound): @@ -737,8 +738,8 @@ class TestDatasetDocumentSegmentBatchImportApi: return_value=MagicMock(), ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=upload_file, ), ): with pytest.raises(ValueError): @@ -769,8 +770,8 @@ class TestDatasetDocumentSegmentBatchImportApi: return_value=MagicMock(), ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=upload_file, ), patch( "controllers.console.datasets.datasets_segments.redis_client.setnx", @@ -830,8 +831,8 @@ class TestChildChunkAddApi: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=segment, ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", @@ -879,8 +880,8 @@ class TestChildChunkAddApi: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=segment, ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", @@ -923,11 +924,8 @@ class TestChildChunkUpdateApi: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - side_effect=[ - MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), - MagicMock(where=lambda *a, **k: MagicMock(first=lambda: child_chunk)), - ], + "controllers.console.datasets.datasets_segments.db.session.scalar", + side_effect=[segment, child_chunk], ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", @@ -969,11 +967,8 @@ class TestChildChunkUpdateApi: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - side_effect=[ - MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), - MagicMock(where=lambda *a, **k: MagicMock(first=lambda: child_chunk)), - ], + "controllers.console.datasets.datasets_segments.db.session.scalar", + side_effect=[segment, child_chunk], ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", @@ -1179,8 +1174,8 @@ class TestSegmentOperationCases: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=upload_file, ), ): with pytest.raises(NotFound): @@ -1214,8 +1209,8 @@ class TestSegmentOperationCases: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=upload_file, ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", diff --git a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py index e7ae37ae454..710c9be684c 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services @@ -20,7 +21,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.model_runtime.errors.invoke import InvokeError from models.account import Account from services.dataset_service import DatasetService from services.hit_testing_service import HitTestingService diff --git a/api/tests/unit_tests/controllers/console/datasets/test_wraps.py b/api/tests/unit_tests/controllers/console/datasets/test_wraps.py index 90f00711c16..e358435de4a 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_wraps.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_wraps.py @@ -26,12 +26,9 @@ class TestGetRagPipeline: return_value=(Mock(), "tenant-1"), ) - mock_query = Mock() - mock_query.where.return_value.first.return_value = None - mocker.patch( - "controllers.console.datasets.wraps.db.session.query", - return_value=mock_query, + "controllers.console.datasets.wraps.db.session.scalar", + return_value=None, ) with pytest.raises(PipelineNotFoundError): @@ -51,12 +48,9 @@ class TestGetRagPipeline: return_value=(Mock(), "tenant-1"), ) - mock_query = Mock() - mock_query.where.return_value.first.return_value = pipeline - mocker.patch( - "controllers.console.datasets.wraps.db.session.query", - return_value=mock_query, + "controllers.console.datasets.wraps.db.session.scalar", + return_value=pipeline, ) result = dummy_view(pipeline_id="pipeline-1") @@ -76,12 +70,9 @@ class TestGetRagPipeline: return_value=(Mock(), "tenant-1"), ) - mock_query = Mock() - mock_query.where.return_value.first.return_value = pipeline - mocker.patch( - "controllers.console.datasets.wraps.db.session.query", - return_value=mock_query, + "controllers.console.datasets.wraps.db.session.scalar", + return_value=pipeline, ) result = dummy_view(pipeline_id="pipeline-1") @@ -100,18 +91,15 @@ class TestGetRagPipeline: return_value=(Mock(), "tenant-1"), ) - def where_side_effect(*args, **kwargs): - assert args[0].right.value == "123" - return Mock(first=lambda: pipeline) - - mock_query = Mock() - mock_query.where.side_effect = where_side_effect - - mocker.patch( - "controllers.console.datasets.wraps.db.session.query", - return_value=mock_query, + mock_scalar = mocker.patch( + "controllers.console.datasets.wraps.db.session.scalar", + return_value=pipeline, ) result = dummy_view(pipeline_id=123) assert result is pipeline + # Verify the pipeline_id was cast to string in the where clause + stmt = mock_scalar.call_args[0][0] + where_clauses = stmt.whereclause.clauses + assert where_clauses[0].right.value == "123" diff --git a/api/tests/unit_tests/controllers/console/explore/test_audio.py b/api/tests/unit_tests/controllers/console/explore/test_audio.py index 0afbc5a8f7b..66c9ba48c59 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_audio.py +++ b/api/tests/unit_tests/controllers/console/explore/test_audio.py @@ -2,6 +2,7 @@ from io import BytesIO from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import InternalServerError import controllers.console.explore.audio as audio_module @@ -19,7 +20,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.model_runtime.errors.invoke import InvokeError from services.errors.audio import ( AudioTooLargeServiceError, NoAudioUploadedServiceError, diff --git a/api/tests/unit_tests/controllers/console/explore/test_banner.py b/api/tests/unit_tests/controllers/console/explore/test_banner.py index 4414f1eb5f6..c8f674f5150 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_banner.py +++ b/api/tests/unit_tests/controllers/console/explore/test_banner.py @@ -24,13 +24,8 @@ class TestBannerApi: banner.status = BannerStatus.ENABLED banner.created_at = datetime(2024, 1, 1) - query = MagicMock() - query.where.return_value = query - query.order_by.return_value = query - query.all.return_value = [banner] - session = MagicMock() - session.query.return_value = query + session.scalars.return_value.all.return_value = [banner] with app.test_request_context("/?language=fr-FR"), patch.object(banner_module.db, "session", session): result = method(api) @@ -58,16 +53,14 @@ class TestBannerApi: banner.status = BannerStatus.ENABLED banner.created_at = None - query = MagicMock() - query.where.return_value = query - query.order_by.return_value = query - query.all.side_effect = [ + scalars_result = MagicMock() + scalars_result.all.side_effect = [ [], [banner], ] session = MagicMock() - session.query.return_value = query + session.scalars.return_value = scalars_result with app.test_request_context("/?language=es-ES"), patch.object(banner_module.db, "session", session): result = method(api) @@ -87,13 +80,8 @@ class TestBannerApi: api = banner_module.BannerApi() method = unwrap(api.get) - query = MagicMock() - query.where.return_value = query - query.order_by.return_value = query - query.all.return_value = [] - session = MagicMock() - session.query.return_value = query + session.scalars.return_value.all.return_value = [] with app.test_request_context("/"), patch.object(banner_module.db, "session", session): result = method(api) diff --git a/api/tests/unit_tests/controllers/console/explore/test_installed_app.py b/api/tests/unit_tests/controllers/console/explore/test_installed_app.py index 3983a6a97ef..93652e75d2c 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_installed_app.py +++ b/api/tests/unit_tests/controllers/console/explore/test_installed_app.py @@ -260,11 +260,10 @@ class TestInstalledAppsCreateApi: app_entity.tenant_id = "t2" session = MagicMock() - session.query.return_value.where.return_value.first.side_effect = [ - recommended, - app_entity, - None, - ] + # scalar() is called for recommended_app and installed_app lookups + session.scalar.side_effect = [recommended, None] + # get() is called for app PK lookup + session.get.return_value = app_entity with ( app.test_request_context("/", json={"app_id": "a1"}), @@ -282,7 +281,7 @@ class TestInstalledAppsCreateApi: method = unwrap(api.post) session = MagicMock() - session.query.return_value.where.return_value.first.return_value = None + session.scalar.return_value = None with ( app.test_request_context("/", json={"app_id": "a1"}), @@ -300,10 +299,10 @@ class TestInstalledAppsCreateApi: app_entity = MagicMock(is_public=False) session = MagicMock() - session.query.return_value.where.return_value.first.side_effect = [ - recommended, - app_entity, - ] + # scalar() returns recommended_app + session.scalar.return_value = recommended + # get() returns the app entity + session.get.return_value = app_entity with ( app.test_request_context("/", json={"app_id": "a1"}), diff --git a/api/tests/unit_tests/controllers/console/explore/test_message.py b/api/tests/unit_tests/controllers/console/explore/test_message.py index 6b5c304884e..2e4ca4f2a44 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_message.py +++ b/api/tests/unit_tests/controllers/console/explore/test_message.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import InternalServerError, NotFound import controllers.console.explore.message as module @@ -21,7 +22,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.model_runtime.errors.invoke import InvokeError from services.errors.conversation import ConversationNotExistsError from services.errors.message import ( FirstMessageNotExistsError, diff --git a/api/tests/unit_tests/controllers/console/explore/test_trial.py b/api/tests/unit_tests/controllers/console/explore/test_trial.py index d85114c8fbe..04beb31389c 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_trial.py +++ b/api/tests/unit_tests/controllers/console/explore/test_trial.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import controllers.console.explore.trial as module @@ -25,7 +26,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.model_runtime.errors.invoke import InvokeError from models import Account from models.account import TenantStatus from models.model import AppMode @@ -958,8 +958,8 @@ class TestTrialSitApi: app_model = MagicMock() app_model.id = "a1" - with app.test_request_context("/"), patch.object(module.db.session, "query") as mock_query: - mock_query.return_value.where.return_value.first.return_value = None + with app.test_request_context("/"), patch.object(module.db.session, "scalar") as mock_scalar: + mock_scalar.return_value = None with pytest.raises(Forbidden): method(api, app_model) @@ -973,8 +973,8 @@ class TestTrialSitApi: app_model.tenant = MagicMock() app_model.tenant.status = TenantStatus.ARCHIVE - with app.test_request_context("/"), patch.object(module.db.session, "query") as mock_query: - mock_query.return_value.where.return_value.first.return_value = site + with app.test_request_context("/"), patch.object(module.db.session, "scalar") as mock_scalar: + mock_scalar.return_value = site with pytest.raises(Forbidden): method(api, app_model) @@ -990,10 +990,10 @@ class TestTrialSitApi: with ( app.test_request_context("/"), - patch.object(module.db.session, "query") as mock_query, + patch.object(module.db.session, "scalar") as mock_scalar, patch.object(module.SiteResponse, "model_validate") as mock_validate, ): - mock_query.return_value.where.return_value.first.return_value = site + mock_scalar.return_value = site mock_validate_result = MagicMock() mock_validate_result.model_dump.return_value = {"name": "test", "icon": "icon"} mock_validate.return_value = mock_validate_result diff --git a/api/tests/unit_tests/controllers/console/explore/test_wraps.py b/api/tests/unit_tests/controllers/console/explore/test_wraps.py index 67e7a32591b..2c1acfc3d65 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_wraps.py +++ b/api/tests/unit_tests/controllers/console/explore/test_wraps.py @@ -34,9 +34,9 @@ def test_installed_app_required_not_found(): "controllers.console.explore.wraps.current_account_with_tenant", return_value=(MagicMock(), "tenant-1"), ), - patch("controllers.console.explore.wraps.db.session.query") as q, + patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock, ): - q.return_value.where.return_value.first.return_value = None + scalar_mock.return_value = None with pytest.raises(NotFound): view("app-id") @@ -54,11 +54,11 @@ def test_installed_app_required_app_deleted(): "controllers.console.explore.wraps.current_account_with_tenant", return_value=(MagicMock(), "tenant-1"), ), - patch("controllers.console.explore.wraps.db.session.query") as q, + patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock, patch("controllers.console.explore.wraps.db.session.delete"), patch("controllers.console.explore.wraps.db.session.commit"), ): - q.return_value.where.return_value.first.return_value = installed_app + scalar_mock.return_value = installed_app with pytest.raises(NotFound): view("app-id") @@ -76,9 +76,9 @@ def test_installed_app_required_success(): "controllers.console.explore.wraps.current_account_with_tenant", return_value=(MagicMock(), "tenant-1"), ), - patch("controllers.console.explore.wraps.db.session.query") as q, + patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock, ): - q.return_value.where.return_value.first.return_value = installed_app + scalar_mock.return_value = installed_app result = view("app-id") assert result == installed_app @@ -149,9 +149,9 @@ def test_trial_app_required_not_allowed(): "controllers.console.explore.wraps.current_account_with_tenant", return_value=(MagicMock(id="user-1"), None), ), - patch("controllers.console.explore.wraps.db.session.query") as q, + patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock, ): - q.return_value.where.return_value.first.return_value = None + scalar_mock.return_value = None with pytest.raises(TrialAppNotAllowed): view("app-id") @@ -170,9 +170,9 @@ def test_trial_app_required_limit_exceeded(): "controllers.console.explore.wraps.current_account_with_tenant", return_value=(MagicMock(id="user-1"), None), ), - patch("controllers.console.explore.wraps.db.session.query") as q, + patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock, ): - q.return_value.where.return_value.first.side_effect = [ + scalar_mock.side_effect = [ trial_app, record, ] @@ -194,9 +194,9 @@ def test_trial_app_required_success(): "controllers.console.explore.wraps.current_account_with_tenant", return_value=(MagicMock(id="user-1"), None), ), - patch("controllers.console.explore.wraps.db.session.query") as q, + patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock, ): - q.return_value.where.return_value.first.side_effect = [ + scalar_mock.side_effect = [ trial_app, record, ] diff --git a/api/tests/unit_tests/controllers/console/tag/test_tags.py b/api/tests/unit_tests/controllers/console/tag/test_tags.py index 769edc8d1cb..e89b89c8b11 100644 --- a/api/tests/unit_tests/controllers/console/tag/test_tags.py +++ b/api/tests/unit_tests/controllers/console/tag/test_tags.py @@ -11,6 +11,7 @@ from controllers.console.tag.tags import ( TagListApi, TagUpdateDeleteApi, ) +from models.enums import TagType def unwrap(func): @@ -52,7 +53,7 @@ def tag(): tag = MagicMock() tag.id = "tag-1" tag.name = "test-tag" - tag.type = "knowledge" + tag.type = TagType.KNOWLEDGE return tag diff --git a/api/tests/unit_tests/controllers/console/test_apikey.py b/api/tests/unit_tests/controllers/console/test_apikey.py index 018257f815a..2dff9c4037f 100644 --- a/api/tests/unit_tests/controllers/console/test_apikey.py +++ b/api/tests/unit_tests/controllers/console/test_apikey.py @@ -8,6 +8,7 @@ from controllers.console.apikey import ( BaseApiKeyResource, _get_resource, ) +from models.enums import ApiTokenType @pytest.fixture @@ -45,14 +46,14 @@ def bypass_permissions(): class DummyApiKeyListResource(BaseApiKeyListResource): - resource_type = "app" + resource_type = ApiTokenType.APP resource_model = MagicMock() resource_id_field = "app_id" token_prefix = "app-" class DummyApiKeyResource(BaseApiKeyResource): - resource_type = "app" + resource_type = ApiTokenType.APP resource_model = MagicMock() resource_id_field = "app_id" @@ -114,7 +115,7 @@ class TestBaseApiKeyResource: def test_delete_key_not_found(self, tenant_context_admin, db_mock): resource = DummyApiKeyResource() - db_mock.session.query.return_value.where.return_value.first.return_value = None + db_mock.session.scalar.return_value = None with patch("controllers.console.apikey._get_resource"): with pytest.raises(Exception) as exc_info: @@ -125,7 +126,7 @@ class TestBaseApiKeyResource: def test_delete_success(self, tenant_context_admin, db_mock): resource = DummyApiKeyResource() - db_mock.session.query.return_value.where.return_value.first.return_value = MagicMock() + db_mock.session.scalar.return_value = MagicMock() with ( patch("controllers.console.apikey._get_resource"), diff --git a/api/tests/unit_tests/controllers/console/test_wraps.py b/api/tests/unit_tests/controllers/console/test_wraps.py index 6777077de89..f6e096a97b2 100644 --- a/api/tests/unit_tests/controllers/console/test_wraps.py +++ b/api/tests/unit_tests/controllers/console/test_wraps.py @@ -328,7 +328,7 @@ class TestSystemSetup: def test_should_raise_not_init_validate_error_with_init_password(self, mock_environ_get, mock_db): """Test NotInitValidateError when INIT_PASSWORD is set but setup not complete""" # Arrange - mock_db.session.query.return_value.first.return_value = None # No setup + mock_db.session.scalar.return_value = None # No setup mock_environ_get.return_value = "some_password" @setup_required @@ -345,7 +345,7 @@ class TestSystemSetup: def test_should_raise_not_setup_error_without_init_password(self, mock_environ_get, mock_db): """Test NotSetupError when no INIT_PASSWORD and setup not complete""" # Arrange - mock_db.session.query.return_value.first.return_value = None # No setup + mock_db.session.scalar.return_value = None # No setup mock_environ_get.return_value = None # No INIT_PASSWORD @setup_required diff --git a/api/tests/unit_tests/controllers/console/workspace/test_accounts.py b/api/tests/unit_tests/controllers/console/workspace/test_accounts.py index 00d322fdea0..42be02cdaf1 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_accounts.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_accounts.py @@ -55,9 +55,9 @@ class TestAccountInitApi: patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")), patch("controllers.console.workspace.account.db.session.commit", return_value=None), patch("controllers.console.workspace.account.dify_config.EDITION", "CLOUD"), - patch("controllers.console.workspace.account.db.session.query") as query_mock, + patch("controllers.console.workspace.account.db.session.scalar") as scalar_mock, ): - query_mock.return_value.where.return_value.first.return_value = MagicMock(status="unused") + scalar_mock.return_value = MagicMock(status="unused") resp = method(api) assert resp["result"] == "success" diff --git a/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py b/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py index f2e57eb65f4..9c42ee9529a 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py @@ -11,11 +11,10 @@ from unittest.mock import MagicMock import pytest from flask import Flask from flask.views import MethodView +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from werkzeug.exceptions import Forbidden -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError - if not hasattr(builtins, "MethodView"): builtins.MethodView = MethodView # type: ignore[attr-defined] diff --git a/api/tests/unit_tests/controllers/console/workspace/test_members.py b/api/tests/unit_tests/controllers/console/workspace/test_members.py index b6708d1f6f6..718b57ba6ba 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_members.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_members.py @@ -207,10 +207,10 @@ class TestMemberCancelInviteApi: with ( app.test_request_context("/"), patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), - patch("controllers.console.workspace.members.db.session.query") as q, + patch("controllers.console.workspace.members.db.session.get") as get_mock, patch("controllers.console.workspace.members.TenantService.remove_member_from_tenant"), ): - q.return_value.where.return_value.first.return_value = member + get_mock.return_value = member result, status = method(api, member.id) assert status == 200 @@ -226,9 +226,9 @@ class TestMemberCancelInviteApi: with ( app.test_request_context("/"), patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), - patch("controllers.console.workspace.members.db.session.query") as q, + patch("controllers.console.workspace.members.db.session.get") as get_mock, ): - q.return_value.where.return_value.first.return_value = None + get_mock.return_value = None with pytest.raises(HTTPException): method(api, "x") @@ -244,13 +244,13 @@ class TestMemberCancelInviteApi: with ( app.test_request_context("/"), patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), - patch("controllers.console.workspace.members.db.session.query") as q, + patch("controllers.console.workspace.members.db.session.get") as get_mock, patch( "controllers.console.workspace.members.TenantService.remove_member_from_tenant", side_effect=services.errors.account.CannotOperateSelfError("x"), ), ): - q.return_value.where.return_value.first.return_value = member + get_mock.return_value = member result, status = method(api, member.id) assert status == 400 @@ -266,13 +266,13 @@ class TestMemberCancelInviteApi: with ( app.test_request_context("/"), patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), - patch("controllers.console.workspace.members.db.session.query") as q, + patch("controllers.console.workspace.members.db.session.get") as get_mock, patch( "controllers.console.workspace.members.TenantService.remove_member_from_tenant", side_effect=services.errors.account.NoPermissionError("x"), ), ): - q.return_value.where.return_value.first.return_value = member + get_mock.return_value = member result, status = method(api, member.id) assert status == 403 @@ -288,13 +288,13 @@ class TestMemberCancelInviteApi: with ( app.test_request_context("/"), patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), - patch("controllers.console.workspace.members.db.session.query") as q, + patch("controllers.console.workspace.members.db.session.get") as get_mock, patch( "controllers.console.workspace.members.TenantService.remove_member_from_tenant", side_effect=services.errors.account.MemberNotInTenantError(), ), ): - q.return_value.where.return_value.first.return_value = member + get_mock.return_value = member result, status = method(api, member.id) assert status == 404 diff --git a/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py index af0c2c55945..fb9eec98cb5 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic_core import ValidationError from werkzeug.exceptions import Forbidden @@ -13,7 +14,6 @@ from controllers.console.workspace.model_providers import ( ModelProviderValidateApi, PreferredProviderTypeUpdateApi, ) -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError VALID_UUID = "123e4567-e89b-12d3-a456-426614174000" INVALID_UUID = "123" diff --git a/api/tests/unit_tests/controllers/console/workspace/test_models.py b/api/tests/unit_tests/controllers/console/workspace/test_models.py index 43b8e1ac2e8..c829327bc7a 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_models.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_models.py @@ -2,6 +2,8 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from controllers.console.workspace.models import ( DefaultModelApi, @@ -14,8 +16,6 @@ from controllers.console.workspace.models import ( ModelProviderModelParameterRuleApi, ModelProviderModelValidateApi, ) -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError def unwrap(func): diff --git a/api/tests/unit_tests/controllers/console/workspace/test_plugin.py b/api/tests/unit_tests/controllers/console/workspace/test_plugin.py index eb19243225e..ce5fd1c4669 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_plugin.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_plugin.py @@ -90,8 +90,8 @@ class TestPluginListLatestVersionsApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginDebuggingKeyApi: @@ -120,8 +120,8 @@ class TestPluginDebuggingKeyApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginListApi: @@ -202,8 +202,9 @@ class TestPluginUploadFromPkgApi: patch("controllers.console.workspace.plugin.dify_config.PLUGIN_MAX_PACKAGE_SIZE", 0), patch("controllers.console.workspace.plugin.PluginService.upload_pkg") as upload_pkg_mock, ): - with pytest.raises(ValueError): + with pytest.raises(ValueError) as exc_info: method(api) + assert "File size exceeds the maximum allowed size" in str(exc_info.value) upload_pkg_mock.assert_not_called() @@ -365,8 +366,8 @@ class TestPluginListInstallationsFromIdsApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginUploadFromGithubApi: @@ -401,8 +402,8 @@ class TestPluginUploadFromGithubApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginUploadFromBundleApi: @@ -449,8 +450,9 @@ class TestPluginUploadFromBundleApi: patch("controllers.console.workspace.plugin.dify_config.PLUGIN_MAX_BUNDLE_SIZE", 0), patch("controllers.console.workspace.plugin.PluginService.upload_bundle") as upload_bundle_mock, ): - with pytest.raises(ValueError): + with pytest.raises(ValueError) as exc_info: method(api) + assert "File size exceeds the maximum allowed size" in str(exc_info.value) upload_bundle_mock.assert_not_called() @@ -495,8 +497,8 @@ class TestPluginInstallFromGithubApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginInstallFromMarketplaceApi: @@ -532,8 +534,8 @@ class TestPluginInstallFromMarketplaceApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginFetchMarketplacePkgApi: @@ -562,8 +564,8 @@ class TestPluginFetchMarketplacePkgApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginFetchManifestApi: @@ -595,8 +597,8 @@ class TestPluginFetchManifestApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginFetchInstallTasksApi: @@ -625,8 +627,8 @@ class TestPluginFetchInstallTasksApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginFetchInstallTaskApi: @@ -655,8 +657,8 @@ class TestPluginFetchInstallTaskApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api, "t") + result = method(api, "t") + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginDeleteInstallTaskApi: @@ -685,8 +687,8 @@ class TestPluginDeleteInstallTaskApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api, "t") + result = method(api, "t") + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginDeleteAllInstallTaskItemsApi: @@ -717,8 +719,8 @@ class TestPluginDeleteAllInstallTaskItemsApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginDeleteInstallTaskItemApi: @@ -747,8 +749,8 @@ class TestPluginDeleteInstallTaskItemApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api, "task1", "item1") + result = method(api, "task1", "item1") + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginUpgradeFromMarketplaceApi: @@ -790,8 +792,8 @@ class TestPluginUpgradeFromMarketplaceApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginUpgradeFromGithubApi: @@ -839,8 +841,8 @@ class TestPluginUpgradeFromGithubApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginFetchDynamicSelectOptionsWithCredentialsApi: @@ -894,8 +896,8 @@ class TestPluginFetchDynamicSelectOptionsWithCredentialsApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginChangePreferencesApi: diff --git a/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py index 94c3019d5e3..44feacf2ada 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py @@ -4,7 +4,7 @@ from __future__ import annotations import builtins import importlib -from contextlib import contextmanager +from contextlib import ExitStack, contextmanager from types import ModuleType, SimpleNamespace from unittest.mock import MagicMock, patch @@ -18,7 +18,6 @@ if not hasattr(builtins, "MethodView"): _CONTROLLER_MODULE: ModuleType | None = None _WRAPS_MODULE: ModuleType | None = None -_CONTROLLER_PATCHERS: list[patch] = [] @contextmanager @@ -37,6 +36,14 @@ def app() -> Flask: @pytest.fixture def controller_module(monkeypatch: pytest.MonkeyPatch): + """ + Import the controller with auth decorators neutralized only during import. + + The imported view classes retain those no-op decorators after import, so we + can restore the original globals immediately and avoid leaking auth patches + into unrelated tests such as libs.login unit coverage. + """ + module_name = "controllers.console.workspace.tool_providers" global _CONTROLLER_MODULE if _CONTROLLER_MODULE is None: @@ -51,13 +58,12 @@ def controller_module(monkeypatch: pytest.MonkeyPatch): ("controllers.console.wraps.is_admin_or_owner_required", _noop), ("controllers.console.wraps.enterprise_license_required", _noop), ] - for target, value in patch_targets: - patcher = patch(target, value) - patcher.start() - _CONTROLLER_PATCHERS.append(patcher) monkeypatch.setenv("DIFY_SETUP_READY", "true") - with _mock_db(): - _CONTROLLER_MODULE = importlib.import_module(module_name) + with ExitStack() as stack: + for target, value in patch_targets: + stack.enter_context(patch(target, value)) + with _mock_db(): + _CONTROLLER_MODULE = importlib.import_module(module_name) module = _CONTROLLER_MODULE monkeypatch.setattr(module, "jsonable_encoder", lambda payload: payload) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_workspace.py b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py index 06f666fa609..b2d13dbbdf3 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_workspace.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py @@ -1,4 +1,3 @@ -from datetime import datetime from io import BytesIO from unittest.mock import MagicMock, patch @@ -26,6 +25,7 @@ from controllers.console.workspace.workspace import ( WorkspacePermissionApi, ) from enums.cloud_plan import CloudPlan +from libs.datetime_utils import naive_utc_now from models.account import TenantStatus @@ -36,7 +36,7 @@ def unwrap(func): class TestTenantListApi: - def test_get_success(self, app): + def test_get_success_saas_path(self, app): api = TenantListApi() method = unwrap(api.get) @@ -44,19 +44,15 @@ class TestTenantListApi: id="t1", name="Tenant 1", status="active", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) tenant2 = MagicMock( id="t2", name="Tenant 2", status="active", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) - features = MagicMock() - features.billing.enabled = True - features.billing.subscription.plan = CloudPlan.SANDBOX - with ( app.test_request_context("/workspaces"), patch( @@ -66,15 +62,141 @@ class TestTenantListApi: "controllers.console.workspace.workspace.TenantService.get_join_tenants", return_value=[tenant1, tenant2], ), - patch("controllers.console.workspace.workspace.FeatureService.get_features", return_value=features), + patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False), + patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", True), + patch("controllers.console.workspace.workspace.dify_config.EDITION", "CLOUD"), + patch( + "controllers.console.workspace.workspace.BillingService.get_plan_bulk", + return_value={ + "t1": {"plan": CloudPlan.TEAM, "expiration_date": 0}, + "t2": {"plan": CloudPlan.PROFESSIONAL, "expiration_date": 0}, + }, + ) as get_plan_bulk_mock, + patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock, ): result, status = method(api) assert status == 200 assert len(result["workspaces"]) == 2 assert result["workspaces"][0]["current"] is True + assert result["workspaces"][0]["plan"] == CloudPlan.TEAM + assert result["workspaces"][1]["plan"] == CloudPlan.PROFESSIONAL + get_plan_bulk_mock.assert_called_once_with(["t1", "t2"]) + get_features_mock.assert_not_called() - def test_get_billing_disabled(self, app): + def test_get_saas_path_partial_fallback_does_not_gate_plan_on_billing_enabled(self, app): + """Bulk omits a tenant: resolve plan via subscription.plan only; billing.enabled is not used. + + billing.enabled is mocked False to prove the endpoint does not gate on it for this path + (SaaS contract treats enabled as on; display follows subscription.plan). + """ + api = TenantListApi() + method = unwrap(api.get) + + tenant1 = MagicMock( + id="t1", + name="Tenant 1", + status="active", + created_at=naive_utc_now(), + ) + tenant2 = MagicMock( + id="t2", + name="Tenant 2", + status="active", + created_at=naive_utc_now(), + ) + + features_t2 = MagicMock() + features_t2.billing.enabled = False + features_t2.billing.subscription.plan = CloudPlan.PROFESSIONAL + + with ( + app.test_request_context("/workspaces"), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + patch( + "controllers.console.workspace.workspace.TenantService.get_join_tenants", + return_value=[tenant1, tenant2], + ), + patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False), + patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", True), + patch("controllers.console.workspace.workspace.dify_config.EDITION", "CLOUD"), + patch( + "controllers.console.workspace.workspace.BillingService.get_plan_bulk", + return_value={"t1": {"plan": CloudPlan.TEAM, "expiration_date": 0}}, + ) as get_plan_bulk_mock, + patch( + "controllers.console.workspace.workspace.FeatureService.get_features", + return_value=features_t2, + ) as get_features_mock, + ): + result, status = method(api) + + assert status == 200 + assert result["workspaces"][0]["plan"] == CloudPlan.TEAM + assert result["workspaces"][1]["plan"] == CloudPlan.PROFESSIONAL + get_plan_bulk_mock.assert_called_once_with(["t1", "t2"]) + get_features_mock.assert_called_once_with("t2") + + def test_get_saas_path_falls_back_to_legacy_feature_path_on_bulk_error(self, app): + """Test fallback to FeatureService when bulk billing returns empty result. + + BillingService.get_plan_bulk catches exceptions internally and returns empty dict, + so we simulate the real failure mode by returning empty dict for non-empty input. + """ + api = TenantListApi() + method = unwrap(api.get) + + tenant1 = MagicMock( + id="t1", + name="Tenant 1", + status="active", + created_at=naive_utc_now(), + ) + tenant2 = MagicMock( + id="t2", + name="Tenant 2", + status="active", + created_at=naive_utc_now(), + ) + + features = MagicMock() + features.billing.enabled = False + features.billing.subscription.plan = CloudPlan.TEAM + + with ( + app.test_request_context("/workspaces"), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t2") + ), + patch( + "controllers.console.workspace.workspace.TenantService.get_join_tenants", + return_value=[tenant1, tenant2], + ), + patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False), + patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", True), + patch("controllers.console.workspace.workspace.dify_config.EDITION", "CLOUD"), + patch( + "controllers.console.workspace.workspace.BillingService.get_plan_bulk", + return_value={}, # Simulates real failure: empty result for non-empty input + ) as get_plan_bulk_mock, + patch( + "controllers.console.workspace.workspace.FeatureService.get_features", + return_value=features, + ) as get_features_mock, + patch("controllers.console.workspace.workspace.logger.warning") as logger_warning_mock, + ): + result, status = method(api) + + assert status == 200 + assert result["workspaces"][0]["plan"] == CloudPlan.TEAM + assert result["workspaces"][1]["plan"] == CloudPlan.TEAM + get_plan_bulk_mock.assert_called_once_with(["t1", "t2"]) + assert get_features_mock.call_count == 2 + logger_warning_mock.assert_called_once() + + def test_get_billing_disabled_community_path(self, app): api = TenantListApi() method = unwrap(api.get) @@ -82,11 +204,12 @@ class TestTenantListApi: id="t1", name="Tenant", status="active", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) features = MagicMock() features.billing.enabled = False + features.billing.subscription.plan = CloudPlan.SANDBOX with ( app.test_request_context("/workspaces"), @@ -98,15 +221,83 @@ class TestTenantListApi: "controllers.console.workspace.workspace.TenantService.get_join_tenants", return_value=[tenant], ), + patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False), + patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", False), + patch("controllers.console.workspace.workspace.dify_config.EDITION", "SELF_HOSTED"), patch( "controllers.console.workspace.workspace.FeatureService.get_features", return_value=features, - ), + ) as get_features_mock, ): result, status = method(api) assert status == 200 assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX + get_features_mock.assert_called_once_with("t1") + + def test_get_enterprise_only_skips_feature_service(self, app): + api = TenantListApi() + method = unwrap(api.get) + + tenant1 = MagicMock( + id="t1", + name="Tenant 1", + status="active", + created_at=naive_utc_now(), + ) + tenant2 = MagicMock( + id="t2", + name="Tenant 2", + status="active", + created_at=naive_utc_now(), + ) + + with ( + app.test_request_context("/workspaces"), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t2") + ), + patch( + "controllers.console.workspace.workspace.TenantService.get_join_tenants", + return_value=[tenant1, tenant2], + ), + patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", True), + patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", False), + patch("controllers.console.workspace.workspace.dify_config.EDITION", "SELF_HOSTED"), + patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock, + ): + result, status = method(api) + + assert status == 200 + assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX + assert result["workspaces"][1]["plan"] == CloudPlan.SANDBOX + assert result["workspaces"][0]["current"] is False + assert result["workspaces"][1]["current"] is True + get_features_mock.assert_not_called() + + def test_get_enterprise_only_with_empty_tenants(self, app): + api = TenantListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/workspaces"), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), None) + ), + patch( + "controllers.console.workspace.workspace.TenantService.get_join_tenants", + return_value=[], + ), + patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", True), + patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", False), + patch("controllers.console.workspace.workspace.dify_config.EDITION", "SELF_HOSTED"), + patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock, + ): + result, status = method(api) + + assert status == 200 + assert result["workspaces"] == [] + get_features_mock.assert_not_called() class TestWorkspaceListApi: @@ -114,7 +305,7 @@ class TestWorkspaceListApi: api = WorkspaceListApi() method = unwrap(api.get) - tenant = MagicMock(id="t1", name="T", status="active", created_at=datetime.utcnow()) + tenant = MagicMock(id="t1", name="T", status="active", created_at=naive_utc_now()) paginate_result = MagicMock( items=[tenant], @@ -140,7 +331,7 @@ class TestWorkspaceListApi: id="t1", name="T", status="active", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) paginate_result = MagicMock( @@ -258,12 +449,12 @@ class TestSwitchWorkspaceApi: "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") ), patch("controllers.console.workspace.workspace.TenantService.switch_tenant"), - patch("controllers.console.workspace.workspace.db.session.query") as query_mock, + patch("controllers.console.workspace.workspace.db.session.get") as get_mock, patch( "controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t2"} ), ): - query_mock.return_value.get.return_value = tenant + get_mock.return_value = tenant result = method(api) assert result["result"] == "success" @@ -297,9 +488,9 @@ class TestSwitchWorkspaceApi: return_value=(MagicMock(), "t1"), ), patch("controllers.console.workspace.workspace.TenantService.switch_tenant"), - patch("controllers.console.workspace.workspace.db.session.query") as query_mock, + patch("controllers.console.workspace.workspace.db.session.get") as get_mock, ): - query_mock.return_value.get.return_value = None + get_mock.return_value = None with pytest.raises(ValueError): method(api) diff --git a/api/tests/unit_tests/controllers/files/test_tool_files.py b/api/tests/unit_tests/controllers/files/test_tool_files.py index e5df7a1eea5..edb91c3f262 100644 --- a/api/tests/unit_tests/controllers/files/test_tool_files.py +++ b/api/tests/unit_tests/controllers/files/test_tool_files.py @@ -18,10 +18,10 @@ def fake_request(args: dict): class DummyToolFile: - def __init__(self, mimetype="text/plain", size=10, name="tool.txt"): - self.mimetype = mimetype + def __init__(self, mime_type="text/plain", size=10, filename="tool.txt"): + self.mime_type = mime_type self.size = size - self.name = name + self.filename = filename @pytest.fixture(autouse=True) @@ -87,8 +87,8 @@ class TestToolFileApi: stream = iter([b"data"]) tool_file = DummyToolFile( - mimetype="application/pdf", - name="doc.pdf", + mime_type="application/pdf", + filename="doc.pdf", ) mock_tool_file_manager.return_value.get_file_generator_by_tool_file_id.return_value = ( diff --git a/api/tests/unit_tests/controllers/inner_api/app/__init__.py b/api/tests/unit_tests/controllers/inner_api/app/__init__.py new file mode 100644 index 00000000000..8b137891791 --- /dev/null +++ b/api/tests/unit_tests/controllers/inner_api/app/__init__.py @@ -0,0 +1 @@ + diff --git a/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py b/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py new file mode 100644 index 00000000000..4a5f91cc5d5 --- /dev/null +++ b/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py @@ -0,0 +1,245 @@ +"""Unit tests for inner_api app DSL import/export endpoints. + +Tests Pydantic model validation, endpoint handler logic, and the +_get_active_account helper. Auth/setup decorators are tested separately +in test_auth_wraps.py; handler tests use inspect.unwrap() to bypass them. +""" + +import inspect +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from pydantic import ValidationError + +from controllers.inner_api.app.dsl import ( + EnterpriseAppDSLExport, + EnterpriseAppDSLImport, + InnerAppDSLImportPayload, + _get_active_account, +) +from services.app_dsl_service import ImportStatus + + +class TestInnerAppDSLImportPayload: + """Test InnerAppDSLImportPayload Pydantic model validation.""" + + def test_valid_payload_all_fields(self): + data = { + "yaml_content": "version: 0.6.0\nkind: app\n", + "creator_email": "user@example.com", + "name": "My App", + "description": "A test app", + } + payload = InnerAppDSLImportPayload.model_validate(data) + assert payload.yaml_content == data["yaml_content"] + assert payload.creator_email == "user@example.com" + assert payload.name == "My App" + assert payload.description == "A test app" + + def test_valid_payload_optional_fields_omitted(self): + data = { + "yaml_content": "version: 0.6.0\n", + "creator_email": "user@example.com", + } + payload = InnerAppDSLImportPayload.model_validate(data) + assert payload.name is None + assert payload.description is None + + def test_missing_yaml_content_fails(self): + with pytest.raises(ValidationError) as exc_info: + InnerAppDSLImportPayload.model_validate({"creator_email": "a@b.com"}) + assert "yaml_content" in str(exc_info.value) + + def test_missing_creator_email_fails(self): + with pytest.raises(ValidationError) as exc_info: + InnerAppDSLImportPayload.model_validate({"yaml_content": "test"}) + assert "creator_email" in str(exc_info.value) + + +class TestGetActiveAccount: + """Test the _get_active_account helper function.""" + + @patch("controllers.inner_api.app.dsl.db") + def test_returns_active_account(self, mock_db): + mock_account = MagicMock() + mock_account.status = "active" + mock_db.session.scalar.return_value = mock_account + + result = _get_active_account("user@example.com") + + assert result is mock_account + mock_db.session.scalar.assert_called_once() + + @patch("controllers.inner_api.app.dsl.db") + def test_returns_none_for_inactive_account(self, mock_db): + mock_account = MagicMock() + mock_account.status = "banned" + mock_db.session.scalar.return_value = mock_account + + result = _get_active_account("banned@example.com") + + assert result is None + + @patch("controllers.inner_api.app.dsl.db") + def test_returns_none_for_nonexistent_email(self, mock_db): + mock_db.session.scalar.return_value = None + + result = _get_active_account("missing@example.com") + + assert result is None + + +class TestEnterpriseAppDSLImport: + """Test EnterpriseAppDSLImport endpoint handler logic. + + Uses inspect.unwrap() to bypass auth/setup decorators. + """ + + @pytest.fixture + def api_instance(self): + return EnterpriseAppDSLImport() + + @pytest.fixture + def _mock_import_deps(self): + """Patch db, Session, and AppDslService for import handler tests.""" + with ( + patch("controllers.inner_api.app.dsl.db"), + patch("controllers.inner_api.app.dsl.Session") as mock_session, + patch("controllers.inner_api.app.dsl.AppDslService") as mock_dsl_cls, + ): + mock_session.return_value.__enter__ = MagicMock(return_value=MagicMock()) + mock_session.return_value.__exit__ = MagicMock(return_value=False) + self._mock_dsl = MagicMock() + mock_dsl_cls.return_value = self._mock_dsl + yield + + def _make_import_result(self, status: ImportStatus, **kwargs) -> "Import": + from services.app_dsl_service import Import + + result = Import( + id="import-id", + status=status, + app_id=kwargs.get("app_id", "app-123"), + app_mode=kwargs.get("app_mode", "workflow"), + ) + return result + + @pytest.mark.usefixtures("_mock_import_deps") + @patch("controllers.inner_api.app.dsl._get_active_account") + def test_import_success_returns_200(self, mock_get_account, api_instance, app: Flask): + mock_account = MagicMock() + mock_get_account.return_value = mock_account + self._mock_dsl.import_app.return_value = self._make_import_result(ImportStatus.COMPLETED) + + unwrapped = inspect.unwrap(api_instance.post) + with app.test_request_context(): + with patch("controllers.inner_api.app.dsl.inner_api_ns") as mock_ns: + mock_ns.payload = { + "yaml_content": "version: 0.6.0\n", + "creator_email": "user@example.com", + } + result = unwrapped(api_instance, workspace_id="ws-123") + + body, status_code = result + assert status_code == 200 + assert body["status"] == "completed" + mock_account.set_tenant_id.assert_called_once_with("ws-123") + + @pytest.mark.usefixtures("_mock_import_deps") + @patch("controllers.inner_api.app.dsl._get_active_account") + def test_import_pending_returns_202(self, mock_get_account, api_instance, app: Flask): + mock_get_account.return_value = MagicMock() + self._mock_dsl.import_app.return_value = self._make_import_result(ImportStatus.PENDING) + + unwrapped = inspect.unwrap(api_instance.post) + with app.test_request_context(): + with patch("controllers.inner_api.app.dsl.inner_api_ns") as mock_ns: + mock_ns.payload = {"yaml_content": "test", "creator_email": "u@e.com"} + body, status_code = unwrapped(api_instance, workspace_id="ws-123") + + assert status_code == 202 + assert body["status"] == "pending" + + @pytest.mark.usefixtures("_mock_import_deps") + @patch("controllers.inner_api.app.dsl._get_active_account") + def test_import_failed_returns_400(self, mock_get_account, api_instance, app: Flask): + mock_get_account.return_value = MagicMock() + self._mock_dsl.import_app.return_value = self._make_import_result(ImportStatus.FAILED) + + unwrapped = inspect.unwrap(api_instance.post) + with app.test_request_context(): + with patch("controllers.inner_api.app.dsl.inner_api_ns") as mock_ns: + mock_ns.payload = {"yaml_content": "test", "creator_email": "u@e.com"} + body, status_code = unwrapped(api_instance, workspace_id="ws-123") + + assert status_code == 400 + assert body["status"] == "failed" + + @patch("controllers.inner_api.app.dsl._get_active_account") + def test_import_account_not_found_returns_404(self, mock_get_account, api_instance, app: Flask): + mock_get_account.return_value = None + + unwrapped = inspect.unwrap(api_instance.post) + with app.test_request_context(): + with patch("controllers.inner_api.app.dsl.inner_api_ns") as mock_ns: + mock_ns.payload = {"yaml_content": "test", "creator_email": "missing@e.com"} + result = unwrapped(api_instance, workspace_id="ws-123") + + body, status_code = result + assert status_code == 404 + assert "missing@e.com" in body["message"] + + +class TestEnterpriseAppDSLExport: + """Test EnterpriseAppDSLExport endpoint handler logic. + + Uses inspect.unwrap() to bypass auth/setup decorators. + """ + + @pytest.fixture + def api_instance(self): + return EnterpriseAppDSLExport() + + @patch("controllers.inner_api.app.dsl.AppDslService") + @patch("controllers.inner_api.app.dsl.db") + def test_export_success_returns_200(self, mock_db, mock_dsl_cls, api_instance, app: Flask): + mock_app = MagicMock() + mock_db.session.get.return_value = mock_app + mock_dsl_cls.export_dsl.return_value = "version: 0.6.0\nkind: app\n" + + unwrapped = inspect.unwrap(api_instance.get) + with app.test_request_context("?include_secret=false"): + result = unwrapped(api_instance, app_id="app-123") + + body, status_code = result + assert status_code == 200 + assert body["data"] == "version: 0.6.0\nkind: app\n" + mock_dsl_cls.export_dsl.assert_called_once_with(app_model=mock_app, include_secret=False) + + @patch("controllers.inner_api.app.dsl.AppDslService") + @patch("controllers.inner_api.app.dsl.db") + def test_export_with_secret(self, mock_db, mock_dsl_cls, api_instance, app: Flask): + mock_app = MagicMock() + mock_db.session.get.return_value = mock_app + mock_dsl_cls.export_dsl.return_value = "yaml-data" + + unwrapped = inspect.unwrap(api_instance.get) + with app.test_request_context("?include_secret=true"): + result = unwrapped(api_instance, app_id="app-123") + + body, status_code = result + assert status_code == 200 + mock_dsl_cls.export_dsl.assert_called_once_with(app_model=mock_app, include_secret=True) + + @patch("controllers.inner_api.app.dsl.db") + def test_export_app_not_found_returns_404(self, mock_db, api_instance, app: Flask): + mock_db.session.get.return_value = None + + unwrapped = inspect.unwrap(api_instance.get) + with app.test_request_context("?include_secret=false"): + result = unwrapped(api_instance, app_id="nonexistent") + + body, status_code = result + assert status_code == 404 + assert "app not found" in body["message"] diff --git a/api/tests/unit_tests/controllers/service_api/app/test_app.py b/api/tests/unit_tests/controllers/service_api/app/test_app.py index f8e9cf9b801..1507bf7a5fc 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_app.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_app.py @@ -65,7 +65,7 @@ class TestAppParameterApi: mock_tenant.status = "normal" # Mock DB queries for app and tenant - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -112,7 +112,7 @@ class TestAppParameterApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -153,7 +153,7 @@ class TestAppParameterApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -192,7 +192,7 @@ class TestAppParameterApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -255,7 +255,7 @@ class TestAppMetaApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -323,7 +323,7 @@ class TestAppInfoApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -380,7 +380,7 @@ class TestAppInfoApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app, mock_tenant, ] @@ -426,7 +426,7 @@ class TestAppInfoApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app, mock_tenant, ] @@ -478,7 +478,7 @@ class TestAppInfoApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app, mock_tenant, ] diff --git a/api/tests/unit_tests/controllers/service_api/app/test_audio.py b/api/tests/unit_tests/controllers/service_api/app/test_audio.py index 1923ab7fa7b..5a8cb4619f4 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_audio.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_audio.py @@ -13,6 +13,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.datastructures import FileStorage from werkzeug.exceptions import InternalServerError @@ -29,7 +30,6 @@ from controllers.service_api.app.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.audio import ( diff --git a/api/tests/unit_tests/controllers/service_api/app/test_completion.py b/api/tests/unit_tests/controllers/service_api/app/test_completion.py index 4e4482f7049..57681d8f5bc 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_completion.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_completion.py @@ -16,6 +16,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import ValidationError from werkzeug.exceptions import BadRequest, NotFound @@ -34,7 +35,6 @@ from controllers.service_api.app.error import ( NotChatAppError, ) from core.errors.error import QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService from services.app_task_service import AppTaskService diff --git a/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py b/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py index 1bdcd0f1a31..d83c22f2cf6 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py @@ -79,10 +79,13 @@ class TestFilePreviewApi: mock_message_file.message_id = mock_message.id with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock database queries - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + # Mock scalar() for MessageFile and Message queries + mock_db.session.scalar.side_effect = [ mock_message_file, # MessageFile query mock_message, # Message query + ] + # Mock get() for UploadFile and App PK lookups + mock_db.session.get.side_effect = [ mock_upload_file, # UploadFile query mock_app, # App query for tenant validation ] @@ -100,8 +103,8 @@ class TestFilePreviewApi: app_id = str(uuid.uuid4()) with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock MessageFile not found - mock_db.session.query.return_value.where.return_value.first.return_value = None + # Mock MessageFile not found via scalar() + mock_db.session.scalar.return_value = None # Execute and assert exception with pytest.raises(FileNotFoundError) as exc_info: @@ -115,8 +118,8 @@ class TestFilePreviewApi: app_id = str(uuid.uuid4()) with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock MessageFile found but Message not owned by app - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + # Mock MessageFile found but Message not owned by app via scalar() + mock_db.session.scalar.side_effect = [ mock_message_file, # MessageFile query - found None, # Message query - not found (access denied) ] @@ -133,12 +136,13 @@ class TestFilePreviewApi: app_id = str(uuid.uuid4()) with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock MessageFile and Message found but UploadFile not found - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + # Mock scalar() for MessageFile and Message + mock_db.session.scalar.side_effect = [ mock_message_file, # MessageFile query - found mock_message, # Message query - found - None, # UploadFile query - not found ] + # Mock get() for UploadFile - not found + mock_db.session.get.return_value = None # Execute and assert exception with pytest.raises(FileNotFoundError) as exc_info: @@ -161,10 +165,13 @@ class TestFilePreviewApi: mock_message_file.message_id = mock_message.id with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock database queries - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + # Mock scalar() for MessageFile and Message queries + mock_db.session.scalar.side_effect = [ mock_message_file, # MessageFile query mock_message, # Message query + ] + # Mock get() for UploadFile and App PK lookups + mock_db.session.get.side_effect = [ mock_upload_file, # UploadFile query mock_app, # App query for tenant validation ] @@ -262,10 +269,13 @@ class TestFilePreviewApi: mock_storage.load.return_value = mock_generator with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock database queries - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + # Mock scalar() for MessageFile and Message queries + mock_db.session.scalar.side_effect = [ mock_message_file, # MessageFile query mock_message, # Message query + ] + # Mock get() for UploadFile and App PK lookups + mock_db.session.get.side_effect = [ mock_upload_file, # UploadFile query mock_app, # App query for tenant validation ] @@ -301,10 +311,13 @@ class TestFilePreviewApi: mock_storage.load.side_effect = Exception("Storage error") with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock database queries for validation - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + # Mock scalar() for MessageFile and Message queries + mock_db.session.scalar.side_effect = [ mock_message_file, # MessageFile query mock_message, # Message query + ] + # Mock get() for UploadFile and App PK lookups + mock_db.session.get.side_effect = [ mock_upload_file, # UploadFile query mock_app, # App query for tenant validation ] @@ -327,8 +340,8 @@ class TestFilePreviewApi: app_id = str(uuid.uuid4()) with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock database query to raise unexpected exception - mock_db.session.query.side_effect = Exception("Unexpected database error") + # Mock database scalar to raise unexpected exception + mock_db.session.scalar.side_effect = Exception("Unexpected database error") # Execute and assert exception with pytest.raises(FileAccessDeniedError) as exc_info: diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py index 4eada73b82c..b1f036c6f36 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py @@ -19,6 +19,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from graphon.enums import WorkflowExecutionStatus from werkzeug.exceptions import BadRequest, NotFound from controllers.service_api.app.error import NotWorkflowAppError @@ -35,7 +36,6 @@ from controllers.service_api.app.workflow import ( WorkflowTaskStopApi, ) from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError -from dify_graph.enums import WorkflowExecutionStatus from models.model import App, AppMode from services.app_generate_service import AppGenerateService from services.errors.app import IsDraftWorkflowError, WorkflowNotFoundError @@ -315,7 +315,7 @@ class TestWorkflowStopMechanism: def test_graph_engine_manager_has_send_stop_command(self): """Test GraphEngineManager has send_stop_command method.""" - from dify_graph.graph_engine.manager import GraphEngineManager + from graphon.graph_engine.manager import GraphEngineManager assert hasattr(GraphEngineManager, "send_stop_command") diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py index 9e95f45a0af..4b8e3a738cb 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py @@ -1,7 +1,8 @@ from types import SimpleNamespace +from graphon.enums import WorkflowExecutionStatus + from controllers.service_api.app.workflow import WorkflowRunOutputsField, WorkflowRunStatusField -from dify_graph.enums import WorkflowExecutionStatus def test_workflow_run_status_field_with_enum() -> None: diff --git a/api/tests/unit_tests/controllers/service_api/conftest.py b/api/tests/unit_tests/controllers/service_api/conftest.py index 4337a0c8c0e..eddba5a5170 100644 --- a/api/tests/unit_tests/controllers/service_api/conftest.py +++ b/api/tests/unit_tests/controllers/service_api/conftest.py @@ -12,6 +12,7 @@ from unittest.mock import Mock import pytest from flask import Flask +from core.rag.index_processor.constant.index_type import IndexStructureType from models.account import TenantStatus from models.model import App, AppMode, EndUser from tests.unit_tests.conftest import setup_mock_tenant_account_query @@ -118,11 +119,8 @@ class AuthenticationMocker: @staticmethod def setup_db_queries(mock_db, mock_app, mock_tenant, mock_account=None): - """Configure mock_db to return app and tenant in sequence.""" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ - mock_app, - mock_tenant, - ] + """Configure mock_db to return app and tenant via session.get().""" + mock_db.session.get.side_effect = [mock_app, mock_tenant] if mock_account: mock_ta = Mock() @@ -135,11 +133,9 @@ class AuthenticationMocker: mock_ta = Mock() mock_ta.account_id = mock_account.id - mock_query = mock_db.session.query.return_value - target_mock = mock_query.where.return_value.where.return_value.where.return_value.where.return_value - target_mock.one_or_none.return_value = (mock_tenant, mock_ta) + mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_account + mock_db.session.get.return_value = mock_account @pytest.fixture @@ -175,7 +171,7 @@ def mock_document(): document.name = "test_document.txt" document.indexing_status = "completed" document.enabled = True - document.doc_form = "text_model" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX return document diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py index 7cb2f1050ca..910d781cd02 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py @@ -35,6 +35,7 @@ from controllers.service_api.dataset.dataset import ( from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError from models.account import Account from models.dataset import DatasetPermissionEnum +from models.enums import TagType from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService from services.tag_service import TagService @@ -277,7 +278,7 @@ class TestDatasetTagsApi: mock_tag = Mock() mock_tag.id = "tag_1" mock_tag.name = "Test Tag" - mock_tag.type = "knowledge" + mock_tag.type = TagType.KNOWLEDGE mock_tag.binding_count = "0" # Required for Pydantic validation - must be string mock_tag_service.get_tags.return_value = [mock_tag] @@ -316,7 +317,7 @@ class TestDatasetTagsApi: mock_tag = Mock() mock_tag.id = "new_tag_1" mock_tag.name = "New Tag" - mock_tag.type = "knowledge" + mock_tag.type = TagType.KNOWLEDGE mock_tag_service.save_tags.return_value = mock_tag mock_service_api_ns.payload = {"name": "New Tag"} @@ -378,7 +379,7 @@ class TestDatasetTagsApi: mock_tag = Mock() mock_tag.id = "tag_1" mock_tag.name = "Updated Tag" - mock_tag.type = "knowledge" + mock_tag.type = TagType.KNOWLEDGE mock_tag.binding_count = "5" mock_tag_service.update_tags.return_value = mock_tag mock_tag_service.get_tag_binding_count.return_value = 5 @@ -866,7 +867,7 @@ class TestTagService: mock_tag = Mock() mock_tag.id = str(uuid.uuid4()) mock_tag.name = "New Tag" - mock_tag.type = "knowledge" + mock_tag.type = TagType.KNOWLEDGE mock_save.return_value = mock_tag result = TagService.save_tags({"name": "New Tag", "type": "knowledge"}) @@ -941,11 +942,11 @@ class TestDatasetListApiGet: """Test suite for DatasetListApi.get() endpoint. ``get`` has no billing decorators but calls ``current_user``, - ``DatasetService``, ``ProviderManager``, and ``marshal``. + ``DatasetService``, ``create_plugin_provider_manager``, and ``marshal``. """ @patch("controllers.service_api.dataset.dataset.marshal") - @patch("controllers.service_api.dataset.dataset.ProviderManager") + @patch("controllers.service_api.dataset.dataset.create_plugin_provider_manager") @patch("controllers.service_api.dataset.dataset.current_user") @patch("controllers.service_api.dataset.dataset.DatasetService") def test_list_datasets_success( @@ -1043,12 +1044,12 @@ class TestDatasetApiGet: """Test suite for DatasetApi.get() endpoint. ``get`` has no billing decorators but calls ``DatasetService``, - ``ProviderManager``, ``marshal``, and ``current_user``. + ``create_plugin_provider_manager``, ``marshal``, and ``current_user``. """ @patch("controllers.service_api.dataset.dataset.DatasetPermissionService") @patch("controllers.service_api.dataset.dataset.marshal") - @patch("controllers.service_api.dataset.dataset.ProviderManager") + @patch("controllers.service_api.dataset.dataset.create_plugin_provider_manager") @patch("controllers.service_api.dataset.dataset.current_user") @patch("controllers.service_api.dataset.dataset.DatasetService") def test_get_dataset_success( diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py index 5c48ef18040..e9c3e6d3769 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py @@ -31,6 +31,7 @@ from controllers.service_api.dataset.segment import ( SegmentCreatePayload, SegmentListQuery, ) +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import ChildChunk, Dataset, Document, DocumentSegment from models.enums import IndexingStatus from services.dataset_service import DocumentService, SegmentService @@ -767,6 +768,7 @@ class TestSegmentApiGet: ``current_account_with_tenant()`` and ``marshal``. """ + @patch("controllers.service_api.dataset.segment.SummaryIndexService") @patch("controllers.service_api.dataset.segment.marshal") @patch("controllers.service_api.dataset.segment.SegmentService") @patch("controllers.service_api.dataset.segment.DocumentService") @@ -779,6 +781,7 @@ class TestSegmentApiGet: mock_doc_svc, mock_seg_svc, mock_marshal, + mock_summary_svc, app, mock_tenant, mock_dataset, @@ -787,10 +790,11 @@ class TestSegmentApiGet: """Test successful segment list retrieval.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset - mock_doc_svc.get_document.return_value = Mock(doc_form="text_model") + mock_db.session.scalar.return_value = mock_dataset + mock_doc_svc.get_document.return_value = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX) mock_seg_svc.get_segments.return_value = ([mock_segment], 1) - mock_marshal.return_value = [{"id": mock_segment.id}] + mock_marshal.return_value = {"id": mock_segment.id} + mock_summary_svc.get_segments_summaries.return_value = {} # Act with app.test_request_context( @@ -812,7 +816,7 @@ class TestSegmentApiGet: """Test 404 when dataset not found.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -832,7 +836,7 @@ class TestSegmentApiGet: """Test 404 when document not found.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = None # Act & Assert @@ -871,6 +875,7 @@ class TestSegmentApiPost: mock_rate_limit.enabled = False mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit + @patch("controllers.service_api.dataset.segment.SummaryIndexService") @patch("controllers.service_api.dataset.segment.marshal") @patch("controllers.service_api.dataset.segment.SegmentService") @patch("controllers.service_api.dataset.segment.DocumentService") @@ -887,6 +892,7 @@ class TestSegmentApiPost: mock_doc_svc, mock_seg_svc, mock_marshal, + mock_summary_svc, app, mock_tenant, mock_dataset, @@ -898,17 +904,18 @@ class TestSegmentApiPost: mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_dataset.indexing_technique = "economy" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc = Mock() mock_doc.indexing_status = "completed" mock_doc.enabled = True - mock_doc.doc_form = "text_model" + mock_doc.doc_form = IndexStructureType.PARAGRAPH_INDEX mock_doc_svc.get_document.return_value = mock_doc mock_seg_svc.segment_create_args_validate.return_value = None mock_seg_svc.multi_create_segment.return_value = [mock_segment] - mock_marshal.return_value = [{"id": mock_segment.id}] + mock_marshal.return_value = {"id": mock_segment.id} + mock_summary_svc.get_segments_summaries.return_value = {} segments_data = [{"content": "Test segment content", "answer": "Test answer"}] @@ -949,7 +956,7 @@ class TestSegmentApiPost: mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_dataset.indexing_technique = "economy" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc = Mock() mock_doc.indexing_status = "completed" @@ -991,7 +998,7 @@ class TestSegmentApiPost: self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc = Mock() mock_doc.indexing_status = "indexing" # Not completed @@ -1042,7 +1049,7 @@ class TestDatasetSegmentApiDelete: """Test successful segment deletion.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc = Mock() @@ -1086,12 +1093,12 @@ class TestDatasetSegmentApiDelete: """Test 404 when segment not found.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc = Mock() mock_doc.indexing_status = "completed" mock_doc.enabled = True - mock_doc.doc_form = "text_model" + mock_doc.doc_form = IndexStructureType.PARAGRAPH_INDEX mock_doc_svc.get_document.return_value = mock_doc mock_seg_svc.get_segment_by_id.return_value = None # Segment not found @@ -1128,7 +1135,7 @@ class TestDatasetSegmentApiDelete: """Test 404 when dataset not found for delete.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -1162,7 +1169,7 @@ class TestDatasetSegmentApiDelete: """Test 404 when document not found for delete.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc_svc.get_document.return_value = None @@ -1205,6 +1212,7 @@ class TestDatasetSegmentApiUpdate: mock_rate_limit.enabled = False mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit + @patch("controllers.service_api.dataset.segment.SummaryIndexService") @patch("controllers.service_api.dataset.segment.marshal") @patch("controllers.service_api.dataset.segment.SegmentService") @patch("controllers.service_api.dataset.segment.DocumentService") @@ -1223,6 +1231,7 @@ class TestDatasetSegmentApiUpdate: mock_doc_svc, mock_seg_svc, mock_marshal, + mock_summary_svc, app, mock_tenant, mock_dataset, @@ -1232,13 +1241,14 @@ class TestDatasetSegmentApiUpdate: self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_dataset.indexing_technique = "economy" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = mock_segment updated = Mock() mock_seg_svc.update_segment.return_value = updated mock_marshal.return_value = {"id": mock_segment.id} + mock_summary_svc.get_segment_summary.return_value = None with app.test_request_context( f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{mock_segment.id}", @@ -1279,7 +1289,7 @@ class TestDatasetSegmentApiUpdate: """Test 404 when dataset not found for update.""" self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None with app.test_request_context( f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id", @@ -1320,7 +1330,7 @@ class TestDatasetSegmentApiUpdate: self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_dataset.indexing_technique = "economy" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = None @@ -1348,6 +1358,7 @@ class TestDatasetSegmentApiGetSingle: ``current_account_with_tenant()`` and ``marshal``. """ + @patch("controllers.service_api.dataset.segment.SummaryIndexService") @patch("controllers.service_api.dataset.segment.marshal") @patch("controllers.service_api.dataset.segment.SegmentService") @patch("controllers.service_api.dataset.segment.DocumentService") @@ -1362,6 +1373,7 @@ class TestDatasetSegmentApiGetSingle: mock_doc_svc, mock_seg_svc, mock_marshal, + mock_summary_svc, app, mock_tenant, mock_dataset, @@ -1369,12 +1381,13 @@ class TestDatasetSegmentApiGetSingle: ): """Test successful single segment retrieval.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None - mock_doc = Mock(doc_form="text_model") + mock_doc = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX) mock_doc_svc.get_document.return_value = mock_doc mock_seg_svc.get_segment_by_id.return_value = mock_segment mock_marshal.return_value = {"id": mock_segment.id} + mock_summary_svc.get_segment_summary.return_value = None with app.test_request_context( f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{mock_segment.id}", @@ -1390,7 +1403,56 @@ class TestDatasetSegmentApiGetSingle: assert status == 200 assert "data" in response - assert response["doc_form"] == "text_model" + assert response["doc_form"] == IndexStructureType.PARAGRAPH_INDEX + + @patch("controllers.service_api.dataset.segment.SummaryIndexService") + @patch("controllers.service_api.dataset.segment.marshal") + @patch("controllers.service_api.dataset.segment.SegmentService") + @patch("controllers.service_api.dataset.segment.DocumentService") + @patch("controllers.service_api.dataset.segment.DatasetService") + @patch("controllers.service_api.dataset.segment.current_account_with_tenant") + @patch("controllers.service_api.dataset.segment.db") + def test_get_single_segment_includes_summary( + self, + mock_db, + mock_account_fn, + mock_dataset_svc, + mock_doc_svc, + mock_seg_svc, + mock_marshal, + mock_summary_svc, + app, + mock_tenant, + mock_dataset, + mock_segment, + ): + """Test that single segment response includes summary content from SummaryIndexService.""" + mock_account_fn.return_value = (Mock(), mock_tenant.id) + mock_db.session.scalar.return_value = mock_dataset + mock_dataset_svc.check_dataset_model_setting.return_value = None + mock_doc = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX) + mock_doc_svc.get_document.return_value = mock_doc + mock_seg_svc.get_segment_by_id.return_value = mock_segment + mock_marshal.return_value = {"id": mock_segment.id, "summary": None} + + mock_summary_record = Mock() + mock_summary_record.summary_content = "This is the segment summary" + mock_summary_svc.get_segment_summary.return_value = mock_summary_record + + with app.test_request_context( + f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{mock_segment.id}", + method="GET", + ): + api = DatasetSegmentApi() + response, status = api.get( + tenant_id=mock_tenant.id, + dataset_id=mock_dataset.id, + document_id="doc-id", + segment_id=mock_segment.id, + ) + + assert status == 200 + assert response["data"]["summary"] == "This is the segment summary" @patch("controllers.service_api.dataset.segment.current_account_with_tenant") @patch("controllers.service_api.dataset.segment.db") @@ -1404,7 +1466,7 @@ class TestDatasetSegmentApiGetSingle: ): """Test 404 when dataset not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None with app.test_request_context( f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id", @@ -1435,7 +1497,7 @@ class TestDatasetSegmentApiGetSingle: ): """Test 404 when document not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc_svc.get_document.return_value = None @@ -1470,7 +1532,7 @@ class TestDatasetSegmentApiGetSingle: ): """Test 404 when segment not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = None @@ -1514,7 +1576,7 @@ class TestChildChunkApiGet: ): """Test successful child chunk list retrieval.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = Mock() @@ -1553,7 +1615,7 @@ class TestChildChunkApiGet: ): """Test 404 when dataset not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None with app.test_request_context( f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id/child_chunks", @@ -1582,7 +1644,7 @@ class TestChildChunkApiGet: ): """Test 404 when document not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = None with app.test_request_context( @@ -1614,7 +1676,7 @@ class TestChildChunkApiGet: ): """Test 404 when segment not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = None @@ -1675,7 +1737,7 @@ class TestChildChunkApiPost: self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_dataset.indexing_technique = "economy" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = Mock() mock_child = Mock() @@ -1716,7 +1778,7 @@ class TestChildChunkApiPost: """Test 404 when dataset not found.""" self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None with app.test_request_context( f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id/child_chunks", @@ -1754,7 +1816,7 @@ class TestChildChunkApiPost: """Test 404 when segment not found.""" self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = None @@ -1807,7 +1869,7 @@ class TestDatasetChildChunkApiDelete: ): """Test successful child chunk deletion.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc = Mock() mock_doc_svc.get_document.return_value = mock_doc @@ -1857,7 +1919,7 @@ class TestDatasetChildChunkApiDelete: ): """Test 404 when child chunk not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() segment_id = str(uuid.uuid4()) @@ -1898,7 +1960,7 @@ class TestDatasetChildChunkApiDelete: ): """Test 404 when segment does not belong to the document.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() segment_id = str(uuid.uuid4()) @@ -1938,7 +2000,7 @@ class TestDatasetChildChunkApiDelete: ): """Test 404 when child chunk does not belong to the segment.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() segment_id = str(uuid.uuid4()) diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py index e6e841be199..12d5e7345d2 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py @@ -35,6 +35,7 @@ from controllers.service_api.dataset.document import ( InvalidMetadataError, ) from controllers.service_api.dataset.error import ArchivedDocumentImmutableError +from core.rag.index_processor.constant.index_type import IndexStructureType from models.enums import IndexingStatus from services.dataset_service import DocumentService from services.entities.knowledge_entities.knowledge_entities import ProcessRule, RetrievalModel @@ -52,7 +53,7 @@ class TestDocumentTextCreatePayload: def test_payload_with_defaults(self): """Test payload default values.""" payload = DocumentTextCreatePayload(name="Doc", text="Content") - assert payload.doc_form == "text_model" + assert payload.doc_form == IndexStructureType.PARAGRAPH_INDEX assert payload.doc_language == "English" assert payload.process_rule is None assert payload.indexing_technique is None @@ -62,14 +63,14 @@ class TestDocumentTextCreatePayload: payload = DocumentTextCreatePayload( name="Full Document", text="Complete document content here", - doc_form="qa_model", + doc_form=IndexStructureType.QA_INDEX, doc_language="Chinese", indexing_technique="high_quality", embedding_model="text-embedding-ada-002", embedding_model_provider="openai", ) assert payload.name == "Full Document" - assert payload.doc_form == "qa_model" + assert payload.doc_form == IndexStructureType.QA_INDEX assert payload.doc_language == "Chinese" assert payload.indexing_technique == "high_quality" assert payload.embedding_model == "text-embedding-ada-002" @@ -147,8 +148,8 @@ class TestDocumentTextUpdate: def test_payload_with_doc_form_update(self): """Test payload with doc_form update.""" - payload = DocumentTextUpdate(doc_form="qa_model") - assert payload.doc_form == "qa_model" + payload = DocumentTextUpdate(doc_form=IndexStructureType.QA_INDEX) + assert payload.doc_form == IndexStructureType.QA_INDEX def test_payload_with_language_update(self): """Test payload with doc_language update.""" @@ -158,7 +159,7 @@ class TestDocumentTextUpdate: def test_payload_default_values(self): """Test payload default values.""" payload = DocumentTextUpdate() - assert payload.doc_form == "text_model" + assert payload.doc_form == IndexStructureType.PARAGRAPH_INDEX assert payload.doc_language == "English" @@ -272,14 +273,24 @@ class TestDocumentDocForm: def test_text_model_form(self): """Test text_model form.""" - doc_form = "text_model" - valid_forms = ["text_model", "qa_model", "hierarchical_model", "parent_child_model"] + doc_form = IndexStructureType.PARAGRAPH_INDEX + valid_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + "parent_child_model", + ] assert doc_form in valid_forms def test_qa_model_form(self): """Test qa_model form.""" - doc_form = "qa_model" - valid_forms = ["text_model", "qa_model", "hierarchical_model", "parent_child_model"] + doc_form = IndexStructureType.QA_INDEX + valid_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + "parent_child_model", + ] assert doc_form in valid_forms @@ -504,7 +515,7 @@ class TestDocumentApiGet: doc.name = "test_document.txt" doc.indexing_status = "completed" doc.enabled = True - doc.doc_form = "text_model" + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX doc.doc_language = "English" doc.doc_type = "book" doc.doc_metadata_details = {"source": "upload"} @@ -706,7 +717,7 @@ class TestDocumentApiDelete: dataset_id = str(uuid.uuid4()) mock_dataset = Mock() mock_dataset.id = dataset_id - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = mock_document mock_doc_svc.check_archived.return_value = False @@ -735,7 +746,7 @@ class TestDocumentApiDelete: document_id = str(uuid.uuid4()) mock_dataset = Mock() mock_dataset.id = dataset_id - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = None @@ -756,7 +767,7 @@ class TestDocumentApiDelete: dataset_id = str(uuid.uuid4()) mock_dataset = Mock() mock_dataset.id = dataset_id - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = mock_document mock_doc_svc.check_archived.return_value = True @@ -777,7 +788,7 @@ class TestDocumentApiDelete: # Arrange dataset_id = str(uuid.uuid4()) document_id = str(uuid.uuid4()) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -798,7 +809,7 @@ class TestDocumentListApi: def test_list_documents_success(self, mock_db, mock_doc_svc, mock_marshal, app, mock_tenant, mock_dataset): """Test successful document list retrieval.""" # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_pagination = Mock() mock_pagination.items = [Mock(), Mock()] @@ -827,7 +838,7 @@ class TestDocumentListApi: def test_list_documents_dataset_not_found(self, mock_db, app, mock_tenant, mock_dataset): """Test 404 when dataset not found.""" # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -849,8 +860,6 @@ class TestDocumentIndexingStatusApi: """Test successful indexing status retrieval.""" # Arrange batch_id = "batch_123" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset - mock_doc = Mock() mock_doc.id = str(uuid.uuid4()) mock_doc.is_paused = False @@ -866,8 +875,8 @@ class TestDocumentIndexingStatusApi: mock_doc_svc.get_batch_documents.return_value = [mock_doc] - # Mock segment count queries - mock_db.session.query.return_value.where.return_value.where.return_value.count.return_value = 5 + # scalar() called 3 times: dataset lookup, completed_segments count, total_segments count + mock_db.session.scalar.side_effect = [mock_dataset, 5, 5] mock_marshal.return_value = {"id": mock_doc.id, "indexing_status": "completed"} # Act @@ -887,7 +896,7 @@ class TestDocumentIndexingStatusApi: """Test 404 when dataset not found.""" # Arrange batch_id = "batch_123" - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -904,7 +913,7 @@ class TestDocumentIndexingStatusApi: """Test 404 when no documents found for batch.""" # Arrange batch_id = "batch_empty" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_batch_documents.return_value = [] # Act & Assert @@ -975,7 +984,7 @@ class TestDocumentAddByTextApi: # Arrange — neutralise billing decorators self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset.indexing_technique = "economy" mock_current_user.id = str(uuid.uuid4()) @@ -1024,7 +1033,7 @@ class TestDocumentAddByTextApi: # Arrange — neutralise billing decorators self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -1053,7 +1062,7 @@ class TestDocumentAddByTextApi: self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_dataset.indexing_technique = None - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset # Act & Assert with app.test_request_context( @@ -1139,7 +1148,7 @@ class TestDocumentUpdateByTextApiPost: _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_dataset.indexing_technique = "economy" mock_dataset.latest_process_rule = Mock() - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_current_user.id = "user-1" mock_upload = Mock() @@ -1182,7 +1191,7 @@ class TestDocumentUpdateByTextApiPost: ): """Test ValueError when dataset not found.""" _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None doc_id = str(uuid.uuid4()) with app.test_request_context( @@ -1221,7 +1230,7 @@ class TestDocumentAddByFileApiPost: ): """Test ValueError when dataset not found.""" _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None from io import BytesIO @@ -1252,7 +1261,7 @@ class TestDocumentAddByFileApiPost: """Test ValueError when dataset is external.""" _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_dataset.provider = "external" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset from io import BytesIO @@ -1287,7 +1296,7 @@ class TestDocumentAddByFileApiPost: mock_dataset.provider = "vendor" mock_dataset.indexing_technique = "economy" mock_dataset.chunk_structure = None - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset with app.test_request_context( f"/datasets/{mock_dataset.id}/document/create_by_file", @@ -1317,7 +1326,7 @@ class TestDocumentAddByFileApiPost: mock_dataset.provider = "vendor" mock_dataset.indexing_technique = None mock_dataset.chunk_structure = None - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset from io import BytesIO @@ -1355,7 +1364,7 @@ class TestDocumentUpdateByFileApiPost: ): """Test ValueError when dataset not found.""" _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None from io import BytesIO @@ -1391,7 +1400,7 @@ class TestDocumentUpdateByFileApiPost: """Test ValueError when dataset is external.""" _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_dataset.provider = "external" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset from io import BytesIO @@ -1439,7 +1448,7 @@ class TestDocumentUpdateByFileApiPost: mock_dataset.chunk_structure = None mock_dataset.latest_process_rule = Mock() mock_dataset.created_by_account = Mock() - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_current_user.id = "user-1" mock_upload = Mock() diff --git a/api/tests/unit_tests/controllers/service_api/test_site.py b/api/tests/unit_tests/controllers/service_api/test_site.py index b58caf3be18..c0b40d070a5 100644 --- a/api/tests/unit_tests/controllers/service_api/test_site.py +++ b/api/tests/unit_tests/controllers/service_api/test_site.py @@ -88,7 +88,7 @@ class TestAppSiteApi: mock_app_model.tenant = mock_tenant # Mock wraps.db for authentication - mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_wraps_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -98,7 +98,7 @@ class TestAppSiteApi: setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account) # Mock site.db for site query - mock_db.session.query.return_value.where.return_value.first.return_value = mock_site + mock_db.session.scalar.return_value = mock_site # Act with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -109,7 +109,7 @@ class TestAppSiteApi: assert response["title"] == "Test Site" assert response["icon"] == "icon-url" assert response["description"] == "Site description" - mock_db.session.query.assert_called_once_with(Site) + mock_db.session.scalar.assert_called_once() @patch("controllers.service_api.wraps.user_logged_in") @patch("controllers.service_api.app.site.db") @@ -140,7 +140,7 @@ class TestAppSiteApi: mock_tenant.status = TenantStatus.NORMAL mock_app_model.tenant = mock_tenant - mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_wraps_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -150,7 +150,7 @@ class TestAppSiteApi: setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account) # Mock site query to return None - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -187,7 +187,7 @@ class TestAppSiteApi: mock_tenant = Mock() mock_tenant.status = TenantStatus.NORMAL - mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_wraps_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -197,7 +197,7 @@ class TestAppSiteApi: setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account) # Mock site query - mock_db.session.query.return_value.where.return_value.first.return_value = mock_site + mock_db.session.scalar.return_value = mock_site # Set tenant status to archived AFTER authentication mock_app_model.tenant.status = TenantStatus.ARCHIVE @@ -230,7 +230,7 @@ class TestAppSiteApi: mock_tenant.status = TenantStatus.NORMAL mock_app_model.tenant = mock_tenant - mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_wraps_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -258,7 +258,7 @@ class TestAppSiteApi: mock_site.icon_type = "image" mock_site.created_at = "2024-01-01T00:00:00" mock_site.updated_at = "2024-01-01T00:00:00" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_site + mock_db.session.scalar.return_value = mock_site # Act with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -267,4 +267,4 @@ class TestAppSiteApi: # Assert # The query was executed successfully (site returned), which validates the correct query was made - mock_db.session.query.assert_called_once_with(Site) + mock_db.session.scalar.assert_called_once() diff --git a/api/tests/unit_tests/controllers/service_api/test_wraps.py b/api/tests/unit_tests/controllers/service_api/test_wraps.py index 9c2d075f417..a2008e024b9 100644 --- a/api/tests/unit_tests/controllers/service_api/test_wraps.py +++ b/api/tests/unit_tests/controllers/service_api/test_wraps.py @@ -144,14 +144,10 @@ class TestValidateAppToken: mock_ta = Mock() mock_ta.account_id = mock_account.id - # Use side_effect to return app first, then tenant - mock_db.session.query.return_value.where.return_value.first.side_effect = [ - mock_app, - mock_tenant, - mock_account, - ] + # Use side_effect to return app first, then tenant via session.get() + mock_db.session.get.side_effect = [mock_app, mock_tenant] - # Mock the tenant owner query + # Mock the tenant owner query (execute(select(...)).one_or_none()) setup_mock_tenant_account_query(mock_db, mock_tenant, mock_ta) @validate_app_token @@ -175,7 +171,7 @@ class TestValidateAppToken: mock_api_token.app_id = str(uuid.uuid4()) mock_validate_token.return_value = mock_api_token - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.get.return_value = None @validate_app_token def protected_view(**kwargs): @@ -198,7 +194,7 @@ class TestValidateAppToken: mock_app = Mock() mock_app.status = "abnormal" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_app + mock_db.session.get.return_value = mock_app @validate_app_token def protected_view(**kwargs): @@ -222,7 +218,7 @@ class TestValidateAppToken: mock_app = Mock() mock_app.status = "normal" mock_app.enable_api = False - mock_db.session.query.return_value.where.return_value.first.return_value = mock_app + mock_db.session.get.return_value = mock_app @validate_app_token def protected_view(**kwargs): @@ -474,11 +470,11 @@ class TestValidateDatasetToken: mock_account.id = mock_ta.account_id mock_account.current_tenant = mock_tenant - # Mock the tenant account join query + # Mock the tenant account join query (execute(select(...)).one_or_none()) setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta) - # Mock the account query - mock_db.session.query.return_value.where.return_value.first.return_value = mock_account + # Mock the account lookup via session.get() + mock_db.session.get.return_value = mock_account @validate_dataset_token def protected_view(tenant_id): @@ -501,7 +497,7 @@ class TestValidateDatasetToken: mock_api_token.tenant_id = str(uuid.uuid4()) mock_validate_token.return_value = mock_api_token - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None @validate_dataset_token def protected_view(dataset_id=None, **kwargs): diff --git a/api/tests/unit_tests/controllers/web/test_audio.py b/api/tests/unit_tests/controllers/web/test_audio.py index 01f34345aa3..cbfc8fa6130 100644 --- a/api/tests/unit_tests/controllers/web/test_audio.py +++ b/api/tests/unit_tests/controllers/web/test_audio.py @@ -8,6 +8,7 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask +from graphon.model_runtime.errors.invoke import InvokeError from controllers.web.audio import AudioApi, TextApi from controllers.web.error import ( @@ -21,7 +22,6 @@ from controllers.web.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError from services.errors.audio import ( AudioTooLargeServiceError, NoAudioUploadedServiceError, diff --git a/api/tests/unit_tests/controllers/web/test_completion.py b/api/tests/unit_tests/controllers/web/test_completion.py index e88bcf2ae65..49039d03fe1 100644 --- a/api/tests/unit_tests/controllers/web/test_completion.py +++ b/api/tests/unit_tests/controllers/web/test_completion.py @@ -7,6 +7,7 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask +from graphon.model_runtime.errors.invoke import InvokeError from controllers.web.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi from controllers.web.error import ( @@ -18,7 +19,6 @@ from controllers.web.error import ( ProviderQuotaExceededError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError def _completion_app() -> SimpleNamespace: diff --git a/api/tests/unit_tests/core/agent/test_base_agent_runner.py b/api/tests/unit_tests/core/agent/test_base_agent_runner.py index 683cc0e36f7..db4b293b163 100644 --- a/api/tests/unit_tests/core/agent/test_base_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_base_agent_runner.py @@ -621,7 +621,7 @@ class TestConvertDatasetRetrieverTool: class TestBaseAgentRunnerInit: def test_init_sets_stream_tool_call_and_files(self, mocker): session = mocker.MagicMock() - session.query.return_value.where.return_value.count.return_value = 2 + session.scalar.return_value = 2 mocker.patch.object(module.db, "session", session) mocker.patch.object(BaseAgentRunner, "organize_agent_history", return_value=[]) diff --git a/api/tests/unit_tests/core/agent/test_cot_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_agent_runner.py index f6d1edbaf01..bc7aea0ef92 100644 --- a/api/tests/unit_tests/core/agent/test_cot_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_agent_runner.py @@ -2,11 +2,11 @@ import json from unittest.mock import MagicMock import pytest +from graphon.model_runtime.entities.llm_entities import LLMUsage from core.agent.cot_agent_runner import CotAgentRunner from core.agent.entities import AgentScratchpadUnit from core.agent.errors import AgentMaxIterationError -from dify_graph.model_runtime.entities.llm_entities import LLMUsage class DummyRunner(CotAgentRunner): @@ -387,7 +387,7 @@ class TestRun: runner.update_prompt_message_tool.assert_called_once() def test_historic_with_assistant_and_tool_calls(self, runner): - from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, ToolPromptMessage + from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, ToolPromptMessage assistant = AssistantPromptMessage(content="thinking") assistant.tool_calls = [MagicMock(function=MagicMock(name="tool", arguments='{"a":1}'))] @@ -400,7 +400,7 @@ class TestRun: assert isinstance(result, list) def test_historic_final_flush_branch(self, runner): - from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage + from graphon.model_runtime.entities.message_entities import AssistantPromptMessage assistant = AssistantPromptMessage(content="final") runner.history_prompt_messages = [assistant] @@ -458,7 +458,7 @@ class TestFillInputsEdgeCases: class TestOrganizeHistoricPromptMessagesExtended: def test_user_message_flushes_scratchpad(self, runner, mocker): - from dify_graph.model_runtime.entities.message_entities import UserPromptMessage + from graphon.model_runtime.entities.message_entities import UserPromptMessage user_message = UserPromptMessage(content="Hi") @@ -473,7 +473,7 @@ class TestOrganizeHistoricPromptMessagesExtended: assert result == ["final"] def test_tool_message_without_scratchpad_raises(self, runner): - from dify_graph.model_runtime.entities.message_entities import ToolPromptMessage + from graphon.model_runtime.entities.message_entities import ToolPromptMessage runner.history_prompt_messages = [ToolPromptMessage(content="obs", tool_call_id="1")] diff --git a/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py index f9d69d11960..97206019b9f 100644 --- a/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py @@ -1,9 +1,9 @@ from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from core.agent.cot_chat_agent_runner import CotChatAgentRunner -from dify_graph.model_runtime.entities.message_entities import TextPromptMessageContent from tests.unit_tests.core.agent.conftest import ( DummyAgentConfig, DummyAppConfig, @@ -93,7 +93,7 @@ class TestOrganizeUserQuery: @patch("core.agent.cot_chat_agent_runner.UserPromptMessage") @patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content") def test_organize_user_query_with_image_file_default_config(self, mock_to_prompt, mock_user_prompt, runner): - from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent + from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent mock_content = ImagePromptMessageContent( url="http://test", @@ -118,7 +118,7 @@ class TestOrganizeUserQuery: @patch("core.agent.cot_chat_agent_runner.UserPromptMessage") @patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content") def test_organize_user_query_with_image_file_high_detail(self, mock_to_prompt, mock_user_prompt, runner): - from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent + from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent mock_content = ImagePromptMessageContent( url="http://test", diff --git a/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py index ab822bb57df..defc8b4b642 100644 --- a/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py @@ -1,15 +1,15 @@ import json import pytest - -from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, TextPromptMessageContent, UserPromptMessage, ) +from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner + # ----------------------------- # Fixtures # ----------------------------- diff --git a/api/tests/unit_tests/core/agent/test_fc_agent_runner.py b/api/tests/unit_tests/core/agent/test_fc_agent_runner.py index 299c9b31d23..a44a0650eb4 100644 --- a/api/tests/unit_tests/core/agent/test_fc_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_fc_agent_runner.py @@ -3,19 +3,19 @@ from typing import Any from unittest.mock import MagicMock import pytest - -from core.agent.errors import AgentMaxIterationError -from core.agent.fc_agent_runner import FunctionCallAgentRunner -from core.app.apps.base_app_queue_manager import PublishFrom -from core.app.entities.queue_entities import QueueMessageFileEvent -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.message_entities import ( DocumentPromptMessageContent, ImagePromptMessageContent, TextPromptMessageContent, UserPromptMessage, ) +from core.agent.errors import AgentMaxIterationError +from core.agent.fc_agent_runner import FunctionCallAgentRunner +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.entities.queue_entities import QueueMessageFileEvent + # ============================== # Dummy Helper Classes # ============================== diff --git a/api/dify_graph/model_runtime/model_providers/__init__.py b/api/tests/unit_tests/core/app/app_config/__init__.py similarity index 100% rename from api/dify_graph/model_runtime/model_providers/__init__.py rename to api/tests/unit_tests/core/app/app_config/__init__.py diff --git a/api/tests/unit_tests/core/app/app_config/common/test_parameters_mapping.py b/api/tests/unit_tests/core/app/app_config/common/test_parameters_mapping.py new file mode 100644 index 00000000000..1c5b6ed944b --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/common/test_parameters_mapping.py @@ -0,0 +1,227 @@ +from unittest.mock import MagicMock + +import pytest + +# Module under test +from core.app.app_config.common import parameters_mapping + + +class TestGetParametersFromFeatureDict: + """Test suite for get_parameters_from_feature_dict""" + + @pytest.fixture + def mock_config(self, monkeypatch): + """Mock dify_config values""" + mock = MagicMock() + mock.UPLOAD_IMAGE_FILE_SIZE_LIMIT = 1 + mock.UPLOAD_VIDEO_FILE_SIZE_LIMIT = 2 + mock.UPLOAD_AUDIO_FILE_SIZE_LIMIT = 3 + mock.UPLOAD_FILE_SIZE_LIMIT = 4 + mock.WORKFLOW_FILE_UPLOAD_LIMIT = 5 + + monkeypatch.setattr(parameters_mapping, "dify_config", mock) + return mock + + @pytest.fixture + def mock_default_file_limits(self, monkeypatch): + """Mock DEFAULT_FILE_NUMBER_LIMITS constant""" + monkeypatch.setattr(parameters_mapping, "DEFAULT_FILE_NUMBER_LIMITS", 99) + return 99 + + @pytest.fixture + def minimal_inputs(self): + return {}, [] + + @pytest.mark.parametrize( + ("feature_key", "expected_default"), + [ + ("suggested_questions", []), + ("suggested_questions_after_answer", {"enabled": False}), + ("speech_to_text", {"enabled": False}), + ("text_to_speech", {"enabled": False}), + ("retriever_resource", {"enabled": False}), + ("annotation_reply", {"enabled": False}), + ("more_like_this", {"enabled": False}), + ( + "sensitive_word_avoidance", + {"enabled": False, "type": "", "configs": []}, + ), + ], + ) + def test_defaults_when_key_missing( + self, + feature_key, + expected_default, + mock_config, + mock_default_file_limits, + ): + # Arrange + features = {} + user_input = [] + + # Act + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=user_input, + ) + + # Assert + assert result[feature_key] == expected_default + + def test_opening_statement_present(self, mock_config, mock_default_file_limits): + # Arrange + features = {"opening_statement": "Hello"} + + # Act + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=[], + ) + + # Assert + assert result["opening_statement"] == "Hello" + + def test_opening_statement_missing_returns_none(self, mock_config, mock_default_file_limits): + # Arrange + features = {} + + # Act + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=[], + ) + + # Assert + assert result["opening_statement"] is None + + def test_all_features_provided(self, mock_config, mock_default_file_limits): + # Arrange + features = { + "opening_statement": "Hi", + "suggested_questions": ["Q1"], + "suggested_questions_after_answer": {"enabled": True}, + "speech_to_text": {"enabled": True}, + "text_to_speech": {"enabled": True}, + "retriever_resource": {"enabled": True}, + "annotation_reply": {"enabled": True}, + "more_like_this": {"enabled": True}, + "sensitive_word_avoidance": { + "enabled": True, + "type": "strict", + "configs": ["a"], + }, + "file_upload": { + "image": { + "enabled": True, + "number_limits": 10, + "detail": "low", + "transfer_methods": ["local_file"], + } + }, + } + user_input = [{"name": "field1"}] + + # Act + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=user_input, + ) + + # Assert + for key in features: + assert result[key] == features[key] + assert result["user_input_form"] == user_input + + def test_file_upload_default_structure(self, mock_config, mock_default_file_limits): + # Arrange + features = {} + + # Act + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=[], + ) + + # Assert + file_upload = result["file_upload"] + assert file_upload["image"]["enabled"] is False + assert file_upload["image"]["number_limits"] == 99 + assert file_upload["image"]["detail"] == "high" + assert "remote_url" in file_upload["image"]["transfer_methods"] + assert "local_file" in file_upload["image"]["transfer_methods"] + + def test_system_parameters_from_config(self, mock_config, mock_default_file_limits): + # Arrange + features = {} + + # Act + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=[], + ) + + # Assert + system_params = result["system_parameters"] + assert system_params["image_file_size_limit"] == 1 + assert system_params["video_file_size_limit"] == 2 + assert system_params["audio_file_size_limit"] == 3 + assert system_params["file_size_limit"] == 4 + assert system_params["workflow_file_upload_limit"] == 5 + + @pytest.mark.parametrize( + ("features_dict", "user_input_form"), + [ + (None, []), + ([], []), + ("invalid", []), + ], + ) + def test_invalid_features_dict_type_raises(self, features_dict, user_input_form): + # Act & Assert + with pytest.raises(AttributeError): + parameters_mapping.get_parameters_from_feature_dict( + features_dict=features_dict, + user_input_form=user_input_form, + ) + + @pytest.mark.parametrize( + "user_input_form", + [None, "invalid", 123], + ) + def test_user_input_form_invalid_type(self, mock_config, mock_default_file_limits, user_input_form): + # Arrange + features = {} + + # Act + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=user_input_form, + ) + + # Assert + assert result["user_input_form"] == user_input_form + + def test_empty_user_input_form(self, mock_config, mock_default_file_limits): + features = {} + user_input = [] + + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=user_input, + ) + + assert result["user_input_form"] == [] + + def test_feature_values_none(self, mock_config, mock_default_file_limits): + features = { + "suggested_questions": None, + "speech_to_text": None, + } + + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=[], + ) + + assert result["suggested_questions"] is None + assert result["speech_to_text"] is None diff --git a/api/tests/unit_tests/core/app/app_config/common/test_sensitive_word_avoidance_manager.py b/api/tests/unit_tests/core/app/app_config/common/test_sensitive_word_avoidance_manager.py new file mode 100644 index 00000000000..013ed0cbc41 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/common/test_sensitive_word_avoidance_manager.py @@ -0,0 +1,202 @@ +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.common.sensitive_word_avoidance.manager import ( + SensitiveWordAvoidanceConfigManager, +) + + +class TestSensitiveWordAvoidanceConfigManagerConvert: + """Tests for convert classmethod""" + + @pytest.mark.parametrize( + "config", + [ + {}, + {"sensitive_word_avoidance": None}, + {"sensitive_word_avoidance": {}}, + {"sensitive_word_avoidance": {"enabled": False}}, + ], + ) + def test_convert_returns_none_when_disabled_or_missing(self, config): + # Act + result = SensitiveWordAvoidanceConfigManager.convert(config) + + # Assert + assert result is None + + def test_convert_returns_entity_when_enabled(self, mocker): + # Arrange + mock_entity = MagicMock() + mocker.patch( + "core.app.app_config.common.sensitive_word_avoidance.manager.SensitiveWordAvoidanceEntity", + return_value=mock_entity, + ) + + config = { + "sensitive_word_avoidance": { + "enabled": True, + "type": "mock_type", + "config": {"key": "value"}, + } + } + + # Act + result = SensitiveWordAvoidanceConfigManager.convert(config) + + # Assert + assert result == mock_entity + + def test_convert_enabled_without_type_or_config(self, mocker): + # Arrange + mock_entity = MagicMock() + patched = mocker.patch( + "core.app.app_config.common.sensitive_word_avoidance.manager.SensitiveWordAvoidanceEntity", + return_value=mock_entity, + ) + + config = {"sensitive_word_avoidance": {"enabled": True}} + + # Act + result = SensitiveWordAvoidanceConfigManager.convert(config) + + # Assert + patched.assert_called_once_with(type=None, config={}) + assert result == mock_entity + + +class TestSensitiveWordAvoidanceConfigManagerValidateAndSetDefaults: + """Tests for validate_and_set_defaults classmethod""" + + @pytest.fixture + def base_config(self): + return {} + + def test_validate_sets_default_when_missing(self, base_config): + # Act + config, fields = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id="tenant1", config=base_config.copy() + ) + + # Assert + assert config["sensitive_word_avoidance"]["enabled"] is False + assert fields == ["sensitive_word_avoidance"] + + def test_validate_raises_when_not_dict(self): + config = {"sensitive_word_avoidance": "invalid"} + + with pytest.raises(ValueError, match="must be of dict type"): + SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config) + + @pytest.mark.parametrize( + "config", + [ + {"sensitive_word_avoidance": {"enabled": False}}, + {"sensitive_word_avoidance": {"enabled": None}}, + {"sensitive_word_avoidance": {}}, + ], + ) + def test_validate_disables_when_enabled_false_or_missing(self, config): + # Act + result_config, _ = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id="tenant1", config=config + ) + + # Assert + assert result_config["sensitive_word_avoidance"]["enabled"] is False + + def test_validate_raises_when_enabled_true_without_type(self): + config = {"sensitive_word_avoidance": {"enabled": True}} + + with pytest.raises(ValueError, match="type is required"): + SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config) + + def test_validate_raises_when_type_not_string(self): + config = { + "sensitive_word_avoidance": { + "enabled": True, + "type": 123, + } + } + + with pytest.raises(ValueError, match="must be a string"): + SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config) + + def test_validate_raises_when_config_not_dict(self): + config = { + "sensitive_word_avoidance": { + "enabled": True, + "type": "mock_type", + "config": "invalid", + } + } + + with pytest.raises(ValueError, match="must be a dict"): + SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config) + + def test_validate_calls_moderation_factory(self, mocker): + # Arrange + mock_validate = mocker.patch( + "core.app.app_config.common.sensitive_word_avoidance.manager.ModerationFactory.validate_config" + ) + + config = { + "sensitive_word_avoidance": { + "enabled": True, + "type": "mock_type", + "config": {"k": "v"}, + } + } + + # Act + result_config, fields = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id="tenant1", config=config + ) + + # Assert + mock_validate.assert_called_once_with(name="mock_type", tenant_id="tenant1", config={"k": "v"}) + assert result_config["sensitive_word_avoidance"]["enabled"] is True + assert fields == ["sensitive_word_avoidance"] + + def test_validate_sets_empty_dict_when_config_none(self, mocker): + # Arrange + mock_validate = mocker.patch( + "core.app.app_config.common.sensitive_word_avoidance.manager.ModerationFactory.validate_config" + ) + + config = { + "sensitive_word_avoidance": { + "enabled": True, + "type": "mock_type", + "config": None, + } + } + + # Act + SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config) + + # Assert + mock_validate.assert_called_once_with(name="mock_type", tenant_id="tenant1", config={}) + + def test_validate_only_structure_validate_skips_factory(self, mocker): + # Arrange + mock_validate = mocker.patch( + "core.app.app_config.common.sensitive_word_avoidance.manager.ModerationFactory.validate_config" + ) + + config = { + "sensitive_word_avoidance": { + "enabled": True, + "type": "mock_type", + "config": {"k": "v"}, + } + } + + # Act + SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id="tenant1", config=config, only_structure_validate=True + ) + + # Assert + mock_validate.assert_not_called() diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_agent_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_agent_manager.py new file mode 100644 index 00000000000..992b580376b --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_agent_manager.py @@ -0,0 +1,236 @@ +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager + + +class TestAgentConfigManagerConvert: + @pytest.fixture + def base_config(self): + return { + "agent_mode": { + "enabled": True, + "strategy": "cot", + "tools": [], + }, + "model": { + "provider": "openai", + "name": "gpt-4", + "mode": "completion", + }, + } + + def test_convert_returns_none_when_agent_mode_missing(self): + config = {"model": {"provider": "openai", "name": "gpt-4"}} + + result = AgentConfigManager.convert(config) + + assert result is None + + @pytest.mark.parametrize("agent_mode_value", [None, {}, {"enabled": False}]) + def test_convert_returns_none_when_agent_mode_disabled(self, agent_mode_value, base_config): + config = base_config.copy() + config["agent_mode"] = agent_mode_value + + result = AgentConfigManager.convert(config) + + assert result is None + + @pytest.mark.parametrize( + ("strategy_input", "expected_enum"), + [ + ("function_call", "FUNCTION_CALLING"), + ("cot", "CHAIN_OF_THOUGHT"), + ("react", "CHAIN_OF_THOUGHT"), + ], + ) + def test_convert_strategy_mapping(self, strategy_input, expected_enum, base_config): + config = base_config.copy() + config["agent_mode"] = { + "enabled": True, + "strategy": strategy_input, + "tools": [], + } + + result = AgentConfigManager.convert(config) + + assert result is not None + assert result.strategy.name == expected_enum + + def test_convert_unknown_strategy_openai_defaults_to_function_calling(self, base_config): + config = base_config.copy() + config["agent_mode"] = { + "enabled": True, + "strategy": "unknown_strategy", + "tools": [], + } + config["model"]["provider"] = "openai" + + result = AgentConfigManager.convert(config) + + assert result.strategy.name == "FUNCTION_CALLING" + + def test_convert_unknown_strategy_non_openai_defaults_to_chain_of_thought(self, base_config): + config = base_config.copy() + config["agent_mode"] = { + "enabled": True, + "strategy": "unknown_strategy", + "tools": [], + } + config["model"]["provider"] = "anthropic" + + result = AgentConfigManager.convert(config) + + assert result.strategy.name == "CHAIN_OF_THOUGHT" + + def test_convert_skips_disabled_tools(self, mocker, base_config): + # Patch AgentEntity to bypass pydantic validation + mock_agent_entity = mocker.patch( + "core.app.app_config.easy_ui_based_app.agent.manager.AgentEntity", + return_value=MagicMock(), + ) + + mock_validate = mocker.patch( + "core.app.app_config.easy_ui_based_app.agent.manager.AgentToolEntity.model_validate", + return_value={ + "provider_type": "type2", + "provider_id": "id2", + "tool_name": "tool2", + "tool_parameters": {}, + "credential_id": None, + }, + ) + + config = base_config.copy() + config["agent_mode"] = { + "enabled": True, + "strategy": "cot", + "tools": [ + { + "provider_type": "type1", + "provider_id": "id1", + "tool_name": "tool1", + "enabled": False, + }, + { + "provider_type": "type2", + "provider_id": "id2", + "tool_name": "tool2", + "enabled": True, + "extra_key": "x", + }, + ], + } + + AgentConfigManager.convert(config) + + mock_validate.assert_called_once() + mock_agent_entity.assert_called_once() + + def test_convert_tool_requires_minimum_keys(self, mocker, base_config): + mock_validate = mocker.patch( + "core.app.app_config.easy_ui_based_app.agent.manager.AgentToolEntity.model_validate", + return_value=MagicMock(), + ) + + config = base_config.copy() + config["agent_mode"] = { + "enabled": True, + "strategy": "cot", + "tools": [ + {"a": 1, "b": 2}, # insufficient keys + ], + } + + result = AgentConfigManager.convert(config) + + assert result is not None + assert result.tools == [] + mock_validate.assert_not_called() + + def test_convert_completion_mode_prompt_defaults(self, base_config): + config = base_config.copy() + config["agent_mode"]["prompt"] = {} + config["model"]["mode"] = "completion" + + result = AgentConfigManager.convert(config) + + assert result is not None + assert result.prompt.first_prompt is not None + assert result.prompt.next_iteration is not None + + def test_convert_chat_mode_prompt_defaults(self, base_config): + config = base_config.copy() + config["agent_mode"]["prompt"] = {} + config["model"]["mode"] = "chat" + + result = AgentConfigManager.convert(config) + + assert result is not None + assert result.prompt.first_prompt is not None + assert result.prompt.next_iteration is not None + + def test_convert_router_strategy_returns_none(self, base_config): + config = base_config.copy() + config["agent_mode"] = { + "enabled": True, + "strategy": "router", + "tools": [], + } + + result = AgentConfigManager.convert(config) + + assert result is None + + def test_convert_react_router_strategy_returns_none(self, base_config): + config = base_config.copy() + config["agent_mode"] = { + "enabled": True, + "strategy": "react_router", + "tools": [], + } + + result = AgentConfigManager.convert(config) + + assert result is None + + def test_convert_max_iteration_default(self, base_config): + config = base_config.copy() + config["agent_mode"].pop("max_iteration", None) + + result = AgentConfigManager.convert(config) + + assert result.max_iteration == 10 + + def test_convert_custom_max_iteration(self, base_config): + config = base_config.copy() + config["agent_mode"]["max_iteration"] = 25 + + result = AgentConfigManager.convert(config) + + assert result.max_iteration == 25 + + def test_convert_missing_model_raises_key_error(self, base_config): + config = base_config.copy() + del config["model"] + + with pytest.raises(KeyError): + AgentConfigManager.convert(config) + + @pytest.mark.parametrize( + ("invalid_config", "should_raise"), + [ + (None, True), + (123, True), + ("", False), + ([], False), + ], + ) + def test_convert_invalid_input_type_behavior(self, invalid_config, should_raise): + if should_raise: + with pytest.raises(TypeError): + AgentConfigManager.convert(invalid_config) # type: ignore + else: + result = AgentConfigManager.convert(invalid_config) # type: ignore + assert result is None diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_dataset_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_dataset_manager.py new file mode 100644 index 00000000000..a688e2a5c50 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_dataset_manager.py @@ -0,0 +1,319 @@ +import uuid +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager +from core.entities.agent_entities import PlanningStrategy +from models.model import AppMode + +# ============================== +# Fixtures +# ============================== + + +@pytest.fixture +def valid_uuid(): + return str(uuid.uuid4()) + + +@pytest.fixture +def base_config(valid_uuid): + return { + "dataset_configs": { + "retrieval_model": "multiple", + "datasets": { + "strategy": "router", + "datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}], + }, + } + } + + +@pytest.fixture +def mock_dataset_service(mocker, valid_uuid): + mock_dataset = MagicMock() + mock_dataset.tenant_id = "tenant1" + + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset", + return_value=mock_dataset, + ) + + +# ============================== +# convert tests +# ============================== + + +class TestDatasetConfigManagerConvert: + def test_convert_returns_none_when_no_datasets(self): + config = {"dataset_configs": {"datasets": {"datasets": []}}} + result = DatasetConfigManager.convert(config) + assert result is None + + def test_convert_single_retrieval(self, valid_uuid): + config = { + "dataset_query_variable": "query", + "dataset_configs": { + "retrieval_model": "single", + "datasets": { + "strategy": "router", + "datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}], + }, + }, + } + + result = DatasetConfigManager.convert(config) + assert result is not None + assert result.dataset_ids == [valid_uuid] + assert result.retrieve_config.query_variable == "query" + + def test_convert_single_with_metadata_configs(self, valid_uuid, mocker): + mock_retrieve_config = MagicMock() + mock_entity = MagicMock() + mock_entity.dataset_ids = [valid_uuid] + mock_entity.retrieve_config = mock_retrieve_config + + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.ModelConfig", + return_value={"mock": "model"}, + ) + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.MetadataFilteringCondition", + return_value={"mock": "condition"}, + ) + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.DatasetRetrieveConfigEntity", + return_value=mock_retrieve_config, + ) + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.DatasetEntity", + return_value=mock_entity, + ) + + config = { + "dataset_query_variable": "query", + "dataset_configs": { + "retrieval_model": "single", + "metadata_filtering_mode": "manual", + "metadata_model_config": {"any": "value"}, + "metadata_filtering_conditions": {"any": "value"}, + "datasets": { + "strategy": "router", + "datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}], + }, + }, + } + result = DatasetConfigManager.convert(config) + assert result.dataset_ids == [valid_uuid] + assert result.retrieve_config is mock_retrieve_config + + def test_convert_multiple_defaults(self, valid_uuid): + config = { + "dataset_configs": { + "retrieval_model": "multiple", + "datasets": { + "strategy": "router", + "datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}], + }, + } + } + result = DatasetConfigManager.convert(config) + assert result.retrieve_config.top_k == 4 + assert result.retrieve_config.score_threshold is None + assert result.retrieve_config.reranking_enabled is True + + def test_convert_agent_mode_disabled_tool(self, valid_uuid): + config = { + "agent_mode": { + "enabled": True, + "tools": [{"dataset": {"id": valid_uuid, "enabled": False}}], + } + } + result = DatasetConfigManager.convert(config) + assert result is None + + def test_convert_dataset_configs_none(self): + config = {"dataset_configs": None} + with pytest.raises(TypeError): + DatasetConfigManager.convert(config) + + def test_convert_agent_mode_old_style_old_format(self, valid_uuid): + config = { + "agent_mode": { + "enabled": True, + "tools": [{"dataset": {"id": valid_uuid, "enabled": True}}], + } + } + result = DatasetConfigManager.convert(config) + assert result.dataset_ids == [valid_uuid] + assert result.retrieve_config.query_variable is None + + def test_convert_multiple_with_score_threshold(self, valid_uuid): + config = { + "dataset_query_variable": "query", + "dataset_configs": { + "retrieval_model": "multiple", + "top_k": 5, + "score_threshold": 0.8, + "score_threshold_enabled": True, + "datasets": { + "strategy": "router", + "datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}], + }, + }, + } + + result = DatasetConfigManager.convert(config) + assert result.retrieve_config.top_k == 5 + assert result.retrieve_config.score_threshold == 0.8 + + @pytest.mark.parametrize( + "dataset_entry", + [ + {}, + {"invalid": {}}, + {"dataset": {"id": None, "enabled": True}}, + {"dataset": {"id": "", "enabled": False}}, + ], + ) + def test_convert_ignores_invalid_dataset_entries(self, dataset_entry): + config = { + "dataset_configs": { + "retrieval_model": "multiple", + "datasets": {"strategy": "router", "datasets": [dataset_entry]}, + } + } + result = DatasetConfigManager.convert(config) + assert result is None + + def test_convert_agent_mode_old_style(self, valid_uuid): + config = { + "agent_mode": { + "enabled": True, + "tools": [{"dataset": {"id": valid_uuid, "enabled": True}}], + } + } + result = DatasetConfigManager.convert(config) + assert result.dataset_ids == [valid_uuid] + + +# ============================== +# validate_and_set_defaults tests +# ============================== + + +class TestValidateAndSetDefaults: + def test_validate_sets_defaults(self): + config = {} + updated, fields = DatasetConfigManager.validate_and_set_defaults("tenant1", AppMode.CHAT, config) + assert "dataset_configs" in updated + assert updated["dataset_configs"]["retrieval_model"] == "single" + assert isinstance(fields, list) + + def test_validate_raises_when_dataset_configs_not_dict(self): + config = {"dataset_configs": "invalid"} + with pytest.raises(AttributeError): + DatasetConfigManager.validate_and_set_defaults("tenant1", AppMode.CHAT, config) + + def test_validate_requires_query_variable_in_completion_mode(self, valid_uuid): + config = { + "dataset_configs": { + "datasets": { + "strategy": "router", + "datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}], + } + } + } + with pytest.raises(ValueError): + DatasetConfigManager.validate_and_set_defaults("tenant1", AppMode.COMPLETION, config) + + +# ============================== +# extract_dataset_config_for_legacy_compatibility tests +# ============================== + + +class TestExtractDatasetConfig: + def test_extract_sets_defaults(self): + config = {} + result = DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config) + assert "agent_mode" in result + assert result["agent_mode"]["enabled"] is False + assert result["agent_mode"]["tools"] == [] + + def test_extract_invalid_agent_mode_type(self): + config = {"agent_mode": "invalid"} + with pytest.raises(ValueError): + DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config) + + def test_extract_invalid_enabled_type(self): + config = {"agent_mode": {"enabled": "yes"}} + with pytest.raises(ValueError): + DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config) + + def test_extract_invalid_tools_type(self): + config = {"agent_mode": {"enabled": True, "tools": "invalid"}} + with pytest.raises(ValueError): + DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config) + + def test_extract_invalid_uuid(self, mocker): + invalid_uuid = "not-a-uuid" + config = { + "agent_mode": { + "enabled": True, + "strategy": PlanningStrategy.ROUTER, + "tools": [{"dataset": {"id": invalid_uuid, "enabled": True}}], + } + } + with pytest.raises(ValueError): + DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config) + + def test_extract_dataset_not_exists(self, valid_uuid, mocker): + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset", + return_value=None, + ) + config = { + "agent_mode": { + "enabled": True, + "strategy": PlanningStrategy.ROUTER, + "tools": [{"dataset": {"id": valid_uuid, "enabled": True}}], + } + } + with pytest.raises(ValueError): + DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config) + + +# ============================== +# is_dataset_exists tests +# ============================== + + +class TestIsDatasetExists: + def test_dataset_exists_true(self, mocker, valid_uuid): + mock_dataset = MagicMock() + mock_dataset.tenant_id = "tenant1" + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset", + return_value=mock_dataset, + ) + + assert DatasetConfigManager.is_dataset_exists("tenant1", valid_uuid) + + def test_dataset_exists_false_when_not_found(self, mocker, valid_uuid): + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset", + return_value=None, + ) + assert not DatasetConfigManager.is_dataset_exists("tenant1", valid_uuid) + + def test_dataset_exists_false_when_tenant_mismatch(self, mocker, valid_uuid): + mock_dataset = MagicMock() + mock_dataset.tenant_id = "other" + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset", + return_value=mock_dataset, + ) + assert not DatasetConfigManager.is_dataset_exists("tenant1", valid_uuid) diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py new file mode 100644 index 00000000000..5ee66da94ab --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py @@ -0,0 +1,234 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelPropertyKey + +from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter +from core.entities.model_entities import ModelStatus +from core.errors.error import ( + ModelCurrentlyNotSupportError, + ProviderTokenNotInitError, + QuotaExceededError, +) + + +class TestModelConfigConverter: + @pytest.fixture(autouse=True) + def patch_response_entity(self, mocker): + """ + Patch ModelConfigWithCredentialsEntity to bypass Pydantic validation + and return a simple namespace object instead. + """ + + def _factory(**kwargs): + return SimpleNamespace(**kwargs) + + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.ModelConfigWithCredentialsEntity", + side_effect=_factory, + ) + + @pytest.fixture + def mock_app_config(self): + app_config = MagicMock() + app_config.tenant_id = "tenant_1" + + model_config = MagicMock() + model_config.provider = "openai" + model_config.model = "gpt-4" + model_config.parameters = {"temperature": 0.5} + model_config.mode = None + + app_config.model = model_config + return app_config + + @pytest.fixture + def mock_provider_bundle(self): + bundle = MagicMock() + + # configuration + configuration = MagicMock() + configuration.provider.provider = "openai" + configuration.get_current_credentials.return_value = {"api_key": "key"} + + provider_model = MagicMock() + provider_model.status = ModelStatus.ACTIVE + configuration.get_provider_model.return_value = provider_model + + bundle.configuration = configuration + + # model type instance + model_type_instance = MagicMock() + model_schema = MagicMock() + model_schema.model_properties = {} + model_type_instance.get_model_schema.return_value = model_schema + bundle.model_type_instance = model_type_instance + + return bundle + + @pytest.fixture + def patch_provider_manager(self, mocker, mock_provider_bundle): + mock_manager = MagicMock() + mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.create_plugin_provider_manager", + return_value=mock_manager, + ) + return mock_manager + + # ============================= + # Positive Scenarios + # ============================= + + def test_convert_success_default_mode(self, mock_app_config, patch_provider_manager): + result = ModelConfigConverter.convert(mock_app_config) + + assert result.provider == "openai" + assert result.model == "gpt-4" + assert result.mode == LLMMode.CHAT + assert result.parameters == {"temperature": 0.5} + assert result.stop == [] + + def test_convert_success_with_stop_parameter(self, mock_app_config, patch_provider_manager): + mock_app_config.model.parameters = {"temperature": 0.7, "stop": ["\n"]} + + result = ModelConfigConverter.convert(mock_app_config) + + assert result.parameters == {"temperature": 0.7} + assert result.stop == ["\n"] + + def test_convert_mode_from_schema_valid(self, mock_app_config, mock_provider_bundle, mocker): + mock_app_config.model.mode = None + + mock_provider_bundle.model_type_instance.get_model_schema.return_value.model_properties = { + ModelPropertyKey.MODE: LLMMode.COMPLETION.value + } + + mock_manager = MagicMock() + mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.create_plugin_provider_manager", + return_value=mock_manager, + ) + + result = ModelConfigConverter.convert(mock_app_config) + assert result.mode == LLMMode.COMPLETION + + def test_convert_mode_from_schema_invalid_fallback(self, mock_app_config, mock_provider_bundle, mocker): + mock_provider_bundle.model_type_instance.get_model_schema.return_value.model_properties = { + ModelPropertyKey.MODE: "invalid" + } + + mock_manager = MagicMock() + mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.create_plugin_provider_manager", + return_value=mock_manager, + ) + + result = ModelConfigConverter.convert(mock_app_config) + assert result.mode == LLMMode.CHAT + + # ============================= + # Credential Errors + # ============================= + + def test_convert_credentials_none_raises(self, mock_app_config, mock_provider_bundle, mocker): + mock_provider_bundle.configuration.get_current_credentials.return_value = None + + mock_manager = MagicMock() + mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.create_plugin_provider_manager", + return_value=mock_manager, + ) + + with pytest.raises(ProviderTokenNotInitError): + ModelConfigConverter.convert(mock_app_config) + + # ============================= + # Provider Model Errors + # ============================= + + def test_convert_provider_model_none_raises(self, mock_app_config, mock_provider_bundle, mocker): + mock_provider_bundle.configuration.get_provider_model.return_value = None + + mock_manager = MagicMock() + mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.create_plugin_provider_manager", + return_value=mock_manager, + ) + + with pytest.raises(ValueError): + ModelConfigConverter.convert(mock_app_config) + + @pytest.mark.parametrize( + ("status", "expected_exception"), + [ + (ModelStatus.NO_CONFIGURE, ProviderTokenNotInitError), + (ModelStatus.NO_PERMISSION, ModelCurrentlyNotSupportError), + (ModelStatus.QUOTA_EXCEEDED, QuotaExceededError), + ], + ) + def test_convert_provider_model_status_errors( + self, mock_app_config, mock_provider_bundle, mocker, status, expected_exception + ): + mock_provider = MagicMock() + mock_provider.status = status + mock_provider_bundle.configuration.get_provider_model.return_value = mock_provider + + mock_manager = MagicMock() + mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.create_plugin_provider_manager", + return_value=mock_manager, + ) + + with pytest.raises(expected_exception): + ModelConfigConverter.convert(mock_app_config) + + # ============================= + # Schema Errors + # ============================= + + def test_convert_model_schema_none_raises(self, mock_app_config, mock_provider_bundle, mocker): + mock_provider_bundle.model_type_instance.get_model_schema.return_value = None + + mock_manager = MagicMock() + mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.create_plugin_provider_manager", + return_value=mock_manager, + ) + + with pytest.raises(ValueError): + ModelConfigConverter.convert(mock_app_config) + + # ============================= + # Edge Cases + # ============================= + + @pytest.mark.parametrize( + "parameters", + [ + {}, + {"stop": []}, + {"stop": ["END"], "max_tokens": 100}, + ], + ) + def test_convert_parameter_edge_cases(self, mock_app_config, patch_provider_manager, parameters): + mock_app_config.model.parameters = parameters.copy() + + result = ModelConfigConverter.convert(mock_app_config) + + if "stop" in parameters: + assert result.stop == parameters.get("stop") + expected_params = parameters.copy() + expected_params.pop("stop", None) + assert result.parameters == expected_params + else: + assert result.stop == [] + assert result.parameters == parameters diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_manager.py new file mode 100644 index 00000000000..68bca485bbd --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_manager.py @@ -0,0 +1,216 @@ +from unittest.mock import MagicMock + +import pytest + +# Target +from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager + +# ----------------------------- +# Fixtures +# ----------------------------- + + +@pytest.fixture +def valid_completion_params(): + return {"temperature": 0.7, "stop": ["\n"]} + + +@pytest.fixture +def valid_model_list(): + model = MagicMock() + model.model = "gpt-4" + model.model_properties = {"mode": "chat"} + return [model] + + +@pytest.fixture +def provider_entities(): + provider = MagicMock() + provider.provider = "openai/gpt" + return [provider] + + +@pytest.fixture +def valid_config(): + return { + "model": {"provider": "openai/gpt", "name": "gpt-4", "completion_params": {"temperature": 0.5, "stop": ["END"]}} + } + + +# ----------------------------- +# Test Class +# ----------------------------- + + +class TestModelConfigManager: + @staticmethod + def _patch_model_assembly(mocker, *, provider_entities, model_list): + assembly = MagicMock() + assembly.model_provider_factory.get_providers.return_value = provider_entities + assembly.provider_manager.get_configurations.return_value.get_models.return_value = model_list + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.manager.create_plugin_model_assembly", + return_value=assembly, + ) + return assembly + + # ========================================================== + # convert + # ========================================================== + + def test_convert_success(self, valid_config): + result = ModelConfigManager.convert(valid_config) + + assert result.provider == "openai/gpt" + assert result.model == "gpt-4" + assert result.parameters == {"temperature": 0.5} + assert result.stop == ["END"] + + def test_convert_missing_model(self): + with pytest.raises(ValueError, match="model is required"): + ModelConfigManager.convert({}) + + def test_convert_without_stop(self): + config = {"model": {"provider": "openai/gpt", "name": "gpt-4", "completion_params": {"temperature": 0.9}}} + result = ModelConfigManager.convert(config) + assert result.stop == [] + assert result.parameters == {"temperature": 0.9} + + # ========================================================== + # validate_model_completion_params + # ========================================================== + + @pytest.mark.parametrize( + "invalid_cp", + [None, "string", 123, []], + ) + def test_validate_model_completion_params_invalid_type(self, invalid_cp): + with pytest.raises(ValueError, match="must be of object type"): + ModelConfigManager.validate_model_completion_params(invalid_cp) + + def test_validate_model_completion_params_default_stop(self): + cp = {"temperature": 0.2} + result = ModelConfigManager.validate_model_completion_params(cp) + assert result["stop"] == [] + + def test_validate_model_completion_params_invalid_stop_type(self): + cp = {"stop": "invalid"} + with pytest.raises(ValueError, match="must be of list type"): + ModelConfigManager.validate_model_completion_params(cp) + + def test_validate_model_completion_params_stop_length_exceeded(self): + cp = {"stop": [1, 2, 3, 4, 5]} + with pytest.raises(ValueError, match="less than 4"): + ModelConfigManager.validate_model_completion_params(cp) + + # ========================================================== + # validate_and_set_defaults + # ========================================================== + + def test_validate_and_set_defaults_success(self, mocker, valid_config, provider_entities, valid_model_list): + self._patch_model_assembly( + mocker, + provider_entities=provider_entities, + model_list=valid_model_list, + ) + + updated_config, keys = ModelConfigManager.validate_and_set_defaults("tenant1", valid_config) + + assert updated_config["model"]["mode"] == "chat" + assert keys == ["model"] + + def test_validate_and_set_defaults_missing_model(self): + with pytest.raises(ValueError, match="model is required"): + ModelConfigManager.validate_and_set_defaults("tenant1", {}) + + def test_validate_and_set_defaults_model_not_dict(self): + with pytest.raises(ValueError, match="object type"): + ModelConfigManager.validate_and_set_defaults("tenant1", {"model": "invalid"}) + + def test_validate_and_set_defaults_missing_provider(self, mocker, provider_entities): + config = {"model": {"name": "gpt-4", "completion_params": {}}} + self._patch_model_assembly(mocker, provider_entities=provider_entities, model_list=[]) + + with pytest.raises(ValueError, match="model.provider is required"): + ModelConfigManager.validate_and_set_defaults("tenant1", config) + + def test_validate_and_set_defaults_invalid_provider(self, mocker, provider_entities): + config = {"model": {"provider": "invalid/provider", "name": "gpt-4", "completion_params": {}}} + self._patch_model_assembly(mocker, provider_entities=provider_entities, model_list=[]) + + with pytest.raises(ValueError, match="model.provider is required"): + ModelConfigManager.validate_and_set_defaults("tenant1", config) + + def test_validate_and_set_defaults_missing_name(self, mocker, provider_entities): + config = {"model": {"provider": "openai/gpt", "completion_params": {}}} + self._patch_model_assembly(mocker, provider_entities=provider_entities, model_list=[]) + + with pytest.raises(ValueError, match="model.name is required"): + ModelConfigManager.validate_and_set_defaults("tenant1", config) + + def test_validate_and_set_defaults_empty_models(self, mocker, provider_entities): + config = {"model": {"provider": "openai/gpt", "name": "gpt-4", "completion_params": {}}} + self._patch_model_assembly(mocker, provider_entities=provider_entities, model_list=[]) + + with pytest.raises(ValueError, match="must be in the specified model list"): + ModelConfigManager.validate_and_set_defaults("tenant1", config) + + def test_validate_and_set_defaults_invalid_model_name(self, mocker, provider_entities, valid_model_list): + config = {"model": {"provider": "openai/gpt", "name": "invalid", "completion_params": {}}} + self._patch_model_assembly( + mocker, + provider_entities=provider_entities, + model_list=valid_model_list, + ) + + with pytest.raises(ValueError, match="must be in the specified model list"): + ModelConfigManager.validate_and_set_defaults("tenant1", config) + + def test_validate_and_set_defaults_default_mode_when_missing(self, mocker, provider_entities): + model = MagicMock() + model.model = "gpt-4" + model.model_properties = {} + + config = {"model": {"provider": "openai/gpt", "name": "gpt-4", "completion_params": {}}} + self._patch_model_assembly(mocker, provider_entities=provider_entities, model_list=[model]) + + updated_config, _ = ModelConfigManager.validate_and_set_defaults("tenant1", config) + + assert updated_config["model"]["mode"] == "completion" + + def test_validate_and_set_defaults_missing_completion_params(self, mocker, provider_entities, valid_model_list): + config = {"model": {"provider": "openai/gpt", "name": "gpt-4"}} + self._patch_model_assembly( + mocker, + provider_entities=provider_entities, + model_list=valid_model_list, + ) + + with pytest.raises(ValueError, match="completion_params is required"): + ModelConfigManager.validate_and_set_defaults("tenant1", config) + + def test_validate_and_set_defaults_provider_without_slash_converted(self, mocker, valid_model_list): + """ + Covers branch where provider does not contain '/' and + ModelProviderID conversion is triggered (line 64). + """ + config = { + "model": { + "provider": "openai", # no slash -> triggers conversion + "name": "gpt-4", + "completion_params": {}, + } + } + + # Mock ModelProviderID to return formatted provider + mock_provider_id = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderID") + mock_provider_id.return_value = "openai/gpt" + provider_entity = MagicMock() + provider_entity.provider = "openai/gpt" + self._patch_model_assembly(mocker, provider_entities=[provider_entity], model_list=valid_model_list) + + updated_config, _ = ModelConfigManager.validate_and_set_defaults("tenant1", config) + + # Ensure conversion happened + mock_provider_id.assert_called_once_with("openai") + assert updated_config["model"]["provider"] == "openai/gpt" diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_prompt_template_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_prompt_template_manager.py new file mode 100644 index 00000000000..fd49072cd53 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_prompt_template_manager.py @@ -0,0 +1,292 @@ +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.easy_ui_based_app.prompt_template.manager import ( + PromptTemplateConfigManager, +) + +# ----------------------------- +# Helpers +# ----------------------------- + + +class DummyEnumValue: + def __init__(self, value): + self.value = value + + +class DummyPromptType: + def __init__(self): + self.SIMPLE = "simple" + self.ADVANCED = "advanced" + + def value_of(self, value): + return value + + def __iter__(self): + return iter([DummyEnumValue("simple"), DummyEnumValue("advanced")]) + + +# ----------------------------- +# Convert Tests +# ----------------------------- + + +class TestPromptTemplateConfigManagerConvert: + def test_convert_missing_prompt_type_raises(self): + with pytest.raises(ValueError, match="prompt_type is required"): + PromptTemplateConfigManager.convert({}) + + def test_convert_simple_prompt(self, mocker): + mock_prompt_entity_cls = MagicMock() + mock_prompt_entity_cls.PromptType = DummyPromptType() + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity", + mock_prompt_entity_cls, + ) + + mock_prompt_entity_cls.return_value = "simple_entity" + + config = {"prompt_type": "simple", "pre_prompt": "hello"} + + result = PromptTemplateConfigManager.convert(config) + + assert result == "simple_entity" + mock_prompt_entity_cls.assert_called_once_with(prompt_type="simple", simple_prompt_template="hello") + + def test_convert_advanced_chat_valid(self, mocker): + mock_prompt_entity_cls = MagicMock() + mock_prompt_entity_cls.PromptType = DummyPromptType() + mock_prompt_entity_cls.return_value = "advanced_entity" + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity", + mock_prompt_entity_cls, + ) + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptMessageRole.value_of", + return_value="role_enum", + ) + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.AdvancedChatMessageEntity", + return_value="chat_msg", + ) + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.AdvancedChatPromptTemplateEntity", + return_value="chat_template", + ) + + config = { + "prompt_type": "advanced", + "chat_prompt_config": {"prompt": [{"text": "hi", "role": "user"}]}, + } + + result = PromptTemplateConfigManager.convert(config) + + assert result == "advanced_entity" + + @pytest.mark.parametrize( + "message", + [ + {"text": 123, "role": "user"}, + {"text": "hi", "role": 123}, + ], + ) + def test_convert_advanced_invalid_message_fields(self, mocker, message): + mock_prompt_entity_cls = MagicMock() + mock_prompt_entity_cls.PromptType = DummyPromptType() + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity", + mock_prompt_entity_cls, + ) + + config = { + "prompt_type": "advanced", + "chat_prompt_config": {"prompt": [message]}, + } + + with pytest.raises(ValueError): + PromptTemplateConfigManager.convert(config) + + def test_convert_advanced_completion_with_roles(self, mocker): + mock_prompt_entity_cls = MagicMock() + mock_prompt_entity_cls.PromptType = DummyPromptType() + mock_prompt_entity_cls.return_value = "advanced_entity" + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity", + mock_prompt_entity_cls, + ) + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.AdvancedCompletionPromptTemplateEntity", + return_value="completion_template", + ) + + config = { + "prompt_type": "advanced", + "completion_prompt_config": { + "prompt": {"text": "complete"}, + "conversation_histories_role": { + "user_prefix": "U", + "assistant_prefix": "A", + }, + }, + } + + result = PromptTemplateConfigManager.convert(config) + + assert result == "advanced_entity" + + +# ----------------------------- +# validate_and_set_defaults +# ----------------------------- + + +class TestValidateAndSetDefaults: + def setup_method(self): + self.valid_model = {"mode": "chat"} + + def _patch_prompt_type(self, mocker): + mock_prompt_entity_cls = MagicMock() + mock_prompt_entity_cls.PromptType = DummyPromptType() + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity", + mock_prompt_entity_cls, + ) + return mock_prompt_entity_cls + + def test_default_prompt_type_set(self, mocker): + self._patch_prompt_type(mocker) + + config = {"model": self.valid_model} + + result, keys = PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config) + + assert result["prompt_type"] == "simple" + assert isinstance(keys, list) + + def test_invalid_prompt_type_raises(self, mocker): + class InvalidEnum(DummyPromptType): + def __iter__(self): + return iter([DummyEnumValue("valid")]) + + mock_prompt_entity_cls = MagicMock() + mock_prompt_entity_cls.PromptType = InvalidEnum() + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity", + mock_prompt_entity_cls, + ) + + config = {"prompt_type": "invalid", "model": self.valid_model} + + with pytest.raises(ValueError): + PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config) + + def test_invalid_chat_prompt_config_type(self, mocker): + self._patch_prompt_type(mocker) + + config = { + "prompt_type": "simple", + "chat_prompt_config": "invalid", + "model": self.valid_model, + } + + with pytest.raises(ValueError): + PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config) + + def test_simple_mode_invalid_pre_prompt_type(self, mocker): + self._patch_prompt_type(mocker) + + config = { + "prompt_type": "simple", + "pre_prompt": 123, + "model": self.valid_model, + } + + with pytest.raises(ValueError): + PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config) + + def test_advanced_requires_one_config(self, mocker): + self._patch_prompt_type(mocker) + + config = { + "prompt_type": "advanced", + "chat_prompt_config": {}, + "completion_prompt_config": {}, + "model": {"mode": "chat"}, + } + + with pytest.raises(ValueError): + PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config) + + def test_advanced_invalid_model_mode(self, mocker): + self._patch_prompt_type(mocker) + + config = { + "prompt_type": "advanced", + "chat_prompt_config": {"prompt": []}, + "model": {"mode": "invalid"}, + } + + with pytest.raises(ValueError): + PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config) + + def test_advanced_chat_prompt_length_exceeds(self, mocker): + self._patch_prompt_type(mocker) + + config = { + "prompt_type": "advanced", + "chat_prompt_config": {"prompt": [{}] * 11}, + "model": {"mode": "chat"}, + } + + with pytest.raises(ValueError): + PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config) + + def test_completion_prefix_defaults_set_when_empty(self, mocker): + self._patch_prompt_type(mocker) + + config = { + "prompt_type": "advanced", + "completion_prompt_config": { + "prompt": {"text": "hi"}, + "conversation_histories_role": { + "user_prefix": "", + "assistant_prefix": "", + }, + }, + "model": {"mode": "completion"}, + } + + updated, _ = PromptTemplateConfigManager.validate_and_set_defaults("chat", config) + + roles = updated["completion_prompt_config"]["conversation_histories_role"] + assert roles["user_prefix"] == "Human" + assert roles["assistant_prefix"] == "Assistant" + + +# ----------------------------- +# validate_post_prompt +# ----------------------------- + + +class TestValidatePostPrompt: + @pytest.mark.parametrize("value", [None, ""]) + def test_post_prompt_defaults(self, value): + config = {"post_prompt": value} + result = PromptTemplateConfigManager.validate_post_prompt_and_set_defaults(config) + assert result["post_prompt"] == "" + + def test_post_prompt_invalid_type(self): + config = {"post_prompt": 123} + with pytest.raises(ValueError): + PromptTemplateConfigManager.validate_post_prompt_and_set_defaults(config) diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py new file mode 100644 index 00000000000..e2f3c16335f --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py @@ -0,0 +1,286 @@ +import pytest +from graphon.variables.input_entities import VariableEntityType + +from core.app.app_config.easy_ui_based_app.variables.manager import ( + BasicVariablesConfigManager, +) + + +class TestBasicVariablesConfigManagerConvert: + def test_convert_empty_config(self): + config = {} + + variables, external = BasicVariablesConfigManager.convert(config) + + assert variables == [] + assert external == [] + + def test_convert_external_data_tools_enabled_and_disabled(self, mocker): + config = { + "external_data_tools": [ + {"enabled": False}, + { + "enabled": True, + "variable": "ext_var", + "type": "tool_type", + "config": {"k": "v"}, + }, + ] + } + + variables, external = BasicVariablesConfigManager.convert(config) + + assert variables == [] + assert len(external) == 1 + assert external[0].variable == "ext_var" + assert external[0].type == "tool_type" + + def test_convert_user_input_form_variable_types(self): + config = { + "user_input_form": [ + { + VariableEntityType.TEXT_INPUT: { + "variable": "name", + "label": "Name", + "description": "desc", + "required": True, + "max_length": 50, + } + }, + { + VariableEntityType.SELECT: { + "variable": "choice", + "label": "Choice", + "options": ["a", "b"], + } + }, + { + VariableEntityType.EXTERNAL_DATA_TOOL: { + "variable": "ext", + "type": "tool", + "config": {"x": 1}, + } + }, + ] + } + + variables, external = BasicVariablesConfigManager.convert(config) + + assert len(variables) == 2 + assert len(external) == 1 + + def test_convert_external_data_tool_without_config_skipped(self): + config = { + "user_input_form": [ + { + VariableEntityType.EXTERNAL_DATA_TOOL: { + "variable": "ext", + "type": "tool", + } + } + ] + } + + variables, external = BasicVariablesConfigManager.convert(config) + + assert variables == [] + assert external == [] + + +class TestValidateVariablesAndSetDefaults: + def test_validate_sets_empty_user_input_form_if_missing(self): + config = {} + + updated, keys = BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + assert updated["user_input_form"] == [] + assert "user_input_form" in keys + + def test_validate_user_input_form_not_list_raises(self): + config = {"user_input_form": "invalid"} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_invalid_key_raises(self): + config = {"user_input_form": [{"invalid": {}}]} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_missing_label_raises(self): + config = {"user_input_form": [{VariableEntityType.TEXT_INPUT: {"variable": "name"}}]} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_label_not_string_raises(self): + config = {"user_input_form": [{VariableEntityType.TEXT_INPUT: {"variable": "name", "label": 123}}]} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_missing_variable_raises(self): + config = {"user_input_form": [{VariableEntityType.TEXT_INPUT: {"label": "Name"}}]} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_variable_not_string_raises(self): + config = {"user_input_form": [{VariableEntityType.TEXT_INPUT: {"label": "Name", "variable": 123}}]} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + @pytest.mark.parametrize( + "variable_name", + ["1invalid", "invalid space", "", None], + ) + def test_validate_variable_invalid_pattern_raises(self, variable_name): + config = { + "user_input_form": [ + { + VariableEntityType.TEXT_INPUT: { + "label": "Name", + "variable": variable_name, + } + } + ] + } + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_required_default_and_type(self): + config = { + "user_input_form": [ + { + VariableEntityType.TEXT_INPUT: { + "label": "Name", + "variable": "valid_name", + } + } + ] + } + + updated, _ = BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + assert updated["user_input_form"][0][VariableEntityType.TEXT_INPUT]["required"] is False + + def test_validate_required_not_bool_raises(self): + config = { + "user_input_form": [ + { + VariableEntityType.TEXT_INPUT: { + "label": "Name", + "variable": "valid_name", + "required": "yes", + } + } + ] + } + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_select_options_default_not_in_options_raises(self): + config = { + "user_input_form": [ + { + VariableEntityType.SELECT: { + "label": "Choice", + "variable": "choice", + "options": ["a", "b"], + "default": "c", + } + } + ] + } + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_select_options_not_list_raises(self): + config = { + "user_input_form": [ + { + VariableEntityType.SELECT: { + "label": "Choice", + "variable": "choice", + "options": "not_list", + } + } + ] + } + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + +class TestValidateExternalDataToolsAndSetDefaults: + def test_validate_sets_empty_external_data_tools_if_missing(self): + config = {} + + updated, keys = BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant", config) + + assert updated["external_data_tools"] == [] + assert "external_data_tools" in keys + + def test_validate_external_data_tools_not_list_raises(self): + config = {"external_data_tools": "invalid"} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant", config) + + def test_validate_disabled_tool_skipped(self, mocker): + config = {"external_data_tools": [{"enabled": False}]} + + spy = mocker.patch( + "core.app.app_config.easy_ui_based_app.variables.manager.ExternalDataToolFactory.validate_config" + ) + + updated, _ = BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant", config) + + spy.assert_not_called() + assert updated["external_data_tools"][0]["enabled"] is False + + def test_validate_enabled_tool_missing_type_raises(self): + config = {"external_data_tools": [{"enabled": True, "config": {}}]} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant", config) + + def test_validate_enabled_tool_calls_factory(self, mocker): + config = {"external_data_tools": [{"enabled": True, "type": "tool", "config": {"a": 1}}]} + + spy = mocker.patch( + "core.app.app_config.easy_ui_based_app.variables.manager.ExternalDataToolFactory.validate_config" + ) + + BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant_id", config) + + spy.assert_called_once_with(name="tool", tenant_id="tenant_id", config={"a": 1}) + + +class TestValidateAndSetDefaultsIntegration: + def test_validate_and_set_defaults_calls_both(self, mocker): + config = {} + + spy_var = mocker.patch.object( + BasicVariablesConfigManager, + "validate_variables_and_set_defaults", + return_value=(config, ["user_input_form"]), + ) + spy_ext = mocker.patch.object( + BasicVariablesConfigManager, + "validate_external_data_tools_and_set_defaults", + return_value=(config, ["external_data_tools"]), + ) + + updated, keys = BasicVariablesConfigManager.validate_and_set_defaults("tenant", config) + + spy_var.assert_called_once() + spy_ext.assert_called_once() + assert "user_input_form" in keys + assert "external_data_tools" in keys + assert updated == config diff --git a/api/dify_graph/model_runtime/schema_validators/__init__.py b/api/tests/unit_tests/core/app/app_config/features/__init__.py similarity index 100% rename from api/dify_graph/model_runtime/schema_validators/__init__.py rename to api/tests/unit_tests/core/app/app_config/features/__init__.py diff --git a/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py b/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py index de99833aac7..8bde9c1f979 100644 --- a/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py +++ b/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py @@ -1,6 +1,7 @@ +from graphon.file import FileTransferMethod, FileUploadConfig, ImageConfig +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent + from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from dify_graph.file.models import FileTransferMethod, FileUploadConfig, ImageConfig -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent def test_convert_with_vision(): diff --git a/api/tests/unit_tests/core/app/app_config/features/test_additional_feature_managers.py b/api/tests/unit_tests/core/app/app_config/features/test_additional_feature_managers.py new file mode 100644 index 00000000000..dd00c3defc3 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/features/test_additional_feature_managers.py @@ -0,0 +1,115 @@ +import pytest + +from core.app.app_config.entities import TextToSpeechEntity +from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager +from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager +from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager +from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager +from core.app.app_config.features.suggested_questions_after_answer.manager import ( + SuggestedQuestionsAfterAnswerConfigManager, +) +from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager + + +class TestAdditionalFeatureManagers: + def test_opening_statement_validate_defaults(self): + config, keys = OpeningStatementConfigManager.validate_and_set_defaults({}) + assert config["opening_statement"] == "" + assert config["suggested_questions"] == [] + assert set(keys) == {"opening_statement", "suggested_questions"} + + def test_opening_statement_validate_types(self): + with pytest.raises(ValueError): + OpeningStatementConfigManager.validate_and_set_defaults({"opening_statement": 123}) + with pytest.raises(ValueError): + OpeningStatementConfigManager.validate_and_set_defaults( + {"opening_statement": "hi", "suggested_questions": "bad"} + ) + with pytest.raises(ValueError): + OpeningStatementConfigManager.validate_and_set_defaults( + {"opening_statement": "hi", "suggested_questions": [1]} + ) + + def test_opening_statement_convert(self): + opening, questions = OpeningStatementConfigManager.convert( + {"opening_statement": "hello", "suggested_questions": ["q1"]} + ) + assert opening == "hello" + assert questions == ["q1"] + + def test_retrieval_resource_validate(self): + config, keys = RetrievalResourceConfigManager.validate_and_set_defaults({}) + assert config["retriever_resource"]["enabled"] is False + assert keys == ["retriever_resource"] + + with pytest.raises(ValueError): + RetrievalResourceConfigManager.validate_and_set_defaults({"retriever_resource": "bad"}) + with pytest.raises(ValueError): + RetrievalResourceConfigManager.validate_and_set_defaults({"retriever_resource": {"enabled": "yes"}}) + + def test_retrieval_resource_convert(self): + assert RetrievalResourceConfigManager.convert({"retriever_resource": {"enabled": True}}) is True + assert RetrievalResourceConfigManager.convert({"retriever_resource": {"enabled": False}}) is False + + def test_speech_to_text_validate_and_convert(self): + config, keys = SpeechToTextConfigManager.validate_and_set_defaults({}) + assert config["speech_to_text"]["enabled"] is False + assert keys == ["speech_to_text"] + + with pytest.raises(ValueError): + SpeechToTextConfigManager.validate_and_set_defaults({"speech_to_text": "bad"}) + with pytest.raises(ValueError): + SpeechToTextConfigManager.validate_and_set_defaults({"speech_to_text": {"enabled": "yes"}}) + + assert SpeechToTextConfigManager.convert({"speech_to_text": {"enabled": True}}) is True + assert SpeechToTextConfigManager.convert({"speech_to_text": {"enabled": False}}) is False + + def test_suggested_questions_after_answer_validate_and_convert(self): + config, keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults({}) + assert config["suggested_questions_after_answer"]["enabled"] is False + assert keys == ["suggested_questions_after_answer"] + + with pytest.raises(ValueError): + SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( + {"suggested_questions_after_answer": "bad"} + ) + with pytest.raises(ValueError): + SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( + {"suggested_questions_after_answer": {"enabled": "yes"}} + ) + + assert ( + SuggestedQuestionsAfterAnswerConfigManager.convert({"suggested_questions_after_answer": {"enabled": True}}) + is True + ) + assert ( + SuggestedQuestionsAfterAnswerConfigManager.convert({"suggested_questions_after_answer": {"enabled": False}}) + is False + ) + + def test_text_to_speech_validate_and_convert(self): + config, keys = TextToSpeechConfigManager.validate_and_set_defaults({}) + assert config["text_to_speech"]["enabled"] is False + assert keys == ["text_to_speech"] + + with pytest.raises(ValueError): + TextToSpeechConfigManager.validate_and_set_defaults({"text_to_speech": "bad"}) + with pytest.raises(ValueError): + TextToSpeechConfigManager.validate_and_set_defaults({"text_to_speech": {"enabled": "yes"}}) + + result = TextToSpeechConfigManager.convert( + {"text_to_speech": {"enabled": True, "voice": "v", "language": "en"}} + ) + assert isinstance(result, TextToSpeechEntity) + assert result.voice == "v" + assert result.language == "en" + + def test_more_like_this_convert_and_validate(self): + config, keys = MoreLikeThisConfigManager.validate_and_set_defaults({}) + assert config["more_like_this"]["enabled"] is False + assert keys == ["more_like_this"] + + assert MoreLikeThisConfigManager.convert({"more_like_this": {"enabled": True}}) is True + assert MoreLikeThisConfigManager.convert({"more_like_this": {"enabled": False}}) is False + with pytest.raises(ValueError): + MoreLikeThisConfigManager.validate_and_set_defaults({"more_like_this": "bad"}) diff --git a/api/tests/unit_tests/core/app/app_config/test_base_app_config_manager.py b/api/tests/unit_tests/core/app/app_config/test_base_app_config_manager.py new file mode 100644 index 00000000000..e99852cf761 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/test_base_app_config_manager.py @@ -0,0 +1,180 @@ +from collections import UserDict +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.base_app_config_manager import BaseAppConfigManager + + +class TestBaseAppConfigManager: + @pytest.fixture + def mock_config_dict(self): + return {"key": "value", "another": 123} + + @pytest.fixture + def mock_app_additional_features(self, mocker): + mock_instance = MagicMock() + mocker.patch( + "core.app.app_config.base_app_config_manager.AppAdditionalFeatures", + return_value=mock_instance, + ) + return mock_instance + + @pytest.fixture + def mock_managers(self, mocker): + retrieval = mocker.patch( + "core.app.app_config.base_app_config_manager.RetrievalResourceConfigManager.convert", + return_value="retrieval_result", + ) + file_upload = mocker.patch( + "core.app.app_config.base_app_config_manager.FileUploadConfigManager.convert", + return_value="file_upload_result", + ) + opening_statement = mocker.patch( + "core.app.app_config.base_app_config_manager.OpeningStatementConfigManager.convert", + return_value=("opening_result", "suggested_result"), + ) + suggested_after = mocker.patch( + "core.app.app_config.base_app_config_manager.SuggestedQuestionsAfterAnswerConfigManager.convert", + return_value="suggested_after_result", + ) + more_like_this = mocker.patch( + "core.app.app_config.base_app_config_manager.MoreLikeThisConfigManager.convert", + return_value="more_like_this_result", + ) + speech_to_text = mocker.patch( + "core.app.app_config.base_app_config_manager.SpeechToTextConfigManager.convert", + return_value="speech_to_text_result", + ) + text_to_speech = mocker.patch( + "core.app.app_config.base_app_config_manager.TextToSpeechConfigManager.convert", + return_value="text_to_speech_result", + ) + + return { + "retrieval": retrieval, + "file_upload": file_upload, + "opening_statement": opening_statement, + "suggested_after": suggested_after, + "more_like_this": more_like_this, + "speech_to_text": speech_to_text, + "text_to_speech": text_to_speech, + } + + @pytest.mark.parametrize( + ("app_mode", "expected_is_vision"), + [ + ("CHAT", True), + ("COMPLETION", True), + ("AGENT_CHAT", True), + ("OTHER", False), + ], + ) + def test_convert_features_all_modes( + self, + mocker, + mock_config_dict, + mock_app_additional_features, + mock_managers, + app_mode, + expected_is_vision, + ): + # Arrange + mock_app_mode = MagicMock() + mock_app_mode.CHAT = "CHAT" + mock_app_mode.COMPLETION = "COMPLETION" + mock_app_mode.AGENT_CHAT = "AGENT_CHAT" + + mocker.patch( + "core.app.app_config.base_app_config_manager.AppMode", + mock_app_mode, + ) + + # Act + result = BaseAppConfigManager.convert_features(mock_config_dict, app_mode) + + # Assert + assert result == mock_app_additional_features + mock_managers["retrieval"].assert_called_once_with(config=dict(mock_config_dict.items())) + mock_managers["file_upload"].assert_called_once() + _, kwargs = mock_managers["file_upload"].call_args + assert kwargs["config"] == dict(mock_config_dict.items()) + assert kwargs["is_vision"] is expected_is_vision + + mock_managers["opening_statement"].assert_called_once_with(config=dict(mock_config_dict.items())) + mock_managers["suggested_after"].assert_called_once_with(config=dict(mock_config_dict.items())) + mock_managers["more_like_this"].assert_called_once_with(config=dict(mock_config_dict.items())) + mock_managers["speech_to_text"].assert_called_once_with(config=dict(mock_config_dict.items())) + mock_managers["text_to_speech"].assert_called_once_with(config=dict(mock_config_dict.items())) + + def test_convert_features_empty_config(self, mocker, mock_app_additional_features, mock_managers): + # Arrange + empty_config = {} + mock_app_mode = MagicMock() + mock_app_mode.CHAT = "CHAT" + mock_app_mode.COMPLETION = "COMPLETION" + mock_app_mode.AGENT_CHAT = "AGENT_CHAT" + + mocker.patch( + "core.app.app_config.base_app_config_manager.AppMode", + mock_app_mode, + ) + + # Act + result = BaseAppConfigManager.convert_features(empty_config, "CHAT") + + # Assert + assert result == mock_app_additional_features + for manager in mock_managers.values(): + assert manager.called + + @pytest.mark.parametrize( + "invalid_config", + [ + None, + "string", + 123, + 12.34, + [], + ], + ) + def test_convert_features_invalid_config_raises(self, invalid_config): + # Act & Assert + with pytest.raises((TypeError, AttributeError)): + BaseAppConfigManager.convert_features(invalid_config, "CHAT") + + def test_convert_features_manager_exception_propagates(self, mocker, mock_config_dict): + # Arrange + mocker.patch( + "core.app.app_config.base_app_config_manager.RetrievalResourceConfigManager.convert", + side_effect=RuntimeError("manager failure"), + ) + + # Act & Assert + with pytest.raises(RuntimeError): + BaseAppConfigManager.convert_features(mock_config_dict, "CHAT") + + def test_convert_features_mapping_subclass(self, mocker, mock_app_additional_features, mock_managers): + # Arrange + class CustomMapping(UserDict): + pass + + custom_config = CustomMapping({"a": 1}) + + mock_app_mode = MagicMock() + mock_app_mode.CHAT = "CHAT" + mock_app_mode.COMPLETION = "COMPLETION" + mock_app_mode.AGENT_CHAT = "AGENT_CHAT" + + mocker.patch( + "core.app.app_config.base_app_config_manager.AppMode", + mock_app_mode, + ) + + # Act + result = BaseAppConfigManager.convert_features(custom_config, "CHAT") + + # Assert + assert result == mock_app_additional_features + for manager in mock_managers.values(): + assert manager.called diff --git a/api/tests/unit_tests/core/app/app_config/test_entities.py b/api/tests/unit_tests/core/app/app_config/test_entities.py new file mode 100644 index 00000000000..000f83cd5a2 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/test_entities.py @@ -0,0 +1,43 @@ +import pytest +from graphon.variables.input_entities import VariableEntity, VariableEntityType + +from core.app.app_config.entities import ( + DatasetRetrieveConfigEntity, + PromptTemplateEntity, +) + + +class TestAppConfigEntities: + def test_variable_entity_coerces_none_description_and_options(self): + entity = VariableEntity( + variable="query", + label="Query", + description=None, + type=VariableEntityType.TEXT_INPUT, + options=None, + ) + + assert entity.description == "" + assert entity.options == [] + + def test_variable_entity_rejects_invalid_json_schema(self): + with pytest.raises(ValueError): + VariableEntity( + variable="query", + label="Query", + type=VariableEntityType.TEXT_INPUT, + json_schema={"type": "string", "minLength": "bad"}, + ) + + def test_prompt_template_value_of(self): + assert PromptTemplateEntity.PromptType.value_of("simple") == PromptTemplateEntity.PromptType.SIMPLE + with pytest.raises(ValueError): + PromptTemplateEntity.PromptType.value_of("missing") + + def test_dataset_retrieve_strategy_value_of(self): + assert ( + DatasetRetrieveConfigEntity.RetrieveStrategy.value_of("single") + == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE + ) + with pytest.raises(ValueError): + DatasetRetrieveConfigEntity.RetrieveStrategy.value_of("missing") diff --git a/api/tests/unit_tests/core/app/app_config/workflow_ui_based_app/test_workflow_ui_based_app_manager.py b/api/tests/unit_tests/core/app/app_config/workflow_ui_based_app/test_workflow_ui_based_app_manager.py new file mode 100644 index 00000000000..fa128aca874 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/workflow_ui_based_app/test_workflow_ui_based_app_manager.py @@ -0,0 +1,222 @@ +import pytest + +from core.app.app_config.workflow_ui_based_app.variables.manager import ( + WorkflowVariablesConfigManager, +) + +# ============================= +# Fixtures +# ============================= + + +@pytest.fixture +def mock_workflow(mocker): + workflow = mocker.MagicMock() + workflow.graph_dict = {"nodes": []} + return workflow + + +@pytest.fixture +def mock_variable_entity(mocker): + return mocker.patch("core.app.app_config.workflow_ui_based_app.variables.manager.VariableEntity") + + +@pytest.fixture +def mock_rag_entity(mocker): + return mocker.patch("core.app.app_config.workflow_ui_based_app.variables.manager.RagPipelineVariableEntity") + + +# ============================= +# Test Convert (user_input_form) +# ============================= + + +class TestWorkflowVariablesConfigManagerConvert: + def test_convert_success_multiple_variables(self, mock_workflow, mock_variable_entity): + # Arrange + input_variables = [{"name": "var1"}, {"name": "var2"}] + mock_workflow.user_input_form.return_value = input_variables + mock_variable_entity.model_validate.side_effect = lambda x: {"validated": x} + + # Act + result = WorkflowVariablesConfigManager.convert(mock_workflow) + + # Assert + assert result == [{"validated": v} for v in input_variables] + assert mock_variable_entity.model_validate.call_count == 2 + + def test_convert_empty_list(self, mock_workflow, mock_variable_entity): + # Arrange + mock_workflow.user_input_form.return_value = [] + + # Act + result = WorkflowVariablesConfigManager.convert(mock_workflow) + + # Assert + assert result == [] + mock_variable_entity.model_validate.assert_not_called() + + def test_convert_none_returned_raises(self, mock_workflow): + # Arrange + mock_workflow.user_input_form.return_value = None + + # Act & Assert + with pytest.raises(TypeError): + WorkflowVariablesConfigManager.convert(mock_workflow) + + def test_convert_validation_error_propagates(self, mock_workflow, mock_variable_entity): + # Arrange + mock_workflow.user_input_form.return_value = [{"invalid": "data"}] + mock_variable_entity.model_validate.side_effect = ValueError("validation error") + + # Act & Assert + with pytest.raises(ValueError): + WorkflowVariablesConfigManager.convert(mock_workflow) + + +# ============================= +# Test convert_rag_pipeline_variable +# ============================= + + +class TestWorkflowVariablesConfigManagerConvertRag: + def test_no_rag_pipeline_variables(self, mock_workflow): + # Arrange + mock_workflow.rag_pipeline_variables = [] + + # Act + result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") + + # Assert + assert result == [] + + def test_rag_pipeline_none(self, mock_workflow): + # Arrange + mock_workflow.rag_pipeline_variables = None + + # Act + result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") + + # Assert + assert result == [] + + def test_no_matching_node_keeps_all(self, mock_workflow, mock_rag_entity): + # Arrange + mock_workflow.rag_pipeline_variables = [ + {"variable": "var1", "belong_to_node_id": "node1"}, + ] + mock_workflow.graph_dict = {"nodes": []} + mock_rag_entity.model_validate.side_effect = lambda x: {"validated": x} + + # Act + result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") + + # Assert + assert result == [{"validated": mock_workflow.rag_pipeline_variables[0]}] + + def test_string_pattern_removes_variable(self, mock_workflow, mock_rag_entity): + # Arrange + mock_workflow.rag_pipeline_variables = [ + {"variable": "var1", "belong_to_node_id": "node1"}, + {"variable": "var2", "belong_to_node_id": "node1"}, + ] + + mock_workflow.graph_dict = { + "nodes": [ + { + "id": "node1", + "data": {"datasource_parameters": {"param1": {"value": "{{#parent.var1#}}"}}}, + } + ] + } + + mock_rag_entity.model_validate.side_effect = lambda x: {"validated": x} + + # Act + result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") + + # Assert + assert len(result) == 1 + assert result[0]["validated"]["variable"] == "var2" + + def test_list_value_removes_variable(self, mock_workflow, mock_rag_entity): + # Arrange + mock_workflow.rag_pipeline_variables = [ + {"variable": "var1", "belong_to_node_id": "node1"}, + {"variable": "var2", "belong_to_node_id": "node1"}, + ] + + mock_workflow.graph_dict = { + "nodes": [ + { + "id": "node1", + "data": {"datasource_parameters": {"param1": {"value": ["x", "var1"]}}}, + } + ] + } + + mock_rag_entity.model_validate.side_effect = lambda x: {"validated": x} + + # Act + result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") + + # Assert + assert len(result) == 1 + assert result[0]["validated"]["variable"] == "var2" + + @pytest.mark.parametrize( + ("belong_to_node_id", "expected_count"), + [ + ("node1", 1), + ("shared", 1), + ("other_node", 0), + ], + ) + def test_belong_to_node_filtering(self, mock_workflow, mock_rag_entity, belong_to_node_id, expected_count): + # Arrange + mock_workflow.rag_pipeline_variables = [ + {"variable": "var1", "belong_to_node_id": belong_to_node_id}, + ] + mock_workflow.graph_dict = {"nodes": []} + mock_rag_entity.model_validate.side_effect = lambda x: {"validated": x} + + # Act + result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") + + # Assert + assert len(result) == expected_count + + def test_invalid_pattern_does_not_remove(self, mock_workflow, mock_rag_entity): + # Arrange + mock_workflow.rag_pipeline_variables = [ + {"variable": "var1", "belong_to_node_id": "node1"}, + ] + + mock_workflow.graph_dict = { + "nodes": [ + { + "id": "node1", + "data": {"datasource_parameters": {"param1": {"value": "invalid_pattern"}}}, + } + ] + } + + mock_rag_entity.model_validate.side_effect = lambda x: {"validated": x} + + # Act + result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") + + # Assert + assert len(result) == 1 + + def test_validation_error_propagates(self, mock_workflow, mock_rag_entity): + # Arrange + mock_workflow.rag_pipeline_variables = [ + {"variable": "var1", "belong_to_node_id": "node1"}, + ] + mock_workflow.graph_dict = {"nodes": []} + mock_rag_entity.model_validate.side_effect = RuntimeError("validation failed") + + # Act & Assert + with pytest.raises(RuntimeError): + WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py index 441d2fcd179..af5d203f126 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py @@ -9,10 +9,17 @@ from pydantic import BaseModel, ValidationError from constants import UUID_NIL from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig -from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator, _refresh_model +from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator +from core.app.apps.advanced_chat.generate_task_pipeline import ( + ConversationSnapshot, + MessageSnapshot, + WorkflowSnapshot, +) from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.ops.ops_trace_manager import TraceQueueManager +from libs.datetime_utils import naive_utc_now +from models.enums import MessageStatus from models.model import AppMode @@ -363,8 +370,15 @@ class TestAdvancedChatAppGeneratorInternals: workflow_run_id="run-id", ) + workflow = SimpleNamespace(id="wf-1", tenant_id="tenant", features={"feature": True}, features_dict={}) conversation = SimpleNamespace(id="conv-1", mode=AppMode.ADVANCED_CHAT, override_model_configs=None) - message = SimpleNamespace(id="msg-1") + message = SimpleNamespace( + id="msg-1", + query="hello", + created_at=naive_utc_now(), + status=MessageStatus.NORMAL, + answer="", + ) db_session = SimpleNamespace(commit=MagicMock(), refresh=MagicMock(), close=MagicMock()) captured: dict[str, object] = {} thread_data: dict[str, object] = {} @@ -394,19 +408,6 @@ class TestAdvancedChatAppGeneratorInternals: thread_data["started"] = True monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.threading.Thread", _Thread) - monkeypatch.setattr("core.app.apps.advanced_chat.app_generator._refresh_model", lambda session, model: model) - - class _Session: - def __init__(self, *args, **kwargs): - _ = args, kwargs - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) monkeypatch.setattr( "core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object(), session=db_session) ) @@ -424,7 +425,7 @@ class TestAdvancedChatAppGeneratorInternals: pause_state_config = SimpleNamespace(session_factory="session-factory", state_owner_user_id="owner") response = generator._generate( - workflow=SimpleNamespace(features={"feature": True}), + workflow=workflow, user=SimpleNamespace(id="user"), invoke_from=InvokeFrom.WEB_APP, application_generate_entity=application_generate_entity, @@ -444,6 +445,9 @@ class TestAdvancedChatAppGeneratorInternals: db_session.refresh.assert_called_once_with(conversation) db_session.close.assert_called_once() assert captured["draft_var_saver_factory"] == "draft-factory" + assert isinstance(captured["workflow"], WorkflowSnapshot) + assert isinstance(captured["conversation"], ConversationSnapshot) + assert isinstance(captured["message"], MessageSnapshot) def test_generate_internal_flow_with_existing_records_skips_init(self, monkeypatch): generator = AdvancedChatAppGenerator() @@ -464,8 +468,15 @@ class TestAdvancedChatAppGeneratorInternals: workflow_run_id="run-id", ) + workflow = SimpleNamespace(id="wf-2", tenant_id="tenant", features={}, features_dict={}) conversation = SimpleNamespace(id="conv-2", mode=AppMode.ADVANCED_CHAT, override_model_configs=None) - message = SimpleNamespace(id="msg-2") + message = SimpleNamespace( + id="msg-2", + query="hello", + created_at=naive_utc_now(), + status=MessageStatus.NORMAL, + answer="", + ) db_session = SimpleNamespace(close=MagicMock(), commit=MagicMock(), refresh=MagicMock()) init_records = MagicMock() thread_data: dict[str, object] = {} @@ -491,19 +502,6 @@ class TestAdvancedChatAppGeneratorInternals: thread_data["started"] = True monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.threading.Thread", _Thread) - monkeypatch.setattr("core.app.apps.advanced_chat.app_generator._refresh_model", lambda session, model: model) - - class _Session: - def __init__(self, *args, **kwargs): - _ = args, kwargs - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) monkeypatch.setattr( "core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object(), session=db_session) ) @@ -519,7 +517,7 @@ class TestAdvancedChatAppGeneratorInternals: ) response = generator._generate( - workflow=SimpleNamespace(features={}), + workflow=workflow, user=SimpleNamespace(id="user"), invoke_from=InvokeFrom.WEB_APP, application_generate_entity=application_generate_entity, @@ -940,10 +938,16 @@ class TestAdvancedChatAppGeneratorInternals: with pytest.raises(GenerateTaskStoppedError): generator._handle_advanced_chat_response( application_generate_entity=application_generate_entity, - workflow=SimpleNamespace(), + workflow=WorkflowSnapshot(id="wf", tenant_id="tenant", features_dict={}), queue_manager=SimpleNamespace(), - conversation=SimpleNamespace(id="conv", mode=AppMode.ADVANCED_CHAT), - message=SimpleNamespace(id="msg"), + conversation=ConversationSnapshot(id="conv", mode=AppMode.ADVANCED_CHAT), + message=MessageSnapshot( + id="msg", + query="hello", + created_at=naive_utc_now(), + status=MessageStatus.NORMAL, + answer="", + ), user=SimpleNamespace(), draft_var_saver_factory=lambda **kwargs: None, stream=False, @@ -981,10 +985,16 @@ class TestAdvancedChatAppGeneratorInternals: with pytest.raises(ValueError, match="other error"): generator._handle_advanced_chat_response( application_generate_entity=application_generate_entity, - workflow=SimpleNamespace(), + workflow=WorkflowSnapshot(id="wf", tenant_id="tenant", features_dict={}), queue_manager=SimpleNamespace(), - conversation=SimpleNamespace(id="conv", mode=AppMode.ADVANCED_CHAT), - message=SimpleNamespace(id="msg"), + conversation=ConversationSnapshot(id="conv", mode=AppMode.ADVANCED_CHAT), + message=MessageSnapshot( + id="msg", + query="hello", + created_at=naive_utc_now(), + status=MessageStatus.NORMAL, + answer="", + ), user=SimpleNamespace(), draft_var_saver_factory=lambda **kwargs: None, stream=False, @@ -992,31 +1002,6 @@ class TestAdvancedChatAppGeneratorInternals: logger_exception.assert_called_once() - def test_refresh_model_returns_detached_model(self, monkeypatch): - source_model = SimpleNamespace(id="source-id") - detached_model = SimpleNamespace(id="source-id", detached=True) - - class _Session: - def __init__(self, *args, **kwargs): - _ = args, kwargs - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def get(self, model_type, model_id): - _ = model_type - return detached_model if model_id == "source-id" else None - - monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) - monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object())) - - refreshed = _refresh_model(session=SimpleNamespace(), model=source_model) - - assert refreshed is detached_model - def test_generate_worker_handles_invoke_auth_error(self, monkeypatch): generator = AdvancedChatAppGenerator() generator._dialogue_count = 1 @@ -1053,7 +1038,7 @@ class TestAdvancedChatAppGeneratorInternals: _ = kwargs def run(self): - from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError + from graphon.model_runtime.errors.invoke import InvokeAuthorizationError raise InvokeAuthorizationError("bad key") diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py index 15aceef2c72..061719d15a5 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py @@ -3,14 +3,27 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 +from graphon.variables import SegmentType from sqlalchemy.orm import Session from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom -from dify_graph.variables import SegmentType from factories import variable_factory from models import ConversationVariable, Workflow +MINIMAL_GRAPH = { + "nodes": [ + { + "id": "start", + "data": { + "type": "start", + "title": "Start", + }, + } + ], + "edges": [], +} + class TestAdvancedChatAppRunnerConversationVariables: """Test that AdvancedChatAppRunner correctly handles conversation variables.""" @@ -49,7 +62,7 @@ class TestAdvancedChatAppRunnerConversationVariables: mock_workflow.app_id = app_id mock_workflow.id = workflow_id mock_workflow.type = "chat" - mock_workflow.graph_dict = {} + mock_workflow.graph_dict = MINIMAL_GRAPH mock_workflow.environment_variables = [] # Create existing conversation variable (only var1 exists in DB) @@ -200,7 +213,7 @@ class TestAdvancedChatAppRunnerConversationVariables: mock_workflow.app_id = app_id mock_workflow.id = workflow_id mock_workflow.type = "chat" - mock_workflow.graph_dict = {} + mock_workflow.graph_dict = MINIMAL_GRAPH mock_workflow.environment_variables = [] # Mock conversation and message @@ -349,7 +362,7 @@ class TestAdvancedChatAppRunnerConversationVariables: mock_workflow.app_id = app_id mock_workflow.id = workflow_id mock_workflow.type = "chat" - mock_workflow.graph_dict = {} + mock_workflow.graph_dict = MINIMAL_GRAPH mock_workflow.environment_variables = [] # Create existing conversation variables (both exist in DB) diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py index 5792a2f1e24..079df0b4e61 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py @@ -8,6 +8,19 @@ from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, from core.app.entities.queue_entities import QueueStopEvent from core.moderation.base import ModerationError +MINIMAL_GRAPH = { + "nodes": [ + { + "id": "start", + "data": { + "type": "start", + "title": "Start", + }, + } + ], + "edges": [], +} + @pytest.fixture def build_runner(): @@ -30,7 +43,7 @@ def build_runner(): mock_workflow.tenant_id = str(uuid4()) mock_workflow.app_id = app_id mock_workflow.type = "chat" - mock_workflow.graph_dict = {} + mock_workflow.graph_dict = MINIMAL_GRAPH mock_workflow.environment_variables = [] mock_app_config = MagicMock() diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py index 5b199e0c52d..e9fdeefee4d 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py @@ -1,5 +1,7 @@ from collections.abc import Generator +from graphon.enums import WorkflowNodeExecutionStatus + from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter from core.app.entities.task_entities import ( ChatbotAppBlockingResponse, @@ -10,7 +12,6 @@ from core.app.entities.task_entities import ( NodeStartStreamResponse, PingStreamResponse, ) -from dify_graph.enums import WorkflowNodeExecutionStatus class TestAdvancedChatGenerateResponseConverter: diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py index 83a6e0f231c..a6d85989556 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py @@ -6,6 +6,8 @@ from types import SimpleNamespace from unittest import mock import pytest +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus from core.app.apps.advanced_chat import generate_task_pipeline as pipeline_module from core.app.entities.app_invoke_entities import InvokeFrom @@ -17,11 +19,9 @@ from core.app.entities.queue_entities import ( QueueWorkflowSucceededEvent, ) from core.app.entities.task_entities import StreamEvent -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import WorkflowExecutionStatus from models.enums import MessageStatus from models.execution_extra_content import HumanInputContent -from models.model import EndUser +from models.model import AppMode, EndUser def _build_pipeline() -> pipeline_module.AdvancedChatAppGenerateTaskPipeline: @@ -137,7 +137,6 @@ def test_handle_workflow_paused_event_persists_human_input_extra_content() -> No actions=[], node_id="node-1", node_title="Approval", - form_token="token-1", resolved_default_values={}, ) event = QueueWorkflowPausedEvent(reasons=[reason], outputs={}, paused_nodes=["node-1"]) @@ -160,8 +159,8 @@ def test_resume_appends_chunks_to_paused_answer() -> None: task_id="task-1", ) queue_manager = SimpleNamespace(graph_runtime_state=None) - conversation = SimpleNamespace(id="conversation-1", mode="advanced-chat") - message = SimpleNamespace( + conversation = pipeline_module.ConversationSnapshot(id="conversation-1", mode=AppMode.ADVANCED_CHAT) + message = pipeline_module.MessageSnapshot( id="message-1", created_at=datetime(2024, 1, 1), query="hello", @@ -171,7 +170,7 @@ def test_resume_appends_chunks_to_paused_answer() -> None: user = EndUser() user.id = "user-1" user.session_id = "session-1" - workflow = SimpleNamespace(id="workflow-1", tenant_id="tenant-1", features_dict={}) + workflow = pipeline_module.WorkflowSnapshot(id="workflow-1", tenant_id="tenant-1", features_dict={}) pipeline = pipeline_module.AdvancedChatAppGenerateTaskPipeline( application_generate_entity=application_generate_entity, @@ -185,14 +184,33 @@ def test_resume_appends_chunks_to_paused_answer() -> None: draft_var_saver_factory=SimpleNamespace(), ) - pipeline._get_message = mock.Mock(return_value=message) + stored_message = SimpleNamespace( + id="message-1", + answer="before", + status=MessageStatus.PAUSED, + updated_at=None, + provider_response_latency=0, + message_tokens=0, + message_unit_price=0, + message_price_unit=0, + answer_tokens=0, + answer_unit_price=0, + answer_price_unit=0, + total_price=0, + currency="USD", + message_metadata=None, + invoke_from=InvokeFrom.WEB_APP, + from_account_id=None, + from_end_user_id="user-1", + ) + pipeline._get_message = mock.Mock(return_value=stored_message) pipeline._recorded_files = [] list(pipeline._handle_text_chunk_event(QueueTextChunkEvent(text="after"))) pipeline._save_message(session=mock.Mock()) - assert message.answer == "beforeafter" - assert message.status == MessageStatus.NORMAL + assert stored_message.answer == "beforeafter" + assert stored_message.status == MessageStatus.NORMAL def test_workflow_succeeded_emits_message_end_before_workflow_finished() -> None: diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py index 0a244b3fea6..82b2e51019f 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py @@ -1,13 +1,19 @@ from __future__ import annotations from contextlib import contextmanager -from datetime import datetime from types import SimpleNamespace import pytest +from graphon.enums import BuiltinNodeTypes +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig -from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline +from core.app.apps.advanced_chat.generate_task_pipeline import ( + AdvancedChatAppGenerateTaskPipeline, + ConversationSnapshot, + MessageSnapshot, + WorkflowSnapshot, +) from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.app.entities.queue_entities import ( QueueAdvancedChatMessageEndEvent, @@ -42,11 +48,11 @@ from core.app.entities.task_entities import ( PingStreamResponse, ) from core.base.tts.app_generator_tts_publisher import AudioTrunk -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables +from libs.datetime_utils import naive_utc_now from models.enums import MessageStatus from models.model import AppMode, EndUser +from tests.workflow_test_utils import build_test_variable_pool def _make_pipeline(): @@ -72,15 +78,15 @@ def _make_pipeline(): workflow_run_id="run-id", ) - message = SimpleNamespace( + message = MessageSnapshot( id="message-id", query="hello", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), status=MessageStatus.NORMAL, answer="", ) - conversation = SimpleNamespace(id="conv-id", mode=AppMode.ADVANCED_CHAT) - workflow = SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}) + conversation = ConversationSnapshot(id="conv-id", mode=AppMode.ADVANCED_CHAT) + workflow = WorkflowSnapshot(id="workflow-id", tenant_id="tenant", features_dict={}) user = EndUser(tenant_id="tenant", type="session", name="tester", session_id="session") pipeline = AdvancedChatAppGenerateTaskPipeline( @@ -166,7 +172,7 @@ class TestAdvancedChatGenerateTaskPipeline: def test_handle_workflow_started_event_sets_run_id(self, monkeypatch): pipeline = _make_pipeline() pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=build_test_variable_pool(variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._workflow_response_converter.workflow_start_to_stream_response = lambda **kwargs: "started" @@ -256,7 +262,7 @@ class TestAdvancedChatGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="LLM", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), node_run_index=1, ) iter_next = QueueIterationNextEvent( @@ -272,7 +278,7 @@ class TestAdvancedChatGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="LLM", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), node_run_index=1, ) loop_start = QueueLoopStartEvent( @@ -280,7 +286,7 @@ class TestAdvancedChatGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="LLM", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), node_run_index=1, ) loop_next = QueueLoopNextEvent( @@ -296,7 +302,7 @@ class TestAdvancedChatGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="LLM", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), node_run_index=1, ) @@ -311,7 +317,7 @@ class TestAdvancedChatGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_run_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish" @@ -359,7 +365,7 @@ class TestAdvancedChatGenerateTaskPipeline: node_execution_id="exec", node_id="node", node_type=BuiltinNodeTypes.LLM, - start_at=datetime.utcnow(), + start_at=naive_utc_now(), inputs={}, outputs={}, process_data={}, @@ -369,7 +375,7 @@ class TestAdvancedChatGenerateTaskPipeline: node_execution_id="exec", node_id="node", node_type=BuiltinNodeTypes.LLM, - start_at=datetime.utcnow(), + start_at=naive_utc_now(), inputs={}, outputs={}, process_data={}, @@ -472,7 +478,7 @@ class TestAdvancedChatGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="title", - expiration_time=datetime.utcnow(), + expiration_time=naive_utc_now(), ) assert list(pipeline._handle_human_input_form_filled_event(filled_event)) == ["filled"] @@ -522,7 +528,7 @@ class TestAdvancedChatGenerateTaskPipeline: self.items = items graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) @@ -556,7 +562,7 @@ class TestAdvancedChatGenerateTaskPipeline: def test_handle_message_end_event_applies_output_moderation(self, monkeypatch): pipeline = _make_pipeline() pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._base_task_pipeline.handle_output_moderation_when_task_finished = lambda answer: "safe" @@ -590,7 +596,7 @@ class TestAdvancedChatGenerateTaskPipeline: node_execution_id="exec", node_id="node", node_type=BuiltinNodeTypes.LLM, - start_at=datetime.utcnow(), + start_at=naive_utc_now(), inputs={}, outputs={}, process_data={}, diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py index 53f26d15928..7dc43581501 100644 --- a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py @@ -1,12 +1,12 @@ import contextlib import pytest +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError class DummyAccount: diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py index 5603115b30c..08250bc3b6f 100644 --- a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py @@ -1,10 +1,10 @@ import pytest +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey from core.agent.entities import AgentEntity from core.app.apps.agent_chat.app_runner import AgentChatAppRunner from core.moderation.base import ModerationError -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey @pytest.fixture diff --git a/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py b/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py index 3cdffbb4cdd..68bcffb0e83 100644 --- a/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py +++ b/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py @@ -2,6 +2,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from core.app.apps.chat.app_generator import ChatAppGenerator from core.app.apps.chat.app_runner import ChatAppRunner @@ -9,7 +10,6 @@ from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.moderation.base import ModerationError -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py index 67b3777c406..f255d2c7df2 100644 --- a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py +++ b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py @@ -4,13 +4,13 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from graphon.file import FileTransferMethod, FileType +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from core.app.apps.base_app_queue_manager import PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueMessageFileEvent -from dify_graph.file.enums import FileTransferMethod, FileType -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent from models.enums import CreatorUserRole diff --git a/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py index b0789bbc1ef..4a94a2b4f1b 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py +++ b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py @@ -1,15 +1,16 @@ from types import SimpleNamespace import pytest +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport -from dify_graph.runtime import GraphRuntimeState -from dify_graph.runtime.variable_pool import VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables +from core.workflow.variable_pool_initializer import add_variables_to_pool def _make_state(workflow_run_id: str | None) -> GraphRuntimeState: - variable_pool = VariablePool(system_variables=SystemVariable(workflow_execution_id=workflow_run_id)) + variable_pool = VariablePool() + add_variables_to_pool(variable_pool, build_system_variables(workflow_execution_id=workflow_run_id)) return GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py index 72430a33477..328cd12f12e 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py @@ -1,8 +1,9 @@ from collections.abc import Mapping, Sequence +from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType +from graphon.variables.segments import ArrayFileSegment, FileSegment + from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter -from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType -from dify_graph.variables.segments import ArrayFileSegment, FileSegment class TestWorkflowResponseConverterFetchFilesFromVariableValue: @@ -12,7 +13,6 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue: """Create a test File object""" return File( id=file_id, - tenant_id="test_tenant", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related_123", @@ -223,7 +223,7 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue: assert len(result) == 1 file_dict = result[0] assert file_dict["id"] == "property_test" - assert file_dict["tenant_id"] == "test_tenant" + assert "tenant_id" not in file_dict assert file_dict["type"] == "document" assert file_dict["transfer_method"] == "local_file" assert file_dict["filename"] == "property_test.txt" diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py index 4ed7d73cd06..bc11bf41744 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py @@ -1,16 +1,17 @@ from datetime import UTC, datetime from types import SimpleNamespace +from graphon.entities import WorkflowStartReason +from graphon.runtime import GraphRuntimeState, VariablePool + from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueHumanInputFormFilledEvent, QueueHumanInputFormTimeoutEvent -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables def _build_converter(): - system_variables = SystemVariable( + system_variables = build_system_variables( files=[], user_id="user-1", app_id="app-1", diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py index 5879e8fb9b1..c9e146ff126 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py @@ -1,15 +1,16 @@ from types import SimpleNamespace +from graphon.entities import WorkflowStartReason +from graphon.runtime import GraphRuntimeState, VariablePool + from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables def _build_converter() -> WorkflowResponseConverter: """Construct a minimal WorkflowResponseConverter for testing.""" - system_variables = SystemVariable( + system_variables = build_system_variables( files=[], user_id="user-1", app_id="app-1", diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py index 374af5ddc47..0fde7565d24 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py @@ -10,6 +10,8 @@ from typing import Any from unittest.mock import Mock import pytest +from graphon.entities import WorkflowStartReason +from graphon.enums import BuiltinNodeTypes from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter @@ -24,9 +26,7 @@ from core.app.entities.queue_entities import ( QueueNodeStartedEvent, QueueNodeSucceededEvent, ) -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables from libs.datetime_utils import naive_utc_now from models import Account from models.model import AppMode @@ -54,7 +54,7 @@ class TestWorkflowResponseConverter: mock_user.name = "Test User" mock_user.email = "test@example.com" - system_variables = SystemVariable(workflow_id="wf-id", workflow_execution_id="initial-run-id") + system_variables = build_system_variables(workflow_id="wf-id", workflow_execution_id="initial-run-id") return WorkflowResponseConverter( application_generate_entity=mock_entity, user=mock_user, @@ -451,9 +451,9 @@ class TestWorkflowResponseConverterServiceApiTruncation: account.id = "test_user_id" return account - def create_test_system_variables(self) -> SystemVariable: + def create_test_system_variables(self): """Create test system variables.""" - return SystemVariable() + return build_system_variables() def create_test_converter(self, invoke_from: InvokeFrom) -> WorkflowResponseConverter: """Create WorkflowResponseConverter with specified invoke_from.""" diff --git a/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py b/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py index 51f33bac352..619d66085a4 100644 --- a/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py @@ -2,11 +2,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent import core.app.apps.completion.app_runner as module from core.app.apps.completion.app_runner import CompletionAppRunner from core.moderation.base import ModerationError -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent @pytest.fixture diff --git a/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py b/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py index 27147573539..96af9fbdeef 100644 --- a/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py @@ -3,13 +3,13 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError import core.app.apps.completion.app_generator as module from core.app.apps.completion.app_generator import CompletionAppGenerator from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError from services.errors.app import MoreLikeThisDisabledError from services.errors.message import MessageNotExistsError diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py index 94ed8166b92..6cdcab29abe 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py @@ -1,5 +1,7 @@ from collections.abc import Generator +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus + from core.app.apps.pipeline.generate_response_converter import WorkflowAppGenerateResponseConverter from core.app.entities.task_entities import ( AppStreamResponse, @@ -10,7 +12,6 @@ from core.app.entities.task_entities import ( WorkflowAppBlockingResponse, WorkflowAppStreamResponse, ) -from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus def test_convert_blocking_full_and_simple_response(): diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py index 72f7552bd1d..4fe82efcb33 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py @@ -1,4 +1,5 @@ import pytest +from graphon.model_runtime.entities.llm_entities import LLMResult import core.app.apps.pipeline.pipeline_queue_manager as module from core.app.apps.base_app_queue_manager import PublishFrom @@ -13,7 +14,6 @@ from core.app.entities.queue_entities import ( QueueWorkflowPartialSuccessEvent, QueueWorkflowSucceededEvent, ) -from dify_graph.model_runtime.entities.llm_entities import LLMResult def test_publish_sets_stop_listen_and_raises_on_stopped(mocker): diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py index eec95b7f392..ab70996f0aa 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py @@ -22,11 +22,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.graph_events import GraphRunFailedEvent import core.app.apps.pipeline.pipeline_runner as module from core.app.apps.pipeline.pipeline_runner import PipelineRunner from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.graph_events import GraphRunFailedEvent def _build_app_generate_entity() -> SimpleNamespace: @@ -284,7 +284,12 @@ def test_run_normal_path_builds_graph(mocker): return_value=SimpleNamespace(belong_to_node_id="start", variable="input1"), ) mocker.patch.object(module, "RAGPipelineVariableInput", side_effect=lambda **kwargs: SimpleNamespace(**kwargs)) - mocker.patch.object(module, "VariablePool", side_effect=lambda **kwargs: SimpleNamespace(**kwargs)) + + class FakeVariablePool: + def add(self, selector, value): + return None + + mocker.patch.object(module, "VariablePool", return_value=FakeVariablePool()) workflow_entry = MagicMock() workflow_entry.graph_engine = MagicMock() diff --git a/api/tests/unit_tests/core/app/apps/test_advanced_chat_app_generator.py b/api/tests/unit_tests/core/app/apps/test_advanced_chat_app_generator.py index a25e3ec3f54..f48a7fb38e2 100644 --- a/api/tests/unit_tests/core/app/apps/test_advanced_chat_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_advanced_chat_app_generator.py @@ -11,6 +11,7 @@ from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.app.task_pipeline import message_cycle_manager from core.app.task_pipeline.message_cycle_manager import MessageCycleManager +from models.enums import ConversationFromSource from models.model import AppMode, Conversation, Message @@ -92,7 +93,7 @@ def test_init_generate_records_marks_existing_conversation(): system_instruction_tokens=0, status="normal", invoke_from=InvokeFrom.WEB_APP.value, - from_source="api", + from_source=ConversationFromSource.API, from_end_user_id="user-id", from_account_id=None, ) diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py index a3ced023949..6167be3bbdb 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py @@ -1,9 +1,7 @@ -from unittest.mock import MagicMock - import pytest +from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.app.apps.base_app_generator import BaseAppGenerator -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType def test_validate_inputs_with_zero(): @@ -403,11 +401,11 @@ class TestBaseAppGeneratorExtras: monkeypatch.setattr( "core.app.apps.base_app_generator.file_factory.build_from_mapping", - lambda mapping, tenant_id, config, strict_type_validation=False: "file-object", + lambda mapping, tenant_id, config, strict_type_validation=False, access_controller=None: "file-object", ) monkeypatch.setattr( "core.app.apps.base_app_generator.file_factory.build_from_mappings", - lambda mappings, tenant_id, config: ["file-1", "file-2"], + lambda mappings, tenant_id, config, access_controller=None: ["file-1", "file-2"], ) user_inputs = { @@ -478,8 +476,9 @@ class TestBaseAppGeneratorExtras: assert converted[1] == "event: ping\n\n" def test_get_draft_var_saver_factory_debugger(self): + from graphon.enums import BuiltinNodeTypes + from core.app.entities.app_invoke_entities import InvokeFrom - from dify_graph.enums import BuiltinNodeTypes from models import Account base_app_generator = BaseAppGenerator() @@ -489,7 +488,6 @@ class TestBaseAppGeneratorExtras: factory = base_app_generator._get_draft_var_saver_factory(InvokeFrom.DEBUGGER, account) saver = factory( - session=MagicMock(), app_id="app-id", node_id="node-id", node_type=BuiltinNodeTypes.START, diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_queue_manager.py b/api/tests/unit_tests/core/app/apps/test_base_app_queue_manager.py index c6dc20ffc66..842d14bbd25 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_queue_manager.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_queue_manager.py @@ -59,3 +59,18 @@ class TestBaseAppQueueManager: bad = SimpleNamespace(_sa_instance_state=True) with pytest.raises(TypeError): manager._check_for_sqlalchemy_models(bad) + + def test_stop_listen_defers_graph_runtime_state_cleanup_until_listener_exits(self): + with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis: + mock_redis.setex.return_value = True + mock_redis.get.return_value = None + manager = DummyQueueManager(task_id="t1", user_id="u1", invoke_from=InvokeFrom.SERVICE_API) + + runtime_state = SimpleNamespace(name="runtime-state") + manager.graph_runtime_state = runtime_state + + manager.stop_listen() + + assert manager.graph_runtime_state is runtime_state + assert list(manager.listen()) == [] + assert manager.graph_runtime_state is None diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_runner.py b/api/tests/unit_tests/core/app/apps/test_base_app_runner.py index aabeb545539..1dee7fdab66 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_runner.py @@ -4,6 +4,15 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessageRole, + TextPromptMessageContent, +) +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.errors.invoke import InvokeBadRequestError from core.app.app_config.entities import ( AdvancedChatMessageEntity, @@ -14,15 +23,6 @@ from core.app.app_config.entities import ( from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from dify_graph.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - PromptMessageRole, - TextPromptMessageContent, -) -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey -from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/apps/test_pause_resume.py b/api/tests/unit_tests/core/app/apps/test_pause_resume.py index 2f73a8cda8a..a126bc85f75 100644 --- a/api/tests/unit_tests/core/app/apps/test_pause_resume.py +++ b/api/tests/unit_tests/core/app/apps/test_pause_resume.py @@ -3,33 +3,34 @@ import time from types import ModuleType, SimpleNamespace from typing import Any -import dify_graph.nodes.human_input.entities # noqa: F401 -from core.app.apps.advanced_chat import app_generator as adv_app_gen_module -from core.app.apps.workflow import app_generator as wf_app_gen_module -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig -from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from dify_graph.entities.pause_reason import SchedulingPause -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.graph_engine import GraphEngine -from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from dify_graph.graph_events import ( +import graphon.nodes.human_input.entities # noqa: F401 +from graphon.entities import WorkflowStartReason +from graphon.entities.base_node_data import BaseNodeData, RetryConfig +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.entities.pause_reason import SchedulingPause +from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.graph_engine import GraphEngine +from graphon.graph_engine.command_channels import InMemoryChannel +from graphon.graph_events import ( GraphEngineEvent, GraphRunPausedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunSucceededEvent, ) -from dify_graph.node_events import NodeRunResult, PauseRequestedEvent -from dify_graph.nodes.base.entities import OutputVariableEntity -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from graphon.node_events import NodeRunResult, PauseRequestedEvent +from graphon.nodes.base.entities import OutputVariableEntity +from graphon.nodes.base.node import Node +from graphon.nodes.end.entities import EndNodeData +from graphon.nodes.start.entities import StartNodeData +from graphon.runtime import GraphRuntimeState, VariablePool + +from core.app.apps.advanced_chat import app_generator as adv_app_gen_module +from core.app.apps.workflow import app_generator as wf_app_gen_module +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params if "core.ops.ops_trace_manager" not in sys.modules: @@ -162,11 +163,11 @@ def _build_graph(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> G def _build_runtime_state(run_id: str) -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"), + system_variables=build_system_variables(user_id="user", app_id="app", workflow_id="workflow"), user_inputs={}, conversation_variables=[], ) - variable_pool.system_variables.workflow_execution_id = run_id + variable_pool.add(("sys", "workflow_run_id"), run_id) return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py index 3f1dd14569d..de5bca161c7 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py @@ -1,9 +1,28 @@ from __future__ import annotations -from datetime import datetime +from datetime import UTC, datetime from types import SimpleNamespace import pytest +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import BuiltinNodeTypes +from graphon.graph_events import ( + GraphRunPausedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunAgentLogEvent, + NodeRunExceptionEvent, + NodeRunFailedEvent, + NodeRunIterationSucceededEvent, + NodeRunLoopFailedEvent, + NodeRunRetryEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) +from graphon.node_events import NodeRunResult +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables.variables import StringVariable from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom @@ -11,25 +30,16 @@ from core.app.entities.queue_entities import ( QueueAgentLogEvent, QueueIterationCompletedEvent, QueueLoopCompletedEvent, + QueueNodeExceptionEvent, + QueueNodeFailedEvent, + QueueNodeRetryEvent, + QueueNodeSucceededEvent, QueueTextChunkEvent, QueueWorkflowPausedEvent, QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.graph_events import ( - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunAgentLogEvent, - NodeRunIterationSucceededEvent, - NodeRunLoopFailedEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, -) -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import default_system_variables class TestWorkflowBasedAppRunner: @@ -44,7 +54,7 @@ class TestWorkflowBasedAppRunner: runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app") runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), start_at=0.0, ) @@ -78,12 +88,12 @@ class TestWorkflowBasedAppRunner: workflow = SimpleNamespace(environment_variables=[], graph_dict={}) with pytest.raises(ValueError, match="Neither single_iteration_run nor single_loop_run"): - runner._prepare_single_node_execution(workflow, None, None) + runner._prepare_single_node_execution(workflow, None, None, user_id="00000000-0000-0000-0000-000000000001") def test_get_graph_and_variable_pool_for_single_node_run(self, monkeypatch): runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app") graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), start_at=0.0, ) @@ -126,11 +136,102 @@ class TestWorkflowBasedAppRunner: graph_runtime_state=graph_runtime_state, node_type_filter_key="iteration_id", node_type_label="iteration", + user_id="00000000-0000-0000-0000-000000000001", ) assert graph is not None assert variable_pool is graph_runtime_state.variable_pool + def test_get_graph_and_variable_pool_preloads_constructor_variables_before_graph_init(self, monkeypatch): + variable_loader = SimpleNamespace( + load_variables=lambda selectors: ( + [ + StringVariable( + name="conversation_id", + value="conv-1", + selector=["sys", "conversation_id"], + ) + ] + if selectors + else [] + ) + ) + runner = WorkflowBasedAppRunner( + queue_manager=SimpleNamespace(), + variable_loader=variable_loader, + app_id="app", + ) + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=default_system_variables()), + start_at=0.0, + ) + + workflow = SimpleNamespace( + tenant_id="tenant", + id="workflow", + graph_dict={ + "nodes": [ + {"id": "loop-node", "data": {"type": "loop", "version": "1", "title": "Loop"}}, + { + "id": "llm-child", + "data": { + "type": "llm", + "version": "1", + "loop_id": "loop-node", + "memory": object(), + }, + }, + ], + "edges": [], + }, + ) + + class _LoopNodeCls: + @staticmethod + def extract_variable_selector_to_variable_mapping(graph_config, config): + return {} + + def _validate_node_config(value): + return {"id": value["id"], "data": SimpleNamespace(**value["data"])} + + def _graph_init(**kwargs): + variable_pool = graph_runtime_state.variable_pool + assert variable_pool.get(["sys", "conversation_id"]) is not None + return SimpleNamespace() + + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.NodeConfigDictAdapter.validate_python", + _validate_node_config, + ) + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.Graph.init", + _graph_init, + ) + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.resolve_workflow_node_class", + lambda **_kwargs: _LoopNodeCls, + ) + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.load_into_variable_pool", + lambda **kwargs: None, + ) + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.WorkflowEntry.mapping_user_inputs_to_variable_pool", + lambda **kwargs: None, + ) + + graph, variable_pool = runner._get_graph_and_variable_pool_for_single_node_run( + workflow=workflow, + node_id="loop-node", + user_inputs={}, + graph_runtime_state=graph_runtime_state, + node_type_filter_key="loop_id", + node_type_label="loop", + ) + + assert graph is not None + assert variable_pool.get(["sys", "conversation_id"]).value == "conv-1" + def test_handle_graph_run_events_and_pause_notifications(self, monkeypatch): published: list[object] = [] @@ -140,7 +241,7 @@ class TestWorkflowBasedAppRunner: runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app") graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), start_at=0.0, ) graph_runtime_state.register_paused_node("node-1") @@ -183,7 +284,7 @@ class TestWorkflowBasedAppRunner: runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app") graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), start_at=0.0, ) workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state)) @@ -195,7 +296,7 @@ class TestWorkflowBasedAppRunner: node_id="node", node_type=BuiltinNodeTypes.START, node_title="Start", - start_at=datetime.utcnow(), + start_at=datetime.now(UTC), ), ) runner._handle_event( @@ -232,7 +333,7 @@ class TestWorkflowBasedAppRunner: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="Iter", - start_at=datetime.utcnow(), + start_at=datetime.now(UTC), inputs={}, outputs={"ok": True}, metadata={}, @@ -246,7 +347,7 @@ class TestWorkflowBasedAppRunner: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="Loop", - start_at=datetime.utcnow(), + start_at=datetime.now(UTC), inputs={}, outputs={}, metadata={}, @@ -259,3 +360,87 @@ class TestWorkflowBasedAppRunner: assert any(isinstance(event, QueueAgentLogEvent) for event in published) assert any(isinstance(event, QueueIterationCompletedEvent) for event in published) assert any(isinstance(event, QueueLoopCompletedEvent) for event in published) + + @pytest.mark.parametrize( + ("event_factory", "queue_event_cls"), + [ + ( + lambda result, start_at, finished_at: NodeRunSucceededEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + start_at=start_at, + finished_at=finished_at, + node_run_result=result, + ), + QueueNodeSucceededEvent, + ), + ( + lambda result, start_at, finished_at: NodeRunFailedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + start_at=start_at, + finished_at=finished_at, + error="boom", + node_run_result=result, + ), + QueueNodeFailedEvent, + ), + ( + lambda result, start_at, finished_at: NodeRunExceptionEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + start_at=start_at, + finished_at=finished_at, + error="boom", + node_run_result=result, + ), + QueueNodeExceptionEvent, + ), + ( + lambda result, start_at, _finished_at: NodeRunRetryEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + node_title="Start", + start_at=start_at, + error="boom", + retry_index=1, + node_run_result=result, + ), + QueueNodeRetryEvent, + ), + ], + ) + def test_handle_start_node_result_events_project_outputs(self, event_factory, queue_event_cls): + published: list[object] = [] + + class _QueueManager: + def publish(self, event, publish_from): + published.append(event) + + runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app") + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=default_system_variables()), + start_at=0.0, + ) + workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state)) + started_at = datetime.now(UTC) + finished_at = datetime.now(UTC) + result = NodeRunResult( + inputs={"question": "hello"}, + outputs={ + "question": "hello", + "sys.query": "hello", + "env.API_KEY": "secret", + "conversation.session_id": "session-1", + }, + ) + + runner._handle_event(workflow_entry, event_factory(result, started_at, finished_at)) + + queue_event = published[-1] + assert isinstance(queue_event, queue_event_cls) + assert queue_event.outputs == {"question": "hello"} diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py index 13882792214..aa789d9ff3a 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py @@ -1,11 +1,11 @@ from unittest.mock import MagicMock import pytest +from graphon.entities.pause_reason import HumanInputRequired +from graphon.graph_events import GraphRunPausedEvent from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.queue_entities import QueueWorkflowPausedEvent -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.graph_events.graph import GraphRunPausedEvent class _DummyQueueManager: diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py index 178e26118ee..9e30faecf23 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py @@ -4,20 +4,20 @@ from typing import Any from unittest.mock import MagicMock, patch import pytest +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_runner import WorkflowAppRunner from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity -from dify_graph.entities.graph_config import NodeConfigDictAdapter -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import default_system_variables from models.workflow import Workflow def _make_graph_state(): variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -100,6 +100,7 @@ def test_run_uses_single_node_execution_branch( workflow=workflow, single_iteration_run=single_iteration_run, single_loop_run=single_loop_run, + user_id="user", ) init_graph.assert_not_called() @@ -158,6 +159,7 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None: graph_runtime_state=graph_runtime_state, node_type_filter_key="loop_id", node_type_label="loop", + user_id="00000000-0000-0000-0000-000000000001", ) assert seen_configs == [workflow.graph_dict["nodes"][0]] diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py index 65c6bd66547..8a717e1dccd 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py @@ -3,6 +3,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import HumanInputRequired +from graphon.graph_events import GraphRunPausedEvent +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType from core.app.apps.common import workflow_response_converter from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter @@ -10,13 +15,9 @@ from core.app.apps.workflow.app_runner import WorkflowAppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueWorkflowPausedEvent from core.app.entities.task_entities import HumanInputRequiredResponse, WorkflowPauseStreamResponse -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.graph_events.graph import GraphRunPausedEvent -from dify_graph.nodes.human_input.entities import FormInput, UserAction -from dify_graph.nodes.human_input.enums import FormInputType -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables from models.account import Account +from models.human_input import RecipientType class _RecordingWorkflowAppRunner(WorkflowAppRunner): @@ -74,7 +75,6 @@ def test_graph_run_paused_event_emits_queue_pause_event(): actions=[], node_id="node-human", node_title="Human Step", - form_token="tok", ) event = GraphRunPausedEvent(reasons=[reason], outputs={"foo": "bar"}) workflow_entry = SimpleNamespace( @@ -98,7 +98,7 @@ def _build_converter(): invoke_from=InvokeFrom.SERVICE_API, app_config=SimpleNamespace(app_id="app-id", tenant_id="tenant-id"), ) - system_variables = SystemVariable( + system_variables = build_system_variables( user_id="user", app_id="app-id", workflow_id="workflow-id", @@ -128,7 +128,21 @@ def test_queue_workflow_paused_event_to_stream_responses(monkeypatch: pytest.Mon class _FakeSession: def execute(self, _stmt): - return [("form-1", expiration_time)] + return [("form-1", expiration_time, '{"display_in_ui": true}')] + + def scalars(self, _stmt): + return [ + SimpleNamespace( + form_id="form-1", + recipient_type=RecipientType.CONSOLE, + access_token="console-token", + ), + SimpleNamespace( + form_id="form-1", + recipient_type=RecipientType.BACKSTAGE, + access_token="backstage-token", + ), + ] def __enter__(self): return self @@ -146,10 +160,8 @@ def test_queue_workflow_paused_event_to_stream_responses(monkeypatch: pytest.Mon FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="field", default=None), ], actions=[UserAction(id="approve", title="Approve")], - display_in_ui=True, node_id="node-id", node_title="Human Step", - form_token="token", ) queue_event = QueueWorkflowPausedEvent( reasons=[reason], @@ -170,7 +182,6 @@ def test_queue_workflow_paused_event_to_stream_responses(monkeypatch: pytest.Mon assert pause_resp.data.paused_nodes == ["node-id"] assert pause_resp.data.outputs == {} assert pause_resp.data.reasons[0]["form_id"] == "form-1" - assert pause_resp.data.reasons[0]["display_in_ui"] is True assert isinstance(responses[0], HumanInputRequiredResponse) hi_resp = responses[0] @@ -180,4 +191,5 @@ def test_queue_workflow_paused_event_to_stream_responses(monkeypatch: pytest.Mon assert hi_resp.data.inputs[0].output_variable_name == "field" assert hi_resp.data.actions[0].id == "approve" assert hi_resp.data.display_in_ui is True + assert hi_resp.data.form_token == "backstage-token" assert hi_resp.data.expiration_time == int(expiration_time.timestamp()) diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py index 62e94a75802..b768e813bd7 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py @@ -1,5 +1,7 @@ from collections.abc import Generator +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus + from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter from core.app.entities.task_entities import ( ErrorStreamResponse, @@ -9,7 +11,6 @@ from core.app.entities.task_entities import ( WorkflowAppBlockingResponse, WorkflowAppStreamResponse, ) -from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus class TestWorkflowGenerateResponseConverter: diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py index 5b23e710358..29df903aa8b 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py @@ -2,16 +2,18 @@ import time from contextlib import contextmanager from unittest.mock import MagicMock +from graphon.entities import WorkflowStartReason +from graphon.runtime import GraphRuntimeState + from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import QueueWorkflowStartedEvent -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables from models.account import Account from models.model import AppMode +from tests.workflow_test_utils import build_test_variable_pool def _build_workflow_app_config() -> WorkflowUIBasedAppConfig: @@ -37,11 +39,7 @@ def _build_generate_entity(run_id: str) -> WorkflowAppGenerateEntity: def _build_runtime_state(run_id: str) -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=SystemVariable(workflow_execution_id=run_id), - user_inputs={}, - conversation_variables=[], - ) + variable_pool = build_test_variable_pool(variables=build_system_variables(workflow_execution_id=run_id)) return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py index f35710d2076..dabd2594b43 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py @@ -1,10 +1,11 @@ from __future__ import annotations from contextlib import contextmanager -from datetime import datetime from types import SimpleNamespace import pytest +from graphon.enums import BuiltinNodeTypes, WorkflowExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline @@ -44,11 +45,11 @@ from core.app.entities.task_entities import ( WorkflowStartStreamResponse, ) from core.base.tts.app_generator_tts_publisher import AudioTrunk -from dify_graph.enums import BuiltinNodeTypes, WorkflowExecutionStatus -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables, system_variables_to_mapping +from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.model import AppMode, EndUser +from tests.workflow_test_utils import build_test_variable_pool def _make_pipeline(): @@ -164,7 +165,7 @@ class TestWorkflowGenerateTaskPipeline: def test_handle_workflow_started_event_sets_run_id(self, monkeypatch): pipeline = _make_pipeline() pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=build_test_variable_pool(variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._workflow_response_converter.workflow_start_to_stream_response = lambda **kwargs: "started" @@ -191,7 +192,7 @@ class TestWorkflowGenerateTaskPipeline: node_execution_id="exec", node_id="node", node_type=BuiltinNodeTypes.START, - start_at=datetime.utcnow(), + start_at=naive_utc_now(), inputs={}, outputs={}, process_data={}, @@ -205,7 +206,7 @@ class TestWorkflowGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_execution_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=build_test_variable_pool(variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish" @@ -244,7 +245,7 @@ class TestWorkflowGenerateTaskPipeline: node_execution_id="exec", node_id="node", node_type=BuiltinNodeTypes.START, - start_at=datetime.utcnow(), + start_at=naive_utc_now(), inputs={}, outputs={}, process_data={}, @@ -257,7 +258,7 @@ class TestWorkflowGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_execution_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish" @@ -302,7 +303,7 @@ class TestWorkflowGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="LLM", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), node_run_index=1, ) iter_next = QueueIterationNextEvent( @@ -318,7 +319,7 @@ class TestWorkflowGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="LLM", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), node_run_index=1, ) loop_start = QueueLoopStartEvent( @@ -326,7 +327,7 @@ class TestWorkflowGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="LLM", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), node_run_index=1, ) loop_next = QueueLoopNextEvent( @@ -342,7 +343,7 @@ class TestWorkflowGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="LLM", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), node_run_index=1, ) filled_event = QueueHumanInputFormFilledEvent( @@ -358,7 +359,7 @@ class TestWorkflowGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="title", - expiration_time=datetime.utcnow(), + expiration_time=naive_utc_now(), ) agent_event = QueueAgentLogEvent( id="log", @@ -451,7 +452,7 @@ class TestWorkflowGenerateTaskPipeline: ) assert pipeline._created_by_role == CreatorUserRole.END_USER - assert pipeline._workflow_system_variables.user_id == "session-id" + assert system_variables_to_mapping(pipeline._workflow_system_variables)["user_id"] == "session-id" def test_process_returns_stream_and_blocking_variants(self): pipeline = _make_pipeline() @@ -647,7 +648,7 @@ class TestWorkflowGenerateTaskPipeline: node_title="title", node_type=BuiltinNodeTypes.LLM, node_run_index=1, - start_at=datetime.utcnow(), + start_at=naive_utc_now(), provider_type="provider", provider_id="provider-id", error="error", @@ -659,7 +660,7 @@ class TestWorkflowGenerateTaskPipeline: node_title="title", node_type=BuiltinNodeTypes.LLM, node_run_index=1, - start_at=datetime.utcnow(), + start_at=naive_utc_now(), provider_type="provider", provider_id="provider-id", ) @@ -684,7 +685,7 @@ class TestWorkflowGenerateTaskPipeline: node_execution_id="exec-id", node_id="node", node_type=BuiltinNodeTypes.START, - start_at=datetime.utcnow(), + start_at=naive_utc_now(), inputs={}, outputs={}, process_data={}, @@ -699,7 +700,7 @@ class TestWorkflowGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_execution_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) @@ -727,7 +728,7 @@ class TestWorkflowGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_execution_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._handle_ping_event = lambda event, **kwargs: iter(["ping"]) @@ -743,7 +744,7 @@ class TestWorkflowGenerateTaskPipeline: def test_process_stream_response_main_match_paths_and_cleanup(self): pipeline = _make_pipeline() pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._base_task_pipeline.queue_manager.listen = lambda: iter( @@ -815,7 +816,7 @@ class TestWorkflowGenerateTaskPipeline: pipeline._save_workflow_app_log(session=_Session(), workflow_run_id=None) assert len(added) == count_before - def test_save_output_for_event_writes_draft_variables(self, monkeypatch): + def test_save_output_for_event_writes_draft_variables(self): pipeline = _make_pipeline() saver_calls: list[tuple[object, object]] = [] captured_factory_args: dict[str, object] = {} @@ -828,36 +829,14 @@ class TestWorkflowGenerateTaskPipeline: captured_factory_args.update(kwargs) return _Saver() - class _Begin: - def __enter__(self): - return None - - def __exit__(self, exc_type, exc, tb): - return False - - class _Session: - def __init__(self, *args, **kwargs): - _ = args, kwargs - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def begin(self): - return _Begin() - pipeline._draft_var_saver_factory = _factory - monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.Session", _Session) - monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.db", SimpleNamespace(engine=object())) event = QueueNodeSucceededEvent( node_execution_id="exec-id", node_id="node-id", node_type=BuiltinNodeTypes.START, in_loop_id="loop-id", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), process_data={"k": "v"}, outputs={"out": 1}, ) diff --git a/api/tests/unit_tests/core/app/entities/test_queue_entities.py b/api/tests/unit_tests/core/app/entities/test_queue_entities.py new file mode 100644 index 00000000000..7c21b00966e --- /dev/null +++ b/api/tests/unit_tests/core/app/entities/test_queue_entities.py @@ -0,0 +1,12 @@ +from core.app.entities.queue_entities import QueueStopEvent + + +class TestQueueEntities: + def test_get_stop_reason_for_known_stop_by(self): + event = QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL) + assert event.get_stop_reason() == "Stopped by user." + + def test_get_stop_reason_for_unknown_stop_by(self): + event = QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL) + event.stopped_by = "unknown" + assert event.get_stop_reason() == "Stopped by unknown reason." diff --git a/api/tests/unit_tests/core/app/entities/test_rag_pipeline_invoke_entities.py b/api/tests/unit_tests/core/app/entities/test_rag_pipeline_invoke_entities.py new file mode 100644 index 00000000000..1e0ef6d6d68 --- /dev/null +++ b/api/tests/unit_tests/core/app/entities/test_rag_pipeline_invoke_entities.py @@ -0,0 +1,17 @@ +from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity + + +class TestRagPipelineInvokeEntity: + def test_defaults_and_fields(self): + entity = RagPipelineInvokeEntity( + pipeline_id="pipe-1", + application_generate_entity={"foo": "bar"}, + user_id="user-1", + tenant_id="tenant-1", + workflow_id="workflow-1", + streaming=True, + ) + + assert entity.workflow_execution_id is None + assert entity.workflow_thread_pool_id is None + assert entity.streaming is True diff --git a/api/tests/unit_tests/core/app/entities/test_task_entities.py b/api/tests/unit_tests/core/app/entities/test_task_entities.py new file mode 100644 index 00000000000..014a0cba729 --- /dev/null +++ b/api/tests/unit_tests/core/app/entities/test_task_entities.py @@ -0,0 +1,79 @@ +from graphon.enums import WorkflowNodeExecutionStatus + +from core.app.entities.task_entities import ( + NodeFinishStreamResponse, + NodeRetryStreamResponse, + NodeStartStreamResponse, + StreamEvent, +) + + +class TestTaskEntities: + def test_node_start_to_ignore_detail_dict(self): + data = NodeStartStreamResponse.Data( + id="exec-1", + node_id="node-1", + node_type="answer", + title="Answer", + index=1, + predecessor_node_id=None, + inputs={"foo": "bar"}, + created_at=1, + ) + response = NodeStartStreamResponse(task_id="task-1", workflow_run_id="run-1", data=data) + + payload = response.to_ignore_detail_dict() + + assert payload["event"] == StreamEvent.NODE_STARTED.value + assert payload["data"]["inputs"] is None + assert payload["data"]["extras"] == {} + + def test_node_finish_to_ignore_detail_dict(self): + data = NodeFinishStreamResponse.Data( + id="exec-1", + node_id="node-1", + node_type="answer", + title="Answer", + index=1, + predecessor_node_id=None, + inputs={"foo": "bar"}, + process_data={"step": 1}, + outputs={"answer": "ok"}, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + elapsed_time=0.1, + created_at=1, + finished_at=2, + ) + response = NodeFinishStreamResponse(task_id="task-1", workflow_run_id="run-1", data=data) + + payload = response.to_ignore_detail_dict() + + assert payload["event"] == StreamEvent.NODE_FINISHED.value + assert payload["data"]["inputs"] is None + assert payload["data"]["outputs"] is None + assert payload["data"]["files"] == [] + + def test_node_retry_to_ignore_detail_dict(self): + data = NodeRetryStreamResponse.Data( + id="exec-1", + node_id="node-1", + node_type="answer", + title="Answer", + index=1, + predecessor_node_id=None, + inputs={"foo": "bar"}, + process_data={"step": 1}, + outputs={"answer": "ok"}, + status=WorkflowNodeExecutionStatus.RETRY, + elapsed_time=0.1, + created_at=1, + finished_at=2, + retry_index=2, + ) + response = NodeRetryStreamResponse(task_id="task-1", workflow_run_id="run-1", data=data) + + payload = response.to_ignore_detail_dict() + + assert payload["event"] == StreamEvent.NODE_RETRY.value + assert payload["data"]["retry_index"] == 2 + assert payload["data"]["outputs"] is None diff --git a/api/tests/unit_tests/core/app/features/rate_limiting/test_rate_limit.py b/api/tests/unit_tests/core/app/features/rate_limiting/test_rate_limit.py index 3db10c1c729..538b130cace 100644 --- a/api/tests/unit_tests/core/app/features/rate_limiting/test_rate_limit.py +++ b/api/tests/unit_tests/core/app/features/rate_limiting/test_rate_limit.py @@ -68,8 +68,8 @@ class TestRateLimit: assert rate_limit.disabled() assert not hasattr(rate_limit, "initialized") - def test_should_skip_reinitialization_of_existing_instance(self, redis_patch): - """Test that existing instance doesn't reinitialize.""" + def test_should_flush_cache_when_reinitializing_existing_instance(self, redis_patch): + """Test existing instance refreshes Redis cache on reinitialization.""" redis_patch.configure_mock( **{ "exists.return_value": False, @@ -82,7 +82,37 @@ class TestRateLimit: RateLimit("client1", 10) + redis_patch.setex.assert_called_once_with( + "dify:rate_limit:client1:max_active_requests", + timedelta(days=1), + 10, + ) + + def test_should_reinitialize_after_being_disabled(self, redis_patch): + """Test disabled instance can be reinitialized and writes max_active_requests to Redis.""" + redis_patch.configure_mock( + **{ + "exists.return_value": False, + "setex.return_value": True, + } + ) + + # First construct with max_active_requests = 0 (disabled), which should skip initialization. + RateLimit("client1", 0) + + # Redis should not have been written to during disabled initialization. redis_patch.setex.assert_not_called() + redis_patch.reset_mock() + + # Reinitialize with a positive max_active_requests value; this should not raise + # and must write the max_active_requests key to Redis. + RateLimit("client1", 10) + + redis_patch.setex.assert_called_once_with( + "dify:rate_limit:client1:max_active_requests", + timedelta(days=1), + 10, + ) def test_should_be_disabled_when_max_requests_is_zero_or_negative(self): """Test disabled state for zero or negative limits.""" diff --git a/api/tests/unit_tests/core/app/features/test_annotation_reply.py b/api/tests/unit_tests/core/app/features/test_annotation_reply.py new file mode 100644 index 00000000000..e721a77079f --- /dev/null +++ b/api/tests/unit_tests/core/app/features/test_annotation_reply.py @@ -0,0 +1,163 @@ +import logging +from types import SimpleNamespace +from unittest.mock import Mock, patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature + + +class TestAnnotationReplyFeature: + def test_query_returns_none_when_setting_missing(self): + feature = AnnotationReplyFeature() + + with patch("core.app.features.annotation_reply.annotation_reply.db") as mock_db: + mock_db.session.scalar.return_value = None + + result = feature.query( + app_record=SimpleNamespace(id="app-1", tenant_id="tenant-1"), + message=SimpleNamespace(id="msg-1"), + query="hi", + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + ) + + assert result is None + + def test_query_returns_none_when_binding_missing(self): + feature = AnnotationReplyFeature() + annotation_setting = SimpleNamespace(collection_binding_detail=None) + + with patch("core.app.features.annotation_reply.annotation_reply.db") as mock_db: + mock_db.session.scalar.return_value = annotation_setting + + result = feature.query( + app_record=SimpleNamespace(id="app-1", tenant_id="tenant-1"), + message=SimpleNamespace(id="msg-1"), + query="hi", + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + ) + + assert result is None + + def test_query_returns_annotation_and_records_history_for_api(self): + feature = AnnotationReplyFeature() + annotation_setting = SimpleNamespace( + score_threshold=None, + collection_binding_detail=SimpleNamespace(provider_name="prov", model_name="model"), + ) + dataset_binding = SimpleNamespace(id="binding-1") + annotation = SimpleNamespace( + id="ann-1", + question_text="question", + content="content", + account_id="acct-1", + account=SimpleNamespace(name="Alice"), + ) + document = SimpleNamespace(metadata={"annotation_id": "ann-1", "score": 0.8}) + vector_instance = Mock() + vector_instance.search_by_vector.return_value = [document] + + with ( + patch("core.app.features.annotation_reply.annotation_reply.db") as mock_db, + patch( + "core.app.features.annotation_reply.annotation_reply.DatasetCollectionBindingService" + ) as mock_binding_service, + patch("core.app.features.annotation_reply.annotation_reply.Vector") as mock_vector, + patch( + "core.app.features.annotation_reply.annotation_reply.AppAnnotationService" + ) as mock_annotation_service, + ): + mock_db.session.scalar.return_value = annotation_setting + mock_binding_service.get_dataset_collection_binding.return_value = dataset_binding + mock_vector.return_value = vector_instance + mock_annotation_service.get_annotation_by_id.return_value = annotation + + result = feature.query( + app_record=SimpleNamespace(id="app-1", tenant_id="tenant-1"), + message=SimpleNamespace(id="msg-1"), + query="hi", + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + ) + + assert result == annotation + mock_annotation_service.add_annotation_history.assert_called_once() + _, _, _, _, _, _, _, from_source, score = mock_annotation_service.add_annotation_history.call_args[0] + assert from_source == "api" + assert score == 0.8 + + def test_query_returns_annotation_and_records_history_for_console(self): + feature = AnnotationReplyFeature() + annotation_setting = SimpleNamespace( + score_threshold=0.5, + collection_binding_detail=SimpleNamespace(provider_name="prov", model_name="model"), + ) + dataset_binding = SimpleNamespace(id="binding-1") + annotation = SimpleNamespace( + id="ann-1", + question_text="question", + content="content", + account_id="acct-1", + account=None, + ) + document = SimpleNamespace(metadata={"annotation_id": "ann-1", "score": 0.6}) + vector_instance = Mock() + vector_instance.search_by_vector.return_value = [document] + + with ( + patch("core.app.features.annotation_reply.annotation_reply.db") as mock_db, + patch( + "core.app.features.annotation_reply.annotation_reply.DatasetCollectionBindingService" + ) as mock_binding_service, + patch("core.app.features.annotation_reply.annotation_reply.Vector") as mock_vector, + patch( + "core.app.features.annotation_reply.annotation_reply.AppAnnotationService" + ) as mock_annotation_service, + ): + mock_db.session.scalar.return_value = annotation_setting + mock_binding_service.get_dataset_collection_binding.return_value = dataset_binding + mock_vector.return_value = vector_instance + mock_annotation_service.get_annotation_by_id.return_value = annotation + + result = feature.query( + app_record=SimpleNamespace(id="app-1", tenant_id="tenant-1"), + message=SimpleNamespace(id="msg-1"), + query="hi", + user_id="user-1", + invoke_from=InvokeFrom.EXPLORE, + ) + + assert result == annotation + _, _, _, _, _, _, _, from_source, _ = mock_annotation_service.add_annotation_history.call_args[0] + assert from_source == "console" + + def test_query_logs_and_returns_none_on_exception(self, caplog): + feature = AnnotationReplyFeature() + annotation_setting = SimpleNamespace( + score_threshold=None, + collection_binding_detail=SimpleNamespace(provider_name="prov", model_name="model"), + ) + + with ( + patch("core.app.features.annotation_reply.annotation_reply.db") as mock_db, + patch( + "core.app.features.annotation_reply.annotation_reply.DatasetCollectionBindingService" + ) as mock_binding_service, + patch("core.app.features.annotation_reply.annotation_reply.Vector") as mock_vector, + ): + mock_db.session.scalar.return_value = annotation_setting + mock_binding_service.get_dataset_collection_binding.return_value = SimpleNamespace(id="binding-1") + mock_vector.return_value.search_by_vector.side_effect = RuntimeError("boom") + + with caplog.at_level(logging.WARNING): + result = feature.query( + app_record=SimpleNamespace(id="app-1", tenant_id="tenant-1"), + message=SimpleNamespace(id="msg-1"), + query="hi", + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + ) + + assert result is None + assert "Query annotation failed" in caplog.text diff --git a/api/tests/unit_tests/core/app/features/test_hosting_moderation.py b/api/tests/unit_tests/core/app/features/test_hosting_moderation.py new file mode 100644 index 00000000000..01194c16f50 --- /dev/null +++ b/api/tests/unit_tests/core/app/features/test_hosting_moderation.py @@ -0,0 +1,30 @@ +from types import SimpleNamespace +from unittest.mock import Mock, patch + +from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature + + +class TestHostingModerationFeature: + def test_check_aggregates_text_and_calls_moderation(self): + application_generate_entity = Mock() + application_generate_entity.model_conf = {"model": "mock"} + application_generate_entity.app_config = SimpleNamespace(tenant_id="tenant-1") + + prompt_messages = [ + SimpleNamespace(content="hello"), + SimpleNamespace(content=123), + SimpleNamespace(content="world"), + ] + + with patch("core.app.features.hosting_moderation.hosting_moderation.moderation.check_moderation") as mock_check: + mock_check.return_value = True + + feature = HostingModerationFeature() + result = feature.check(application_generate_entity, prompt_messages) + + assert result is True + mock_check.assert_called_once_with( + tenant_id="tenant-1", + model_config=application_generate_entity.model_conf, + text="hello\nworld\n", + ) diff --git a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py index bdc889d9415..a78c1b428fb 100644 --- a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py @@ -1,18 +1,18 @@ from collections.abc import Sequence -from datetime import datetime from unittest.mock import Mock +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.graph_engine.command_channels import CommandChannel +from graphon.graph_events import NodeRunSucceededEvent, NodeRunVariableUpdatedEvent +from graphon.node_events import NodeRunResult +from graphon.runtime import ReadOnlyGraphRuntimeState +from graphon.variables import StringVariable +from graphon.variables.segments import Segment, StringSegment + from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID -from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus -from dify_graph.graph_engine.protocols.command_channel import CommandChannel -from dify_graph.graph_events.node import NodeRunSucceededEvent -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.variable_assigner.common import helpers as common_helpers -from dify_graph.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState -from dify_graph.system_variable import SystemVariable -from dify_graph.variables import StringVariable -from dify_graph.variables.segments import Segment +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID +from libs.datetime_utils import naive_utc_now class MockReadOnlyVariablePool: @@ -36,31 +36,38 @@ def _build_graph_runtime_state( conversation_id: str | None = None, ) -> ReadOnlyGraphRuntimeState: graph_runtime_state = Mock(spec=ReadOnlyGraphRuntimeState) + if conversation_id is not None: + variable_pool._variables[("sys", SystemVariableKey.CONVERSATION_ID.value)] = StringSegment( + value=conversation_id + ) graph_runtime_state.variable_pool = variable_pool - graph_runtime_state.system_variable = SystemVariable(conversation_id=conversation_id).as_view() return graph_runtime_state -def _build_node_run_succeeded_event( - *, - node_type: NodeType, - outputs: dict[str, object] | None = None, - process_data: dict[str, object] | None = None, -) -> NodeRunSucceededEvent: +def _build_node_run_succeeded_event() -> NodeRunSucceededEvent: return NodeRunSucceededEvent( id="node-exec-id", node_id="assigner", - node_type=node_type, - start_at=datetime.utcnow(), + node_type=BuiltinNodeTypes.LLM, + start_at=naive_utc_now(), node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs=outputs or {}, - process_data=process_data or {}, + outputs={}, + process_data={}, ), ) -def test_persists_conversation_variables_from_assigner_output(): +def _build_variable_updated_event(variable: StringVariable) -> NodeRunVariableUpdatedEvent: + return NodeRunVariableUpdatedEvent( + id="node-exec-id", + node_id="assigner", + node_type=BuiltinNodeTypes.VARIABLE_ASSIGNER, + variable=variable, + ) + + +def test_persists_conversation_variables_from_variable_update_event(): conversation_id = "conv-123" variable = StringVariable( id="var-1", @@ -68,55 +75,26 @@ def test_persists_conversation_variables_from_assigner_output(): value="updated", selector=[CONVERSATION_VARIABLE_NODE_ID, "name"], ) - process_data = common_helpers.set_updated_variables( - {}, [common_helpers.variable_to_processed_data(variable.selector, variable)] - ) - - variable_pool = MockReadOnlyVariablePool({(CONVERSATION_VARIABLE_NODE_ID, "name"): variable}) - updater = Mock() layer = ConversationVariablePersistenceLayer(updater) - layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) + layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool(), conversation_id), Mock(spec=CommandChannel)) - event = _build_node_run_succeeded_event(node_type=BuiltinNodeTypes.VARIABLE_ASSIGNER, process_data=process_data) + event = _build_variable_updated_event(variable) layer.on_event(event) updater.update.assert_called_once_with(conversation_id=conversation_id, variable=variable) - updater.flush.assert_called_once() -def test_skips_when_outputs_missing(): +def test_skips_non_variable_update_events(): conversation_id = "conv-456" - variable = StringVariable( - id="var-2", - name="name", - value="updated", - selector=[CONVERSATION_VARIABLE_NODE_ID, "name"], - ) - - variable_pool = MockReadOnlyVariablePool({(CONVERSATION_VARIABLE_NODE_ID, "name"): variable}) - updater = Mock() layer = ConversationVariablePersistenceLayer(updater) - layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) + layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool(), conversation_id), Mock(spec=CommandChannel)) - event = _build_node_run_succeeded_event(node_type=BuiltinNodeTypes.VARIABLE_ASSIGNER) + event = _build_node_run_succeeded_event() layer.on_event(event) updater.update.assert_not_called() - updater.flush.assert_not_called() - - -def test_skips_non_assigner_nodes(): - updater = Mock() - layer = ConversationVariablePersistenceLayer(updater) - layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool()), Mock(spec=CommandChannel)) - - event = _build_node_run_succeeded_event(node_type=BuiltinNodeTypes.LLM) - layer.on_event(event) - - updater.update.assert_not_called() - updater.flush.assert_not_called() def test_skips_non_conversation_variables(): @@ -127,18 +105,11 @@ def test_skips_non_conversation_variables(): value="updated", selector=["environment", "name"], ) - process_data = common_helpers.set_updated_variables( - {}, [common_helpers.variable_to_processed_data(non_conversation_variable.selector, non_conversation_variable)] - ) - - variable_pool = MockReadOnlyVariablePool() - updater = Mock() layer = ConversationVariablePersistenceLayer(updater) - layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) + layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool(), conversation_id), Mock(spec=CommandChannel)) - event = _build_node_run_succeeded_event(node_type=BuiltinNodeTypes.VARIABLE_ASSIGNER, process_data=process_data) + event = _build_variable_updated_event(non_conversation_variable) layer.on_event(event) updater.update.assert_not_called() - updater.flush.assert_not_called() diff --git a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py index 035f0ee05c9..035e64325bb 100644 --- a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py @@ -4,6 +4,17 @@ from time import time from unittest.mock import Mock import pytest +from graphon.entities.pause_reason import SchedulingPause +from graphon.graph_engine.entities.commands import GraphEngineCommand +from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError +from graphon.graph_events import ( + GraphRunFailedEvent, + GraphRunPausedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, +) +from graphon.runtime import ReadOnlyVariablePool +from graphon.variables.segments import Segment from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity @@ -13,17 +24,7 @@ from core.app.layers.pause_state_persist_layer import ( _AdvancedChatAppGenerateEntityWrapper, _WorkflowGenerateEntityWrapper, ) -from dify_graph.entities.pause_reason import SchedulingPause -from dify_graph.graph_engine.entities.commands import GraphEngineCommand -from dify_graph.graph_engine.layers.base import GraphEngineLayerNotInitializedError -from dify_graph.graph_events.graph import ( - GraphRunFailedEvent, - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, -) -from dify_graph.runtime.graph_runtime_state_protocol import ReadOnlyVariablePool -from dify_graph.variables.segments import Segment +from core.workflow.system_variables import SystemVariableKey from models.model import AppMode from repositories.factory import DifyAPIRepositoryFactory @@ -51,17 +52,6 @@ class TestDataFactory: return GraphRunFailedEvent(error=error, exceptions_count=exceptions_count) -class MockSystemVariableReadOnlyView: - """Minimal read-only system variable view for testing.""" - - def __init__(self, workflow_execution_id: str | None = None) -> None: - self._workflow_execution_id = workflow_execution_id - - @property - def workflow_execution_id(self) -> str | None: - return self._workflow_execution_id - - class MockReadOnlyVariablePool: """Mock implementation of ReadOnlyVariablePool for testing.""" @@ -76,13 +66,14 @@ class MockReadOnlyVariablePool: return None mock_segment = Mock(spec=Segment) mock_segment.value = value + mock_segment.text = value if isinstance(value, str) else None return mock_segment def get_all_by_node(self, node_id: str) -> dict[str, object]: return {key: value for (nid, key), value in self._variables.items() if nid == node_id} def get_by_prefix(self, prefix: str) -> dict[str, object]: - return {f"{nid}.{key}": value for (nid, key), value in self._variables.items() if nid.startswith(prefix)} + return {key: value for (nid, key), value in self._variables.items() if nid == prefix} class MockReadOnlyGraphRuntimeState: @@ -105,12 +96,10 @@ class MockReadOnlyGraphRuntimeState: self._ready_queue_size = ready_queue_size self._exceptions_count = exceptions_count self._outputs = outputs or {} - self._variable_pool = MockReadOnlyVariablePool(variables) - self._system_variable = MockSystemVariableReadOnlyView(workflow_execution_id) - - @property - def system_variable(self) -> MockSystemVariableReadOnlyView: - return self._system_variable + resolved_variables = dict(variables or {}) + if workflow_execution_id is not None: + resolved_variables[("sys", SystemVariableKey.WORKFLOW_EXECUTION_ID.value)] = workflow_execution_id + self._variable_pool = MockReadOnlyVariablePool(resolved_variables) @property def variable_pool(self) -> ReadOnlyVariablePool: @@ -161,7 +150,9 @@ class MockReadOnlyGraphRuntimeState: "exceptions_count": self._exceptions_count, "outputs": self._outputs, "variables": {f"{k[0]}.{k[1]}": v for k, v in self._variable_pool._variables.items()}, - "workflow_execution_id": self._system_variable.workflow_execution_id, + "workflow_execution_id": self._variable_pool._variables.get( + ("sys", SystemVariableKey.WORKFLOW_EXECUTION_ID.value) + ), } ) diff --git a/api/tests/unit_tests/core/app/layers/test_suspend_layer.py b/api/tests/unit_tests/core/app/layers/test_suspend_layer.py new file mode 100644 index 00000000000..95931f4f8ba --- /dev/null +++ b/api/tests/unit_tests/core/app/layers/test_suspend_layer.py @@ -0,0 +1,20 @@ +from graphon.graph_events import GraphRunPausedEvent + +from core.app.layers.suspend_layer import SuspendLayer + + +class TestSuspendLayer: + def test_on_event_accepts_paused_event(self): + layer = SuspendLayer() + assert layer.is_paused() is False + layer.on_graph_start() + assert layer.is_paused() is False + layer.on_event(GraphRunPausedEvent()) + assert layer.is_paused() is True + + def test_on_event_ignores_other_events(self): + layer = SuspendLayer() + layer.on_graph_start() + initial_state = layer.is_paused() + layer.on_event(object()) + assert layer.is_paused() is initial_state diff --git a/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py b/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py new file mode 100644 index 00000000000..7cf6eb4f310 --- /dev/null +++ b/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py @@ -0,0 +1,99 @@ +from unittest.mock import Mock, patch + +from graphon.graph_engine.entities.commands import CommandType, GraphEngineCommand + +from core.app.layers.timeslice_layer import TimeSliceLayer +from services.workflow.entities import WorkflowScheduleCFSPlanEntity +from services.workflow.scheduler import SchedulerCommand + + +class TestTimeSliceLayer: + def test_init_starts_scheduler_when_not_running(self): + scheduler = Mock() + scheduler.running = False + + with patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler): + _ = TimeSliceLayer(cfs_plan_scheduler=Mock(plan=Mock())) + + scheduler.start.assert_called_once() + + def test_on_graph_start_adds_job_for_time_slice(self): + scheduler = Mock() + scheduler.running = True + plan = WorkflowScheduleCFSPlanEntity( + schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice, + granularity=3, + ) + cfs_plan_scheduler = Mock(plan=plan) + + with ( + patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler), + patch("core.app.layers.timeslice_layer.uuid.uuid4") as mock_uuid, + ): + mock_uuid.return_value.hex = "job-1" + layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler) + layer.on_graph_start() + + assert layer.schedule_id == "job-1" + scheduler.add_job.assert_called_once() + + def test_on_graph_end_removes_job(self): + scheduler = Mock() + scheduler.running = True + plan = WorkflowScheduleCFSPlanEntity( + schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice, + granularity=3, + ) + cfs_plan_scheduler = Mock(plan=plan) + + with patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler): + layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler) + layer.schedule_id = "job-1" + layer.on_graph_end(None) + + scheduler.remove_job.assert_called_once_with("job-1") + + def test_checker_job_removes_when_stopped(self): + scheduler = Mock() + scheduler.running = True + cfs_plan_scheduler = Mock(plan=Mock()) + + with patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler): + layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler) + layer.stopped = True + layer._checker_job("job-1") + + scheduler.remove_job.assert_called_once_with("job-1") + + def test_checker_job_handles_resource_limit_without_command_channel(self): + scheduler = Mock() + scheduler.running = True + cfs_plan_scheduler = Mock(plan=Mock()) + cfs_plan_scheduler.can_schedule.return_value = SchedulerCommand.RESOURCE_LIMIT_REACHED + + with ( + patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler), + patch("core.app.layers.timeslice_layer.logger") as mock_logger, + ): + layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler) + layer._checker_job("job-1") + + scheduler.remove_job.assert_called_once_with("job-1") + mock_logger.exception.assert_called_once() + + def test_checker_job_sends_pause_command(self): + scheduler = Mock() + scheduler.running = True + cfs_plan_scheduler = Mock(plan=Mock()) + cfs_plan_scheduler.can_schedule.return_value = SchedulerCommand.RESOURCE_LIMIT_REACHED + + with patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler): + layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler) + layer.command_channel = Mock() + layer._checker_job("job-1") + + scheduler.remove_job.assert_called_once_with("job-1") + layer.command_channel.send_command.assert_called_once() + sent_command = layer.command_channel.send_command.call_args[0][0] + assert isinstance(sent_command, GraphEngineCommand) + assert sent_command.command_type == CommandType.PAUSE diff --git a/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py new file mode 100644 index 00000000000..aa9285789b3 --- /dev/null +++ b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py @@ -0,0 +1,109 @@ +from datetime import UTC, datetime, timedelta +from types import SimpleNamespace +from unittest.mock import Mock, patch + +from graphon.graph_events import GraphRunFailedEvent, GraphRunSucceededEvent +from graphon.runtime import VariablePool + +from core.app.layers.trigger_post_layer import TriggerPostLayer +from core.workflow.system_variables import build_system_variables +from models.enums import WorkflowTriggerStatus + + +class TestTriggerPostLayer: + def test_on_event_updates_trigger_log(self): + trigger_log = SimpleNamespace( + status=None, + workflow_run_id=None, + outputs=None, + elapsed_time=None, + total_tokens=None, + finished_at=None, + ) + runtime_state = SimpleNamespace( + outputs={"answer": "ok"}, + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")), + total_tokens=12, + ) + + with ( + patch("core.app.layers.trigger_post_layer.session_factory") as mock_session_factory, + patch("core.app.layers.trigger_post_layer.SQLAlchemyWorkflowTriggerLogRepository") as mock_repo_cls, + patch("core.app.layers.trigger_post_layer.datetime") as mock_datetime, + ): + mock_datetime.now.return_value = datetime(2026, 2, 20, tzinfo=UTC) + + session = Mock() + mock_session_factory.create_session.return_value.__enter__.return_value = session + + repo = Mock() + repo.get_by_id.return_value = trigger_log + mock_repo_cls.return_value = repo + + layer = TriggerPostLayer( + cfs_plan_scheduler_entity=Mock(), + start_time=datetime(2026, 2, 20, tzinfo=UTC) - timedelta(seconds=10), + trigger_log_id="log-1", + ) + layer.initialize(runtime_state, Mock()) + + layer.on_event(GraphRunSucceededEvent()) + + assert trigger_log.status == WorkflowTriggerStatus.SUCCEEDED + assert trigger_log.workflow_run_id == "run-1" + assert trigger_log.outputs is not None + assert trigger_log.elapsed_time is not None + assert trigger_log.total_tokens == 12 + assert trigger_log.finished_at is not None + repo.update.assert_called_once_with(trigger_log) + session.commit.assert_called_once() + + def test_on_event_handles_missing_trigger_log(self): + runtime_state = SimpleNamespace( + outputs={}, + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")), + total_tokens=0, + ) + + with ( + patch("core.app.layers.trigger_post_layer.session_factory") as mock_session_factory, + patch("core.app.layers.trigger_post_layer.SQLAlchemyWorkflowTriggerLogRepository") as mock_repo_cls, + patch("core.app.layers.trigger_post_layer.logger") as mock_logger, + ): + session = Mock() + mock_session_factory.create_session.return_value.__enter__.return_value = session + + repo = Mock() + repo.get_by_id.return_value = None + mock_repo_cls.return_value = repo + + layer = TriggerPostLayer( + cfs_plan_scheduler_entity=Mock(), + start_time=datetime(2026, 2, 20, tzinfo=UTC), + trigger_log_id="missing", + ) + layer.initialize(runtime_state, Mock()) + + layer.on_event(GraphRunFailedEvent(error="boom")) + + mock_logger.exception.assert_called_once() + session.commit.assert_not_called() + + def test_on_event_ignores_non_status_events(self): + runtime_state = SimpleNamespace( + outputs={}, + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")), + total_tokens=0, + ) + + with patch("core.app.layers.trigger_post_layer.session_factory") as mock_session_factory: + layer = TriggerPostLayer( + cfs_plan_scheduler_entity=Mock(), + start_time=datetime(2026, 2, 20, tzinfo=UTC), + trigger_log_id="log-1", + ) + layer.initialize(runtime_state, Mock()) + + layer.on_event(Mock()) + + mock_session_factory.create_session.assert_not_called() diff --git a/api/dify_graph/model_runtime/utils/__init__.py b/api/tests/unit_tests/core/app/task_pipeline/__init__.py similarity index 100% rename from api/dify_graph/model_runtime/utils/__init__.py rename to api/tests/unit_tests/core/app/task_pipeline/__init__.py diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py b/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py new file mode 100644 index 00000000000..58aa7d74782 --- /dev/null +++ b/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py @@ -0,0 +1,91 @@ +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError + +from core.app.entities.queue_entities import QueueErrorEvent +from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline +from core.errors.error import QuotaExceededError +from models.enums import MessageStatus + + +class TestBasedGenerateTaskPipeline: + @pytest.fixture + def pipeline(self): + app_config = SimpleNamespace( + tenant_id="tenant-1", + app_id="app-1", + sensitive_word_avoidance=None, + ) + app_generate_entity = SimpleNamespace(task_id="task-1", app_config=app_config) + return BasedGenerateTaskPipeline( + application_generate_entity=app_generate_entity, + queue_manager=Mock(), + stream=True, + ) + + def test_error_to_desc_quota_exceeded(self, pipeline): + message = pipeline._error_to_desc(QuotaExceededError()) + assert "quota" in message.lower() + + def test_handle_error_wraps_invoke_authorization(self, pipeline): + event = QueueErrorEvent(error=InvokeAuthorizationError()) + err = pipeline.handle_error(event=event) + assert isinstance(err, InvokeAuthorizationError) + assert str(err) == "Incorrect API key provided" + + def test_handle_error_preserves_invoke_error(self, pipeline): + event = QueueErrorEvent(error=InvokeError("bad")) + err = pipeline.handle_error(event=event) + assert err is event.error + + def test_handle_error_updates_message_when_found(self, pipeline): + event = QueueErrorEvent(error=ValueError("oops")) + message = SimpleNamespace(status=MessageStatus.NORMAL, error=None) + session = Mock() + session.scalar.return_value = message + + err = pipeline.handle_error(event=event, session=session, message_id="msg-1") + + assert err is event.error + assert message.status == MessageStatus.ERROR + assert message.error == "oops" + + def test_handle_error_returns_err_when_message_missing(self, pipeline): + event = QueueErrorEvent(error=ValueError("oops")) + session = Mock() + session.scalar.return_value = None + + err = pipeline.handle_error(event=event, session=session, message_id="msg-1") + + assert err is event.error + + def test_error_to_stream_response_and_ping(self, pipeline): + error_response = pipeline.error_to_stream_response(ValueError("boom")) + ping_response = pipeline.ping_stream_response() + + assert error_response.task_id == "task-1" + assert ping_response.task_id == "task-1" + + def test_handle_output_moderation_when_flagged(self, pipeline): + handler = Mock() + handler.moderation_completion.return_value = ("filtered", True) + pipeline.output_moderation_handler = handler + + result = pipeline.handle_output_moderation_when_task_finished("raw") + + assert result == "filtered" + handler.stop_thread.assert_called_once() + assert pipeline.output_moderation_handler is None + + def test_handle_output_moderation_when_not_flagged(self, pipeline): + handler = Mock() + handler.moderation_completion.return_value = ("safe", False) + pipeline.output_moderation_handler = handler + + result = pipeline.handle_output_moderation_when_task_finished("raw") + + assert result is None + handler.stop_thread.assert_called_once() + assert pipeline.output_moderation_handler is None diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py index 13fbca6e261..4aaa10a81af 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py @@ -2,6 +2,8 @@ from types import SimpleNamespace from unittest.mock import ANY, Mock, patch import pytest +from graphon.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import ChatAppGenerateEntity @@ -26,8 +28,6 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.base.tts import AppGeneratorTTSPublisher from core.ops.ops_trace_manager import TraceQueueManager -from dify_graph.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult -from dify_graph.model_runtime.entities.message_entities import TextPromptMessageContent from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py new file mode 100644 index 00000000000..f7e7b7e20ef --- /dev/null +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py @@ -0,0 +1,1228 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest +from graphon.file import FileTransferMethod +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, TextPromptMessageContent + +from core.app.app_config.entities import ( + AppAdditionalFeatures, + EasyUIBasedAppConfig, + EasyUIBasedAppModelConfigFrom, + ModelConfigEntity, + PromptTemplateEntity, +) +from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, CompletionAppGenerateEntity, InvokeFrom +from core.app.entities.queue_entities import ( + QueueAgentMessageEvent, + QueueAgentThoughtEvent, + QueueAnnotationReplyEvent, + QueueErrorEvent, + QueueLLMChunkEvent, + QueueMessageEndEvent, + QueueMessageFileEvent, + QueueMessageReplaceEvent, + QueuePingEvent, + QueueRetrieverResourcesEvent, + QueueStopEvent, +) +from core.app.entities.task_entities import ( + ChatbotAppStreamResponse, + CompletionAppStreamResponse, + ErrorStreamResponse, + MessageAudioEndStreamResponse, + MessageAudioStreamResponse, + MessageEndStreamResponse, + PingStreamResponse, +) +from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline +from core.base.tts import AudioTrunk +from models.model import AppMode + + +class _DummyModelConf: + def __init__(self) -> None: + self.model = "mock" + + +def _make_app_config(app_mode: AppMode) -> EasyUIBasedAppConfig: + return EasyUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=app_mode, + app_model_config_from=EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG, + app_model_config_id="model-config", + app_model_config_dict={}, + model=ModelConfigEntity(provider="mock", model="mock"), + prompt_template=PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.SIMPLE, + simple_prompt_template="hi", + ), + additional_features=AppAdditionalFeatures(), + variables=[], + ) + + +def _make_entity(entity_cls, app_mode: AppMode): + app_config = _make_app_config(app_mode) + return entity_cls.model_construct( + task_id="task", + app_config=app_config, + model_conf=_DummyModelConf(), + file_upload_config=None, + conversation_id=None, + inputs={}, + query="hello", + files=[], + parent_message_id=None, + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + call_depth=0, + trace_manager=None, + ) + + +class TestEasyUiBasedGenerateTaskPipeline: + def test_to_blocking_response_chat(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + pipeline._task_state.llm_result.message.content = "answer" + + def _gen(): + yield MessageEndStreamResponse(task_id="task", id="msg") + + response = pipeline._to_blocking_response(_gen()) + + assert response.data.answer == "answer" + + def test_to_blocking_response_completion(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.COMPLETION) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(CompletionAppGenerateEntity, AppMode.COMPLETION), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + pipeline._task_state.llm_result.message.content = "answer" + + def _gen(): + yield MessageEndStreamResponse(task_id="task", id="msg") + + response = pipeline._to_blocking_response(_gen()) + + assert response.data.answer == "answer" + + def test_listen_audio_msg_returns_none_when_no_publisher(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + + assert pipeline._listen_audio_msg(publisher=None, task_id="task") is None + + def test_process_stream_response_handles_chunks_and_end(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + chunk = LLMResultChunk( + model="mock", + prompt_messages=[], + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage( + content=[TextPromptMessageContent(data="hi"), TextPromptMessageContent(data="yo")] + ), + ), + ) + llm_result = LLMResult( + model="mock", + prompt_messages=[], + message=AssistantPromptMessage(content="done"), + usage=LLMUsage.empty_usage(), + ) + + events = [ + SimpleNamespace(event=QueueLLMChunkEvent(chunk=chunk)), + SimpleNamespace(event=QueueMessageReplaceEvent(text="replace", reason="output_moderation")), + SimpleNamespace(event=QueuePingEvent()), + SimpleNamespace(event=QueueMessageEndEvent(llm_result=llm_result)), + ] + + pipeline.queue_manager.listen = lambda: iter(events) + pipeline._message_cycle_manager.get_message_event_type = lambda message_id: None + pipeline._message_cycle_manager.message_to_stream_response = lambda **kwargs: "chunk" + pipeline._message_cycle_manager.message_replace_to_stream_response = lambda **kwargs: "replace" + pipeline.handle_output_moderation_when_task_finished = lambda completion: None + pipeline._message_end_to_stream_response = lambda: "end" + pipeline._save_message = lambda **kwargs: None + + class _Session: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def commit(self): + return None + + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", + _Session, + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + + responses = list(pipeline._process_stream_response(publisher=None)) + + assert "chunk" in responses + assert "replace" in responses + assert any(isinstance(item, PingStreamResponse) for item in responses) + assert responses[-1] == "end" + + def test_handle_output_moderation_chunk_directs_output(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + events: list[object] = [] + + class _Moderation: + def should_direct_output(self): + return True + + def get_final_output(self): + return "final" + + pipeline.output_moderation_handler = _Moderation() + pipeline.queue_manager.publish = lambda event, publish_from: events.append(event) + + result = pipeline._handle_output_moderation_chunk("token") + + assert result is True + assert any(isinstance(event, QueueLLMChunkEvent) for event in events) + assert any(isinstance(event, QueueStopEvent) for event in events) + + def test_handle_stop_updates_usage(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + class _ModelType: + def calc_response_usage(self, model, credentials, prompt_tokens, completion_tokens): + return LLMUsage.from_metadata( + { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + } + ) + + class _ModelConf: + def __init__(self) -> None: + self.model = "mock" + self.credentials = {} + self.provider_model_bundle = SimpleNamespace(model_type_instance=_ModelType()) + + app_config = _make_app_config(AppMode.CHAT) + application_generate_entity = ChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + model_conf=_ModelConf(), + file_upload_config=None, + conversation_id=None, + inputs={}, + query="hello", + files=[], + parent_message_id=None, + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + call_depth=0, + trace_manager=None, + ) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + pipeline._task_state.llm_result.prompt_messages = [AssistantPromptMessage(content="prompt")] + pipeline._task_state.llm_result.message = AssistantPromptMessage(content="answer") + + calls: list[int] = [] + + class _FakeModelInstance: + def __init__(self, provider_model_bundle, model): + pass + + def get_llm_num_tokens(self, messages): + calls.append(1) + return 10 if len(calls) == 1 else 5 + + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.ModelInstance", + _FakeModelInstance, + ) + + pipeline._handle_stop(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL)) + + assert pipeline._task_state.llm_result.usage.prompt_tokens == 10 + assert pipeline._task_state.llm_result.usage.completion_tokens == 5 + + def test_record_files_builds_file_payloads(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + + message_files = [ + SimpleNamespace( + id="mf-1", + message_id="msg", + transfer_method=FileTransferMethod.REMOTE_URL, + url="http://example.com/a.png", + upload_file_id=None, + type="image", + ), + SimpleNamespace( + id="mf-2", + message_id="msg", + transfer_method=FileTransferMethod.LOCAL_FILE, + url="", + upload_file_id="upload-1", + type="image", + ), + SimpleNamespace( + id="mf-3", + message_id="msg", + transfer_method=FileTransferMethod.TOOL_FILE, + url="tool/file.bin", + upload_file_id=None, + type="file", + ), + ] + upload_files = [ + SimpleNamespace( + id="upload-1", + name="local.png", + mime_type="image/png", + size=123, + extension="png", + ) + ] + + class _Result: + def __init__(self, items): + self._items = items + + def all(self): + return self._items + + class _Session: + def __init__(self, *args, **kwargs): + self.calls = 0 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def scalars(self, *args, **kwargs): + self.calls += 1 + return _Result(message_files if self.calls == 1 else upload_files) + + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", + _Session, + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + monkeypatch.setattr( + "core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url", + lambda **kwargs: "signed-url", + ) + monkeypatch.setattr( + "core.app.task_pipeline.message_file_utils.sign_tool_file", + lambda **kwargs: "signed-tool", + ) + + response = pipeline._message_end_to_stream_response() + files = response.files + + assert files + assert len(files) == 3 + + def test_process_stream_response_handles_annotation_and_error(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + agent_chunk = LLMResultChunk( + model="mock", + prompt_messages=[], + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content="agent"), + ), + ) + + events = [ + SimpleNamespace(event=QueueAnnotationReplyEvent(message_annotation_id="ann")), + SimpleNamespace(event=QueueAgentThoughtEvent(agent_thought_id="thought")), + SimpleNamespace(event=QueueMessageFileEvent(message_file_id="file")), + SimpleNamespace(event=QueueAgentMessageEvent(chunk=agent_chunk)), + SimpleNamespace(event=QueueErrorEvent(error=ValueError("boom"))), + ] + + pipeline.queue_manager.listen = lambda: iter(events) + pipeline._message_cycle_manager.handle_annotation_reply = lambda event: SimpleNamespace(content="annotated") + pipeline._agent_thought_to_stream_response = lambda event: "thought" + pipeline._message_cycle_manager.message_file_to_stream_response = lambda event: "file" + pipeline._agent_message_to_stream_response = lambda **kwargs: "agent" + pipeline.handle_error = lambda **kwargs: ValueError("boom") + pipeline.error_to_stream_response = lambda err: err + + class _Session: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def commit(self): + return None + + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", + _Session, + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + + responses = list(pipeline._process_stream_response(publisher=None)) + + assert "thought" in responses + assert "file" in responses + assert "agent" in responses + assert isinstance(responses[-1], ValueError) + assert pipeline._task_state.llm_result.message.content == "annotatedagent" + + def test_agent_thought_to_stream_response_returns_payload(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + agent_thought = SimpleNamespace( + id="thought", + position=1, + thought="t", + observation="o", + tool="tool", + tool_labels={}, + tool_input="input", + files=[], + ) + + class _Session: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def query(self, *args, **kwargs): + return self + + def where(self, *args, **kwargs): + return self + + def first(self): + return agent_thought + + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", + _Session, + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + + response = pipeline._agent_thought_to_stream_response(QueueAgentThoughtEvent(agent_thought_id="thought")) + + assert response is not None + assert response.id == "thought" + + def test_process_routes_to_stream_and_starts_conversation_name_generation(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + pipeline._message_cycle_manager.generate_conversation_name = Mock(return_value=object()) + pipeline._wrapper_process_stream_response = lambda trace_manager: iter(["payload"]) + pipeline._to_stream_response = lambda generator: "streamed" + + result = pipeline.process() + + assert result == "streamed" + pipeline._message_cycle_manager.generate_conversation_name.assert_called_once_with( + conversation_id="conv", query="hello" + ) + + def test_process_routes_to_blocking_for_completion_mode(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.COMPLETION) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(CompletionAppGenerateEntity, AppMode.COMPLETION), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + pipeline._message_cycle_manager.generate_conversation_name = Mock() + pipeline._wrapper_process_stream_response = lambda trace_manager: iter(["payload"]) + pipeline._to_blocking_response = lambda generator: "blocking" + + result = pipeline.process() + + assert result == "blocking" + pipeline._message_cycle_manager.generate_conversation_name.assert_not_called() + + def test_to_blocking_response_raises_error_stream_exception(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + + def _gen(): + yield ErrorStreamResponse(task_id="task", err=ValueError("stream error")) + + with pytest.raises(ValueError, match="stream error"): + pipeline._to_blocking_response(_gen()) + + def test_to_blocking_response_raises_when_generator_ends_without_message_end(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + + def _gen(): + yield PingStreamResponse(task_id="task") + + with pytest.raises(RuntimeError, match="queue listening stopped unexpectedly"): + pipeline._to_blocking_response(_gen()) + + def test_to_stream_response_wraps_completion_stream_events(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.COMPLETION) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(CompletionAppGenerateEntity, AppMode.COMPLETION), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + def _gen(): + yield PingStreamResponse(task_id="task") + + response = list(pipeline._to_stream_response(_gen()))[0] + + assert isinstance(response, CompletionAppStreamResponse) + assert response.message_id == "msg" + + def test_to_stream_response_wraps_chat_stream_events(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + def _gen(): + yield PingStreamResponse(task_id="task") + + response = list(pipeline._to_stream_response(_gen()))[0] + + assert isinstance(response, ChatbotAppStreamResponse) + assert response.conversation_id == "conv" + + def test_listen_audio_msg_returns_audio_response_for_non_finish_audio(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + publisher = SimpleNamespace(check_and_get_audio=lambda: AudioTrunk("responding", "abc")) + + response = pipeline._listen_audio_msg(publisher=publisher, task_id="task") + + assert isinstance(response, MessageAudioStreamResponse) + assert response.audio == "abc" + + def test_listen_audio_msg_returns_none_for_finish_audio(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + publisher = SimpleNamespace(check_and_get_audio=lambda: AudioTrunk("finish", "abc")) + + assert pipeline._listen_audio_msg(publisher=publisher, task_id="task") is None + + def test_wrapper_process_stream_response_without_tts_publisher(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + pipeline._process_stream_response = lambda publisher, trace_manager: iter(["payload"]) + + responses = list(pipeline._wrapper_process_stream_response()) + + assert responses == ["payload"] + + def test_wrapper_process_stream_response_with_tts_publisher(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + entity = _make_entity(ChatAppGenerateEntity, AppMode.CHAT) + entity.app_config.app_model_config_dict = { + "text_to_speech": {"autoPlay": "enabled", "enabled": True, "voice": "v", "language": "en"} + } + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=entity, + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + class _Publisher: + def check_and_get_audio(self): + return AudioTrunk("finish", "") + + inline_audio = MessageAudioStreamResponse(task_id="task", audio="inline") + audio_calls = iter([inline_audio, None]) + pipeline._listen_audio_msg = lambda publisher, task_id: next(audio_calls) + pipeline._process_stream_response = lambda publisher, trace_manager: iter(["payload"]) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.AppGeneratorTTSPublisher", + lambda tenant_id, voice, language: _Publisher(), + ) + + responses = list(pipeline._wrapper_process_stream_response()) + + assert responses[0] == inline_audio + assert responses[1] == "payload" + assert isinstance(responses[-1], MessageAudioEndStreamResponse) + + def test_wrapper_process_stream_response_timeout_yields_audio_chunk(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + entity = _make_entity(ChatAppGenerateEntity, AppMode.CHAT) + entity.app_config.app_model_config_dict = { + "text_to_speech": {"autoPlay": "enabled", "enabled": True, "voice": "v", "language": "en"} + } + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=entity, + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + class _Publisher: + def __init__(self): + self._events = iter([None, AudioTrunk("responding", "later"), AudioTrunk("finish", "")]) + + def check_and_get_audio(self): + return next(self._events) + + clock = {"value": 0.0} + + def _fake_time(): + clock["value"] += 0.1 + return clock["value"] + + pipeline._process_stream_response = lambda publisher, trace_manager: iter([]) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.AppGeneratorTTSPublisher", + lambda tenant_id, voice, language: _Publisher(), + ) + monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.time.time", _fake_time) + monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.time.sleep", lambda _: None) + + responses = list(pipeline._wrapper_process_stream_response()) + + assert any(isinstance(item, MessageAudioStreamResponse) for item in responses) + assert isinstance(responses[-1], MessageAudioEndStreamResponse) + + def test_process_stream_response_handles_stop_event_and_output_replacement(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + pipeline._task_state.llm_result.message.content = "raw answer" + pipeline.queue_manager.listen = lambda: iter( + [SimpleNamespace(event=QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL))] + ) + pipeline._handle_stop = Mock() + pipeline.handle_output_moderation_when_task_finished = lambda answer: "moderated answer" + pipeline._message_cycle_manager.message_replace_to_stream_response = lambda answer: f"replace:{answer}" + pipeline._save_message = lambda **kwargs: None + pipeline._message_end_to_stream_response = lambda: "end" + + class _Session: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def commit(self): + return None + + monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + + responses = list(pipeline._process_stream_response(publisher=None)) + + assert responses == ["replace:moderated answer", "end"] + pipeline._handle_stop.assert_called_once() + + def test_process_stream_response_handles_retriever_unknown_and_empty_chunk(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + retriever_event = QueueRetrieverResourcesEvent(retriever_resources=[]) + chunk = LLMResultChunk( + model="mock", + prompt_messages=[], + delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=None)), + ) + handled = {"retriever": 0} + + def _handle_retriever_resources(event): + handled["retriever"] += 1 + + pipeline._message_cycle_manager.handle_retriever_resources = _handle_retriever_resources + pipeline.queue_manager.listen = lambda: iter( + [ + SimpleNamespace(event=retriever_event), + SimpleNamespace(event=SimpleNamespace()), + SimpleNamespace(event=QueueLLMChunkEvent(chunk=chunk)), + ] + ) + + responses = list(pipeline._process_stream_response(publisher=None)) + + assert responses == [] + assert handled["retriever"] == 1 + + def test_process_stream_response_skips_when_output_moderation_directs_chunk(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + chunk = LLMResultChunk( + model="mock", + prompt_messages=[], + delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content="x")), + ) + pipeline._handle_output_moderation_chunk = lambda text: True + pipeline.queue_manager.listen = lambda: iter([SimpleNamespace(event=QueueLLMChunkEvent(chunk=chunk))]) + + responses = list(pipeline._process_stream_response(publisher=None)) + + assert responses == [] + + def test_process_stream_response_ignores_unsupported_chunk_content_types(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + chunk = SimpleNamespace( + prompt_messages=[], + delta=SimpleNamespace(message=SimpleNamespace(content=[object(), "ok"])), + ) + pipeline._message_cycle_manager.get_message_event_type = lambda message_id: None + pipeline._message_cycle_manager.message_to_stream_response = lambda **kwargs: kwargs["answer"] + pipeline.queue_manager.listen = lambda: iter( + [SimpleNamespace(event=QueueLLMChunkEvent.model_construct(chunk=chunk))] + ) + + responses = list(pipeline._process_stream_response(publisher=None)) + + assert responses == ["ok"] + + def test_process_stream_response_reaches_post_loop_branch_with_thread_reference(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + pipeline._conversation_name_generate_thread = object() + pipeline.queue_manager.listen = lambda: iter([]) + + assert list(pipeline._process_stream_response(publisher=None)) == [] + + def test_save_message_persists_fields_and_emits_trace(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + pipeline.start_at = 10.0 + pipeline._model_config = SimpleNamespace(mode="chat") + pipeline._task_state.llm_result.prompt_messages = [AssistantPromptMessage(content="prompt")] + pipeline._task_state.llm_result.message = AssistantPromptMessage(content=" {{name}} hello ") + pipeline._task_state.llm_result.usage = LLMUsage.from_metadata( + {"prompt_tokens": 3, "completion_tokens": 5, "total_price": "1.23"} + ) + + message_obj = SimpleNamespace(id="msg") + conversation_obj = SimpleNamespace(id="conv") + session = Mock() + session.scalar.side_effect = [message_obj, conversation_obj] + trace_manager = SimpleNamespace(add_trace_task=Mock()) + sent_payloads: list[tuple[tuple[object, ...], dict[str, object]]] = [] + + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.PromptMessageUtil.prompt_messages_to_prompt_for_saving", + lambda mode, prompt_messages: "serialized-prompt", + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.PromptTemplateParser.remove_template_variables", + lambda text: text.replace("{{name}}", "").strip(), + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.naive_utc_now", + lambda: datetime(2024, 1, 1, tzinfo=UTC), + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.time.perf_counter", lambda: 15.0 + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.message_was_created.send", + lambda *args, **kwargs: sent_payloads.append((args, kwargs)), + ) + + pipeline._save_message(session=session, trace_manager=trace_manager) + + assert message_obj.message == "serialized-prompt" + assert message_obj.answer == "hello" + assert message_obj.provider_response_latency == 5.0 + assert trace_manager.add_trace_task.called + assert len(sent_payloads) == 1 + + def test_save_message_raises_when_message_not_found(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + session = Mock() + session.scalar.return_value = None + + with pytest.raises(ValueError, match="message msg not found"): + pipeline._save_message(session=session) + + def test_save_message_raises_when_conversation_not_found(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + session = Mock() + session.scalar.side_effect = [SimpleNamespace(id="msg"), None] + + with pytest.raises(ValueError, match="Conversation conv not found"): + pipeline._save_message(session=session) + + def test_message_end_to_stream_response_includes_usage_metadata(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + pipeline._task_state.llm_result.usage = LLMUsage.from_metadata({"prompt_tokens": 1, "completion_tokens": 2}) + + class _Result: + def all(self): + return [] + + class _Session: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def scalars(self, *args, **kwargs): + return _Result() + + monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + + response = pipeline._message_end_to_stream_response() + + assert response.id == "msg" + assert response.metadata["usage"]["prompt_tokens"] == 1 + + def test_record_files_returns_none_when_message_has_no_files(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + + class _Result: + def all(self): + return [] + + class _Session: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def scalars(self, *args, **kwargs): + return _Result() + + monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + + response = pipeline._message_end_to_stream_response() + + assert response.files is None + + def test_record_files_handles_local_fallback_and_tool_url_variants(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + message_files = [ + SimpleNamespace( + id="mf-local-fallback", + message_id="msg", + transfer_method=FileTransferMethod.LOCAL_FILE, + url="", + upload_file_id="upload-missing", + type="file", + ), + SimpleNamespace( + id="mf-tool-http", + message_id="msg", + transfer_method=FileTransferMethod.TOOL_FILE, + url="http://cdn.example.com/file.txt?x=1", + upload_file_id=None, + type="file", + ), + SimpleNamespace( + id="mf-tool-noext", + message_id="msg", + transfer_method=FileTransferMethod.TOOL_FILE, + url="tool/path/toolid", + upload_file_id=None, + type="file", + ), + ] + + class _Result: + def __init__(self, items): + self._items = items + + def all(self): + return self._items + + class _Session: + def __init__(self, *args, **kwargs): + self.calls = 0 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def scalars(self, *args, **kwargs): + self.calls += 1 + return _Result(message_files if self.calls == 1 else []) + + monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + monkeypatch.setattr( + "core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url", + lambda **kwargs: "local-fallback-signed", + ) + monkeypatch.setattr( + "core.app.task_pipeline.message_file_utils.sign_tool_file", + lambda **kwargs: "tool-signed", + ) + + response = pipeline._message_end_to_stream_response() + files = response.files + + assert files is not None + assert files[0]["url"] == "local-fallback-signed" + assert files[1]["filename"] == "file.txt" + assert files[2]["extension"] == ".bin" + + def test_agent_message_to_stream_response_builds_payload(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + response = pipeline._agent_message_to_stream_response(answer="hello", message_id="msg") + + assert response.id == "msg" + assert response.answer == "hello" + + def test_agent_thought_to_stream_response_returns_none_when_not_found(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + class _Session: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def query(self, *args, **kwargs): + return self + + def where(self, *args, **kwargs): + return self + + def first(self): + return None + + monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + + response = pipeline._agent_thought_to_stream_response(QueueAgentThoughtEvent(agent_thought_id="missing")) + + assert response is None + + def test_handle_output_moderation_chunk_appends_token_when_not_directing(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + appended_tokens: list[str] = [] + + class _Moderation: + def should_direct_output(self): + return False + + def append_new_token(self, text): + appended_tokens.append(text) + + pipeline.output_moderation_handler = _Moderation() + + result = pipeline._handle_output_moderation_chunk("next-token") + + assert result is False + assert appended_tokens == ["next-token"] diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py index 582990c88a9..31b73130666 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py @@ -17,11 +17,11 @@ import uuid from unittest.mock import MagicMock, Mock, patch import pytest +from graphon.file import FileTransferMethod, FileType from sqlalchemy.orm import Session from core.app.entities.task_entities import MessageEndStreamResponse from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline -from dify_graph.file.enums import FileTransferMethod from models.model import MessageFile, UploadFile @@ -51,7 +51,7 @@ class TestMessageEndStreamResponseFiles: message_file.transfer_method = FileTransferMethod.LOCAL_FILE message_file.upload_file_id = str(uuid.uuid4()) message_file.url = None - message_file.type = "image" + message_file.type = FileType.IMAGE return message_file @pytest.fixture @@ -63,7 +63,7 @@ class TestMessageEndStreamResponseFiles: message_file.transfer_method = FileTransferMethod.REMOTE_URL message_file.upload_file_id = None message_file.url = "https://example.com/image.jpg" - message_file.type = "image" + message_file.type = FileType.IMAGE return message_file @pytest.fixture @@ -75,7 +75,7 @@ class TestMessageEndStreamResponseFiles: message_file.transfer_method = FileTransferMethod.TOOL_FILE message_file.upload_file_id = None message_file.url = "tool_file_123.png" - message_file.type = "image" + message_file.type = FileType.IMAGE return message_file @pytest.fixture diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_exc.py b/api/tests/unit_tests/core/app/task_pipeline/test_exc.py new file mode 100644 index 00000000000..9ea7e96e73d --- /dev/null +++ b/api/tests/unit_tests/core/app/task_pipeline/test_exc.py @@ -0,0 +1,11 @@ +from core.app.task_pipeline.exc import RecordNotFoundError, WorkflowRunNotFoundError + + +class TestTaskPipelineExceptions: + def test_record_not_found_error_message(self): + err = RecordNotFoundError("Message", "msg-1") + assert str(err) == "Message with id msg-1 not found" + + def test_workflow_run_not_found_error_message(self): + err = WorkflowRunNotFoundError("run-1") + assert str(err) == "WorkflowRun with id run-1 not found" diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py index c0c636715df..07ee75ed35f 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py @@ -1,12 +1,16 @@ """Unit tests for the message cycle manager optimization.""" +from types import SimpleNamespace from unittest.mock import Mock, patch import pytest -from flask import current_app +from flask import Flask, current_app -from core.app.entities.task_entities import MessageStreamResponse, StreamEvent +from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueRetrieverResourcesEvent +from core.app.entities.task_entities import MessageStreamResponse, StreamEvent, TaskStateMetadata from core.app.task_pipeline.message_cycle_manager import MessageCycleManager +from core.rag.entities.citation_metadata import RetrievalSourceMetadata +from models.model import AppMode class TestMessageCycleManagerOptimization: @@ -90,6 +94,16 @@ class TestMessageCycleManagerOptimization: assert result == StreamEvent.MESSAGE mock_session.scalar.assert_called_once() + def test_get_message_event_type_uses_cache_without_query(self, message_cycle_manager): + """Return MESSAGE_FILE directly from in-memory cache without opening a DB session.""" + message_cycle_manager._message_has_file.add("cached-message") + + with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: + result = message_cycle_manager.get_message_event_type("cached-message") + + assert result == StreamEvent.MESSAGE_FILE + mock_session_factory.create_session.assert_not_called() + def test_message_to_stream_response_with_precomputed_event_type(self, message_cycle_manager): """MessageCycleManager.message_to_stream_response expects a valid event_type; callers should precompute it.""" with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: @@ -180,3 +194,390 @@ class TestMessageCycleManagerOptimization: assert chunk2_response.event == StreamEvent.MESSAGE assert chunk1_response.answer == "Chunk 1" assert chunk2_response.answer == "Chunk 2" + + def test_generate_conversation_name_returns_none_for_completion(self, message_cycle_manager): + """Return None when completion entities are used for conversation naming. + + Args: message_cycle_manager with DummyCompletion injected as CompletionAppGenerateEntity. + Returns: None, indicating no name generation for completion apps. + Side effects: None expected. + """ + + class DummyCompletion: + pass + + with patch("core.app.task_pipeline.message_cycle_manager.CompletionAppGenerateEntity", DummyCompletion): + message_cycle_manager._application_generate_entity = DummyCompletion() + result = message_cycle_manager.generate_conversation_name(conversation_id="c1", query="hi") + + assert result is None + + def test_generate_conversation_name_starts_thread_and_flips_first_message_flag(self, message_cycle_manager): + """Spawn background generation thread for the first chat message.""" + message_cycle_manager._application_generate_entity.is_new_conversation = True + message_cycle_manager._application_generate_entity.extras = {"auto_generate_conversation_name": True} + flask_app = object() + + class DummyTimer: + def __init__(self, interval, function, args=None, kwargs=None): + self.interval = interval + self.function = function + self.args = args or [] + self.kwargs = kwargs + self.daemon = False + self.started = False + + def start(self): + self.started = True + + with ( + patch( + "core.app.task_pipeline.message_cycle_manager.current_app", + new=SimpleNamespace(_get_current_object=lambda: flask_app), + ), + patch("core.app.task_pipeline.message_cycle_manager.Timer", DummyTimer), + ): + thread = message_cycle_manager.generate_conversation_name(conversation_id="conv-1", query="hello") + + assert isinstance(thread, DummyTimer) + assert thread.interval == 1 + assert thread.function == message_cycle_manager._generate_conversation_name_worker + assert thread.started is True + assert thread.daemon is True + assert thread.kwargs["flask_app"] is flask_app + assert thread.kwargs["conversation_id"] == "conv-1" + assert thread.kwargs["query"] == "hello" + assert message_cycle_manager._application_generate_entity.is_new_conversation is False + + def test_generate_conversation_name_skips_thread_when_auto_generate_disabled(self, message_cycle_manager): + """Skip thread creation when auto naming is disabled but still mark conversation as not new.""" + message_cycle_manager._application_generate_entity.is_new_conversation = True + message_cycle_manager._application_generate_entity.extras = {"auto_generate_conversation_name": False} + + with patch("core.app.task_pipeline.message_cycle_manager.Timer") as mock_timer: + result = message_cycle_manager.generate_conversation_name(conversation_id="conv-2", query="hello") + + assert result is None + assert message_cycle_manager._application_generate_entity.is_new_conversation is False + mock_timer.assert_not_called() + + def test_generate_conversation_name_worker_returns_when_conversation_missing(self, message_cycle_manager): + """Return early when the conversation cannot be found.""" + flask_app = Flask(__name__) + db_session = Mock() + db_session.scalar.return_value = None + + with patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db: + mock_db.session = db_session + message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-missing", "hello") + + db_session.commit.assert_not_called() + db_session.close.assert_not_called() + + def test_generate_conversation_name_worker_returns_when_app_missing(self, message_cycle_manager): + """Return early when non-completion conversation has no app relation.""" + flask_app = Flask(__name__) + conversation = SimpleNamespace(mode=AppMode.CHAT, app=None, app_id="app-id") + db_session = Mock() + db_session.scalar.return_value = conversation + + with patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db: + mock_db.session = db_session + message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-1", "hello") + + db_session.commit.assert_not_called() + db_session.close.assert_not_called() + + def test_generate_conversation_name_worker_uses_cached_name(self, message_cycle_manager): + """Use cached conversation name when present and avoid LLM call.""" + flask_app = Flask(__name__) + conversation = SimpleNamespace( + mode=AppMode.CHAT, + app=SimpleNamespace(tenant_id="tenant-1"), + app_id="app-id", + name="", + ) + db_session = Mock() + db_session.scalar.return_value = conversation + + with ( + patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db, + patch("core.app.task_pipeline.message_cycle_manager.redis_client") as mock_redis, + patch("core.app.task_pipeline.message_cycle_manager.LLMGenerator") as mock_llm_generator, + ): + mock_db.session = db_session + mock_redis.get.return_value = b"cached-title" + + message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-1", "hello") + + assert conversation.name == "cached-title" + db_session.commit.assert_called_once() + db_session.close.assert_called_once() + mock_llm_generator.generate_conversation_name.assert_not_called() + mock_redis.setex.assert_not_called() + + def test_generate_conversation_name_worker_generates_and_caches_name(self, message_cycle_manager): + """Generate conversation name and write it to redis cache on cache miss.""" + flask_app = Flask(__name__) + conversation = SimpleNamespace( + mode=AppMode.CHAT, + app=SimpleNamespace(tenant_id="tenant-1"), + app_id="app-id", + name="", + ) + db_session = Mock() + db_session.scalar.return_value = conversation + + with ( + patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db, + patch("core.app.task_pipeline.message_cycle_manager.redis_client") as mock_redis, + patch("core.app.task_pipeline.message_cycle_manager.LLMGenerator") as mock_llm_generator, + ): + mock_db.session = db_session + mock_redis.get.return_value = None + mock_llm_generator.generate_conversation_name.return_value = "generated-title" + + message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-1", "hello") + + assert conversation.name == "generated-title" + db_session.commit.assert_called_once() + db_session.close.assert_called_once() + mock_redis.setex.assert_called_once() + + def test_generate_conversation_name_worker_falls_back_when_generation_fails(self, message_cycle_manager): + """Fallback to truncated query when LLM generation fails.""" + flask_app = Flask(__name__) + conversation = SimpleNamespace( + mode=AppMode.CHAT, + app=SimpleNamespace(tenant_id="tenant-1"), + app_id="app-id", + name="", + ) + db_session = Mock() + db_session.scalar.return_value = conversation + long_query = "q" * 60 + + with ( + patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db, + patch("core.app.task_pipeline.message_cycle_manager.redis_client") as mock_redis, + patch("core.app.task_pipeline.message_cycle_manager.LLMGenerator") as mock_llm_generator, + patch("core.app.task_pipeline.message_cycle_manager.dify_config") as mock_dify_config, + patch("core.app.task_pipeline.message_cycle_manager.logger") as mock_logger, + ): + mock_db.session = db_session + mock_redis.get.return_value = None + mock_llm_generator.generate_conversation_name.side_effect = RuntimeError("generation failed") + mock_dify_config.DEBUG = True + + message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-1", long_query) + + assert conversation.name == (long_query[:47] + "...") + db_session.commit.assert_called_once() + db_session.close.assert_called_once() + mock_logger.exception.assert_called_once() + + def test_handle_annotation_reply_sets_metadata(self, message_cycle_manager): + """Populate task metadata from annotation reply events. + + Args: message_cycle_manager with TaskStateMetadata and a mocked AppAnnotationService. + Returns: The fetched annotation object. + Side effects: Updates metadata.annotation_reply with id and account name. + """ + message_cycle_manager._task_state = SimpleNamespace(metadata=TaskStateMetadata()) + + annotation = SimpleNamespace( + id="ann-1", + account_id="acct-1", + account=SimpleNamespace(name="Alice"), + ) + + with patch("core.app.task_pipeline.message_cycle_manager.AppAnnotationService") as mock_service: + mock_service.get_annotation_by_id.return_value = annotation + + result = message_cycle_manager.handle_annotation_reply( + QueueAnnotationReplyEvent(message_annotation_id="ann-1") + ) + + assert result == annotation + assert message_cycle_manager._task_state.metadata.annotation_reply.id == "ann-1" + assert message_cycle_manager._task_state.metadata.annotation_reply.account.name == "Alice" + + def test_handle_annotation_reply_returns_none_when_missing(self, message_cycle_manager): + """Return None and keep metadata unchanged when annotation is not found.""" + message_cycle_manager._task_state = SimpleNamespace(metadata=TaskStateMetadata()) + + with patch("core.app.task_pipeline.message_cycle_manager.AppAnnotationService") as mock_service: + mock_service.get_annotation_by_id.return_value = None + + result = message_cycle_manager.handle_annotation_reply( + QueueAnnotationReplyEvent(message_annotation_id="missing") + ) + + assert result is None + assert message_cycle_manager._task_state.metadata.annotation_reply is None + + def test_handle_retriever_resources_merges_and_deduplicates(self, message_cycle_manager): + """Merge retriever resources, deduplicate, and preserve ordering positions. + + Args: message_cycle_manager with show_retrieve_source enabled and existing metadata. + Returns: None. + Side effects: Updates metadata.retriever_resources with unique items and positions. + """ + message_cycle_manager._application_generate_entity.app_config = SimpleNamespace( + additional_features=SimpleNamespace(show_retrieve_source=True) + ) + existing = RetrievalSourceMetadata(dataset_id="d1", document_id="doc1") + message_cycle_manager._task_state = SimpleNamespace(metadata=TaskStateMetadata(retriever_resources=[existing])) + + duplicate = RetrievalSourceMetadata(dataset_id="d1", document_id="doc1") + new_resource = RetrievalSourceMetadata(dataset_id="d2", document_id="doc2") + + event = QueueRetrieverResourcesEvent(retriever_resources=[duplicate, new_resource]) + message_cycle_manager.handle_retriever_resources(event) + + assert len(message_cycle_manager._task_state.metadata.retriever_resources) == 2 + assert message_cycle_manager._task_state.metadata.retriever_resources[0].position == 1 + assert message_cycle_manager._task_state.metadata.retriever_resources[1].position == 2 + + def test_message_file_to_stream_response_builds_signed_url(self, message_cycle_manager): + """Build a stream response with a signed tool file URL. + + Args: message_cycle_manager with mocked Session/db and sign_tool_file. + Returns: MessageStreamResponse with signed url and belongs_to normalized to user. + Side effects: Calls sign_tool_file for tool file ids. + """ + message_cycle_manager._application_generate_entity.task_id = "task-1" + + message_file = SimpleNamespace( + id="file-1", + type="image", + belongs_to=None, + url="tool://file.verylongextension", + message_id="msg-1", + ) + + session = Mock() + session.scalar.return_value = message_file + + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_cls, + patch("core.app.task_pipeline.message_cycle_manager.sign_tool_file") as mock_sign, + patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db, + ): + mock_db.engine = Mock() + mock_session_cls.return_value.__enter__.return_value = session + mock_sign.return_value = "signed-url" + + response = message_cycle_manager.message_file_to_stream_response(SimpleNamespace(message_file_id="file-1")) + + assert response.url == "signed-url" + assert response.belongs_to == "user" + mock_sign.assert_called_once_with(tool_file_id="file", extension=".bin") + + def test_handle_retriever_resources_requires_features(self, message_cycle_manager): + """Raise when retriever resources are handled without feature config. + + Args: message_cycle_manager with additional_features unset and empty metadata. + Raises: ValueError when show_retrieve_source configuration is missing. + """ + message_cycle_manager._application_generate_entity.app_config = SimpleNamespace(additional_features=None) + message_cycle_manager._task_state = SimpleNamespace(metadata=TaskStateMetadata()) + + with pytest.raises(ValueError): + message_cycle_manager.handle_retriever_resources(QueueRetrieverResourcesEvent(retriever_resources=[])) + + def test_handle_retriever_resources_skips_none_entries(self, message_cycle_manager): + """Ignore null resource entries while preserving valid resources.""" + message_cycle_manager._application_generate_entity.app_config = SimpleNamespace( + additional_features=SimpleNamespace(show_retrieve_source=True) + ) + message_cycle_manager._task_state = SimpleNamespace(metadata=TaskStateMetadata(retriever_resources=[])) + resource = RetrievalSourceMetadata(dataset_id="d1", document_id="doc1") + + message_cycle_manager.handle_retriever_resources(SimpleNamespace(retriever_resources=[None, resource])) + + assert len(message_cycle_manager._task_state.metadata.retriever_resources) == 1 + assert message_cycle_manager._task_state.metadata.retriever_resources[0].position == 1 + + def test_message_file_to_stream_response_uses_http_url_directly(self, message_cycle_manager): + """Use original URL when message file URL is already HTTP.""" + message_cycle_manager._application_generate_entity.task_id = "task-http" + message_file = SimpleNamespace( + id="file-http", + type="image", + belongs_to="assistant", + url="http://example.com/pic.png", + message_id="msg-http", + ) + + session = Mock() + session.scalar.return_value = message_file + + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_cls, + patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db, + ): + mock_db.engine = Mock() + mock_session_cls.return_value.__enter__.return_value = session + + response = message_cycle_manager.message_file_to_stream_response( + SimpleNamespace(message_file_id="file-http") + ) + + assert response is not None + assert response.url == "http://example.com/pic.png" + assert "msg-http" in message_cycle_manager._message_has_file + + def test_message_file_to_stream_response_defaults_extension_to_bin_without_dot(self, message_cycle_manager): + """Default tool file extension to .bin when URL has no extension part.""" + message_cycle_manager._application_generate_entity.task_id = "task-bin" + message_file = SimpleNamespace( + id="file-bin", + type="file", + belongs_to="assistant", + url="tool-file-id", + message_id="msg-bin", + ) + + session = Mock() + session.scalar.return_value = message_file + + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_cls, + patch("core.app.task_pipeline.message_cycle_manager.sign_tool_file") as mock_sign, + patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db, + ): + mock_db.engine = Mock() + mock_session_cls.return_value.__enter__.return_value = session + mock_sign.return_value = "signed-bin-url" + + response = message_cycle_manager.message_file_to_stream_response( + SimpleNamespace(message_file_id="file-bin") + ) + + assert response is not None + assert response.url == "signed-bin-url" + mock_sign.assert_called_once_with(tool_file_id="tool-file-id", extension=".bin") + + def test_message_file_to_stream_response_returns_none_when_file_missing(self, message_cycle_manager): + """Return None when message file lookup does not find a record.""" + session = Mock() + session.scalar.return_value = None + + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_cls, + patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db, + ): + mock_db.engine = Mock() + mock_session_cls.return_value.__enter__.return_value = session + + response = message_cycle_manager.message_file_to_stream_response(SimpleNamespace(message_file_id="missing")) + + assert response is None + + def test_message_replace_to_stream_response_returns_reason(self, message_cycle_manager): + """Include the provided replacement reason in the stream payload.""" + response = message_cycle_manager.message_replace_to_stream_response("replaced", reason="moderation") + + assert response.answer == "replaced" + assert response.reason == "moderation" diff --git a/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py b/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py new file mode 100644 index 00000000000..29df7eea863 --- /dev/null +++ b/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py @@ -0,0 +1,58 @@ +from types import SimpleNamespace +from unittest.mock import patch + +from graphon.model_runtime.entities.model_entities import ModelPropertyKey + +from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager +from core.app.app_config.entities import ModelConfigEntity +from models.provider_ids import ModelProviderID + + +def test_validate_and_set_defaults_reuses_single_model_assembly(): + provider_name = str(ModelProviderID("openai")) + provider_entity = SimpleNamespace(provider=provider_name) + model = SimpleNamespace(model="gpt-4o-mini", model_properties={ModelPropertyKey.MODE: "chat"}) + provider_configurations = SimpleNamespace(get_models=lambda **kwargs: [model]) + assembly = SimpleNamespace( + model_provider_factory=SimpleNamespace(get_providers=lambda: [provider_entity]), + provider_manager=SimpleNamespace(get_configurations=lambda tenant_id: provider_configurations), + ) + config = { + "model": { + "provider": "openai", + "name": "gpt-4o-mini", + "completion_params": {"stop": []}, + } + } + + with patch( + "core.app.app_config.easy_ui_based_app.model_config.manager.create_plugin_model_assembly", + return_value=assembly, + ) as mock_assembly: + result, keys = ModelConfigManager.validate_and_set_defaults("tenant-1", config) + + assert result["model"]["provider"] == provider_name + assert result["model"]["mode"] == "chat" + assert keys == ["model"] + mock_assembly.assert_called_once_with(tenant_id="tenant-1") + + +def test_convert_keeps_model_config_shape(): + config = { + "model": { + "provider": "openai", + "name": "gpt-4o-mini", + "mode": "chat", + "completion_params": {"temperature": 0.3, "stop": ["END"]}, + } + } + + result = ModelConfigManager.convert(config) + + assert result == ModelConfigEntity( + provider="openai", + model="gpt-4o-mini", + mode="chat", + parameters={"temperature": 0.3}, + stop=["END"], + ) diff --git a/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py b/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py index 0f8a846d112..dc2d82ccd6c 100644 --- a/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py +++ b/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py @@ -2,14 +2,14 @@ from datetime import UTC, datetime from unittest.mock import Mock import pytest +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus, WorkflowType +from graphon.node_events import NodeRunResult from core.app.workflow.layers.persistence import ( PersistenceWorkflowInfo, WorkflowPersistenceLayer, _NodeRuntimeSnapshot, ) -from dify_graph.enums import WorkflowNodeExecutionStatus, WorkflowType -from dify_graph.node_events import NodeRunResult def _build_layer() -> WorkflowPersistenceLayer: @@ -58,3 +58,42 @@ def test_update_node_execution_prefers_event_finished_at(monkeypatch: pytest.Mon assert node_execution.finished_at == event_finished_at assert node_execution.elapsed_time == 2.0 + + +def test_update_node_execution_projects_start_outputs() -> None: + layer = _build_layer() + node_execution = Mock() + node_execution.id = "node-exec-2" + node_execution.node_type = BuiltinNodeTypes.START + node_execution.created_at = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC).replace(tzinfo=None) + node_execution.update_from_mapping = Mock() + + layer._node_snapshots[node_execution.id] = _NodeRuntimeSnapshot( + node_id="start", + title="Start", + predecessor_node_id=None, + iteration_id=None, + loop_id=None, + created_at=node_execution.created_at, + ) + + layer._update_node_execution( + node_execution, + NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"question": "hello"}, + outputs={ + "question": "hello", + "sys.query": "hello", + "env.API_KEY": "secret", + }, + ), + WorkflowNodeExecutionStatus.SUCCEEDED, + ) + + node_execution.update_from_mapping.assert_called_once_with( + inputs={"question": "hello"}, + process_data={}, + outputs={"question": "hello"}, + metadata={}, + ) diff --git a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py new file mode 100644 index 00000000000..7be9d6ac1eb --- /dev/null +++ b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py @@ -0,0 +1,370 @@ +from __future__ import annotations + +import base64 +import hashlib +import hmac +from types import SimpleNamespace +from unittest.mock import MagicMock, patch +from urllib.parse import parse_qs, urlparse + +import pytest +from graphon.file import File, FileTransferMethod, FileType + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.file_access import DatabaseFileAccessController, FileAccessScope +from core.app.workflow import file_runtime +from core.app.workflow.file_runtime import DifyWorkflowFileRuntime, bind_dify_workflow_file_runtime +from core.workflow.file_reference import build_file_reference +from models import ToolFile, UploadFile + + +def _build_file( + *, + transfer_method: FileTransferMethod, + reference: str | None = None, + remote_url: str | None = None, + extension: str | None = None, +) -> File: + return File( + id="file-id", + type=FileType.IMAGE, + transfer_method=transfer_method, + reference=reference, + remote_url=remote_url, + filename="diagram.png", + extension=extension, + mime_type="image/png", + size=128, + ) + + +def _build_runtime() -> DifyWorkflowFileRuntime: + return DifyWorkflowFileRuntime(file_access_controller=DatabaseFileAccessController()) + + +def test_resolve_file_url_returns_remote_url() -> None: + runtime = _build_runtime() + file = _build_file( + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/diagram.png", + ) + + assert runtime.resolve_file_url(file=file) == "https://example.com/diagram.png" + + +def test_resolve_file_url_requires_file_reference() -> None: + runtime = _build_runtime() + file = SimpleNamespace(transfer_method=FileTransferMethod.LOCAL_FILE, reference=None) + + with pytest.raises(ValueError, match="Missing file reference"): + runtime.resolve_file_url(file=file) + + +def test_resolve_file_url_requires_extension_for_tool_files() -> None: + runtime = _build_runtime() + file = _build_file( + transfer_method=FileTransferMethod.TOOL_FILE, + reference=build_file_reference(record_id="tool-file-id"), + extension=None, + ) + + with pytest.raises(ValueError, match="Missing file extension"): + runtime.resolve_file_url(file=file) + + +def test_resolve_file_url_uses_tool_signatures_for_tool_and_datasource_files( + monkeypatch: pytest.MonkeyPatch, +) -> None: + sign_tool_file = MagicMock(return_value="https://signed.example.com/file") + monkeypatch.setattr(file_runtime, "sign_tool_file", sign_tool_file) + runtime = _build_runtime() + + tool_file = _build_file( + transfer_method=FileTransferMethod.TOOL_FILE, + reference=build_file_reference(record_id="tool-file-id"), + extension=".png", + ) + datasource_file = _build_file( + transfer_method=FileTransferMethod.DATASOURCE_FILE, + reference=build_file_reference(record_id="datasource-file-id"), + extension=".png", + ) + + assert runtime.resolve_file_url(file=tool_file) == "https://signed.example.com/file" + assert runtime.resolve_file_url(file=datasource_file) == "https://signed.example.com/file" + assert sign_tool_file.call_count == 2 + + +def test_resolve_upload_file_url_signs_internal_urls_and_supports_attachments( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr("core.app.workflow.file_runtime.time.time", lambda: 1700000000) + monkeypatch.setattr("core.app.workflow.file_runtime.os.urandom", lambda _: b"\x01" * 16) + monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr( + "core.app.workflow.file_runtime.dify_config.INTERNAL_FILES_URL", + "https://internal.example.com", + ) + + runtime = _build_runtime() + url = runtime.resolve_upload_file_url( + upload_file_id="upload-file-id", + as_attachment=True, + for_external=False, + ) + parsed = urlparse(url) + query = parse_qs(parsed.query) + + assert parsed.netloc == "internal.example.com" + assert parsed.path == "/files/upload-file-id/file-preview" + assert query["as_attachment"] == ["true"] + assert query["timestamp"] == ["1700000000"] + + +def test_verify_preview_signature_validates_signature_and_expiration(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.app.workflow.file_runtime.time.time", lambda: 1700000000) + monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.FILES_ACCESS_TIMEOUT", 60) + runtime = _build_runtime() + payload = "file-preview|upload-file-id|1700000000|nonce" + sign = base64.urlsafe_b64encode(hmac.new(b"unit-secret", payload.encode(), hashlib.sha256).digest()).decode() + + assert ( + runtime.verify_preview_signature( + preview_kind="file", + file_id="upload-file-id", + timestamp="1700000000", + nonce="nonce", + sign=sign, + ) + is True + ) + assert ( + runtime.verify_preview_signature( + preview_kind="file", + file_id="upload-file-id", + timestamp="1700000000", + nonce="nonce", + sign="bad-signature", + ) + is False + ) + + monkeypatch.setattr("core.app.workflow.file_runtime.time.time", lambda: 1700000100) + assert ( + runtime.verify_preview_signature( + preview_kind="file", + file_id="upload-file-id", + timestamp="1700000000", + nonce="nonce", + sign=sign, + ) + is False + ) + + +def test_load_file_bytes_returns_bytes_and_rejects_non_bytes(monkeypatch: pytest.MonkeyPatch) -> None: + runtime = _build_runtime() + file = _build_file( + transfer_method=FileTransferMethod.LOCAL_FILE, + reference=build_file_reference(record_id="upload-file-id"), + ) + session = MagicMock() + session.get.return_value = SimpleNamespace(key="canonical-storage-key") + + class _SessionContext: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(file_runtime.session_factory, "create_session", lambda: _SessionContext()) + monkeypatch.setattr(file_runtime.storage, "load", lambda *args, **kwargs: b"image-bytes") + + assert runtime.load_file_bytes(file=file) == b"image-bytes" + session.get.assert_called_with(UploadFile, "upload-file-id") + + monkeypatch.setattr(file_runtime.storage, "load", lambda *args, **kwargs: "not-bytes") + with pytest.raises(ValueError, match="is not a bytes object"): + runtime.load_file_bytes(file=file) + + +def test_resolve_storage_key_ignores_encoded_reference_when_unscoped(monkeypatch: pytest.MonkeyPatch) -> None: + runtime = _build_runtime() + file = _build_file( + transfer_method=FileTransferMethod.LOCAL_FILE, + reference=build_file_reference(record_id="upload-file-id", storage_key="tampered-storage-key"), + ) + session = MagicMock() + session.get.return_value = SimpleNamespace(key="canonical-storage-key") + + class _SessionContext: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(file_runtime.session_factory, "create_session", lambda: _SessionContext()) + + assert runtime._resolve_storage_key(file=file) == "canonical-storage-key" + session.get.assert_called_once_with(UploadFile, "upload-file-id") + + +def test_resolve_storage_key_uses_canonical_record_when_scope_is_bound(monkeypatch: pytest.MonkeyPatch) -> None: + controller = MagicMock() + controller.current_scope.return_value = FileAccessScope( + tenant_id="tenant-id", + user_id="end-user-id", + user_from=UserFrom.END_USER, + invoke_from=InvokeFrom.WEB_APP, + ) + controller.get_upload_file.return_value = SimpleNamespace(key="canonical-storage-key") + runtime = DifyWorkflowFileRuntime(file_access_controller=controller) + file = _build_file( + transfer_method=FileTransferMethod.LOCAL_FILE, + reference=build_file_reference(record_id="upload-file-id", storage_key="tampered-storage-key"), + ) + session = MagicMock() + + class _SessionContext: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(file_runtime.session_factory, "create_session", lambda: _SessionContext()) + + assert runtime._resolve_storage_key(file=file) == "canonical-storage-key" + controller.get_upload_file.assert_called_once_with(session=session, file_id="upload-file-id") + + +def test_resolve_upload_file_url_rejects_unauthorized_scoped_access(monkeypatch: pytest.MonkeyPatch) -> None: + controller = MagicMock() + controller.current_scope.return_value = FileAccessScope( + tenant_id="tenant-id", + user_id="end-user-id", + user_from=UserFrom.END_USER, + invoke_from=InvokeFrom.WEB_APP, + ) + controller.get_upload_file.return_value = None + runtime = DifyWorkflowFileRuntime(file_access_controller=controller) + session = MagicMock() + + class _SessionContext: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(file_runtime.session_factory, "create_session", lambda: _SessionContext()) + + with pytest.raises(ValueError, match="Upload file upload-file-id not found"): + runtime.resolve_upload_file_url(upload_file_id="upload-file-id") + + +@pytest.mark.parametrize( + ("transfer_method", "record_id", "expected_storage_key"), + [ + (FileTransferMethod.LOCAL_FILE, "upload-file-id", "upload-storage-key"), + (FileTransferMethod.DATASOURCE_FILE, "upload-file-id", "upload-storage-key"), + (FileTransferMethod.TOOL_FILE, "tool-file-id", "tool-storage-key"), + ], +) +def test_resolve_storage_key_loads_database_records( + monkeypatch: pytest.MonkeyPatch, + transfer_method: FileTransferMethod, + record_id: str, + expected_storage_key: str, +) -> None: + runtime = _build_runtime() + file = _build_file( + transfer_method=transfer_method, + reference=build_file_reference(record_id=record_id), + extension=".png", + ) + session = MagicMock() + + def get(model_class, value): + if transfer_method in {FileTransferMethod.LOCAL_FILE, FileTransferMethod.DATASOURCE_FILE}: + assert model_class is UploadFile + return SimpleNamespace(key="upload-storage-key") + assert model_class is ToolFile + return SimpleNamespace(file_key="tool-storage-key") + + session.get.side_effect = get + + class _SessionContext: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(file_runtime.session_factory, "create_session", lambda: _SessionContext()) + + assert runtime._resolve_storage_key(file=file) == expected_storage_key + + +@pytest.mark.parametrize( + ("transfer_method", "expected_message"), + [ + (FileTransferMethod.LOCAL_FILE, "Upload file upload-file-id not found"), + (FileTransferMethod.TOOL_FILE, "Tool file tool-file-id not found"), + ], +) +def test_resolve_storage_key_raises_when_records_are_missing( + monkeypatch: pytest.MonkeyPatch, + transfer_method: FileTransferMethod, + expected_message: str, +) -> None: + runtime = _build_runtime() + record_id = "upload-file-id" if transfer_method == FileTransferMethod.LOCAL_FILE else "tool-file-id" + file = _build_file( + transfer_method=transfer_method, + reference=build_file_reference(record_id=record_id), + extension=".png", + ) + session = MagicMock() + session.get.return_value = None + + class _SessionContext: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(file_runtime.session_factory, "create_session", lambda: _SessionContext()) + + with pytest.raises(ValueError, match=expected_message): + runtime._resolve_storage_key(file=file) + + +def test_runtime_helper_wrappers_delegate_to_config_and_io(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.MULTIMODAL_SEND_FORMAT", "url") + runtime = _build_runtime() + + assert runtime.multimodal_send_format == "url" + + with patch.object(file_runtime.ssrf_proxy, "get", return_value="response") as mock_get: + assert runtime.http_get("http://example", follow_redirects=False) == "response" + mock_get.assert_called_once_with("http://example", follow_redirects=False) + + with patch.object(file_runtime.storage, "load", return_value=b"data") as mock_load: + assert runtime.storage_load("path", stream=True) == b"data" + mock_load.assert_called_once_with("path", stream=True) + + +def test_bind_dify_workflow_file_runtime_registers_runtime(monkeypatch: pytest.MonkeyPatch) -> None: + set_runtime = MagicMock() + monkeypatch.setattr(file_runtime, "set_workflow_file_runtime", set_runtime) + + bind_dify_workflow_file_runtime() + + set_runtime.assert_called_once() + assert isinstance(set_runtime.call_args.args[0], DifyWorkflowFileRuntime) diff --git a/api/tests/unit_tests/core/app/workflow/test_node_factory.py b/api/tests/unit_tests/core/app/workflow/test_node_factory.py new file mode 100644 index 00000000000..8497261d453 --- /dev/null +++ b/api/tests/unit_tests/core/app/workflow/test_node_factory.py @@ -0,0 +1,161 @@ +from types import SimpleNamespace + +import pytest +from graphon.enums import BuiltinNodeTypes + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context +from core.workflow.node_factory import DifyNodeFactory + + +class DummyNode: + def __init__(self, *, id, config, graph_init_params, graph_runtime_state, **kwargs): + self.id = id + self.config = config + self.graph_init_params = graph_init_params + self.graph_runtime_state = graph_runtime_state + self.kwargs = kwargs + + +class DummyCodeNode(DummyNode): + @classmethod + def default_code_providers(cls): + return () + + +class DummyTemplateTransformNode(DummyNode): + pass + + +class DummyHttpRequestNode(DummyNode): + pass + + +class DummyKnowledgeRetrievalNode(DummyNode): + pass + + +class DummyDocumentExtractorNode(DummyNode): + pass + + +class TestDifyNodeFactory: + @staticmethod + def _stub_node_resolution(monkeypatch, node_class): + monkeypatch.setattr( + "core.workflow.node_factory.resolve_workflow_node_class", + lambda **_kwargs: node_class, + ) + + def _factory(self, monkeypatch): + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_STRING_LENGTH", 10) + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_NUMBER", 10) + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MIN_NUMBER", -10) + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_PRECISION", 4) + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_DEPTH", 2) + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH", 2) + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_STRING_ARRAY_LENGTH", 2) + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH", 2) + monkeypatch.setattr("core.workflow.node_factory.dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH", 100) + monkeypatch.setattr("core.workflow.node_factory.dify_config.UNSTRUCTURED_API_URL", "http://u") + monkeypatch.setattr("core.workflow.node_factory.dify_config.UNSTRUCTURED_API_KEY", "key") + + run_context = build_dify_run_context( + tenant_id="tenant", + app_id="app", + user_id="user", + user_from=UserFrom.END_USER, + invoke_from=InvokeFrom.WEB_APP, + ) + + return DifyNodeFactory( + graph_init_params=SimpleNamespace(run_context=run_context), + graph_runtime_state=SimpleNamespace(), + ) + + def test_create_node_unknown_type(self, monkeypatch): + factory = self._factory(monkeypatch) + + with pytest.raises(ValueError): + factory.create_node({"id": "node-1", "data": {"type": "unknown"}}) + + def test_create_node_missing_mapping(self, monkeypatch): + factory = self._factory(monkeypatch) + monkeypatch.setattr("core.workflow.node_factory.get_node_type_classes_mapping", lambda: {}) + + with pytest.raises(ValueError): + factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.START}}) + + def test_create_node_missing_latest_class(self, monkeypatch): + factory = self._factory(monkeypatch) + monkeypatch.setattr( + "core.workflow.node_factory.get_node_type_classes_mapping", + lambda: {BuiltinNodeTypes.START: {"1": None}}, + ) + monkeypatch.setattr("core.workflow.node_factory.LATEST_VERSION", "latest") + + with pytest.raises(ValueError): + factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.START}}) + + def test_create_node_selects_versioned_class(self, monkeypatch): + factory = self._factory(monkeypatch) + selected_versions: list[tuple[str, str]] = [] + + class DummyNodeV2(DummyNode): + pass + + def _get_mapping(): + selected_versions.append(("snapshot", "called")) + return {BuiltinNodeTypes.START: {"1": DummyNode, "2": DummyNodeV2}} + + monkeypatch.setattr("core.workflow.node_factory.get_node_type_classes_mapping", _get_mapping) + + node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.START, "version": "2"}}) + + assert isinstance(node, DummyNodeV2) + assert node.id == "node-1" + assert selected_versions == [("snapshot", "called")] + + def test_create_node_code_branch(self, monkeypatch): + factory = self._factory(monkeypatch) + self._stub_node_resolution(monkeypatch, DummyCodeNode) + + node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.CODE}}) + + assert isinstance(node, DummyCodeNode) + assert node.id == "node-1" + + def test_create_node_template_transform_branch(self, monkeypatch): + factory = self._factory(monkeypatch) + self._stub_node_resolution(monkeypatch, DummyTemplateTransformNode) + + node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.TEMPLATE_TRANSFORM}}) + + assert isinstance(node, DummyTemplateTransformNode) + assert "jinja2_template_renderer" in node.kwargs + + def test_create_node_http_request_branch(self, monkeypatch): + factory = self._factory(monkeypatch) + self._stub_node_resolution(monkeypatch, DummyHttpRequestNode) + + node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.HTTP_REQUEST}}) + + assert isinstance(node, DummyHttpRequestNode) + assert "http_request_config" in node.kwargs + + def test_create_node_knowledge_retrieval_branch(self, monkeypatch): + factory = self._factory(monkeypatch) + self._stub_node_resolution(monkeypatch, DummyKnowledgeRetrievalNode) + + node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL}}) + + assert isinstance(node, DummyKnowledgeRetrievalNode) + assert node.kwargs == {} + + def test_create_node_document_extractor_branch(self, monkeypatch): + factory = self._factory(monkeypatch) + self._stub_node_resolution(monkeypatch, DummyDocumentExtractorNode) + + node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.DOCUMENT_EXTRACTOR}}) + + assert isinstance(node, DummyDocumentExtractorNode) + assert "unstructured_api_config" in node.kwargs diff --git a/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py b/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py new file mode 100644 index 00000000000..a47d3db6f5b --- /dev/null +++ b/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py @@ -0,0 +1,210 @@ +from __future__ import annotations + +from types import SimpleNamespace + +from graphon.enums import BuiltinNodeTypes + +from core.app.workflow.layers.observability import ObservabilityLayer + + +class TestObservabilityLayerExtras: + def test_init_tracer_enabled_sets_tracer(self, monkeypatch): + tracer = object() + monkeypatch.setattr("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True) + monkeypatch.setattr("core.app.workflow.layers.observability.is_instrument_flag_enabled", lambda: False) + monkeypatch.setattr("core.app.workflow.layers.observability.get_tracer", lambda _: tracer) + + layer = ObservabilityLayer() + + assert layer._is_disabled is False + assert layer._tracer is tracer + + def test_init_tracer_disables_when_get_tracer_fails(self, monkeypatch, caplog): + monkeypatch.setattr("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True) + monkeypatch.setattr("core.app.workflow.layers.observability.is_instrument_flag_enabled", lambda: False) + + def _raise(*_args, **_kwargs): + raise RuntimeError("tracer init failed") + + monkeypatch.setattr("core.app.workflow.layers.observability.get_tracer", _raise) + + layer = ObservabilityLayer() + + assert layer._is_disabled is True + assert layer._tracer is None + assert "Failed to get OpenTelemetry tracer" in caplog.text + + def test_init_tracer_disables_when_otel_disabled(self, monkeypatch): + monkeypatch.setattr("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", False) + monkeypatch.setattr("core.app.workflow.layers.observability.is_instrument_flag_enabled", lambda: False) + + layer = ObservabilityLayer() + + assert layer._is_disabled is True + + def test_get_parser_uses_registry_when_node_type_matches(self): + layer = ObservabilityLayer() + + parser = layer._get_parser(SimpleNamespace(node_type=BuiltinNodeTypes.TOOL)) + + assert parser is layer._parsers[BuiltinNodeTypes.TOOL] + + def test_get_parser_defaults_when_node_type_missing(self): + layer = ObservabilityLayer() + + parser = layer._get_parser(SimpleNamespace(node_type=None)) + + assert parser is layer._default_parser + + def test_on_graph_start_clears_contexts(self): + layer = ObservabilityLayer() + layer._node_contexts["exec"] = SimpleNamespace(span=object(), token="token") + + layer.on_graph_start() + + assert layer._node_contexts == {} + + def test_on_event_is_noop(self): + layer = ObservabilityLayer() + + layer.on_event(object()) + + def test_on_graph_end_clears_unfinished_contexts(self, caplog): + layer = ObservabilityLayer() + layer._node_contexts["exec"] = SimpleNamespace(span=object(), token="token") + + layer.on_graph_end(error=None) + + assert layer._node_contexts == {} + assert "node spans were not properly ended" in caplog.text + + def test_on_node_run_start_skips_without_execution_id(self): + layer = ObservabilityLayer() + layer._is_disabled = False + layer._tracer = None + + layer.on_node_run_start(SimpleNamespace(execution_id=None, title="node", id="node")) + + assert layer._node_contexts == {} + + def test_on_node_run_start_skips_when_disabled(self): + layer = ObservabilityLayer() + layer._is_disabled = True + layer._tracer = SimpleNamespace(start_span=lambda *_args, **_kwargs: object()) + + layer.on_node_run_start(SimpleNamespace(execution_id="exec", title="node", id="node")) + + assert layer._node_contexts == {} + + def test_on_node_run_start_skips_when_execution_id_missing_even_with_tracer(self): + layer = ObservabilityLayer() + layer._is_disabled = False + calls: list[str] = [] + layer._tracer = SimpleNamespace(start_span=lambda *_args, **_kwargs: calls.append("called")) + + layer.on_node_run_start(SimpleNamespace(execution_id=None, title="node", id="node")) + + assert calls == [] + + def test_on_node_run_start_logs_warning_when_span_creation_fails(self, caplog): + layer = ObservabilityLayer() + layer._is_disabled = False + + def _raise(*_args, **_kwargs): + raise RuntimeError("start failed") + + layer._tracer = SimpleNamespace(start_span=_raise) + + layer.on_node_run_start(SimpleNamespace(execution_id="exec", title="node", id="node")) + + assert "Failed to create OpenTelemetry span for node" in caplog.text + + def test_on_node_run_end_without_context_noop(self): + layer = ObservabilityLayer() + layer._is_disabled = False + + layer.on_node_run_end(SimpleNamespace(execution_id="missing", id="node"), error=None) + + assert layer._node_contexts == {} + + def test_on_node_run_end_skips_when_disabled(self): + layer = ObservabilityLayer() + layer._is_disabled = True + layer._node_contexts["exec"] = SimpleNamespace(span=object(), token="token") + + layer.on_node_run_end(SimpleNamespace(execution_id="exec", id="node"), error=None) + + assert "exec" in layer._node_contexts + + def test_on_node_run_end_skips_without_execution_id(self): + layer = ObservabilityLayer() + layer._is_disabled = False + + layer.on_node_run_end(SimpleNamespace(execution_id=None, id="node"), error=None) + + assert layer._node_contexts == {} + + def test_on_node_run_end_calls_span_end(self, monkeypatch): + layer = ObservabilityLayer() + layer._is_disabled = False + ended: list[str] = [] + + class _Parser: + def parse(self, **_kwargs): + return None + + span = SimpleNamespace(end=lambda: ended.append("ended")) + layer._default_parser = _Parser() + layer._node_contexts["exec"] = SimpleNamespace(span=span, token="token") + + monkeypatch.setattr("core.app.workflow.layers.observability.context_api.detach", lambda _token: None) + + node = SimpleNamespace(execution_id="exec", title="Node", id="node", node_type=None) + layer.on_node_run_end(node, error=None) + + assert ended == ["ended"] + assert "exec" not in layer._node_contexts + + def test_on_node_run_end_logs_detach_failure(self, monkeypatch, caplog): + layer = ObservabilityLayer() + layer._is_disabled = False + + class _Parser: + def parse(self, **_kwargs): + return None + + layer._default_parser = _Parser() + layer._node_contexts["exec"] = SimpleNamespace(span=SimpleNamespace(end=lambda: None), token="bad-token") + + def _raise(*_args, **_kwargs): + raise RuntimeError("detach failed") + + monkeypatch.setattr("core.app.workflow.layers.observability.context_api.detach", _raise) + + node = SimpleNamespace(execution_id="exec", title="Node", id="node", node_type=None) + layer.on_node_run_end(node, error=None) + + assert "Failed to detach OpenTelemetry token" in caplog.text + assert "exec" not in layer._node_contexts + + def test_on_node_run_start_and_end_creates_span(self, monkeypatch): + layer = ObservabilityLayer() + layer._is_disabled = False + + span = SimpleNamespace(end=lambda: None) + tracer = SimpleNamespace(start_span=lambda *args, **kwargs: span) + + monkeypatch.setattr("core.app.workflow.layers.observability.context_api.get_current", lambda: object()) + monkeypatch.setattr("core.app.workflow.layers.observability.set_span_in_context", lambda s: object()) + monkeypatch.setattr("core.app.workflow.layers.observability.context_api.attach", lambda ctx: "token") + monkeypatch.setattr("core.app.workflow.layers.observability.context_api.detach", lambda token: None) + + layer._tracer = tracer + + node = SimpleNamespace(execution_id="exec", title="Node", id="node", node_type=None) + + layer.on_node_run_start(node) + assert "exec" in layer._node_contexts + + layer.on_node_run_end(node, error=None) + assert "exec" not in layer._node_contexts diff --git a/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py new file mode 100644 index 00000000000..d8a68f6d000 --- /dev/null +++ b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py @@ -0,0 +1,498 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace + +import pytest +from graphon.entities import WorkflowNodeExecution +from graphon.entities.pause_reason import SchedulingPause +from graphon.enums import ( + BuiltinNodeTypes, + WorkflowExecutionStatus, + WorkflowNodeExecutionStatus, + WorkflowType, +) +from graphon.graph_events import ( + GraphRunAbortedEvent, + GraphRunFailedEvent, + GraphRunPartialSucceededEvent, + GraphRunPausedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunExceptionEvent, + NodeRunFailedEvent, + NodeRunPauseRequestedEvent, + NodeRunRetryEvent, + NodeRunStartedEvent, + NodeRunSucceededEvent, +) +from graphon.node_events import NodeRunResult +from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool + +from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity +from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.workflow.system_variables import SystemVariableKey, build_system_variables + + +class _RepoRecorder: + def __init__(self) -> None: + self.saved: list[object] = [] + self.saved_exec_data: list[object] = [] + + def save(self, entity): + self.saved.append(entity) + + def save_execution_data(self, entity): + self.saved_exec_data.append(entity) + + +def _naive_utc_now() -> datetime: + return datetime.now(UTC).replace(tzinfo=None) + + +def _make_layer( + system_variables: list | None = None, + *, + extras: dict | None = None, + trace_manager: object | None = None, +): + system_variables = system_variables or build_system_variables( + workflow_execution_id="run-id", + conversation_id="conv-id", + ) + runtime_state = GraphRuntimeState(variable_pool=VariablePool(system_variables=system_variables), start_at=0.0) + read_only_state = ReadOnlyGraphRuntimeStateWrapper(runtime_state) + + application_generate_entity = WorkflowAppGenerateEntity.model_construct( + task_id="task", + app_config=SimpleNamespace(app_id="app", tenant_id="tenant"), + inputs={"foo": "bar"}, + files=[], + user_id="user", + stream=False, + invoke_from=None, + trace_manager=None, + workflow_execution_id="run-id", + extras=extras or {}, + call_depth=0, + ) + + workflow_info = PersistenceWorkflowInfo( + workflow_id="workflow-id", + workflow_type=WorkflowType.WORKFLOW, + version="1", + graph_data={"nodes": [], "edges": []}, + ) + + workflow_execution_repo = _RepoRecorder() + workflow_node_execution_repo = _RepoRecorder() + + layer = WorkflowPersistenceLayer( + application_generate_entity=application_generate_entity, + workflow_info=workflow_info, + workflow_execution_repository=workflow_execution_repo, + workflow_node_execution_repository=workflow_node_execution_repo, + trace_manager=trace_manager, + ) + layer.initialize(read_only_state, command_channel=None) + + return layer, workflow_execution_repo, workflow_node_execution_repo, runtime_state + + +class TestWorkflowPersistenceLayer: + def test_on_graph_start_resets_state(self): + layer, _, _, _ = _make_layer() + layer._workflow_execution = object() + layer._node_execution_cache["cached"] = object() + layer._node_snapshots["cached"] = object() + layer._node_sequence = 9 + + layer.on_graph_start() + + assert layer._workflow_execution is None + assert layer._node_execution_cache == {} + assert layer._node_snapshots == {} + assert layer._node_sequence == 0 + + def test_get_execution_id_requires_system_variable(self): + layer, _, _, _ = _make_layer(build_system_variables()) + + with pytest.raises(ValueError, match="workflow_execution_id must be provided"): + layer._get_execution_id() + + def test_prepare_workflow_inputs_excludes_conversation_id(self, monkeypatch): + layer, _, _, _ = _make_layer() + + monkeypatch.setattr( + "core.workflow.workflow_entry.WorkflowEntry.handle_special_values", + lambda inputs: inputs, + ) + + inputs = layer._prepare_workflow_inputs() + + assert "sys.conversation_id" not in inputs + assert inputs[f"sys.{SystemVariableKey.WORKFLOW_EXECUTION_ID.value}"] == "run-id" + + def test_fail_running_node_executions_marks_failed(self): + layer, _, node_repo, _ = _make_layer() + + execution = WorkflowNodeExecution( + id="exec-id", + workflow_id="workflow-id", + workflow_execution_id="run-id", + index=1, + node_id="node", + node_type=BuiltinNodeTypes.START, + title="Start", + created_at=_naive_utc_now(), + ) + layer._node_execution_cache[execution.id] = execution + + layer._fail_running_node_executions(error_message="boom") + + assert execution.status == WorkflowNodeExecutionStatus.FAILED + assert node_repo.saved + + def test_handle_graph_run_started_saves_execution(self): + layer, exec_repo, _, _ = _make_layer() + + layer._handle_graph_run_started() + + assert exec_repo.saved + + def test_handle_graph_run_succeeded_updates_execution(self): + layer, exec_repo, _, runtime_state = _make_layer() + layer._handle_graph_run_started() + runtime_state.total_tokens = 3 + runtime_state.node_run_steps = 2 + runtime_state.outputs = {"out": "v"} + + layer._handle_graph_run_succeeded(GraphRunSucceededEvent(outputs={"ok": True})) + + saved = exec_repo.saved[-1] + assert saved.status == WorkflowExecutionStatus.SUCCEEDED + assert saved.total_tokens == 3 + assert saved.total_steps == 2 + + def test_handle_graph_run_partial_succeeded_updates_execution(self): + layer, exec_repo, _, runtime_state = _make_layer() + layer._handle_graph_run_started() + runtime_state.total_tokens = 5 + runtime_state.node_run_steps = 4 + runtime_state._graph_execution = SimpleNamespace(exceptions_count=2) + + layer._handle_graph_run_partial_succeeded( + GraphRunPartialSucceededEvent(outputs={"ok": True}, exceptions_count=2) + ) + + saved = exec_repo.saved[-1] + assert saved.status == WorkflowExecutionStatus.PARTIAL_SUCCEEDED + assert saved.exceptions_count == 2 + assert saved.total_tokens == 5 + + def test_handle_graph_run_failed_marks_nodes_and_enqueues_trace(self): + trace_tasks: list[object] = [] + trace_manager = SimpleNamespace(user_id="user", add_trace_task=lambda task: trace_tasks.append(task)) + layer, exec_repo, node_repo, _ = _make_layer(extras={"external_trace_id": "trace"}, trace_manager=trace_manager) + layer._handle_graph_run_started() + + running = WorkflowNodeExecution( + id="node-exec", + workflow_id="workflow-id", + workflow_execution_id="run-id", + index=1, + node_id="node", + node_type=BuiltinNodeTypes.START, + title="Start", + created_at=_naive_utc_now(), + ) + layer._node_execution_cache[running.id] = running + + layer._handle_graph_run_failed(GraphRunFailedEvent(error="boom", exceptions_count=1)) + + assert node_repo.saved + assert exec_repo.saved[-1].status == WorkflowExecutionStatus.FAILED + assert trace_tasks + + def test_handle_graph_run_aborted_sets_status(self): + layer, exec_repo, _, _ = _make_layer() + layer._handle_graph_run_started() + + layer._handle_graph_run_aborted(GraphRunAbortedEvent(reason=None, outputs={})) + + saved = exec_repo.saved[-1] + assert saved.status == WorkflowExecutionStatus.STOPPED + assert saved.error_message + + def test_handle_graph_run_paused_updates_outputs(self): + layer, exec_repo, _, runtime_state = _make_layer() + layer._handle_graph_run_started() + runtime_state.total_tokens = 7 + runtime_state.node_run_steps = 5 + + layer._handle_graph_run_paused(GraphRunPausedEvent(outputs={"pause": True})) + + saved = exec_repo.saved[-1] + assert saved.status == WorkflowExecutionStatus.PAUSED + assert saved.outputs == {"pause": True} + assert saved.finished_at is None + + def test_handle_node_started_and_retry(self): + layer, _, node_repo, _ = _make_layer() + layer._handle_graph_run_started() + + start_event = NodeRunStartedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + node_title="Start", + start_at=_naive_utc_now(), + predecessor_node_id="prev", + in_iteration_id="iter", + in_loop_id="loop", + ) + layer._handle_node_started(start_event) + + assert node_repo.saved + assert "exec" in layer._node_execution_cache + assert layer._node_snapshots["exec"].node_id == "node" + + retry_event = NodeRunRetryEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + node_title="Start", + start_at=_naive_utc_now(), + error="retry", + retry_index=1, + ) + layer._handle_node_retry(retry_event) + assert node_repo.saved_exec_data + + def test_handle_node_result_events_update_execution(self): + layer, _, node_repo, _ = _make_layer() + layer._handle_graph_run_started() + + start_event = NodeRunStartedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + node_title="LLM", + start_at=_naive_utc_now(), + ) + layer._handle_node_started(start_event) + + result = NodeRunResult(inputs={"a": 1}, process_data={"b": 2}, outputs={"c": 3}, metadata={}) + success_event = NodeRunSucceededEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + start_at=_naive_utc_now(), + node_run_result=result, + ) + layer._handle_node_succeeded(success_event) + + failed_event = NodeRunFailedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + start_at=_naive_utc_now(), + error="boom", + node_run_result=result, + ) + layer._handle_node_failed(failed_event) + + exception_event = NodeRunExceptionEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + start_at=_naive_utc_now(), + error="err", + node_run_result=result, + ) + layer._handle_node_exception(exception_event) + + assert node_repo.saved_exec_data + + def test_handle_node_pause_requested_skips_outputs(self): + layer, _, _, _ = _make_layer() + layer._handle_graph_run_started() + start_event = NodeRunStartedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + node_title="LLM", + start_at=_naive_utc_now(), + ) + layer._handle_node_started(start_event) + + domain_execution = layer._node_execution_cache["exec"] + domain_execution.inputs = {"old": True} + + result = NodeRunResult(inputs={"new": True}, outputs={"out": 1}, process_data={"p": 1}, metadata={}) + pause_event = NodeRunPauseRequestedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + reason=SchedulingPause(message="pause"), + node_run_result=result, + ) + layer._handle_node_pause_requested(pause_event) + + assert domain_execution.status == WorkflowNodeExecutionStatus.PAUSED + assert domain_execution.inputs == {"old": True} + + def test_get_node_execution_raises_for_missing(self): + layer, _, _, _ = _make_layer() + with pytest.raises(ValueError, match="Node execution not found"): + layer._get_node_execution("missing") + + def test_get_workflow_execution_raises_when_uninitialized(self): + layer, _, _, _ = _make_layer() + + with pytest.raises(ValueError, match="workflow execution not initialized"): + layer._get_workflow_execution() + + def test_next_node_sequence_increments(self): + layer, _, _, _ = _make_layer() + assert layer._next_node_sequence() == 1 + assert layer._next_node_sequence() == 2 + + def test_on_graph_end_is_noop(self): + layer, _, _, _ = _make_layer() + + assert layer.on_graph_end(error=None) is None + + def test_on_event_dispatches_to_all_known_handlers(self): + layer, _, _, _ = _make_layer() + called: list[str] = [] + + def _record(name: str): + def _handler(*_args, **_kwargs): + called.append(name) + + return _handler + + layer._handle_graph_run_started = _record("started") + layer._handle_graph_run_succeeded = _record("succeeded") + layer._handle_graph_run_partial_succeeded = _record("partial") + layer._handle_graph_run_failed = _record("failed") + layer._handle_graph_run_aborted = _record("aborted") + layer._handle_graph_run_paused = _record("paused") + layer._handle_node_started = _record("node_started") + layer._handle_node_retry = _record("node_retry") + layer._handle_node_succeeded = _record("node_succeeded") + layer._handle_node_failed = _record("node_failed") + layer._handle_node_exception = _record("node_exception") + layer._handle_node_pause_requested = _record("node_paused") + + node_result = NodeRunResult() + now = _naive_utc_now() + events = [ + GraphRunStartedEvent(), + GraphRunSucceededEvent(outputs={"ok": True}), + GraphRunPartialSucceededEvent(outputs={"ok": True}, exceptions_count=1), + GraphRunFailedEvent(error="boom", exceptions_count=1), + GraphRunAbortedEvent(reason="stop", outputs={"x": 1}), + GraphRunPausedEvent(outputs={"pause": True}), + NodeRunStartedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + node_title="Start", + start_at=now, + ), + NodeRunRetryEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + node_title="Start", + start_at=now, + error="retry", + retry_index=1, + ), + NodeRunSucceededEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + start_at=now, + node_run_result=node_result, + ), + NodeRunFailedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + start_at=now, + error="failed", + node_run_result=node_result, + ), + NodeRunExceptionEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + start_at=now, + error="error", + node_run_result=node_result, + ), + NodeRunPauseRequestedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + reason=SchedulingPause(message="pause"), + node_run_result=node_result, + ), + ] + expected_order = [ + "started", + "succeeded", + "partial", + "failed", + "aborted", + "paused", + "node_started", + "node_retry", + "node_succeeded", + "node_failed", + "node_exception", + "node_paused", + ] + + for event in events: + layer.on_event(event) + + assert called == expected_order + + def test_on_event_dispatches_retry_before_started_for_retry_event(self): + layer, _, _, _ = _make_layer() + called: list[str] = [] + + def _record(name: str): + def _handler(*_args, **_kwargs): + called.append(name) + + return _handler + + layer._handle_node_started = _record("node_started") + layer._handle_node_retry = _record("node_retry") + + layer.on_event( + NodeRunRetryEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + node_title="Start", + start_at=_naive_utc_now(), + error="retry", + retry_index=1, + ) + ) + + assert called == ["node_retry"] + + def test_enqueue_trace_task_skips_when_disabled(self): + trace_tasks: list[object] = [] + layer, exec_repo, _, _ = _make_layer() + layer._handle_graph_run_started() + layer._handle_graph_run_succeeded(GraphRunSucceededEvent(outputs={"ok": True})) + assert exec_repo.saved + assert not trace_tasks diff --git a/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py b/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py index 3759b6aa37a..5ff9774b525 100644 --- a/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py +++ b/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py @@ -28,10 +28,7 @@ def mock_model_instance(mocker): def mock_model_manager(mocker, mock_model_instance): manager = mocker.MagicMock() manager.get_default_model_instance.return_value = mock_model_instance - mocker.patch( - "core.base.tts.app_generator_tts_publisher.ModelManager", - return_value=manager, - ) + mocker.patch("core.base.tts.app_generator_tts_publisher.ModelManager.for_tenant", return_value=manager) return manager @@ -64,16 +61,14 @@ class TestInvoiceTTS: [None, "", " "], ) def test_invoice_tts_empty_or_none_returns_none(self, text, mock_model_instance): - result = _invoice_tts(text, mock_model_instance, "tenant", "voice1") + result = _invoice_tts(text, mock_model_instance, "voice1") assert result is None mock_model_instance.invoke_tts.assert_not_called() def test_invoice_tts_valid_text(self, mock_model_instance): - result = _invoice_tts(" hello ", mock_model_instance, "tenant", "voice1") + result = _invoice_tts(" hello ", mock_model_instance, "voice1") mock_model_instance.invoke_tts.assert_called_once_with( content_text="hello", - user="responding_tts", - tenant_id="tenant", voice="voice1", ) assert result == [b"audio1", b"audio2"] @@ -306,14 +301,15 @@ class TestAppGeneratorTTSPublisher: publisher = AppGeneratorTTSPublisher("tenant", "voice1") publisher.executor = MagicMock() - from core.app.entities.queue_entities import QueueAgentMessageEvent - from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta - from dify_graph.model_runtime.entities.message_entities import ( + from graphon.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta + from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, TextPromptMessageContent, ) + from core.app.entities.queue_entities import QueueAgentMessageEvent + chunk = LLMResultChunk( model="model", delta=LLMResultChunkDelta( @@ -341,9 +337,10 @@ class TestAppGeneratorTTSPublisher: publisher = AppGeneratorTTSPublisher("tenant", "voice1") publisher.executor = MagicMock() + from graphon.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta + from graphon.model_runtime.entities.message_entities import AssistantPromptMessage + from core.app.entities.queue_entities import QueueAgentMessageEvent - from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta - from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage chunk = LLMResultChunk( model="model", diff --git a/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py b/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py index b37c4c57a1f..8e5670e9be3 100644 --- a/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py +++ b/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py @@ -114,13 +114,9 @@ class TestOnToolEnd: document = mocker.Mock() document.metadata = {"document_id": "doc-1", "doc_id": "node-1"} - mock_query = mocker.Mock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - handler.on_tool_end([document]) - mock_query.update.assert_called_once() + mock_db.session.execute.assert_called_once() mock_db.session.commit.assert_called_once() def test_on_tool_end_non_parent_child_index(self, handler, mocker): @@ -138,13 +134,9 @@ class TestOnToolEnd: "dataset_id": "dataset-1", } - mock_query = mocker.Mock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - handler.on_tool_end([document]) - mock_query.update.assert_called_once() + mock_db.session.execute.assert_called_once() mock_db.session.commit.assert_called_once() def test_on_tool_end_empty_documents(self, handler): diff --git a/api/tests/unit_tests/core/datasource/test_datasource_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_manager.py index d5eeae912cb..b0c72ee42f5 100644 --- a/api/tests/unit_tests/core/datasource/test_datasource_manager.py +++ b/api/tests/unit_tests/core/datasource/test_datasource_manager.py @@ -2,15 +2,15 @@ import types from collections.abc import Generator import pytest +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType +from graphon.node_events import StreamChunkEvent, StreamCompletedEvent from contexts.wrapper import RecyclableContextVar from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceMessage, DatasourceProviderType from core.datasource.errors import DatasourceProviderNotFoundError -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.file import File -from dify_graph.file.enums import FileTransferMethod, FileType -from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent +from core.workflow.file_reference import parse_file_reference def _gen_messages_text_only(text: str) -> Generator[DatasourceMessage, None, None]: @@ -428,11 +428,8 @@ def test_stream_node_events_builds_file_and_variables_from_messages(mocker): return fake_tool_file mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_Session()) - mocker.patch( - "core.datasource.datasource_manager.file_factory.get_file_type_by_mime_type", return_value=FileType.IMAGE - ) + mocker.patch("core.datasource.datasource_manager.get_file_type_by_mime_type", return_value=FileType.IMAGE) built = File( - tenant_id="t1", type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id="tool_file_1", @@ -533,7 +530,6 @@ def test_stream_node_events_online_drive_sets_variable_pool_file_and_outputs(moc mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored")) file_in = File( - tenant_id="t1", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.TOOL_FILE, related_id="tf", @@ -664,6 +660,8 @@ def test_get_upload_file_by_id_builds_file(mocker): f = DatasourceManager.get_upload_file_by_id(file_id="fid", tenant_id="t1") assert f.related_id == "fid" assert f.extension == ".txt" + assert parse_file_reference(f.reference).storage_key is None + assert f.storage_key == "k" def test_get_upload_file_by_id_raises_when_missing(mocker): diff --git a/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py b/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py index 43f582feb75..fbaf6d497d7 100644 --- a/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py +++ b/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py @@ -1,11 +1,10 @@ from unittest.mock import MagicMock, patch import pytest +from graphon.file import File, FileTransferMethod, FileType from core.datasource.entities.datasource_entities import DatasourceMessage from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer -from dify_graph.file import File -from dify_graph.file.enums import FileTransferMethod, FileType from models.tools import ToolFile diff --git a/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py b/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py index 2e4f6d34fb8..ff9fd0d8f3f 100644 --- a/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py +++ b/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py @@ -1,11 +1,12 @@ +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType + from core.entities.execution_extra_content import ( ExecutionExtraContentDomainModel, HumanInputContent, HumanInputFormDefinition, HumanInputFormSubmissionData, ) -from dify_graph.nodes.human_input.entities import FormInput, UserAction -from dify_graph.nodes.human_input.enums import FormInputType from models.execution_extra_content import ExecutionContentType diff --git a/api/tests/unit_tests/core/entities/test_entities_model_entities.py b/api/tests/unit_tests/core/entities/test_entities_model_entities.py index 7a3d5e84ed0..2acd278a31e 100644 --- a/api/tests/unit_tests/core/entities/test_entities_model_entities.py +++ b/api/tests/unit_tests/core/entities/test_entities_model_entities.py @@ -8,6 +8,9 @@ drive provider mapping behavior. """ import pytest +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity from core.entities.model_entities import ( DefaultModelEntity, @@ -16,9 +19,6 @@ from core.entities.model_entities import ( ProviderModelWithStatusEntity, SimpleModelProviderEntity, ) -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import FetchFrom, ModelType -from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity def _build_model_with_status(status: ModelStatus) -> ProviderModelWithStatusEntity: diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py index 95d58757f10..8cf0409c4c2 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py @@ -6,6 +6,17 @@ from typing import Any from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FieldModelSchema, + FormType, + ModelCredentialSchema, + ProviderCredentialSchema, + ProviderEntity, +) from constants import HIDDEN_VALUE from core.entities.model_entities import ModelStatus @@ -24,17 +35,6 @@ from core.entities.provider_entities import ( SystemConfiguration, SystemConfigurationStatus, ) -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from dify_graph.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FieldModelSchema, - FormType, - ModelCredentialSchema, - ProviderCredentialSchema, - ProviderEntity, -) from models.enums import CredentialSourceType from models.provider import ProviderType from models.provider_ids import ModelProviderID @@ -350,7 +350,7 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None: mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "restored-key", "region": "us"} - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"): with patch( "core.entities.provider_configuration.encrypter.encrypt_token", @@ -380,7 +380,9 @@ def test_validate_provider_credentials_opens_session_when_not_passed() -> None: with patch("core.entities.provider_configuration.db") as mock_db: mock_db.engine = Mock() mock_session_cls.return_value.__enter__.return_value = mock_session - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): validated = configuration.validate_provider_credentials(credentials={"region": "us"}) assert validated == {"region": "us"} @@ -434,12 +436,16 @@ def test_get_model_type_instance_and_schema_delegate_to_factory() -> None: mock_factory.get_model_type_instance.return_value = mock_model_type_instance mock_factory.get_model_schema.return_value = mock_schema - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", + return_value=mock_factory, + ) as mock_factory_builder: model_type_instance = configuration.get_model_type_instance(ModelType.LLM) model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"}) assert model_type_instance is mock_model_type_instance assert model_schema is mock_schema + assert mock_factory_builder.call_count == 2 mock_factory.get_model_type_instance.assert_called_once_with(provider="openai", model_type=ModelType.LLM) mock_factory.get_model_schema.assert_called_once_with( provider="openai", @@ -449,6 +455,33 @@ def test_get_model_type_instance_and_schema_delegate_to_factory() -> None: ) +def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> None: + configuration = _build_provider_configuration() + bound_runtime = Mock() + configuration.bind_model_runtime(bound_runtime) + + mock_factory = Mock() + mock_model_type_instance = Mock() + mock_schema = _build_ai_model("gpt-4o") + mock_factory.get_model_type_instance.return_value = mock_model_type_instance + mock_factory.get_model_schema.return_value = mock_schema + + with ( + patch( + "core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory + ) as mock_factory_cls, + patch("core.entities.provider_configuration.create_plugin_model_provider_factory") as mock_factory_builder, + ): + model_type_instance = configuration.get_model_type_instance(ModelType.LLM) + model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"}) + + assert model_type_instance is mock_model_type_instance + assert model_schema is mock_schema + assert mock_factory_cls.call_count == 2 + mock_factory_cls.assert_called_with(model_runtime=bound_runtime) + mock_factory_builder.assert_not_called() + + def test_get_provider_model_returns_none_when_model_not_found() -> None: configuration = _build_provider_configuration() fake_model = SimpleNamespace(model="other-model") @@ -475,7 +508,7 @@ def test_get_provider_models_system_deduplicates_sorts_and_filters_active() -> N mock_factory = Mock() mock_factory.get_provider_schema.return_value = provider_schema - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): all_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=False) active_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=True) @@ -689,7 +722,7 @@ def test_validate_provider_credentials_handles_invalid_original_json() -> None: mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "new-key"} - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"): validated = configuration.validate_provider_credentials( credentials={"openai_api_key": HIDDEN_VALUE}, @@ -1034,7 +1067,7 @@ def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless mock_factory = Mock() mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"): with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): validated = configuration.validate_custom_model_credentials( @@ -1050,7 +1083,9 @@ def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless mock_factory = Mock() mock_factory.model_credentials_validate.return_value = {"region": "us"} with _patched_session(session): - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): validated = configuration.validate_custom_model_credentials( model_type=ModelType.LLM, model="gpt-4o", @@ -1540,7 +1575,7 @@ def test_validate_provider_credentials_uses_empty_original_when_record_missing() mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): validated = configuration.validate_provider_credentials( credentials={"openai_api_key": HIDDEN_VALUE}, @@ -1662,7 +1697,7 @@ def test_validate_custom_model_credentials_handles_invalid_original_json() -> No mock_factory = Mock() mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): validated = configuration.validate_custom_model_credentials( model_type=ModelType.LLM, diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_entities.py b/api/tests/unit_tests/core/entities/test_entities_provider_entities.py index c5bfd05a1ef..8685d162831 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_entities.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_entities.py @@ -1,4 +1,5 @@ import pytest +from graphon.model_runtime.entities.model_entities import ModelType from core.entities.parameter_entities import AppSelectorScope from core.entities.provider_entities import ( @@ -8,7 +9,6 @@ from core.entities.provider_entities import ( ProviderQuotaType, ) from core.tools.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import ModelType def test_provider_quota_type_value_of_returns_enum_member() -> None: diff --git a/api/tests/unit_tests/core/file/test_models.py b/api/tests/unit_tests/core/file/test_models.py index deebf41320c..bb6e40e2247 100644 --- a/api/tests/unit_tests/core/file/test_models.py +++ b/api/tests/unit_tests/core/file/test_models.py @@ -1,4 +1,4 @@ -from dify_graph.file import File, FileTransferMethod, FileType +from graphon.file import File, FileTransferMethod, FileType def test_file(): @@ -15,18 +15,17 @@ def test_file(): storage_key="test-storage-key", url="https://example.com/image.png", ) - assert file.tenant_id == "test-tenant-id" assert file.type == FileType.IMAGE assert file.transfer_method == FileTransferMethod.TOOL_FILE assert file.related_id == "test-related-id" + assert file.storage_key == "test-storage-key" assert file.filename == "image.png" assert file.extension == ".png" assert file.mime_type == "image/png" assert file.size == 67 -def test_file_model_validate_with_legacy_fields(): - """Test `File` model can handle data containing compatibility fields.""" +def test_file_model_validate_accepts_legacy_tenant_id(): data = { "id": "test-file", "tenant_id": "test-tenant-id", @@ -45,10 +44,8 @@ def test_file_model_validate_with_legacy_fields(): "datasource_file_id": "datasource-file-789", } - # Should be able to create `File` object without raising an exception file = File.model_validate(data) - # The File object does not have tool_file_id, upload_file_id, or datasource_file_id as attributes. - # Instead, check it does not expose unrecognized legacy fields (should raise on getattr). - for legacy_field in ("tool_file_id", "upload_file_id", "datasource_file_id"): - assert not hasattr(file, legacy_field) + assert file.related_id == "test-related-id" + assert file.storage_key == "test-storage-key" + assert "tenant_id" not in file.model_dump() diff --git a/api/tests/unit_tests/core/helper/test_encrypter.py b/api/tests/unit_tests/core/helper/test_encrypter.py index 58900097428..f3ef7fccd0f 100644 --- a/api/tests/unit_tests/core/helper/test_encrypter.py +++ b/api/tests/unit_tests/core/helper/test_encrypter.py @@ -38,13 +38,13 @@ class TestObfuscatedToken: class TestEncryptToken: - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") def test_successful_encryption(self, mock_encrypt, mock_query): """Test successful token encryption""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_query.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_data" result = encrypt_token("tenant-123", "test_token") @@ -52,10 +52,10 @@ class TestEncryptToken: assert result == base64.b64encode(b"encrypted_data").decode() mock_encrypt.assert_called_with("test_token", "mock_public_key") - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") def test_tenant_not_found(self, mock_query): """Test error when tenant doesn't exist""" - mock_query.return_value.where.return_value.first.return_value = None + mock_query.return_value = None with pytest.raises(ValueError) as exc_info: encrypt_token("invalid-tenant", "test_token") @@ -119,7 +119,7 @@ class TestGetDecryptDecoding: class TestEncryptDecryptIntegration: - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") @patch("libs.rsa.decrypt") def test_should_encrypt_and_decrypt_consistently(self, mock_decrypt, mock_encrypt, mock_query): @@ -127,7 +127,7 @@ class TestEncryptDecryptIntegration: # Setup mock tenant mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_query.return_value = mock_tenant # Setup mock encryption/decryption original_token = "test_token_123" @@ -146,14 +146,14 @@ class TestEncryptDecryptIntegration: class TestSecurity: """Critical security tests for encryption system""" - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") def test_cross_tenant_isolation(self, mock_encrypt, mock_query): """Ensure tokens encrypted for one tenant cannot be used by another""" # Setup mock tenant mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "tenant1_public_key" - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_query.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_for_tenant1" # Encrypt token for tenant1 @@ -181,12 +181,12 @@ class TestSecurity: with pytest.raises(Exception, match="Decryption error"): decrypt_token("tenant-123", tampered) - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") def test_encryption_randomness(self, mock_encrypt, mock_query): """Ensure same plaintext produces different ciphertext""" mock_tenant = MagicMock(encrypt_public_key="key") - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_query.return_value = mock_tenant # Different outputs for same input mock_encrypt.side_effect = [b"enc1", b"enc2", b"enc3"] @@ -205,13 +205,13 @@ class TestEdgeCases: # Test empty string (which is a valid str type) assert obfuscated_token("") == "" - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") def test_should_handle_empty_token_encryption(self, mock_encrypt, mock_query): """Test encryption of empty token""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_query.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_empty" result = encrypt_token("tenant-123", "") @@ -219,13 +219,13 @@ class TestEdgeCases: assert result == base64.b64encode(b"encrypted_empty").decode() mock_encrypt.assert_called_with("", "mock_public_key") - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") def test_should_handle_special_characters_in_token(self, mock_encrypt, mock_query): """Test tokens containing special/unicode characters""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_query.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_special" # Test various special characters @@ -242,13 +242,13 @@ class TestEdgeCases: assert result == base64.b64encode(b"encrypted_special").decode() mock_encrypt.assert_called_with(token, "mock_public_key") - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") def test_should_handle_rsa_size_limits(self, mock_encrypt, mock_query): """Test behavior when token exceeds RSA encryption limits""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_query.return_value = mock_tenant # RSA 2048-bit can only encrypt ~245 bytes # The actual limit depends on padding scheme diff --git a/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py b/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py index 46c9dc6f9c6..b45f6fd9a77 100644 --- a/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py +++ b/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py @@ -2,6 +2,20 @@ import json from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMResultWithStructuredOutput, + LLMUsage, +) +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from graphon.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.output_parser.structured_output import ( @@ -16,20 +30,6 @@ from core.llm_generator.output_parser.structured_output import ( remove_additional_properties, ) from core.model_manager import ModelInstance -from dify_graph.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkDelta, - LLMResultWithStructuredOutput, - LLMUsage, -) -from dify_graph.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - SystemPromptMessage, - TextPromptMessageContent, - UserPromptMessage, -) -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType class TestStructuredOutput: diff --git a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py index 5b7640696f1..62e714deb61 100644 --- a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py +++ b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py @@ -2,18 +2,18 @@ import json from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.app.app_config.entities import ModelConfig from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.llm_generator import LLMGenerator -from dify_graph.model_runtime.entities.llm_entities import LLMMode, LLMResult -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError class TestLLMGenerator: @pytest.fixture def mock_model_instance(self): - with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager: + with patch("core.llm_generator.llm_generator.ModelManager.for_tenant") as mock_manager: instance = MagicMock() mock_manager.return_value.get_default_model_instance.return_value = instance mock_manager.return_value.get_model_instance.return_value = instance @@ -98,7 +98,7 @@ class TestLLMGenerator: assert questions[0] == "Question 1?" def test_generate_suggested_questions_after_answer_auth_error(self, mock_model_instance): - with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager: + with patch("core.llm_generator.llm_generator.ModelManager.for_tenant") as mock_manager: mock_manager.return_value.get_default_model_instance.side_effect = InvokeAuthorizationError("Auth failed") questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories") assert questions == [] @@ -314,8 +314,8 @@ class TestLLMGenerator: assert "An unexpected error occurred" in result["error"] def test_instruction_modify_legacy_no_last_run(self, mock_model_instance, model_config_entity): - with patch("extensions.ext_database.db.session.query") as mock_query: - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: + mock_scalar.return_value = None # Mock __instruction_modify_common call via invoke_llm mock_response = MagicMock() @@ -328,12 +328,12 @@ class TestLLMGenerator: assert result == {"modified": "prompt"} def test_instruction_modify_legacy_with_last_run(self, mock_model_instance, model_config_entity): - with patch("extensions.ext_database.db.session.query") as mock_query: + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: last_run = MagicMock() last_run.query = "q" last_run.answer = "a" last_run.error = "e" - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = last_run + mock_scalar.return_value = last_run mock_response = MagicMock() mock_response.message.get_text_content.return_value = '{"modified": "prompt"}' @@ -483,8 +483,8 @@ class TestLLMGenerator: def test_instruction_modify_common_placeholders(self, mock_model_instance, model_config_entity): # Testing placeholders replacement via instruction_modify_legacy for convenience - with patch("extensions.ext_database.db.session.query") as mock_query: - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: + mock_scalar.return_value = None mock_response = MagicMock() mock_response.message.get_text_content.return_value = '{"ok": true}' @@ -504,8 +504,8 @@ class TestLLMGenerator: assert "current_val" in user_msg_dict["instruction"] def test_instruction_modify_common_no_braces(self, mock_model_instance, model_config_entity): - with patch("extensions.ext_database.db.session.query") as mock_query: - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: + mock_scalar.return_value = None mock_response = MagicMock() mock_response.message.get_text_content.return_value = "No braces here" mock_model_instance.invoke_llm.return_value = mock_response @@ -516,8 +516,8 @@ class TestLLMGenerator: assert "Could not find a valid JSON object" in result["error"] def test_instruction_modify_common_not_dict(self, mock_model_instance, model_config_entity): - with patch("extensions.ext_database.db.session.query") as mock_query: - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: + mock_scalar.return_value = None mock_response = MagicMock() mock_response.message.get_text_content.return_value = "[1, 2, 3]" mock_model_instance.invoke_llm.return_value = mock_response @@ -528,7 +528,7 @@ class TestLLMGenerator: assert "An unexpected error occurred" in result["error"] def test_instruction_modify_common_other_node_type(self, mock_model_instance, model_config_entity): - with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager: + with patch("core.llm_generator.llm_generator.ModelManager.for_tenant") as mock_manager: instance = MagicMock() mock_manager.return_value.get_model_instance.return_value = instance mock_response = MagicMock() @@ -556,8 +556,8 @@ class TestLLMGenerator: ) def test_instruction_modify_common_invoke_error(self, mock_model_instance, model_config_entity): - with patch("extensions.ext_database.db.session.query") as mock_query: - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: + mock_scalar.return_value = None mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke Failed") result = LLMGenerator.instruction_modify_legacy( @@ -566,8 +566,8 @@ class TestLLMGenerator: assert "Failed to generate code" in result["error"] def test_instruction_modify_common_exception(self, mock_model_instance, model_config_entity): - with patch("extensions.ext_database.db.session.query") as mock_query: - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: + mock_scalar.return_value = None mock_model_instance.invoke_llm.side_effect = Exception("Random error") result = LLMGenerator.instruction_modify_legacy( @@ -576,8 +576,8 @@ class TestLLMGenerator: assert "An unexpected error occurred" in result["error"] def test_instruction_modify_common_json_error(self, mock_model_instance, model_config_entity): - with patch("extensions.ext_database.db.session.query") as mock_query: - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: + mock_scalar.return_value = None mock_response = MagicMock() mock_response.message.get_text_content.return_value = "No JSON here" diff --git a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py index f982765b1a5..313d18c695d 100644 --- a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py +++ b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py @@ -3,6 +3,7 @@ from unittest.mock import Mock, patch import jsonschema import pytest +from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from core.mcp import types @@ -18,7 +19,6 @@ from core.mcp.server.streamable_http import ( prepare_tool_arguments, process_mapping_response, ) -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType from models.model import App, AppMCPServer, AppMode, EndUser diff --git a/api/tests/unit_tests/core/memory/test_token_buffer_memory.py b/api/tests/unit_tests/core/memory/test_token_buffer_memory.py index 5ecfe018088..9a5fb319d7b 100644 --- a/api/tests/unit_tests/core/memory/test_token_buffer_memory.py +++ b/api/tests/unit_tests/core/memory/test_token_buffer_memory.py @@ -4,15 +4,15 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest - -from core.memory.token_buffer_memory import TokenBufferMemory -from dify_graph.model_runtime.entities import ( +from graphon.model_runtime.entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessageRole, TextPromptMessageContent, UserPromptMessage, ) + +from core.memory.token_buffer_memory import TokenBufferMemory from models.model import AppMode # --------------------------------------------------------------------------- diff --git a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py new file mode 100644 index 00000000000..6a672fdfd57 --- /dev/null +++ b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py @@ -0,0 +1,419 @@ +from unittest.mock import Mock + +import pytest +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FieldModelSchema, + FormType, + ModelCredentialSchema, + ProviderCredentialSchema, + ProviderEntity, +) +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel +from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel +from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel +from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from graphon.model_runtime.model_providers.__base.tts_model import TTSModel +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory + + +def _build_model(model: str, model_type: ModelType) -> AIModelEntity: + return AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=model_type, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={}, + ) + + +def _build_provider( + *, + provider: str, + provider_name: str, + supported_model_types: list[ModelType], + models: list[AIModelEntity] | None = None, + provider_credential_schema: ProviderCredentialSchema | None = None, + model_credential_schema: ModelCredentialSchema | None = None, +) -> ProviderEntity: + return ProviderEntity( + provider=provider, + provider_name=provider_name, + label=I18nObject(en_US=provider_name or provider), + supported_model_types=supported_model_types, + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + models=models or [], + provider_credential_schema=provider_credential_schema, + model_credential_schema=model_credential_schema, + ) + + +class _FakeModelRuntime: + def __init__(self, providers: list[ProviderEntity]) -> None: + self._providers = providers + self.validate_provider_credentials = Mock() + self.validate_model_credentials = Mock() + self.get_model_schema = Mock() + self.get_provider_icon = Mock() + + def fetch_model_providers(self) -> list[ProviderEntity]: + return self._providers + + +def test_model_provider_factory_resolves_runtime_provider_name() -> None: + provider = ProviderEntity( + provider="langgenius/openai/openai", + provider_name="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ) + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime([provider])) + + provider_schema = factory.get_model_provider("openai") + + assert provider_schema.provider == "langgenius/openai/openai" + assert provider_schema.provider_name == "openai" + + +def test_model_provider_factory_resolves_canonical_short_name_independent_of_provider_order() -> None: + providers = [ + ProviderEntity( + provider="acme/openai/openai", + provider_name="", + label=I18nObject(en_US="Acme OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ProviderEntity( + provider="langgenius/openai/openai", + provider_name="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ] + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + + provider_schema = factory.get_model_provider("openai") + + assert provider_schema.provider == "langgenius/openai/openai" + assert provider_schema.provider_name == "openai" + + +def test_model_provider_factory_requires_runtime() -> None: + with pytest.raises(ValueError, match="model_runtime is required"): + ModelProviderFactory(model_runtime=None) # type: ignore[arg-type] + + +def test_model_provider_factory_get_providers_returns_runtime_providers() -> None: + providers = [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + + result = factory.get_providers() + + assert list(result) == providers + assert result is not providers + + +def test_model_provider_factory_get_provider_schema_delegates_to_provider_lookup() -> None: + provider = _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime([provider])) + + result = factory.get_provider_schema("openai") + + assert result is provider + + +def test_model_provider_factory_raises_for_unknown_provider() -> None: + factory = ModelProviderFactory( + model_runtime=_FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] + ) + ) + + with pytest.raises(ValueError, match="Invalid provider: anthropic"): + factory.get_model_provider("anthropic") + + +def test_model_provider_factory_get_models_filters_provider_and_model_type() -> None: + providers = [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM, ModelType.TTS], + models=[_build_model("gpt-4o-mini", ModelType.LLM), _build_model("tts-1", ModelType.TTS)], + ), + _build_provider( + provider="langgenius/cohere/cohere", + provider_name="cohere", + supported_model_types=[ModelType.RERANK], + models=[_build_model("rerank-v3", ModelType.RERANK)], + ), + ] + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + + results = factory.get_models(provider="openai", model_type=ModelType.LLM) + + assert len(results) == 1 + assert results[0].provider == "langgenius/openai/openai" + assert [model.model for model in results[0].models] == ["gpt-4o-mini"] + + +def test_model_provider_factory_get_models_skips_providers_without_requested_model_type() -> None: + providers = [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + models=[_build_model("gpt-4o-mini", ModelType.LLM)], + ), + _build_provider( + provider="langgenius/elevenlabs/elevenlabs", + provider_name="elevenlabs", + supported_model_types=[ModelType.TTS], + models=[_build_model("eleven_multilingual_v2", ModelType.TTS)], + ), + ] + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + + results = factory.get_models(model_type=ModelType.TTS) + + assert len(results) == 1 + assert results[0].provider == "langgenius/elevenlabs/elevenlabs" + assert [model.model for model in results[0].models] == ["eleven_multilingual_v2"] + + +def test_model_provider_factory_get_models_without_model_type_keeps_all_provider_models() -> None: + providers = [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM, ModelType.TTS], + models=[_build_model("gpt-4o-mini", ModelType.LLM), _build_model("tts-1", ModelType.TTS)], + ) + ] + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + + results = factory.get_models(provider="openai") + + assert len(results) == 1 + assert [model.model for model in results[0].models] == ["gpt-4o-mini", "tts-1"] + + +def test_model_provider_factory_validates_provider_credentials() -> None: + runtime = _FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + provider_credential_schema=ProviderCredentialSchema( + credential_form_schemas=[ + CredentialFormSchema( + variable="api_key", + label=I18nObject(en_US="API key"), + type=FormType.SECRET_INPUT, + required=True, + ) + ] + ), + ) + ] + ) + factory = ModelProviderFactory(model_runtime=runtime) + + filtered = factory.provider_credentials_validate( + provider="openai", + credentials={"api_key": "secret", "ignored": "value"}, + ) + + assert filtered == {"api_key": "secret"} + runtime.validate_provider_credentials.assert_called_once_with( + provider="langgenius/openai/openai", + credentials={"api_key": "secret"}, + ) + + +def test_model_provider_factory_provider_credentials_validate_requires_schema() -> None: + factory = ModelProviderFactory( + model_runtime=_FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] + ) + ) + + with pytest.raises(ValueError, match="Provider openai does not have provider_credential_schema"): + factory.provider_credentials_validate(provider="openai", credentials={"api_key": "secret"}) + + +def test_model_provider_factory_validates_model_credentials() -> None: + runtime = _FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + model_credential_schema=ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model")), + credential_form_schemas=[ + CredentialFormSchema( + variable="api_key", + label=I18nObject(en_US="API key"), + type=FormType.SECRET_INPUT, + required=True, + ) + ], + ), + ) + ] + ) + factory = ModelProviderFactory(model_runtime=runtime) + + filtered = factory.model_credentials_validate( + provider="openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "secret", "ignored": "value"}, + ) + + assert filtered == {"api_key": "secret"} + runtime.validate_model_credentials.assert_called_once_with( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + ) + + +def test_model_provider_factory_model_credentials_validate_requires_schema() -> None: + factory = ModelProviderFactory( + model_runtime=_FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] + ) + ) + + with pytest.raises(ValueError, match="Provider openai does not have model_credential_schema"): + factory.model_credentials_validate( + provider="openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + ) + + +def test_model_provider_factory_get_model_schema_and_icon_use_canonical_provider() -> None: + runtime = _FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] + ) + runtime.get_model_schema.return_value = "schema" + runtime.get_provider_icon.return_value = (b"icon", "image/png") + factory = ModelProviderFactory(model_runtime=runtime) + + assert ( + factory.get_model_schema( + provider="openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials=None, + ) + == "schema" + ) + assert factory.get_provider_icon("openai", "icon_small", "en_US") == (b"icon", "image/png") + runtime.get_model_schema.assert_called_once_with( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={}, + ) + runtime.get_provider_icon.assert_called_once_with( + provider="langgenius/openai/openai", + icon_type="icon_small", + lang="en_US", + ) + + +@pytest.mark.parametrize( + ("model_type", "expected_type"), + [ + (ModelType.LLM, LargeLanguageModel), + (ModelType.TEXT_EMBEDDING, TextEmbeddingModel), + (ModelType.RERANK, RerankModel), + (ModelType.SPEECH2TEXT, Speech2TextModel), + (ModelType.MODERATION, ModerationModel), + (ModelType.TTS, TTSModel), + ], +) +def test_model_provider_factory_builds_model_type_instances( + model_type: ModelType, + expected_type: type[object], +) -> None: + factory = ModelProviderFactory( + model_runtime=_FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[model_type], + ) + ] + ) + ) + + instance = factory.get_model_type_instance("openai", model_type) + + assert isinstance(instance, expected_type) + + +def test_model_provider_factory_rejects_unsupported_model_type() -> None: + factory = ModelProviderFactory( + model_runtime=_FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] + ) + ) + + with pytest.raises(ValueError, match="Unsupported model type: unsupported"): + factory.get_model_type_instance("openai", "unsupported") # type: ignore[arg-type] diff --git a/api/tests/unit_tests/core/moderation/api/test_api.py b/api/tests/unit_tests/core/moderation/api/test_api.py new file mode 100644 index 00000000000..558b20e5f88 --- /dev/null +++ b/api/tests/unit_tests/core/moderation/api/test_api.py @@ -0,0 +1,181 @@ +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.extension.api_based_extension_requestor import APIBasedExtensionPoint +from core.moderation.api.api import ApiModeration, ModerationInputParams, ModerationOutputParams +from core.moderation.base import ModerationAction, ModerationInputsResult, ModerationOutputsResult +from models.api_based_extension import APIBasedExtension + + +class TestApiModeration: + @pytest.fixture + def api_config(self): + return { + "inputs_config": { + "enabled": True, + }, + "outputs_config": { + "enabled": True, + }, + "api_based_extension_id": "test-extension-id", + } + + @pytest.fixture + def api_moderation(self, api_config): + return ApiModeration(app_id="test-app-id", tenant_id="test-tenant-id", config=api_config) + + def test_moderation_input_params(self): + params = ModerationInputParams(app_id="app-1", inputs={"key": "val"}, query="test query") + assert params.app_id == "app-1" + assert params.inputs == {"key": "val"} + assert params.query == "test query" + + # Test defaults + params_default = ModerationInputParams() + assert params_default.app_id == "" + assert params_default.inputs == {} + assert params_default.query == "" + + def test_moderation_output_params(self): + params = ModerationOutputParams(app_id="app-1", text="test text") + assert params.app_id == "app-1" + assert params.text == "test text" + + with pytest.raises(ValidationError): + ModerationOutputParams() + + @patch("core.moderation.api.api.ApiModeration._get_api_based_extension") + def test_validate_config_success(self, mock_get_extension, api_config): + mock_get_extension.return_value = MagicMock(spec=APIBasedExtension) + ApiModeration.validate_config("test-tenant-id", api_config) + mock_get_extension.assert_called_once_with("test-tenant-id", "test-extension-id") + + def test_validate_config_missing_extension_id(self): + config = { + "inputs_config": {"enabled": True}, + "outputs_config": {"enabled": True}, + } + with pytest.raises(ValueError, match="api_based_extension_id is required"): + ApiModeration.validate_config("test-tenant-id", config) + + @patch("core.moderation.api.api.ApiModeration._get_api_based_extension") + def test_validate_config_extension_not_found(self, mock_get_extension, api_config): + mock_get_extension.return_value = None + with pytest.raises(ValueError, match="API-based Extension not found"): + ApiModeration.validate_config("test-tenant-id", api_config) + + @patch("core.moderation.api.api.ApiModeration._get_config_by_requestor") + def test_moderation_for_inputs_enabled(self, mock_get_config, api_moderation): + mock_get_config.return_value = {"flagged": True, "action": "direct_output", "preset_response": "Blocked by API"} + + result = api_moderation.moderation_for_inputs(inputs={"q": "a"}, query="hello") + + assert isinstance(result, ModerationInputsResult) + assert result.flagged is True + assert result.action == ModerationAction.DIRECT_OUTPUT + assert result.preset_response == "Blocked by API" + + mock_get_config.assert_called_once_with( + APIBasedExtensionPoint.APP_MODERATION_INPUT, + {"app_id": "test-app-id", "inputs": {"q": "a"}, "query": "hello"}, + ) + + def test_moderation_for_inputs_disabled(self): + config = { + "inputs_config": {"enabled": False}, + "outputs_config": {"enabled": True}, + "api_based_extension_id": "ext-id", + } + moderation = ApiModeration("app-id", "tenant-id", config) + result = moderation.moderation_for_inputs(inputs={}, query="") + + assert result.flagged is False + assert result.action == ModerationAction.DIRECT_OUTPUT + assert result.preset_response == "" + + def test_moderation_for_inputs_no_config(self): + moderation = ApiModeration("app-id", "tenant-id", None) + with pytest.raises(ValueError, match="The config is not set"): + moderation.moderation_for_inputs({}, "") + + @patch("core.moderation.api.api.ApiModeration._get_config_by_requestor") + def test_moderation_for_outputs_enabled(self, mock_get_config, api_moderation): + mock_get_config.return_value = {"flagged": False, "action": "direct_output", "preset_response": ""} + + result = api_moderation.moderation_for_outputs(text="hello world") + + assert isinstance(result, ModerationOutputsResult) + assert result.flagged is False + + mock_get_config.assert_called_once_with( + APIBasedExtensionPoint.APP_MODERATION_OUTPUT, {"app_id": "test-app-id", "text": "hello world"} + ) + + def test_moderation_for_outputs_disabled(self): + config = { + "inputs_config": {"enabled": True}, + "outputs_config": {"enabled": False}, + "api_based_extension_id": "ext-id", + } + moderation = ApiModeration("app-id", "tenant-id", config) + result = moderation.moderation_for_outputs(text="test") + + assert result.flagged is False + assert result.action == ModerationAction.DIRECT_OUTPUT + + def test_moderation_for_outputs_no_config(self): + moderation = ApiModeration("app-id", "tenant-id", None) + with pytest.raises(ValueError, match="The config is not set"): + moderation.moderation_for_outputs("test") + + @patch("core.moderation.api.api.ApiModeration._get_api_based_extension") + @patch("core.moderation.api.api.decrypt_token") + @patch("core.moderation.api.api.APIBasedExtensionRequestor") + def test_get_config_by_requestor_success(self, mock_requestor_cls, mock_decrypt, mock_get_ext, api_moderation): + mock_ext = MagicMock(spec=APIBasedExtension) + mock_ext.api_endpoint = "http://api.test" + mock_ext.api_key = "encrypted-key" + mock_get_ext.return_value = mock_ext + + mock_decrypt.return_value = "decrypted-key" + + mock_requestor = MagicMock() + mock_requestor.request.return_value = {"flagged": True} + mock_requestor_cls.return_value = mock_requestor + + params = {"some": "params"} + result = api_moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params) + + assert result == {"flagged": True} + mock_get_ext.assert_called_once_with("test-tenant-id", "test-extension-id") + mock_decrypt.assert_called_once_with("test-tenant-id", "encrypted-key") + mock_requestor_cls.assert_called_once_with("http://api.test", "decrypted-key") + mock_requestor.request.assert_called_once_with(APIBasedExtensionPoint.APP_MODERATION_INPUT, params) + + def test_get_config_by_requestor_no_config(self): + moderation = ApiModeration("app-id", "tenant-id", None) + with pytest.raises(ValueError, match="The config is not set"): + moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, {}) + + @patch("core.moderation.api.api.ApiModeration._get_api_based_extension") + def test_get_config_by_requestor_extension_not_found(self, mock_get_ext, api_moderation): + mock_get_ext.return_value = None + with pytest.raises(ValueError, match="API-based Extension not found"): + api_moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, {}) + + @patch("core.moderation.api.api.db.session.scalar") + def test_get_api_based_extension(self, mock_scalar): + mock_ext = MagicMock(spec=APIBasedExtension) + mock_scalar.return_value = mock_ext + + result = ApiModeration._get_api_based_extension("tenant-1", "ext-1") + + assert result == mock_ext + mock_scalar.assert_called_once() + # Verify the call has the correct filters + args, kwargs = mock_scalar.call_args + stmt = args[0] + # We can't easily inspect the statement without complex sqlalchemy tricks, + # but calling it is usually enough for unit tests if we mock the result. diff --git a/api/tests/unit_tests/core/moderation/test_content_moderation.py b/api/tests/unit_tests/core/moderation/test_content_moderation.py index e61cde22e73..3a97ad5c5d2 100644 --- a/api/tests/unit_tests/core/moderation/test_content_moderation.py +++ b/api/tests/unit_tests/core/moderation/test_content_moderation.py @@ -324,7 +324,7 @@ class TestOpenAIModeration: with pytest.raises(ValueError, match="At least one of inputs_config or outputs_config must be enabled"): OpenAIModeration.validate_config("test-tenant", config) - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_inputs_no_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration): """Test input moderation when OpenAI API returns no violations.""" # Mock the model manager and instance @@ -341,7 +341,7 @@ class TestOpenAIModeration: assert result.action == ModerationAction.DIRECT_OUTPUT assert result.preset_response == "Content flagged by OpenAI moderation." - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_inputs_with_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration): """Test input moderation when OpenAI API detects violations.""" # Mock the model manager to return violation @@ -358,7 +358,7 @@ class TestOpenAIModeration: assert result.action == ModerationAction.DIRECT_OUTPUT assert result.preset_response == "Content flagged by OpenAI moderation." - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_inputs_query_included(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration): """Test that query is included in moderation check with special key.""" mock_instance = MagicMock() @@ -385,7 +385,7 @@ class TestOpenAIModeration: assert "u" in moderated_text assert "e" in moderated_text - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_inputs_disabled(self, mock_model_manager: Mock): """Test input moderation when inputs_config is disabled.""" config = { @@ -400,7 +400,7 @@ class TestOpenAIModeration: # Should not call the API when disabled mock_model_manager.assert_not_called() - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_outputs_no_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration): """Test output moderation when OpenAI API returns no violations.""" mock_instance = MagicMock() @@ -414,7 +414,7 @@ class TestOpenAIModeration: assert result.action == ModerationAction.DIRECT_OUTPUT assert result.preset_response == "Response blocked by moderation." - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_outputs_with_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration): """Test output moderation when OpenAI API detects violations.""" mock_instance = MagicMock() @@ -427,7 +427,7 @@ class TestOpenAIModeration: assert result.flagged is True assert result.action == ModerationAction.DIRECT_OUTPUT - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_outputs_disabled(self, mock_model_manager: Mock): """Test output moderation when outputs_config is disabled.""" config = { @@ -441,7 +441,7 @@ class TestOpenAIModeration: assert result.flagged is False mock_model_manager.assert_not_called() - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_model_manager_called_with_correct_params( self, mock_model_manager: Mock, openai_moderation: OpenAIModeration ): @@ -629,7 +629,7 @@ class TestPresetManagement: assert result.flagged is True assert result.preset_response == "Custom output blocked message" - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_preset_response_in_inputs(self, mock_model_manager: Mock): """Test preset response is properly returned for OpenAI input violations.""" mock_instance = MagicMock() @@ -650,7 +650,7 @@ class TestPresetManagement: assert result.flagged is True assert result.preset_response == "OpenAI input blocked" - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_preset_response_in_outputs(self, mock_model_manager: Mock): """Test preset response is properly returned for OpenAI output violations.""" mock_instance = MagicMock() @@ -989,7 +989,7 @@ class TestOpenAIModerationAdvanced: - Performance considerations """ - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_api_timeout_handling(self, mock_model_manager: Mock): """ Test graceful handling of OpenAI API timeouts. @@ -1012,7 +1012,7 @@ class TestOpenAIModerationAdvanced: with pytest.raises(TimeoutError): moderation.moderation_for_inputs({"text": "test"}, "") - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_api_rate_limit_handling(self, mock_model_manager: Mock): """ Test handling of OpenAI API rate limit errors. @@ -1035,7 +1035,7 @@ class TestOpenAIModerationAdvanced: with pytest.raises(Exception, match="Rate limit exceeded"): moderation.moderation_for_inputs({"text": "test"}, "") - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_with_multiple_input_fields(self, mock_model_manager: Mock): """ Test OpenAI moderation with multiple input fields. @@ -1079,7 +1079,7 @@ class TestOpenAIModerationAdvanced: assert "u" in moderated_text assert "e" in moderated_text - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_empty_text_handling(self, mock_model_manager: Mock): """ Test OpenAI moderation with empty text inputs. @@ -1103,7 +1103,7 @@ class TestOpenAIModerationAdvanced: assert result.flagged is False mock_instance.invoke_moderation.assert_called_once() - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_model_instance_fetched_on_each_call(self, mock_model_manager: Mock): """ Test that ModelManager fetches a fresh model instance on each call. diff --git a/api/tests/unit_tests/core/moderation/test_input_moderation.py b/api/tests/unit_tests/core/moderation/test_input_moderation.py new file mode 100644 index 00000000000..2dbc80cf14d --- /dev/null +++ b/api/tests/unit_tests/core/moderation/test_input_moderation.py @@ -0,0 +1,207 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.app.app_config.entities import AppConfig, SensitiveWordAvoidanceEntity +from core.moderation.base import ModerationAction, ModerationError, ModerationInputsResult +from core.moderation.input_moderation import InputModeration +from core.ops.entities.trace_entity import TraceTaskName +from core.ops.ops_trace_manager import TraceQueueManager + + +class TestInputModeration: + @pytest.fixture + def app_config(self): + config = MagicMock(spec=AppConfig) + config.sensitive_word_avoidance = None + return config + + @pytest.fixture + def input_moderation(self): + return InputModeration() + + def test_check_no_sensitive_word_avoidance(self, app_config, input_moderation): + app_id = "test_app_id" + tenant_id = "test_tenant_id" + inputs = {"input_key": "input_value"} + query = "test query" + message_id = "test_message_id" + + flagged, final_inputs, final_query = input_moderation.check( + app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id + ) + + assert flagged is False + assert final_inputs == inputs + assert final_query == query + + @patch("core.moderation.input_moderation.ModerationFactory") + def test_check_not_flagged(self, mock_factory_cls, app_config, input_moderation): + app_id = "test_app_id" + tenant_id = "test_tenant_id" + inputs = {"input_key": "input_value"} + query = "test query" + message_id = "test_message_id" + + # Setup config + sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity) + sensitive_word_config.type = "keywords" + sensitive_word_config.config = {"keywords": ["bad"]} + app_config.sensitive_word_avoidance = sensitive_word_config + + # Setup factory mock + mock_factory = mock_factory_cls.return_value + mock_result = ModerationInputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT) + mock_factory.moderation_for_inputs.return_value = mock_result + + flagged, final_inputs, final_query = input_moderation.check( + app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id + ) + + assert flagged is False + assert final_inputs == inputs + assert final_query == query + mock_factory_cls.assert_called_once_with( + name="keywords", app_id=app_id, tenant_id=tenant_id, config={"keywords": ["bad"]} + ) + mock_factory.moderation_for_inputs.assert_called_once_with(dict(inputs), query) + + @patch("core.moderation.input_moderation.ModerationFactory") + @patch("core.moderation.input_moderation.TraceTask") + def test_check_with_trace_manager(self, mock_trace_task, mock_factory_cls, app_config, input_moderation): + app_id = "test_app_id" + tenant_id = "test_tenant_id" + inputs = {"input_key": "input_value"} + query = "test query" + message_id = "test_message_id" + trace_manager = MagicMock(spec=TraceQueueManager) + + # Setup config + sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity) + sensitive_word_config.type = "keywords" + sensitive_word_config.config = {} + app_config.sensitive_word_avoidance = sensitive_word_config + + # Setup factory mock + mock_factory = mock_factory_cls.return_value + mock_result = ModerationInputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT) + mock_factory.moderation_for_inputs.return_value = mock_result + + input_moderation.check( + app_id=app_id, + tenant_id=tenant_id, + app_config=app_config, + inputs=inputs, + query=query, + message_id=message_id, + trace_manager=trace_manager, + ) + + trace_manager.add_trace_task.assert_called_once_with(mock_trace_task.return_value) + mock_trace_task.assert_called_once() + call_kwargs = mock_trace_task.call_args.kwargs + call_args = mock_trace_task.call_args.args + assert call_args[0] == TraceTaskName.MODERATION_TRACE + assert call_kwargs["message_id"] == message_id + assert call_kwargs["moderation_result"] == mock_result + assert call_kwargs["inputs"] == inputs + assert "timer" in call_kwargs + + @patch("core.moderation.input_moderation.ModerationFactory") + def test_check_flagged_direct_output(self, mock_factory_cls, app_config, input_moderation): + app_id = "test_app_id" + tenant_id = "test_tenant_id" + inputs = {"input_key": "input_value"} + query = "test query" + message_id = "test_message_id" + + # Setup config + sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity) + sensitive_word_config.type = "keywords" + sensitive_word_config.config = {} + app_config.sensitive_word_avoidance = sensitive_word_config + + # Setup factory mock + mock_factory = mock_factory_cls.return_value + mock_result = ModerationInputsResult( + flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="Blocked content" + ) + mock_factory.moderation_for_inputs.return_value = mock_result + + with pytest.raises(ModerationError) as excinfo: + input_moderation.check( + app_id=app_id, + tenant_id=tenant_id, + app_config=app_config, + inputs=inputs, + query=query, + message_id=message_id, + ) + + assert str(excinfo.value) == "Blocked content" + + @patch("core.moderation.input_moderation.ModerationFactory") + def test_check_flagged_overridden(self, mock_factory_cls, app_config, input_moderation): + app_id = "test_app_id" + tenant_id = "test_tenant_id" + inputs = {"input_key": "input_value"} + query = "test query" + message_id = "test_message_id" + + # Setup config + sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity) + sensitive_word_config.type = "keywords" + sensitive_word_config.config = {} + app_config.sensitive_word_avoidance = sensitive_word_config + + # Setup factory mock + mock_factory = mock_factory_cls.return_value + mock_result = ModerationInputsResult( + flagged=True, + action=ModerationAction.OVERRIDDEN, + inputs={"input_key": "overridden_value"}, + query="overridden query", + ) + mock_factory.moderation_for_inputs.return_value = mock_result + + flagged, final_inputs, final_query = input_moderation.check( + app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id + ) + + assert flagged is True + assert final_inputs == {"input_key": "overridden_value"} + assert final_query == "overridden query" + + @patch("core.moderation.input_moderation.ModerationFactory") + def test_check_flagged_other_action(self, mock_factory_cls, app_config, input_moderation): + app_id = "test_app_id" + tenant_id = "test_tenant_id" + inputs = {"input_key": "input_value"} + query = "test query" + message_id = "test_message_id" + + # Setup config + sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity) + sensitive_word_config.type = "keywords" + sensitive_word_config.config = {} + app_config.sensitive_word_avoidance = sensitive_word_config + + # Setup factory mock + mock_factory = mock_factory_cls.return_value + mock_result = MagicMock() + mock_result.flagged = True + mock_result.action = "NONE" # Some other action + mock_factory.moderation_for_inputs.return_value = mock_result + + flagged, final_inputs, final_query = input_moderation.check( + app_id=app_id, + tenant_id=tenant_id, + app_config=app_config, + inputs=inputs, + query=query, + message_id=message_id, + ) + + assert flagged is True + assert final_inputs == inputs + assert final_query == query diff --git a/api/tests/unit_tests/core/moderation/test_output_moderation.py b/api/tests/unit_tests/core/moderation/test_output_moderation.py new file mode 100644 index 00000000000..c6a7cd3f61c --- /dev/null +++ b/api/tests/unit_tests/core/moderation/test_output_moderation.py @@ -0,0 +1,234 @@ +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.entities.queue_entities import QueueMessageReplaceEvent +from core.moderation.base import ModerationAction, ModerationOutputsResult +from core.moderation.output_moderation import ModerationRule, OutputModeration + + +class TestOutputModeration: + @pytest.fixture + def mock_queue_manager(self): + return MagicMock(spec=AppQueueManager) + + @pytest.fixture + def moderation_rule(self): + return ModerationRule(type="keywords", config={"keywords": "badword"}) + + @pytest.fixture + def output_moderation(self, mock_queue_manager, moderation_rule): + return OutputModeration( + tenant_id="test_tenant", app_id="test_app", rule=moderation_rule, queue_manager=mock_queue_manager + ) + + def test_should_direct_output(self, output_moderation): + assert output_moderation.should_direct_output() is False + output_moderation.final_output = "blocked" + assert output_moderation.should_direct_output() is True + + def test_get_final_output(self, output_moderation): + assert output_moderation.get_final_output() == "" + output_moderation.final_output = "blocked" + assert output_moderation.get_final_output() == "blocked" + + def test_append_new_token(self, output_moderation): + with patch.object(OutputModeration, "start_thread") as mock_start: + output_moderation.append_new_token("hello") + assert output_moderation.buffer == "hello" + mock_start.assert_called_once() + + output_moderation.thread = MagicMock() + output_moderation.append_new_token(" world") + assert output_moderation.buffer == "hello world" + assert mock_start.call_count == 1 + + def test_moderation_completion_no_flag(self, output_moderation): + with patch.object(OutputModeration, "moderation") as mock_moderation: + mock_moderation.return_value = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT) + + output, flagged = output_moderation.moderation_completion("safe content") + + assert output == "safe content" + assert flagged is False + assert output_moderation.is_final_chunk is True + + def test_moderation_completion_flagged_direct_output(self, output_moderation, mock_queue_manager): + with patch.object(OutputModeration, "moderation") as mock_moderation: + mock_moderation.return_value = ModerationOutputsResult( + flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="preset" + ) + + output, flagged = output_moderation.moderation_completion("badword content", public_event=True) + + assert output == "preset" + assert flagged is True + mock_queue_manager.publish.assert_called_once() + args, _ = mock_queue_manager.publish.call_args + assert isinstance(args[0], QueueMessageReplaceEvent) + assert args[0].text == "preset" + assert args[1] == PublishFrom.TASK_PIPELINE + + def test_moderation_completion_flagged_overridden(self, output_moderation, mock_queue_manager): + with patch.object(OutputModeration, "moderation") as mock_moderation: + mock_moderation.return_value = ModerationOutputsResult( + flagged=True, action=ModerationAction.OVERRIDDEN, text="masked content" + ) + + output, flagged = output_moderation.moderation_completion("badword content", public_event=True) + + assert output == "masked content" + assert flagged is True + mock_queue_manager.publish.assert_called_once() + args, _ = mock_queue_manager.publish.call_args + assert args[0].text == "masked content" + + def test_start_thread(self, output_moderation): + mock_app = MagicMock(spec=Flask) + with patch("core.moderation.output_moderation.current_app") as mock_current_app: + mock_current_app._get_current_object.return_value = mock_app + with patch("threading.Thread") as mock_thread_class: + mock_thread_instance = MagicMock() + mock_thread_class.return_value = mock_thread_instance + + thread = output_moderation.start_thread() + + assert thread == mock_thread_instance + mock_thread_class.assert_called_once() + mock_thread_instance.start.assert_called_once() + + def test_stop_thread(self, output_moderation): + mock_thread = MagicMock() + mock_thread.is_alive.return_value = True + output_moderation.thread = mock_thread + + output_moderation.stop_thread() + assert output_moderation.thread_running is False + + output_moderation.thread_running = True + mock_thread.is_alive.return_value = False + output_moderation.stop_thread() + assert output_moderation.thread_running is True + + @patch("core.moderation.output_moderation.ModerationFactory") + def test_moderation_success(self, mock_factory_class, output_moderation): + mock_factory = mock_factory_class.return_value + mock_result = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT) + mock_factory.moderation_for_outputs.return_value = mock_result + + result = output_moderation.moderation("tenant", "app", "buffer") + + assert result == mock_result + mock_factory_class.assert_called_once_with( + name="keywords", app_id="app", tenant_id="tenant", config={"keywords": "badword"} + ) + + @patch("core.moderation.output_moderation.ModerationFactory") + def test_moderation_exception(self, mock_factory_class, output_moderation): + mock_factory_class.side_effect = Exception("error") + + result = output_moderation.moderation("tenant", "app", "buffer") + assert result is None + + def test_worker_loop_and_exit(self, output_moderation, mock_queue_manager): + mock_app = MagicMock(spec=Flask) + + # Test exit on thread_running=False + output_moderation.thread_running = False + output_moderation.worker(mock_app, 10) + # Should exit immediately + + def test_worker_no_flag(self, output_moderation): + mock_app = MagicMock(spec=Flask) + + with patch.object(OutputModeration, "moderation") as mock_moderation: + mock_moderation.return_value = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT) + + output_moderation.buffer = "safe" + output_moderation.is_final_chunk = True + + # To avoid infinite loop, we'll set thread_running to False after one iteration + def side_effect(*args, **kwargs): + output_moderation.thread_running = False + return mock_moderation.return_value + + mock_moderation.side_effect = side_effect + + output_moderation.worker(mock_app, 10) + + assert mock_moderation.called + + def test_worker_flagged_direct_output(self, output_moderation, mock_queue_manager): + mock_app = MagicMock(spec=Flask) + + with patch.object(OutputModeration, "moderation") as mock_moderation: + mock_moderation.return_value = ModerationOutputsResult( + flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="preset" + ) + + output_moderation.buffer = "badword" + output_moderation.is_final_chunk = True + + output_moderation.worker(mock_app, 10) + + assert output_moderation.final_output == "preset" + mock_queue_manager.publish.assert_called_once() + # It breaks on DIRECT_OUTPUT + + def test_worker_flagged_overridden(self, output_moderation, mock_queue_manager): + mock_app = MagicMock(spec=Flask) + + with patch.object(OutputModeration, "moderation") as mock_moderation: + # Use side_effect to change thread_running on second call + def side_effect(*args, **kwargs): + if mock_moderation.call_count > 1: + output_moderation.thread_running = False + return None + return ModerationOutputsResult(flagged=True, action=ModerationAction.OVERRIDDEN, text="masked") + + mock_moderation.side_effect = side_effect + + output_moderation.buffer = "badword" + output_moderation.is_final_chunk = True + + output_moderation.worker(mock_app, 10) + + mock_queue_manager.publish.assert_called_once() + args, _ = mock_queue_manager.publish.call_args + assert args[0].text == "masked" + + def test_worker_chunk_too_small(self, output_moderation): + mock_app = MagicMock(spec=Flask) + with patch("time.sleep") as mock_sleep: + # chunk_length < buffer_size and not is_final_chunk + output_moderation.buffer = "123" # length 3 + output_moderation.is_final_chunk = False + + def sleep_side_effect(seconds): + output_moderation.thread_running = False + + mock_sleep.side_effect = sleep_side_effect + + output_moderation.worker(mock_app, 10) # buffer_size 10 + + mock_sleep.assert_called_once_with(1) + + def test_worker_empty_not_flagged(self, output_moderation, mock_queue_manager): + mock_app = MagicMock(spec=Flask) + with patch.object(OutputModeration, "moderation") as mock_moderation: + # Return None (exception or no rule) + mock_moderation.return_value = None + + def side_effect(*args, **kwargs): + output_moderation.thread_running = False + + mock_moderation.side_effect = side_effect + + output_moderation.buffer = "something" + output_moderation.is_final_chunk = True + + output_moderation.worker(mock_app, 10) + + mock_queue_manager.publish.assert_not_called() diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py index dfd61acfa77..62d631a7541 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py +++ b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py @@ -5,6 +5,8 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from opentelemetry.trace import Link, SpanContext, SpanKind, Status, StatusCode, TraceFlags import core.ops.aliyun_trace.aliyun_trace as aliyun_trace_module @@ -34,8 +36,6 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey class RecordingTraceClient: @@ -396,14 +396,14 @@ def test_get_workflow_node_executions_builds_repo_and_fetches( monkeypatch.setattr(aliyun_trace_module, "db", SimpleNamespace(engine="engine")) repo = MagicMock() - repo.get_by_workflow_run.return_value = ["node1"] + repo.get_by_workflow_execution.return_value = ["node1"] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr(aliyun_trace_module, "DifyCoreRepositoryFactory", mock_factory) result = trace_instance.get_workflow_node_executions(trace_info) assert result == ["node1"] - repo.get_by_workflow_run.assert_called_once_with(workflow_run_id=trace_info.workflow_run_id) + repo.get_by_workflow_execution.assert_called_once_with(workflow_execution_id=trace_info.workflow_run_id) def test_build_workflow_node_span_routes_llm_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py index 763fc90710b..2d2be12f051 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py +++ b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py @@ -1,6 +1,8 @@ import json from unittest.mock import MagicMock +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionStatus from opentelemetry.trace import Link, StatusCode from core.ops.aliyun_trace.entities.semconv import ( @@ -24,8 +26,6 @@ from core.ops.aliyun_trace.utils import ( serialize_json_data, ) from core.rag.models.document import Document -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import WorkflowNodeExecutionStatus from models import EndUser @@ -45,11 +45,8 @@ def test_get_user_id_from_message_data_with_end_user(monkeypatch): end_user_data = MagicMock(spec=EndUser) end_user_data.session_id = "session_id" - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = end_user_data - mock_session = MagicMock() - mock_session.query.return_value = mock_query + mock_session.get.return_value = end_user_data from core.ops.aliyun_trace.utils import db @@ -63,11 +60,8 @@ def test_get_user_id_from_message_data_end_user_not_found(monkeypatch): message_data.from_account_id = "account_id" message_data.from_end_user_id = "end_user_id" - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = None - mock_session = MagicMock() - mock_session.query.return_value = mock_query + mock_session.get.return_value = None from core.ops.aliyun_trace.utils import db diff --git a/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py b/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py index 1cee2f5b687..4ce9e22fd77 100644 --- a/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py +++ b/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py @@ -254,7 +254,7 @@ def test_workflow_trace_full(mock_db, mock_repo_factory, mock_sessionmaker, trac node1.id = "n1" node1.error = None - repo.get_by_workflow_run.return_value = [node1] + repo.get_by_workflow_execution.return_value = [node1] with patch.object(trace_instance, "get_service_account_with_tenant"): trace_instance.workflow_trace(info) diff --git a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py b/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py index 0ff135562cb..374371fb42d 100644 --- a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py +++ b/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py @@ -5,6 +5,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.enums import BuiltinNodeTypes from core.ops.entities.config_entity import LangfuseConfig from core.ops.entities.trace_entity import ( @@ -25,7 +26,6 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( UnitEnum, ) from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace -from dify_graph.enums import BuiltinNodeTypes from models import EndUser from models.enums import MessageStatus @@ -174,7 +174,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): node_other.metadata = None repo = MagicMock() - repo.get_by_workflow_run.return_value = [node_llm, node_other] + repo.get_by_workflow_execution.return_value = [node_llm, node_other] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo @@ -244,7 +244,7 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch): monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) repo = MagicMock() - repo.get_by_workflow_run.return_value = [] + repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) @@ -365,9 +365,7 @@ def test_message_trace_with_end_user(trace_instance, monkeypatch): mock_end_user = MagicMock(spec=EndUser) mock_end_user.session_id = "session-id-123" - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = mock_end_user - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db.session.query", lambda model: mock_query) + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db.session.get", lambda model, pk: mock_end_user) trace_instance.add_trace = MagicMock() trace_instance.add_generation = MagicMock() @@ -523,11 +521,11 @@ def test_generate_name_trace(trace_instance): def test_add_trace_success(trace_instance): data = LangfuseTrace(id="t1", name="trace") trace_instance.add_trace(data) - trace_instance.langfuse_client.trace.assert_called_once() + trace_instance.langfuse_client.api.ingestion.batch.assert_called_once() def test_add_trace_error(trace_instance): - trace_instance.langfuse_client.trace.side_effect = Exception("error") + trace_instance.langfuse_client.api.ingestion.batch.side_effect = Exception("error") data = LangfuseTrace(id="t1", name="trace") with pytest.raises(ValueError, match="LangFuse Failed to create trace: error"): trace_instance.add_trace(data) @@ -536,11 +534,11 @@ def test_add_trace_error(trace_instance): def test_add_span_success(trace_instance): data = LangfuseSpan(id="s1", name="span", trace_id="t1") trace_instance.add_span(data) - trace_instance.langfuse_client.span.assert_called_once() + trace_instance.langfuse_client.api.ingestion.batch.assert_called_once() def test_add_span_error(trace_instance): - trace_instance.langfuse_client.span.side_effect = Exception("error") + trace_instance.langfuse_client.api.ingestion.batch.side_effect = Exception("error") data = LangfuseSpan(id="s1", name="span", trace_id="t1") with pytest.raises(ValueError, match="LangFuse Failed to create span: error"): trace_instance.add_span(data) @@ -556,11 +554,11 @@ def test_update_span(trace_instance): def test_add_generation_success(trace_instance): data = LangfuseGeneration(id="g1", name="gen", trace_id="t1") trace_instance.add_generation(data) - trace_instance.langfuse_client.generation.assert_called_once() + trace_instance.langfuse_client.api.ingestion.batch.assert_called_once() def test_add_generation_error(trace_instance): - trace_instance.langfuse_client.generation.side_effect = Exception("error") + trace_instance.langfuse_client.api.ingestion.batch.side_effect = Exception("error") data = LangfuseGeneration(id="g1", name="gen", trace_id="t1") with pytest.raises(ValueError, match="LangFuse Failed to create generation: error"): trace_instance.add_generation(data) @@ -587,12 +585,12 @@ def test_api_check_error(trace_instance): def test_get_project_key_success(trace_instance): mock_data = MagicMock() mock_data.id = "proj-1" - trace_instance.langfuse_client.client.projects.get.return_value = MagicMock(data=[mock_data]) + trace_instance.langfuse_client.api.projects.get.return_value = MagicMock(data=[mock_data]) assert trace_instance.get_project_key() == "proj-1" def test_get_project_key_error(trace_instance): - trace_instance.langfuse_client.client.projects.get.side_effect = Exception("fail") + trace_instance.langfuse_client.api.projects.get.side_effect = Exception("fail") with pytest.raises(ValueError, match="LangFuse get project key failed: fail"): trace_instance.get_project_key() @@ -680,7 +678,7 @@ def test_workflow_trace_handles_usage_extraction_error(trace_instance, monkeypat node.outputs = {} repo = MagicMock() - repo.get_by_workflow_run.return_value = [node] + repo.get_by_workflow_execution.return_value = [node] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) diff --git a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py b/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py index f656f7435f0..bfe916f0182 100644 --- a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py +++ b/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py @@ -3,6 +3,7 @@ from datetime import datetime, timedelta from unittest.mock import MagicMock import pytest +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from core.ops.entities.config_entity import LangSmithConfig from core.ops.entities.trace_entity import ( @@ -21,7 +22,6 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( LangSmithRunUpdateModel, ) from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser @@ -184,7 +184,7 @@ def test_workflow_trace(trace_instance, monkeypatch): node_retrieval.metadata = {} repo = MagicMock() - repo.get_by_workflow_run.return_value = [node_llm, node_other, node_retrieval] + repo.get_by_workflow_execution.return_value = [node_llm, node_other, node_retrieval] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo @@ -255,7 +255,7 @@ def test_workflow_trace_no_start_time(trace_instance, monkeypatch): monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) repo = MagicMock() - repo.get_by_workflow_run.return_value = [] + repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) @@ -319,9 +319,7 @@ def test_message_trace(trace_instance, monkeypatch): # Mock EndUser lookup mock_end_user = MagicMock(spec=EndUser) mock_end_user.session_id = "session-id-123" - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = mock_end_user - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db.session.query", lambda model: mock_query) + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db.session.get", lambda model, pk: mock_end_user) trace_instance.add_run = MagicMock() @@ -565,7 +563,7 @@ def test_workflow_trace_usage_extraction_error(trace_instance, monkeypatch, capl node_llm.metadata = {} repo = MagicMock() - repo.get_by_workflow_run.return_value = [node_llm] + repo.get_by_workflow_execution.return_value = [node_llm] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo diff --git a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py b/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py index cccedaa08cc..f4c485a9fc5 100644 --- a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py +++ b/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py @@ -9,6 +9,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from graphon.enums import BuiltinNodeTypes from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig from core.ops.entities.trace_entity import ( @@ -21,7 +22,6 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace, datetime_to_nanoseconds -from dify_graph.enums import BuiltinNodeTypes # ── Helpers ────────────────────────────────────────────────────────────────── @@ -330,7 +330,7 @@ class TestTraceDispatcher: class TestWorkflowTrace: def test_basic_workflow_no_nodes(self, trace_instance, mock_tracing, mock_db): - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + mock_db.session.scalars.return_value.all.return_value = [] span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" @@ -343,7 +343,7 @@ class TestWorkflowTrace: span.end.assert_called_once() def test_workflow_filters_sys_inputs_and_adds_query(self, trace_instance, mock_tracing, mock_db): - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + mock_db.session.scalars.return_value.all.return_value = [] span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" @@ -374,7 +374,7 @@ class TestWorkflowTrace: ), outputs='{"text": "hello world"}', ) - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [llm_node] + mock_db.session.scalars.return_value.all.return_value = [llm_node] workflow_span = MagicMock() node_span = MagicMock() @@ -397,7 +397,7 @@ class TestWorkflowTrace: } ), ) - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [qc_node] + mock_db.session.scalars.return_value.all.return_value = [qc_node] workflow_span = MagicMock() node_span = MagicMock() mock_tracing["start"].side_effect = [workflow_span, node_span] @@ -411,7 +411,7 @@ class TestWorkflowTrace: node_type=BuiltinNodeTypes.HTTP_REQUEST, process_data='{"url": "https://api.com"}', ) - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [http_node] + mock_db.session.scalars.return_value.all.return_value = [http_node] workflow_span = MagicMock() node_span = MagicMock() mock_tracing["start"].side_effect = [workflow_span, node_span] @@ -434,7 +434,7 @@ class TestWorkflowTrace: } ), ) - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [kr_node] + mock_db.session.scalars.return_value.all.return_value = [kr_node] workflow_span = MagicMock() node_span = MagicMock() mock_tracing["start"].side_effect = [workflow_span, node_span] @@ -448,7 +448,7 @@ class TestWorkflowTrace: def test_workflow_with_failed_node(self, trace_instance, mock_tracing, mock_db): failed_node = _make_node(status="failed") - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [failed_node] + mock_db.session.scalars.return_value.all.return_value = [failed_node] workflow_span = MagicMock() node_span = MagicMock() mock_tracing["start"].side_effect = [workflow_span, node_span] @@ -459,7 +459,7 @@ class TestWorkflowTrace: node_span.add_event.assert_called_once() def test_workflow_with_workflow_error(self, trace_instance, mock_tracing, mock_db): - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + mock_db.session.scalars.return_value.all.return_value = [] workflow_span = MagicMock() mock_tracing["start"].return_value = workflow_span mock_tracing["set"].return_value = "token" @@ -473,7 +473,7 @@ class TestWorkflowTrace: def test_workflow_node_no_inputs_no_outputs(self, trace_instance, mock_tracing, mock_db): node = _make_node(inputs=None, outputs=None) - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [node] + mock_db.session.scalars.return_value.all.return_value = [node] workflow_span = MagicMock() node_span = MagicMock() mock_tracing["start"].side_effect = [workflow_span, node_span] @@ -486,7 +486,7 @@ class TestWorkflowTrace: assert end_call.kwargs["outputs"] == {} def test_workflow_no_user_id_no_conversation_id(self, trace_instance, mock_tracing, mock_db): - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + mock_db.session.scalars.return_value.all.return_value = [] span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" @@ -501,7 +501,7 @@ class TestWorkflowTrace: def test_workflow_empty_query(self, trace_instance, mock_tracing, mock_db): """When query is empty string, it's falsy so no query key added.""" - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + mock_db.session.scalars.return_value.all.return_value = [] span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" @@ -680,12 +680,12 @@ class TestGetMessageUserId: def test_returns_end_user_session_id(self, trace_instance, mock_db): end_user = MagicMock() end_user.session_id = "session-1" - mock_db.session.query.return_value.where.return_value.first.return_value = end_user + mock_db.session.get.return_value = end_user result = trace_instance._get_message_user_id({"from_end_user_id": "eu-1"}) assert result == "session-1" def test_returns_account_id_when_no_end_user(self, trace_instance, mock_db): - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.get.return_value = None result = trace_instance._get_message_user_id({"from_end_user_id": "eu-1", "from_account_id": "acc-1"}) assert result == "acc-1" @@ -834,7 +834,7 @@ class TestGenerateNameTrace: class TestGetWorkflowNodes: def test_queries_db(self, trace_instance, mock_db): - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = ["n1", "n2"] + mock_db.session.scalars.return_value.all.return_value = ["n1", "n2"] result = trace_instance._get_workflow_nodes("run-1") assert result == ["n1", "n2"] diff --git a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py b/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py index b2cb7d51098..1cb32f2ee02 100644 --- a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py +++ b/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py @@ -5,6 +5,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from core.ops.entities.config_entity import OpikConfig from core.ops.entities.trace_entity import ( @@ -18,7 +19,6 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.opik_trace.opik_trace import OpikDataTrace, prepare_opik_uuid, wrap_dict, wrap_metadata -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser from models.enums import MessageStatus @@ -199,7 +199,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): node_other.metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS.value: 10} repo = MagicMock() - repo.get_by_workflow_run.return_value = [node_llm, node_other] + repo.get_by_workflow_execution.return_value = [node_llm, node_other] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo @@ -253,7 +253,7 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch): monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) repo = MagicMock() - repo.get_by_workflow_run.return_value = [] + repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) @@ -373,9 +373,7 @@ def test_message_trace_with_end_user(trace_instance, monkeypatch): mock_end_user = MagicMock(spec=EndUser) mock_end_user.session_id = "session-id-123" - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = mock_end_user - monkeypatch.setattr("core.ops.opik_trace.opik_trace.db.session.query", lambda model: mock_query) + monkeypatch.setattr("core.ops.opik_trace.opik_trace.db.session.get", lambda model, pk: mock_end_user) trace_instance.add_trace = MagicMock(return_value=MagicMock(id="trace_id_2")) trace_instance.add_span = MagicMock() @@ -657,7 +655,7 @@ def test_workflow_trace_usage_extraction_error_fixed(trace_instance, monkeypatch node.outputs = {} repo = MagicMock() - repo.get_by_workflow_run.return_value = [node] + repo.get_by_workflow_execution.return_value = [node] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py b/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py index a0b6d527208..696f859b6f7 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py +++ b/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py @@ -1,6 +1,8 @@ from datetime import datetime from unittest.mock import MagicMock, patch +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from opentelemetry.trace import StatusCode from core.ops.entities.trace_entity import ( @@ -25,8 +27,6 @@ from core.ops.tencent_trace.entities.semconv import ( ) from core.ops.tencent_trace.span_builder import TencentSpanBuilder from core.rag.models.document import Document -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus class TestTencentSpanBuilder: diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py b/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py index f259e4639fb..382e5dadc3b 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py +++ b/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py @@ -2,6 +2,8 @@ import logging from unittest.mock import MagicMock, patch import pytest +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes from core.ops.entities.config_entity import TencentConfig from core.ops.entities.trace_entity import ( @@ -14,8 +16,6 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.tencent_trace.tencent_trace import TencentDataTrace -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import BuiltinNodeTypes from models import Account, App, TenantAccountJoin logger = logging.getLogger(__name__) @@ -413,7 +413,7 @@ class TestTencentDataTrace: with patch( "core.ops.tencent_trace.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository" ) as mock_repo: - mock_repo.return_value.get_by_workflow_run.return_value = mock_executions + mock_repo.return_value.get_by_workflow_execution.return_value = mock_executions results = tencent_data_trace._get_workflow_node_executions(trace_info) diff --git a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py b/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py index 49d6b698eff..6b5cb5b09a8 100644 --- a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py +++ b/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py @@ -1,7 +1,7 @@ +from graphon.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes from openinference.semconv.trace import OpenInferenceSpanKindValues from core.ops.arize_phoenix_trace.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind -from dify_graph.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes class TestGetNodeSpanKind: diff --git a/api/tests/unit_tests/core/ops/test_lookup_helpers.py b/api/tests/unit_tests/core/ops/test_lookup_helpers.py new file mode 100644 index 00000000000..86aa68643da --- /dev/null +++ b/api/tests/unit_tests/core/ops/test_lookup_helpers.py @@ -0,0 +1,554 @@ +"""Unit tests for lookup helper functions in core.ops.ops_trace_manager. + +Covers: +- _lookup_app_and_workspace_names +- _lookup_credential_name +- _lookup_llm_credential_info +- TraceTask._get_user_id_from_metadata +""" + +from unittest.mock import MagicMock, patch + +import pytest + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_db_and_session_patches(scalar_side_effect=None, scalar_return_value=None): + """Return (mock_db, cm, session) ready to patch 'core.ops.ops_trace_manager.db' + and 'core.ops.ops_trace_manager.Session'. + + Provide either scalar_side_effect (list, for multiple calls) or + scalar_return_value (single value). + """ + mock_db = MagicMock() + mock_db.engine = MagicMock() + + session = MagicMock() + if scalar_side_effect is not None: + session.scalar.side_effect = scalar_side_effect + else: + session.scalar.return_value = scalar_return_value + + cm = MagicMock() + cm.__enter__ = MagicMock(return_value=session) + cm.__exit__ = MagicMock(return_value=False) + + return mock_db, cm, session + + +# --------------------------------------------------------------------------- +# _lookup_app_and_workspace_names +# --------------------------------------------------------------------------- + + +class TestLookupAppAndWorkspaceNames: + """Tests for _lookup_app_and_workspace_names(app_id, tenant_id).""" + + def test_both_found(self): + """Returns (app_name, workspace_name) when both records exist.""" + from core.ops.ops_trace_manager import _lookup_app_and_workspace_names + + mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=["MyApp", "MyWorkspace"]) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + app_name, workspace_name = _lookup_app_and_workspace_names("app-123", "tenant-456") + + assert app_name == "MyApp" + assert workspace_name == "MyWorkspace" + + def test_app_only_found(self): + """Returns (app_name, '') when tenant record is absent.""" + from core.ops.ops_trace_manager import _lookup_app_and_workspace_names + + mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=["MyApp", None]) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + app_name, workspace_name = _lookup_app_and_workspace_names("app-123", "tenant-456") + + assert app_name == "MyApp" + assert workspace_name == "" + + def test_tenant_only_found(self): + """Returns ('', workspace_name) when app record is absent.""" + from core.ops.ops_trace_manager import _lookup_app_and_workspace_names + + mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=[None, "MyWorkspace"]) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + app_name, workspace_name = _lookup_app_and_workspace_names("app-123", "tenant-456") + + assert app_name == "" + assert workspace_name == "MyWorkspace" + + def test_neither_found(self): + """Returns ('', '') when both DB lookups return None.""" + from core.ops.ops_trace_manager import _lookup_app_and_workspace_names + + mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=[None, None]) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + app_name, workspace_name = _lookup_app_and_workspace_names("app-123", "tenant-456") + + assert app_name == "" + assert workspace_name == "" + + def test_none_inputs_skips_db(self): + """Returns ('', '') immediately when both IDs are None — no DB access.""" + from core.ops.ops_trace_manager import _lookup_app_and_workspace_names + + mock_db = MagicMock() + mock_session_cls = MagicMock() + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", mock_session_cls), + ): + app_name, workspace_name = _lookup_app_and_workspace_names(None, None) + + mock_session_cls.assert_not_called() + assert app_name == "" + assert workspace_name == "" + + def test_app_id_none_only_queries_tenant(self): + """When app_id is None, only the tenant query is issued.""" + from core.ops.ops_trace_manager import _lookup_app_and_workspace_names + + mock_db, cm, session = _make_db_and_session_patches(scalar_return_value="OnlyWorkspace") + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + app_name, workspace_name = _lookup_app_and_workspace_names(None, "tenant-456") + + assert app_name == "" + assert workspace_name == "OnlyWorkspace" + assert session.scalar.call_count == 1 + + def test_tenant_id_none_only_queries_app(self): + """When tenant_id is None, only the app query is issued.""" + from core.ops.ops_trace_manager import _lookup_app_and_workspace_names + + mock_db, cm, session = _make_db_and_session_patches(scalar_return_value="OnlyApp") + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + app_name, workspace_name = _lookup_app_and_workspace_names("app-123", None) + + assert app_name == "OnlyApp" + assert workspace_name == "" + assert session.scalar.call_count == 1 + + +# --------------------------------------------------------------------------- +# _lookup_credential_name +# --------------------------------------------------------------------------- + + +class TestLookupCredentialName: + """Tests for _lookup_credential_name(credential_id, provider_type).""" + + @pytest.mark.parametrize("provider_type", ["builtin", "plugin", "api", "workflow", "mcp"]) + def test_known_provider_types_return_name(self, provider_type): + """Each valid provider_type results in a DB query and returns the credential name.""" + from core.ops.ops_trace_manager import _lookup_credential_name + + mock_db, cm, session = _make_db_and_session_patches(scalar_return_value="CredentialA") + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + result = _lookup_credential_name("cred-123", provider_type) + + assert result == "CredentialA" + session.scalar.assert_called_once() + + def test_credential_not_found_returns_empty_string(self): + """Returns '' when DB yields None for the given credential_id.""" + from core.ops.ops_trace_manager import _lookup_credential_name + + mock_db, cm, _session = _make_db_and_session_patches(scalar_return_value=None) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + result = _lookup_credential_name("cred-999", "api") + + assert result == "" + + def test_invalid_provider_type_returns_empty_string_without_db(self): + """Returns '' immediately for an unrecognised provider_type — no DB access.""" + from core.ops.ops_trace_manager import _lookup_credential_name + + mock_db = MagicMock() + mock_session_cls = MagicMock() + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", mock_session_cls), + ): + result = _lookup_credential_name("cred-123", "unknown_type") + + mock_session_cls.assert_not_called() + assert result == "" + + def test_none_credential_id_returns_empty_string_without_db(self): + """Returns '' immediately when credential_id is None — no DB access.""" + from core.ops.ops_trace_manager import _lookup_credential_name + + mock_db = MagicMock() + mock_session_cls = MagicMock() + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", mock_session_cls), + ): + result = _lookup_credential_name(None, "api") + + mock_session_cls.assert_not_called() + assert result == "" + + def test_none_provider_type_returns_empty_string_without_db(self): + """Returns '' immediately when provider_type is None — no DB access.""" + from core.ops.ops_trace_manager import _lookup_credential_name + + mock_db = MagicMock() + mock_session_cls = MagicMock() + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", mock_session_cls), + ): + result = _lookup_credential_name("cred-123", None) + + mock_session_cls.assert_not_called() + assert result == "" + + def test_builtin_and_plugin_map_to_same_model(self): + """Both 'builtin' and 'plugin' provider_types query BuiltinToolProvider.""" + from core.ops.ops_trace_manager import _PROVIDER_TYPE_TO_MODEL + from models.tools import BuiltinToolProvider + + assert _PROVIDER_TYPE_TO_MODEL["builtin"] is BuiltinToolProvider + assert _PROVIDER_TYPE_TO_MODEL["plugin"] is BuiltinToolProvider + + def test_api_maps_to_api_tool_provider(self): + """'api' maps to ApiToolProvider.""" + from core.ops.ops_trace_manager import _PROVIDER_TYPE_TO_MODEL + from models.tools import ApiToolProvider + + assert _PROVIDER_TYPE_TO_MODEL["api"] is ApiToolProvider + + def test_workflow_maps_to_workflow_tool_provider(self): + """'workflow' maps to WorkflowToolProvider.""" + from core.ops.ops_trace_manager import _PROVIDER_TYPE_TO_MODEL + from models.tools import WorkflowToolProvider + + assert _PROVIDER_TYPE_TO_MODEL["workflow"] is WorkflowToolProvider + + def test_mcp_maps_to_mcp_tool_provider(self): + """'mcp' maps to MCPToolProvider.""" + from core.ops.ops_trace_manager import _PROVIDER_TYPE_TO_MODEL + from models.tools import MCPToolProvider + + assert _PROVIDER_TYPE_TO_MODEL["mcp"] is MCPToolProvider + + +# --------------------------------------------------------------------------- +# _lookup_llm_credential_info +# --------------------------------------------------------------------------- + + +class TestLookupLlmCredentialInfo: + """Tests for _lookup_llm_credential_info(tenant_id, provider, model, model_type).""" + + def _provider_record(self, credential_id: str | None = None) -> MagicMock: + record = MagicMock() + record.credential_id = credential_id + return record + + def _model_record(self, credential_id: str | None = None) -> MagicMock: + record = MagicMock() + record.credential_id = credential_id + return record + + def test_model_level_credential_found(self): + """Returns model-level credential_id and name when ProviderModel has a credential.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + provider_record = self._provider_record(credential_id=None) + model_record = self._model_record(credential_id="model-cred-id") + + # scalar calls: (1) Provider, (2) ProviderModel, (3) ProviderModelCredential.credential_name + mock_db, cm, _session = _make_db_and_session_patches( + scalar_side_effect=[provider_record, model_record, "ModelCredName"] + ) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4") + + assert cred_id == "model-cred-id" + assert cred_name == "ModelCredName" + + def test_provider_level_fallback_when_no_model_credential(self): + """Falls back to provider-level credential when ProviderModel has no credential_id.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + provider_record = self._provider_record(credential_id="prov-cred-id") + model_record = self._model_record(credential_id=None) + + # scalar calls: (1) Provider, (2) ProviderModel (no cred), (3) ProviderCredential.credential_name + mock_db, cm, _session = _make_db_and_session_patches( + scalar_side_effect=[provider_record, model_record, "ProvCredName"] + ) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4") + + assert cred_id == "prov-cred-id" + assert cred_name == "ProvCredName" + + def test_provider_level_fallback_when_no_model_record(self): + """Falls back to provider-level credential when no ProviderModel row exists.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + provider_record = self._provider_record(credential_id="prov-cred-id") + + # scalar calls: (1) Provider, (2) ProviderModel → None, (3) ProviderCredential.credential_name + mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=[provider_record, None, "ProvCredName"]) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4") + + assert cred_id == "prov-cred-id" + assert cred_name == "ProvCredName" + + def test_no_model_arg_uses_provider_level_only(self): + """When model is None, skips ProviderModel query and uses provider credential.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + provider_record = self._provider_record(credential_id="prov-cred-id") + + # scalar calls: (1) Provider, (2) ProviderCredential.credential_name — no ProviderModel + mock_db, cm, session = _make_db_and_session_patches(scalar_side_effect=[provider_record, "ProvCredName"]) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", None) + + assert cred_id == "prov-cred-id" + assert cred_name == "ProvCredName" + assert session.scalar.call_count == 2 + + def test_provider_not_found_returns_none_and_empty(self): + """Returns (None, '') when Provider record does not exist.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + mock_db, cm, _session = _make_db_and_session_patches(scalar_return_value=None) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4") + + assert cred_id is None + assert cred_name == "" + + def test_none_tenant_id_returns_none_and_empty_without_db(self): + """Returns (None, '') immediately when tenant_id is None — no DB access.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + mock_db = MagicMock() + mock_session_cls = MagicMock() + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", mock_session_cls), + ): + cred_id, cred_name = _lookup_llm_credential_info(None, "openai", "gpt-4") + + mock_session_cls.assert_not_called() + assert cred_id is None + assert cred_name == "" + + def test_none_provider_returns_none_and_empty_without_db(self): + """Returns (None, '') immediately when provider is None — no DB access.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + mock_db = MagicMock() + mock_session_cls = MagicMock() + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", mock_session_cls), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", None, "gpt-4") + + mock_session_cls.assert_not_called() + assert cred_id is None + assert cred_name == "" + + def test_db_error_on_outer_query_returns_none_and_empty(self): + """Returns (None, '') and logs a warning when the outer DB query raises.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + mock_db, cm, session = _make_db_and_session_patches() + session.scalar.side_effect = Exception("DB connection failed") + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4") + + assert cred_id is None + assert cred_name == "" + + def test_credential_name_lookup_failure_returns_id_with_empty_name(self): + """When credential name sub-query fails, returns cred_id but '' for name.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + provider_record = self._provider_record(credential_id="prov-cred-id") + + # Provider found, no model record, then name lookup raises + mock_db, cm, _session = _make_db_and_session_patches( + scalar_side_effect=[provider_record, None, Exception("deleted")] + ) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4") + + assert cred_id == "prov-cred-id" + assert cred_name == "" + + def test_no_credential_on_provider_or_model_returns_none_id(self): + """Returns (None, '') when neither provider nor model has a credential_id.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + provider_record = self._provider_record(credential_id=None) + model_record = self._model_record(credential_id=None) + + mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=[provider_record, model_record]) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4") + + assert cred_id is None + assert cred_name == "" + + +# --------------------------------------------------------------------------- +# TraceTask._get_user_id_from_metadata +# --------------------------------------------------------------------------- + + +class TestGetUserIdFromMetadata: + """Tests for TraceTask._get_user_id_from_metadata(metadata). + + Pure dict logic — no DB access required. + """ + + @pytest.fixture + def get_user_id(self): + """Return the classmethod under test.""" + from core.ops.ops_trace_manager import TraceTask + + return TraceTask._get_user_id_from_metadata + + def test_from_end_user_id_has_highest_priority(self, get_user_id): + """from_end_user_id takes precedence over all other keys.""" + metadata = { + "from_end_user_id": "eu-abc", + "from_account_id": "acc-xyz", + "user_id": "u-123", + } + assert get_user_id(metadata) == "end_user:eu-abc" + + def test_from_account_id_used_when_no_end_user(self, get_user_id): + """from_account_id is used when from_end_user_id is absent.""" + metadata = { + "from_account_id": "acc-xyz", + "user_id": "u-123", + } + assert get_user_id(metadata) == "account:acc-xyz" + + def test_user_id_used_when_no_end_user_or_account(self, get_user_id): + """user_id is used when both higher-priority keys are absent.""" + metadata = {"user_id": "u-123"} + assert get_user_id(metadata) == "user:u-123" + + def test_returns_anonymous_when_all_keys_absent(self, get_user_id): + """Returns 'anonymous' when metadata has none of the expected keys.""" + assert get_user_id({}) == "anonymous" + + def test_empty_string_end_user_id_is_skipped(self, get_user_id): + """Empty string for from_end_user_id is falsy and falls through to next key.""" + metadata = { + "from_end_user_id": "", + "from_account_id": "acc-xyz", + } + assert get_user_id(metadata) == "account:acc-xyz" + + def test_empty_string_account_id_is_skipped(self, get_user_id): + """Empty string for from_account_id is falsy and falls through to user_id.""" + metadata = { + "from_end_user_id": "", + "from_account_id": "", + "user_id": "u-123", + } + assert get_user_id(metadata) == "user:u-123" + + def test_empty_string_user_id_falls_through_to_anonymous(self, get_user_id): + """Empty string for user_id is falsy, so 'anonymous' is returned.""" + metadata = { + "from_end_user_id": "", + "from_account_id": "", + "user_id": "", + } + assert get_user_id(metadata) == "anonymous" + + def test_only_from_end_user_id_present(self, get_user_id): + """Minimal case: only from_end_user_id present.""" + assert get_user_id({"from_end_user_id": "eu-only"}) == "end_user:eu-only" + + def test_irrelevant_keys_do_not_interfere(self, get_user_id): + """Extra metadata keys have no effect on the result.""" + metadata = {"invoke_from": "web", "app_id": "a1"} + assert get_user_id(metadata) == "anonymous" diff --git a/api/tests/unit_tests/core/ops/test_opik_trace.py b/api/tests/unit_tests/core/ops/test_opik_trace.py index 76609671834..ad9d0846be1 100644 --- a/api/tests/unit_tests/core/ops/test_opik_trace.py +++ b/api/tests/unit_tests/core/ops/test_opik_trace.py @@ -130,7 +130,7 @@ class TestWorkflowTraceWithoutMessageId: def _run(self, trace_info: WorkflowTraceInfo, node_executions: list | None = None): instance = _make_opik_trace_instance() fake_repo = MagicMock() - fake_repo.get_by_workflow_run.return_value = node_executions or [] + fake_repo.get_by_workflow_execution.return_value = node_executions or [] with ( patch("core.ops.opik_trace.opik_trace.db") as mock_db, @@ -262,7 +262,7 @@ class TestWorkflowTraceWithMessageId: def _run(self, trace_info: WorkflowTraceInfo, node_executions: list | None = None): instance = _make_opik_trace_instance() fake_repo = MagicMock() - fake_repo.get_by_workflow_run.return_value = node_executions or [] + fake_repo.get_by_workflow_execution.return_value = node_executions or [] with ( patch("core.ops.opik_trace.opik_trace.db") as mock_db, diff --git a/api/tests/unit_tests/core/ops/test_ops_trace_manager.py b/api/tests/unit_tests/core/ops/test_ops_trace_manager.py index 2d325ccb0eb..e47df0121ea 100644 --- a/api/tests/unit_tests/core/ops/test_ops_trace_manager.py +++ b/api/tests/unit_tests/core/ops/test_ops_trace_manager.py @@ -86,6 +86,7 @@ def make_message_data(**overrides): created_at = datetime(2025, 2, 20, 12, 0, 0) base = { "id": "msg-id", + "app_id": "app-id", "conversation_id": "conv-id", "created_at": created_at, "updated_at": created_at + timedelta(seconds=3), @@ -156,17 +157,19 @@ def make_workflow_run(): ) -def configure_db_query(session, *, message_file=None, workflow_app_log=None): - def _side_effect(model): - query = MagicMock() - query.filter_by.return_value.first.return_value = None - if message_file and model.__name__ == "MessageFile": - query.filter_by.return_value.first.return_value = message_file - if workflow_app_log and model.__name__ == "WorkflowAppLog": - query.filter_by.return_value.first.return_value = workflow_app_log - return query +def configure_db_scalar(session, *, message_file=None, workflow_app_log=None): + """Configure session.scalar to return appropriate values for MessageFile/WorkflowAppLog lookups.""" + original_scalar = session.scalar - session.query.side_effect = _side_effect + def _side_effect(stmt): + stmt_str = str(stmt) + if "message_file" in stmt_str.lower(): + return message_file + if "workflow_app_log" in stmt_str.lower(): + return workflow_app_log + return original_scalar(stmt) + + session.scalar.side_effect = _side_effect class DummySessionContext: @@ -182,6 +185,9 @@ class DummySessionContext: def __exit__(self, exc_type, exc_val, exc_tb): return False + def execute(self, *args, **kwargs): + return self + def scalar(self, *args, **kwargs): if self._index >= len(self._values): return None @@ -189,6 +195,12 @@ class DummySessionContext: self._index += 1 return value + def scalars(self, *args, **kwargs): + return self + + def all(self): + return [] + @pytest.fixture(autouse=True) def patch_provider_map(monkeypatch): @@ -253,7 +265,7 @@ def workflow_repo_fixture(monkeypatch): def trace_task_message(monkeypatch, mock_db): message_data = make_message_data() monkeypatch.setattr("core.ops.ops_trace_manager.get_message_data", lambda msg_id: message_data) - configure_db_query(mock_db, message_file=FakeMessageFile(), workflow_app_log=SimpleNamespace(id="log-id")) + configure_db_scalar(mock_db, message_file=FakeMessageFile(), workflow_app_log=SimpleNamespace(id="log-id")) return message_data @@ -297,56 +309,53 @@ def test_obfuscated_decrypt_token(encryption_mocks): def test_get_decrypted_tracing_config_returns_config(encryption_mocks, mock_db): trace_config_data = SimpleNamespace(tracing_config={"secret_value": "enc", "other_value": "info"}) - mock_db.query.return_value.where.return_value.first.return_value = trace_config_data app = SimpleNamespace(id="app-id", tenant_id="tenant") - mock_db.scalar.return_value = app + mock_db.scalar.side_effect = [trace_config_data, app] decrypted = OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") assert decrypted["other_value"] == "info" def test_get_decrypted_tracing_config_missing_trace_config(mock_db): - mock_db.query.return_value.where.return_value.first.return_value = None + mock_db.scalar.return_value = None assert OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") is None def test_get_decrypted_tracing_config_raises_for_missing_app(mock_db): trace_config_data = SimpleNamespace(tracing_config={"secret_value": "enc"}) - mock_db.query.return_value.where.return_value.first.return_value = trace_config_data - mock_db.scalar.return_value = None + mock_db.scalar.side_effect = [trace_config_data, None] with pytest.raises(ValueError, match="App not found"): OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") def test_get_decrypted_tracing_config_raises_for_none_config(mock_db): trace_config_data = SimpleNamespace(tracing_config=None) - mock_db.query.return_value.where.return_value.first.return_value = trace_config_data - mock_db.scalar.return_value = SimpleNamespace(tenant_id="tenant") + mock_db.scalar.side_effect = [trace_config_data, SimpleNamespace(tenant_id="tenant")] with pytest.raises(ValueError, match="Tracing config cannot be None"): OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") def test_get_ops_trace_instance_handles_none_app(mock_db): - mock_db.query.return_value.where.return_value.first.return_value = None + mock_db.get.return_value = None assert OpsTraceManager.get_ops_trace_instance("app-id") is None def test_get_ops_trace_instance_returns_none_when_disabled(mock_db, monkeypatch): app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": False})) - mock_db.query.return_value.where.return_value.first.return_value = app + mock_db.get.return_value = app assert OpsTraceManager.get_ops_trace_instance("app-id") is None def test_get_ops_trace_instance_invalid_provider(mock_db, monkeypatch): app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": True, "tracing_provider": "missing"})) - mock_db.query.return_value.where.return_value.first.return_value = app + mock_db.get.return_value = app monkeypatch.setattr("core.ops.ops_trace_manager.provider_config_map", FakeProviderMap({})) assert OpsTraceManager.get_ops_trace_instance("app-id") is None def test_get_ops_trace_instance_success(monkeypatch, mock_db): app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": True, "tracing_provider": "dummy"})) - mock_db.query.return_value.where.return_value.first.return_value = app + mock_db.get.return_value = app monkeypatch.setattr( "core.ops.ops_trace_manager.OpsTraceManager.get_decrypted_tracing_config", classmethod(lambda cls, aid, provider: {"secret_value": "decrypted", "other_value": "info"}), @@ -380,7 +389,7 @@ def test_get_app_config_through_message_id_app_model_config(mock_db): def test_update_app_tracing_config_invalid_provider(mock_db, monkeypatch): - mock_db.query.return_value.where.return_value.first.return_value = None + mock_db.get.return_value = None with pytest.raises(ValueError, match="Invalid tracing provider"): OpsTraceManager.update_app_tracing_config("app", True, "bad") with pytest.raises(ValueError, match="App not found"): @@ -389,26 +398,26 @@ def test_update_app_tracing_config_invalid_provider(mock_db, monkeypatch): def test_update_app_tracing_config_success(mock_db): app = SimpleNamespace(id="app-id", tracing="{}") - mock_db.query.return_value.where.return_value.first.return_value = app + mock_db.get.return_value = app OpsTraceManager.update_app_tracing_config("app-id", True, "dummy") assert app.tracing is not None mock_db.commit.assert_called_once() def test_get_app_tracing_config_errors_when_missing(mock_db): - mock_db.query.return_value.where.return_value.first.return_value = None + mock_db.get.return_value = None with pytest.raises(ValueError, match="App not found"): OpsTraceManager.get_app_tracing_config("app") def test_get_app_tracing_config_returns_defaults(mock_db): - mock_db.query.return_value.where.return_value.first.return_value = SimpleNamespace(tracing=None) + mock_db.get.return_value = SimpleNamespace(tracing=None) assert OpsTraceManager.get_app_tracing_config("app-id") == {"enabled": False, "tracing_provider": None} def test_get_app_tracing_config_returns_payload(mock_db): payload = {"enabled": True, "tracing_provider": "dummy"} - mock_db.query.return_value.where.return_value.first.return_value = SimpleNamespace(tracing=json.dumps(payload)) + mock_db.get.return_value = SimpleNamespace(tracing=json.dumps(payload)) assert OpsTraceManager.get_app_tracing_config("app-id") == payload @@ -454,7 +463,7 @@ def test_trace_task_message_trace(trace_task_message, mock_db): def test_trace_task_workflow_trace(workflow_repo_fixture, mock_db): DummySessionContext.scalar_values = ["wf-app-log", "message-ref"] - execution = SimpleNamespace(id_="run-id") + execution = SimpleNamespace(id_="run-id", total_tokens=0) task = TraceTask( trace_type=TraceTaskName.WORKFLOW_TRACE, workflow_execution=execution, conversation_id="conv", user_id="user" ) @@ -491,7 +500,7 @@ def test_trace_task_dataset_retrieval_trace(trace_task_message): def test_trace_task_tool_trace(monkeypatch, mock_db): custom_message = make_message_data(agent_thoughts=[make_agent_thought("tool-a", datetime(2025, 2, 20, 12, 1, 0))]) monkeypatch.setattr("core.ops.ops_trace_manager.get_message_data", lambda _: custom_message) - configure_db_query(mock_db, message_file=FakeMessageFile()) + configure_db_scalar(mock_db, message_file=FakeMessageFile()) task = TraceTask(trace_type=TraceTaskName.TOOL_TRACE, message_id="msg-id") timer = {"start": 1, "end": 5} result = task.tool_trace("msg-id", timer, tool_name="tool-a", tool_inputs={"foo": 1}, tool_outputs="result") diff --git a/api/tests/unit_tests/core/ops/test_trace_queue_manager.py b/api/tests/unit_tests/core/ops/test_trace_queue_manager.py new file mode 100644 index 00000000000..a4903054e0a --- /dev/null +++ b/api/tests/unit_tests/core/ops/test_trace_queue_manager.py @@ -0,0 +1,194 @@ +"""Unit tests for TraceQueueManager telemetry guard. + +Verifies that TraceQueueManager.add_trace_task() only enqueues tasks when at +least one consumer is active: +- Enterprise telemetry is enabled (_enterprise_telemetry_enabled=True), OR +- A third-party trace instance (Langfuse, etc.) is configured + +When neither is active, tasks are silently dropped to avoid unnecessary work. + +When BOTH are false, tasks are silently dropped (correct behavior). +""" + +import queue +import sys +import types +from unittest.mock import MagicMock, patch + +import pytest + + +@pytest.fixture +def trace_queue_manager_and_task(monkeypatch): + """Fixture to provide TraceQueueManager and TraceTask with delayed imports.""" + module_name = "core.ops.ops_trace_manager" + if module_name not in sys.modules: + ops_stub = types.ModuleType(module_name) + + class StubTraceTask: + def __init__(self, trace_type): + self.trace_type = trace_type + self.app_id = None + + class StubTraceQueueManager: + def __init__(self, app_id=None): + self.app_id = app_id + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + self._enterprise_telemetry_enabled = is_enterprise_telemetry_enabled() + self.trace_instance = StubOpsTraceManager.get_ops_trace_instance(app_id) + + def add_trace_task(self, trace_task): + if self._enterprise_telemetry_enabled or self.trace_instance: + trace_task.app_id = self.app_id + from core.ops.ops_trace_manager import trace_manager_queue + + trace_manager_queue.put(trace_task) + + class StubOpsTraceManager: + @staticmethod + def get_ops_trace_instance(app_id): + return None + + ops_stub.TraceQueueManager = StubTraceQueueManager + ops_stub.TraceTask = StubTraceTask + ops_stub.OpsTraceManager = StubOpsTraceManager + ops_stub.trace_manager_queue = MagicMock(spec=queue.Queue) + monkeypatch.setitem(sys.modules, module_name, ops_stub) + + from core.ops.entities.trace_entity import TraceTaskName + + ops_module = __import__(module_name, fromlist=["TraceQueueManager", "TraceTask"]) + TraceQueueManager = ops_module.TraceQueueManager + TraceTask = ops_module.TraceTask + + return TraceQueueManager, TraceTask, TraceTaskName + + +class TestTraceQueueManagerTelemetryGuard: + """Test TraceQueueManager's telemetry guard in add_trace_task().""" + + def test_task_not_enqueued_when_telemetry_disabled_and_no_trace_instance(self, trace_queue_manager_and_task): + """Verify task is NOT enqueued when telemetry disabled and no trace instance. + + This is the core guard: when _enterprise_telemetry_enabled=False AND + trace_instance=None, the task should be silently dropped. + """ + TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task + + mock_queue = MagicMock(spec=queue.Queue) + + trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE) + + with ( + patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False), + patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None), + patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue), + ): + manager = TraceQueueManager(app_id="test-app-id") + manager.add_trace_task(trace_task) + + mock_queue.put.assert_not_called() + + def test_task_enqueued_when_telemetry_enabled(self, trace_queue_manager_and_task): + """Verify task IS enqueued when enterprise telemetry is enabled. + + When _enterprise_telemetry_enabled=True, the task should be enqueued + regardless of trace_instance state. + """ + TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task + + mock_queue = MagicMock(spec=queue.Queue) + + trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE) + + with ( + patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True), + patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None), + patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue), + ): + manager = TraceQueueManager(app_id="test-app-id") + manager.add_trace_task(trace_task) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.app_id == "test-app-id" + + def test_task_enqueued_when_trace_instance_configured(self, trace_queue_manager_and_task): + """Verify task IS enqueued when third-party trace instance is configured. + + When trace_instance is not None (e.g., Langfuse configured), the task + should be enqueued even if enterprise telemetry is disabled. + """ + TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task + + mock_queue = MagicMock(spec=queue.Queue) + + mock_trace_instance = MagicMock() + + trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE) + + with ( + patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False), + patch( + "core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=mock_trace_instance + ), + patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue), + ): + manager = TraceQueueManager(app_id="test-app-id") + manager.add_trace_task(trace_task) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.app_id == "test-app-id" + + def test_task_enqueued_when_both_telemetry_and_trace_instance_enabled(self, trace_queue_manager_and_task): + """Verify task IS enqueued when both telemetry and trace instance are enabled. + + When both _enterprise_telemetry_enabled=True AND trace_instance is set, + the task should definitely be enqueued. + """ + TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task + + mock_queue = MagicMock(spec=queue.Queue) + + mock_trace_instance = MagicMock() + + trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE) + + with ( + patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True), + patch( + "core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=mock_trace_instance + ), + patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue), + ): + manager = TraceQueueManager(app_id="test-app-id") + manager.add_trace_task(trace_task) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.app_id == "test-app-id" + + def test_app_id_set_before_enqueue(self, trace_queue_manager_and_task): + """Verify app_id is set on the task before enqueuing. + + The guard logic sets trace_task.app_id = self.app_id before calling + trace_manager_queue.put(trace_task). This test verifies that behavior. + """ + TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task + + mock_queue = MagicMock(spec=queue.Queue) + + trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE) + + with ( + patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True), + patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None), + patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue), + ): + manager = TraceQueueManager(app_id="expected-app-id") + manager.add_trace_task(trace_task) + + called_task = mock_queue.put.call_args[0][0] + assert called_task.app_id == "expected-app-id" diff --git a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py b/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py index 8057bbbad51..5014f40afca 100644 --- a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py +++ b/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py @@ -7,6 +7,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from weave.trace_server.trace_server_interface import TraceStatus from core.ops.entities.config_entity import WeaveConfig @@ -22,7 +23,6 @@ from core.ops.entities.trace_entity import ( ) from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel from core.ops.weave_trace.weave_trace import WeaveDataTrace -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey # ── Helpers ────────────────────────────────────────────────────────────────── @@ -589,7 +589,7 @@ class TestWorkflowTrace: nodes = [] repo = MagicMock() - repo.get_by_workflow_run.return_value = nodes + repo.get_by_workflow_execution.return_value = nodes mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo @@ -802,8 +802,8 @@ class TestMessageTrace: def test_basic_message_trace(self, trace_instance, monkeypatch): """message_trace creates message run and llm child run.""" monkeypatch.setattr( - "core.ops.weave_trace.weave_trace.db.session.query", - lambda model: MagicMock(where=lambda: MagicMock(first=lambda: None)), + "core.ops.weave_trace.weave_trace.db.session.get", + lambda model, pk: None, ) trace_instance.start_call = MagicMock() @@ -823,7 +823,7 @@ class TestMessageTrace: trace_instance.file_base_url = "http://files.test" mock_db = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.get.return_value = None monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() @@ -845,7 +845,7 @@ class TestMessageTrace: end_user.session_id = "session-xyz" mock_db = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = end_user + mock_db.session.get.return_value = end_user monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() @@ -865,7 +865,7 @@ class TestMessageTrace: def test_message_trace_no_end_user(self, trace_instance, monkeypatch): """message_trace handles when from_end_user_id is None.""" mock_db = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.get.return_value = None monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() @@ -883,7 +883,7 @@ class TestMessageTrace: def test_message_trace_trace_id_fallback_to_message_id(self, trace_instance, monkeypatch): """trace_id falls back to message_id when trace_id is None.""" mock_db = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.get.return_value = None monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() @@ -898,7 +898,7 @@ class TestMessageTrace: def test_message_trace_file_list_none(self, trace_instance, monkeypatch): """message_trace handles file_list=None gracefully.""" mock_db = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.get.return_value = None monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() diff --git a/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py b/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py new file mode 100644 index 00000000000..7491e79f305 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py @@ -0,0 +1,36 @@ +from unittest.mock import Mock, patch + +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly + + +def test_plugin_model_assembly_reuses_single_runtime_across_views(): + runtime = Mock(name="runtime") + provider_factory = Mock(name="provider_factory") + provider_manager = Mock(name="provider_manager") + model_manager = Mock(name="model_manager") + + with ( + patch( + "core.plugin.impl.model_runtime_factory.create_plugin_model_runtime", + return_value=runtime, + ) as mock_runtime_factory, + patch( + "core.plugin.impl.model_runtime_factory.ModelProviderFactory", + return_value=provider_factory, + ) as mock_provider_factory_cls, + patch("core.provider_manager.ProviderManager", return_value=provider_manager) as mock_provider_manager_cls, + patch("core.model_manager.ModelManager", return_value=model_manager) as mock_model_manager_cls, + ): + assembly = create_plugin_model_assembly(tenant_id="tenant-1", user_id="user-1") + + assert assembly.model_provider_factory is provider_factory + assert assembly.provider_manager is provider_manager + assert assembly.model_manager is model_manager + assert assembly.model_provider_factory is provider_factory + assert assembly.provider_manager is provider_manager + assert assembly.model_manager is model_manager + + mock_runtime_factory.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") + mock_provider_factory_cls.assert_called_once_with(model_runtime=runtime) + mock_provider_manager_cls.assert_called_once_with(model_runtime=runtime) + mock_model_manager_cls.assert_called_once_with(provider_manager=provider_manager) diff --git a/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py b/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py index c2778f082b8..3feb4159ade 100644 --- a/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py +++ b/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py @@ -332,27 +332,21 @@ class TestPluginAppBackwardsInvocation: PluginAppBackwardsInvocation._get_user("uid") def test_get_app_returns_app(self, mocker): - query_chain = MagicMock() - query_chain.where.return_value = query_chain app_obj = MagicMock(id="app") - query_chain.first.return_value = app_obj - db = SimpleNamespace(session=MagicMock(query=MagicMock(return_value=query_chain))) + db = SimpleNamespace(session=MagicMock(scalar=MagicMock(return_value=app_obj))) mocker.patch("core.plugin.backwards_invocation.app.db", db) assert PluginAppBackwardsInvocation._get_app("app", "tenant") is app_obj def test_get_app_raises_when_missing(self, mocker): - query_chain = MagicMock() - query_chain.where.return_value = query_chain - query_chain.first.return_value = None - db = SimpleNamespace(session=MagicMock(query=MagicMock(return_value=query_chain))) + db = SimpleNamespace(session=MagicMock(scalar=MagicMock(return_value=None))) mocker.patch("core.plugin.backwards_invocation.app.db", db) with pytest.raises(ValueError, match="app not found"): PluginAppBackwardsInvocation._get_app("app", "tenant") def test_get_app_raises_when_query_fails(self, mocker): - db = SimpleNamespace(session=MagicMock(query=MagicMock(side_effect=RuntimeError("db down")))) + db = SimpleNamespace(session=MagicMock(scalar=MagicMock(side_effect=RuntimeError("db down")))) mocker.patch("core.plugin.backwards_invocation.app.db", db) with pytest.raises(ValueError, match="app not found"): diff --git a/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py b/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py new file mode 100644 index 00000000000..543b278715d --- /dev/null +++ b/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py @@ -0,0 +1,62 @@ +from types import SimpleNamespace +from unittest.mock import patch + +from graphon.model_runtime.entities.message_entities import UserPromptMessage + +from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation +from core.plugin.entities.request import RequestInvokeSummary + + +def test_system_model_helpers_forward_user_id(): + with ( + patch( + "core.plugin.backwards_invocation.model.ModelInvocationUtils.get_max_llm_context_tokens", + return_value=4096, + ) as mock_max_tokens, + patch( + "core.plugin.backwards_invocation.model.ModelInvocationUtils.calculate_tokens", + return_value=7, + ) as mock_prompt_tokens, + ): + assert PluginModelBackwardsInvocation.get_system_model_max_tokens("tenant-1", user_id="user-1") == 4096 + assert ( + PluginModelBackwardsInvocation.get_prompt_tokens( + "tenant-1", + [UserPromptMessage(content="hello")], + user_id="user-1", + ) + == 7 + ) + + mock_max_tokens.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") + mock_prompt_tokens.assert_called_once_with( + tenant_id="tenant-1", + prompt_messages=[UserPromptMessage(content="hello")], + user_id="user-1", + ) + + +def test_invoke_summary_uses_same_user_scope_for_token_helpers(): + tenant = SimpleNamespace(id="tenant-1") + payload = RequestInvokeSummary(text="short", instruction="keep it concise") + + with ( + patch.object( + PluginModelBackwardsInvocation, + "get_system_model_max_tokens", + return_value=100, + ) as mock_max_tokens, + patch.object( + PluginModelBackwardsInvocation, + "get_prompt_tokens", + return_value=10, + ) as mock_prompt_tokens, + ): + assert PluginModelBackwardsInvocation.invoke_summary("user-1", tenant, payload) == "short" + + mock_max_tokens.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") + mock_prompt_tokens.assert_called_once_with( + tenant_id="tenant-1", + prompt_messages=[UserPromptMessage(content="short")], + user_id="user-1", + ) diff --git a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py new file mode 100644 index 00000000000..f8d0e127b1b --- /dev/null +++ b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py @@ -0,0 +1,506 @@ +"""Unit tests for the plugin-backed model runtime adapter.""" + +import datetime +import uuid +from types import SimpleNamespace +from unittest.mock import Mock, sentinel + +import pytest +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity + +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from core.plugin.impl import model_runtime as model_runtime_module +from core.plugin.impl.model import PluginModelClient +from core.plugin.impl.model_runtime import TENANT_SCOPE_SCHEMA_CACHE_USER_ID, PluginModelRuntime +from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime + + +def _build_model_schema() -> AIModelEntity: + return AIModelEntity( + model="gpt-4o-mini", + label=I18nObject(en_US="GPT-4o mini"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={}, + ) + + +class TestPluginModelRuntime: + """Validate the adapter keeps plugin-specific routing out of the runtime port.""" + + def test_fetch_model_providers_returns_runtime_entities(self) -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [ + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="openai", + tenant_id="tenant", + plugin_unique_identifier="langgenius/openai/openai", + plugin_id="langgenius/openai", + declaration=ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ) + ] + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + providers = runtime.fetch_model_providers() + + assert len(providers) == 1 + assert providers[0].provider == "langgenius/openai/openai" + assert providers[0].provider_name == "openai" + assert providers[0].label.en_US == "OpenAI" + client.fetch_model_providers.assert_called_once_with("tenant") + + def test_fetch_model_providers_only_exposes_short_name_for_canonical_provider(self) -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [ + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="openai", + tenant_id="tenant", + plugin_unique_identifier="acme/openai/openai", + plugin_id="acme/openai", + declaration=ProviderEntity( + provider="openai", + label=I18nObject(en_US="Acme OpenAI"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ), + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="openai", + tenant_id="tenant", + plugin_unique_identifier="langgenius/openai/openai", + plugin_id="langgenius/openai", + declaration=ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ), + ] + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + providers = runtime.fetch_model_providers() + + provider_aliases = {provider.provider: provider.provider_name for provider in providers} + assert provider_aliases["acme/openai/openai"] == "" + assert provider_aliases["langgenius/openai/openai"] == "openai" + + def test_fetch_model_providers_keeps_google_alias_on_canonical_gemini_provider(self) -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [ + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="google", + tenant_id="tenant", + plugin_unique_identifier="langgenius/gemini/google", + plugin_id="langgenius/gemini", + declaration=ProviderEntity( + provider="google", + label=I18nObject(en_US="Google"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ) + ] + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + providers = runtime.fetch_model_providers() + + assert providers[0].provider == "langgenius/gemini/google" + assert providers[0].provider_name == "google" + + def test_validate_provider_credentials_resolves_plugin_fields(self) -> None: + client = Mock(spec=PluginModelClient) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + runtime.validate_provider_credentials( + provider="langgenius/openai/openai", + credentials={"api_key": "secret"}, + ) + + client.validate_provider_credentials.assert_called_once_with( + tenant_id="tenant", + user_id="user", + plugin_id="langgenius/openai", + provider="openai", + credentials={"api_key": "secret"}, + ) + + def test_invoke_llm_resolves_plugin_fields(self) -> None: + client = Mock(spec=PluginModelClient) + client.invoke_llm.return_value = sentinel.result + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + result = runtime.invoke_llm( + provider="langgenius/openai/openai", + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + model_parameters={"temperature": 0.3}, + prompt_messages=[], + tools=None, + stop=None, + stream=False, + ) + + assert result is sentinel.result + client.invoke_llm.assert_called_once_with( + tenant_id="tenant", + user_id="user", + plugin_id="langgenius/openai", + provider="openai", + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + model_parameters={"temperature": 0.3}, + prompt_messages=[], + tools=None, + stop=None, + stream=False, + ) + + def test_invoke_llm_rejects_per_call_user_override(self) -> None: + client = Mock(spec=PluginModelClient) + client.invoke_llm.return_value = sentinel.result + runtime = PluginModelRuntime(tenant_id="tenant", user_id="bound-user", client=client) + + with pytest.raises(TypeError, match="unexpected keyword argument 'user_id'"): + runtime.invoke_llm( # type: ignore[call-arg] + provider="langgenius/openai/openai", + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + model_parameters={"temperature": 0.3}, + prompt_messages=[], + tools=None, + stop=None, + stream=False, + user_id="request-user", + ) + + client.invoke_llm.assert_not_called() + + def test_invoke_tts_uses_bound_runtime_user_when_runtime_is_unbound(self) -> None: + client = Mock(spec=PluginModelClient) + client.invoke_tts.return_value = iter([b"chunk"]) + runtime = PluginModelRuntime(tenant_id="tenant", user_id=None, client=client) + + result = runtime.invoke_tts( + provider="langgenius/openai/openai", + model="tts-1", + credentials={"api_key": "secret"}, + content_text="hello", + voice="alloy", + ) + + assert list(result) == [b"chunk"] + client.invoke_tts.assert_called_once_with( + tenant_id="tenant", + user_id=None, + plugin_id="langgenius/openai", + provider="openai", + model="tts-1", + credentials={"api_key": "secret"}, + content_text="hello", + voice="alloy", + ) + + def test_fetch_model_providers_uses_bound_runtime_cache(self) -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [] + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + runtime.fetch_model_providers() + runtime.fetch_model_providers() + + client.fetch_model_providers.assert_called_once_with("tenant") + + +def test_create_plugin_model_runtime_without_user_context() -> None: + runtime = create_plugin_model_runtime(tenant_id="tenant") + + assert runtime.user_id is None + + +def test_plugin_model_runtime_requires_client() -> None: + with pytest.raises(ValueError, match="client is required"): + PluginModelRuntime(tenant_id="tenant", user_id="user", client=None) # type: ignore[arg-type] + + +def test_get_model_schema_uses_cached_schema_without_hitting_client(monkeypatch: pytest.MonkeyPatch) -> None: + client = Mock(spec=PluginModelClient) + schema = _build_model_schema() + monkeypatch.setattr( + model_runtime_module, + "redis_client", + SimpleNamespace( + get=Mock(return_value=schema.model_dump_json()), + delete=Mock(), + setex=Mock(), + ), + ) + + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + result = runtime.get_model_schema( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + ) + + assert result == schema + client.get_model_schema.assert_not_called() + + +def test_get_model_schema_deletes_invalid_cache_and_refetches(monkeypatch: pytest.MonkeyPatch) -> None: + client = Mock(spec=PluginModelClient) + schema = _build_model_schema() + delete = Mock() + setex = Mock() + monkeypatch.setattr( + model_runtime_module, + "redis_client", + SimpleNamespace( + get=Mock(return_value="not-json"), + delete=delete, + setex=setex, + ), + ) + monkeypatch.setattr(model_runtime_module.dify_config, "PLUGIN_MODEL_SCHEMA_CACHE_TTL", 300) + client.get_model_schema.return_value = schema + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + result = runtime.get_model_schema( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + ) + + assert result == schema + delete.assert_called_once() + client.get_model_schema.assert_called_once_with( + tenant_id="tenant", + user_id="user", + plugin_id="langgenius/openai", + provider="openai", + model_type=ModelType.LLM.value, + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + ) + setex.assert_called_once() + + +def test_get_llm_num_tokens_returns_zero_when_plugin_counting_is_disabled(monkeypatch: pytest.MonkeyPatch) -> None: + client = Mock(spec=PluginModelClient) + monkeypatch.setattr(model_runtime_module.dify_config, "PLUGIN_BASED_TOKEN_COUNTING_ENABLED", False) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + assert ( + runtime.get_llm_num_tokens( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + prompt_messages=[], + tools=None, + ) + == 0 + ) + client.get_llm_num_tokens.assert_not_called() + + +def test_get_provider_icon_reads_requested_variant_and_detects_svg_mime(monkeypatch: pytest.MonkeyPatch) -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [ + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="openai", + tenant_id="tenant", + plugin_unique_identifier="langgenius/openai/openai", + plugin_id="langgenius/openai", + declaration=ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + icon_small=I18nObject(en_US="logo.svg"), + icon_small_dark=I18nObject(en_US="logo-dark.png"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ) + ] + fetch_asset = Mock(return_value=b"") + monkeypatch.setattr(model_runtime_module.PluginAssetManager, "fetch_asset", fetch_asset) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + icon_bytes, mime_type = runtime.get_provider_icon( + provider="langgenius/openai/openai", + icon_type="icon_small", + lang="en_US", + ) + + assert icon_bytes == b"" + assert mime_type == "image/svg+xml" + fetch_asset.assert_called_once_with(tenant_id="tenant", id="logo.svg") + + +def test_get_provider_icon_rejects_unsupported_types_and_missing_variants() -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [ + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="openai", + tenant_id="tenant", + plugin_unique_identifier="langgenius/openai/openai", + plugin_id="langgenius/openai", + declaration=ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ) + ] + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + with pytest.raises(ValueError, match="does not have small dark icon"): + runtime.get_provider_icon( + provider="langgenius/openai/openai", + icon_type="icon_small_dark", + lang="en_US", + ) + + with pytest.raises(ValueError, match="Unsupported icon type"): + runtime.get_provider_icon( + provider="langgenius/openai/openai", + icon_type="icon_large", + lang="en_US", + ) + + +def test_get_schema_cache_key_is_stable_across_credential_order() -> None: + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=Mock(spec=PluginModelClient)) + + first = runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"b": "2", "a": "1"}, + ) + second = runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"a": "1", "b": "2"}, + ) + + assert first == second + + +def test_get_schema_cache_key_separates_distinct_user_scopes() -> None: + first_runtime = PluginModelRuntime(tenant_id="tenant", user_id="user-a", client=Mock(spec=PluginModelClient)) + second_runtime = PluginModelRuntime(tenant_id="tenant", user_id="user-b", client=Mock(spec=PluginModelClient)) + + first = first_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"a": "1"}, + ) + second = second_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"a": "1"}, + ) + + assert first != second + + +def test_get_schema_cache_key_separates_tenant_scope_from_user_scope() -> None: + tenant_runtime = PluginModelRuntime(tenant_id="tenant", user_id=None, client=Mock(spec=PluginModelClient)) + user_runtime = PluginModelRuntime(tenant_id="tenant", user_id="user-a", client=Mock(spec=PluginModelClient)) + + tenant_key = tenant_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"a": "1"}, + ) + user_key = user_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"a": "1"}, + ) + + assert tenant_key != user_key + assert f":{TENANT_SCOPE_SCHEMA_CACHE_USER_ID}" in tenant_key + + +def test_get_schema_cache_key_separates_tenant_scope_from_empty_string_user_scope() -> None: + tenant_runtime = PluginModelRuntime(tenant_id="tenant", user_id=None, client=Mock(spec=PluginModelClient)) + empty_user_runtime = PluginModelRuntime(tenant_id="tenant", user_id="", client=Mock(spec=PluginModelClient)) + + tenant_key = tenant_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={}, + ) + empty_user_key = empty_user_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={}, + ) + + assert tenant_key != empty_user_key + assert empty_user_key.endswith(":") + assert TENANT_SCOPE_SCHEMA_CACHE_USER_ID not in empty_user_key + + +def test_get_provider_schema_supports_short_alias_and_rejects_invalid_provider() -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [ + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="openai", + tenant_id="tenant", + plugin_unique_identifier="langgenius/openai/openai", + plugin_id="langgenius/openai", + declaration=ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ) + ] + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + assert runtime._get_provider_schema("openai").provider == "langgenius/openai/openai" + + with pytest.raises(ValueError, match="Invalid provider"): + runtime._get_provider_schema("missing") diff --git a/api/tests/unit_tests/core/plugin/test_plugin_entities.py b/api/tests/unit_tests/core/plugin/test_plugin_entities.py index b0b64a601bc..a812b01c5bd 100644 --- a/api/tests/unit_tests/core/plugin/test_plugin_entities.py +++ b/api/tests/unit_tests/core/plugin/test_plugin_entities.py @@ -4,6 +4,12 @@ from enum import StrEnum import pytest from flask import Response +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) from pydantic import ValidationError from core.plugin.entities.endpoint import EndpointEntityWithInstance @@ -25,12 +31,6 @@ from core.plugin.entities.request import ( ) from core.plugin.utils.http_parser import serialize_response from core.tools.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - SystemPromptMessage, - ToolPromptMessage, - UserPromptMessage, -) class TestEndpointEntity: diff --git a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py index 4f038d4a5b6..3063ca01970 100644 --- a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py +++ b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py @@ -17,6 +17,14 @@ from unittest.mock import MagicMock, patch import httpx import pytest +from graphon.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic import BaseModel from core.plugin.entities.plugin_daemon import ( @@ -26,6 +34,7 @@ from core.plugin.entities.plugin_daemon import ( from core.plugin.impl.base import BasePluginClient from core.plugin.impl.exc import ( PluginDaemonBadRequestError, + PluginDaemonClientSideError, PluginDaemonInternalServerError, PluginDaemonNotFoundError, PluginDaemonUnauthorizedError, @@ -36,14 +45,6 @@ from core.plugin.impl.exc import ( ) from core.plugin.impl.plugin import PluginInstaller from core.plugin.impl.tool import PluginToolManager -from dify_graph.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError class TestPluginRuntimeExecution: @@ -557,7 +558,7 @@ class TestPluginRuntimeErrorHandling: with patch("httpx.request", return_value=mock_response, autospec=True): # Act & Assert - with pytest.raises(httpx.HTTPStatusError): + with pytest.raises(PluginDaemonInternalServerError): plugin_client._request_with_plugin_daemon_response("GET", "plugin/test-tenant/test", bool) def test_empty_data_response_error(self, plugin_client, mock_config): @@ -1808,8 +1809,8 @@ class TestPluginInstallerAdvanced: mock_response.raise_for_status = raise_for_status with patch("httpx.request", return_value=mock_response, autospec=True): - # Act & Assert - Should raise HTTPStatusError for 404 - with pytest.raises(httpx.HTTPStatusError): + # Act & Assert - Should raise PluginDaemonClientSideError for 404 + with pytest.raises(PluginDaemonClientSideError): installer.fetch_plugin_readme("test-tenant", "test-org/test-plugin", "en") def test_list_plugins_with_pagination(self, installer, mock_config): diff --git a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py index c7e94aa4cf5..90730dff5a4 100644 --- a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py +++ b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py @@ -1,13 +1,12 @@ from collections.abc import Generator import pytest +from graphon.file import File, FileTransferMethod, FileType from core.agent.entities import AgentInvokeMessage from core.plugin.utils.chunk_merger import FileChunk, merge_blob_chunks from core.plugin.utils.converter import convert_parameters_to_plugin_format from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolSelector -from dify_graph.file.enums import FileTransferMethod, FileType -from dify_graph.file.models import File class TestChunkMerger: diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 3d08525aba5..2b280dd6746 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -2,15 +2,8 @@ from typing import cast from unittest.mock import MagicMock, patch import pytest - -from configs import dify_config -from core.app.app_config.entities import ModelConfigEntity -from core.memory.token_buffer_memory import TokenBufferMemory -from core.prompt.advanced_prompt_transform import AdvancedPromptTransform -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig -from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.file import File, FileTransferMethod, FileType +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessageRole, @@ -18,6 +11,13 @@ from dify_graph.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) + +from configs import dify_config +from core.app.app_config.entities import ModelConfigEntity +from core.memory.token_buffer_memory import TokenBufferMemory +from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import Conversation @@ -145,7 +145,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg prompt_transform = AdvancedPromptTransform() prompt_transform._calculate_rest_token = MagicMock(return_value=2000) - with patch("dify_graph.file.file_manager.to_prompt_message_content", autospec=True) as mock_get_encoded_string: + with patch("graphon.file.file_manager.to_prompt_message_content", autospec=True) as mock_get_encoded_string: mock_get_encoded_string.return_value = ImagePromptMessageContent( url=str(files[0].remote_url), format="jpg", mime_type="image/jpg" ) diff --git a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py index 634703740c5..4a54649b289 100644 --- a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py @@ -1,18 +1,19 @@ from unittest.mock import MagicMock +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel + from core.app.entities.app_invoke_entities import ( ModelConfigWithCredentialsEntity, ) from core.entities.provider_configuration import ProviderModelBundle from core.memory.token_buffer_memory import TokenBufferMemory from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform -from dify_graph.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - SystemPromptMessage, - ToolPromptMessage, - UserPromptMessage, -) -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from models.model import Conversation diff --git a/api/tests/unit_tests/core/prompt/test_prompt_message.py b/api/tests/unit_tests/core/prompt/test_prompt_message.py index 9fc300348a0..a4b3960b0a2 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_message.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_message.py @@ -1,6 +1,4 @@ -from core.prompt.simple_prompt_transform import ModelMode -from core.prompt.utils.prompt_message_util import PromptMessageUtil -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, AudioPromptMessageContent, ImagePromptMessageContent, @@ -9,6 +7,9 @@ from dify_graph.model_runtime.entities.message_entities import ( UserPromptMessage, ) +from core.prompt.simple_prompt_transform import ModelMode +from core.prompt.utils.prompt_message_util import PromptMessageUtil + def test_build_prompt_message_with_prompt_message_contents(): prompt = UserPromptMessage(content=[TextPromptMessageContent(data="Hello, World!")]) diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py index d379e3067a7..e35ce2c48a9 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py @@ -2,16 +2,16 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.entities.model_entities import ModelPropertyKey from core.prompt.prompt_transform import PromptTransform -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey # from core.app.app_config.entities import ModelConfigEntity # from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle -# from dify_graph.model_runtime.entities.message_entities import UserPromptMessage -# from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule -# from dify_graph.model_runtime.entities.provider_entities import ProviderEntity -# from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +# from graphon.model_runtime.entities.message_entities import UserPromptMessage +# from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule +# from graphon.model_runtime.entities.provider_entities import ProviderEntity +# from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel # from core.prompt.prompt_transform import PromptTransform diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py index e6d28224d75..3f188cfbb4b 100644 --- a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -2,6 +2,12 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + TextPromptMessageContent, + UserPromptMessage, +) from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -18,12 +24,6 @@ from core.prompt.prompt_templates.advanced_prompt_templates import ( CONTEXT, ) from core.prompt.simple_prompt_transform import SimplePromptTransform -from dify_graph.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - TextPromptMessageContent, - UserPromptMessage, -) from models.model import AppMode, Conversation diff --git a/api/tests/unit_tests/core/rag/cleaner/test_clean_processor.py b/api/tests/unit_tests/core/rag/cleaner/test_clean_processor.py index 65ee62b8dd2..c7a4265a954 100644 --- a/api/tests/unit_tests/core/rag/cleaner/test_clean_processor.py +++ b/api/tests/unit_tests/core/rag/cleaner/test_clean_processor.py @@ -211,3 +211,16 @@ class TestCleanProcessor: text = "[Text with (parens) and symbols](https://example.com)" expected = "[Text with (parens) and symbols](https://example.com)" assert CleanProcessor.clean(text, process_rule) == expected + + def test_clean_remove_urls_emails_preserves_markdown_image_links(self): + """Remove plain URLs and emails while preserving markdown image links.""" + process_rule = {"rules": {"pre_processing_rules": [{"id": "remove_urls_emails", "enabled": True}]}} + text = "Email test@example.com and remove https://remove.com but keep ![diagram](https://example.com/image.png)" + result = CleanProcessor.clean(text, process_rule) + + assert result == "Email and remove but keep ![diagram](https://example.com/image.png)" + + def test_filter_string_returns_input_text(self): + """Test filter_string passthrough behavior.""" + processor = CleanProcessor() + assert processor.filter_string("raw text") == "raw text" diff --git a/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py b/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py new file mode 100644 index 00000000000..006b4e7345e --- /dev/null +++ b/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py @@ -0,0 +1,247 @@ +from unittest.mock import MagicMock, patch + +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError + +from core.rag.data_post_processor.data_post_processor import DataPostProcessor +from core.rag.data_post_processor.reorder import ReorderRunner +from core.rag.index_processor.constant.query_type import QueryType +from core.rag.models.document import Document +from core.rag.rerank.rerank_type import RerankMode + + +def _doc(content: str) -> Document: + return Document(page_content=content) + + +class TestDataPostProcessor: + def test_init_sets_rerank_and_reorder_runners(self): + rerank_runner = object() + reorder_runner = object() + + with patch.object(DataPostProcessor, "_get_rerank_runner", return_value=rerank_runner) as rerank_mock: + with patch.object(DataPostProcessor, "_get_reorder_runner", return_value=reorder_runner) as reorder_mock: + processor = DataPostProcessor( + tenant_id="tenant-1", + reranking_mode=RerankMode.WEIGHTED_SCORE, + reranking_model={"config": "value"}, + weights={"weight": "value"}, + reorder_enabled=True, + ) + + assert processor.rerank_runner is rerank_runner + assert processor.reorder_runner is reorder_runner + rerank_mock.assert_called_once_with( + RerankMode.WEIGHTED_SCORE, + "tenant-1", + {"config": "value"}, + {"weight": "value"}, + ) + reorder_mock.assert_called_once_with(True) + + def test_invoke_applies_rerank_then_reorder(self): + original_documents = [_doc("doc-a")] + reranked_documents = [_doc("doc-b")] + reordered_documents = [_doc("doc-c")] + + processor = DataPostProcessor.__new__(DataPostProcessor) + processor.rerank_runner = MagicMock() + processor.rerank_runner.run.return_value = reranked_documents + processor.reorder_runner = MagicMock() + processor.reorder_runner.run.return_value = reordered_documents + + result = processor.invoke( + query="how to test", + documents=original_documents, + score_threshold=0.3, + top_n=2, + query_type=QueryType.IMAGE_QUERY, + ) + + processor.rerank_runner.run.assert_called_once_with( + "how to test", + original_documents, + 0.3, + 2, + QueryType.IMAGE_QUERY, + ) + processor.reorder_runner.run.assert_called_once_with(reranked_documents) + assert result == reordered_documents + + def test_invoke_returns_original_documents_when_no_runner_is_configured(self): + documents = [_doc("doc-a"), _doc("doc-b")] + + processor = DataPostProcessor.__new__(DataPostProcessor) + processor.rerank_runner = None + processor.reorder_runner = None + + assert processor.invoke(query="query", documents=documents) == documents + + def test_get_rerank_runner_for_weighted_score(self): + weights_config = { + "vector_setting": { + "vector_weight": 0.7, + "embedding_provider_name": "provider-x", + "embedding_model_name": "embedding-y", + }, + "keyword_setting": {"keyword_weight": 0.3}, + } + expected_runner = object() + processor = DataPostProcessor.__new__(DataPostProcessor) + + with patch( + "core.rag.data_post_processor.data_post_processor.RerankRunnerFactory.create_rerank_runner", + return_value=expected_runner, + ) as factory_mock: + result = processor._get_rerank_runner( + reranking_mode=RerankMode.WEIGHTED_SCORE, + tenant_id="tenant-1", + reranking_model=None, + weights=weights_config, + ) + + assert result is expected_runner + kwargs = factory_mock.call_args.kwargs + assert kwargs["runner_type"] == RerankMode.WEIGHTED_SCORE + assert kwargs["tenant_id"] == "tenant-1" + assert kwargs["weights"].vector_setting.vector_weight == 0.7 + assert kwargs["weights"].vector_setting.embedding_provider_name == "provider-x" + assert kwargs["weights"].vector_setting.embedding_model_name == "embedding-y" + assert kwargs["weights"].keyword_setting.keyword_weight == 0.3 + + def test_get_rerank_runner_for_reranking_model_returns_none_without_model_instance(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + reranking_model = { + "reranking_provider_name": "provider-x", + "reranking_model_name": "model-y", + } + + with patch.object(DataPostProcessor, "_get_rerank_model_instance", return_value=None) as model_mock: + with patch( + "core.rag.data_post_processor.data_post_processor.RerankRunnerFactory.create_rerank_runner" + ) as factory_mock: + result = processor._get_rerank_runner( + reranking_mode=RerankMode.RERANKING_MODEL, + tenant_id="tenant-1", + reranking_model=reranking_model, + weights=None, + ) + + assert result is None + model_mock.assert_called_once_with("tenant-1", reranking_model) + factory_mock.assert_not_called() + + def test_get_rerank_runner_for_reranking_model_creates_runner_with_model_instance(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + model_instance = object() + expected_runner = object() + + with patch.object(DataPostProcessor, "_get_rerank_model_instance", return_value=model_instance): + with patch( + "core.rag.data_post_processor.data_post_processor.RerankRunnerFactory.create_rerank_runner", + return_value=expected_runner, + ) as factory_mock: + result = processor._get_rerank_runner( + reranking_mode=RerankMode.RERANKING_MODEL, + tenant_id="tenant-1", + reranking_model={ + "reranking_provider_name": "provider-x", + "reranking_model_name": "model-y", + }, + weights=None, + ) + + assert result is expected_runner + factory_mock.assert_called_once_with( + runner_type=RerankMode.RERANKING_MODEL, + rerank_model_instance=model_instance, + ) + + def test_get_rerank_runner_returns_none_for_unsupported_mode(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + + assert processor._get_rerank_runner("unsupported", "tenant-1", None, None) is None + assert processor._get_rerank_runner(RerankMode.WEIGHTED_SCORE, "tenant-1", None, None) is None + + def test_get_reorder_runner_by_flag(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + + assert isinstance(processor._get_reorder_runner(True), ReorderRunner) + assert processor._get_reorder_runner(False) is None + + def test_get_rerank_model_instance_returns_none_when_config_is_missing(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + assert processor._get_rerank_model_instance("tenant-1", None) is None + + def test_get_rerank_model_instance_returns_none_for_incomplete_config(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + + with patch("core.rag.data_post_processor.data_post_processor.ModelManager.for_tenant") as for_tenant_mock: + result = processor._get_rerank_model_instance( + tenant_id="tenant-1", + reranking_model={"reranking_provider_name": "provider-x"}, + ) + + assert result is None + for_tenant_mock.assert_called_once_with(tenant_id="tenant-1") + + def test_get_rerank_model_instance_success(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + model_instance = object() + + with patch("core.rag.data_post_processor.data_post_processor.ModelManager.for_tenant") as for_tenant_mock: + manager_instance = for_tenant_mock.return_value + manager_instance.get_model_instance.return_value = model_instance + + result = processor._get_rerank_model_instance( + tenant_id="tenant-1", + reranking_model={ + "reranking_provider_name": "provider-x", + "reranking_model_name": "reranker-1", + }, + ) + + assert result is model_instance + for_tenant_mock.assert_called_once_with(tenant_id="tenant-1") + manager_instance.get_model_instance.assert_called_once_with( + tenant_id="tenant-1", + provider="provider-x", + model_type=ModelType.RERANK, + model="reranker-1", + ) + + def test_get_rerank_model_instance_handles_authorization_error(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + + with patch("core.rag.data_post_processor.data_post_processor.ModelManager.for_tenant") as for_tenant_mock: + manager_instance = for_tenant_mock.return_value + manager_instance.get_model_instance.side_effect = InvokeAuthorizationError("not authorized") + + result = processor._get_rerank_model_instance( + tenant_id="tenant-1", + reranking_model={ + "reranking_provider_name": "provider-x", + "reranking_model_name": "reranker-1", + }, + ) + + assert result is None + for_tenant_mock.assert_called_once_with(tenant_id="tenant-1") + + +class TestReorderRunner: + def test_run_reorders_even_sized_document_list(self): + documents = [_doc("0"), _doc("1"), _doc("2"), _doc("3"), _doc("4"), _doc("5")] + + reordered = ReorderRunner().run(documents) + + assert [document.page_content for document in reordered] == ["0", "2", "4", "5", "3", "1"] + + def test_run_handles_odd_sized_and_empty_document_lists(self): + odd_documents = [_doc("0"), _doc("1"), _doc("2"), _doc("3"), _doc("4")] + runner = ReorderRunner() + + odd_reordered = runner.run(odd_documents) + + assert [document.page_content for document in odd_reordered] == ["0", "2", "4", "3", "1"] + assert runner.run([]) == [] diff --git a/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py new file mode 100644 index 00000000000..bbdd4769146 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py @@ -0,0 +1,410 @@ +import json +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +import core.rag.datasource.keyword.jieba.jieba as jieba_module +from core.rag.datasource.keyword.jieba.jieba import Jieba, dumps_with_sets, set_orjson_default +from core.rag.models.document import Document + + +class _DummyLock: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + +class _Field: + def __init__(self, name: str): + self._name = name + + def __eq__(self, other): + return ("eq", self._name, other) + + def in_(self, values): + return ("in", self._name, tuple(values)) + + +class _FakeQuery: + def __init__(self): + self.where_calls: list[tuple] = [] + + def where(self, *conditions): + self.where_calls.append(conditions) + return self + + +class _FakeExecuteResult: + def __init__(self, segments: list[SimpleNamespace]): + self._segments = segments + + def scalars(self): + return self + + def all(self): + return self._segments + + +class _FakeSelect: + def __init__(self): + self.where_conditions: tuple | None = None + + def where(self, *conditions): + self.where_conditions = conditions + return self + + +def _dataset_keyword_table(data_source_type: str = "database", keyword_table_dict: dict | None = None): + return SimpleNamespace( + data_source_type=data_source_type, + keyword_table_dict=keyword_table_dict, + keyword_table="", + ) + + +def _dataset(dataset_keyword_table=None, keyword_number=None): + return SimpleNamespace( + id="dataset-1", + tenant_id="tenant-1", + keyword_number=keyword_number, + dataset_keyword_table=dataset_keyword_table, + ) + + +@pytest.fixture +def patched_runtime(monkeypatch): + session = MagicMock() + db = SimpleNamespace(session=session) + storage = MagicMock() + lock = MagicMock(return_value=_DummyLock()) + redis_client = SimpleNamespace(lock=lock) + + monkeypatch.setattr(jieba_module, "db", db) + monkeypatch.setattr(jieba_module, "storage", storage) + monkeypatch.setattr(jieba_module, "redis_client", redis_client) + + return SimpleNamespace(session=session, storage=storage, lock=lock) + + +def test_create_indexes_documents_and_returns_self(monkeypatch, patched_runtime): + dataset = _dataset(_dataset_keyword_table(), keyword_number=2) + keyword = Jieba(dataset) + handler = MagicMock() + handler.extract_keywords.return_value = {"kw1", "kw2"} + + monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={})) + monkeypatch.setattr(keyword, "_update_segment_keywords", MagicMock()) + monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock()) + + result = keyword.create( + [ + Document(page_content="alpha", metadata={"doc_id": "node-1"}), + SimpleNamespace(page_content="ignored", metadata=None), + ] + ) + + assert result is keyword + keyword._update_segment_keywords.assert_called_once() + call_args = keyword._update_segment_keywords.call_args.args + assert call_args[0] == "dataset-1" + assert call_args[1] == "node-1" + assert set(call_args[2]) == {"kw1", "kw2"} + saved_table = keyword._save_dataset_keyword_table.call_args.args[0] + assert saved_table["kw1"] == {"node-1"} + assert saved_table["kw2"] == {"node-1"} + patched_runtime.lock.assert_called_once_with("keyword_indexing_lock_dataset-1", timeout=600) + + +def test_add_texts_supports_keywords_list_and_extract_fallback(monkeypatch, patched_runtime): + keyword = Jieba(_dataset(_dataset_keyword_table(), keyword_number=3)) + handler = MagicMock() + handler.extract_keywords.return_value = {"auto"} + + monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={})) + monkeypatch.setattr(keyword, "_update_segment_keywords", MagicMock()) + monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock()) + + texts = [ + Document(page_content="extract-this", metadata={"doc_id": "node-1"}), + Document(page_content="use-manual", metadata={"doc_id": "node-2"}), + ] + keyword.add_texts(texts, keywords_list=[[], ["manual"]]) + + assert keyword._update_segment_keywords.call_count == 2 + first_call = keyword._update_segment_keywords.call_args_list[0].args + second_call = keyword._update_segment_keywords.call_args_list[1].args + assert set(first_call[2]) == {"auto"} + assert second_call[2] == ["manual"] + keyword._save_dataset_keyword_table.assert_called_once() + + +def test_add_texts_without_keywords_list_always_uses_extractor(monkeypatch, patched_runtime): + keyword = Jieba(_dataset(_dataset_keyword_table(), keyword_number=1)) + handler = MagicMock() + handler.extract_keywords.return_value = {"from-extractor"} + + monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={})) + monkeypatch.setattr(keyword, "_update_segment_keywords", MagicMock()) + monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock()) + + keyword.add_texts([Document(page_content="content", metadata={"doc_id": "node-1"})]) + + handler.extract_keywords.assert_called_once_with("content", 1) + assert set(keyword._update_segment_keywords.call_args.args[2]) == {"from-extractor"} + + +def test_text_exists_handles_missing_and_existing_keyword_table(monkeypatch): + keyword = Jieba(_dataset(_dataset_keyword_table())) + + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value=None)) + assert keyword.text_exists("node-1") is False + + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={"k": {"node-1", "node-2"}})) + assert keyword.text_exists("node-2") is True + assert keyword.text_exists("node-x") is False + + +def test_delete_by_ids_updates_table_when_present(monkeypatch, patched_runtime): + keyword = Jieba(_dataset(_dataset_keyword_table())) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={"k": {"node-1", "node-2"}})) + monkeypatch.setattr(keyword, "_delete_ids_from_keyword_table", MagicMock(return_value={"k": {"node-2"}})) + monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock()) + + keyword.delete_by_ids(["node-1"]) + + keyword._delete_ids_from_keyword_table.assert_called_once_with({"k": {"node-1", "node-2"}}, ["node-1"]) + keyword._save_dataset_keyword_table.assert_called_once_with({"k": {"node-2"}}) + + +def test_delete_by_ids_saves_none_when_keyword_table_is_missing(monkeypatch, patched_runtime): + keyword = Jieba(_dataset(_dataset_keyword_table())) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value=None)) + monkeypatch.setattr(keyword, "_delete_ids_from_keyword_table", MagicMock()) + monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock()) + + keyword.delete_by_ids(["node-1"]) + + keyword._delete_ids_from_keyword_table.assert_not_called() + keyword._save_dataset_keyword_table.assert_called_once_with(None) + + +def test_search_returns_documents_in_rank_order_and_applies_filter(monkeypatch, patched_runtime): + class _FakeDocumentSegment: + dataset_id = _Field("dataset_id") + index_node_id = _Field("index_node_id") + document_id = _Field("document_id") + + keyword = Jieba(_dataset(_dataset_keyword_table())) + patched_runtime.session.scalars.return_value.all.return_value = [ + SimpleNamespace( + index_node_id="node-2", + content="segment-content", + index_node_hash="hash-2", + document_id="doc-2", + dataset_id="dataset-1", + ) + ] + + monkeypatch.setattr(jieba_module, "DocumentSegment", _FakeDocumentSegment) + monkeypatch.setattr(jieba_module, "select", lambda *_: _FakeSelect()) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={"k": {"node-1", "node-2"}})) + monkeypatch.setattr(keyword, "_retrieve_ids_by_query", MagicMock(return_value=["node-1", "node-2"])) + + documents = keyword.search("query", top_k=2, document_ids_filter=["doc-2"]) + + assert len(documents) == 1 + assert documents[0].page_content == "segment-content" + assert documents[0].metadata["doc_id"] == "node-2" + assert documents[0].metadata["doc_hash"] == "hash-2" + + +def test_delete_removes_keyword_table_and_optional_file(monkeypatch, patched_runtime): + db_keyword = _dataset_keyword_table(data_source_type="database") + file_keyword = _dataset_keyword_table(data_source_type="object_storage") + + keyword_db = Jieba(_dataset(db_keyword)) + keyword_db.delete() + patched_runtime.storage.delete.assert_not_called() + + keyword_file = Jieba(_dataset(file_keyword)) + keyword_file.delete() + + patched_runtime.storage.delete.assert_called_once_with("keyword_files/tenant-1/dataset-1.txt") + assert patched_runtime.session.delete.call_count == 2 + assert patched_runtime.session.commit.call_count == 2 + + +def test_save_dataset_keyword_table_to_database(monkeypatch, patched_runtime): + dataset_keyword_table = _dataset_keyword_table(data_source_type="database") + keyword = Jieba(_dataset(dataset_keyword_table)) + + keyword._save_dataset_keyword_table({"kw": {"node-1"}}) + + assert '"__type__":"keyword_table"' in dataset_keyword_table.keyword_table + assert '"index_id":"dataset-1"' in dataset_keyword_table.keyword_table + patched_runtime.session.commit.assert_called_once() + + +def test_save_dataset_keyword_table_to_file_storage(monkeypatch, patched_runtime): + dataset_keyword_table = _dataset_keyword_table(data_source_type="file") + keyword = Jieba(_dataset(dataset_keyword_table)) + patched_runtime.storage.exists.return_value = True + + keyword._save_dataset_keyword_table({"kw": {"node-1"}}) + + patched_runtime.storage.delete.assert_called_once_with("keyword_files/tenant-1/dataset-1.txt") + patched_runtime.storage.save.assert_called_once() + save_args = patched_runtime.storage.save.call_args.args + assert save_args[0] == "keyword_files/tenant-1/dataset-1.txt" + assert isinstance(save_args[1], bytes) + + +def test_get_dataset_keyword_table_returns_existing_table_data(monkeypatch, patched_runtime): + existing = _dataset_keyword_table( + keyword_table_dict={"__type__": "keyword_table", "__data__": {"table": {"kw": ["node-1"]}}} + ) + keyword = Jieba(_dataset(existing)) + assert keyword._get_dataset_keyword_table() == {"kw": ["node-1"]} + + missing_payload = _dataset_keyword_table(keyword_table_dict=None) + keyword_with_missing_payload = Jieba(_dataset(missing_payload)) + assert keyword_with_missing_payload._get_dataset_keyword_table() == {} + + +def test_get_dataset_keyword_table_creates_table_when_missing(monkeypatch, patched_runtime): + created_tables: list[SimpleNamespace] = [] + + def _fake_dataset_keyword_table(**kwargs): + kwargs.setdefault("keyword_table", "") + kwargs.setdefault("keyword_table_dict", None) + table = SimpleNamespace(**kwargs) + created_tables.append(table) + return table + + keyword = Jieba(_dataset(dataset_keyword_table=None)) + monkeypatch.setattr(jieba_module, "DatasetKeywordTable", _fake_dataset_keyword_table) + monkeypatch.setattr(jieba_module.dify_config, "KEYWORD_DATA_SOURCE_TYPE", "database") + + result = keyword._get_dataset_keyword_table() + + assert result == {} + assert len(created_tables) == 1 + assert created_tables[0].dataset_id == "dataset-1" + assert created_tables[0].data_source_type == "database" + assert '"index_id":"dataset-1"' in created_tables[0].keyword_table + patched_runtime.session.add.assert_called_once_with(created_tables[0]) + patched_runtime.session.commit.assert_called_once() + + +def test_add_and_delete_ids_from_keyword_table_helpers(): + keyword = Jieba(_dataset(_dataset_keyword_table())) + keyword_table = {"kw1": {"node-1"}, "kw2": {"node-1", "node-2"}} + + updated = keyword._add_text_to_keyword_table(keyword_table, "node-3", ["kw1", "kw3"]) + assert updated["kw1"] == {"node-1", "node-3"} + assert updated["kw3"] == {"node-3"} + + deleted = keyword._delete_ids_from_keyword_table(updated, ["node-1", "node-3"]) + assert "kw3" not in deleted + assert "kw1" not in deleted + assert deleted["kw2"] == {"node-2"} + + +def test_retrieve_ids_by_query_ranks_by_keyword_frequency(monkeypatch): + keyword = Jieba(_dataset(_dataset_keyword_table())) + handler = MagicMock() + handler.extract_keywords.return_value = ["kw-a", "kw-b"] + monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler) + + ranked_ids = keyword._retrieve_ids_by_query( + {"kw-a": {"node-1", "node-2"}, "kw-b": {"node-2"}, "kw-c": {"node-3"}}, + "query", + k=1, + ) + + assert ranked_ids == ["node-2"] + + +def test_update_segment_keywords_updates_when_segment_exists(monkeypatch, patched_runtime): + class _FakeDocumentSegment: + dataset_id = _Field("dataset_id") + index_node_id = _Field("index_node_id") + + monkeypatch.setattr(jieba_module, "DocumentSegment", _FakeDocumentSegment) + monkeypatch.setattr(jieba_module, "select", lambda *_: _FakeSelect()) + + keyword = Jieba(_dataset(_dataset_keyword_table())) + segment = SimpleNamespace(keywords=[]) + patched_runtime.session.scalar.return_value = segment + + keyword._update_segment_keywords("dataset-1", "node-1", ["kw1", "kw2"]) + + assert segment.keywords == ["kw1", "kw2"] + patched_runtime.session.add.assert_called_once_with(segment) + patched_runtime.session.commit.assert_called_once() + + patched_runtime.session.reset_mock() + patched_runtime.session.scalar.return_value = None + + keyword._update_segment_keywords("dataset-1", "node-missing", ["kw3"]) + + patched_runtime.session.add.assert_not_called() + patched_runtime.session.commit.assert_not_called() + + +def test_create_segment_keywords_and_update_segment_keywords_index(monkeypatch): + keyword = Jieba(_dataset(_dataset_keyword_table())) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={})) + monkeypatch.setattr(keyword, "_update_segment_keywords", MagicMock()) + monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock()) + + keyword.create_segment_keywords("node-1", ["kw"]) + keyword._update_segment_keywords.assert_called_once_with("dataset-1", "node-1", ["kw"]) + keyword._save_dataset_keyword_table.assert_called_once() + + keyword._save_dataset_keyword_table.reset_mock() + keyword.update_segment_keywords_index("node-2", ["kw2"]) + keyword._save_dataset_keyword_table.assert_called_once() + + +def test_multi_create_segment_keywords_uses_provided_and_extracted_keywords(monkeypatch): + keyword = Jieba(_dataset(_dataset_keyword_table(), keyword_number=2)) + handler = MagicMock() + handler.extract_keywords.return_value = {"auto"} + monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={})) + monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock()) + + first_segment = SimpleNamespace(index_node_id="node-1", content="first content", keywords=None) + second_segment = SimpleNamespace(index_node_id="node-2", content="second content", keywords=None) + + keyword.multi_create_segment_keywords( + [ + {"segment": first_segment, "keywords": ["manual"]}, + {"segment": second_segment, "keywords": []}, + ] + ) + + assert first_segment.keywords == ["manual"] + assert second_segment.keywords == ["auto"] + saved_table = keyword._save_dataset_keyword_table.call_args.args[0] + assert saved_table["manual"] == {"node-1"} + assert saved_table["auto"] == {"node-2"} + + +def test_set_orjson_default_and_dumps_with_sets(): + assert set(set_orjson_default({"a", "b"})) == {"a", "b"} + + with pytest.raises(TypeError, match="is not JSON serializable"): + set_orjson_default(("not", "a", "set")) + + payload = {"items": {"a", "b"}} + json_payload = dumps_with_sets(payload) + decoded = json.loads(json_payload) + assert set(decoded["items"]) == {"a", "b"} diff --git a/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba_keyword_table_handler.py b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba_keyword_table_handler.py new file mode 100644 index 00000000000..a4586c141bb --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba_keyword_table_handler.py @@ -0,0 +1,142 @@ +import sys +import types +from types import SimpleNamespace + +from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler +from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS + + +class _DummyTFIDF: + def __init__(self): + self.stop_words = set() + + @staticmethod + def extract_tags(sentence: str, top_k: int | None = 20, **kwargs): + return ["alpha_beta", "during", "gamma"] + + +def _install_fake_jieba_modules( + monkeypatch, + analyse_module: types.ModuleType, + jieba_attrs: dict[str, object] | None = None, + tfidf_module: types.ModuleType | None = None, +): + jieba_module = types.ModuleType("jieba") + jieba_module.__path__ = [] + if jieba_attrs: + for key, value in jieba_attrs.items(): + setattr(jieba_module, key, value) + + jieba_module.analyse = analyse_module + analyse_module.__package__ = "jieba" + + monkeypatch.setitem(sys.modules, "jieba", jieba_module) + monkeypatch.setitem(sys.modules, "jieba.analyse", analyse_module) + if tfidf_module is not None: + monkeypatch.setitem(sys.modules, "jieba.analyse.tfidf", tfidf_module) + else: + monkeypatch.delitem(sys.modules, "jieba.analyse.tfidf", raising=False) + + +def test_init_uses_existing_default_tfidf(monkeypatch): + analyse_module = types.ModuleType("jieba.analyse") + default_tfidf = _DummyTFIDF() + analyse_module.default_tfidf = default_tfidf + + _install_fake_jieba_modules(monkeypatch, analyse_module) + + handler = JiebaKeywordTableHandler() + + assert handler._tfidf is default_tfidf + assert handler._tfidf.stop_words == STOPWORDS + + +def test_load_tfidf_extractor_uses_tfidf_class_and_caches_default(monkeypatch): + analyse_module = types.ModuleType("jieba.analyse") + analyse_module.default_tfidf = None + + class _TFIDFFactory(_DummyTFIDF): + pass + + analyse_module.TFIDF = _TFIDFFactory + _install_fake_jieba_modules(monkeypatch, analyse_module) + + handler = JiebaKeywordTableHandler() + + assert isinstance(handler._tfidf, _TFIDFFactory) + assert analyse_module.default_tfidf is handler._tfidf + + +def test_load_tfidf_extractor_imports_from_tfidf_submodule(monkeypatch): + analyse_module = types.ModuleType("jieba.analyse") + analyse_module.default_tfidf = None + + tfidf_module = types.ModuleType("jieba.analyse.tfidf") + + class _ImportedTFIDF(_DummyTFIDF): + pass + + tfidf_module.TFIDF = _ImportedTFIDF + _install_fake_jieba_modules(monkeypatch, analyse_module, tfidf_module=tfidf_module) + + handler = JiebaKeywordTableHandler() + + assert isinstance(handler._tfidf, _ImportedTFIDF) + assert analyse_module.default_tfidf is handler._tfidf + + +def test_load_tfidf_extractor_falls_back_when_tfidf_unavailable(monkeypatch): + analyse_module = types.ModuleType("jieba.analyse") + analyse_module.default_tfidf = None + _install_fake_jieba_modules(monkeypatch, analyse_module) + + handler = JiebaKeywordTableHandler() + fallback_keywords = handler._tfidf.extract_tags("one two two and three", topK=1) + + assert fallback_keywords == ["two"] + + +def test_build_fallback_tfidf_uses_lcut_when_available(monkeypatch): + analyse_module = types.ModuleType("jieba.analyse") + _install_fake_jieba_modules(monkeypatch, analyse_module, jieba_attrs={"lcut": lambda _: ["x", "x", "y"]}) + + tfidf = JiebaKeywordTableHandler._build_fallback_tfidf() + + assert tfidf.extract_tags("ignored", topK=1) == ["x"] + + +def test_build_fallback_tfidf_uses_cut_when_lcut_is_missing(monkeypatch): + analyse_module = types.ModuleType("jieba.analyse") + _install_fake_jieba_modules( + monkeypatch, + analyse_module, + jieba_attrs={"cut": lambda _: iter(["foo", "foo", "bar"])}, + ) + + tfidf = JiebaKeywordTableHandler._build_fallback_tfidf() + + assert tfidf.extract_tags("ignored", topK=1) == ["foo"] + + +def test_extract_keywords_expands_subtokens(): + handler = JiebaKeywordTableHandler.__new__(JiebaKeywordTableHandler) + handler._tfidf = SimpleNamespace(extract_tags=lambda *_args, **_kwargs: ["alpha-beta", "during", "gamma"]) + + keywords = handler.extract_keywords("input text", max_keywords_per_chunk=3) + + assert "alpha-beta" in keywords + assert "alpha" in keywords + assert "beta" in keywords + assert "during" in keywords + assert "gamma" in keywords + + +def test_expand_tokens_with_subtokens_filters_stopwords_from_subtokens(): + handler = JiebaKeywordTableHandler.__new__(JiebaKeywordTableHandler) + + expanded = handler._expand_tokens_with_subtokens({"alpha-during-beta"}) + + assert "alpha-during-beta" in expanded + assert "alpha" in expanded + assert "beta" in expanded + assert "during" not in expanded diff --git a/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_stopwords.py b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_stopwords.py new file mode 100644 index 00000000000..1b1541ddd64 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_stopwords.py @@ -0,0 +1,6 @@ +from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS + + +def test_stopwords_loaded(): + assert "during" in STOPWORDS + assert "the" in STOPWORDS diff --git a/api/tests/unit_tests/core/rag/datasource/keyword/test_keyword_base.py b/api/tests/unit_tests/core/rag/datasource/keyword/test_keyword_base.py new file mode 100644 index 00000000000..55e22aea0ac --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/keyword/test_keyword_base.py @@ -0,0 +1,97 @@ +from types import SimpleNamespace + +import pytest + +from core.rag.datasource.keyword.keyword_base import BaseKeyword +from core.rag.models.document import Document + + +class _KeywordThatRaises(BaseKeyword): + def create(self, texts: list[Document], **kwargs): + return super().create(texts, **kwargs) + + def add_texts(self, texts: list[Document], **kwargs): + return super().add_texts(texts, **kwargs) + + def text_exists(self, id: str) -> bool: + return super().text_exists(id) + + def delete_by_ids(self, ids: list[str]): + return super().delete_by_ids(ids) + + def delete(self): + return super().delete() + + def search(self, query: str, **kwargs): + return super().search(query, **kwargs) + + +class _KeywordForHelpers(BaseKeyword): + def __init__(self, dataset, existing_ids: set[str] | None = None): + super().__init__(dataset) + self._existing_ids = existing_ids or set() + + def create(self, texts: list[Document], **kwargs): + return self + + def add_texts(self, texts: list[Document], **kwargs): + return None + + def text_exists(self, id: str) -> bool: + return id in self._existing_ids + + def delete_by_ids(self, ids: list[str]): + return None + + def delete(self): + return None + + def search(self, query: str, **kwargs): + return [] + + +def test_abstract_methods_raise_not_implemented(): + keyword = _KeywordThatRaises(SimpleNamespace(id="dataset-1")) + + with pytest.raises(NotImplementedError): + keyword.create([]) + + with pytest.raises(NotImplementedError): + keyword.add_texts([]) + + with pytest.raises(NotImplementedError): + keyword.text_exists("doc-1") + + with pytest.raises(NotImplementedError): + keyword.delete_by_ids(["doc-1"]) + + with pytest.raises(NotImplementedError): + keyword.delete() + + with pytest.raises(NotImplementedError): + keyword.search("query") + + +def test_filter_duplicate_texts_removes_existing_doc_ids(): + keyword = _KeywordForHelpers(SimpleNamespace(id="dataset-1"), existing_ids={"duplicate"}) + texts = [ + Document(page_content="keep", metadata={"doc_id": "keep"}), + Document(page_content="duplicate", metadata={"doc_id": "duplicate"}), + SimpleNamespace(page_content="without-metadata", metadata=None), + ] + + filtered = keyword._filter_duplicate_texts(texts) + + assert [text.metadata["doc_id"] for text in filtered if text.metadata] == ["keep"] + assert any(text.metadata is None for text in filtered) + + +def test_get_uuids_returns_only_docs_with_metadata(): + keyword = _KeywordForHelpers(SimpleNamespace(id="dataset-1")) + texts = [ + Document(page_content="doc-1", metadata={"doc_id": "doc-1"}), + Document(page_content="doc-2", metadata={"doc_id": "doc-2"}), + SimpleNamespace(page_content="doc-3", metadata=None), + ] + + assert keyword._get_uuids(texts) == ["doc-1", "doc-2"] diff --git a/api/tests/unit_tests/core/rag/datasource/keyword/test_keyword_factory.py b/api/tests/unit_tests/core/rag/datasource/keyword/test_keyword_factory.py new file mode 100644 index 00000000000..0d969a3270d --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/keyword/test_keyword_factory.py @@ -0,0 +1,84 @@ +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.rag.datasource.keyword.keyword_factory import Keyword +from core.rag.datasource.keyword.keyword_type import KeyWordType +from core.rag.models.document import Document + + +def test_get_keyword_factory_returns_jieba_factory(monkeypatch): + fake_module = types.ModuleType("core.rag.datasource.keyword.jieba.jieba") + + class FakeJieba: + pass + + fake_module.Jieba = FakeJieba + monkeypatch.setitem(sys.modules, "core.rag.datasource.keyword.jieba.jieba", fake_module) + + assert Keyword.get_keyword_factory(KeyWordType.JIEBA) is FakeJieba + + +def test_get_keyword_factory_raises_for_unsupported_type(): + with pytest.raises(ValueError, match="Keyword store unsupported is not supported"): + Keyword.get_keyword_factory("unsupported") + + +def test_keyword_initialization_uses_configured_factory(monkeypatch): + dataset = SimpleNamespace(id="dataset-1") + fake_processor = MagicMock() + + monkeypatch.setattr("core.rag.datasource.keyword.keyword_factory.dify_config.KEYWORD_STORE", KeyWordType.JIEBA) + monkeypatch.setattr(Keyword, "get_keyword_factory", staticmethod(lambda keyword_type: lambda _: fake_processor)) + + keyword = Keyword(dataset) + + assert keyword._keyword_processor is fake_processor + + +def test_keyword_methods_forward_to_processor(): + processor = MagicMock() + processor.text_exists.return_value = True + processor.search.return_value = [Document(page_content="matched", metadata={"doc_id": "doc-1"})] + + keyword = Keyword.__new__(Keyword) + keyword._keyword_processor = processor + + docs = [Document(page_content="doc", metadata={"doc_id": "doc-1"})] + keyword.create(docs, foo="bar") + keyword.add_texts(docs, batch=True) + assert keyword.text_exists("doc-1") is True + keyword.delete_by_ids(["doc-1"]) + keyword.delete() + assert keyword.search("query", top_k=1) == processor.search.return_value + + processor.create.assert_called_once_with(docs, foo="bar") + processor.add_texts.assert_called_once_with(docs, batch=True) + processor.text_exists.assert_called_once_with("doc-1") + processor.delete_by_ids.assert_called_once_with(["doc-1"]) + processor.delete.assert_called_once() + processor.search.assert_called_once_with("query", top_k=1) + + +def test_keyword_getattr_returns_callable_and_raises_for_invalid_attributes(): + class Processor: + value = 1 + + @staticmethod + def custom(): + return "ok" + + keyword = Keyword.__new__(Keyword) + keyword._keyword_processor = Processor() + + assert keyword.custom() == "ok" + + with pytest.raises(AttributeError): + _ = keyword.value + + keyword._keyword_processor = None + with pytest.raises(AttributeError): + _ = keyword.missing_method diff --git a/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py new file mode 100644 index 00000000000..5dbd62580aa --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py @@ -0,0 +1,1176 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, call, patch +from uuid import uuid4 + +import pytest + +from core.rag.datasource import retrieval_service as retrieval_service_module +from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.query_type import QueryType +from core.rag.models.document import Document +from core.rag.rerank.rerank_type import RerankMode +from core.rag.retrieval.retrieval_methods import RetrievalMethod +from models.dataset import Dataset + + +def create_mock_document( + content: str, + doc_id: str, + score: float = 0.8, + provider: str = "dify", + additional_metadata: dict | None = None, +) -> Document: + """ + Create a mock Document object for testing. + + This helper function standardizes document creation across tests, + ensuring consistent structure and reducing code duplication. + + Args: + content: The text content of the document + doc_id: Unique identifier for the document chunk + score: Relevance score (0.0 to 1.0) + provider: Document provider ("dify" or "external") + additional_metadata: Optional extra metadata fields + + Returns: + Document: A properly structured Document object + + Example: + >>> doc = create_mock_document("Python is great", "doc1", score=0.95) + >>> assert doc.metadata["score"] == 0.95 + """ + metadata = { + "doc_id": doc_id, + "document_id": str(uuid4()), + "dataset_id": str(uuid4()), + "score": score, + } + + # Merge additional metadata if provided + if additional_metadata: + metadata.update(additional_metadata) + + return Document( + page_content=content, + metadata=metadata, + provider=provider, + ) + + +class _ImmediateFuture: + def __init__(self, exception: Exception | None = None) -> None: + self._exception = exception + self.cancel_called = False + + def exception(self) -> Exception | None: + return self._exception + + def cancel(self) -> None: + self.cancel_called = True + + +class _ImmediateExecutor: + def __init__(self) -> None: + self.futures: list[_ImmediateFuture] = [] + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + def submit(self, fn, *args, **kwargs): + try: + fn(*args, **kwargs) + future = _ImmediateFuture() + except Exception as exc: # pragma: no cover - only for defensive parity with Future semantics + future = _ImmediateFuture(exc) + self.futures.append(future) + return future + + +class _FakeExecuteScalarResult: + def __init__(self, data: list) -> None: + self._data = data + + def all(self) -> list: + return self._data + + +class _FakeExecuteResult: + def __init__(self, data: list) -> None: + self._data = data + + def scalars(self) -> _FakeExecuteScalarResult: + return _FakeExecuteScalarResult(self._data) + + +class _FakeSummaryQuery: + def __init__(self, summaries: list) -> None: + self._summaries = summaries + + def filter(self, *args, **kwargs): + return self + + def all(self) -> list: + return self._summaries + + +class _FakeSession: + def __init__(self, execute_payloads: list[list], summaries: list) -> None: + self._payloads = list(execute_payloads) + self._summaries = summaries + + def execute(self, stmt): + data = self._payloads.pop(0) if self._payloads else [] + return _FakeExecuteResult(data) + + def query(self, model): + return _FakeSummaryQuery(self._summaries) + + +class _FakeSessionContext: + def __init__(self, session: _FakeSession) -> None: + self._session = session + + def __enter__(self) -> _FakeSession: + return self._session + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + +class _SimpleRetrievalChildChunk: + def __init__(self, id: str, content: str, score: float, position: int) -> None: + self.id = id + self.content = content + self.score = score + self.position = position + + +class _SimpleRetrievalSegment: + def __init__( + self, + segment, + child_chunks: list[_SimpleRetrievalChildChunk] | None = None, + score: float | None = None, + files: list[dict[str, str | int]] | None = None, + summary: str | None = None, + ) -> None: + self.segment = segment + self.child_chunks = child_chunks + self.score = score + self.files = files + self.summary = summary + + +class TestRetrievalServiceInternals: + @pytest.fixture + def internal_dataset(self) -> Dataset: + dataset = Mock(spec=Dataset) + dataset.id = "dataset-id" + dataset.tenant_id = "tenant-id" + dataset.is_multimodal = False + dataset.doc_form = IndexStructureType.PARENT_CHILD_INDEX + return dataset + + @pytest.fixture + def internal_flask_app(self): + app = MagicMock() + app.app_context.return_value.__enter__ = Mock() + app.app_context.return_value.__exit__.return_value = False + return app + + def test_retrieve_with_attachment_ids_only(self, monkeypatch, internal_dataset): + with ( + patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset", return_value=internal_dataset), + patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") as mock_retrieve, + ): + executor = _ImmediateExecutor() + monkeypatch.setattr(retrieval_service_module, "ThreadPoolExecutor", lambda *args, **kwargs: executor) + monkeypatch.setattr( + retrieval_service_module.concurrent.futures, + "as_completed", + lambda futures, timeout=None: iter(futures), + ) + + def side_effect( + flask_app, + retrieval_method, + dataset, + all_documents, + exceptions, + query=None, + top_k=4, + score_threshold=0.0, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, + document_ids_filter=None, + attachment_id=None, + ): + all_documents.append(create_mock_document(f"content-{attachment_id}", attachment_id or "none", 0.9)) + + mock_retrieve.side_effect = side_effect + + results = RetrievalService.retrieve( + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + dataset_id=internal_dataset.id, + query="", + attachment_ids=["att-1", "att-2"], + ) + + assert len(results) == 2 + assert {doc.metadata["doc_id"] for doc in results} == {"att-1", "att-2"} + assert mock_retrieve.call_count == 2 + + @patch("core.rag.datasource.retrieval_service.ExternalDatasetService.fetch_external_knowledge_retrieval") + @patch("core.rag.datasource.retrieval_service.MetadataCondition.model_validate") + @patch("core.rag.datasource.retrieval_service.db.session.scalar") + def test_external_retrieve_with_metadata_conditions(self, mock_scalar, mock_validate, mock_fetch): + mock_scalar.return_value = SimpleNamespace(tenant_id="tenant-1") + mock_validate.return_value = "validated-condition" + expected_documents = [create_mock_document("external-doc", "external-1", 0.8, provider="external")] + mock_fetch.return_value = expected_documents + + results = RetrievalService.external_retrieve( + dataset_id="dataset-1", + query="test query", + external_retrieval_model={"top_k": 3}, + metadata_filtering_conditions={"field": "source", "operator": "contains", "value": "manual"}, + ) + + assert results == expected_documents + mock_validate.assert_called_once() + mock_fetch.assert_called_once_with( + "tenant-1", + "dataset-1", + "test query", + {"top_k": 3}, + metadata_condition="validated-condition", + ) + + @patch("core.rag.datasource.retrieval_service.db.session.scalar") + def test_external_retrieve_returns_empty_when_dataset_not_found(self, mock_scalar): + mock_scalar.return_value = None + + results = RetrievalService.external_retrieve(dataset_id="missing", query="q") + + assert results == [] + + @patch("core.rag.datasource.retrieval_service.Session") + def test_get_dataset_queries_by_id(self, mock_session_class): + expected_dataset = Mock(spec=Dataset) + mock_session = Mock() + mock_session.query.return_value.where.return_value.first.return_value = expected_dataset + mock_session_class.return_value.__enter__.return_value = mock_session + + with patch.object(retrieval_service_module, "db", SimpleNamespace(engine=Mock())): + result = RetrievalService._get_dataset("dataset-123") + + assert result == expected_dataset + mock_session.query.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.Keyword") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_keyword_search_success(self, mock_get_dataset, mock_keyword_class, internal_dataset, internal_flask_app): + mock_get_dataset.return_value = internal_dataset + keyword_instance = Mock() + keyword_instance.search.return_value = [create_mock_document("keyword-content", "kw-1", 0.91)] + mock_keyword_class.return_value = keyword_instance + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.keyword_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query='query "with quotes"', + top_k=5, + all_documents=all_documents, + exceptions=exceptions, + ) + + assert len(all_documents) == 1 + assert exceptions == [] + keyword_instance.search.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_keyword_search_appends_exception_when_dataset_missing(self, mock_get_dataset, internal_flask_app): + mock_get_dataset.return_value = None + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.keyword_search( + flask_app=internal_flask_app, + dataset_id="dataset-id", + query="query", + top_k=2, + all_documents=all_documents, + exceptions=exceptions, + ) + + assert all_documents == [] + assert exceptions == ["dataset not found"] + + @patch("core.rag.datasource.retrieval_service.Keyword") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_keyword_search_appends_exception_when_search_fails( + self, mock_get_dataset, mock_keyword_class, internal_dataset, internal_flask_app + ): + mock_get_dataset.return_value = internal_dataset + keyword_instance = Mock() + keyword_instance.search.side_effect = RuntimeError("keyword failed") + mock_keyword_class.return_value = keyword_instance + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.keyword_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="query", + top_k=2, + all_documents=all_documents, + exceptions=exceptions, + ) + + assert all_documents == [] + assert exceptions == ["keyword failed"] + + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_embedding_search_text_without_reranking( + self, mock_get_dataset, mock_vector_class, internal_dataset, internal_flask_app + ): + internal_dataset.is_multimodal = False + mock_get_dataset.return_value = internal_dataset + vector_instance = Mock() + vector_instance.search_by_vector.return_value = [create_mock_document("vector-content", "vec-1", 0.7)] + mock_vector_class.return_value = vector_instance + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.embedding_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="query", + top_k=4, + score_threshold=0.5, + reranking_model=None, + all_documents=all_documents, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + exceptions=exceptions, + document_ids_filter=["doc-1"], + query_type=QueryType.TEXT_QUERY, + ) + + assert len(all_documents) == 1 + assert exceptions == [] + vector_instance.search_by_vector.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_embedding_search_image_non_multimodal_returns_early( + self, mock_get_dataset, mock_vector_class, internal_dataset, internal_flask_app + ): + internal_dataset.is_multimodal = False + mock_get_dataset.return_value = internal_dataset + vector_instance = Mock() + mock_vector_class.return_value = vector_instance + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.embedding_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="file-1", + top_k=4, + score_threshold=0.5, + reranking_model=None, + all_documents=all_documents, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + exceptions=exceptions, + query_type=QueryType.IMAGE_QUERY, + ) + + assert all_documents == [] + assert exceptions == [] + vector_instance.search_by_file.assert_not_called() + + @patch("core.rag.datasource.retrieval_service.ModelManager.for_tenant") + @patch("core.rag.datasource.retrieval_service.DataPostProcessor") + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_embedding_search_image_multimodal_with_vision_reranking( + self, + mock_get_dataset, + mock_vector_class, + mock_processor_class, + mock_model_manager_class, + internal_dataset, + internal_flask_app, + ): + internal_dataset.is_multimodal = True + mock_get_dataset.return_value = internal_dataset + original_docs = [create_mock_document("image-content", "img-doc", 0.73)] + reranked_docs = [create_mock_document("image-content-reranked", "img-doc", 0.97)] + + vector_instance = Mock() + vector_instance.search_by_file.return_value = original_docs + mock_vector_class.return_value = vector_instance + + processor_instance = Mock() + processor_instance.invoke.return_value = reranked_docs + mock_processor_class.return_value = processor_instance + + model_manager = Mock() + model_manager.check_model_support_vision.return_value = True + mock_model_manager_class.return_value = model_manager + + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.embedding_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="file-id", + top_k=4, + score_threshold=0.5, + reranking_model={ + "reranking_provider_name": "provider", + "reranking_model_name": "model", + }, + all_documents=all_documents, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + exceptions=exceptions, + query_type=QueryType.IMAGE_QUERY, + ) + + assert all_documents == reranked_docs + assert exceptions == [] + processor_instance.invoke.assert_called_once() + mock_model_manager_class.assert_called_once_with(tenant_id=internal_dataset.tenant_id) + model_manager.check_model_support_vision.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.ModelManager.for_tenant") + @patch("core.rag.datasource.retrieval_service.DataPostProcessor") + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_embedding_search_image_multimodal_without_vision_support( + self, + mock_get_dataset, + mock_vector_class, + mock_processor_class, + mock_model_manager_class, + internal_dataset, + internal_flask_app, + ): + internal_dataset.is_multimodal = True + mock_get_dataset.return_value = internal_dataset + original_docs = [create_mock_document("image-content", "img-doc", 0.73)] + + vector_instance = Mock() + vector_instance.search_by_file.return_value = original_docs + mock_vector_class.return_value = vector_instance + + processor_instance = Mock() + processor_instance.invoke.return_value = [create_mock_document("unused", "unused", 0.1)] + mock_processor_class.return_value = processor_instance + + model_manager = Mock() + model_manager.check_model_support_vision.return_value = False + mock_model_manager_class.return_value = model_manager + + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.embedding_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="file-id", + top_k=4, + score_threshold=0.5, + reranking_model={ + "reranking_provider_name": "provider", + "reranking_model_name": "model", + }, + all_documents=all_documents, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + exceptions=exceptions, + query_type=QueryType.IMAGE_QUERY, + ) + + assert all_documents == original_docs + assert exceptions == [] + mock_model_manager_class.assert_called_once_with(tenant_id=internal_dataset.tenant_id) + processor_instance.invoke.assert_not_called() + + @patch("core.rag.datasource.retrieval_service.DataPostProcessor") + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_embedding_search_text_with_reranking_non_multimodal( + self, mock_get_dataset, mock_vector_class, mock_processor_class, internal_dataset, internal_flask_app + ): + internal_dataset.is_multimodal = False + mock_get_dataset.return_value = internal_dataset + original_docs = [create_mock_document("vector-content", "vec-doc", 0.62)] + reranked_docs = [create_mock_document("vector-content-reranked", "vec-doc", 0.89)] + + vector_instance = Mock() + vector_instance.search_by_vector.return_value = original_docs + mock_vector_class.return_value = vector_instance + + processor_instance = Mock() + processor_instance.invoke.return_value = reranked_docs + mock_processor_class.return_value = processor_instance + + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.embedding_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="query", + top_k=4, + score_threshold=0.5, + reranking_model={ + "reranking_provider_name": "provider", + "reranking_model_name": "model", + }, + all_documents=all_documents, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + exceptions=exceptions, + query_type=QueryType.TEXT_QUERY, + ) + + assert all_documents == reranked_docs + assert exceptions == [] + processor_instance.invoke.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_embedding_search_appends_exception_when_vector_fails( + self, mock_get_dataset, mock_vector_class, internal_dataset, internal_flask_app + ): + mock_get_dataset.return_value = internal_dataset + vector_instance = Mock() + vector_instance.search_by_vector.side_effect = RuntimeError("vector failed") + mock_vector_class.return_value = vector_instance + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.embedding_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="query", + top_k=4, + score_threshold=0.5, + reranking_model=None, + all_documents=all_documents, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + exceptions=exceptions, + query_type=QueryType.TEXT_QUERY, + ) + + assert all_documents == [] + assert exceptions == ["vector failed"] + + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_full_text_index_search_without_reranking( + self, mock_get_dataset, mock_vector_class, internal_dataset, internal_flask_app + ): + mock_get_dataset.return_value = internal_dataset + vector_instance = Mock() + vector_instance.search_by_full_text.return_value = [create_mock_document("fulltext", "ft-1", 0.68)] + mock_vector_class.return_value = vector_instance + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.full_text_index_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query='query "x"', + top_k=4, + score_threshold=0.4, + reranking_model=None, + all_documents=all_documents, + retrieval_method=RetrievalMethod.FULL_TEXT_SEARCH, + exceptions=exceptions, + ) + + assert len(all_documents) == 1 + assert exceptions == [] + vector_instance.search_by_full_text.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.DataPostProcessor") + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_full_text_index_search_with_reranking( + self, mock_get_dataset, mock_vector_class, mock_processor_class, internal_dataset, internal_flask_app + ): + mock_get_dataset.return_value = internal_dataset + original_docs = [create_mock_document("fulltext", "ft-1", 0.68)] + reranked_docs = [create_mock_document("fulltext-reranked", "ft-1", 0.9)] + + vector_instance = Mock() + vector_instance.search_by_full_text.return_value = original_docs + mock_vector_class.return_value = vector_instance + + processor_instance = Mock() + processor_instance.invoke.return_value = reranked_docs + mock_processor_class.return_value = processor_instance + + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.full_text_index_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="query", + top_k=4, + score_threshold=0.4, + reranking_model={ + "reranking_provider_name": "provider", + "reranking_model_name": "model", + }, + all_documents=all_documents, + retrieval_method=RetrievalMethod.FULL_TEXT_SEARCH, + exceptions=exceptions, + ) + + assert all_documents == reranked_docs + assert exceptions == [] + processor_instance.invoke.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_full_text_index_search_dataset_not_found(self, mock_get_dataset, internal_flask_app): + mock_get_dataset.return_value = None + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.full_text_index_search( + flask_app=internal_flask_app, + dataset_id="dataset-id", + query="query", + top_k=4, + score_threshold=0.4, + reranking_model=None, + all_documents=all_documents, + retrieval_method=RetrievalMethod.FULL_TEXT_SEARCH, + exceptions=exceptions, + ) + + assert all_documents == [] + assert exceptions == ["dataset not found"] + + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_full_text_index_search_appends_exception_when_search_fails( + self, mock_get_dataset, mock_vector_class, internal_dataset, internal_flask_app + ): + mock_get_dataset.return_value = internal_dataset + vector_instance = Mock() + vector_instance.search_by_full_text.side_effect = RuntimeError("fulltext failed") + mock_vector_class.return_value = vector_instance + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.full_text_index_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="query", + top_k=4, + score_threshold=0.4, + reranking_model=None, + all_documents=all_documents, + retrieval_method=RetrievalMethod.FULL_TEXT_SEARCH, + exceptions=exceptions, + ) + + assert all_documents == [] + assert exceptions == ["fulltext failed"] + + def test_format_retrieval_documents_with_empty_input_returns_empty_list(self): + assert RetrievalService.format_retrieval_documents([]) == [] + + def test_format_retrieval_documents_without_document_id_returns_empty_list(self): + documents = [Document(page_content="content", metadata={"doc_id": "doc-1", "score": 0.4}, provider="dify")] + + assert RetrievalService.format_retrieval_documents(documents) == [] + + def test_format_retrieval_documents_with_parent_child_summary_and_attachments(self, monkeypatch): + dataset_doc_parent = SimpleNamespace( + id="doc-parent", + doc_form=IndexStructureType.PARENT_CHILD_INDEX, + dataset_id="dataset-id", + ) + dataset_doc_text = SimpleNamespace(id="doc-text", doc_form="paragraph", dataset_id="dataset-id") + dataset_doc_parent_summary = SimpleNamespace( + id="doc-parent-summary", + doc_form=IndexStructureType.PARENT_CHILD_INDEX, + dataset_id="dataset-id", + ) + + scalars_result = Mock() + scalars_result.all.return_value = [ + dataset_doc_parent, + dataset_doc_text, + dataset_doc_parent_summary, + ] + monkeypatch.setattr(retrieval_service_module.db.session, "scalars", Mock(return_value=scalars_result)) + monkeypatch.setattr(retrieval_service_module, "RetrievalChildChunk", _SimpleRetrievalChildChunk) + monkeypatch.setattr(retrieval_service_module, "RetrievalSegments", _SimpleRetrievalSegment) + + input_documents = [ + Document( + page_content="child node content", + metadata={"document_id": "doc-parent", "doc_id": "child-node-1", "score": 0.7}, + provider="dify", + ), + Document( + page_content="parent image", + metadata={ + "document_id": "doc-parent", + "doc_id": "attach-node-1", + "doc_type": DocType.IMAGE, + "score": 0.8, + }, + provider="dify", + ), + Document( + page_content="text index node", + metadata={"document_id": "doc-text", "doc_id": "index-node-1", "score": 0.6}, + provider="dify", + ), + Document( + page_content="text image node", + metadata={ + "document_id": "doc-text", + "doc_id": "attach-text-1", + "doc_type": DocType.IMAGE, + "score": 0.65, + }, + provider="dify", + ), + Document( + page_content="summary candidate 1", + metadata={ + "document_id": "doc-text", + "doc_id": "summary-node-1", + "is_summary": True, + "original_chunk_id": "segment-summary", + "score": "0.9", + }, + provider="dify", + ), + Document( + page_content="summary candidate 2", + metadata={ + "document_id": "doc-text", + "doc_id": "summary-node-2", + "is_summary": True, + "original_chunk_id": "segment-summary", + "score": "0.95", + }, + provider="dify", + ), + Document( + page_content="invalid score summary", + metadata={ + "document_id": "doc-parent-summary", + "doc_id": "summary-parent-invalid", + "is_summary": True, + "original_chunk_id": "segment-parent-summary", + "score": "invalid", + }, + provider="dify", + ), + Document( + page_content="valid parent summary", + metadata={ + "document_id": "doc-parent-summary", + "doc_id": "summary-parent-valid", + "is_summary": True, + "original_chunk_id": "segment-parent-summary", + "score": "0.4", + }, + provider="dify", + ), + ] + + child_chunk = SimpleNamespace( + id="child-chunk-1", + segment_id="segment-parent", + index_node_id="child-node-1", + content="child details", + position=2, + ) + segment_parent = SimpleNamespace(id="segment-parent", document_id="doc-parent", index_node_id="parent-node") + segment_text = SimpleNamespace(id="segment-text", document_id="doc-text", index_node_id="index-node-1") + segment_summary = SimpleNamespace(id="segment-summary", document_id="doc-text", index_node_id="summary-node") + segment_parent_summary = SimpleNamespace( + id="segment-parent-summary", + document_id="doc-parent-summary", + index_node_id="summary-parent-node", + ) + + fake_session = _FakeSession( + execute_payloads=[ + [child_chunk], + [segment_text], + [segment_parent, segment_text], + [segment_summary, segment_parent_summary], + ], + summaries=[ + SimpleNamespace(chunk_id="segment-summary", summary_content="summary for text"), + SimpleNamespace(chunk_id="segment-parent-summary", summary_content="summary for parent"), + ], + ) + monkeypatch.setattr( + retrieval_service_module.session_factory, + "create_session", + lambda: _FakeSessionContext(fake_session), + ) + monkeypatch.setattr( + RetrievalService, + "get_segment_attachment_infos", + lambda attachment_ids, session: [ + { + "attachment_id": "attach-node-1", + "attachment_info": { + "id": "attach-node-1", + "name": "img-parent", + "extension": ".png", + "mime_type": "image/png", + "source_url": "signed://parent", + "size": 11, + }, + "segment_id": "segment-parent", + }, + { + "attachment_id": "attach-text-1", + "attachment_info": { + "id": "attach-text-1", + "name": "img-text", + "extension": ".png", + "mime_type": "image/png", + "source_url": "signed://text", + "size": 22, + }, + "segment_id": "segment-text", + }, + ], + ) + + result = RetrievalService.format_retrieval_documents(input_documents) + + assert len(result) == 4 + result_by_segment_id = {item.segment.id: item for item in result} + assert result_by_segment_id["segment-summary"].score == pytest.approx(0.95) + assert result_by_segment_id["segment-summary"].summary == "summary for text" + assert result_by_segment_id["segment-parent"].score == pytest.approx(0.8) + assert result_by_segment_id["segment-parent"].files is not None + assert len(result_by_segment_id["segment-parent"].child_chunks or []) == 1 + assert result_by_segment_id["segment-text"].score == pytest.approx(0.65) + assert result_by_segment_id["segment-parent-summary"].score == pytest.approx(0.4) + assert result_by_segment_id["segment-parent-summary"].summary == "summary for parent" + assert result_by_segment_id["segment-parent-summary"].child_chunks == [] + + def test_format_retrieval_documents_rolls_back_and_raises_when_db_fails(self, monkeypatch): + rollback = Mock() + monkeypatch.setattr(retrieval_service_module.db.session, "rollback", rollback) + monkeypatch.setattr(retrieval_service_module.db.session, "scalars", Mock(side_effect=RuntimeError("db error"))) + + documents = [Document(page_content="content", metadata={"document_id": "doc-1"}, provider="dify")] + + with pytest.raises(RuntimeError, match="db error"): + RetrievalService.format_retrieval_documents(documents) + + rollback.assert_called_once() + + def test_retrieve_internal_returns_early_without_query_or_attachment(self, internal_dataset, internal_flask_app): + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService()._retrieve( + flask_app=internal_flask_app, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + dataset=internal_dataset, + all_documents=all_documents, + exceptions=exceptions, + query=None, + attachment_id=None, + ) + + assert all_documents == [] + assert exceptions == [] + + def test_retrieve_internal_cancels_futures_when_future_has_exception(self, internal_dataset, internal_flask_app): + future_error = Mock() + future_error.exception.return_value = RuntimeError("future failed") + future_ok = Mock() + future_ok.exception.return_value = None + + with ( + patch("core.rag.datasource.retrieval_service.ThreadPoolExecutor") as mock_executor, + patch( + "core.rag.datasource.retrieval_service.concurrent.futures.as_completed", + return_value=[future_error, future_ok], + ), + ): + mock_executor_instance = Mock() + mock_executor_instance.submit.side_effect = [future_error, future_ok] + mock_executor.return_value.__enter__.return_value = mock_executor_instance + RetrievalService()._retrieve( + flask_app=internal_flask_app, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + dataset=internal_dataset, + all_documents=[], + exceptions=[], + query="query", + attachment_id="file-1", + ) + + future_error.cancel.assert_called() + future_ok.cancel.assert_called() + + def test_retrieve_internal_raises_value_error_when_exceptions_exist( + self, monkeypatch, internal_dataset, internal_flask_app + ): + executor = _ImmediateExecutor() + monkeypatch.setattr(retrieval_service_module, "ThreadPoolExecutor", lambda *args, **kwargs: executor) + monkeypatch.setattr( + retrieval_service_module.concurrent.futures, + "as_completed", + lambda futures, timeout=None: iter(futures), + ) + + with patch("core.rag.datasource.retrieval_service.RetrievalService.keyword_search") as mock_keyword_search: + mock_keyword_search.side_effect = lambda *args, **kwargs: None + with pytest.raises(ValueError, match="keyword error"): + RetrievalService()._retrieve( + flask_app=internal_flask_app, + retrieval_method=RetrievalMethod.KEYWORD_SEARCH, + dataset=internal_dataset, + all_documents=[], + exceptions=["keyword error"], + query="query", + ) + + def test_retrieve_internal_hybrid_weighted_attachment_flow(self, monkeypatch, internal_dataset, internal_flask_app): + executor = _ImmediateExecutor() + monkeypatch.setattr(retrieval_service_module, "ThreadPoolExecutor", lambda *args, **kwargs: executor) + monkeypatch.setattr( + retrieval_service_module.concurrent.futures, + "as_completed", + lambda futures, timeout=None: iter(futures), + ) + + text_doc = create_mock_document("text", "text-doc", 0.81) + image_doc = create_mock_document("image", "image-doc", 0.72) + fulltext_doc = create_mock_document("full", "full-doc", 0.65) + processed_doc = create_mock_document("processed", "processed-doc", 0.99) + + with ( + patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") as mock_embedding_search, + patch("core.rag.datasource.retrieval_service.RetrievalService.full_text_index_search") as mock_fulltext, + patch("core.rag.datasource.retrieval_service.DataPostProcessor") as mock_processor_class, + ): + + def embedding_side_effect( + flask_app, + dataset_id, + query, + top_k, + score_threshold, + reranking_model, + all_documents, + retrieval_method, + exceptions, + document_ids_filter=None, + query_type=QueryType.TEXT_QUERY, + ): + if query_type == QueryType.IMAGE_QUERY: + all_documents.append(image_doc) + else: + all_documents.append(text_doc) + + mock_embedding_search.side_effect = embedding_side_effect + + def fulltext_side_effect( + flask_app, + dataset_id, + query, + top_k, + score_threshold, + reranking_model, + all_documents, + retrieval_method, + exceptions, + document_ids_filter=None, + ): + all_documents.append(fulltext_doc) + + mock_fulltext.side_effect = fulltext_side_effect + processor_instance = Mock() + processor_instance.invoke.return_value = [processed_doc] + mock_processor_class.return_value = processor_instance + + all_documents: list[Document] = [] + RetrievalService()._retrieve( + flask_app=internal_flask_app, + retrieval_method=RetrievalMethod.HYBRID_SEARCH, + dataset=internal_dataset, + all_documents=all_documents, + exceptions=[], + query="query", + attachment_id="file-1", + reranking_mode=RerankMode.WEIGHTED_SCORE, + top_k=3, + ) + + assert len(all_documents) == 4 + assert any(doc.metadata["doc_id"] == "processed-doc" for doc in all_documents) + processor_instance.invoke.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.sign_upload_file", return_value="signed://file") + def test_get_segment_attachment_info_success(self, mock_sign): + upload_file = SimpleNamespace( + id="upload-1", + name="file-name", + extension="png", + mime_type="image/png", + size=42, + ) + binding = SimpleNamespace(segment_id="segment-1", attachment_id="upload-1") + upload_query = Mock() + upload_query.where.return_value.first.return_value = upload_file + binding_query = Mock() + binding_query.where.return_value.first.return_value = binding + session = Mock() + session.query.side_effect = [upload_query, binding_query] + + result = RetrievalService.get_segment_attachment_info("dataset-id", "tenant-id", "upload-1", session) + + assert result == { + "attachment_info": { + "id": "upload-1", + "name": "file-name", + "extension": ".png", + "mime_type": "image/png", + "source_url": "signed://file", + "size": 42, + }, + "segment_id": "segment-1", + } + mock_sign.assert_called_once_with("upload-1", "png") + + def test_get_segment_attachment_info_returns_none_when_binding_missing(self): + upload_file = SimpleNamespace( + id="upload-1", + name="file-name", + extension="png", + mime_type="image/png", + size=42, + ) + upload_query = Mock() + upload_query.where.return_value.first.return_value = upload_file + binding_query = Mock() + binding_query.where.return_value.first.return_value = None + session = Mock() + session.query.side_effect = [upload_query, binding_query] + + result = RetrievalService.get_segment_attachment_info("dataset-id", "tenant-id", "upload-1", session) + + assert result is None + + def test_get_segment_attachment_info_returns_none_when_upload_file_missing(self): + upload_query = Mock() + upload_query.where.return_value.first.return_value = None + session = Mock() + session.query.return_value = upload_query + + result = RetrievalService.get_segment_attachment_info("dataset-id", "tenant-id", "upload-1", session) + + assert result is None + + def test_get_segment_attachment_infos_returns_empty_when_upload_files_missing(self): + upload_query = Mock() + upload_query.where.return_value.all.return_value = [] + session = Mock() + session.query.return_value = upload_query + + result = RetrievalService.get_segment_attachment_infos(["upload-1"], session) + + assert result == [] + + def test_get_segment_attachment_infos_returns_empty_when_bindings_missing(self): + upload_file = SimpleNamespace( + id="upload-1", + name="file-name", + extension="png", + mime_type="image/png", + size=42, + ) + upload_query = Mock() + upload_query.where.return_value.all.return_value = [upload_file] + binding_query = Mock() + binding_query.where.return_value.all.return_value = [] + session = Mock() + session.query.side_effect = [upload_query, binding_query] + + result = RetrievalService.get_segment_attachment_infos(["upload-1"], session) + + assert result == [] + + @patch("core.rag.datasource.retrieval_service.sign_upload_file", return_value="signed://file") + def test_get_segment_attachment_infos_success(self, mock_sign): + upload_file_1 = SimpleNamespace( + id="upload-1", + name="file-1", + extension="png", + mime_type="image/png", + size=42, + ) + upload_file_2 = SimpleNamespace( + id="upload-2", + name="file-2", + extension="jpg", + mime_type="image/jpeg", + size=99, + ) + binding = SimpleNamespace(attachment_id="upload-1", segment_id="segment-1") + + upload_query = Mock() + upload_query.where.return_value.all.return_value = [upload_file_1, upload_file_2] + binding_query = Mock() + binding_query.where.return_value.all.return_value = [binding] + session = Mock() + session.query.side_effect = [upload_query, binding_query] + + result = RetrievalService.get_segment_attachment_infos(["upload-1", "upload-2"], session) + + assert result == [ + { + "attachment_id": "upload-1", + "attachment_info": { + "id": "upload-1", + "name": "file-1", + "extension": ".png", + "mime_type": "image/png", + "source_url": "signed://file", + "size": 42, + }, + "segment_id": "segment-1", + } + ] + mock_sign.assert_has_calls( + [ + call("upload-1", "png"), + call("upload-2", "jpg"), + ] + ) + assert mock_sign.call_count == 2 diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/alibabacloud_mysql/test_alibabacloud_mysql_factory.py b/api/tests/unit_tests/core/rag/datasource/vdb/alibabacloud_mysql/test_alibabacloud_mysql_factory.py new file mode 100644 index 00000000000..e063a49f22d --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/alibabacloud_mysql/test_alibabacloud_mysql_factory.py @@ -0,0 +1,74 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +import core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector as alibaba_module +from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import AlibabaCloudMySQLVectorFactory + + +def test_validate_distance_function_accepts_supported_values(): + factory = AlibabaCloudMySQLVectorFactory() + + assert factory._validate_distance_function("cosine") == "cosine" + assert factory._validate_distance_function("euclidean") == "euclidean" + + +def test_validate_distance_function_rejects_unsupported_values(): + factory = AlibabaCloudMySQLVectorFactory() + + with pytest.raises(ValueError, match="Invalid distance function"): + factory._validate_distance_function("dot_product") + + +def test_factory_init_vector_uses_existing_index_struct_class_prefix(monkeypatch): + factory = AlibabaCloudMySQLVectorFactory() + dataset = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "existing_collection"}}, + index_struct=None, + ) + + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_HOST", "host") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_PORT", 3306) + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_USER", "user") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_PASSWORD", "password") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_DATABASE", "db") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_MAX_CONNECTION", 5) + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_CHARSET", "utf8mb4") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION", "cosine") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_HNSW_M", 6) + + with patch.object(alibaba_module, "AlibabaCloudMySQLVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + + assert result == "vector" + assert vector_cls.call_args.kwargs["collection_name"] == "existing_collection" + + +def test_factory_init_vector_generates_collection_name_when_index_struct_is_missing(monkeypatch): + factory = AlibabaCloudMySQLVectorFactory() + dataset = SimpleNamespace( + id="dataset-2", + index_struct_dict=None, + index_struct=None, + ) + + monkeypatch.setattr(alibaba_module.Dataset, "gen_collection_name_by_id", lambda dataset_id: f"COL_{dataset_id}") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_HOST", "host") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_PORT", 3306) + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_USER", "user") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_PASSWORD", "password") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_DATABASE", "db") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_MAX_CONNECTION", 5) + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_CHARSET", "utf8mb4") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION", "euclidean") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_HNSW_M", 12) + + with patch.object(alibaba_module, "AlibabaCloudMySQLVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + + assert result == "vector" + vector_cls.assert_called_once() + assert vector_cls.call_args.kwargs["collection_name"] == "COL_dataset-2" + assert dataset.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector.py new file mode 100644 index 00000000000..545565cdf4a --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector.py @@ -0,0 +1,133 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +import core.rag.datasource.vdb.analyticdb.analyticdb_vector as analyticdb_module +from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVector, AnalyticdbVectorFactory +from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig +from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig +from core.rag.models.document import Document + + +def test_init_prefers_openapi_when_api_config_is_provided(): + api_config = AnalyticdbVectorOpenAPIConfig( + access_key_id="ak", + access_key_secret="sk", + region_id="cn-hangzhou", + instance_id="instance-1", + account="account", + account_password="password", + namespace="dify", + namespace_password="ns-password", + ) + + with patch.object(analyticdb_module, "AnalyticdbVectorOpenAPI", return_value="openapi_runner") as openapi_cls: + vector = AnalyticdbVector("COLLECTION", api_config=api_config, sql_config=None) + + assert vector.analyticdb_vector == "openapi_runner" + openapi_cls.assert_called_once_with("COLLECTION", api_config) + + +def test_init_uses_sql_implementation_when_api_config_is_missing(): + sql_config = AnalyticdbVectorBySqlConfig( + host="localhost", + port=5432, + account="account", + account_password="password", + min_connection=1, + max_connection=2, + namespace="dify", + ) + + with patch.object(analyticdb_module, "AnalyticdbVectorBySql", return_value="sql_runner") as sql_cls: + vector = AnalyticdbVector("COLLECTION", api_config=None, sql_config=sql_config) + + assert vector.analyticdb_vector == "sql_runner" + sql_cls.assert_called_once_with("COLLECTION", sql_config) + + +def test_init_raises_when_both_configs_are_missing(): + with pytest.raises(ValueError, match="Either api_config or sql_config must be provided"): + AnalyticdbVector("COLLECTION", api_config=None, sql_config=None) + + +def test_vector_methods_delegate_to_underlying_implementation(): + runner = MagicMock() + runner.search_by_vector.return_value = [Document(page_content="v", metadata={"doc_id": "1"})] + runner.search_by_full_text.return_value = [Document(page_content="t", metadata={"doc_id": "2"})] + runner.text_exists.return_value = True + + vector = AnalyticdbVector.__new__(AnalyticdbVector) + vector.analyticdb_vector = runner + + texts = [Document(page_content="hello", metadata={"doc_id": "d1"})] + vector.create(texts=texts, embeddings=[[0.1, 0.2]]) + vector.add_texts(documents=texts, embeddings=[[0.1, 0.2]]) + assert vector.text_exists("d1") is True + vector.delete_by_ids(["d1"]) + vector.delete_by_metadata_field("document_id", "doc-1") + assert vector.search_by_vector([0.1, 0.2], top_k=2) == runner.search_by_vector.return_value + assert vector.search_by_full_text("hello", top_k=2) == runner.search_by_full_text.return_value + vector.delete() + + runner._create_collection_if_not_exists.assert_called_once_with(2) + runner.add_texts.assert_any_call(texts, [[0.1, 0.2]]) + runner.delete_by_ids.assert_called_once_with(["d1"]) + runner.delete_by_metadata_field.assert_called_once_with("document_id", "doc-1") + runner.delete.assert_called_once() + + +def test_get_type_is_analyticdb(): + vector = AnalyticdbVector.__new__(AnalyticdbVector) + assert vector.get_type() == "analyticdb" + + +def test_factory_builds_openapi_config_when_host_is_missing(monkeypatch): + factory = AnalyticdbVectorFactory() + dataset = SimpleNamespace(id="dataset-1", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(analyticdb_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_HOST", None) + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_KEY_ID", "ak") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_KEY_SECRET", "sk") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_REGION_ID", "cn-hz") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_INSTANCE_ID", "instance") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_ACCOUNT", "account") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_PASSWORD", "password") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_NAMESPACE", "dify") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_NAMESPACE_PASSWORD", "ns-password") + + with patch.object(analyticdb_module, "AnalyticdbVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + + assert result == "vector" + args = vector_cls.call_args.args + assert args[0] == "auto_collection" + assert isinstance(args[1], AnalyticdbVectorOpenAPIConfig) + assert args[2] is None + assert dataset.index_struct is not None + + +def test_factory_builds_sql_config_when_host_is_present(monkeypatch): + factory = AnalyticdbVectorFactory() + dataset = SimpleNamespace( + id="dataset-2", index_struct_dict={"vector_store": {"class_prefix": "EXISTING"}}, index_struct=None + ) + + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_HOST", "127.0.0.1") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_PORT", 5432) + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_ACCOUNT", "account") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_PASSWORD", "password") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_MIN_CONNECTION", 1) + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_MAX_CONNECTION", 3) + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_NAMESPACE", "dify") + + with patch.object(analyticdb_module, "AnalyticdbVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + + assert result == "vector" + args = vector_cls.call_args.args + assert args[0] == "existing" + assert args[1] is None + assert isinstance(args[2], AnalyticdbVectorBySqlConfig) diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_openapi.py b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_openapi.py new file mode 100644 index 00000000000..45777774d03 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_openapi.py @@ -0,0 +1,384 @@ +import json +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +import core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi as openapi_module +from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import ( + AnalyticdbVectorOpenAPI, + AnalyticdbVectorOpenAPIConfig, +) +from core.rag.models.document import Document + + +def _request_class(name: str): + class _Request: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + _Request.__name__ = name + return _Request + + +def _install_openapi_stubs(monkeypatch): + gpdb_package = types.ModuleType("alibabacloud_gpdb20160503") + gpdb_package.__path__ = [] + gpdb_models = types.ModuleType("alibabacloud_gpdb20160503.models") + for class_name in [ + "InitVectorDatabaseRequest", + "DescribeNamespaceRequest", + "CreateNamespaceRequest", + "DescribeCollectionRequest", + "CreateCollectionRequest", + "UpsertCollectionDataRequestRows", + "UpsertCollectionDataRequest", + "QueryCollectionDataRequest", + "DeleteCollectionDataRequest", + "DeleteCollectionRequest", + ]: + setattr(gpdb_models, class_name, _request_class(class_name)) + + class _Client: + def __init__(self, config): + self.config = config + + gpdb_client = types.ModuleType("alibabacloud_gpdb20160503.client") + gpdb_client.Client = _Client + gpdb_package.models = gpdb_models + + tea_openapi = types.ModuleType("alibabacloud_tea_openapi") + tea_openapi.__path__ = [] + tea_openapi_models = types.ModuleType("alibabacloud_tea_openapi.models") + + class OpenApiConfig: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + tea_openapi_models.Config = OpenApiConfig + tea_openapi.models = tea_openapi_models + + tea_package = types.ModuleType("Tea") + tea_package.__path__ = [] + tea_exceptions = types.ModuleType("Tea.exceptions") + + class TeaError(Exception): + def __init__(self, status_code=None, **kwargs): + super().__init__("TeaException") + status_code = kwargs.get("statusCode", status_code) + self.statusCode = status_code + self.status_code = status_code + + tea_exceptions.TeaException = TeaError + tea_package.exceptions = tea_exceptions + + monkeypatch.setitem(sys.modules, "alibabacloud_gpdb20160503", gpdb_package) + monkeypatch.setitem(sys.modules, "alibabacloud_gpdb20160503.models", gpdb_models) + monkeypatch.setitem(sys.modules, "alibabacloud_gpdb20160503.client", gpdb_client) + monkeypatch.setitem(sys.modules, "alibabacloud_tea_openapi", tea_openapi) + monkeypatch.setitem(sys.modules, "alibabacloud_tea_openapi.models", tea_openapi_models) + monkeypatch.setitem(sys.modules, "Tea", tea_package) + monkeypatch.setitem(sys.modules, "Tea.exceptions", tea_exceptions) + + return SimpleNamespace(models=gpdb_models, TeaException=TeaError, OpenApiConfig=OpenApiConfig) + + +def _config() -> AnalyticdbVectorOpenAPIConfig: + return AnalyticdbVectorOpenAPIConfig( + access_key_id="ak", + access_key_secret="sk", + region_id="cn-hangzhou", + instance_id="instance-1", + account="account", + account_password="password", + namespace="dify", + namespace_password="ns-password", + ) + + +@pytest.mark.parametrize( + ("field", "value", "error_message"), + [ + ("access_key_id", "", "ANALYTICDB_KEY_ID"), + ("access_key_secret", "", "ANALYTICDB_KEY_SECRET"), + ("region_id", "", "ANALYTICDB_REGION_ID"), + ("instance_id", "", "ANALYTICDB_INSTANCE_ID"), + ("account", "", "ANALYTICDB_ACCOUNT"), + ("account_password", "", "ANALYTICDB_PASSWORD"), + ("namespace_password", "", "ANALYTICDB_NAMESPACE_PASSWORD"), + ], +) +def test_openapi_config_validation(field, value, error_message): + values = _config().model_dump() + values[field] = value + + with pytest.raises(ValueError, match=error_message): + AnalyticdbVectorOpenAPIConfig.model_validate(values) + + +def test_openapi_config_to_client_params(): + config = _config() + params = config.to_analyticdb_client_params() + + assert params["access_key_id"] == "ak" + assert params["access_key_secret"] == "sk" + assert params["region_id"] == "cn-hangzhou" + assert params["read_timeout"] == 60000 + + +def test_init_creates_openapi_client_and_runs_initialize(monkeypatch): + stubs = _install_openapi_stubs(monkeypatch) + initialize_mock = MagicMock() + monkeypatch.setattr(openapi_module.AnalyticdbVectorOpenAPI, "_initialize", initialize_mock) + + vector = AnalyticdbVectorOpenAPI("COLLECTION_1", _config()) + + assert vector._collection_name == "collection_1" + assert isinstance(vector._client_config, stubs.OpenApiConfig) + assert vector._client_config.user_agent == "dify" + assert vector._client_config.access_key_id == "ak" + assert vector._client.config is vector._client_config + initialize_mock.assert_called_once_with() + + +def test_initialize_skips_when_cached(monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock()) + + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector.config = _config() + vector._initialize_vector_database = MagicMock() + vector._create_namespace_if_not_exists = MagicMock() + + vector._initialize() + + vector._initialize_vector_database.assert_not_called() + vector._create_namespace_if_not_exists.assert_not_called() + + +def test_initialize_runs_when_cache_is_missing(monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock()) + + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector.config = _config() + vector._initialize_vector_database = MagicMock() + vector._create_namespace_if_not_exists = MagicMock() + + vector._initialize() + + vector._initialize_vector_database.assert_called_once() + vector._create_namespace_if_not_exists.assert_called_once() + openapi_module.redis_client.set.assert_called_once() + + +def test_initialize_vector_database_calls_openapi_client(monkeypatch): + _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector.config = _config() + vector._client = MagicMock() + + vector._initialize_vector_database() + + request = vector._client.init_vector_database.call_args.args[0] + assert request.dbinstance_id == "instance-1" + assert request.region_id == "cn-hangzhou" + assert request.manager_account == "account" + assert request.manager_account_password == "password" + + +def test_create_namespace_creates_when_namespace_not_found(monkeypatch): + stubs = _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector.config = _config() + vector._client = MagicMock() + vector._client.describe_namespace.side_effect = stubs.TeaException(statusCode=404) + + vector._create_namespace_if_not_exists() + + vector._client.create_namespace.assert_called_once() + + +def test_create_namespace_raises_on_unexpected_api_error(monkeypatch): + stubs = _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector.config = _config() + vector._client = MagicMock() + vector._client.describe_namespace.side_effect = stubs.TeaException(statusCode=500) + + with pytest.raises(ValueError, match="failed to create namespace"): + vector._create_namespace_if_not_exists() + + +def test_create_namespace_noop_when_namespace_exists(monkeypatch): + _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector.config = _config() + vector._client = MagicMock() + + vector._create_namespace_if_not_exists() + + vector._client.describe_namespace.assert_called_once() + vector._client.create_namespace.assert_not_called() + + +def test_create_collection_if_not_exists_creates_when_missing(monkeypatch): + stubs = _install_openapi_stubs(monkeypatch) + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock()) + + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector._collection_name = "collection_1" + vector.config = _config() + vector._client = MagicMock() + vector._client.describe_collection.side_effect = stubs.TeaException(statusCode=404) + + vector._create_collection_if_not_exists(embedding_dimension=1024) + + vector._client.create_collection.assert_called_once() + openapi_module.redis_client.set.assert_called_once() + + +def test_create_collection_if_not_exists_skips_when_cached(monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock()) + + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector._collection_name = "collection_1" + vector.config = _config() + vector._client = MagicMock() + + vector._create_collection_if_not_exists(embedding_dimension=1024) + + vector._client.describe_collection.assert_not_called() + vector._client.create_collection.assert_not_called() + + +def test_create_collection_if_not_exists_raises_on_non_404_errors(monkeypatch): + stubs = _install_openapi_stubs(monkeypatch) + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock()) + + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector._collection_name = "collection_1" + vector.config = _config() + vector._client = MagicMock() + vector._client.describe_collection.side_effect = stubs.TeaException(statusCode=500) + + with pytest.raises(ValueError, match="failed to create collection collection_1"): + vector._create_collection_if_not_exists(embedding_dimension=512) + + +def test_openapi_add_delete_and_search_methods(monkeypatch): + _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector._collection_name = "collection_1" + vector.config = _config() + vector._client = MagicMock() + + documents = [ + Document(page_content="doc 1", metadata={"doc_id": "d1", "document_id": "doc-1"}), + SimpleNamespace(page_content="doc 2", metadata=None), + ] + embeddings = [[0.1, 0.2], [0.2, 0.3]] + vector.add_texts(documents, embeddings) + + upsert_request = vector._client.upsert_collection_data.call_args.args[0] + assert upsert_request.collection == "collection_1" + assert len(upsert_request.rows) == 1 + + vector._client.query_collection_data.return_value = SimpleNamespace( + body=SimpleNamespace(matches=SimpleNamespace(match=[SimpleNamespace()])) + ) + assert vector.text_exists("d1") is True + + vector.delete_by_ids(["d1", "d2"]) + request = vector._client.delete_collection_data.call_args.args[0] + assert request.collection_data_filter == "ref_doc_id IN ('d1','d2')" + + vector.delete_by_metadata_field("document_id", "doc-1") + request = vector._client.delete_collection_data.call_args.args[0] + assert request.collection_data_filter == "metadata_ ->> 'document_id' = 'doc-1'" + + match_high = SimpleNamespace( + score=0.9, + metadata={"metadata_": json.dumps({"document_id": "doc-1"}), "page_content": "high"}, + values=SimpleNamespace(value=[1.0, 2.0]), + ) + match_low = SimpleNamespace( + score=0.1, + metadata={"metadata_": json.dumps({"document_id": "doc-2"}), "page_content": "low"}, + values=SimpleNamespace(value=[3.0, 4.0]), + ) + vector._client.query_collection_data.return_value = SimpleNamespace( + body=SimpleNamespace(matches=SimpleNamespace(match=[match_low, match_high])) + ) + + docs_by_vector = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["doc-1"]) + assert len(docs_by_vector) == 1 + assert docs_by_vector[0].page_content == "high" + assert docs_by_vector[0].metadata["score"] == 0.9 + + docs_by_text = vector.search_by_full_text("hello", top_k=2, score_threshold=0.2) + assert len(docs_by_text) == 1 + assert docs_by_text[0].page_content == "high" + + +def test_text_exists_returns_false_when_matches_empty(monkeypatch): + _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector._collection_name = "collection_1" + vector.config = _config() + vector._client = MagicMock() + vector._client.query_collection_data.return_value = SimpleNamespace( + body=SimpleNamespace(matches=SimpleNamespace(match=[])) + ) + + assert vector.text_exists("missing-id") is False + + +def test_openapi_delete_success(monkeypatch): + _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector._collection_name = "collection_1" + vector.config = _config() + vector._client = MagicMock() + + vector.delete() + vector._client.delete_collection.assert_called_once() + + +def test_openapi_delete_propagates_errors(monkeypatch): + _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector._collection_name = "collection_1" + vector.config = _config() + vector._client = MagicMock() + vector._client.delete_collection.side_effect = RuntimeError("boom") + + with pytest.raises(RuntimeError, match="boom"): + vector.delete() diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_sql.py b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_sql.py new file mode 100644 index 00000000000..8f1206696bb --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_sql.py @@ -0,0 +1,427 @@ +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import MagicMock + +import psycopg2.errors +import pytest + +import core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql as sql_module +from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import ( + AnalyticdbVectorBySql, + AnalyticdbVectorBySqlConfig, +) +from core.rag.models.document import Document + + +def _config_values() -> dict: + return { + "host": "localhost", + "port": 5432, + "account": "account", + "account_password": "password", + "min_connection": 1, + "max_connection": 2, + "namespace": "dify", + } + + +@pytest.mark.parametrize( + ("field", "value", "error_message"), + [ + ("host", "", "ANALYTICDB_HOST"), + ("port", 0, "ANALYTICDB_PORT"), + ("account", "", "ANALYTICDB_ACCOUNT"), + ("account_password", "", "ANALYTICDB_PASSWORD"), + ("min_connection", 0, "ANALYTICDB_MIN_CONNECTION"), + ("max_connection", 0, "ANALYTICDB_MAX_CONNECTION"), + ], +) +def test_sql_config_required_fields(field, value, error_message): + values = _config_values() + values[field] = value + + with pytest.raises(ValueError, match=error_message): + AnalyticdbVectorBySqlConfig.model_validate(values) + + +def test_sql_config_rejects_min_connection_greater_than_max_connection(): + values = _config_values() + values["min_connection"] = 10 + values["max_connection"] = 2 + + with pytest.raises(ValueError, match="ANALYTICDB_MIN_CONNECTION should less than ANALYTICDB_MAX_CONNECTION"): + AnalyticdbVectorBySqlConfig.model_validate(values) + + +def test_initialize_skips_when_cache_exists(monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(sql_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(sql_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(sql_module.redis_client, "set", MagicMock()) + + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.config = AnalyticdbVectorBySqlConfig(**_config_values()) + vector._initialize_vector_database = MagicMock() + + vector._initialize() + + vector._initialize_vector_database.assert_not_called() + + +def test_initialize_runs_when_cache_is_missing(monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(sql_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(sql_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(sql_module.redis_client, "set", MagicMock()) + + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.config = AnalyticdbVectorBySqlConfig(**_config_values()) + vector._initialize_vector_database = MagicMock() + + vector._initialize() + + vector._initialize_vector_database.assert_called_once() + sql_module.redis_client.set.assert_called_once() + + +def test_create_connection_pool_uses_psycopg2_pool(monkeypatch): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.config = AnalyticdbVectorBySqlConfig(**_config_values()) + vector.databaseName = "knowledgebase" + + pool_instance = MagicMock() + monkeypatch.setattr(sql_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool_instance)) + + pool = vector._create_connection_pool() + + assert pool is pool_instance + sql_module.psycopg2.pool.SimpleConnectionPool.assert_called_once() + + +def test_get_cursor_context_manager_handles_connection_lifecycle(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + cursor = MagicMock() + connection = MagicMock() + connection.cursor.return_value = cursor + pool = MagicMock() + pool.getconn.return_value = connection + vector.pool = pool + + with vector._get_cursor() as cur: + assert cur is cursor + + cursor.close.assert_called_once() + connection.commit.assert_called_once() + pool.putconn.assert_called_once_with(connection) + + +def test_add_texts_inserts_only_documents_with_metadata(monkeypatch): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + + cursor = MagicMock() + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + + monkeypatch.setattr(sql_module.uuid, "uuid4", lambda: "prefix-id") + monkeypatch.setattr(sql_module.psycopg2.extras, "execute_batch", MagicMock()) + + docs = [ + Document(page_content="doc 1", metadata={"doc_id": "d1", "document_id": "doc-1"}), + SimpleNamespace(page_content="doc 2", metadata=None), + ] + vector.add_texts(docs, [[0.1, 0.2], [0.2, 0.3]]) + + execute_args = sql_module.psycopg2.extras.execute_batch.call_args.args + assert execute_args[0] is cursor + assert len(execute_args[2]) == 1 + + +def test_text_exists_returns_true_and_false_based_on_query_result(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + cursor = MagicMock() + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + + cursor.fetchone.return_value = ("row",) + assert vector.text_exists("d1") is True + + cursor.fetchone.return_value = None + assert vector.text_exists("d1") is False + + +def test_delete_by_ids_handles_empty_input_and_missing_table_error(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + cursor = MagicMock() + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + vector.delete_by_ids([]) + cursor.execute.assert_not_called() + + cursor.execute.side_effect = psycopg2.errors.UndefinedTable("relation does not exist") + vector.delete_by_ids(["d1"]) + + +def test_delete_by_metadata_field_handles_missing_table_error(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + cursor = MagicMock() + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + cursor.execute.side_effect = psycopg2.errors.UndefinedTable("relation does not exist") + vector.delete_by_metadata_field("document_id", "doc-1") + + +@pytest.mark.parametrize("invalid_top_k", [0, "x", -1]) +def test_search_by_vector_validates_top_k(invalid_top_k): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_vector([0.1, 0.2], top_k=invalid_top_k) + + +def test_search_by_vector_returns_documents_above_threshold(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + cursor = MagicMock() + cursor.__iter__.return_value = iter( + [ + ("id1", [1.0], 0.8, "content 1", {"doc_id": "id1", "document_id": "doc-1"}), + ("id2", [2.0], 0.3, "content 2", {"doc_id": "id2", "document_id": "doc-2"}), + ] + ) + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["doc-1"]) + + assert len(docs) == 1 + assert docs[0].page_content == "content 1" + assert docs[0].metadata["score"] == 0.8 + + +@pytest.mark.parametrize("invalid_top_k", [0, "x", -1]) +def test_search_by_full_text_validates_top_k(invalid_top_k): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_full_text("query", top_k=invalid_top_k) + + +def test_search_by_full_text_returns_documents(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + cursor = MagicMock() + cursor.__iter__.return_value = iter( + [ + ("id1", [1.0], "content 1", {"doc_id": "id1", "document_id": "doc-1"}, 0.9), + ] + ) + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + docs = vector.search_by_full_text("query", top_k=1, document_ids_filter=["doc-1"]) + + assert len(docs) == 1 + assert docs[0].metadata["score"] == 0.9 + assert docs[0].page_content == "content 1" + + +def test_delete_drops_table(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + cursor = MagicMock() + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + vector.delete() + + cursor.execute.assert_called_once() + + +def test_init_normalizes_collection_name_and_creates_pool_when_missing(monkeypatch): + config = AnalyticdbVectorBySqlConfig(**_config_values()) + created_pool = MagicMock() + + monkeypatch.setattr(AnalyticdbVectorBySql, "_initialize", MagicMock()) + monkeypatch.setattr(AnalyticdbVectorBySql, "_create_connection_pool", MagicMock(return_value=created_pool)) + + vector = AnalyticdbVectorBySql("My_Collection", config) + + assert vector._collection_name == "my_collection" + assert vector.table_name == "dify.my_collection" + assert vector.databaseName == "knowledgebase" + assert vector.pool is created_pool + + +def test_initialize_vector_database_handles_existing_database_and_search_config(monkeypatch): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.config = AnalyticdbVectorBySqlConfig(**_config_values()) + vector.databaseName = "knowledgebase" + + bootstrap_cursor = MagicMock() + bootstrap_connection = MagicMock() + bootstrap_connection.cursor.return_value = bootstrap_cursor + bootstrap_cursor.execute.side_effect = RuntimeError("database already exists") + monkeypatch.setattr(sql_module.psycopg2, "connect", MagicMock(return_value=bootstrap_connection)) + + worker_cursor = MagicMock() + worker_connection = MagicMock() + worker_cursor.connection = worker_connection + + def _execute(sql, *args, **kwargs): + if "CREATE TEXT SEARCH CONFIGURATION zh_cn" in sql: + raise RuntimeError("already exists") + + worker_cursor.execute.side_effect = _execute + pooled_connection = MagicMock() + pooled_connection.cursor.return_value = worker_cursor + pool = MagicMock() + pool.getconn.return_value = pooled_connection + vector._create_connection_pool = MagicMock(return_value=pool) + + vector._initialize_vector_database() + + bootstrap_cursor.close.assert_called_once() + bootstrap_connection.close.assert_called_once() + vector._create_connection_pool.assert_called_once() + assert any( + "CREATE OR REPLACE FUNCTION public.to_tsquery_from_text" in call.args[0] + for call in worker_cursor.execute.call_args_list + ) + assert any("CREATE SCHEMA IF NOT EXISTS dify" in call.args[0] for call in worker_cursor.execute.call_args_list) + + +def test_initialize_vector_database_raises_runtime_error_when_zhparser_fails(monkeypatch): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.config = AnalyticdbVectorBySqlConfig(**_config_values()) + vector.databaseName = "knowledgebase" + + bootstrap_cursor = MagicMock() + bootstrap_connection = MagicMock() + bootstrap_connection.cursor.return_value = bootstrap_cursor + monkeypatch.setattr(sql_module.psycopg2, "connect", MagicMock(return_value=bootstrap_connection)) + + worker_cursor = MagicMock() + worker_connection = MagicMock() + worker_cursor.connection = worker_connection + worker_cursor.execute.side_effect = RuntimeError("zhparser unavailable") + + pooled_connection = MagicMock() + pooled_connection.cursor.return_value = worker_cursor + pool = MagicMock() + pool.getconn.return_value = pooled_connection + vector._create_connection_pool = MagicMock(return_value=pool) + + with pytest.raises(RuntimeError, match="Failed to create zhparser extension"): + vector._initialize_vector_database() + + worker_connection.rollback.assert_called_once() + + +def test_create_collection_if_not_exists_creates_table_indexes_and_cache(monkeypatch): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.config = AnalyticdbVectorBySqlConfig(**_config_values()) + vector._collection_name = "collection" + vector.table_name = "dify.collection" + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(sql_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(sql_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(sql_module.redis_client, "set", MagicMock()) + + cursor = MagicMock() + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + + vector._create_collection_if_not_exists(embedding_dimension=3) + + assert any("CREATE TABLE IF NOT EXISTS dify.collection" in call.args[0] for call in cursor.execute.call_args_list) + assert any("CREATE INDEX collection_embedding_idx" in call.args[0] for call in cursor.execute.call_args_list) + sql_module.redis_client.set.assert_called_once() + + +def test_create_collection_if_not_exists_raises_for_non_existing_error(monkeypatch): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.config = AnalyticdbVectorBySqlConfig(**_config_values()) + vector._collection_name = "collection" + vector.table_name = "dify.collection" + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(sql_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(sql_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(sql_module.redis_client, "set", MagicMock()) + + cursor = MagicMock() + cursor.execute.side_effect = RuntimeError("permission denied") + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + + with pytest.raises(RuntimeError, match="permission denied"): + vector._create_collection_if_not_exists(embedding_dimension=3) + + +def test_delete_methods_raise_when_error_is_not_missing_table(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + cursor = MagicMock() + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + + cursor.execute.side_effect = RuntimeError("unexpected delete failure") + with pytest.raises(RuntimeError, match="unexpected delete failure"): + vector.delete_by_ids(["doc-1"]) + + cursor.execute.side_effect = RuntimeError("unexpected metadata failure") + with pytest.raises(RuntimeError, match="unexpected metadata failure"): + vector.delete_by_metadata_field("document_id", "doc-1") diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/baidu/test_baidu_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/baidu/test_baidu_vector.py new file mode 100644 index 00000000000..c46c3d5e4bd --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/baidu/test_baidu_vector.py @@ -0,0 +1,542 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_pymochow_modules(): + pymochow = types.ModuleType("pymochow") + pymochow.__path__ = [] + pymochow_auth = types.ModuleType("pymochow.auth") + pymochow_auth.__path__ = [] + pymochow_credentials = types.ModuleType("pymochow.auth.bce_credentials") + pymochow_configuration = types.ModuleType("pymochow.configuration") + pymochow_exception = types.ModuleType("pymochow.exception") + pymochow_model = types.ModuleType("pymochow.model") + pymochow_model.__path__ = [] + pymochow_model_database = types.ModuleType("pymochow.model.database") + pymochow_model_enum = types.ModuleType("pymochow.model.enum") + pymochow_model_schema = types.ModuleType("pymochow.model.schema") + pymochow_model_table = types.ModuleType("pymochow.model.table") + + class _SimpleObject: + def __init__(self, *args, **kwargs): + self.args = args + for key, value in kwargs.items(): + setattr(self, key, value) + + class ServerError(Exception): + def __init__(self, code): + super().__init__(f"server error {code}") + self.code = code + + class ServerErrCode: + TABLE_NOT_EXIST = 1001 + DB_ALREADY_EXIST = 1002 + + class IndexType: + __members__ = {"HNSW": "HNSW"} + + class MetricType: + __members__ = {"IP": "IP"} + + class IndexState: + NORMAL = "NORMAL" + + class TableState: + NORMAL = "NORMAL" + + class InvertedIndexAnalyzer: + DEFAULT_ANALYZER = "DEFAULT_ANALYZER" + + class InvertedIndexParseMode: + COARSE_MODE = "COARSE_MODE" + + class InvertedIndexFieldAttribute: + ANALYZED = "ANALYZED" + + class FieldType: + STRING = "STRING" + TEXT = "TEXT" + JSON = "JSON" + FLOAT_VECTOR = "FLOAT_VECTOR" + + pymochow.MochowClient = _SimpleObject + pymochow_credentials.BceCredentials = _SimpleObject + pymochow_configuration.Configuration = _SimpleObject + pymochow_exception.ServerError = ServerError + pymochow_model_database.Database = _SimpleObject + + pymochow_model_enum.FieldType = FieldType + pymochow_model_enum.IndexState = IndexState + pymochow_model_enum.IndexType = IndexType + pymochow_model_enum.MetricType = MetricType + pymochow_model_enum.ServerErrCode = ServerErrCode + pymochow_model_enum.TableState = TableState + + for cls_name in [ + "AutoBuildRowCountIncrement", + "Field", + "FilteringIndex", + "HNSWParams", + "InvertedIndex", + "InvertedIndexParams", + "Schema", + "VectorIndex", + ]: + setattr(pymochow_model_schema, cls_name, _SimpleObject) + pymochow_model_schema.InvertedIndexAnalyzer = InvertedIndexAnalyzer + pymochow_model_schema.InvertedIndexFieldAttribute = InvertedIndexFieldAttribute + pymochow_model_schema.InvertedIndexParseMode = InvertedIndexParseMode + + for cls_name in ["AnnSearch", "BM25SearchRequest", "HNSWSearchParams", "Partition", "Row"]: + setattr(pymochow_model_table, cls_name, _SimpleObject) + + pymochow.auth = pymochow_auth + pymochow.model = pymochow_model + pymochow_auth.bce_credentials = pymochow_credentials + pymochow_model.database = pymochow_model_database + pymochow_model.enum = pymochow_model_enum + pymochow_model.schema = pymochow_model_schema + pymochow_model.table = pymochow_model_table + + modules = { + "pymochow": pymochow, + "pymochow.auth": pymochow_auth, + "pymochow.auth.bce_credentials": pymochow_credentials, + "pymochow.configuration": pymochow_configuration, + "pymochow.exception": pymochow_exception, + "pymochow.model": pymochow_model, + "pymochow.model.database": pymochow_model_database, + "pymochow.model.enum": pymochow_model_enum, + "pymochow.model.schema": pymochow_model_schema, + "pymochow.model.table": pymochow_model_table, + } + return modules + + +@pytest.fixture +def baidu_module(monkeypatch): + for name, module in _build_fake_pymochow_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + import core.rag.datasource.vdb.baidu.baidu_vector as module + + return importlib.reload(module) + + +def test_baidu_config_validation(baidu_module): + values = { + "endpoint": "https://example.com", + "account": "account", + "api_key": "key", + "database": "database", + } + config = baidu_module.BaiduConfig.model_validate(values) + assert config.endpoint == "https://example.com" + + for key, error_message in [ + ("endpoint", "BAIDU_VECTOR_DB_ENDPOINT"), + ("account", "BAIDU_VECTOR_DB_ACCOUNT"), + ("api_key", "BAIDU_VECTOR_DB_API_KEY"), + ("database", "BAIDU_VECTOR_DB_DATABASE"), + ]: + invalid = dict(values) + invalid[key] = "" + with pytest.raises(ValueError, match=error_message): + baidu_module.BaiduConfig.model_validate(invalid) + + +def test_get_search_result_handles_metadata_and_threshold(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + response = SimpleNamespace( + rows=[ + {"row": {"page_content": "doc1", "metadata": '{"document_id":"d1"}'}, "score": 0.9}, + {"row": {"page_content": "doc2", "metadata": {"document_id": "d2"}}, "score": 0.4}, + {"row": {"page_content": "doc3", "metadata": 123}, "score": 0.95}, + ] + ) + + docs = vector._get_search_res(response, score_threshold=0.8) + + assert len(docs) == 2 + assert docs[0].page_content == "doc1" + assert docs[0].metadata["score"] == 0.9 + assert docs[1].page_content == "doc3" + + +def test_delete_by_ids_and_delete_by_metadata_field(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + table = MagicMock() + vector._db = MagicMock() + vector._db.table.return_value = table + vector._collection_name = "collection_1" + + vector.delete_by_ids([]) + table.delete.assert_not_called() + + vector.delete_by_ids(["id1", "id2"]) + table.delete.assert_called_once() + + table.delete.reset_mock() + vector.delete_by_metadata_field("source", 'abc"def') + delete_filter = table.delete.call_args.kwargs["filter"] + assert delete_filter == 'metadata["source"] = "abc\\"def"' + + +def test_delete_handles_table_not_exist_error_and_raises_for_other_codes(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + vector._db = MagicMock() + + vector._db.drop_table.side_effect = baidu_module.ServerError(baidu_module.ServerErrCode.TABLE_NOT_EXIST) + vector.delete() + + vector._db.drop_table.side_effect = baidu_module.ServerError(9999) + with pytest.raises(baidu_module.ServerError): + vector.delete() + + +def test_init_database_uses_existing_or_creates_when_missing(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._client = MagicMock() + vector._client_config = SimpleNamespace(database="my_db") + + vector._client.list_databases.return_value = [SimpleNamespace(database_name="my_db")] + vector._client.database.return_value = "existing_db" + assert vector._init_database() == "existing_db" + + vector._client.list_databases.return_value = [] + vector._client.database.return_value = "created_db" + vector._client.create_database.side_effect = None + assert vector._init_database() == "created_db" + + vector._client.create_database.side_effect = baidu_module.ServerError(baidu_module.ServerErrCode.DB_ALREADY_EXIST) + assert vector._init_database() == "created_db" + + +def test_table_existed_checks_table_access(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + vector._db = MagicMock() + vector._db.table.return_value = MagicMock() + + assert vector._table_existed() is True + + vector._db.table.side_effect = baidu_module.ServerError(baidu_module.ServerErrCode.TABLE_NOT_EXIST) + assert vector._table_existed() is False + + vector._db.table.side_effect = baidu_module.ServerError(9999) + with pytest.raises(baidu_module.ServerError): + vector._table_existed() + + +def test_search_methods_delegate_to_database_table(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + vector._db = MagicMock() + vector._get_search_res = MagicMock(return_value=[Document(page_content="doc", metadata={"doc_id": "1"})]) + + table = MagicMock() + vector._db.table.return_value = table + table.search.return_value = "vector_result" + table.bm25_search.return_value = "bm25_result" + + result1 = vector.search_by_vector([0.1, 0.2], top_k=3, document_ids_filter=["doc-1"], score_threshold=0.2) + result2 = vector.search_by_full_text("query", top_k=3, document_ids_filter=["doc-1"], score_threshold=0.2) + + assert result1 == vector._get_search_res.return_value + assert result2 == vector._get_search_res.return_value + assert vector._get_search_res.call_count == 2 + + +def test_factory_initializes_collection_name_and_index_struct(baidu_module, monkeypatch): + factory = baidu_module.BaiduVectorFactory() + dataset = SimpleNamespace(id="dataset-1", index_struct_dict=None, index_struct=None) + monkeypatch.setattr(baidu_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_ENDPOINT", "https://endpoint") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS", 1000) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_ACCOUNT", "account") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_API_KEY", "key") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_DATABASE", "database") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_SHARD", 1) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_REPLICAS", 1) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER", "DEFAULT_ANALYZER") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE", "COARSE_MODE") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT", 500) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO", 0.05) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS", 300) + + with patch.object(baidu_module, "BaiduVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + + assert result == "vector" + assert vector_cls.call_args.kwargs["collection_name"] == "auto_collection" + assert dataset.index_struct is not None + + +def test_init_get_type_to_index_struct_and_create_delegate(baidu_module, monkeypatch): + init_client = MagicMock(return_value="client") + init_database = MagicMock(return_value="database") + monkeypatch.setattr(baidu_module.BaiduVector, "_init_client", init_client) + monkeypatch.setattr(baidu_module.BaiduVector, "_init_database", init_database) + + config = baidu_module.BaiduConfig( + endpoint="https://example.com", + account="account", + api_key="key", + database="db", + ) + vector = baidu_module.BaiduVector(collection_name="my_collection", config=config) + + assert vector.get_type() == baidu_module.VectorType.BAIDU + assert vector.to_index_struct()["vector_store"]["class_prefix"] == "my_collection" + assert vector._client == "client" + assert vector._db == "database" + + vector._create_table = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="p1", metadata={"doc_id": "d1"})] + vector.create(docs, [[0.1, 0.2]]) + vector._create_table.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_add_texts_batches_rows(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + table = MagicMock() + vector._db = MagicMock() + vector._db.table.return_value = table + + docs = [ + Document(page_content="doc-1", metadata={"doc_id": "id-1", "document_id": "doc-1"}), + Document(page_content="doc-2", metadata={"doc_id": "id-2", "document_id": "doc-2"}), + ] + vector.add_texts(docs, [[0.1, 0.2], [0.3, 0.4]]) + + assert table.upsert.call_count == 1 + inserted_rows = table.upsert.call_args.kwargs["rows"] + assert len(inserted_rows) == 2 + + +def test_add_texts_batches_more_than_batch_size(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + table = MagicMock() + vector._db = MagicMock() + vector._db.table.return_value = table + + docs = [ + Document(page_content=f"doc-{idx}", metadata={"doc_id": f"id-{idx}", "document_id": f"doc-{idx}"}) + for idx in range(1001) + ] + embeddings = [[0.1, 0.2] for _ in range(1001)] + + vector.add_texts(docs, embeddings) + + assert table.upsert.call_count == 2 + assert len(table.upsert.call_args_list[0].kwargs["rows"]) == 1000 + assert len(table.upsert.call_args_list[1].kwargs["rows"]) == 1 + + +def test_text_exists_returns_false_when_query_code_is_not_success(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + table = MagicMock() + vector._db = MagicMock() + vector._db.table.return_value = table + + table.query.return_value = SimpleNamespace(code=0) + assert vector.text_exists("id-1") is True + + table.query.return_value = SimpleNamespace(code=1) + assert vector.text_exists("id-1") is False + + table.query.return_value = None + assert vector.text_exists("id-1") is False + + +def test_get_search_result_handles_invalid_metadata_json(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + response = SimpleNamespace(rows=[{"row": {"page_content": "doc1", "metadata": "{bad json"}, "score": 0.7}]) + + docs = vector._get_search_res(response, score_threshold=0.1) + + assert len(docs) == 1 + assert docs[0].metadata["score"] == 0.7 + assert "document_id" not in docs[0].metadata + + +def test_init_client_constructs_configuration_and_client(baidu_module, monkeypatch): + credentials = MagicMock(return_value="credentials") + configuration = MagicMock(return_value="configuration") + client_cls = MagicMock(return_value="client") + monkeypatch.setattr(baidu_module, "BceCredentials", credentials) + monkeypatch.setattr(baidu_module, "Configuration", configuration) + monkeypatch.setattr(baidu_module, "MochowClient", client_cls) + + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + config = SimpleNamespace(account="account", api_key="key", endpoint="https://endpoint") + + client = vector._init_client(config) + + assert client == "client" + credentials.assert_called_once_with("account", "key") + configuration.assert_called_once_with(credentials="credentials", endpoint="https://endpoint") + client_cls.assert_called_once_with("configuration") + + +def test_init_database_raises_for_unknown_create_database_error(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._client = MagicMock() + vector._client_config = SimpleNamespace(database="my_db") + vector._client.list_databases.return_value = [] + vector._client.create_database.side_effect = baidu_module.ServerError(9999) + + with pytest.raises(baidu_module.ServerError): + vector._init_database() + + +def test_create_table_handles_cache_and_validation_paths(baidu_module, monkeypatch): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + vector._client_config = SimpleNamespace( + index_type="HNSW", + metric_type="IP", + inverted_index_analyzer="DEFAULT_ANALYZER", + inverted_index_parser_mode="COARSE_MODE", + auto_build_row_count_increment=500, + auto_build_row_count_increment_ratio=0.05, + rebuild_index_timeout_in_seconds=300, + replicas=1, + shard=1, + ) + vector._db = MagicMock() + table = MagicMock() + table.state = baidu_module.TableState.NORMAL + vector._db.describe_table.return_value = table + vector._table_existed = MagicMock(return_value=False) + vector.delete = MagicMock() + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(baidu_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(baidu_module.redis_client, "set", MagicMock()) + monkeypatch.setattr(baidu_module.time, "sleep", lambda _s: None) + monkeypatch.setattr(vector, "_wait_for_index_ready", MagicMock()) + + # Cached table skips all work. + monkeypatch.setattr(baidu_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_table(3) + vector._db.create_table.assert_not_called() + + # Existing table also skips creation. + monkeypatch.setattr(baidu_module.redis_client, "get", MagicMock(return_value=None)) + vector._table_existed.return_value = True + vector._create_table(3) + vector._db.create_table.assert_not_called() + + # Create table when cache is empty and table does not exist. + vector._table_existed.return_value = False + vector._create_table(3) + vector._db.create_table.assert_called_once() + baidu_module.redis_client.set.assert_called_once_with("vector_indexing_collection_1", 1, ex=3600) + table.rebuild_index.assert_called_once_with(vector.vector_index) + vector._wait_for_index_ready.assert_called_once_with(table, 3600) + + +def test_create_table_raises_for_invalid_index_or_metric(baidu_module, monkeypatch): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + vector._db = MagicMock() + vector._table_existed = MagicMock(return_value=False) + vector.delete = MagicMock() + vector._client_config = SimpleNamespace( + index_type="INVALID", + metric_type="IP", + inverted_index_analyzer="DEFAULT_ANALYZER", + inverted_index_parser_mode="COARSE_MODE", + auto_build_row_count_increment=500, + auto_build_row_count_increment_ratio=0.05, + rebuild_index_timeout_in_seconds=300, + replicas=1, + shard=1, + ) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(baidu_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(baidu_module.redis_client, "get", MagicMock(return_value=None)) + + with pytest.raises(ValueError, match="unsupported index_type"): + vector._create_table(3) + + vector._client_config.index_type = "HNSW" + vector._client_config.metric_type = "INVALID" + with pytest.raises(ValueError, match="unsupported metric_type"): + vector._create_table(3) + + +def test_create_table_raises_timeout_if_table_never_becomes_normal(baidu_module, monkeypatch): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + vector._client_config = SimpleNamespace( + index_type="HNSW", + metric_type="IP", + inverted_index_analyzer="DEFAULT_ANALYZER", + inverted_index_parser_mode="COARSE_MODE", + auto_build_row_count_increment=500, + auto_build_row_count_increment_ratio=0.05, + rebuild_index_timeout_in_seconds=300, + replicas=1, + shard=1, + ) + vector._db = MagicMock() + vector._db.describe_table.return_value = SimpleNamespace(state="CREATING") + vector._table_existed = MagicMock(return_value=False) + vector.delete = MagicMock() + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(baidu_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(baidu_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(baidu_module.time, "sleep", lambda _s: None) + monkeypatch.setattr(baidu_module.time, "time", MagicMock(side_effect=[0, 301])) + + with pytest.raises(TimeoutError, match="Table creation timeout"): + vector._create_table(3) + + +def test_factory_uses_existing_collection_prefix_when_index_struct_exists(baidu_module, monkeypatch): + factory = baidu_module.BaiduVectorFactory() + dataset = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_ENDPOINT", "https://endpoint") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS", 1000) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_ACCOUNT", "account") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_API_KEY", "key") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_DATABASE", "database") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_SHARD", 1) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_REPLICAS", 1) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER", "DEFAULT_ANALYZER") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE", "COARSE_MODE") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT", 500) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO", 0.05) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS", 300) + + with patch.object(baidu_module, "BaiduVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + + assert result == "vector" + assert vector_cls.call_args.kwargs["collection_name"] == "existing_collection" diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/chroma/test_chroma_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/chroma/test_chroma_vector.py new file mode 100644 index 00000000000..44427b7d879 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/chroma/test_chroma_vector.py @@ -0,0 +1,199 @@ +import importlib +import sys +import types +from collections import UserDict +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_chroma_modules(): + chroma = types.ModuleType("chromadb") + chroma.DEFAULT_TENANT = "default_tenant" + chroma.DEFAULT_DATABASE = "default_database" + + class Settings: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + class QueryResult(UserDict): + pass + + class _Collection: + def __init__(self): + self.upsert = MagicMock() + self.delete = MagicMock() + self.query = MagicMock() + self.get = MagicMock(return_value={}) + + class _Client: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.collection = _Collection() + self.get_or_create_collection = MagicMock(return_value=self.collection) + self.delete_collection = MagicMock() + + chroma.Settings = Settings + chroma.QueryResult = QueryResult + chroma.HttpClient = _Client + return chroma + + +@pytest.fixture +def chroma_module(monkeypatch): + fake_chroma = _build_fake_chroma_modules() + monkeypatch.setitem(sys.modules, "chromadb", fake_chroma) + import core.rag.datasource.vdb.chroma.chroma_vector as module + + return importlib.reload(module) + + +def test_chroma_config_to_params_builds_expected_payload(chroma_module): + config = chroma_module.ChromaConfig( + host="localhost", + port=8000, + tenant="tenant-1", + database="db-1", + auth_provider="provider", + auth_credentials="credentials", + ) + + params = config.to_chroma_params() + + assert params["host"] == "localhost" + assert params["port"] == 8000 + assert params["tenant"] == "tenant-1" + assert params["database"] == "db-1" + assert params["ssl"] is False + assert params["settings"].chroma_client_auth_provider == "provider" + assert params["settings"].chroma_client_auth_credentials == "credentials" + + +def test_create_collection_uses_redis_lock_and_cache(chroma_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(chroma_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(chroma_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(chroma_module.redis_client, "set", MagicMock()) + + vector = chroma_module.ChromaVector( + collection_name="collection_1", + config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"), + ) + vector.create_collection("collection_1") + + vector._client.get_or_create_collection.assert_called_once_with("collection_1") + chroma_module.redis_client.set.assert_called_once() + + +def test_create_with_empty_texts_is_noop(chroma_module): + vector = chroma_module.ChromaVector( + collection_name="collection_1", + config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"), + ) + vector.create([], []) + vector._client.get_or_create_collection.assert_not_called() + + +def test_create_with_texts_creates_collection_and_upserts(chroma_module): + vector = chroma_module.ChromaVector( + collection_name="collection_1", + config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"), + ) + docs = [Document(page_content="hello", metadata={"doc_id": "d1", "document_id": "doc-1"})] + vector.create(docs, [[0.1, 0.2]]) + + vector._client.get_or_create_collection.assert_called() + vector._client.collection.upsert.assert_called_once() + + +def test_delete_methods_and_text_exists(chroma_module): + vector = chroma_module.ChromaVector( + collection_name="collection_1", + config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"), + ) + + vector.delete_by_ids([]) + vector._client.collection.delete.assert_not_called() + + vector.delete_by_ids(["id-1"]) + vector._client.collection.delete.assert_called_with(ids=["id-1"]) + + vector.delete_by_metadata_field("document_id", "doc-1") + vector._client.collection.delete.assert_called_with(where={"document_id": {"$eq": "doc-1"}}) + + vector._client.collection.get.return_value = {"ids": ["id-1"]} + assert vector.text_exists("id-1") is True + vector._client.collection.get.return_value = {} + assert vector.text_exists("id-2") is False + + vector.delete() + vector._client.delete_collection.assert_called_once_with("collection_1") + + +def test_search_by_vector_handles_empty_results(chroma_module): + vector = chroma_module.ChromaVector( + collection_name="collection_1", + config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"), + ) + vector._client.collection.query.return_value = {"ids": [], "documents": [], "metadatas": [], "distances": []} + + assert vector.search_by_vector([0.1, 0.2], top_k=2) == [] + + +def test_search_by_vector_applies_score_threshold_and_sorting(chroma_module): + vector = chroma_module.ChromaVector( + collection_name="collection_1", + config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"), + ) + vector._client.collection.query.return_value = { + "ids": [["id-1", "id-2"]], + "documents": [["doc high", "doc low"]], + "metadatas": [[{"doc_id": "id-1"}, {"doc_id": "id-2"}]], + "distances": [[0.1, 0.8]], + } + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["doc-1"]) + + assert len(docs) == 1 + assert docs[0].page_content == "doc high" + assert docs[0].metadata["score"] == 0.9 + + +def test_search_by_full_text_returns_empty_list(chroma_module): + vector = chroma_module.ChromaVector( + collection_name="collection_1", + config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"), + ) + assert vector.search_by_full_text("query") == [] + + +def test_factory_init_vector_uses_existing_or_generated_collection(chroma_module, monkeypatch): + factory = chroma_module.ChromaVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", index_struct_dict={"vector_store": {"class_prefix": "EXISTING"}}, index_struct=None + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(chroma_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(chroma_module.dify_config, "CHROMA_HOST", "localhost") + monkeypatch.setattr(chroma_module.dify_config, "CHROMA_PORT", 8000) + monkeypatch.setattr(chroma_module.dify_config, "CHROMA_TENANT", None) + monkeypatch.setattr(chroma_module.dify_config, "CHROMA_DATABASE", None) + monkeypatch.setattr(chroma_module.dify_config, "CHROMA_AUTH_PROVIDER", None) + monkeypatch.setattr(chroma_module.dify_config, "CHROMA_AUTH_CREDENTIALS", None) + + with patch.object(chroma_module, "ChromaVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/clickzetta/test_clickzetta_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/clickzetta/test_clickzetta_vector.py new file mode 100644 index 00000000000..0ce5c04dd61 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/clickzetta/test_clickzetta_vector.py @@ -0,0 +1,927 @@ +import importlib +import queue +import sys +import types +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_clickzetta_module(): + clickzetta = types.ModuleType("clickzetta") + + class _FakeCursor: + def __init__(self): + self.execute = MagicMock() + self.executemany = MagicMock() + self.fetchall = MagicMock(return_value=[]) + self.fetchone = MagicMock(return_value=(0,)) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + class _FakeConnection: + def __init__(self): + self.cursor_obj = _FakeCursor() + + def cursor(self): + return self.cursor_obj + + def close(self): + return None + + def connect(**_kwargs): + return _FakeConnection() + + clickzetta.connect = connect + return clickzetta + + +@pytest.fixture +def clickzetta_module(monkeypatch): + monkeypatch.setitem(sys.modules, "clickzetta", _build_fake_clickzetta_module()) + import core.rag.datasource.vdb.clickzetta.clickzetta_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.ClickzettaConfig( + username="username", + password="password", + instance="instance", + service="service", + workspace="workspace", + vcluster="cluster", + schema_name="dify", + ) + + +@pytest.mark.parametrize( + ("field", "error_message"), + [ + ("username", "CLICKZETTA_USERNAME"), + ("password", "CLICKZETTA_PASSWORD"), + ("instance", "CLICKZETTA_INSTANCE"), + ("service", "CLICKZETTA_SERVICE"), + ("workspace", "CLICKZETTA_WORKSPACE"), + ("vcluster", "CLICKZETTA_VCLUSTER"), + ("schema_name", "CLICKZETTA_SCHEMA"), + ], +) +def test_clickzetta_config_validation(clickzetta_module, field, error_message): + values = _config(clickzetta_module).model_dump() + values[field] = "" + with pytest.raises(ValueError, match=error_message): + clickzetta_module.ClickzettaConfig.model_validate(values) + + +def test_parse_metadata_handles_valid_double_encoded_and_invalid_json(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + + parsed = vector._parse_metadata('{"document_id":"doc-1"}', "row-1") + assert parsed["doc_id"] == "row-1" + assert parsed["document_id"] == "doc-1" + + parsed_double = vector._parse_metadata('"{\\"document_id\\": \\"doc-2\\"}"', "row-2") + assert parsed_double["doc_id"] == "row-2" + assert parsed_double["document_id"] == "doc-2" + + parsed_fallback = vector._parse_metadata("not-json", "row-3") + assert parsed_fallback["doc_id"] == "row-3" + assert parsed_fallback["document_id"] == "row-3" + + +def test_safe_doc_id_and_vector_format_helpers(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + + assert vector._format_vector_simple([0.1, 0.2, 0.3]) == "0.1,0.2,0.3" + assert vector._safe_doc_id("abc-123_DEF") == "abc-123_DEF" + assert vector._safe_doc_id("ab c;\n") == "abc" + assert len(vector._safe_doc_id("a" * 300)) == 255 + + +def test_table_exists_returns_false_for_not_found_and_other_exceptions(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + + @contextmanager + def _ctx_not_found(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.execute.side_effect = RuntimeError("CZLH-42000 table or view not found") + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx_not_found + assert vector._table_exists() is False + + @contextmanager + def _ctx_other_error(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.execute.side_effect = RuntimeError("permission denied") + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx_other_error + assert vector._table_exists() is False + + +def test_text_exists_handles_missing_table_and_existing_rows(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + + vector._table_exists = MagicMock(return_value=False) + assert vector.text_exists("doc-1") is False + + vector._table_exists = MagicMock(return_value=True) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.fetchone.return_value = (1,) + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + assert vector.text_exists("doc-1") is True + + +def test_delete_by_ids_and_delete_by_metadata_field_short_circuit(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + vector._execute_write = MagicMock() + + vector.delete_by_ids([]) + vector._execute_write.assert_not_called() + + vector._table_exists = MagicMock(return_value=False) + vector.delete_by_ids(["doc-1"]) + vector._execute_write.assert_not_called() + + vector.delete_by_metadata_field("document_id", "doc-1") + vector._execute_write.assert_not_called() + + +def test_search_short_circuit_behaviors(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + + vector._table_exists = MagicMock(return_value=False) + assert vector.search_by_vector([0.1, 0.2], top_k=2) == [] + + vector._config.enable_inverted_index = False + assert vector.search_by_full_text("query", top_k=2) == [] + + +def test_search_by_like_returns_documents_with_default_score(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + vector._table_exists = MagicMock(return_value=True) + vector._parse_metadata = MagicMock(return_value={"document_id": "doc-1", "doc_id": "seg-1"}) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.fetchall.return_value = [("seg-1", "content", '{"document_id":"doc-1"}')] + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + docs = vector._search_by_like("query", top_k=3, document_ids_filter=["doc-1"]) + + assert len(docs) == 1 + assert docs[0].page_content == "content" + assert docs[0].metadata["score"] == 0.5 + + +def test_factory_initializes_clickzetta_vector(clickzetta_module, monkeypatch): + factory = clickzetta_module.ClickzettaVectorFactory() + dataset = SimpleNamespace(id="dataset-1") + + monkeypatch.setattr(clickzetta_module.Dataset, "gen_collection_name_by_id", lambda _id: "COLLECTION") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_USERNAME", "username") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_PASSWORD", "password") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_INSTANCE", "instance") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_SERVICE", "service") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_WORKSPACE", "workspace") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_VCLUSTER", "cluster") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_SCHEMA", "dify") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_BATCH_SIZE", 10) + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_ENABLE_INVERTED_INDEX", True) + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_ANALYZER_TYPE", "chinese") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_ANALYZER_MODE", "smart") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_VECTOR_DISTANCE_FUNCTION", "cosine_distance") + + with patch.object(clickzetta_module, "ClickzettaVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + + assert result == "vector" + assert vector_cls.call_args.kwargs["collection_name"] == "collection" + + +def test_connection_pool_singleton_and_config_key(clickzetta_module, monkeypatch): + clickzetta_module.ClickzettaConnectionPool._instance = None + monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock()) + + pool_1 = clickzetta_module.ClickzettaConnectionPool.get_instance() + pool_2 = clickzetta_module.ClickzettaConnectionPool.get_instance() + key = pool_1._get_config_key(_config(clickzetta_module)) + + assert pool_1 is pool_2 + assert "username:instance:service:workspace:cluster:dify" in key + + +def test_connection_pool_create_connection_retries_and_configures(clickzetta_module, monkeypatch): + monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock()) + pool = clickzetta_module.ClickzettaConnectionPool() + config = _config(clickzetta_module) + connection = MagicMock() + + monkeypatch.setattr(clickzetta_module.time, "sleep", lambda _s: None) + monkeypatch.setattr( + clickzetta_module.clickzetta, "connect", MagicMock(side_effect=[RuntimeError("boom"), connection]) + ) + pool._configure_connection = MagicMock() + + created = pool._create_connection(config) + + assert created is connection + assert clickzetta_module.clickzetta.connect.call_count == 2 + pool._configure_connection.assert_called_once_with(connection) + + +def test_connection_pool_create_connection_raises_after_retries(clickzetta_module, monkeypatch): + monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock()) + pool = clickzetta_module.ClickzettaConnectionPool() + config = _config(clickzetta_module) + + monkeypatch.setattr(clickzetta_module.time, "sleep", lambda _s: None) + monkeypatch.setattr(clickzetta_module.clickzetta, "connect", MagicMock(side_effect=RuntimeError("boom"))) + + with pytest.raises(RuntimeError, match="boom"): + pool._create_connection(config) + + +def test_connection_pool_configure_and_validate_connection(clickzetta_module): + monkeypatch = pytest.MonkeyPatch() + monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock()) + pool = clickzetta_module.ClickzettaConnectionPool() + + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + connection = MagicMock() + connection.cursor.return_value = cursor + + pool._configure_connection(connection) + assert cursor.execute.call_count >= 2 + assert pool._is_connection_valid(connection) is True + + bad_connection = MagicMock() + bad_connection.cursor.side_effect = RuntimeError("bad connection") + assert pool._is_connection_valid(bad_connection) is False + monkeypatch.undo() + + +def test_connection_pool_configure_connection_swallows_errors(clickzetta_module): + monkeypatch = pytest.MonkeyPatch() + monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock()) + pool = clickzetta_module.ClickzettaConnectionPool() + connection = MagicMock() + connection.cursor.side_effect = RuntimeError("cannot configure") + + pool._configure_connection(connection) + monkeypatch.undo() + + +def test_connection_pool_get_return_cleanup_and_shutdown(clickzetta_module, monkeypatch): + monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock()) + pool = clickzetta_module.ClickzettaConnectionPool() + config = _config(clickzetta_module) + key = pool._get_config_key(config) + + created_connection = MagicMock() + pool._create_connection = MagicMock(return_value=created_connection) + first = pool.get_connection(config) + assert first is created_connection + + reusable_connection = MagicMock() + pool._pools[key] = [(reusable_connection, clickzetta_module.time.time())] + pool._is_connection_valid = MagicMock(return_value=True) + reused = pool.get_connection(config) + assert reused is reusable_connection + + expired_connection = MagicMock() + pool._pools[key] = [(expired_connection, 0.0)] + pool._is_connection_valid = MagicMock(return_value=False) + monkeypatch.setattr(clickzetta_module.time, "time", MagicMock(return_value=1000.0)) + pool.get_connection(config) + expired_connection.close.assert_called_once() + + random_connection = MagicMock() + pool._is_connection_valid = MagicMock(return_value=True) + pool.return_connection(config, random_connection) + assert len(pool._pools[key]) == 1 + + pool._pools[key] = [(MagicMock(), 0.0), (MagicMock(), 1000.0)] + pool._connection_timeout = 10 + pool._cleanup_expired_connections() + assert len(pool._pools[key]) == 1 + + unknown_pool = MagicMock() + pool.return_connection(_config(clickzetta_module).model_copy(update={"workspace": "other"}), unknown_pool) + unknown_pool.close.assert_called_once() + + pool.shutdown() + assert pool._shutdown is True + + +def test_connection_pool_start_cleanup_thread_runs_worker_once(clickzetta_module, monkeypatch): + pool = clickzetta_module.ClickzettaConnectionPool.__new__(clickzetta_module.ClickzettaConnectionPool) + pool._shutdown = False + pool._cleanup_expired_connections = MagicMock(side_effect=lambda: setattr(pool, "_shutdown", True)) + + monkeypatch.setattr(clickzetta_module.time, "sleep", lambda _s: None) + + class _Thread: + def __init__(self, target, daemon): + self._target = target + self.daemon = daemon + self.started = False + + def start(self): + self.started = True + self._target() + + monkeypatch.setattr(clickzetta_module.threading, "Thread", _Thread) + pool._start_cleanup_thread() + + assert pool._cleanup_thread.started is True + pool._cleanup_expired_connections.assert_called_once() + + +def test_vector_init_connection_context_and_helpers(clickzetta_module, monkeypatch): + pool = MagicMock() + pool.get_connection.return_value = "conn" + monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "get_instance", MagicMock(return_value=pool)) + monkeypatch.setattr(clickzetta_module.ClickzettaVector, "_init_write_queue", MagicMock()) + + vector = clickzetta_module.ClickzettaVector("My-Collection", _config(clickzetta_module)) + assert vector._table_name == "my_collection" + + assert vector._get_connection() == "conn" + vector._return_connection("conn") + pool.return_connection.assert_called_with(vector._config, "conn") + + with vector.get_connection_context() as conn: + assert conn == "conn" + assert pool.return_connection.call_count >= 2 + + assert vector.get_type() == "clickzetta" + assert vector._ensure_connection() == "conn" + + +def test_write_queue_initialization_worker_and_execute_write(clickzetta_module, monkeypatch): + class _Thread: + def __init__(self, target, daemon): + self.target = target + self.daemon = daemon + self.started = 0 + + def start(self): + self.started += 1 + + monkeypatch.setattr(clickzetta_module.threading, "Thread", _Thread) + clickzetta_module.ClickzettaVector._write_queue = None + clickzetta_module.ClickzettaVector._write_thread = None + clickzetta_module.ClickzettaVector._shutdown = False + clickzetta_module.ClickzettaVector._init_write_queue() + clickzetta_module.ClickzettaVector._init_write_queue() + assert clickzetta_module.ClickzettaVector._write_thread.started == 1 + + result_queue_ok = queue.Queue() + result_queue_fail = queue.Queue() + clickzetta_module.ClickzettaVector._write_queue = queue.Queue() + clickzetta_module.ClickzettaVector._shutdown = False + clickzetta_module.ClickzettaVector._write_queue.put((lambda x: x + 1, (1,), {}, result_queue_ok)) + clickzetta_module.ClickzettaVector._write_queue.put( + (lambda: (_ for _ in ()).throw(RuntimeError("worker error")), (), {}, result_queue_fail) + ) + clickzetta_module.ClickzettaVector._write_queue.put(None) + clickzetta_module.ClickzettaVector._write_worker() + + assert result_queue_ok.get() == (True, 2) + failed = result_queue_fail.get() + assert failed[0] is False + assert isinstance(failed[1], RuntimeError) + + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + clickzetta_module.ClickzettaVector._write_queue = None + with pytest.raises(RuntimeError, match="Write queue not initialized"): + vector._execute_write(lambda: None) + + class _ImmediateSuccessQueue: + def put(self, task): + func, args, kwargs, result_q = task + result_q.put((True, func(*args, **kwargs))) + + clickzetta_module.ClickzettaVector._write_queue = _ImmediateSuccessQueue() + assert vector._execute_write(lambda x: x * 2, 3) == 6 + + class _ImmediateFailQueue: + def put(self, task): + _, _, _, result_q = task + result_q.put((False, ValueError("write failed"))) + + clickzetta_module.ClickzettaVector._write_queue = _ImmediateFailQueue() + with pytest.raises(ValueError, match="write failed"): + vector._execute_write(lambda: None) + + +def test_table_exists_true_and_create_invokes_write_and_add_texts(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + + @contextmanager + def _ctx_exists(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx_exists + assert vector._table_exists() is True + + vector._execute_write = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="content", metadata={"doc_id": "d1"})] + vector.create(docs, [[0.1, 0.2]]) + vector._execute_write.assert_called_once() + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_create_table_and_indexes_paths(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + vector._create_vector_index = MagicMock() + vector._create_inverted_index = MagicMock() + + vector._table_exists = MagicMock(return_value=True) + vector._create_table_and_indexes([[0.1, 0.2]]) + vector._create_vector_index.assert_not_called() + + vector._table_exists = MagicMock(return_value=False) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + vector._create_table_and_indexes([[0.1, 0.2, 0.3]]) + vector._create_vector_index.assert_called_once() + vector._create_inverted_index.assert_called_once() + + vector._config.enable_inverted_index = False + vector._create_vector_index.reset_mock() + vector._create_inverted_index.reset_mock() + vector._create_table_and_indexes([]) + vector._create_vector_index.assert_called_once() + vector._create_inverted_index.assert_not_called() + + +def test_create_vector_index_branches(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + cursor = MagicMock() + + cursor.fetchall.return_value = [("idx_table_vector", "embedding_vector")] + vector._create_vector_index(cursor) + assert cursor.execute.call_count == 1 + + cursor.reset_mock() + cursor.execute.side_effect = [RuntimeError("show index failed"), None] + vector._create_vector_index(cursor) + assert cursor.execute.call_count == 2 + + cursor.reset_mock() + cursor.execute.side_effect = [None, RuntimeError("already exists")] + cursor.fetchall.return_value = [] + vector._create_vector_index(cursor) + + cursor.reset_mock() + cursor.execute.side_effect = [None, RuntimeError("unexpected")] + cursor.fetchall.return_value = [] + with pytest.raises(RuntimeError, match="unexpected"): + vector._create_vector_index(cursor) + + +def test_create_inverted_index_branches(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + cursor = MagicMock() + + cursor.fetchall.return_value = [("idx_table_1_text", "INVERTED", "page_content")] + vector._create_inverted_index(cursor) + assert cursor.execute.call_count == 1 + + cursor.reset_mock() + cursor.execute.side_effect = [RuntimeError("show failed"), None] + vector._create_inverted_index(cursor) + assert cursor.execute.call_count == 2 + + cursor.reset_mock() + cursor.execute.side_effect = [ + None, + RuntimeError("already has index"), + None, + ] + cursor.fetchall.return_value = [("idx_table_1_text", "INVERTED", "page_content")] + vector._create_inverted_index(cursor) + + cursor.reset_mock() + cursor.execute.side_effect = [None, RuntimeError("other create failure")] + cursor.fetchall.return_value = [] + vector._create_inverted_index(cursor) + + +def test_add_texts_batches_and_insert_batch_behaviors(clickzetta_module, monkeypatch): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._config.batch_size = 2 + vector._table_name = "table_1" + vector._execute_write = MagicMock() + vector._safe_doc_id = MagicMock(side_effect=lambda doc_id: str(doc_id)) + + docs = [ + Document(page_content="doc-1", metadata={"doc_id": "id-1"}), + Document(page_content="doc-2", metadata={"doc_id": "id-2"}), + Document(page_content="doc-3", metadata={"doc_id": "id-3"}), + ] + vectors = [[0.1], [0.2], [0.3]] + + vector.add_texts([], []) + vector._execute_write.assert_not_called() + + added_ids = vector.add_texts(docs, vectors) + assert added_ids == ["id-1", "id-2", "id-3"] + assert vector._execute_write.call_count == 2 + assert vector._execute_write.call_args_list[0].args == ( + vector._insert_batch, + docs[:2], + vectors[:2], + ["id-1", "id-2"], + 0, + 2, + 2, + ) + assert vector._execute_write.call_args_list[1].args == ( + vector._insert_batch, + docs[2:], + vectors[2:], + ["id-3"], + 2, + 2, + 2, + ) + + vector._insert_batch([], [], [], 0, 2, 1) + vector._insert_batch(docs[:1], vectors, ["id-1"], 0, 2, 1) + + bad_doc = Document(page_content="doc-bad", metadata={"doc_id": "id-bad", "bad": {1}}) + good_doc = Document(page_content="doc-good", metadata={"doc_id": "id-good"}) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + vector._insert_batch( + [bad_doc, good_doc], + [[0.1, 0.2], [0.3, 0.4]], + ["id-bad", "id-good"], + 0, + 2, + 1, + ) + + @contextmanager + def _ctx_error(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.executemany.side_effect = RuntimeError("insert failed") + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx_error + with pytest.raises(RuntimeError, match="insert failed"): + vector._insert_batch([good_doc], [[0.1, 0.2]], ["id-good"], 0, 1, 1) + + monkeypatch.setattr(clickzetta_module.uuid, "uuid4", lambda: "generated-id") + vector._safe_doc_id = clickzetta_module.ClickzettaVector._safe_doc_id.__get__(vector) + assert vector._safe_doc_id("") == "generated-id" + assert vector._safe_doc_id("!!!") == "generated-id" + + +def test_delete_by_ids_and_metadata_impl_paths(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + vector._execute_write = MagicMock() + vector._table_exists = MagicMock(return_value=True) + + vector.delete_by_ids(["id-1", "id-2"]) + vector._execute_write.assert_called_once() + assert vector._execute_write.call_args.args[0] == vector._delete_by_ids_impl + + vector._execute_write.reset_mock() + vector.delete_by_metadata_field("document_id", "doc-1") + vector._execute_write.assert_called_once() + assert vector._execute_write.call_args.args[0] == vector._delete_by_metadata_field_impl + + vector._safe_doc_id = MagicMock(side_effect=lambda x: x) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + vector._delete_by_ids_impl(["id-1", "id-2"]) + vector._delete_by_metadata_field_impl("document_id", "doc-1") + + +def test_search_by_vector_covers_cosine_and_l2_paths(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._config.vector_distance_function = "cosine_distance" + vector._table_name = "table_1" + vector._table_exists = MagicMock(return_value=True) + vector._parse_metadata = MagicMock(return_value={"document_id": "doc-1", "doc_id": "seg-1"}) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.fetchall.return_value = [("seg-1", "content", '{"document_id":"doc-1"}', 0.2)] + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + cosine_docs = vector.search_by_vector( + [0.1, 0.2], top_k=3, score_threshold=0.5, document_ids_filter=["doc-1"], filter={"k": "v"} + ) + assert cosine_docs[0].metadata["score"] == pytest.approx(0.9) + + vector._config.vector_distance_function = "l2_distance" + l2_docs = vector.search_by_vector([0.1, 0.2], top_k=3, score_threshold=0.5) + assert l2_docs[0].metadata["score"] == pytest.approx(1 / 1.2) + + +def test_search_by_full_text_success_and_fallback(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + vector._table_exists = MagicMock(return_value=True) + + @contextmanager + def _ctx_success(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.fetchall.return_value = [ + ("seg-1", "content-1", '"{\\"document_id\\":\\"doc-1\\"}"'), + ("seg-2", "content-2", "invalid-json"), + ] + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx_success + docs = vector.search_by_full_text("search'value", top_k=2, document_ids_filter=["doc-1"], filter={"a": 1}) + assert len(docs) == 2 + assert docs[0].metadata["score"] == 1.0 + assert docs[1].metadata["doc_id"] == "seg-2" + + @contextmanager + def _ctx_failure(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.execute.side_effect = RuntimeError("full text failed") + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx_failure + vector._search_by_like = MagicMock(return_value=[Document(page_content="fallback", metadata={"score": 0.5})]) + fallback_docs = vector.search_by_full_text("query", top_k=1) + assert fallback_docs == vector._search_by_like.return_value + + +def test_search_by_like_missing_table_and_delete_table(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + vector._table_exists = MagicMock(return_value=False) + assert vector._search_by_like("query", top_k=1) == [] + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + vector.delete() + + +def test_clickzetta_pool_cleanup_and_shutdown_edge_paths(clickzetta_module): + pool = clickzetta_module.ClickzettaConnectionPool.__new__(clickzetta_module.ClickzettaConnectionPool) + pool._pools = {} + pool._pool_locks = {} + pool._max_pool_size = 1 + pool._connection_timeout = 10 + pool._lock = clickzetta_module.threading.Lock() + pool._shutdown = False + + config = _config(clickzetta_module) + key = pool._get_config_key(config) + pool._pools[key] = [(MagicMock(), 1.0)] + pool._pool_locks[key] = clickzetta_module.threading.Lock() + pool._is_connection_valid = MagicMock(return_value=False) + + conn = MagicMock() + pool.return_connection(config, conn) + conn.close.assert_called_once() + + pool._pools["missing-lock-key"] = [(MagicMock(), 0.0)] + pool._cleanup_expired_connections() + pool.shutdown() + assert pool._shutdown is True + + +def test_clickzetta_pool_cleanup_thread_and_worker_exception_paths(clickzetta_module, monkeypatch): + pool = clickzetta_module.ClickzettaConnectionPool.__new__(clickzetta_module.ClickzettaConnectionPool) + pool._shutdown = False + + def _cleanup_then_fail(): + pool._shutdown = True + raise RuntimeError("cleanup failed") + + pool._cleanup_expired_connections = MagicMock(side_effect=_cleanup_then_fail) + monkeypatch.setattr(clickzetta_module.time, "sleep", lambda _s: None) + + class _Thread: + def __init__(self, target, daemon): + self._target = target + self.daemon = daemon + + def start(self): + self._target() + + monkeypatch.setattr(clickzetta_module.threading, "Thread", _Thread) + pool._start_cleanup_thread() + pool._cleanup_expired_connections.assert_called_once() + + +def test_clickzetta_parse_metadata_and_write_worker_additional_branches(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + + parsed_non_dict = vector._parse_metadata("[1,2,3]", "row-1") + assert parsed_non_dict["doc_id"] == "row-1" + assert parsed_non_dict["document_id"] == "row-1" + + parsed_none = vector._parse_metadata(None, "row-2") + assert parsed_none["doc_id"] == "row-2" + assert parsed_none["document_id"] == "row-2" + + clickzetta_module.ClickzettaVector._shutdown = False + clickzetta_module.ClickzettaVector._write_queue = None + clickzetta_module.ClickzettaVector._write_worker() + + class _BadQueue: + def get(self, timeout): + clickzetta_module.ClickzettaVector._shutdown = True + raise RuntimeError("queue failed") + + clickzetta_module.ClickzettaVector._shutdown = False + clickzetta_module.ClickzettaVector._write_queue = _BadQueue() + clickzetta_module.ClickzettaVector._write_worker() + + +def test_clickzetta_inverted_index_existing_and_insert_non_dict_metadata(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + cursor = MagicMock() + cursor.fetchall.return_value = [("idx_table_1_text", "INVERTED", "page_content")] + cursor.execute.side_effect = [ + None, + RuntimeError("already has index with the same type cannot create inverted index"), + None, + ] + + vector._create_inverted_index(cursor) + + vector._safe_doc_id = MagicMock(side_effect=lambda value: str(value)) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor_obj = MagicMock() + cursor_obj.__enter__.return_value = cursor_obj + cursor_obj.__exit__.return_value = None + connection.cursor.return_value = cursor_obj + yield connection + + vector.get_connection_context = _ctx + vector._insert_batch( + [SimpleNamespace(page_content="content", metadata="not-a-dict")], + [[0.1, 0.2]], + ["doc-1"], + 0, + 1, + 1, + ) + + +def test_clickzetta_full_text_table_missing_and_non_dict_metadata(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._config.enable_inverted_index = True + vector._table_name = "table_1" + + vector._table_exists = MagicMock(return_value=False) + assert vector.search_by_full_text("query") == [] + + vector._table_exists = MagicMock(return_value=True) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.fetchall.return_value = [ + ("seg-1", "content-1", "[1,2,3]"), + ("seg-2", "content-2", None), + ] + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + docs = vector.search_by_full_text("query") + assert len(docs) == 2 + assert docs[0].metadata["doc_id"] == "seg-1" + assert docs[1].metadata["doc_id"] == "seg-2" diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/couchbase/test_couchbase_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/couchbase/test_couchbase_vector.py new file mode 100644 index 00000000000..9fea187615e --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/couchbase/test_couchbase_vector.py @@ -0,0 +1,364 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_couchbase_modules(): + couchbase = types.ModuleType("couchbase") + couchbase_auth = types.ModuleType("couchbase.auth") + couchbase_cluster = types.ModuleType("couchbase.cluster") + couchbase_management = types.ModuleType("couchbase.management") + couchbase_management_search = types.ModuleType("couchbase.management.search") + couchbase_options = types.ModuleType("couchbase.options") + couchbase_vector = types.ModuleType("couchbase.vector_search") + couchbase_search = types.ModuleType("couchbase.search") + + class PasswordAuthenticator: + def __init__(self, user, password): + self.user = user + self.password = password + + class ClusterOptions: + def __init__(self, auth): + self.auth = auth + + class SearchOptions: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class VectorQuery: + def __init__(self, field, vector, top_k): + self.field = field + self.vector = vector + self.top_k = top_k + + class VectorSearch: + @staticmethod + def from_vector_query(vector_query): + return {"vector_query": vector_query} + + class QueryStringQuery: + def __init__(self, query): + self.query = query + + class SearchRequest: + @staticmethod + def create(payload): + return {"payload": payload} + + class SearchIndex: + def __init__(self, name, params, source_name): + self.name = name + self.params = params + self.source_name = source_name + + class _QueryResult: + def __init__(self, rows=None): + self._rows = rows or [] + + def execute(self): + return self + + def __iter__(self): + return iter(self._rows) + + class _SearchIter: + def __init__(self, rows=None): + self._rows = rows or [] + + def rows(self): + return self._rows + + class _Collection: + def __init__(self): + self.upsert = MagicMock(return_value=True) + + class _SearchIndexManager: + def __init__(self): + self.upsert_index = MagicMock() + + class _Scope: + def __init__(self): + self._collection = _Collection() + self._search_index_manager = _SearchIndexManager() + self.search = MagicMock(return_value=_SearchIter()) + + def collection(self, _name): + return self._collection + + def search_indexes(self): + return self._search_index_manager + + class _CollectionManager: + def __init__(self): + self.create_collection = MagicMock() + self.drop_collection = MagicMock() + self.get_all_scopes = MagicMock(return_value=[]) + + class _Bucket: + def __init__(self): + self._scope = _Scope() + self._collections = _CollectionManager() + + def scope(self, _scope_name): + return self._scope + + def collections(self): + return self._collections + + class Cluster: + def __init__(self, connection_string, options): + self.connection_string = connection_string + self.options = options + self._bucket = _Bucket() + self.wait_until_ready = MagicMock() + self.query = MagicMock(return_value=_QueryResult()) + + def bucket(self, _name): + return self._bucket + + couchbase_auth.PasswordAuthenticator = PasswordAuthenticator + couchbase_cluster.Cluster = Cluster + couchbase_management_search.SearchIndex = SearchIndex + couchbase_options.ClusterOptions = ClusterOptions + couchbase_options.SearchOptions = SearchOptions + couchbase_vector.VectorQuery = VectorQuery + couchbase_vector.VectorSearch = VectorSearch + couchbase_search.QueryStringQuery = QueryStringQuery + couchbase_search.SearchRequest = SearchRequest + + couchbase.search = couchbase_search + couchbase.management = couchbase_management + + return { + "couchbase": couchbase, + "couchbase.auth": couchbase_auth, + "couchbase.cluster": couchbase_cluster, + "couchbase.management": couchbase_management, + "couchbase.management.search": couchbase_management_search, + "couchbase.options": couchbase_options, + "couchbase.vector_search": couchbase_vector, + "couchbase.search": couchbase_search, + } + + +@pytest.fixture +def couchbase_module(monkeypatch): + for name, module in _build_fake_couchbase_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.couchbase.couchbase_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.CouchbaseConfig( + connection_string="couchbase://localhost", + user="user", + password="pass", + bucket_name="bucket", + scope_name="scope", + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("connection_string", "", "CONNECTION_STRING is required"), + ("user", "", "COUCHBASE_USER is required"), + ("password", "", "COUCHBASE_PASSWORD is required"), + ("bucket_name", "", "COUCHBASE_PASSWORD is required"), + ("scope_name", "", "COUCHBASE_SCOPE_NAME is required"), + ], +) +def test_couchbase_config_validation(couchbase_module, field, value, message): + values = _config(couchbase_module).model_dump() + values[field] = value + with pytest.raises(ValidationError, match=message): + couchbase_module.CouchbaseConfig.model_validate(values) + + +def test_init_sets_cluster_handles(couchbase_module): + vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module)) + + assert vector._bucket_name == "bucket" + assert vector._scope_name == "scope" + vector._cluster.wait_until_ready.assert_called_once() + + +def test_create_and_create_collection_branches(couchbase_module, monkeypatch): + vector = couchbase_module.CouchbaseVector.__new__(couchbase_module.CouchbaseVector) + vector._collection_name = "collection_1" + vector._client_config = _config(couchbase_module) + vector._scope_name = "scope" + vector._bucket_name = "bucket" + vector._bucket = MagicMock() + vector._scope = MagicMock() + vector._collection_exists = MagicMock(return_value=False) + vector.add_texts = MagicMock() + + monkeypatch.setattr(couchbase_module.uuid, "uuid4", lambda: "a-b-c") + vector._create_collection = MagicMock() + docs = [Document(page_content="text", metadata={"doc_id": "id-1"})] + vector.create(docs, [[0.1, 0.2]]) + + vector._create_collection.assert_called_once_with(uuid="abc", vector_length=2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(couchbase_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(couchbase_module.redis_client, "set", MagicMock()) + + vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module)) + monkeypatch.setattr(couchbase_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection(vector_length=2, uuid="uuid-1") + vector._bucket.collections().create_collection.assert_not_called() + + monkeypatch.setattr(couchbase_module.redis_client, "get", MagicMock(return_value=None)) + vector._collection_exists = MagicMock(return_value=True) + vector._create_collection(vector_length=2, uuid="uuid-2") + vector._bucket.collections().create_collection.assert_not_called() + + vector._collection_exists = MagicMock(return_value=False) + vector._create_collection(vector_length=3, uuid="uuid-3") + + vector._bucket.collections().create_collection.assert_called_once_with("scope", "collection_1") + vector._scope.search_indexes().upsert_index.assert_called_once() + search_index = vector._scope.search_indexes().upsert_index.call_args.args[0] + assert search_index.name == "collection_1_search" + assert ( + search_index.params["mapping"]["types"]["scope.collection_1"]["properties"]["embedding"]["fields"][0]["dims"] + == 3 + ) + couchbase_module.redis_client.set.assert_called_once() + + +def test_collection_exists_get_type_and_add_texts(couchbase_module): + vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module)) + + scope_obj = SimpleNamespace(name="scope", collections=[SimpleNamespace(name="collection_1")]) + vector._bucket.collections().get_all_scopes.return_value = [scope_obj] + assert vector._collection_exists("collection_1") is True + + scope_obj = SimpleNamespace(name="scope", collections=[SimpleNamespace(name="other")]) + vector._bucket.collections().get_all_scopes.return_value = [scope_obj] + assert vector._collection_exists("collection_1") is False + + vector._get_uuids = MagicMock(return_value=["id-1", "id-2"]) + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + Document(page_content="b", metadata={"doc_id": "id-2"}), + ] + ids = vector.add_texts(docs, [[0.1], [0.2]]) + + assert ids == ["id-1", "id-2"] + assert vector._scope.collection("collection_1").upsert.call_count == 2 + assert vector.get_type() == couchbase_module.VectorType.COUCHBASE + + +def test_query_delete_helpers(couchbase_module): + vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module)) + + vector._cluster.query.return_value = SimpleNamespace(execute=lambda: iter([{"count": 2}])) + assert vector.text_exists("id-1") is True + + vector._cluster.query.return_value = SimpleNamespace(execute=lambda: iter([])) + assert vector.text_exists("id-2") is False + + query_result = MagicMock() + query_result.execute.return_value = None + vector._cluster.query.return_value = query_result + + vector.delete_by_ids(["id-1", "id-2"]) + vector.delete_by_document_id("id-1") + vector.delete_by_metadata_field("document_id", "doc-1") + assert vector._cluster.query.call_count >= 3 + + vector._cluster.query.side_effect = RuntimeError("delete failed") + vector.delete_by_ids(["id-3"]) + + +def test_search_methods_and_format_metadata(couchbase_module): + vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module)) + + row_1 = SimpleNamespace(fields={"text": "doc-a", "metadata.document_id": "d-1"}, score=0.9) + row_2 = SimpleNamespace(fields={"text": "doc-b", "metadata.document_id": "d-2"}, score=0.3) + vector._scope.search.return_value = SimpleNamespace(rows=lambda: [row_1, row_2]) + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].page_content == "doc-a" + assert docs[0].metadata["document_id"] == "d-1" + assert docs[0].metadata["score"] == pytest.approx(0.9) + + vector._scope.search.side_effect = RuntimeError("search error") + with pytest.raises(ValueError, match="Search failed"): + vector.search_by_vector([0.1], top_k=1) + + vector._scope.search.side_effect = None + row_3 = SimpleNamespace(fields={"text": "full-text", "metadata.doc_id": "x"}, score=0.7) + vector._scope.search.return_value = SimpleNamespace(rows=lambda: [row_3]) + docs = vector.search_by_full_text("hello", top_k=1) + assert len(docs) == 1 + assert docs[0].metadata["doc_id"] == "x" + + vector._scope.search.side_effect = RuntimeError("full text failed") + with pytest.raises(ValueError, match="Search failed"): + vector.search_by_full_text("hello", top_k=1) + + assert vector._format_metadata({"metadata.a": 1, "plain": 2}) == {"a": 1, "plain": 2} + + +def test_delete_collection_and_factory(couchbase_module, monkeypatch): + vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module)) + scopes = [ + SimpleNamespace(collections=[SimpleNamespace(name="other")]), + SimpleNamespace(collections=[SimpleNamespace(name="collection_1")]), + ] + vector._bucket.collections().get_all_scopes.return_value = scopes + + vector.delete() + vector._bucket.collections().drop_collection.assert_called_once_with("_default", "collection_1") + + factory = couchbase_module.CouchbaseVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(couchbase_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr( + couchbase_module, + "current_app", + SimpleNamespace( + config={ + "COUCHBASE_CONNECTION_STRING": "couchbase://localhost", + "COUCHBASE_USER": "user", + "COUCHBASE_PASSWORD": "pass", + "COUCHBASE_BUCKET_NAME": "bucket", + "COUCHBASE_SCOPE_NAME": "scope", + } + ), + ) + + with patch.object(couchbase_module, "CouchbaseVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_ja_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_ja_vector.py new file mode 100644 index 00000000000..edd29a46491 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_ja_vector.py @@ -0,0 +1,121 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + + +def _build_fake_elasticsearch_modules(): + elasticsearch = types.ModuleType("elasticsearch") + + class ConnectionError(Exception): + pass + + class Elasticsearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.ping = MagicMock(return_value=True) + self.info = MagicMock(return_value={"version": {"number": "8.12.0"}}) + self.indices = SimpleNamespace( + refresh=MagicMock(), delete=MagicMock(), exists=MagicMock(return_value=False), create=MagicMock() + ) + + elasticsearch.Elasticsearch = Elasticsearch + elasticsearch.ConnectionError = ConnectionError + return {"elasticsearch": elasticsearch} + + +@pytest.fixture +def elasticsearch_ja_module(monkeypatch): + for name, module in _build_fake_elasticsearch_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector as ja_module + import core.rag.datasource.vdb.elasticsearch.elasticsearch_vector as base_module + + importlib.reload(base_module) + return importlib.reload(ja_module) + + +def test_create_collection_cache_hit(elasticsearch_ja_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(elasticsearch_ja_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(elasticsearch_ja_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(elasticsearch_ja_module.redis_client, "set", MagicMock()) + + vector = elasticsearch_ja_module.ElasticSearchJaVector.__new__(elasticsearch_ja_module.ElasticSearchJaVector) + vector._collection_name = "test" + vector._client = MagicMock() + + vector.create_collection([[0.1, 0.2]], [{}]) + + vector._client.indices.create.assert_not_called() + elasticsearch_ja_module.redis_client.set.assert_not_called() + + +def test_create_collection_create_and_exists_paths(elasticsearch_ja_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(elasticsearch_ja_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(elasticsearch_ja_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(elasticsearch_ja_module.redis_client, "set", MagicMock()) + + vector = elasticsearch_ja_module.ElasticSearchJaVector.__new__(elasticsearch_ja_module.ElasticSearchJaVector) + vector._collection_name = "test" + vector._client = MagicMock() + + vector._client.indices.exists.return_value = False + vector.create_collection([[0.1, 0.2, 0.3]], [{}]) + + vector._client.indices.create.assert_called_once() + kwargs = vector._client.indices.create.call_args.kwargs + assert kwargs["index"] == "test" + assert kwargs["mappings"]["properties"][elasticsearch_ja_module.Field.VECTOR]["dims"] == 3 + elasticsearch_ja_module.redis_client.set.assert_called_once() + + vector._client.indices.create.reset_mock() + elasticsearch_ja_module.redis_client.set.reset_mock() + vector._client.indices.exists.return_value = True + vector.create_collection([[0.1, 0.2]], [{}]) + + vector._client.indices.create.assert_not_called() + elasticsearch_ja_module.redis_client.set.assert_called_once() + + +def test_ja_factory_uses_existing_or_generated_collection(elasticsearch_ja_module, monkeypatch): + factory = elasticsearch_ja_module.ElasticSearchJaVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(elasticsearch_ja_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr( + elasticsearch_ja_module, + "current_app", + SimpleNamespace( + config={ + "ELASTICSEARCH_HOST": "localhost", + "ELASTICSEARCH_PORT": 9200, + "ELASTICSEARCH_USERNAME": "elastic", + "ELASTICSEARCH_PASSWORD": "secret", + } + ), + ) + + with patch.object(elasticsearch_ja_module, "ElasticSearchJaVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["index_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["index_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_vector.py new file mode 100644 index 00000000000..9ecf0caa244 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_vector.py @@ -0,0 +1,405 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_elasticsearch_modules(): + elasticsearch = types.ModuleType("elasticsearch") + + class ConnectionError(Exception): + pass + + class Elasticsearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.ping = MagicMock(return_value=True) + self.info = MagicMock(return_value={"version": {"number": "8.12.0-SNAPSHOT"}}) + self.index = MagicMock() + self.exists = MagicMock(return_value=False) + self.delete = MagicMock() + self.search = MagicMock(return_value={"hits": {"hits": []}}) + self.indices = SimpleNamespace( + refresh=MagicMock(), + delete=MagicMock(), + exists=MagicMock(return_value=False), + create=MagicMock(), + ) + + elasticsearch.Elasticsearch = Elasticsearch + elasticsearch.ConnectionError = ConnectionError + return {"elasticsearch": elasticsearch} + + +@pytest.fixture +def elasticsearch_module(monkeypatch): + for name, module in _build_fake_elasticsearch_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.elasticsearch.elasticsearch_vector as module + + return importlib.reload(module) + + +def _regular_config(module, **overrides): + values = { + "host": "localhost", + "port": 9200, + "username": "elastic", + "password": "secret", + "verify_certs": False, + "request_timeout": 10, + "retry_on_timeout": True, + "max_retries": 3, + } + values.update(overrides) + return module.ElasticSearchConfig.model_validate(values) + + +def _cloud_config(module, **overrides): + values = { + "use_cloud": True, + "cloud_url": "https://cloud.example:9243", + "api_key": "api-key", + "verify_certs": True, + "ca_certs": "/tmp/ca.pem", + "request_timeout": 10, + "retry_on_timeout": True, + "max_retries": 3, + } + values.update(overrides) + return module.ElasticSearchConfig.model_validate(values) + + +@pytest.mark.parametrize( + ("values", "message"), + [ + ({"use_cloud": True, "cloud_url": None, "api_key": "x"}, "cloud_url is required"), + ({"use_cloud": True, "cloud_url": "https://cloud", "api_key": None}, "api_key is required"), + ({"host": None, "port": 9200, "username": "u", "password": "p"}, "HOST is required"), + ({"host": "h", "port": None, "username": "u", "password": "p"}, "PORT is required"), + ({"host": "h", "port": 9200, "username": None, "password": "p"}, "USERNAME is required"), + ({"host": "h", "port": 9200, "username": "u", "password": None}, "PASSWORD is required"), + ], +) +def test_elasticsearch_config_validation(elasticsearch_module, values, message): + with pytest.raises(ValidationError, match=message): + elasticsearch_module.ElasticSearchConfig.model_validate(values) + + +def test_init_client_cloud_configuration(elasticsearch_module): + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + client = MagicMock() + client.ping.return_value = True + + with patch.object(elasticsearch_module, "Elasticsearch", return_value=client) as es_cls: + result = vector._init_client(_cloud_config(elasticsearch_module)) + + assert result is client + kwargs = es_cls.call_args.kwargs + assert kwargs["hosts"] == ["https://cloud.example:9243"] + assert kwargs["api_key"] == "api-key" + assert kwargs["verify_certs"] is True + assert kwargs["ca_certs"] == "/tmp/ca.pem" + + +def test_init_client_regular_https_and_http_fallback(elasticsearch_module): + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + client = MagicMock() + client.ping.return_value = True + + with patch.object(elasticsearch_module, "Elasticsearch", return_value=client) as es_cls: + vector._init_client( + _regular_config( + elasticsearch_module, + host="https://es.example", + port=9443, + verify_certs=True, + ca_certs="/tmp/ca.pem", + ) + ) + kwargs = es_cls.call_args.kwargs + assert kwargs["hosts"] == ["https://es.example:9443"] + assert kwargs["verify_certs"] is True + assert kwargs["ca_certs"] == "/tmp/ca.pem" + + with patch.object(elasticsearch_module, "Elasticsearch", return_value=client) as es_cls: + vector._init_client(_regular_config(elasticsearch_module, host="es.internal", port=9200)) + kwargs = es_cls.call_args.kwargs + assert kwargs["hosts"] == ["http://es.internal:9200"] + assert "verify_certs" not in kwargs + + +def test_init_client_connection_failures(elasticsearch_module): + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + + client = MagicMock() + client.ping.return_value = False + with patch.object(elasticsearch_module, "Elasticsearch", return_value=client): + with pytest.raises(ConnectionError, match="Failed to connect"): + vector._init_client(_regular_config(elasticsearch_module)) + + with patch.object( + elasticsearch_module, + "Elasticsearch", + side_effect=elasticsearch_module.ElasticsearchConnectionError("boom"), + ): + with pytest.raises(ConnectionError, match="Vector database connection error"): + vector._init_client(_regular_config(elasticsearch_module)) + + with patch.object(elasticsearch_module, "Elasticsearch", side_effect=RuntimeError("oops")): + with pytest.raises(ConnectionError, match="initialization failed"): + vector._init_client(_regular_config(elasticsearch_module)) + + +def test_init_get_version_and_check_version(elasticsearch_module): + with ( + patch.object(elasticsearch_module.ElasticSearchVector, "_init_client", return_value=MagicMock()) as init_client, + patch.object(elasticsearch_module.ElasticSearchVector, "_get_version", return_value="8.10.0") as get_version, + patch.object(elasticsearch_module.ElasticSearchVector, "_check_version") as check_version, + ): + vector = elasticsearch_module.ElasticSearchVector( + "collection_1", _regular_config(elasticsearch_module), attributes=["doc_id"] + ) + + init_client.assert_called_once() + get_version.assert_called_once() + check_version.assert_called_once() + assert vector._attributes == ["doc_id"] + + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + vector._client = MagicMock() + vector._client.info.return_value = {"version": {"number": "8.13.2-SNAPSHOT"}} + assert vector._get_version() == "8.13.2" + + vector._version = "7.17.0" + with pytest.raises(ValueError, match="greater than 8.0.0"): + vector._check_version() + + vector._version = "8.0.0" + vector._check_version() + + +def test_crud_methods_and_get_type(elasticsearch_module): + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + vector._client.indices = SimpleNamespace(refresh=MagicMock(), delete=MagicMock()) + vector._get_uuids = MagicMock(return_value=["id-1", "id-2"]) + + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + Document(page_content="b", metadata={"doc_id": "id-2"}), + ] + + ids = vector.add_texts(docs, [[0.1], [0.2]]) + assert ids == ["id-1", "id-2"] + assert vector._client.index.call_count == 2 + vector._client.indices.refresh.assert_called_once_with(index="collection_1") + + vector._client.exists.return_value = True + assert vector.text_exists("id-1") is True + + vector.delete_by_ids([]) + vector._client.delete.assert_not_called() + vector.delete_by_ids(["id-1", "id-2"]) + assert vector._client.delete.call_count == 2 + + vector._client.search.return_value = {"hits": {"hits": [{"_id": "id-1"}]}} + vector.delete_by_ids = MagicMock() + vector.delete_by_metadata_field("doc_id", "d1") + vector.delete_by_ids.assert_called_once_with(["id-1"]) + + vector.delete_by_ids.reset_mock() + vector._client.search.return_value = {"hits": {"hits": []}} + vector.delete_by_metadata_field("doc_id", "d2") + vector.delete_by_ids.assert_not_called() + + vector.delete() + vector._client.indices.delete.assert_called_once_with(index="collection_1") + assert vector.get_type() == elasticsearch_module.VectorType.ELASTICSEARCH + + +def test_search_by_vector_and_full_text(elasticsearch_module): + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_score": 0.8, + "_source": { + elasticsearch_module.Field.CONTENT_KEY: "doc-a", + elasticsearch_module.Field.VECTOR: [0.1], + elasticsearch_module.Field.METADATA_KEY: {"doc_id": "1", "document_id": "d-1"}, + }, + }, + { + "_score": 0.2, + "_source": { + elasticsearch_module.Field.CONTENT_KEY: "doc-b", + elasticsearch_module.Field.VECTOR: [0.2], + elasticsearch_module.Field.METADATA_KEY: {"doc_id": "2", "document_id": "d-2"}, + }, + }, + ] + } + } + + docs = vector.search_by_vector( + [0.1, 0.2], + top_k=2, + score_threshold=0.5, + document_ids_filter=["d-1", "d-2"], + ) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.8) + knn = vector._client.search.call_args.kwargs["knn"] + assert knn["k"] == 2 + assert knn["num_candidates"] == 3 + assert "filter" in knn + + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_source": { + elasticsearch_module.Field.CONTENT_KEY: "text-hit", + elasticsearch_module.Field.VECTOR: [0.3], + elasticsearch_module.Field.METADATA_KEY: {"doc_id": "3"}, + } + } + ] + } + } + docs = vector.search_by_full_text("hello", top_k=3, document_ids_filter=["d-3"]) + assert len(docs) == 1 + assert docs[0].page_content == "text-hit" + query = vector._client.search.call_args.kwargs["query"] + assert "bool" in query + + +def test_create_and_create_collection_paths(elasticsearch_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(elasticsearch_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(elasticsearch_module.redis_client, "set", MagicMock()) + + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + vector._client.indices = SimpleNamespace(exists=MagicMock(return_value=False), create=MagicMock()) + + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="a", metadata={"doc_id": "1"})] + vector.create(docs, [[0.1]]) + vector.create_collection.assert_called_once() + vector.add_texts.assert_called_once_with(docs, [[0.1]]) + + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + vector._client.indices = SimpleNamespace(exists=MagicMock(return_value=False), create=MagicMock()) + + monkeypatch.setattr(elasticsearch_module.redis_client, "get", MagicMock(return_value=1)) + vector.create_collection([[0.1, 0.2]], [{}]) + vector._client.indices.create.assert_not_called() + + monkeypatch.setattr(elasticsearch_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.indices.exists.return_value = False + vector.create_collection([[0.1, 0.2]], [{}]) + vector._client.indices.create.assert_called_once() + mappings = vector._client.indices.create.call_args.kwargs["mappings"] + assert mappings["properties"][elasticsearch_module.Field.VECTOR]["dims"] == 2 + elasticsearch_module.redis_client.set.assert_called_once() + + vector._client.indices.create.reset_mock() + elasticsearch_module.redis_client.set.reset_mock() + vector._client.indices.exists.return_value = True + vector.create_collection([[0.1, 0.2]], [{}]) + vector._client.indices.create.assert_not_called() + elasticsearch_module.redis_client.set.assert_called_once() + + +def test_elasticsearch_factory_branches(elasticsearch_module, monkeypatch): + factory = elasticsearch_module.ElasticSearchVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(elasticsearch_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + + monkeypatch.setattr( + elasticsearch_module, + "current_app", + SimpleNamespace( + config={ + "ELASTICSEARCH_USE_CLOUD": False, + "ELASTICSEARCH_HOST": "es-host", + "ELASTICSEARCH_PORT": 9200, + "ELASTICSEARCH_USERNAME": "elastic", + "ELASTICSEARCH_PASSWORD": "secret", + "ELASTICSEARCH_VERIFY_CERTS": False, + } + ), + ) + + with patch.object(elasticsearch_module, "ElasticSearchVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + assert result_1 == "vector" + cfg = vector_cls.call_args.kwargs["config"] + assert cfg.use_cloud is False + assert vector_cls.call_args.kwargs["index_name"] == "EXISTING_COLLECTION" + + monkeypatch.setattr( + elasticsearch_module, + "current_app", + SimpleNamespace( + config={ + "ELASTICSEARCH_USE_CLOUD": True, + "ELASTICSEARCH_CLOUD_URL": "https://cloud.elastic", + "ELASTICSEARCH_API_KEY": "api-key", + "ELASTICSEARCH_VERIFY_CERTS": True, + } + ), + ) + with patch.object(elasticsearch_module, "ElasticSearchVector", return_value="vector") as vector_cls: + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + assert result_2 == "vector" + cfg = vector_cls.call_args.kwargs["config"] + assert cfg.use_cloud is True + assert cfg.cloud_url == "https://cloud.elastic" + assert dataset_without_index.index_struct is not None + + monkeypatch.setattr( + elasticsearch_module, + "current_app", + SimpleNamespace( + config={ + "ELASTICSEARCH_USE_CLOUD": True, + "ELASTICSEARCH_CLOUD_URL": None, + "ELASTICSEARCH_HOST": "fallback-host", + "ELASTICSEARCH_PORT": 9201, + "ELASTICSEARCH_USERNAME": "elastic", + "ELASTICSEARCH_PASSWORD": "secret", + } + ), + ) + with patch.object(elasticsearch_module, "ElasticSearchVector", return_value="vector") as vector_cls: + factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + cfg = vector_cls.call_args.kwargs["config"] + assert cfg.use_cloud is False + assert cfg.host == "fallback-host" diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/hologres/test_hologres_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/hologres/test_hologres_vector.py new file mode 100644 index 00000000000..5d9e744ded4 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/hologres/test_hologres_vector.py @@ -0,0 +1,371 @@ +import importlib +import json +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_hologres_modules(): + holo_module = types.ModuleType("holo_search_sdk") + holo_types_module = types.ModuleType("holo_search_sdk.types") + + holo_types_module.BaseQuantizationType = str + holo_types_module.DistanceType = str + holo_types_module.TokenizerType = str + + def _connect(**kwargs): + client = MagicMock() + client.kwargs = kwargs + client.connect = MagicMock() + client.check_table_exist = MagicMock(return_value=False) + client.open_table = MagicMock(return_value=MagicMock()) + client.execute = MagicMock(return_value=[]) + client.drop_table = MagicMock() + return client + + holo_module.connect = MagicMock(side_effect=_connect) + + return { + "holo_search_sdk": holo_module, + "holo_search_sdk.types": holo_types_module, + } + + +@pytest.fixture +def hologres_module(monkeypatch): + for name, module in _build_fake_hologres_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.hologres.hologres_vector as module + + return importlib.reload(module) + + +def _valid_config(module): + return module.HologresVectorConfig( + host="localhost", + port=80, + database="dify", + access_key_id="ak", + access_key_secret="sk", + schema_name="public", + tokenizer="jieba", + distance_method="Cosine", + base_quantization_type="rabitq", + max_degree=64, + ef_construction=400, + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config HOLOGRES_HOST is required"), + ("database", "", "config HOLOGRES_DATABASE is required"), + ("access_key_id", "", "config HOLOGRES_ACCESS_KEY_ID is required"), + ("access_key_secret", "", "config HOLOGRES_ACCESS_KEY_SECRET is required"), + ], +) +def test_hologres_config_validation(hologres_module, field, value, message): + values = _valid_config(hologres_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + hologres_module.HologresVectorConfig.model_validate(values) + + +def test_init_client_and_get_type(hologres_module): + vector = hologres_module.HologresVector("Collection_One", _valid_config(hologres_module)) + + hologres_module.holo.connect.assert_called_once_with( + host="localhost", + port=80, + database="dify", + access_key_id="ak", + access_key_secret="sk", + schema="public", + ) + vector._client.connect.assert_called_once() + assert vector.table_name == "embedding_collection_one" + assert vector.get_type() == hologres_module.VectorType.HOLOGRES + + +def test_create_delegates_collection_creation_and_upsert(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})] + + result = vector.create(docs, [[0.1, 0.2]]) + + assert result is None + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_add_texts_returns_empty_for_empty_documents(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + + assert vector.add_texts([], []) == [] + vector._client.open_table.assert_not_called() + + +def test_add_texts_batches_and_serializes_metadata(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + table = vector._client.open_table.return_value + documents = [ + Document(page_content=f"doc-{i}", metadata={"doc_id": f"id-{i}", "document_id": f"document-{i}"}) + for i in range(100) + ] + documents.append(SimpleNamespace(page_content="doc-100", metadata=None)) + embeddings = [[float(i)] for i in range(len(documents))] + + ids = vector.add_texts(documents, embeddings) + + assert ids[:2] == ["id-0", "id-1"] + assert ids[-1] == "" + assert len(ids) == 101 + assert vector._client.open_table.call_count == 2 + assert table.upsert_multi.call_count == 2 + first_call = table.upsert_multi.call_args_list[0].kwargs + second_call = table.upsert_multi.call_args_list[1].kwargs + assert first_call["index_column"] == "id" + assert first_call["column_names"] == ["id", "text", "meta", "embedding"] + assert first_call["update_columns"] == ["text", "meta", "embedding"] + assert len(first_call["values"]) == 100 + assert json.loads(first_call["values"][0][2]) == {"doc_id": "id-0", "document_id": "document-0"} + assert second_call["values"][0][0] == "" + assert second_call["values"][0][2] == "{}" + + +def test_text_exists_handles_missing_and_present_tables(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.side_effect = [False, True] + vector._client.execute.return_value = [(1,)] + + assert vector.text_exists("seg-1") is False + assert vector.text_exists("seg-1") is True + vector._client.execute.assert_called_once() + + +def test_get_ids_by_metadata_field_returns_ids_or_none(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.execute.side_effect = [[("id-1",), ("id-2",)], []] + + assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"] + assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None + + +def test_delete_by_ids_branches(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + + vector.delete_by_ids([]) + vector._client.check_table_exist.assert_not_called() + + vector._client.check_table_exist.return_value = False + vector.delete_by_ids(["id-1"]) + vector._client.execute.assert_not_called() + + vector._client.check_table_exist.return_value = True + vector.delete_by_ids(["id-1", "id-2"]) + vector._client.execute.assert_called_once() + + +def test_delete_by_metadata_field_branches(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.return_value = False + + vector.delete_by_metadata_field("document_id", "doc-1") + vector._client.execute.assert_not_called() + + vector._client.check_table_exist.return_value = True + vector.delete_by_metadata_field("document_id", "doc-1") + vector._client.execute.assert_called_once() + + +def test_search_by_vector_returns_empty_when_table_missing(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.return_value = False + + assert vector.search_by_vector([0.1, 0.2]) == [] + + +def test_search_by_vector_applies_filter_and_processes_results(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.return_value = True + table = vector._client.open_table.return_value + query = MagicMock() + table.search_vector.return_value = query + query.select.return_value = query + query.limit.return_value = query + query.where.return_value = query + query.fetchall.return_value = [ + (0.2, "seg-1", "doc-1", '{"doc_id":"seg-1","document_id":"doc-1"}'), + (0.9, "seg-2", "doc-2", {"doc_id": "seg-2", "document_id": "doc-2"}), + ] + + docs = vector.search_by_vector( + [0.1, 0.2], + top_k=2, + score_threshold=0.5, + document_ids_filter=["doc-1"], + ) + + assert len(docs) == 1 + assert docs[0].page_content == "doc-1" + assert docs[0].metadata["doc_id"] == "seg-1" + assert docs[0].metadata["score"] == pytest.approx(0.8) + table.search_vector.assert_called_once() + query.where.assert_called_once() + + +def test_search_by_full_text_returns_empty_when_table_missing(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.return_value = False + + assert vector.search_by_full_text("query") == [] + + +def test_search_by_full_text_applies_filter_and_processes_results(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.return_value = True + table = vector._client.open_table.return_value + search_query = MagicMock() + table.search_text.return_value = search_query + search_query.limit.return_value = search_query + search_query.where.return_value = search_query + search_query.fetchall.return_value = [ + ("seg-1", "doc-1", '{"doc_id":"seg-1"}', [0.1], 0.95), + ("seg-2", "doc-2", {"doc_id": "seg-2"}, [0.2], 0.7), + ] + + docs = vector.search_by_full_text("query", top_k=2, document_ids_filter=["doc-1"]) + + assert len(docs) == 2 + assert docs[0].metadata["doc_id"] == "seg-1" + assert docs[0].metadata["score"] == pytest.approx(0.95) + assert docs[1].metadata["score"] == pytest.approx(0.7) + table.search_text.assert_called_once() + search_query.where.assert_called_once() + + +def test_delete_handles_existing_and_missing_tables(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.side_effect = [False, True] + + vector.delete() + vector._client.drop_table.assert_not_called() + + vector.delete() + vector._client.drop_table.assert_called_once_with(vector.table_name) + + +def test_create_collection_returns_early_when_cache_hits(hologres_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = False + monkeypatch.setattr(hologres_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(hologres_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(hologres_module.redis_client, "set", MagicMock()) + + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._create_collection(3) + + vector._client.check_table_exist.assert_not_called() + hologres_module.redis_client.set.assert_not_called() + + +def test_create_collection_creates_table_and_indexes(hologres_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = False + monkeypatch.setattr(hologres_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(hologres_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(hologres_module.redis_client, "set", MagicMock()) + monkeypatch.setattr(hologres_module.time, "sleep", MagicMock()) + + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.side_effect = [False, False, True] + table = vector._client.open_table.return_value + + vector._create_collection(3) + + vector._client.execute.assert_called_once() + table.set_vector_index.assert_called_once_with( + column="embedding", + distance_method="Cosine", + base_quantization_type="rabitq", + max_degree=64, + ef_construction=400, + use_reorder=True, + ) + table.create_text_index.assert_called_once_with( + index_name="ft_idx_collection_one", + column="text", + tokenizer="jieba", + ) + hologres_module.redis_client.set.assert_called_once() + + +def test_create_collection_raises_when_table_never_becomes_ready(hologres_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = False + monkeypatch.setattr(hologres_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(hologres_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(hologres_module.redis_client, "set", MagicMock()) + monkeypatch.setattr(hologres_module.time, "sleep", MagicMock()) + + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.side_effect = [False] + [False] * 15 + + with pytest.raises(RuntimeError, match="was not ready after 30s"): + vector._create_collection(3) + + hologres_module.redis_client.set.assert_not_called() + + +def test_hologres_factory_uses_existing_or_generated_collection(hologres_module, monkeypatch): + factory = hologres_module.HologresVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "existing_collection"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(hologres_module.Dataset, "gen_collection_name_by_id", lambda _id: "generated_collection") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_HOST", "127.0.0.1") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_PORT", 80) + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_DATABASE", "dify") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_ACCESS_KEY_ID", "ak") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_ACCESS_KEY_SECRET", "sk") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_SCHEMA", "public") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_TOKENIZER", "jieba") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_DISTANCE_METHOD", "Cosine") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_BASE_QUANTIZATION_TYPE", "rabitq") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_MAX_DEGREE", 64) + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_EF_CONSTRUCTION", 400) + + with patch.object(hologres_module, "HologresVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "generated_collection" + generated_config = vector_cls.call_args_list[1].kwargs["config"] + assert generated_config.host == "127.0.0.1" + assert generated_config.database == "dify" + assert generated_config.access_key_id == "ak" + assert json.loads(dataset_without_index.index_struct) == { + "type": hologres_module.VectorType.HOLOGRES, + "vector_store": {"class_prefix": "generated_collection"}, + } diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/huawei/test_huawei_cloud_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/huawei/test_huawei_cloud_vector.py new file mode 100644 index 00000000000..9d23dfcf631 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/huawei/test_huawei_cloud_vector.py @@ -0,0 +1,243 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_elasticsearch_modules(): + elasticsearch = types.ModuleType("elasticsearch") + + class Elasticsearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.index = MagicMock() + self.exists = MagicMock(return_value=False) + self.delete = MagicMock() + self.search = MagicMock(return_value={"hits": {"hits": []}}) + self.indices = SimpleNamespace( + refresh=MagicMock(), delete=MagicMock(), exists=MagicMock(return_value=False), create=MagicMock() + ) + + elasticsearch.Elasticsearch = Elasticsearch + return {"elasticsearch": elasticsearch} + + +@pytest.fixture +def huawei_module(monkeypatch): + for name, module in _build_fake_elasticsearch_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.huawei.huawei_cloud_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.HuaweiCloudVectorConfig(hosts="http://localhost:9200", username="user", password="pass") + + +def test_create_ssl_context(huawei_module): + ctx = huawei_module.create_ssl_context() + assert ctx.check_hostname is False + assert ctx.verify_mode == huawei_module.ssl.CERT_NONE + + +def test_huawei_config_validation_and_params(huawei_module): + with pytest.raises(ValidationError, match="HOSTS is required"): + huawei_module.HuaweiCloudVectorConfig.model_validate({"hosts": ""}) + + config = _config(huawei_module) + params = config.to_elasticsearch_params() + assert params["hosts"] == ["http://localhost:9200"] + assert params["basic_auth"] == ("user", "pass") + + config = huawei_module.HuaweiCloudVectorConfig(hosts="host1,host2", username=None, password=None) + params = config.to_elasticsearch_params() + assert "basic_auth" not in params + + +def test_init_get_type_and_add_texts(huawei_module): + vector = huawei_module.HuaweiCloudVector("COLLECTION", _config(huawei_module)) + + assert vector._collection_name == "collection" + assert vector.get_type() == huawei_module.VectorType.HUAWEI_CLOUD + + vector._get_uuids = MagicMock(return_value=["id-1", "id-2"]) + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + Document(page_content="b", metadata={"doc_id": "id-2"}), + ] + + ids = vector.add_texts(docs, [[0.1], [0.2]]) + assert ids == ["id-1", "id-2"] + assert vector._client.index.call_count == 2 + vector._client.indices.refresh.assert_called_once_with(index="collection") + + +def test_crud_methods(huawei_module): + vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module)) + + vector._client.exists.return_value = True + assert vector.text_exists("id-1") is True + + vector.delete_by_ids([]) + vector._client.delete.assert_not_called() + vector.delete_by_ids(["id-1"]) + vector._client.delete.assert_called_once_with(index="collection", id="id-1") + + vector._client.search.return_value = {"hits": {"hits": [{"_id": "id-1"}]}} + vector.delete_by_ids = MagicMock() + vector.delete_by_metadata_field("doc_id", "x") + vector.delete_by_ids.assert_called_once_with(["id-1"]) + + vector.delete_by_ids.reset_mock() + vector._client.search.return_value = {"hits": {"hits": []}} + vector.delete_by_metadata_field("doc_id", "x") + vector.delete_by_ids.assert_not_called() + + vector.delete() + vector._client.indices.delete.assert_called_once_with(index="collection") + + +def test_search_by_vector_and_full_text(huawei_module): + vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module)) + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_score": 0.9, + "_source": { + huawei_module.Field.CONTENT_KEY: "doc-a", + huawei_module.Field.VECTOR: [0.1], + huawei_module.Field.METADATA_KEY: {"doc_id": "1"}, + }, + }, + { + "_score": 0.1, + "_source": { + huawei_module.Field.CONTENT_KEY: "doc-b", + huawei_module.Field.VECTOR: [0.2], + huawei_module.Field.METADATA_KEY: {"doc_id": "2"}, + }, + }, + ] + } + } + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + + query_body = vector._client.search.call_args.kwargs["body"] + assert query_body["query"]["vector"][huawei_module.Field.VECTOR]["topk"] == 2 + + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_source": { + huawei_module.Field.CONTENT_KEY: "text-hit", + huawei_module.Field.VECTOR: [0.3], + huawei_module.Field.METADATA_KEY: {"doc_id": "3"}, + } + } + ] + } + } + docs = vector.search_by_full_text("hello", top_k=3) + assert len(docs) == 1 + assert docs[0].page_content == "text-hit" + + +def test_search_by_vector_skips_hits_without_metadata(huawei_module, monkeypatch): + class FakeDocument: + def __init__(self, page_content, vector, metadata): + self.page_content = page_content + self.vector = vector + self.metadata = None + + monkeypatch.setattr(huawei_module, "Document", FakeDocument) + + vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module)) + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_score": 0.9, + "_source": { + huawei_module.Field.CONTENT_KEY: "doc-a", + huawei_module.Field.VECTOR: [0.1], + huawei_module.Field.METADATA_KEY: {"doc_id": "1"}, + }, + } + ] + } + } + + docs = vector.search_by_vector([0.1, 0.2], top_k=1, score_threshold=0.5) + + assert docs == [] + + +def test_create_and_create_collection_paths(huawei_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(huawei_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(huawei_module.redis_client, "set", MagicMock()) + + vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module)) + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + + docs = [Document(page_content="a", metadata={"doc_id": "1"})] + vector.create(docs, [[0.1]]) + vector.create_collection.assert_called_once() + vector.add_texts.assert_called_once_with(docs, [[0.1]]) + + vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module)) + monkeypatch.setattr(huawei_module.redis_client, "get", MagicMock(return_value=1)) + vector.create_collection([[0.1, 0.2]], [{}]) + vector._client.indices.create.assert_not_called() + + monkeypatch.setattr(huawei_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.indices.exists.return_value = False + vector.create_collection([[0.1, 0.2]], [{}]) + vector._client.indices.create.assert_called_once() + + kwargs = vector._client.indices.create.call_args.kwargs + mappings = kwargs["mappings"] + assert mappings["properties"][huawei_module.Field.VECTOR]["dimension"] == 2 + assert kwargs["settings"] == {"index.vector": True} + huawei_module.redis_client.set.assert_called_once() + + +def test_huawei_factory_branches(huawei_module, monkeypatch): + factory = huawei_module.HuaweiCloudVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(huawei_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(huawei_module.dify_config, "HUAWEI_CLOUD_HOSTS", "http://huawei-es:9200") + monkeypatch.setattr(huawei_module.dify_config, "HUAWEI_CLOUD_USER", "user") + monkeypatch.setattr(huawei_module.dify_config, "HUAWEI_CLOUD_PASSWORD", "pass") + + with patch.object(huawei_module, "HuaweiCloudVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["index_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["index_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/iris/test_iris_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/iris/test_iris_vector.py new file mode 100644 index 00000000000..63338ca809c --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/iris/test_iris_vector.py @@ -0,0 +1,412 @@ +import importlib +import sys +import types +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_iris_module(): + iris = types.ModuleType("iris") + + def connect(**_kwargs): + conn = MagicMock() + conn.cursor.return_value = MagicMock() + return conn + + iris.connect = MagicMock(side_effect=connect) + return iris + + +@pytest.fixture +def iris_module(monkeypatch): + monkeypatch.setitem(sys.modules, "iris", _build_fake_iris_module()) + + import core.rag.datasource.vdb.iris.iris_vector as module + + reloaded = importlib.reload(module) + reloaded._pool_instance = None + return reloaded + + +def _config(module, **overrides): + values = { + "IRIS_HOST": "localhost", + "IRIS_SUPER_SERVER_PORT": 1972, + "IRIS_USER": "user", + "IRIS_PASSWORD": "pass", + "IRIS_DATABASE": "db", + "IRIS_SCHEMA": "schema", + "IRIS_CONNECTION_URL": "url", + "IRIS_MIN_CONNECTION": 1, + "IRIS_MAX_CONNECTION": 2, + "IRIS_TEXT_INDEX": True, + "IRIS_TEXT_INDEX_LANGUAGE": "en", + } + values.update(overrides) + return module.IrisVectorConfig.model_validate(values) + + +def test_get_iris_pool_singleton(iris_module): + iris_module._pool_instance = None + cfg = _config(iris_module) + + with patch.object(iris_module, "IrisConnectionPool", return_value="pool") as pool_cls: + pool_1 = iris_module.get_iris_pool(cfg) + pool_2 = iris_module.get_iris_pool(cfg) + + assert pool_1 == "pool" + assert pool_2 == "pool" + pool_cls.assert_called_once_with(cfg) + + +@pytest.fixture +def pool_with_min_max(iris_module): + cfg = _config(iris_module, IRIS_MIN_CONNECTION=2, IRIS_MAX_CONNECTION=3) + with patch.object(iris_module.IrisConnectionPool, "_create_connection", return_value=MagicMock()) as create_conn: + pool = iris_module.IrisConnectionPool(cfg) + yield pool, create_conn + + +def test_pool_initialization_respects_min_max(pool_with_min_max): + pool, create_conn = pool_with_min_max + assert len(pool._pool) == 2 + assert create_conn.call_count == 2 + + +@pytest.fixture +def pool_for_get_connection(iris_module): + cfg = _config(iris_module, IRIS_MIN_CONNECTION=2, IRIS_MAX_CONNECTION=3) + pool = iris_module.IrisConnectionPool(cfg) + return pool + + +def test_get_connection_returns_existing_and_increments(pool_for_get_connection): + pool = pool_for_get_connection + conn = MagicMock() + pool._pool = [conn] + pool._in_use = 0 + assert pool.get_connection() is conn + assert pool._in_use == 1 + + +def test_get_connection_creates_new_when_empty(pool_for_get_connection): + pool = pool_for_get_connection + pool._pool = [] + pool._in_use = 0 + pool._create_connection = MagicMock(return_value="new-conn") + assert pool.get_connection() == "new-conn" + + +def test_get_connection_raises_when_exhausted(pool_for_get_connection): + pool = pool_for_get_connection + pool._pool = [] + pool._in_use = pool._max_size + with pytest.raises(RuntimeError, match="exhausted"): + pool.get_connection() + + +@pytest.fixture +def pool_for_return_connection(iris_module): + cfg = _config(iris_module) + with patch.object(iris_module.IrisConnectionPool, "_initialize_pool", return_value=None): + pool = iris_module.IrisConnectionPool(cfg) + return pool + + +def test_return_connection_adds_healthy(pool_for_return_connection): + pool = pool_for_return_connection + pool._in_use = 1 + conn = MagicMock() + cursor = MagicMock() + conn.cursor.return_value = cursor + pool.return_connection(conn) + assert pool._pool[-1] is conn + assert pool._in_use == 0 + + +def test_return_connection_replaces_bad(pool_for_return_connection): + pool = pool_for_return_connection + pool._in_use = 1 + bad_conn = MagicMock() + bad_cursor = MagicMock() + bad_cursor.execute.side_effect = OSError("bad") + bad_conn.cursor.return_value = bad_cursor + replacement = MagicMock() + pool._create_connection = MagicMock(return_value=replacement) + pool.return_connection(bad_conn) + bad_conn.close.assert_called_once() + assert pool._pool[-1] is replacement + assert pool._in_use == 0 + + +def test_return_connection_ignores_none(pool_for_return_connection): + pool = pool_for_return_connection + before = len(pool._pool) + pool.return_connection(None) + assert len(pool._pool) == before + + +@pytest.fixture +def pool_for_schema_and_close(iris_module): + cfg = _config(iris_module) + with patch.object(iris_module.IrisConnectionPool, "_initialize_pool", return_value=None): + pool = iris_module.IrisConnectionPool(cfg) + conn = MagicMock() + cursor = MagicMock() + conn.cursor.return_value = cursor + pool._pool = [conn] + return pool, conn, cursor + + +def test_ensure_schema_exists_cached_noop(pool_for_schema_and_close): + pool, conn, cursor = pool_for_schema_and_close + pool._schemas_initialized = {"cached_schema"} + pool.ensure_schema_exists("cached_schema") + cursor.execute.assert_not_called() + + +def test_ensure_schema_exists_creates_new(pool_for_schema_and_close): + pool, conn, cursor = pool_for_schema_and_close + pool._schemas_initialized = set() + cursor.fetchone.return_value = (0,) + pool.ensure_schema_exists("new_schema") + assert "new_schema" in pool._schemas_initialized + assert any("CREATE SCHEMA" in call.args[0] for call in cursor.execute.call_args_list) + conn.commit.assert_called_once() + + +def test_ensure_schema_exists_existing_no_commit(pool_for_schema_and_close): + pool, conn, cursor = pool_for_schema_and_close + pool._schemas_initialized = set() + cursor.fetchone.return_value = (1,) + pool.ensure_schema_exists("existing_schema") + conn.commit.assert_not_called() + + +def test_ensure_schema_exists_rollback_on_error(pool_for_schema_and_close): + pool, conn, cursor = pool_for_schema_and_close + pool._schemas_initialized = set() + cursor.execute.side_effect = RuntimeError("schema failure") + with pytest.raises(RuntimeError, match="schema failure"): + pool.ensure_schema_exists("broken_schema") + conn.rollback.assert_called() + + +def test_close_all_closes_and_resets(iris_module): + cfg = _config(iris_module) + with patch.object(iris_module.IrisConnectionPool, "_initialize_pool", return_value=None): + pool = iris_module.IrisConnectionPool(cfg) + conn = MagicMock() + conn_2 = MagicMock() + conn_2.close.side_effect = OSError("close fail") + pool._pool = [conn, conn_2] + pool._schemas_initialized = {"x"} + pool.close_all() + assert pool._pool == [] + assert pool._in_use == 0 + assert pool._schemas_initialized == set() + + +def test_iris_vector_init_get_cursor_and_create(iris_module): + pool = MagicMock() + pool.get_connection.return_value = MagicMock() + + with patch.object(iris_module, "get_iris_pool", return_value=pool): + vector = iris_module.IrisVector("collection", _config(iris_module)) + + assert vector.table_name == "EMBEDDING_COLLECTION" + assert vector.schema == "schema" + assert vector.get_type() == iris_module.VectorType.IRIS + + conn = MagicMock() + cursor = MagicMock() + conn.cursor.return_value = cursor + vector.pool.get_connection.return_value = conn + + with vector._get_cursor() as got_cursor: + assert got_cursor is cursor + conn.commit.assert_called_once() + vector.pool.return_connection.assert_called_with(conn) + + conn = MagicMock() + cursor = MagicMock() + conn.cursor.return_value = cursor + vector.pool.get_connection.return_value = conn + with pytest.raises(RuntimeError, match="boom"): + with vector._get_cursor(): + raise RuntimeError("boom") + conn.rollback.assert_called_once() + + vector._create_collection = MagicMock() + vector.add_texts = MagicMock(return_value=["id-1"]) + docs = [Document(page_content="a", metadata={"doc_id": "id-1"})] + assert vector.create(docs, [[0.1, 0.2]]) == ["id-1"] + vector._create_collection.assert_called_once_with(2) + + +def test_iris_vector_crud_and_vector_search(iris_module, monkeypatch): + with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()): + vector = iris_module.IrisVector("collection", _config(iris_module)) + + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + monkeypatch.setattr(iris_module.uuid, "uuid4", lambda: "generated-id") + + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + SimpleNamespace(page_content="b", metadata=None), + ] + ids = vector.add_texts(docs, [[0.1], [0.2]]) + assert ids == ["id-1", "generated-id"] + assert cursor.execute.call_count == 2 + + cursor.fetchone.return_value = (1,) + assert vector.text_exists("id-1") is True + cursor.fetchone.return_value = None + assert vector.text_exists("id-2") is False + + vector._get_cursor = MagicMock(side_effect=RuntimeError("db down")) + assert vector.text_exists("id-3") is False + + vector._get_cursor = _cursor_ctx + vector.delete_by_ids([]) + before = cursor.execute.call_count + vector.delete_by_ids(["id-1", "id-2"]) + assert cursor.execute.call_count == before + 1 + + vector.delete_by_metadata_field("document_id", "doc-1") + assert "meta LIKE" in cursor.execute.call_args.args[0] + + cursor.fetchall.return_value = [ + ("id-1", "text-1", '{"document_id":"d-1"}', 0.9), + ("id-2", "text-2", '{"document_id":"d-2"}', 0.2), + ("id-x",), + ] + docs = vector.search_by_vector([0.1, 0.2], top_k=3, score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + + +def test_iris_vector_full_text_search_paths(iris_module, monkeypatch): + cfg = _config(iris_module, IRIS_TEXT_INDEX=True) + with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()): + vector = iris_module.IrisVector("collection", cfg) + + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + cursor.execute.side_effect = None + cursor.fetchall.return_value = [ + ("id-1", "text-1", '{"document_id":"d-1"}', 0.7), + ("id-2", "text-2", "{}", None), + ] + docs = vector.search_by_full_text("query", top_k=2, document_ids_filter=["d-1"]) + assert len(docs) == 2 + assert docs[0].metadata["score"] == pytest.approx(0.7) + assert docs[1].metadata["score"] == pytest.approx(0.0) + + cursor.reset_mock() + cursor.execute.side_effect = [RuntimeError("rank failed"), None] + cursor.fetchall.return_value = [("id-3", "text-3", "{}", 0.5)] + docs = vector.search_by_full_text("query", top_k=1) + assert len(docs) == 1 + assert cursor.execute.call_count == 2 + + cfg_like = _config(iris_module, IRIS_TEXT_INDEX=False) + with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()): + vector_like = iris_module.IrisVector("collection", cfg_like) + vector_like._get_cursor = _cursor_ctx + + fake_libs = types.ModuleType("libs") + fake_helper = types.ModuleType("libs.helper") + fake_helper.escape_like_pattern = lambda value: value.replace("%", "\\%") + monkeypatch.setitem(sys.modules, "libs", fake_libs) + monkeypatch.setitem(sys.modules, "libs.helper", fake_helper) + + cursor.reset_mock() + cursor.execute.side_effect = None + cursor.fetchall.return_value = [] + assert vector_like.search_by_full_text("100%", top_k=1) == [] + + +def test_iris_vector_delete_create_collection_and_factory(iris_module, monkeypatch): + with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()): + vector = iris_module.IrisVector("collection", _config(iris_module, IRIS_TEXT_INDEX=True)) + + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + vector.delete() + assert "DROP TABLE" in cursor.execute.call_args.args[0] + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(iris_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(iris_module.redis_client, "set", MagicMock()) + + monkeypatch.setattr(iris_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection(2) + cursor.execute.assert_called_once() + + cursor.reset_mock() + monkeypatch.setattr(iris_module.redis_client, "get", MagicMock(return_value=None)) + vector.pool.ensure_schema_exists = MagicMock() + vector._create_collection(3) + assert cursor.execute.call_count == 3 + iris_module.redis_client.set.assert_called_once() + + cursor.reset_mock() + vector.config.IRIS_TEXT_INDEX = False + vector._create_collection(3) + assert cursor.execute.call_count == 2 + + factory = iris_module.IrisVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(iris_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(iris_module.dify_config, "IRIS_HOST", "localhost") + monkeypatch.setattr(iris_module.dify_config, "IRIS_SUPER_SERVER_PORT", 1972) + monkeypatch.setattr(iris_module.dify_config, "IRIS_USER", "user") + monkeypatch.setattr(iris_module.dify_config, "IRIS_PASSWORD", "pass") + monkeypatch.setattr(iris_module.dify_config, "IRIS_DATABASE", "db") + monkeypatch.setattr(iris_module.dify_config, "IRIS_SCHEMA", "schema") + monkeypatch.setattr(iris_module.dify_config, "IRIS_CONNECTION_URL", "url") + monkeypatch.setattr(iris_module.dify_config, "IRIS_MIN_CONNECTION", 1) + monkeypatch.setattr(iris_module.dify_config, "IRIS_MAX_CONNECTION", 2) + monkeypatch.setattr(iris_module.dify_config, "IRIS_TEXT_INDEX", True) + monkeypatch.setattr(iris_module.dify_config, "IRIS_TEXT_INDEX_LANGUAGE", "en") + + with patch.object(iris_module, "IrisVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/lindorm/test_lindorm_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/lindorm/test_lindorm_vector.py new file mode 100644 index 00000000000..34357d5907d --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/lindorm/test_lindorm_vector.py @@ -0,0 +1,394 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_opensearch_modules(): + opensearchpy = types.ModuleType("opensearchpy") + opensearch_helpers = types.ModuleType("opensearchpy.helpers") + + class BulkIndexError(Exception): + def __init__(self, errors): + super().__init__("bulk error") + self.errors = errors + + class OpenSearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.indices = SimpleNamespace( + refresh=MagicMock(), + exists=MagicMock(return_value=False), + delete=MagicMock(), + create=MagicMock(), + ) + self.bulk = MagicMock(return_value={"errors": False, "items": []}) + self.search = MagicMock(return_value={"hits": {"hits": []}}) + self.delete_by_query = MagicMock() + self.get = MagicMock(return_value={"_id": "id"}) + self.exists = MagicMock(return_value=True) + + opensearch_helpers.BulkIndexError = BulkIndexError + opensearch_helpers.bulk = MagicMock() + + opensearchpy.OpenSearch = OpenSearch + opensearchpy.helpers = opensearch_helpers + + return { + "opensearchpy": opensearchpy, + "opensearchpy.helpers": opensearch_helpers, + } + + +@pytest.fixture +def lindorm_module(monkeypatch): + for name, module in _build_fake_opensearch_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.lindorm.lindorm_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.LindormVectorStoreConfig( + hosts="http://localhost:9200", + username="user", + password="pass", + using_ugc=False, + request_timeout=3.0, + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("hosts", None, "config URL is required"), + ("username", None, "config USERNAME is required"), + ("password", None, "config PASSWORD is required"), + ], +) +def test_lindorm_config_validation(lindorm_module, field, value, message): + values = _config(lindorm_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + lindorm_module.LindormVectorStoreConfig.model_validate(values) + + +def test_to_opensearch_params_and_init(lindorm_module): + cfg = _config(lindorm_module) + params = cfg.to_opensearch_params() + + assert params["hosts"] == "http://localhost:9200" + assert params["http_auth"] == ("user", "pass") + + vector = lindorm_module.LindormVectorStore("Collection", cfg, using_ugc=False) + assert vector._collection_name == "collection" + assert vector.get_type() == lindorm_module.VectorType.LINDORM + + with pytest.raises(ValueError, match="routing_value"): + lindorm_module.LindormVectorStore("c", cfg, using_ugc=True) + + vector_ugc = lindorm_module.LindormVectorStore("c", cfg, using_ugc=True, routing_value="ROUTE") + assert vector_ugc._routing == "route" + + +def test_create_refresh_and_add_texts_success(lindorm_module, monkeypatch): + vector = lindorm_module.LindormVectorStore( + "collection", _config(lindorm_module), using_ugc=True, routing_value="route" + ) + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + + docs = [Document(page_content="a", metadata={"doc_id": "id-1"})] + vector.create(docs, [[0.1]]) + vector.create_collection.assert_called_once_with([[0.1]], [{"doc_id": "id-1"}]) + vector.add_texts.assert_called_once_with(docs, [[0.1]]) + + vector = lindorm_module.LindormVectorStore( + "collection", _config(lindorm_module), using_ugc=True, routing_value="route" + ) + monkeypatch.setattr(lindorm_module.time, "sleep", MagicMock()) + + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + Document(page_content="b", metadata={"doc_id": "id-2"}), + Document(page_content="c", metadata={"doc_id": "id-3"}), + ] + embeddings = [[0.1], [0.2], [0.3]] + + vector.add_texts(docs, embeddings, batch_size=2, timeout=9) + + assert vector._client.bulk.call_count == 2 + actions = vector._client.bulk.call_args_list[0].args[0] + assert actions[0]["index"]["routing"] == "route" + assert actions[1][lindorm_module.ROUTING_FIELD] == "route" + vector.refresh() + vector._client.indices.refresh.assert_called_once_with(index="collection") + + +def test_add_texts_error_paths(lindorm_module): + vector = lindorm_module.LindormVectorStore("collection", _config(lindorm_module), using_ugc=False) + vector._client.bulk.return_value = {"errors": True, "items": [{"index": {"error": "boom"}}]} + + docs = [Document(page_content="a", metadata={"doc_id": "id-1"})] + with pytest.raises(Exception, match="RetryError"): + vector.add_texts(docs, [[0.1]], batch_size=1) + + vector._client.bulk.side_effect = RuntimeError("bulk failed") + with pytest.raises(Exception, match="RetryError"): + vector.add_texts(docs, [[0.1]], batch_size=1) + + +def test_metadata_lookup_and_delete_by_metadata(lindorm_module): + vector = lindorm_module.LindormVectorStore( + "collection", _config(lindorm_module), using_ugc=True, routing_value="route" + ) + vector._client.search.return_value = {"hits": {"hits": [{"_id": "id-1"}, {"_id": "id-2"}]}} + + ids = vector.get_ids_by_metadata_field("document_id", "doc-1") + assert ids == ["id-1", "id-2"] + query = vector._client.search.call_args.kwargs["body"] + must_conditions = query["query"]["bool"]["must"] + assert any("routing_field.keyword" in cond.get("term", {}) for cond in must_conditions) + + vector.delete_by_ids = MagicMock() + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete_by_ids.assert_called_once_with(["id-1", "id-2"]) + + vector._client.search.return_value = {"hits": {"hits": []}} + vector.delete_by_ids.reset_mock() + vector.delete_by_metadata_field("document_id", "doc-2") + vector.delete_by_ids.assert_not_called() + + +def test_delete_by_ids_paths(lindorm_module): + vector = lindorm_module.LindormVectorStore( + "collection", _config(lindorm_module), using_ugc=True, routing_value="route" + ) + + vector.delete_by_ids([]) + vector._client.indices.exists.assert_not_called() + + vector._client.indices.exists.return_value = False + vector.delete_by_ids(["id-1"]) + + vector._client.indices.exists.return_value = True + vector._client.exists.side_effect = [True, False] + lindorm_module.helpers.bulk.reset_mock() + vector.delete_by_ids(["id-1", "id-2"]) + lindorm_module.helpers.bulk.assert_called_once() + actions = lindorm_module.helpers.bulk.call_args.args[1] + assert len(actions) == 1 + assert actions[0]["routing"] == "route" + + lindorm_module.helpers.bulk.reset_mock() + lindorm_module.helpers.bulk.side_effect = lindorm_module.BulkIndexError( + errors=[ + {"delete": {"status": 404, "_id": "id-404"}}, + {"delete": {"status": 500, "_id": "id-500"}}, + ] + ) + vector._client.exists.side_effect = [True] + vector.delete_by_ids(["id-1"]) + + +def test_delete_and_text_exists(lindorm_module): + vector = lindorm_module.LindormVectorStore( + "collection", _config(lindorm_module), using_ugc=True, routing_value="route" + ) + vector.delete() + vector._client.delete_by_query.assert_called_once() + vector._client.indices.refresh.assert_called_once_with(index="collection") + + vector = lindorm_module.LindormVectorStore("collection", _config(lindorm_module), using_ugc=False) + vector._client.indices.exists.return_value = True + vector.delete() + vector._client.indices.delete.assert_called_once_with(index="collection", params={"timeout": 60}) + + vector._client.indices.delete.reset_mock() + vector._client.indices.exists.return_value = False + vector.delete() + vector._client.indices.delete.assert_not_called() + + assert vector.text_exists("id-1") is True + vector._client.get.side_effect = RuntimeError("missing") + assert vector.text_exists("id-1") is False + + +def test_search_by_vector_validation_and_success(lindorm_module): + vector = lindorm_module.LindormVectorStore( + "collection", _config(lindorm_module), using_ugc=True, routing_value="route" + ) + + with pytest.raises(ValueError, match="should be a list"): + vector.search_by_vector("bad") + + with pytest.raises(ValueError, match="should be floats"): + vector.search_by_vector([0.1, "bad"]) + + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_score": 0.9, + "_source": { + lindorm_module.Field.CONTENT_KEY: "doc-a", + lindorm_module.Field.VECTOR: [0.1], + lindorm_module.Field.METADATA_KEY: {"doc_id": "1", "document_id": "d-1"}, + }, + }, + { + "_score": 0.2, + "_source": { + lindorm_module.Field.CONTENT_KEY: "doc-b", + lindorm_module.Field.VECTOR: [0.2], + lindorm_module.Field.METADATA_KEY: {"doc_id": "2", "document_id": "d-2"}, + }, + }, + ] + } + } + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + + call_kwargs = vector._client.search.call_args.kwargs + query = call_kwargs["body"] + assert "ext" in query + assert query["query"]["knn"][lindorm_module.Field.VECTOR]["filter"]["bool"]["must"] + assert call_kwargs["params"]["routing"] == "route" + + vector._client.search.side_effect = RuntimeError("search failed") + with pytest.raises(RuntimeError, match="search failed"): + vector.search_by_vector([0.1]) + + +def test_search_by_full_text_success_and_error(lindorm_module): + vector = lindorm_module.LindormVectorStore( + "collection", _config(lindorm_module), using_ugc=True, routing_value="route" + ) + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_source": { + lindorm_module.Field.CONTENT_KEY: "doc-a", + lindorm_module.Field.VECTOR: [0.1], + lindorm_module.Field.METADATA_KEY: {"doc_id": "1"}, + } + } + ] + } + } + + docs = vector.search_by_full_text("hello", top_k=2, document_ids_filter=["d-1"]) + assert len(docs) == 1 + assert docs[0].page_content == "doc-a" + + query = vector._client.search.call_args.kwargs["body"] + assert query["query"]["bool"]["filter"] + + vector._client.search.side_effect = RuntimeError("full text failed") + with pytest.raises(RuntimeError, match="full text failed"): + vector.search_by_full_text("hello") + + +def test_create_collection_paths(lindorm_module, monkeypatch): + vector = lindorm_module.LindormVectorStore("collection", _config(lindorm_module), using_ugc=False) + + with pytest.raises(ValueError, match="cannot be empty"): + vector.create_collection([]) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(lindorm_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(lindorm_module.redis_client, "set", MagicMock()) + + monkeypatch.setattr(lindorm_module.redis_client, "get", MagicMock(return_value=1)) + vector.create_collection([[0.1, 0.2]]) + vector._client.indices.create.assert_not_called() + + monkeypatch.setattr(lindorm_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.indices.exists.return_value = False + vector.create_collection([[0.1, 0.2]], index_params={"index_type": "ivf", "space_type": "cosine"}) + vector._client.indices.create.assert_called_once() + body = vector._client.indices.create.call_args.kwargs["body"] + assert body["mappings"]["properties"][lindorm_module.Field.VECTOR]["method"]["name"] == "ivf" + assert body["mappings"]["properties"][lindorm_module.Field.VECTOR]["method"]["space_type"] == "cosine" + + vector._client.indices.create.reset_mock() + vector._client.indices.exists.return_value = True + vector.create_collection([[0.1, 0.2]]) + vector._client.indices.create.assert_not_called() + + +def test_lindorm_factory_branches(lindorm_module, monkeypatch): + factory = lindorm_module.LindormVectorStoreFactory() + + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_URL", "http://localhost:9200") + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USERNAME", "user") + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_PASSWORD", "pass") + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_QUERY_TIMEOUT", 3.0) + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_INDEX_TYPE", "hnsw") + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_DISTANCE_TYPE", "l2") + monkeypatch.setattr(lindorm_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + + dataset = SimpleNamespace(id="dataset-1", index_struct=None, index_struct_dict={}) + embeddings = SimpleNamespace(embed_query=lambda _q: [0.1, 0.2, 0.3]) + + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USING_UGC", None) + with pytest.raises(ValueError, match="LINDORM_USING_UGC is not set"): + factory.init_vector(dataset, attributes=[], embeddings=embeddings) + + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USING_UGC", False) + + dataset_existing_plain = SimpleNamespace( + id="dataset-1", + index_struct="{}", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING"}, "using_ugc": False}, + ) + with patch.object(lindorm_module, "LindormVectorStore", return_value="vector") as store_cls: + result = factory.init_vector(dataset_existing_plain, attributes=[], embeddings=embeddings) + assert result == "vector" + assert store_cls.call_args.args[0] == "existing" + + dataset_existing_ugc = SimpleNamespace( + id="dataset-1", + index_struct="{}", + index_struct_dict={ + "vector_store": {"class_prefix": "ROUTING"}, + "using_ugc": True, + "dimension": 1536, + "index_type": "hnsw", + "distance_type": "l2", + }, + ) + with patch.object(lindorm_module, "LindormVectorStore", return_value="vector") as store_cls: + factory.init_vector(dataset_existing_ugc, attributes=[], embeddings=embeddings) + assert store_cls.call_args.args[0] == "ugc_index_1536_hnsw_l2" + assert store_cls.call_args.kwargs["routing_value"] == "ROUTING" + + dataset_new = SimpleNamespace(id="dataset-2", index_struct=None, index_struct_dict={}) + + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USING_UGC", True) + with patch.object(lindorm_module, "LindormVectorStore", return_value="vector") as store_cls: + factory.init_vector(dataset_new, attributes=[], embeddings=embeddings) + assert store_cls.call_args.args[0] == "ugc_index_3_hnsw_l2" + assert store_cls.call_args.kwargs["routing_value"] == "auto_collection" + assert dataset_new.index_struct is not None + + dataset_new_plain = SimpleNamespace(id="dataset-3", index_struct=None, index_struct_dict={}) + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USING_UGC", False) + with patch.object(lindorm_module, "LindormVectorStore", return_value="vector") as store_cls: + factory.init_vector(dataset_new_plain, attributes=[], embeddings=embeddings) + assert store_cls.call_args.args[0] == "auto_collection" + assert store_cls.call_args.kwargs["routing_value"] is None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/matrixone/test_matrixone_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/matrixone/test_matrixone_vector.py new file mode 100644 index 00000000000..55e7b9112ec --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/matrixone/test_matrixone_vector.py @@ -0,0 +1,252 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_mo_vector_modules(): + mo_vector = types.ModuleType("mo_vector") + mo_vector.__path__ = [] + mo_vector_client = types.ModuleType("mo_vector.client") + + class MoVectorClient: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.create_full_text_index = MagicMock() + self.insert = MagicMock() + self.get = MagicMock(return_value=[]) + self.delete = MagicMock() + self.query_by_metadata = MagicMock(return_value=[]) + self.query = MagicMock(return_value=[]) + self.full_text_query = MagicMock(return_value=[]) + + mo_vector_client.MoVectorClient = MoVectorClient + mo_vector.client = mo_vector_client + return {"mo_vector": mo_vector, "mo_vector.client": mo_vector_client} + + +@pytest.fixture +def matrixone_module(monkeypatch): + for name, module in _build_fake_mo_vector_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.matrixone.matrixone_vector as module + + return importlib.reload(module) + + +def _valid_config(module): + return module.MatrixoneConfig( + host="localhost", + port=6001, + user="dump", + password="111", + database="dify", + metric="l2", + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config host is required"), + ("port", 0, "config port is required"), + ("user", "", "config user is required"), + ("password", "", "config password is required"), + ("database", "", "config database is required"), + ], +) +def test_matrixone_config_validation(matrixone_module, field, value, message): + values = _valid_config(matrixone_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + matrixone_module.MatrixoneConfig.model_validate(values) + + +def test_get_client_creates_full_text_index_when_cache_misses(matrixone_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(matrixone_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(matrixone_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(matrixone_module.redis_client, "set", MagicMock()) + + vector = matrixone_module.MatrixoneVector("Collection_1", _valid_config(matrixone_module)) + client = vector._get_client(dimension=3, create_table=True) + + assert client.kwargs["table_name"] == "collection_1" + client.create_full_text_index.assert_called_once() + matrixone_module.redis_client.set.assert_called_once() + + +def test_get_client_skips_index_creation_when_cache_hits(matrixone_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(matrixone_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(matrixone_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(matrixone_module.redis_client, "set", MagicMock()) + + vector = matrixone_module.MatrixoneVector("Collection_1", _valid_config(matrixone_module)) + client = vector._get_client(dimension=3, create_table=True) + + client.create_full_text_index.assert_not_called() + matrixone_module.redis_client.set.assert_not_called() + + +def test_ensure_client_initializes_client_for_decorated_methods(matrixone_module): + vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module)) + vector.client = None + fake_client = MagicMock() + fake_client.get.return_value = [{"id": "seg-1"}] + vector._get_client = MagicMock(return_value=fake_client) + + exists = vector.text_exists("seg-1") + + assert exists is True + vector._get_client.assert_called_once_with(None, False) + + +def test_search_by_full_text_parses_metadata_and_applies_threshold(matrixone_module): + vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module)) + vector.client = MagicMock() + vector.client.full_text_query.return_value = [ + SimpleNamespace(document="doc-a", metadata='{"doc_id":"1"}', distance=0.1), + SimpleNamespace(document="doc-b", metadata={"doc_id": "2"}, distance=0.7), + ] + + docs = vector.search_by_full_text("query", top_k=2, score_threshold=0.5, document_ids_filter=["doc-1"]) + + assert len(docs) == 1 + assert docs[0].page_content == "doc-a" + assert docs[0].metadata["doc_id"] == "1" + assert docs[0].metadata["score"] == pytest.approx(0.9) + assert vector.client.full_text_query.call_args.kwargs["filter"] == {"document_id": {"$in": ["doc-1"]}} + + +def test_get_type_and_create_delegate_to_add_texts(matrixone_module): + vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module)) + fake_client = MagicMock() + vector._get_client = MagicMock(return_value=fake_client) + vector.add_texts = MagicMock(return_value=["seg-1"]) + docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})] + + result = vector.create(docs, [[0.1, 0.2]]) + + assert vector.get_type() == "matrixone" + assert result == ["seg-1"] + vector._get_client.assert_called_once_with(2, True) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_get_client_handles_full_text_index_creation_error(matrixone_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(matrixone_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(matrixone_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(matrixone_module.redis_client, "set", MagicMock()) + + failing_client = MagicMock() + failing_client.create_full_text_index.side_effect = RuntimeError("boom") + monkeypatch.setattr(matrixone_module, "MoVectorClient", MagicMock(return_value=failing_client)) + + vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module)) + client = vector._get_client(dimension=3, create_table=True) + + assert client is failing_client + matrixone_module.redis_client.set.assert_not_called() + + +def test_add_texts_generates_ids_and_inserts(matrixone_module, monkeypatch): + vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module)) + vector.client = MagicMock() + monkeypatch.setattr(matrixone_module.uuid, "uuid4", lambda: "generated-uuid") + docs = [ + Document(page_content="a", metadata={"doc_id": "doc-a", "document_id": "d-1"}), + Document(page_content="b", metadata={"document_id": "d-2"}), + SimpleNamespace(page_content="c", metadata=None), + ] + + ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]]) + + # For current prod code, only docs with metadata get ids, so only two ids + assert ids == ["doc-a", "generated-uuid"] + vector.client.insert.assert_called_once() + insert_kwargs = vector.client.insert.call_args.kwargs + # All lists passed to insert should be the same length + texts = insert_kwargs["texts"] + embeddings = insert_kwargs["embeddings"] + metadatas = insert_kwargs["metadatas"] + ids_insert = insert_kwargs["ids"] + assert len(texts) == len(embeddings) == len(metadatas) == len(docs) + # ids may be shorter than docs for current prod code, but should match number of docs with metadata + assert ids_insert == ["doc-a", "generated-uuid"] + + +def test_delete_and_metadata_methods(matrixone_module): + vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module)) + vector.client = MagicMock() + vector.client.query_by_metadata.return_value = [SimpleNamespace(id="seg-1"), SimpleNamespace(id="seg-2")] + + vector.delete_by_ids([]) + vector.client.delete.assert_not_called() + + vector.delete_by_ids(["seg-1"]) + vector.delete_by_metadata_field("document_id", "doc-1") + ids = vector.get_ids_by_metadata_field("document_id", "doc-1") + vector.delete() + + assert ids == ["seg-1", "seg-2"] + assert vector.client.delete.call_count == 3 + + +def test_search_by_vector_builds_documents(matrixone_module): + vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module)) + vector.client = MagicMock() + vector.client.query.return_value = [ + SimpleNamespace(document="doc-a", metadata={"doc_id": "1"}), + SimpleNamespace(document="doc-b", metadata={"doc_id": "2"}), + ] + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, document_ids_filter=["d-1"]) + + assert len(docs) == 2 + assert docs[0].page_content == "doc-a" + assert docs[1].metadata["doc_id"] == "2" + assert vector.client.query.call_args.kwargs["filter"] == {"document_id": {"$in": ["d-1"]}} + + +def test_matrixone_factory_uses_existing_or_generated_collection(matrixone_module, monkeypatch): + factory = matrixone_module.MatrixoneVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(matrixone_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_HOST", "127.0.0.1") + monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_PORT", 6001) + monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_USER", "dump") + monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_PASSWORD", "111") + monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_DATABASE", "dify") + monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_METRIC", "l2") + + with patch.object(matrixone_module, "MatrixoneVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py b/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py index fb2ddfe162c..2ac2c40d38c 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py @@ -1,18 +1,414 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + import pytest from pydantic import ValidationError -from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig +from core.rag.models.document import Document -def test_default_value(): +def _build_fake_pymilvus_modules(): + pymilvus = types.ModuleType("pymilvus") + pymilvus.__path__ = [] + pymilvus_milvus_client = types.ModuleType("pymilvus.milvus_client") + pymilvus_orm = types.ModuleType("pymilvus.orm") + pymilvus_orm.__path__ = [] + pymilvus_orm_types = types.ModuleType("pymilvus.orm.types") + + class MilvusError(Exception): + pass + + class MilvusClient: + def __init__(self, **kwargs): + self.init_kwargs = kwargs + self.has_collection = MagicMock(return_value=False) + self.describe_collection = MagicMock( + return_value={"fields": [{"name": "id"}, {"name": "content"}, {"name": "metadata"}]} + ) + self.get_server_version = MagicMock(return_value="2.5.0") + self.insert = MagicMock(return_value=[1]) + self.query = MagicMock(return_value=[]) + self.delete = MagicMock() + self.drop_collection = MagicMock() + self.search = MagicMock(return_value=[[]]) + self.create_collection = MagicMock() + + class IndexParams: + def __init__(self): + self.indexes = [] + + def add_index(self, **kwargs): + self.indexes.append(kwargs) + + class DataType: + JSON = "JSON" + VARCHAR = "VARCHAR" + INT64 = "INT64" + SPARSE_FLOAT_VECTOR = "SPARSE_FLOAT_VECTOR" + FLOAT_VECTOR = "FLOAT_VECTOR" + + class FieldSchema: + def __init__(self, name, dtype, **kwargs): + self.name = name + self.dtype = dtype + self.kwargs = kwargs + + class CollectionSchema: + def __init__(self, fields): + self.fields = fields + self.functions = [] + + def add_function(self, func): + self.functions.append(func) + + class FunctionType: + BM25 = "BM25" + + class Function: + def __init__(self, **kwargs): + self.kwargs = kwargs + + def infer_dtype_bydata(_value): + return DataType.FLOAT_VECTOR + + pymilvus.MilvusException = MilvusError + pymilvus.MilvusClient = MilvusClient + pymilvus.IndexParams = IndexParams + pymilvus.CollectionSchema = CollectionSchema + pymilvus.DataType = DataType + pymilvus.FieldSchema = FieldSchema + pymilvus.Function = Function + pymilvus.FunctionType = FunctionType + pymilvus_milvus_client.IndexParams = IndexParams + pymilvus_orm.types = pymilvus_orm_types + pymilvus_orm_types.infer_dtype_bydata = infer_dtype_bydata + + # Attach submodules for dotted imports + pymilvus.milvus_client = pymilvus_milvus_client + pymilvus.orm = pymilvus_orm + + return { + "pymilvus": pymilvus, + "pymilvus.milvus_client": pymilvus_milvus_client, + "pymilvus.orm": pymilvus_orm, + "pymilvus.orm.types": pymilvus_orm_types, + } + + +@pytest.fixture +def milvus_module(monkeypatch): + for name, module in _build_fake_pymilvus_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.milvus.milvus_vector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "uri": "http://localhost:19530", + "user": "root", + "password": "Milvus", + "database": "default", + "enable_hybrid_search": False, + "analyzer_params": None, + } + values.update(overrides) + return module.MilvusConfig.model_validate(values) + + +def test_config_validation_and_defaults(milvus_module): valid_config = {"uri": "http://localhost:19530", "user": "root", "password": "Milvus"} for key in valid_config: config = valid_config.copy() del config[key] with pytest.raises(ValidationError) as e: - MilvusConfig.model_validate(config) + milvus_module.MilvusConfig.model_validate(config) assert e.value.errors()[0]["msg"] == f"Value error, config MILVUS_{key.upper()} is required" - config = MilvusConfig.model_validate(valid_config) + config = milvus_module.MilvusConfig.model_validate(valid_config) assert config.database == "default" + + token_config = milvus_module.MilvusConfig.model_validate( + {"uri": "http://localhost:19530", "token": "token-value", "database": "db-1"} + ) + assert token_config.token == "token-value" + + +def test_config_to_milvus_params(milvus_module): + config = _config(milvus_module, analyzer_params='{"tokenizer":"standard"}') + + params = config.to_milvus_params() + + assert params["uri"] == "http://localhost:19530" + assert params["db_name"] == "default" + assert params["analyzer_params"] == '{"tokenizer":"standard"}' + + +def test_init_client_supports_token_and_user_password(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + token_client = vector._init_client( + milvus_module.MilvusConfig.model_validate({"uri": "http://localhost:19530", "token": "abc", "database": "db"}) + ) + assert token_client.init_kwargs == {"uri": "http://localhost:19530", "token": "abc", "db_name": "db"} + + user_client = vector._init_client(_config(milvus_module)) + assert user_client.init_kwargs["uri"] == "http://localhost:19530" + assert user_client.init_kwargs["user"] == "root" + assert user_client.init_kwargs["password"] == "Milvus" + + +def test_init_loads_fields_when_collection_exists(milvus_module): + client = milvus_module.MilvusClient(uri="http://localhost:19530") + client.has_collection.return_value = True + client.describe_collection.return_value = { + "fields": [{"name": "id"}, {"name": "content"}, {"name": "metadata"}, {"name": "sparse_vector"}] + } + + with patch.object(milvus_module.MilvusVector, "_init_client", return_value=client): + with patch.object(milvus_module.MilvusVector, "_check_hybrid_search_support", return_value=False): + vector = milvus_module.MilvusVector("collection_1", _config(milvus_module)) + + assert "id" not in vector._fields + assert "content" in vector._fields + + +def test_load_collection_fields_from_argument_and_remote(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._client = MagicMock() + vector._collection_name = "collection_1" + vector._client.describe_collection.return_value = {"fields": [{"name": "id"}, {"name": "content"}]} + + vector._load_collection_fields(["id", "metadata"]) + assert vector._fields == ["metadata"] + + vector._load_collection_fields() + assert vector._fields == ["content"] + + +def test_check_hybrid_search_support_branches(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._client = MagicMock() + + vector._client_config = SimpleNamespace(enable_hybrid_search=False) + assert vector._check_hybrid_search_support() is False + + vector._client_config = SimpleNamespace(enable_hybrid_search=True) + vector._client.get_server_version.return_value = "Zilliz Cloud 2.4" + assert vector._check_hybrid_search_support() is True + + vector._client.get_server_version.return_value = "2.5.1" + assert vector._check_hybrid_search_support() is True + + vector._client.get_server_version.return_value = "2.4.9" + assert vector._check_hybrid_search_support() is False + + vector._client.get_server_version.side_effect = RuntimeError("boom") + assert vector._check_hybrid_search_support() is False + + +def test_get_type_and_create_delegate(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [SimpleNamespace(page_content="hello", metadata=None)] + + vector.create(docs, [[0.1, 0.2]]) + + assert vector.get_type() == "milvus" + vector.create_collection.assert_called_once() + create_args = vector.create_collection.call_args.args + assert create_args[0] == [[0.1, 0.2]] + assert create_args[1] == [{}] + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_add_texts_batches_and_raises_milvus_exception(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + vector._client.insert.side_effect = [["id-1"], ["id-2"]] + docs = [Document(page_content=f"text-{i}", metadata={"doc_id": f"d-{i}"}) for i in range(1001)] + embeddings = [[0.1, 0.2] for _ in range(1001)] + + ids = vector.add_texts(docs, embeddings) + assert ids == ["id-1", "id-2"] + assert vector._client.insert.call_count == 2 + + vector._client.insert.side_effect = milvus_module.MilvusException("insert failed") + with pytest.raises(milvus_module.MilvusException): + vector.add_texts([Document(page_content="x", metadata={})], [[0.1]]) + + +def test_get_ids_and_delete_methods(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + vector._client.query.return_value = [{"id": 1}, {"id": 2}] + + assert vector.get_ids_by_metadata_field("document_id", "doc-1") == [1, 2] + vector._client.query.return_value = [] + assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None + + vector._client.has_collection.return_value = True + vector.get_ids_by_metadata_field = MagicMock(return_value=[101, 102]) + vector.delete_by_metadata_field("document_id", "doc-1") + vector._client.delete.assert_called_with(collection_name="collection_1", pks=[101, 102]) + + vector._client.delete.reset_mock() + vector._client.query.return_value = [{"id": 11}, {"id": 12}] + vector.delete_by_ids(["doc-a", "doc-b"]) + vector._client.delete.assert_called_with(collection_name="collection_1", pks=[11, 12]) + + vector._client.has_collection.return_value = True + vector.delete() + vector._client.drop_collection.assert_called_once_with("collection_1", None) + + +def test_text_exists_and_field_exists(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._collection_name = "collection_1" + vector._fields = ["content", "metadata"] + vector._client = MagicMock() + vector._client.has_collection.return_value = False + assert vector.text_exists("doc-1") is False + + vector._client.has_collection.return_value = True + vector._client.query.return_value = [{"id": 1}] + assert vector.text_exists("doc-1") is True + vector._client.query.return_value = [] + assert vector.text_exists("doc-1") is False + assert vector.field_exists("content") is True + assert vector.field_exists("unknown") is False + + +def test_process_search_results_and_search_methods(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + vector._fields = ["content", "metadata", "sparse_vector"] + + processed = vector._process_search_results( + [ + [ + {"entity": {"content": "doc-1", "metadata": {"doc_id": "1"}}, "distance": 0.9}, + {"entity": {"content": "doc-2", "metadata": {"doc_id": "2"}}, "distance": 0.2}, + ] + ], + [milvus_module.Field.CONTENT_KEY, milvus_module.Field.METADATA_KEY], + score_threshold=0.5, + ) + assert len(processed) == 1 + assert processed[0].metadata["score"] == 0.9 + + vector._client.search.return_value = [[{"entity": {"content": "doc"}, "distance": 0.8}]] + vector._process_search_results = MagicMock(return_value=["doc"]) + + docs = vector.search_by_vector([0.1, 0.2], top_k=3, document_ids_filter=["a", "b"], score_threshold=0.1) + assert docs == ["doc"] + assert vector._client.search.call_args.kwargs["filter"] == 'metadata["document_id"] in ["a", "b"]' + + vector._hybrid_search_enabled = False + assert vector.search_by_full_text("query") == [] + + vector._hybrid_search_enabled = True + vector._fields = [] + assert vector.search_by_full_text("query") == [] + + vector._fields = [milvus_module.Field.SPARSE_VECTOR] + vector._process_search_results = MagicMock(return_value=["full-text-doc"]) + full_text_docs = vector.search_by_full_text("query", top_k=2, document_ids_filter=["d-1"], score_threshold=0.2) + assert full_text_docs == ["full-text-doc"] + assert "document_id" in vector._client.search.call_args.kwargs["filter"] + + +def test_create_collection_cache_and_existing_collection(milvus_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(milvus_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(milvus_module.redis_client, "set", MagicMock()) + + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._collection_name = "collection_1" + vector._consistency_level = "Session" + vector._client_config = _config(milvus_module) + vector._hybrid_search_enabled = False + vector._client = MagicMock() + + monkeypatch.setattr(milvus_module.redis_client, "get", MagicMock(return_value=1)) + vector.create_collection([[0.1, 0.2]], metadatas=[{"doc_id": "1"}], index_params={"index_type": "HNSW"}) + vector._client.create_collection.assert_not_called() + + monkeypatch.setattr(milvus_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.has_collection.return_value = True + vector.create_collection([[0.1, 0.2]], metadatas=[{"doc_id": "1"}], index_params={"index_type": "HNSW"}) + milvus_module.redis_client.set.assert_called() + + +def test_create_collection_builds_schema_and_indexes(milvus_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(milvus_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(milvus_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(milvus_module.redis_client, "set", MagicMock()) + + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._collection_name = "collection_1" + vector._consistency_level = "Session" + vector._client = MagicMock() + vector._client.has_collection.return_value = False + vector._load_collection_fields = MagicMock() + + vector._client_config = _config(milvus_module, analyzer_params='{"tokenizer":"standard"}') + vector._hybrid_search_enabled = True + vector.create_collection( + embeddings=[[0.1, 0.2]], + metadatas=[{"doc_id": "1"}], + index_params={"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8}}, + ) + + call_kwargs = vector._client.create_collection.call_args.kwargs + schema = call_kwargs["schema"] + index_params_obj = call_kwargs["index_params"] + field_names = [f.name for f in schema.fields] + + assert milvus_module.Field.SPARSE_VECTOR in field_names + assert len(schema.functions) == 1 + assert len(index_params_obj.indexes) == 2 + assert call_kwargs["consistency_level"] == "Session" + + +def test_factory_initializes_milvus_vector(milvus_module, monkeypatch): + factory = milvus_module.MilvusVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(milvus_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(milvus_module.dify_config, "MILVUS_URI", "http://localhost:19530") + monkeypatch.setattr(milvus_module.dify_config, "MILVUS_TOKEN", "") + monkeypatch.setattr(milvus_module.dify_config, "MILVUS_USER", "root") + monkeypatch.setattr(milvus_module.dify_config, "MILVUS_PASSWORD", "Milvus") + monkeypatch.setattr(milvus_module.dify_config, "MILVUS_DATABASE", "default") + monkeypatch.setattr(milvus_module.dify_config, "MILVUS_ENABLE_HYBRID_SEARCH", True) + monkeypatch.setattr(milvus_module.dify_config, "MILVUS_ANALYZER_PARAMS", '{"tokenizer":"standard"}') + + with patch.object(milvus_module, "MilvusVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/myscale/test_myscale_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/myscale/test_myscale_vector.py new file mode 100644 index 00000000000..a75ba822385 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/myscale/test_myscale_vector.py @@ -0,0 +1,230 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_clickhouse_connect_module(): + clickhouse_connect = types.ModuleType("clickhouse_connect") + + class QueryResult: + def __init__(self, rows=None, named_rows=None): + self.row_count = len(rows or []) + self.result_rows = rows or [] + self._named_rows = named_rows or [] + + def named_results(self): + return self._named_rows + + class Client: + def __init__(self): + self.command = MagicMock() + self.query = MagicMock(return_value=QueryResult()) + + client = Client() + + def get_client(**_kwargs): + return client + + clickhouse_connect.get_client = get_client + clickhouse_connect.QueryResult = QueryResult + clickhouse_connect._fake_client = client + return clickhouse_connect + + +@pytest.fixture +def myscale_module(monkeypatch): + fake_module = _build_fake_clickhouse_connect_module() + monkeypatch.setitem(sys.modules, "clickhouse_connect", fake_module) + + import core.rag.datasource.vdb.myscale.myscale_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.MyScaleConfig( + host="localhost", + port=8123, + user="default", + password="", + database="dify", + fts_params="", + ) + + +def test_escape_str_replaces_backslash_and_quote(myscale_module): + escaped = myscale_module.MyScaleVector.escape_str(r"text\with'special") + assert escaped == "text with special" + + +def test_search_raises_for_invalid_top_k(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector._search("distance(vector, [0.1, 0.2])", myscale_module.SortOrder.ASC, top_k=0) + + +def test_search_builds_where_clause_for_cosine_threshold(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + vector._client.query.return_value = myscale_module.get_client().query.return_value.__class__( + named_rows=[{"text": "doc-1", "vector": [0.1, 0.2], "metadata": {"doc_id": "seg-1"}}] + ) + + docs = vector._search("distance(vector, [0.1, 0.2])", myscale_module.SortOrder.ASC, top_k=1, score_threshold=0.2) + + assert len(docs) == 1 + sql = vector._client.query.call_args.args[0] + assert "WHERE dist < 0.8" in sql + + +def test_delete_by_ids_short_circuits_on_empty_list(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + vector._client.command.reset_mock() + + vector.delete_by_ids([]) + vector._client.command.assert_not_called() + + +def test_factory_initializes_lower_case_collection_name(myscale_module, monkeypatch): + factory = myscale_module.MyScaleVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(myscale_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_HOST", "localhost") + monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_PORT", 8123) + monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_USER", "default") + monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_PASSWORD", "") + monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_DATABASE", "dify") + monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_FTS_PARAMS", "") + + with patch.object(myscale_module, "MyScaleVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None + + +def test_init_and_get_type_set_expected_defaults(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + + assert vector.get_type() == "myscale" + assert vector._vec_order == myscale_module.SortOrder.ASC + vector._client.command.assert_called_with("SET allow_experimental_object_type=1") + + +def test_create_calls_create_collection_and_add_texts(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + vector._create_collection = MagicMock() + vector.add_texts = MagicMock(return_value=["seg-1"]) + docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})] + + result = vector.create(docs, [[0.1, 0.2]]) + + assert result == ["seg-1"] + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once() + + +def test_create_collection_builds_expected_sql(myscale_module): + config = myscale_module.MyScaleConfig( + host="localhost", + port=8123, + user="default", + password="", + database="dify", + fts_params="tokenizer=unicode", + ) + vector = myscale_module.MyScaleVector("collection_1", config) + vector._client.command.reset_mock() + + vector._create_collection(3) + + assert vector._client.command.call_count == 2 + sql = vector._client.command.call_args_list[1].args[0] + assert "CREATE TABLE IF NOT EXISTS dify.collection_1" in sql + assert "CONSTRAINT cons_vec_len CHECK length(vector) = 3" in sql + assert "INDEX text_idx text TYPE fts('tokenizer=unicode')" in sql + + +def test_add_texts_inserts_rows_and_returns_ids(myscale_module, monkeypatch): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + monkeypatch.setattr(myscale_module.uuid, "uuid4", lambda: "generated-uuid") + docs = [ + Document(page_content=r"te'xt\1", metadata={"doc_id": "doc-a", "document_id": "d-1"}), + Document(page_content="text-2", metadata={"document_id": "d-2"}), + SimpleNamespace(page_content="text-3", metadata=None), + ] + + ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]]) + + assert ids == ["doc-a", "generated-uuid"] + sql = vector._client.command.call_args.args[0] + assert "INSERT INTO dify.collection_1" in sql + assert "te xt 1" in sql + + +def test_text_exists_and_metadata_operations(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + vector._client.query.return_value = SimpleNamespace(row_count=1, result_rows=[("id-1",), ("id-2",)]) + + assert vector.text_exists("id-1") is True + assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"] + + vector.delete_by_ids(["id-1", "id-2"]) + vector.delete_by_metadata_field("document_id", "doc-1") + assert vector._client.command.call_count >= 2 + + +def test_search_delegation_methods(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + vector._search = MagicMock(return_value=["result"]) + + result_vector = vector.search_by_vector([0.1, 0.2], top_k=2) + result_text = vector.search_by_full_text("hello", top_k=2) + + assert result_vector == ["result"] + assert result_text == ["result"] + assert vector._search.call_count == 2 + + +def test_search_with_document_filter_and_exception(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + vector._client.query.return_value = SimpleNamespace( + named_results=lambda: [{"text": "doc", "vector": [0.1], "metadata": {"doc_id": "1"}}] + ) + + docs = vector._search( + "distance(vector, [0.1])", + myscale_module.SortOrder.ASC, + top_k=2, + document_ids_filter=["doc-1", "doc-2"], + ) + assert len(docs) == 1 + sql = vector._client.query.call_args.args[0] + assert "metadata['document_id'] in ('doc-1', 'doc-2')" in sql + + vector._client.query.side_effect = RuntimeError("boom") + assert vector._search("distance(vector, [0.1])", myscale_module.SortOrder.ASC, top_k=1) == [] + + +def test_delete_drops_table(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + vector._client.command.reset_mock() + + vector.delete() + + vector._client.command.assert_called_once_with("DROP TABLE IF EXISTS dify.collection_1") diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/oceanbase/test_oceanbase_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/oceanbase/test_oceanbase_vector.py new file mode 100644 index 00000000000..27d8198ec02 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/oceanbase/test_oceanbase_vector.py @@ -0,0 +1,553 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError +from sqlalchemy.exc import SQLAlchemyError + +from core.rag.models.document import Document + + +def _build_fake_pyobvector_module(): + pyobvector = types.ModuleType("pyobvector") + + class VECTOR: + def __init__(self, dim): + self.dim = dim + + def l2_distance(*_args, **_kwargs): + return "l2" + + def cosine_distance(*_args, **_kwargs): + return "cosine" + + def inner_product(*_args, **_kwargs): + return "inner_product" + + class ObVecClient: + def __init__(self, **_kwargs): + self.metadata_obj = SimpleNamespace(tables={}) + self.engine = MagicMock() + self.check_table_exists = MagicMock(return_value=False) + self.perform_raw_text_sql = MagicMock() + self.prepare_index_params = MagicMock() + self.create_table_with_index_params = MagicMock() + self.refresh_metadata = MagicMock() + self.insert = MagicMock() + self.refresh_index = MagicMock() + self.get = MagicMock() + self.delete = MagicMock() + self.set_ob_hnsw_ef_search = MagicMock() + self.ann_search = MagicMock(return_value=[]) + self.drop_table_if_exist = MagicMock() + + pyobvector.VECTOR = VECTOR + pyobvector.ObVecClient = ObVecClient + pyobvector.l2_distance = l2_distance + pyobvector.cosine_distance = cosine_distance + pyobvector.inner_product = inner_product + return pyobvector + + +@pytest.fixture +def oceanbase_module(monkeypatch): + monkeypatch.setitem(sys.modules, "pyobvector", _build_fake_pyobvector_module()) + + import core.rag.datasource.vdb.oceanbase.oceanbase_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.OceanBaseVectorConfig( + host="127.0.0.1", + port=2881, + user="root", + password="secret", + database="test", + enable_hybrid_search=True, + batch_size=10, + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config OCEANBASE_VECTOR_HOST is required"), + ("port", 0, "config OCEANBASE_VECTOR_PORT is required"), + ("user", "", "config OCEANBASE_VECTOR_USER is required"), + ("database", "", "config OCEANBASE_VECTOR_DATABASE is required"), + ], +) +def test_oceanbase_config_validation(oceanbase_module, field, value, message): + values = _config(oceanbase_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + oceanbase_module.OceanBaseVectorConfig.model_validate(values) + + +def test_init_rejects_invalid_collection_name(oceanbase_module): + with pytest.raises(ValueError, match="Invalid collection name"): + oceanbase_module.OceanBaseVector("invalid-name", _config(oceanbase_module)) + + +def test_distance_to_score_for_supported_metrics(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._config = SimpleNamespace(metric_type="l2") + assert vector._distance_to_score(3.0) == pytest.approx(0.25) + + vector._config = SimpleNamespace(metric_type="cosine") + assert vector._distance_to_score(0.2) == pytest.approx(0.8) + + vector._config = SimpleNamespace(metric_type="inner_product") + assert vector._distance_to_score(-0.2) == pytest.approx(0.2) + + +def test_get_distance_func_raises_for_unknown_metric(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._config = SimpleNamespace(metric_type="manhattan") + + with pytest.raises(ValueError, match="Unsupported metric_type"): + vector._get_distance_func() + + +def test_process_search_results_handles_json_and_score_threshold(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + rows = [ + ("doc-1", '{"doc_id":"1"}', 0.9), + ("doc-2", "not-json", 0.8), + ("doc-3", {"doc_id": "3"}, 0.3), + ] + + docs = vector._process_search_results(rows, score_threshold=0.5, score_key="rank") + + assert len(docs) == 2 + assert docs[0].metadata["doc_id"] == "1" + assert docs[0].metadata["rank"] == 0.9 + assert docs[1].metadata["rank"] == 0.8 + + +def test_search_by_vector_validates_document_id_format(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._hnsw_ef_search = -1 + vector._config = SimpleNamespace(metric_type="cosine") + vector._client = MagicMock() + + with pytest.raises(ValueError, match="Invalid document ID format"): + vector.search_by_vector([0.1, 0.2], document_ids_filter=["bad id"]) + + +def test_search_by_full_text_returns_empty_when_disabled(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._hybrid_search_enabled = False + vector._collection_name = "collection_1" + + assert vector.search_by_full_text("query") == [] + + +def test_check_hybrid_search_support_uses_version_comment(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._config = SimpleNamespace(enable_hybrid_search=True) + vector._client = MagicMock() + cursor = MagicMock() + cursor.fetchone.return_value = ("OceanBase_CE 4.3.5.1 (rxxxxxxxxx) (Built Mar 18 2025)",) + vector._client.perform_raw_text_sql.return_value = cursor + + assert vector._check_hybrid_search_support() is True + + cursor.fetchone.return_value = ("OceanBase_CE 4.3.4.0 (rxxxxxxxxx) (Built Mar 18 2025)",) + assert vector._check_hybrid_search_support() is False + + +def test_init_get_type_and_field_loading(oceanbase_module): + config = _config(oceanbase_module) + config.enable_hybrid_search = False + + table = SimpleNamespace(columns=[SimpleNamespace(name="id"), SimpleNamespace(name="text")]) + fake_client = oceanbase_module.ObVecClient() + fake_client.check_table_exists.return_value = True + fake_client.metadata_obj.tables = {"collection_1": table} + + with patch.object(oceanbase_module, "ObVecClient", return_value=fake_client): + vector = oceanbase_module.OceanBaseVector("collection_1", config) + + assert vector.get_type() == "oceanbase" + assert vector.field_exists("text") is True + + +def test_load_collection_fields_handles_missing_table_and_exception(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._fields = [] + vector._client = MagicMock() + vector._client.metadata_obj.tables = {} + + vector._load_collection_fields() + assert vector._fields == [] + + vector._client.metadata_obj.tables = {"collection_1": MagicMock(columns=MagicMock(side_effect=RuntimeError("x")))} + vector._load_collection_fields() + assert vector._fields == [] + + +def test_create_delegates_to_collection_and_insert(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="text", metadata={"doc_id": "1"})] + + vector.create(docs, [[0.1, 0.2]]) + + assert vector._vec_dim == 2 + vector._create_collection.assert_called_once() + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_create_collection_cache_and_existing_table_short_circuits(oceanbase_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(oceanbase_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(oceanbase_module.redis_client, "set", MagicMock()) + + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._vec_dim = 2 + vector._hybrid_search_enabled = False + vector._config = SimpleNamespace(metric_type="cosine", hnsw_m=16, hnsw_ef_construction=64) + vector._client = MagicMock() + vector.delete = MagicMock() + vector._load_collection_fields = MagicMock() + + monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection() + vector._client.check_table_exists.assert_not_called() + + monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.check_table_exists.return_value = True + vector._create_collection() + vector.delete.assert_not_called() + + +def test_create_collection_happy_path_with_hybrid_and_index(oceanbase_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(oceanbase_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(oceanbase_module.redis_client, "set", MagicMock()) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_FULLTEXT_PARSER", "ik") + monkeypatch.setattr(oceanbase_module, "Column", lambda *args, **kwargs: SimpleNamespace(args=args, kwargs=kwargs)) + monkeypatch.setattr(oceanbase_module, "VECTOR", lambda dim: SimpleNamespace(dim=dim)) + + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._vec_dim = 3 + vector._hybrid_search_enabled = True + vector._config = SimpleNamespace(metric_type="cosine", hnsw_m=16, hnsw_ef_construction=64) + vector._client = MagicMock() + vector._client.check_table_exists.return_value = False + vector._client.perform_raw_text_sql.side_effect = [ + [[None, None, None, None, None, None, "30"]], + None, + None, + ] + index_params = MagicMock() + vector._client.prepare_index_params.return_value = index_params + vector.delete = MagicMock() + vector._load_collection_fields = MagicMock() + + vector._create_collection() + + vector.delete.assert_called_once() + vector._client.create_table_with_index_params.assert_called_once() + index_params.add_index.assert_called_once() + vector._client.refresh_metadata.assert_called_once_with(["collection_1"]) + oceanbase_module.redis_client.set.assert_called_once() + + +def test_create_collection_error_paths(oceanbase_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(oceanbase_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(oceanbase_module, "Column", lambda *args, **kwargs: SimpleNamespace(args=args, kwargs=kwargs)) + monkeypatch.setattr(oceanbase_module, "VECTOR", lambda dim: SimpleNamespace(dim=dim)) + + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._vec_dim = 2 + vector._hybrid_search_enabled = True + vector._config = SimpleNamespace(metric_type="cosine", hnsw_m=16, hnsw_ef_construction=64) + vector._client = MagicMock() + vector._client.check_table_exists.return_value = False + vector._client.prepare_index_params.return_value = MagicMock() + vector.delete = MagicMock() + vector._load_collection_fields = MagicMock() + + vector._client.perform_raw_text_sql.return_value = [] + with pytest.raises(ValueError, match="ob_vector_memory_limit_percentage not found"): + vector._create_collection() + + vector._client.perform_raw_text_sql.side_effect = [ + [[None, None, None, None, None, None, "0"]], + RuntimeError("no privilege"), + ] + with pytest.raises(Exception, match="Failed to set ob_vector_memory_limit_percentage"): + vector._create_collection() + + vector._client.perform_raw_text_sql.side_effect = [[[None, None, None, None, None, None, "30"]]] + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_FULLTEXT_PARSER", "not-valid") + with pytest.raises(ValueError, match="Invalid OceanBase full-text parser"): + vector._create_collection() + + +def test_create_collection_fulltext_and_metadata_index_exceptions(oceanbase_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(oceanbase_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(oceanbase_module.redis_client, "set", MagicMock()) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_FULLTEXT_PARSER", "ik") + monkeypatch.setattr(oceanbase_module, "Column", lambda *args, **kwargs: SimpleNamespace(args=args, kwargs=kwargs)) + monkeypatch.setattr(oceanbase_module, "VECTOR", lambda dim: SimpleNamespace(dim=dim)) + + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._vec_dim = 2 + vector._hybrid_search_enabled = True + vector._config = SimpleNamespace(metric_type="cosine", hnsw_m=16, hnsw_ef_construction=64) + vector._client = MagicMock() + vector._client.check_table_exists.return_value = False + vector._client.prepare_index_params.return_value = MagicMock() + vector.delete = MagicMock() + vector._load_collection_fields = MagicMock() + + vector._client.perform_raw_text_sql.side_effect = [ + [[None, None, None, None, None, None, "30"]], + RuntimeError("fulltext failed"), + ] + with pytest.raises(Exception, match="Failed to add fulltext index"): + vector._create_collection() + + vector._hybrid_search_enabled = False + vector._client.perform_raw_text_sql.side_effect = [ + [[None, None, None, None, None, None, "30"]], + SQLAlchemyError("metadata index failed"), + ] + vector._create_collection() + vector._client.refresh_metadata.assert_called_once_with(["collection_1"]) + + +def test_check_hybrid_search_support_false_and_exception(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._config = SimpleNamespace(enable_hybrid_search=False) + vector._client = MagicMock() + assert vector._check_hybrid_search_support() is False + + vector._config = SimpleNamespace(enable_hybrid_search=True) + vector._client.perform_raw_text_sql.side_effect = RuntimeError("boom") + assert vector._check_hybrid_search_support() is False + + +def test_add_texts_batches_refresh_and_exceptions(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._config = SimpleNamespace(batch_size=2, hnsw_refresh_threshold=2) + vector._client = MagicMock() + vector._get_uuids = MagicMock(return_value=["id-1", "id-2", "id-3"]) + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + Document(page_content="b", metadata={"doc_id": "id-2"}), + Document(page_content="c", metadata={"doc_id": "id-3"}), + ] + + vector.add_texts(docs, [[0.1], [0.2], [0.3]]) + assert vector._client.insert.call_count == 2 + vector._client.refresh_index.assert_called_once() + + vector._client.insert.reset_mock() + vector._client.refresh_index.reset_mock() + vector._client.insert.side_effect = RuntimeError("insert failed") + with pytest.raises(Exception, match="Failed to insert batch"): + vector.add_texts([docs[0]], [[0.1]]) + + vector._client.insert.side_effect = None + vector._client.insert.return_value = None + vector._client.refresh_index.side_effect = SQLAlchemyError("refresh failed") + vector._config = SimpleNamespace(batch_size=10, hnsw_refresh_threshold=1) + vector._get_uuids.return_value = ["id-1"] + vector.add_texts([docs[0]], [[0.1]]) + + +def test_text_exists_and_delete_by_ids(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + vector._client.get.return_value = SimpleNamespace(rowcount=1) + assert vector.text_exists("id-1") is True + + vector._client.get.side_effect = RuntimeError("boom") + with pytest.raises(Exception, match="Failed to check text existence"): + vector.text_exists("id-1") + + vector.delete_by_ids([]) + vector._client.delete.assert_not_called() + + vector._client.delete.side_effect = None + vector.delete_by_ids(["id-1"]) + vector._client.delete.assert_called_once() + + vector._client.delete.side_effect = RuntimeError("boom") + with pytest.raises(Exception, match="Failed to delete documents"): + vector.delete_by_ids(["id-1"]) + + +def test_get_ids_and_delete_by_metadata_field(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + execute_result = [("id-1",), ("id-2",)] + + conn = MagicMock() + conn.__enter__.return_value = conn + conn.__exit__.return_value = None + conn.execute.return_value = execute_result + vector._client.engine.connect.return_value = conn + + ids = vector.get_ids_by_metadata_field("document_id", "doc-1") + assert ids == ["id-1", "id-2"] + + with pytest.raises(Exception, match="Failed to query documents by metadata field"): + vector.get_ids_by_metadata_field("bad key!", "doc-1") + + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"]) + vector.delete_by_ids = MagicMock() + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete_by_ids.assert_called_once_with(["id-1"]) + + vector.get_ids_by_metadata_field = MagicMock(return_value=[]) + vector.delete_by_ids.reset_mock() + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete_by_ids.assert_not_called() + + +def test_search_by_full_text_paths(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._hybrid_search_enabled = True + vector.field_exists = MagicMock(return_value=False) + + assert vector.search_by_full_text("query") == [] + + vector.field_exists.return_value = True + vector._client = MagicMock() + conn = MagicMock() + tx = MagicMock() + tx.__enter__.return_value = tx + tx.__exit__.return_value = None + conn.begin.return_value = tx + conn.__enter__.return_value = conn + conn.__exit__.return_value = None + conn.execute.return_value.fetchall.return_value = [("text-1", '{"doc_id":"1"}', 0.9)] + vector._client.engine.connect.return_value = conn + + docs = vector.search_by_full_text("query", top_k=2, document_ids_filter=["d-1"], score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].metadata["score"] == 0.9 + + with pytest.raises(Exception, match="Full-text search failed"): + vector.search_by_full_text("query", top_k=0) + + +def test_search_by_vector_paths(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._hnsw_ef_search = -1 + vector._config = SimpleNamespace(metric_type="cosine") + vector._client = MagicMock() + vector._client.ann_search.return_value = [("doc-1", '{"doc_id":"1"}', 0.2)] + vector._process_search_results = MagicMock(return_value=["doc"]) + + docs = vector.search_by_vector( + [0.1, 0.2], + ef_search=10, + top_k=3, + score_threshold=0.1, + document_ids_filter=["good_id"], + ) + assert docs == ["doc"] + vector._client.set_ob_hnsw_ef_search.assert_called_once_with(10) + + with pytest.raises(ValueError, match="Invalid score_threshold parameter"): + vector.search_by_vector([0.1], score_threshold="x") + + vector._client.ann_search.side_effect = RuntimeError("boom") + with pytest.raises(Exception, match="Vector search failed"): + vector.search_by_vector([0.1], score_threshold=0.1) + + +def test_get_distance_func_and_distance_to_score_errors(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._config = SimpleNamespace(metric_type="cosine") + assert vector._get_distance_func() is oceanbase_module.cosine_distance + + vector._config = SimpleNamespace(metric_type="unknown") + with pytest.raises(ValueError, match="Unsupported metric_type"): + vector._distance_to_score(0.1) + + +def test_delete_success_and_exception(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + + vector.delete() + vector._client.drop_table_if_exist.assert_called_once_with("collection_1") + + vector._client.drop_table_if_exist.side_effect = RuntimeError("boom") + with pytest.raises(Exception, match="Failed to delete collection"): + vector.delete() + + +def test_oceanbase_factory_uses_existing_or_generated_collection(oceanbase_module, monkeypatch): + factory = oceanbase_module.OceanBaseVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(oceanbase_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_HOST", "127.0.0.1") + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_PORT", 2881) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_USER", "root") + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_PASSWORD", "password") + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_DATABASE", "test") + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_ENABLE_HYBRID_SEARCH", True) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_BATCH_SIZE", 10) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_METRIC_TYPE", "cosine") + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_HNSW_M", 16) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_HNSW_EF_CONSTRUCTION", 64) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_HNSW_EF_SEARCH", -1) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_POOL_SIZE", 5) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_MAX_OVERFLOW", 10) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_HNSW_REFRESH_THRESHOLD", 1000) + + with patch.object(oceanbase_module, "OceanBaseVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].args[0] == "existing_collection" + assert vector_cls.call_args_list[1].args[0] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/opengauss/test_opengauss.py b/api/tests/unit_tests/core/rag/datasource/vdb/opengauss/test_opengauss.py new file mode 100644 index 00000000000..6641dbe4a09 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/opengauss/test_opengauss.py @@ -0,0 +1,400 @@ +import importlib +import sys +import types +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_psycopg2_modules(): + psycopg2 = types.ModuleType("psycopg2") + psycopg2.__path__ = [] + psycopg2_extras = types.ModuleType("psycopg2.extras") + psycopg2_pool = types.ModuleType("psycopg2.pool") + + class SimpleConnectionPool: + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + self.getconn = MagicMock() + self.putconn = MagicMock() + + psycopg2_pool.SimpleConnectionPool = SimpleConnectionPool + psycopg2_extras.execute_values = MagicMock() + + psycopg2.pool = psycopg2_pool + psycopg2.extras = psycopg2_extras + return { + "psycopg2": psycopg2, + "psycopg2.pool": psycopg2_pool, + "psycopg2.extras": psycopg2_extras, + } + + +@pytest.fixture +def opengauss_module(monkeypatch): + for name, module in _build_fake_psycopg2_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.opengauss.opengauss as module + + return importlib.reload(module) + + +def _config(module, *, enable_pq=False): + return module.OpenGaussConfig( + host="localhost", + port=6600, + user="postgres", + password="password", + database="dify", + min_connection=1, + max_connection=5, + enable_pq=enable_pq, + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config OPENGAUSS_HOST is required"), + ("port", 0, "config OPENGAUSS_PORT is required"), + ("user", "", "config OPENGAUSS_USER is required"), + ("password", "", "config OPENGAUSS_PASSWORD is required"), + ("database", "", "config OPENGAUSS_DATABASE is required"), + ("min_connection", 0, "config OPENGAUSS_MIN_CONNECTION is required"), + ("max_connection", 0, "config OPENGAUSS_MAX_CONNECTION is required"), + ], +) +def test_opengauss_config_validation(opengauss_module, field, value, message): + values = _config(opengauss_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + opengauss_module.OpenGaussConfig.model_validate(values) + + +def test_opengauss_config_validation_rejects_min_greater_than_max(opengauss_module): + values = _config(opengauss_module).model_dump() + values["min_connection"] = 6 + values["max_connection"] = 5 + + with pytest.raises(ValidationError, match="OPENGAUSS_MIN_CONNECTION should less than OPENGAUSS_MAX_CONNECTION"): + opengauss_module.OpenGaussConfig.model_validate(values) + + +def test_init_sets_table_name_and_vector_type(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module)) + + assert vector.table_name == "embedding_collection_1" + assert vector.get_type() == "opengauss" + assert vector.pool is pool + + +def test_create_index_with_pq_executes_pq_sql(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock()) + + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module, enable_pq=True)) + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + vector._create_index(1536) + + executed_sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("enable_pq=on" in sql for sql in executed_sql) + assert any("SET hnsw_earlystop_threshold = 320" in sql for sql in executed_sql) + opengauss_module.redis_client.set.assert_called_once() + + +def test_create_index_skips_index_sql_for_large_dimension(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock()) + + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module, enable_pq=False)) + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + vector._create_index(3072) + + cursor.execute.assert_not_called() + opengauss_module.redis_client.set.assert_called_once() + + +def test_search_by_vector_validates_top_k(opengauss_module): + vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss) + + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_vector([0.1, 0.2], top_k=0) + + +def test_delete_by_ids_short_circuits_with_empty_input(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module)) + vector._get_cursor = MagicMock() + + vector.delete_by_ids([]) + + vector._get_cursor.assert_not_called() + + +def test_get_cursor_closes_commits_and_returns_connection(opengauss_module): + vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss) + pool = MagicMock() + conn = MagicMock() + cur = MagicMock() + pool.getconn.return_value = conn + conn.cursor.return_value = cur + vector.pool = pool + + with vector._get_cursor() as got_cur: + assert got_cur is cur + + cur.close.assert_called_once() + conn.commit.assert_called_once() + pool.putconn.assert_called_once_with(conn) + + +def test_create_calls_collection_insert_and_index(opengauss_module): + vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss) + vector._create_collection = MagicMock() + vector.add_texts = MagicMock() + vector._create_index = MagicMock() + docs = [Document(page_content="text", metadata={"doc_id": "seg-1"})] + + vector.create(docs, [[0.1, 0.2]]) + + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + vector._create_index.assert_called_once_with(2) + + +def test_create_index_returns_early_on_cache_hit(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock()) + + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module)) + vector._get_cursor = MagicMock() + + vector._create_index(1536) + + vector._get_cursor.assert_not_called() + opengauss_module.redis_client.set.assert_not_called() + + +def test_create_index_without_pq_executes_standard_index_sql(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock()) + + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module, enable_pq=False)) + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + vector._create_index(1536) + + sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("embedding_cosine_embedding_collection_1_idx" in query for query in sql) + + +def test_add_texts_uses_execute_values(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module)) + cursor = MagicMock() + opengauss_module.psycopg2.extras.execute_values.reset_mock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + docs = [ + Document(page_content="text-1", metadata={"doc_id": "seg-1", "document_id": "d-1"}), + SimpleNamespace(page_content="text-2", metadata=None), + ] + monkeypatch.setattr(opengauss_module.uuid, "uuid4", lambda: "generated-uuid") + + ids = vector.add_texts(docs, [[0.1], [0.2]]) + + assert ids == ["seg-1"] + opengauss_module.psycopg2.extras.execute_values.assert_called_once() + + +def test_text_exists_and_get_by_ids(opengauss_module): + vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + cursor.fetchone.return_value = ("seg-1",) + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1"), ({"doc_id": "2"}, "text-2")]) + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + assert vector.text_exists("seg-1") is True + docs = vector.get_by_ids(["seg-1", "seg-2"]) + assert len(docs) == 2 + assert docs[0].page_content == "text-1" + + +def test_delete_and_metadata_field_queries(opengauss_module): + vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + vector.delete_by_ids(["seg-1", "seg-2"]) + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete() + + sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("DELETE FROM embedding_collection_1 WHERE id IN %s" in query for query in sql) + assert any("meta->>%s = %s" in query for query in sql) + assert any("DROP TABLE IF EXISTS embedding_collection_1" in query for query in sql) + + +def test_search_by_vector_and_full_text(opengauss_module): + vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + cursor.__iter__.return_value = iter( + [ + ({"doc_id": "1"}, "text-1", 0.1), + ({"doc_id": "2"}, "text-2", 0.6), + ] + ) + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + + cursor.__iter__.return_value = iter([({"doc_id": "3"}, "full-text", 0.8)]) + full_docs = vector.search_by_full_text("hello world", top_k=2) + assert len(full_docs) == 1 + assert full_docs[0].page_content == "full-text" + + +def test_search_by_full_text_validates_top_k(opengauss_module): + vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss) + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_full_text("query", top_k=0) + + +def test_create_collection_cache_and_create_path(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock()) + + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module)) + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection(1536) + cursor.execute.assert_not_called() + + monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=None)) + vector._create_collection(1536) + cursor.execute.assert_called_once() + opengauss_module.redis_client.set.assert_called_once() + + +def test_opengauss_factory_uses_existing_or_generated_collection(opengauss_module, monkeypatch): + factory = opengauss_module.OpenGaussFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(opengauss_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_HOST", "localhost") + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_PORT", 6600) + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_USER", "postgres") + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_PASSWORD", "password") + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_DATABASE", "dify") + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_MIN_CONNECTION", 1) + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_MAX_CONNECTION", 5) + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_ENABLE_PQ", False) + + with patch.object(opengauss_module, "OpenGauss", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/opensearch/test_opensearch_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/opensearch/test_opensearch_vector.py new file mode 100644 index 00000000000..1030158dd1a --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/opensearch/test_opensearch_vector.py @@ -0,0 +1,360 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_opensearch_modules(): + opensearchpy = types.ModuleType("opensearchpy") + opensearchpy_helpers = types.ModuleType("opensearchpy.helpers") + + class BulkIndexError(Exception): + def __init__(self, errors): + super().__init__("bulk error") + self.errors = errors + + class Urllib3AWSV4SignerAuth: + def __init__(self, credentials, region, service): + self.credentials = credentials + self.region = region + self.service = service + + class Urllib3HttpConnection: + pass + + class _IndicesClient: + def __init__(self): + self.exists = MagicMock(return_value=False) + self.create = MagicMock() + self.delete = MagicMock() + + class OpenSearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.indices = _IndicesClient() + self.search = MagicMock(return_value={"hits": {"hits": []}}) + self.get = MagicMock() + + helpers = SimpleNamespace(bulk=MagicMock()) + + opensearchpy.OpenSearch = OpenSearch + opensearchpy.Urllib3AWSV4SignerAuth = Urllib3AWSV4SignerAuth + opensearchpy.Urllib3HttpConnection = Urllib3HttpConnection + opensearchpy.helpers = helpers + opensearchpy_helpers.BulkIndexError = BulkIndexError + + return { + "opensearchpy": opensearchpy, + "opensearchpy.helpers": opensearchpy_helpers, + } + + +@pytest.fixture +def opensearch_module(monkeypatch): + for name, module in _build_fake_opensearch_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.opensearch.opensearch_vector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "host": "localhost", + "port": 9200, + "secure": True, + "verify_certs": True, + "auth_method": "basic", + "user": "admin", + "password": "secret", + } + values.update(overrides) + return module.OpenSearchConfig.model_validate(values) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config OPENSEARCH_HOST is required"), + ("port", 0, "config OPENSEARCH_PORT is required"), + ], +) +def test_config_validation_required_fields(opensearch_module, field, value, message): + values = _config(opensearch_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + opensearch_module.OpenSearchConfig.model_validate(values) + + +def test_config_validation_for_aws_auth_and_https_fields(opensearch_module): + values = { + "host": "localhost", + "port": 9200, + "secure": True, + "verify_certs": True, + "auth_method": "aws_managed_iam", + "user": "admin", + "password": "secret", + } + with pytest.raises(ValidationError, match="OPENSEARCH_AWS_REGION"): + opensearch_module.OpenSearchConfig.model_validate(values) + + values = _config(opensearch_module).model_dump() + values["OPENSEARCH_SECURE"] = False + values["OPENSEARCH_VERIFY_CERTS"] = True + with pytest.raises(ValidationError, match="verify_certs=True requires secure"): + opensearch_module.OpenSearchConfig.model_validate(values) + + +def test_create_aws_managed_iam_auth(opensearch_module, monkeypatch): + class _Session: + def get_credentials(self): + return "creds" + + boto3 = types.ModuleType("boto3") + boto3.Session = _Session + monkeypatch.setitem(sys.modules, "boto3", boto3) + + config = _config( + opensearch_module, + auth_method="aws_managed_iam", + aws_region="us-east-1", + aws_service="es", + ) + auth = config.create_aws_managed_iam_auth() + + assert auth.credentials == "creds" + assert auth.region == "us-east-1" + assert auth.service == "es" + + +def test_to_opensearch_params_supports_basic_and_aws(opensearch_module): + basic_params = _config(opensearch_module).to_opensearch_params() + assert basic_params["http_auth"] == ("admin", "secret") + + aws_config = _config( + opensearch_module, + auth_method="aws_managed_iam", + aws_region="us-west-2", + aws_service="es", + ) + with patch.object(opensearch_module.OpenSearchConfig, "create_aws_managed_iam_auth", return_value="iam-auth"): + aws_params = aws_config.to_opensearch_params() + + assert aws_params["http_auth"] == "iam-auth" + + +def test_init_and_create_delegate_calls(opensearch_module): + vector = opensearch_module.OpenSearchVector("Collection_1", _config(opensearch_module)) + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})] + + vector.create(docs, [[0.1, 0.2]]) + + assert vector.get_type() == "opensearch" + vector.create_collection.assert_called_once_with([[0.1, 0.2]], [{"doc_id": "seg-1"}]) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_add_texts_supports_regular_and_aoss_clients(opensearch_module, monkeypatch): + vector = opensearch_module.OpenSearchVector("Collection_1", _config(opensearch_module, aws_service="es")) + docs = [ + Document(page_content="a", metadata={"doc_id": "1"}), + Document(page_content="b", metadata={"doc_id": "2"}), + ] + + monkeypatch.setattr(opensearch_module, "uuid4", lambda: SimpleNamespace(hex="generated-id")) + opensearch_module.helpers.bulk.reset_mock() + vector.add_texts(docs, [[0.1], [0.2]]) + actions = opensearch_module.helpers.bulk.call_args.kwargs["actions"] + assert len(actions) == 2 + assert all("_id" in action for action in actions) + + vector._client_config.aws_service = "aoss" + opensearch_module.helpers.bulk.reset_mock() + vector.add_texts(docs, [[0.3], [0.4]]) + aoss_actions = opensearch_module.helpers.bulk.call_args.kwargs["actions"] + assert all("_id" not in action for action in aoss_actions) + + +def test_metadata_lookup_and_delete_by_metadata_field(opensearch_module): + vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module)) + vector._client.search.return_value = {"hits": {"hits": [{"_id": "id-1"}, {"_id": "id-2"}]}} + + assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"] + + vector._client.search.return_value = {"hits": {"hits": []}} + assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None + + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"]) + vector.delete_by_ids = MagicMock() + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete_by_ids.assert_called_once_with(["id-1"]) + + +def test_delete_by_ids_branches_and_bulk_error_handling(opensearch_module): + vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module)) + opensearch_module.helpers.bulk.reset_mock() + vector._client.indices.exists.return_value = False + vector.delete_by_ids(["doc-1"]) + opensearch_module.helpers.bulk.assert_not_called() + + vector._client.indices.exists.return_value = True + vector.get_ids_by_metadata_field = MagicMock(side_effect=[["es-1"], None]) + vector.delete_by_ids(["doc-1", "doc-2"]) + opensearch_module.helpers.bulk.assert_called_once() + + opensearch_module.helpers.bulk.reset_mock() + vector.get_ids_by_metadata_field = MagicMock(return_value=["es-404"]) + opensearch_module.helpers.bulk.side_effect = opensearch_module.BulkIndexError( + [{"delete": {"status": 404, "_id": "es-404"}}] + ) + vector.delete_by_ids(["doc-404"]) + assert opensearch_module.helpers.bulk.call_count == 1 + + opensearch_module.helpers.bulk.side_effect = None + + +def test_delete_and_text_exists(opensearch_module): + vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module)) + vector.delete() + vector._client.indices.delete.assert_called_once_with(index="collection_1", ignore_unavailable=True) + + vector._client.get.return_value = {"_id": "id-1"} + assert vector.text_exists("id-1") is True + vector._client.get.side_effect = RuntimeError("not found") + assert vector.text_exists("id-1") is False + + +def test_search_by_vector_validates_and_builds_documents(opensearch_module): + vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module)) + + with pytest.raises(ValueError, match="query_vector should be a list"): + vector.search_by_vector("not-a-list") + + with pytest.raises(ValueError, match="should be floats"): + vector.search_by_vector([0.1, 1]) + + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_source": { + opensearch_module.Field.CONTENT_KEY: "doc-1", + opensearch_module.Field.METADATA_KEY: None, + }, + "_score": 0.9, + }, + { + "_source": { + opensearch_module.Field.CONTENT_KEY: "doc-2", + opensearch_module.Field.METADATA_KEY: {"doc_id": "2"}, + }, + "_score": 0.1, + }, + ] + } + } + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].page_content == "doc-1" + assert docs[0].metadata["score"] == pytest.approx(0.9) + + vector.search_by_vector([0.1, 0.2], top_k=3, document_ids_filter=["doc-a", "doc-b"]) + query = vector._client.search.call_args.kwargs["body"] + assert "script_score" in query["query"] + + +def test_search_by_vector_reraises_client_error(opensearch_module): + vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module)) + vector._client.search.side_effect = RuntimeError("boom") + + with pytest.raises(RuntimeError, match="boom"): + vector.search_by_vector([0.1, 0.2]) + + +def test_search_by_full_text_and_filters(opensearch_module): + vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module)) + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_source": { + opensearch_module.Field.METADATA_KEY: {"doc_id": "1"}, + opensearch_module.Field.VECTOR: [0.1], + opensearch_module.Field.CONTENT_KEY: "matched text", + } + }, + ] + } + } + + docs = vector.search_by_full_text("hello", document_ids_filter=["d-1"]) + + assert len(docs) == 1 + assert docs[0].page_content == "matched text" + query = vector._client.search.call_args.kwargs["body"] + assert query["query"]["bool"]["filter"] == [{"terms": {"metadata.document_id": ["d-1"]}}] + + +def test_create_collection_cache_and_create_path(opensearch_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(opensearch_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(opensearch_module.redis_client, "set", MagicMock()) + + vector = opensearch_module.OpenSearchVector("Collection_1", _config(opensearch_module)) + + monkeypatch.setattr(opensearch_module.redis_client, "get", MagicMock(return_value=1)) + vector._client.indices.create.reset_mock() + vector.create_collection([[0.1, 0.2]]) + vector._client.indices.create.assert_not_called() + + monkeypatch.setattr(opensearch_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.indices.exists.return_value = False + vector.create_collection([[0.1, 0.2]]) + vector._client.indices.create.assert_called_once() + index_body = vector._client.indices.create.call_args.kwargs["body"] + assert index_body["mappings"]["properties"]["vector"]["dimension"] == 2 + opensearch_module.redis_client.set.assert_called() + + +def test_opensearch_factory_initializes_expected_collection_name(opensearch_module, monkeypatch): + factory = opensearch_module.OpenSearchVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(opensearch_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_HOST", "localhost") + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_PORT", 9200) + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_SECURE", True) + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_VERIFY_CERTS", True) + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_AUTH_METHOD", "basic") + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_USER", "admin") + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_PASSWORD", "secret") + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_AWS_REGION", None) + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_AWS_SERVICE", None) + + with patch.object(opensearch_module, "OpenSearchVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/oracle/test_oraclevector.py b/api/tests/unit_tests/core/rag/datasource/vdb/oracle/test_oraclevector.py new file mode 100644 index 00000000000..817a7d342b3 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/oracle/test_oraclevector.py @@ -0,0 +1,375 @@ +import array +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import numpy +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_oracle_modules(): + jieba = types.ModuleType("jieba") + jieba_posseg = types.ModuleType("jieba.posseg") + jieba_posseg.cut = MagicMock(return_value=[]) + jieba.posseg = jieba_posseg + + oracledb = types.ModuleType("oracledb") + oracledb_connection = types.ModuleType("oracledb.connection") + + class Connection: + pass + + oracledb_connection.Connection = Connection + oracledb.defaults = SimpleNamespace(fetch_lobs=True) + oracledb.DB_TYPE_VECTOR = object() + oracledb.create_pool = MagicMock(return_value=MagicMock(release=MagicMock())) + oracledb.connect = MagicMock() + + return { + "jieba": jieba, + "jieba.posseg": jieba_posseg, + "oracledb": oracledb, + "oracledb.connection": oracledb_connection, + } + + +def _connection_with_cursor(cursor): + cursor_ctx = MagicMock() + cursor_ctx.__enter__.return_value = cursor + cursor_ctx.__exit__.return_value = None + + connection = MagicMock() + connection.__enter__.return_value = connection + connection.__exit__.return_value = None + connection.cursor.return_value = cursor_ctx + return connection + + +@pytest.fixture +def oracle_module(monkeypatch): + for name, module in _build_fake_oracle_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.oracle.oraclevector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "user": "system", + "password": "oracle", + "dsn": "oracle:1521/freepdb1", + "is_autonomous": False, + } + values.update(overrides) + return module.OracleVectorConfig.model_validate(values) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("user", "", "config ORACLE_USER is required"), + ("password", "", "config ORACLE_PASSWORD is required"), + ("dsn", "", "config ORACLE_DSN is required"), + ], +) +def test_oracle_config_validation_required_fields(oracle_module, field, value, message): + values = _config(oracle_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + oracle_module.OracleVectorConfig.model_validate(values) + + +def test_oracle_config_validation_autonomous_requirements(oracle_module): + with pytest.raises(ValidationError, match="config_dir is required"): + oracle_module.OracleVectorConfig.model_validate( + {"user": "u", "password": "p", "dsn": "d", "is_autonomous": True} + ) + + +def test_init_and_get_type(oracle_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(oracle_module.oracledb, "create_pool", MagicMock(return_value=pool)) + vector = oracle_module.OracleVector("collection_1", _config(oracle_module)) + + assert vector.get_type() == "oracle" + assert vector.table_name == "embedding_collection_1" + assert vector.pool is pool + + +def test_numpy_converters_and_type_handlers(oracle_module): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + + in_float64 = vector.numpy_converter_in(numpy.array([0.1], dtype=numpy.float64)) + in_float32 = vector.numpy_converter_in(numpy.array([0.1], dtype=numpy.float32)) + in_int8 = vector.numpy_converter_in(numpy.array([1], dtype=numpy.int8)) + assert in_float64.typecode == "d" + assert in_float32.typecode == "f" + assert in_int8.typecode == "b" + + cursor = MagicMock() + vector.input_type_handler(cursor, numpy.array([0.1], dtype=numpy.float32), 2) + cursor.var.assert_called_with( + oracle_module.oracledb.DB_TYPE_VECTOR, + arraysize=2, + inconverter=vector.numpy_converter_in, + ) + + metadata = SimpleNamespace(type_code=oracle_module.oracledb.DB_TYPE_VECTOR) + cursor.arraysize = 3 + vector.output_type_handler(cursor, metadata) + cursor.var.assert_called_with( + metadata.type_code, + arraysize=3, + outconverter=vector.numpy_converter_out, + ) + + out_int8 = vector.numpy_converter_out(array.array("b", [1])) + assert out_int8.dtype == numpy.int8 + out_float32 = vector.numpy_converter_out(array.array("f", [1.0])) + assert out_float32.dtype == numpy.float32 + out_float64 = vector.numpy_converter_out(array.array("d", [1.0])) + assert out_float64.dtype == numpy.float64 + + +def test_get_connection_supports_standard_and_autonomous_paths(oracle_module, monkeypatch): + connect = MagicMock(return_value="connection") + monkeypatch.setattr(oracle_module.oracledb, "connect", connect) + + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector.config = _config(oracle_module) + assert vector._get_connection() == "connection" + connect.assert_called_with(user="system", password="oracle", dsn="oracle:1521/freepdb1") + + vector.config = _config( + oracle_module, + is_autonomous=True, + config_dir="/wallet", + wallet_location="/wallet", + wallet_password="pw", + ) + vector._get_connection() + assert connect.call_args.kwargs["config_dir"] == "/wallet" + assert connect.call_args.kwargs["wallet_location"] == "/wallet" + + +def test_create_delegates_collection_and_insert(oracle_module): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector._create_collection = MagicMock() + vector.add_texts = MagicMock(return_value=["seg-1"]) + docs = [Document(page_content="doc", metadata={"doc_id": "seg-1"})] + + result = vector.create(docs, [[0.1, 0.2]]) + + assert result == ["seg-1"] + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_add_texts_inserts_and_logs_on_failures(oracle_module, monkeypatch): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector.table_name = "embedding_collection_1" + vector.input_type_handler = MagicMock() + vector.output_type_handler = MagicMock() + + cursor = MagicMock() + cursor.execute.side_effect = [None, RuntimeError("insert failed")] + connection = _connection_with_cursor(cursor) + vector._get_connection = MagicMock(return_value=connection) + + monkeypatch.setattr(oracle_module.uuid, "uuid4", lambda: "generated-uuid") + docs = [ + Document(page_content="a", metadata={"doc_id": "doc-a"}), + Document(page_content="b", metadata={"document_id": "doc-b"}), + SimpleNamespace(page_content="c", metadata=None), + ] + + ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]]) + + assert ids == ["doc-a", "generated-uuid"] + assert cursor.execute.call_count == 2 + assert connection.commit.call_count >= 1 + connection.close.assert_called() + + +def test_text_exists_and_get_by_ids(oracle_module): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector.table_name = "embedding_collection_1" + vector.pool = MagicMock() + + cursor = MagicMock() + cursor.fetchone.return_value = ("id-1",) + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1"), ({"doc_id": "2"}, "text-2")]) + vector._get_connection = MagicMock(return_value=_connection_with_cursor(cursor)) + + assert vector.text_exists("id-1") is True + docs = vector.get_by_ids(["id-1", "id-2"]) + assert len(docs) == 2 + assert docs[0].page_content == "text-1" + vector.pool.release.assert_called_once() + assert vector.get_by_ids([]) == [] + + +def test_delete_methods(oracle_module): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector.table_name = "embedding_collection_1" + + cursor = MagicMock() + vector._get_connection = MagicMock(return_value=_connection_with_cursor(cursor)) + + vector.delete_by_ids([]) + vector._get_connection.assert_not_called() + + vector.delete_by_ids(["id-1", "id-2"]) + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete() + + executed_sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("DELETE FROM embedding_collection_1 WHERE id IN" in sql for sql in executed_sql) + assert any("JSON_VALUE(meta" in sql for sql in executed_sql) + assert any("DROP TABLE IF EXISTS embedding_collection_1" in sql for sql in executed_sql) + + +def test_search_by_vector_with_threshold_and_filter(oracle_module): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector.table_name = "embedding_collection_1" + vector.input_type_handler = MagicMock() + vector.output_type_handler = MagicMock() + + cursor = MagicMock() + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "doc-1", 0.1), ({"doc_id": "2"}, "doc-2", 0.8)]) + connection = _connection_with_cursor(cursor) + vector._get_connection = MagicMock(return_value=connection) + + docs = vector.search_by_vector( + [0.1, 0.2], + top_k=0, + score_threshold=0.5, + document_ids_filter=["d-1", "d-2"], + ) + + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + sql = cursor.execute.call_args.args[0] + assert "fetch first 4 rows only" in sql + assert "JSON_VALUE(meta, '$.document_id') IN (:2, :3)" in sql + + +def _fake_nltk_module(*, missing_data=False): + nltk = types.ModuleType("nltk") + nltk_corpus = types.ModuleType("nltk.corpus") + + class _Data: + @staticmethod + def find(_path): + if missing_data: + raise LookupError("missing") + return True + + nltk.data = _Data() + nltk.word_tokenize = lambda text: text.split() + nltk_corpus.stopwords = SimpleNamespace(words=lambda _lang: ["and", "the"]) + return nltk, nltk_corpus + + +def test_search_by_full_text_chinese_and_english_paths(oracle_module, monkeypatch): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector.table_name = "embedding_collection_1" + + cursor = MagicMock() + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1", [0.1, 0.2])]) + vector._get_connection = MagicMock(return_value=_connection_with_cursor(cursor)) + + monkeypatch.setattr(oracle_module.pseg, "cut", MagicMock(return_value=[("张", "nr"), ("三", "nr"), ("。", "x")])) + zh_docs = vector.search_by_full_text("张三", top_k=2) + assert len(zh_docs) == 1 + zh_params = cursor.execute.call_args.args[1] + assert zh_params["kk"] == "张三" + + nltk, nltk_corpus = _fake_nltk_module(missing_data=False) + monkeypatch.setitem(sys.modules, "nltk", nltk) + monkeypatch.setitem(sys.modules, "nltk.corpus", nltk_corpus) + cursor.__iter__.return_value = iter([({"doc_id": "2"}, "text-2", [0.3, 0.4])]) + en_docs = vector.search_by_full_text("alice and bob", top_k=-1, document_ids_filter=["d-1"]) + assert len(en_docs) == 1 + en_sql = cursor.execute.call_args.args[0] + en_params = cursor.execute.call_args.args[1] + assert "fetch first 5 rows only" in en_sql + assert "doc_id_0" in en_params + + +def test_search_by_full_text_empty_query_and_missing_nltk(oracle_module, monkeypatch): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector.table_name = "embedding_collection_1" + vector._get_connection = MagicMock() + + empty_result = vector.search_by_full_text("") + assert empty_result[0].page_content == "" + + nltk, nltk_corpus = _fake_nltk_module(missing_data=True) + monkeypatch.setitem(sys.modules, "nltk", nltk) + monkeypatch.setitem(sys.modules, "nltk.corpus", nltk_corpus) + with pytest.raises(LookupError, match="required NLTK data package"): + vector.search_by_full_text("english query") + + +def test_create_collection_cache_and_execute_path(oracle_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(oracle_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(oracle_module.redis_client, "set", MagicMock()) + + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector._collection_name = "collection_1" + vector.table_name = "embedding_collection_1" + + cursor = MagicMock() + vector._get_connection = MagicMock(return_value=_connection_with_cursor(cursor)) + + monkeypatch.setattr(oracle_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection(2) + cursor.execute.assert_not_called() + + monkeypatch.setattr(oracle_module.redis_client, "get", MagicMock(return_value=None)) + vector._create_collection(2) + executed_sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("CREATE TABLE IF NOT EXISTS embedding_collection_1" in sql for sql in executed_sql) + assert any("CREATE INDEX IF NOT EXISTS idx_docs_embedding_collection_1" in sql for sql in executed_sql) + oracle_module.redis_client.set.assert_called_once() + + +def test_oracle_factory_init_vector_uses_existing_or_generated_collection(oracle_module, monkeypatch): + factory = oracle_module.OracleVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(oracle_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(oracle_module.dify_config, "ORACLE_USER", "system") + monkeypatch.setattr(oracle_module.dify_config, "ORACLE_PASSWORD", "oracle") + monkeypatch.setattr(oracle_module.dify_config, "ORACLE_DSN", "oracle:1521/freepdb1") + monkeypatch.setattr(oracle_module.dify_config, "ORACLE_CONFIG_DIR", None) + monkeypatch.setattr(oracle_module.dify_config, "ORACLE_WALLET_LOCATION", None) + monkeypatch.setattr(oracle_module.dify_config, "ORACLE_WALLET_PASSWORD", None) + monkeypatch.setattr(oracle_module.dify_config, "ORACLE_IS_AUTONOMOUS", False) + + with patch.object(oracle_module, "OracleVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/pgvecto_rs/test_pgvecto_rs.py b/api/tests/unit_tests/core/rag/datasource/vdb/pgvecto_rs/test_pgvecto_rs.py new file mode 100644 index 00000000000..1aec81b8ac8 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/pgvecto_rs/test_pgvecto_rs.py @@ -0,0 +1,317 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError +from sqlalchemy.types import UserDefinedType + +from core.rag.models.document import Document + + +def _build_fake_pgvecto_modules(): + pgvecto_rs = types.ModuleType("pgvecto_rs") + pgvecto_rs_sqlalchemy = types.ModuleType("pgvecto_rs.sqlalchemy") + + class VECTOR(UserDefinedType): + def __init__(self, dim): + self.dim = dim + + pgvecto_rs_sqlalchemy.VECTOR = VECTOR + return { + "pgvecto_rs": pgvecto_rs, + "pgvecto_rs.sqlalchemy": pgvecto_rs_sqlalchemy, + } + + +class _FakeSessionContext: + def __init__(self, calls, execute_results=None): + self.calls = calls + self.execute_results = execute_results or [] + self.execute = MagicMock(side_effect=self._execute_side_effect) + self.commit = MagicMock() + + def _execute_side_effect(self, *args, **kwargs): + self.calls.append((args, kwargs)) + if self.execute_results: + return self.execute_results.pop(0) + return MagicMock() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return None + + +def _session_factory(calls, execute_results=None): + def _session(_client): + return _FakeSessionContext(calls=calls, execute_results=execute_results) + + return _session + + +@pytest.fixture +def pgvecto_module(monkeypatch): + for name, module in _build_fake_pgvecto_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.pgvecto_rs.collection as collection_module + import core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs as module + + return importlib.reload(module), importlib.reload(collection_module) + + +def _config(module, **overrides): + values = { + "host": "localhost", + "port": 5432, + "user": "postgres", + "password": "secret", + "database": "postgres", + } + values.update(overrides) + return module.PgvectoRSConfig.model_validate(values) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config PGVECTO_RS_HOST is required"), + ("port", 0, "config PGVECTO_RS_PORT is required"), + ("user", "", "config PGVECTO_RS_USER is required"), + ("password", "", "config PGVECTO_RS_PASSWORD is required"), + ("database", "", "config PGVECTO_RS_DATABASE is required"), + ], +) +def test_pgvecto_config_validation(pgvecto_module, field, value, message): + module, _ = pgvecto_module + values = _config(module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + module.PgvectoRSConfig.model_validate(values) + + +def test_collection_base_has_expected_annotations(pgvecto_module): + _, collection_module = pgvecto_module + annotations = collection_module.CollectionORM.__annotations__ + assert {"id", "text", "meta", "vector"} <= set(annotations) + + +def test_init_get_type_and_create_delegate(pgvecto_module, monkeypatch): + module, _ = pgvecto_module + session_calls = [] + monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine")) + monkeypatch.setattr(module, "Session", _session_factory(session_calls)) + + vector = module.PGVectoRS("collection_1", _config(module), dim=3) + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="hello", metadata={"doc_id": "1"})] + vector.create(docs, [[0.1, 0.2]]) + + assert vector.get_type() == module.VectorType.PGVECTO_RS + module.create_engine.assert_called_once_with("postgresql+psycopg2://postgres:secret@localhost:5432/postgres") + assert any("CREATE EXTENSION IF NOT EXISTS vectors" in str(args[0]) for args, _ in session_calls) + vector.create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_create_collection_cache_and_sql_execution(pgvecto_module, monkeypatch): + module, _ = pgvecto_module + session_calls = [] + monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine")) + monkeypatch.setattr(module, "Session", _session_factory(session_calls)) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(module.redis_client, "set", MagicMock()) + + vector = module.PGVectoRS("collection_1", _config(module), dim=3) + monkeypatch.setattr(module.redis_client, "get", MagicMock(return_value=1)) + vector.create_collection(3) + assert not any("CREATE TABLE IF NOT EXISTS collection_1" in str(args[0]) for args, _ in session_calls) + + monkeypatch.setattr(module.redis_client, "get", MagicMock(return_value=None)) + vector.create_collection(3) + assert any("CREATE TABLE IF NOT EXISTS collection_1" in str(args[0]) for args, _ in session_calls) + assert any("CREATE INDEX IF NOT EXISTS collection_1_embedding_index" in str(args[0]) for args, _ in session_calls) + module.redis_client.set.assert_called() + + +def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch): + module, _ = pgvecto_module + init_calls = [] + runtime_calls = [] + execute_results = [SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)]), SimpleNamespace(fetchall=lambda: [])] + + monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine")) + monkeypatch.setattr(module, "Session", _session_factory(init_calls)) + vector = module.PGVectoRS("collection_1", _config(module), dim=3) + + monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=list(execute_results))) + + class _InsertBuilder: + def __init__(self, table): + self.table = table + + def values(self, **kwargs): + return ("insert", kwargs) + + monkeypatch.setattr(module, "insert", lambda table: _InsertBuilder(table)) + monkeypatch.setattr(module, "uuid4", MagicMock(side_effect=["uuid-1", "uuid-2"])) + docs = [ + Document(page_content="a", metadata={"doc_id": "1"}), + Document(page_content="b", metadata={"doc_id": "2"}), + ] + + ids = vector.add_texts(docs, [[0.1], [0.2]]) + assert ids == ["uuid-1", "uuid-2"] + assert any(call[0][0][0] == "insert" for call in runtime_calls if call[0]) + + monkeypatch.setattr( + module, + "Session", + _session_factory(runtime_calls, execute_results=[SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)])]), + ) + assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"] + + monkeypatch.setattr( + module, + "Session", + _session_factory(runtime_calls, execute_results=[SimpleNamespace(fetchall=lambda: [])]), + ) + assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None + + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"]) + vector.delete_by_metadata_field("document_id", "doc-1") + assert any("DELETE FROM collection_1 WHERE id = ANY(:ids)" in str(args[0]) for args, _ in runtime_calls) + + runtime_calls.clear() + monkeypatch.setattr( + module, + "Session", + _session_factory( + runtime_calls, + execute_results=[ + SimpleNamespace(fetchall=lambda: [("row-id-1",)]), + MagicMock(), + ], + ), + ) + vector.delete_by_ids(["doc-1"]) + assert any("meta->>'doc_id' = ANY (:doc_ids)" in str(args[0]) for args, _ in runtime_calls) + assert any("DELETE FROM collection_1 WHERE id = ANY(:ids)" in str(args[0]) for args, _ in runtime_calls) + + runtime_calls.clear() + monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[MagicMock()])) + vector.delete() + assert any("DROP TABLE IF EXISTS collection_1" in str(args[0]) for args, _ in runtime_calls) + + +def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch): + module, _ = pgvecto_module + init_calls = [] + monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine")) + monkeypatch.setattr(module, "Session", _session_factory(init_calls)) + vector = module.PGVectoRS("collection_1", _config(module), dim=3) + + runtime_calls = [] + monkeypatch.setattr( + module, + "Session", + _session_factory( + runtime_calls, + execute_results=[ + SimpleNamespace(fetchall=lambda: [("id-1",)]), + SimpleNamespace(fetchall=lambda: []), + ], + ), + ) + assert vector.text_exists("doc-1") is True + assert vector.text_exists("doc-1") is False + + class _DistanceExpr: + def label(self, _name): + return self + + class _VectorColumn: + def op(self, _operator, return_type=None): + def _call(_query_vector): + return _DistanceExpr() + + return _call + + class _MetaFilter: + def in_(self, values): + return ("in", values) + + class _MetaColumn: + def __getitem__(self, _item): + return _MetaFilter() + + class _Stmt: + def __init__(self): + self.where_called = False + + def limit(self, _value): + return self + + def order_by(self, _value): + return self + + def where(self, _value): + self.where_called = True + return self + + stmt = _Stmt() + monkeypatch.setattr(module, "select", lambda *_args: stmt) + + vector._table = SimpleNamespace(vector=_VectorColumn(), meta=_MetaColumn()) + rows = [ + (SimpleNamespace(meta={"doc_id": "1"}, text="text-1"), 0.1), + (SimpleNamespace(meta={"doc_id": "2"}, text="text-2"), 0.8), + ] + monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[rows])) + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + assert stmt.where_called is True + assert vector.search_by_full_text("hello") == [] + + +def test_factory_uses_existing_or_generated_collection(pgvecto_module, monkeypatch): + module, _ = pgvecto_module + factory = module.PGVectoRSFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(module.dify_config, "PGVECTO_RS_HOST", "localhost") + monkeypatch.setattr(module.dify_config, "PGVECTO_RS_PORT", 5432) + monkeypatch.setattr(module.dify_config, "PGVECTO_RS_USER", "postgres") + monkeypatch.setattr(module.dify_config, "PGVECTO_RS_PASSWORD", "secret") + monkeypatch.setattr(module.dify_config, "PGVECTO_RS_DATABASE", "postgres") + + embeddings = MagicMock() + embeddings.embed_query.return_value = [0.1, 0.2, 0.3] + + with patch.object(module, "PGVectoRS", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=embeddings) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=embeddings) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/test_pgvector.py b/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/test_pgvector.py index 4998a9858fa..7505262eb78 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/test_pgvector.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/test_pgvector.py @@ -1,16 +1,19 @@ -import unittest +from contextlib import contextmanager +from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +import core.rag.datasource.vdb.pgvector.pgvector as pgvector_module from core.rag.datasource.vdb.pgvector.pgvector import ( PGVector, PGVectorConfig, ) +from core.rag.models.document import Document -class TestPGVector(unittest.TestCase): - def setUp(self): +class TestPGVector: + def setup_method(self, method): self.config = PGVectorConfig( host="localhost", port=5432, @@ -323,5 +326,172 @@ def test_config_validation_parametrized(invalid_config_override): PGVectorConfig(**config) -if __name__ == "__main__": - unittest.main() +def test_create_delegates_collection_creation_and_insert(): + vector = PGVector.__new__(PGVector) + vector._create_collection = MagicMock() + vector.add_texts = MagicMock(return_value=["doc-a"]) + docs = [Document(page_content="hello", metadata={"doc_id": "doc-a"})] + + result = vector.create(docs, [[0.1, 0.2]]) + + assert result == ["doc-a"] + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_add_texts_uses_execute_values_and_returns_ids(monkeypatch): + vector = PGVector.__new__(PGVector) + vector.table_name = "embedding_collection_1" + + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + monkeypatch.setattr(pgvector_module.uuid, "uuid4", lambda: "generated-uuid") + execute_values = MagicMock() + monkeypatch.setattr(pgvector_module.psycopg2.extras, "execute_values", execute_values) + + docs = [ + Document(page_content="a", metadata={"doc_id": "doc-a"}), + Document(page_content="b", metadata={"document_id": "doc-b"}), + SimpleNamespace(page_content="c", metadata=None), + ] + ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]]) + + assert ids == ["doc-a", "generated-uuid"] + execute_values.assert_called_once() + + +def test_text_get_and_delete_methods(): + vector = PGVector.__new__(PGVector) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + cursor.fetchone.return_value = ("id-1",) + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1"), ({"doc_id": "2"}, "text-2")]) + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + assert vector.text_exists("id-1") is True + docs = vector.get_by_ids(["id-1", "id-2"]) + assert len(docs) == 2 + assert docs[0].page_content == "text-1" + + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete() + executed_sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("meta->>%s = %s" in sql for sql in executed_sql) + assert any("DROP TABLE IF EXISTS embedding_collection_1" in sql for sql in executed_sql) + + +def test_delete_by_ids_handles_empty_undefined_table_and_generic_exception(monkeypatch): + vector = PGVector.__new__(PGVector) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + vector.delete_by_ids([]) + cursor.execute.assert_not_called() + + class _UndefinedTableError(Exception): + pass + + monkeypatch.setattr(pgvector_module.psycopg2.errors, "UndefinedTable", _UndefinedTableError) + cursor.execute.side_effect = _UndefinedTableError("missing") + vector.delete_by_ids(["doc-1"]) + + cursor.execute.side_effect = RuntimeError("boom") + with pytest.raises(RuntimeError, match="boom"): + vector.delete_by_ids(["doc-1"]) + + +def test_search_by_vector_supports_filter_and_threshold(): + vector = PGVector.__new__(PGVector) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1", 0.1), ({"doc_id": "2"}, "text-2", 0.8)]) + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_vector([0.1], top_k=0) + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + sql = cursor.execute.call_args.args[0] + assert "meta->>'document_id' in ('d-1')" in sql + + +def test_search_by_full_text_branches_for_bigm_and_standard(): + vector = PGVector.__new__(PGVector) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1", 0.7)]) + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_full_text("hello", top_k=0) + + vector.pg_bigm = False + docs = vector.search_by_full_text("hello world", top_k=2, document_ids_filter=["d-1"]) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.7) + standard_sql = cursor.execute.call_args.args[0] + assert "to_tsvector(text) @@ plainto_tsquery(%s)" in standard_sql + + cursor.execute.reset_mock() + cursor.__iter__.return_value = iter([({"doc_id": "2"}, "text-2", 0.6)]) + vector.pg_bigm = True + vector.search_by_full_text("hello world", top_k=2, document_ids_filter=["d-2"]) + assert "SET pg_bigm.similarity_limit TO 0.000001" in cursor.execute.call_args_list[0].args[0] + assert "bigm_similarity" in cursor.execute.call_args_list[1].args[0] + + +def test_pgvector_factory_initializes_expected_collection_name(monkeypatch): + factory = pgvector_module.PGVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(pgvector_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_HOST", "localhost") + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_PORT", 5432) + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_USER", "postgres") + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_PASSWORD", "secret") + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_DATABASE", "postgres") + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_MIN_CONNECTION", 1) + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_MAX_CONNECTION", 5) + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_PG_BIGM", False) + + with patch.object(pgvector_module, "PGVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/pyvastbase/test_vastbase_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/pyvastbase/test_vastbase_vector.py new file mode 100644 index 00000000000..bd8df520ba2 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/pyvastbase/test_vastbase_vector.py @@ -0,0 +1,269 @@ +import importlib +import sys +import types +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_psycopg2_modules(): + psycopg2 = types.ModuleType("psycopg2") + psycopg2.__path__ = [] + psycopg2_extras = types.ModuleType("psycopg2.extras") + psycopg2_pool = types.ModuleType("psycopg2.pool") + + class SimpleConnectionPool: + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + self.getconn = MagicMock() + self.putconn = MagicMock() + + psycopg2_pool.SimpleConnectionPool = SimpleConnectionPool + psycopg2_extras.execute_values = MagicMock() + psycopg2.pool = psycopg2_pool + psycopg2.extras = psycopg2_extras + + return { + "psycopg2": psycopg2, + "psycopg2.pool": psycopg2_pool, + "psycopg2.extras": psycopg2_extras, + } + + +@pytest.fixture +def vastbase_module(monkeypatch): + for name, module in _build_fake_psycopg2_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.pyvastbase.vastbase_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.VastbaseVectorConfig( + host="localhost", + port=5432, + user="dify", + password="secret", + database="dify", + min_connection=1, + max_connection=5, + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config VASTBASE_HOST is required"), + ("port", 0, "config VASTBASE_PORT is required"), + ("user", "", "config VASTBASE_USER is required"), + ("password", "", "config VASTBASE_PASSWORD is required"), + ("database", "", "config VASTBASE_DATABASE is required"), + ("min_connection", 0, "config VASTBASE_MIN_CONNECTION is required"), + ("max_connection", 0, "config VASTBASE_MAX_CONNECTION is required"), + ], +) +def test_vastbase_config_validation(vastbase_module, field, value, message): + values = _config(vastbase_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + vastbase_module.VastbaseVectorConfig.model_validate(values) + + +def test_vastbase_config_rejects_invalid_connection_window(vastbase_module): + with pytest.raises(ValidationError, match="VASTBASE_MIN_CONNECTION should less than VASTBASE_MAX_CONNECTION"): + vastbase_module.VastbaseVectorConfig.model_validate( + { + "host": "localhost", + "port": 5432, + "user": "dify", + "password": "secret", + "database": "dify", + "min_connection": 6, + "max_connection": 5, + } + ) + + +def test_init_and_get_cursor_context_manager(vastbase_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(vastbase_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + + conn = MagicMock() + cur = MagicMock() + pool.getconn.return_value = conn + conn.cursor.return_value = cur + + vector = vastbase_module.VastbaseVector("collection_1", _config(vastbase_module)) + assert vector.get_type() == "vastbase" + assert vector.table_name == "embedding_collection_1" + + with vector._get_cursor() as got_cur: + assert got_cur is cur + + cur.close.assert_called_once() + conn.commit.assert_called_once() + pool.putconn.assert_called_once_with(conn) + + +def test_create_and_add_texts(vastbase_module, monkeypatch): + vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector) + vector.table_name = "embedding_collection_1" + vector._create_collection = MagicMock() + + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + monkeypatch.setattr(vastbase_module.uuid, "uuid4", lambda: "generated-uuid") + + docs = [ + Document(page_content="a", metadata={"doc_id": "doc-a"}), + Document(page_content="b", metadata={"document_id": "doc-b"}), + SimpleNamespace(page_content="c", metadata=None), + ] + + ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]]) + assert ids == ["doc-a", "generated-uuid"] + vastbase_module.psycopg2.extras.execute_values.assert_called_once() + + vector.add_texts = MagicMock(return_value=["doc-a"]) + result = vector.create(docs, [[0.1], [0.2], [0.3]]) + vector._create_collection.assert_called_once_with(1) + assert result == ["doc-a"] + + +def test_text_get_delete_and_metadata_methods(vastbase_module): + vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + cursor.fetchone.return_value = ("id-1",) + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1"), ({"doc_id": "2"}, "text-2")]) + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + assert vector.text_exists("id-1") is True + docs = vector.get_by_ids(["id-1", "id-2"]) + assert len(docs) == 2 + assert docs[0].page_content == "text-1" + + vector.delete_by_ids([]) + vector.delete_by_ids(["id-1"]) + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete() + executed_sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("DELETE FROM embedding_collection_1 WHERE id IN %s" in sql for sql in executed_sql) + assert any("meta->>%s = %s" in sql for sql in executed_sql) + assert any("DROP TABLE IF EXISTS embedding_collection_1" in sql for sql in executed_sql) + + +def test_search_by_vector_and_full_text(vastbase_module): + vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + cursor.__iter__.return_value = iter( + [ + ({"doc_id": "1"}, "text-1", 0.1), + ({"doc_id": "2"}, "text-2", 0.8), + ] + ) + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_vector([0.1, 0.2], top_k=0) + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_full_text("hello", top_k=0) + + cursor.__iter__.return_value = iter([({"doc_id": "3"}, "full-text", 0.7)]) + full_docs = vector.search_by_full_text("hello world", top_k=2) + assert len(full_docs) == 1 + assert full_docs[0].page_content == "full-text" + + +def test_create_collection_cache_and_dimension_branches(vastbase_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(vastbase_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(vastbase_module.redis_client, "set", MagicMock()) + + vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector) + vector._collection_name = "collection_1" + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + monkeypatch.setattr(vastbase_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection(3) + cursor.execute.assert_not_called() + + monkeypatch.setattr(vastbase_module.redis_client, "get", MagicMock(return_value=None)) + vector._create_collection(17000) + executed_sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("CREATE TABLE IF NOT EXISTS embedding_collection_1" in sql for sql in executed_sql) + assert all("embedding_cosine_v1_idx" not in sql for sql in executed_sql) + + cursor.execute.reset_mock() + vector._create_collection(3) + executed_sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("embedding_cosine_v1_idx" in sql for sql in executed_sql) + vastbase_module.redis_client.set.assert_called() + + +def test_vastbase_factory_uses_existing_or_generated_collection(vastbase_module, monkeypatch): + factory = vastbase_module.VastbaseVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(vastbase_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_HOST", "localhost") + monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_PORT", 5432) + monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_USER", "dify") + monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_PASSWORD", "secret") + monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_DATABASE", "dify") + monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_MIN_CONNECTION", 1) + monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_MAX_CONNECTION", 5) + + with patch.object(vastbase_module, "VastbaseVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/qdrant/test_qdrant_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/qdrant/test_qdrant_vector.py new file mode 100644 index 00000000000..04085065638 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/qdrant/test_qdrant_vector.py @@ -0,0 +1,328 @@ +import importlib +import os +import sys +import types +from collections import UserDict +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_qdrant_modules(): + qdrant_client = types.ModuleType("qdrant_client") + qdrant_http = types.ModuleType("qdrant_client.http") + qdrant_http_models = types.ModuleType("qdrant_client.http.models") + qdrant_http_exceptions = types.ModuleType("qdrant_client.http.exceptions") + qdrant_local_pkg = types.ModuleType("qdrant_client.local") + qdrant_local_mod = types.ModuleType("qdrant_client.local.qdrant_local") + + class UnexpectedResponseError(Exception): + def __init__(self, status_code): + super().__init__(f"status={status_code}") + self.status_code = status_code + + class FilterSelector: + def __init__(self, filter): + self.filter = filter + + class HnswConfigDiff: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class TextIndexParams: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class VectorParams: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class PointStruct: + def __init__(self, **kwargs): + self.id = kwargs["id"] + self.vector = kwargs["vector"] + self.payload = kwargs["payload"] + + class Filter: + def __init__(self, must=None): + self.must = must or [] + + class FieldCondition: + def __init__(self, key, match): + self.key = key + self.match = match + + class MatchValue: + def __init__(self, value): + self.value = value + + class MatchAny: + def __init__(self, any): + self.any = any + + class MatchText: + def __init__(self, text): + self.text = text + + class _Distance(UserDict): + def __getitem__(self, key): + return key + + class QdrantClient: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.get_collections = MagicMock(return_value=SimpleNamespace(collections=[])) + self.create_collection = MagicMock() + self.create_payload_index = MagicMock() + self.upsert = MagicMock() + self.delete = MagicMock() + self.delete_collection = MagicMock() + self.retrieve = MagicMock(return_value=[]) + self.search = MagicMock(return_value=[]) + self.scroll = MagicMock(return_value=([], None)) + + class QdrantLocal(QdrantClient): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._load = MagicMock() + + qdrant_client.QdrantClient = QdrantClient + qdrant_http_models.FilterSelector = FilterSelector + qdrant_http_models.HnswConfigDiff = HnswConfigDiff + qdrant_http_models.PayloadSchemaType = SimpleNamespace(KEYWORD="KEYWORD") + qdrant_http_models.TextIndexParams = TextIndexParams + qdrant_http_models.TextIndexType = SimpleNamespace(TEXT="TEXT") + qdrant_http_models.TokenizerType = SimpleNamespace(MULTILINGUAL="MULTILINGUAL") + qdrant_http_models.VectorParams = VectorParams + qdrant_http_models.Distance = _Distance() + qdrant_http_models.PointStruct = PointStruct + qdrant_http_models.Filter = Filter + qdrant_http_models.FieldCondition = FieldCondition + qdrant_http_models.MatchValue = MatchValue + qdrant_http_models.MatchAny = MatchAny + qdrant_http_models.MatchText = MatchText + qdrant_http_exceptions.UnexpectedResponse = UnexpectedResponseError + + qdrant_http.models = qdrant_http_models + qdrant_local_mod.QdrantLocal = QdrantLocal + qdrant_local_pkg.qdrant_local = qdrant_local_mod + + return { + "qdrant_client": qdrant_client, + "qdrant_client.http": qdrant_http, + "qdrant_client.http.models": qdrant_http_models, + "qdrant_client.http.exceptions": qdrant_http_exceptions, + "qdrant_client.local": qdrant_local_pkg, + "qdrant_client.local.qdrant_local": qdrant_local_mod, + } + + +@pytest.fixture +def qdrant_module(monkeypatch): + for name, module in _build_fake_qdrant_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.qdrant.qdrant_vector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "endpoint": "http://localhost:6333", + "api_key": "api-key", + "timeout": 20, + "root_path": "/tmp", + "grpc_port": 6334, + "prefer_grpc": False, + "replication_factor": 1, + "write_consistency_factor": 1, + } + values.update(overrides) + return module.QdrantConfig.model_validate(values) + + +def test_qdrant_config_to_params(qdrant_module): + url_params = _config(qdrant_module).to_qdrant_params().model_dump() + assert url_params["url"] == "http://localhost:6333" + assert url_params["verify"] is False + + path_config = _config(qdrant_module, endpoint="path:storage") + assert path_config.to_qdrant_params().path == os.path.join("/tmp", "storage") + + with pytest.raises(ValueError, match="Root path is not set"): + _config(qdrant_module, endpoint="path:storage", root_path=None).to_qdrant_params() + + +def test_init_and_basic_behaviour(qdrant_module): + vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module)) + assert vector.get_type() == qdrant_module.VectorType.QDRANT + assert vector.to_index_struct()["vector_store"]["class_prefix"] == "collection_1" + + docs = [Document(page_content="a", metadata={"doc_id": "a"})] + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + vector.create(docs, [[0.1]]) + vector.create_collection.assert_called_once_with("collection_1", 1) + vector.add_texts.assert_called_once() + + +def test_create_collection_and_add_texts(qdrant_module, monkeypatch): + vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module)) + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(qdrant_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(qdrant_module.redis_client, "set", MagicMock()) + + monkeypatch.setattr(qdrant_module.redis_client, "get", MagicMock(return_value=1)) + vector.create_collection("collection_1", 3) + vector._client.create_collection.assert_not_called() + + monkeypatch.setattr(qdrant_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.get_collections.return_value = SimpleNamespace(collections=[]) + vector.create_collection("collection_1", 3) + vector._client.create_collection.assert_called_once() + assert vector._client.create_payload_index.call_count == 4 + qdrant_module.redis_client.set.assert_called_once() + + # add_texts and generated batches + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + Document(page_content="b", metadata={"doc_id": "id-2"}), + ] + ids = vector.add_texts(docs, [[0.1], [0.2]]) + assert ids == ["id-1", "id-2"] + assert vector._client.upsert.call_count == 1 + + payloads = qdrant_module.QdrantVector._build_payloads( + ["a"], [{"doc_id": "id-1"}], "content", "metadata", "g1", "group_id" + ) + assert payloads[0]["group_id"] == "g1" + with pytest.raises(ValueError, match="At least one of the texts is None"): + qdrant_module.QdrantVector._build_payloads( + [None], [{"doc_id": "id-1"}], "content", "metadata", "g1", "group_id" + ) + + +def test_delete_and_exists_paths(qdrant_module): + vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module)) + unexpected = sys.modules["qdrant_client.http.exceptions"].UnexpectedResponse + + vector._client.delete.side_effect = unexpected(404) + vector.delete_by_metadata_field("document_id", "doc-1") + vector._client.delete.side_effect = None + + vector._client.delete.side_effect = unexpected(500) + with pytest.raises(unexpected): + vector.delete_by_metadata_field("document_id", "doc-1") + vector._client.delete.side_effect = None + + vector._client.delete.side_effect = unexpected(404) + vector.delete() + vector._client.delete.side_effect = unexpected(500) + with pytest.raises(unexpected): + vector.delete() + vector._client.delete.side_effect = None + + vector._client.delete.side_effect = unexpected(404) + vector.delete_by_ids(["doc-1"]) + vector._client.delete.side_effect = unexpected(500) + with pytest.raises(unexpected): + vector.delete_by_ids(["doc-1"]) + vector._client.delete.side_effect = None + + vector._client.get_collections.return_value = SimpleNamespace(collections=[SimpleNamespace(name="other")]) + assert vector.text_exists("id-1") is False + vector._client.get_collections.return_value = SimpleNamespace(collections=[SimpleNamespace(name="collection_1")]) + vector._client.retrieve.return_value = [{"id": "id-1"}] + assert vector.text_exists("id-1") is True + + +def test_search_and_helper_methods(qdrant_module): + vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module)) + assert vector.search_by_vector([0.1], score_threshold=1.0) == [] + + vector._client.search.return_value = [ + SimpleNamespace(payload=None, score=0.9, vector=[0.1]), + SimpleNamespace(payload={"metadata": {"doc_id": "1"}, "page_content": "doc-a"}, score=0.8, vector=[0.1]), + ] + docs = vector.search_by_vector([0.1], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.8) + + # full text search: keyword split, dedup and top_k limit + scroll_results = [ + ( + [ + SimpleNamespace(id="p1", payload={"page_content": "doc-1", "metadata": {"doc_id": "1"}}, vector=[0.1]), + SimpleNamespace(id="p2", payload={"page_content": "doc-2", "metadata": {"doc_id": "2"}}, vector=[0.2]), + ], + None, + ), + ( + [ + SimpleNamespace(id="p2", payload={"page_content": "doc-2", "metadata": {"doc_id": "2"}}, vector=[0.2]), + ], + None, + ), + ] + vector._client.scroll.side_effect = scroll_results + docs = vector.search_by_full_text("hello world", top_k=2, document_ids_filter=["d-1"]) + assert len(docs) == 2 + assert vector.search_by_full_text(" ", top_k=2) == [] + + local_client = qdrant_module.QdrantLocal() + vector._client = local_client + vector._reload_if_needed() + local_client._load.assert_called_once() + + doc = vector._document_from_scored_point( + SimpleNamespace(payload={"page_content": "doc", "metadata": {"doc_id": "1"}}, vector=[0.1]), + "page_content", + "metadata", + ) + assert doc.page_content == "doc" + + +def test_qdrant_factory_paths(qdrant_module, monkeypatch): + factory = qdrant_module.QdrantVectorFactory() + dataset = SimpleNamespace( + id="dataset-1", + tenant_id="tenant-1", + collection_binding_id=None, + index_struct_dict=None, + index_struct=None, + ) + monkeypatch.setattr(qdrant_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(qdrant_module, "current_app", SimpleNamespace(config=SimpleNamespace(root_path="/root"))) + monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_URL", "http://localhost:6333") + monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_API_KEY", "api-key") + monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_CLIENT_TIMEOUT", 20) + monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_GRPC_PORT", 6334) + monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_GRPC_ENABLED", False) + monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_REPLICATION_FACTOR", 1) + + with patch.object(qdrant_module, "QdrantVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + assert result == "vector" + assert vector_cls.call_args.kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset.index_struct is not None + + # collection binding lookup path + dataset.collection_binding_id = "binding-1" + dataset.index_struct_dict = {"vector_store": {"class_prefix": "existing"}} + monkeypatch.setattr(qdrant_module, "select", lambda _model: SimpleNamespace(where=lambda *_args: "stmt")) + qdrant_module.db.session.scalars = MagicMock( + return_value=SimpleNamespace(one_or_none=lambda: SimpleNamespace(collection_name="BOUND_COLLECTION")) + ) + with patch.object(qdrant_module, "QdrantVector", return_value="vector") as vector_cls: + factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + assert vector_cls.call_args.kwargs["collection_name"] == "BOUND_COLLECTION" + + qdrant_module.db.session.scalars = MagicMock(return_value=SimpleNamespace(one_or_none=lambda: None)) + with pytest.raises(ValueError, match="Dataset Collection Bindings does not exist"): + factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/relyt/test_relyt_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/relyt/test_relyt_vector.py new file mode 100644 index 00000000000..ca8cd5e5140 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/relyt/test_relyt_vector.py @@ -0,0 +1,303 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError +from sqlalchemy.types import UserDefinedType + +from core.rag.models.document import Document + + +def _build_fake_relyt_modules(): + pgvecto_rs = types.ModuleType("pgvecto_rs") + pgvecto_rs_sqlalchemy = types.ModuleType("pgvecto_rs.sqlalchemy") + + class VECTOR(UserDefinedType): + def __init__(self, dim): + self.dim = dim + + pgvecto_rs_sqlalchemy.VECTOR = VECTOR + return { + "pgvecto_rs": pgvecto_rs, + "pgvecto_rs.sqlalchemy": pgvecto_rs_sqlalchemy, + } + + +class _FakeSession: + def __init__(self, execute_result=None): + self.execute_result = execute_result or MagicMock(fetchall=lambda: []) + self.execute = MagicMock(return_value=self.execute_result) + self.commit = MagicMock() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return None + + +@pytest.fixture +def relyt_module(monkeypatch): + for name, module in _build_fake_relyt_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.relyt.relyt_vector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "host": "localhost", + "port": 5432, + "user": "postgres", + "password": "secret", + "database": "relyt", + } + values.update(overrides) + return module.RelytConfig.model_validate(values) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config RELYT_HOST is required"), + ("port", 0, "config RELYT_PORT is required"), + ("user", "", "config RELYT_USER is required"), + ("password", "", "config RELYT_PASSWORD is required"), + ("database", "", "config RELYT_DATABASE is required"), + ], +) +def test_relyt_config_validation(relyt_module, field, value, message): + values = _config(relyt_module).model_dump() + values[field] = value + with pytest.raises(ValidationError, match=message): + relyt_module.RelytConfig.model_validate(values) + + +def test_init_get_type_and_create_delegate(relyt_module, monkeypatch): + engine = MagicMock() + monkeypatch.setattr(relyt_module, "create_engine", MagicMock(return_value=engine)) + vector = relyt_module.RelytVector("collection_1", _config(relyt_module), group_id="group-1") + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})] + + vector.create(docs, [[0.1, 0.2]]) + + assert vector.get_type() == relyt_module.VectorType.RELYT + assert vector._url == "postgresql+psycopg2://postgres:secret@localhost:5432/relyt" + assert vector.embedding_dimension == 2 + vector.create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_create_collection_cache_and_sql_execution(relyt_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(relyt_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(relyt_module.redis_client, "set", MagicMock()) + + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + + monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=1)) + session = _FakeSession() + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + vector.create_collection(3) + session.execute.assert_not_called() + + monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=None)) + session = _FakeSession() + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + vector.create_collection(3) + executed_sql = [str(call.args[0]) for call in session.execute.call_args_list] + assert any("DROP TABLE IF EXISTS" in sql for sql in executed_sql) + assert any("CREATE TABLE IF NOT EXISTS" in sql for sql in executed_sql) + assert any("CREATE INDEX" in sql for sql in executed_sql) + relyt_module.redis_client.set.assert_called_once() + + +def test_add_texts_and_metadata_queries(relyt_module, monkeypatch): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector._group_id = "group-1" + vector.client = MagicMock() + + begin_ctx = MagicMock() + begin_ctx.__enter__.return_value = None + begin_ctx.__exit__.return_value = None + conn = MagicMock() + conn.__enter__.return_value = conn + conn.__exit__.return_value = None + conn.begin.return_value = begin_ctx + vector.client.connect.return_value = conn + + monkeypatch.setattr(relyt_module.uuid, "uuid1", MagicMock(side_effect=["id-1", "id-2"])) + docs = [ + Document(page_content="a", metadata={"doc_id": "d-1"}), + Document(page_content="b", metadata={"doc_id": "d-2"}), + ] + ids = vector.add_texts(docs, [[0.1], [0.2]]) + + assert ids == ["id-1", "id-2"] + assert conn.execute.call_count >= 1 + first_insert_values = conn.execute.call_args.args[0].compile().params + assert "group_id" in str(first_insert_values) + + session = _FakeSession(execute_result=MagicMock(fetchall=lambda: [("id-a",), ("id-b",)])) + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-a", "id-b"] + + session = _FakeSession(execute_result=MagicMock(fetchall=lambda: [])) + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None + + +# 1. delete_by_uuids: success and connect error +def test_delete_by_uuids_success_and_connect_error(relyt_module): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + with pytest.raises(ValueError, match="No ids provided"): + vector.delete_by_uuids(None) + conn = MagicMock() + conn.__enter__.return_value = conn + conn.__exit__.return_value = None + begin_ctx = MagicMock() + begin_ctx.__enter__.return_value = None + begin_ctx.__exit__.return_value = None + conn.begin.return_value = begin_ctx + vector.client.connect.return_value = conn + assert vector.delete_by_uuids(["id-1"]) is True + vector.client.connect.side_effect = RuntimeError("boom") + assert vector.delete_by_uuids(["id-1"]) is False + + +# 2. delete_by_metadata_field calls delete_by_uuids +def test_delete_by_metadata_field_calls_delete_by_uuids(relyt_module): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"]) + vector.delete_by_uuids = MagicMock(return_value=True) + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete_by_uuids.assert_called_once_with(["id-1"]) + + +# 3. delete_by_ids translates to uuids +def test_delete_by_ids_translates_to_uuids(relyt_module, monkeypatch): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + session = _FakeSession(execute_result=MagicMock(fetchall=lambda: [("uuid-1",), ("uuid-2",)])) + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + vector.delete_by_uuids = MagicMock(return_value=True) + vector.delete_by_ids(["doc-1", "doc-2"]) + vector.delete_by_uuids.assert_called_once_with(["uuid-1", "uuid-2"]) + + +# 4. text_exists True +def test_text_exists_true(relyt_module, monkeypatch): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + session = _FakeSession(execute_result=MagicMock(fetchall=lambda: [("id-1",)])) + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + assert vector.text_exists("doc-1") is True + + +# 5. text_exists False +def test_text_exists_false(relyt_module, monkeypatch): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + session = _FakeSession(execute_result=MagicMock(fetchall=lambda: [])) + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + assert vector.text_exists("doc-1") is False + + +# 6. similarity_search_with_score_by_vector returns Documents and scores +def test_similarity_search_with_score_by_vector(relyt_module): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + result_rows = [ + SimpleNamespace(document="doc-a", metadata={"doc_id": "1"}, distance=0.1), + SimpleNamespace(document="doc-b", metadata={"doc_id": "2"}, distance=0.8), + ] + conn = MagicMock() + conn.__enter__.return_value = conn + conn.__exit__.return_value = None + conn.execute.return_value.fetchall.return_value = result_rows + vector.client.connect.return_value = conn + similarities = vector.similarity_search_with_score_by_vector([0.1, 0.2], k=2, filter={"document_id": ["d-1"]}) + assert len(similarities) == 2 + assert similarities[0][0].page_content == "doc-a" + + +# 7. search_by_vector filters by score and ids +def test_search_by_vector_filters_by_score_and_ids(relyt_module): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + vector.similarity_search_with_score_by_vector = MagicMock( + return_value=[ + (Document(page_content="a", metadata={"doc_id": "1"}), 0.1), + (Document(page_content="b", metadata={}), 0.9), + ] + ) + docs = vector.search_by_vector([0.1], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) + assert len(docs) == 1 + assert vector.search_by_full_text("query") == [] + + +# 8. delete commits session +def test_delete_commits_session(relyt_module, monkeypatch): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + session = _FakeSession() + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + vector.delete() + session.commit.assert_called_once() + + +def test_relyt_factory_existing_and_generated_collection(relyt_module, monkeypatch): + factory = relyt_module.RelytVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(relyt_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(relyt_module.dify_config, "RELYT_HOST", "localhost") + monkeypatch.setattr(relyt_module.dify_config, "RELYT_PORT", 5432) + monkeypatch.setattr(relyt_module.dify_config, "RELYT_USER", "postgres") + monkeypatch.setattr(relyt_module.dify_config, "RELYT_PASSWORD", "secret") + monkeypatch.setattr(relyt_module.dify_config, "RELYT_DATABASE", "relyt") + + with patch.object(relyt_module, "RelytVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/tablestore/test_tablestore_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/tablestore/test_tablestore_vector.py new file mode 100644 index 00000000000..e3b6676d9bf --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/tablestore/test_tablestore_vector.py @@ -0,0 +1,316 @@ +import importlib +import json +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_tablestore_module(): + tablestore = types.ModuleType("tablestore") + + class _BatchGetRowRequest: + def __init__(self): + self.items = [] + + def add(self, item): + self.items.append(item) + + class _TableInBatchGetRowItem: + def __init__(self, table_name, rows_to_get, columns_to_get, _unused, _ver): + self.table_name = table_name + self.rows_to_get = rows_to_get + self.columns_to_get = columns_to_get + + class _Row: + def __init__(self, primary_key, attribute_columns=None): + self.primary_key = primary_key + self.attribute_columns = attribute_columns or [] + + class _Client: + def __init__(self, *_args): + self.list_table = MagicMock(return_value=[]) + self.create_table = MagicMock() + self.list_search_index = MagicMock(return_value=[]) + self.create_search_index = MagicMock() + self.delete_search_index = MagicMock() + self.delete_table = MagicMock() + self.put_row = MagicMock() + self.delete_row = MagicMock() + self.get_row = MagicMock(return_value=(None, None, None)) + self.batch_get_row = MagicMock() + self.search = MagicMock() + + tablestore.OTSClient = _Client + tablestore.BatchGetRowRequest = _BatchGetRowRequest + tablestore.TableInBatchGetRowItem = _TableInBatchGetRowItem + tablestore.Row = _Row + tablestore.TableMeta = lambda name, schema: ("table_meta", name, schema) + tablestore.TableOptions = lambda: ("table_options",) + tablestore.CapacityUnit = lambda read, write: ("capacity", read, write) + tablestore.ReservedThroughput = lambda cap: ("reserved", cap) + tablestore.FieldSchema = lambda *args, **kwargs: ("field", args, kwargs) + tablestore.VectorOptions = lambda **kwargs: ("vector_options", kwargs) + tablestore.SearchIndexMeta = lambda field_schemas: ("search_index_meta", field_schemas) + tablestore.SearchQuery = lambda query, **kwargs: SimpleNamespace(query=query, **kwargs) + tablestore.TermQuery = lambda key, value: ("term_query", key, value) + tablestore.ColumnsToGet = lambda **kwargs: ("columns_to_get", kwargs) + tablestore.KnnVectorQuery = lambda **kwargs: SimpleNamespace(**kwargs) + tablestore.TermsQuery = lambda key, values: ("terms_query", key, values) + tablestore.Sort = lambda **kwargs: ("sort", kwargs) + tablestore.ScoreSort = lambda **kwargs: ("score_sort", kwargs) + tablestore.BoolQuery = lambda **kwargs: SimpleNamespace(**kwargs) + tablestore.MatchQuery = lambda **kwargs: ("match_query", kwargs) + + tablestore.FieldType = SimpleNamespace(TEXT="TEXT", VECTOR="VECTOR", KEYWORD="KEYWORD") + tablestore.AnalyzerType = SimpleNamespace(MAXWORD="MAXWORD") + tablestore.VectorDataType = SimpleNamespace(VD_FLOAT_32="VD_FLOAT_32") + tablestore.VectorMetricType = SimpleNamespace(VM_COSINE="VM_COSINE") + tablestore.ColumnReturnType = SimpleNamespace(SPECIFIED="SPECIFIED", ALL_FROM_INDEX="ALL_FROM_INDEX") + tablestore.SortOrder = SimpleNamespace(DESC="DESC") + return tablestore + + +@pytest.fixture +def tablestore_module(monkeypatch): + fake_module = _build_fake_tablestore_module() + monkeypatch.setitem(sys.modules, "tablestore", fake_module) + + import core.rag.datasource.vdb.tablestore.tablestore_vector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "access_key_id": "ak", + "access_key_secret": "sk", + "instance_name": "instance", + "endpoint": "endpoint", + "normalize_full_text_bm25_score": False, + } + values.update(overrides) + return module.TableStoreConfig.model_validate(values) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("access_key_id", "", "config ACCESS_KEY_ID is required"), + ("access_key_secret", "", "config ACCESS_KEY_SECRET is required"), + ("instance_name", "", "config INSTANCE_NAME is required"), + ("endpoint", "", "config ENDPOINT is required"), + ], +) +def test_tablestore_config_validation(tablestore_module, field, value, message): + values = _config(tablestore_module).model_dump() + values[field] = value + with pytest.raises(ValidationError, match=message): + tablestore_module.TableStoreConfig.model_validate(values) + + +def test_init_and_basic_delegation(tablestore_module): + vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module)) + assert vector.get_type() == tablestore_module.VectorType.TABLESTORE + assert vector._table_name == "collection_1" + assert vector._index_name == "collection_1_idx" + + vector._create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="hello", metadata={"doc_id": "d-1"})] + vector.create(docs, [[0.1, 0.2]]) + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(documents=docs, embeddings=[[0.1, 0.2]]) + + vector.create_collection([[0.1, 0.2]]) + assert vector._create_collection.call_count == 2 + + +def test_get_by_ids_text_exists_delete_and_wrappers(tablestore_module): + vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module)) + + # get_by_ids + ok_item = SimpleNamespace( + is_ok=True, + row=SimpleNamespace( + attribute_columns=[("metadata", json.dumps({"doc_id": "1"}), None), ("page_content", "text-1", None)] + ), + ) + fail_item = SimpleNamespace(is_ok=False, row=None) + batch_resp = SimpleNamespace(get_result_by_table=lambda _table: [ok_item, fail_item]) + vector._tablestore_client.batch_get_row.return_value = batch_resp + docs = vector.get_by_ids(["id-1"]) + assert len(docs) == 1 + assert docs[0].page_content == "text-1" + + # text_exists + vector._tablestore_client.get_row.return_value = (None, object(), None) + assert vector.text_exists("id-1") is True + vector._tablestore_client.get_row.return_value = (None, None, None) + assert vector.text_exists("id-1") is False + + # delete wrappers + vector._delete_row = MagicMock() + vector.delete_by_ids([]) + vector._delete_row.assert_not_called() + vector.delete_by_ids(["id-1", "id-2"]) + assert vector._delete_row.call_count == 2 + + vector._search_by_metadata = MagicMock(return_value=["id-a"]) + assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-a"] + vector.delete_by_ids = MagicMock() + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete_by_ids.assert_called_once_with(["id-a"]) + + vector._search_by_vector = MagicMock(return_value=["vec-doc"]) + vector._search_by_full_text = MagicMock(return_value=["fts-doc"]) + assert vector.search_by_vector([0.1], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) == ["vec-doc"] + assert vector.search_by_full_text("query", top_k=2, score_threshold=0.3, document_ids_filter=["d-1"]) == ["fts-doc"] + + vector._delete_table_if_exist = MagicMock() + vector.delete() + vector._delete_table_if_exist.assert_called_once() + + +def test_create_collection_and_table_index_lifecycle(tablestore_module, monkeypatch): + vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module)) + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(tablestore_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(tablestore_module.redis_client, "set", MagicMock()) + + monkeypatch.setattr(tablestore_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_table_if_not_exist = MagicMock() + vector._create_search_index_if_not_exist = MagicMock() + vector._create_collection(3) + vector._create_table_if_not_exist.assert_not_called() + + monkeypatch.setattr(tablestore_module.redis_client, "get", MagicMock(return_value=None)) + vector._create_collection(3) + vector._create_table_if_not_exist.assert_called_once() + vector._create_search_index_if_not_exist.assert_called_once_with(3) + tablestore_module.redis_client.set.assert_called_once() + + vector = tablestore_module.TableStoreVector("collection_2", _config(tablestore_module)) + vector._tablestore_client.list_table.return_value = ["collection_2"] + assert vector._create_table_if_not_exist() is None + vector._tablestore_client.list_table.return_value = [] + vector._create_table_if_not_exist() + vector._tablestore_client.create_table.assert_called_once() + + vector._tablestore_client.list_search_index.return_value = [("collection_2", "collection_2_idx")] + assert vector._create_search_index_if_not_exist(3) is None + vector._tablestore_client.list_search_index.return_value = [] + vector._create_search_index_if_not_exist(3) + vector._tablestore_client.create_search_index.assert_called_once() + + vector._tablestore_client.list_search_index.return_value = [("collection_2", "idx_a"), ("collection_2", "idx_b")] + vector._delete_table_if_exist() + assert vector._tablestore_client.delete_search_index.call_count == 2 + vector._tablestore_client.delete_table.assert_called_once_with("collection_2") + + vector._delete_search_index() + vector._tablestore_client.delete_search_index.assert_called_with("collection_2", "collection_2_idx") + + +def test_write_row_and_search_helpers(tablestore_module): + vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module)) + + vector._write_row( + "id-1", + { + "page_content": "hello", + "vector": [0.1, 0.2], + "metadata": {"doc_id": "d-1", "document_id": "doc-1"}, + }, + ) + put_row_call = vector._tablestore_client.put_row.call_args + assert put_row_call.args[0] == "collection_1" + attrs = put_row_call.args[1].attribute_columns + assert any(item[0] == "metadata_tags" for item in attrs) + + vector._delete_row("id-1") + vector._tablestore_client.delete_row.assert_called_once() + + # metadata search pagination + first_page = SimpleNamespace(rows=[[(("id", "row-1"),)]], next_token=b"next") + second_page = SimpleNamespace(rows=[[(("id", "row-2"),)]], next_token=b"") + vector._tablestore_client.search.side_effect = [first_page, second_page] + ids = vector._search_by_metadata("document_id", "doc-1") + assert ids == ["row-1", "row-2"] + vector._tablestore_client.search.side_effect = None + + # vector search + hit1 = SimpleNamespace( + score=0.9, + row=( + None, + [("page_content", "doc-a"), ("metadata", json.dumps({"doc_id": "1"})), ("vector", json.dumps([0.1]))], + ), + ) + hit2 = SimpleNamespace( + score=0.2, + row=( + None, + [("page_content", "doc-b"), ("metadata", json.dumps({"doc_id": "2"})), ("vector", json.dumps([0.2]))], + ), + ) + vector._tablestore_client.search.return_value = SimpleNamespace(search_hits=[hit1, hit2]) + docs = vector._search_by_vector([0.1], document_ids_filter=["document_id=doc-1"], top_k=2, score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + + assert tablestore_module.TableStoreVector._normalize_score_exp_decay(0) == pytest.approx(0.0) + assert tablestore_module.TableStoreVector._normalize_score_exp_decay(100) <= 1.0 + + # full text search with and without normalized score filter + vector._normalize_full_text_bm25_score = True + hit3 = SimpleNamespace( + score=10.0, row=(None, [("page_content", "doc-c"), ("metadata", json.dumps({"doc_id": "3"}))]) + ) + hit4 = SimpleNamespace( + score=0.1, row=(None, [("page_content", "doc-d"), ("metadata", json.dumps({"doc_id": "4"}))]) + ) + vector._tablestore_client.search.return_value = SimpleNamespace(search_hits=[hit3, hit4]) + docs = vector._search_by_full_text("query", document_ids_filter=["document_id=doc-1"], top_k=2, score_threshold=0.2) + assert len(docs) == 1 + assert "score" in docs[0].metadata + + vector._normalize_full_text_bm25_score = False + vector._tablestore_client.search.return_value = SimpleNamespace(search_hits=[hit3]) + docs = vector._search_by_full_text("query", document_ids_filter=None, top_k=2, score_threshold=0.0) + assert len(docs) == 1 + assert "score" not in docs[0].metadata + + +def test_tablestore_factory_uses_existing_or_generated_collection(tablestore_module, monkeypatch): + factory = tablestore_module.TableStoreVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(tablestore_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_ENDPOINT", "endpoint") + monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_INSTANCE_NAME", "instance") + monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_ACCESS_KEY_ID", "ak") + monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_ACCESS_KEY_SECRET", "sk") + monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE", True) + + with patch.object(tablestore_module, "TableStoreVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/tencent/test_tencent_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/tencent/test_tencent_vector.py new file mode 100644 index 00000000000..d8f35a60197 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/tencent/test_tencent_vector.py @@ -0,0 +1,309 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_tencent_modules(): + tcvdb_text = types.ModuleType("tcvdb_text") + tcvdb_text_encoder = types.ModuleType("tcvdb_text.encoder") + tcvectordb = types.ModuleType("tcvectordb") + tcvectordb_model = types.ModuleType("tcvectordb.model") + tcvectordb_document = types.ModuleType("tcvectordb.model.document") + tcvectordb_index = types.ModuleType("tcvectordb.model.index") + tcvectordb_enum = types.ModuleType("tcvectordb.model.enum") + + class _BM25Encoder: + def encode_texts(self, text): + return {"encoded_text": text} + + def encode_queries(self, query): + return {"encoded_query": query} + + @classmethod + def default(cls, _lang): + return cls() + + class VectorDBError(Exception): + def __init__(self, message): + super().__init__(message) + self.message = message + + class RPCVectorDBClient: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.create_database_if_not_exists = MagicMock() + self.exists_collection = MagicMock(return_value=False) + self.describe_collection = MagicMock(return_value=SimpleNamespace(indexes=[])) + self.create_collection = MagicMock() + self.upsert = MagicMock() + self.query = MagicMock(return_value=[]) + self.delete = MagicMock() + self.search = MagicMock(return_value=[]) + self.hybrid_search = MagicMock(return_value=[]) + self.drop_collection = MagicMock() + + class _Document: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + class _HNSWSearchParams: + def __init__(self, ef): + self.ef = ef + + class _AnnSearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class _KeywordSearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class _WeightedRerank: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class _Filter: + @staticmethod + def in_(field, values): + return ("in", field, values) + + def __init__(self, condition): + self.condition = condition + + _Filter.In = staticmethod(_Filter.in_) + + class _HNSWParams: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class _FilterIndex: + def __init__(self, *args): + self.args = args + + class _VectorIndex: + def __init__(self, *args): + self.args = args + + class _SparseIndex: + def __init__(self, **kwargs): + self.kwargs = kwargs + + tcvectordb_enum.IndexType = SimpleNamespace( + __members__={"HNSW": "HNSW", "PRIMARY_KEY": "PRIMARY_KEY", "FILTER": "FILTER", "SPARSE_INVERTED": "SPARSE"}, + PRIMARY_KEY="PRIMARY_KEY", + FILTER="FILTER", + SPARSE_INVERTED="SPARSE", + ) + tcvectordb_enum.MetricType = SimpleNamespace(__members__={"IP": "IP"}, IP="IP") + tcvectordb_enum.FieldType = SimpleNamespace(String="String", Json="Json", SparseVector="SparseVector") + + tcvectordb_document.Document = _Document + tcvectordb_document.HNSWSearchParams = _HNSWSearchParams + tcvectordb_document.AnnSearch = _AnnSearch + tcvectordb_document.Filter = _Filter + tcvectordb_document.KeywordSearch = _KeywordSearch + tcvectordb_document.WeightedRerank = _WeightedRerank + + tcvectordb_index.HNSWParams = _HNSWParams + tcvectordb_index.FilterIndex = _FilterIndex + tcvectordb_index.VectorIndex = _VectorIndex + tcvectordb_index.SparseIndex = _SparseIndex + + tcvdb_text_encoder.BM25Encoder = _BM25Encoder + + tcvectordb_model.document = tcvectordb_document + tcvectordb_model.enum = tcvectordb_enum + tcvectordb_model.index = tcvectordb_index + + tcvectordb.RPCVectorDBClient = RPCVectorDBClient + tcvectordb.VectorDBException = VectorDBError + + return { + "tcvdb_text": tcvdb_text, + "tcvdb_text.encoder": tcvdb_text_encoder, + "tcvectordb": tcvectordb, + "tcvectordb.model": tcvectordb_model, + "tcvectordb.model.document": tcvectordb_document, + "tcvectordb.model.index": tcvectordb_index, + "tcvectordb.model.enum": tcvectordb_enum, + } + + +@pytest.fixture +def tencent_module(monkeypatch): + for name, module in _build_fake_tencent_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.tencent.tencent_vector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "url": "http://vdb.local", + "api_key": "api-key", + "timeout": 30, + "username": "user", + "database": "db", + "index_type": "HNSW", + "metric_type": "IP", + "shard": 1, + "replicas": 2, + "max_upsert_batch_size": 2, + "enable_hybrid_search": False, + } + values.update(overrides) + return module.TencentConfig.model_validate(values) + + +def test_config_and_init_paths(tencent_module): + config = _config(tencent_module) + assert config.to_tencent_params()["url"] == "http://vdb.local" + + vector = tencent_module.TencentVector("collection_1", config) + assert vector.get_type() == tencent_module.VectorType.TENCENT + assert vector._client.kwargs["key"] == "api-key" + + vector._client.exists_collection.return_value = True + vector._client.describe_collection.return_value = SimpleNamespace( + indexes=[SimpleNamespace(name="vector", dimension=768), SimpleNamespace(name="sparse_vector", dimension=0)] + ) + vector._client_config.enable_hybrid_search = True + vector._load_collection() + assert vector._enable_hybrid_search is True + assert vector._dimension == 768 + + vector._client.describe_collection.return_value = SimpleNamespace( + indexes=[SimpleNamespace(name="vector", dimension=512)] + ) + vector._load_collection() + assert vector._enable_hybrid_search is False + + +def test_create_collection_branches(tencent_module, monkeypatch): + vector = tencent_module.TencentVector("collection_1", _config(tencent_module)) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(tencent_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(tencent_module.redis_client, "set", MagicMock()) + + monkeypatch.setattr(tencent_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection(3) + vector._client.create_collection.assert_not_called() + + monkeypatch.setattr(tencent_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.exists_collection.return_value = True + vector._create_collection(3) + vector._client.create_collection.assert_not_called() + + vector._client.exists_collection.return_value = False + vector._client_config.index_type = "UNKNOWN" + with pytest.raises(ValueError, match="unsupported index_type"): + vector._create_collection(3) + + vector._client_config.index_type = "HNSW" + vector._client_config.metric_type = "UNKNOWN" + with pytest.raises(ValueError, match="unsupported metric_type"): + vector._create_collection(3) + + vector._client_config.metric_type = "IP" + vector._client.create_collection.side_effect = [ + tencent_module.VectorDBException("fieldType:json unsupported"), + None, + ] + vector._enable_hybrid_search = True + vector._create_collection(3) + assert vector._client.create_collection.call_count == 2 + tencent_module.redis_client.set.assert_called_once() + vector._client.create_collection.side_effect = None + + +def test_create_add_delete_and_search_behaviour(tencent_module): + vector = tencent_module.TencentVector("collection_1", _config(tencent_module, enable_hybrid_search=True)) + vector._create_collection = MagicMock() + docs = [ + Document(page_content="text-a", metadata={"doc_id": "a", "document_id": "doc-a"}), + Document(page_content="text-b", metadata={"doc_id": "b", "document_id": "doc-b"}), + Document(page_content="text-c", metadata={"doc_id": "c", "document_id": "doc-c"}), + ] + embeddings = [[0.1], [0.2], [0.3]] + vector.create(docs, embeddings) + vector._create_collection.assert_called_once_with(1) + + vector._client.upsert.reset_mock() + vector.add_texts(docs, embeddings) + assert vector._client.upsert.call_count == 2 + first_docs = vector._client.upsert.call_args_list[0].kwargs["documents"] + assert "sparse_vector" in first_docs[0].__dict__ + + vector._client.query.return_value = [{"id": "a"}] + assert vector.text_exists("a") is True + vector._client.query.return_value = [] + assert vector.text_exists("a") is False + + vector.delete_by_ids([]) + vector._client.delete.assert_not_called() + vector.delete_by_ids(["a", "b", "c"]) + assert vector._client.delete.call_count == 2 + vector.delete_by_metadata_field("document_id", "doc-a") + assert vector._client.delete.call_count >= 3 + + vector._client.search.return_value = [[{"metadata": {"doc_id": "1"}, "text": "vec-doc", "score": 0.9}]] + vec_docs = vector.search_by_vector([0.1], top_k=2, score_threshold=0.5, document_ids_filter=["doc-a"]) + assert len(vec_docs) == 1 + assert vec_docs[0].metadata["score"] == pytest.approx(0.9) + + vector._enable_hybrid_search = False + assert vector.search_by_full_text("query") == [] + vector._enable_hybrid_search = True + vector._client.hybrid_search.return_value = [[{"metadata": {"doc_id": "2"}, "text": "fts-doc", "score": 0.8}]] + fts_docs = vector.search_by_full_text("query", top_k=2, score_threshold=0.5, document_ids_filter=["doc-a"]) + assert len(fts_docs) == 1 + + # _get_search_res handles old string metadata format + compat_docs = vector._get_search_res([[{"metadata": '{"doc_id": "3"}', "text": "compat", "score": 0.2}]], 0.5) + assert len(compat_docs) == 1 + assert compat_docs[0].metadata["score"] == pytest.approx(0.8) + + vector._has_collection = MagicMock(return_value=True) + vector.delete() + vector._client.drop_collection.assert_called_once() + + +def test_tencent_factory_existing_and_generated_collection(tencent_module, monkeypatch): + factory = tencent_module.TencentVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(tencent_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_URL", "http://vdb.local") + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_API_KEY", "api-key") + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_TIMEOUT", 30) + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_USERNAME", "user") + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_DATABASE", "db") + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_SHARD", 1) + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_REPLICAS", 2) + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH", True) + + with patch.object(tencent_module, "TencentVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_base.py b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_base.py new file mode 100644 index 00000000000..369cda39bfe --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_base.py @@ -0,0 +1,88 @@ +from types import SimpleNamespace + +import pytest + +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.models.document import Document + + +class _DummyVector(BaseVector): + def __init__(self, collection_name: str, existing_ids: set[str] | None = None): + super().__init__(collection_name) + self._existing_ids = existing_ids or set() + + def get_type(self) -> str: + return "dummy" + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + return None + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + return None + + def text_exists(self, id: str) -> bool: + return id in self._existing_ids + + def delete_by_ids(self, ids: list[str]): + return None + + def delete_by_metadata_field(self, key: str, value: str): + return None + + def search_by_vector(self, query_vector: list[float], **kwargs): + return [] + + def search_by_full_text(self, query: str, **kwargs): + return [] + + def delete(self): + return None + + +@pytest.mark.parametrize( + ("base_method", "args"), + [ + (BaseVector.get_type, ()), + (BaseVector.create, ([], [])), + (BaseVector.add_texts, ([], [])), + (BaseVector.text_exists, ("doc-1",)), + (BaseVector.delete_by_ids, ([],)), + (BaseVector.get_ids_by_metadata_field, ("doc_id", "doc-1")), + (BaseVector.delete_by_metadata_field, ("doc_id", "doc-1")), + (BaseVector.search_by_vector, ([0.1],)), + (BaseVector.search_by_full_text, ("query",)), + (BaseVector.delete, ()), + ], +) +def test_base_vector_default_methods_raise_not_implemented(base_method, args): + vector = _DummyVector("collection_1") + + with pytest.raises(NotImplementedError): + base_method(vector, *args) + + +def test_filter_duplicate_texts_removes_existing_docs(): + vector = _DummyVector("collection_1", existing_ids={"dup"}) + docs = [ + SimpleNamespace(page_content="keep-no-meta", metadata=None), + Document(page_content="keep-no-doc-id", metadata={"document_id": "d1"}), + Document(page_content="remove-dup", metadata={"doc_id": "dup"}), + Document(page_content="keep-unique", metadata={"doc_id": "unique"}), + ] + + filtered = vector._filter_duplicate_texts(docs) + + assert [d.page_content for d in filtered] == ["keep-no-meta", "keep-no-doc-id", "keep-unique"] + + +def test_get_uuids_and_collection_name_property(): + vector = _DummyVector("collection_1") + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + SimpleNamespace(page_content="b", metadata=None), + Document(page_content="c", metadata={"document_id": "d-1"}), + Document(page_content="d", metadata={"doc_id": "id-2"}), + ] + + assert vector._get_uuids(docs) == ["id-1", "id-2"] + assert vector.collection_name == "collection_1" diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py new file mode 100644 index 00000000000..4e9ceddda9a --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py @@ -0,0 +1,434 @@ +import base64 +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _register_fake_factory_module(monkeypatch, module_path: str, class_name: str): + fake_module = types.ModuleType(module_path) + fake_cls = type(class_name, (), {}) + setattr(fake_module, class_name, fake_cls) + monkeypatch.setitem(sys.modules, module_path, fake_module) + return fake_cls + + +@pytest.fixture +def vector_factory_module(): + import importlib + + import core.rag.datasource.vdb.vector_factory as module + + return importlib.reload(module) + + +def test_gen_index_struct_dict(vector_factory_module): + result = vector_factory_module.AbstractVectorFactory.gen_index_struct_dict( + vector_factory_module.VectorType.WEAVIATE, + "collection_1", + ) + + assert result == { + "type": vector_factory_module.VectorType.WEAVIATE, + "vector_store": {"class_prefix": "collection_1"}, + } + + +@pytest.mark.parametrize( + ("vector_type", "module_path", "class_name"), + [ + ("CHROMA", "core.rag.datasource.vdb.chroma.chroma_vector", "ChromaVectorFactory"), + ("MILVUS", "core.rag.datasource.vdb.milvus.milvus_vector", "MilvusVectorFactory"), + ( + "ALIBABACLOUD_MYSQL", + "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector", + "AlibabaCloudMySQLVectorFactory", + ), + ("MYSCALE", "core.rag.datasource.vdb.myscale.myscale_vector", "MyScaleVectorFactory"), + ("PGVECTOR", "core.rag.datasource.vdb.pgvector.pgvector", "PGVectorFactory"), + ("VASTBASE", "core.rag.datasource.vdb.pyvastbase.vastbase_vector", "VastbaseVectorFactory"), + ("PGVECTO_RS", "core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs", "PGVectoRSFactory"), + ("QDRANT", "core.rag.datasource.vdb.qdrant.qdrant_vector", "QdrantVectorFactory"), + ("RELYT", "core.rag.datasource.vdb.relyt.relyt_vector", "RelytVectorFactory"), + ( + "ELASTICSEARCH", + "core.rag.datasource.vdb.elasticsearch.elasticsearch_vector", + "ElasticSearchVectorFactory", + ), + ( + "ELASTICSEARCH_JA", + "core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector", + "ElasticSearchJaVectorFactory", + ), + ("TIDB_VECTOR", "core.rag.datasource.vdb.tidb_vector.tidb_vector", "TiDBVectorFactory"), + ("WEAVIATE", "core.rag.datasource.vdb.weaviate.weaviate_vector", "WeaviateVectorFactory"), + ("TENCENT", "core.rag.datasource.vdb.tencent.tencent_vector", "TencentVectorFactory"), + ("ORACLE", "core.rag.datasource.vdb.oracle.oraclevector", "OracleVectorFactory"), + ( + "OPENSEARCH", + "core.rag.datasource.vdb.opensearch.opensearch_vector", + "OpenSearchVectorFactory", + ), + ("ANALYTICDB", "core.rag.datasource.vdb.analyticdb.analyticdb_vector", "AnalyticdbVectorFactory"), + ("COUCHBASE", "core.rag.datasource.vdb.couchbase.couchbase_vector", "CouchbaseVectorFactory"), + ("BAIDU", "core.rag.datasource.vdb.baidu.baidu_vector", "BaiduVectorFactory"), + ("VIKINGDB", "core.rag.datasource.vdb.vikingdb.vikingdb_vector", "VikingDBVectorFactory"), + ("UPSTASH", "core.rag.datasource.vdb.upstash.upstash_vector", "UpstashVectorFactory"), + ( + "TIDB_ON_QDRANT", + "core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector", + "TidbOnQdrantVectorFactory", + ), + ("LINDORM", "core.rag.datasource.vdb.lindorm.lindorm_vector", "LindormVectorStoreFactory"), + ("OCEANBASE", "core.rag.datasource.vdb.oceanbase.oceanbase_vector", "OceanBaseVectorFactory"), + ("SEEKDB", "core.rag.datasource.vdb.oceanbase.oceanbase_vector", "OceanBaseVectorFactory"), + ("OPENGAUSS", "core.rag.datasource.vdb.opengauss.opengauss", "OpenGaussFactory"), + ("TABLESTORE", "core.rag.datasource.vdb.tablestore.tablestore_vector", "TableStoreVectorFactory"), + ( + "HUAWEI_CLOUD", + "core.rag.datasource.vdb.huawei.huawei_cloud_vector", + "HuaweiCloudVectorFactory", + ), + ("MATRIXONE", "core.rag.datasource.vdb.matrixone.matrixone_vector", "MatrixoneVectorFactory"), + ("CLICKZETTA", "core.rag.datasource.vdb.clickzetta.clickzetta_vector", "ClickzettaVectorFactory"), + ("IRIS", "core.rag.datasource.vdb.iris.iris_vector", "IrisVectorFactory"), + ], +) +def test_get_vector_factory_supported(vector_factory_module, monkeypatch, vector_type, module_path, class_name): + expected_cls = _register_fake_factory_module(monkeypatch, module_path, class_name) + + result_cls = vector_factory_module.Vector.get_vector_factory(getattr(vector_factory_module.VectorType, vector_type)) + + assert result_cls is expected_cls + + +def test_get_vector_factory_unsupported(vector_factory_module): + with pytest.raises(ValueError, match="not supported"): + vector_factory_module.Vector.get_vector_factory("unknown") + + +def test_vector_init_uses_default_and_custom_attributes(vector_factory_module): + dataset = SimpleNamespace(id="dataset-1") + + with ( + patch.object(vector_factory_module.Vector, "_get_embeddings", return_value="embeddings"), + patch.object(vector_factory_module.Vector, "_init_vector", return_value="processor"), + ): + default_vector = vector_factory_module.Vector(dataset) + custom_vector = vector_factory_module.Vector(dataset, attributes=["doc_id"]) + + assert default_vector._attributes == ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"] + assert custom_vector._attributes == ["doc_id"] + assert default_vector._embeddings == "embeddings" + assert default_vector._vector_processor == "processor" + + +def test_init_vector_prefers_dataset_index_struct(vector_factory_module, monkeypatch): + calls = {"vector_type": None, "init_args": None} + + class _Factory: + def init_vector(self, dataset, attributes, embeddings): + calls["init_args"] = (dataset, attributes, embeddings) + return "vector-processor" + + monkeypatch.setattr( + vector_factory_module.Vector, + "get_vector_factory", + staticmethod(lambda vector_type: calls.update(vector_type=vector_type) or _Factory), + ) + + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._dataset = SimpleNamespace( + index_struct_dict={"type": vector_factory_module.VectorType.UPSTASH}, tenant_id="tenant-1" + ) + vector._attributes = ["doc_id"] + vector._embeddings = "embeddings" + + result = vector._init_vector() + + assert result == "vector-processor" + assert calls["vector_type"] == vector_factory_module.VectorType.UPSTASH + assert calls["init_args"] == (vector._dataset, ["doc_id"], "embeddings") + + +def test_init_vector_uses_whitelist_override(vector_factory_module, monkeypatch): + class _Expr: + def __eq__(self, _other): + return "expr" + + calls = {"vector_type": None} + + class _Factory: + def init_vector(self, dataset, attributes, embeddings): + return "vector-processor" + + monkeypatch.setattr(vector_factory_module, "Whitelist", SimpleNamespace(tenant_id=_Expr(), category=_Expr())) + monkeypatch.setattr(vector_factory_module, "select", lambda _model: SimpleNamespace(where=lambda *_args: "stmt")) + monkeypatch.setattr( + vector_factory_module, + "db", + SimpleNamespace(session=SimpleNamespace(scalars=lambda _stmt: SimpleNamespace(one_or_none=lambda: object()))), + ) + monkeypatch.setattr(vector_factory_module.dify_config, "VECTOR_STORE", vector_factory_module.VectorType.CHROMA) + monkeypatch.setattr(vector_factory_module.dify_config, "VECTOR_STORE_WHITELIST_ENABLE", True) + monkeypatch.setattr( + vector_factory_module.Vector, + "get_vector_factory", + staticmethod(lambda vector_type: calls.update(vector_type=vector_type) or _Factory), + ) + + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._dataset = SimpleNamespace(index_struct_dict=None, tenant_id="tenant-1") + vector._attributes = ["doc_id"] + vector._embeddings = "embeddings" + + result = vector._init_vector() + + assert result == "vector-processor" + assert calls["vector_type"] == vector_factory_module.VectorType.TIDB_ON_QDRANT + + +def test_init_vector_raises_when_vector_store_missing(vector_factory_module, monkeypatch): + monkeypatch.setattr(vector_factory_module.dify_config, "VECTOR_STORE", None) + monkeypatch.setattr(vector_factory_module.dify_config, "VECTOR_STORE_WHITELIST_ENABLE", False) + + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._dataset = SimpleNamespace(index_struct_dict=None, tenant_id="tenant-1") + vector._attributes = [] + vector._embeddings = "embeddings" + + with pytest.raises(ValueError, match="Vector store must be specified"): + vector._init_vector() + + +def test_create_batches_texts_and_skips_empty_input(vector_factory_module): + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._embeddings = MagicMock() + vector._vector_processor = MagicMock() + + docs = [Document(page_content=f"doc-{i}", metadata={"doc_id": f"id-{i}"}) for i in range(1001)] + vector._embeddings.embed_documents.side_effect = [ + [[0.1] for _ in range(1000)], + [[0.2]], + ] + + vector.create(texts=docs, trace_id="trace-1") + + assert vector._embeddings.embed_documents.call_count == 2 + assert vector._vector_processor.create.call_count == 2 + assert vector._vector_processor.create.call_args_list[0].kwargs["trace_id"] == "trace-1" + + vector._embeddings.embed_documents.reset_mock() + vector._vector_processor.create.reset_mock() + vector.create(texts=None) + vector._embeddings.embed_documents.assert_not_called() + vector._vector_processor.create.assert_not_called() + + +def test_create_multimodal_filters_missing_uploads(vector_factory_module, monkeypatch): + class _Field: + def in_(self, value): + return value + + def __eq__(self, value): + return value + + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._embeddings = MagicMock() + vector._embeddings.embed_multimodal_documents.return_value = [[0.1, 0.2]] + vector._vector_processor = MagicMock() + + monkeypatch.setattr(vector_factory_module, "UploadFile", SimpleNamespace(id=_Field())) + monkeypatch.setattr(vector_factory_module, "select", lambda _model: SimpleNamespace(where=lambda *_args: "stmt")) + monkeypatch.setattr( + vector_factory_module, + "db", + SimpleNamespace( + session=SimpleNamespace( + scalars=lambda _stmt: SimpleNamespace(all=lambda: [SimpleNamespace(id="f-1", key="k-1")]) + ) + ), + ) + monkeypatch.setattr(vector_factory_module.storage, "load_once", MagicMock(return_value=b"abc")) + + docs = [ + Document(page_content="file-1", metadata={"doc_id": "f-1", "doc_type": "image"}), + Document(page_content="file-2", metadata={"doc_id": "f-2", "doc_type": "image"}), + ] + + vector.create_multimodal(file_documents=docs, request_id="r-1") + + file_base64 = base64.b64encode(b"abc").decode() + vector._embeddings.embed_multimodal_documents.assert_called_once_with( + [{"content": file_base64, "content_type": "image", "file_id": "f-1"}] + ) + vector._vector_processor.create.assert_called_once_with( + texts=[docs[0]], + embeddings=[[0.1, 0.2]], + request_id="r-1", + ) + + vector._embeddings.embed_multimodal_documents.reset_mock() + vector._vector_processor.create.reset_mock() + vector.create_multimodal(file_documents=None) + vector._embeddings.embed_multimodal_documents.assert_not_called() + vector._vector_processor.create.assert_not_called() + + +def test_add_texts_with_optional_duplicate_check(vector_factory_module): + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._embeddings = MagicMock() + vector._vector_processor = MagicMock() + vector._filter_duplicate_texts = MagicMock() + + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + Document(page_content="b", metadata={"doc_id": "id-2"}), + ] + vector._filter_duplicate_texts.return_value = [docs[0]] + vector._embeddings.embed_documents.return_value = [[0.1]] + + vector.add_texts(docs, duplicate_check=True, flag=True) + + vector._filter_duplicate_texts.assert_called_once_with(docs) + vector._vector_processor.create.assert_called_once_with( + texts=[docs[0]], embeddings=[[0.1]], duplicate_check=True, flag=True + ) + + vector._filter_duplicate_texts.reset_mock() + vector._vector_processor.create.reset_mock() + vector._embeddings.embed_documents.return_value = [[0.2], [0.3]] + + vector.add_texts(docs, duplicate_check=False) + + vector._filter_duplicate_texts.assert_not_called() + vector._vector_processor.create.assert_called_once() + + +def test_vector_delegation_methods(vector_factory_module): + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._embeddings = MagicMock() + vector._embeddings.embed_query.return_value = [0.1, 0.2] + vector._vector_processor = MagicMock() + vector._vector_processor.text_exists.return_value = True + vector._vector_processor.search_by_vector.return_value = ["vector-doc"] + vector._vector_processor.search_by_full_text.return_value = ["text-doc"] + + assert vector.text_exists("doc-1") is True + vector.delete_by_ids(["doc-1"]) + vector.delete_by_metadata_field("doc_id", "doc-1") + assert vector.search_by_vector("hello", top_k=3) == ["vector-doc"] + assert vector.search_by_full_text("hello", top_k=3) == ["text-doc"] + + vector._vector_processor.delete_by_ids.assert_called_once_with(["doc-1"]) + vector._vector_processor.delete_by_metadata_field.assert_called_once_with("doc_id", "doc-1") + + +def test_search_by_file_handles_missing_and_existing_upload(vector_factory_module, monkeypatch): + class _Field: + def __eq__(self, value): + return value + + upload_query = MagicMock() + upload_query.where.return_value = upload_query + + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._embeddings = MagicMock() + vector._vector_processor = MagicMock() + + mock_session = SimpleNamespace(get=lambda _model, _id: None) + monkeypatch.setattr(vector_factory_module, "UploadFile", SimpleNamespace(id=_Field())) + monkeypatch.setattr(vector_factory_module, "db", SimpleNamespace(session=mock_session)) + + assert vector.search_by_file("file-1") == [] + + mock_session.get = lambda _model, _id: SimpleNamespace(key="blob-key") + monkeypatch.setattr(vector_factory_module.storage, "load_once", MagicMock(return_value=b"file-bytes")) + vector._embeddings.embed_multimodal_query.return_value = [0.3, 0.4] + vector._vector_processor.search_by_vector.return_value = ["hit"] + + result = vector.search_by_file("file-2", top_k=2) + + assert result == ["hit"] + payload = vector._embeddings.embed_multimodal_query.call_args.args[0] + assert payload["content_type"] == vector_factory_module.DocType.IMAGE + assert payload["file_id"] == "file-2" + + +def test_delete_clears_redis_cache_when_collection_exists(vector_factory_module, monkeypatch): + delete_mock = MagicMock() + redis_delete = MagicMock() + monkeypatch.setattr(vector_factory_module.redis_client, "delete", redis_delete) + + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._vector_processor = SimpleNamespace(delete=delete_mock, collection_name="collection_1") + + vector.delete() + + delete_mock.assert_called_once() + redis_delete.assert_called_once_with("vector_indexing_collection_1") + + vector._vector_processor = SimpleNamespace(delete=delete_mock, collection_name="") + redis_delete.reset_mock() + vector.delete() + redis_delete.assert_not_called() + + +def test_get_embeddings_builds_cache_embedding(vector_factory_module, monkeypatch): + model_manager = MagicMock() + model_manager.get_model_instance.return_value = "model-instance" + + for_tenant_mock = MagicMock(return_value=model_manager) + monkeypatch.setattr(vector_factory_module.ModelManager, "for_tenant", for_tenant_mock) + monkeypatch.setattr(vector_factory_module, "CacheEmbedding", MagicMock(return_value="cached-embedding")) + + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._dataset = SimpleNamespace( + tenant_id="tenant-1", + embedding_model_provider="openai", + embedding_model="text-embedding-3-small", + ) + + result = vector._get_embeddings() + + assert result == "cached-embedding" + for_tenant_mock.assert_called_once_with(tenant_id="tenant-1") + model_manager.get_model_instance.assert_called_once_with( + tenant_id="tenant-1", + provider="openai", + model_type=vector_factory_module.ModelType.TEXT_EMBEDDING, + model="text-embedding-3-small", + ) + + +def test_filter_duplicate_texts_and_getattr(vector_factory_module): + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector.text_exists = MagicMock(side_effect=lambda doc_id: doc_id == "dup") + + docs = [ + SimpleNamespace(page_content="no-meta", metadata=None), + Document(page_content="empty-doc-id", metadata={"doc_id": ""}), + Document(page_content="duplicate", metadata={"doc_id": "dup"}), + Document(page_content="unique", metadata={"doc_id": "ok"}), + ] + + filtered = vector._filter_duplicate_texts(docs) + assert [doc.page_content for doc in filtered] == ["no-meta", "empty-doc-id", "unique"] + + class _Processor: + def ping(self): + return "pong" + + vector._vector_processor = _Processor() + assert vector.ping() == "pong" + + with pytest.raises(AttributeError): + _ = vector.unknown_method + + vector._vector_processor = None + with pytest.raises(AttributeError, match="vector_processor"): + _ = vector.another_missing diff --git a/api/dify_graph/nodes/answer/__init__.py b/api/tests/unit_tests/core/rag/datasource/vdb/tidb_on_qdrant/__init__.py similarity index 100% rename from api/dify_graph/nodes/answer/__init__.py rename to api/tests/unit_tests/core/rag/datasource/vdb/tidb_on_qdrant/__init__.py diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/tidb_on_qdrant/test_tidb_on_qdrant_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/tidb_on_qdrant/test_tidb_on_qdrant_vector.py new file mode 100644 index 00000000000..c25af79ae4e --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/tidb_on_qdrant/test_tidb_on_qdrant_vector.py @@ -0,0 +1,160 @@ +from unittest.mock import patch + +import httpx +import pytest +from qdrant_client.http import models as rest +from qdrant_client.http.exceptions import UnexpectedResponse + +from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import ( + TidbOnQdrantConfig, + TidbOnQdrantVector, +) + + +class TestTidbOnQdrantVectorDeleteByIds: + """Unit tests for TidbOnQdrantVector.delete_by_ids method.""" + + @pytest.fixture + def vector_instance(self): + """Create a TidbOnQdrantVector instance for testing.""" + config = TidbOnQdrantConfig( + endpoint="http://localhost:6333", + api_key="test_api_key", + ) + + with patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector.qdrant_client.QdrantClient"): + vector = TidbOnQdrantVector( + collection_name="test_collection", + group_id="test_group", + config=config, + ) + return vector + + def test_delete_by_ids_with_multiple_ids(self, vector_instance): + """Test batch deletion with multiple document IDs.""" + ids = ["doc1", "doc2", "doc3"] + + vector_instance.delete_by_ids(ids) + + # Verify that delete was called once with MatchAny filter + vector_instance._client.delete.assert_called_once() + call_args = vector_instance._client.delete.call_args + + # Check collection name + assert call_args[1]["collection_name"] == "test_collection" + + # Verify filter uses MatchAny with all IDs + filter_selector = call_args[1]["points_selector"] + filter_obj = filter_selector.filter + assert len(filter_obj.must) == 1 + + field_condition = filter_obj.must[0] + assert field_condition.key == "metadata.doc_id" + assert isinstance(field_condition.match, rest.MatchAny) + assert set(field_condition.match.any) == {"doc1", "doc2", "doc3"} + + def test_delete_by_ids_with_single_id(self, vector_instance): + """Test deletion with a single document ID.""" + ids = ["doc1"] + + vector_instance.delete_by_ids(ids) + + # Verify that delete was called once + vector_instance._client.delete.assert_called_once() + call_args = vector_instance._client.delete.call_args + + # Verify filter uses MatchAny with single ID + filter_selector = call_args[1]["points_selector"] + filter_obj = filter_selector.filter + field_condition = filter_obj.must[0] + assert isinstance(field_condition.match, rest.MatchAny) + assert field_condition.match.any == ["doc1"] + + def test_delete_by_ids_with_empty_list(self, vector_instance): + """Test deletion with empty ID list returns early without API call.""" + vector_instance.delete_by_ids([]) + + # Verify that delete was NOT called + vector_instance._client.delete.assert_not_called() + + def test_delete_by_ids_with_404_error(self, vector_instance): + """Test that 404 errors (collection not found) are handled gracefully.""" + ids = ["doc1", "doc2"] + + # Mock a 404 error + error = UnexpectedResponse( + status_code=404, + reason_phrase="Not Found", + content=b"Collection not found", + headers=httpx.Headers(), + ) + vector_instance._client.delete.side_effect = error + + # Should not raise an exception + vector_instance.delete_by_ids(ids) + + # Verify delete was called + vector_instance._client.delete.assert_called_once() + + def test_delete_by_ids_with_unexpected_error(self, vector_instance): + """Test that non-404 errors are re-raised.""" + ids = ["doc1", "doc2"] + + # Mock a 500 error + error = UnexpectedResponse( + status_code=500, + reason_phrase="Internal Server Error", + content=b"Server error", + headers=httpx.Headers(), + ) + vector_instance._client.delete.side_effect = error + + # Should re-raise the exception + with pytest.raises(UnexpectedResponse) as exc_info: + vector_instance.delete_by_ids(ids) + + assert exc_info.value.status_code == 500 + + def test_delete_by_ids_with_large_batch(self, vector_instance): + """Test deletion with a large batch of IDs.""" + # Create 1000 IDs + ids = [f"doc_{i}" for i in range(1000)] + + vector_instance.delete_by_ids(ids) + + # Verify single delete call with all IDs + vector_instance._client.delete.assert_called_once() + call_args = vector_instance._client.delete.call_args + + filter_selector = call_args[1]["points_selector"] + filter_obj = filter_selector.filter + field_condition = filter_obj.must[0] + + # Verify all 1000 IDs are in the batch + assert len(field_condition.match.any) == 1000 + assert "doc_0" in field_condition.match.any + assert "doc_999" in field_condition.match.any + + def test_delete_by_ids_filter_structure(self, vector_instance): + """Test that the filter structure is correctly constructed.""" + ids = ["doc1", "doc2"] + + vector_instance.delete_by_ids(ids) + + call_args = vector_instance._client.delete.call_args + filter_selector = call_args[1]["points_selector"] + filter_obj = filter_selector.filter + + # Verify Filter structure + assert isinstance(filter_obj, rest.Filter) + assert filter_obj.must is not None + assert len(filter_obj.must) == 1 + + # Verify FieldCondition structure + field_condition = filter_obj.must[0] + assert isinstance(field_condition, rest.FieldCondition) + assert field_condition.key == "metadata.doc_id" + + # Verify MatchAny structure + assert isinstance(field_condition.match, rest.MatchAny) + assert field_condition.match.any == ids diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/tidb_vector/test_tidb_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/tidb_vector/test_tidb_vector.py new file mode 100644 index 00000000000..951a920f3bc --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/tidb_vector/test_tidb_vector.py @@ -0,0 +1,443 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +@pytest.fixture +def tidb_module(): + import core.rag.datasource.vdb.tidb_vector.tidb_vector as module + + return importlib.reload(module) + + +def _config(tidb_module): + return tidb_module.TiDBVectorConfig( + host="localhost", + port=4000, + user="root", + password="secret", + database="dify", + program_name="dify-app", + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config TIDB_VECTOR_HOST is required"), + ("port", 0, "config TIDB_VECTOR_PORT is required"), + ("user", "", "config TIDB_VECTOR_USER is required"), + ("database", "", "config TIDB_VECTOR_DATABASE is required"), + ("program_name", "", "config APPLICATION_NAME is required"), + ], +) +def test_tidb_config_validation(tidb_module, field, value, message): + values = _config(tidb_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + tidb_module.TiDBVectorConfig.model_validate(values) + + +def test_init_get_type_and_distance_func(tidb_module, monkeypatch): + monkeypatch.setattr(tidb_module, "create_engine", MagicMock(return_value="engine")) + + vector = tidb_module.TiDBVector("collection_1", _config(tidb_module), distance_func="L2") + + assert vector.get_type() == tidb_module.VectorType.TIDB_VECTOR + assert vector._url.startswith("mysql+pymysql://root:secret@localhost:4000/dify") + assert vector._dimension == 1536 + assert vector._get_distance_func() == "VEC_L2_DISTANCE" + + vector._distance_func = "cosine" + assert vector._get_distance_func() == "VEC_COSINE_DISTANCE" + + vector._distance_func = "other" + assert vector._get_distance_func() == "VEC_COSINE_DISTANCE" + + +def test_table_builds_columns_with_tidb_vector_type(tidb_module, monkeypatch): + fake_tidb_vector = types.ModuleType("tidb_vector") + fake_tidb_sqlalchemy = types.ModuleType("tidb_vector.sqlalchemy") + + class _VectorType: + def __init__(self, dim): + self.dim = dim + + fake_tidb_sqlalchemy.VectorType = _VectorType + + monkeypatch.setitem(sys.modules, "tidb_vector", fake_tidb_vector) + monkeypatch.setitem(sys.modules, "tidb_vector.sqlalchemy", fake_tidb_sqlalchemy) + monkeypatch.setattr(tidb_module, "create_engine", MagicMock(return_value=MagicMock())) + monkeypatch.setattr(tidb_module, "Column", lambda *args, **kwargs: SimpleNamespace(args=args, kwargs=kwargs)) + monkeypatch.setattr( + tidb_module, + "Table", + lambda name, _metadata, *columns, **_kwargs: SimpleNamespace(name=name, columns=columns), + ) + + vector = tidb_module.TiDBVector("collection_1", _config(tidb_module)) + table = vector._table(3) + + assert table.name == "collection_1" + column_names = [column.args[0] for column in table.columns] + assert tidb_module.Field.PRIMARY_KEY in column_names + assert tidb_module.Field.VECTOR in column_names + assert tidb_module.Field.TEXT_KEY in column_names + + +def test_create_calls_collection_and_add_texts(tidb_module): + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._create_collection = MagicMock() + vector.add_texts = MagicMock() + + docs = [Document(page_content="a", metadata={"doc_id": "id-1"})] + vector.create(docs, [[0.1, 0.2]]) + + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + assert vector._dimension == 2 + + +def test_create_collection_skips_when_cache_hit(tidb_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(tidb_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(tidb_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(tidb_module.redis_client, "set", MagicMock()) + + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._engine = MagicMock() + + tidb_module.Session = MagicMock() + + vector._create_collection(3) + + tidb_module.Session.assert_not_called() + tidb_module.redis_client.set.assert_not_called() + + +def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(tidb_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(tidb_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(tidb_module.redis_client, "set", MagicMock()) + + session = MagicMock() + + class _SessionCtx: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx()) + + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._engine = MagicMock() + vector._distance_func = "l2" + + vector._create_collection(3) + + session.begin.assert_called_once() + sql = str(session.execute.call_args.args[0]) + assert "VECTOR(3)" in sql + assert "VEC_L2_DISTANCE" in sql + session.commit.assert_called_once() + tidb_module.redis_client.set.assert_called_once() + + +def test_add_texts_batches_inserts_and_returns_ids(tidb_module, monkeypatch): + class _InsertStmt: + def __init__(self, table): + self.table = table + + def values(self, rows): + return {"table": self.table, "rows": rows} + + monkeypatch.setattr(tidb_module, "insert", lambda table: _InsertStmt(table)) + + conn = MagicMock() + transaction = MagicMock() + transaction.__enter__.return_value = None + transaction.__exit__.return_value = None + conn.begin.return_value = transaction + + connection_ctx = MagicMock() + connection_ctx.__enter__.return_value = conn + connection_ctx.__exit__.return_value = None + + engine = MagicMock() + engine.connect.return_value = connection_ctx + + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._engine = engine + vector._table = MagicMock(return_value="table") + + docs = [Document(page_content=f"text-{i}", metadata={"doc_id": f"id-{i}"}) for i in range(501)] + embeddings = [[float(i)] for i in range(501)] + + ids = vector.add_texts(docs, embeddings) + + assert ids[0] == "id-0" + assert len(ids) == 501 + assert conn.execute.call_count == 2 + + +@pytest.fixture +def tidb_vector_with_session(tidb_module, monkeypatch): + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._engine = MagicMock() + session = MagicMock() + + class _SessionCtx: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx()) + return vector, session, tidb_module + + +# 1. search_by_full_text returns empty +def test_search_by_full_text_returns_empty(tidb_vector_with_session): + vector, _, _ = tidb_vector_with_session + assert vector.search_by_full_text("query") == [] + + +# 2. text_exists returns True when ids found +def test_text_exists_returns_true_when_ids_found(tidb_vector_with_session): + vector, _, _ = tidb_vector_with_session + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"]) + assert vector.text_exists("doc-1") is True + + +# 3. text_exists returns False when no ids +def test_text_exists_returns_false_when_no_ids(tidb_vector_with_session): + vector, _, _ = tidb_vector_with_session + vector.get_ids_by_metadata_field = MagicMock(return_value=None) + assert vector.text_exists("doc-1") is False + + +# 4. delete_by_ids delegates to _delete_by_ids when ids found +def test_delete_by_ids_delegates_to_internal_delete(tidb_vector_with_session): + vector, session, tidb_module = tidb_vector_with_session + session.execute.return_value.fetchall.return_value = [("id-a",), ("id-b",)] + vector._delete_by_ids = MagicMock() + # Use real get_ids_by_metadata_field + vector.get_ids_by_metadata_field = tidb_module.TiDBVector.get_ids_by_metadata_field.__get__( + vector, tidb_module.TiDBVector + ) + vector.delete_by_ids(["doc-a", "doc-b"]) + vector._delete_by_ids.assert_called_once_with(["id-a", "id-b"]) + + +# 5. delete_by_ids skips when no ids found +def test_delete_by_ids_skips_when_no_ids_found(tidb_vector_with_session): + vector, session, tidb_module = tidb_vector_with_session + session.execute.return_value.fetchall.return_value = [] + vector._delete_by_ids = MagicMock() + # Use real get_ids_by_metadata_field + vector.get_ids_by_metadata_field = tidb_module.TiDBVector.get_ids_by_metadata_field.__get__( + vector, tidb_module.TiDBVector + ) + vector.delete_by_ids(["doc-c"]) + vector._delete_by_ids.assert_not_called() + + +# 6. get_ids_by_metadata_field returns ids and returns None +def test_get_ids_by_metadata_field_returns_ids_and_returns_none(tidb_vector_with_session): + vector, session, tidb_module = tidb_vector_with_session + # Returns ids + session.execute.return_value.fetchall.return_value = [("id-1",)] + assert vector.get_ids_by_metadata_field("doc_id", "doc-1") == ["id-1"] + # Returns None + session.execute.return_value.fetchall.return_value = [] + assert vector.get_ids_by_metadata_field("doc_id", "doc-1") is None + + +# 1. _delete_by_ids raises on None +def test__delete_by_ids_raises_on_none(tidb_module): + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + with pytest.raises(ValueError, match="No ids provided"): + vector._delete_by_ids(None) + + +# 2. _delete_by_ids returns True and calls execute +def test__delete_by_ids_returns_true_and_calls_execute(tidb_module): + class _IDColumn: + def in_(self, ids): + return ids + + class _Delete: + def where(self, condition): + return condition + + table = SimpleNamespace(c=SimpleNamespace(id=_IDColumn()), delete=lambda: _Delete()) + conn = MagicMock() + tx = MagicMock() + tx.__enter__.return_value = None + tx.__exit__.return_value = None + conn.begin.return_value = tx + conn_ctx = MagicMock() + conn_ctx.__enter__.return_value = conn + conn_ctx.__exit__.return_value = None + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._dimension = 2 + vector._engine = SimpleNamespace(connect=MagicMock(return_value=conn_ctx)) + vector._table = MagicMock(return_value=table) + assert vector._delete_by_ids(["id-1"]) is True + conn.execute.assert_called_once() + + +# 3. _delete_by_ids returns False on RuntimeError +def test__delete_by_ids_returns_false_on_runtime_error(tidb_module): + class _IDColumn: + def in_(self, ids): + return ids + + class _Delete: + def where(self, condition): + return condition + + table = SimpleNamespace(c=SimpleNamespace(id=_IDColumn()), delete=lambda: _Delete()) + conn = MagicMock() + tx = MagicMock() + tx.__enter__.return_value = None + tx.__exit__.return_value = None + conn.begin.return_value = tx + conn_ctx = MagicMock() + conn_ctx.__enter__.return_value = conn + conn_ctx.__exit__.return_value = None + conn.execute.side_effect = RuntimeError("delete failed") + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._dimension = 2 + vector._engine = SimpleNamespace(connect=MagicMock(return_value=conn_ctx)) + vector._table = MagicMock(return_value=table) + assert vector._delete_by_ids(["id-2"]) is False + + +# 4. delete_by_metadata_field calls _delete_by_ids when ids found +def test_delete_by_metadata_field_calls__delete_by_ids_when_ids_found(tidb_module): + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-3"]) + vector._delete_by_ids = MagicMock() + vector.delete_by_metadata_field("doc_id", "doc-3") + vector._delete_by_ids.assert_called_once_with(["id-3"]) + + +# 5. delete_by_metadata_field does nothing when no ids +def test_delete_by_metadata_field_does_nothing_when_no_ids(tidb_module): + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector.get_ids_by_metadata_field = MagicMock(return_value=[]) + vector._delete_by_ids = MagicMock() + vector.delete_by_metadata_field("doc_id", "doc-4") + vector._delete_by_ids.assert_not_called() + + +# Test search_by_vector filters and scores +def test_search_by_vector_filters_and_scores(tidb_module, monkeypatch): + session = MagicMock() + session.execute.return_value = [ + ('{"doc_id":"id-1","document_id":"d-1"}', "text-1", 0.2), + ('{"doc_id":"id-2","document_id":"d-2"}', "text-2", 0.4), + ] + session.commit = MagicMock() + + class _SessionCtx: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx()) + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._engine = MagicMock() + vector._distance_func = "cosine" + docs = vector.search_by_vector( + [0.1, 0.2], + top_k=2, + score_threshold=0.5, + document_ids_filter=["d-1", "d-2"], + ) + assert len(docs) == 2 + assert docs[0].metadata["score"] == pytest.approx(0.8) + assert docs[1].metadata["score"] == pytest.approx(0.6) + sql = str(session.execute.call_args.args[0]) + params = session.execute.call_args.kwargs["params"] + assert "meta->>'$.document_id' in ('d-1', 'd-2')" in sql + assert params["distance"] == pytest.approx(0.5) + assert params["top_k"] == 2 + session.commit.assert_not_called() + + +# Test delete drops table +def test_delete_drops_table(tidb_module, monkeypatch): + session = MagicMock() + session.execute.return_value = None + session.commit = MagicMock() + + class _SessionCtx: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx()) + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._engine = MagicMock() + vector.delete() + drop_sql = str(session.execute.call_args.args[0]) + assert "DROP TABLE IF EXISTS collection_1" in drop_sql + session.commit.assert_called_once() + + +def test_tidb_factory_uses_existing_or_generated_collection(tidb_module, monkeypatch): + factory = tidb_module.TiDBVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(tidb_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_HOST", "localhost") + monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_PORT", 4000) + monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_USER", "root") + monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_PASSWORD", "secret") + monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_DATABASE", "dify") + monkeypatch.setattr(tidb_module.dify_config, "APPLICATION_NAME", "dify-app") + + with patch.object(tidb_module, "TiDBVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/upstash/test_upstash_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/upstash/test_upstash_vector.py new file mode 100644 index 00000000000..ac8a63a44ba --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/upstash/test_upstash_vector.py @@ -0,0 +1,186 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_upstash_module(): + upstash_module = types.ModuleType("upstash_vector") + + class Vector: + def __init__(self, id, vector, metadata, data): + self.id = id + self.vector = vector + self.metadata = metadata + self.data = data + + class Index: + def __init__(self, url, token): + self.url = url + self.token = token + self.info = MagicMock(return_value=SimpleNamespace(dimension=8)) + self.upsert = MagicMock() + self.query = MagicMock(return_value=[]) + self.delete = MagicMock() + self.reset = MagicMock() + + upstash_module.Vector = Vector + upstash_module.Index = Index + return upstash_module + + +@pytest.fixture +def upstash_module(monkeypatch): + # Remove patched modules if present + for modname in ["upstash_vector", "core.rag.datasource.vdb.upstash.upstash_vector"]: + if modname in sys.modules: + monkeypatch.delitem(sys.modules, modname, raising=False) + monkeypatch.setitem(sys.modules, "upstash_vector", _build_fake_upstash_module()) + module = importlib.import_module("core.rag.datasource.vdb.upstash.upstash_vector") + return module + + +def _config(module): + return module.UpstashVectorConfig(url="https://upstash.example", token="token-123") + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("url", "", "Upstash URL is required"), + ("token", "", "Upstash Token is required"), + ], +) +def test_upstash_config_validation(upstash_module, field, value, message): + values = _config(upstash_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + upstash_module.UpstashVectorConfig.model_validate(values) + + +def test_init_get_type_and_dimension(upstash_module, monkeypatch): + vector = upstash_module.UpstashVector("collection_1", _config(upstash_module)) + + assert vector.get_type() == upstash_module.VectorType.UPSTASH + assert vector._table_name == "collection_1" + assert vector._get_index_dimension() == 8 + + vector.index.info.return_value = SimpleNamespace(dimension=None) + assert vector._get_index_dimension() == 1536 + + vector.index.info.return_value = None + assert vector._get_index_dimension() == 1536 + + monkeypatch.setattr(upstash_module, "uuid4", lambda: "generated-uuid") + docs = [Document(page_content="hello", metadata={"doc_id": "id-1"})] + vector.add_texts(docs, [[0.1, 0.2]]) + + vector.index.upsert.assert_called_once() + upsert_vectors = vector.index.upsert.call_args.kwargs["vectors"] + assert upsert_vectors[0].id == "generated-uuid" + + +def test_create_text_exists_and_delete_by_ids(upstash_module): + vector = upstash_module.UpstashVector("collection_1", _config(upstash_module)) + vector.add_texts = MagicMock() + + docs = [Document(page_content="hello", metadata={"doc_id": "id-1"})] + vector.create(docs, [[0.1]]) + vector.add_texts.assert_called_once_with(docs, [[0.1]]) + + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"]) + assert vector.text_exists("doc-1") is True + vector.get_ids_by_metadata_field.return_value = [] + assert vector.text_exists("doc-1") is False + + vector.get_ids_by_metadata_field = MagicMock(side_effect=[["item-1"], [], ["item-2"]]) + vector._delete_by_ids = MagicMock() + vector.delete_by_ids(["doc-1", "doc-2", "doc-3"]) + vector._delete_by_ids.assert_called_once_with(ids=["item-1", "item-2"]) + + +def test_delete_helpers_and_search(upstash_module): + vector = upstash_module.UpstashVector("collection_1", _config(upstash_module)) + + vector._delete_by_ids([]) + vector.index.delete.assert_not_called() + vector._delete_by_ids(["a", "b"]) + vector.index.delete.assert_called_once_with(ids=["a", "b"]) + + vector.index.query.return_value = [SimpleNamespace(id="x-1"), SimpleNamespace(id="x-2")] + ids = vector.get_ids_by_metadata_field("doc_id", "doc-1") + assert ids == ["x-1", "x-2"] + query_kwargs = vector.index.query.call_args.kwargs + assert query_kwargs["top_k"] == 1000 + assert query_kwargs["filter"] == "doc_id = 'doc-1'" + + vector._delete_by_ids = MagicMock() + vector.get_ids_by_metadata_field = MagicMock(return_value=["x-1"]) + vector.delete_by_metadata_field("doc_id", "doc-1") + vector._delete_by_ids.assert_called_once_with(["x-1"]) + + vector._delete_by_ids.reset_mock() + vector.get_ids_by_metadata_field.return_value = [] + vector.delete_by_metadata_field("doc_id", "doc-2") + vector._delete_by_ids.assert_not_called() + + +def test_search_by_vector_filter_threshold_and_delete(upstash_module): + vector = upstash_module.UpstashVector("collection_1", _config(upstash_module)) + vector.index.query.return_value = [ + SimpleNamespace(metadata={"document_id": "d-1"}, data="text-1", score=0.9), + SimpleNamespace(metadata={"document_id": "d-2"}, data="text-2", score=0.3), + SimpleNamespace(metadata=None, data="text-3", score=0.99), + SimpleNamespace(metadata={"document_id": "d-4"}, data=None, score=0.99), + ] + + docs = vector.search_by_vector( + [0.1, 0.2], + top_k=3, + score_threshold=0.5, + document_ids_filter=["d-1", "d-2"], + ) + + assert len(docs) == 1 + assert docs[0].page_content == "text-1" + assert docs[0].metadata["score"] == pytest.approx(0.9) + + search_kwargs = vector.index.query.call_args.kwargs + assert search_kwargs["top_k"] == 3 + assert search_kwargs["filter"] == "document_id in ('d-1', 'd-2')" + + assert vector.search_by_full_text("query") == [] + + vector.delete() + vector.index.reset.assert_called_once() + + +def test_upstash_factory_uses_existing_or_generated_collection(upstash_module, monkeypatch): + factory = upstash_module.UpstashVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(upstash_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(upstash_module.dify_config, "UPSTASH_VECTOR_URL", "https://upstash.example") + monkeypatch.setattr(upstash_module.dify_config, "UPSTASH_VECTOR_TOKEN", "token-123") + + with patch.object(upstash_module, "UpstashVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/vikingdb/test_vikingdb_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/vikingdb/test_vikingdb_vector.py new file mode 100644 index 00000000000..9da92af2d02 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/vikingdb/test_vikingdb_vector.py @@ -0,0 +1,310 @@ +import importlib +import json +import sys +import types +from collections import UserDict +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_vikingdb_modules(): + volcengine = types.ModuleType("volcengine") + volcengine.__path__ = [] + viking_db = types.ModuleType("volcengine.viking_db") + + class Data(UserDict): + def __init__(self, payload): + super().__init__(payload) + self.fields = payload + + class DistanceType: + L2 = "L2" + + class IndexType: + HNSW = "HNSW" + + class QuantType: + Float = "Float" + + class FieldType: + String = "string" + Text = "text" + Vector = "vector" + + class Field: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class VectorIndexParams: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class _Collection: + def __init__(self): + self.upsert_data = MagicMock() + self.fetch_data = MagicMock(return_value=None) + self.delete_data = MagicMock() + + class _Index: + def __init__(self): + self.search = MagicMock(return_value=[]) + self.search_by_vector = MagicMock(return_value=[]) + + class VikingDBService: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.create_collection = MagicMock() + self.create_index = MagicMock() + self.drop_index = MagicMock() + self.drop_collection = MagicMock() + self._collection = _Collection() + self._index = _Index() + self.get_collection = MagicMock(return_value=self._collection) + self.get_index = MagicMock(return_value=self._index) + + viking_db.Data = Data + viking_db.DistanceType = DistanceType + viking_db.Field = Field + viking_db.FieldType = FieldType + viking_db.IndexType = IndexType + viking_db.QuantType = QuantType + viking_db.VectorIndexParams = VectorIndexParams + viking_db.VikingDBService = VikingDBService + + return {"volcengine": volcengine, "volcengine.viking_db": viking_db} + + +@pytest.fixture +def vikingdb_module(monkeypatch): + for name, module in _build_fake_vikingdb_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.vikingdb.vikingdb_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.VikingDBConfig( + access_key="ak", + secret_key="sk", + host="host", + region="region", + scheme="https", + connection_timeout=10, + socket_timeout=20, + ) + + +def test_init_get_type_and_has_checks(vikingdb_module): + vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module)) + + assert vector.get_type() == vikingdb_module.VectorType.VIKINGDB + assert vector._index_name == "collection_1_idx" + + assert vector._has_collection() is True + assert vector._has_index() is True + + vector._client.get_collection.side_effect = RuntimeError("missing") + assert vector._has_collection() is False + vector._client.get_collection.side_effect = None + + vector._client.get_index.side_effect = RuntimeError("missing") + assert vector._has_index() is False + + +def test_create_collection_cache_and_creation_paths(vikingdb_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(vikingdb_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(vikingdb_module.redis_client, "set", MagicMock()) + + vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module)) + + monkeypatch.setattr(vikingdb_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection(3) + vector._client.create_collection.assert_not_called() + vector._client.create_index.assert_not_called() + + monkeypatch.setattr(vikingdb_module.redis_client, "get", MagicMock(return_value=None)) + vector._has_collection = MagicMock(return_value=False) + vector._has_index = MagicMock(return_value=False) + vector._create_collection(4) + + vector._client.create_collection.assert_called_once() + vector._client.create_index.assert_called_once() + vikingdb_module.redis_client.set.assert_called_once() + + +def test_create_and_add_texts(vikingdb_module): + vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module)) + vector._create_collection = MagicMock() + vector.add_texts = MagicMock() + + docs = [Document(page_content="hello", metadata={"doc_id": "id-1"})] + vector.create(docs, [[0.1, 0.2]]) + + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + vector = vikingdb_module.VikingDBVector("collection_2", "group-2", _config(vikingdb_module)) + docs = [ + Document(page_content="a", metadata={"doc_id": "id-a", "document_id": "d-1"}), + Document(page_content="b", metadata={"doc_id": "id-b", "document_id": "d-2"}), + ] + vector.add_texts(docs, [[0.1], [0.2]]) + + vector._client.get_collection.assert_called() + upsert_docs = vector._client.get_collection.return_value.upsert_data.call_args.args[0] + assert upsert_docs[0][vikingdb_module.vdb_Field.PRIMARY_KEY] == "id-a" + assert upsert_docs[0][vikingdb_module.vdb_Field.GROUP_KEY] == "group-2" + + +def test_text_exists_and_delete_operations(vikingdb_module): + vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module)) + + vector._client.get_collection.return_value.fetch_data.return_value = SimpleNamespace(fields={"message": "ok"}) + assert vector.text_exists("id-1") is True + + vector._client.get_collection.return_value.fetch_data.return_value = SimpleNamespace( + fields={"message": "data does not exist"} + ) + assert vector.text_exists("id-1") is False + + vector._client.get_collection.return_value.fetch_data.return_value = None + assert vector.text_exists("id-1") is False + + vector.delete_by_ids(["id-1"]) + vector._client.get_collection.return_value.delete_data.assert_called_once_with(["id-1"]) + + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-2"]) + vector.delete_by_ids = MagicMock() + vector.delete_by_metadata_field("doc_id", "doc-1") + vector.delete_by_ids.assert_called_once_with(["id-2"]) + + +def test_get_ids_and_search_helpers(vikingdb_module): + vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module)) + + vector._client.get_index.return_value.search.return_value = [] + assert vector.get_ids_by_metadata_field("doc_id", "x") == [] + + vector._client.get_index.return_value.search.return_value = [ + SimpleNamespace(id="a", fields={vikingdb_module.vdb_Field.METADATA_KEY: json.dumps({"doc_id": "x"})}), + SimpleNamespace(id="b", fields={vikingdb_module.vdb_Field.METADATA_KEY: json.dumps({"doc_id": "y"})}), + SimpleNamespace(id="c", fields={}), + ] + assert vector.get_ids_by_metadata_field("doc_id", "x") == ["a"] + + empty_docs = vector._get_search_res([], score_threshold=0.1) + assert empty_docs == [] + + results = [ + SimpleNamespace( + id="a", + score=0.3, + fields={ + vikingdb_module.vdb_Field.CONTENT_KEY: "doc-a", + vikingdb_module.vdb_Field.METADATA_KEY: json.dumps({"document_id": "d-1"}), + }, + ), + SimpleNamespace( + id="b", + score=0.9, + fields={ + vikingdb_module.vdb_Field.CONTENT_KEY: "doc-b", + vikingdb_module.vdb_Field.METADATA_KEY: json.dumps({"document_id": "d-2"}), + }, + ), + ] + + docs = vector._get_search_res(results, score_threshold=0.2) + assert [doc.page_content for doc in docs] == ["doc-b", "doc-a"] + + vector._client.get_index.return_value.search_by_vector.return_value = results + filtered_docs = vector.search_by_vector([0.1], top_k=2, score_threshold=0.2, document_ids_filter=["d-2"]) + assert len(filtered_docs) == 1 + assert filtered_docs[0].page_content == "doc-b" + assert vector.search_by_full_text("query") == [] + + +def test_delete_drops_index_and_collection_when_present(vikingdb_module): + vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module)) + vector._has_index = MagicMock(return_value=True) + vector._has_collection = MagicMock(return_value=True) + + vector.delete() + + vector._client.drop_index.assert_called_once_with("collection_1", "collection_1_idx") + vector._client.drop_collection.assert_called_once_with("collection_1") + + vector._client.drop_index.reset_mock() + vector._client.drop_collection.reset_mock() + vector._has_index.return_value = False + vector._has_collection.return_value = False + vector.delete() + + vector._client.drop_index.assert_not_called() + vector._client.drop_collection.assert_not_called() + + +def test_vikingdb_factory_validates_config_and_builds_vector(vikingdb_module, monkeypatch): + factory = vikingdb_module.VikingDBVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(vikingdb_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + + with patch.object(vikingdb_module, "VikingDBVector", return_value="vector") as vector_cls: + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_ACCESS_KEY", "ak") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SECRET_KEY", "sk") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_HOST", "host") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_REGION", "region") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SCHEME", "https") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_CONNECTION_TIMEOUT", 10) + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SOCKET_TIMEOUT", 20) + + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None + + +@pytest.mark.parametrize( + ("field", "message"), + [ + ("VIKINGDB_ACCESS_KEY", "VIKINGDB_ACCESS_KEY should not be None"), + ("VIKINGDB_SECRET_KEY", "VIKINGDB_SECRET_KEY should not be None"), + ("VIKINGDB_HOST", "VIKINGDB_HOST should not be None"), + ("VIKINGDB_REGION", "VIKINGDB_REGION should not be None"), + ("VIKINGDB_SCHEME", "VIKINGDB_SCHEME should not be None"), + ], +) +def test_vikingdb_factory_raises_when_required_config_missing(vikingdb_module, monkeypatch, field, message): + factory = vikingdb_module.VikingDBVectorFactory() + dataset = SimpleNamespace( + id="dataset-1", index_struct_dict={"vector_store": {"class_prefix": "existing"}}, index_struct=None + ) + + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_ACCESS_KEY", "ak") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SECRET_KEY", "sk") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_HOST", "host") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_REGION", "region") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SCHEME", "https") + monkeypatch.setattr(vikingdb_module.dify_config, field, None) + + with pytest.raises(ValueError, match=message): + factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weaviate_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weaviate_vector.py index 3bd656ba848..69d18330011 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weaviate_vector.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weaviate_vector.py @@ -7,10 +7,14 @@ Focuses on verifying that doc_type is properly handled in: - Full-text search result metadata (search_by_full_text) """ +import datetime +import json import unittest from types import SimpleNamespace from unittest.mock import MagicMock, patch +import pytest + from core.rag.datasource.vdb.weaviate import weaviate_vector as weaviate_vector_module from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector from core.rag.models.document import Document @@ -32,6 +36,10 @@ class TestWeaviateVector(unittest.TestCase): def tearDown(self): weaviate_vector_module._weaviate_client = None + def test_config_requires_endpoint(self): + with pytest.raises(ValueError, match="config WEAVIATE_ENDPOINT is required"): + WeaviateConfig(endpoint="") + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") def _create_weaviate_vector(self, mock_weaviate_module): """Helper to create a WeaviateVector instance with mocked client.""" @@ -46,6 +54,85 @@ class TestWeaviateVector(unittest.TestCase): ) return wv, mock_client + def test_shutdown_client_logs_debug_when_close_fails(self): + mock_client = MagicMock() + mock_client.close.side_effect = RuntimeError("close failed") + weaviate_vector_module._weaviate_client = mock_client + + with patch.object(weaviate_vector_module.logger, "debug") as mock_debug: + weaviate_vector_module._shutdown_weaviate_client() + + assert weaviate_vector_module._weaviate_client is None + mock_client.close.assert_called_once() + mock_debug.assert_called_once() + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom") + def test_init_client_reuses_cached_client_without_reconnect(self, mock_connect): + cached_client = MagicMock() + cached_client.is_ready.return_value = True + weaviate_vector_module._weaviate_client = cached_client + + wv = WeaviateVector.__new__(WeaviateVector) + + client = wv._init_client(self.config) + + assert client is cached_client + mock_connect.assert_not_called() + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom") + def test_init_client_reuses_cached_client_after_lock_recheck(self, mock_connect): + cached_client = MagicMock() + cached_client.is_ready.side_effect = [False, True] + weaviate_vector_module._weaviate_client = cached_client + + wv = WeaviateVector.__new__(WeaviateVector) + + client = wv._init_client(self.config) + + assert client is cached_client + mock_connect.assert_not_called() + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.Auth.api_key", return_value="auth-token") + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom") + def test_init_client_parses_custom_grpc_endpoint_without_scheme(self, mock_connect, mock_api_key): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_connect.return_value = mock_client + + wv = WeaviateVector.__new__(WeaviateVector) + config = WeaviateConfig( + endpoint="https://weaviate.example.com", + grpc_endpoint="grpc.example.com:6000", + api_key="test-key", + batch_size=50, + ) + + client = wv._init_client(config) + + assert client is mock_client + assert mock_connect.call_args.kwargs == { + "http_host": "weaviate.example.com", + "http_port": 443, + "http_secure": True, + "grpc_host": "grpc.example.com", + "grpc_port": 6000, + "grpc_secure": False, + "auth_credentials": "auth-token", + "skip_init_checks": True, + } + mock_api_key.assert_called_once_with("test-key") + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom") + def test_init_client_raises_when_database_not_ready(self, mock_connect): + mock_client = MagicMock() + mock_client.is_ready.return_value = False + mock_connect.return_value = mock_client + + wv = WeaviateVector.__new__(WeaviateVector) + + with pytest.raises(ConnectionError, match="Vector database is not ready"): + wv._init_client(self.config) + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") def test_init(self, mock_weaviate_module): """Test WeaviateVector initialization stores attributes including doc_type.""" @@ -62,6 +149,40 @@ class TestWeaviateVector(unittest.TestCase): assert wv._collection_name == self.collection_name assert "doc_type" in wv._attributes + def test_get_type_and_to_index_struct(self): + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + + assert wv.get_type() == weaviate_vector_module.VectorType.WEAVIATE + assert wv.to_index_struct() == { + "type": weaviate_vector_module.VectorType.WEAVIATE, + "vector_store": {"class_prefix": self.collection_name}, + } + + def test_get_collection_name_uses_existing_class_prefix_and_appends_suffix(self): + dataset = SimpleNamespace(index_struct_dict={"vector_store": {"class_prefix": "ExistingCollection"}}, id="ds-1") + wv = WeaviateVector.__new__(WeaviateVector) + + assert wv.get_collection_name(dataset) == "ExistingCollection_Node" + + def test_get_collection_name_generates_name_from_dataset_id(self): + dataset = SimpleNamespace(index_struct_dict=None, id="ds-2") + wv = WeaviateVector.__new__(WeaviateVector) + + with patch.object(weaviate_vector_module.Dataset, "gen_collection_name_by_id", return_value="Generated_Node"): + assert wv.get_collection_name(dataset) == "Generated_Node" + + def test_create_calls_collection_setup_then_add_texts(self): + doc = Document(page_content="hello", metadata={}) + wv = WeaviateVector.__new__(WeaviateVector) + wv._create_collection = MagicMock() + wv.add_texts = MagicMock() + + wv.create([doc], [[0.1, 0.2]]) + + wv._create_collection.assert_called_once() + wv.add_texts.assert_called_once_with([doc], [[0.1, 0.2]]) + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.redis_client") @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.dify_config") @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") @@ -111,6 +232,44 @@ class TestWeaviateVector(unittest.TestCase): f"doc_type should be in collection schema properties, got: {property_names}" ) + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.redis_client") + def test_create_collection_returns_early_when_cache_key_exists(self, mock_redis): + mock_lock = MagicMock() + mock_lock.__enter__ = MagicMock() + mock_lock.__exit__ = MagicMock() + mock_redis.lock.return_value = mock_lock + mock_redis.get.return_value = 1 + + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._ensure_properties = MagicMock() + + wv._create_collection() + + wv._client.collections.exists.assert_not_called() + wv._ensure_properties.assert_not_called() + mock_redis.set.assert_not_called() + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.redis_client") + def test_create_collection_logs_and_reraises_errors(self, mock_redis): + mock_lock = MagicMock() + mock_lock.__enter__ = MagicMock() + mock_lock.__exit__ = MagicMock(return_value=False) + mock_redis.lock.return_value = mock_lock + mock_redis.get.return_value = None + + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._client.collections.exists.side_effect = RuntimeError("create failed") + + with patch.object(weaviate_vector_module.logger, "exception") as mock_exception: + with pytest.raises(RuntimeError, match="create failed"): + wv._create_collection() + + mock_exception.assert_called_once() + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") def test_ensure_properties_adds_missing_doc_type(self, mock_weaviate_module): """Test that _ensure_properties adds doc_type when it's missing from existing schema.""" @@ -146,6 +305,29 @@ class TestWeaviateVector(unittest.TestCase): added_names = [call.args[0].name for call in add_calls] assert "doc_type" in added_names, f"doc_type should be added to existing collection, added: {added_names}" + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_ensure_properties_adds_all_missing_core_properties(self, mock_weaviate_module): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_client.collections.exists.return_value = True + mock_col = MagicMock() + mock_client.collections.use.return_value = mock_col + mock_cfg = MagicMock() + mock_cfg.properties = [SimpleNamespace(name="text")] + mock_col.config.get.return_value = mock_cfg + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + wv._ensure_properties() + + add_calls = mock_col.config.add_property.call_args_list + added_names = [call.args[0].name for call in add_calls] + assert added_names == ["document_id", "doc_id", "doc_type", "chunk_index"] + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") def test_ensure_properties_skips_existing_doc_type(self, mock_weaviate_module): """Test that _ensure_properties does not add doc_type when it already exists.""" @@ -179,6 +361,30 @@ class TestWeaviateVector(unittest.TestCase): # No properties should be added mock_col.config.add_property.assert_not_called() + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_ensure_properties_logs_warning_when_property_addition_fails(self, mock_weaviate_module): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_client.collections.exists.return_value = True + mock_col = MagicMock() + mock_client.collections.use.return_value = mock_col + mock_cfg = MagicMock() + mock_cfg.properties = [] + mock_col.config.get.return_value = mock_cfg + mock_col.config.add_property.side_effect = RuntimeError("cannot add") + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + + with patch.object(weaviate_vector_module.logger, "warning") as mock_warning: + wv._ensure_properties() + + assert mock_warning.call_count == 4 + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") def test_search_by_vector_returns_doc_type_in_metadata(self, mock_weaviate_module): """Test that search_by_vector returns doc_type in document metadata. @@ -226,6 +432,58 @@ class TestWeaviateVector(unittest.TestCase): assert len(docs) == 1 assert docs[0].metadata.get("doc_type") == "image" + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_search_by_vector_uses_document_filter_and_default_distance(self, mock_weaviate_module): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_client.collections.exists.return_value = True + mock_col = MagicMock() + mock_client.collections.use.return_value = mock_col + + mock_obj = MagicMock() + mock_obj.properties = { + "text": "fallback distance result", + "document_id": "doc-1", + "doc_id": "segment-1", + } + mock_obj.metadata = None + + mock_result = MagicMock() + mock_result.objects = [mock_obj] + mock_col.query.near_vector.return_value = mock_result + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + docs = wv.search_by_vector( + query_vector=[0.2] * 3, + document_ids_filter=["doc-1"], + top_k=2, + score_threshold=-1, + ) + + assert len(docs) == 1 + assert docs[0].metadata["score"] == 0.0 + assert mock_col.query.near_vector.call_args.kwargs["filters"] is not None + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_search_by_vector_returns_empty_when_collection_is_missing(self, mock_weaviate_module): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_client.collections.exists.return_value = False + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + + assert wv.search_by_vector(query_vector=[0.1] * 3) == [] + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") def test_search_by_full_text_returns_doc_type_in_metadata(self, mock_weaviate_module): """Test that search_by_full_text also returns doc_type in document metadata.""" @@ -268,6 +526,49 @@ class TestWeaviateVector(unittest.TestCase): assert len(docs) == 1 assert docs[0].metadata.get("doc_type") == "image" + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_search_by_full_text_uses_document_filter(self, mock_weaviate_module): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_client.collections.exists.return_value = True + mock_col = MagicMock() + mock_client.collections.use.return_value = mock_col + + mock_obj = MagicMock() + mock_obj.properties = {"text": "bm25 result", "doc_id": "segment-1"} + mock_obj.vector = [0.3, 0.4] + + mock_result = MagicMock() + mock_result.objects = [mock_obj] + mock_col.query.bm25.return_value = mock_result + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + docs = wv.search_by_full_text(query="bm25", document_ids_filter=["doc-1"]) + + assert len(docs) == 1 + assert docs[0].vector == [0.3, 0.4] + assert mock_col.query.bm25.call_args.kwargs["filters"] is not None + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_search_by_full_text_returns_empty_when_collection_is_missing(self, mock_weaviate_module): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_client.collections.exists.return_value = False + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + + assert wv.search_by_full_text(query="missing") == [] + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") def test_add_texts_stores_doc_type_in_properties(self, mock_weaviate_module): """Test that add_texts includes doc_type from document metadata in stored properties.""" @@ -310,6 +611,135 @@ class TestWeaviateVector(unittest.TestCase): stored_props = call_kwargs.kwargs.get("properties") assert stored_props.get("doc_type") == "image", f"doc_type should be stored in properties, got: {stored_props}" + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_add_texts_falls_back_to_random_uuid_and_serializes_datetime_metadata(self, mock_weaviate_module): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_col = MagicMock() + mock_client.collections.use.return_value = mock_col + + mock_batch = MagicMock() + mock_batch.__enter__ = MagicMock(return_value=mock_batch) + mock_batch.__exit__ = MagicMock(return_value=False) + mock_col.batch.dynamic.return_value = mock_batch + + created_at = datetime.datetime(2024, 1, 2, 3, 4, 5, tzinfo=datetime.UTC) + doc = Document(page_content="text", metadata={"created_at": created_at}) + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + + with ( + patch.object(wv, "_get_uuids", return_value=["not-a-uuid"]), + patch("core.rag.datasource.vdb.weaviate.weaviate_vector._uuid.uuid4", return_value="fallback-uuid"), + ): + ids = wv.add_texts(documents=[doc], embeddings=[[]]) + + assert ids == ["fallback-uuid"] + call_kwargs = mock_batch.add_object.call_args + assert call_kwargs.kwargs["uuid"] == "fallback-uuid" + assert call_kwargs.kwargs["vector"] is None + assert call_kwargs.kwargs["properties"]["created_at"] == created_at.isoformat() + + def test_is_uuid_handles_invalid_values(self): + wv = WeaviateVector.__new__(WeaviateVector) + + assert wv._is_uuid("123e4567-e89b-12d3-a456-426614174000") is True + assert wv._is_uuid("not-a-uuid") is False + + def test_delete_by_metadata_field_returns_when_collection_is_missing(self): + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._client.collections.exists.return_value = False + + wv.delete_by_metadata_field("doc_id", "segment-1") + + wv._client.collections.use.assert_not_called() + + def test_delete_by_metadata_field_deletes_matching_objects(self): + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._client.collections.exists.return_value = True + mock_col = MagicMock() + wv._client.collections.use.return_value = mock_col + + wv.delete_by_metadata_field("doc_id", "segment-1") + + mock_col.data.delete_many.assert_called_once() + + def test_delete_removes_collection_when_present(self): + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._client.collections.exists.return_value = True + + wv.delete() + + wv._client.collections.delete.assert_called_once_with(self.collection_name) + + def test_text_exists_handles_missing_and_present_documents(self): + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._client.collections.exists.side_effect = [False, True] + mock_col = MagicMock() + wv._client.collections.use.return_value = mock_col + mock_col.query.fetch_objects.return_value = SimpleNamespace(objects=[SimpleNamespace()]) + + assert wv.text_exists("segment-1") is False + assert wv.text_exists("segment-1") is True + + def test_delete_by_ids_handles_missing_collections_and_404s(self): + class FakeUnexpectedStatusCodeError(Exception): + def __init__(self, status_code): + super().__init__(f"status={status_code}") + self.status_code = status_code + + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._client.collections.exists.side_effect = [False, True] + mock_col = MagicMock() + wv._client.collections.use.return_value = mock_col + mock_col.data.delete_by_id.side_effect = [FakeUnexpectedStatusCodeError(404), None] + + with patch.object(weaviate_vector_module, "UnexpectedStatusCodeError", FakeUnexpectedStatusCodeError): + wv.delete_by_ids(["ignored"]) + wv.delete_by_ids(["missing-id", "ok-id"]) + + assert mock_col.data.delete_by_id.call_count == 2 + + def test_delete_by_ids_reraises_non_404_errors(self): + class FakeUnexpectedStatusCodeError(Exception): + def __init__(self, status_code): + super().__init__(f"status={status_code}") + self.status_code = status_code + + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._client.collections.exists.return_value = True + mock_col = MagicMock() + wv._client.collections.use.return_value = mock_col + mock_col.data.delete_by_id.side_effect = FakeUnexpectedStatusCodeError(500) + + with patch.object(weaviate_vector_module, "UnexpectedStatusCodeError", FakeUnexpectedStatusCodeError): + with pytest.raises(FakeUnexpectedStatusCodeError, match="status=500"): + wv.delete_by_ids(["bad-id"]) + + def test_json_serializable_converts_datetime(self): + wv = WeaviateVector.__new__(WeaviateVector) + created_at = datetime.datetime(2024, 1, 2, 3, 4, 5, tzinfo=datetime.UTC) + + assert wv._json_serializable(created_at) == created_at.isoformat() + assert wv._json_serializable("plain") == "plain" + class TestVectorDefaultAttributes(unittest.TestCase): """Tests for Vector class default attributes list.""" @@ -331,5 +761,65 @@ class TestVectorDefaultAttributes(unittest.TestCase): assert "doc_type" in vector._attributes, f"doc_type should be in default attributes, got: {vector._attributes}" +class TestWeaviateVectorFactory(unittest.TestCase): + def test_init_vector_uses_existing_dataset_index_struct(self): + dataset = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "ExistingCollection_Node"}}, + index_struct=None, + ) + attributes = ["doc_id"] + + with ( + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_ENDPOINT", "http://localhost:8080"), + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_GRPC_ENDPOINT", "localhost:50051"), + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_API_KEY", "api-key"), + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_BATCH_SIZE", 88), + patch( + "core.rag.datasource.vdb.weaviate.weaviate_vector.WeaviateVector", return_value="vector" + ) as mock_vector, + ): + factory = weaviate_vector_module.WeaviateVectorFactory() + result = factory.init_vector(dataset, attributes, MagicMock()) + + assert result == "vector" + config = mock_vector.call_args.kwargs["config"] + assert mock_vector.call_args.kwargs["collection_name"] == "ExistingCollection_Node" + assert mock_vector.call_args.kwargs["attributes"] == attributes + assert config.endpoint == "http://localhost:8080" + assert config.grpc_endpoint == "localhost:50051" + assert config.api_key == "api-key" + assert config.batch_size == 88 + assert dataset.index_struct is None + + def test_init_vector_generates_collection_and_updates_index_struct(self): + dataset = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + attributes = ["doc_id", "doc_type"] + + with ( + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_ENDPOINT", "http://localhost:8080"), + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_GRPC_ENDPOINT", ""), + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_API_KEY", None), + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_BATCH_SIZE", 100), + patch.object( + weaviate_vector_module.Dataset, + "gen_collection_name_by_id", + return_value="GeneratedCollection_Node", + ), + patch( + "core.rag.datasource.vdb.weaviate.weaviate_vector.WeaviateVector", return_value="vector" + ) as mock_vector, + ): + factory = weaviate_vector_module.WeaviateVectorFactory() + result = factory.init_vector(dataset, attributes, MagicMock()) + + assert result == "vector" + assert mock_vector.call_args.kwargs["collection_name"] == "GeneratedCollection_Node" + assert json.loads(dataset.index_struct) == { + "type": weaviate_vector_module.VectorType.WEAVIATE, + "vector_store": {"class_prefix": "GeneratedCollection_Node"}, + } + + if __name__ == "__main__": unittest.main() diff --git a/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py b/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py index 13285cdad0c..a7b7c1595b6 100644 --- a/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py +++ b/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py @@ -163,11 +163,11 @@ class TestDatasetDocumentStoreAddDocuments: with ( patch("core.rag.docstore.dataset_docstore.db") as mock_db, - patch("core.rag.docstore.dataset_docstore.ModelManager") as mock_manager_class, + patch("core.rag.docstore.dataset_docstore.ModelManager.for_tenant") as mock_manager_class, ): mock_session = MagicMock() mock_db.session = mock_session - mock_db.session.query.return_value.where.return_value.scalar.return_value = None + mock_db.session.scalar.return_value = None mock_manager = MagicMock() mock_manager.get_model_instance.return_value = mock_model_instance @@ -211,7 +211,7 @@ class TestDatasetDocumentStoreAddDocuments: with patch("core.rag.docstore.dataset_docstore.db") as mock_db: mock_session = MagicMock() mock_db.session = mock_session - mock_db.session.query.return_value.where.return_value.scalar.return_value = 5 + mock_db.session.scalar.return_value = 5 with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment): with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"): @@ -276,7 +276,7 @@ class TestDatasetDocumentStoreAddDocuments: with patch("core.rag.docstore.dataset_docstore.db") as mock_db: mock_session = MagicMock() mock_db.session = mock_session - mock_db.session.query.return_value.where.return_value.scalar.return_value = None + mock_db.session.scalar.return_value = None with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None): with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"): @@ -353,7 +353,7 @@ class TestDatasetDocumentStoreAddDocuments: with patch("core.rag.docstore.dataset_docstore.db") as mock_db: mock_session = MagicMock() mock_db.session = mock_session - mock_db.session.query.return_value.where.return_value.scalar.return_value = None + mock_db.session.scalar.return_value = None with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None): with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"): @@ -755,7 +755,7 @@ class TestDatasetDocumentStoreAddDocumentsUpdateChild: with patch("core.rag.docstore.dataset_docstore.db") as mock_db: mock_session = MagicMock() mock_db.session = mock_session - mock_db.session.query.return_value.where.return_value.scalar.return_value = 5 + mock_db.session.scalar.return_value = 5 with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment): with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"): @@ -767,7 +767,7 @@ class TestDatasetDocumentStoreAddDocumentsUpdateChild: store.add_documents([mock_doc], save_child=True) - mock_db.session.query.return_value.where.return_value.delete.assert_called() + mock_db.session.execute.assert_called() mock_db.session.commit.assert_called() @@ -798,7 +798,7 @@ class TestDatasetDocumentStoreAddDocumentsUpdateAnswer: with patch("core.rag.docstore.dataset_docstore.db") as mock_db: mock_session = MagicMock() mock_db.session = mock_session - mock_db.session.query.return_value.where.return_value.scalar.return_value = 5 + mock_db.session.scalar.return_value = 5 with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment): with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"): diff --git a/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py b/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py index a0db25174d5..35631861868 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py +++ b/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py @@ -12,11 +12,11 @@ from unittest.mock import Mock, patch import numpy as np import pytest +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage from sqlalchemy.exc import IntegrityError from core.rag.embedding.cached_embedding import CacheEmbedding -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey -from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage from models.dataset import Embedding @@ -28,6 +28,7 @@ class TestCacheEmbeddingMultimodalDocuments: """Create a mock ModelInstance for testing.""" model_instance = Mock() model_instance.model = "vision-embedding-model" + model_instance.model_name = "vision-embedding-model" model_instance.provider = "openai" model_instance.credentials = {"api_key": "test-key"} @@ -64,11 +65,11 @@ class TestCacheEmbeddingMultimodalDocuments: def test_embed_single_multimodal_document_cache_miss(self, mock_model_instance, sample_multimodal_result): """Test embedding a single multimodal document when cache is empty.""" - cache_embedding = CacheEmbedding(mock_model_instance, user="test-user") + cache_embedding = CacheEmbedding(mock_model_instance) documents = [{"file_id": "file123", "content": "test content"}] with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None mock_model_instance.invoke_multimodal_embedding.return_value = sample_multimodal_result result = cache_embedding.embed_multimodal_documents(documents) @@ -113,7 +114,7 @@ class TestCacheEmbeddingMultimodalDocuments: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result result = cache_embedding.embed_multimodal_documents(documents) @@ -133,7 +134,7 @@ class TestCacheEmbeddingMultimodalDocuments: mock_cached_embedding.get_embedding.return_value = normalized_cached with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = mock_cached_embedding + mock_session.scalar.return_value = mock_cached_embedding result = cache_embedding.embed_multimodal_documents(documents) @@ -179,18 +180,7 @@ class TestCacheEmbeddingMultimodalDocuments: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - call_count = [0] - - def mock_filter_by(**kwargs): - call_count[0] += 1 - mock_query = Mock() - if call_count[0] == 1: - mock_query.first.return_value = mock_cached_embedding - else: - mock_query.first.return_value = None - return mock_query - - mock_session.query.return_value.filter_by = mock_filter_by + mock_session.scalar.side_effect = [mock_cached_embedding, None, None] mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result result = cache_embedding.embed_multimodal_documents(documents) @@ -223,7 +213,7 @@ class TestCacheEmbeddingMultimodalDocuments: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result with patch("core.rag.embedding.cached_embedding.logger") as mock_logger: @@ -264,7 +254,7 @@ class TestCacheEmbeddingMultimodalDocuments: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None batch_results = [create_batch_result(10), create_batch_result(10), create_batch_result(5)] mock_model_instance.invoke_multimodal_embedding.side_effect = batch_results @@ -280,7 +270,7 @@ class TestCacheEmbeddingMultimodalDocuments: documents = [{"file_id": "file123"}] with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None mock_model_instance.invoke_multimodal_embedding.side_effect = Exception("API Error") with pytest.raises(Exception) as exc_info: @@ -297,7 +287,7 @@ class TestCacheEmbeddingMultimodalDocuments: documents = [{"file_id": "file123"}] with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None mock_model_instance.invoke_multimodal_embedding.return_value = sample_multimodal_result mock_session.commit.side_effect = IntegrityError("Duplicate key", None, None) @@ -316,13 +306,14 @@ class TestCacheEmbeddingMultimodalQuery: """Create a mock ModelInstance for testing.""" model_instance = Mock() model_instance.model = "vision-embedding-model" + model_instance.model_name = "vision-embedding-model" model_instance.provider = "openai" model_instance.credentials = {"api_key": "test-key"} return model_instance def test_embed_multimodal_query_cache_miss(self, mock_model_instance): """Test embedding multimodal query when Redis cache is empty.""" - cache_embedding = CacheEmbedding(mock_model_instance, user="test-user") + cache_embedding = CacheEmbedding(mock_model_instance) document = {"file_id": "file123"} vector = np.random.randn(1536) @@ -467,6 +458,7 @@ class TestCacheEmbeddingQueryErrors: """Create a mock ModelInstance for testing.""" model_instance = Mock() model_instance.model = "text-embedding-ada-002" + model_instance.model_name = "text-embedding-ada-002" model_instance.provider = "openai" model_instance.credentials = {"api_key": "test-key"} return model_instance @@ -532,24 +524,13 @@ class TestCacheEmbeddingQueryErrors: class TestCacheEmbeddingInitialization: """Test suite for CacheEmbedding initialization.""" - def test_initialization_with_user(self): - """Test CacheEmbedding initialization with user parameter.""" - model_instance = Mock() - model_instance.model = "test-model" - model_instance.provider = "test-provider" - - cache_embedding = CacheEmbedding(model_instance, user="test-user") - - assert cache_embedding._model_instance == model_instance - assert cache_embedding._user == "test-user" - - def test_initialization_without_user(self): - """Test CacheEmbedding initialization without user parameter.""" + def test_initialization_sets_model_instance(self): + """Test CacheEmbedding initialization stores the provided model instance.""" model_instance = Mock() model_instance.model = "test-model" + model_instance.model_name = "test-model" model_instance.provider = "test-provider" cache_embedding = CacheEmbedding(model_instance) assert cache_embedding._model_instance == model_instance - assert cache_embedding._user is None diff --git a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py index 6e71f0c61f6..408cf14a51b 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py +++ b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py @@ -49,17 +49,17 @@ from unittest.mock import Mock, patch import numpy as np import pytest -from sqlalchemy.exc import IntegrityError - -from core.entities.embedding_type import EmbeddingInputType -from core.rag.embedding.cached_embedding import CacheEmbedding -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey -from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage -from dify_graph.model_runtime.errors.invoke import ( +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage +from graphon.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError, ) +from sqlalchemy.exc import IntegrityError + +from core.entities.embedding_type import EmbeddingInputType +from core.rag.embedding.cached_embedding import CacheEmbedding from models.dataset import Embedding @@ -134,12 +134,12 @@ class TestCacheEmbeddingDocuments: - Correct return value """ # Arrange - cache_embedding = CacheEmbedding(mock_model_instance, user="test-user") + cache_embedding = CacheEmbedding(mock_model_instance) texts = ["Python is a programming language"] # Mock database query to return no cached embedding (cache miss) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None # Mock model invocation mock_model_instance.invoke_text_embedding.return_value = sample_embedding_result @@ -156,7 +156,6 @@ class TestCacheEmbeddingDocuments: # Verify model was invoked with correct parameters mock_model_instance.invoke_text_embedding.assert_called_once_with( texts=texts, - user="test-user", input_type=EmbeddingInputType.DOCUMENT, ) @@ -204,7 +203,7 @@ class TestCacheEmbeddingDocuments: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None mock_model_instance.invoke_text_embedding.return_value = embedding_result # Act @@ -241,7 +240,7 @@ class TestCacheEmbeddingDocuments: with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: # Mock database to return cached embedding (cache hit) - mock_session.query.return_value.filter_by.return_value.first.return_value = mock_cached_embedding + mock_session.scalar.return_value = mock_cached_embedding # Act result = cache_embedding.embed_documents(texts) @@ -314,19 +313,7 @@ class TestCacheEmbeddingDocuments: mock_hash.side_effect = generate_hash # Mock database to return cached embedding only for first text (hash_1) - call_count = [0] - - def mock_filter_by(**kwargs): - call_count[0] += 1 - mock_query = Mock() - # First call (hash_1) returns cached, others return None - if call_count[0] == 1: - mock_query.first.return_value = mock_cached_embedding - else: - mock_query.first.return_value = None - return mock_query - - mock_session.query.return_value.filter_by = mock_filter_by + mock_session.scalar.side_effect = [mock_cached_embedding, None, None] mock_model_instance.invoke_text_embedding.return_value = embedding_result # Act @@ -393,7 +380,7 @@ class TestCacheEmbeddingDocuments: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None # Mock model to return appropriate batch results batch_results = [ @@ -456,7 +443,7 @@ class TestCacheEmbeddingDocuments: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None mock_model_instance.invoke_text_embedding.return_value = embedding_result with patch("core.rag.embedding.cached_embedding.logger") as mock_logger: @@ -490,7 +477,7 @@ class TestCacheEmbeddingDocuments: texts = ["Test text"] with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None # Mock model to raise connection error mock_model_instance.invoke_text_embedding.side_effect = InvokeConnectionError("Failed to connect to API") @@ -516,7 +503,7 @@ class TestCacheEmbeddingDocuments: texts = ["Test text"] with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None # Mock model to raise rate limit error mock_model_instance.invoke_text_embedding.side_effect = InvokeRateLimitError("Rate limit exceeded") @@ -540,7 +527,7 @@ class TestCacheEmbeddingDocuments: texts = ["Test text"] with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None # Mock model to raise authorization error mock_model_instance.invoke_text_embedding.side_effect = InvokeAuthorizationError("Invalid API key") @@ -565,7 +552,7 @@ class TestCacheEmbeddingDocuments: texts = ["Test text"] with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None mock_model_instance.invoke_text_embedding.return_value = sample_embedding_result # Mock database commit to raise IntegrityError @@ -612,7 +599,7 @@ class TestCacheEmbeddingQuery: - Correct return value """ # Arrange - cache_embedding = CacheEmbedding(mock_model_instance, user="test-user") + cache_embedding = CacheEmbedding(mock_model_instance) query = "What is Python?" # Create embedding result @@ -651,7 +638,6 @@ class TestCacheEmbeddingQuery: # Verify model was invoked with QUERY input type mock_model_instance.invoke_text_embedding.assert_called_once_with( texts=[query], - user="test-user", input_type=EmbeddingInputType.QUERY, ) @@ -886,7 +872,7 @@ class TestEmbeddingModelSwitching: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None model_instance_ada.invoke_text_embedding.return_value = result_ada model_instance_3_small.invoke_text_embedding.return_value = result_3_small @@ -1049,7 +1035,7 @@ class TestEmbeddingDimensionValidation: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None mock_model_instance.invoke_text_embedding.return_value = embedding_result # Act @@ -1102,7 +1088,7 @@ class TestEmbeddingDimensionValidation: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None mock_model_instance.invoke_text_embedding.return_value = embedding_result # Act @@ -1188,7 +1174,7 @@ class TestEmbeddingDimensionValidation: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None model_instance_ada.invoke_text_embedding.return_value = result_ada model_instance_cohere.invoke_text_embedding.return_value = result_cohere @@ -1286,7 +1272,7 @@ class TestEmbeddingEdgeCases: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None mock_model_instance.invoke_text_embedding.return_value = embedding_result # Act @@ -1329,7 +1315,7 @@ class TestEmbeddingEdgeCases: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None mock_model_instance.invoke_text_embedding.return_value = embedding_result # Act @@ -1377,7 +1363,7 @@ class TestEmbeddingEdgeCases: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None mock_model_instance.invoke_text_embedding.return_value = embedding_result # Act @@ -1429,7 +1415,7 @@ class TestEmbeddingEdgeCases: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None mock_model_instance.invoke_text_embedding.return_value = embedding_result # Act @@ -1485,7 +1471,7 @@ class TestEmbeddingEdgeCases: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None mock_model_instance.invoke_text_embedding.return_value = embedding_result # Act @@ -1553,7 +1539,7 @@ class TestEmbeddingEdgeCases: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None mock_model_instance.invoke_text_embedding.return_value = embedding_result # Act @@ -1568,25 +1554,16 @@ class TestEmbeddingEdgeCases: norm = np.linalg.norm(emb) assert abs(norm - 1.0) < 0.01 - def test_embed_query_with_user_context(self, mock_model_instance): - """Test query embedding with user context parameter. + def test_embed_query_uses_bound_model_instance(self, mock_model_instance): + """Test query embedding using the provided model instance. Verifies: - - User parameter is passed correctly to model - - User context is used for tracking/logging - - Embedding generation works with user context - - Context: - -------- - The user parameter is important for: - 1. Usage tracking per user - 2. Rate limiting per user - 3. Audit logging - 4. Personalization (in some models) + - Embedding generation works with the injected model instance + - Query input type is preserved + - No extra binding step is required at call time """ # Arrange - user_id = "user-12345" - cache_embedding = CacheEmbedding(mock_model_instance, user=user_id) + cache_embedding = CacheEmbedding(mock_model_instance) query = "What is machine learning?" # Create embedding @@ -1620,24 +1597,20 @@ class TestEmbeddingEdgeCases: assert isinstance(result, list) assert len(result) == 1536 - # Verify user parameter was passed to model mock_model_instance.invoke_text_embedding.assert_called_once_with( texts=[query], - user=user_id, input_type=EmbeddingInputType.QUERY, ) - def test_embed_documents_with_user_context(self, mock_model_instance): - """Test document embedding with user context parameter. + def test_embed_documents_uses_bound_model_instance(self, mock_model_instance): + """Test document embedding using the provided model instance. Verifies: - - User parameter is passed correctly for document embeddings - - Batch processing maintains user context - - User tracking works across batches + - Batch processing uses the injected model instance + - Document input type is preserved """ # Arrange - user_id = "user-67890" - cache_embedding = CacheEmbedding(mock_model_instance, user=user_id) + cache_embedding = CacheEmbedding(mock_model_instance) texts = ["Document 1", "Document 2"] # Create embeddings @@ -1664,7 +1637,7 @@ class TestEmbeddingEdgeCases: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None mock_model_instance.invoke_text_embedding.return_value = embedding_result # Act @@ -1673,10 +1646,8 @@ class TestEmbeddingEdgeCases: # Assert assert len(result) == 2 - # Verify user parameter was passed mock_model_instance.invoke_text_embedding.assert_called_once() call_args = mock_model_instance.invoke_text_embedding.call_args - assert call_args.kwargs["user"] == user_id assert call_args.kwargs["input_type"] == EmbeddingInputType.DOCUMENT @@ -1745,7 +1716,7 @@ class TestEmbeddingCachePerformance: with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: # First call: cache miss - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None usage = EmbeddingUsage( tokens=5, @@ -1773,7 +1744,7 @@ class TestEmbeddingCachePerformance: assert len(result1) == 1 # Arrange - Second call: cache hit - mock_session.query.return_value.filter_by.return_value.first.return_value = mock_cached_embedding + mock_session.scalar.return_value = mock_cached_embedding # Act - Second call (cache hit) result2 = cache_embedding.embed_documents([text]) @@ -1833,7 +1804,7 @@ class TestEmbeddingCachePerformance: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None # Mock model to return appropriate batch results batch_results = [ diff --git a/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py index 2add12fd093..db49221583f 100644 --- a/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py +++ b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py @@ -164,6 +164,13 @@ class TestFirecrawlApp: with pytest.raises(Exception, match="No page found"): app.check_crawl_status("job-1") + def test_check_crawl_status_completed_with_null_total_raises(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + mocker.patch("httpx.get", return_value=_response(200, {"status": "completed", "total": None, "data": []})) + + with pytest.raises(Exception, match="No page found"): + app.check_crawl_status("job-1") + def test_check_crawl_status_non_completed(self, mocker: MockerFixture): app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") payload = {"status": "processing", "total": 5, "completed": 1, "data": []} @@ -203,6 +210,77 @@ class TestFirecrawlApp: with pytest.raises(Exception, match="Error saving crawl data"): app.check_crawl_status("job-err") + def test_check_crawl_status_follows_pagination(self, mocker: MockerFixture): + """When status is completed and next is present, follow pagination to collect all pages.""" + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + page1 = { + "status": "completed", + "total": 3, + "completed": 3, + "next": "https://custom.firecrawl.dev/v2/crawl/job-42?skip=1", + "data": [{"metadata": {"title": "p1", "description": "", "sourceURL": "https://p1"}, "markdown": "m1"}], + } + page2 = { + "status": "completed", + "total": 3, + "completed": 3, + "next": "https://custom.firecrawl.dev/v2/crawl/job-42?skip=2", + "data": [{"metadata": {"title": "p2", "description": "", "sourceURL": "https://p2"}, "markdown": "m2"}], + } + page3 = { + "status": "completed", + "total": 3, + "completed": 3, + "data": [{"metadata": {"title": "p3", "description": "", "sourceURL": "https://p3"}, "markdown": "m3"}], + } + mocker.patch("httpx.get", side_effect=[_response(200, page1), _response(200, page2), _response(200, page3)]) + mock_storage = MagicMock() + mock_storage.exists.return_value = False + mocker.patch.object(firecrawl_module, "storage", mock_storage) + + result = app.check_crawl_status("job-42") + + assert result["status"] == "completed" + assert result["total"] == 3 + assert len(result["data"]) == 3 + assert [d["title"] for d in result["data"]] == ["p1", "p2", "p3"] + + def test_check_crawl_status_pagination_error_raises(self, mocker: MockerFixture): + """An error while fetching a paginated page raises an exception; no partial data is returned.""" + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + page1 = { + "status": "completed", + "total": 2, + "completed": 2, + "next": "https://custom.firecrawl.dev/v2/crawl/job-99?skip=1", + "data": [{"metadata": {"title": "p1", "description": "", "sourceURL": "https://p1"}, "markdown": "m1"}], + } + mocker.patch("httpx.get", side_effect=[_response(200, page1), _response(500, {"error": "server error"})]) + + with pytest.raises(Exception, match="fetch next crawl page"): + app.check_crawl_status("job-99") + + def test_check_crawl_status_pagination_capped_at_total(self, mocker: MockerFixture): + """Pagination stops once pages_processed reaches total, even if next is present.""" + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + # total=1: only the first page should be processed; next must not be followed + page1 = { + "status": "completed", + "total": 1, + "completed": 1, + "next": "https://custom.firecrawl.dev/v2/crawl/job-cap?skip=1", + "data": [{"metadata": {"title": "p1", "description": "", "sourceURL": "https://p1"}, "markdown": "m1"}], + } + mock_get = mocker.patch("httpx.get", return_value=_response(200, page1)) + mock_storage = MagicMock() + mock_storage.exists.return_value = False + mocker.patch.object(firecrawl_module, "storage", mock_storage) + + result = app.check_crawl_status("job-cap") + + assert len(result["data"]) == 1 + mock_get.assert_called_once() # initial fetch only; next URL is not followed due to cap + def test_extract_common_fields_and_status_formatter(self): app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") diff --git a/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py index 6daee11f8f9..808e41867ed 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py @@ -405,35 +405,36 @@ class TestNotionMetadataAndCredentialMethods: class FakeDocumentModel: data_source_info = "data_source_info" + id = "id" - update_calls = [] + execute_calls = [] - class FakeQuery: - def filter_by(self, **kwargs): + class FakeUpdateStmt: + def where(self, *args): return self - def update(self, payload): - update_calls.append(payload) + def values(self, **kwargs): + return self class FakeSession: committed = False - def query(self, model): - assert model is FakeDocumentModel - return FakeQuery() + def execute(self, stmt): + execute_calls.append(stmt) def commit(self): self.committed = True fake_db = SimpleNamespace(session=FakeSession()) monkeypatch.setattr(notion_extractor, "DocumentModel", FakeDocumentModel) + monkeypatch.setattr(notion_extractor, "update", lambda model: FakeUpdateStmt()) monkeypatch.setattr(notion_extractor, "db", fake_db) monkeypatch.setattr(extractor, "get_notion_last_edited_time", lambda: "2026-01-01T00:00:00.000Z") doc_model = SimpleNamespace(id="doc-1", data_source_info_dict={"source": "notion"}) extractor.update_last_edited_time(doc_model) - assert update_calls + assert execute_calls assert fake_db.session.committed is True def test_get_notion_last_edited_time_uses_page_and_database_urls(self, mocker: MockerFixture): diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py index e6cc582398b..d4b987c8325 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py @@ -2,13 +2,14 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, ImagePromptMessageContent +from graphon.model_runtime.entities.model_entities import ModelFeature from core.entities.knowledge_entities import PreviewDetail +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor from core.rag.models.document import AttachmentDocument, Document -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, ImagePromptMessageContent -from dify_graph.model_runtime.entities.model_entities import ModelFeature class TestParagraphIndexProcessor: @@ -21,7 +22,7 @@ class TestParagraphIndexProcessor: dataset = Mock() dataset.id = "dataset-1" dataset.tenant_id = "tenant-1" - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.is_multimodal = True return dataset @@ -167,7 +168,7 @@ class TestParagraphIndexProcessor: def test_load_uses_keyword_add_texts_with_keywords_when_economy( self, processor: ParagraphIndexProcessor, dataset: Mock ) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY docs = [Document(page_content="chunk", metadata={})] with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls: @@ -178,7 +179,7 @@ class TestParagraphIndexProcessor: def test_load_uses_keyword_add_texts_without_keywords_when_economy( self, processor: ParagraphIndexProcessor, dataset: Mock ) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY docs = [Document(page_content="chunk", metadata={})] with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls: @@ -187,10 +188,10 @@ class TestParagraphIndexProcessor: mock_keyword_cls.return_value.add_texts.assert_called_once_with(docs) def test_clean_deletes_summaries_and_vector(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None: - segment_query = Mock() - segment_query.filter.return_value.all.return_value = [SimpleNamespace(id="seg-1")] + scalars_result = Mock() + scalars_result.all.return_value = [SimpleNamespace(id="seg-1")] session = Mock() - session.query.return_value = segment_query + session.scalars.return_value = scalars_result with ( patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session), @@ -208,7 +209,7 @@ class TestParagraphIndexProcessor: def test_clean_economy_deletes_summaries_and_keywords( self, processor: ParagraphIndexProcessor, dataset: Mock ) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY with ( patch( @@ -222,7 +223,7 @@ class TestParagraphIndexProcessor: mock_keyword_cls.return_value.delete.assert_called_once() def test_clean_deletes_keywords_by_ids(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls: processor.clean(dataset, ["node-2"], with_keywords=True) @@ -267,7 +268,7 @@ class TestParagraphIndexProcessor: def test_index_list_chunks_economy( self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock ) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY with ( patch( "core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash", @@ -399,7 +400,9 @@ class TestParagraphIndexProcessor: model_instance.invoke_llm.return_value = self._llm_result("text summary") with ( - patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls, + patch( + "core.rag.index_processor.processor.paragraph_index_processor.create_plugin_provider_manager" + ) as mock_provider_manager, patch( "core.rag.index_processor.processor.paragraph_index_processor.ModelInstance", return_value=model_instance, @@ -410,7 +413,7 @@ class TestParagraphIndexProcessor: ), patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, ): - mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock() + mock_provider_manager.return_value.get_provider_model_bundle.return_value = Mock() summary, usage = ParagraphIndexProcessor.generate_summary( "tenant-1", "text content", @@ -433,7 +436,9 @@ class TestParagraphIndexProcessor: image_content = ImagePromptMessageContent(format="url", mime_type="image/png", url="http://example.com/a.png") with ( - patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls, + patch( + "core.rag.index_processor.processor.paragraph_index_processor.create_plugin_provider_manager" + ) as mock_provider_manager, patch( "core.rag.index_processor.processor.paragraph_index_processor.ModelInstance", return_value=model_instance, @@ -448,7 +453,7 @@ class TestParagraphIndexProcessor: ), patch("core.rag.index_processor.processor.paragraph_index_processor.deduct_llm_quota"), ): - mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock() + mock_provider_manager.return_value.get_provider_model_bundle.return_value = Mock() summary, _ = ParagraphIndexProcessor.generate_summary( "tenant-1", "text content", @@ -469,7 +474,9 @@ class TestParagraphIndexProcessor: image_file = SimpleNamespace() with ( - patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls, + patch( + "core.rag.index_processor.processor.paragraph_index_processor.create_plugin_provider_manager" + ) as mock_provider_manager, patch( "core.rag.index_processor.processor.paragraph_index_processor.ModelInstance", return_value=model_instance, @@ -486,7 +493,7 @@ class TestParagraphIndexProcessor: ), patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, ): - mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock() + mock_provider_manager.return_value.get_provider_model_bundle.return_value = Mock() with pytest.raises(ValueError, match="Expected LLMResult"): ParagraphIndexProcessor.generate_summary( "tenant-1", @@ -524,10 +531,10 @@ class TestParagraphIndexProcessor: size=1, key="key", ) - query = Mock() - query.where.return_value.all.return_value = [image_upload, non_image_upload] + scalars_result = Mock() + scalars_result.all.return_value = [image_upload, non_image_upload] session = Mock() - session.query.return_value = query + session.scalars.return_value = scalars_result with ( patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session), @@ -558,10 +565,10 @@ class TestParagraphIndexProcessor: size=1, key="key", ) - query = Mock() - query.where.return_value.all.return_value = [image_upload] + scalars_result = Mock() + scalars_result.all.return_value = [image_upload] session = Mock() - session.query.return_value = query + session.scalars.return_value = scalars_result with ( patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session), diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py index 5c78cae7c1d..d363a0804d0 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py @@ -4,6 +4,7 @@ from unittest.mock import MagicMock, Mock, patch import pytest from core.entities.knowledge_entities import PreviewDetail +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor from core.rag.models.document import AttachmentDocument, ChildDocument, Document from services.entities.knowledge_entities.knowledge_entities import ParentMode @@ -19,7 +20,7 @@ class TestParentChildIndexProcessor: dataset = Mock() dataset.id = "dataset-1" dataset.tenant_id = "tenant-1" - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.is_multimodal = True return dataset @@ -207,11 +208,7 @@ class TestParentChildIndexProcessor: vector.create_multimodal.assert_called_once_with(multimodal_docs) def test_clean_with_precomputed_child_ids(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None: - delete_query = Mock() - where_query = Mock() - where_query.delete.return_value = 2 session = Mock() - session.query.return_value.where.return_value = where_query with ( patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls, @@ -226,16 +223,16 @@ class TestParentChildIndexProcessor: ) vector.delete_by_ids.assert_called_once_with(["child-1", "child-2"]) - where_query.delete.assert_called_once_with(synchronize_session=False) + session.execute.assert_called() session.commit.assert_called_once() def test_clean_queries_child_ids_when_not_precomputed( self, processor: ParentChildIndexProcessor, dataset: Mock ) -> None: - child_query = Mock() - child_query.join.return_value.where.return_value.all.return_value = [("child-1",), (None,), ("child-2",)] + execute_result = Mock() + execute_result.all.return_value = [("child-1",), (None,), ("child-2",)] session = Mock() - session.query.return_value = child_query + session.execute.return_value = execute_result with ( patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls, @@ -247,10 +244,7 @@ class TestParentChildIndexProcessor: vector.delete_by_ids.assert_called_once_with(["child-1", "child-2"]) def test_clean_dataset_wide_cleanup(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None: - where_query = Mock() - where_query.delete.return_value = 3 session = Mock() - session.query.return_value.where.return_value = where_query with ( patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls, @@ -260,7 +254,7 @@ class TestParentChildIndexProcessor: processor.clean(dataset, None, delete_child_chunks=True) vector.delete.assert_called_once() - where_query.delete.assert_called_once_with(synchronize_session=False) + session.execute.assert_called() session.commit.assert_called_once() def test_clean_deletes_summaries_when_requested(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None: diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py index 99323eeec9e..98c47bec8f9 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py @@ -6,6 +6,7 @@ import pytest from werkzeug.datastructures import FileStorage from core.entities.knowledge_entities import PreviewDetail +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor from core.rag.models.document import AttachmentDocument, Document @@ -33,7 +34,7 @@ class TestQAIndexProcessor: dataset = Mock() dataset.id = "dataset-1" dataset.tenant_id = "tenant-1" - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.is_multimodal = True return dataset @@ -207,7 +208,7 @@ class TestQAIndexProcessor: vector.create_multimodal.assert_called_once_with(multimodal_docs) def test_load_skips_vector_for_non_high_quality(self, processor: QAIndexProcessor, dataset: Mock) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY docs = [Document(page_content="Q1", metadata={"answer": "A1"})] with patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls: @@ -298,7 +299,7 @@ class TestQAIndexProcessor: def test_index_requires_high_quality( self, processor: QAIndexProcessor, dataset: Mock, dataset_document: Mock ) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY qa_chunks = SimpleNamespace(qa_chunks=[SimpleNamespace(question="Q1", answer="A1")]) with ( diff --git a/api/tests/unit_tests/core/rag/indexing/test_index_processor_base.py b/api/tests/unit_tests/core/rag/indexing/test_index_processor_base.py index b31bb6eea7e..12c5238f5e6 100644 --- a/api/tests/unit_tests/core/rag/indexing/test_index_processor_base.py +++ b/api/tests/unit_tests/core/rag/indexing/test_index_processor_base.py @@ -133,10 +133,10 @@ class TestBaseIndexProcessor: upload_b = SimpleNamespace(id="bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb", name="b.png") upload_tool = SimpleNamespace(id="tool-upload-id", name="tool.png") upload_remote = SimpleNamespace(id="remote-upload-id", name="remote.png") - db_query = Mock() - db_query.where.return_value.all.return_value = [upload_a, upload_b, upload_tool, upload_remote] + scalars_result = Mock() + scalars_result.all.return_value = [upload_a, upload_b, upload_tool, upload_remote] db_session = Mock() - db_session.query.return_value = db_query + db_session.scalars.return_value = scalars_result with ( patch.object(processor, "_extract_markdown_images", return_value=images), @@ -170,10 +170,10 @@ class TestBaseIndexProcessor: def test_get_content_files_ignores_missing_upload_records(self, processor: _ForwardingBaseIndexProcessor) -> None: document = Document(page_content="ignored", metadata={"document_id": "doc-1", "dataset_id": "ds-1"}) images = ["/files/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/image-preview"] - db_query = Mock() - db_query.where.return_value.all.return_value = [] + scalars_result = Mock() + scalars_result.all.return_value = [] db_session = Mock() - db_session.query.return_value = db_query + db_session.scalars.return_value = scalars_result with ( patch.object(processor, "_extract_markdown_images", return_value=images), @@ -259,20 +259,16 @@ class TestBaseIndexProcessor: assert processor._download_image("https://example.com/image.png", current_user=Mock()) is None def test_download_tool_file_returns_none_when_not_found(self, processor: _ForwardingBaseIndexProcessor) -> None: - db_query = Mock() - db_query.where.return_value.first.return_value = None db_session = Mock() - db_session.query.return_value = db_query + db_session.get.return_value = None with patch("core.rag.index_processor.index_processor_base.db.session", db_session): assert processor._download_tool_file("tool-id", current_user=Mock()) is None def test_download_tool_file_uploads_file_when_found(self, processor: _ForwardingBaseIndexProcessor) -> None: tool_file = SimpleNamespace(file_key="k1", name="tool.png", mimetype="image/png") - db_query = Mock() - db_query.where.return_value.first.return_value = tool_file db_session = Mock() - db_session.query.return_value = db_query + db_session.get.return_value = tool_file mock_db = Mock() mock_db.session = db_session mock_db.engine = Mock() diff --git a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py index b011ade8846..450e7166360 100644 --- a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py +++ b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py @@ -53,6 +53,7 @@ from typing import Any from unittest.mock import MagicMock, Mock, patch import pytest +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm.exc import ObjectDeletedError from core.errors.error import ProviderTokenNotInitError @@ -61,9 +62,8 @@ from core.indexing_runner import ( DocumentIsPausedError, IndexingRunner, ) -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.models.document import ChildDocument, Document -from dify_graph.model_runtime.entities.model_entities import ModelType from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, DatasetProcessRule from models.dataset import Document as DatasetDocument @@ -76,7 +76,7 @@ from models.dataset import Document as DatasetDocument def create_mock_dataset( dataset_id: str | None = None, tenant_id: str | None = None, - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, embedding_provider: str = "openai", embedding_model: str = "text-embedding-ada-002", ) -> Mock: @@ -445,7 +445,7 @@ class TestIndexingRunnerTransform: """Mock all external dependencies for transform tests.""" with ( patch("core.indexing_runner.db") as mock_db, - patch("core.indexing_runner.ModelManager") as mock_model_manager, + patch("core.indexing_runner.ModelManager.for_tenant") as mock_model_manager, ): yield { "db": mock_db, @@ -458,7 +458,7 @@ class TestIndexingRunnerTransform: dataset = Mock(spec=Dataset) dataset.id = str(uuid.uuid4()) dataset.tenant_id = str(uuid.uuid4()) - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.embedding_model_provider = "openai" dataset.embedding_model = "text-embedding-ada-002" return dataset @@ -482,7 +482,8 @@ class TestIndexingRunnerTransform: # Arrange runner = IndexingRunner() mock_embedding_instance = MagicMock() - runner.model_manager.get_model_instance.return_value = mock_embedding_instance + model_manager = mock_dependencies["model_manager"].return_value + model_manager.get_model_instance.return_value = mock_embedding_instance mock_processor = MagicMock() transformed_docs = [ @@ -509,7 +510,7 @@ class TestIndexingRunnerTransform: assert len(result) == 2 assert result[0].page_content == "Chunk 1" assert result[1].page_content == "Chunk 2" - runner.model_manager.get_model_instance.assert_called_once_with( + model_manager.get_model_instance.assert_called_once_with( tenant_id=sample_dataset.tenant_id, provider=sample_dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, @@ -521,7 +522,8 @@ class TestIndexingRunnerTransform: """Test transformation with economy indexing (no embeddings).""" # Arrange runner = IndexingRunner() - sample_dataset.indexing_technique = "economy" + model_manager = mock_dependencies["model_manager"].return_value + sample_dataset.indexing_technique = IndexTechniqueType.ECONOMY mock_processor = MagicMock() transformed_docs = [ @@ -539,14 +541,15 @@ class TestIndexingRunnerTransform: # Assert assert len(result) == 1 - runner.model_manager.get_model_instance.assert_not_called() + model_manager.get_model_instance.assert_not_called() def test_transform_with_custom_segmentation(self, mock_dependencies, sample_dataset, sample_text_docs): """Test transformation with custom segmentation rules.""" # Arrange runner = IndexingRunner() mock_embedding_instance = MagicMock() - runner.model_manager.get_model_instance.return_value = mock_embedding_instance + model_manager = mock_dependencies["model_manager"].return_value + model_manager.get_model_instance.return_value = mock_embedding_instance mock_processor = MagicMock() transformed_docs = [Document(page_content="Custom chunk", metadata={"doc_id": "custom1", "doc_hash": "hash1"})] @@ -586,7 +589,7 @@ class TestIndexingRunnerLoad: """Mock all external dependencies for load tests.""" with ( patch("core.indexing_runner.db") as mock_db, - patch("core.indexing_runner.ModelManager") as mock_model_manager, + patch("core.indexing_runner.ModelManager.for_tenant") as mock_model_manager, patch("core.indexing_runner.current_app") as mock_app, patch("core.indexing_runner.threading.Thread") as mock_thread, patch("core.indexing_runner.concurrent.futures.ThreadPoolExecutor") as mock_executor, @@ -605,7 +608,7 @@ class TestIndexingRunnerLoad: dataset = Mock(spec=Dataset) dataset.id = str(uuid.uuid4()) dataset.tenant_id = str(uuid.uuid4()) - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.embedding_model_provider = "openai" dataset.embedding_model = "text-embedding-ada-002" return dataset @@ -645,7 +648,8 @@ class TestIndexingRunnerLoad: runner = IndexingRunner() mock_embedding_instance = MagicMock() mock_embedding_instance.get_text_embedding_num_tokens.return_value = 100 - runner.model_manager.get_model_instance.return_value = mock_embedding_instance + model_manager = mock_dependencies["model_manager"].return_value + model_manager.get_model_instance.return_value = mock_embedding_instance mock_processor = MagicMock() @@ -664,7 +668,7 @@ class TestIndexingRunnerLoad: runner._load(mock_processor, sample_dataset, sample_dataset_document, sample_documents) # Assert - runner.model_manager.get_model_instance.assert_called_once() + model_manager.get_model_instance.assert_called_once() # Verify executor was used for parallel processing assert mock_executor_instance.submit.called @@ -674,7 +678,7 @@ class TestIndexingRunnerLoad: """Test loading with economy indexing (keyword only).""" # Arrange runner = IndexingRunner() - sample_dataset.indexing_technique = "economy" + sample_dataset.indexing_technique = IndexTechniqueType.ECONOMY mock_processor = MagicMock() @@ -701,7 +705,7 @@ class TestIndexingRunnerLoad: # Arrange runner = IndexingRunner() sample_dataset_document.doc_form = IndexStructureType.PARENT_CHILD_INDEX - sample_dataset.indexing_technique = "high_quality" + sample_dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY # Add child documents for doc in sample_documents: @@ -714,7 +718,8 @@ class TestIndexingRunnerLoad: mock_embedding_instance = MagicMock() mock_embedding_instance.get_text_embedding_num_tokens.return_value = 50 - runner.model_manager.get_model_instance.return_value = mock_embedding_instance + model_manager = mock_dependencies["model_manager"].return_value + model_manager.get_model_instance.return_value = mock_embedding_instance mock_processor = MagicMock() @@ -754,7 +759,7 @@ class TestIndexingRunnerRun: with ( patch("core.indexing_runner.db") as mock_db, patch("core.indexing_runner.IndexProcessorFactory") as mock_factory, - patch("core.indexing_runner.ModelManager") as mock_model_manager, + patch("core.indexing_runner.ModelManager.for_tenant") as mock_model_manager, patch("core.indexing_runner.storage") as mock_storage, patch("core.indexing_runner.threading.Thread") as mock_thread, ): @@ -795,7 +800,7 @@ class TestIndexingRunnerRun: mock_dataset = Mock(spec=Dataset) mock_dataset.id = doc.dataset_id mock_dataset.tenant_id = doc.tenant_id - mock_dataset.indexing_technique = "economy" + mock_dataset.indexing_technique = IndexTechniqueType.ECONOMY mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_dataset mock_process_rule = Mock(spec=DatasetProcessRule) @@ -949,7 +954,7 @@ class TestIndexingRunnerRun: mock_dependencies["db"].session.get.side_effect = get_side_effect mock_dataset = Mock(spec=Dataset) - mock_dataset.indexing_technique = "economy" + mock_dataset.indexing_technique = IndexTechniqueType.ECONOMY mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_dataset mock_process_rule = Mock(spec=DatasetProcessRule) diff --git a/api/tests/unit_tests/core/rag/rerank/test_reranker.py b/api/tests/unit_tests/core/rag/rerank/test_reranker.py index b150d677f1a..c279b00d3bc 100644 --- a/api/tests/unit_tests/core/rag/rerank/test_reranker.py +++ b/api/tests/unit_tests/core/rag/rerank/test_reranker.py @@ -17,6 +17,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, Mock, patch import pytest +from graphon.model_runtime.entities.rerank_entities import RerankDocument, RerankResult from core.model_manager import ModelInstance from core.rag.index_processor.constant.doc_type import DocType @@ -28,7 +29,6 @@ from core.rag.rerank.rerank_factory import RerankRunnerFactory from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.rerank.rerank_type import RerankMode from core.rag.rerank.weight_rerank import WeightRerankRunner -from dify_graph.model_runtime.entities.rerank_entities import RerankDocument, RerankResult def create_mock_model_instance() -> ModelInstance: @@ -57,7 +57,7 @@ class TestRerankModelRunner: @pytest.fixture(autouse=True) def mock_model_manager(self): """Auto-use fixture to patch ModelManager for all tests in this class.""" - with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant", autospec=True) as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False yield mock_mm @@ -352,12 +352,14 @@ class TestRerankModelRunner: # Assert: Empty result is returned assert len(result) == 0 - def test_user_parameter_passed_to_model(self, rerank_runner, mock_model_instance, sample_documents): - """Test that user parameter is passed to model invocation. + def test_run_uses_bound_model_instance( + self, rerank_runner, mock_model_instance, sample_documents, mock_model_manager + ): + """Test that rerank uses the bound model instance directly. Verifies: - - User ID is correctly forwarded to the model - - Model receives all expected parameters + - The injected model instance is used for invocation + - No late rebinding occurs through ModelManager.get_model_instance """ # Arrange: Mock rerank result mock_rerank_result = RerankResult( @@ -368,16 +370,18 @@ class TestRerankModelRunner: ) mock_model_instance.invoke_rerank.return_value = mock_rerank_result - # Act: Run reranking with user parameter + # Act: Run reranking result = rerank_runner.run( query="test", documents=sample_documents, - user="user123", ) - # Assert: User parameter is passed to model + # Assert: The injected model instance is invoked directly. + assert len(result) == 1 + mock_model_manager.return_value.get_model_instance.assert_not_called() call_kwargs = mock_model_instance.invoke_rerank.call_args.kwargs - assert call_kwargs["user"] == "user123" + assert call_kwargs["query"] == "test" + assert "user" not in call_kwargs class _ForwardingBaseRerankRunner(BaseRerankRunner): @@ -387,7 +391,6 @@ class _ForwardingBaseRerankRunner(BaseRerankRunner): documents: list[Document], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: return super().run( @@ -395,7 +398,6 @@ class _ForwardingBaseRerankRunner(BaseRerankRunner): documents=documents, score_threshold=score_threshold, top_n=top_n, - user=user, query_type=query_type, ) @@ -424,7 +426,7 @@ class TestRerankModelRunnerMultimodal: Document(page_content="doc", metadata={"doc_id": "doc1"}, provider="dify"), ] - with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant") as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False result = rerank_runner.run(query="image-file-id", documents=documents, query_type=QueryType.IMAGE_QUERY) @@ -441,7 +443,7 @@ class TestRerankModelRunnerMultimodal: ) with ( - patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm, + patch("core.rag.rerank.rerank_model.ModelManager.for_tenant") as mock_mm, patch.object( rerank_runner, "fetch_multimodal_rerank", @@ -471,12 +473,10 @@ class TestRerankModelRunnerMultimodal: metadata={}, provider="external", ) - query = Mock() - query.where.return_value.first.return_value = SimpleNamespace(key="image-key") rerank_result = RerankResult(model="rerank-model", docs=[]) with ( - patch("core.rag.rerank.rerank_model.db.session.query", return_value=query), + patch("core.rag.rerank.rerank_model.db.session.get", return_value=SimpleNamespace(key="image-key")), patch("core.rag.rerank.rerank_model.storage.load_once", return_value=b"image-bytes") as mock_load_once, patch.object( rerank_runner, @@ -502,12 +502,10 @@ class TestRerankModelRunnerMultimodal: metadata={"doc_id": "img-missing", "doc_type": DocType.IMAGE}, provider="dify", ) - query = Mock() - query.where.return_value.first.return_value = None rerank_result = RerankResult(model="rerank-model", docs=[]) with ( - patch("core.rag.rerank.rerank_model.db.session.query", return_value=query), + patch("core.rag.rerank.rerank_model.db.session.get", return_value=None), patch.object( rerank_runner, "fetch_text_rerank", @@ -531,16 +529,16 @@ class TestRerankModelRunnerMultimodal: metadata={"doc_id": "txt-1", "doc_type": DocType.TEXT}, provider="dify", ) - query_chain = Mock() - query_chain.where.return_value.first.return_value = SimpleNamespace(key="query-image-key") rerank_result = RerankResult( model="rerank-model", docs=[RerankDocument(index=0, text="text-content", score=0.77)], ) mock_model_instance.invoke_multimodal_rerank.return_value = rerank_result + session = MagicMock() + session.get.return_value = SimpleNamespace(key="query-image-key") with ( - patch("core.rag.rerank.rerank_model.db.session.query", return_value=query_chain), + patch("core.rag.rerank.rerank_model.db.session", session), patch("core.rag.rerank.rerank_model.storage.load_once", return_value=b"query-image-bytes"), ): result, unique_documents = rerank_runner.fetch_multimodal_rerank( @@ -548,7 +546,6 @@ class TestRerankModelRunnerMultimodal: documents=[text_doc], score_threshold=0.2, top_n=2, - user="user-1", query_type=QueryType.IMAGE_QUERY, ) @@ -557,13 +554,10 @@ class TestRerankModelRunnerMultimodal: invoke_kwargs = mock_model_instance.invoke_multimodal_rerank.call_args.kwargs assert invoke_kwargs["query"]["content_type"] == DocType.IMAGE assert invoke_kwargs["docs"][0]["content"] == "text-content" - assert invoke_kwargs["user"] == "user-1" + assert "user" not in invoke_kwargs def test_fetch_multimodal_rerank_raises_when_query_image_not_found(self, rerank_runner): - query_chain = Mock() - query_chain.where.return_value.first.return_value = None - - with patch("core.rag.rerank.rerank_model.db.session.query", return_value=query_chain): + with patch("core.rag.rerank.rerank_model.db.session.get", return_value=None): with pytest.raises(ValueError, match="Upload file not found for query"): rerank_runner.fetch_multimodal_rerank( query="missing-upload-id", @@ -595,7 +589,7 @@ class TestWeightRerankRunner: @pytest.fixture def mock_model_manager(self): """Mock ModelManager for embedding model.""" - with patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager: + with patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant") as mock_manager: yield mock_manager @pytest.fixture @@ -1145,7 +1139,7 @@ class TestRerankIntegration: @pytest.fixture(autouse=True) def mock_model_manager(self): """Auto-use fixture to patch ModelManager for all tests in this class.""" - with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant", autospec=True) as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False yield mock_mm @@ -1257,7 +1251,7 @@ class TestRerankEdgeCases: @pytest.fixture(autouse=True) def mock_model_manager(self): """Auto-use fixture to patch ModelManager for all tests in this class.""" - with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant", autospec=True) as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False yield mock_mm @@ -1527,7 +1521,7 @@ class TestRerankEdgeCases: # Mock dependencies with ( patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba, - patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager, + patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant") as mock_manager, patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache, ): mock_handler = MagicMock() @@ -1598,7 +1592,7 @@ class TestRerankPerformance: @pytest.fixture(autouse=True) def mock_model_manager(self): """Auto-use fixture to patch ModelManager for all tests in this class.""" - with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant", autospec=True) as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False yield mock_mm @@ -1673,7 +1667,7 @@ class TestRerankPerformance: with ( patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba, - patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager, + patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant") as mock_manager, patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache, ): mock_handler = MagicMock() @@ -1715,7 +1709,7 @@ class TestRerankErrorHandling: @pytest.fixture(autouse=True) def mock_model_manager(self): """Auto-use fixture to patch ModelManager for all tests in this class.""" - with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant", autospec=True) as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False yield mock_mm @@ -1824,7 +1818,7 @@ class TestRerankErrorHandling: with ( patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba, - patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager, + patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant") as mock_manager, patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache, ): mock_handler = MagicMock() diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index 665e98bd9c6..fee7b168ad0 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -6,6 +6,8 @@ from uuid import uuid4 import pytest from flask import Flask, current_app +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.model_entities import ModelFeature from sqlalchemy import column from core.app.app_config.entities import ( @@ -35,9 +37,8 @@ from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.workflow.nodes.knowledge_retrieval import exc from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.model_runtime.entities.model_entities import ModelFeature from models.dataset import Dataset +from models.enums import CreatorUserRole # ==================== Helper Functions ==================== @@ -3747,6 +3748,24 @@ class TestDatasetRetrievalAdditionalHelpers: mock_session.add_all.assert_called() mock_session.commit.assert_called() + def test_on_query_normalizes_workflow_end_user_role(self, retrieval: DatasetRetrieval) -> None: + with patch("core.rag.retrieval.dataset_retrieval.db.session") as mock_session: + retrieval._on_query( + query="python", + attachment_ids=None, + dataset_ids=["d1"], + app_id="a1", + user_from="end-user", + user_id="u1", + ) + + mock_session.add_all.assert_called_once() + added_queries = mock_session.add_all.call_args.args[0] + + assert len(added_queries) == 1 + assert added_queries[0].created_by_role == CreatorUserRole.END_USER + mock_session.commit.assert_called_once() + def test_handle_invoke_result(self, retrieval: DatasetRetrieval) -> None: usage = LLMUsage.empty_usage() chunk_1 = SimpleNamespace( @@ -3836,7 +3855,7 @@ class TestDatasetRetrievalAdditionalHelpers: model_instance.model_type_instance.get_model_schema.return_value = Mock() with ( - patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_manager, + patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_manager, patch("core.rag.retrieval.dataset_retrieval.ModelConfigWithCredentialsEntity") as mock_cfg_entity, ): mock_manager.return_value.get_model_instance.return_value = model_instance @@ -3952,11 +3971,10 @@ class TestDatasetRetrievalAdditionalHelpers: ) def test_get_metadata_filter_condition(self, retrieval: DatasetRetrieval) -> None: - db_query = Mock() - db_query.where.return_value = db_query - db_query.all.return_value = [SimpleNamespace(dataset_id="d1", id="doc-1")] + scalars_result = Mock() + scalars_result.all.return_value = [SimpleNamespace(dataset_id="d1", id="doc-1")] - with patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query): + with patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=scalars_result): mapping, condition = retrieval.get_metadata_filter_condition( dataset_ids=["d1"], query="python", @@ -3972,7 +3990,7 @@ class TestDatasetRetrievalAdditionalHelpers: automatic_filters = [{"condition": "contains", "metadata_name": "author", "value": "Alice"}] with ( - patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query), + patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=scalars_result), patch.object(retrieval, "_automatic_metadata_filter_func", return_value=automatic_filters), ): mapping, condition = retrieval.get_metadata_filter_condition( @@ -3993,7 +4011,7 @@ class TestDatasetRetrievalAdditionalHelpers: logical_operator="and", conditions=[AppCondition(name="author", comparison_operator="contains", value="{{name}}")], ) - with patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query): + with patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=scalars_result): mapping, condition = retrieval.get_metadata_filter_condition( dataset_ids=["d1"], query="python", @@ -4008,7 +4026,7 @@ class TestDatasetRetrievalAdditionalHelpers: assert condition is not None assert condition.conditions[0].value == "Alice" - with patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query): + with patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=scalars_result): with pytest.raises(ValueError, match="Invalid metadata filtering mode"): retrieval.get_metadata_filter_condition( dataset_ids=["d1"], @@ -4222,11 +4240,12 @@ class TestKnowledgeRetrievalCoverage: with ( patch.object(retrieval, "_check_knowledge_rate_limit"), patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="dataset-1")]), - patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager, + patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_model_manager, ): mock_model_manager.return_value.get_model_instance.return_value = model_instance with pytest.raises(Exception) as exc_info: retrieval.knowledge_retrieval(request) + mock_model_manager.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") assert error_cls in type(exc_info.value).__name__ @@ -4279,9 +4298,13 @@ class TestRetrieveCoverage: ), ) model_config = self._build_model_config() - model_config.provider_model_bundle.model_type_instance.get_model_schema.return_value = None - with patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager: - mock_model_manager.return_value.get_model_instance.return_value = Mock() + model_instance = Mock() + model_instance.model_name = "gpt-4" + model_instance.credentials = {"api_key": "secret"} + model_instance.provider_model_bundle = Mock() + model_instance.model_type_instance.get_model_schema.return_value = None + with patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_model_manager: + mock_model_manager.return_value.get_model_instance.return_value = model_instance result = retrieval.retrieve( app_id="app-1", user_id="user-1", @@ -4294,8 +4317,58 @@ class TestRetrieveCoverage: hit_callback=Mock(), message_id="m1", ) + mock_model_manager.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") assert result == (None, []) + def test_retrieve_uses_bound_model_instance_schema_and_updates_model_config( + self, retrieval: DatasetRetrieval + ) -> None: + config = DatasetEntity( + dataset_ids=["d1"], + retrieve_config=DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE, + metadata_filtering_mode="disabled", + ), + ) + model_config = self._build_model_config(features=[]) + model_config.provider_model_bundle.model_type_instance.get_model_schema.return_value = None + bound_schema = SimpleNamespace(features=[ModelFeature.TOOL_CALL]) + bound_bundle = Mock() + bound_model_instance = Mock() + bound_model_instance.model_name = "gpt-4" + bound_model_instance.credentials = {"api_key": "secret"} + bound_model_instance.provider_model_bundle = bound_bundle + bound_model_instance.model_type_instance.get_model_schema.return_value = bound_schema + + with ( + patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_model_manager, + patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]), + patch.object(retrieval, "get_metadata_filter_condition", return_value=(None, None)), + patch.object(retrieval, "single_retrieve", return_value=[]) as mock_single_retrieve, + ): + mock_model_manager.return_value.get_model_instance.return_value = bound_model_instance + context, files = retrieval.retrieve( + app_id="app-1", + user_id="user-1", + tenant_id="tenant-1", + model_config=model_config, + config=config, + query="python", + invoke_from=InvokeFrom.WEB_APP, + show_retrieve_source=False, + hit_callback=Mock(), + message_id="m1", + ) + + mock_model_manager.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") + mock_single_retrieve.assert_called_once() + assert mock_single_retrieve.call_args.args[8] == PlanningStrategy.ROUTER + assert model_config.provider_model_bundle is bound_bundle + assert model_config.credentials == {"api_key": "secret"} + assert model_config.model_schema is bound_schema + assert context == "" + assert files == [] + def test_single_strategy_with_external_documents(self, retrieval: DatasetRetrieval) -> None: retrieve_config = DatasetRetrieveConfigEntity( retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE, @@ -4312,12 +4385,17 @@ class TestRetrieveCoverage: extra={"title": "External", "dataset_name": "External DS"}, ) with ( - patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager, + patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_model_manager, patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]), patch.object(retrieval, "get_metadata_filter_condition", return_value=(None, None)), patch.object(retrieval, "single_retrieve", return_value=[external_doc]), ): - mock_model_manager.return_value.get_model_instance.return_value = Mock() + bound_model_instance = Mock() + bound_model_instance.model_name = "gpt-4" + bound_model_instance.credentials = {} + bound_model_instance.provider_model_bundle = Mock() + bound_model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace(features=[]) + mock_model_manager.return_value.get_model_instance.return_value = bound_model_instance context, files = retrieval.retrieve( app_id="app-1", user_id="user-1", @@ -4402,7 +4480,7 @@ class TestRetrieveCoverage: hit_callback = Mock() with ( - patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager, + patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_model_manager, patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]), patch.object(retrieval, "get_metadata_filter_condition", return_value=(None, None)), patch.object(retrieval, "multiple_retrieve", return_value=[external_doc, dify_doc]), @@ -4413,7 +4491,14 @@ class TestRetrieveCoverage: patch("core.rag.retrieval.dataset_retrieval.sign_upload_file", return_value="https://signed"), patch("core.rag.retrieval.dataset_retrieval.db.session.execute") as mock_execute, ): - mock_model_manager.return_value.get_model_instance.return_value = Mock() + bound_model_instance = Mock() + bound_model_instance.model_name = "gpt-4" + bound_model_instance.credentials = {} + bound_model_instance.provider_model_bundle = Mock() + bound_model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace( + features=[ModelFeature.TOOL_CALL] + ) + mock_model_manager.return_value.get_model_instance.return_value = bound_model_instance mock_execute.side_effect = [execute_attachments, execute_docs, execute_datasets] context, files = retrieval.retrieve( app_id="app-1", @@ -4800,8 +4885,8 @@ class TestInternalHooksCoverage: dataset_docs = [ SimpleNamespace(id="doc-a", doc_form=IndexStructureType.PARENT_CHILD_INDEX), SimpleNamespace(id="doc-b", doc_form=IndexStructureType.PARENT_CHILD_INDEX), - SimpleNamespace(id="doc-c", doc_form="qa_model"), - SimpleNamespace(id="doc-d", doc_form="qa_model"), + SimpleNamespace(id="doc-c", doc_form=IndexStructureType.QA_INDEX), + SimpleNamespace(id="doc-d", doc_form=IndexStructureType.QA_INDEX), ] child_chunks = [SimpleNamespace(index_node_id="idx-a", segment_id="seg-a")] segments = [SimpleNamespace(index_node_id="idx-c", id="seg-c")] diff --git a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py index cfa9094e129..5a2ecb82204 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py @@ -1,7 +1,8 @@ from unittest.mock import Mock +from graphon.model_runtime.entities.llm_entities import LLMUsage + from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter -from dify_graph.model_runtime.entities.llm_entities import LLMUsage class TestFunctionCallMultiDatasetRouter: diff --git a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py index e4295637392..539ac0f849f 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py @@ -1,10 +1,12 @@ from types import SimpleNamespace from unittest.mock import Mock, patch +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessageRole +from graphon.model_runtime.entities.model_entities import ModelType + from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole class TestReactMultiDatasetRouter: @@ -87,6 +89,7 @@ class TestReactMultiDatasetRouter: model_config = Mock() model_config.mode = "chat" model_config.parameters = {"temperature": 0.1} + model_instance = Mock() usage = LLMUsage.empty_usage() tools = [Mock(name="dataset-1"), Mock(name="dataset-2")] tools[0].name = "dataset-1" @@ -108,13 +111,14 @@ class TestReactMultiDatasetRouter: dataset_id, returned_usage = router._react_invoke( query="python", model_config=model_config, - model_instance=Mock(), + model_instance=model_instance, tools=tools, user_id="u1", tenant_id="t1", ) mock_chat_prompt.assert_called_once() + assert mock_prompt_transform.return_value.get_prompt.call_args.kwargs["model_instance"] is model_instance assert dataset_id == "dataset-2" assert returned_usage == usage @@ -162,7 +166,11 @@ class TestReactMultiDatasetRouter: model_instance = Mock() model_instance.invoke_llm.return_value = iter([chunk]) - with patch("core.rag.retrieval.router.multi_dataset_react_route.deduct_llm_quota") as mock_deduct: + with ( + patch("core.rag.retrieval.router.multi_dataset_react_route.ModelManager.for_tenant") as mock_manager, + patch("core.rag.retrieval.router.multi_dataset_react_route.deduct_llm_quota") as mock_deduct, + ): + mock_manager.return_value.get_model_instance.return_value = model_instance text, returned_usage = router._invoke_llm( completion_param={"temperature": 0.1}, model_instance=model_instance, @@ -174,6 +182,13 @@ class TestReactMultiDatasetRouter: assert text == "part" assert returned_usage == usage + mock_manager.assert_called_once_with(tenant_id="t1", user_id="u1") + mock_manager.return_value.get_model_instance.assert_called_once_with( + tenant_id="t1", + provider=model_instance.provider, + model_type=ModelType.LLM, + model=model_instance.model_name, + ) mock_deduct.assert_called_once() def test_handle_invoke_result_with_empty_usage(self) -> None: diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py index e7eecfa2972..e229d5fc1a5 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py @@ -9,9 +9,10 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowType from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository -from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowType from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py index 2a83a4e802f..7dbf78d0f0c 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py @@ -9,14 +9,14 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest - -from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository -from dify_graph.entities.workflow_node_execution import ( +from graphon.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig +from graphon.enums import BuiltinNodeTypes + +from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository +from core.repositories.factory import OrderConfig from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.workflow import WorkflowNodeExecutionTriggeredFrom @@ -181,10 +181,10 @@ class TestCeleryWorkflowNodeExecutionRepository: repo.save(sample_workflow_node_execution) @patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task") - def test_get_by_workflow_run_from_cache( + def test_get_by_workflow_execution_from_cache( self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution ): - """Test that get_by_workflow_run retrieves executions from cache.""" + """Test that get_by_workflow_execution retrieves executions from cache.""" repo = CeleryWorkflowNodeExecutionRepository( session_factory=mock_session_factory, user=mock_account, @@ -195,18 +195,18 @@ class TestCeleryWorkflowNodeExecutionRepository: # Save execution to cache first repo.save(sample_workflow_node_execution) - workflow_run_id = sample_workflow_node_execution.workflow_execution_id + workflow_execution_id = sample_workflow_node_execution.workflow_execution_id order_config = OrderConfig(order_by=["index"], order_direction="asc") - result = repo.get_by_workflow_run(workflow_run_id, order_config) + result = repo.get_by_workflow_execution(workflow_execution_id, order_config) # Verify results were retrieved from cache assert len(result) == 1 assert result[0].id == sample_workflow_node_execution.id assert result[0] is sample_workflow_node_execution - def test_get_by_workflow_run_without_order_config(self, mock_session_factory, mock_account): - """Test get_by_workflow_run without order configuration.""" + def test_get_by_workflow_execution_without_order_config(self, mock_session_factory, mock_account): + """Test get_by_workflow_execution without order configuration.""" repo = CeleryWorkflowNodeExecutionRepository( session_factory=mock_session_factory, user=mock_account, @@ -214,7 +214,7 @@ class TestCeleryWorkflowNodeExecutionRepository: triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) - result = repo.get_by_workflow_run("workflow-run-id") + result = repo.get_by_workflow_execution("workflow-run-id") # Should return empty list since nothing in cache assert len(result) == 0 @@ -236,7 +236,7 @@ class TestCeleryWorkflowNodeExecutionRepository: assert sample_workflow_node_execution.id in repo._execution_cache # Test retrieving from cache - result = repo.get_by_workflow_run(sample_workflow_node_execution.workflow_execution_id) + result = repo.get_by_workflow_execution(sample_workflow_node_execution.workflow_execution_id) assert len(result) == 1 assert result[0].id == sample_workflow_node_execution.id @@ -251,12 +251,12 @@ class TestCeleryWorkflowNodeExecutionRepository: ) # Create multiple executions for the same workflow - workflow_run_id = str(uuid4()) + workflow_execution_id = str(uuid4()) exec1 = WorkflowNodeExecution( id=str(uuid4()), node_execution_id=str(uuid4()), workflow_id=str(uuid4()), - workflow_execution_id=workflow_run_id, + workflow_execution_id=workflow_execution_id, index=1, node_id="node1", node_type=BuiltinNodeTypes.START, @@ -269,7 +269,7 @@ class TestCeleryWorkflowNodeExecutionRepository: id=str(uuid4()), node_execution_id=str(uuid4()), workflow_id=str(uuid4()), - workflow_execution_id=workflow_run_id, + workflow_execution_id=workflow_execution_id, index=2, node_id="node2", node_type=BuiltinNodeTypes.LLM, @@ -285,10 +285,10 @@ class TestCeleryWorkflowNodeExecutionRepository: # Verify both are cached and mapped assert len(repo._execution_cache) == 2 - assert len(repo._workflow_execution_mapping[workflow_run_id]) == 2 + assert len(repo._workflow_execution_mapping[workflow_execution_id]) == 2 # Test retrieval - result = repo.get_by_workflow_run(workflow_run_id) + result = repo.get_by_workflow_execution(workflow_execution_id) assert len(result) == 2 @patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task") @@ -302,12 +302,12 @@ class TestCeleryWorkflowNodeExecutionRepository: ) # Create executions with different indices - workflow_run_id = str(uuid4()) + workflow_execution_id = str(uuid4()) exec1 = WorkflowNodeExecution( id=str(uuid4()), node_execution_id=str(uuid4()), workflow_id=str(uuid4()), - workflow_execution_id=workflow_run_id, + workflow_execution_id=workflow_execution_id, index=2, node_id="node2", node_type=BuiltinNodeTypes.START, @@ -320,7 +320,7 @@ class TestCeleryWorkflowNodeExecutionRepository: id=str(uuid4()), node_execution_id=str(uuid4()), workflow_id=str(uuid4()), - workflow_execution_id=workflow_run_id, + workflow_execution_id=workflow_execution_id, index=1, node_id="node1", node_type=BuiltinNodeTypes.LLM, @@ -336,14 +336,14 @@ class TestCeleryWorkflowNodeExecutionRepository: # Test ascending order order_config = OrderConfig(order_by=["index"], order_direction="asc") - result = repo.get_by_workflow_run(workflow_run_id, order_config) + result = repo.get_by_workflow_execution(workflow_execution_id, order_config) assert len(result) == 2 assert result[0].index == 1 assert result[1].index == 2 # Test descending order order_config = OrderConfig(order_by=["index"], order_direction="desc") - result = repo.get_by_workflow_run(workflow_run_id, order_config) + result = repo.get_by_workflow_execution(workflow_execution_id, order_config) assert len(result) == 2 assert result[0].index == 2 assert result[1].index == 1 diff --git a/api/tests/unit_tests/core/repositories/test_factory.py b/api/tests/unit_tests/core/repositories/test_factory.py index fe9eed03071..48327c39134 100644 --- a/api/tests/unit_tests/core/repositories/test_factory.py +++ b/api/tests/unit_tests/core/repositories/test_factory.py @@ -11,9 +11,12 @@ import pytest from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker -from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.repositories.factory import ( + DifyCoreRepositoryFactory, + RepositoryImportError, + WorkflowExecutionRepository, + WorkflowNodeExecutionRepository, +) from libs.module_loading import import_string from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py index 9af4d126647..0fc82dda530 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py @@ -7,6 +7,11 @@ from datetime import datetime from types import SimpleNamespace import pytest +from graphon.nodes.human_input.entities import ( + FormDefinition, + UserAction, +) +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from core.repositories.human_input_repository import ( HumanInputFormRecord, @@ -14,16 +19,13 @@ from core.repositories.human_input_repository import ( HumanInputFormSubmissionRepository, _WorkspaceMemberInfo, ) -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - FormDefinition, MemberRecipient, - UserAction, ) -from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from models.human_input import ( EmailExternalRecipientPayload, @@ -89,9 +91,9 @@ class TestHumanInputFormRepositoryImplHelpers: form_id="form-id", delivery_id="delivery-id", recipients_config=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(user_id="member-1"), + MemberRecipient(reference_id="member-1"), ExternalRecipient(email="external@example.com"), ], ), @@ -125,9 +127,9 @@ class TestHumanInputFormRepositoryImplHelpers: form_id="form-id", delivery_id="delivery-id", recipients_config=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(user_id="missing-member"), + MemberRecipient(reference_id="missing-member"), ExternalRecipient(email="external@example.com"), ], ), @@ -156,7 +158,7 @@ class TestHumanInputFormRepositoryImplHelpers: form_id="form-id", delivery_id="delivery-id", recipients_config=EmailRecipients( - whole_workspace=True, + include_bound_group=True, items=[], ), ) @@ -182,7 +184,7 @@ class TestHumanInputFormRepositoryImplHelpers: form_id="form-id", delivery_id="delivery-id", recipients_config=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ ExternalRecipient(email="external@example.com"), ExternalRecipient(email="external@example.com"), @@ -212,9 +214,9 @@ class TestHumanInputFormRepositoryImplHelpers: form_id="form-id", delivery_id="delivery-id", recipients_config=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(user_id="member-1"), + MemberRecipient(reference_id="member-1"), ExternalRecipient(email="shared@example.com"), ], ), @@ -243,7 +245,7 @@ class TestHumanInputFormRepositoryImplHelpers: method = EmailDeliveryMethod( config=EmailDeliveryConfig( recipients=EmailRecipients( - whole_workspace=True, + include_bound_group=True, items=[ExternalRecipient(email="external@example.com")], ), subject="subject", @@ -272,7 +274,7 @@ def _make_form_definition() -> str: inputs=[], user_actions=[UserAction(id="submit", title="Submit")], rendered_content="

hello

", - expiration_time=datetime.utcnow(), + expiration_time=naive_utc_now(), ).model_dump_json() @@ -421,22 +423,22 @@ class TestHumanInputFormRepositoryImplPublicMethods: ) session = _FakeSession(scalars_results=[form, [recipient]]) _patch_repo_session_factory(monkeypatch, session) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id", workflow_execution_id=form.workflow_run_id) - entity = repo.get_form(form.workflow_run_id, form.node_id) + entity = repo.get_form(form.node_id) assert entity is not None assert entity.id == form.id - assert entity.web_app_token == "token-123" + assert entity.submission_token == "token-123" assert len(entity.recipients) == 1 assert entity.recipients[0].token == "token-123" def test_get_form_returns_none_when_missing(self, monkeypatch: pytest.MonkeyPatch): session = _FakeSession(scalars_results=[None]) _patch_repo_session_factory(monkeypatch, session) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id", workflow_execution_id="run-1") - assert repo.get_form("run-1", "node-1") is None + assert repo.get_form("node-1") is None def test_get_form_returns_unsubmitted_state(self, monkeypatch: pytest.MonkeyPatch): form = _DummyForm( @@ -451,9 +453,9 @@ class TestHumanInputFormRepositoryImplPublicMethods: ) session = _FakeSession(scalars_results=[form, []]) _patch_repo_session_factory(monkeypatch, session) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id", workflow_execution_id=form.workflow_run_id) - entity = repo.get_form(form.workflow_run_id, form.node_id) + entity = repo.get_form(form.node_id) assert entity is not None assert entity.submitted is False @@ -476,9 +478,9 @@ class TestHumanInputFormRepositoryImplPublicMethods: ) session = _FakeSession(scalars_results=[form, []]) _patch_repo_session_factory(monkeypatch, session) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id", workflow_execution_id=form.workflow_run_id) - entity = repo.get_form(form.workflow_run_id, form.node_id) + entity = repo.get_form(form.node_id) assert entity is not None assert entity.submitted is True diff --git a/api/tests/unit_tests/core/repositories/test_human_input_repository.py b/api/tests/unit_tests/core/repositories/test_human_input_repository.py new file mode 100644 index 00000000000..8ff0e405874 --- /dev/null +++ b/api/tests/unit_tests/core/repositories/test_human_input_repository.py @@ -0,0 +1,679 @@ +from __future__ import annotations + +import dataclasses +import json +from collections.abc import Sequence +from datetime import datetime, timedelta +from types import SimpleNamespace +from typing import Any +from unittest.mock import MagicMock + +import pytest +from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus + +from core.repositories.human_input_repository import ( + FormCreateParams, + FormNotFoundError, + HumanInputFormRecord, + HumanInputFormRepositoryImpl, + HumanInputFormSubmissionRepository, + _HumanInputFormEntityImpl, + _HumanInputFormRecipientEntityImpl, + _InvalidTimeoutStatusError, + _WorkspaceMemberInfo, +) +from core.workflow.human_input_compat import ( + EmailDeliveryConfig, + EmailDeliveryMethod, + EmailRecipients, + ExternalRecipient, + MemberRecipient, + WebAppDeliveryMethod, +) +from libs.datetime_utils import naive_utc_now +from models.human_input import HumanInputFormRecipient, RecipientType + + +@pytest.fixture(autouse=True) +def _stub_select(monkeypatch: pytest.MonkeyPatch) -> None: + class _FakeSelect: + def join(self, *_args: Any, **_kwargs: Any) -> _FakeSelect: + return self + + def where(self, *_args: Any, **_kwargs: Any) -> _FakeSelect: + return self + + def options(self, *_args: Any, **_kwargs: Any) -> _FakeSelect: + return self + + monkeypatch.setattr("core.repositories.human_input_repository.select", lambda *_args, **_kwargs: _FakeSelect()) + monkeypatch.setattr("core.repositories.human_input_repository.selectinload", lambda *_args, **_kwargs: "_loader") + + +def _make_form_definition_json(*, include_expiration_time: bool) -> str: + payload: dict[str, Any] = { + "form_content": "hi", + "inputs": [], + "user_actions": [{"id": "submit", "title": "Submit"}], + "rendered_content": "

hi

", + } + if include_expiration_time: + payload["expiration_time"] = naive_utc_now() + return json.dumps(payload, default=str) + + +@dataclasses.dataclass +class _DummyForm: + id: str + workflow_run_id: str | None + node_id: str + tenant_id: str + app_id: str + form_definition: str + rendered_content: str + expiration_time: datetime + form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME + created_at: datetime = dataclasses.field(default_factory=naive_utc_now) + selected_action_id: str | None = None + submitted_data: str | None = None + submitted_at: datetime | None = None + submission_user_id: str | None = None + submission_end_user_id: str | None = None + completed_by_recipient_id: str | None = None + status: HumanInputFormStatus = HumanInputFormStatus.WAITING + + +@dataclasses.dataclass +class _DummyRecipient: + id: str + form_id: str + recipient_type: RecipientType + access_token: str | None + + +class _FakeScalarResult: + def __init__(self, obj: Any): + self._obj = obj + + def first(self) -> Any: + if isinstance(self._obj, list): + return self._obj[0] if self._obj else None + return self._obj + + def all(self) -> list[Any]: + if self._obj is None: + return [] + if isinstance(self._obj, list): + return list(self._obj) + return [self._obj] + + +class _FakeExecuteResult: + def __init__(self, rows: Sequence[tuple[Any, ...]]): + self._rows = list(rows) + + def all(self) -> list[tuple[Any, ...]]: + return list(self._rows) + + +class _FakeSession: + def __init__( + self, + *, + scalars_result: Any = None, + scalars_results: list[Any] | None = None, + forms: dict[str, _DummyForm] | None = None, + recipients: dict[str, _DummyRecipient] | None = None, + execute_rows: Sequence[tuple[Any, ...]] = (), + ): + if scalars_results is not None: + self._scalars_queue = list(scalars_results) + else: + self._scalars_queue = [scalars_result] + self._forms = forms or {} + self._recipients = recipients or {} + self._execute_rows = list(execute_rows) + self.added: list[Any] = [] + + def scalars(self, _query: Any) -> _FakeScalarResult: + if self._scalars_queue: + value = self._scalars_queue.pop(0) + else: + value = None + return _FakeScalarResult(value) + + def execute(self, _stmt: Any) -> _FakeExecuteResult: + return _FakeExecuteResult(self._execute_rows) + + def get(self, model_cls: Any, obj_id: str) -> Any: + name = getattr(model_cls, "__name__", "") + if name == "HumanInputForm": + return self._forms.get(obj_id) + if name == "HumanInputFormRecipient": + return self._recipients.get(obj_id) + return None + + def add(self, obj: Any) -> None: + self.added.append(obj) + + def add_all(self, objs: Sequence[Any]) -> None: + self.added.extend(list(objs)) + + def flush(self) -> None: + # Simulate DB default population for attributes referenced in entity wrappers. + for obj in self.added: + if hasattr(obj, "id") and obj.id in (None, ""): + obj.id = f"gen-{len(str(self.added))}" + if isinstance(obj, HumanInputFormRecipient) and obj.access_token is None: + if obj.recipient_type == RecipientType.CONSOLE: + obj.access_token = "token-console" + elif obj.recipient_type == RecipientType.BACKSTAGE: + obj.access_token = "token-backstage" + else: + obj.access_token = "token-webapp" + + def refresh(self, _obj: Any) -> None: + return None + + def begin(self) -> _FakeSession: + return self + + def __enter__(self) -> _FakeSession: + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + +class _SessionFactoryStub: + def __init__(self, session: _FakeSession): + self._session = session + + def create_session(self) -> _FakeSession: + return self._session + + +def _patch_session_factory(monkeypatch: pytest.MonkeyPatch, session: _FakeSession) -> None: + monkeypatch.setattr("core.repositories.human_input_repository.session_factory", _SessionFactoryStub(session)) + + +def test_recipient_entity_token_raises_when_missing() -> None: + recipient = SimpleNamespace(id="r1", access_token=None) + entity = _HumanInputFormRecipientEntityImpl(recipient) # type: ignore[arg-type] + with pytest.raises(AssertionError, match="access_token should not be None"): + _ = entity.token + + +def test_recipient_entity_id_and_token_success() -> None: + recipient = SimpleNamespace(id="r1", access_token="tok") + entity = _HumanInputFormRecipientEntityImpl(recipient) # type: ignore[arg-type] + assert entity.id == "r1" + assert entity.token == "tok" + + +def test_form_entity_submission_token_prefers_console_then_webapp_then_none() -> None: + form = _DummyForm( + id="f1", + workflow_run_id="run", + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="

x

", + expiration_time=naive_utc_now(), + ) + console = _DummyRecipient(id="c1", form_id=form.id, recipient_type=RecipientType.CONSOLE, access_token="ctok") + webapp = _DummyRecipient( + id="w1", form_id=form.id, recipient_type=RecipientType.STANDALONE_WEB_APP, access_token="wtok" + ) + + entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp, console]) # type: ignore[arg-type] + assert entity.submission_token == "ctok" + + entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp]) # type: ignore[arg-type] + assert entity.submission_token == "wtok" + + entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[]) # type: ignore[arg-type] + assert entity.submission_token is None + + +def test_form_entity_submitted_data_parsed() -> None: + form = _DummyForm( + id="f1", + workflow_run_id="run", + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="

x

", + expiration_time=naive_utc_now(), + submitted_data='{"a": 1}', + submitted_at=naive_utc_now(), + ) + entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[]) # type: ignore[arg-type] + assert entity.submitted is True + assert entity.submitted_data == {"a": 1} + assert entity.rendered_content == "

x

" + assert entity.selected_action_id is None + assert entity.status == HumanInputFormStatus.WAITING + + +def test_form_record_from_models_injects_expiration_time_when_missing() -> None: + expiration = naive_utc_now() + form = _DummyForm( + id="f1", + workflow_run_id=None, + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=False), + rendered_content="

x

", + expiration_time=expiration, + submitted_data='{"k": "v"}', + ) + record = HumanInputFormRecord.from_models(form, None) # type: ignore[arg-type] + assert record.definition.expiration_time == expiration + assert record.submitted_data == {"k": "v"} + assert record.submitted is False + + +def test_create_email_recipients_from_resolved_dedupes_and_skips_blank(monkeypatch: pytest.MonkeyPatch) -> None: + created: list[SimpleNamespace] = [] + + def fake_new(cls, form_id: str, delivery_id: str, payload: Any): # type: ignore[no-untyped-def] + recipient = SimpleNamespace( + id=f"{payload.TYPE}-{len(created)}", + form_id=form_id, + delivery_id=delivery_id, + recipient_type=payload.TYPE, + recipient_payload=payload.model_dump_json(), + access_token="tok", + ) + created.append(recipient) + return recipient + + monkeypatch.setattr("core.repositories.human_input_repository.HumanInputFormRecipient.new", classmethod(fake_new)) + + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + recipients = repo._create_email_recipients_from_resolved( # type: ignore[attr-defined] + form_id="f", + delivery_id="d", + members=[ + _WorkspaceMemberInfo(user_id="u1", email=""), + _WorkspaceMemberInfo(user_id="u2", email="a@example.com"), + _WorkspaceMemberInfo(user_id="u3", email="a@example.com"), + ], + external_emails=["", "a@example.com", "b@example.com", "b@example.com"], + ) + assert [r.recipient_type for r in recipients] == [RecipientType.EMAIL_MEMBER, RecipientType.EMAIL_EXTERNAL] + + +def test_query_workspace_members_by_ids_empty_returns_empty() -> None: + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + assert repo._query_workspace_members_by_ids(session=MagicMock(), restrict_to_user_ids=["", ""]) == [] + + +def test_query_workspace_members_by_ids_maps_rows() -> None: + session = _FakeSession(execute_rows=[("u1", "a@example.com"), ("u2", "b@example.com")]) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + rows = repo._query_workspace_members_by_ids(session=session, restrict_to_user_ids=["u1", "u2"]) + assert rows == [ + _WorkspaceMemberInfo(user_id="u1", email="a@example.com"), + _WorkspaceMemberInfo(user_id="u2", email="b@example.com"), + ] + + +def test_query_all_workspace_members_maps_rows() -> None: + session = _FakeSession(execute_rows=[("u1", "a@example.com")]) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + rows = repo._query_all_workspace_members(session=session) + assert rows == [_WorkspaceMemberInfo(user_id="u1", email="a@example.com")] + + +def test_repository_init_sets_tenant_id() -> None: + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + assert repo._tenant_id == "tenant" + + +def test_delivery_method_to_model_webapp_creates_delivery_and_recipient(monkeypatch: pytest.MonkeyPatch) -> None: + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: "del-1") + result = repo._delivery_method_to_model( + session=MagicMock(), form_id="form-1", delivery_method=WebAppDeliveryMethod() + ) + assert result.delivery.id == "del-1" + assert result.delivery.form_id == "form-1" + assert len(result.recipients) == 1 + assert result.recipients[0].recipient_type == RecipientType.STANDALONE_WEB_APP + + +def test_delivery_method_to_model_email_uses_build_email_recipients(monkeypatch: pytest.MonkeyPatch) -> None: + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: "del-1") + called: dict[str, Any] = {} + + def fake_build(*, session: Any, form_id: str, delivery_id: str, recipients_config: Any) -> list[Any]: + called.update( + {"session": session, "form_id": form_id, "delivery_id": delivery_id, "recipients_config": recipients_config} + ) + return ["r"] + + monkeypatch.setattr(repo, "_build_email_recipients", fake_build) + + method = EmailDeliveryMethod( + config=EmailDeliveryConfig( + recipients=EmailRecipients( + include_bound_group=False, + items=[MemberRecipient(reference_id="u1"), ExternalRecipient(email="e@example.com")], + ), + subject="s", + body="b", + ) + ) + result = repo._delivery_method_to_model(session="sess", form_id="form-1", delivery_method=method) + assert result.recipients == ["r"] + assert called["delivery_id"] == "del-1" + + +def test_build_email_recipients_uses_all_members_when_whole_workspace(monkeypatch: pytest.MonkeyPatch) -> None: + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + monkeypatch.setattr( + repo, + "_query_all_workspace_members", + lambda *, session: [_WorkspaceMemberInfo(user_id="u", email="a@example.com")], + ) + monkeypatch.setattr(repo, "_create_email_recipients_from_resolved", lambda **_: ["ok"]) + recipients = repo._build_email_recipients( + session=MagicMock(), + form_id="f", + delivery_id="d", + recipients_config=EmailRecipients(include_bound_group=True, items=[ExternalRecipient(email="e@example.com")]), + ) + assert recipients == ["ok"] + + +def test_build_email_recipients_uses_selected_members_when_not_whole_workspace(monkeypatch: pytest.MonkeyPatch) -> None: + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + + def fake_query(*, session: Any, restrict_to_user_ids: Sequence[str]) -> list[_WorkspaceMemberInfo]: + assert restrict_to_user_ids == ["u1"] + return [_WorkspaceMemberInfo(user_id="u1", email="a@example.com")] + + monkeypatch.setattr(repo, "_query_workspace_members_by_ids", fake_query) + monkeypatch.setattr(repo, "_create_email_recipients_from_resolved", lambda **_: ["ok"]) + recipients = repo._build_email_recipients( + session=MagicMock(), + form_id="f", + delivery_id="d", + recipients_config=EmailRecipients( + include_bound_group=False, + items=[MemberRecipient(reference_id="u1"), ExternalRecipient(email="e@example.com")], + ), + ) + assert recipients == ["ok"] + + +def test_get_form_returns_entity_and_none_when_missing(monkeypatch: pytest.MonkeyPatch) -> None: + _patch_session_factory(monkeypatch, _FakeSession(scalars_results=[None])) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant", workflow_execution_id="run") + assert repo.get_form("node") is None + + form = _DummyForm( + id="f1", + workflow_run_id="run", + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="

x

", + expiration_time=naive_utc_now(), + ) + recipient = _DummyRecipient( + id="r1", + form_id=form.id, + recipient_type=RecipientType.STANDALONE_WEB_APP, + access_token="tok", + ) + session = _FakeSession(scalars_results=[form, [recipient]]) + _patch_session_factory(monkeypatch, session) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant", workflow_execution_id="run") + entity = repo.get_form("node") + assert entity is not None + assert entity.id == "f1" + assert entity.recipients[0].id == "r1" + assert entity.recipients[0].token == "tok" + + +def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.MonkeyPatch) -> None: + fixed_now = datetime(2024, 1, 1, 0, 0, 0) + monkeypatch.setattr("core.repositories.human_input_repository.naive_utc_now", lambda: fixed_now) + + ids = iter(["form-id", "del-web", "del-console", "del-backstage"]) + monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: next(ids)) + + session = _FakeSession() + _patch_session_factory(monkeypatch, session) + repo = HumanInputFormRepositoryImpl( + tenant_id="tenant", + app_id="app", + workflow_execution_id="run", + invoke_source="debugger", + submission_actor_id="acc-1", + ) + + form_config = HumanInputNodeData( + title="Title", + delivery_methods=[], + form_content="hello", + inputs=[], + user_actions=[UserAction(id="submit", title="Submit")], + ) + params = FormCreateParams( + workflow_execution_id=None, + node_id="node", + form_config=form_config, + rendered_content="

hello

", + delivery_methods=[WebAppDeliveryMethod()], + display_in_ui=True, + resolved_default_values={}, + form_kind=HumanInputFormKind.RUNTIME, + ) + + entity = repo.create_form(params) + assert entity.id == "form-id" + assert entity.expiration_time == fixed_now + timedelta(hours=form_config.timeout) + # Console token should take precedence when console recipient is present. + assert entity.submission_token == "token-console" + assert len(entity.recipients) == 3 + + +def test_submission_get_by_token_returns_none_when_missing_or_form_missing(monkeypatch: pytest.MonkeyPatch) -> None: + _patch_session_factory(monkeypatch, _FakeSession(scalars_result=None)) + repo = HumanInputFormSubmissionRepository() + assert repo.get_by_token("tok") is None + + recipient = SimpleNamespace(form=None) + _patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient)) + repo = HumanInputFormSubmissionRepository() + assert repo.get_by_token("tok") is None + + +def test_submission_repository_init_no_args() -> None: + repo = HumanInputFormSubmissionRepository() + assert isinstance(repo, HumanInputFormSubmissionRepository) + + +def test_submission_get_by_token_and_get_by_form_id_success_paths(monkeypatch: pytest.MonkeyPatch) -> None: + form = _DummyForm( + id="f1", + workflow_run_id=None, + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="

x

", + expiration_time=naive_utc_now(), + ) + recipient = SimpleNamespace( + id="r1", + form_id=form.id, + recipient_type=RecipientType.STANDALONE_WEB_APP, + access_token="tok", + form=form, + ) + + _patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient)) + repo = HumanInputFormSubmissionRepository() + record = repo.get_by_token("tok") + assert record is not None + assert record.access_token == "tok" + + _patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient)) + repo = HumanInputFormSubmissionRepository() + record = repo.get_by_form_id_and_recipient_type(form_id=form.id, recipient_type=RecipientType.STANDALONE_WEB_APP) + assert record is not None + assert record.recipient_id == "r1" + + +def test_submission_get_by_form_id_returns_none_on_missing(monkeypatch: pytest.MonkeyPatch) -> None: + _patch_session_factory(monkeypatch, _FakeSession(scalars_result=None)) + repo = HumanInputFormSubmissionRepository() + assert repo.get_by_form_id_and_recipient_type(form_id="f", recipient_type=RecipientType.CONSOLE) is None + + +def test_mark_submitted_updates_and_raises_when_missing(monkeypatch: pytest.MonkeyPatch) -> None: + fixed_now = datetime(2024, 1, 1, 0, 0, 0) + monkeypatch.setattr("core.repositories.human_input_repository.naive_utc_now", lambda: fixed_now) + + missing_session = _FakeSession(forms={}) + _patch_session_factory(monkeypatch, missing_session) + repo = HumanInputFormSubmissionRepository() + with pytest.raises(FormNotFoundError, match="form not found"): + repo.mark_submitted( + form_id="missing", + recipient_id=None, + selected_action_id="a", + form_data={}, + submission_user_id=None, + submission_end_user_id=None, + ) + + form = _DummyForm( + id="f", + workflow_run_id=None, + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="

x

", + expiration_time=fixed_now, + ) + recipient = _DummyRecipient(id="r", form_id=form.id, recipient_type=RecipientType.CONSOLE, access_token="tok") + session = _FakeSession(forms={form.id: form}, recipients={recipient.id: recipient}) + _patch_session_factory(monkeypatch, session) + repo = HumanInputFormSubmissionRepository() + record = repo.mark_submitted( + form_id=form.id, + recipient_id=recipient.id, + selected_action_id="approve", + form_data={"k": "v"}, + submission_user_id="u", + submission_end_user_id="eu", + ) + assert form.status == HumanInputFormStatus.SUBMITTED + assert form.submitted_at == fixed_now + assert record.submitted_data == {"k": "v"} + + +def test_mark_timeout_invalid_status_raises(monkeypatch: pytest.MonkeyPatch) -> None: + form = _DummyForm( + id="f", + workflow_run_id=None, + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="

x

", + expiration_time=naive_utc_now(), + ) + session = _FakeSession(forms={form.id: form}) + _patch_session_factory(monkeypatch, session) + repo = HumanInputFormSubmissionRepository() + with pytest.raises(_InvalidTimeoutStatusError, match="invalid timeout status"): + repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.SUBMITTED) # type: ignore[arg-type] + + +def test_mark_timeout_already_timed_out_returns_record(monkeypatch: pytest.MonkeyPatch) -> None: + form = _DummyForm( + id="f", + workflow_run_id=None, + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="

x

", + expiration_time=naive_utc_now(), + status=HumanInputFormStatus.TIMEOUT, + ) + session = _FakeSession(forms={form.id: form}) + _patch_session_factory(monkeypatch, session) + repo = HumanInputFormSubmissionRepository() + record = repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.TIMEOUT, reason="r") + assert record.status == HumanInputFormStatus.TIMEOUT + + +def test_mark_timeout_submitted_raises_form_not_found(monkeypatch: pytest.MonkeyPatch) -> None: + form = _DummyForm( + id="f", + workflow_run_id=None, + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="

x

", + expiration_time=naive_utc_now(), + status=HumanInputFormStatus.SUBMITTED, + ) + session = _FakeSession(forms={form.id: form}) + _patch_session_factory(monkeypatch, session) + repo = HumanInputFormSubmissionRepository() + with pytest.raises(FormNotFoundError, match="form already submitted"): + repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.EXPIRED) + + +def test_mark_timeout_updates_fields(monkeypatch: pytest.MonkeyPatch) -> None: + form = _DummyForm( + id="f", + workflow_run_id=None, + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="

x

", + expiration_time=naive_utc_now(), + selected_action_id="a", + submitted_data="{}", + submission_user_id="u", + submission_end_user_id="eu", + completed_by_recipient_id="r", + status=HumanInputFormStatus.WAITING, + ) + session = _FakeSession(forms={form.id: form}) + _patch_session_factory(monkeypatch, session) + repo = HumanInputFormSubmissionRepository() + record = repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.EXPIRED) + assert form.status == HumanInputFormStatus.EXPIRED + assert form.selected_action_id is None + assert form.submitted_data is None + assert form.submission_user_id is None + assert form.submission_end_user_id is None + assert form.completed_by_recipient_id is None + assert record.status == HumanInputFormStatus.EXPIRED + + +def test_mark_timeout_raises_when_form_missing(monkeypatch: pytest.MonkeyPatch) -> None: + _patch_session_factory(monkeypatch, _FakeSession(forms={})) + repo = HumanInputFormSubmissionRepository() + with pytest.raises(FormNotFoundError, match="form not found"): + repo.mark_timeout(form_id="missing", timeout_status=HumanInputFormStatus.TIMEOUT) diff --git a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py index c66e50437ab..e5c3e854875 100644 --- a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py @@ -1,84 +1,292 @@ -from datetime import datetime +from datetime import UTC, datetime from unittest.mock import MagicMock from uuid import uuid4 -from sqlalchemy import create_engine +import pytest +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus, WorkflowType +from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository -from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowType -from models import Account, WorkflowRun +from models import Account, CreatorUserRole, EndUser, WorkflowRun from models.enums import WorkflowRunTriggeredFrom -def _build_repository_with_mocked_session(session: MagicMock) -> SQLAlchemyWorkflowExecutionRepository: - engine = create_engine("sqlite:///:memory:") - real_session_factory = sessionmaker(bind=engine, expire_on_commit=False) - - user = MagicMock(spec=Account) - user.id = str(uuid4()) - user.current_tenant_id = str(uuid4()) - - repository = SQLAlchemyWorkflowExecutionRepository( - session_factory=real_session_factory, - user=user, - app_id="app-id", - triggered_from=WorkflowRunTriggeredFrom.APP_RUN, - ) - - session_context = MagicMock() - session_context.__enter__.return_value = session - session_context.__exit__.return_value = False - repository._session_factory = MagicMock(return_value=session_context) - return repository - - -def _build_execution(*, execution_id: str, started_at: datetime) -> WorkflowExecution: - return WorkflowExecution.new( - id_=execution_id, - workflow_id="workflow-id", - workflow_type=WorkflowType.WORKFLOW, - workflow_version="1.0.0", - graph={"nodes": [], "edges": []}, - inputs={"query": "hello"}, - started_at=started_at, - ) - - -def test_save_uses_execution_started_at_when_record_does_not_exist(): +@pytest.fixture +def mock_session_factory(): + """Mock SQLAlchemy session factory.""" + session_factory = MagicMock(spec=sessionmaker) session = MagicMock() session.get.return_value = None - repository = _build_repository_with_mocked_session(session) - - started_at = datetime(2026, 1, 1, 12, 0, 0) - execution = _build_execution(execution_id=str(uuid4()), started_at=started_at) - - repository.save(execution) - - saved_model = session.merge.call_args.args[0] - assert saved_model.created_at == started_at - session.commit.assert_called_once() + session_factory.return_value.__enter__.return_value = session + return session_factory -def test_save_preserves_existing_created_at_when_record_already_exists(): - session = MagicMock() - repository = _build_repository_with_mocked_session(session) +@pytest.fixture +def mock_engine(): + """Mock SQLAlchemy Engine.""" + return MagicMock(spec=Engine) - execution_id = str(uuid4()) - existing_created_at = datetime(2026, 1, 1, 12, 0, 0) - existing_run = WorkflowRun() - existing_run.id = execution_id - existing_run.tenant_id = repository._tenant_id - existing_run.created_at = existing_created_at - session.get.return_value = existing_run - execution = _build_execution( - execution_id=execution_id, - started_at=datetime(2026, 1, 1, 12, 30, 0), +@pytest.fixture +def mock_account(): + """Mock Account user.""" + account = MagicMock(spec=Account) + account.id = str(uuid4()) + account.current_tenant_id = str(uuid4()) + return account + + +@pytest.fixture +def mock_end_user(): + """Mock EndUser.""" + user = MagicMock(spec=EndUser) + user.id = str(uuid4()) + user.tenant_id = str(uuid4()) + return user + + +@pytest.fixture +def sample_workflow_execution(): + """Sample WorkflowExecution for testing.""" + return WorkflowExecution( + id_=str(uuid4()), + workflow_id=str(uuid4()), + workflow_type=WorkflowType.WORKFLOW, + workflow_version="1.0", + graph={"nodes": [], "edges": []}, + inputs={"input1": "value1"}, + outputs={"output1": "result1"}, + status=WorkflowExecutionStatus.SUCCEEDED, + error_message="", + total_tokens=100, + total_steps=5, + exceptions_count=0, + started_at=datetime.now(UTC), + finished_at=datetime.now(UTC), ) - repository.save(execution) - saved_model = session.merge.call_args.args[0] - assert saved_model.created_at == existing_created_at - session.commit.assert_called_once() +class TestSQLAlchemyWorkflowExecutionRepository: + def test_init_with_sessionmaker(self, mock_session_factory, mock_account): + app_id = "test_app_id" + triggered_from = WorkflowRunTriggeredFrom.APP_RUN + + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, user=mock_account, app_id=app_id, triggered_from=triggered_from + ) + + assert repo._session_factory == mock_session_factory + assert repo._tenant_id == mock_account.current_tenant_id + assert repo._app_id == app_id + assert repo._triggered_from == triggered_from + assert repo._creator_user_id == mock_account.id + assert repo._creator_user_role == CreatorUserRole.ACCOUNT + + def test_init_with_engine(self, mock_engine, mock_account): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_engine, + user=mock_account, + app_id="test_app_id", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + assert isinstance(repo._session_factory, sessionmaker) + assert repo._session_factory.kw["bind"] == mock_engine + + def test_init_invalid_session_factory(self, mock_account): + with pytest.raises(ValueError, match="Invalid session_factory type"): + SQLAlchemyWorkflowExecutionRepository( + session_factory="invalid", user=mock_account, app_id=None, triggered_from=None + ) + + def test_init_no_tenant_id(self, mock_session_factory): + user = MagicMock(spec=Account) + user.current_tenant_id = None + + with pytest.raises(ValueError, match="User must have a tenant_id"): + SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, user=user, app_id=None, triggered_from=None + ) + + def test_init_with_end_user(self, mock_session_factory, mock_end_user): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, user=mock_end_user, app_id=None, triggered_from=None + ) + assert repo._tenant_id == mock_end_user.tenant_id + assert repo._creator_user_role == CreatorUserRole.END_USER + + def test_to_domain_model(self, mock_session_factory, mock_account): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, user=mock_account, app_id=None, triggered_from=None + ) + + db_model = MagicMock(spec=WorkflowRun) + db_model.id = str(uuid4()) + db_model.workflow_id = str(uuid4()) + db_model.type = "workflow" + db_model.version = "1.0" + db_model.inputs_dict = {"in": "val"} + db_model.outputs_dict = {"out": "val"} + db_model.graph_dict = {"nodes": []} + db_model.status = "succeeded" + db_model.error = "some error" + db_model.total_tokens = 50 + db_model.total_steps = 3 + db_model.exceptions_count = 1 + db_model.created_at = datetime.now(UTC) + db_model.finished_at = datetime.now(UTC) + + domain_model = repo._to_domain_model(db_model) + + assert domain_model.id_ == db_model.id + assert domain_model.workflow_id == db_model.workflow_id + assert domain_model.status == WorkflowExecutionStatus.SUCCEEDED + assert domain_model.inputs == db_model.inputs_dict + assert domain_model.error_message == "some error" + + def test_to_db_model(self, mock_session_factory, mock_account, sample_workflow_execution): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test_app", + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + ) + + # Make elapsed time deterministic to avoid flaky tests + sample_workflow_execution.started_at = datetime(2023, 1, 1, 0, 0, 0, tzinfo=UTC) + sample_workflow_execution.finished_at = datetime(2023, 1, 1, 0, 0, 10, tzinfo=UTC) + + db_model = repo._to_db_model(sample_workflow_execution) + + assert db_model.id == sample_workflow_execution.id_ + assert db_model.tenant_id == repo._tenant_id + assert db_model.app_id == "test_app" + assert db_model.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING + assert db_model.status == sample_workflow_execution.status.value + assert db_model.total_tokens == sample_workflow_execution.total_tokens + assert db_model.elapsed_time == 10.0 + + def test_to_db_model_edge_cases(self, mock_session_factory, mock_account, sample_workflow_execution): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test_app", + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + ) + # Test with empty/None fields + sample_workflow_execution.graph = None + sample_workflow_execution.inputs = None + sample_workflow_execution.outputs = None + sample_workflow_execution.error_message = None + sample_workflow_execution.finished_at = None + + db_model = repo._to_db_model(sample_workflow_execution) + + assert db_model.graph is None + assert db_model.inputs is None + assert db_model.outputs is None + assert db_model.error is None + assert db_model.elapsed_time == 0 + + def test_to_db_model_app_id_none(self, mock_session_factory, mock_account, sample_workflow_execution): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id=None, + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + db_model = repo._to_db_model(sample_workflow_execution) + assert not hasattr(db_model, "app_id") or db_model.app_id is None + assert db_model.tenant_id == repo._tenant_id + + def test_to_db_model_missing_context(self, mock_session_factory, mock_account, sample_workflow_execution): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, user=mock_account, app_id=None, triggered_from=None + ) + + # Test triggered_from missing + with pytest.raises(ValueError, match="triggered_from is required"): + repo._to_db_model(sample_workflow_execution) + + repo._triggered_from = WorkflowRunTriggeredFrom.APP_RUN + repo._creator_user_id = None + with pytest.raises(ValueError, match="created_by is required"): + repo._to_db_model(sample_workflow_execution) + + repo._creator_user_id = "some_id" + repo._creator_user_role = None + with pytest.raises(ValueError, match="created_by_role is required"): + repo._to_db_model(sample_workflow_execution) + + def test_save(self, mock_session_factory, mock_account, sample_workflow_execution): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test_app", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + repo.save(sample_workflow_execution) + + session = mock_session_factory.return_value.__enter__.return_value + session.merge.assert_called_once() + session.commit.assert_called_once() + + # Check cache + assert sample_workflow_execution.id_ in repo._execution_cache + cached_model = repo._execution_cache[sample_workflow_execution.id_] + assert cached_model.id == sample_workflow_execution.id_ + + def test_save_uses_execution_started_at_when_record_does_not_exist( + self, mock_session_factory, mock_account, sample_workflow_execution + ): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test_app", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + started_at = datetime(2026, 1, 1, 12, 0, 0, tzinfo=UTC) + sample_workflow_execution.started_at = started_at + + session = mock_session_factory.return_value.__enter__.return_value + session.get.return_value = None + + repo.save(sample_workflow_execution) + + saved_model = session.merge.call_args.args[0] + assert saved_model.created_at == started_at + session.commit.assert_called_once() + + def test_save_preserves_existing_created_at_when_record_already_exists( + self, mock_session_factory, mock_account, sample_workflow_execution + ): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test_app", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + execution_id = sample_workflow_execution.id_ + existing_created_at = datetime(2026, 1, 1, 12, 0, 0, tzinfo=UTC) + + existing_run = WorkflowRun() + existing_run.id = execution_id + existing_run.tenant_id = repo._tenant_id + existing_run.created_at = existing_created_at + + session = mock_session_factory.return_value.__enter__.return_value + session.get.return_value = existing_run + + sample_workflow_execution.started_at = datetime(2026, 1, 1, 12, 30, 0, tzinfo=UTC) + + repo.save(sample_workflow_execution) + + saved_model = session.merge.call_args.args[0] + assert saved_model.created_at == existing_created_at + session.commit.assert_called_once() diff --git a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py new file mode 100644 index 00000000000..5b4d26b7808 --- /dev/null +++ b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py @@ -0,0 +1,772 @@ +from __future__ import annotations + +import json +import logging +from collections.abc import Mapping +from datetime import UTC, datetime +from types import SimpleNamespace +from typing import Any +from unittest.mock import MagicMock, Mock + +import psycopg2.errors +import pytest +from graphon.entities import WorkflowNodeExecution +from graphon.enums import ( + BuiltinNodeTypes, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from sqlalchemy import Engine, create_engine +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import sessionmaker + +from configs import dify_config +from core.repositories.factory import OrderConfig +from core.repositories.sqlalchemy_workflow_node_execution_repository import ( + SQLAlchemyWorkflowNodeExecutionRepository, + _deterministic_json_dump, + _filter_by_offload_type, + _find_first, + _replace_or_append_offload, +) +from models import Account, EndUser +from models.enums import ExecutionOffLoadType +from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom + + +def _mock_account(*, tenant_id: str = "tenant", user_id: str = "user") -> Account: + user = Mock(spec=Account) + user.id = user_id + user.current_tenant_id = tenant_id + return user + + +def _mock_end_user(*, tenant_id: str = "tenant", user_id: str = "user") -> EndUser: + user = Mock(spec=EndUser) + user.id = user_id + user.tenant_id = tenant_id + return user + + +def _execution( + *, + execution_id: str = "exec-id", + node_execution_id: str = "node-exec-id", + workflow_run_id: str = "run-id", + status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.SUCCEEDED, + inputs: Mapping[str, Any] | None = None, + outputs: Mapping[str, Any] | None = None, + process_data: Mapping[str, Any] | None = None, + metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None, +) -> WorkflowNodeExecution: + return WorkflowNodeExecution( + id=execution_id, + node_execution_id=node_execution_id, + workflow_id="workflow-id", + workflow_execution_id=workflow_run_id, + index=1, + predecessor_node_id=None, + node_id="node-id", + node_type=BuiltinNodeTypes.LLM, + title="Title", + inputs=inputs, + outputs=outputs, + process_data=process_data, + status=status, + error=None, + elapsed_time=1.0, + metadata=metadata, + created_at=datetime.now(UTC), + finished_at=None, + ) + + +class _SessionCtx: + def __init__(self, session: Any): + self._session = session + + def __enter__(self) -> Any: + return self._session + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + +def _session_factory(session: Any) -> sessionmaker: + factory = Mock(spec=sessionmaker) + factory.return_value = _SessionCtx(session) + return factory + + +def test_init_accepts_engine_and_sessionmaker_and_sets_role(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + + engine: Engine = create_engine("sqlite:///:memory:") + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=engine, + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + assert isinstance(repo._session_factory, sessionmaker) + + sm = Mock(spec=sessionmaker) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=sm, + user=_mock_end_user(), + app_id="app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, + ) + assert repo._creator_user_role.value == "end_user" + + +def test_init_rejects_invalid_session_factory_type(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + with pytest.raises(ValueError, match="Invalid session_factory type"): + SQLAlchemyWorkflowNodeExecutionRepository( # type: ignore[arg-type] + session_factory=object(), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + +def test_init_requires_tenant_id(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + user = _mock_account() + user.current_tenant_id = None + with pytest.raises(ValueError, match="User must have a tenant_id"): + SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=user, + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + +def test_create_truncator_uses_config(monkeypatch: pytest.MonkeyPatch) -> None: + created: dict[str, Any] = {} + + class FakeTruncator: + def __init__(self, *, max_size_bytes: int, array_element_limit: int, string_length_limit: int): + created.update( + { + "max_size_bytes": max_size_bytes, + "array_element_limit": array_element_limit, + "string_length_limit": string_length_limit, + } + ) + + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.VariableTruncator", + FakeTruncator, + ) + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + _ = repo._create_truncator() + assert created["max_size_bytes"] == dify_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE + + +def test_helpers_find_first_and_replace_or_append_and_filter() -> None: + assert _deterministic_json_dump({"b": 1, "a": 2}) == '{"a": 2, "b": 1}' + assert _find_first([], lambda _: True) is None + assert _find_first([1, 2, 3], lambda x: x > 1) == 2 + + off1 = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS) + off2 = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS) + assert _find_first([off1, off2], _filter_by_offload_type(ExecutionOffLoadType.OUTPUTS)) is off2 + + replaced = _replace_or_append_offload([off1, off2], WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)) + assert len(replaced) == 2 + assert [o.type_ for o in replaced] == [ExecutionOffLoadType.OUTPUTS, ExecutionOffLoadType.INPUTS] + + +def test_to_db_model_requires_constructor_context(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + execution = _execution(inputs={"b": 1, "a": 2}, metadata={WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 1}) + + # Happy path: deterministic json dump should be sorted + db_model = repo._to_db_model(execution) + assert json.loads(db_model.inputs or "{}") == {"a": 2, "b": 1} + assert json.loads(db_model.execution_metadata or "{}")["total_tokens"] == 1 + + repo._triggered_from = None + with pytest.raises(ValueError, match="triggered_from is required"): + repo._to_db_model(execution) + + +def test_to_db_model_requires_creator_user_id_and_role(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id="app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + execution = _execution() + db_model = repo._to_db_model(execution) + assert db_model.app_id == "app" + + repo._creator_user_id = None + with pytest.raises(ValueError, match="created_by is required"): + repo._to_db_model(execution) + + repo._creator_user_id = "user" + repo._creator_user_role = None + with pytest.raises(ValueError, match="created_by_role is required"): + repo._to_db_model(execution) + + +def test_is_duplicate_key_error_and_regenerate_id( + monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture +) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + unique = Mock(spec=psycopg2.errors.UniqueViolation) + duplicate_error = IntegrityError("dup", params=None, orig=unique) + assert repo._is_duplicate_key_error(duplicate_error) is True + assert repo._is_duplicate_key_error(IntegrityError("other", params=None, orig=None)) is False + + execution = _execution(execution_id="old-id") + db_model = WorkflowNodeExecutionModel() + db_model.id = "old-id" + monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "new-id") + caplog.set_level(logging.WARNING) + repo._regenerate_id_on_duplicate(execution, db_model) + assert execution.id == "new-id" + assert db_model.id == "new-id" + assert any("Duplicate key conflict" in r.message for r in caplog.records) + + +def test_persist_to_database_updates_existing_and_inserts_new(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + session = MagicMock() + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=_session_factory(session), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + db_model = WorkflowNodeExecutionModel() + db_model.id = "id1" + db_model.node_execution_id = "node1" + db_model.foo = "bar" # type: ignore[attr-defined] + db_model.__dict__["_private"] = "x" + + existing = SimpleNamespace() + session.get.return_value = existing + repo._persist_to_database(db_model) + assert existing.foo == "bar" + session.add.assert_not_called() + assert repo._node_execution_cache["node1"] is db_model + + session.reset_mock() + session.get.return_value = None + repo._node_execution_cache.clear() + repo._persist_to_database(db_model) + session.add.assert_called_once_with(db_model) + assert repo._node_execution_cache["node1"] is db_model + + +def test_truncate_and_upload_returns_none_when_no_values_or_not_truncated(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id="app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + assert repo._truncate_and_upload(None, "e", ExecutionOffLoadType.INPUTS) is None + + class FakeTruncator: + def truncate_variable_mapping(self, value: Any): # type: ignore[no-untyped-def] + return value, False + + monkeypatch.setattr(repo, "_create_truncator", lambda: FakeTruncator()) + assert repo._truncate_and_upload({"a": 1}, "e", ExecutionOffLoadType.INPUTS) is None + + +def test_truncate_and_upload_uploads_and_builds_offload(monkeypatch: pytest.MonkeyPatch) -> None: + uploaded: dict[str, Any] = {} + + class FakeFileService: + def upload_file(self, *, filename: str, content: bytes, mimetype: str, user: Any): # type: ignore[no-untyped-def] + uploaded.update({"filename": filename, "content": content, "mimetype": mimetype, "user": user}) + return SimpleNamespace(id="file-id", key="file-key") + + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", lambda *_: FakeFileService() + ) + monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "offload-id") + + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id="app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + class FakeTruncator: + def truncate_variable_mapping(self, value: Any): # type: ignore[no-untyped-def] + return {"truncated": True}, True + + monkeypatch.setattr(repo, "_create_truncator", lambda: FakeTruncator()) + + result = repo._truncate_and_upload({"a": 1}, "exec", ExecutionOffLoadType.INPUTS) + assert result is not None + assert result.truncated_value == {"truncated": True} + assert uploaded["filename"].startswith("node_execution_exec_inputs.json") + assert result.offload.file_id == "file-id" + assert result.offload.type_ == ExecutionOffLoadType.INPUTS + + +def test_to_domain_model_loads_offloaded_files(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + db_model = WorkflowNodeExecutionModel() + db_model.id = "id" + db_model.node_execution_id = "node-exec" + db_model.workflow_id = "wf" + db_model.workflow_run_id = "run" + db_model.index = 1 + db_model.predecessor_node_id = None + db_model.node_id = "node" + db_model.node_type = BuiltinNodeTypes.LLM + db_model.title = "t" + db_model.inputs = json.dumps({"trunc": "i"}) + db_model.process_data = json.dumps({"trunc": "p"}) + db_model.outputs = json.dumps({"trunc": "o"}) + db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED + db_model.error = None + db_model.elapsed_time = 0.1 + db_model.execution_metadata = json.dumps({"total_tokens": 3}) + db_model.created_at = datetime.now(UTC) + db_model.finished_at = None + + off_in = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS) + off_out = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS) + off_proc = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.PROCESS_DATA) + off_in.file = SimpleNamespace(key="k-in") + off_out.file = SimpleNamespace(key="k-out") + off_proc.file = SimpleNamespace(key="k-proc") + db_model.offload_data = [off_out, off_in, off_proc] + + def fake_load(key: str) -> bytes: + return json.dumps({"full": key}).encode() + + monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.storage.load", fake_load) + + domain = repo._to_domain_model(db_model) + assert domain.inputs == {"full": "k-in"} + assert domain.outputs == {"full": "k-out"} + assert domain.process_data == {"full": "k-proc"} + assert domain.get_truncated_inputs() == {"trunc": "i"} + assert domain.get_truncated_outputs() == {"trunc": "o"} + assert domain.get_truncated_process_data() == {"trunc": "p"} + + +def test_to_domain_model_returns_early_when_no_offload_data(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + db_model = WorkflowNodeExecutionModel() + db_model.id = "id" + db_model.node_execution_id = "node-exec" + db_model.workflow_id = "wf" + db_model.workflow_run_id = "run" + db_model.index = 1 + db_model.predecessor_node_id = None + db_model.node_id = "node" + db_model.node_type = BuiltinNodeTypes.LLM + db_model.title = "t" + db_model.inputs = json.dumps({"i": 1}) + db_model.process_data = json.dumps({"p": 2}) + db_model.outputs = json.dumps({"o": 3}) + db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED + db_model.error = None + db_model.elapsed_time = 0.1 + db_model.execution_metadata = "{}" + db_model.created_at = datetime.now(UTC) + db_model.finished_at = None + db_model.offload_data = [] + + domain = repo._to_domain_model(db_model) + assert domain.inputs == {"i": 1} + assert domain.outputs == {"o": 3} + + +def test_json_encode_uses_runtime_converter(monkeypatch: pytest.MonkeyPatch) -> None: + class FakeConverter: + def to_json_encodable(self, values: Mapping[str, Any]) -> Mapping[str, Any]: + return {"wrapped": values["a"]} + + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowRuntimeTypeConverter", + FakeConverter, + ) + assert SQLAlchemyWorkflowNodeExecutionRepository._json_encode({"a": 1}) == '{"wrapped": 1}' + + +def test_save_execution_data_handles_existing_db_model_and_truncation(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + session = MagicMock() + session.execute.return_value.scalars.return_value.first.return_value = SimpleNamespace( + id="id", + offload_data=[WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)], + inputs=None, + outputs=None, + process_data=None, + ) + session.merge = Mock() + session.flush = Mock() + session.begin.return_value.__enter__ = Mock(return_value=session) + session.begin.return_value.__exit__ = Mock(return_value=None) + + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=_session_factory(session), + user=_mock_account(), + app_id="app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + execution = _execution(inputs={"a": 1}, outputs={"b": 2}, process_data={"c": 3}) + + trunc_result = SimpleNamespace( + truncated_value={"trunc": True}, + offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS, file_id="f1"), + ) + monkeypatch.setattr( + repo, "_truncate_and_upload", lambda values, *_args, **_kwargs: trunc_result if values == {"a": 1} else None + ) + monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values, sort_keys=True)) + + repo.save_execution_data(execution) + # Inputs should be truncated, outputs/process_data encoded directly + db_model = session.merge.call_args.args[0] + assert json.loads(db_model.inputs) == {"trunc": True} + assert json.loads(db_model.outputs) == {"b": 2} + assert json.loads(db_model.process_data) == {"c": 3} + assert any(off.type_ == ExecutionOffLoadType.INPUTS for off in db_model.offload_data) + assert execution.get_truncated_inputs() == {"trunc": True} + + +def test_save_execution_data_truncates_outputs_and_process_data(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + existing = SimpleNamespace( + id="id", + offload_data=[], + inputs=None, + outputs=None, + process_data=None, + ) + session = MagicMock() + session.execute.return_value.scalars.return_value.first.return_value = existing + session.merge = Mock() + session.flush = Mock() + session.begin.return_value.__enter__ = Mock(return_value=session) + session.begin.return_value.__exit__ = Mock(return_value=None) + + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=_session_factory(session), + user=_mock_account(), + app_id="app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + execution = _execution(inputs={"a": 1}, outputs={"b": 2}, process_data={"c": 3}) + + def trunc(values: Mapping[str, Any], *_args: Any, **_kwargs: Any) -> Any: + if values == {"b": 2}: + return SimpleNamespace( + truncated_value={"b": "trunc"}, + offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS, file_id="f2"), + ) + if values == {"c": 3}: + return SimpleNamespace( + truncated_value={"c": "trunc"}, + offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.PROCESS_DATA, file_id="f3"), + ) + return None + + monkeypatch.setattr(repo, "_truncate_and_upload", trunc) + monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values, sort_keys=True)) + + repo.save_execution_data(execution) + db_model = session.merge.call_args.args[0] + assert json.loads(db_model.outputs) == {"b": "trunc"} + assert json.loads(db_model.process_data) == {"c": "trunc"} + assert execution.get_truncated_outputs() == {"b": "trunc"} + assert execution.get_truncated_process_data() == {"c": "trunc"} + + +def test_save_execution_data_handles_missing_db_model(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + session = MagicMock() + session.execute.return_value.scalars.return_value.first.return_value = None + session.merge = Mock() + session.flush = Mock() + session.begin.return_value.__enter__ = Mock(return_value=session) + session.begin.return_value.__exit__ = Mock(return_value=None) + + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=_session_factory(session), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + execution = _execution(inputs={"a": 1}) + fake_db_model = SimpleNamespace(id=execution.id, offload_data=[], inputs=None, outputs=None, process_data=None) + monkeypatch.setattr(repo, "_to_db_model", lambda *_: fake_db_model) + monkeypatch.setattr(repo, "_truncate_and_upload", lambda *_args, **_kwargs: None) + monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values)) + + repo.save_execution_data(execution) + merged = session.merge.call_args.args[0] + assert merged.inputs == '{"a": 1}' + + +def test_save_retries_duplicate_and_logs_non_duplicate( + monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture +) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + execution = _execution(execution_id="id") + unique = Mock(spec=psycopg2.errors.UniqueViolation) + duplicate_error = IntegrityError("dup", params=None, orig=unique) + other_error = IntegrityError("other", params=None, orig=None) + + calls = {"n": 0} + + def persist(_db_model: Any) -> None: + calls["n"] += 1 + if calls["n"] == 1: + raise duplicate_error + + monkeypatch.setattr(repo, "_persist_to_database", persist) + monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "new-id") + repo.save(execution) + assert execution.id == "new-id" + assert repo._node_execution_cache[execution.node_execution_id] is not None + + caplog.set_level(logging.ERROR) + monkeypatch.setattr(repo, "_persist_to_database", lambda _db: (_ for _ in ()).throw(other_error)) + with pytest.raises(IntegrityError): + repo.save(_execution(execution_id="id2", node_execution_id="node2")) + assert any("Non-duplicate key integrity error" in r.message for r in caplog.records) + + +def test_save_logs_and_reraises_on_unexpected_error( + monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture +) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + caplog.set_level(logging.ERROR) + monkeypatch.setattr(repo, "_persist_to_database", lambda _db: (_ for _ in ()).throw(RuntimeError("boom"))) + with pytest.raises(RuntimeError, match="boom"): + repo.save(_execution(execution_id="id3", node_execution_id="node3")) + assert any("Failed to save workflow node execution" in r.message for r in caplog.records) + + +def test_get_db_models_by_workflow_run_orders_and_caches(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + + class FakeStmt: + def __init__(self) -> None: + self.where_calls = 0 + self.order_by_args: tuple[Any, ...] | None = None + + def where(self, *_args: Any) -> FakeStmt: + self.where_calls += 1 + return self + + def order_by(self, *args: Any) -> FakeStmt: + self.order_by_args = args + return self + + stmt = FakeStmt() + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowNodeExecutionModel.preload_offload_data_and_files", + lambda _q: stmt, + ) + monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.select", lambda *_: "select") + + model1 = SimpleNamespace(node_execution_id="n1") + model2 = SimpleNamespace(node_execution_id=None) + session = MagicMock() + session.scalars.return_value.all.return_value = [model1, model2] + + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=_session_factory(session), + user=_mock_account(), + app_id="app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + order = OrderConfig(order_by=["index", "missing"], order_direction="desc") + db_models = repo.get_db_models_by_workflow_run("run", order) + assert db_models == [model1, model2] + assert repo._node_execution_cache["n1"] is model1 + assert stmt.order_by_args is not None + + +def test_get_db_models_by_workflow_run_uses_asc_order(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + + class FakeStmt: + def where(self, *_args: Any) -> FakeStmt: + return self + + def order_by(self, *args: Any) -> FakeStmt: + self.args = args # type: ignore[attr-defined] + return self + + stmt = FakeStmt() + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowNodeExecutionModel.preload_offload_data_and_files", + lambda _q: stmt, + ) + monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.select", lambda *_: "select") + + session = MagicMock() + session.scalars.return_value.all.return_value = [] + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=_session_factory(session), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + repo.get_db_models_by_workflow_run("run", OrderConfig(order_by=["index"], order_direction="asc")) + + +def test_get_by_workflow_run_maps_to_domain(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + db_models = [SimpleNamespace(id="db1"), SimpleNamespace(id="db2")] + monkeypatch.setattr(repo, "get_db_models_by_workflow_run", lambda *_args, **_kwargs: db_models) + monkeypatch.setattr(repo, "_to_domain_model", lambda m: f"domain:{m.id}") + + class FakeExecutor: + def __enter__(self) -> FakeExecutor: + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def map(self, func, items, timeout: int): # type: ignore[no-untyped-def] + assert timeout == 30 + return list(map(func, items)) + + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.ThreadPoolExecutor", + lambda max_workers: FakeExecutor(), + ) + + result = repo.get_by_workflow_execution("run", order_config=None) + assert result == ["domain:db1", "domain:db2"] diff --git a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py index 456c3dde12c..84fe522388e 100644 --- a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py +++ b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py @@ -4,17 +4,17 @@ from unittest.mock import MagicMock, Mock import psycopg2.errors import pytest +from graphon.entities.workflow_node_execution import ( + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, +) +from graphon.enums import BuiltinNodeTypes from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, ) -from dify_graph.entities.workflow_node_execution import ( - WorkflowNodeExecution, - WorkflowNodeExecutionStatus, -) -from dify_graph.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models import Account, WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py index eeab81a1789..27729e7f06b 100644 --- a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py +++ b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py @@ -11,17 +11,17 @@ from datetime import UTC, datetime from typing import Any from unittest.mock import MagicMock +from graphon.entities.workflow_node_execution import ( + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, +) +from graphon.enums import BuiltinNodeTypes from sqlalchemy import Engine from configs import dify_config from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, ) -from dify_graph.entities.workflow_node_execution import ( - WorkflowNodeExecution, - WorkflowNodeExecutionStatus, -) -from dify_graph.enums import BuiltinNodeTypes from models import Account, WorkflowNodeExecutionTriggeredFrom from models.enums import ExecutionOffLoadType from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload diff --git a/api/tests/unit_tests/core/schemas/test_registry.py b/api/tests/unit_tests/core/schemas/test_registry.py new file mode 100644 index 00000000000..5749e72eb03 --- /dev/null +++ b/api/tests/unit_tests/core/schemas/test_registry.py @@ -0,0 +1,137 @@ +import json +from unittest.mock import patch + +from core.schemas.registry import SchemaRegistry + + +class TestSchemaRegistry: + def test_initialization(self, tmp_path): + base_dir = tmp_path / "schemas" + base_dir.mkdir() + registry = SchemaRegistry(str(base_dir)) + assert registry.base_dir == base_dir + assert registry.versions == {} + assert registry.metadata == {} + + def test_default_registry_singleton(self): + registry1 = SchemaRegistry.default_registry() + registry2 = SchemaRegistry.default_registry() + assert registry1 is registry2 + assert isinstance(registry1, SchemaRegistry) + + def test_load_all_versions_non_existent_dir(self, tmp_path): + base_dir = tmp_path / "non_existent" + registry = SchemaRegistry(str(base_dir)) + registry.load_all_versions() + assert registry.versions == {} + + def test_load_all_versions_filtering(self, tmp_path): + base_dir = tmp_path / "schemas" + base_dir.mkdir() + (base_dir / "not_a_version_dir").mkdir() + (base_dir / "v1").mkdir() + (base_dir / "some_file.txt").write_text("content") + + registry = SchemaRegistry(str(base_dir)) + with patch.object(registry, "_load_version_dir") as mock_load: + registry.load_all_versions() + mock_load.assert_called_once() + assert mock_load.call_args[0][0] == "v1" + + def test_load_version_dir_filtering(self, tmp_path): + version_dir = tmp_path / "v1" + version_dir.mkdir() + (version_dir / "schema1.json").write_text("{}") + (version_dir / "not_a_schema.txt").write_text("content") + + registry = SchemaRegistry(str(tmp_path)) + with patch.object(registry, "_load_schema") as mock_load: + registry._load_version_dir("v1", version_dir) + mock_load.assert_called_once() + assert mock_load.call_args[0][1] == "schema1" + + def test_load_version_dir_non_existent(self, tmp_path): + version_dir = tmp_path / "non_existent" + registry = SchemaRegistry(str(tmp_path)) + registry._load_version_dir("v1", version_dir) + assert "v1" not in registry.versions + + def test_load_schema_success(self, tmp_path): + schema_path = tmp_path / "test.json" + schema_content = {"title": "Test Schema", "description": "A test schema"} + schema_path.write_text(json.dumps(schema_content)) + + registry = SchemaRegistry(str(tmp_path)) + registry.versions["v1"] = {} + registry._load_schema("v1", "test", schema_path) + + assert registry.versions["v1"]["test"] == schema_content + uri = "https://dify.ai/schemas/v1/test.json" + assert registry.metadata[uri]["title"] == "Test Schema" + assert registry.metadata[uri]["version"] == "v1" + + def test_load_schema_invalid_json(self, tmp_path, caplog): + schema_path = tmp_path / "invalid.json" + schema_path.write_text("invalid json") + + registry = SchemaRegistry(str(tmp_path)) + registry.versions["v1"] = {} + registry._load_schema("v1", "invalid", schema_path) + + assert "Failed to load schema v1/invalid" in caplog.text + + def test_load_schema_os_error(self, tmp_path, caplog): + schema_path = tmp_path / "error.json" + schema_path.write_text("{}") + + registry = SchemaRegistry(str(tmp_path)) + registry.versions["v1"] = {} + + with patch("builtins.open", side_effect=OSError("Read error")): + registry._load_schema("v1", "error", schema_path) + + assert "Failed to load schema v1/error" in caplog.text + + def test_get_schema(self): + registry = SchemaRegistry("/tmp") + registry.versions = {"v1": {"test": {"type": "object"}}} + + # Valid URI + assert registry.get_schema("https://dify.ai/schemas/v1/test.json") == {"type": "object"} + + # Invalid URI + assert registry.get_schema("invalid-uri") is None + + # Missing version + assert registry.get_schema("https://dify.ai/schemas/v2/test.json") is None + + def test_list_versions(self): + registry = SchemaRegistry("/tmp") + registry.versions = {"v2": {}, "v1": {}} + assert registry.list_versions() == ["v1", "v2"] + + def test_list_schemas(self): + registry = SchemaRegistry("/tmp") + registry.versions = {"v1": {"b": {}, "a": {}}} + + assert registry.list_schemas("v1") == ["a", "b"] + assert registry.list_schemas("v2") == [] + + def test_get_all_schemas_for_version(self): + registry = SchemaRegistry("/tmp") + registry.versions = {"v1": {"test": {"title": "Test Label"}}} + + results = registry.get_all_schemas_for_version("v1") + assert len(results) == 1 + assert results[0]["name"] == "test" + assert results[0]["label"] == "Test Label" + assert results[0]["schema"] == {"title": "Test Label"} + + # Default label if title missing + registry.versions["v1"]["no_title"] = {} + results = registry.get_all_schemas_for_version("v1") + item = next(r for r in results if r["name"] == "no_title") + assert item["label"] == "no_title" + + # Empty if version missing + assert registry.get_all_schemas_for_version("v2") == [] diff --git a/api/tests/unit_tests/core/schemas/test_schema_manager.py b/api/tests/unit_tests/core/schemas/test_schema_manager.py new file mode 100644 index 00000000000..cb07340c6d2 --- /dev/null +++ b/api/tests/unit_tests/core/schemas/test_schema_manager.py @@ -0,0 +1,80 @@ +from unittest.mock import MagicMock, patch + +from core.schemas.registry import SchemaRegistry +from core.schemas.schema_manager import SchemaManager + + +def test_init_with_provided_registry(): + mock_registry = MagicMock(spec=SchemaRegistry) + manager = SchemaManager(registry=mock_registry) + assert manager.registry == mock_registry + + +@patch("core.schemas.schema_manager.SchemaRegistry.default_registry") +def test_init_with_default_registry(mock_default_registry): + mock_registry = MagicMock(spec=SchemaRegistry) + mock_default_registry.return_value = mock_registry + + manager = SchemaManager() + + mock_default_registry.assert_called_once() + assert manager.registry == mock_registry + + +def test_get_all_schema_definitions(): + mock_registry = MagicMock(spec=SchemaRegistry) + expected_definitions = [{"name": "schema1", "schema": {}}, {"name": "schema2", "schema": {}}] + mock_registry.get_all_schemas_for_version.return_value = expected_definitions + + manager = SchemaManager(registry=mock_registry) + result = manager.get_all_schema_definitions(version="v2") + + mock_registry.get_all_schemas_for_version.assert_called_once_with("v2") + assert result == expected_definitions + + +def test_get_schema_by_name_success(): + mock_registry = MagicMock(spec=SchemaRegistry) + mock_schema = {"type": "object"} + mock_registry.get_schema.return_value = mock_schema + + manager = SchemaManager(registry=mock_registry) + result = manager.get_schema_by_name("my_schema", version="v1") + + expected_uri = "https://dify.ai/schemas/v1/my_schema.json" + mock_registry.get_schema.assert_called_once_with(expected_uri) + assert result == {"name": "my_schema", "schema": mock_schema} + + +def test_get_schema_by_name_not_found(): + mock_registry = MagicMock(spec=SchemaRegistry) + mock_registry.get_schema.return_value = None + + manager = SchemaManager(registry=mock_registry) + result = manager.get_schema_by_name("non_existent", version="v1") + + assert result is None + + +def test_list_available_schemas(): + mock_registry = MagicMock(spec=SchemaRegistry) + expected_schemas = ["schema1", "schema2"] + mock_registry.list_schemas.return_value = expected_schemas + + manager = SchemaManager(registry=mock_registry) + result = manager.list_available_schemas(version="v1") + + mock_registry.list_schemas.assert_called_once_with("v1") + assert result == expected_schemas + + +def test_list_available_versions(): + mock_registry = MagicMock(spec=SchemaRegistry) + expected_versions = ["v1", "v2"] + mock_registry.list_versions.return_value = expected_versions + + manager = SchemaManager(registry=mock_registry) + result = manager.list_available_versions() + + mock_registry.list_versions.assert_called_once() + assert result == expected_versions diff --git a/api/tests/unit_tests/core/telemetry/test_facade.py b/api/tests/unit_tests/core/telemetry/test_facade.py new file mode 100644 index 00000000000..36e8e1bbb13 --- /dev/null +++ b/api/tests/unit_tests/core/telemetry/test_facade.py @@ -0,0 +1,181 @@ +"""Unit tests for core.telemetry.emit() routing and enterprise-only filtering.""" + +from __future__ import annotations + +import queue +import sys +import types +from unittest.mock import MagicMock, patch + +import pytest + +from core.ops.entities.trace_entity import TraceTaskName +from core.telemetry.events import TelemetryContext, TelemetryEvent + + +@pytest.fixture +def telemetry_test_setup(monkeypatch): + module_name = "core.ops.ops_trace_manager" + ops_stub = types.ModuleType(module_name) + + class StubTraceTask: + def __init__(self, trace_type, **kwargs): + self.trace_type = trace_type + self.app_id = None + self.kwargs = kwargs + + class StubTraceQueueManager: + def __init__(self, app_id=None, user_id=None): + self.app_id = app_id + self.user_id = user_id + self.trace_instance = StubOpsTraceManager.get_ops_trace_instance(app_id) + + def add_trace_task(self, trace_task): + trace_task.app_id = self.app_id + from core.ops.ops_trace_manager import trace_manager_queue + + trace_manager_queue.put(trace_task) + + class StubOpsTraceManager: + @staticmethod + def get_ops_trace_instance(app_id): + return None + + ops_stub.TraceQueueManager = StubTraceQueueManager + ops_stub.TraceTask = StubTraceTask + ops_stub.OpsTraceManager = StubOpsTraceManager + ops_stub.trace_manager_queue = MagicMock(spec=queue.Queue) + monkeypatch.setitem(sys.modules, module_name, ops_stub) + + from core.telemetry import emit + + return emit, ops_stub.trace_manager_queue + + +class TestTelemetryEmit: + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_emit_enterprise_trace_creates_trace_task(self, mock_ee, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + event = TelemetryEvent( + name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={"key": "value"}, + ) + + emit_fn(event) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.trace_type == TraceTaskName.DRAFT_NODE_EXECUTION_TRACE + + def test_emit_community_trace_enqueued(self, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + event = TelemetryEvent( + name=TraceTaskName.WORKFLOW_TRACE, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={}, + ) + + emit_fn(event) + + mock_queue.put.assert_called_once() + + def test_emit_enterprise_only_trace_dropped_when_ee_disabled(self, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + event = TelemetryEvent( + name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={}, + ) + + emit_fn(event) + + mock_queue.put.assert_not_called() + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_emit_all_enterprise_only_traces_allowed_when_ee_enabled(self, mock_ee, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + enterprise_only_traces = [ + TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + TraceTaskName.NODE_EXECUTION_TRACE, + TraceTaskName.PROMPT_GENERATION_TRACE, + ] + + for trace_name in enterprise_only_traces: + mock_queue.reset_mock() + + event = TelemetryEvent( + name=trace_name, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={}, + ) + + emit_fn(event) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.trace_type == trace_name + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_emit_passes_name_directly_to_trace_task(self, mock_ee, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + event = TelemetryEvent( + name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={"extra": "data"}, + ) + + emit_fn(event) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.trace_type == TraceTaskName.DRAFT_NODE_EXECUTION_TRACE + assert isinstance(called_task.trace_type, TraceTaskName) + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_emit_with_provided_trace_manager(self, mock_ee, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + mock_trace_manager = MagicMock() + mock_trace_manager.add_trace_task = MagicMock() + + event = TelemetryEvent( + name=TraceTaskName.NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={}, + ) + + emit_fn(event, trace_manager=mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + called_task = mock_trace_manager.add_trace_task.call_args[0][0] + assert called_task.trace_type == TraceTaskName.NODE_EXECUTION_TRACE diff --git a/api/tests/unit_tests/core/telemetry/test_gateway_integration.py b/api/tests/unit_tests/core/telemetry/test_gateway_integration.py new file mode 100644 index 00000000000..a68fce5e7fa --- /dev/null +++ b/api/tests/unit_tests/core/telemetry/test_gateway_integration.py @@ -0,0 +1,225 @@ +from __future__ import annotations + +import sys +from unittest.mock import MagicMock, patch + +import pytest + +from core.telemetry.gateway import emit, is_enterprise_telemetry_enabled +from enterprise.telemetry.contracts import TelemetryCase + + +class TestTelemetryCoreExports: + def test_is_enterprise_telemetry_enabled_exported(self) -> None: + from core.telemetry.gateway import is_enterprise_telemetry_enabled as exported_func + + assert callable(exported_func) + + +@pytest.fixture +def mock_ops_trace_manager(): + mock_module = MagicMock() + mock_trace_task_class = MagicMock() + mock_trace_task_class.return_value = MagicMock() + mock_module.TraceTask = mock_trace_task_class + mock_module.TraceQueueManager = MagicMock() + + mock_trace_entity = MagicMock() + mock_trace_task_name = MagicMock() + mock_trace_task_name.return_value = "workflow" + mock_trace_entity.TraceTaskName = mock_trace_task_name + + with ( + patch.dict(sys.modules, {"core.ops.ops_trace_manager": mock_module}), + patch.dict(sys.modules, {"core.ops.entities.trace_entity": mock_trace_entity}), + ): + yield mock_module, mock_trace_entity + + +class TestGatewayIntegrationTraceRouting: + @pytest.fixture + def mock_trace_manager(self) -> MagicMock: + return MagicMock() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_ce_eligible_trace_routed_to_trace_manager( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True): + context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"} + payload = {"workflow_run_id": "run-abc"} + + emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_ce_eligible_trace_routed_when_ee_disabled( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"workflow_run_id": "run-abc"} + + emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_enterprise_only_trace_dropped_when_ee_disabled( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_id": "node-abc"} + + emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_not_called() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_enterprise_only_trace_routed_when_ee_enabled( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_id": "node-abc"} + + emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + +class TestGatewayIntegrationMetricRouting: + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_metric_case_routes_to_celery_task( + self, + mock_ee_enabled: MagicMock, + ) -> None: + from enterprise.telemetry.contracts import TelemetryEnvelope + + with patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") as mock_delay: + context = {"tenant_id": "tenant-123"} + payload = {"app_id": "app-abc", "name": "My App"} + + emit(TelemetryCase.APP_CREATED, context, payload) + + mock_delay.assert_called_once() + envelope_json = mock_delay.call_args[0][0] + envelope = TelemetryEnvelope.model_validate_json(envelope_json) + assert envelope.case == TelemetryCase.APP_CREATED + assert envelope.tenant_id == "tenant-123" + assert envelope.payload["app_id"] == "app-abc" + + @pytest.mark.usefixtures("mock_ops_trace_manager") + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_tool_execution_trace_routed( + self, + mock_ee_enabled: MagicMock, + ) -> None: + mock_trace_manager = MagicMock() + context = {"tenant_id": "tenant-123", "app_id": "app-123"} + payload = {"tool_name": "test_tool", "tool_inputs": {}, "tool_outputs": "result"} + + emit(TelemetryCase.TOOL_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_moderation_check_trace_routed( + self, + mock_ee_enabled: MagicMock, + ) -> None: + mock_trace_manager = MagicMock() + context = {"tenant_id": "tenant-123", "app_id": "app-123"} + payload = {"message_id": "msg-123", "moderation_result": {"flagged": False}} + + emit(TelemetryCase.MODERATION_CHECK, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + +class TestGatewayIntegrationCEEligibility: + @pytest.fixture + def mock_trace_manager(self) -> MagicMock: + return MagicMock() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_workflow_run_is_ce_eligible( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"workflow_run_id": "run-abc"} + + emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_message_run_is_ce_eligible( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"message_id": "msg-abc", "conversation_id": "conv-123"} + + emit(TelemetryCase.MESSAGE_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_node_execution_not_ce_eligible( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_id": "node-abc"} + + emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_not_called() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_draft_node_execution_not_ce_eligible( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_execution_data": {}} + + emit(TelemetryCase.DRAFT_NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_not_called() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_prompt_generation_not_ce_eligible( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"} + payload = {"operation_type": "generate", "instruction": "test"} + + emit(TelemetryCase.PROMPT_GENERATION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_not_called() + + +class TestIsEnterpriseTelemetryEnabled: + def test_returns_false_when_exporter_import_fails(self) -> None: + with patch.dict(sys.modules, {"enterprise.telemetry.exporter": None}): + result = is_enterprise_telemetry_enabled() + assert result is False + + def test_function_is_callable(self) -> None: + assert callable(is_enterprise_telemetry_enabled) diff --git a/api/tests/unit_tests/core/test_file.py b/api/tests/unit_tests/core/test_file.py index 251d6fd25ea..ac65d0c02bc 100644 --- a/api/tests/unit_tests/core/test_file.py +++ b/api/tests/unit_tests/core/test_file.py @@ -1,6 +1,7 @@ import json -from dify_graph.file import File, FileTransferMethod, FileType, FileUploadConfig +from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig + from models.workflow import Workflow diff --git a/api/tests/unit_tests/core/test_model_manager.py b/api/tests/unit_tests/core/test_model_manager.py index 92e4b584736..f5efb78b614 100644 --- a/api/tests/unit_tests/core/test_model_manager.py +++ b/api/tests/unit_tests/core/test_model_manager.py @@ -2,11 +2,11 @@ from unittest.mock import MagicMock, patch import pytest import redis +from graphon.model_runtime.entities.model_entities import ModelType from pytest_mock import MockerFixture from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.model_manager import LBModelManager -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_redis import redis_client diff --git a/api/tests/unit_tests/core/test_provider_configuration.py b/api/tests/unit_tests/core/test_provider_configuration.py index 90ed1647aac..331166fe63c 100644 --- a/api/tests/unit_tests/core/test_provider_configuration.py +++ b/api/tests/unit_tests/core/test_provider_configuration.py @@ -1,6 +1,15 @@ from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FormOption, + FormType, + ProviderEntity, +) from core.entities.provider_configuration import ProviderConfiguration, SystemConfigurationStatus from core.entities.provider_entities import ( @@ -12,15 +21,6 @@ from core.entities.provider_entities import ( RestrictModel, SystemConfiguration, ) -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FormOption, - FormType, - ProviderEntity, -) from models.provider import Provider, ProviderType diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py index 69567c54eb8..259cb5fdd07 100644 --- a/api/tests/unit_tests/core/test_provider_manager.py +++ b/api/tests/unit_tests/core/test_provider_manager.py @@ -1,12 +1,26 @@ -from unittest.mock import Mock, PropertyMock, patch +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, PropertyMock, patch import pytest +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType +from pytest_mock import MockerFixture from core.entities.provider_entities import ModelSettings from core.provider_manager import ProviderManager -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import ModelType -from models.provider import LoadBalancingModelConfig, ProviderModelSetting +from models.provider import LoadBalancingModelConfig, ProviderModelSetting, TenantDefaultModel +from models.provider_ids import ModelProviderID + + +def _build_provider_manager(mocker: MockerFixture) -> ProviderManager: + return ProviderManager(model_runtime=mocker.Mock()) + + +def _build_session_context(session: Mock) -> MagicMock: + session_cm = MagicMock() + session_cm.__enter__.return_value = session + session_cm.__exit__.return_value = False + return session_cm @pytest.fixture @@ -28,7 +42,7 @@ def mock_provider_entity(): return mock_entity -def test__to_model_settings(mock_provider_entity): +def test__to_model_settings(mocker: MockerFixture, mock_provider_entity): # Mocking the inputs ps = ProviderModelSetting( tenant_id="tenant_id", @@ -69,7 +83,7 @@ def test__to_model_settings(mock_provider_entity): "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}, ): - provider_manager = ProviderManager() + provider_manager = _build_provider_manager(mocker) # Running the method result = provider_manager._to_model_settings( @@ -89,7 +103,7 @@ def test__to_model_settings(mock_provider_entity): assert result[0].load_balancing_configs[1].name == "first" -def test__to_model_settings_only_one_lb(mock_provider_entity): +def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_entity): # Mocking the inputs ps = ProviderModelSetting( @@ -119,7 +133,7 @@ def test__to_model_settings_only_one_lb(mock_provider_entity): "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}, ): - provider_manager = ProviderManager() + provider_manager = _build_provider_manager(mocker) # Running the method result = provider_manager._to_model_settings( @@ -137,7 +151,7 @@ def test__to_model_settings_only_one_lb(mock_provider_entity): assert len(result[0].load_balancing_configs) == 0 -def test__to_model_settings_lb_disabled(mock_provider_entity): +def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_entity): # Mocking the inputs ps = ProviderModelSetting( tenant_id="tenant_id", @@ -176,7 +190,7 @@ def test__to_model_settings_lb_disabled(mock_provider_entity): "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}, ): - provider_manager = ProviderManager() + provider_manager = _build_provider_manager(mocker) # Running the method result = provider_manager._to_model_settings( @@ -194,7 +208,7 @@ def test__to_model_settings_lb_disabled(mock_provider_entity): assert len(result[0].load_balancing_configs) == 0 -def test_get_default_model_uses_first_available_active_model(): +def test_get_default_model_uses_first_available_active_model(mocker: MockerFixture): mock_session = Mock() mock_session.scalar.return_value = None @@ -204,7 +218,7 @@ def test_get_default_model_uses_first_available_active_model(): Mock(model="gpt-4", provider=Mock(provider="openai")), ] - manager = ProviderManager() + manager = _build_provider_manager(mocker) with ( patch("core.provider_manager.db.session", mock_session), patch.object(manager, "get_configurations", return_value=provider_configurations), @@ -228,3 +242,345 @@ def test_get_default_model_uses_first_available_active_model(): assert saved_default_model.model_name == "gpt-3.5-turbo" assert saved_default_model.provider_name == "openai" mock_session.commit.assert_called_once() + + +def test_get_default_model_returns_none_when_no_default_or_active_models(mocker: MockerFixture): + mock_session = Mock() + mock_session.scalar.return_value = None + provider_configurations = Mock() + provider_configurations.get_models.return_value = [] + manager = _build_provider_manager(mocker) + + with ( + patch("core.provider_manager.db.session", mock_session), + patch.object(manager, "get_configurations", return_value=provider_configurations), + patch("core.provider_manager.ModelProviderFactory") as mock_factory_cls, + ): + result = manager.get_default_model("tenant-id", ModelType.LLM) + + assert result is None + provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True) + mock_factory_cls.assert_not_called() + mock_session.add.assert_not_called() + mock_session.commit.assert_not_called() + + +def test_get_default_model_uses_injected_runtime_for_existing_default_record(mocker: MockerFixture): + existing_default_model = TenantDefaultModel( + tenant_id="tenant-id", + provider_name="openai", + model_name="gpt-4", + model_type=ModelType.LLM.to_origin_model_type(), + ) + mock_session = Mock() + mock_session.scalar.return_value = existing_default_model + manager = _build_provider_manager(mocker) + + with ( + patch("core.provider_manager.db.session", mock_session), + patch("core.provider_manager.ModelProviderFactory") as mock_factory_cls, + ): + mock_factory_cls.return_value.get_provider_schema.return_value = Mock( + provider="openai", + label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"), + icon_small=I18nObject(en_US="icon_small.png", zh_Hans="icon_small.png"), + supported_model_types=[ModelType.LLM], + ) + + result = manager.get_default_model("tenant-id", ModelType.LLM) + + mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime) + assert result is not None + assert result.model == "gpt-4" + assert result.provider.provider == "openai" + + +def test_get_configurations_uses_injected_runtime_and_adds_provider_aliases(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + provider_records = {"openai": [SimpleNamespace(provider_name="openai")]} + provider_model_records = {"openai": [SimpleNamespace(provider_name="openai")]} + preferred_provider_records = {"openai": SimpleNamespace(preferred_provider_type="system")} + + with ( + patch.object(manager, "_get_all_providers", return_value=provider_records), + patch.object(manager, "_init_trial_provider_records", return_value=provider_records), + patch.object(manager, "_get_all_provider_models", return_value=provider_model_records), + patch.object(manager, "_get_all_preferred_model_providers", return_value=preferred_provider_records), + patch.object(manager, "_get_all_provider_model_settings", return_value={}), + patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}), + patch.object(manager, "_get_all_provider_model_credentials", return_value={}), + patch("core.provider_manager.ModelProviderFactory") as mock_factory_cls, + ): + mock_factory_cls.return_value.get_providers.return_value = [] + + result = manager.get_configurations("tenant-id") + + expected_alias = str(ModelProviderID("openai")) + mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime) + assert result.tenant_id == "tenant-id" + assert expected_alias in provider_records + assert expected_alias in provider_model_records + assert expected_alias in preferred_provider_records + + +@pytest.mark.parametrize( + ("provider_name", "expected_provider_names"), + [ + ("openai", ["openai", "langgenius/openai/openai"]), + ("langgenius/openai/openai", ["langgenius/openai/openai", "openai"]), + ("langgenius/gemini/google", ["langgenius/gemini/google", "google"]), + ], +) +def test_get_provider_names_returns_short_and_full_aliases(provider_name: str, expected_provider_names: list[str]): + assert ProviderManager._get_provider_names(provider_name) == expected_provider_names + + +def test_get_provider_model_bundle_raises_for_unknown_provider(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + + with patch.object(manager, "get_configurations", return_value={}): + with pytest.raises(ValueError, match="Provider openai does not exist."): + manager.get_provider_model_bundle("tenant-id", "openai", ModelType.LLM) + + +def test_get_configurations_binds_manager_runtime_to_provider_configuration( + mocker: MockerFixture, mock_provider_entity +): + manager = _build_provider_manager(mocker) + provider_configuration = Mock() + provider_factory = Mock() + provider_factory.get_providers.return_value = [mock_provider_entity] + custom_configuration = SimpleNamespace(provider=None, models=[]) + system_configuration = SimpleNamespace(enabled=False, quota_configurations=[], current_quota_type=None) + + with ( + patch.object(manager, "_get_all_providers", return_value={"openai": []}), + patch.object(manager, "_init_trial_provider_records", return_value={"openai": []}), + patch.object(manager, "_get_all_provider_models", return_value={"openai": []}), + patch.object(manager, "_get_all_preferred_model_providers", return_value={}), + patch.object(manager, "_get_all_provider_model_settings", return_value={}), + patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}), + patch.object(manager, "_get_all_provider_model_credentials", return_value={}), + patch.object(manager, "_to_custom_configuration", return_value=custom_configuration), + patch.object(manager, "_to_system_configuration", return_value=system_configuration), + patch.object(manager, "_to_model_settings", return_value=[]), + patch("core.provider_manager.ModelProviderFactory", return_value=provider_factory), + patch("core.provider_manager.ProviderConfiguration", return_value=provider_configuration), + ): + manager.get_configurations("tenant-id") + + provider_configuration.bind_model_runtime.assert_called_once_with(manager._model_runtime) + + +def test_get_provider_model_bundle_returns_selected_model_type_instance(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + provider_configuration = Mock() + model_type_instance = Mock() + provider_configuration.get_model_type_instance.return_value = model_type_instance + expected_bundle = Mock() + + with ( + patch.object(manager, "get_configurations", return_value={"openai": provider_configuration}), + patch("core.provider_manager.ProviderModelBundle", return_value=expected_bundle) as mock_bundle, + ): + result = manager.get_provider_model_bundle("tenant-id", "openai", ModelType.LLM) + + provider_configuration.get_model_type_instance.assert_called_once_with(ModelType.LLM) + mock_bundle.assert_called_once_with( + configuration=provider_configuration, + model_type_instance=model_type_instance, + ) + assert result is expected_bundle + + +def test_get_first_provider_first_model_returns_none_when_no_models(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + provider_configurations = Mock() + provider_configurations.get_models.return_value = [] + + with patch.object(manager, "get_configurations", return_value=provider_configurations): + result = manager.get_first_provider_first_model("tenant-id", ModelType.LLM) + + assert result == (None, None) + provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=False) + + +def test_get_first_provider_first_model_returns_first_model_and_provider(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + provider_configurations = Mock() + provider_configurations.get_models.return_value = [ + Mock(model="gpt-4", provider=Mock(provider="openai")), + Mock(model="gpt-4o", provider=Mock(provider="openai")), + ] + + with patch.object(manager, "get_configurations", return_value=provider_configurations): + result = manager.get_first_provider_first_model("tenant-id", ModelType.LLM) + + assert result == ("openai", "gpt-4") + + +def test_update_default_model_record_raises_for_unknown_provider(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + + with patch.object(manager, "get_configurations", return_value={}): + with pytest.raises(ValueError, match="Provider openai does not exist."): + manager.update_default_model_record("tenant-id", ModelType.LLM, "openai", "gpt-4") + + +def test_update_default_model_record_raises_for_unknown_model(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + provider_configurations = MagicMock() + provider_configurations.__contains__.return_value = True + provider_configurations.get_models.return_value = [Mock(model="gpt-4")] + + with patch.object(manager, "get_configurations", return_value=provider_configurations): + with pytest.raises(ValueError, match="Model gpt-3.5-turbo does not exist."): + manager.update_default_model_record("tenant-id", ModelType.LLM, "openai", "gpt-3.5-turbo") + + provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True) + + +def test_update_default_model_record_updates_existing_record(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + provider_configurations = MagicMock() + provider_configurations.__contains__.return_value = True + provider_configurations.get_models.return_value = [Mock(model="gpt-3.5-turbo")] + existing_default_model = TenantDefaultModel( + tenant_id="tenant-id", + provider_name="anthropic", + model_name="claude-3-sonnet", + model_type=ModelType.LLM.to_origin_model_type(), + ) + mock_session = Mock() + mock_session.scalar.return_value = existing_default_model + + with ( + patch.object(manager, "get_configurations", return_value=provider_configurations), + patch("core.provider_manager.db.session", mock_session), + ): + result = manager.update_default_model_record("tenant-id", ModelType.LLM, "openai", "gpt-3.5-turbo") + + assert result is existing_default_model + assert existing_default_model.provider_name == "openai" + assert existing_default_model.model_name == "gpt-3.5-turbo" + mock_session.commit.assert_called_once() + mock_session.add.assert_not_called() + + +def test_update_default_model_record_creates_record_with_origin_model_type(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + provider_configurations = MagicMock() + provider_configurations.__contains__.return_value = True + provider_configurations.get_models.return_value = [Mock(model="gpt-4")] + mock_session = Mock() + mock_session.scalar.return_value = None + + with ( + patch.object(manager, "get_configurations", return_value=provider_configurations), + patch("core.provider_manager.db.session", mock_session), + ): + result = manager.update_default_model_record("tenant-id", ModelType.LLM, "openai", "gpt-4") + + mock_session.add.assert_called_once() + created_default_model = mock_session.add.call_args.args[0] + assert result is created_default_model + assert created_default_model.tenant_id == "tenant-id" + assert created_default_model.provider_name == "openai" + assert created_default_model.model_name == "gpt-4" + assert created_default_model.model_type == ModelType.LLM.to_origin_model_type() + mock_session.commit.assert_called_once() + + +def test_get_all_providers_normalizes_provider_names_with_model_provider_id() -> None: + session = Mock() + openai_provider = SimpleNamespace(provider_name="openai") + gemini_provider = SimpleNamespace(provider_name="langgenius/gemini/google") + session.scalars.return_value = [openai_provider, gemini_provider] + + with ( + patch("core.provider_manager.db", SimpleNamespace(engine=object())), + patch("core.provider_manager.Session", return_value=_build_session_context(session)), + ): + result = ProviderManager._get_all_providers("tenant-id") + + assert list(result[str(ModelProviderID("openai"))]) == [openai_provider] + assert list(result[str(ModelProviderID("langgenius/gemini/google"))]) == [gemini_provider] + + +@pytest.mark.parametrize( + "method_name", + [ + "_get_all_provider_models", + "_get_all_provider_model_settings", + "_get_all_provider_model_credentials", + ], +) +def test_provider_grouping_helpers_group_records_by_provider_name(method_name: str) -> None: + session = Mock() + openai_primary = SimpleNamespace(provider_name="openai") + openai_secondary = SimpleNamespace(provider_name="openai") + anthropic_record = SimpleNamespace(provider_name="anthropic") + session.scalars.return_value = [openai_primary, openai_secondary, anthropic_record] + + with ( + patch("core.provider_manager.db", SimpleNamespace(engine=object())), + patch("core.provider_manager.Session", return_value=_build_session_context(session)), + ): + result = getattr(ProviderManager, method_name)("tenant-id") + + assert list(result["openai"]) == [openai_primary, openai_secondary] + assert list(result["anthropic"]) == [anthropic_record] + + +def test_get_all_preferred_model_providers_returns_mapping_by_provider_name() -> None: + session = Mock() + openai_preference = SimpleNamespace(provider_name="openai") + anthropic_preference = SimpleNamespace(provider_name="anthropic") + session.scalars.return_value = [openai_preference, anthropic_preference] + + with ( + patch("core.provider_manager.db", SimpleNamespace(engine=object())), + patch("core.provider_manager.Session", return_value=_build_session_context(session)), + ): + result = ProviderManager._get_all_preferred_model_providers("tenant-id") + + assert result == { + "openai": openai_preference, + "anthropic": anthropic_preference, + } + + +def test_get_all_provider_load_balancing_configs_returns_empty_when_cached_flag_is_disabled() -> None: + with ( + patch("core.provider_manager.redis_client.get", return_value=b"False"), + patch("core.provider_manager.FeatureService.get_features") as mock_get_features, + patch("core.provider_manager.Session") as mock_session_cls, + ): + result = ProviderManager._get_all_provider_load_balancing_configs("tenant-id") + + assert result == {} + mock_get_features.assert_not_called() + mock_session_cls.assert_not_called() + + +def test_get_all_provider_load_balancing_configs_populates_cache_and_groups_configs() -> None: + session = Mock() + openai_config = SimpleNamespace(provider_name="openai") + anthropic_config = SimpleNamespace(provider_name="anthropic") + session.scalars.return_value = [openai_config, anthropic_config] + + with ( + patch("core.provider_manager.db", SimpleNamespace(engine=object())), + patch("core.provider_manager.redis_client.get", return_value=None), + patch("core.provider_manager.redis_client.setex") as mock_setex, + patch( + "core.provider_manager.FeatureService.get_features", + return_value=SimpleNamespace(model_load_balancing_enabled=True), + ), + patch("core.provider_manager.Session", return_value=_build_session_context(session)), + ): + result = ProviderManager._get_all_provider_load_balancing_configs("tenant-id") + + mock_setex.assert_called_once_with("tenant:tenant-id:model_load_balancing_enabled", 120, "True") + assert list(result["openai"]) == [openai_config] + assert list(result["anthropic"]) == [anthropic_config] diff --git a/api/tests/unit_tests/core/tools/test_builtin_tool_base.py b/api/tests/unit_tests/core/tools/test_builtin_tool_base.py index f123f60a34a..5d744f88c9b 100644 --- a/api/tests/unit_tests/core/tools/test_builtin_tool_base.py +++ b/api/tests/unit_tests/core/tools/test_builtin_tool_base.py @@ -6,13 +6,13 @@ from typing import Any from unittest.mock import patch import pytest +from graphon.model_runtime.entities.message_entities import UserPromptMessage from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage, ToolProviderType -from dify_graph.model_runtime.entities.message_entities import UserPromptMessage class _BuiltinDummyTool(BuiltinTool): @@ -27,12 +27,12 @@ class _BuiltinDummyTool(BuiltinTool): yield self.create_text_message("ok") -def _build_tool() -> _BuiltinDummyTool: +def _build_tool(user_id: str | None = None) -> _BuiltinDummyTool: entity = ToolEntity( identity=ToolIdentity(author="author", name="tool-a", label=I18nObject(en_US="tool-a"), provider="provider-a"), parameters=[], ) - runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER) + runtime = ToolRuntime(tenant_id="tenant-1", user_id=user_id, invoke_from=InvokeFrom.DEBUGGER) return _BuiltinDummyTool(provider="provider-a", entity=entity, runtime=runtime) @@ -45,7 +45,7 @@ def test_builtin_tool_fork_and_provider_type(): def test_invoke_model_calls_model_invocation_utils_invoke(): - tool = _build_tool() + tool = _build_tool(user_id="runtime-user") with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.invoke", return_value="result") as mock_invoke: assert ( tool.invoke_model( @@ -55,19 +55,47 @@ def test_invoke_model_calls_model_invocation_utils_invoke(): ) == "result" ) - mock_invoke.assert_called_once() + mock_invoke.assert_called_once_with( + user_id="u1", + tenant_id="tenant-1", + tool_type=ToolProviderType.BUILT_IN, + tool_name="tool-a", + prompt_messages=[UserPromptMessage(content="hello")], + caller_user_id="runtime-user", + ) def test_get_max_tokens_returns_value(): - tool = _build_tool() - with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.get_max_llm_context_tokens", return_value=4096): + tool = _build_tool(user_id="runtime-user") + with patch( + "core.tools.builtin_tool.tool.ModelInvocationUtils.get_max_llm_context_tokens", return_value=4096 + ) as mock_get: assert tool.get_max_tokens() == 4096 + mock_get.assert_called_once_with(tenant_id="tenant-1", user_id="runtime-user") def test_get_prompt_tokens_returns_value(): - tool = _build_tool() - with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.calculate_tokens", return_value=7): + tool = _build_tool(user_id="runtime-user") + with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.calculate_tokens", return_value=7) as mock_calculate: assert tool.get_prompt_tokens([UserPromptMessage(content="hello")]) == 7 + mock_calculate.assert_called_once_with( + tenant_id="tenant-1", + prompt_messages=[UserPromptMessage(content="hello")], + user_id="runtime-user", + ) + + +def test_get_prompt_tokens_falls_back_to_tenant_scope_when_runtime_user_id_missing(): + tool = _build_tool() + + with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.calculate_tokens", return_value=7) as mock_calculate: + assert tool.get_prompt_tokens([UserPromptMessage(content="hello")]) == 7 + + mock_calculate.assert_called_once_with( + tenant_id="tenant-1", + prompt_messages=[UserPromptMessage(content="hello")], + user_id=None, + ) def test_runtime_none_raises(): diff --git a/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py b/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py index 62cfb6ce5bf..ee0ce51eec9 100644 --- a/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py +++ b/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py @@ -1,9 +1,13 @@ from __future__ import annotations +import calendar import math +from datetime import date from types import SimpleNamespace import pytest +from graphon.file import FileType +from graphon.model_runtime.entities.model_entities import ModelPropertyKey from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.__base.tool_runtime import ToolRuntime @@ -25,8 +29,6 @@ from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage from core.tools.errors import ToolInvokeError -from dify_graph.file.enums import FileType -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey def _build_builtin_tool(tool_cls: type[BuiltinTool]) -> BuiltinTool: @@ -98,7 +100,13 @@ def test_timezone_conversion_tool(): def test_weekday_tool(): weekday_tool = _build_builtin_tool(WeekdayTool) valid = list(weekday_tool.invoke(user_id="u", tool_parameters={"year": 2024, "month": 1, "day": 1}))[0].message.text - assert "January 1, 2024" in valid + expected_date = date(2024, 1, 1) + expected_message = ( + f"{calendar.month_name[expected_date.month]} " + f"{expected_date.day}, {expected_date.year} " + f"is {calendar.day_name[expected_date.weekday()]}." + ) + assert valid == expected_message invalid = list(weekday_tool.invoke(user_id="u", tool_parameters={"year": 2024, "month": 2, "day": 31}))[ 0 ].message.text @@ -186,13 +194,19 @@ def test_asr_invalid_file(): def test_asr_valid_file_invocation(monkeypatch): asr = _build_builtin_tool(ASRTool) - model_instance = type("M", (), {"invoke_speech2text": lambda self, file, user: "transcript"})() + model_instance = type("M", (), {"invoke_speech2text": lambda self, file: "transcript"})() model_manager = type("Mgr", (), {"get_model_instance": lambda *a, **k: model_instance})() monkeypatch.setattr("core.tools.builtin_tool.providers.audio.tools.asr.download", lambda file: b"audio-bytes") - monkeypatch.setattr("core.tools.builtin_tool.providers.audio.tools.asr.ModelManager", lambda: model_manager) + captured_manager_kwargs = {} + + monkeypatch.setattr( + "core.tools.builtin_tool.providers.audio.tools.asr.ModelManager.for_tenant", + lambda **kwargs: captured_manager_kwargs.update(kwargs) or model_manager, + ) audio_file = SimpleNamespace(type=FileType.AUDIO) ok = list(asr.invoke(user_id="u", tool_parameters={"audio_file": audio_file, "model": "p#m"}))[0].message.text assert ok == "transcript" + assert captured_manager_kwargs == {"tenant_id": "tenant-1", "user_id": "u"} def test_asr_available_models_and_runtime_parameters(monkeypatch): @@ -208,6 +222,7 @@ def test_asr_available_models_and_runtime_parameters(monkeypatch): def test_tts_invoke_returns_messages(monkeypatch): tts = _build_builtin_tool(TTSTool) + captured_manager_kwargs = {} voices_model_instance = type( "TTSM", (), @@ -217,11 +232,15 @@ def test_tts_invoke_returns_messages(monkeypatch): }, )() monkeypatch.setattr( - "core.tools.builtin_tool.providers.audio.tools.tts.ModelManager", - lambda: type("M", (), {"get_model_instance": lambda *a, **k: voices_model_instance})(), + "core.tools.builtin_tool.providers.audio.tools.tts.ModelManager.for_tenant", + lambda **kwargs: ( + captured_manager_kwargs.update(kwargs) + or type("M", (), {"get_model_instance": lambda *a, **k: voices_model_instance})() + ), ) messages = list(tts.invoke(user_id="u", tool_parameters={"model": "p#m", "text": "hello"})) assert [m.type for m in messages] == [ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.BLOB] + assert captured_manager_kwargs == {"tenant_id": "tenant-1", "user_id": "u"} def test_tts_get_available_models_requires_runtime(): @@ -254,8 +273,8 @@ def test_tts_tool_raises_when_voice_unavailable(monkeypatch, voices): }, )() monkeypatch.setattr( - "core.tools.builtin_tool.providers.audio.tools.tts.ModelManager", - lambda: type("Manager", (), {"get_model_instance": lambda *args, **kwargs: model_without_voice})(), + "core.tools.builtin_tool.providers.audio.tools.tts.ModelManager.for_tenant", + lambda **_: type("Manager", (), {"get_model_instance": lambda *args, **kwargs: model_without_voice})(), ) with pytest.raises(ValueError, match="no voice available"): list(tts.invoke(user_id="u", tool_parameters={"model": "p#m", "text": "hello"})) diff --git a/api/tests/unit_tests/core/tools/test_signature.py b/api/tests/unit_tests/core/tools/test_signature.py index a5242a78c50..353988d7a64 100644 --- a/api/tests/unit_tests/core/tools/test_signature.py +++ b/api/tests/unit_tests/core/tools/test_signature.py @@ -6,7 +6,13 @@ from urllib.parse import parse_qs, urlparse import pytest -from core.tools.signature import sign_tool_file, sign_upload_file, verify_tool_file_signature +from core.tools.signature import ( + get_signed_file_url_for_plugin, + sign_tool_file, + sign_upload_file, + verify_plugin_file_signature, + verify_tool_file_signature, +) def test_sign_tool_file_and_verify_roundtrip(monkeypatch: pytest.MonkeyPatch) -> None: @@ -117,3 +123,82 @@ def test_sign_upload_file_uses_files_url_fallback(monkeypatch: pytest.MonkeyPatc assert query["timestamp"][0] assert query["nonce"][0] assert query["sign"][0] + + +def test_get_signed_file_url_for_plugin_and_verify_roundtrip(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000) + monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x06" * 16) + monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "https://internal.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_ACCESS_TIMEOUT", 60) + + url = get_signed_file_url_for_plugin( + filename="report.pdf", + mimetype="application/pdf", + tenant_id="tenant-id", + user_id="user-id", + ) + parsed = urlparse(url) + query = parse_qs(parsed.query) + + assert parsed.netloc == "internal.example.com" + assert parsed.path == "/files/upload/for-plugin" + assert query["tenant_id"] == ["tenant-id"] + assert query["user_id"] == ["user-id"] + assert ( + verify_plugin_file_signature( + filename="report.pdf", + mimetype="application/pdf", + tenant_id="tenant-id", + user_id="user-id", + timestamp=query["timestamp"][0], + nonce=query["nonce"][0], + sign=query["sign"][0], + ) + is True + ) + + +def test_verify_plugin_file_signature_rejects_invalid_signatures(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000) + monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x07" * 16) + monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_ACCESS_TIMEOUT", 30) + + url = get_signed_file_url_for_plugin( + filename="report.pdf", + mimetype="application/pdf", + tenant_id="tenant-id", + user_id="user-id", + ) + query = parse_qs(urlparse(url).query) + + assert ( + verify_plugin_file_signature( + filename="report.pdf", + mimetype="application/pdf", + tenant_id="tenant-id", + user_id="user-id", + timestamp=query["timestamp"][0], + nonce=query["nonce"][0], + sign="bad-signature", + ) + is False + ) + + monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000100) + assert ( + verify_plugin_file_signature( + filename="report.pdf", + mimetype="application/pdf", + tenant_id="tenant-id", + user_id="user-id", + timestamp=query["timestamp"][0], + nonce=query["nonce"][0], + sign=query["sign"][0], + ) + is False + ) diff --git a/api/tests/unit_tests/core/tools/test_tool_file_manager.py b/api/tests/unit_tests/core/tools/test_tool_file_manager.py index cca8254dd69..7fcebde3c55 100644 --- a/api/tests/unit_tests/core/tools/test_tool_file_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_file_manager.py @@ -12,6 +12,7 @@ from unittest.mock import MagicMock, Mock, patch import httpx import pytest +from graphon.file import FileTransferMethod from core.tools.tool_file_manager import ToolFileManager @@ -232,7 +233,14 @@ def test_get_file_generator_returns_none_when_toolfile_missing() -> None: def test_get_file_generator_returns_stream_when_found() -> None: # Arrange manager = ToolFileManager() - tool_file = SimpleNamespace(file_key="k2", mimetype="image/png") + tool_file = SimpleNamespace( + id="tool123", + file_key="k2", + mimetype="image/png", + original_url=None, + name="image.png", + size=12, + ) session = Mock() session.query.return_value.where.return_value.first.return_value = tool_file @@ -240,10 +248,10 @@ def test_get_file_generator_returns_stream_when_found() -> None: with patch("core.tools.tool_file_manager.storage") as storage: stream = iter([b"a", b"b"]) storage.load_stream.return_value = stream - with ( - _patch_session_factory(session), - patch("core.tools.tool_file_manager.ToolFilePydanticModel.model_validate", return_value="validated-file"), - ): + with _patch_session_factory(session): result_stream, result_file = manager.get_file_generator_by_tool_file_id("tool123") assert list(result_stream) == [b"a", b"b"] - assert result_file == "validated-file" + assert result_file is not None + assert result_file.related_id == "tool123" + assert result_file.mime_type == "image/png" + assert result_file.transfer_method == FileTransferMethod.TOOL_FILE diff --git a/api/tests/unit_tests/core/tools/test_tool_label_manager.py b/api/tests/unit_tests/core/tools/test_tool_label_manager.py index 857f4aa1780..8c0e7e9419e 100644 --- a/api/tests/unit_tests/core/tools/test_tool_label_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_label_manager.py @@ -38,11 +38,9 @@ def test_tool_label_manager_filter_tool_labels(): def test_tool_label_manager_update_tool_labels_db(): controller = _api_controller("api-1") with patch("core.tools.tool_label_manager.db") as mock_db: - delete_query = mock_db.session.query.return_value.where.return_value - delete_query.delete.return_value = None ToolLabelManager.update_tool_labels(controller, ["search", "search", "invalid"]) - delete_query.delete.assert_called_once() + mock_db.session.execute.assert_called_once() # only one valid unique label should be inserted. assert mock_db.session.add.call_count == 1 mock_db.session.commit.assert_called_once() diff --git a/api/tests/unit_tests/core/tools/test_tool_manager.py b/api/tests/unit_tests/core/tools/test_tool_manager.py index 0f73e226547..31b68f0b3f3 100644 --- a/api/tests/unit_tests/core/tools/test_tool_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_manager.py @@ -15,6 +15,7 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ( ApiProviderAuthType, + ToolInvokeFrom, ToolParameter, ToolProviderType, ) @@ -219,9 +220,7 @@ def test_get_tool_runtime_builtin_with_credentials_decrypts_and_forks(): with patch.object(ToolManager, "get_builtin_provider", return_value=controller): with patch("core.helper.credential_utils.check_credential_policy_compliance"): with patch("core.tools.tool_manager.db") as mock_db: - mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = ( - builtin_provider - ) + mock_db.session.scalar.return_value = builtin_provider encrypter = Mock() encrypter.decrypt.return_value = {"api_key": "secret"} cache = Mock() @@ -273,7 +272,7 @@ def test_get_tool_runtime_builtin_refreshes_expired_oauth_credentials( ) refreshed = SimpleNamespace(credentials={"token": "new"}, expires_at=123456) - mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = builtin_provider + mock_db.session.scalar.return_value = builtin_provider encrypter = Mock() encrypter.decrypt.return_value = {"token": "old"} encrypter.encrypt.return_value = {"token": "encrypted"} @@ -421,7 +420,7 @@ def test_get_agent_runtime_apply_runtime_parameters(): tool_runtime = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={})) tool_runtime.get_merged_runtime_parameters = Mock(return_value=[parameter]) - with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime): + with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime) as mock_get_tool_runtime: with patch.object(ToolManager, "_convert_tool_parameters_type", return_value={"query": "hello"}): manager = Mock() manager.decrypt_tool_parameters.return_value = {"query": "decrypted"} @@ -437,12 +436,23 @@ def test_get_agent_runtime_apply_runtime_parameters(): tenant_id="tenant-1", app_id="app-1", agent_tool=agent_tool, + user_id="user-1", invoke_from=InvokeFrom.DEBUGGER, variable_pool=None, ) assert result is tool_runtime assert tool_runtime.runtime.runtime_parameters["query"] == "decrypted" + mock_get_tool_runtime.assert_called_once_with( + provider_type=ToolProviderType.API, + provider_id="api-1", + tool_name="search", + tenant_id="tenant-1", + user_id="user-1", + invoke_from=InvokeFrom.DEBUGGER, + tool_invoke_from=ToolInvokeFrom.AGENT, + credential_id=None, + ) def test_get_workflow_runtime_apply_runtime_parameters(): @@ -463,7 +473,7 @@ def test_get_workflow_runtime_apply_runtime_parameters(): ) tool_runtime2 = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={})) tool_runtime2.get_merged_runtime_parameters = Mock(return_value=[parameter]) - with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime2): + with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime2) as mock_get_tool_runtime: with patch.object(ToolManager, "_convert_tool_parameters_type", return_value={"query": "workflow"}): manager = Mock() manager.decrypt_tool_parameters.return_value = {"query": "workflow-dec"} @@ -473,12 +483,23 @@ def test_get_workflow_runtime_apply_runtime_parameters(): app_id="app-1", node_id="node-1", workflow_tool=workflow_tool, + user_id="user-1", invoke_from=InvokeFrom.DEBUGGER, variable_pool=None, ) assert workflow_result is tool_runtime2 assert tool_runtime2.runtime.runtime_parameters["query"] == "workflow-dec" + mock_get_tool_runtime.assert_called_once_with( + provider_type=ToolProviderType.API, + provider_id="api-1", + tool_name="search", + tenant_id="tenant-1", + user_id="user-1", + invoke_from=InvokeFrom.DEBUGGER, + tool_invoke_from=ToolInvokeFrom.WORKFLOW, + credential_id=None, + ) def test_get_agent_runtime_raises_when_runtime_missing(): @@ -520,17 +541,28 @@ def test_get_tool_runtime_from_plugin_only_uses_form_parameters(): tool_entity = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={})) tool_entity.get_merged_runtime_parameters = Mock(return_value=[form_param, llm_param]) - with patch.object(ToolManager, "get_tool_runtime", return_value=tool_entity): + with patch.object(ToolManager, "get_tool_runtime", return_value=tool_entity) as mock_get_tool_runtime: result = ToolManager.get_tool_runtime_from_plugin( tool_type=ToolProviderType.API, tenant_id="tenant-1", provider="api-1", tool_name="search", tool_parameters={"q": "hello", "llm": "ignore"}, + user_id="user-1", ) assert result is tool_entity assert tool_entity.runtime.runtime_parameters == {"q": "hello"} + mock_get_tool_runtime.assert_called_once_with( + provider_type=ToolProviderType.API, + provider_id="api-1", + tool_name="search", + tenant_id="tenant-1", + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + tool_invoke_from=ToolInvokeFrom.PLUGIN, + credential_id=None, + ) def test_hardcoded_provider_icon_success(): @@ -664,12 +696,10 @@ def test_get_api_provider_controller_returns_controller_and_credentials(): privacy_policy="privacy", custom_disclaimer="disclaimer", ) - db_query = Mock() - db_query.where.return_value.first.return_value = provider controller = Mock() with patch("core.tools.tool_manager.db") as mock_db: - mock_db.session.query.return_value = db_query + mock_db.session.scalar.return_value = provider with patch( "core.tools.tool_manager.ApiToolProviderController.from_db", return_value=controller ) as mock_from_db: @@ -696,12 +726,10 @@ def test_user_get_api_provider_masks_credentials_and_adds_labels(): privacy_policy="privacy", custom_disclaimer="disclaimer", ) - db_query = Mock() - db_query.where.return_value.first.return_value = provider controller = Mock() with patch("core.tools.tool_manager.db") as mock_db: - mock_db.session.query.return_value = db_query + mock_db.session.scalar.return_value = provider with patch("core.tools.tool_manager.ApiToolProviderController.from_db", return_value=controller): encrypter = Mock() encrypter.decrypt.return_value = {"api_key_value": "secret"} @@ -716,7 +744,7 @@ def test_user_get_api_provider_masks_credentials_and_adds_labels(): def test_get_api_provider_controller_not_found_raises(): with patch("core.tools.tool_manager.db") as mock_db: - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None with pytest.raises(ToolProviderNotFoundError, match="api provider missing not found"): ToolManager.get_api_provider_controller("tenant-1", "missing") @@ -775,14 +803,14 @@ def test_generate_tool_icon_urls_for_workflow_and_api(): workflow_provider = SimpleNamespace(icon='{"background": "#222", "content": "W"}') api_provider = SimpleNamespace(icon='{"background": "#333", "content": "A"}') with patch("core.tools.tool_manager.db") as mock_db: - mock_db.session.query.return_value.where.return_value.first.side_effect = [workflow_provider, api_provider] + mock_db.session.scalar.side_effect = [workflow_provider, api_provider] assert ToolManager.generate_workflow_tool_icon_url("tenant-1", "wf-1") == {"background": "#222", "content": "W"} assert ToolManager.generate_api_tool_icon_url("tenant-1", "api-1") == {"background": "#333", "content": "A"} def test_generate_tool_icon_urls_missing_workflow_and_api_use_default(): with patch("core.tools.tool_manager.db") as mock_db: - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None assert ToolManager.generate_workflow_tool_icon_url("tenant-1", "missing")["background"] == "#252525" assert ToolManager.generate_api_tool_icon_url("tenant-1", "missing")["background"] == "#252525" diff --git a/api/tests/unit_tests/core/tools/utils/test_configuration.py b/api/tests/unit_tests/core/tools/utils/test_configuration.py index 5ceaa08893b..ae5638784c4 100644 --- a/api/tests/unit_tests/core/tools/utils/test_configuration.py +++ b/api/tests/unit_tests/core/tools/utils/test_configuration.py @@ -110,7 +110,7 @@ def test_encrypt_tool_parameters(): assert encrypted["plain"] == "x" -def test_decrypt_tool_parameters_cache_hit_and_miss(): +def test_decrypt_tool_parameters_cache_hit_and_miss(monkeypatch): manager = _build_manager() with ( @@ -139,7 +139,7 @@ def test_delete_tool_parameters_cache(): mock_delete.assert_called_once() -def test_configuration_manager_decrypt_suppresses_errors(): +def test_configuration_manager_decrypt_suppresses_errors(monkeypatch): manager = _build_manager() with ( patch.object(ToolParameterCache, "get", return_value=None), diff --git a/api/tests/unit_tests/core/tools/utils/test_message_transformer.py b/api/tests/unit_tests/core/tools/utils/test_message_transformer.py index af3cdddd5f8..6454a5bcd1f 100644 --- a/api/tests/unit_tests/core/tools/utils/test_message_transformer.py +++ b/api/tests/unit_tests/core/tools/utils/test_message_transformer.py @@ -84,3 +84,24 @@ def test_transform_tool_invoke_messages_mimetype_key_present_but_none(): # meta is preserved (still contains mime_type: None) assert "mime_type" in (o.meta or {}) assert o.meta["mime_type"] is None + assert o.meta["tool_file_id"] == "fake-tool-file-id" + + +def test_transform_tool_invoke_messages_parses_existing_tool_file_link_meta(): + msg = ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=ToolInvokeMessage.TextMessage(text="/files/tools/existing-tool-file.png"), + meta={}, + ) + + out = list( + mt.ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=_gen([msg]), + user_id="u1", + tenant_id="t1", + conversation_id="c1", + ) + ) + + assert len(out) == 1 + assert out[0].meta["tool_file_id"] == "existing-tool-file" diff --git a/api/tests/unit_tests/core/tools/utils/test_misc_utils_extra.py b/api/tests/unit_tests/core/tools/utils/test_misc_utils_extra.py index 4ce73272bf0..a93624123e2 100644 --- a/api/tests/unit_tests/core/tools/utils/test_misc_utils_extra.py +++ b/api/tests/unit_tests/core/tools/utils/test_misc_utils_extra.py @@ -263,7 +263,7 @@ def test_single_dataset_retriever_non_economy_run_sorts_context_and_resources(): ) db_session = Mock() db_session.scalar.side_effect = [dataset, lookup_doc_low, lookup_doc_high] - db_session.query.return_value.filter_by.return_value.first.return_value = dataset + db_session.get.return_value = dataset tool = SingleDatasetRetrieverTool( tenant_id="tenant-1", @@ -444,7 +444,7 @@ def test_multi_dataset_retriever_run_orders_segments_and_returns_resources(): ) db_session = Mock() db_session.scalars.return_value.all.return_value = [segment_for_node_2, segment_for_node_1] - db_session.query.return_value.filter_by.return_value.first.side_effect = [ + db_session.get.side_effect = [ SimpleNamespace(id="dataset-2", name="Dataset Two"), SimpleNamespace(id="dataset-1", name="Dataset One"), ] diff --git a/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py b/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py index 2acae889b24..52f262e1cf1 100644 --- a/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py +++ b/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py @@ -13,10 +13,8 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest - -from core.tools.utils.model_invocation_utils import InvokeModelError, ModelInvocationUtils -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey -from dify_graph.model_runtime.errors.invoke import ( +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, @@ -24,6 +22,8 @@ from dify_graph.model_runtime.errors.invoke import ( InvokeServerUnavailableError, ) +from core.tools.utils.model_invocation_utils import InvokeModelError, ModelInvocationUtils + def _mock_model_instance(*, schema: dict | None = None) -> SimpleNamespace: model_type_instance = Mock() @@ -60,20 +60,23 @@ def test_get_max_llm_context_tokens_branches(model_instance, expected, error_mat manager = Mock() manager.get_default_model_instance.return_value = model_instance - with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory: if error_match: with pytest.raises(InvokeModelError, match=error_match): - ModelInvocationUtils.get_max_llm_context_tokens("tenant") + ModelInvocationUtils.get_max_llm_context_tokens("tenant", user_id="user-1") else: - assert ModelInvocationUtils.get_max_llm_context_tokens("tenant") == expected + assert ModelInvocationUtils.get_max_llm_context_tokens("tenant", user_id="user-1") == expected + + mock_factory.assert_called_once_with(tenant_id="tenant", user_id="user-1") def test_calculate_tokens_handles_missing_model(): manager = Mock() manager.get_default_model_instance.return_value = None - with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory: with pytest.raises(InvokeModelError, match="Model not found"): ModelInvocationUtils.calculate_tokens("tenant", []) + mock_factory.assert_called_once_with(tenant_id="tenant", user_id=None) def test_invoke_success_and_error_mappings(): @@ -98,7 +101,7 @@ def test_invoke_success_and_error_mappings(): db_mock = SimpleNamespace(session=Mock()) - with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory: with patch("core.tools.utils.model_invocation_utils.ToolModelInvoke", _ToolModelInvoke): with patch("core.tools.utils.model_invocation_utils.db", db_mock): response = ModelInvocationUtils.invoke( @@ -107,11 +110,13 @@ def test_invoke_success_and_error_mappings(): tool_type="builtin", tool_name="tool-a", prompt_messages=[], + caller_user_id="caller-1", ) assert response.message.content == "ok" assert db_mock.session.add.call_count == 1 assert db_mock.session.commit.call_count == 2 + mock_factory.assert_called_once_with(tenant_id="tenant", user_id="caller-1") @pytest.mark.parametrize( @@ -145,7 +150,7 @@ def test_invoke_error_mappings(exc, expected): db_mock = SimpleNamespace(session=Mock()) - with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory: with patch("core.tools.utils.model_invocation_utils.ToolModelInvoke", _ToolModelInvoke): with patch("core.tools.utils.model_invocation_utils.db", db_mock): with pytest.raises(InvokeModelError, match=expected): @@ -156,3 +161,4 @@ def test_invoke_error_mappings(exc, expected): tool_name="tool-a", prompt_messages=[], ) + mock_factory.assert_called_once_with(tenant_id="tenant", user_id="u1") diff --git a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py index dd79b797186..0e3a7e623a8 100644 --- a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py +++ b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py @@ -1,9 +1,9 @@ import pytest +from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.tools.entities.tool_entities import ToolParameter, WorkflowToolParameterConfiguration from core.tools.errors import WorkflowToolHumanInputNotSupportedError from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType def test_ensure_no_human_input_nodes_passes_for_non_human_input(): diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py index dd140cbb276..2607861b59d 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py @@ -4,6 +4,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, Mock, patch import pytest +from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ( @@ -13,7 +14,6 @@ from core.tools.entities.tool_entities import ( ToolProviderType, ) from core.tools.workflow_as_tool.provider import WorkflowToolProviderController -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType def _controller() -> WorkflowToolProviderController: diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py index cc00f796980..c20edd74004 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py @@ -11,6 +11,7 @@ from typing import Any from unittest.mock import MagicMock, Mock, patch import pytest +from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod, FileType from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.__base.tool_runtime import ToolRuntime @@ -24,7 +25,6 @@ from core.tools.entities.tool_entities import ( ) from core.tools.errors import ToolInvokeError from core.tools.workflow_as_tool.tool import WorkflowTool -from dify_graph.file import FILE_MODEL_IDENTITY class StubScalars: @@ -439,6 +439,32 @@ def _setup_transform_args_tool(monkeypatch: pytest.MonkeyPatch) -> WorkflowTool: def test_transform_args_valid_files(monkeypatch: pytest.MonkeyPatch): """Transform args into parameters and files payloads.""" tool = _setup_transform_args_tool(monkeypatch) + build_file_from_stored_mapping = MagicMock( + side_effect=[ + SimpleNamespace( + transfer_method=FileTransferMethod.TOOL_FILE, + type=FileType.IMAGE, + reference="tool-1", + generate_url=lambda: None, + ), + SimpleNamespace( + transfer_method=FileTransferMethod.LOCAL_FILE, + type=FileType.DOCUMENT, + reference="upload-1", + generate_url=lambda: None, + ), + SimpleNamespace( + transfer_method=FileTransferMethod.REMOTE_URL, + type=FileType.DOCUMENT, + reference=None, + generate_url=lambda: "https://example.com/a.pdf", + ), + ] + ) + monkeypatch.setattr( + "core.tools.workflow_as_tool.tool.build_file_from_stored_mapping", + build_file_from_stored_mapping, + ) params, files = tool._transform_args( { @@ -470,6 +496,8 @@ def test_transform_args_valid_files(monkeypatch: pytest.MonkeyPatch): assert any(file_item.get("tool_file_id") == "tool-1" for file_item in files) assert any(file_item.get("upload_file_id") == "upload-1" for file_item in files) assert any(file_item.get("url") == "https://example.com/a.pdf" for file_item in files) + assert build_file_from_stored_mapping.call_count == 3 + assert all(call.kwargs["tenant_id"] == "test_tool" for call in build_file_from_stored_mapping.call_args_list) def test_transform_args_invalid_files(monkeypatch: pytest.MonkeyPatch): diff --git a/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py b/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py index bcb1d745e3a..78622b78b6b 100644 --- a/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py +++ b/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py @@ -11,6 +11,7 @@ from datetime import datetime from unittest.mock import MagicMock, patch import pytest +from graphon.enums import BuiltinNodeTypes, NodeType from core.plugin.entities.request import TriggerInvokeEventResponse from core.trigger.constants import ( @@ -26,7 +27,6 @@ from core.trigger.debug.event_selectors import ( select_trigger_debug_events, ) from core.trigger.debug.events import PluginTriggerDebugEvent, WebhookDebugEvent -from dify_graph.enums import BuiltinNodeTypes, NodeType from tests.unit_tests.core.trigger.conftest import VALID_PROVIDER_ID diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py index 91259c9a454..7406b88270b 100644 --- a/api/tests/unit_tests/core/variables/test_segment.py +++ b/api/tests/unit_tests/core/variables/test_segment.py @@ -2,14 +2,10 @@ import dataclasses import orjson import pytest -from pydantic import BaseModel - -from core.helper import encrypter -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.runtime import VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables.segment_group import SegmentGroup -from dify_graph.variables.segments import ( +from graphon.file import File, FileTransferMethod, FileType +from graphon.runtime import VariablePool +from graphon.variables.segment_group import SegmentGroup +from graphon.variables.segments import ( ArrayAnySegment, ArrayFileSegment, ArrayNumberSegment, @@ -25,13 +21,13 @@ from dify_graph.variables.segments import ( StringSegment, get_segment_discriminator, ) -from dify_graph.variables.types import SegmentType -from dify_graph.variables.utils import ( +from graphon.variables.types import SegmentType +from graphon.variables.utils import ( dumps_with_segments, segment_orjson_default, to_selector, ) -from dify_graph.variables.variables import ( +from graphon.variables.variables import ( ArrayAnyVariable, ArrayFileVariable, ArrayNumberVariable, @@ -46,16 +42,35 @@ from dify_graph.variables.variables import ( StringVariable, Variable, ) +from pydantic import BaseModel + +from core.helper import encrypter +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_variables_to_pool + + +def _build_variable_pool( + *, + system_variables: list[Variable] | None = None, + environment_variables: list[Variable] | None = None, +) -> VariablePool: + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_variables or [], + environment_variables=environment_variables or [], + ), + ) + return variable_pool def test_segment_group_to_text(): - variable_pool = VariablePool( - system_variables=SystemVariable(user_id="fake-user-id"), - user_inputs={}, + variable_pool = _build_variable_pool( + system_variables=build_system_variables(user_id="fake-user-id"), environment_variables=[ SecretVariable(name="secret_key", value="fake-secret-key"), ], - conversation_variables=[], ) variable_pool.add(("node_id", "custom_query"), "fake-user-query") template = ( @@ -71,11 +86,8 @@ def test_segment_group_to_text(): def test_convert_constant_to_segment_group(): - variable_pool = VariablePool( - system_variables=SystemVariable(user_id="1", app_id="1", workflow_id="1"), - user_inputs={}, - environment_variables=[], - conversation_variables=[], + variable_pool = _build_variable_pool( + system_variables=build_system_variables(user_id="1", app_id="1", workflow_id="1"), ) template = "Hello, world!" segments_group = variable_pool.convert_template(template) @@ -84,12 +96,7 @@ def test_convert_constant_to_segment_group(): def test_convert_variable_to_segment_group(): - variable_pool = VariablePool( - system_variables=SystemVariable(user_id="fake-user-id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[], - ) + variable_pool = _build_variable_pool(system_variables=build_system_variables(user_id="fake-user-id")) template = "{{#sys.user_id#}}" segments_group = variable_pool.convert_template(template) assert segments_group.text == "fake-user-id" @@ -116,7 +123,6 @@ def create_test_file( ) -> File: """Factory function to create File objects for testing""" return File( - tenant_id="test-tenant", type=file_type, transfer_method=transfer_method, filename=filename, @@ -190,7 +196,6 @@ class TestSegmentDumpAndLoad: loaded_file = loaded_segment.value assert isinstance(orig_file, File) assert isinstance(loaded_file, File) - assert loaded_file.tenant_id == orig_file.tenant_id assert loaded_file.type == orig_file.type assert loaded_file.filename == orig_file.filename else: @@ -234,7 +239,6 @@ class TestSegmentDumpAndLoad: loaded_file = loaded_variable.value assert isinstance(orig_file, File) assert isinstance(loaded_file, File) - assert loaded_file.tenant_id == orig_file.tenant_id assert loaded_file.type == orig_file.type assert loaded_file.filename == orig_file.filename else: diff --git a/api/tests/unit_tests/core/variables/test_segment_type.py b/api/tests/unit_tests/core/variables/test_segment_type.py index bb234d9bbd4..37ecd2890bb 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type.py +++ b/api/tests/unit_tests/core/variables/test_segment_type.py @@ -1,8 +1,7 @@ import pytest - -from dify_graph.variables.segment_group import SegmentGroup -from dify_graph.variables.segments import StringSegment -from dify_graph.variables.types import ArrayValidation, SegmentType +from graphon.variables.segment_group import SegmentGroup +from graphon.variables.segments import StringSegment +from graphon.variables.types import ArrayValidation, SegmentType class TestSegmentTypeIsArrayType: diff --git a/api/tests/unit_tests/core/variables/test_segment_type_validation.py b/api/tests/unit_tests/core/variables/test_segment_type_validation.py index 41ce4834476..09254e17a30 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type_validation.py +++ b/api/tests/unit_tests/core/variables/test_segment_type_validation.py @@ -9,11 +9,9 @@ from dataclasses import dataclass from typing import Any import pytest - -from dify_graph.file.enums import FileTransferMethod, FileType -from dify_graph.file.models import File -from dify_graph.variables.segment_group import SegmentGroup -from dify_graph.variables.segments import ( +from graphon.file import File, FileTransferMethod, FileType +from graphon.variables.segment_group import SegmentGroup +from graphon.variables.segments import ( ArrayFileSegment, BooleanSegment, FileSegment, @@ -22,7 +20,7 @@ from dify_graph.variables.segments import ( ObjectSegment, StringSegment, ) -from dify_graph.variables.types import ArrayValidation, SegmentType +from graphon.variables.types import ArrayValidation, SegmentType def create_test_file( diff --git a/api/tests/unit_tests/core/variables/test_variables.py b/api/tests/unit_tests/core/variables/test_variables.py index dd0fe2e65a7..75b01bf42e9 100644 --- a/api/tests/unit_tests/core/variables/test_variables.py +++ b/api/tests/unit_tests/core/variables/test_variables.py @@ -1,7 +1,5 @@ import pytest -from pydantic import ValidationError - -from dify_graph.variables import ( +from graphon.variables import ( ArrayFileVariable, ArrayVariable, FloatVariable, @@ -11,7 +9,8 @@ from dify_graph.variables import ( SegmentType, StringVariable, ) -from dify_graph.variables.variables import VariableBase +from graphon.variables.variables import VariableBase +from pydantic import ValidationError def test_frozen_variables(): diff --git a/api/tests/unit_tests/core/workflow/context/test_execution_context.py b/api/tests/unit_tests/core/workflow/context/test_execution_context.py index d09b8397c3b..3ce4bb753b9 100644 --- a/api/tests/unit_tests/core/workflow/context/test_execution_context.py +++ b/api/tests/unit_tests/core/workflow/context/test_execution_context.py @@ -9,7 +9,7 @@ from unittest.mock import MagicMock import pytest from pydantic import BaseModel -from dify_graph.context.execution_context import ( +from context.execution_context import ( AppContext, ExecutionContext, ExecutionContextBuilder, @@ -286,7 +286,7 @@ class TestCaptureCurrentContext: def test_capture_current_context_returns_context(self): """Test that capture_current_context returns a valid context.""" - from dify_graph.context.execution_context import capture_current_context + from context.execution_context import capture_current_context result = capture_current_context() @@ -303,7 +303,7 @@ class TestCaptureCurrentContext: test_var = contextvars.ContextVar("capture_test_var") test_var.set("test_value_123") - from dify_graph.context.execution_context import capture_current_context + from context.execution_context import capture_current_context result = capture_current_context() @@ -313,12 +313,12 @@ class TestCaptureCurrentContext: class TestTenantScopedContextRegistry: def setup_method(self): - from dify_graph.context import reset_context_provider + from context import reset_context_provider reset_context_provider() def teardown_method(self): - from dify_graph.context import reset_context_provider + from context import reset_context_provider reset_context_provider() @@ -333,7 +333,7 @@ class TestTenantScopedContextRegistry: assert read_context("workflow.sandbox", tenant_id="t2").base_url == "http://t2" def test_missing_provider_raises_keyerror(self): - from dify_graph.context import ContextProviderNotFoundError + from context import ContextProviderNotFoundError with pytest.raises(ContextProviderNotFoundError): read_context("missing", tenant_id="unknown") diff --git a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py deleted file mode 100644 index 22792eb5b3a..00000000000 --- a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py +++ /dev/null @@ -1,296 +0,0 @@ -import json -from time import time -from unittest.mock import MagicMock, patch - -import pytest - -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool -from dify_graph.variables.variables import StringVariable - - -class StubCoordinator: - def __init__(self) -> None: - self.state = "initial" - - def dumps(self) -> str: - return json.dumps({"state": self.state}) - - def loads(self, data: str) -> None: - payload = json.loads(data) - self.state = payload["state"] - - -class TestGraphRuntimeState: - def test_property_getters_and_setters(self): - # FIXME(-LAN-): Mock VariablePool if needed - variable_pool = VariablePool() - start_time = time() - - state = GraphRuntimeState(variable_pool=variable_pool, start_at=start_time) - - # Test variable_pool property (read-only) - assert state.variable_pool == variable_pool - - # Test start_at property - assert state.start_at == start_time - new_time = time() + 100 - state.start_at = new_time - assert state.start_at == new_time - - # Test total_tokens property - assert state.total_tokens == 0 - state.total_tokens = 100 - assert state.total_tokens == 100 - - # Test node_run_steps property - assert state.node_run_steps == 0 - state.node_run_steps = 5 - assert state.node_run_steps == 5 - - def test_outputs_immutability(self): - variable_pool = VariablePool() - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - - # Test that getting outputs returns a copy - outputs1 = state.outputs - outputs2 = state.outputs - assert outputs1 == outputs2 - assert outputs1 is not outputs2 # Different objects - - # Test that modifying retrieved outputs doesn't affect internal state - outputs = state.outputs - outputs["test"] = "value" - assert "test" not in state.outputs - - # Test set_output method - state.set_output("key1", "value1") - assert state.get_output("key1") == "value1" - - # Test update_outputs method - state.update_outputs({"key2": "value2", "key3": "value3"}) - assert state.get_output("key2") == "value2" - assert state.get_output("key3") == "value3" - - def test_llm_usage_immutability(self): - variable_pool = VariablePool() - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - - # Test that getting llm_usage returns a copy - usage1 = state.llm_usage - usage2 = state.llm_usage - assert usage1 is not usage2 # Different objects - - def test_type_validation(self): - variable_pool = VariablePool() - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - - # Test total_tokens validation - with pytest.raises(ValueError): - state.total_tokens = -1 - - # Test node_run_steps validation - with pytest.raises(ValueError): - state.node_run_steps = -1 - - def test_helper_methods(self): - variable_pool = VariablePool() - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - - # Test increment_node_run_steps - initial_steps = state.node_run_steps - state.increment_node_run_steps() - assert state.node_run_steps == initial_steps + 1 - - # Test add_tokens - initial_tokens = state.total_tokens - state.add_tokens(50) - assert state.total_tokens == initial_tokens + 50 - - # Test add_tokens validation - with pytest.raises(ValueError): - state.add_tokens(-1) - - def test_ready_queue_default_instantiation(self): - state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time()) - - queue = state.ready_queue - - from dify_graph.graph_engine.ready_queue import InMemoryReadyQueue - - assert isinstance(queue, InMemoryReadyQueue) - - def test_graph_execution_lazy_instantiation(self): - state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time()) - - execution = state.graph_execution - - from dify_graph.graph_engine.domain.graph_execution import GraphExecution - - assert isinstance(execution, GraphExecution) - assert execution.workflow_id == "" - assert state.graph_execution is execution - - def test_response_coordinator_configuration(self): - variable_pool = VariablePool() - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - - with pytest.raises(ValueError): - _ = state.response_coordinator - - mock_graph = MagicMock() - with patch( - "dify_graph.graph_engine.response_coordinator.ResponseStreamCoordinator", autospec=True - ) as coordinator_cls: - coordinator_instance = coordinator_cls.return_value - state.configure(graph=mock_graph) - - assert state.response_coordinator is coordinator_instance - coordinator_cls.assert_called_once_with(variable_pool=variable_pool, graph=mock_graph) - - # Configure again with same graph should be idempotent - state.configure(graph=mock_graph) - - other_graph = MagicMock() - with pytest.raises(ValueError): - state.attach_graph(other_graph) - - def test_read_only_wrapper_exposes_additional_state(self): - state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time()) - state.configure() - - wrapper = ReadOnlyGraphRuntimeStateWrapper(state) - - assert wrapper.ready_queue_size == 0 - assert wrapper.exceptions_count == 0 - - def test_read_only_wrapper_serializes_runtime_state(self): - state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time()) - state.total_tokens = 5 - state.set_output("result", {"success": True}) - state.ready_queue.put("node-1") - - wrapper = ReadOnlyGraphRuntimeStateWrapper(state) - - wrapper_snapshot = json.loads(wrapper.dumps()) - state_snapshot = json.loads(state.dumps()) - - assert wrapper_snapshot == state_snapshot - - def test_dumps_and_loads_roundtrip_with_response_coordinator(self): - variable_pool = VariablePool() - variable_pool.add(("node1", "value"), "payload") - - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - state.total_tokens = 10 - state.node_run_steps = 3 - state.set_output("final", {"result": True}) - usage = LLMUsage.from_metadata( - { - "prompt_tokens": 2, - "completion_tokens": 3, - "total_tokens": 5, - "total_price": "1.23", - "currency": "USD", - "latency": 0.5, - } - ) - state.llm_usage = usage - state.ready_queue.put("node-A") - - graph_execution = state.graph_execution - graph_execution.workflow_id = "wf-123" - graph_execution.exceptions_count = 4 - graph_execution.started = True - - mock_graph = MagicMock() - stub = StubCoordinator() - with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=stub, autospec=True): - state.attach_graph(mock_graph) - - stub.state = "configured" - - snapshot = state.dumps() - - restored = GraphRuntimeState.from_snapshot(snapshot) - - assert restored.total_tokens == 10 - assert restored.node_run_steps == 3 - assert restored.get_output("final") == {"result": True} - assert restored.llm_usage.total_tokens == usage.total_tokens - assert restored.ready_queue.qsize() == 1 - assert restored.ready_queue.get(timeout=0.01) == "node-A" - - restored_segment = restored.variable_pool.get(("node1", "value")) - assert restored_segment is not None - assert restored_segment.value == "payload" - - restored_execution = restored.graph_execution - assert restored_execution.workflow_id == "wf-123" - assert restored_execution.exceptions_count == 4 - assert restored_execution.started is True - - new_stub = StubCoordinator() - with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub, autospec=True): - restored.attach_graph(mock_graph) - - assert new_stub.state == "configured" - - def test_loads_rehydrates_existing_instance(self): - variable_pool = VariablePool() - variable_pool.add(("node", "key"), "value") - - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - state.total_tokens = 7 - state.node_run_steps = 2 - state.set_output("foo", "bar") - state.ready_queue.put("node-1") - - execution = state.graph_execution - execution.workflow_id = "wf-456" - execution.started = True - - mock_graph = MagicMock() - original_stub = StubCoordinator() - with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=original_stub, autospec=True): - state.attach_graph(mock_graph) - - original_stub.state = "configured" - snapshot = state.dumps() - - new_stub = StubCoordinator() - with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub, autospec=True): - restored = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0) - restored.attach_graph(mock_graph) - restored.loads(snapshot) - - assert restored.total_tokens == 7 - assert restored.node_run_steps == 2 - assert restored.get_output("foo") == "bar" - assert restored.ready_queue.qsize() == 1 - assert restored.ready_queue.get(timeout=0.01) == "node-1" - - restored_segment = restored.variable_pool.get(("node", "key")) - assert restored_segment is not None - assert restored_segment.value == "value" - - restored_execution = restored.graph_execution - assert restored_execution.workflow_id == "wf-456" - assert restored_execution.started is True - - assert new_stub.state == "configured" - - def test_snapshot_restore_preserves_updated_conversation_variable(self): - variable_pool = VariablePool( - conversation_variables=[StringVariable(name="session_name", value="before")], - ) - variable_pool.add((CONVERSATION_VARIABLE_NODE_ID, "session_name"), "after") - - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - snapshot = state.dumps() - restored = GraphRuntimeState.from_snapshot(snapshot) - - restored_value = restored.variable_pool.get((CONVERSATION_VARIABLE_NODE_ID, "session_name")) - assert restored_value is not None - assert restored_value.value == "after" diff --git a/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py b/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py deleted file mode 100644 index 158f7018b51..00000000000 --- a/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py +++ /dev/null @@ -1,88 +0,0 @@ -""" -Tests for PauseReason discriminated union serialization/deserialization. -""" - -import pytest -from pydantic import BaseModel, ValidationError - -from dify_graph.entities.pause_reason import ( - HumanInputRequired, - PauseReason, - SchedulingPause, -) - - -class _Holder(BaseModel): - """Helper model that embeds PauseReason for union tests.""" - - reason: PauseReason - - -class TestPauseReasonDiscriminator: - """Test suite for PauseReason union discriminator.""" - - @pytest.mark.parametrize( - ("dict_value", "expected"), - [ - pytest.param( - { - "reason": { - "TYPE": "human_input_required", - "form_id": "form_id", - "form_content": "form_content", - "node_id": "node_id", - "node_title": "node_title", - }, - }, - HumanInputRequired( - form_id="form_id", - form_content="form_content", - node_id="node_id", - node_title="node_title", - ), - id="HumanInputRequired", - ), - pytest.param( - { - "reason": { - "TYPE": "scheduled_pause", - "message": "Hold on", - } - }, - SchedulingPause(message="Hold on"), - id="SchedulingPause", - ), - ], - ) - def test_model_validate(self, dict_value, expected): - """Ensure scheduled pause payloads with lowercase TYPE deserialize.""" - holder = _Holder.model_validate(dict_value) - - assert type(holder.reason) == type(expected) - assert holder.reason == expected - - @pytest.mark.parametrize( - "reason", - [ - HumanInputRequired( - form_id="form_id", - form_content="form_content", - node_id="node_id", - node_title="node_title", - ), - SchedulingPause(message="Hold on"), - ], - ids=lambda x: type(x).__name__, - ) - def test_model_construct(self, reason): - holder = _Holder(reason=reason) - assert holder.reason == reason - - def test_model_construct_with_invalid_type(self): - with pytest.raises(ValidationError): - holder = _Holder(reason=object()) # type: ignore - - def test_unknown_type_fails_validation(self): - """Unknown TYPE values should raise a validation error.""" - with pytest.raises(ValidationError): - _Holder.model_validate({"reason": {"TYPE": "UNKNOWN"}}) diff --git a/api/tests/unit_tests/core/workflow/entities/test_template.py b/api/tests/unit_tests/core/workflow/entities/test_template.py deleted file mode 100644 index 2d4c7f7b77e..00000000000 --- a/api/tests/unit_tests/core/workflow/entities/test_template.py +++ /dev/null @@ -1,87 +0,0 @@ -"""Tests for template module.""" - -from dify_graph.nodes.base.template import Template, TextSegment, VariableSegment - - -class TestTemplate: - """Test Template class functionality.""" - - def test_from_answer_template_simple(self): - """Test parsing a simple answer template.""" - template_str = "Hello, {{#node1.name#}}!" - template = Template.from_answer_template(template_str) - - assert len(template.segments) == 3 - assert isinstance(template.segments[0], TextSegment) - assert template.segments[0].text == "Hello, " - assert isinstance(template.segments[1], VariableSegment) - assert template.segments[1].selector == ["node1", "name"] - assert isinstance(template.segments[2], TextSegment) - assert template.segments[2].text == "!" - - def test_from_answer_template_multiple_vars(self): - """Test parsing an answer template with multiple variables.""" - template_str = "Hello {{#node1.name#}}, your age is {{#node2.age#}}." - template = Template.from_answer_template(template_str) - - assert len(template.segments) == 5 - assert isinstance(template.segments[0], TextSegment) - assert template.segments[0].text == "Hello " - assert isinstance(template.segments[1], VariableSegment) - assert template.segments[1].selector == ["node1", "name"] - assert isinstance(template.segments[2], TextSegment) - assert template.segments[2].text == ", your age is " - assert isinstance(template.segments[3], VariableSegment) - assert template.segments[3].selector == ["node2", "age"] - assert isinstance(template.segments[4], TextSegment) - assert template.segments[4].text == "." - - def test_from_answer_template_no_vars(self): - """Test parsing an answer template with no variables.""" - template_str = "Hello, world!" - template = Template.from_answer_template(template_str) - - assert len(template.segments) == 1 - assert isinstance(template.segments[0], TextSegment) - assert template.segments[0].text == "Hello, world!" - - def test_from_end_outputs_single(self): - """Test creating template from End node outputs with single variable.""" - outputs_config = [{"variable": "text", "value_selector": ["node1", "text"]}] - template = Template.from_end_outputs(outputs_config) - - assert len(template.segments) == 1 - assert isinstance(template.segments[0], VariableSegment) - assert template.segments[0].selector == ["node1", "text"] - - def test_from_end_outputs_multiple(self): - """Test creating template from End node outputs with multiple variables.""" - outputs_config = [ - {"variable": "text", "value_selector": ["node1", "text"]}, - {"variable": "result", "value_selector": ["node2", "result"]}, - ] - template = Template.from_end_outputs(outputs_config) - - assert len(template.segments) == 3 - assert isinstance(template.segments[0], VariableSegment) - assert template.segments[0].selector == ["node1", "text"] - assert template.segments[0].variable_name == "text" - assert isinstance(template.segments[1], TextSegment) - assert template.segments[1].text == "\n" - assert isinstance(template.segments[2], VariableSegment) - assert template.segments[2].selector == ["node2", "result"] - assert template.segments[2].variable_name == "result" - - def test_from_end_outputs_empty(self): - """Test creating template from empty End node outputs.""" - outputs_config = [] - template = Template.from_end_outputs(outputs_config) - - assert len(template.segments) == 0 - - def test_template_str_representation(self): - """Test string representation of template.""" - template_str = "Hello, {{#node1.name#}}!" - template = Template.from_answer_template(template_str) - - assert str(template) == template_str diff --git a/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py b/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py deleted file mode 100644 index 6100ebede51..00000000000 --- a/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py +++ /dev/null @@ -1,136 +0,0 @@ -from dify_graph.runtime import VariablePool -from dify_graph.variables.segments import ( - BooleanSegment, - IntegerSegment, - NoneSegment, - StringSegment, -) - - -class TestVariablePoolGetAndNestedAttribute: - # - # _get_nested_attribute tests - # - def test__get_nested_attribute_existing_key(self): - pool = VariablePool.empty() - obj = {"a": 123} - segment = pool._get_nested_attribute(obj, "a") - assert segment is not None - assert segment.value == 123 - - def test__get_nested_attribute_missing_key(self): - pool = VariablePool.empty() - obj = {"a": 123} - segment = pool._get_nested_attribute(obj, "b") - assert segment is None - - def test__get_nested_attribute_non_dict(self): - pool = VariablePool.empty() - obj = ["not", "a", "dict"] - segment = pool._get_nested_attribute(obj, "a") - assert segment is None - - def test__get_nested_attribute_with_none_value(self): - pool = VariablePool.empty() - obj = {"a": None} - segment = pool._get_nested_attribute(obj, "a") - assert segment is not None - assert isinstance(segment, NoneSegment) - - def test__get_nested_attribute_with_empty_string(self): - pool = VariablePool.empty() - obj = {"a": ""} - segment = pool._get_nested_attribute(obj, "a") - assert segment is not None - assert isinstance(segment, StringSegment) - assert segment.value == "" - - # - # get tests - # - def test_get_simple_variable(self): - pool = VariablePool.empty() - pool.add(("node1", "var1"), "value1") - segment = pool.get(("node1", "var1")) - assert segment is not None - assert segment.value == "value1" - - def test_get_missing_variable(self): - pool = VariablePool.empty() - result = pool.get(("node1", "unknown")) - assert result is None - - def test_get_with_too_short_selector(self): - pool = VariablePool.empty() - result = pool.get(("only_node",)) - assert result is None - - def test_get_nested_object_attribute(self): - pool = VariablePool.empty() - obj_value = {"inner": "hello"} - pool.add(("node1", "obj"), obj_value) - - # simulate selector with nested attr - segment = pool.get(("node1", "obj", "inner")) - assert segment is not None - assert segment.value == "hello" - - def test_get_nested_object_missing_attribute(self): - pool = VariablePool.empty() - obj_value = {"inner": "hello"} - pool.add(("node1", "obj"), obj_value) - - result = pool.get(("node1", "obj", "not_exist")) - assert result is None - - def test_get_nested_object_attribute_with_falsy_values(self): - pool = VariablePool.empty() - obj_value = { - "inner_none": None, - "inner_empty": "", - "inner_zero": 0, - "inner_false": False, - } - pool.add(("node1", "obj"), obj_value) - - segment_none = pool.get(("node1", "obj", "inner_none")) - assert segment_none is not None - assert isinstance(segment_none, NoneSegment) - - segment_empty = pool.get(("node1", "obj", "inner_empty")) - assert segment_empty is not None - assert isinstance(segment_empty, StringSegment) - assert segment_empty.value == "" - - segment_zero = pool.get(("node1", "obj", "inner_zero")) - assert segment_zero is not None - assert isinstance(segment_zero, IntegerSegment) - assert segment_zero.value == 0 - - segment_false = pool.get(("node1", "obj", "inner_false")) - assert segment_false is not None - assert isinstance(segment_false, BooleanSegment) - assert segment_false.value is False - - -class TestVariablePoolGetNotModifyVariableDictionary: - _NODE_ID = "start" - _VAR_NAME = "name" - - def test_convert_to_template_should_not_introduce_extra_keys(self): - pool = VariablePool.empty() - pool.add([self._NODE_ID, self._VAR_NAME], 0) - pool.convert_template("The start.name is {{#start.name#}}") - assert "The start" not in pool.variable_dictionary - - def test_get_should_not_modify_variable_dictionary(self): - pool = VariablePool.empty() - pool.get([self._NODE_ID, self._VAR_NAME]) - assert len(pool.variable_dictionary) == 1 # only contains `sys` node id - assert "start" not in pool.variable_dictionary - - pool = VariablePool.empty() - pool.add([self._NODE_ID, self._VAR_NAME], "Joe") - pool.get([self._NODE_ID, "count"]) - start_subdict = pool.variable_dictionary[self._NODE_ID] - assert "count" not in start_subdict diff --git a/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py b/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py deleted file mode 100644 index 216e64db8d9..00000000000 --- a/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py +++ /dev/null @@ -1,225 +0,0 @@ -""" -Unit tests for WorkflowNodeExecution domain model, focusing on process_data truncation functionality. -""" - -from dataclasses import dataclass -from datetime import datetime -from typing import Any - -import pytest - -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecution -from dify_graph.enums import BuiltinNodeTypes - - -class TestWorkflowNodeExecutionProcessDataTruncation: - """Test process_data truncation functionality in WorkflowNodeExecution domain model.""" - - def create_workflow_node_execution( - self, - process_data: dict[str, Any] | None = None, - ) -> WorkflowNodeExecution: - """Create a WorkflowNodeExecution instance for testing.""" - return WorkflowNodeExecution( - id="test-execution-id", - workflow_id="test-workflow-id", - index=1, - node_id="test-node-id", - node_type=BuiltinNodeTypes.LLM, - title="Test Node", - process_data=process_data, - created_at=datetime.now(), - ) - - def test_initial_process_data_truncated_state(self): - """Test that process_data_truncated returns False initially.""" - execution = self.create_workflow_node_execution() - - assert execution.process_data_truncated is False - assert execution.get_truncated_process_data() is None - - def test_set_and_get_truncated_process_data(self): - """Test setting and getting truncated process_data.""" - execution = self.create_workflow_node_execution() - test_truncated_data = {"truncated": True, "key": "value"} - - execution.set_truncated_process_data(test_truncated_data) - - assert execution.process_data_truncated is True - assert execution.get_truncated_process_data() == test_truncated_data - - def test_set_truncated_process_data_to_none(self): - """Test setting truncated process_data to None.""" - execution = self.create_workflow_node_execution() - - # First set some data - execution.set_truncated_process_data({"key": "value"}) - assert execution.process_data_truncated is True - - # Then set to None - execution.set_truncated_process_data(None) - assert execution.process_data_truncated is False - assert execution.get_truncated_process_data() is None - - def test_get_response_process_data_with_no_truncation(self): - """Test get_response_process_data when no truncation is set.""" - original_data = {"original": True, "data": "value"} - execution = self.create_workflow_node_execution(process_data=original_data) - - response_data = execution.get_response_process_data() - - assert response_data == original_data - assert execution.process_data_truncated is False - - def test_get_response_process_data_with_truncation(self): - """Test get_response_process_data when truncation is set.""" - original_data = {"original": True, "large_data": "x" * 10000} - truncated_data = {"original": True, "large_data": "[TRUNCATED]"} - - execution = self.create_workflow_node_execution(process_data=original_data) - execution.set_truncated_process_data(truncated_data) - - response_data = execution.get_response_process_data() - - # Should return truncated data, not original - assert response_data == truncated_data - assert response_data != original_data - assert execution.process_data_truncated is True - - def test_get_response_process_data_with_none_process_data(self): - """Test get_response_process_data when process_data is None.""" - execution = self.create_workflow_node_execution(process_data=None) - - response_data = execution.get_response_process_data() - - assert response_data is None - assert execution.process_data_truncated is False - - def test_consistency_with_inputs_outputs_pattern(self): - """Test that process_data truncation follows the same pattern as inputs/outputs.""" - execution = self.create_workflow_node_execution() - - # Test that all truncation methods exist and behave consistently - test_data = {"test": "data"} - - # Test inputs truncation - execution.set_truncated_inputs(test_data) - assert execution.inputs_truncated is True - assert execution.get_truncated_inputs() == test_data - - # Test outputs truncation - execution.set_truncated_outputs(test_data) - assert execution.outputs_truncated is True - assert execution.get_truncated_outputs() == test_data - - # Test process_data truncation - execution.set_truncated_process_data(test_data) - assert execution.process_data_truncated is True - assert execution.get_truncated_process_data() == test_data - - @pytest.mark.parametrize( - "test_data", - [ - {"simple": "value"}, - {"nested": {"key": "value"}}, - {"list": [1, 2, 3]}, - {"mixed": {"string": "value", "number": 42, "list": [1, 2]}}, - {}, # empty dict - ], - ) - def test_truncated_process_data_with_various_data_types(self, test_data): - """Test that truncated process_data works with various data types.""" - execution = self.create_workflow_node_execution() - - execution.set_truncated_process_data(test_data) - - assert execution.process_data_truncated is True - assert execution.get_truncated_process_data() == test_data - assert execution.get_response_process_data() == test_data - - -@dataclass -class ProcessDataScenario: - """Test scenario data for process_data functionality.""" - - name: str - original_data: dict[str, Any] | None - truncated_data: dict[str, Any] | None - expected_truncated_flag: bool - expected_response_data: dict[str, Any] | None - - -class TestWorkflowNodeExecutionProcessDataScenarios: - """Test various scenarios for process_data handling.""" - - def get_process_data_scenarios(self) -> list[ProcessDataScenario]: - """Create test scenarios for process_data functionality.""" - return [ - ProcessDataScenario( - name="no_process_data", - original_data=None, - truncated_data=None, - expected_truncated_flag=False, - expected_response_data=None, - ), - ProcessDataScenario( - name="process_data_without_truncation", - original_data={"small": "data"}, - truncated_data=None, - expected_truncated_flag=False, - expected_response_data={"small": "data"}, - ), - ProcessDataScenario( - name="process_data_with_truncation", - original_data={"large": "x" * 10000, "metadata": "info"}, - truncated_data={"large": "[TRUNCATED]", "metadata": "info"}, - expected_truncated_flag=True, - expected_response_data={"large": "[TRUNCATED]", "metadata": "info"}, - ), - ProcessDataScenario( - name="empty_process_data", - original_data={}, - truncated_data=None, - expected_truncated_flag=False, - expected_response_data={}, - ), - ProcessDataScenario( - name="complex_nested_data_with_truncation", - original_data={ - "config": {"setting": "value"}, - "logs": ["log1", "log2"] * 1000, # Large list - "status": "running", - }, - truncated_data={"config": {"setting": "value"}, "logs": "[TRUNCATED: 2000 items]", "status": "running"}, - expected_truncated_flag=True, - expected_response_data={ - "config": {"setting": "value"}, - "logs": "[TRUNCATED: 2000 items]", - "status": "running", - }, - ), - ] - - @pytest.mark.parametrize( - "scenario", - get_process_data_scenarios(None), - ids=[scenario.name for scenario in get_process_data_scenarios(None)], - ) - def test_process_data_scenarios(self, scenario: ProcessDataScenario): - """Test various process_data scenarios.""" - execution = WorkflowNodeExecution( - id="test-execution-id", - workflow_id="test-workflow-id", - index=1, - node_id="test-node-id", - node_type=BuiltinNodeTypes.LLM, - title="Test Node", - process_data=scenario.original_data, - created_at=datetime.now(), - ) - - if scenario.truncated_data is not None: - execution.set_truncated_process_data(scenario.truncated_data) - - assert execution.process_data_truncated == scenario.expected_truncated_flag - assert execution.get_response_process_data() == scenario.expected_response_data diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph.py b/api/tests/unit_tests/core/workflow/graph/test_graph.py deleted file mode 100644 index 24bd9ccbed6..00000000000 --- a/api/tests/unit_tests/core/workflow/graph/test_graph.py +++ /dev/null @@ -1,281 +0,0 @@ -"""Unit tests for Graph class methods.""" - -from unittest.mock import Mock - -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeState -from dify_graph.graph.edge import Edge -from dify_graph.graph.graph import Graph -from dify_graph.nodes.base.node import Node - - -def create_mock_node(node_id: str, execution_type: NodeExecutionType, state: NodeState = NodeState.UNKNOWN) -> Node: - """Create a mock node for testing.""" - node = Mock(spec=Node) - node.id = node_id - node.execution_type = execution_type - node.state = state - node.node_type = BuiltinNodeTypes.START - return node - - -class TestMarkInactiveRootBranches: - """Test cases for _mark_inactive_root_branches method.""" - - def test_single_root_no_marking(self): - """Test that single root graph doesn't mark anything as skipped.""" - nodes = { - "root1": create_mock_node("root1", NodeExecutionType.ROOT), - "child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE), - } - - edges = { - "edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"), - } - - in_edges = {"child1": ["edge1"]} - out_edges = {"root1": ["edge1"]} - - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") - - assert nodes["root1"].state == NodeState.UNKNOWN - assert nodes["child1"].state == NodeState.UNKNOWN - assert edges["edge1"].state == NodeState.UNKNOWN - - def test_multiple_roots_mark_inactive(self): - """Test marking inactive root branches with multiple root nodes.""" - nodes = { - "root1": create_mock_node("root1", NodeExecutionType.ROOT), - "root2": create_mock_node("root2", NodeExecutionType.ROOT), - "child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE), - "child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE), - } - - edges = { - "edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"), - "edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"), - } - - in_edges = {"child1": ["edge1"], "child2": ["edge2"]} - out_edges = {"root1": ["edge1"], "root2": ["edge2"]} - - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") - - assert nodes["root1"].state == NodeState.UNKNOWN - assert nodes["root2"].state == NodeState.SKIPPED - assert nodes["child1"].state == NodeState.UNKNOWN - assert nodes["child2"].state == NodeState.SKIPPED - assert edges["edge1"].state == NodeState.UNKNOWN - assert edges["edge2"].state == NodeState.SKIPPED - - def test_shared_downstream_node(self): - """Test that shared downstream nodes are not skipped if at least one path is active.""" - nodes = { - "root1": create_mock_node("root1", NodeExecutionType.ROOT), - "root2": create_mock_node("root2", NodeExecutionType.ROOT), - "child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE), - "child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE), - "shared": create_mock_node("shared", NodeExecutionType.EXECUTABLE), - } - - edges = { - "edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"), - "edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"), - "edge3": Edge(id="edge3", tail="child1", head="shared", source_handle="source"), - "edge4": Edge(id="edge4", tail="child2", head="shared", source_handle="source"), - } - - in_edges = { - "child1": ["edge1"], - "child2": ["edge2"], - "shared": ["edge3", "edge4"], - } - out_edges = { - "root1": ["edge1"], - "root2": ["edge2"], - "child1": ["edge3"], - "child2": ["edge4"], - } - - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") - - assert nodes["root1"].state == NodeState.UNKNOWN - assert nodes["root2"].state == NodeState.SKIPPED - assert nodes["child1"].state == NodeState.UNKNOWN - assert nodes["child2"].state == NodeState.SKIPPED - assert nodes["shared"].state == NodeState.UNKNOWN # Not skipped because edge3 is active - assert edges["edge1"].state == NodeState.UNKNOWN - assert edges["edge2"].state == NodeState.SKIPPED - assert edges["edge3"].state == NodeState.UNKNOWN - assert edges["edge4"].state == NodeState.SKIPPED - - def test_deep_branch_marking(self): - """Test marking deep branches with multiple levels.""" - nodes = { - "root1": create_mock_node("root1", NodeExecutionType.ROOT), - "root2": create_mock_node("root2", NodeExecutionType.ROOT), - "level1_a": create_mock_node("level1_a", NodeExecutionType.EXECUTABLE), - "level1_b": create_mock_node("level1_b", NodeExecutionType.EXECUTABLE), - "level2_a": create_mock_node("level2_a", NodeExecutionType.EXECUTABLE), - "level2_b": create_mock_node("level2_b", NodeExecutionType.EXECUTABLE), - "level3": create_mock_node("level3", NodeExecutionType.EXECUTABLE), - } - - edges = { - "edge1": Edge(id="edge1", tail="root1", head="level1_a", source_handle="source"), - "edge2": Edge(id="edge2", tail="root2", head="level1_b", source_handle="source"), - "edge3": Edge(id="edge3", tail="level1_a", head="level2_a", source_handle="source"), - "edge4": Edge(id="edge4", tail="level1_b", head="level2_b", source_handle="source"), - "edge5": Edge(id="edge5", tail="level2_b", head="level3", source_handle="source"), - } - - in_edges = { - "level1_a": ["edge1"], - "level1_b": ["edge2"], - "level2_a": ["edge3"], - "level2_b": ["edge4"], - "level3": ["edge5"], - } - out_edges = { - "root1": ["edge1"], - "root2": ["edge2"], - "level1_a": ["edge3"], - "level1_b": ["edge4"], - "level2_b": ["edge5"], - } - - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") - - assert nodes["root1"].state == NodeState.UNKNOWN - assert nodes["root2"].state == NodeState.SKIPPED - assert nodes["level1_a"].state == NodeState.UNKNOWN - assert nodes["level1_b"].state == NodeState.SKIPPED - assert nodes["level2_a"].state == NodeState.UNKNOWN - assert nodes["level2_b"].state == NodeState.SKIPPED - assert nodes["level3"].state == NodeState.SKIPPED - assert edges["edge1"].state == NodeState.UNKNOWN - assert edges["edge2"].state == NodeState.SKIPPED - assert edges["edge3"].state == NodeState.UNKNOWN - assert edges["edge4"].state == NodeState.SKIPPED - assert edges["edge5"].state == NodeState.SKIPPED - - def test_non_root_execution_type(self): - """Test that nodes with non-ROOT execution type are not treated as root nodes.""" - nodes = { - "root1": create_mock_node("root1", NodeExecutionType.ROOT), - "non_root": create_mock_node("non_root", NodeExecutionType.EXECUTABLE), - "child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE), - "child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE), - } - - edges = { - "edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"), - "edge2": Edge(id="edge2", tail="non_root", head="child2", source_handle="source"), - } - - in_edges = {"child1": ["edge1"], "child2": ["edge2"]} - out_edges = {"root1": ["edge1"], "non_root": ["edge2"]} - - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") - - assert nodes["root1"].state == NodeState.UNKNOWN - assert nodes["non_root"].state == NodeState.UNKNOWN # Not marked as skipped - assert nodes["child1"].state == NodeState.UNKNOWN - assert nodes["child2"].state == NodeState.UNKNOWN - assert edges["edge1"].state == NodeState.UNKNOWN - assert edges["edge2"].state == NodeState.UNKNOWN - - def test_empty_graph(self): - """Test handling of empty graph structures.""" - nodes = {} - edges = {} - in_edges = {} - out_edges = {} - - # Should not raise any errors - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "non_existent") - - def test_three_roots_mark_two_inactive(self): - """Test with three root nodes where two should be marked inactive.""" - nodes = { - "root1": create_mock_node("root1", NodeExecutionType.ROOT), - "root2": create_mock_node("root2", NodeExecutionType.ROOT), - "root3": create_mock_node("root3", NodeExecutionType.ROOT), - "child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE), - "child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE), - "child3": create_mock_node("child3", NodeExecutionType.EXECUTABLE), - } - - edges = { - "edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"), - "edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"), - "edge3": Edge(id="edge3", tail="root3", head="child3", source_handle="source"), - } - - in_edges = { - "child1": ["edge1"], - "child2": ["edge2"], - "child3": ["edge3"], - } - out_edges = { - "root1": ["edge1"], - "root2": ["edge2"], - "root3": ["edge3"], - } - - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root2") - - assert nodes["root1"].state == NodeState.SKIPPED - assert nodes["root2"].state == NodeState.UNKNOWN # Active root - assert nodes["root3"].state == NodeState.SKIPPED - assert nodes["child1"].state == NodeState.SKIPPED - assert nodes["child2"].state == NodeState.UNKNOWN - assert nodes["child3"].state == NodeState.SKIPPED - assert edges["edge1"].state == NodeState.SKIPPED - assert edges["edge2"].state == NodeState.UNKNOWN - assert edges["edge3"].state == NodeState.SKIPPED - - def test_convergent_paths(self): - """Test convergent paths where multiple inactive branches lead to same node.""" - nodes = { - "root1": create_mock_node("root1", NodeExecutionType.ROOT), - "root2": create_mock_node("root2", NodeExecutionType.ROOT), - "root3": create_mock_node("root3", NodeExecutionType.ROOT), - "mid1": create_mock_node("mid1", NodeExecutionType.EXECUTABLE), - "mid2": create_mock_node("mid2", NodeExecutionType.EXECUTABLE), - "convergent": create_mock_node("convergent", NodeExecutionType.EXECUTABLE), - } - - edges = { - "edge1": Edge(id="edge1", tail="root1", head="mid1", source_handle="source"), - "edge2": Edge(id="edge2", tail="root2", head="mid2", source_handle="source"), - "edge3": Edge(id="edge3", tail="root3", head="convergent", source_handle="source"), - "edge4": Edge(id="edge4", tail="mid1", head="convergent", source_handle="source"), - "edge5": Edge(id="edge5", tail="mid2", head="convergent", source_handle="source"), - } - - in_edges = { - "mid1": ["edge1"], - "mid2": ["edge2"], - "convergent": ["edge3", "edge4", "edge5"], - } - out_edges = { - "root1": ["edge1"], - "root2": ["edge2"], - "root3": ["edge3"], - "mid1": ["edge4"], - "mid2": ["edge5"], - } - - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") - - assert nodes["root1"].state == NodeState.UNKNOWN - assert nodes["root2"].state == NodeState.SKIPPED - assert nodes["root3"].state == NodeState.SKIPPED - assert nodes["mid1"].state == NodeState.UNKNOWN - assert nodes["mid2"].state == NodeState.SKIPPED - assert nodes["convergent"].state == NodeState.UNKNOWN # Not skipped due to active path from root1 - assert edges["edge1"].state == NodeState.UNKNOWN - assert edges["edge2"].state == NodeState.SKIPPED - assert edges["edge3"].state == NodeState.SKIPPED - assert edges["edge4"].state == NodeState.UNKNOWN - assert edges["edge5"].state == NodeState.SKIPPED diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py b/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py deleted file mode 100644 index 64c2eee7766..00000000000 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py +++ /dev/null @@ -1,59 +0,0 @@ -from unittest.mock import MagicMock - -import pytest - -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.graph import Graph -from dify_graph.nodes.base.node import Node - - -def _make_node(node_id: str, node_type: NodeType = BuiltinNodeTypes.START) -> Node: - node = MagicMock(spec=Node) - node.id = node_id - node.node_type = node_type - node.execution_type = None # attribute not used in builder path - return node - - -def test_graph_builder_creates_linear_graph(): - builder = Graph.new() - root = _make_node("root", BuiltinNodeTypes.START) - mid = _make_node("mid", BuiltinNodeTypes.LLM) - end = _make_node("end", BuiltinNodeTypes.END) - - graph = builder.add_root(root).add_node(mid).add_node(end).build() - - assert graph.root_node is root - assert graph.nodes == {"root": root, "mid": mid, "end": end} - assert len(graph.edges) == 2 - first_edge = next(iter(graph.edges.values())) - assert first_edge.tail == "root" - assert first_edge.head == "mid" - assert graph.out_edges["mid"] == [edge_id for edge_id, edge in graph.edges.items() if edge.tail == "mid"] - - -def test_graph_builder_supports_custom_predecessor(): - builder = Graph.new() - root = _make_node("root") - branch = _make_node("branch") - other = _make_node("other") - - graph = builder.add_root(root).add_node(branch).add_node(other, from_node_id="root").build() - - outgoing_root = graph.out_edges["root"] - assert len(outgoing_root) == 2 - edge_targets = {graph.edges[eid].head for eid in outgoing_root} - assert edge_targets == {"branch", "other"} - - -def test_graph_builder_validates_usage(): - builder = Graph.new() - node = _make_node("node") - - with pytest.raises(ValueError, match="Root node"): - builder.add_node(node) - - builder.add_root(node) - duplicate = _make_node("node") - with pytest.raises(ValueError, match="Duplicate"): - builder.add_node(duplicate) diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py deleted file mode 100644 index 75de07bd8b8..00000000000 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py +++ /dev/null @@ -1,118 +0,0 @@ -from __future__ import annotations - -from typing import Any - -import pytest - -from core.workflow.node_factory import DifyNodeFactory -from dify_graph.graph import Graph -from dify_graph.graph.validation import GraphValidationError -from dify_graph.nodes import BuiltinNodeTypes -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from tests.workflow_test_utils import build_test_graph_init_params - - -def _build_iteration_graph(node_id: str) -> dict[str, Any]: - return { - "nodes": [ - { - "id": node_id, - "data": { - "type": "iteration", - "title": "Iteration", - "iterator_selector": ["start", "items"], - "output_selector": [node_id, "output"], - }, - } - ], - "edges": [], - } - - -def _build_loop_graph(node_id: str) -> dict[str, Any]: - return { - "nodes": [ - { - "id": node_id, - "data": { - "type": "loop", - "title": "Loop", - "loop_count": 1, - "break_conditions": [], - "logical_operator": "and", - "loop_variables": [], - "outputs": {}, - }, - } - ], - "edges": [], - } - - -def _make_factory(graph_config: dict[str, Any]) -> DifyNodeFactory: - graph_init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool( - system_variables=SystemVariable.default(), - user_inputs={}, - environment_variables=[], - ), - start_at=0.0, - ) - return DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state) - - -def test_iteration_root_requires_skip_validation(): - node_id = "iteration-node" - graph_config = _build_iteration_graph(node_id) - node_factory = _make_factory(graph_config) - - with pytest.raises(GraphValidationError): - Graph.init( - graph_config=graph_config, - node_factory=node_factory, - root_node_id=node_id, - ) - - graph = Graph.init( - graph_config=graph_config, - node_factory=node_factory, - root_node_id=node_id, - skip_validation=True, - ) - - assert graph.root_node.id == node_id - assert graph.root_node.node_type == BuiltinNodeTypes.ITERATION - - -def test_loop_root_requires_skip_validation(): - node_id = "loop-node" - graph_config = _build_loop_graph(node_id) - node_factory = _make_factory(graph_config) - - with pytest.raises(GraphValidationError): - Graph.init( - graph_config=graph_config, - node_factory=node_factory, - root_node_id=node_id, - ) - - graph = Graph.init( - graph_config=graph_config, - node_factory=node_factory, - root_node_id=node_id, - skip_validation=True, - ) - - assert graph.root_node.id == node_id - assert graph.root_node.node_type == BuiltinNodeTypes.LOOP diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py deleted file mode 100644 index e94ad74eb01..00000000000 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py +++ /dev/null @@ -1,219 +0,0 @@ -from __future__ import annotations - -import time -from collections.abc import Mapping -from dataclasses import dataclass - -import pytest - -from dify_graph.entities import GraphInitParams -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, ErrorStrategy, NodeExecutionType, NodeType -from dify_graph.graph import Graph -from dify_graph.graph.validation import GraphValidationError -from dify_graph.nodes.base.node import Node -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from tests.workflow_test_utils import build_test_graph_init_params - - -class _TestNodeData(BaseNodeData): - type: NodeType | None = None - execution_type: NodeExecutionType | str | None = None - - -class _TestNode(Node[_TestNodeData]): - node_type = BuiltinNodeTypes.ANSWER - execution_type = NodeExecutionType.EXECUTABLE - - @classmethod - def version(cls) -> str: - return "1" - - def __init__( - self, - *, - id: str, - config: Mapping[str, object], - graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - node_type_value = self.data.get("type") - if isinstance(node_type_value, str): - self.node_type = node_type_value - - def _run(self): - raise NotImplementedError - - def post_init(self) -> None: - super().post_init() - self._maybe_override_execution_type() - self.data = dict(self.node_data.model_dump()) - - def _maybe_override_execution_type(self) -> None: - execution_type_value = self.node_data.execution_type - if execution_type_value is None: - return - if isinstance(execution_type_value, NodeExecutionType): - self.execution_type = execution_type_value - else: - self.execution_type = NodeExecutionType(execution_type_value) - - -@dataclass(slots=True) -class _SimpleNodeFactory: - graph_init_params: GraphInitParams - graph_runtime_state: GraphRuntimeState - - def create_node(self, node_config: Mapping[str, object]) -> _TestNode: - node_id = str(node_config["id"]) - node = _TestNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - ) - return node - - -@pytest.fixture -def graph_init_dependencies() -> tuple[_SimpleNodeFactory, dict[str, object]]: - graph_config: dict[str, object] = {"edges": [], "nodes": []} - init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="service-api", - call_depth=0, - ) - variable_pool = VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}) - runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - factory = _SimpleNodeFactory(graph_init_params=init_params, graph_runtime_state=runtime_state) - return factory, graph_config - - -def test_graph_initialization_runs_default_validators( - graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], -): - node_factory, graph_config = graph_init_dependencies - graph_config["nodes"] = [ - { - "id": "start", - "data": {"type": BuiltinNodeTypes.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, - }, - {"id": "answer", "data": {"type": BuiltinNodeTypes.ANSWER, "title": "Answer"}}, - ] - graph_config["edges"] = [ - {"source": "start", "target": "answer", "sourceHandle": "success"}, - ] - - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - assert graph.root_node.id == "start" - assert "answer" in graph.nodes - - -def test_graph_validation_fails_for_unknown_edge_targets( - graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], -) -> None: - node_factory, graph_config = graph_init_dependencies - graph_config["nodes"] = [ - { - "id": "start", - "data": {"type": BuiltinNodeTypes.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, - }, - ] - graph_config["edges"] = [ - {"source": "start", "target": "missing", "sourceHandle": "success"}, - ] - - with pytest.raises(GraphValidationError) as exc: - Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - assert any(issue.code == "MISSING_NODE" for issue in exc.value.issues) - - -def test_graph_promotes_fail_branch_nodes_to_branch_execution_type( - graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], -) -> None: - node_factory, graph_config = graph_init_dependencies - graph_config["nodes"] = [ - { - "id": "start", - "data": {"type": BuiltinNodeTypes.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, - }, - { - "id": "branch", - "data": { - "type": BuiltinNodeTypes.IF_ELSE, - "title": "Branch", - "error_strategy": ErrorStrategy.FAIL_BRANCH, - }, - }, - ] - graph_config["edges"] = [ - {"source": "start", "target": "branch", "sourceHandle": "success"}, - ] - - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - assert graph.nodes["branch"].execution_type == NodeExecutionType.BRANCH - - -def test_graph_init_ignores_custom_note_nodes_before_node_data_validation( - graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], -) -> None: - node_factory, graph_config = graph_init_dependencies - graph_config["nodes"] = [ - { - "id": "start", - "data": {"type": BuiltinNodeTypes.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, - }, - {"id": "answer", "data": {"type": BuiltinNodeTypes.ANSWER, "title": "Answer"}}, - { - "id": "note", - "type": "custom-note", - "data": { - "type": "", - "title": "", - "desc": "", - "text": "{}", - "theme": "blue", - }, - }, - ] - graph_config["edges"] = [ - {"source": "start", "target": "answer", "sourceHandle": "success"}, - ] - - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - assert graph.root_node.id == "start" - assert "answer" in graph.nodes - assert "note" not in graph.nodes - - -def test_graph_init_fails_for_unknown_root_node_id( - graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], -) -> None: - node_factory, graph_config = graph_init_dependencies - graph_config["nodes"] = [ - { - "id": "start", - "data": {"type": BuiltinNodeTypes.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, - }, - ] - graph_config["edges"] = [] - - with pytest.raises(ValueError, match="Root node id missing not found in the graph"): - Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="missing") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/README.md b/api/tests/unit_tests/core/workflow/graph_engine/README.md index 40ed61eb029..dd419f0810f 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/README.md +++ b/api/tests/unit_tests/core/workflow/graph_engine/README.md @@ -1,441 +1,30 @@ -# Graph Engine Testing Framework +# Workflow Graph Engine Smoke Tests -## Overview +This directory now keeps only a small Dify-owned smoke layer around the external +`graphon` package. -This directory contains a comprehensive testing framework for the Graph Engine, including: +Retained coverage focuses on: -1. **TableTestRunner** - Advanced table-driven test framework for workflow testing -1. **Auto-Mock System** - Powerful mocking framework for testing without external dependencies +1. Dify workflow layers: + - `layers/test_llm_quota.py` + - `layers/test_observability.py` +2. Human-input resume integration: + - `test_parallel_human_input_join_resume.py` +3. One mocked tool/chatflow smoke path: + - `test_tool_in_chatflow.py` -## TableTestRunner Framework +The helper modules below remain only because the retained smoke tests use them: -The TableTestRunner (`test_table_runner.py`) provides a robust table-driven testing framework for GraphEngine workflows. +1. `test_mock_config.py` +2. `test_mock_factory.py` +3. `test_mock_nodes.py` +4. `test_table_runner.py` -### Features - -- **Table-driven testing** - Define test cases as structured data -- **Parallel test execution** - Run tests concurrently for faster execution -- **Property-based testing** - Integration with Hypothesis for fuzzing -- **Event sequence validation** - Verify correct event ordering -- **Mock configuration** - Seamless integration with the auto-mock system -- **Performance metrics** - Track execution times and bottlenecks -- **Detailed error reporting** - Comprehensive failure diagnostics - -### Basic Usage - -```python -from test_table_runner import TableTestRunner, WorkflowTestCase - -# Create test runner -runner = TableTestRunner() - -# Define test case -test_case = WorkflowTestCase( - fixture_path="simple_workflow", - inputs={"query": "Hello"}, - expected_outputs={"result": "World"}, - description="Basic workflow test", -) - -# Run single test -result = runner.run_test_case(test_case) -assert result.success -``` - -### Advanced Features - -#### Parallel Execution - -```python -runner = TableTestRunner(max_workers=8) - -test_cases = [ - WorkflowTestCase(...), - WorkflowTestCase(...), - # ... more test cases -] - -# Run tests in parallel -suite_result = runner.run_table_tests( - test_cases, - parallel=True, - fail_fast=False -) - -print(f"Success rate: {suite_result.success_rate:.1f}%") -``` - -#### Event Sequence Validation - -```python -from dify_graph.graph_events import ( - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, -) - -test_case = WorkflowTestCase( - fixture_path="workflow", - inputs={}, - expected_outputs={}, - expected_event_sequence=[ - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ] -) -``` - -### Test Suite Reports - -```python -# Run test suite -suite_result = runner.run_table_tests(test_cases) - -# Generate detailed report -report = runner.generate_report(suite_result) -print(report) - -# Access specific results -failed_results = suite_result.get_failed_results() -for result in failed_results: - print(f"Failed: {result.test_case.description}") - print(f" Error: {result.error}") -``` - -### Performance Testing - -```python -# Enable logging for performance insights -runner = TableTestRunner( - enable_logging=True, - log_level="DEBUG" -) - -# Run tests and analyze performance -suite_result = runner.run_table_tests(test_cases) - -# Get slowest tests -sorted_results = sorted( - suite_result.results, - key=lambda r: r.execution_time, - reverse=True -) - -print("Slowest tests:") -for result in sorted_results[:5]: - print(f" {result.test_case.description}: {result.execution_time:.2f}s") -``` - -## Integration: TableTestRunner + Auto-Mock System - -The TableTestRunner seamlessly integrates with the auto-mock system for comprehensive workflow testing: - -```python -from test_table_runner import TableTestRunner, WorkflowTestCase -from test_mock_config import MockConfigBuilder - -# Configure mocks -mock_config = (MockConfigBuilder() - .with_llm_response("Mocked LLM response") - .with_tool_response({"result": "mocked"}) - .with_delays(True) # Simulate realistic delays - .build()) - -# Create test case with mocking -test_case = WorkflowTestCase( - fixture_path="complex_workflow", - inputs={"query": "test"}, - expected_outputs={"answer": "Mocked LLM response"}, - use_auto_mock=True, # Enable auto-mocking - mock_config=mock_config, - description="Test with mocked services", -) - -# Run test -runner = TableTestRunner() -result = runner.run_test_case(test_case) -``` - -## Auto-Mock System - -The auto-mock system provides a powerful framework for testing workflows that contain nodes requiring third-party services (LLM, APIs, tools, etc.) without making actual external calls. This enables: - -- **Fast test execution** - No network latency or API rate limits -- **Deterministic results** - Consistent outputs for reliable testing -- **Cost savings** - No API usage charges during testing -- **Offline testing** - Tests can run without internet connectivity -- **Error simulation** - Test error handling without triggering real failures - -## Architecture - -The auto-mock system consists of three main components: - -### 1. MockNodeFactory (`test_mock_factory.py`) - -- Extends `DifyNodeFactory` to intercept node creation -- Automatically detects nodes requiring third-party services -- Returns mock node implementations instead of real ones -- Supports registration of custom mock implementations - -### 2. Mock Node Implementations (`test_mock_nodes.py`) - -- `MockLLMNode` - Mocks LLM API calls (OpenAI, Anthropic, etc.) -- `MockAgentNode` - Mocks agent execution -- `MockToolNode` - Mocks tool invocations -- `MockKnowledgeRetrievalNode` - Mocks knowledge base queries -- `MockHttpRequestNode` - Mocks HTTP requests -- `MockParameterExtractorNode` - Mocks parameter extraction -- `MockDocumentExtractorNode` - Mocks document processing -- `MockQuestionClassifierNode` - Mocks question classification - -### 3. Mock Configuration (`test_mock_config.py`) - -- `MockConfig` - Global configuration for mock behavior -- `NodeMockConfig` - Node-specific mock configuration -- `MockConfigBuilder` - Fluent interface for building configurations - -## Usage - -### Basic Example - -```python -from test_graph_engine import TableTestRunner, WorkflowTestCase -from test_mock_config import MockConfigBuilder - -# Create test runner -runner = TableTestRunner() - -# Configure mock responses -mock_config = (MockConfigBuilder() - .with_llm_response("Mocked LLM response") - .build()) - -# Define test case -test_case = WorkflowTestCase( - fixture_path="llm-simple", - inputs={"query": "Hello"}, - expected_outputs={"answer": "Mocked LLM response"}, - use_auto_mock=True, # Enable auto-mocking - mock_config=mock_config, -) - -# Run test -result = runner.run_test_case(test_case) -assert result.success -``` - -### Custom Node Outputs - -```python -# Configure specific outputs for individual nodes -mock_config = MockConfig() -mock_config.set_node_outputs("llm_node_123", { - "text": "Custom response for this specific node", - "usage": {"total_tokens": 50}, - "finish_reason": "stop", -}) -``` - -### Error Simulation - -```python -# Simulate node failures for error handling tests -mock_config = MockConfig() -mock_config.set_node_error("http_node", "Connection timeout") -``` - -### Simulated Delays - -```python -# Add realistic execution delays -from test_mock_config import NodeMockConfig - -node_config = NodeMockConfig( - node_id="llm_node", - outputs={"text": "Response"}, - delay=1.5, # 1.5 second delay -) -mock_config.set_node_config("llm_node", node_config) -``` - -### Custom Handlers - -```python -# Define custom logic for mock outputs -def custom_handler(node): - # Access node state and return dynamic outputs - return { - "text": f"Processed: {node.graph_runtime_state.variable_pool.get('query')}", - } - -node_config = NodeMockConfig( - node_id="llm_node", - custom_handler=custom_handler, -) -``` - -## Node Types Automatically Mocked - -The following node types are automatically mocked when `use_auto_mock=True`: - -- `LLM` - Language model nodes -- `AGENT` - Agent execution nodes -- `TOOL` - Tool invocation nodes -- `KNOWLEDGE_RETRIEVAL` - Knowledge base query nodes -- `HTTP_REQUEST` - HTTP request nodes -- `PARAMETER_EXTRACTOR` - Parameter extraction nodes -- `DOCUMENT_EXTRACTOR` - Document processing nodes -- `QUESTION_CLASSIFIER` - Question classification nodes - -## Advanced Features - -### Registering Custom Mock Implementations - -```python -from test_mock_factory import MockNodeFactory - -# Create custom mock implementation -class CustomMockNode(BaseNode): - def _run(self): - # Custom mock logic - pass - -# Register for a specific node type -factory = MockNodeFactory(...) -factory.register_mock_node_type(NodeType.CUSTOM, CustomMockNode) -``` - -### Default Configurations by Node Type - -```python -# Set defaults for all nodes of a specific type -mock_config.set_default_config(NodeType.LLM, { - "temperature": 0.7, - "max_tokens": 100, -}) -``` - -### MockConfigBuilder Fluent API - -```python -config = (MockConfigBuilder() - .with_llm_response("LLM response") - .with_agent_response("Agent response") - .with_tool_response({"result": "data"}) - .with_retrieval_response("Retrieved content") - .with_http_response({"status_code": 200, "body": "{}"}) - .with_node_output("node_id", {"output": "value"}) - .with_node_error("error_node", "Error message") - .with_delays(True) - .build()) -``` - -## Testing Workflows - -### 1. Create Workflow Fixture - -Create a YAML fixture file in `api/tests/fixtures/workflow/` directory defining your workflow graph. - -### 2. Configure Mocks - -Set up mock configurations for nodes that need third-party services. - -### 3. Define Test Cases - -Create `WorkflowTestCase` instances with inputs, expected outputs, and mock config. - -### 4. Run Tests - -Use `TableTestRunner` to execute test cases and validate results. - -## Best Practices - -1. **Use descriptive mock responses** - Make it clear in outputs that they are mocked -1. **Test both success and failure paths** - Use error simulation to test error handling -1. **Keep mock configs close to tests** - Define mocks in the same test file for clarity -1. **Use custom handlers sparingly** - Only when dynamic behavior is needed -1. **Document mock behavior** - Comment why specific mock values are chosen -1. **Validate mock accuracy** - Ensure mocks reflect real service behavior - -## Examples - -See `test_mock_example.py` for comprehensive examples including: - -- Basic LLM workflow testing -- Custom node outputs -- HTTP and tool workflow testing -- Error simulation -- Performance testing with delays - -## Running Tests - -### TableTestRunner Tests +Examples: ```bash -# Run graph engine tests (includes property-based tests) -uv run pytest api/tests/unit_tests/dify_graph/graph_engine/test_graph_engine.py - -# Run with specific test patterns -uv run pytest api/tests/unit_tests/dify_graph/graph_engine/test_graph_engine.py -k "test_echo" - -# Run with verbose output -uv run pytest api/tests/unit_tests/dify_graph/graph_engine/test_graph_engine.py -v +uv run --project api --dev pytest api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py +uv run --project api --dev pytest api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py +uv run --project api --dev pytest api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py +uv run --project api --dev pytest api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py ``` - -### Mock System Tests - -```bash -# Run auto-mock system tests -uv run pytest api/tests/unit_tests/dify_graph/graph_engine/test_auto_mock_system.py - -# Run examples -uv run python api/tests/unit_tests/dify_graph/graph_engine/test_mock_example.py - -# Run simple validation -uv run python api/tests/unit_tests/dify_graph/graph_engine/test_mock_simple.py -``` - -### All Tests - -```bash -# Run all graph engine tests -uv run pytest api/tests/unit_tests/dify_graph/graph_engine/ - -# Run with coverage -uv run pytest api/tests/unit_tests/dify_graph/graph_engine/ --cov=dify_graph.graph_engine - -# Run in parallel -uv run pytest api/tests/unit_tests/dify_graph/graph_engine/ -n auto -``` - -## Troubleshooting - -### Issue: Mock not being applied - -- Ensure `use_auto_mock=True` in `WorkflowTestCase` -- Verify node ID matches in mock config -- Check that node type is in the auto-mock list - -### Issue: Unexpected outputs - -- Debug by printing `result.actual_outputs` -- Check if custom handler is overriding expected outputs -- Verify mock config is properly built - -### Issue: Import errors - -- Ensure all mock modules are in the correct path -- Check that required dependencies are installed - -## Future Enhancements - -Potential improvements to the auto-mock system: - -1. **Recording and playback** - Record real API responses for replay in tests -1. **Mock templates** - Pre-defined mock configurations for common scenarios -1. **Async support** - Better support for async node execution -1. **Mock validation** - Validate mock outputs against node schemas -1. **Performance profiling** - Built-in performance metrics for mocked workflows diff --git a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py b/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py deleted file mode 100644 index 4dec618e49c..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py +++ /dev/null @@ -1,315 +0,0 @@ -"""Tests for Redis command channel implementation.""" - -import json -from unittest.mock import MagicMock - -from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel -from dify_graph.graph_engine.entities.commands import ( - AbortCommand, - CommandType, - GraphEngineCommand, - UpdateVariablesCommand, - VariableUpdate, -) -from dify_graph.variables import IntegerVariable, StringVariable - - -class TestRedisChannel: - """Test suite for RedisChannel functionality.""" - - def test_init(self): - """Test RedisChannel initialization.""" - mock_redis = MagicMock() - channel_key = "test:channel:key" - ttl = 7200 - - channel = RedisChannel(mock_redis, channel_key, ttl) - - assert channel._redis == mock_redis - assert channel._key == channel_key - assert channel._command_ttl == ttl - - def test_init_default_ttl(self): - """Test RedisChannel initialization with default TTL.""" - mock_redis = MagicMock() - channel_key = "test:channel:key" - - channel = RedisChannel(mock_redis, channel_key) - - assert channel._command_ttl == 3600 # Default TTL - - def test_send_command(self): - """Test sending a command to Redis.""" - mock_redis = MagicMock() - mock_pipe = MagicMock() - context = MagicMock() - context.__enter__.return_value = mock_pipe - context.__exit__.return_value = None - mock_redis.pipeline.return_value = context - - channel = RedisChannel(mock_redis, "test:key", 3600) - - pending_key = "test:key:pending" - - # Create a test command - command = GraphEngineCommand(command_type=CommandType.ABORT) - - # Send the command - channel.send_command(command) - - # Verify pipeline was used - mock_redis.pipeline.assert_called_once() - - # Verify rpush was called with correct data - expected_json = json.dumps(command.model_dump()) - mock_pipe.rpush.assert_called_once_with("test:key", expected_json) - - # Verify expire was set - mock_pipe.expire.assert_called_once_with("test:key", 3600) - mock_pipe.set.assert_called_once_with(pending_key, "1", ex=3600) - - # Verify execute was called - mock_pipe.execute.assert_called_once() - - def test_fetch_commands_empty(self): - """Test fetching commands when Redis list is empty.""" - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context] - - # No pending marker - pending_pipe.execute.return_value = [None, 0] - mock_redis.llen.return_value = 0 - - channel = RedisChannel(mock_redis, "test:key") - commands = channel.fetch_commands() - - assert commands == [] - mock_redis.pipeline.assert_called_once() - fetch_pipe.lrange.assert_not_called() - fetch_pipe.delete.assert_not_called() - - def test_fetch_commands_with_abort_command(self): - """Test fetching abort commands from Redis.""" - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - # Create abort command data - abort_command = AbortCommand() - command_json = json.dumps(abort_command.model_dump()) - - # Simulate Redis returning one command - pending_pipe.execute.return_value = [b"1", 1] - fetch_pipe.execute.return_value = [[command_json.encode()], 1] - - channel = RedisChannel(mock_redis, "test:key") - commands = channel.fetch_commands() - - assert len(commands) == 1 - assert isinstance(commands[0], AbortCommand) - assert commands[0].command_type == CommandType.ABORT - - def test_fetch_commands_multiple(self): - """Test fetching multiple commands from Redis.""" - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - # Create multiple commands - command1 = GraphEngineCommand(command_type=CommandType.ABORT) - command2 = AbortCommand() - - command1_json = json.dumps(command1.model_dump()) - command2_json = json.dumps(command2.model_dump()) - - # Simulate Redis returning multiple commands - pending_pipe.execute.return_value = [b"1", 1] - fetch_pipe.execute.return_value = [[command1_json.encode(), command2_json.encode()], 1] - - channel = RedisChannel(mock_redis, "test:key") - commands = channel.fetch_commands() - - assert len(commands) == 2 - assert commands[0].command_type == CommandType.ABORT - assert isinstance(commands[1], AbortCommand) - - def test_fetch_commands_with_update_variables_command(self): - """Test fetching update variables command from Redis.""" - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - update_command = UpdateVariablesCommand( - updates=[ - VariableUpdate( - value=StringVariable(name="foo", value="bar", selector=["node1", "foo"]), - ), - VariableUpdate( - value=IntegerVariable(name="baz", value=123, selector=["node2", "baz"]), - ), - ] - ) - command_json = json.dumps(update_command.model_dump()) - - pending_pipe.execute.return_value = [b"1", 1] - fetch_pipe.execute.return_value = [[command_json.encode()], 1] - - channel = RedisChannel(mock_redis, "test:key") - commands = channel.fetch_commands() - - assert len(commands) == 1 - assert isinstance(commands[0], UpdateVariablesCommand) - assert isinstance(commands[0].updates[0].value, StringVariable) - assert list(commands[0].updates[0].value.selector) == ["node1", "foo"] - assert commands[0].updates[0].value.value == "bar" - - def test_fetch_commands_skips_invalid_json(self): - """Test that invalid JSON commands are skipped.""" - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - # Mix valid and invalid JSON - valid_command = AbortCommand() - valid_json = json.dumps(valid_command.model_dump()) - invalid_json = b"invalid json {" - - # Simulate Redis returning mixed valid/invalid commands - pending_pipe.execute.return_value = [b"1", 1] - fetch_pipe.execute.return_value = [[invalid_json, valid_json.encode()], 1] - - channel = RedisChannel(mock_redis, "test:key") - commands = channel.fetch_commands() - - # Should only return the valid command - assert len(commands) == 1 - assert isinstance(commands[0], AbortCommand) - - def test_deserialize_command_abort(self): - """Test deserializing an abort command.""" - channel = RedisChannel(MagicMock(), "test:key") - - abort_data = {"command_type": CommandType.ABORT} - command = channel._deserialize_command(abort_data) - - assert isinstance(command, AbortCommand) - assert command.command_type == CommandType.ABORT - - def test_deserialize_command_generic(self): - """Test deserializing a generic command.""" - channel = RedisChannel(MagicMock(), "test:key") - - # For now, only ABORT is supported, but test generic handling - generic_data = {"command_type": CommandType.ABORT} - command = channel._deserialize_command(generic_data) - - assert command is not None - assert command.command_type == CommandType.ABORT - - def test_deserialize_command_invalid(self): - """Test deserializing invalid command data.""" - channel = RedisChannel(MagicMock(), "test:key") - - # Missing command_type - invalid_data = {"some_field": "value"} - command = channel._deserialize_command(invalid_data) - - assert command is None - - def test_deserialize_command_invalid_type(self): - """Test deserializing command with invalid type.""" - channel = RedisChannel(MagicMock(), "test:key") - - # Invalid command type - invalid_data = {"command_type": "INVALID_TYPE"} - command = channel._deserialize_command(invalid_data) - - assert command is None - - def test_atomic_fetch_and_clear(self): - """Test that fetch_commands atomically fetches and clears the list.""" - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - command = AbortCommand() - command_json = json.dumps(command.model_dump()) - pending_pipe.execute.return_value = [b"1", 1] - fetch_pipe.execute.return_value = [[command_json.encode()], 1] - - channel = RedisChannel(mock_redis, "test:key") - - # First fetch should return the command - commands = channel.fetch_commands() - assert len(commands) == 1 - - # Verify both lrange and delete were called in the pipeline - assert fetch_pipe.lrange.call_count == 1 - assert fetch_pipe.delete.call_count == 1 - fetch_pipe.lrange.assert_called_with("test:key", 0, -1) - fetch_pipe.delete.assert_called_with("test:key") - - def test_fetch_commands_without_pending_marker_returns_empty(self): - """Ensure we avoid unnecessary list reads when pending flag is missing.""" - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - # Pending flag absent - pending_pipe.execute.return_value = [None, 0] - channel = RedisChannel(mock_redis, "test:key") - commands = channel.fetch_commands() - - assert commands == [] - mock_redis.llen.assert_not_called() - assert mock_redis.pipeline.call_count == 1 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py deleted file mode 100644 index 6f821ba7991..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py +++ /dev/null @@ -1,119 +0,0 @@ -"""Tests for graph engine event handlers.""" - -from __future__ import annotations - -from dify_graph.entities.base_node_data import RetryConfig -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeState, WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.graph_engine.domain.graph_execution import GraphExecution -from dify_graph.graph_engine.event_management.event_handlers import EventHandler -from dify_graph.graph_engine.event_management.event_manager import EventManager -from dify_graph.graph_engine.graph_state_manager import GraphStateManager -from dify_graph.graph_engine.ready_queue.in_memory import InMemoryReadyQueue -from dify_graph.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator -from dify_graph.graph_events import NodeRunRetryEvent, NodeRunStartedEvent -from dify_graph.node_events import NodeRunResult -from dify_graph.runtime import GraphRuntimeState, VariablePool -from libs.datetime_utils import naive_utc_now - - -class _StubEdgeProcessor: - """Minimal edge processor stub for tests.""" - - -class _StubErrorHandler: - """Minimal error handler stub for tests.""" - - -class _StubNode: - """Simple node stub exposing the attributes needed by the state manager.""" - - def __init__(self, node_id: str) -> None: - self.id = node_id - self.state = NodeState.UNKNOWN - self.title = "Stub Node" - self.execution_type = NodeExecutionType.EXECUTABLE - self.error_strategy = None - self.retry_config = RetryConfig() - self.retry = False - - -def _build_event_handler(node_id: str) -> tuple[EventHandler, EventManager, GraphExecution]: - """Construct an EventHandler with in-memory dependencies for testing.""" - - node = _StubNode(node_id) - graph = Graph(nodes={node_id: node}, edges={}, in_edges={}, out_edges={}, root_node=node) - - variable_pool = VariablePool() - runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) - graph_execution = GraphExecution(workflow_id="test-workflow") - - event_manager = EventManager() - state_manager = GraphStateManager(graph=graph, ready_queue=InMemoryReadyQueue()) - response_coordinator = ResponseStreamCoordinator(variable_pool=variable_pool, graph=graph) - - handler = EventHandler( - graph=graph, - graph_runtime_state=runtime_state, - graph_execution=graph_execution, - response_coordinator=response_coordinator, - event_collector=event_manager, - edge_processor=_StubEdgeProcessor(), - state_manager=state_manager, - error_handler=_StubErrorHandler(), - ) - - return handler, event_manager, graph_execution - - -def test_retry_does_not_emit_additional_start_event() -> None: - """Ensure retry attempts do not produce duplicate start events.""" - - node_id = "test-node" - handler, event_manager, graph_execution = _build_event_handler(node_id) - - execution_id = "exec-1" - node_type = BuiltinNodeTypes.CODE - start_time = naive_utc_now() - - start_event = NodeRunStartedEvent( - id=execution_id, - node_id=node_id, - node_type=node_type, - node_title="Stub Node", - start_at=start_time, - ) - handler.dispatch(start_event) - - retry_event = NodeRunRetryEvent( - id=execution_id, - node_id=node_id, - node_type=node_type, - node_title="Stub Node", - start_at=start_time, - error="boom", - retry_index=1, - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error="boom", - error_type="TestError", - ), - ) - handler.dispatch(retry_event) - - # Simulate the node starting execution again after retry - second_start_event = NodeRunStartedEvent( - id=execution_id, - node_id=node_id, - node_type=node_type, - node_title="Stub Node", - start_at=start_time, - ) - handler.dispatch(second_start_event) - - collected_types = [type(event) for event in event_manager._events] # type: ignore[attr-defined] - - assert collected_types == [NodeRunStartedEvent, NodeRunRetryEvent] - - node_execution = graph_execution.get_or_create_node_execution(node_id) - assert node_execution.retry_count == 1 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py deleted file mode 100644 index 25494dc647e..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Tests for the EventManager.""" - -from __future__ import annotations - -import logging - -from dify_graph.graph_engine.event_management.event_manager import EventManager -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events import GraphEngineEvent - - -class _FaultyLayer(GraphEngineLayer): - """Layer that raises from on_event to test error handling.""" - - def on_graph_start(self) -> None: # pragma: no cover - not used in tests - pass - - def on_event(self, event: GraphEngineEvent) -> None: - raise RuntimeError("boom") - - def on_graph_end(self, error: Exception | None) -> None: # pragma: no cover - not used in tests - pass - - -def test_event_manager_logs_layer_errors(caplog) -> None: - """Ensure errors raised by layers are logged when collecting events.""" - - event_manager = EventManager() - event_manager.set_layers([_FaultyLayer()]) - - with caplog.at_level(logging.ERROR): - event_manager.collect(GraphEngineEvent()) - - error_logs = [record for record in caplog.records if "Error in layer on_event" in record.getMessage()] - assert error_logs, "Expected layer errors to be logged" - - log_record = error_logs[0] - assert log_record.exc_info is not None - assert isinstance(log_record.exc_info[1], RuntimeError) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/__init__.py b/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/__init__.py deleted file mode 100644 index cf8811dc2b5..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Tests for graph traversal components.""" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py b/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py deleted file mode 100644 index 73d59ea4e98..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py +++ /dev/null @@ -1,307 +0,0 @@ -"""Unit tests for skip propagator.""" - -from unittest.mock import MagicMock, create_autospec - -from dify_graph.graph import Edge, Graph -from dify_graph.graph_engine.graph_state_manager import GraphStateManager -from dify_graph.graph_engine.graph_traversal.skip_propagator import SkipPropagator - - -class TestSkipPropagator: - """Test suite for SkipPropagator.""" - - def test_propagate_skip_from_edge_with_unknown_edges_stops_processing(self) -> None: - """When there are unknown incoming edges, propagation should stop.""" - # Arrange - mock_graph = create_autospec(Graph) - mock_state_manager = create_autospec(GraphStateManager) - - # Create a mock edge - mock_edge = MagicMock(spec=Edge) - mock_edge.id = "edge_1" - mock_edge.head = "node_2" - - # Setup graph edges dict - mock_graph.edges = {"edge_1": mock_edge} - - # Setup incoming edges - incoming_edges = [MagicMock(spec=Edge), MagicMock(spec=Edge)] - mock_graph.get_incoming_edges.return_value = incoming_edges - - # Setup state manager to return has_unknown=True - mock_state_manager.analyze_edge_states.return_value = { - "has_unknown": True, - "has_taken": False, - "all_skipped": False, - } - - propagator = SkipPropagator(mock_graph, mock_state_manager) - - # Act - propagator.propagate_skip_from_edge("edge_1") - - # Assert - mock_graph.get_incoming_edges.assert_called_once_with("node_2") - mock_state_manager.analyze_edge_states.assert_called_once_with(incoming_edges) - # Should not call any other state manager methods - mock_state_manager.enqueue_node.assert_not_called() - mock_state_manager.start_execution.assert_not_called() - mock_state_manager.mark_node_skipped.assert_not_called() - - def test_propagate_skip_from_edge_with_taken_edge_enqueues_node(self) -> None: - """When there is at least one taken edge, node should be enqueued.""" - # Arrange - mock_graph = create_autospec(Graph) - mock_state_manager = create_autospec(GraphStateManager) - - # Create a mock edge - mock_edge = MagicMock(spec=Edge) - mock_edge.id = "edge_1" - mock_edge.head = "node_2" - - mock_graph.edges = {"edge_1": mock_edge} - incoming_edges = [MagicMock(spec=Edge)] - mock_graph.get_incoming_edges.return_value = incoming_edges - - # Setup state manager to return has_taken=True - mock_state_manager.analyze_edge_states.return_value = { - "has_unknown": False, - "has_taken": True, - "all_skipped": False, - } - - propagator = SkipPropagator(mock_graph, mock_state_manager) - - # Act - propagator.propagate_skip_from_edge("edge_1") - - # Assert - mock_state_manager.enqueue_node.assert_called_once_with("node_2") - mock_state_manager.start_execution.assert_called_once_with("node_2") - mock_state_manager.mark_node_skipped.assert_not_called() - - def test_propagate_skip_from_edge_with_all_skipped_propagates_to_node(self) -> None: - """When all incoming edges are skipped, should propagate skip to node.""" - # Arrange - mock_graph = create_autospec(Graph) - mock_state_manager = create_autospec(GraphStateManager) - - # Create a mock edge - mock_edge = MagicMock(spec=Edge) - mock_edge.id = "edge_1" - mock_edge.head = "node_2" - - mock_graph.edges = {"edge_1": mock_edge} - incoming_edges = [MagicMock(spec=Edge)] - mock_graph.get_incoming_edges.return_value = incoming_edges - - # Setup state manager to return all_skipped=True - mock_state_manager.analyze_edge_states.return_value = { - "has_unknown": False, - "has_taken": False, - "all_skipped": True, - } - - propagator = SkipPropagator(mock_graph, mock_state_manager) - - # Act - propagator.propagate_skip_from_edge("edge_1") - - # Assert - mock_state_manager.mark_node_skipped.assert_called_once_with("node_2") - mock_state_manager.enqueue_node.assert_not_called() - mock_state_manager.start_execution.assert_not_called() - - def test_propagate_skip_to_node_marks_node_and_outgoing_edges_skipped(self) -> None: - """_propagate_skip_to_node should mark node and all outgoing edges as skipped.""" - # Arrange - mock_graph = create_autospec(Graph) - mock_state_manager = create_autospec(GraphStateManager) - - # Create outgoing edges - edge1 = MagicMock(spec=Edge) - edge1.id = "edge_2" - edge1.head = "node_downstream_1" # Set head for propagate_skip_from_edge - - edge2 = MagicMock(spec=Edge) - edge2.id = "edge_3" - edge2.head = "node_downstream_2" - - # Setup graph edges dict for propagate_skip_from_edge - mock_graph.edges = {"edge_2": edge1, "edge_3": edge2} - mock_graph.get_outgoing_edges.return_value = [edge1, edge2] - - # Setup get_incoming_edges to return empty list to stop recursion - mock_graph.get_incoming_edges.return_value = [] - - propagator = SkipPropagator(mock_graph, mock_state_manager) - - # Use mock to call private method - # Act - propagator._propagate_skip_to_node("node_1") - - # Assert - mock_state_manager.mark_node_skipped.assert_called_once_with("node_1") - mock_state_manager.mark_edge_skipped.assert_any_call("edge_2") - mock_state_manager.mark_edge_skipped.assert_any_call("edge_3") - assert mock_state_manager.mark_edge_skipped.call_count == 2 - # Should recursively propagate from each edge - # Since propagate_skip_from_edge is called, we need to verify it was called - # But we can't directly verify due to recursion. We'll trust the logic. - - def test_skip_branch_paths_marks_unselected_edges_and_propagates(self) -> None: - """skip_branch_paths should mark all unselected edges as skipped and propagate.""" - # Arrange - mock_graph = create_autospec(Graph) - mock_state_manager = create_autospec(GraphStateManager) - - # Create unselected edges - edge1 = MagicMock(spec=Edge) - edge1.id = "edge_1" - edge1.head = "node_downstream_1" - - edge2 = MagicMock(spec=Edge) - edge2.id = "edge_2" - edge2.head = "node_downstream_2" - - unselected_edges = [edge1, edge2] - - # Setup graph edges dict - mock_graph.edges = {"edge_1": edge1, "edge_2": edge2} - # Setup get_incoming_edges to return empty list to stop recursion - mock_graph.get_incoming_edges.return_value = [] - - propagator = SkipPropagator(mock_graph, mock_state_manager) - - # Act - propagator.skip_branch_paths(unselected_edges) - - # Assert - mock_state_manager.mark_edge_skipped.assert_any_call("edge_1") - mock_state_manager.mark_edge_skipped.assert_any_call("edge_2") - assert mock_state_manager.mark_edge_skipped.call_count == 2 - # propagate_skip_from_edge should be called for each edge - # We can't directly verify due to the mock, but the logic is covered - - def test_propagate_skip_from_edge_recursively_propagates_through_graph(self) -> None: - """Skip propagation should recursively propagate through the graph.""" - # Arrange - mock_graph = create_autospec(Graph) - mock_state_manager = create_autospec(GraphStateManager) - - # Create edge chain: edge_1 -> node_2 -> edge_3 -> node_4 - edge1 = MagicMock(spec=Edge) - edge1.id = "edge_1" - edge1.head = "node_2" - - edge3 = MagicMock(spec=Edge) - edge3.id = "edge_3" - edge3.head = "node_4" - - mock_graph.edges = {"edge_1": edge1, "edge_3": edge3} - - # Setup get_incoming_edges to return different values based on node - def get_incoming_edges_side_effect(node_id): - if node_id == "node_2": - return [edge1] - elif node_id == "node_4": - return [edge3] - return [] - - mock_graph.get_incoming_edges.side_effect = get_incoming_edges_side_effect - - # Setup get_outgoing_edges to return different values based on node - def get_outgoing_edges_side_effect(node_id): - if node_id == "node_2": - return [edge3] - elif node_id == "node_4": - return [] # No outgoing edges, stops recursion - return [] - - mock_graph.get_outgoing_edges.side_effect = get_outgoing_edges_side_effect - - # Setup state manager to return all_skipped for both nodes - mock_state_manager.analyze_edge_states.return_value = { - "has_unknown": False, - "has_taken": False, - "all_skipped": True, - } - - propagator = SkipPropagator(mock_graph, mock_state_manager) - - # Act - propagator.propagate_skip_from_edge("edge_1") - - # Assert - # Should mark node_2 as skipped - mock_state_manager.mark_node_skipped.assert_any_call("node_2") - # Should mark edge_3 as skipped - mock_state_manager.mark_edge_skipped.assert_any_call("edge_3") - # Should propagate to node_4 - mock_state_manager.mark_node_skipped.assert_any_call("node_4") - assert mock_state_manager.mark_node_skipped.call_count == 2 - - def test_propagate_skip_from_edge_with_mixed_edge_states_handles_correctly(self) -> None: - """Test with mixed edge states (some unknown, some taken, some skipped).""" - # Arrange - mock_graph = create_autospec(Graph) - mock_state_manager = create_autospec(GraphStateManager) - - mock_edge = MagicMock(spec=Edge) - mock_edge.id = "edge_1" - mock_edge.head = "node_2" - - mock_graph.edges = {"edge_1": mock_edge} - incoming_edges = [MagicMock(spec=Edge), MagicMock(spec=Edge), MagicMock(spec=Edge)] - mock_graph.get_incoming_edges.return_value = incoming_edges - - # Test 1: has_unknown=True, has_taken=False, all_skipped=False - mock_state_manager.analyze_edge_states.return_value = { - "has_unknown": True, - "has_taken": False, - "all_skipped": False, - } - - propagator = SkipPropagator(mock_graph, mock_state_manager) - - # Act - propagator.propagate_skip_from_edge("edge_1") - - # Assert - should stop processing - mock_state_manager.enqueue_node.assert_not_called() - mock_state_manager.mark_node_skipped.assert_not_called() - - # Reset mocks for next test - mock_state_manager.reset_mock() - mock_graph.reset_mock() - - # Test 2: has_unknown=False, has_taken=True, all_skipped=False - mock_state_manager.analyze_edge_states.return_value = { - "has_unknown": False, - "has_taken": True, - "all_skipped": False, - } - - # Act - propagator.propagate_skip_from_edge("edge_1") - - # Assert - should enqueue node - mock_state_manager.enqueue_node.assert_called_once_with("node_2") - mock_state_manager.start_execution.assert_called_once_with("node_2") - - # Reset mocks for next test - mock_state_manager.reset_mock() - mock_graph.reset_mock() - - # Test 3: has_unknown=False, has_taken=False, all_skipped=True - mock_state_manager.analyze_edge_states.return_value = { - "has_unknown": False, - "has_taken": False, - "all_skipped": True, - } - - # Act - propagator.propagate_skip_from_edge("edge_1") - - # Assert - should propagate skip - mock_state_manager.mark_node_skipped.assert_called_once_with("node_2") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py b/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py deleted file mode 100644 index fc8133f5e1b..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py +++ /dev/null @@ -1,131 +0,0 @@ -"""Utilities for testing HumanInputNode without database dependencies.""" - -from __future__ import annotations - -from collections.abc import Mapping -from dataclasses import dataclass -from datetime import datetime, timedelta -from typing import Any - -from dify_graph.nodes.human_input.enums import HumanInputFormStatus -from dify_graph.repositories.human_input_form_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRecipientEntity, - HumanInputFormRepository, -) -from libs.datetime_utils import naive_utc_now - - -class _InMemoryFormRecipient(HumanInputFormRecipientEntity): - """Minimal recipient entity required by the repository interface.""" - - def __init__(self, recipient_id: str, token: str) -> None: - self._id = recipient_id - self._token = token - - @property - def id(self) -> str: - return self._id - - @property - def token(self) -> str: - return self._token - - -@dataclass -class _InMemoryFormEntity(HumanInputFormEntity): - form_id: str - rendered: str - token: str | None = None - action_id: str | None = None - data: Mapping[str, Any] | None = None - is_submitted: bool = False - status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING - expiration: datetime = naive_utc_now() - - @property - def id(self) -> str: - return self.form_id - - @property - def web_app_token(self) -> str | None: - return self.token - - @property - def recipients(self) -> list[HumanInputFormRecipientEntity]: - return [] - - @property - def rendered_content(self) -> str: - return self.rendered - - @property - def selected_action_id(self) -> str | None: - return self.action_id - - @property - def submitted_data(self) -> Mapping[str, Any] | None: - return self.data - - @property - def submitted(self) -> bool: - return self.is_submitted - - @property - def status(self) -> HumanInputFormStatus: - return self.status_value - - @property - def expiration_time(self) -> datetime: - return self.expiration - - -class InMemoryHumanInputFormRepository(HumanInputFormRepository): - """Pure in-memory repository used by workflow graph engine tests.""" - - def __init__(self) -> None: - self._form_counter = 0 - self.created_params: list[FormCreateParams] = [] - self.created_forms: list[_InMemoryFormEntity] = [] - self._forms_by_key: dict[tuple[str, str], _InMemoryFormEntity] = {} - - def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: - self.created_params.append(params) - self._form_counter += 1 - form_id = f"form-{self._form_counter}" - token = f"console-{form_id}" if params.console_recipient_required else f"token-{form_id}" - entity = _InMemoryFormEntity( - form_id=form_id, - rendered=params.rendered_content, - token=token, - ) - self.created_forms.append(entity) - self._forms_by_key[(params.workflow_execution_id, params.node_id)] = entity - return entity - - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: - return self._forms_by_key.get((workflow_execution_id, node_id)) - - # Convenience helpers for tests ------------------------------------- - - def set_submission(self, *, action_id: str, form_data: Mapping[str, Any] | None = None) -> None: - """Simulate a human submission for the next repository lookup.""" - - if not self.created_forms: - raise AssertionError("no form has been created to attach submission data") - entity = self.created_forms[-1] - entity.action_id = action_id - entity.data = form_data or {} - entity.is_submitted = True - entity.status_value = HumanInputFormStatus.SUBMITTED - entity.expiration = naive_utc_now() + timedelta(days=1) - - def clear_submission(self) -> None: - if not self.created_forms: - return - for form in self.created_forms: - form.action_id = None - form.data = None - form.is_submitted = False - form.status_value = HumanInputFormStatus.WAITING diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py index 9e7b3654b72..41627f5e0be 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py @@ -5,13 +5,12 @@ Shared fixtures for ObservabilityLayer tests. from unittest.mock import MagicMock, patch import pytest +from graphon.enums import BuiltinNodeTypes from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from opentelemetry.trace import set_tracer_provider -from dify_graph.enums import BuiltinNodeTypes - @pytest.fixture def memory_span_exporter(): @@ -62,8 +61,9 @@ def mock_llm_node(): @pytest.fixture def mock_tool_node(): """Create a mock Tool Node with tool-specific attributes.""" + from graphon.nodes.tool.entities import ToolNodeData + from core.tools.entities.tool_entities import ToolProviderType - from dify_graph.nodes.tool.entities import ToolNodeData node = MagicMock() node.id = "test-tool-node-id" @@ -117,8 +117,8 @@ def mock_result_event(): """Create a mock result event with NodeRunResult.""" from datetime import datetime - from dify_graph.graph_events.node import NodeRunSucceededEvent - from dify_graph.node_events.base import NodeRunResult + from graphon.graph_events import NodeRunSucceededEvent + from graphon.node_events import NodeRunResult node_run_result = NodeRunResult( inputs={"query": "test query"}, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py deleted file mode 100644 index db325278490..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py +++ /dev/null @@ -1,57 +0,0 @@ -from __future__ import annotations - -import pytest - -from dify_graph.graph_engine import GraphEngine, GraphEngineConfig -from dify_graph.graph_engine.command_channels import InMemoryChannel -from dify_graph.graph_engine.layers.base import ( - GraphEngineLayer, - GraphEngineLayerNotInitializedError, -) -from dify_graph.graph_events import GraphEngineEvent - -from ..test_table_runner import WorkflowRunner - - -class LayerForTest(GraphEngineLayer): - def on_graph_start(self) -> None: - pass - - def on_event(self, event: GraphEngineEvent) -> None: - pass - - def on_graph_end(self, error: Exception | None) -> None: - pass - - -def test_layer_runtime_state_raises_when_uninitialized() -> None: - layer = LayerForTest() - - with pytest.raises(GraphEngineLayerNotInitializedError): - _ = layer.graph_runtime_state - - -def test_layer_runtime_state_available_after_engine_layer() -> None: - runner = WorkflowRunner() - fixture_data = runner.load_fixture("simple_passthrough_workflow") - graph, graph_runtime_state = runner.create_graph_from_fixture( - fixture_data, - inputs={"query": "test layer state"}, - ) - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - layer = LayerForTest() - engine.layer(layer) - - outputs = layer.graph_runtime_state.outputs - ready_queue_size = layer.graph_runtime_state.ready_queue_size - - assert outputs == {} - assert isinstance(ready_queue_size, int) - assert ready_queue_size >= 0 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py index 2a36f712fd7..99d131737e9 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py @@ -1,14 +1,28 @@ import threading from datetime import datetime +from types import SimpleNamespace from unittest.mock import MagicMock, patch +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.graph_engine.entities.commands import CommandType +from graphon.graph_events import NodeRunSucceededEvent +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.node_events import NodeRunResult + +from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom from core.app.workflow.layers.llm_quota import LLMQuotaLayer from core.errors.error import QuotaExceededError -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.graph_engine.entities.commands import CommandType -from dify_graph.graph_events.node import NodeRunSucceededEvent -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.node_events import NodeRunResult +from core.model_manager import ModelInstance + + +def _build_dify_context() -> DifyRunContext: + return DifyRunContext( + tenant_id="tenant-id", + app_id="app-id", + user_id="user-id", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) def _build_succeeded_event() -> NodeRunSucceededEvent: @@ -25,6 +39,11 @@ def _build_succeeded_event() -> NodeRunSucceededEvent: ) +def _build_wrapped_model_instance() -> tuple[SimpleNamespace, ModelInstance]: + raw_model_instance = ModelInstance.__new__(ModelInstance) + return SimpleNamespace(_model_instance=raw_model_instance), raw_model_instance + + def test_deduct_quota_called_for_successful_llm_node() -> None: layer = LLMQuotaLayer() node = MagicMock() @@ -32,8 +51,8 @@ def test_deduct_quota_called_for_successful_llm_node() -> None: node.execution_id = "execution-id" node.node_type = BuiltinNodeTypes.LLM node.tenant_id = "tenant-id" - node.require_dify_context.return_value.tenant_id = "tenant-id" - node.model_instance = object() + node.require_run_context_value.return_value = _build_dify_context() + node.model_instance, raw_model_instance = _build_wrapped_model_instance() result_event = _build_succeeded_event() with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct: @@ -41,7 +60,7 @@ def test_deduct_quota_called_for_successful_llm_node() -> None: mock_deduct.assert_called_once_with( tenant_id="tenant-id", - model_instance=node.model_instance, + model_instance=raw_model_instance, usage=result_event.node_run_result.llm_usage, ) @@ -53,8 +72,8 @@ def test_deduct_quota_called_for_question_classifier_node() -> None: node.execution_id = "execution-id" node.node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER node.tenant_id = "tenant-id" - node.require_dify_context.return_value.tenant_id = "tenant-id" - node.model_instance = object() + node.require_run_context_value.return_value = _build_dify_context() + node.model_instance, raw_model_instance = _build_wrapped_model_instance() result_event = _build_succeeded_event() with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct: @@ -62,7 +81,7 @@ def test_deduct_quota_called_for_question_classifier_node() -> None: mock_deduct.assert_called_once_with( tenant_id="tenant-id", - model_instance=node.model_instance, + model_instance=raw_model_instance, usage=result_event.node_run_result.llm_usage, ) @@ -74,7 +93,7 @@ def test_non_llm_node_is_ignored() -> None: node.execution_id = "execution-id" node.node_type = BuiltinNodeTypes.START node.tenant_id = "tenant-id" - node.require_dify_context.return_value.tenant_id = "tenant-id" + node.require_run_context_value.return_value = _build_dify_context() node._model_instance = object() result_event = _build_succeeded_event() @@ -91,7 +110,7 @@ def test_quota_error_is_handled_in_layer() -> None: node.execution_id = "execution-id" node.node_type = BuiltinNodeTypes.LLM node.tenant_id = "tenant-id" - node.require_dify_context.return_value.tenant_id = "tenant-id" + node.require_run_context_value.return_value = _build_dify_context() node.model_instance = object() result_event = _build_succeeded_event() @@ -113,8 +132,8 @@ def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None: node.execution_id = "execution-id" node.node_type = BuiltinNodeTypes.LLM node.tenant_id = "tenant-id" - node.require_dify_context.return_value.tenant_id = "tenant-id" - node.model_instance = object() + node.require_run_context_value.return_value = _build_dify_context() + node.model_instance, _ = _build_wrapped_model_instance() node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event @@ -141,7 +160,7 @@ def test_quota_precheck_failure_aborts_workflow_immediately() -> None: node = MagicMock() node.id = "llm-node-id" node.node_type = BuiltinNodeTypes.LLM - node.model_instance = object() + node.model_instance, _ = _build_wrapped_model_instance() node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event @@ -167,7 +186,7 @@ def test_quota_precheck_passes_without_abort() -> None: node = MagicMock() node.id = "llm-node-id" node.node_type = BuiltinNodeTypes.LLM - node.model_instance = object() + node.model_instance, raw_model_instance = _build_wrapped_model_instance() node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event @@ -175,5 +194,5 @@ def test_quota_precheck_passes_without_abort() -> None: layer.on_node_run_start(node) assert not stop_event.is_set() - mock_check.assert_called_once_with(model_instance=node.model_instance) + mock_check.assert_called_once_with(model_instance=raw_model_instance) layer.command_channel.send_command.assert_not_called() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py index 478a2b592e4..9cf72763ee2 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py @@ -13,10 +13,10 @@ Test coverage: from unittest.mock import patch import pytest +from graphon.enums import BuiltinNodeTypes from opentelemetry.trace import StatusCode from core.app.workflow.layers.observability import ObservabilityLayer -from dify_graph.enums import BuiltinNodeTypes class TestObservabilityLayerInitialization: @@ -144,7 +144,7 @@ class TestObservabilityLayerParserIntegration: self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node, mock_result_event ): """Test that LLM parser is used for LLM nodes and extracts LLM-specific attributes.""" - from dify_graph.node_events.base import NodeRunResult + from graphon.node_events import NodeRunResult mock_result_event.node_run_result = NodeRunResult( inputs={}, @@ -182,7 +182,7 @@ class TestObservabilityLayerParserIntegration: self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_retrieval_node, mock_result_event ): """Test that retrieval parser is used for retrieval nodes and extracts retrieval-specific attributes.""" - from dify_graph.node_events.base import NodeRunResult + from graphon.node_events import NodeRunResult mock_result_event.node_run_result = NodeRunResult( inputs={"query": "test query"}, @@ -210,7 +210,7 @@ class TestObservabilityLayerParserIntegration: self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_start_node, mock_result_event ): """Test that result_event parameter allows parsers to extract inputs and outputs.""" - from dify_graph.node_events.base import NodeRunResult + from graphon.node_events import NodeRunResult mock_result_event.node_run_result = NodeRunResult( inputs={"input_key": "input_value"}, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py b/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py deleted file mode 100644 index 548c10ce8d5..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py +++ /dev/null @@ -1,189 +0,0 @@ -"""Tests for dispatcher command checking behavior.""" - -from __future__ import annotations - -import queue -from unittest import mock - -from dify_graph.entities.pause_reason import SchedulingPause -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.graph_engine.event_management.event_handlers import EventHandler -from dify_graph.graph_engine.orchestration.dispatcher import Dispatcher -from dify_graph.graph_engine.orchestration.execution_coordinator import ExecutionCoordinator -from dify_graph.graph_events import ( - GraphNodeEventBase, - NodeRunPauseRequestedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, -) -from dify_graph.node_events import NodeRunResult -from libs.datetime_utils import naive_utc_now - - -def test_dispatcher_should_consume_remains_events_after_pause(): - event_queue = queue.Queue() - event_queue.put( - GraphNodeEventBase( - id="test", - node_id="test", - node_type=BuiltinNodeTypes.START, - ) - ) - event_handler = mock.Mock(spec=EventHandler) - execution_coordinator = mock.Mock(spec=ExecutionCoordinator) - execution_coordinator.paused.return_value = True - dispatcher = Dispatcher( - event_queue=event_queue, - event_handler=event_handler, - execution_coordinator=execution_coordinator, - ) - dispatcher._dispatcher_loop() - assert event_queue.empty() - - -class _StubExecutionCoordinator: - """Stub execution coordinator that tracks command checks.""" - - def __init__(self) -> None: - self.command_checks = 0 - self.scaling_checks = 0 - self.execution_complete = False - self.failed = False - self._paused = False - - def process_commands(self) -> None: - self.command_checks += 1 - - def check_scaling(self) -> None: - self.scaling_checks += 1 - - @property - def paused(self) -> bool: - return self._paused - - @property - def aborted(self) -> bool: - return False - - def mark_complete(self) -> None: - self.execution_complete = True - - def mark_failed(self, error: Exception) -> None: # pragma: no cover - defensive, not triggered in tests - self.failed = True - - -class _StubEventHandler: - """Minimal event handler that marks execution complete after handling an event.""" - - def __init__(self, coordinator: _StubExecutionCoordinator) -> None: - self._coordinator = coordinator - self.events = [] - - def dispatch(self, event) -> None: - self.events.append(event) - self._coordinator.mark_complete() - - -def _run_dispatcher_for_event(event) -> int: - """Run the dispatcher loop for a single event and return command check count.""" - event_queue: queue.Queue = queue.Queue() - event_queue.put(event) - - coordinator = _StubExecutionCoordinator() - event_handler = _StubEventHandler(coordinator) - - dispatcher = Dispatcher( - event_queue=event_queue, - event_handler=event_handler, - execution_coordinator=coordinator, - ) - - dispatcher._dispatcher_loop() - - return coordinator.command_checks - - -def _make_started_event() -> NodeRunStartedEvent: - return NodeRunStartedEvent( - id="start-event", - node_id="node-1", - node_type=BuiltinNodeTypes.CODE, - node_title="Test Node", - start_at=naive_utc_now(), - ) - - -def _make_succeeded_event() -> NodeRunSucceededEvent: - return NodeRunSucceededEvent( - id="success-event", - node_id="node-1", - node_type=BuiltinNodeTypes.CODE, - node_title="Test Node", - start_at=naive_utc_now(), - node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED), - ) - - -def test_dispatcher_checks_commands_during_idle_and_on_completion() -> None: - """Dispatcher polls commands when idle and after completion events.""" - started_checks = _run_dispatcher_for_event(_make_started_event()) - succeeded_checks = _run_dispatcher_for_event(_make_succeeded_event()) - - assert started_checks == 2 - assert succeeded_checks == 3 - - -class _PauseStubEventHandler: - """Minimal event handler that marks execution complete after handling an event.""" - - def __init__(self, coordinator: _StubExecutionCoordinator) -> None: - self._coordinator = coordinator - self.events = [] - - def dispatch(self, event) -> None: - self.events.append(event) - if isinstance(event, NodeRunPauseRequestedEvent): - self._coordinator.mark_complete() - - -def test_dispatcher_drain_event_queue(): - events = [ - NodeRunStartedEvent( - id="start-event", - node_id="node-1", - node_type=BuiltinNodeTypes.CODE, - node_title="Code", - start_at=naive_utc_now(), - ), - NodeRunPauseRequestedEvent( - id="pause-event", - node_id="node-1", - node_type=BuiltinNodeTypes.CODE, - reason=SchedulingPause(message="test pause"), - ), - NodeRunSucceededEvent( - id="success-event", - node_id="node-1", - node_type=BuiltinNodeTypes.CODE, - start_at=naive_utc_now(), - node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED), - ), - ] - - event_queue: queue.Queue = queue.Queue() - for e in events: - event_queue.put(e) - - coordinator = _StubExecutionCoordinator() - event_handler = _PauseStubEventHandler(coordinator) - - dispatcher = Dispatcher( - event_queue=event_queue, - event_handler=event_handler, - execution_coordinator=coordinator, - ) - - dispatcher._dispatcher_loop() - - # ensure all events are drained. - assert event_queue.empty() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_answer_end_with_text.py b/api/tests/unit_tests/core/workflow/graph_engine/test_answer_end_with_text.py deleted file mode 100644 index 7af6b26d87e..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_answer_end_with_text.py +++ /dev/null @@ -1,37 +0,0 @@ -from dify_graph.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) - -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_answer_end_with_text(): - fixture_name = "answer_end_with_text" - case = WorkflowTestCase( - fixture_name, - query="Hello, AI!", - expected_outputs={"answer": "prefixHello, AI!suffix"}, - expected_event_sequence=[ - GraphRunStartedEvent, - # Start - NodeRunStartedEvent, - # The chunks are now emitted as the Answer node processes them - # since sys.query is a special selector that gets attributed to - # the active response node - NodeRunStreamChunkEvent, # prefix - NodeRunStreamChunkEvent, # sys.query - NodeRunStreamChunkEvent, # suffix - NodeRunSucceededEvent, - # Answer - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], - ) - runner = TableTestRunner() - result = runner.run_test_case(case) - assert result.success, f"Test failed: {result.error}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_answer_order_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_answer_order_workflow.py deleted file mode 100644 index 6569439b568..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_answer_order_workflow.py +++ /dev/null @@ -1,28 +0,0 @@ -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - -LLM_NODE_ID = "1759052580454" - - -def test_answer_nodes_emit_in_order() -> None: - mock_config = ( - MockConfigBuilder() - .with_llm_response("unused default") - .with_node_output(LLM_NODE_ID, {"text": "mocked llm text"}) - .build() - ) - - expected_answer = "--- answer 1 ---\n\nfoo\n--- answer 2 ---\n\nmocked llm text\n" - - case = WorkflowTestCase( - fixture_path="test-answer-order", - query="", - expected_outputs={"answer": expected_answer}, - use_auto_mock=True, - mock_config=mock_config, - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - - assert result.success, result.error diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_array_iteration_formatting_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_array_iteration_formatting_workflow.py deleted file mode 100644 index 05ec565def6..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_array_iteration_formatting_workflow.py +++ /dev/null @@ -1,24 +0,0 @@ -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_array_iteration_formatting_workflow(): - """ - Validate Iteration node processes [1,2,3] into formatted strings. - - Fixture description expects: - {"output": ["output: 1", "output: 2", "output: 3"]} - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="array_iteration_formatting_workflow", - inputs={}, - expected_outputs={"output": ["output: 1", "output: 2", "output: 3"]}, - description="Iteration formats numbers into strings", - use_auto_mock=True, - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Iteration workflow failed: {result.error}" - assert result.actual_outputs == test_case.expected_outputs diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py deleted file mode 100644 index fc0d22f7396..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py +++ /dev/null @@ -1,392 +0,0 @@ -""" -Tests for the auto-mock system. - -This module contains tests that validate the auto-mock functionality -for workflows containing nodes that require third-party services. -""" - -import pytest - -from dify_graph.enums import BuiltinNodeTypes -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_simple_llm_workflow_with_auto_mock(): - """Test that a simple LLM workflow runs successfully with auto-mocking.""" - runner = TableTestRunner() - - # Create mock configuration - mock_config = MockConfigBuilder().with_llm_response("This is a test response from mocked LLM").build() - - test_case = WorkflowTestCase( - fixture_path="basic_llm_chat_workflow", - inputs={"query": "Hello, how are you?"}, - expected_outputs={"answer": "This is a test response from mocked LLM"}, - description="Simple LLM workflow with auto-mock", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed: {result.error}" - assert result.actual_outputs is not None - assert "answer" in result.actual_outputs - assert result.actual_outputs["answer"] == "This is a test response from mocked LLM" - - -def test_llm_workflow_with_custom_node_output(): - """Test LLM workflow with custom output for specific node.""" - runner = TableTestRunner() - - # Create mock configuration with custom output for specific node - mock_config = MockConfig() - mock_config.set_node_outputs( - "llm_node", - { - "text": "Custom response for this specific node", - "usage": { - "prompt_tokens": 20, - "completion_tokens": 10, - "total_tokens": 30, - }, - "finish_reason": "stop", - }, - ) - - test_case = WorkflowTestCase( - fixture_path="basic_llm_chat_workflow", - inputs={"query": "Test query"}, - expected_outputs={"answer": "Custom response for this specific node"}, - description="LLM workflow with custom node output", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed: {result.error}" - assert result.actual_outputs is not None - assert result.actual_outputs["answer"] == "Custom response for this specific node" - - -def test_http_tool_workflow_with_auto_mock(): - """Test workflow with HTTP request and tool nodes using auto-mock.""" - runner = TableTestRunner() - - # Create mock configuration - mock_config = MockConfig() - mock_config.set_node_outputs( - "http_node", - { - "status_code": 200, - "body": '{"key": "value", "number": 42}', - "headers": {"content-type": "application/json"}, - }, - ) - mock_config.set_node_outputs( - "tool_node", - { - "result": {"key": "value", "number": 42}, - }, - ) - - test_case = WorkflowTestCase( - fixture_path="http_request_with_json_tool_workflow", - inputs={"url": "https://api.example.com/data"}, - expected_outputs={ - "status_code": 200, - "parsed_data": {"key": "value", "number": 42}, - }, - description="HTTP and Tool workflow with auto-mock", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed: {result.error}" - assert result.actual_outputs is not None - assert result.actual_outputs["status_code"] == 200 - assert result.actual_outputs["parsed_data"] == {"key": "value", "number": 42} - - -def test_workflow_with_simulated_node_error(): - """Test that workflows handle simulated node errors correctly.""" - runner = TableTestRunner() - - # Create mock configuration with error - mock_config = MockConfig() - mock_config.set_node_error("llm_node", "Simulated LLM API error") - - test_case = WorkflowTestCase( - fixture_path="basic_llm_chat_workflow", - inputs={"query": "This should fail"}, - expected_outputs={}, # We expect failure, so no outputs - description="LLM workflow with simulated error", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - # The workflow should fail due to the simulated error - assert not result.success - assert result.error is not None - - -def test_workflow_with_mock_delays(): - """Test that mock delays work correctly.""" - runner = TableTestRunner() - - # Create mock configuration with delays - mock_config = MockConfig(simulate_delays=True) - node_config = NodeMockConfig( - node_id="llm_node", - outputs={"text": "Response after delay"}, - delay=0.1, # 100ms delay - ) - mock_config.set_node_config("llm_node", node_config) - - test_case = WorkflowTestCase( - fixture_path="basic_llm_chat_workflow", - inputs={"query": "Test with delay"}, - expected_outputs={"answer": "Response after delay"}, - description="LLM workflow with simulated delay", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed: {result.error}" - # Execution time should be at least the delay - assert result.execution_time >= 0.1 - - -def test_mock_config_builder(): - """Test the MockConfigBuilder fluent interface.""" - config = ( - MockConfigBuilder() - .with_llm_response("LLM response") - .with_agent_response("Agent response") - .with_tool_response({"tool": "output"}) - .with_retrieval_response("Retrieval content") - .with_http_response({"status_code": 201, "body": "created"}) - .with_node_output("node1", {"output": "value"}) - .with_node_error("node2", "error message") - .with_delays(True) - .build() - ) - - assert config.default_llm_response == "LLM response" - assert config.default_agent_response == "Agent response" - assert config.default_tool_response == {"tool": "output"} - assert config.default_retrieval_response == "Retrieval content" - assert config.default_http_response == {"status_code": 201, "body": "created"} - assert config.simulate_delays is True - - node1_config = config.get_node_config("node1") - assert node1_config is not None - assert node1_config.outputs == {"output": "value"} - - node2_config = config.get_node_config("node2") - assert node2_config is not None - assert node2_config.error == "error message" - - -def test_mock_factory_node_type_detection(): - """Test that MockNodeFactory correctly identifies nodes to mock.""" - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from dify_graph.runtime import GraphRuntimeState, VariablePool - - from .test_mock_factory import MockNodeFactory - - graph_init_params = build_test_graph_init_params( - workflow_id="test", - graph_config={}, - tenant_id="test", - app_id="test", - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, - ) - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), - start_at=0, - total_tokens=0, - node_run_steps=0, - ) - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=None, - ) - - # Test that third-party service nodes are identified for mocking - assert factory.should_mock_node(BuiltinNodeTypes.LLM) - assert factory.should_mock_node(BuiltinNodeTypes.AGENT) - assert factory.should_mock_node(BuiltinNodeTypes.TOOL) - assert factory.should_mock_node(BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL) - assert factory.should_mock_node(BuiltinNodeTypes.HTTP_REQUEST) - assert factory.should_mock_node(BuiltinNodeTypes.PARAMETER_EXTRACTOR) - assert factory.should_mock_node(BuiltinNodeTypes.DOCUMENT_EXTRACTOR) - - # Test that CODE and TEMPLATE_TRANSFORM are mocked (they require SSRF proxy) - assert factory.should_mock_node(BuiltinNodeTypes.CODE) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - # Test that non-service nodes are not mocked - assert not factory.should_mock_node(BuiltinNodeTypes.START) - assert not factory.should_mock_node(BuiltinNodeTypes.END) - assert not factory.should_mock_node(BuiltinNodeTypes.IF_ELSE) - assert not factory.should_mock_node(BuiltinNodeTypes.VARIABLE_AGGREGATOR) - - -def test_custom_mock_handler(): - """Test using a custom handler function for mock outputs.""" - runner = TableTestRunner() - - # Custom handler that modifies output based on input - def custom_llm_handler(node) -> dict: - # In a real scenario, we could access node.graph_runtime_state.variable_pool - # to get the actual inputs - return { - "text": "Custom handler response", - "usage": { - "prompt_tokens": 5, - "completion_tokens": 3, - "total_tokens": 8, - }, - "finish_reason": "stop", - } - - mock_config = MockConfig() - node_config = NodeMockConfig( - node_id="llm_node", - custom_handler=custom_llm_handler, - ) - mock_config.set_node_config("llm_node", node_config) - - test_case = WorkflowTestCase( - fixture_path="basic_llm_chat_workflow", - inputs={"query": "Test custom handler"}, - expected_outputs={"answer": "Custom handler response"}, - description="LLM workflow with custom handler", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed: {result.error}" - assert result.actual_outputs["answer"] == "Custom handler response" - - -def test_workflow_without_auto_mock(): - """Test that workflows work normally without auto-mock enabled.""" - runner = TableTestRunner() - - # This test uses the echo workflow which doesn't need external services - test_case = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "Test without mock"}, - expected_outputs={"query": "Test without mock"}, - description="Echo workflow without auto-mock", - use_auto_mock=False, # Auto-mock disabled - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed: {result.error}" - assert result.actual_outputs["query"] == "Test without mock" - - -def test_register_custom_mock_node(): - """Test registering a custom mock implementation for a node type.""" - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from dify_graph.nodes.template_transform import TemplateTransformNode - from dify_graph.runtime import GraphRuntimeState, VariablePool - - from .test_mock_factory import MockNodeFactory - - # Create a custom mock for TemplateTransformNode - class MockTemplateTransformNode(TemplateTransformNode): - def _run(self): - # Custom mock implementation - pass - - graph_init_params = build_test_graph_init_params( - workflow_id="test", - graph_config={}, - tenant_id="test", - app_id="test", - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, - ) - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), - start_at=0, - total_tokens=0, - node_run_steps=0, - ) - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=None, - ) - - # TEMPLATE_TRANSFORM is mocked by default (requires SSRF proxy) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - # Unregister mock - factory.unregister_mock_node_type(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - assert not factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - # Re-register custom mock - factory.register_mock_node_type(BuiltinNodeTypes.TEMPLATE_TRANSFORM, MockTemplateTransformNode) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - -def test_default_config_by_node_type(): - """Test setting default configurations by node type.""" - mock_config = MockConfig() - - # Set default config for all LLM nodes - mock_config.set_default_config( - BuiltinNodeTypes.LLM, - { - "default_response": "Default LLM response for all nodes", - "temperature": 0.7, - }, - ) - - # Set default config for all HTTP nodes - mock_config.set_default_config( - BuiltinNodeTypes.HTTP_REQUEST, - { - "default_status": 200, - "default_timeout": 30, - }, - ) - - llm_config = mock_config.get_default_config(BuiltinNodeTypes.LLM) - assert llm_config["default_response"] == "Default LLM response for all nodes" - assert llm_config["temperature"] == 0.7 - - http_config = mock_config.get_default_config(BuiltinNodeTypes.HTTP_REQUEST) - assert http_config["default_status"] == 200 - assert http_config["default_timeout"] == 30 - - # Non-configured node type should return empty dict - tool_config = mock_config.get_default_config(BuiltinNodeTypes.TOOL) - assert tool_config == {} - - -if __name__ == "__main__": - # Run all tests - pytest.main([__file__, "-v"]) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_basic_chatflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_basic_chatflow.py deleted file mode 100644 index 30acbdaf3db..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_basic_chatflow.py +++ /dev/null @@ -1,41 +0,0 @@ -from dify_graph.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) - -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_basic_chatflow(): - fixture_name = "basic_chatflow" - mock_config = MockConfigBuilder().with_llm_response("mocked llm response").build() - case = WorkflowTestCase( - fixture_path=fixture_name, - use_auto_mock=True, - mock_config=mock_config, - expected_outputs={"answer": "mocked llm response"}, - expected_event_sequence=[ - GraphRunStartedEvent, - # START - NodeRunStartedEvent, - NodeRunSucceededEvent, - # LLM - NodeRunStartedEvent, - ] - + [NodeRunStreamChunkEvent] * ("mocked llm response".count(" ") + 2) - + [ - NodeRunSucceededEvent, - # ANSWER - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - assert result.success, f"Test failed: {result.error}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py deleted file mode 100644 index 765c4deba32..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py +++ /dev/null @@ -1,267 +0,0 @@ -"""Test the command system for GraphEngine control.""" - -import time -from unittest.mock import MagicMock - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams -from dify_graph.entities.pause_reason import SchedulingPause -from dify_graph.graph import Graph -from dify_graph.graph_engine import GraphEngine, GraphEngineConfig -from dify_graph.graph_engine.command_channels import InMemoryChannel -from dify_graph.graph_engine.entities.commands import ( - AbortCommand, - CommandType, - PauseCommand, - UpdateVariablesCommand, - VariableUpdate, -) -from dify_graph.graph_events import GraphRunAbortedEvent, GraphRunPausedEvent, GraphRunStartedEvent -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.variables import IntegerVariable, StringVariable - - -def test_abort_command(): - """Test that GraphEngine properly handles abort commands.""" - - # Create shared GraphRuntimeState - shared_runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) - - # Create a minimal mock graph - mock_graph = MagicMock(spec=Graph) - mock_graph.nodes = {} - mock_graph.edges = {} - mock_graph.root_node = MagicMock() - mock_graph.root_node.id = "start" - - # Create mock nodes with required attributes - using shared runtime state - start_node = StartNode( - id="start", - config={"id": "start", "data": {"title": "start", "variables": []}}, - graph_init_params=GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ), - graph_runtime_state=shared_runtime_state, - ) - mock_graph.nodes["start"] = start_node - - # Mock graph methods - mock_graph.get_outgoing_edges = MagicMock(return_value=[]) - mock_graph.get_incoming_edges = MagicMock(return_value=[]) - - # Create command channel - command_channel = InMemoryChannel() - - # Create GraphEngine with same shared runtime state - engine = GraphEngine( - workflow_id="test_workflow", - graph=mock_graph, - graph_runtime_state=shared_runtime_state, # Use shared instance - command_channel=command_channel, - config=GraphEngineConfig(), - ) - - # Send abort command before starting - abort_command = AbortCommand(reason="Test abort") - command_channel.send_command(abort_command) - - # Run engine and collect events - events = list(engine.run()) - - # Verify we get start and abort events - assert any(isinstance(e, GraphRunStartedEvent) for e in events) - assert any(isinstance(e, GraphRunAbortedEvent) for e in events) - - # Find the abort event and check its reason - abort_events = [e for e in events if isinstance(e, GraphRunAbortedEvent)] - assert len(abort_events) == 1 - assert abort_events[0].reason is not None - assert "aborted: test abort" in abort_events[0].reason.lower() - - -def test_redis_channel_serialization(): - """Test that Redis channel properly serializes and deserializes commands.""" - import json - from unittest.mock import MagicMock - - # Mock redis client - mock_redis = MagicMock() - mock_pipeline = MagicMock() - mock_redis.pipeline.return_value.__enter__ = MagicMock(return_value=mock_pipeline) - mock_redis.pipeline.return_value.__exit__ = MagicMock(return_value=None) - - from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel - - # Create channel with a specific key - channel = RedisChannel(mock_redis, channel_key="workflow:123:commands") - - # Test sending a command - abort_command = AbortCommand(reason="Test abort") - channel.send_command(abort_command) - - # Verify redis methods were called - mock_pipeline.rpush.assert_called_once() - mock_pipeline.expire.assert_called_once() - - # Verify the serialized data - call_args = mock_pipeline.rpush.call_args - key = call_args[0][0] - command_json = call_args[0][1] - - assert key == "workflow:123:commands" - - # Verify JSON structure - command_data = json.loads(command_json) - assert command_data["command_type"] == "abort" - assert command_data["reason"] == "Test abort" - - # Test pause command serialization - pause_command = PauseCommand(reason="User requested pause") - channel.send_command(pause_command) - - assert len(mock_pipeline.rpush.call_args_list) == 2 - second_call_args = mock_pipeline.rpush.call_args_list[1] - pause_command_json = second_call_args[0][1] - pause_command_data = json.loads(pause_command_json) - assert pause_command_data["command_type"] == CommandType.PAUSE.value - assert pause_command_data["reason"] == "User requested pause" - - -def test_pause_command(): - """Test that GraphEngine properly handles pause commands.""" - - shared_runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) - - mock_graph = MagicMock(spec=Graph) - mock_graph.nodes = {} - mock_graph.edges = {} - mock_graph.root_node = MagicMock() - mock_graph.root_node.id = "start" - - start_node = StartNode( - id="start", - config={"id": "start", "data": {"title": "start", "variables": []}}, - graph_init_params=GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ), - graph_runtime_state=shared_runtime_state, - ) - mock_graph.nodes["start"] = start_node - - mock_graph.get_outgoing_edges = MagicMock(return_value=[]) - mock_graph.get_incoming_edges = MagicMock(return_value=[]) - - command_channel = InMemoryChannel() - - engine = GraphEngine( - workflow_id="test_workflow", - graph=mock_graph, - graph_runtime_state=shared_runtime_state, - command_channel=command_channel, - config=GraphEngineConfig(), - ) - - pause_command = PauseCommand(reason="User requested pause") - command_channel.send_command(pause_command) - - events = list(engine.run()) - - assert any(isinstance(e, GraphRunStartedEvent) for e in events) - pause_events = [e for e in events if isinstance(e, GraphRunPausedEvent)] - assert len(pause_events) == 1 - assert pause_events[0].reasons == [SchedulingPause(message="User requested pause")] - - graph_execution = engine.graph_runtime_state.graph_execution - assert graph_execution.pause_reasons == [SchedulingPause(message="User requested pause")] - - -def test_update_variables_command_updates_pool(): - """Test that GraphEngine updates variable pool via update variables command.""" - - shared_runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) - shared_runtime_state.variable_pool.add(("node1", "foo"), "old value") - - mock_graph = MagicMock(spec=Graph) - mock_graph.nodes = {} - mock_graph.edges = {} - mock_graph.root_node = MagicMock() - mock_graph.root_node.id = "start" - - start_node = StartNode( - id="start", - config={"id": "start", "data": {"title": "start", "variables": []}}, - graph_init_params=GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ), - graph_runtime_state=shared_runtime_state, - ) - mock_graph.nodes["start"] = start_node - - mock_graph.get_outgoing_edges = MagicMock(return_value=[]) - mock_graph.get_incoming_edges = MagicMock(return_value=[]) - - command_channel = InMemoryChannel() - - engine = GraphEngine( - workflow_id="test_workflow", - graph=mock_graph, - graph_runtime_state=shared_runtime_state, - command_channel=command_channel, - config=GraphEngineConfig(), - ) - - update_command = UpdateVariablesCommand( - updates=[ - VariableUpdate( - value=StringVariable(name="foo", value="new value", selector=["node1", "foo"]), - ), - VariableUpdate( - value=IntegerVariable(name="bar", value=123, selector=["node2", "bar"]), - ), - ] - ) - command_channel.send_command(update_command) - - list(engine.run()) - - updated_existing = shared_runtime_state.variable_pool.get(["node1", "foo"]) - added_new = shared_runtime_state.variable_pool.get(["node2", "bar"]) - - assert updated_existing is not None - assert updated_existing.value == "new value" - assert added_new is not None - assert added_new.value == 123 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py deleted file mode 100644 index 3a9a0b18bcb..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py +++ /dev/null @@ -1,134 +0,0 @@ -""" -Test suite for complex branch workflow with parallel execution and conditional routing. - -This test suite validates the behavior of a workflow that: -1. Executes nodes in parallel (IF/ELSE and LLM branches) -2. Routes based on conditional logic (query containing 'hello') -3. Handles multiple answer nodes with different outputs -""" - -from dify_graph.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, -) - -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -class TestComplexBranchWorkflow: - """Test suite for complex branch workflow with parallel execution.""" - - def setup_method(self): - """Set up test environment before each test method.""" - self.runner = TableTestRunner() - self.fixture_path = "test_complex_branch" - - def test_hello_branch_with_llm(self): - """ - Test when query contains 'hello' - should trigger true branch. - Both IF/ELSE and LLM should execute in parallel. - """ - mock_text_1 = "This is a mocked LLM response for hello world" - test_cases = [ - WorkflowTestCase( - fixture_path=self.fixture_path, - query="hello world", - expected_outputs={ - "answer": f"contains 'hello'{mock_text_1}", - }, - description="Basic hello case with parallel LLM execution", - use_auto_mock=True, - mock_config=(MockConfigBuilder().with_node_output("1755502777322", {"text": mock_text_1}).build()), - ), - WorkflowTestCase( - fixture_path=self.fixture_path, - query="say hello to everyone", - expected_outputs={ - "answer": "contains 'hello'Mocked response for greeting", - }, - description="Hello in middle of sentence", - use_auto_mock=True, - mock_config=( - MockConfigBuilder() - .with_node_output("1755502777322", {"text": "Mocked response for greeting"}) - .build() - ), - ), - ] - - suite_result = self.runner.run_table_tests(test_cases) - - for result in suite_result.results: - assert result.success, f"Test '{result.test_case.description}' failed: {result.error}" - assert result.actual_outputs - assert any(isinstance(event, GraphRunStartedEvent) for event in result.events) - assert any(isinstance(event, GraphRunSucceededEvent) for event in result.events) - - start_index = next( - idx for idx, event in enumerate(result.events) if isinstance(event, GraphRunStartedEvent) - ) - success_index = max( - idx for idx, event in enumerate(result.events) if isinstance(event, GraphRunSucceededEvent) - ) - assert start_index < success_index - - started_node_ids = {event.node_id for event in result.events if isinstance(event, NodeRunStartedEvent)} - assert {"1755502773326", "1755502777322"}.issubset(started_node_ids), ( - f"Branch or LLM nodes missing in events: {started_node_ids}" - ) - - assert any(isinstance(event, NodeRunStreamChunkEvent) for event in result.events), ( - "Expected streaming chunks from LLM execution" - ) - - llm_start_index = next( - idx - for idx, event in enumerate(result.events) - if isinstance(event, NodeRunStartedEvent) and event.node_id == "1755502777322" - ) - assert any( - idx > llm_start_index and isinstance(event, NodeRunStreamChunkEvent) - for idx, event in enumerate(result.events) - ), "Streaming chunks should follow LLM node start" - - def test_non_hello_branch_with_llm(self): - """ - Test when query doesn't contain 'hello' - should trigger false branch. - LLM output should be used as the final answer. - """ - test_cases = [ - WorkflowTestCase( - fixture_path=self.fixture_path, - query="goodbye world", - expected_outputs={ - "answer": "Mocked LLM response for goodbye", - }, - description="Goodbye case - false branch with LLM output", - use_auto_mock=True, - mock_config=( - MockConfigBuilder() - .with_node_output("1755502777322", {"text": "Mocked LLM response for goodbye"}) - .build() - ), - ), - WorkflowTestCase( - fixture_path=self.fixture_path, - query="test message", - expected_outputs={ - "answer": "Mocked response for test", - }, - description="Regular message - false branch", - use_auto_mock=True, - mock_config=( - MockConfigBuilder().with_node_output("1755502777322", {"text": "Mocked response for test"}).build() - ), - ), - ] - - suite_result = self.runner.run_table_tests(test_cases) - - for result in suite_result.results: - assert result.success, f"Test '{result.test_case.description}' failed: {result.error}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py deleted file mode 100644 index 76bf179f330..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py +++ /dev/null @@ -1,220 +0,0 @@ -""" -Test for streaming output workflow behavior. - -This test validates that: -- When blocking == 1: No NodeRunStreamChunkEvent (flow through Template node) -- When blocking != 1: NodeRunStreamChunkEvent present (direct LLM to End output) -""" - -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.graph_engine import GraphEngine, GraphEngineConfig -from dify_graph.graph_engine.command_channels import InMemoryChannel -from dify_graph.graph_events import ( - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) - -from .test_table_runner import TableTestRunner - - -def test_streaming_output_with_blocking_equals_one(): - """ - Test workflow when blocking == 1 (LLM → Template → End). - - Template node doesn't produce streaming output, so no NodeRunStreamChunkEvent should be present. - This test should FAIL according to requirements. - """ - runner = TableTestRunner() - - # Load the workflow configuration - fixture_data = runner.workflow_runner.load_fixture("conditional_streaming_vs_template_workflow") - - # Create graph from fixture with auto-mock enabled - graph, graph_runtime_state = runner.workflow_runner.create_graph_from_fixture( - fixture_data=fixture_data, - inputs={"query": "Hello, how are you?", "blocking": 1}, - use_mock_factory=True, - ) - - # Create and run the engine - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - # Execute the workflow - events = list(engine.run()) - - # Check for successful completion - success_events = [e for e in events if isinstance(e, GraphRunSucceededEvent)] - assert len(success_events) > 0, "Workflow should complete successfully" - - # Check for streaming events - stream_chunk_events = [e for e in events if isinstance(e, NodeRunStreamChunkEvent)] - stream_chunk_count = len(stream_chunk_events) - - # According to requirements, we expect exactly 3 streaming events from the End node - # 1. User query - # 2. Newline - # 3. Template output (which contains the LLM response) - assert stream_chunk_count == 3, f"Expected 3 streaming events when blocking=1, but got {stream_chunk_count}" - - first_chunk, second_chunk, third_chunk = stream_chunk_events[0], stream_chunk_events[1], stream_chunk_events[2] - assert first_chunk.chunk == "Hello, how are you?", ( - f"Expected first chunk to be user input, but got {first_chunk.chunk}" - ) - assert second_chunk.chunk == "\n", f"Expected second chunk to be newline, but got {second_chunk.chunk}" - # Third chunk will be the template output with the mock LLM response - assert isinstance(third_chunk.chunk, str), f"Expected third chunk to be string, but got {type(third_chunk.chunk)}" - - # Find indices of first LLM success event and first stream chunk event - llm2_start_index = next( - ( - i - for i, e in enumerate(events) - if isinstance(e, NodeRunSucceededEvent) and e.node_type == BuiltinNodeTypes.LLM - ), - -1, - ) - first_chunk_index = next( - (i for i, e in enumerate(events) if isinstance(e, NodeRunStreamChunkEvent)), - -1, - ) - - assert first_chunk_index < llm2_start_index, ( - f"Expected first chunk before LLM2 start, but got {first_chunk_index} and {llm2_start_index}" - ) - - # Check that NodeRunStreamChunkEvent contains 'query' should has same id with Start NodeRunStartedEvent - start_node_id = graph.root_node.id - start_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_id == start_node_id] - assert len(start_events) == 1, f"Expected 1 start event for node {start_node_id}, but got {len(start_events)}" - start_event = start_events[0] - query_chunk_events = [e for e in stream_chunk_events if e.chunk == "Hello, how are you?"] - assert all(e.id == start_event.id for e in query_chunk_events), "Expected all query chunk events to have same id" - - # Check all Template's NodeRunStreamChunkEvent should has same id with Template's NodeRunStartedEvent - start_events = [ - e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == BuiltinNodeTypes.TEMPLATE_TRANSFORM - ] - template_chunk_events = [e for e in stream_chunk_events if e.node_type == BuiltinNodeTypes.TEMPLATE_TRANSFORM] - assert len(template_chunk_events) == 1, f"Expected 1 template chunk event, but got {len(template_chunk_events)}" - assert all(e.id in [se.id for se in start_events] for e in template_chunk_events), ( - "Expected all Template chunk events to have same id with Template's NodeRunStartedEvent" - ) - - # Check that NodeRunStreamChunkEvent contains '\n' is from the End node - end_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == BuiltinNodeTypes.END] - assert len(end_events) == 1, f"Expected 1 end event, but got {len(end_events)}" - newline_chunk_events = [e for e in stream_chunk_events if e.chunk == "\n"] - assert len(newline_chunk_events) == 1, f"Expected 1 newline chunk event, but got {len(newline_chunk_events)}" - # The newline chunk should be from the End node (check node_id, not execution id) - assert all(e.node_id == end_events[0].node_id for e in newline_chunk_events), ( - "Expected all newline chunk events to be from End node" - ) - - -def test_streaming_output_with_blocking_not_equals_one(): - """ - Test workflow when blocking != 1 (LLM → End directly). - - End node should produce streaming output with NodeRunStreamChunkEvent. - This test should PASS according to requirements. - """ - runner = TableTestRunner() - - # Load the workflow configuration - fixture_data = runner.workflow_runner.load_fixture("conditional_streaming_vs_template_workflow") - - # Create graph from fixture with auto-mock enabled - graph, graph_runtime_state = runner.workflow_runner.create_graph_from_fixture( - fixture_data=fixture_data, - inputs={"query": "Hello, how are you?", "blocking": 2}, - use_mock_factory=True, - ) - - # Create and run the engine - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - # Execute the workflow - events = list(engine.run()) - - # Check for successful completion - success_events = [e for e in events if isinstance(e, GraphRunSucceededEvent)] - assert len(success_events) > 0, "Workflow should complete successfully" - - # Check for streaming events - expecting streaming events - stream_chunk_events = [e for e in events if isinstance(e, NodeRunStreamChunkEvent)] - stream_chunk_count = len(stream_chunk_events) - - # This assertion should PASS according to requirements - assert stream_chunk_count > 0, f"Expected streaming events when blocking!=1, but got {stream_chunk_count}" - - # We should have at least 2 chunks (query and newline) - assert stream_chunk_count >= 2, f"Expected at least 2 streaming events, but got {stream_chunk_count}" - - first_chunk, second_chunk = stream_chunk_events[0], stream_chunk_events[1] - assert first_chunk.chunk == "Hello, how are you?", ( - f"Expected first chunk to be user input, but got {first_chunk.chunk}" - ) - assert second_chunk.chunk == "\n", f"Expected second chunk to be newline, but got {second_chunk.chunk}" - - # Find indices of first LLM success event and first stream chunk event - llm2_start_index = next( - ( - i - for i, e in enumerate(events) - if isinstance(e, NodeRunSucceededEvent) and e.node_type == BuiltinNodeTypes.LLM - ), - -1, - ) - first_chunk_index = next( - (i for i, e in enumerate(events) if isinstance(e, NodeRunStreamChunkEvent)), - -1, - ) - - assert first_chunk_index < llm2_start_index, ( - f"Expected first chunk before LLM2 start, but got {first_chunk_index} and {llm2_start_index}" - ) - - # With auto-mock, the LLM will produce mock responses - just verify we have streaming chunks - # and they are strings - for chunk_event in stream_chunk_events[2:]: - assert isinstance(chunk_event.chunk, str), f"Expected chunk to be string, but got {type(chunk_event.chunk)}" - - # Check that NodeRunStreamChunkEvent contains 'query' should has same id with Start NodeRunStartedEvent - start_node_id = graph.root_node.id - start_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_id == start_node_id] - assert len(start_events) == 1, f"Expected 1 start event for node {start_node_id}, but got {len(start_events)}" - start_event = start_events[0] - query_chunk_events = [e for e in stream_chunk_events if e.chunk == "Hello, how are you?"] - assert all(e.id == start_event.id for e in query_chunk_events), "Expected all query chunk events to have same id" - - # Check all LLM's NodeRunStreamChunkEvent should be from LLM nodes - start_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == BuiltinNodeTypes.LLM] - llm_chunk_events = [e for e in stream_chunk_events if e.node_type == BuiltinNodeTypes.LLM] - llm_node_ids = {se.node_id for se in start_events} - assert all(e.node_id in llm_node_ids for e in llm_chunk_events), ( - "Expected all LLM chunk events to be from LLM nodes" - ) - - # Check that NodeRunStreamChunkEvent contains '\n' is from the End node - end_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == BuiltinNodeTypes.END] - assert len(end_events) == 1, f"Expected 1 end event, but got {len(end_events)}" - newline_chunk_events = [e for e in stream_chunk_events if e.chunk == "\n"] - assert len(newline_chunk_events) == 1, f"Expected 1 newline chunk event, but got {len(newline_chunk_events)}" - # The newline chunk should be from the End node (check node_id, not execution id) - assert all(e.node_id == end_events[0].node_id for e in newline_chunk_events), ( - "Expected all newline chunk events to be from End node" - ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_database_utils.py b/api/tests/unit_tests/core/workflow/graph_engine/test_database_utils.py deleted file mode 100644 index ae7dd48bb16..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_database_utils.py +++ /dev/null @@ -1,46 +0,0 @@ -""" -Utilities for detecting if database service is available for workflow tests. -""" - -import psycopg2 -import pytest - -from configs import dify_config - - -def is_database_available() -> bool: - """ - Check if the database service is available by attempting to connect to it. - - Returns: - True if database is available, False otherwise. - """ - try: - # Try to establish a database connection using a context manager - with psycopg2.connect( - host=dify_config.DB_HOST, - port=dify_config.DB_PORT, - database=dify_config.DB_DATABASE, - user=dify_config.DB_USERNAME, - password=dify_config.DB_PASSWORD, - connect_timeout=2, # 2 second timeout - ) as conn: - pass # Connection established and will be closed automatically - return True - except (psycopg2.OperationalError, psycopg2.Error): - return False - - -def skip_if_database_unavailable(): - """ - Pytest skip decorator that skips tests when database service is unavailable. - - Usage: - @skip_if_database_unavailable() - def test_my_workflow(): - ... - """ - return pytest.mark.skipif( - not is_database_available(), - reason="Database service is not available (connection refused or authentication failed)", - ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py b/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py deleted file mode 100644 index 778dad59527..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py +++ /dev/null @@ -1,72 +0,0 @@ -import queue -from datetime import datetime - -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.graph_engine.orchestration.dispatcher import Dispatcher -from dify_graph.graph_events import NodeRunSucceededEvent -from dify_graph.node_events import NodeRunResult - - -class StubExecutionCoordinator: - def __init__(self, paused: bool) -> None: - self._paused = paused - self.mark_complete_called = False - self.failed_error: Exception | None = None - - @property - def aborted(self) -> bool: - return False - - @property - def paused(self) -> bool: - return self._paused - - @property - def execution_complete(self) -> bool: - return False - - def check_scaling(self) -> None: - return None - - def process_commands(self) -> None: - return None - - def mark_complete(self) -> None: - self.mark_complete_called = True - - def mark_failed(self, error: Exception) -> None: - self.failed_error = error - - -class StubEventHandler: - def __init__(self) -> None: - self.events: list[object] = [] - - def dispatch(self, event: object) -> None: - self.events.append(event) - - -def test_dispatcher_drains_events_when_paused() -> None: - event_queue: queue.Queue = queue.Queue() - event = NodeRunSucceededEvent( - id="exec-1", - node_id="node-1", - node_type=BuiltinNodeTypes.START, - start_at=datetime.utcnow(), - node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED), - ) - event_queue.put(event) - - handler = StubEventHandler() - coordinator = StubExecutionCoordinator(paused=True) - dispatcher = Dispatcher( - event_queue=event_queue, - event_handler=handler, - execution_coordinator=coordinator, - event_emitter=None, - ) - - dispatcher._dispatcher_loop() - - assert handler.events == [event] - assert coordinator.mark_complete_called is True diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py b/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py deleted file mode 100644 index c87dc75b95e..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py +++ /dev/null @@ -1,60 +0,0 @@ -""" -Test case for end node without value_type field (backward compatibility). - -This test validates that end nodes work correctly even when the value_type -field is missing from the output configuration, ensuring backward compatibility -with older workflow definitions. -""" - -from dify_graph.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) - -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_end_node_without_value_type_field(): - """ - Test that end node works without explicit value_type field. - - The fixture implements a simple workflow that: - 1. Takes a query input from start node - 2. Passes it directly to end node - 3. End node outputs the value without specifying value_type - 4. Should correctly infer the type and output the value - - This ensures backward compatibility with workflow definitions - created before value_type became a required field. - """ - fixture_name = "end_node_without_value_type_field_workflow" - - case = WorkflowTestCase( - fixture_path=fixture_name, - inputs={"query": "test query"}, - expected_outputs={"query": "test query"}, - expected_event_sequence=[ - # Graph start - GraphRunStartedEvent, - # Start node - NodeRunStartedEvent, - NodeRunStreamChunkEvent, # Start node streams the input value - NodeRunSucceededEvent, - # End node - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Graph end - GraphRunSucceededEvent, - ], - description="End node without value_type field should work correctly", - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - assert result.success, f"Test failed: {result.error}" - assert result.actual_outputs == {"query": "test query"}, ( - f"Expected output to be {{'query': 'test query'}}, got {result.actual_outputs}" - ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py b/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py deleted file mode 100644 index 35406997edf..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py +++ /dev/null @@ -1,62 +0,0 @@ -"""Unit tests for the execution coordinator orchestration logic.""" - -from unittest.mock import MagicMock - -import pytest - -from dify_graph.graph_engine.command_processing.command_processor import CommandProcessor -from dify_graph.graph_engine.domain.graph_execution import GraphExecution -from dify_graph.graph_engine.graph_state_manager import GraphStateManager -from dify_graph.graph_engine.orchestration.execution_coordinator import ExecutionCoordinator -from dify_graph.graph_engine.worker_management.worker_pool import WorkerPool - - -def _build_coordinator(graph_execution: GraphExecution) -> tuple[ExecutionCoordinator, MagicMock, MagicMock]: - command_processor = MagicMock(spec=CommandProcessor) - state_manager = MagicMock(spec=GraphStateManager) - worker_pool = MagicMock(spec=WorkerPool) - - coordinator = ExecutionCoordinator( - graph_execution=graph_execution, - state_manager=state_manager, - command_processor=command_processor, - worker_pool=worker_pool, - ) - return coordinator, state_manager, worker_pool - - -def test_handle_pause_stops_workers_and_clears_state() -> None: - """Paused execution should stop workers and clear executing state.""" - graph_execution = GraphExecution(workflow_id="workflow") - graph_execution.start() - graph_execution.pause("Awaiting human input") - - coordinator, state_manager, worker_pool = _build_coordinator(graph_execution) - - coordinator.handle_pause_if_needed() - - worker_pool.stop.assert_called_once_with() - state_manager.clear_executing.assert_called_once_with() - - -def test_handle_pause_noop_when_execution_running() -> None: - """Running execution should not trigger pause handling.""" - graph_execution = GraphExecution(workflow_id="workflow") - graph_execution.start() - - coordinator, state_manager, worker_pool = _build_coordinator(graph_execution) - - coordinator.handle_pause_if_needed() - - worker_pool.stop.assert_not_called() - state_manager.clear_executing.assert_not_called() - - -def test_has_executing_nodes_requires_pause() -> None: - graph_execution = GraphExecution(workflow_id="workflow") - graph_execution.start() - - coordinator, _, _ = _build_coordinator(graph_execution) - - with pytest.raises(AssertionError): - coordinator.has_executing_nodes() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py deleted file mode 100644 index 4e13177d2b9..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ /dev/null @@ -1,770 +0,0 @@ -""" -Table-driven test framework for GraphEngine workflows. - -This file contains property-based tests and specific workflow tests. -The core test framework is in test_table_runner.py. -""" - -import time - -from hypothesis import HealthCheck, given, settings -from hypothesis import strategies as st - -from dify_graph.entities.base_node_data import DefaultValue, DefaultValueType -from dify_graph.enums import ErrorStrategy -from dify_graph.graph_engine import GraphEngine, GraphEngineConfig -from dify_graph.graph_engine.command_channels import InMemoryChannel -from dify_graph.graph_events import ( - GraphRunPartialSucceededEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, -) - -# Import the test framework from the new module -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowRunner, WorkflowTestCase - - -# Property-based fuzzing tests for the start-end workflow -@given(query_input=st.text()) -@settings(max_examples=50, deadline=30000, suppress_health_check=[HealthCheck.too_slow]) -def test_echo_workflow_property_basic_strings(query_input): - """ - Property-based test: Echo workflow should return exactly what was input. - - This tests the fundamental property that for any string input, - the start-end workflow should echo it back unchanged. - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": query_input}, - expected_outputs={"query": query_input}, - description=f"Fuzzing test with input: {repr(query_input)[:50]}...", - ) - - result = runner.run_test_case(test_case) - - # Property: The workflow should complete successfully - assert result.success, f"Workflow failed with input {repr(query_input)}: {result.error}" - - # Property: Output should equal input (echo behavior) - assert result.actual_outputs - assert result.actual_outputs == {"query": query_input}, ( - f"Echo property violated. Input: {repr(query_input)}, " - f"Expected: {repr(query_input)}, Got: {repr(result.actual_outputs.get('query'))}" - ) - - -@given(query_input=st.text(min_size=0, max_size=1000)) -@settings(max_examples=30, deadline=20000) -def test_echo_workflow_property_bounded_strings(query_input): - """ - Property-based test with size bounds to test edge cases more efficiently. - - Tests strings up to 1000 characters to balance thoroughness with performance. - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": query_input}, - expected_outputs={"query": query_input}, - description=f"Bounded fuzzing test (len={len(query_input)})", - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed with bounded input: {result.error}" - assert result.actual_outputs == {"query": query_input} - - -@given( - query_input=st.one_of( - st.text(alphabet=st.characters(whitelist_categories=["Lu", "Ll", "Nd", "Po"])), # Letters, digits, punctuation - st.text(alphabet="🎉🌟💫⭐🔥💯🚀🎯"), # Emojis - st.text(alphabet="αβγδεζηθικλμνξοπρστυφχψω"), # Greek letters - st.text(alphabet="中文测试한국어日本語العربية"), # International characters - st.just(""), # Empty string - st.just(" " * 100), # Whitespace only - st.just("\n\t\r\f\v"), # Special whitespace chars - st.just('{"json": "like", "data": [1, 2, 3]}'), # JSON-like string - st.just("SELECT * FROM users; DROP TABLE users;--"), # SQL injection attempt - st.just(""), # XSS attempt - st.just("../../etc/passwd"), # Path traversal attempt - ) -) -@settings(max_examples=40, deadline=25000) -def test_echo_workflow_property_diverse_inputs(query_input): - """ - Property-based test with diverse input types including edge cases and security payloads. - - Tests various categories of potentially problematic inputs: - - Unicode characters from different languages - - Emojis and special symbols - - Whitespace variations - - Malicious payloads (SQL injection, XSS, path traversal) - - JSON-like structures - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": query_input}, - expected_outputs={"query": query_input}, - description=f"Diverse input fuzzing: {type(query_input).__name__}", - ) - - result = runner.run_test_case(test_case) - - # Property: System should handle all inputs gracefully (no crashes) - assert result.success, f"Workflow failed with diverse input {repr(query_input)}: {result.error}" - - # Property: Echo behavior must be preserved regardless of input type - assert result.actual_outputs == {"query": query_input} - - -@given(query_input=st.text(min_size=1000, max_size=5000)) -@settings(max_examples=10, deadline=60000) -def test_echo_workflow_property_large_inputs(query_input): - """ - Property-based test for large inputs to test memory and performance boundaries. - - Tests the system's ability to handle larger payloads efficiently. - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": query_input}, - expected_outputs={"query": query_input}, - description=f"Large input test (size: {len(query_input)} chars)", - timeout=45.0, # Longer timeout for large inputs - ) - - start_time = time.perf_counter() - result = runner.run_test_case(test_case) - execution_time = time.perf_counter() - start_time - - # Property: Large inputs should still work - assert result.success, f"Large input workflow failed: {result.error}" - - # Property: Echo behavior preserved for large inputs - assert result.actual_outputs == {"query": query_input} - - # Property: Performance should be reasonable even for large inputs - assert execution_time < 30.0, f"Large input took too long: {execution_time:.2f}s" - - -def test_echo_workflow_robustness_smoke_test(): - """ - Smoke test to ensure the basic workflow functionality works before fuzzing. - - This test uses a simple, known-good input to verify the test infrastructure - is working correctly before running the fuzzing tests. - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "smoke test"}, - expected_outputs={"query": "smoke test"}, - description="Smoke test for basic functionality", - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Smoke test failed: {result.error}" - assert result.actual_outputs == {"query": "smoke test"} - assert result.execution_time > 0 - - -def test_if_else_workflow_true_branch(): - """ - Test if-else workflow when input contains 'hello' (true branch). - - Should output {"true": input_query} when query contains "hello". - """ - runner = TableTestRunner() - - test_cases = [ - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "hello world"}, - expected_outputs={"true": "hello world"}, - description="Basic hello case", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "say hello to everyone"}, - expected_outputs={"true": "say hello to everyone"}, - description="Hello in middle of sentence", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "hello"}, - expected_outputs={"true": "hello"}, - description="Just hello", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "hellohello"}, - expected_outputs={"true": "hellohello"}, - description="Multiple hello occurrences", - ), - ] - - suite_result = runner.run_table_tests(test_cases) - - for result in suite_result.results: - assert result.success, f"Test case '{result.test_case.description}' failed: {result.error}" - # Check that outputs contain ONLY the expected key (true branch) - assert result.actual_outputs == result.test_case.expected_outputs, ( - f"Expected only 'true' key in outputs for {result.test_case.description}. " - f"Expected: {result.test_case.expected_outputs}, Got: {result.actual_outputs}" - ) - - -def test_if_else_workflow_false_branch(): - """ - Test if-else workflow when input does not contain 'hello' (false branch). - - Should output {"false": input_query} when query does not contain "hello". - """ - runner = TableTestRunner() - - test_cases = [ - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "goodbye world"}, - expected_outputs={"false": "goodbye world"}, - description="Basic goodbye case", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "hi there"}, - expected_outputs={"false": "hi there"}, - description="Simple greeting without hello", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": ""}, - expected_outputs={"false": ""}, - description="Empty string", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "test message"}, - expected_outputs={"false": "test message"}, - description="Regular message", - ), - ] - - suite_result = runner.run_table_tests(test_cases) - - for result in suite_result.results: - assert result.success, f"Test case '{result.test_case.description}' failed: {result.error}" - # Check that outputs contain ONLY the expected key (false branch) - assert result.actual_outputs == result.test_case.expected_outputs, ( - f"Expected only 'false' key in outputs for {result.test_case.description}. " - f"Expected: {result.test_case.expected_outputs}, Got: {result.actual_outputs}" - ) - - -def test_if_else_workflow_edge_cases(): - """ - Test if-else workflow edge cases and case sensitivity. - - Tests various edge cases including case sensitivity, similar words, etc. - """ - runner = TableTestRunner() - - test_cases = [ - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "Hello world"}, - expected_outputs={"false": "Hello world"}, - description="Capitalized Hello (case sensitive test)", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "HELLO"}, - expected_outputs={"false": "HELLO"}, - description="All caps HELLO (case sensitive test)", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "helllo"}, - expected_outputs={"false": "helllo"}, - description="Typo: helllo (with extra l)", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "helo"}, - expected_outputs={"false": "helo"}, - description="Typo: helo (missing l)", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "hello123"}, - expected_outputs={"true": "hello123"}, - description="Hello with numbers", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "hello!@#"}, - expected_outputs={"true": "hello!@#"}, - description="Hello with special characters", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": " hello "}, - expected_outputs={"true": " hello "}, - description="Hello with surrounding spaces", - ), - ] - - suite_result = runner.run_table_tests(test_cases) - - for result in suite_result.results: - assert result.success, f"Test case '{result.test_case.description}' failed: {result.error}" - # Check that outputs contain ONLY the expected key - assert result.actual_outputs == result.test_case.expected_outputs, ( - f"Expected exact match for {result.test_case.description}. " - f"Expected: {result.test_case.expected_outputs}, Got: {result.actual_outputs}" - ) - - -@given(query_input=st.text()) -@settings(max_examples=50, deadline=30000, suppress_health_check=[HealthCheck.too_slow]) -def test_if_else_workflow_property_basic_strings(query_input): - """ - Property-based test: If-else workflow should output correct branch based on 'hello' content. - - This tests the fundamental property that for any string input: - - If input contains "hello", output should be {"true": input} - - If input doesn't contain "hello", output should be {"false": input} - """ - runner = TableTestRunner() - - # Determine expected output based on whether input contains "hello" - contains_hello = "hello" in query_input - expected_key = "true" if contains_hello else "false" - expected_outputs = {expected_key: query_input} - - test_case = WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": query_input}, - expected_outputs=expected_outputs, - description=f"Property test with input: {repr(query_input)[:50]}...", - ) - - result = runner.run_test_case(test_case) - - # Property: The workflow should complete successfully - assert result.success, f"Workflow failed with input {repr(query_input)}: {result.error}" - - # Property: Output should contain ONLY the expected key with correct value - assert result.actual_outputs == expected_outputs, ( - f"If-else property violated. Input: {repr(query_input)}, " - f"Expected: {expected_outputs}, Got: {result.actual_outputs}" - ) - - -@given(query_input=st.text(min_size=0, max_size=1000)) -@settings(max_examples=30, deadline=20000) -def test_if_else_workflow_property_bounded_strings(query_input): - """ - Property-based test with size bounds for if-else workflow. - - Tests strings up to 1000 characters to balance thoroughness with performance. - """ - runner = TableTestRunner() - - contains_hello = "hello" in query_input - expected_key = "true" if contains_hello else "false" - expected_outputs = {expected_key: query_input} - - test_case = WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": query_input}, - expected_outputs=expected_outputs, - description=f"Bounded if-else test (len={len(query_input)}, contains_hello={contains_hello})", - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed with bounded input: {result.error}" - assert result.actual_outputs == expected_outputs - - -@given( - query_input=st.one_of( - st.text(alphabet=st.characters(whitelist_categories=["Lu", "Ll", "Nd", "Po"])), # Letters, digits, punctuation - st.text(alphabet="hello"), # Strings that definitely contain hello - st.text(alphabet="xyz"), # Strings that definitely don't contain hello - st.just("hello world"), # Known true case - st.just("goodbye world"), # Known false case - st.just(""), # Empty string - st.just("Hello"), # Case sensitivity test - st.just("HELLO"), # Case sensitivity test - st.just("hello" * 10), # Multiple hello occurrences - st.just("say hello to everyone"), # Hello in middle - st.text(alphabet="🎉🌟💫⭐🔥💯🚀🎯"), # Emojis - st.text(alphabet="中文测试한국어日本語العربية"), # International characters - ) -) -@settings(max_examples=40, deadline=25000) -def test_if_else_workflow_property_diverse_inputs(query_input): - """ - Property-based test with diverse input types for if-else workflow. - - Tests various categories including: - - Known true/false cases - - Case sensitivity scenarios - - Unicode characters from different languages - - Emojis and special symbols - - Multiple hello occurrences - """ - runner = TableTestRunner() - - contains_hello = "hello" in query_input - expected_key = "true" if contains_hello else "false" - expected_outputs = {expected_key: query_input} - - test_case = WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": query_input}, - expected_outputs=expected_outputs, - description=f"Diverse if-else test: {type(query_input).__name__} (contains_hello={contains_hello})", - ) - - result = runner.run_test_case(test_case) - - # Property: System should handle all inputs gracefully (no crashes) - assert result.success, f"Workflow failed with diverse input {repr(query_input)}: {result.error}" - - # Property: Correct branch logic must be preserved regardless of input type - assert result.actual_outputs == expected_outputs, ( - f"Branch logic violated. Input: {repr(query_input)}, " - f"Contains 'hello': {contains_hello}, Expected: {expected_outputs}, Got: {result.actual_outputs}" - ) - - -# Tests for the Layer system -def test_layer_system_basic(): - """Test basic layer functionality with DebugLoggingLayer.""" - from dify_graph.graph_engine.layers import DebugLoggingLayer - - runner = WorkflowRunner() - - # Load a simple echo workflow - fixture_data = runner.load_fixture("simple_passthrough_workflow") - graph, graph_runtime_state = runner.create_graph_from_fixture(fixture_data, inputs={"query": "test layer system"}) - - # Create engine with layer - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - # Add debug logging layer - debug_layer = DebugLoggingLayer(level="DEBUG", include_inputs=True, include_outputs=True) - engine.layer(debug_layer) - - # Run workflow - events = list(engine.run()) - - # Verify events were generated - assert len(events) > 0 - assert isinstance(events[0], GraphRunStartedEvent) - assert isinstance(events[-1], GraphRunSucceededEvent) - - # Verify layer received context - assert debug_layer.graph_runtime_state is not None - assert debug_layer.command_channel is not None - - # Verify layer tracked execution stats - assert debug_layer.node_count > 0 - assert debug_layer.success_count > 0 - - -def test_layer_chaining(): - """Test chaining multiple layers.""" - from dify_graph.graph_engine.layers import DebugLoggingLayer, GraphEngineLayer - - # Create a custom test layer - class TestLayer(GraphEngineLayer): - def __init__(self): - super().__init__() - self.events_received = [] - self.graph_started = False - self.graph_ended = False - - def on_graph_start(self): - self.graph_started = True - - def on_event(self, event): - self.events_received.append(event.__class__.__name__) - - def on_graph_end(self, error): - self.graph_ended = True - - runner = WorkflowRunner() - - # Load workflow - fixture_data = runner.load_fixture("simple_passthrough_workflow") - graph, graph_runtime_state = runner.create_graph_from_fixture(fixture_data, inputs={"query": "test chaining"}) - - # Create engine - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - # Chain multiple layers - test_layer = TestLayer() - debug_layer = DebugLoggingLayer(level="INFO") - - engine.layer(test_layer).layer(debug_layer) - - # Run workflow - events = list(engine.run()) - - # Verify both layers received events - assert test_layer.graph_started - assert test_layer.graph_ended - assert len(test_layer.events_received) > 0 - - # Verify debug layer also worked - assert debug_layer.node_count > 0 - - -def test_layer_error_handling(): - """Test that layer errors don't crash the engine.""" - from dify_graph.graph_engine.layers import GraphEngineLayer - - # Create a layer that throws errors - class FaultyLayer(GraphEngineLayer): - def on_graph_start(self): - raise RuntimeError("Intentional error in on_graph_start") - - def on_event(self, event): - raise RuntimeError("Intentional error in on_event") - - def on_graph_end(self, error): - raise RuntimeError("Intentional error in on_graph_end") - - runner = WorkflowRunner() - - # Load workflow - fixture_data = runner.load_fixture("simple_passthrough_workflow") - graph, graph_runtime_state = runner.create_graph_from_fixture(fixture_data, inputs={"query": "test error handling"}) - - # Create engine with faulty layer - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - # Add faulty layer - engine.layer(FaultyLayer()) - - # Run workflow - should not crash despite layer errors - events = list(engine.run()) - - # Verify workflow still completed successfully - assert len(events) > 0 - assert isinstance(events[-1], GraphRunSucceededEvent) - assert events[-1].outputs == {"query": "test error handling"} - - -def test_event_sequence_validation(): - """Test the new event sequence validation feature.""" - from dify_graph.graph_events import NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent - - runner = TableTestRunner() - - # Test 1: Successful event sequence validation - test_case_success = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test event sequence"}, - expected_outputs={"query": "test event sequence"}, - expected_event_sequence=[ - GraphRunStartedEvent, - NodeRunStartedEvent, # Start node begins - NodeRunStreamChunkEvent, # Start node streaming - NodeRunSucceededEvent, # Start node completes - NodeRunStartedEvent, # End node begins - NodeRunSucceededEvent, # End node completes - GraphRunSucceededEvent, # Graph completes - ], - description="Test with correct event sequence", - ) - - result = runner.run_test_case(test_case_success) - assert result.success, f"Test should pass with correct event sequence. Error: {result.event_mismatch_details}" - assert result.event_sequence_match is True - assert result.event_mismatch_details is None - - # Test 2: Failed event sequence validation - wrong order - test_case_wrong_order = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test wrong order"}, - expected_outputs={"query": "test wrong order"}, - expected_event_sequence=[ - GraphRunStartedEvent, - NodeRunSucceededEvent, # Wrong: expecting success before start - NodeRunStreamChunkEvent, - NodeRunStartedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], - description="Test with incorrect event order", - ) - - result = runner.run_test_case(test_case_wrong_order) - assert not result.success, "Test should fail with incorrect event sequence" - assert result.event_sequence_match is False - assert result.event_mismatch_details is not None - assert "Event mismatch at position" in result.event_mismatch_details - - # Test 3: Failed event sequence validation - wrong count - test_case_wrong_count = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test wrong count"}, - expected_outputs={"query": "test wrong count"}, - expected_event_sequence=[ - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Missing the second node's events - GraphRunSucceededEvent, - ], - description="Test with incorrect event count", - ) - - result = runner.run_test_case(test_case_wrong_count) - assert not result.success, "Test should fail with incorrect event count" - assert result.event_sequence_match is False - assert result.event_mismatch_details is not None - assert "Event count mismatch" in result.event_mismatch_details - - # Test 4: No event sequence validation (backward compatibility) - test_case_no_validation = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test no validation"}, - expected_outputs={"query": "test no validation"}, - # No expected_event_sequence provided - description="Test without event sequence validation", - ) - - result = runner.run_test_case(test_case_no_validation) - assert result.success, "Test should pass when no event sequence is provided" - assert result.event_sequence_match is None - assert result.event_mismatch_details is None - - -def test_event_sequence_validation_with_table_tests(): - """Test event sequence validation with table-driven tests.""" - from dify_graph.graph_events import NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent - - runner = TableTestRunner() - - test_cases = [ - WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test1"}, - expected_outputs={"query": "test1"}, - expected_event_sequence=[ - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], - description="Table test 1: Valid sequence", - ), - WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test2"}, - expected_outputs={"query": "test2"}, - # No event sequence validation for this test - description="Table test 2: No sequence validation", - ), - WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test3"}, - expected_outputs={"query": "test3"}, - expected_event_sequence=[ - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], - description="Table test 3: Valid sequence", - ), - ] - - suite_result = runner.run_table_tests(test_cases) - - # Check all tests passed - for i, result in enumerate(suite_result.results): - if i == 1: # Test 2 has no event sequence validation - assert result.event_sequence_match is None - else: - assert result.event_sequence_match is True - assert result.success, f"Test {i + 1} failed: {result.event_mismatch_details or result.error}" - - -def test_graph_run_emits_partial_success_when_node_failure_recovered(): - runner = TableTestRunner() - - fixture_data = runner.workflow_runner.load_fixture("basic_chatflow") - mock_config = MockConfigBuilder().with_node_error("llm", "mock llm failure").build() - - graph, graph_runtime_state = runner.workflow_runner.create_graph_from_fixture( - fixture_data=fixture_data, - query="hello", - use_mock_factory=True, - mock_config=mock_config, - ) - - llm_node = graph.nodes["llm"] - base_node_data = llm_node.node_data - base_node_data.error_strategy = ErrorStrategy.DEFAULT_VALUE - base_node_data.default_value = [DefaultValue(key="text", value="fallback response", type=DefaultValueType.STRING)] - - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - events = list(engine.run()) - - assert isinstance(events[-1], GraphRunPartialSucceededEvent) - - partial_event = next(event for event in events if isinstance(event, GraphRunPartialSucceededEvent)) - assert partial_event.exceptions_count == 1 - assert partial_event.outputs.get("answer") == "fallback response" - - assert not any(isinstance(event, GraphRunSucceededEvent) for event in events) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py deleted file mode 100644 index 255784b77d2..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py +++ /dev/null @@ -1,196 +0,0 @@ -"""Unit tests for GraphExecution serialization helpers.""" - -from __future__ import annotations - -import json -from collections import deque -from unittest.mock import MagicMock - -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeState -from dify_graph.graph_engine.domain import GraphExecution -from dify_graph.graph_engine.response_coordinator import ResponseStreamCoordinator -from dify_graph.graph_engine.response_coordinator.path import Path -from dify_graph.graph_engine.response_coordinator.session import ResponseSession -from dify_graph.graph_events import NodeRunStreamChunkEvent -from dify_graph.nodes.base.template import Template, TextSegment, VariableSegment - - -class CustomGraphExecutionError(Exception): - """Custom exception used to verify error serialization.""" - - -def test_graph_execution_serialization_round_trip() -> None: - """GraphExecution serialization restores full aggregate state.""" - # Arrange - execution = GraphExecution(workflow_id="wf-1") - execution.start() - node_a = execution.get_or_create_node_execution("node-a") - node_a.mark_started(execution_id="exec-1") - node_a.increment_retry() - node_a.mark_failed("boom") - node_b = execution.get_or_create_node_execution("node-b") - node_b.mark_skipped() - execution.fail(CustomGraphExecutionError("serialization failure")) - - # Act - serialized = execution.dumps() - payload = json.loads(serialized) - restored = GraphExecution(workflow_id="wf-1") - restored.loads(serialized) - - # Assert - assert payload["type"] == "GraphExecution" - assert payload["version"] == "1.0" - assert restored.workflow_id == "wf-1" - assert restored.started is True - assert restored.completed is True - assert restored.aborted is False - assert isinstance(restored.error, CustomGraphExecutionError) - assert str(restored.error) == "serialization failure" - assert set(restored.node_executions) == {"node-a", "node-b"} - restored_node_a = restored.node_executions["node-a"] - assert restored_node_a.state is NodeState.TAKEN - assert restored_node_a.retry_count == 1 - assert restored_node_a.execution_id == "exec-1" - assert restored_node_a.error == "boom" - restored_node_b = restored.node_executions["node-b"] - assert restored_node_b.state is NodeState.SKIPPED - assert restored_node_b.retry_count == 0 - assert restored_node_b.execution_id is None - assert restored_node_b.error is None - - -def test_graph_execution_loads_replaces_existing_state() -> None: - """loads replaces existing runtime data with serialized snapshot.""" - # Arrange - source = GraphExecution(workflow_id="wf-2") - source.start() - source_node = source.get_or_create_node_execution("node-source") - source_node.mark_taken() - serialized = source.dumps() - - target = GraphExecution(workflow_id="wf-2") - target.start() - target.abort("pre-existing abort") - temp_node = target.get_or_create_node_execution("node-temp") - temp_node.increment_retry() - temp_node.mark_failed("temp error") - - # Act - target.loads(serialized) - - # Assert - assert target.aborted is False - assert target.error is None - assert target.started is True - assert target.completed is False - assert set(target.node_executions) == {"node-source"} - restored_node = target.node_executions["node-source"] - assert restored_node.state is NodeState.TAKEN - assert restored_node.retry_count == 0 - assert restored_node.execution_id is None - assert restored_node.error is None - - -def test_response_stream_coordinator_serialization_round_trip(monkeypatch) -> None: - """ResponseStreamCoordinator serialization restores coordinator internals.""" - - template_main = Template(segments=[TextSegment(text="Hi "), VariableSegment(selector=["node-source", "text"])]) - template_secondary = Template(segments=[TextSegment(text="secondary")]) - - class DummyNode: - def __init__(self, node_id: str, template: Template, execution_type: NodeExecutionType) -> None: - self.id = node_id - self.node_type = ( - BuiltinNodeTypes.ANSWER if execution_type == NodeExecutionType.RESPONSE else BuiltinNodeTypes.LLM - ) - self.execution_type = execution_type - self.state = NodeState.UNKNOWN - self.title = node_id - self.template = template - - def blocks_variable_output(self, *_args) -> bool: - return False - - response_node1 = DummyNode("response-1", template_main, NodeExecutionType.RESPONSE) - response_node2 = DummyNode("response-2", template_main, NodeExecutionType.RESPONSE) - response_node3 = DummyNode("response-3", template_main, NodeExecutionType.RESPONSE) - source_node = DummyNode("node-source", template_secondary, NodeExecutionType.EXECUTABLE) - - class DummyGraph: - def __init__(self) -> None: - self.nodes = { - response_node1.id: response_node1, - response_node2.id: response_node2, - response_node3.id: response_node3, - source_node.id: source_node, - } - self.edges: dict[str, object] = {} - self.root_node = response_node1 - - def get_outgoing_edges(self, _node_id: str): # pragma: no cover - not exercised - return [] - - def get_incoming_edges(self, _node_id: str): # pragma: no cover - not exercised - return [] - - graph = DummyGraph() - - def fake_from_node(cls, node: DummyNode) -> ResponseSession: - return ResponseSession(node_id=node.id, template=node.template) - - monkeypatch.setattr(ResponseSession, "from_node", classmethod(fake_from_node)) - - coordinator = ResponseStreamCoordinator(variable_pool=MagicMock(), graph=graph) # type: ignore[arg-type] - coordinator._response_nodes = {"response-1", "response-2", "response-3"} - coordinator._paths_maps = { - "response-1": [Path(edges=["edge-1"])], - "response-2": [Path(edges=[])], - "response-3": [Path(edges=["edge-2", "edge-3"])], - } - - active_session = ResponseSession(node_id="response-1", template=response_node1.template) - active_session.index = 1 - coordinator._active_session = active_session - waiting_session = ResponseSession(node_id="response-2", template=response_node2.template) - coordinator._waiting_sessions = deque([waiting_session]) - pending_session = ResponseSession(node_id="response-3", template=response_node3.template) - pending_session.index = 2 - coordinator._response_sessions = {"response-3": pending_session} - - coordinator._node_execution_ids = {"response-1": "exec-1"} - event = NodeRunStreamChunkEvent( - id="exec-1", - node_id="response-1", - node_type=BuiltinNodeTypes.ANSWER, - selector=["node-source", "text"], - chunk="chunk-1", - is_final=False, - ) - coordinator._stream_buffers = {("node-source", "text"): [event]} - coordinator._stream_positions = {("node-source", "text"): 1} - coordinator._closed_streams = {("node-source", "text")} - - serialized = coordinator.dumps() - - restored = ResponseStreamCoordinator(variable_pool=MagicMock(), graph=graph) # type: ignore[arg-type] - monkeypatch.setattr(ResponseSession, "from_node", classmethod(fake_from_node)) - restored.loads(serialized) - - assert restored._response_nodes == {"response-1", "response-2", "response-3"} - assert restored._paths_maps["response-1"][0].edges == ["edge-1"] - assert restored._active_session is not None - assert restored._active_session.node_id == "response-1" - assert restored._active_session.index == 1 - waiting_restored = list(restored._waiting_sessions) - assert len(waiting_restored) == 1 - assert waiting_restored[0].node_id == "response-2" - assert waiting_restored[0].index == 0 - assert set(restored._response_sessions) == {"response-3"} - assert restored._response_sessions["response-3"].index == 2 - assert restored._node_execution_ids == {"response-1": "exec-1"} - assert ("node-source", "text") in restored._stream_buffers - restored_event = restored._stream_buffers[("node-source", "text")][0] - assert restored_event.chunk == "chunk-1" - assert restored._stream_positions[("node-source", "text")] == 1 - assert ("node-source", "text") in restored._closed_streams diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py deleted file mode 100644 index d54f0be1904..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py +++ /dev/null @@ -1,190 +0,0 @@ -import time -from collections.abc import Mapping - -from dify_graph.entities import GraphInitParams -from dify_graph.enums import NodeState -from dify_graph.graph import Graph -from dify_graph.graph_engine.graph_state_manager import GraphStateManager -from dify_graph.graph_engine.ready_queue import InMemoryReadyQueue -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole -from dify_graph.nodes.end.end_node import EndNode -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_mock_config import MockConfig -from .test_mock_nodes import MockLLMNode - - -def _build_runtime_state() -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="exec-1", - ), - user_inputs={}, - conversation_variables=[], - ) - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -def _build_llm_node( - *, - node_id: str, - runtime_state: GraphRuntimeState, - graph_init_params: GraphInitParams, - mock_config: MockConfig, -) -> MockLLMNode: - llm_data = LLMNodeData( - title=f"LLM {node_id}", - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text=f"Prompt {node_id}", - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - ) - llm_config = {"id": node_id, "data": llm_data.model_dump()} - return MockLLMNode( - id=llm_config["id"], - config=llm_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - mock_config=mock_config, - ) - - -def _build_graph(runtime_state: GraphRuntimeState) -> Graph: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - mock_config = MockConfig() - llm_a = _build_llm_node( - node_id="llm_a", - runtime_state=runtime_state, - graph_init_params=graph_init_params, - mock_config=mock_config, - ) - llm_b = _build_llm_node( - node_id="llm_b", - runtime_state=runtime_state, - graph_init_params=graph_init_params, - mock_config=mock_config, - ) - - end_data = EndNodeData(title="End", outputs=[], desc=None) - end_config = {"id": "end", "data": end_data.model_dump()} - end_node = EndNode( - id=end_config["id"], - config=end_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - builder = ( - Graph.new() - .add_root(start_node) - .add_node(llm_a, from_node_id="start") - .add_node(llm_b, from_node_id="start") - .add_node(end_node, from_node_id="llm_a") - ) - return builder.connect(tail="llm_b", head="end").build() - - -def _edge_state_map(graph: Graph) -> Mapping[tuple[str, str, str], NodeState]: - return {(edge.tail, edge.head, edge.source_handle): edge.state for edge in graph.edges.values()} - - -def test_runtime_state_snapshot_restores_graph_states() -> None: - runtime_state = _build_runtime_state() - graph = _build_graph(runtime_state) - runtime_state.attach_graph(graph) - - graph.nodes["llm_a"].state = NodeState.TAKEN - graph.nodes["llm_b"].state = NodeState.SKIPPED - - for edge in graph.edges.values(): - if edge.tail == "start" and edge.head == "llm_a": - edge.state = NodeState.TAKEN - elif edge.tail == "start" and edge.head == "llm_b": - edge.state = NodeState.SKIPPED - elif edge.head == "end" and edge.tail == "llm_a": - edge.state = NodeState.TAKEN - elif edge.head == "end" and edge.tail == "llm_b": - edge.state = NodeState.SKIPPED - - snapshot = runtime_state.dumps() - - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - resumed_graph = _build_graph(resumed_state) - resumed_state.attach_graph(resumed_graph) - - assert resumed_graph.nodes["llm_a"].state == NodeState.TAKEN - assert resumed_graph.nodes["llm_b"].state == NodeState.SKIPPED - assert _edge_state_map(resumed_graph) == _edge_state_map(graph) - - -def test_join_readiness_uses_restored_edge_states() -> None: - runtime_state = _build_runtime_state() - graph = _build_graph(runtime_state) - runtime_state.attach_graph(graph) - - ready_queue = InMemoryReadyQueue() - state_manager = GraphStateManager(graph, ready_queue) - - for edge in graph.get_incoming_edges("end"): - if edge.tail == "llm_a": - edge.state = NodeState.TAKEN - if edge.tail == "llm_b": - edge.state = NodeState.UNKNOWN - - assert state_manager.is_node_ready("end") is False - - for edge in graph.get_incoming_edges("end"): - if edge.tail == "llm_b": - edge.state = NodeState.TAKEN - - assert state_manager.is_node_ready("end") is True - - snapshot = runtime_state.dumps() - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - resumed_graph = _build_graph(resumed_state) - resumed_state.attach_graph(resumed_graph) - - resumed_state_manager = GraphStateManager(resumed_graph, InMemoryReadyQueue()) - assert resumed_state_manager.is_node_ready("end") is True diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py deleted file mode 100644 index 538f53c6039..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py +++ /dev/null @@ -1,387 +0,0 @@ -import datetime -import time -from collections.abc import Iterable -from unittest import mock -from unittest.mock import MagicMock - -from dify_graph.graph import Graph -from dify_graph.graph_events import ( - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunPauseRequestedEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) -from dify_graph.graph_events.node import NodeRunHumanInputFormFilledEvent -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole -from dify_graph.nodes.base.entities import OutputVariableEntity, OutputVariableType -from dify_graph.nodes.end.end_node import EndNode -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from libs.datetime_utils import naive_utc_now -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_mock_config import MockConfig -from .test_mock_nodes import MockLLMNode -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def _build_branching_graph( - mock_config: MockConfig, - form_repository: HumanInputFormRepository, - graph_runtime_state: GraphRuntimeState | None = None, -) -> tuple[Graph, GraphRuntimeState]: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - if graph_runtime_state is None: - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="test-execution-id", - ), - user_inputs={}, - conversation_variables=[], - ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode: - llm_data = LLMNodeData( - title=title, - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text=prompt_text, - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - ) - llm_config = {"id": node_id, "data": llm_data.model_dump()} - llm_node = MockLLMNode( - id=node_id, - config=llm_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - credentials_provider=mock.Mock(), - model_factory=mock.Mock(), - ) - return llm_node - - llm_initial = _create_llm_node("llm_initial", "Initial LLM", "Initial stream") - - human_data = HumanInputNodeData( - title="Human Input", - form_content="Human input required", - inputs=[], - user_actions=[ - UserAction(id="primary", title="Primary"), - UserAction(id="secondary", title="Secondary"), - ], - ) - - human_config = {"id": "human", "data": human_data.model_dump()} - human_node = HumanInputNode( - id=human_config["id"], - config=human_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - form_repository=form_repository, - ) - - llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output") - llm_secondary = _create_llm_node("llm_secondary", "Secondary LLM", "Secondary") - - end_primary_data = EndNodeData( - title="End Primary", - outputs=[ - OutputVariableEntity( - variable="initial_text", value_type=OutputVariableType.STRING, value_selector=["llm_initial", "text"] - ), - OutputVariableEntity( - variable="primary_text", value_type=OutputVariableType.STRING, value_selector=["llm_primary", "text"] - ), - ], - desc=None, - ) - end_primary_config = {"id": "end_primary", "data": end_primary_data.model_dump()} - end_primary = EndNode( - id=end_primary_config["id"], - config=end_primary_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - end_secondary_data = EndNodeData( - title="End Secondary", - outputs=[ - OutputVariableEntity( - variable="initial_text", value_type=OutputVariableType.STRING, value_selector=["llm_initial", "text"] - ), - OutputVariableEntity( - variable="secondary_text", - value_type=OutputVariableType.STRING, - value_selector=["llm_secondary", "text"], - ), - ], - desc=None, - ) - end_secondary_config = {"id": "end_secondary", "data": end_secondary_data.model_dump()} - end_secondary = EndNode( - id=end_secondary_config["id"], - config=end_secondary_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - graph = ( - Graph.new() - .add_root(start_node) - .add_node(llm_initial) - .add_node(human_node) - .add_node(llm_primary, from_node_id="human", source_handle="primary") - .add_node(end_primary, from_node_id="llm_primary") - .add_node(llm_secondary, from_node_id="human", source_handle="secondary") - .add_node(end_secondary, from_node_id="llm_secondary") - .build() - ) - return graph, graph_runtime_state - - -def _expected_mock_llm_chunks(text: str) -> list[str]: - chunks: list[str] = [] - for index, word in enumerate(text.split(" ")): - chunk = word if index == 0 else f" {word}" - chunks.append(chunk) - chunks.append("") - return chunks - - -def _assert_stream_chunk_sequence( - chunk_events: Iterable[NodeRunStreamChunkEvent], - expected_nodes: list[str], - expected_chunks: list[str], -) -> None: - actual_nodes = [event.node_id for event in chunk_events] - actual_chunks = [event.chunk for event in chunk_events] - assert actual_nodes == expected_nodes - assert actual_chunks == expected_chunks - - -def test_human_input_llm_streaming_across_multiple_branches() -> None: - mock_config = MockConfig() - mock_config.set_node_outputs("llm_initial", {"text": "Initial stream"}) - mock_config.set_node_outputs("llm_primary", {"text": "Primary stream output"}) - mock_config.set_node_outputs("llm_secondary", {"text": "Secondary"}) - - branch_scenarios = [ - { - "handle": "primary", - "resume_llm": "llm_primary", - "end_node": "end_primary", - "expected_pre_chunks": [ - ("llm_initial", _expected_mock_llm_chunks("Initial stream")), # cached output before branch completes - ("end_primary", ["\n"]), # literal segment emitted when end_primary session activates - ], - "expected_post_chunks": [ - ("llm_primary", _expected_mock_llm_chunks("Primary stream output")), # live stream from chosen branch - ], - }, - { - "handle": "secondary", - "resume_llm": "llm_secondary", - "end_node": "end_secondary", - "expected_pre_chunks": [ - ("llm_initial", _expected_mock_llm_chunks("Initial stream")), # cached output before branch completes - ("end_secondary", ["\n"]), # literal segment emitted when end_secondary session activates - ], - "expected_post_chunks": [ - ("llm_secondary", _expected_mock_llm_chunks("Secondary")), # live stream from chosen branch - ], - }, - ] - - for scenario in branch_scenarios: - runner = TableTestRunner() - - mock_create_repo = MagicMock(spec=HumanInputFormRepository) - mock_create_repo.get_form.return_value = None - mock_form_entity = MagicMock(spec=HumanInputFormEntity) - mock_form_entity.id = "test_form_id" - mock_form_entity.web_app_token = "test_web_app_token" - mock_form_entity.recipients = [] - mock_form_entity.rendered_content = "rendered" - mock_form_entity.submitted = False - mock_create_repo.create_form.return_value = mock_form_entity - - def initial_graph_factory(mock_create_repo=mock_create_repo) -> tuple[Graph, GraphRuntimeState]: - return _build_branching_graph(mock_config, mock_create_repo) - - initial_case = WorkflowTestCase( - description="HumanInput pause before branching decision", - graph_factory=initial_graph_factory, - expected_event_sequence=[ - GraphRunStartedEvent, # initial run: graph execution starts - NodeRunStartedEvent, # start node begins execution - NodeRunSucceededEvent, # start node completes - NodeRunStartedEvent, # llm_initial starts streaming - NodeRunSucceededEvent, # llm_initial completes streaming - NodeRunStartedEvent, # human node begins and issues pause - NodeRunPauseRequestedEvent, # human node requests pause awaiting input - GraphRunPausedEvent, # graph run pauses awaiting resume - ], - ) - - initial_result = runner.run_test_case(initial_case) - - assert initial_result.success, initial_result.event_mismatch_details - assert not any(isinstance(event, NodeRunStreamChunkEvent) for event in initial_result.events) - - pre_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_pre_chunks"]) - post_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_post_chunks"]) - expected_pre_chunk_events_in_resumption = [ - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunHumanInputFormFilledEvent, - ] - - expected_resume_sequence: list[type] = ( - expected_pre_chunk_events_in_resumption - + [NodeRunStreamChunkEvent] * pre_chunk_count - + [ - NodeRunSucceededEvent, - NodeRunStartedEvent, - ] - + [NodeRunStreamChunkEvent] * post_chunk_count - + [ - NodeRunSucceededEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ] - ) - - mock_get_repo = MagicMock(spec=HumanInputFormRepository) - submitted_form = MagicMock(spec=HumanInputFormEntity) - submitted_form.id = mock_form_entity.id - submitted_form.web_app_token = mock_form_entity.web_app_token - submitted_form.recipients = [] - submitted_form.rendered_content = mock_form_entity.rendered_content - submitted_form.submitted = True - submitted_form.selected_action_id = scenario["handle"] - submitted_form.submitted_data = {} - submitted_form.expiration_time = naive_utc_now() + datetime.timedelta(days=1) - mock_get_repo.get_form.return_value = submitted_form - - def resume_graph_factory( - initial_result=initial_result, mock_get_repo=mock_get_repo - ) -> tuple[Graph, GraphRuntimeState]: - assert initial_result.graph_runtime_state is not None - serialized_runtime_state = initial_result.graph_runtime_state.dumps() - resume_runtime_state = GraphRuntimeState.from_snapshot(serialized_runtime_state) - return _build_branching_graph(mock_config, mock_get_repo, resume_runtime_state) - - resume_case = WorkflowTestCase( - description=f"HumanInput resumes via {scenario['handle']} branch", - graph_factory=resume_graph_factory, - expected_event_sequence=expected_resume_sequence, - ) - - resume_result = runner.run_test_case(resume_case) - - assert resume_result.success, resume_result.event_mismatch_details - - resume_events = resume_result.events - - chunk_events = [event for event in resume_events if isinstance(event, NodeRunStreamChunkEvent)] - assert len(chunk_events) == pre_chunk_count + post_chunk_count - - pre_chunk_events = chunk_events[:pre_chunk_count] - post_chunk_events = chunk_events[pre_chunk_count:] - - expected_pre_nodes: list[str] = [] - expected_pre_chunks: list[str] = [] - for node_id, chunks in scenario["expected_pre_chunks"]: - expected_pre_nodes.extend([node_id] * len(chunks)) - expected_pre_chunks.extend(chunks) - _assert_stream_chunk_sequence(pre_chunk_events, expected_pre_nodes, expected_pre_chunks) - - expected_post_nodes: list[str] = [] - expected_post_chunks: list[str] = [] - for node_id, chunks in scenario["expected_post_chunks"]: - expected_post_nodes.extend([node_id] * len(chunks)) - expected_post_chunks.extend(chunks) - _assert_stream_chunk_sequence(post_chunk_events, expected_post_nodes, expected_post_chunks) - - human_success_index = next( - index - for index, event in enumerate(resume_events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == "human" - ) - pre_indices = [ - index - for index, event in enumerate(resume_events) - if isinstance(event, NodeRunStreamChunkEvent) and index < human_success_index - ] - expected_pre_chunk_events_count_in_resumption = len(expected_pre_chunk_events_in_resumption) - assert pre_indices == list(range(expected_pre_chunk_events_count_in_resumption, human_success_index)) - - resume_chunk_indices = [ - index - for index, event in enumerate(resume_events) - if isinstance(event, NodeRunStreamChunkEvent) and event.node_id == scenario["resume_llm"] - ] - assert resume_chunk_indices, "Expected streaming output from the selected branch" - resume_start_index = next( - index - for index, event in enumerate(resume_events) - if isinstance(event, NodeRunStartedEvent) and event.node_id == scenario["resume_llm"] - ) - resume_success_index = next( - index - for index, event in enumerate(resume_events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == scenario["resume_llm"] - ) - assert resume_start_index < min(resume_chunk_indices) - assert max(resume_chunk_indices) < resume_success_index - - started_nodes = [event.node_id for event in resume_events if isinstance(event, NodeRunStartedEvent)] - assert started_nodes == ["human", scenario["resume_llm"], scenario["end_node"]] diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py deleted file mode 100644 index 36bba6deb66..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py +++ /dev/null @@ -1,344 +0,0 @@ -import datetime -import time -from unittest import mock -from unittest.mock import MagicMock - -from dify_graph.graph import Graph -from dify_graph.graph_events import ( - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunPauseRequestedEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) -from dify_graph.graph_events.node import NodeRunHumanInputFormFilledEvent -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole -from dify_graph.nodes.base.entities import OutputVariableEntity, OutputVariableType -from dify_graph.nodes.end.end_node import EndNode -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from libs.datetime_utils import naive_utc_now -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_mock_config import MockConfig -from .test_mock_nodes import MockLLMNode -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def _build_llm_human_llm_graph( - mock_config: MockConfig, - form_repository: HumanInputFormRepository, - graph_runtime_state: GraphRuntimeState | None = None, -) -> tuple[Graph, GraphRuntimeState]: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - if graph_runtime_state is None: - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="user", app_id="app", workflow_id="workflow", workflow_execution_id="test-execution-id," - ), - user_inputs={}, - conversation_variables=[], - ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode: - llm_data = LLMNodeData( - title=title, - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text=prompt_text, - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - ) - llm_config = {"id": node_id, "data": llm_data.model_dump()} - llm_node = MockLLMNode( - id=node_id, - config=llm_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - credentials_provider=mock.Mock(), - model_factory=mock.Mock(), - ) - return llm_node - - llm_first = _create_llm_node("llm_initial", "Initial LLM", "Initial prompt") - - human_data = HumanInputNodeData( - title="Human Input", - form_content="Human input required", - inputs=[], - user_actions=[ - UserAction(id="accept", title="Accept"), - UserAction(id="reject", title="Reject"), - ], - ) - - human_config = {"id": "human", "data": human_data.model_dump()} - human_node = HumanInputNode( - id=human_config["id"], - config=human_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - form_repository=form_repository, - ) - - llm_second = _create_llm_node("llm_resume", "Follow-up LLM", "Follow-up prompt") - - end_data = EndNodeData( - title="End", - outputs=[ - OutputVariableEntity( - variable="initial_text", value_type=OutputVariableType.STRING, value_selector=["llm_initial", "text"] - ), - OutputVariableEntity( - variable="resume_text", value_type=OutputVariableType.STRING, value_selector=["llm_resume", "text"] - ), - ], - desc=None, - ) - end_config = {"id": "end", "data": end_data.model_dump()} - end_node = EndNode( - id=end_config["id"], - config=end_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - graph = ( - Graph.new() - .add_root(start_node) - .add_node(llm_first) - .add_node(human_node) - .add_node(llm_second, source_handle="accept") - .add_node(end_node) - .build() - ) - return graph, graph_runtime_state - - -def _expected_mock_llm_chunks(text: str) -> list[str]: - chunks: list[str] = [] - for index, word in enumerate(text.split(" ")): - chunk = word if index == 0 else f" {word}" - chunks.append(chunk) - chunks.append("") - return chunks - - -def test_human_input_llm_streaming_order_across_pause() -> None: - runner = TableTestRunner() - - initial_text = "Hello, pause" - resume_text = "Welcome back!" - - mock_config = MockConfig() - mock_config.set_node_outputs("llm_initial", {"text": initial_text}) - mock_config.set_node_outputs("llm_resume", {"text": resume_text}) - - expected_initial_sequence: list[type] = [ - GraphRunStartedEvent, # graph run begins - NodeRunStartedEvent, # start node begins - NodeRunSucceededEvent, # start node completes - NodeRunStartedEvent, # llm_initial begins streaming - NodeRunSucceededEvent, # llm_initial completes streaming - NodeRunStartedEvent, # human node begins and requests pause - NodeRunPauseRequestedEvent, # human node pause requested - GraphRunPausedEvent, # graph run pauses awaiting resume - ] - - mock_create_repo = MagicMock(spec=HumanInputFormRepository) - mock_create_repo.get_form.return_value = None - mock_form_entity = MagicMock(spec=HumanInputFormEntity) - mock_form_entity.id = "test_form_id" - mock_form_entity.web_app_token = "test_web_app_token" - mock_form_entity.recipients = [] - mock_form_entity.rendered_content = "rendered" - mock_form_entity.submitted = False - mock_create_repo.create_form.return_value = mock_form_entity - - def graph_factory() -> tuple[Graph, GraphRuntimeState]: - return _build_llm_human_llm_graph(mock_config, mock_create_repo) - - initial_case = WorkflowTestCase( - description="HumanInput pause preserves LLM streaming order", - graph_factory=graph_factory, - expected_event_sequence=expected_initial_sequence, - ) - - initial_result = runner.run_test_case(initial_case) - - assert initial_result.success, initial_result.event_mismatch_details - - initial_events = initial_result.events - initial_chunks = _expected_mock_llm_chunks(initial_text) - - initial_stream_chunk_events = [event for event in initial_events if isinstance(event, NodeRunStreamChunkEvent)] - assert initial_stream_chunk_events == [] - - pause_index = next(i for i, event in enumerate(initial_events) if isinstance(event, GraphRunPausedEvent)) - llm_succeeded_index = next( - i - for i, event in enumerate(initial_events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == "llm_initial" - ) - assert llm_succeeded_index < pause_index - - graph_runtime_state = initial_result.graph_runtime_state - graph = initial_result.graph - assert graph_runtime_state is not None - assert graph is not None - - coordinator = graph_runtime_state.response_coordinator - stream_buffers = coordinator._stream_buffers # Tests may access internals for assertions - assert ("llm_initial", "text") in stream_buffers - initial_stream_chunks = [event.chunk for event in stream_buffers[("llm_initial", "text")]] - assert initial_stream_chunks == initial_chunks - assert ("llm_resume", "text") not in stream_buffers - - resume_chunks = _expected_mock_llm_chunks(resume_text) - expected_resume_sequence: list[type] = [ - GraphRunStartedEvent, # resumed graph run begins - NodeRunStartedEvent, # human node restarts - # Form Filled should be generated first, then the node execution ends and stream chunk is generated. - NodeRunHumanInputFormFilledEvent, - NodeRunStreamChunkEvent, # cached llm_initial chunk 1 - NodeRunStreamChunkEvent, # cached llm_initial chunk 2 - NodeRunStreamChunkEvent, # cached llm_initial final chunk - NodeRunStreamChunkEvent, # end node emits combined template separator - NodeRunSucceededEvent, # human node finishes instantly after input - NodeRunStartedEvent, # llm_resume begins streaming - NodeRunStreamChunkEvent, # llm_resume chunk 1 - NodeRunStreamChunkEvent, # llm_resume chunk 2 - NodeRunStreamChunkEvent, # llm_resume final chunk - NodeRunSucceededEvent, # llm_resume completes streaming - NodeRunStartedEvent, # end node starts - NodeRunSucceededEvent, # end node finishes - GraphRunSucceededEvent, # graph run succeeds after resume - ] - - mock_get_repo = MagicMock(spec=HumanInputFormRepository) - submitted_form = MagicMock(spec=HumanInputFormEntity) - submitted_form.id = mock_form_entity.id - submitted_form.web_app_token = mock_form_entity.web_app_token - submitted_form.recipients = [] - submitted_form.rendered_content = mock_form_entity.rendered_content - submitted_form.submitted = True - submitted_form.selected_action_id = "accept" - submitted_form.submitted_data = {} - submitted_form.expiration_time = naive_utc_now() + datetime.timedelta(days=1) - mock_get_repo.get_form.return_value = submitted_form - - def resume_graph_factory() -> tuple[Graph, GraphRuntimeState]: - # restruct the graph runtime state - serialized_runtime_state = initial_result.graph_runtime_state.dumps() - resume_runtime_state = GraphRuntimeState.from_snapshot(serialized_runtime_state) - return _build_llm_human_llm_graph( - mock_config, - mock_get_repo, - resume_runtime_state, - ) - - resume_case = WorkflowTestCase( - description="HumanInput resume continues LLM streaming order", - graph_factory=resume_graph_factory, - expected_event_sequence=expected_resume_sequence, - ) - - resume_result = runner.run_test_case(resume_case) - - assert resume_result.success, resume_result.event_mismatch_details - - resume_events = resume_result.events - - success_index = next(i for i, event in enumerate(resume_events) if isinstance(event, GraphRunSucceededEvent)) - llm_resume_succeeded_index = next( - i - for i, event in enumerate(resume_events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == "llm_resume" - ) - assert llm_resume_succeeded_index < success_index - - resume_chunk_events = [event for event in resume_events if isinstance(event, NodeRunStreamChunkEvent)] - assert [event.node_id for event in resume_chunk_events[:3]] == ["llm_initial"] * 3 - assert [event.chunk for event in resume_chunk_events[:3]] == initial_chunks - assert resume_chunk_events[3].node_id == "end" - assert resume_chunk_events[3].chunk == "\n" - assert [event.node_id for event in resume_chunk_events[4:]] == ["llm_resume"] * 3 - assert [event.chunk for event in resume_chunk_events[4:]] == resume_chunks - - human_success_index = next( - i - for i, event in enumerate(resume_events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == "human" - ) - cached_chunk_indices = [ - i - for i, event in enumerate(resume_events) - if isinstance(event, NodeRunStreamChunkEvent) and event.node_id in {"llm_initial", "end"} - ] - assert all(index < human_success_index for index in cached_chunk_indices) - - llm_resume_start_index = next( - i - for i, event in enumerate(resume_events) - if isinstance(event, NodeRunStartedEvent) and event.node_id == "llm_resume" - ) - llm_resume_success_index = next( - i - for i, event in enumerate(resume_events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == "llm_resume" - ) - llm_resume_chunk_indices = [ - i - for i, event in enumerate(resume_events) - if isinstance(event, NodeRunStreamChunkEvent) and event.node_id == "llm_resume" - ] - assert llm_resume_chunk_indices - first_resume_chunk_index = min(llm_resume_chunk_indices) - last_resume_chunk_index = max(llm_resume_chunk_indices) - assert llm_resume_start_index < first_resume_chunk_index - assert last_resume_chunk_index < llm_resume_success_index - - started_nodes = [event.node_id for event in resume_events if isinstance(event, NodeRunStartedEvent)] - assert started_nodes == ["human", "llm_resume", "end"] diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py b/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py deleted file mode 100644 index 8da179c15ea..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py +++ /dev/null @@ -1,324 +0,0 @@ -import time -from unittest import mock - -from dify_graph.graph import Graph -from dify_graph.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole -from dify_graph.nodes.base.entities import OutputVariableEntity, OutputVariableType -from dify_graph.nodes.end.end_node import EndNode -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.if_else.entities import IfElseNodeData -from dify_graph.nodes.if_else.if_else_node import IfElseNode -from dify_graph.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.utils.condition.entities import Condition -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_mock_config import MockConfig -from .test_mock_nodes import MockLLMNode -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = build_test_graph_init_params( - graph_config=graph_config, - user_from="account", - invoke_from="debugger", - ) - - variable_pool = VariablePool( - system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"), - user_inputs={}, - conversation_variables=[], - ) - variable_pool.add(("branch", "value"), branch_value) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode: - llm_data = LLMNodeData( - title=title, - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text=prompt_text, - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - ) - llm_config = {"id": node_id, "data": llm_data.model_dump()} - llm_node = MockLLMNode( - id=node_id, - config=llm_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - credentials_provider=mock.Mock(), - model_factory=mock.Mock(), - ) - return llm_node - - llm_initial = _create_llm_node("llm_initial", "Initial LLM", "Initial stream") - - if_else_data = IfElseNodeData( - title="IfElse", - cases=[ - IfElseNodeData.Case( - case_id="primary", - logical_operator="and", - conditions=[ - Condition(variable_selector=["branch", "value"], comparison_operator="is", value="primary") - ], - ), - IfElseNodeData.Case( - case_id="secondary", - logical_operator="and", - conditions=[ - Condition(variable_selector=["branch", "value"], comparison_operator="is", value="secondary") - ], - ), - ], - ) - if_else_config = {"id": "if_else", "data": if_else_data.model_dump()} - if_else_node = IfElseNode( - id=if_else_config["id"], - config=if_else_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output") - llm_secondary = _create_llm_node("llm_secondary", "Secondary LLM", "Secondary") - - end_primary_data = EndNodeData( - title="End Primary", - outputs=[ - OutputVariableEntity( - variable="initial_text", value_type=OutputVariableType.STRING, value_selector=["llm_initial", "text"] - ), - OutputVariableEntity( - variable="primary_text", value_type=OutputVariableType.STRING, value_selector=["llm_primary", "text"] - ), - ], - desc=None, - ) - end_primary_config = {"id": "end_primary", "data": end_primary_data.model_dump()} - end_primary = EndNode( - id=end_primary_config["id"], - config=end_primary_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - end_secondary_data = EndNodeData( - title="End Secondary", - outputs=[ - OutputVariableEntity( - variable="initial_text", value_type=OutputVariableType.STRING, value_selector=["llm_initial", "text"] - ), - OutputVariableEntity( - variable="secondary_text", - value_type=OutputVariableType.STRING, - value_selector=["llm_secondary", "text"], - ), - ], - desc=None, - ) - end_secondary_config = {"id": "end_secondary", "data": end_secondary_data.model_dump()} - end_secondary = EndNode( - id=end_secondary_config["id"], - config=end_secondary_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - graph = ( - Graph.new() - .add_root(start_node) - .add_node(llm_initial) - .add_node(if_else_node) - .add_node(llm_primary, from_node_id="if_else", source_handle="primary") - .add_node(end_primary, from_node_id="llm_primary") - .add_node(llm_secondary, from_node_id="if_else", source_handle="secondary") - .add_node(end_secondary, from_node_id="llm_secondary") - .build() - ) - return graph, graph_runtime_state - - -def _expected_mock_llm_chunks(text: str) -> list[str]: - chunks: list[str] = [] - for index, word in enumerate(text.split(" ")): - chunk = word if index == 0 else f" {word}" - chunks.append(chunk) - chunks.append("") - return chunks - - -def test_if_else_llm_streaming_order() -> None: - mock_config = MockConfig() - mock_config.set_node_outputs("llm_initial", {"text": "Initial stream"}) - mock_config.set_node_outputs("llm_primary", {"text": "Primary stream output"}) - mock_config.set_node_outputs("llm_secondary", {"text": "Secondary"}) - - scenarios = [ - { - "branch": "primary", - "resume_llm": "llm_primary", - "end_node": "end_primary", - "expected_sequence": [ - GraphRunStartedEvent, # graph run begins - NodeRunStartedEvent, # start node begins execution - NodeRunSucceededEvent, # start node completes - NodeRunStartedEvent, # llm_initial starts and streams - NodeRunSucceededEvent, # llm_initial completes streaming - NodeRunStartedEvent, # if_else evaluates conditions - NodeRunStreamChunkEvent, # cached llm_initial chunk 1 flushed - NodeRunStreamChunkEvent, # cached llm_initial chunk 2 flushed - NodeRunStreamChunkEvent, # cached llm_initial final chunk flushed - NodeRunStreamChunkEvent, # template literal newline emitted - NodeRunSucceededEvent, # if_else completes branch selection - NodeRunStartedEvent, # llm_primary begins streaming - NodeRunStreamChunkEvent, # llm_primary chunk 1 - NodeRunStreamChunkEvent, # llm_primary chunk 2 - NodeRunStreamChunkEvent, # llm_primary chunk 3 - NodeRunStreamChunkEvent, # llm_primary final chunk - NodeRunSucceededEvent, # llm_primary completes streaming - NodeRunStartedEvent, # end_primary node starts - NodeRunSucceededEvent, # end_primary finishes aggregation - GraphRunSucceededEvent, # graph run succeeds - ], - "expected_chunks": [ - ("llm_initial", _expected_mock_llm_chunks("Initial stream")), - ("end_primary", ["\n"]), - ("llm_primary", _expected_mock_llm_chunks("Primary stream output")), - ], - }, - { - "branch": "secondary", - "resume_llm": "llm_secondary", - "end_node": "end_secondary", - "expected_sequence": [ - GraphRunStartedEvent, # graph run begins - NodeRunStartedEvent, # start node begins execution - NodeRunSucceededEvent, # start node completes - NodeRunStartedEvent, # llm_initial starts and streams - NodeRunSucceededEvent, # llm_initial completes streaming - NodeRunStartedEvent, # if_else evaluates conditions - NodeRunStreamChunkEvent, # cached llm_initial chunk 1 flushed - NodeRunStreamChunkEvent, # cached llm_initial chunk 2 flushed - NodeRunStreamChunkEvent, # cached llm_initial final chunk flushed - NodeRunStreamChunkEvent, # template literal newline emitted - NodeRunSucceededEvent, # if_else completes branch selection - NodeRunStartedEvent, # llm_secondary begins streaming - NodeRunStreamChunkEvent, # llm_secondary chunk 1 - NodeRunStreamChunkEvent, # llm_secondary final chunk - NodeRunSucceededEvent, # llm_secondary completes - NodeRunStartedEvent, # end_secondary node starts - NodeRunSucceededEvent, # end_secondary finishes aggregation - GraphRunSucceededEvent, # graph run succeeds - ], - "expected_chunks": [ - ("llm_initial", _expected_mock_llm_chunks("Initial stream")), - ("end_secondary", ["\n"]), - ("llm_secondary", _expected_mock_llm_chunks("Secondary")), - ], - }, - ] - - for scenario in scenarios: - runner = TableTestRunner() - - def graph_factory( - branch_value: str = scenario["branch"], - cfg: MockConfig = mock_config, - ) -> tuple[Graph, GraphRuntimeState]: - return _build_if_else_graph(branch_value, cfg) - - test_case = WorkflowTestCase( - description=f"IfElse streaming via {scenario['branch']} branch", - graph_factory=graph_factory, - expected_event_sequence=scenario["expected_sequence"], - ) - - result = runner.run_test_case(test_case) - - assert result.success, result.event_mismatch_details - - chunk_events = [event for event in result.events if isinstance(event, NodeRunStreamChunkEvent)] - expected_nodes: list[str] = [] - expected_chunks: list[str] = [] - for node_id, chunks in scenario["expected_chunks"]: - expected_nodes.extend([node_id] * len(chunks)) - expected_chunks.extend(chunks) - assert [event.node_id for event in chunk_events] == expected_nodes - assert [event.chunk for event in chunk_events] == expected_chunks - - branch_node_index = next( - index - for index, event in enumerate(result.events) - if isinstance(event, NodeRunStartedEvent) and event.node_id == "if_else" - ) - branch_success_index = next( - index - for index, event in enumerate(result.events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == "if_else" - ) - pre_branch_chunk_indices = [ - index - for index, event in enumerate(result.events) - if isinstance(event, NodeRunStreamChunkEvent) and index < branch_success_index - ] - assert len(pre_branch_chunk_indices) == len(_expected_mock_llm_chunks("Initial stream")) + 1 - assert min(pre_branch_chunk_indices) == branch_node_index + 1 - assert max(pre_branch_chunk_indices) < branch_success_index - - resume_chunk_indices = [ - index - for index, event in enumerate(result.events) - if isinstance(event, NodeRunStreamChunkEvent) and event.node_id == scenario["resume_llm"] - ] - assert resume_chunk_indices - resume_start_index = next( - index - for index, event in enumerate(result.events) - if isinstance(event, NodeRunStartedEvent) and event.node_id == scenario["resume_llm"] - ) - resume_success_index = next( - index - for index, event in enumerate(result.events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == scenario["resume_llm"] - ) - assert resume_start_index < min(resume_chunk_indices) - assert max(resume_chunk_indices) < resume_success_index - - started_nodes = [event.node_id for event in result.events if isinstance(event, NodeRunStartedEvent)] - assert started_nodes == ["start", "llm_initial", "if_else", scenario["resume_llm"], scenario["end_node"]] diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_iteration_flatten_output.py b/api/tests/unit_tests/core/workflow/graph_engine/test_iteration_flatten_output.py deleted file mode 100644 index b9bf4be13a0..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_iteration_flatten_output.py +++ /dev/null @@ -1,126 +0,0 @@ -""" -Test cases for the Iteration node's flatten_output functionality. - -This module tests the iteration node's ability to: -1. Flatten array outputs when flatten_output=True (default) -2. Preserve nested array structure when flatten_output=False -""" - -from .test_database_utils import skip_if_database_unavailable -from .test_mock_config import MockConfigBuilder, NodeMockConfig -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def _create_iteration_mock_config(): - """Helper to create a mock config for iteration tests.""" - - def code_inner_handler(node): - pool = node.graph_runtime_state.variable_pool - item_seg = pool.get(["iteration_node", "item"]) - if item_seg is not None: - item = item_seg.to_object() - return {"result": [item, item * 2]} - # This fallback is likely unreachable, but if it is, - # it doesn't simulate iteration with different values as the comment suggests. - return {"result": [1, 2]} - - return ( - MockConfigBuilder() - .with_node_output("code_node", {"result": [1, 2, 3]}) - .with_node_config(NodeMockConfig(node_id="code_inner_node", custom_handler=code_inner_handler)) - .build() - ) - - -@skip_if_database_unavailable() -def test_iteration_with_flatten_output_enabled(): - """ - Test iteration node with flatten_output=True (default behavior). - - The fixture implements an iteration that: - 1. Iterates over [1, 2, 3] - 2. For each item, outputs [item, item*2] - 3. With flatten_output=True, should output [1, 2, 2, 4, 3, 6] - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="iteration_flatten_output_enabled_workflow", - inputs={}, - expected_outputs={"output": [1, 2, 2, 4, 3, 6]}, - description="Iteration with flatten_output=True flattens nested arrays", - use_auto_mock=True, # Use auto-mock to avoid sandbox service - mock_config=_create_iteration_mock_config(), - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Test failed: {result.error}" - assert result.actual_outputs is not None, "Should have outputs" - assert result.actual_outputs == {"output": [1, 2, 2, 4, 3, 6]}, ( - f"Expected flattened output [1, 2, 2, 4, 3, 6], got {result.actual_outputs}" - ) - - -@skip_if_database_unavailable() -def test_iteration_with_flatten_output_disabled(): - """ - Test iteration node with flatten_output=False. - - The fixture implements an iteration that: - 1. Iterates over [1, 2, 3] - 2. For each item, outputs [item, item*2] - 3. With flatten_output=False, should output [[1, 2], [2, 4], [3, 6]] - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="iteration_flatten_output_disabled_workflow", - inputs={}, - expected_outputs={"output": [[1, 2], [2, 4], [3, 6]]}, - description="Iteration with flatten_output=False preserves nested structure", - use_auto_mock=True, # Use auto-mock to avoid sandbox service - mock_config=_create_iteration_mock_config(), - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Test failed: {result.error}" - assert result.actual_outputs is not None, "Should have outputs" - assert result.actual_outputs == {"output": [[1, 2], [2, 4], [3, 6]]}, ( - f"Expected nested output [[1, 2], [2, 4], [3, 6]], got {result.actual_outputs}" - ) - - -@skip_if_database_unavailable() -def test_iteration_flatten_output_comparison(): - """ - Run both flatten_output configurations in parallel to verify the difference. - """ - runner = TableTestRunner() - - test_cases = [ - WorkflowTestCase( - fixture_path="iteration_flatten_output_enabled_workflow", - inputs={}, - expected_outputs={"output": [1, 2, 2, 4, 3, 6]}, - description="flatten_output=True: Flattened output", - use_auto_mock=True, # Use auto-mock to avoid sandbox service - mock_config=_create_iteration_mock_config(), - ), - WorkflowTestCase( - fixture_path="iteration_flatten_output_disabled_workflow", - inputs={}, - expected_outputs={"output": [[1, 2], [2, 4], [3, 6]]}, - description="flatten_output=False: Nested output", - use_auto_mock=True, # Use auto-mock to avoid sandbox service - mock_config=_create_iteration_mock_config(), - ), - ] - - suite_result = runner.run_table_tests(test_cases, parallel=True) - - # Assert all tests passed - assert suite_result.passed_tests == 2, f"Expected 2 passed tests, got {suite_result.passed_tests}" - assert suite_result.failed_tests == 0, f"Expected 0 failed tests, got {suite_result.failed_tests}" - assert suite_result.success_rate == 100.0, f"Expected 100% success rate, got {suite_result.success_rate}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py deleted file mode 100644 index 733fd53bc8a..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py +++ /dev/null @@ -1,85 +0,0 @@ -""" -Test case for loop with inner answer output error scenario. - -This test validates the behavior of a loop containing an answer node -inside the loop that may produce output errors. -""" - -from dify_graph.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) - -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_loop_contains_answer(): - """ - Test loop with inner answer node that may have output errors. - - The fixture implements a loop that: - 1. Iterates 4 times (index 0-3) - 2. Contains an inner answer node that outputs index and item values - 3. Has a break condition when index equals 4 - 4. Tests error handling for answer nodes within loops - """ - fixture_name = "loop_contains_answer" - mock_config = MockConfigBuilder().build() - - case = WorkflowTestCase( - fixture_path=fixture_name, - use_auto_mock=True, - mock_config=mock_config, - query="1", - expected_outputs={"answer": "1\n2\n1 + 2"}, - expected_event_sequence=[ - # Graph start - GraphRunStartedEvent, - # Start - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Loop start - NodeRunStartedEvent, - NodeRunLoopStartedEvent, - # Variable assigner - NodeRunStartedEvent, - NodeRunStreamChunkEvent, # 1 - NodeRunStreamChunkEvent, # \n - NodeRunSucceededEvent, - # Answer - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Loop next - NodeRunLoopNextEvent, - # Variable assigner - NodeRunStartedEvent, - NodeRunStreamChunkEvent, # 2 - NodeRunStreamChunkEvent, # \n - NodeRunSucceededEvent, - # Answer - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Loop end - NodeRunLoopSucceededEvent, - NodeRunStreamChunkEvent, # 1 - NodeRunStreamChunkEvent, # + - NodeRunStreamChunkEvent, # 2 - NodeRunSucceededEvent, - # Answer - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Graph end - GraphRunSucceededEvent, - ], - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - assert result.success, f"Test failed: {result.error}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_node.py b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_node.py deleted file mode 100644 index ad8d777ea69..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_node.py +++ /dev/null @@ -1,41 +0,0 @@ -""" -Test cases for the Loop node functionality using TableTestRunner. - -This module tests the loop node's ability to: -1. Execute iterations with loop variables -2. Handle break conditions correctly -3. Update and propagate loop variables between iterations -4. Output the final loop variable value -""" - -from tests.unit_tests.core.workflow.graph_engine.test_table_runner import ( - TableTestRunner, - WorkflowTestCase, -) - - -def test_loop_with_break_condition(): - """ - Test loop node with break condition. - - The increment_loop_with_break_condition_workflow.yml fixture implements a loop that: - 1. Starts with num=1 - 2. Increments num by 1 each iteration - 3. Breaks when num >= 5 - 4. Should output {"num": 5} - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="increment_loop_with_break_condition_workflow", - inputs={}, # No inputs needed for this test - expected_outputs={"num": 5}, - description="Loop with break condition when num >= 5", - ) - - result = runner.run_test_case(test_case) - - # Assert the test passed - assert result.success, f"Test failed: {result.error}" - assert result.actual_outputs is not None, "Should have outputs" - assert result.actual_outputs == {"num": 5}, f"Expected {{'num': 5}}, got {result.actual_outputs}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py deleted file mode 100644 index 6ff2722f78a..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py +++ /dev/null @@ -1,67 +0,0 @@ -from dify_graph.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) - -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_loop_with_tool(): - fixture_name = "search_dify_from_2023_to_2025" - mock_config = ( - MockConfigBuilder() - .with_tool_response( - { - "text": "mocked search result", - } - ) - .build() - ) - case = WorkflowTestCase( - fixture_path=fixture_name, - use_auto_mock=True, - mock_config=mock_config, - expected_outputs={ - "answer": """- mocked search result -- mocked search result""" - }, - expected_event_sequence=[ - GraphRunStartedEvent, - # START - NodeRunStartedEvent, - NodeRunSucceededEvent, - # LOOP START - NodeRunStartedEvent, - NodeRunLoopStartedEvent, - # 2023 - NodeRunStartedEvent, - NodeRunSucceededEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - NodeRunLoopNextEvent, - # 2024 - NodeRunStartedEvent, - NodeRunSucceededEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - # LOOP END - NodeRunLoopSucceededEvent, - NodeRunStreamChunkEvent, # loop.res - NodeRunSucceededEvent, - # ANSWER - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - assert result.success, f"Test failed: {result.error}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_example.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_example.py deleted file mode 100644 index c511548749c..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_example.py +++ /dev/null @@ -1,281 +0,0 @@ -""" -Example demonstrating the auto-mock system for testing workflows. - -This example shows how to test workflows with third-party service nodes -without making actual API calls. -""" - -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def example_test_llm_workflow(): - """ - Example: Testing a workflow with an LLM node. - - This demonstrates how to test a workflow that uses an LLM service - without making actual API calls to OpenAI, Anthropic, etc. - """ - print("\n=== Example: Testing LLM Workflow ===\n") - - # Initialize the test runner - runner = TableTestRunner() - - # Configure mock responses - mock_config = MockConfigBuilder().with_llm_response("I'm a helpful AI assistant. How can I help you today?").build() - - # Define the test case - test_case = WorkflowTestCase( - fixture_path="llm-simple", - inputs={"query": "Hello, AI!"}, - expected_outputs={"answer": "I'm a helpful AI assistant. How can I help you today?"}, - description="Testing LLM workflow with mocked response", - use_auto_mock=True, # Enable auto-mocking - mock_config=mock_config, - ) - - # Run the test - result = runner.run_test_case(test_case) - - if result.success: - print("✅ Test passed!") - print(f" Input: {test_case.inputs['query']}") - print(f" Output: {result.actual_outputs['answer']}") - print(f" Execution time: {result.execution_time:.2f}s") - else: - print(f"❌ Test failed: {result.error}") - - return result.success - - -def example_test_with_custom_outputs(): - """ - Example: Testing with custom outputs for specific nodes. - - This shows how to provide different mock outputs for specific node IDs, - useful when testing complex workflows with multiple LLM/tool nodes. - """ - print("\n=== Example: Custom Node Outputs ===\n") - - runner = TableTestRunner() - - # Configure mock with specific outputs for different nodes - mock_config = MockConfigBuilder().build() - - # Set custom output for a specific LLM node - mock_config.set_node_outputs( - "llm_node", - { - "text": "This is a custom response for the specific LLM node", - "usage": { - "prompt_tokens": 50, - "completion_tokens": 20, - "total_tokens": 70, - }, - "finish_reason": "stop", - }, - ) - - test_case = WorkflowTestCase( - fixture_path="llm-simple", - inputs={"query": "Tell me about custom outputs"}, - expected_outputs={"answer": "This is a custom response for the specific LLM node"}, - description="Testing with custom node outputs", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - if result.success: - print("✅ Test with custom outputs passed!") - print(f" Custom output: {result.actual_outputs['answer']}") - else: - print(f"❌ Test failed: {result.error}") - - return result.success - - -def example_test_http_and_tool_workflow(): - """ - Example: Testing a workflow with HTTP request and tool nodes. - - This demonstrates mocking external HTTP calls and tool executions. - """ - print("\n=== Example: HTTP and Tool Workflow ===\n") - - runner = TableTestRunner() - - # Configure mocks for HTTP and Tool nodes - mock_config = MockConfigBuilder().build() - - # Mock HTTP response - mock_config.set_node_outputs( - "http_node", - { - "status_code": 200, - "body": '{"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]}', - "headers": {"content-type": "application/json"}, - }, - ) - - # Mock tool response (e.g., JSON parser) - mock_config.set_node_outputs( - "tool_node", - { - "result": {"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]}, - }, - ) - - test_case = WorkflowTestCase( - fixture_path="http-tool-workflow", - inputs={"url": "https://api.example.com/users"}, - expected_outputs={ - "status_code": 200, - "parsed_data": {"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]}, - }, - description="Testing HTTP and Tool workflow", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - if result.success: - print("✅ HTTP and Tool workflow test passed!") - print(f" HTTP Status: {result.actual_outputs['status_code']}") - print(f" Parsed Data: {result.actual_outputs['parsed_data']}") - else: - print(f"❌ Test failed: {result.error}") - - return result.success - - -def example_test_error_simulation(): - """ - Example: Simulating errors in specific nodes. - - This shows how to test error handling in workflows by simulating - failures in specific nodes. - """ - print("\n=== Example: Error Simulation ===\n") - - runner = TableTestRunner() - - # Configure mock to simulate an error - mock_config = MockConfigBuilder().build() - mock_config.set_node_error("llm_node", "API rate limit exceeded") - - test_case = WorkflowTestCase( - fixture_path="llm-simple", - inputs={"query": "This will fail"}, - expected_outputs={}, # We expect failure - description="Testing error handling", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - if not result.success: - print("✅ Error simulation worked as expected!") - print(f" Simulated error: {result.error}") - else: - print("❌ Expected failure but test succeeded") - - return not result.success # Success means we got the expected error - - -def example_test_with_delays(): - """ - Example: Testing with simulated execution delays. - - This demonstrates how to simulate realistic execution times - for performance testing. - """ - print("\n=== Example: Simulated Delays ===\n") - - runner = TableTestRunner() - - # Configure mock with delays - mock_config = ( - MockConfigBuilder() - .with_delays(True) # Enable delay simulation - .with_llm_response("Response after delay") - .build() - ) - - # Add specific delay for the LLM node - from .test_mock_config import NodeMockConfig - - node_config = NodeMockConfig( - node_id="llm_node", - outputs={"text": "Response after delay"}, - delay=0.5, # 500ms delay - ) - mock_config.set_node_config("llm_node", node_config) - - test_case = WorkflowTestCase( - fixture_path="llm-simple", - inputs={"query": "Test with delay"}, - expected_outputs={"answer": "Response after delay"}, - description="Testing with simulated delays", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - if result.success: - print("✅ Delay simulation test passed!") - print(f" Execution time: {result.execution_time:.2f}s") - print(" (Should be >= 0.5s due to simulated delay)") - else: - print(f"❌ Test failed: {result.error}") - - return result.success and result.execution_time >= 0.5 - - -def run_all_examples(): - """Run all example tests.""" - print("\n" + "=" * 50) - print("AUTO-MOCK SYSTEM EXAMPLES") - print("=" * 50) - - examples = [ - example_test_llm_workflow, - example_test_with_custom_outputs, - example_test_http_and_tool_workflow, - example_test_error_simulation, - example_test_with_delays, - ] - - results = [] - for example in examples: - try: - results.append(example()) - except Exception as e: - print(f"\n❌ Example failed with exception: {e}") - results.append(False) - - print("\n" + "=" * 50) - print("SUMMARY") - print("=" * 50) - - passed = sum(results) - total = len(results) - print(f"\n✅ Passed: {passed}/{total}") - - if passed == total: - print("\n🎉 All examples passed successfully!") - else: - print(f"\n⚠️ {total - passed} example(s) failed") - - return passed == total - - -if __name__ == "__main__": - import sys - - success = run_all_examples() - sys.exit(0 if success else 1) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py index 93010eea542..88989db8565 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py @@ -7,10 +7,11 @@ requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request from typing import TYPE_CHECKING, Any +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.base.node import Node + from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base.node import Node from .test_mock_nodes import ( MockAgentNode, @@ -28,8 +29,8 @@ from .test_mock_nodes import ( ) if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState from .test_mock_config import MockConfig @@ -111,7 +112,7 @@ class MockNodeFactory(DifyNodeFactory): mock_config=self.mock_config, http_request_config=self._http_request_config, http_client=self._http_request_http_client, - tool_file_manager_factory=self._http_request_tool_file_manager_factory, + tool_file_manager_factory=self._bound_tool_file_manager_factory, file_manager=self._http_request_file_manager, ) elif node_type in { diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py deleted file mode 100644 index 3e4247f33f5..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py +++ /dev/null @@ -1,199 +0,0 @@ -""" -Simple test to verify MockNodeFactory works with iteration nodes. -""" - -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import BuiltinNodeTypes -from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfigBuilder -from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory - - -def test_mock_factory_registers_iteration_node(): - """Test that MockNodeFactory has iteration node registered.""" - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool - - # Create a MockNodeFactory instance - graph_init_params = GraphInitParams( - workflow_id="test", - graph_config={"nodes": [], "edges": []}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test", - "app_id": "test", - "user_id": "test", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.SERVICE_API, - } - }, - call_depth=0, - ) - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), - start_at=0, - total_tokens=0, - node_run_steps=0, - ) - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=None, - ) - - # Check that iteration node is registered - assert BuiltinNodeTypes.ITERATION in factory._mock_node_types - print("✓ Iteration node is registered in MockNodeFactory") - - # Check that loop node is registered - assert BuiltinNodeTypes.LOOP in factory._mock_node_types - print("✓ Loop node is registered in MockNodeFactory") - - # Check the class types - from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockIterationNode, MockLoopNode - - assert factory._mock_node_types[BuiltinNodeTypes.ITERATION] == MockIterationNode - print("✓ Iteration node maps to MockIterationNode class") - - assert factory._mock_node_types[BuiltinNodeTypes.LOOP] == MockLoopNode - print("✓ Loop node maps to MockLoopNode class") - - -def test_mock_iteration_node_preserves_config(): - """Test that MockIterationNode preserves mock configuration.""" - - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool - from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockIterationNode - - # Create mock config - mock_config = MockConfigBuilder().with_llm_response("Test response").build() - - # Create minimal graph init params - graph_init_params = GraphInitParams( - workflow_id="test", - graph_config={"nodes": [], "edges": []}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test", - "app_id": "test", - "user_id": "test", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.SERVICE_API, - } - }, - call_depth=0, - ) - - # Create minimal runtime state - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), - start_at=0, - total_tokens=0, - node_run_steps=0, - ) - - # Create mock iteration node - node_config = { - "id": "iter1", - "data": { - "type": "iteration", - "title": "Test", - "iterator_selector": ["start", "items"], - "output_selector": ["node", "text"], - "start_node_id": "node1", - }, - } - - mock_node = MockIterationNode( - id="iter1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - ) - - # Verify the mock config is preserved - assert mock_node.mock_config == mock_config - print("✓ MockIterationNode preserves mock configuration") - - # Check that _create_graph_engine method exists and is overridden - assert hasattr(mock_node, "_create_graph_engine") - assert MockIterationNode._create_graph_engine != MockIterationNode.__bases__[1]._create_graph_engine - print("✓ MockIterationNode overrides _create_graph_engine method") - - -def test_mock_loop_node_preserves_config(): - """Test that MockLoopNode preserves mock configuration.""" - - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool - from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockLoopNode - - # Create mock config - mock_config = MockConfigBuilder().with_http_response({"status": 200}).build() - - # Create minimal graph init params - graph_init_params = GraphInitParams( - workflow_id="test", - graph_config={"nodes": [], "edges": []}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test", - "app_id": "test", - "user_id": "test", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.SERVICE_API, - } - }, - call_depth=0, - ) - - # Create minimal runtime state - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), - start_at=0, - total_tokens=0, - node_run_steps=0, - ) - - # Create mock loop node - node_config = { - "id": "loop1", - "data": { - "type": "loop", - "title": "Test", - "loop_count": 3, - "start_node_id": "node1", - "loop_variables": [], - "outputs": {}, - "break_conditions": [], - "logical_operator": "and", - }, - } - - mock_node = MockLoopNode( - id="loop1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - ) - - # Verify the mock config is preserved - assert mock_node.mock_config == mock_config - print("✓ MockLoopNode preserves mock configuration") - - # Check that _create_graph_engine method exists and is overridden - assert hasattr(mock_node, "_create_graph_engine") - assert MockLoopNode._create_graph_engine != MockLoopNode.__bases__[1]._create_graph_engine - print("✓ MockLoopNode overrides _create_graph_engine method") - - -if __name__ == "__main__": - test_mock_factory_registers_iteration_node() - test_mock_iteration_node_preserves_config() - test_mock_loop_node_preserves_config() - print("\n✅ All tests passed! MockNodeFactory now supports iteration and loop nodes.") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 454263bef9a..8b7fbd1b303 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -10,30 +10,31 @@ from collections.abc import Generator, Mapping from typing import TYPE_CHECKING, Any, Optional from unittest.mock import MagicMock +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent +from graphon.nodes.code import CodeNode +from graphon.nodes.document_extractor import DocumentExtractorNode +from graphon.nodes.http_request import HttpRequestNode +from graphon.nodes.llm import LLMNode +from graphon.nodes.llm.file_saver import LLMFileSaver +from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory +from graphon.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol +from graphon.nodes.parameter_extractor import ParameterExtractorNode +from graphon.nodes.protocols import FileReferenceFactoryProtocol, HttpClientProtocol, ToolFileManagerProtocol +from graphon.nodes.question_classifier import QuestionClassifierNode +from graphon.nodes.template_transform import TemplateTransformNode +from graphon.nodes.tool import ToolNode +from graphon.template_rendering import Jinja2TemplateRenderer, TemplateRenderError + from core.model_manager import ModelInstance +from core.workflow.node_runtime import DifyToolNodeRuntime from core.workflow.nodes.agent import AgentNode from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode -from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent -from dify_graph.nodes.code import CodeNode -from dify_graph.nodes.document_extractor import DocumentExtractorNode -from dify_graph.nodes.http_request import HttpRequestNode -from dify_graph.nodes.llm import LLMNode -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer -from dify_graph.nodes.parameter_extractor import ParameterExtractorNode -from dify_graph.nodes.protocols import HttpClientProtocol, ToolFileManagerProtocol -from dify_graph.nodes.question_classifier import QuestionClassifierNode -from dify_graph.nodes.template_transform import TemplateTransformNode -from dify_graph.nodes.template_transform.template_renderer import ( - Jinja2TemplateRenderer, - TemplateRenderError, -) -from dify_graph.nodes.tool import ToolNode if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState from .test_mock_config import MockConfig @@ -66,20 +67,26 @@ class MockNodeMixin: kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider)) kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory)) kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance)) + kwargs.setdefault("prompt_message_serializer", MagicMock(spec=PromptMessageSerializerProtocol)) # LLM-like nodes now require an http_client; provide a mock by default for tests. kwargs.setdefault("http_client", MagicMock(spec=HttpClientProtocol)) - if isinstance(self, (LLMNode, QuestionClassifierNode)): - kwargs.setdefault("template_renderer", MagicMock(spec=TemplateRenderer)) + + if isinstance(self, (LLMNode, QuestionClassifierNode)): + kwargs.setdefault("llm_file_saver", MagicMock(spec=LLMFileSaver)) + + if isinstance(self, HttpRequestNode): + kwargs.setdefault("file_reference_factory", MagicMock(spec=FileReferenceFactoryProtocol)) # Ensure TemplateTransformNode receives a renderer now required by constructor if isinstance(self, TemplateTransformNode): - kwargs.setdefault("template_renderer", _TestJinja2Renderer()) + kwargs.setdefault("jinja2_template_renderer", _TestJinja2Renderer()) # Provide default tool_file_manager_factory for ToolNode subclasses - from dify_graph.nodes.tool import ToolNode as _ToolNode # local import to avoid cycles + from graphon.nodes.tool import ToolNode as _ToolNode # local import to avoid cycles if isinstance(self, _ToolNode): kwargs.setdefault("tool_file_manager_factory", MagicMock(spec=ToolFileManagerProtocol)) + kwargs.setdefault("runtime", DifyToolNodeRuntime(graph_init_params.run_context)) if isinstance(self, AgentNode): presentation_provider = MagicMock() @@ -596,8 +603,8 @@ class MockDocumentExtractorNode(MockNodeMixin, DocumentExtractorNode): ) -from dify_graph.nodes.iteration import IterationNode -from dify_graph.nodes.loop import LoopNode +from graphon.nodes.iteration import IterationNode +from graphon.nodes.loop import LoopNode class MockIterationNode(MockNodeMixin, IterationNode): @@ -611,11 +618,11 @@ class MockIterationNode(MockNodeMixin, IterationNode): def _create_graph_engine(self, index: int, item: Any): """Create a graph engine with MockNodeFactory instead of DifyNodeFactory.""" # Import dependencies - from dify_graph.entities import GraphInitParams - from dify_graph.graph import Graph - from dify_graph.graph_engine import GraphEngine, GraphEngineConfig - from dify_graph.graph_engine.command_channels import InMemoryChannel - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.graph import Graph + from graphon.graph_engine import GraphEngine, GraphEngineConfig + from graphon.graph_engine.command_channels import InMemoryChannel + from graphon.runtime import GraphRuntimeState # Import our MockNodeFactory instead of DifyNodeFactory from .test_mock_factory import MockNodeFactory @@ -656,7 +663,7 @@ class MockIterationNode(MockNodeMixin, IterationNode): ) if not iteration_graph: - from dify_graph.nodes.iteration.exc import IterationGraphNotFoundError + from graphon.nodes.iteration.exc import IterationGraphNotFoundError raise IterationGraphNotFoundError("iteration graph not found") @@ -683,11 +690,11 @@ class MockLoopNode(MockNodeMixin, LoopNode): def _create_graph_engine(self, start_at, root_node_id: str): """Create a graph engine with MockNodeFactory instead of DifyNodeFactory.""" # Import dependencies - from dify_graph.entities import GraphInitParams - from dify_graph.graph import Graph - from dify_graph.graph_engine import GraphEngine, GraphEngineConfig - from dify_graph.graph_engine.command_channels import InMemoryChannel - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.graph import Graph + from graphon.graph_engine import GraphEngine, GraphEngineConfig + from graphon.graph_engine.command_channels import InMemoryChannel + from graphon.runtime import GraphRuntimeState # Import our MockNodeFactory instead of DifyNodeFactory from .test_mock_factory import MockNodeFactory diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py deleted file mode 100644 index a8398e8f79f..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py +++ /dev/null @@ -1,670 +0,0 @@ -""" -Test cases for Mock Template Transform and Code nodes. - -This module tests the functionality of MockTemplateTransformNode and MockCodeNode -to ensure they work correctly with the TableTestRunner. -""" - -from configs import dify_config -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.nodes.code.limits import CodeNodeLimits -from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig -from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory -from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockCodeNode, MockTemplateTransformNode - -DEFAULT_CODE_LIMITS = CodeNodeLimits( - max_string_length=dify_config.CODE_MAX_STRING_LENGTH, - max_number=dify_config.CODE_MAX_NUMBER, - min_number=dify_config.CODE_MIN_NUMBER, - max_precision=dify_config.CODE_MAX_PRECISION, - max_depth=dify_config.CODE_MAX_DEPTH, - max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH, - max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH, - max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH, -) - - -class _NoopCodeExecutor: - def execute(self, *, language: object, code: str, inputs: dict[str, object]) -> dict[str, object]: - _ = (language, code, inputs) - return {} - - def is_execution_error(self, error: Exception) -> bool: - _ = error - return False - - -class TestMockTemplateTransformNode: - """Test cases for MockTemplateTransformNode.""" - - def test_mock_template_transform_node_default_output(self): - """Test that MockTemplateTransformNode processes templates with Jinja2.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create mock config - mock_config = MockConfig() - - # Create node config - node_config = { - "id": "template_node_1", - "data": { - "type": "template-transform", - "title": "Test Template Transform", - "variables": [], - "template": "Hello {{ name }}", - }, - } - - # Create mock node - mock_node = MockTemplateTransformNode( - id="template_node_1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - ) - - # Run the node - result = mock_node._run() - - # Verify results - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert "output" in result.outputs - # The template "Hello {{ name }}" with no name variable renders as "Hello " - assert result.outputs["output"] == "Hello " - - def test_mock_template_transform_node_custom_output(self): - """Test that MockTemplateTransformNode returns custom configured output.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create mock config with custom output - mock_config = ( - MockConfigBuilder().with_node_output("template_node_1", {"output": "Custom template output"}).build() - ) - - # Create node config - node_config = { - "id": "template_node_1", - "data": { - "type": "template-transform", - "title": "Test Template Transform", - "variables": [], - "template": "Hello {{ name }}", - }, - } - - # Create mock node - mock_node = MockTemplateTransformNode( - id="template_node_1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - ) - - # Run the node - result = mock_node._run() - - # Verify results - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert "output" in result.outputs - assert result.outputs["output"] == "Custom template output" - - def test_mock_template_transform_node_error_simulation(self): - """Test that MockTemplateTransformNode can simulate errors.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create mock config with error - mock_config = MockConfigBuilder().with_node_error("template_node_1", "Simulated template error").build() - - # Create node config - node_config = { - "id": "template_node_1", - "data": { - "type": "template-transform", - "title": "Test Template Transform", - "variables": [], - "template": "Hello {{ name }}", - }, - } - - # Create mock node - mock_node = MockTemplateTransformNode( - id="template_node_1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - ) - - # Run the node - result = mock_node._run() - - # Verify results - assert result.status == WorkflowNodeExecutionStatus.FAILED - assert result.error == "Simulated template error" - - def test_mock_template_transform_node_with_variables(self): - """Test that MockTemplateTransformNode processes templates with variables.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool - from dify_graph.variables import StringVariable - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - ) - - # Add a variable to the pool - variable_pool.add(["test", "name"], StringVariable(name="name", value="World", selector=["test", "name"])) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create mock config - mock_config = MockConfig() - - # Create node config with a variable - node_config = { - "id": "template_node_1", - "data": { - "type": "template-transform", - "title": "Test Template Transform", - "variables": [{"variable": "name", "value_selector": ["test", "name"]}], - "template": "Hello {{ name }}!", - }, - } - - # Create mock node - mock_node = MockTemplateTransformNode( - id="template_node_1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - ) - - # Run the node - result = mock_node._run() - - # Verify results - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert "output" in result.outputs - assert result.outputs["output"] == "Hello World!" - - -class TestMockCodeNode: - """Test cases for MockCodeNode.""" - - def test_mock_code_node_default_output(self): - """Test that MockCodeNode returns default output.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create mock config - mock_config = MockConfig() - - # Create node config - node_config = { - "id": "code_node_1", - "data": { - "type": "code", - "title": "Test Code", - "variables": [], - "code_language": "python3", - "code": "result = 'test'", - "outputs": {}, # Empty outputs for default case - }, - } - - # Create mock node - mock_node = MockCodeNode( - id="code_node_1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - code_executor=_NoopCodeExecutor(), - code_limits=DEFAULT_CODE_LIMITS, - ) - - # Run the node - result = mock_node._run() - - # Verify results - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert "result" in result.outputs - assert result.outputs["result"] == "mocked code execution result" - - def test_mock_code_node_with_output_schema(self): - """Test that MockCodeNode generates outputs based on schema.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create mock config - mock_config = MockConfig() - - # Create node config with output schema - node_config = { - "id": "code_node_1", - "data": { - "type": "code", - "title": "Test Code", - "variables": [], - "code_language": "python3", - "code": "name = 'test'\ncount = 42\nitems = ['a', 'b']", - "outputs": { - "name": {"type": "string"}, - "count": {"type": "number"}, - "items": {"type": "array[string]"}, - }, - }, - } - - # Create mock node - mock_node = MockCodeNode( - id="code_node_1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - code_executor=_NoopCodeExecutor(), - code_limits=DEFAULT_CODE_LIMITS, - ) - - # Run the node - result = mock_node._run() - - # Verify results - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert "name" in result.outputs - assert result.outputs["name"] == "mocked_name" - assert "count" in result.outputs - assert result.outputs["count"] == 42 - assert "items" in result.outputs - assert result.outputs["items"] == ["item1", "item2"] - - def test_mock_code_node_custom_output(self): - """Test that MockCodeNode returns custom configured output.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create mock config with custom output - mock_config = ( - MockConfigBuilder() - .with_node_output("code_node_1", {"result": "Custom code result", "status": "success"}) - .build() - ) - - # Create node config - node_config = { - "id": "code_node_1", - "data": { - "type": "code", - "title": "Test Code", - "variables": [], - "code_language": "python3", - "code": "result = 'test'", - "outputs": {}, # Empty outputs for default case - }, - } - - # Create mock node - mock_node = MockCodeNode( - id="code_node_1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - code_executor=_NoopCodeExecutor(), - code_limits=DEFAULT_CODE_LIMITS, - ) - - # Run the node - result = mock_node._run() - - # Verify results - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert "result" in result.outputs - assert result.outputs["result"] == "Custom code result" - assert "status" in result.outputs - assert result.outputs["status"] == "success" - - -class TestMockNodeFactory: - """Test cases for MockNodeFactory with new node types.""" - - def test_code_and_template_nodes_mocked_by_default(self): - """Test that CODE and TEMPLATE_TRANSFORM nodes are mocked by default (they require SSRF proxy).""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create factory - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - # Verify that CODE and TEMPLATE_TRANSFORM ARE mocked by default (they require SSRF proxy) - assert factory.should_mock_node(BuiltinNodeTypes.CODE) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - # Verify that other third-party service nodes ARE also mocked by default - assert factory.should_mock_node(BuiltinNodeTypes.LLM) - assert factory.should_mock_node(BuiltinNodeTypes.AGENT) - - def test_factory_creates_mock_template_transform_node(self): - """Test that MockNodeFactory creates MockTemplateTransformNode for template-transform type.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create factory - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - # Create node config - node_config = { - "id": "template_node_1", - "data": { - "type": "template-transform", - "title": "Test Template", - "variables": [], - "template": "Hello {{ name }}", - }, - } - - # Create node through factory - node = factory.create_node(node_config) - - # Verify the correct mock type was created - assert isinstance(node, MockTemplateTransformNode) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - def test_factory_creates_mock_code_node(self): - """Test that MockNodeFactory creates MockCodeNode for code type.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create factory - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - # Create node config - node_config = { - "id": "code_node_1", - "data": { - "type": "code", - "title": "Test Code", - "variables": [], - "code_language": "python3", - "code": "result = 42", - "outputs": {}, # Required field for CodeNodeData - }, - } - - # Create node through factory - node = factory.create_node(node_config) - - # Verify the correct mock type was created - assert isinstance(node, MockCodeNode) - assert factory.should_mock_node(BuiltinNodeTypes.CODE) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py deleted file mode 100644 index 5b35b3310a0..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py +++ /dev/null @@ -1,231 +0,0 @@ -""" -Simple test to validate the auto-mock system without external dependencies. -""" - -import sys - -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import BuiltinNodeTypes -from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig -from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory - - -def test_mock_config_builder(): - """Test the MockConfigBuilder fluent interface.""" - print("Testing MockConfigBuilder...") - - config = ( - MockConfigBuilder() - .with_llm_response("LLM response") - .with_agent_response("Agent response") - .with_tool_response({"tool": "output"}) - .with_retrieval_response("Retrieval content") - .with_http_response({"status_code": 201, "body": "created"}) - .with_node_output("node1", {"output": "value"}) - .with_node_error("node2", "error message") - .with_delays(True) - .build() - ) - - assert config.default_llm_response == "LLM response" - assert config.default_agent_response == "Agent response" - assert config.default_tool_response == {"tool": "output"} - assert config.default_retrieval_response == "Retrieval content" - assert config.default_http_response == {"status_code": 201, "body": "created"} - assert config.simulate_delays is True - - node1_config = config.get_node_config("node1") - assert node1_config is not None - assert node1_config.outputs == {"output": "value"} - - node2_config = config.get_node_config("node2") - assert node2_config is not None - assert node2_config.error == "error message" - - print("✓ MockConfigBuilder test passed") - - -def test_mock_config_operations(): - """Test MockConfig operations.""" - print("Testing MockConfig operations...") - - config = MockConfig() - - # Test setting node outputs - config.set_node_outputs("test_node", {"result": "test_value"}) - node_config = config.get_node_config("test_node") - assert node_config is not None - assert node_config.outputs == {"result": "test_value"} - - # Test setting node error - config.set_node_error("error_node", "Test error") - error_config = config.get_node_config("error_node") - assert error_config is not None - assert error_config.error == "Test error" - - # Test default configs by node type - config.set_default_config(BuiltinNodeTypes.LLM, {"temperature": 0.7}) - llm_config = config.get_default_config(BuiltinNodeTypes.LLM) - assert llm_config == {"temperature": 0.7} - - print("✓ MockConfig operations test passed") - - -def test_node_mock_config(): - """Test NodeMockConfig.""" - print("Testing NodeMockConfig...") - - # Test with custom handler - def custom_handler(node): - return {"custom": "output"} - - node_config = NodeMockConfig( - node_id="test_node", outputs={"text": "test"}, error=None, delay=0.5, custom_handler=custom_handler - ) - - assert node_config.node_id == "test_node" - assert node_config.outputs == {"text": "test"} - assert node_config.delay == 0.5 - assert node_config.custom_handler is not None - - # Test custom handler - result = node_config.custom_handler(None) - assert result == {"custom": "output"} - - print("✓ NodeMockConfig test passed") - - -def test_mock_factory_detection(): - """Test MockNodeFactory node type detection.""" - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool - - print("Testing MockNodeFactory detection...") - - graph_init_params = GraphInitParams( - workflow_id="test", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test", - "app_id": "test", - "user_id": "test", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.SERVICE_API, - } - }, - call_depth=0, - ) - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), - start_at=0, - total_tokens=0, - node_run_steps=0, - ) - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=None, - ) - - # Test that third-party service nodes are identified for mocking - assert factory.should_mock_node(BuiltinNodeTypes.LLM) - assert factory.should_mock_node(BuiltinNodeTypes.AGENT) - assert factory.should_mock_node(BuiltinNodeTypes.TOOL) - assert factory.should_mock_node(BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL) - assert factory.should_mock_node(BuiltinNodeTypes.HTTP_REQUEST) - assert factory.should_mock_node(BuiltinNodeTypes.PARAMETER_EXTRACTOR) - assert factory.should_mock_node(BuiltinNodeTypes.DOCUMENT_EXTRACTOR) - - # Test that CODE and TEMPLATE_TRANSFORM are mocked (they require SSRF proxy) - assert factory.should_mock_node(BuiltinNodeTypes.CODE) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - # Test that non-service nodes are not mocked - assert not factory.should_mock_node(BuiltinNodeTypes.START) - assert not factory.should_mock_node(BuiltinNodeTypes.END) - assert not factory.should_mock_node(BuiltinNodeTypes.IF_ELSE) - assert not factory.should_mock_node(BuiltinNodeTypes.VARIABLE_AGGREGATOR) - - print("✓ MockNodeFactory detection test passed") - - -def test_mock_factory_registration(): - """Test registering and unregistering mock node types.""" - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool - - print("Testing MockNodeFactory registration...") - - graph_init_params = GraphInitParams( - workflow_id="test", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test", - "app_id": "test", - "user_id": "test", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.SERVICE_API, - } - }, - call_depth=0, - ) - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), - start_at=0, - total_tokens=0, - node_run_steps=0, - ) - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=None, - ) - - # TEMPLATE_TRANSFORM is mocked by default (requires SSRF proxy) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - # Unregister mock - factory.unregister_mock_node_type(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - assert not factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - # Register custom mock (using a dummy class for testing) - class DummyMockNode: - pass - - factory.register_mock_node_type(BuiltinNodeTypes.TEMPLATE_TRANSFORM, DummyMockNode) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - print("✓ MockNodeFactory registration test passed") - - -def run_all_tests(): - """Run all tests.""" - print("\n=== Running Auto-Mock System Tests ===\n") - - try: - test_mock_config_builder() - test_mock_config_operations() - test_node_mock_config() - test_mock_factory_detection() - test_mock_factory_registration() - - print("\n=== All tests passed! ✅ ===\n") - return True - except AssertionError as e: - print(f"\n❌ Test failed: {e}") - return False - except Exception as e: - print(f"\n❌ Unexpected error: {e}") - import traceback - - traceback.print_exc() - return False - - -if __name__ == "__main__": - success = run_all_tests() - sys.exit(0 if success else 1) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py index e681b39cc75..8311a1e847a 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py @@ -4,32 +4,33 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any, Protocol -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.graph import Graph -from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from dify_graph.graph_engine.config import GraphEngineConfig -from dify_graph.graph_engine.graph_engine import GraphEngine -from dify_graph.graph_events import ( +from graphon.entities import WorkflowStartReason +from graphon.graph import Graph +from graphon.graph_engine import GraphEngine, GraphEngineConfig +from graphon.graph_engine.command_channels import InMemoryChannel +from graphon.graph_events import ( GraphRunPausedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunSucceededEvent, ) -from dify_graph.nodes.base.entities import OutputVariableEntity -from dify_graph.nodes.end.end_node import EndNode -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction -from dify_graph.nodes.human_input.enums import HumanInputFormStatus -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import ( +from graphon.nodes.base.entities import OutputVariableEntity +from graphon.nodes.end.end_node import EndNode +from graphon.nodes.end.entities import EndNodeData +from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction +from graphon.nodes.human_input.enums import HumanInputFormStatus +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.nodes.start.entities import StartNodeData +from graphon.nodes.start.start_node import StartNode +from graphon.runtime import GraphRuntimeState, VariablePool + +from core.repositories.human_input_repository import ( FormCreateParams, HumanInputFormEntity, HumanInputFormRepository, ) -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from libs.datetime_utils import naive_utc_now from tests.workflow_test_utils import build_test_graph_init_params @@ -67,7 +68,7 @@ class StaticForm(HumanInputFormEntity): return self.form_id @property - def web_app_token(self) -> str | None: + def submission_token(self) -> str | None: return "token" @property @@ -103,7 +104,7 @@ class StaticRepo(HumanInputFormRepository): def __init__(self, forms_by_node_id: Mapping[str, HumanInputFormEntity]) -> None: self._forms_by_node_id = dict(forms_by_node_id) - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: + def get_form(self, node_id: str) -> HumanInputFormEntity | None: return self._forms_by_node_id.get(node_id) def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: @@ -112,7 +113,7 @@ class StaticRepo(HumanInputFormRepository): def _build_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -159,6 +160,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) human_b_config = {"id": "human_b", "data": human_data.model_dump()} @@ -168,6 +170,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) end_data = EndNodeData( diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py deleted file mode 100644 index 60167c0441a..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py +++ /dev/null @@ -1,333 +0,0 @@ -import time -from collections.abc import Mapping -from dataclasses import dataclass -from datetime import datetime, timedelta -from typing import Any - -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.graph import Graph -from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from dify_graph.graph_engine.config import GraphEngineConfig -from dify_graph.graph_engine.graph_engine import GraphEngine -from dify_graph.graph_events import ( - GraphRunPausedEvent, - GraphRunStartedEvent, - NodeRunPauseRequestedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, -) -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole -from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction -from dify_graph.nodes.human_input.enums import HumanInputFormStatus -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from libs.datetime_utils import naive_utc_now -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_mock_config import MockConfig, NodeMockConfig -from .test_mock_nodes import MockLLMNode - - -@dataclass -class StaticForm(HumanInputFormEntity): - form_id: str - rendered: str - is_submitted: bool - action_id: str | None = None - data: Mapping[str, Any] | None = None - status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING - expiration: datetime = naive_utc_now() + timedelta(days=1) - - @property - def id(self) -> str: - return self.form_id - - @property - def web_app_token(self) -> str | None: - return "token" - - @property - def recipients(self) -> list: - return [] - - @property - def rendered_content(self) -> str: - return self.rendered - - @property - def selected_action_id(self) -> str | None: - return self.action_id - - @property - def submitted_data(self) -> Mapping[str, Any] | None: - return self.data - - @property - def submitted(self) -> bool: - return self.is_submitted - - @property - def status(self) -> HumanInputFormStatus: - return self.status_value - - @property - def expiration_time(self) -> datetime: - return self.expiration - - -class StaticRepo(HumanInputFormRepository): - def __init__(self, forms_by_node_id: Mapping[str, HumanInputFormEntity]) -> None: - self._forms_by_node_id = dict(forms_by_node_id) - - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: - return self._forms_by_node_id.get(node_id) - - def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: - raise AssertionError("create_form should not be called in resume scenario") - - -class DelayedHumanInputNode(HumanInputNode): - def __init__(self, delay_seconds: float, **kwargs: Any) -> None: - super().__init__(**kwargs) - self._delay_seconds = delay_seconds - - def _run(self): - if self._delay_seconds > 0: - time.sleep(self._delay_seconds) - yield from super()._run() - - -def _build_runtime_state() -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="exec-1", - ), - user_inputs={}, - conversation_variables=[], - ) - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - human_data = HumanInputNodeData( - title="Human Input", - form_content="Human input required", - inputs=[], - user_actions=[UserAction(id="approve", title="Approve")], - ) - - human_a_config = {"id": "human_a", "data": human_data.model_dump()} - human_a = HumanInputNode( - id=human_a_config["id"], - config=human_a_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - form_repository=repo, - ) - - human_b_config = {"id": "human_b", "data": human_data.model_dump()} - human_b = DelayedHumanInputNode( - id=human_b_config["id"], - config=human_b_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - form_repository=repo, - delay_seconds=0.2, - ) - - llm_data = LLMNodeData( - title="LLM A", - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text="Prompt A", - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - structured_output_enabled=False, - ) - llm_config = {"id": "llm_a", "data": llm_data.model_dump()} - llm_a = MockLLMNode( - id=llm_config["id"], - config=llm_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - mock_config=mock_config, - ) - - return ( - Graph.new() - .add_root(start_node) - .add_node(human_a, from_node_id="start") - .add_node(human_b, from_node_id="start") - .add_node(llm_a, from_node_id="human_a", source_handle="approve") - .build() - ) - - -def test_parallel_human_input_pause_preserves_node_finished() -> None: - runtime_state = _build_runtime_state() - - runtime_state.graph_execution.start() - runtime_state.register_paused_node("human_a") - runtime_state.register_paused_node("human_b") - - submitted = StaticForm( - form_id="form-a", - rendered="rendered", - is_submitted=True, - action_id="approve", - data={}, - status_value=HumanInputFormStatus.SUBMITTED, - ) - pending = StaticForm( - form_id="form-b", - rendered="rendered", - is_submitted=False, - action_id=None, - data=None, - status_value=HumanInputFormStatus.WAITING, - ) - repo = StaticRepo({"human_a": submitted, "human_b": pending}) - - mock_config = MockConfig() - mock_config.simulate_delays = True - mock_config.set_node_config( - "llm_a", - NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5), - ) - - graph = _build_graph(runtime_state, repo, mock_config) - engine = GraphEngine( - workflow_id="workflow", - graph=graph, - graph_runtime_state=runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig( - min_workers=2, - max_workers=2, - scale_up_threshold=1, - scale_down_idle_time=30.0, - ), - ) - - events = list(engine.run()) - - llm_started = any(isinstance(e, NodeRunStartedEvent) and e.node_id == "llm_a" for e in events) - llm_succeeded = any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in events) - human_b_pause = any(isinstance(e, NodeRunPauseRequestedEvent) and e.node_id == "human_b" for e in events) - graph_paused = any(isinstance(e, GraphRunPausedEvent) for e in events) - graph_started = any(isinstance(e, GraphRunStartedEvent) for e in events) - - assert graph_started - assert graph_paused - assert human_b_pause - assert llm_started - assert llm_succeeded - - -def test_parallel_human_input_pause_preserves_node_finished_after_snapshot_resume() -> None: - base_state = _build_runtime_state() - base_state.graph_execution.start() - base_state.register_paused_node("human_a") - base_state.register_paused_node("human_b") - snapshot = base_state.dumps() - - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - - submitted = StaticForm( - form_id="form-a", - rendered="rendered", - is_submitted=True, - action_id="approve", - data={}, - status_value=HumanInputFormStatus.SUBMITTED, - ) - pending = StaticForm( - form_id="form-b", - rendered="rendered", - is_submitted=False, - action_id=None, - data=None, - status_value=HumanInputFormStatus.WAITING, - ) - repo = StaticRepo({"human_a": submitted, "human_b": pending}) - - mock_config = MockConfig() - mock_config.simulate_delays = True - mock_config.set_node_config( - "llm_a", - NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5), - ) - - graph = _build_graph(resumed_state, repo, mock_config) - engine = GraphEngine( - workflow_id="workflow", - graph=graph, - graph_runtime_state=resumed_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig( - min_workers=2, - max_workers=2, - scale_up_threshold=1, - scale_down_idle_time=30.0, - ), - ) - - events = list(engine.run()) - - start_event = next(e for e in events if isinstance(e, GraphRunStartedEvent)) - assert start_event.reason is WorkflowStartReason.RESUMPTION - - llm_started = any(isinstance(e, NodeRunStartedEvent) and e.node_id == "llm_a" for e in events) - llm_succeeded = any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in events) - human_b_pause = any(isinstance(e, NodeRunPauseRequestedEvent) and e.node_id == "human_b" for e in events) - graph_paused = any(isinstance(e, GraphRunPausedEvent) for e in events) - - assert graph_paused - assert human_b_pause - assert llm_started - assert llm_succeeded diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py deleted file mode 100644 index b954a4faaca..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py +++ /dev/null @@ -1,286 +0,0 @@ -""" -Test for parallel streaming workflow behavior. - -This test validates that: -- LLM 1 always speaks English -- LLM 2 always speaks Chinese -- 2 LLMs run parallel, but LLM 2 will output before LLM 1 -- All chunks should be sent before Answer Node started -""" - -import time -from unittest.mock import MagicMock, patch -from uuid import uuid4 - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.model_manager import ModelInstance -from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.graph_engine import GraphEngine, GraphEngineConfig -from dify_graph.graph_engine.command_channels import InMemoryChannel -from dify_graph.graph_events import ( - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) -from dify_graph.node_events import NodeRunResult, StreamCompletedEvent -from dify_graph.nodes.llm.node import LLMNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_table_runner import TableTestRunner - - -def create_llm_generator_with_delay(chunks: list[str], delay: float = 0.1): - """Create a generator that simulates LLM streaming output with delay""" - - def llm_generator(self): - for i, chunk in enumerate(chunks): - time.sleep(delay) # Simulate network delay - yield NodeRunStreamChunkEvent( - id=str(uuid4()), - node_id=self.id, - node_type=self.node_type, - selector=[self.id, "text"], - chunk=chunk, - is_final=i == len(chunks) - 1, - ) - - # Complete response - full_text = "".join(chunks) - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"text": full_text}, - ) - ) - - return llm_generator - - -def test_parallel_streaming_workflow(): - """ - Test parallel streaming workflow to verify: - 1. All chunks from LLM 2 are output before LLM 1 - 2. At least one chunk from LLM 2 is output before LLM 1 completes (Success) - 3. At least one chunk from LLM 1 is output before LLM 2 completes (EXPECTED TO FAIL) - 4. All chunks are output before End begins - 5. The final output content matches the order defined in the Answer - - Test setup: - - LLM 1 outputs English (slower) - - LLM 2 outputs Chinese (faster) - - Both run in parallel - - This test is expected to FAIL because chunks are currently buffered - until after node completion instead of streaming during execution. - """ - runner = TableTestRunner() - - # Load the workflow configuration - fixture_data = runner.workflow_runner.load_fixture("multilingual_parallel_llm_streaming_workflow") - workflow_config = fixture_data.get("workflow", {}) - graph_config = workflow_config.get("graph", {}) - - # Create graph initialization parameters - init_params = build_test_graph_init_params( - workflow_id="test_workflow", - graph_config=graph_config, - tenant_id="test_tenant", - app_id="test_app", - user_id="test_user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.WEB_APP, - call_depth=0, - ) - - # Create variable pool with system variables - system_variables = SystemVariable( - user_id="test_user", - app_id="test_app", - workflow_id=init_params.workflow_id, - files=[], - query="Tell me about yourself", # User query - ) - variable_pool = VariablePool( - system_variables=system_variables, - user_inputs={}, - ) - - # Create graph runtime state - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - # Create node factory and graph - node_factory = DifyNodeFactory(graph_init_params=init_params, graph_runtime_state=graph_runtime_state) - with patch.object( - DifyNodeFactory, "_build_model_instance_for_llm_node", return_value=MagicMock(spec=ModelInstance), autospec=True - ): - graph = Graph.init( - graph_config=graph_config, - node_factory=node_factory, - root_node_id=get_default_root_node_id(graph_config), - ) - - # Create the graph engine - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - # Define LLM outputs - llm1_chunks = ["Hello", ", ", "I", " ", "am", " ", "an", " ", "AI", " ", "assistant", "."] # English (slower) - llm2_chunks = ["你好", ",", "我", "是", "AI", "助手", "。"] # Chinese (faster) - - # Create generators with different delays (LLM 2 is faster) - llm1_generator = create_llm_generator_with_delay(llm1_chunks, delay=0.05) # Slower - llm2_generator = create_llm_generator_with_delay(llm2_chunks, delay=0.01) # Faster - - # Track which LLM node is being called - llm_call_order = [] - generators = { - "1754339718571": llm1_generator, # LLM 1 node ID - "1754339725656": llm2_generator, # LLM 2 node ID - } - - def mock_llm_run(self): - llm_call_order.append(self.id) - generator = generators.get(self.id) - if generator: - yield from generator(self) - else: - raise Exception(f"Unexpected LLM node ID: {self.id}") - - # Execute with mocked LLMs - with patch.object(LLMNode, "_run", new=mock_llm_run): - events = list(engine.run()) - - # Check for successful completion - success_events = [e for e in events if isinstance(e, GraphRunSucceededEvent)] - assert len(success_events) > 0, "Workflow should complete successfully" - - # Get all streaming chunk events - stream_chunk_events = [e for e in events if isinstance(e, NodeRunStreamChunkEvent)] - - # Get Answer node start event - answer_start_events = [ - e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == BuiltinNodeTypes.ANSWER - ] - assert len(answer_start_events) == 1, f"Expected 1 Answer node start event, got {len(answer_start_events)}" - answer_start_event = answer_start_events[0] - - # Find the index of Answer node start - answer_start_index = events.index(answer_start_event) - - # Collect chunk events by node - llm1_chunks_events = [e for e in stream_chunk_events if e.node_id == "1754339718571"] - llm2_chunks_events = [e for e in stream_chunk_events if e.node_id == "1754339725656"] - - # Verify both LLMs produced chunks - assert len(llm1_chunks_events) == len(llm1_chunks), ( - f"Expected {len(llm1_chunks)} chunks from LLM 1, got {len(llm1_chunks_events)}" - ) - assert len(llm2_chunks_events) == len(llm2_chunks), ( - f"Expected {len(llm2_chunks)} chunks from LLM 2, got {len(llm2_chunks_events)}" - ) - - # 1. Verify chunk ordering based on actual implementation - llm1_chunk_indices = [events.index(e) for e in llm1_chunks_events] - llm2_chunk_indices = [events.index(e) for e in llm2_chunks_events] - - # In the current implementation, chunks may be interleaved or in a specific order - # Update this based on actual behavior observed - if llm1_chunk_indices and llm2_chunk_indices: - # Check the actual ordering - if LLM 2 chunks come first (as seen in debug) - assert max(llm2_chunk_indices) < min(llm1_chunk_indices), ( - f"All LLM 2 chunks should be output before LLM 1 chunks. " - f"LLM 2 chunk indices: {llm2_chunk_indices}, LLM 1 chunk indices: {llm1_chunk_indices}" - ) - - # Get indices of all chunk events - chunk_indices = [events.index(e) for e in stream_chunk_events if e in llm1_chunks_events + llm2_chunks_events] - - # 4. Verify all chunks were sent before Answer node started - assert all(idx < answer_start_index for idx in chunk_indices), ( - "All LLM chunks should be sent before Answer node starts" - ) - - # The test has successfully verified: - # 1. Both LLMs run in parallel (they start at the same time) - # 2. LLM 2 (Chinese) outputs all its chunks before LLM 1 (English) due to faster processing - # 3. All LLM chunks are sent before the Answer node starts - - # Get LLM completion events - llm_completed_events = [ - (i, e) - for i, e in enumerate(events) - if isinstance(e, NodeRunSucceededEvent) and e.node_type == BuiltinNodeTypes.LLM - ] - - # Check LLM completion order - in the current implementation, LLMs run sequentially - # LLM 1 completes first, then LLM 2 runs and completes - assert len(llm_completed_events) == 2, f"Expected 2 LLM completion events, got {len(llm_completed_events)}" - llm2_complete_idx = next((i for i, e in llm_completed_events if e.node_id == "1754339725656"), None) - llm1_complete_idx = next((i for i, e in llm_completed_events if e.node_id == "1754339718571"), None) - assert llm2_complete_idx is not None, "LLM 2 completion event not found" - assert llm1_complete_idx is not None, "LLM 1 completion event not found" - # In the actual implementation, LLM 1 completes before LLM 2 (sequential execution) - assert llm1_complete_idx < llm2_complete_idx, ( - f"LLM 1 should complete before LLM 2 in sequential execution, but LLM 1 completed at {llm1_complete_idx} " - f"and LLM 2 completed at {llm2_complete_idx}" - ) - - # 2. In sequential execution, LLM 2 chunks appear AFTER LLM 1 completes - if llm2_chunk_indices: - # LLM 1 completes first, then LLM 2 starts streaming - assert min(llm2_chunk_indices) > llm1_complete_idx, ( - f"LLM 2 chunks should appear after LLM 1 completes in sequential execution. " - f"First LLM 2 chunk at index {min(llm2_chunk_indices)}, LLM 1 completed at index {llm1_complete_idx}" - ) - - # 3. In the current implementation, LLM 1 chunks appear after LLM 2 completes - # This is because chunks are buffered and output after both nodes complete - if llm1_chunk_indices and llm2_complete_idx: - # Check if LLM 1 chunks exist and where they appear relative to LLM 2 completion - # In current behavior, LLM 1 chunks typically appear after LLM 2 completes - pass # Skipping this check as the chunk ordering is implementation-dependent - - # CURRENT BEHAVIOR: Chunks are buffered and appear after node completion - # In the sequential execution, LLM 1 completes first without streaming, - # then LLM 2 streams its chunks - assert stream_chunk_events, "Expected streaming events, but got none" - - first_chunk_index = events.index(stream_chunk_events[0]) - llm_success_indices = [i for i, e in llm_completed_events] - - # Current implementation: LLM 1 completes first, then chunks start appearing - # This is the actual behavior we're testing - if llm_success_indices: - # At least one LLM (LLM 1) completes before any chunks appear - assert min(llm_success_indices) < first_chunk_index, ( - f"In current implementation, LLM 1 completes before chunks start streaming. " - f"First chunk at index {first_chunk_index}, LLM 1 completed at index {min(llm_success_indices)}" - ) - - # 5. Verify final output content matches the order defined in Answer node - # According to Answer node configuration: '{{#1754339725656.text#}}{{#1754339718571.text#}}' - # This means LLM 2 output should come first, then LLM 1 output - answer_complete_events = [ - e for e in events if isinstance(e, NodeRunSucceededEvent) and e.node_type == BuiltinNodeTypes.ANSWER - ] - assert len(answer_complete_events) == 1, f"Expected 1 Answer completion event, got {len(answer_complete_events)}" - - answer_outputs = answer_complete_events[0].node_run_result.outputs - expected_answer_text = "你好,我是AI助手。Hello, I am an AI assistant." - - if "answer" in answer_outputs: - actual_answer_text = answer_outputs["answer"] - assert actual_answer_text == expected_answer_text, ( - f"Answer content should match the order defined in Answer node. " - f"Expected: '{expected_answer_text}', Got: '{actual_answer_text}'" - ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py deleted file mode 100644 index 7328ce443f2..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py +++ /dev/null @@ -1,309 +0,0 @@ -import time -from collections.abc import Mapping -from dataclasses import dataclass -from datetime import datetime, timedelta -from typing import Any - -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.graph import Graph -from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from dify_graph.graph_engine.config import GraphEngineConfig -from dify_graph.graph_engine.graph_engine import GraphEngine -from dify_graph.graph_events import ( - GraphRunPausedEvent, - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, -) -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole -from dify_graph.nodes.end.end_node import EndNode -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction -from dify_graph.nodes.human_input.enums import HumanInputFormStatus -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from libs.datetime_utils import naive_utc_now -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_mock_config import MockConfig, NodeMockConfig -from .test_mock_nodes import MockLLMNode - - -@dataclass -class StaticForm(HumanInputFormEntity): - form_id: str - rendered: str - is_submitted: bool - action_id: str | None = None - data: Mapping[str, Any] | None = None - status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING - expiration: datetime = naive_utc_now() + timedelta(days=1) - - @property - def id(self) -> str: - return self.form_id - - @property - def web_app_token(self) -> str | None: - return "token" - - @property - def recipients(self) -> list: - return [] - - @property - def rendered_content(self) -> str: - return self.rendered - - @property - def selected_action_id(self) -> str | None: - return self.action_id - - @property - def submitted_data(self) -> Mapping[str, Any] | None: - return self.data - - @property - def submitted(self) -> bool: - return self.is_submitted - - @property - def status(self) -> HumanInputFormStatus: - return self.status_value - - @property - def expiration_time(self) -> datetime: - return self.expiration - - -class StaticRepo(HumanInputFormRepository): - def __init__(self, form: HumanInputFormEntity) -> None: - self._form = form - - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: - if node_id != "human_pause": - return None - return self._form - - def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: - raise AssertionError("create_form should not be called in this test") - - -def _build_runtime_state() -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="exec-1", - ), - user_inputs={}, - conversation_variables=[], - ) - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - llm_a_data = LLMNodeData( - title="LLM A", - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text="Prompt A", - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - structured_output_enabled=False, - ) - llm_a_config = {"id": "llm_a", "data": llm_a_data.model_dump()} - llm_a = MockLLMNode( - id=llm_a_config["id"], - config=llm_a_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - mock_config=mock_config, - ) - - llm_b_data = LLMNodeData( - title="LLM B", - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text="Prompt B", - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - structured_output_enabled=False, - ) - llm_b_config = {"id": "llm_b", "data": llm_b_data.model_dump()} - llm_b = MockLLMNode( - id=llm_b_config["id"], - config=llm_b_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - mock_config=mock_config, - ) - - human_data = HumanInputNodeData( - title="Human Input", - form_content="Pause here", - inputs=[], - user_actions=[UserAction(id="approve", title="Approve")], - ) - human_config = {"id": "human_pause", "data": human_data.model_dump()} - human_node = HumanInputNode( - id=human_config["id"], - config=human_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - form_repository=repo, - ) - - end_human_data = EndNodeData(title="End Human", outputs=[], desc=None) - end_human_config = {"id": "end_human", "data": end_human_data.model_dump()} - end_human = EndNode( - id=end_human_config["id"], - config=end_human_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - return ( - Graph.new() - .add_root(start_node) - .add_node(llm_a, from_node_id="start") - .add_node(human_node, from_node_id="start") - .add_node(llm_b, from_node_id="llm_a") - .add_node(end_human, from_node_id="human_pause", source_handle="approve") - .build() - ) - - -def _get_node_started_event(events: list[object], node_id: str) -> NodeRunStartedEvent | None: - for event in events: - if isinstance(event, NodeRunStartedEvent) and event.node_id == node_id: - return event - return None - - -def test_pause_defers_ready_nodes_until_resume() -> None: - runtime_state = _build_runtime_state() - - paused_form = StaticForm( - form_id="form-pause", - rendered="rendered", - is_submitted=False, - status_value=HumanInputFormStatus.WAITING, - ) - pause_repo = StaticRepo(paused_form) - - mock_config = MockConfig() - mock_config.simulate_delays = True - mock_config.set_node_config( - "llm_a", - NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5), - ) - mock_config.set_node_config( - "llm_b", - NodeMockConfig(node_id="llm_b", outputs={"text": "LLM B output"}, delay=0.0), - ) - - graph = _build_graph(runtime_state, pause_repo, mock_config) - engine = GraphEngine( - workflow_id="workflow", - graph=graph, - graph_runtime_state=runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig( - min_workers=2, - max_workers=2, - scale_up_threshold=1, - scale_down_idle_time=30.0, - ), - ) - - paused_events = list(engine.run()) - - assert any(isinstance(e, GraphRunPausedEvent) for e in paused_events) - assert any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in paused_events) - assert _get_node_started_event(paused_events, "llm_b") is None - - snapshot = runtime_state.dumps() - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - - submitted_form = StaticForm( - form_id="form-pause", - rendered="rendered", - is_submitted=True, - action_id="approve", - data={}, - status_value=HumanInputFormStatus.SUBMITTED, - ) - resume_repo = StaticRepo(submitted_form) - - resumed_graph = _build_graph(resumed_state, resume_repo, mock_config) - resumed_engine = GraphEngine( - workflow_id="workflow", - graph=resumed_graph, - graph_runtime_state=resumed_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig( - min_workers=2, - max_workers=2, - scale_up_threshold=1, - scale_down_idle_time=30.0, - ), - ) - - resumed_events = list(resumed_engine.run()) - - start_event = next(e for e in resumed_events if isinstance(e, GraphRunStartedEvent)) - assert start_event.reason is WorkflowStartReason.RESUMPTION - - llm_b_started = _get_node_started_event(resumed_events, "llm_b") - assert llm_b_started is not None - assert any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_b" for e in resumed_events) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py deleted file mode 100644 index 15a7de3c521..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py +++ /dev/null @@ -1,217 +0,0 @@ -import datetime -import time -from typing import Any -from unittest.mock import MagicMock - -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.graph import Graph -from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from dify_graph.graph_engine.graph_engine import GraphEngine -from dify_graph.graph_events import ( - GraphEngineEvent, - GraphRunPausedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, -) -from dify_graph.graph_events.graph import GraphRunStartedEvent -from dify_graph.nodes.base.entities import OutputVariableEntity -from dify_graph.nodes.end.end_node import EndNode -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import ( - HumanInputFormEntity, - HumanInputFormRepository, -) -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from libs.datetime_utils import naive_utc_now -from tests.workflow_test_utils import build_test_graph_init_params - - -def _build_runtime_state() -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="test-execution-id", - ), - user_inputs={}, - conversation_variables=[], - ) - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepository: - repo = MagicMock(spec=HumanInputFormRepository) - form_entity = MagicMock(spec=HumanInputFormEntity) - form_entity.id = "test-form-id" - form_entity.web_app_token = "test-form-token" - form_entity.recipients = [] - form_entity.rendered_content = "rendered" - form_entity.submitted = True - form_entity.selected_action_id = action_id - form_entity.submitted_data = {} - form_entity.expiration_time = naive_utc_now() + datetime.timedelta(days=1) - repo.get_form.return_value = form_entity - return repo - - -def _mock_form_repository_without_submission() -> HumanInputFormRepository: - repo = MagicMock(spec=HumanInputFormRepository) - form_entity = MagicMock(spec=HumanInputFormEntity) - form_entity.id = "test-form-id" - form_entity.web_app_token = "test-form-token" - form_entity.recipients = [] - form_entity.rendered_content = "rendered" - form_entity.submitted = False - repo.create_form.return_value = form_entity - repo.get_form.return_value = None - return repo - - -def _build_human_input_graph( - runtime_state: GraphRuntimeState, - form_repository: HumanInputFormRepository, -) -> Graph: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="service-api", - call_depth=0, - ) - - start_data = StartNodeData(title="start", variables=[]) - start_node = StartNode( - id="start", - config={"id": "start", "data": start_data.model_dump()}, - graph_init_params=params, - graph_runtime_state=runtime_state, - ) - - human_data = HumanInputNodeData( - title="human", - form_content="Awaiting human input", - inputs=[], - user_actions=[ - UserAction(id="continue", title="Continue"), - ], - ) - human_node = HumanInputNode( - id="human", - config={"id": "human", "data": human_data.model_dump()}, - graph_init_params=params, - graph_runtime_state=runtime_state, - form_repository=form_repository, - ) - - end_data = EndNodeData( - title="end", - outputs=[ - OutputVariableEntity(variable="result", value_selector=["human", "action_id"]), - ], - desc=None, - ) - end_node = EndNode( - id="end", - config={"id": "end", "data": end_data.model_dump()}, - graph_init_params=params, - graph_runtime_state=runtime_state, - ) - - return ( - Graph.new() - .add_root(start_node) - .add_node(human_node) - .add_node(end_node, from_node_id="human", source_handle="continue") - .build() - ) - - -def _run_graph(graph: Graph, runtime_state: GraphRuntimeState) -> list[GraphEngineEvent]: - engine = GraphEngine( - workflow_id="workflow", - graph=graph, - graph_runtime_state=runtime_state, - command_channel=InMemoryChannel(), - ) - return list(engine.run()) - - -def _node_successes(events: list[GraphEngineEvent]) -> list[str]: - return [event.node_id for event in events if isinstance(event, NodeRunSucceededEvent)] - - -def _node_start_event(events: list[GraphEngineEvent], node_id: str) -> NodeRunStartedEvent | None: - for event in events: - if isinstance(event, NodeRunStartedEvent) and event.node_id == node_id: - return event - return None - - -def _segment_value(variable_pool: VariablePool, selector: tuple[str, str]) -> Any: - segment = variable_pool.get(selector) - assert segment is not None - return getattr(segment, "value", segment) - - -def test_engine_resume_restores_state_and_completion(): - # Baseline run without pausing - baseline_state = _build_runtime_state() - baseline_repo = _mock_form_repository_with_submission(action_id="continue") - baseline_graph = _build_human_input_graph(baseline_state, baseline_repo) - baseline_events = _run_graph(baseline_graph, baseline_state) - assert baseline_events - first_paused_event = baseline_events[0] - assert isinstance(first_paused_event, GraphRunStartedEvent) - assert first_paused_event.reason is WorkflowStartReason.INITIAL - assert isinstance(baseline_events[-1], GraphRunSucceededEvent) - baseline_success_nodes = _node_successes(baseline_events) - - # Run with pause - paused_state = _build_runtime_state() - pause_repo = _mock_form_repository_without_submission() - paused_graph = _build_human_input_graph(paused_state, pause_repo) - paused_events = _run_graph(paused_graph, paused_state) - assert paused_events - first_paused_event = paused_events[0] - assert isinstance(first_paused_event, GraphRunStartedEvent) - assert first_paused_event.reason is WorkflowStartReason.INITIAL - assert isinstance(paused_events[-1], GraphRunPausedEvent) - snapshot = paused_state.dumps() - - # Resume from snapshot - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - resume_repo = _mock_form_repository_with_submission(action_id="continue") - resumed_graph = _build_human_input_graph(resumed_state, resume_repo) - resumed_events = _run_graph(resumed_graph, resumed_state) - assert resumed_events - first_resumed_event = resumed_events[0] - assert isinstance(first_resumed_event, GraphRunStartedEvent) - assert first_resumed_event.reason is WorkflowStartReason.RESUMPTION - assert isinstance(resumed_events[-1], GraphRunSucceededEvent) - - combined_success_nodes = _node_successes(paused_events) + _node_successes(resumed_events) - assert combined_success_nodes == baseline_success_nodes - - paused_human_started = _node_start_event(paused_events, "human") - resumed_human_started = _node_start_event(resumed_events, "human") - assert paused_human_started is not None - assert resumed_human_started is not None - assert paused_human_started.id == resumed_human_started.id - - assert baseline_state.outputs == resumed_state.outputs - assert _segment_value(baseline_state.variable_pool, ("human", "__action_id")) == _segment_value( - resumed_state.variable_pool, ("human", "__action_id") - ) - assert baseline_state.graph_execution.completed - assert resumed_state.graph_execution.completed diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py b/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py deleted file mode 100644 index 9c84f42db6a..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py +++ /dev/null @@ -1,268 +0,0 @@ -""" -Unit tests for Redis-based stop functionality in GraphEngine. - -Tests the integration of Redis command channel for stopping workflows -without user permission checks. -""" - -import json -from unittest.mock import MagicMock, Mock, patch - -import pytest -import redis - -from core.app.apps.base_app_queue_manager import AppQueueManager -from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel -from dify_graph.graph_engine.entities.commands import AbortCommand, CommandType, PauseCommand -from dify_graph.graph_engine.manager import GraphEngineManager - - -class TestRedisStopIntegration: - """Test suite for Redis-based workflow stop functionality.""" - - def test_graph_engine_manager_sends_abort_command(self): - """Test that GraphEngineManager correctly sends abort command through Redis.""" - # Setup - task_id = "test-task-123" - expected_channel_key = f"workflow:{task_id}:commands" - - # Mock redis client - mock_redis = MagicMock() - mock_pipeline = MagicMock() - mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline) - mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None) - - manager = GraphEngineManager(mock_redis) - - # Execute - manager.send_stop_command(task_id, reason="Test stop") - - # Verify - mock_redis.pipeline.assert_called_once() - - # Check that rpush was called with correct arguments - calls = mock_pipeline.rpush.call_args_list - assert len(calls) == 1 - - # Verify the channel key - assert calls[0][0][0] == expected_channel_key - - # Verify the command data - command_json = calls[0][0][1] - command_data = json.loads(command_json) - assert command_data["command_type"] == CommandType.ABORT - assert command_data["reason"] == "Test stop" - - def test_graph_engine_manager_sends_pause_command(self): - """Test that GraphEngineManager correctly sends pause command through Redis.""" - task_id = "test-task-pause-123" - expected_channel_key = f"workflow:{task_id}:commands" - - mock_redis = MagicMock() - mock_pipeline = MagicMock() - mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline) - mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None) - - manager = GraphEngineManager(mock_redis) - manager.send_pause_command(task_id, reason="Awaiting resources") - - mock_redis.pipeline.assert_called_once() - calls = mock_pipeline.rpush.call_args_list - assert len(calls) == 1 - assert calls[0][0][0] == expected_channel_key - - command_json = calls[0][0][1] - command_data = json.loads(command_json) - assert command_data["command_type"] == CommandType.PAUSE.value - assert command_data["reason"] == "Awaiting resources" - - def test_graph_engine_manager_handles_redis_failure_gracefully(self): - """Test that GraphEngineManager handles Redis failures without raising exceptions.""" - task_id = "test-task-456" - - # Mock redis client to raise exception - mock_redis = MagicMock() - mock_redis.pipeline.side_effect = redis.ConnectionError("Redis connection failed") - manager = GraphEngineManager(mock_redis) - - # Should not raise exception - try: - manager.send_stop_command(task_id) - except Exception as e: - pytest.fail(f"GraphEngineManager.send_stop_command raised {e} unexpectedly") - - def test_app_queue_manager_no_user_check(self): - """Test that AppQueueManager.set_stop_flag_no_user_check works without user validation.""" - task_id = "test-task-789" - expected_cache_key = f"generate_task_stopped:{task_id}" - - # Mock redis client - mock_redis = MagicMock() - - with patch("core.app.apps.base_app_queue_manager.redis_client", mock_redis): - # Execute - AppQueueManager.set_stop_flag_no_user_check(task_id) - - # Verify - mock_redis.setex.assert_called_once_with(expected_cache_key, 600, 1) - - def test_app_queue_manager_no_user_check_with_empty_task_id(self): - """Test that AppQueueManager.set_stop_flag_no_user_check handles empty task_id.""" - # Mock redis client - mock_redis = MagicMock() - - with patch("core.app.apps.base_app_queue_manager.redis_client", mock_redis): - # Execute with empty task_id - AppQueueManager.set_stop_flag_no_user_check("") - - # Verify redis was not called - mock_redis.setex.assert_not_called() - - def test_redis_channel_send_abort_command(self): - """Test RedisChannel correctly serializes and sends AbortCommand.""" - # Setup - mock_redis = MagicMock() - mock_pipeline = MagicMock() - mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline) - mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None) - - channel_key = "workflow:test:commands" - channel = RedisChannel(mock_redis, channel_key) - - # Create commands - abort_command = AbortCommand(reason="User requested stop") - pause_command = PauseCommand(reason="User requested pause") - - # Execute - channel.send_command(abort_command) - channel.send_command(pause_command) - - # Verify - mock_redis.pipeline.assert_called() - - # Check rpush was called - calls = mock_pipeline.rpush.call_args_list - assert len(calls) == 2 - assert calls[0][0][0] == channel_key - assert calls[1][0][0] == channel_key - - # Verify serialized commands - abort_command_json = calls[0][0][1] - abort_command_data = json.loads(abort_command_json) - assert abort_command_data["command_type"] == CommandType.ABORT.value - assert abort_command_data["reason"] == "User requested stop" - - pause_command_json = calls[1][0][1] - pause_command_data = json.loads(pause_command_json) - assert pause_command_data["command_type"] == CommandType.PAUSE.value - assert pause_command_data["reason"] == "User requested pause" - - # Check expire was set for each - assert mock_pipeline.expire.call_count == 2 - mock_pipeline.expire.assert_any_call(channel_key, 3600) - - def test_redis_channel_fetch_commands(self): - """Test RedisChannel correctly fetches and deserializes commands.""" - # Setup - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - # Mock command data - abort_command_json = json.dumps( - {"command_type": CommandType.ABORT.value, "reason": "Test abort", "payload": None} - ) - pause_command_json = json.dumps( - {"command_type": CommandType.PAUSE.value, "reason": "Pause requested", "payload": None} - ) - - # Mock pipeline execute to return commands - pending_pipe.execute.return_value = [b"1", 1] - fetch_pipe.execute.return_value = [ - [abort_command_json.encode(), pause_command_json.encode()], # lrange result - True, # delete result - ] - - channel_key = "workflow:test:commands" - channel = RedisChannel(mock_redis, channel_key) - - # Execute - commands = channel.fetch_commands() - - # Verify - assert len(commands) == 2 - assert isinstance(commands[0], AbortCommand) - assert commands[0].command_type == CommandType.ABORT - assert commands[0].reason == "Test abort" - assert isinstance(commands[1], PauseCommand) - assert commands[1].command_type == CommandType.PAUSE - assert commands[1].reason == "Pause requested" - - # Verify Redis operations - pending_pipe.get.assert_called_once_with(f"{channel_key}:pending") - pending_pipe.delete.assert_called_once_with(f"{channel_key}:pending") - fetch_pipe.lrange.assert_called_once_with(channel_key, 0, -1) - fetch_pipe.delete.assert_called_once_with(channel_key) - assert mock_redis.pipeline.call_count == 2 - - def test_redis_channel_fetch_commands_handles_invalid_json(self): - """Test RedisChannel gracefully handles invalid JSON in commands.""" - # Setup - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - # Mock invalid command data - pending_pipe.execute.return_value = [b"1", 1] - fetch_pipe.execute.return_value = [ - [b"invalid json", b'{"command_type": "invalid_type"}'], # lrange result - True, # delete result - ] - - channel_key = "workflow:test:commands" - channel = RedisChannel(mock_redis, channel_key) - - # Execute - commands = channel.fetch_commands() - - # Should return empty list due to invalid commands - assert len(commands) == 0 - - def test_dual_stop_mechanism_compatibility(self): - """Test that both stop mechanisms can work together.""" - task_id = "test-task-dual" - - # Mock redis client - mock_redis = MagicMock() - mock_pipeline = MagicMock() - mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline) - mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None) - - with patch("core.app.apps.base_app_queue_manager.redis_client", mock_redis): - # Execute both stop mechanisms - AppQueueManager.set_stop_flag_no_user_check(task_id) - GraphEngineManager(mock_redis).send_stop_command(task_id) - - # Verify legacy stop flag was set - expected_stop_flag_key = f"generate_task_stopped:{task_id}" - mock_redis.setex.assert_called_once_with(expected_stop_flag_key, 600, 1) - - # Verify command was sent through Redis channel - mock_redis.pipeline.assert_called() - calls = mock_pipeline.rpush.call_args_list - assert len(calls) == 1 - assert calls[0][0][0] == f"workflow:{task_id}:commands" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_response_session.py b/api/tests/unit_tests/core/workflow/graph_engine/test_response_session.py deleted file mode 100644 index cd9d56f683d..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_response_session.py +++ /dev/null @@ -1,55 +0,0 @@ -"""Unit tests for response session creation.""" - -from __future__ import annotations - -import pytest - -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeState, NodeType -from dify_graph.graph_engine.response_coordinator.session import ResponseSession -from dify_graph.nodes.base.template import Template, TextSegment - - -class DummyResponseNode: - """Minimal response-capable node for session tests.""" - - def __init__(self, *, node_id: str, node_type: NodeType, template: Template) -> None: - self.id = node_id - self.node_type = node_type - self.execution_type = NodeExecutionType.RESPONSE - self.state = NodeState.UNKNOWN - self._template = template - - def get_streaming_template(self) -> Template: - return self._template - - -class DummyNodeWithoutStreamingTemplate: - """Minimal node that violates the response-session contract.""" - - def __init__(self, *, node_id: str, node_type: NodeType) -> None: - self.id = node_id - self.node_type = node_type - self.execution_type = NodeExecutionType.RESPONSE - self.state = NodeState.UNKNOWN - - -def test_response_session_from_node_accepts_nodes_outside_previous_allowlist() -> None: - """Session creation depends on the streaming-template contract rather than node type.""" - node = DummyResponseNode( - node_id="llm-node", - node_type=BuiltinNodeTypes.LLM, - template=Template(segments=[TextSegment(text="hello")]), - ) - - session = ResponseSession.from_node(node) - - assert session.node_id == "llm-node" - assert session.template.segments == [TextSegment(text="hello")] - - -def test_response_session_from_node_requires_streaming_template_method() -> None: - """Allowed node types still need to implement the streaming-template contract.""" - node = DummyNodeWithoutStreamingTemplate(node_id="answer-node", node_type=BuiltinNodeTypes.ANSWER) - - with pytest.raises(TypeError, match="get_streaming_template"): - ResponseSession.from_node(node) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py b/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py deleted file mode 100644 index 4f1741d4fb8..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py +++ /dev/null @@ -1,77 +0,0 @@ -from dify_graph.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) - -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_streaming_conversation_variables(): - fixture_name = "test_streaming_conversation_variables" - - # The test expects the workflow to output the input query - # Since the workflow assigns sys.query to conversation variable "str" and then answers with it - input_query = "Hello, this is my test query" - - mock_config = MockConfigBuilder().build() - - case = WorkflowTestCase( - fixture_path=fixture_name, - use_auto_mock=False, # Don't use auto mock since we want to test actual variable assignment - mock_config=mock_config, - query=input_query, # Pass query as the sys.query value - inputs={}, # No additional inputs needed - expected_outputs={"answer": input_query}, # Expecting the input query to be output - expected_event_sequence=[ - GraphRunStartedEvent, - # START node - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Variable Assigner node - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - # ANSWER node - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - assert result.success, f"Test failed: {result.error}" - - -def test_streaming_conversation_variables_v1_overwrite_waits_for_assignment(): - fixture_name = "test_streaming_conversation_variables_v1_overwrite" - input_query = "overwrite-value" - - case = WorkflowTestCase( - fixture_path=fixture_name, - use_auto_mock=False, - mock_config=MockConfigBuilder().build(), - query=input_query, - inputs={}, - expected_outputs={"answer": f"Current Value Of `conv_var` is:{input_query}"}, - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - assert result.success, f"Test failed: {result.error}" - - events = result.events - conv_var_chunk_events = [ - event - for event in events - if isinstance(event, NodeRunStreamChunkEvent) and tuple(event.selector) == ("conversation", "conv_var") - ] - - assert conv_var_chunk_events, "Expected conversation variable chunk events to be emitted" - assert all(event.chunk == input_query for event in conv_var_chunk_events), ( - "Expected streamed conversation variable value to match the input query" - ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index ab8fb346b8f..b11f9576777 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -12,29 +12,24 @@ This module provides a robust table-driven testing framework with support for: import logging import time -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Callable, Sequence from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, field from functools import lru_cache from pathlib import Path -from typing import Any, cast +from typing import Any -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.tools.utils.yaml_utils import _load_yaml_file -from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams -from dify_graph.graph import Graph -from dify_graph.graph_engine import GraphEngine, GraphEngineConfig -from dify_graph.graph_engine.command_channels import InMemoryChannel -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events import ( +from graphon.entities import GraphInitParams +from graphon.graph import Graph +from graphon.graph_engine import GraphEngine, GraphEngineConfig +from graphon.graph_engine.command_channels import InMemoryChannel +from graphon.graph_events import ( GraphEngineEvent, GraphRunStartedEvent, GraphRunSucceededEvent, ) -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables import ( +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables import ( ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, @@ -44,6 +39,12 @@ from dify_graph.variables import ( StringVariable, ) +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.tools.utils.yaml_utils import _load_yaml_file +from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool + from .test_mock_config import MockConfig from .test_mock_factory import MockNodeFactory @@ -60,20 +61,28 @@ class _TableTestChildEngineBuilder: *, workflow_id: str, graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: Mapping[str, Any], + parent_graph_runtime_state: GraphRuntimeState, root_node_id: str, - layers: Sequence[object] = (), + variable_pool: VariablePool | None = None, ) -> GraphEngine: + child_graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool if variable_pool is not None else parent_graph_runtime_state.variable_pool, + start_at=time.perf_counter(), + execution_context=parent_graph_runtime_state.execution_context, + ) if self._use_mock_factory: node_factory = MockNodeFactory( graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, + graph_runtime_state=child_graph_runtime_state, mock_config=self._mock_config, ) else: - node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state) + node_factory = DifyNodeFactory( + graph_init_params=graph_init_params, + graph_runtime_state=child_graph_runtime_state, + ) + graph_config = graph_init_params.graph_config child_graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=root_node_id) if not child_graph: raise ValueError("child graph not found") @@ -81,13 +90,11 @@ class _TableTestChildEngineBuilder: child_engine = GraphEngine( workflow_id=workflow_id, graph=child_graph, - graph_runtime_state=graph_runtime_state, + graph_runtime_state=child_graph_runtime_state, command_channel=InMemoryChannel(), config=GraphEngineConfig(), child_engine_builder=self, ) - for layer in layers: - child_engine.layer(cast(GraphEngineLayer, layer)) return child_engine @@ -206,14 +213,15 @@ class WorkflowRunner: call_depth=0, ) - system_variables = SystemVariable( + system_variables = build_system_variables( user_id="test_user", app_id="test_app", workflow_id=graph_init_params.workflow_id, files=[], query=query, ) - user_inputs = inputs if inputs is not None else {} + root_node_inputs = dict(inputs or {}) + root_node_inputs.setdefault("query", query) # Extract conversation variables from workflow config conversation_variables = [] @@ -242,11 +250,16 @@ class WorkflowRunner: ) conversation_variables.append(var) - variable_pool = VariablePool( - system_variables=system_variables, - user_inputs=user_inputs, - conversation_variables=conversation_variables, + root_node_id = get_default_root_node_id(graph_config) + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_variables, + conversation_variables=conversation_variables, + ), ) + add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=root_node_inputs) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) @@ -260,7 +273,7 @@ class WorkflowRunner: graph = Graph.init( graph_config=graph_config, node_factory=node_factory, - root_node_id=get_default_root_node_id(graph_config), + root_node_id=root_node_id, ) return graph, graph_runtime_state diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py index 7f26bc11a72..12aec6edf24 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py @@ -1,6 +1,6 @@ -from dify_graph.graph_engine import GraphEngine, GraphEngineConfig -from dify_graph.graph_engine.command_channels import InMemoryChannel -from dify_graph.graph_events import ( +from graphon.graph_engine import GraphEngine, GraphEngineConfig +from graphon.graph_engine.command_channels import InMemoryChannel +from graphon.graph_events import ( GraphRunSucceededEvent, NodeRunStreamChunkEvent, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_update_conversation_variable_iteration.py b/api/tests/unit_tests/core/workflow/graph_engine/test_update_conversation_variable_iteration.py deleted file mode 100644 index a7309f64de9..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_update_conversation_variable_iteration.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Validate conversation variable updates inside an iteration workflow. - -This test uses the ``update-conversation-variable-in-iteration`` fixture, which -routes ``sys.query`` into the conversation variable ``answer`` from within an -iteration container. The workflow should surface that updated conversation -variable in the final answer output. - -Code nodes in the fixture are mocked because their concrete outputs are not -relevant to verifying variable propagation semantics. -""" - -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_update_conversation_variable_in_iteration(): - fixture_name = "update-conversation-variable-in-iteration" - user_query = "ensure conversation variable syncs" - - mock_config = ( - MockConfigBuilder() - .with_node_output("1759032363865", {"result": [1]}) - .with_node_output("1759032476318", {"result": ""}) - .build() - ) - - case = WorkflowTestCase( - fixture_path=fixture_name, - use_auto_mock=True, - mock_config=mock_config, - query=user_query, - expected_outputs={"answer": user_query}, - description="Conversation variable updated within iteration should flow to answer output.", - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - - assert result.success, f"Workflow execution failed: {result.error}" - assert result.actual_outputs is not None - assert result.actual_outputs.get("answer") == user_query diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py b/api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py deleted file mode 100644 index f63e8ff4ce5..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py +++ /dev/null @@ -1,58 +0,0 @@ -from unittest.mock import patch - -import pytest - -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode - -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -class TestVariableAggregator: - """Test cases for the variable aggregator workflow.""" - - @pytest.mark.parametrize( - ("switch1", "switch2", "expected_group1", "expected_group2", "description"), - [ - (0, 0, "switch 1 off", "switch 2 off", "Both switches off"), - (0, 1, "switch 1 off", "switch 2 on", "Switch1 off, Switch2 on"), - (1, 0, "switch 1 on", "switch 2 off", "Switch1 on, Switch2 off"), - (1, 1, "switch 1 on", "switch 2 on", "Both switches on"), - ], - ) - def test_variable_aggregator_combinations( - self, - switch1: int, - switch2: int, - expected_group1: str, - expected_group2: str, - description: str, - ) -> None: - """Test all four combinations of switch1 and switch2.""" - - def mock_template_transform_run(self): - """Mock the TemplateTransformNode._run() method to return results based on node title.""" - title = self._node_data.title - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={}, outputs={"output": title}) - - with patch.object( - TemplateTransformNode, - "_run", - mock_template_transform_run, - ): - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="dual_switch_variable_aggregator_workflow", - inputs={"switch1": switch1, "switch2": switch2}, - expected_outputs={"group1": expected_group1, "group2": expected_group2}, - description=description, - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Test failed: {result.error}" - assert result.actual_outputs == test_case.expected_outputs, ( - f"Output mismatch: expected {test_case.expected_outputs}, got {result.actual_outputs}" - ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_worker.py b/api/tests/unit_tests/core/workflow/graph_engine/test_worker.py deleted file mode 100644 index bc00b49fba7..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_worker.py +++ /dev/null @@ -1,145 +0,0 @@ -import queue -from collections.abc import Generator -from datetime import UTC, datetime, timedelta -from types import SimpleNamespace -from unittest.mock import MagicMock, patch - -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.graph_engine.ready_queue import InMemoryReadyQueue -from dify_graph.graph_engine.worker import Worker -from dify_graph.graph_events import NodeRunFailedEvent, NodeRunStartedEvent - - -def test_build_fallback_failure_event_uses_naive_utc_and_failed_node_run_result(mocker) -> None: - fixed_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None) - mocker.patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=fixed_time) - - worker = Worker( - ready_queue=InMemoryReadyQueue(), - event_queue=queue.Queue(), - graph=MagicMock(), - layers=[], - ) - node = SimpleNamespace( - execution_id="exec-1", - id="node-1", - node_type=BuiltinNodeTypes.LLM, - ) - - event = worker._build_fallback_failure_event(node, RuntimeError("boom")) - - assert event.start_at == fixed_time - assert event.finished_at == fixed_time - assert event.error == "boom" - assert event.node_run_result.status == WorkflowNodeExecutionStatus.FAILED - assert event.node_run_result.error == "boom" - assert event.node_run_result.error_type == "RuntimeError" - - -def test_worker_fallback_failure_event_reuses_observed_start_time() -> None: - start_at = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None) - failure_time = start_at + timedelta(seconds=5) - captured_events: list[NodeRunFailedEvent | NodeRunStartedEvent] = [] - - class FakeNode: - execution_id = "exec-1" - id = "node-1" - node_type = BuiltinNodeTypes.LLM - - def ensure_execution_id(self) -> str: - return self.execution_id - - def run(self) -> Generator[NodeRunStartedEvent, None, None]: - yield NodeRunStartedEvent( - id=self.execution_id, - node_id=self.id, - node_type=self.node_type, - node_title="LLM", - start_at=start_at, - ) - - worker = Worker( - ready_queue=MagicMock(), - event_queue=MagicMock(), - graph=MagicMock(nodes={"node-1": FakeNode()}), - layers=[], - ) - - worker._ready_queue.get.side_effect = ["node-1"] - - def put_side_effect(event: NodeRunFailedEvent | NodeRunStartedEvent) -> None: - captured_events.append(event) - if len(captured_events) == 1: - raise RuntimeError("queue boom") - worker.stop() - - worker._event_queue.put.side_effect = put_side_effect - - with patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=failure_time): - worker.run() - - fallback_event = captured_events[-1] - - assert isinstance(fallback_event, NodeRunFailedEvent) - assert fallback_event.start_at == start_at - assert fallback_event.finished_at == failure_time - assert fallback_event.error == "queue boom" - assert fallback_event.node_run_result.status == WorkflowNodeExecutionStatus.FAILED - - -def test_worker_fallback_failure_event_ignores_nested_iteration_child_start_times() -> None: - parent_start = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None) - child_start = parent_start + timedelta(seconds=3) - failure_time = parent_start + timedelta(seconds=5) - captured_events: list[NodeRunFailedEvent | NodeRunStartedEvent] = [] - - class FakeIterationNode: - execution_id = "iteration-exec" - id = "iteration-node" - node_type = BuiltinNodeTypes.ITERATION - - def ensure_execution_id(self) -> str: - return self.execution_id - - def run(self) -> Generator[NodeRunStartedEvent, None, None]: - yield NodeRunStartedEvent( - id=self.execution_id, - node_id=self.id, - node_type=self.node_type, - node_title="Iteration", - start_at=parent_start, - ) - yield NodeRunStartedEvent( - id="child-exec", - node_id="child-node", - node_type=BuiltinNodeTypes.LLM, - node_title="LLM", - start_at=child_start, - in_iteration_id=self.id, - ) - - worker = Worker( - ready_queue=MagicMock(), - event_queue=MagicMock(), - graph=MagicMock(nodes={"iteration-node": FakeIterationNode()}), - layers=[], - ) - - worker._ready_queue.get.side_effect = ["iteration-node"] - - def put_side_effect(event: NodeRunFailedEvent | NodeRunStartedEvent) -> None: - captured_events.append(event) - if len(captured_events) == 2: - raise RuntimeError("queue boom") - worker.stop() - - worker._event_queue.put.side_effect = put_side_effect - - with patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=failure_time): - worker.run() - - fallback_event = captured_events[-1] - - assert isinstance(fallback_event, NodeRunFailedEvent) - assert fallback_event.start_at == parent_start - assert fallback_event.finished_at == failure_time diff --git a/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py b/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py new file mode 100644 index 00000000000..cbc920705ca --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py @@ -0,0 +1,34 @@ +from unittest.mock import patch + +from graphon.enums import BuiltinNodeTypes + +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.workflow.nodes.agent.message_transformer import AgentMessageTransformer + + +def test_transform_passes_conversation_id_to_tool_file_message_transformer() -> None: + messages = iter(()) + transformer = AgentMessageTransformer() + + with patch.object(ToolFileMessageTransformer, "transform_tool_invoke_messages", return_value=iter(())) as transform: + result = list( + transformer.transform( + messages=messages, + tool_info={}, + parameters_for_log={}, + user_id="user-id", + tenant_id="tenant-id", + conversation_id="conversation-id", + node_type=BuiltinNodeTypes.AGENT, + node_id="node-id", + node_execution_id="execution-id", + ) + ) + + assert len(result) == 2 + transform.assert_called_once_with( + messages=messages, + user_id="user-id", + tenant_id="tenant-id", + conversation_id="conversation-id", + ) diff --git a/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py b/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py new file mode 100644 index 00000000000..59dd763b59d --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py @@ -0,0 +1,50 @@ +from types import SimpleNamespace +from unittest.mock import Mock, patch + +from graphon.model_runtime.entities.model_entities import ModelType + +from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport + + +def test_fetch_model_reuses_single_model_assembly(): + provider_configuration = SimpleNamespace( + get_current_credentials=Mock(return_value={"api_key": "x"}), + provider=SimpleNamespace(provider="openai"), + ) + model_type_instance = SimpleNamespace(get_model_schema=Mock(return_value="schema")) + provider_model_bundle = SimpleNamespace( + configuration=provider_configuration, + model_type_instance=model_type_instance, + ) + model_instance = Mock() + assembly = SimpleNamespace( + provider_manager=Mock(), + model_manager=Mock(), + ) + assembly.provider_manager.get_provider_model_bundle.return_value = provider_model_bundle + assembly.model_manager.get_model_instance.return_value = model_instance + + with patch( + "core.workflow.nodes.agent.runtime_support.create_plugin_model_assembly", + return_value=assembly, + ) as mock_assembly: + resolved_instance, resolved_schema = AgentRuntimeSupport().fetch_model( + tenant_id="tenant-1", + user_id="user-1", + value={"provider": "openai", "model": "gpt-4o-mini", "model_type": "llm"}, + ) + + assert resolved_instance is model_instance + assert resolved_schema == "schema" + mock_assembly.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") + assembly.provider_manager.get_provider_model_bundle.assert_called_once_with( + tenant_id="tenant-1", + provider="openai", + model_type=ModelType.LLM, + ) + assembly.model_manager.get_model_instance.assert_called_once_with( + tenant_id="tenant-1", + provider="openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + ) diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index fd563d1be2f..7195471eb6b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -2,13 +2,14 @@ import time import uuid from unittest.mock import MagicMock +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.nodes.answer.answer_node import AnswerNode +from graphon.runtime import GraphRuntimeState, VariablePool + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.nodes.answer.answer_node import AnswerNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables from extensions.ext_database import db from tests.workflow_test_utils import build_test_graph_init_params @@ -48,7 +49,7 @@ def test_execute_answer(): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py index 81d3f5be9c8..343bcd39193 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py @@ -1,9 +1,9 @@ import pytest +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.base.node import Node from core.workflow.node_factory import get_node_type_classes_mapping -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base.node import Node # Ensures that all production node classes are imported and registered. _ = get_node_type_classes_mapping() diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py index 972a945ca0f..b9371a34f44 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py @@ -1,19 +1,20 @@ import types from collections.abc import Mapping -from core.workflow.node_factory import get_node_type_classes_mapping -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base.node import Node +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.base.node import Node # Import concrete nodes we will assert on (numeric version path) -from dify_graph.nodes.variable_assigner.v1.node import ( +from graphon.nodes.variable_assigner.v1.node import ( VariableAssignerNode as VariableAssignerV1, ) -from dify_graph.nodes.variable_assigner.v2.node import ( +from graphon.nodes.variable_assigner.v2.node import ( VariableAssignerNode as VariableAssignerV2, ) +from core.workflow.node_factory import get_node_type_classes_mapping + def test_variable_assigner_latest_prefers_highest_numeric_version(): # Act diff --git a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py index 784e08edd24..d155124c501 100644 --- a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py @@ -1,13 +1,14 @@ -from configs import dify_config -from dify_graph.nodes.code.code_node import CodeNode -from dify_graph.nodes.code.entities import CodeLanguage, CodeNodeData -from dify_graph.nodes.code.exc import ( +from graphon.nodes.code.code_node import CodeNode +from graphon.nodes.code.entities import CodeLanguage, CodeNodeData +from graphon.nodes.code.exc import ( CodeNodeError, DepthLimitError, OutputValidationError, ) -from dify_graph.nodes.code.limits import CodeNodeLimits -from dify_graph.variables.types import SegmentType +from graphon.nodes.code.limits import CodeNodeLimits +from graphon.variables.types import SegmentType + +from configs import dify_config CodeNode._limits = CodeNodeLimits( max_string_length=dify_config.CODE_MAX_STRING_LENGTH, diff --git a/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py deleted file mode 100644 index de7ed0815eb..00000000000 --- a/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py +++ /dev/null @@ -1,352 +0,0 @@ -import pytest -from pydantic import ValidationError - -from dify_graph.nodes.code.entities import CodeLanguage, CodeNodeData -from dify_graph.variables.types import SegmentType - - -class TestCodeNodeDataOutput: - """Test suite for CodeNodeData.Output model.""" - - def test_output_with_string_type(self): - """Test Output with STRING type.""" - output = CodeNodeData.Output(type=SegmentType.STRING) - - assert output.type == SegmentType.STRING - assert output.children is None - - def test_output_with_number_type(self): - """Test Output with NUMBER type.""" - output = CodeNodeData.Output(type=SegmentType.NUMBER) - - assert output.type == SegmentType.NUMBER - assert output.children is None - - def test_output_with_boolean_type(self): - """Test Output with BOOLEAN type.""" - output = CodeNodeData.Output(type=SegmentType.BOOLEAN) - - assert output.type == SegmentType.BOOLEAN - - def test_output_with_object_type(self): - """Test Output with OBJECT type.""" - output = CodeNodeData.Output(type=SegmentType.OBJECT) - - assert output.type == SegmentType.OBJECT - - def test_output_with_array_string_type(self): - """Test Output with ARRAY_STRING type.""" - output = CodeNodeData.Output(type=SegmentType.ARRAY_STRING) - - assert output.type == SegmentType.ARRAY_STRING - - def test_output_with_array_number_type(self): - """Test Output with ARRAY_NUMBER type.""" - output = CodeNodeData.Output(type=SegmentType.ARRAY_NUMBER) - - assert output.type == SegmentType.ARRAY_NUMBER - - def test_output_with_array_object_type(self): - """Test Output with ARRAY_OBJECT type.""" - output = CodeNodeData.Output(type=SegmentType.ARRAY_OBJECT) - - assert output.type == SegmentType.ARRAY_OBJECT - - def test_output_with_array_boolean_type(self): - """Test Output with ARRAY_BOOLEAN type.""" - output = CodeNodeData.Output(type=SegmentType.ARRAY_BOOLEAN) - - assert output.type == SegmentType.ARRAY_BOOLEAN - - def test_output_with_nested_children(self): - """Test Output with nested children for OBJECT type.""" - child_output = CodeNodeData.Output(type=SegmentType.STRING) - parent_output = CodeNodeData.Output( - type=SegmentType.OBJECT, - children={"name": child_output}, - ) - - assert parent_output.type == SegmentType.OBJECT - assert parent_output.children is not None - assert "name" in parent_output.children - assert parent_output.children["name"].type == SegmentType.STRING - - def test_output_with_deeply_nested_children(self): - """Test Output with deeply nested children.""" - inner_child = CodeNodeData.Output(type=SegmentType.NUMBER) - middle_child = CodeNodeData.Output( - type=SegmentType.OBJECT, - children={"value": inner_child}, - ) - outer_output = CodeNodeData.Output( - type=SegmentType.OBJECT, - children={"nested": middle_child}, - ) - - assert outer_output.children is not None - assert outer_output.children["nested"].children is not None - assert outer_output.children["nested"].children["value"].type == SegmentType.NUMBER - - def test_output_with_multiple_children(self): - """Test Output with multiple children.""" - output = CodeNodeData.Output( - type=SegmentType.OBJECT, - children={ - "name": CodeNodeData.Output(type=SegmentType.STRING), - "age": CodeNodeData.Output(type=SegmentType.NUMBER), - "active": CodeNodeData.Output(type=SegmentType.BOOLEAN), - }, - ) - - assert output.children is not None - assert len(output.children) == 3 - assert output.children["name"].type == SegmentType.STRING - assert output.children["age"].type == SegmentType.NUMBER - assert output.children["active"].type == SegmentType.BOOLEAN - - def test_output_rejects_invalid_type(self): - """Test Output rejects invalid segment types.""" - with pytest.raises(ValidationError): - CodeNodeData.Output(type=SegmentType.FILE) - - def test_output_rejects_array_file_type(self): - """Test Output rejects ARRAY_FILE type.""" - with pytest.raises(ValidationError): - CodeNodeData.Output(type=SegmentType.ARRAY_FILE) - - -class TestCodeNodeDataDependency: - """Test suite for CodeNodeData.Dependency model.""" - - def test_dependency_basic(self): - """Test Dependency with name and version.""" - dependency = CodeNodeData.Dependency(name="numpy", version="1.24.0") - - assert dependency.name == "numpy" - assert dependency.version == "1.24.0" - - def test_dependency_with_complex_version(self): - """Test Dependency with complex version string.""" - dependency = CodeNodeData.Dependency(name="pandas", version=">=2.0.0,<3.0.0") - - assert dependency.name == "pandas" - assert dependency.version == ">=2.0.0,<3.0.0" - - def test_dependency_with_empty_version(self): - """Test Dependency with empty version.""" - dependency = CodeNodeData.Dependency(name="requests", version="") - - assert dependency.name == "requests" - assert dependency.version == "" - - -class TestCodeNodeData: - """Test suite for CodeNodeData model.""" - - def test_code_node_data_python3(self): - """Test CodeNodeData with Python3 language.""" - data = CodeNodeData( - title="Test Code Node", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="def main(): return {'result': 42}", - outputs={"result": CodeNodeData.Output(type=SegmentType.NUMBER)}, - ) - - assert data.title == "Test Code Node" - assert data.code_language == CodeLanguage.PYTHON3 - assert data.code == "def main(): return {'result': 42}" - assert "result" in data.outputs - assert data.dependencies is None - - def test_code_node_data_javascript(self): - """Test CodeNodeData with JavaScript language.""" - data = CodeNodeData( - title="JS Code Node", - variables=[], - code_language=CodeLanguage.JAVASCRIPT, - code="function main() { return { result: 'hello' }; }", - outputs={"result": CodeNodeData.Output(type=SegmentType.STRING)}, - ) - - assert data.code_language == CodeLanguage.JAVASCRIPT - assert "result" in data.outputs - assert data.outputs["result"].type == SegmentType.STRING - - def test_code_node_data_with_dependencies(self): - """Test CodeNodeData with dependencies.""" - data = CodeNodeData( - title="Code with Deps", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="import numpy as np\ndef main(): return {'sum': 10}", - outputs={"sum": CodeNodeData.Output(type=SegmentType.NUMBER)}, - dependencies=[ - CodeNodeData.Dependency(name="numpy", version="1.24.0"), - CodeNodeData.Dependency(name="pandas", version="2.0.0"), - ], - ) - - assert data.dependencies is not None - assert len(data.dependencies) == 2 - assert data.dependencies[0].name == "numpy" - assert data.dependencies[1].name == "pandas" - - def test_code_node_data_with_multiple_outputs(self): - """Test CodeNodeData with multiple outputs.""" - data = CodeNodeData( - title="Multi Output", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="def main(): return {'name': 'test', 'count': 5, 'items': ['a', 'b']}", - outputs={ - "name": CodeNodeData.Output(type=SegmentType.STRING), - "count": CodeNodeData.Output(type=SegmentType.NUMBER), - "items": CodeNodeData.Output(type=SegmentType.ARRAY_STRING), - }, - ) - - assert len(data.outputs) == 3 - assert data.outputs["name"].type == SegmentType.STRING - assert data.outputs["count"].type == SegmentType.NUMBER - assert data.outputs["items"].type == SegmentType.ARRAY_STRING - - def test_code_node_data_with_object_output(self): - """Test CodeNodeData with nested object output.""" - data = CodeNodeData( - title="Object Output", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="def main(): return {'user': {'name': 'John', 'age': 30}}", - outputs={ - "user": CodeNodeData.Output( - type=SegmentType.OBJECT, - children={ - "name": CodeNodeData.Output(type=SegmentType.STRING), - "age": CodeNodeData.Output(type=SegmentType.NUMBER), - }, - ), - }, - ) - - assert data.outputs["user"].type == SegmentType.OBJECT - assert data.outputs["user"].children is not None - assert len(data.outputs["user"].children) == 2 - - def test_code_node_data_with_array_object_output(self): - """Test CodeNodeData with array of objects output.""" - data = CodeNodeData( - title="Array Object Output", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="def main(): return {'users': [{'name': 'A'}, {'name': 'B'}]}", - outputs={ - "users": CodeNodeData.Output( - type=SegmentType.ARRAY_OBJECT, - children={ - "name": CodeNodeData.Output(type=SegmentType.STRING), - }, - ), - }, - ) - - assert data.outputs["users"].type == SegmentType.ARRAY_OBJECT - assert data.outputs["users"].children is not None - - def test_code_node_data_empty_code(self): - """Test CodeNodeData with empty code.""" - data = CodeNodeData( - title="Empty Code", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="", - outputs={}, - ) - - assert data.code == "" - assert len(data.outputs) == 0 - - def test_code_node_data_multiline_code(self): - """Test CodeNodeData with multiline code.""" - multiline_code = """ -def main(): - result = 0 - for i in range(10): - result += i - return {'sum': result} -""" - data = CodeNodeData( - title="Multiline Code", - variables=[], - code_language=CodeLanguage.PYTHON3, - code=multiline_code, - outputs={"sum": CodeNodeData.Output(type=SegmentType.NUMBER)}, - ) - - assert "for i in range(10)" in data.code - assert "result += i" in data.code - - def test_code_node_data_with_special_characters_in_code(self): - """Test CodeNodeData with special characters in code.""" - code_with_special = "def main(): return {'msg': 'Hello\\nWorld\\t!'}" - data = CodeNodeData( - title="Special Chars", - variables=[], - code_language=CodeLanguage.PYTHON3, - code=code_with_special, - outputs={"msg": CodeNodeData.Output(type=SegmentType.STRING)}, - ) - - assert "\\n" in data.code - assert "\\t" in data.code - - def test_code_node_data_with_unicode_in_code(self): - """Test CodeNodeData with unicode characters in code.""" - unicode_code = "def main(): return {'greeting': '你好世界'}" - data = CodeNodeData( - title="Unicode Code", - variables=[], - code_language=CodeLanguage.PYTHON3, - code=unicode_code, - outputs={"greeting": CodeNodeData.Output(type=SegmentType.STRING)}, - ) - - assert "你好世界" in data.code - - def test_code_node_data_empty_dependencies_list(self): - """Test CodeNodeData with empty dependencies list.""" - data = CodeNodeData( - title="No Deps", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="def main(): return {}", - outputs={}, - dependencies=[], - ) - - assert data.dependencies is not None - assert len(data.dependencies) == 0 - - def test_code_node_data_with_boolean_array_output(self): - """Test CodeNodeData with boolean array output.""" - data = CodeNodeData( - title="Boolean Array", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="def main(): return {'flags': [True, False, True]}", - outputs={"flags": CodeNodeData.Output(type=SegmentType.ARRAY_BOOLEAN)}, - ) - - assert data.outputs["flags"].type == SegmentType.ARRAY_BOOLEAN - - def test_code_node_data_with_number_array_output(self): - """Test CodeNodeData with number array output.""" - data = CodeNodeData( - title="Number Array", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="def main(): return {'values': [1, 2, 3, 4, 5]}", - outputs={"values": CodeNodeData.Output(type=SegmentType.ARRAY_NUMBER)}, - ) - - assert data.outputs["values"].type == SegmentType.ARRAY_NUMBER diff --git a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py index 859115ceb3f..fb03ae9998d 100644 --- a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py @@ -1,7 +1,8 @@ +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from core.workflow.nodes.datasource.datasource_node import DatasourceNode -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent class _VarSeg: diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_config.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_config.py deleted file mode 100644 index cd822a6f895..00000000000 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_config.py +++ /dev/null @@ -1,33 +0,0 @@ -from dify_graph.nodes.http_request import build_http_request_config - - -def test_build_http_request_config_uses_literal_defaults(): - config = build_http_request_config() - - assert config.max_connect_timeout == 10 - assert config.max_read_timeout == 600 - assert config.max_write_timeout == 600 - assert config.max_binary_size == 10 * 1024 * 1024 - assert config.max_text_size == 1 * 1024 * 1024 - assert config.ssl_verify is True - assert config.ssrf_default_max_retries == 3 - - -def test_build_http_request_config_supports_explicit_overrides(): - config = build_http_request_config( - max_connect_timeout=5, - max_read_timeout=30, - max_write_timeout=40, - max_binary_size=2048, - max_text_size=1024, - ssl_verify=False, - ssrf_default_max_retries=8, - ) - - assert config.max_connect_timeout == 5 - assert config.max_read_timeout == 30 - assert config.max_write_timeout == 40 - assert config.max_binary_size == 2048 - assert config.max_text_size == 1024 - assert config.ssl_verify is False - assert config.ssrf_default_max_retries == 8 diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py deleted file mode 100644 index fec6ad90eb1..00000000000 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py +++ /dev/null @@ -1,233 +0,0 @@ -import json -from unittest.mock import Mock, PropertyMock, patch - -import httpx -import pytest - -from dify_graph.nodes.http_request.entities import Response - - -@pytest.fixture -def mock_response(): - response = Mock(spec=httpx.Response) - response.headers = {} - return response - - -def test_is_file_with_attachment_disposition(mock_response): - """Test is_file when content-disposition header contains 'attachment'""" - mock_response.headers = {"content-disposition": "attachment; filename=test.pdf", "content-type": "application/pdf"} - response = Response(mock_response) - assert response.is_file - - -def test_is_file_with_filename_disposition(mock_response): - """Test is_file when content-disposition header contains filename parameter""" - mock_response.headers = {"content-disposition": "inline; filename=test.pdf", "content-type": "application/pdf"} - response = Response(mock_response) - assert response.is_file - - -@pytest.mark.parametrize("content_type", ["application/pdf", "image/jpeg", "audio/mp3", "video/mp4"]) -def test_is_file_with_file_content_types(mock_response, content_type): - """Test is_file with various file content types""" - mock_response.headers = {"content-type": content_type} - # Mock binary content - type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512)) - response = Response(mock_response) - assert response.is_file, f"Content type {content_type} should be identified as a file" - - -@pytest.mark.parametrize( - "content_type", - [ - "application/json", - "application/xml", - "application/javascript", - "application/x-www-form-urlencoded", - "application/yaml", - "application/graphql", - ], -) -def test_text_based_application_types(mock_response, content_type): - """Test common text-based application types are not identified as files""" - mock_response.headers = {"content-type": content_type} - response = Response(mock_response) - assert not response.is_file, f"Content type {content_type} should not be identified as a file" - - -@pytest.mark.parametrize( - ("content", "content_type"), - [ - (b'{"key": "value"}', "application/octet-stream"), - (b"[1, 2, 3]", "application/unknown"), - (b"function test() {}", "application/x-unknown"), - (b"test", "application/binary"), - (b"var x = 1;", "application/data"), - ], -) -def test_content_based_detection(mock_response, content, content_type): - """Test content-based detection for text-like content""" - mock_response.headers = {"content-type": content_type} - type(mock_response).content = PropertyMock(return_value=content) - response = Response(mock_response) - assert not response.is_file, f"Content {content} with type {content_type} should not be identified as a file" - - -@pytest.mark.parametrize( - ("content", "content_type"), - [ - (bytes([0x00, 0xFF] * 512), "application/octet-stream"), - (bytes([0x89, 0x50, 0x4E, 0x47]), "application/unknown"), # PNG magic numbers - (bytes([0xFF, 0xD8, 0xFF]), "application/binary"), # JPEG magic numbers - ], -) -def test_binary_content_detection(mock_response, content, content_type): - """Test content-based detection for binary content""" - mock_response.headers = {"content-type": content_type} - type(mock_response).content = PropertyMock(return_value=content) - response = Response(mock_response) - assert response.is_file, f"Binary content with type {content_type} should be identified as a file" - - -@pytest.mark.parametrize( - ("content_type", "expected_main_type"), - [ - ("x-world/x-vrml", "model"), # VRML 3D model - ("font/ttf", "application"), # TrueType font - ("text/csv", "text"), # CSV text file - ("unknown/xyz", None), # Unknown type - ], -) -def test_mimetype_based_detection(mock_response, content_type, expected_main_type): - """Test detection using mimetypes.guess_type for non-application content types""" - mock_response.headers = {"content-type": content_type} - type(mock_response).content = PropertyMock(return_value=bytes([0x00])) # Dummy content - - with patch("dify_graph.nodes.http_request.entities.mimetypes.guess_type") as mock_guess_type: - # Mock the return value based on expected_main_type - if expected_main_type: - mock_guess_type.return_value = (f"{expected_main_type}/subtype", None) - else: - mock_guess_type.return_value = (None, None) - - response = Response(mock_response) - - # Check if the result matches our expectation - if expected_main_type in ("application", "image", "audio", "video"): - assert response.is_file, f"Content type {content_type} should be identified as a file" - else: - assert not response.is_file, f"Content type {content_type} should not be identified as a file" - - # Verify that guess_type was called - mock_guess_type.assert_called_once() - - -def test_is_file_with_inline_disposition(mock_response): - """Test is_file when content-disposition is 'inline'""" - mock_response.headers = {"content-disposition": "inline", "content-type": "application/pdf"} - # Mock binary content - type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512)) - response = Response(mock_response) - assert response.is_file - - -def test_is_file_with_no_content_disposition(mock_response): - """Test is_file when no content-disposition header is present""" - mock_response.headers = {"content-type": "application/pdf"} - # Mock binary content - type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512)) - response = Response(mock_response) - assert response.is_file - - -# UTF-8 Encoding Tests -@pytest.mark.parametrize( - ("content_bytes", "expected_text", "description"), - [ - # Chinese UTF-8 bytes - ( - b'{"message": "\xe4\xbd\xa0\xe5\xa5\xbd\xe4\xb8\x96\xe7\x95\x8c"}', - '{"message": "你好世界"}', - "Chinese characters UTF-8", - ), - # Japanese UTF-8 bytes - ( - b'{"message": "\xe3\x81\x93\xe3\x82\x93\xe3\x81\xab\xe3\x81\xa1\xe3\x81\xaf"}', - '{"message": "こんにちは"}', - "Japanese characters UTF-8", - ), - # Korean UTF-8 bytes - ( - b'{"message": "\xec\x95\x88\xeb\x85\x95\xed\x95\x98\xec\x84\xb8\xec\x9a\x94"}', - '{"message": "안녕하세요"}', - "Korean characters UTF-8", - ), - # Arabic UTF-8 - (b'{"text": "\xd9\x85\xd8\xb1\xd8\xad\xd8\xa8\xd8\xa7"}', '{"text": "مرحبا"}', "Arabic characters UTF-8"), - # European characters UTF-8 - (b'{"text": "Caf\xc3\xa9 M\xc3\xbcnchen"}', '{"text": "Café München"}', "European accented characters"), - # Simple ASCII - (b'{"text": "Hello World"}', '{"text": "Hello World"}', "Simple ASCII text"), - ], -) -def test_text_property_utf8_decoding(mock_response, content_bytes, expected_text, description): - """Test that Response.text properly decodes UTF-8 content with charset_normalizer""" - mock_response.headers = {"content-type": "application/json; charset=utf-8"} - type(mock_response).content = PropertyMock(return_value=content_bytes) - # Mock httpx response.text to return something different (simulating potential encoding issues) - mock_response.text = "incorrect-fallback-text" # To ensure we are not falling back to httpx's text property - - response = Response(mock_response) - - # Our enhanced text property should decode properly using charset_normalizer - assert response.text == expected_text, ( - f"Failed for {description}: got {repr(response.text)}, expected {repr(expected_text)}" - ) - - -def test_text_property_fallback_to_httpx(mock_response): - """Test that Response.text falls back to httpx.text when charset_normalizer fails""" - mock_response.headers = {"content-type": "application/json"} - - # Create malformed UTF-8 bytes - malformed_bytes = b'{"text": "\xff\xfe\x00\x00 invalid"}' - type(mock_response).content = PropertyMock(return_value=malformed_bytes) - - # Mock httpx.text to return some fallback value - fallback_text = '{"text": "fallback"}' - mock_response.text = fallback_text - - response = Response(mock_response) - - # Should fall back to httpx's text when charset_normalizer fails - assert response.text == fallback_text - - -@pytest.mark.parametrize( - ("json_content", "description"), - [ - # JSON with escaped Unicode (like Flask jsonify()) - ('{"message": "\\u4f60\\u597d\\u4e16\\u754c"}', "JSON with escaped Unicode"), - # JSON with mixed escape sequences and UTF-8 - ('{"mixed": "Hello \\u4f60\\u597d"}', "Mixed escaped and regular text"), - # JSON with complex escape sequences - ('{"complex": "\\ud83d\\ude00\\u4f60\\u597d"}', "Emoji and Chinese escapes"), - ], -) -def test_text_property_with_escaped_unicode(mock_response, json_content, description): - """Test Response.text with JSON containing Unicode escape sequences""" - mock_response.headers = {"content-type": "application/json"} - - content_bytes = json_content.encode("utf-8") - type(mock_response).content = PropertyMock(return_value=content_bytes) - mock_response.text = json_content # httpx would return the same for valid UTF-8 - - response = Response(mock_response) - - # Should preserve the escape sequences (valid JSON) - assert response.text == json_content, f"Failed for {description}" - - # The text should be valid JSON that can be parsed back to proper Unicode - parsed = json.loads(response.text) - assert isinstance(parsed, dict), f"Invalid JSON for {description}" diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py index cea71954171..a5026b40cf6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -1,20 +1,20 @@ import pytest - -from configs import dify_config -from core.helper.ssrf_proxy import ssrf_proxy -from dify_graph.file.file_manager import file_manager -from dify_graph.nodes.http_request import ( +from graphon.file.file_manager import file_manager +from graphon.nodes.http_request import ( BodyData, HttpRequestNodeAuthorization, HttpRequestNodeBody, HttpRequestNodeConfig, HttpRequestNodeData, ) -from dify_graph.nodes.http_request.entities import HttpRequestNodeTimeout -from dify_graph.nodes.http_request.exc import AuthorizationConfigError -from dify_graph.nodes.http_request.executor import Executor -from dify_graph.runtime import VariablePool -from dify_graph.system_variable import SystemVariable +from graphon.nodes.http_request.entities import HttpRequestNodeTimeout +from graphon.nodes.http_request.exc import AuthorizationConfigError +from graphon.nodes.http_request.executor import Executor +from graphon.runtime import VariablePool + +from configs import dify_config +from core.helper.ssrf_proxy import ssrf_proxy +from core.workflow.system_variables import default_system_variables HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( max_connect_timeout=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, @@ -30,7 +30,7 @@ HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( def test_executor_with_json_body_and_number_variable(): # Prepare the variable pool variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "number"], 42) @@ -86,7 +86,7 @@ def test_executor_with_json_body_and_number_variable(): def test_executor_with_json_body_and_object_variable(): # Prepare the variable pool variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"}) @@ -144,7 +144,7 @@ def test_executor_with_json_body_and_object_variable(): def test_executor_with_json_body_and_nested_object_variable(): # Prepare the variable pool variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"}) @@ -201,7 +201,7 @@ def test_executor_with_json_body_and_nested_object_variable(): def test_extract_selectors_from_template_with_newline(): - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) variable_pool.add(("node_id", "custom_query"), "line1\nline2") node_data = HttpRequestNodeData( title="Test JSON Body with Nested Object Variable", @@ -231,7 +231,7 @@ def test_extract_selectors_from_template_with_newline(): def test_executor_with_form_data(): # Prepare the variable pool variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "text_field"], "Hello, World!") @@ -320,7 +320,7 @@ def test_init_headers(): node_data=node_data, timeout=timeout, http_request_config=HTTP_REQUEST_CONFIG, - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), http_client=ssrf_proxy, file_manager=file_manager, ) @@ -357,7 +357,7 @@ def test_init_params(): node_data=node_data, timeout=timeout, http_request_config=HTTP_REQUEST_CONFIG, - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), http_client=ssrf_proxy, file_manager=file_manager, ) @@ -390,7 +390,7 @@ def test_init_params(): def test_empty_api_key_raises_error_bearer(): """Test that empty API key raises AuthorizationConfigError for bearer auth.""" - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -417,7 +417,7 @@ def test_empty_api_key_raises_error_bearer(): def test_empty_api_key_raises_error_basic(): """Test that empty API key raises AuthorizationConfigError for basic auth.""" - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -444,7 +444,7 @@ def test_empty_api_key_raises_error_basic(): def test_empty_api_key_raises_error_custom(): """Test that empty API key raises AuthorizationConfigError for custom auth.""" - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -471,7 +471,7 @@ def test_empty_api_key_raises_error_custom(): def test_whitespace_only_api_key_raises_error(): """Test that whitespace-only API key raises AuthorizationConfigError.""" - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -498,7 +498,7 @@ def test_whitespace_only_api_key_raises_error(): def test_valid_api_key_works(): """Test that valid API key works correctly for bearer auth.""" - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -537,7 +537,7 @@ def test_executor_with_json_body_and_unquoted_uuid_variable(): test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2" variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "uuid"], test_uuid) @@ -584,7 +584,7 @@ def test_executor_with_json_body_and_unquoted_uuid_with_newlines(): test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2" variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "uuid"], test_uuid) @@ -625,7 +625,7 @@ def test_executor_with_json_body_and_unquoted_uuid_with_newlines(): def test_executor_with_json_body_preserves_numbers_and_strings(): """Test that numbers are preserved and string values are properly quoted.""" variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["node", "count"], 42) diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index 5e34bf1d94c..4705b3f76ec 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -3,16 +3,17 @@ from typing import Any import httpx import pytest +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file.file_manager import file_manager +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig +from graphon.nodes.http_request.entities import HttpRequestNodeTimeout, Response +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.file.file_manager import file_manager -from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig -from dify_graph.nodes.http_request.entities import HttpRequestNodeTimeout, Response -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.node_runtime import DifyFileReferenceFactory +from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( @@ -109,7 +110,7 @@ def _build_http_node( call_depth=0, ) graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}), + variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}), start_at=time.perf_counter(), ) return HttpRequestNode( @@ -121,6 +122,7 @@ def _build_http_node( http_client=ssrf_proxy, tool_file_manager_factory=ToolFileManager, file_manager=file_manager, + file_reference_factory=DifyFileReferenceFactory(graph_init_params.run_context), ) @@ -161,7 +163,7 @@ def test_run_passes_node_data_ssl_verify_to_executor(monkeypatch: pytest.MonkeyP ) ) - monkeypatch.setattr("dify_graph.nodes.http_request.node.Executor", FakeExecutor) + monkeypatch.setattr("graphon.nodes.http_request.node.Executor", FakeExecutor) result = node._run() diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py index d52dfa2a65e..d16e1233ac9 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py @@ -1,5 +1,6 @@ -from dify_graph.nodes.human_input.entities import EmailDeliveryConfig, EmailRecipients -from dify_graph.runtime import VariablePool +from graphon.runtime import VariablePool + +from core.workflow.human_input_compat import EmailDeliveryConfig, EmailRecipients def test_render_body_template_replaces_variable_values(): diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py index 55aa62a1c01..a2cdbbf132e 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py @@ -2,42 +2,138 @@ Unit tests for human input node entities. """ +from collections.abc import Mapping +from dataclasses import dataclass, field +from datetime import datetime, timedelta from types import SimpleNamespace +from typing import Any from unittest.mock import MagicMock import pytest -from pydantic import ValidationError - -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.node_events import PauseRequestedEvent -from dify_graph.node_events.node import StreamCompletedEvent -from dify_graph.nodes.human_input.entities import ( - EmailDeliveryConfig, - EmailDeliveryMethod, - EmailRecipients, - ExternalRecipient, +from graphon.entities import GraphInitParams +from graphon.node_events import PauseRequestedEvent +from graphon.node_events.node import StreamCompletedEvent +from graphon.nodes.human_input.entities import ( FormInput, FormInputDefault, HumanInputNodeData, - MemberRecipient, UserAction, - WebAppDeliveryMethod, - _WebAppDeliveryConfig, ) -from dify_graph.nodes.human_input.enums import ( +from graphon.nodes.human_input.enums import ( ButtonStyle, - DeliveryMethodType, - EmailRecipientType, FormInputType, + HumanInputFormStatus, PlaceholderType, TimeoutUnit, ) -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.repositories.human_input_form_repository import HumanInputFormRepository -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from tests.unit_tests.core.workflow.graph_engine.human_input_test_utils import InMemoryHumanInputFormRepository +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.runtime import GraphRuntimeState, VariablePool +from pydantic import ValidationError + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRecipientEntity, + HumanInputFormRepository, +) +from core.workflow.human_input_compat import ( + DeliveryMethodType, + EmailDeliveryConfig, + EmailDeliveryMethod, + EmailRecipients, + EmailRecipientType, + ExternalRecipient, + MemberRecipient, + WebAppDeliveryMethod, + _WebAppDeliveryConfig, +) +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables +from libs.datetime_utils import naive_utc_now + + +@dataclass +class _InMemoryFormEntity(HumanInputFormEntity): + form_id: str + rendered: str + token: str | None = None + action_id: str | None = None + data: Mapping[str, Any] | None = None + is_submitted: bool = False + status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING + expiration: datetime = field(default_factory=lambda: naive_utc_now() + timedelta(days=1)) + + @property + def id(self) -> str: + return self.form_id + + @property + def submission_token(self) -> str | None: + return self.token + + @property + def recipients(self) -> list[HumanInputFormRecipientEntity]: + return [] + + @property + def rendered_content(self) -> str: + return self.rendered + + @property + def selected_action_id(self) -> str | None: + return self.action_id + + @property + def submitted_data(self) -> Mapping[str, Any] | None: + return self.data + + @property + def submitted(self) -> bool: + return self.is_submitted + + @property + def status(self) -> HumanInputFormStatus: + return self.status_value + + @property + def expiration_time(self) -> datetime: + return self.expiration + + +class InMemoryHumanInputFormRepository(HumanInputFormRepository): + """Minimal in-memory repository for Dify-owned HumanInputNode behavior tests.""" + + def __init__(self) -> None: + self._form_counter = 0 + self.created_params: list[FormCreateParams] = [] + self.created_forms: list[_InMemoryFormEntity] = [] + self._forms_by_node_id: dict[str, _InMemoryFormEntity] = {} + + def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: + self.created_params.append(params) + self._form_counter += 1 + form_id = f"form-{self._form_counter}" + entity = _InMemoryFormEntity( + form_id=form_id, + rendered=params.rendered_content, + token=f"token-{form_id}", + ) + self.created_forms.append(entity) + self._forms_by_node_id[params.node_id] = entity + return entity + + def get_form(self, node_id: str) -> HumanInputFormEntity | None: + return self._forms_by_node_id.get(node_id) + + def set_submission(self, *, action_id: str, form_data: Mapping[str, Any] | None = None) -> None: + if not self.created_forms: + raise AssertionError("no form has been created to attach submission data") + entity = self.created_forms[-1] + entity.action_id = action_id + entity.data = form_data or {} + entity.is_submitted = True + entity.status_value = HumanInputFormStatus.SUBMITTED class TestDeliveryMethod: @@ -54,9 +150,9 @@ class TestDeliveryMethod: def test_email_delivery_method(self): """Test email delivery method creation.""" recipients = EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(type=EmailRecipientType.MEMBER, user_id="test-user-123"), + MemberRecipient(type=EmailRecipientType.MEMBER, reference_id="test-user-123"), ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="test@example.com"), ], ) @@ -193,7 +289,7 @@ class TestHumanInputNodeData: EmailDeliveryMethod( enabled=False, # Disabled method should be fine config=EmailDeliveryConfig( - subject="Hi there", body="", recipients=EmailRecipients(whole_workspace=True) + subject="Hi there", body="", recipients=EmailRecipients(include_bound_group=True) ), ), ] @@ -212,7 +308,7 @@ class TestHumanInputNodeData: assert node_data.title == "Test Node" assert node_data.desc is None - assert node_data.delivery_methods == [] + assert node_data.model_dump().get("delivery_methods") is None assert node_data.form_content == "" assert node_data.inputs == [] assert node_data.user_actions == [] @@ -261,10 +357,10 @@ class TestRecipients: def test_member_recipient(self): """Test member recipient creation.""" - recipient = MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123") + recipient = MemberRecipient(type=EmailRecipientType.MEMBER, reference_id="user-123") assert recipient.type == EmailRecipientType.MEMBER - assert recipient.user_id == "user-123" + assert recipient.reference_id == "user-123" def test_external_recipient(self): """Test external recipient creation.""" @@ -273,37 +369,46 @@ class TestRecipients: assert recipient.type == EmailRecipientType.EXTERNAL assert recipient.email == "test@example.com" - def test_email_recipients_whole_workspace(self): - """Test email recipients with whole workspace enabled.""" + def test_email_recipients_bound_group(self): + """Test email recipients with the bound group enabled.""" recipients = EmailRecipients( - whole_workspace=True, items=[MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123")] + include_bound_group=True, + items=[MemberRecipient(type=EmailRecipientType.MEMBER, reference_id="user-123")], ) - assert recipients.whole_workspace is True - assert len(recipients.items) == 1 # Items are preserved even when whole_workspace is True + assert recipients.include_bound_group is True + assert len(recipients.items) == 1 # Items are preserved even when include_bound_group is True def test_email_recipients_specific_users(self): """Test email recipients with specific users.""" recipients = EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123"), + MemberRecipient(type=EmailRecipientType.MEMBER, reference_id="user-123"), ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="external@example.com"), ], ) - assert recipients.whole_workspace is False + assert recipients.include_bound_group is False assert len(recipients.items) == 2 - assert recipients.items[0].user_id == "user-123" + assert recipients.items[0].reference_id == "user-123" assert recipients.items[1].email == "external@example.com" + def test_legacy_recipient_keys_are_rejected(self): + with pytest.raises(ValidationError): + MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123") + + recipients = EmailRecipients(whole_workspace=True, items=[]) + assert recipients.include_bound_group is True + assert recipients.items == [] + class TestHumanInputNodeVariableResolution: """Tests for resolving variable-based defaults in HumanInputNode.""" def test_resolves_variable_defaults(self): variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -353,17 +458,19 @@ class TestHumanInputNodeVariableResolution: mock_repo.create_form.return_value = SimpleNamespace( id="form-1", rendered_content="Provide your name", - web_app_token="token", + submission_token="token", recipients=[], submitted=False, ) + runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) + runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] node = HumanInputNode( id=config["id"], config=config, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, - form_repository=mock_repo, + runtime=runtime, ) run_result = node._run() @@ -378,7 +485,7 @@ class TestHumanInputNodeVariableResolution: def test_debugger_falls_back_to_recipient_token_when_webapp_disabled(self): variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -416,28 +523,96 @@ class TestHumanInputNodeVariableResolution: mock_repo.create_form.return_value = SimpleNamespace( id="form-2", rendered_content="Provide your name", - web_app_token="console-token", + submission_token="console-token", recipients=[SimpleNamespace(token="recipient-token")], submitted=False, ) + runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) + runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] node = HumanInputNode( id=config["id"], config=config, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, - form_repository=mock_repo, + runtime=runtime, ) run_result = node._run() pause_event = next(run_result) assert isinstance(pause_event, PauseRequestedEvent) - assert pause_event.reason.form_token == "console-token" + assert not hasattr(pause_event.reason, "form_token") + + def test_webapp_runtime_keeps_form_visible_in_ui_when_webapp_delivery_is_enabled(self): + variable_pool = VariablePool( + system_variables=build_system_variables( + user_id="user", + app_id="app", + workflow_id="workflow", + workflow_execution_id="exec-4", + ), + user_inputs={}, + conversation_variables=[], + ) + runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) + graph_init_params = GraphInitParams( + workflow_id="workflow", + graph_config={"nodes": [], "edges": []}, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant", + "app_id": "app", + "user_id": "end-user-1", + "user_from": "end-user", + "invoke_from": "web-app", + } + }, + call_depth=0, + ) + + config = { + "id": "human", + "data": { + "type": "human-input", + "title": "Human Input", + "form_content": "Provide your name", + "inputs": [], + "user_actions": [{"id": "submit", "title": "Submit"}], + "delivery_methods": [{"enabled": True, "type": "webapp", "config": {}}], + }, + } + + mock_repo = MagicMock(spec=HumanInputFormRepository) + mock_repo.get_form.return_value = None + mock_repo.create_form.return_value = SimpleNamespace( + id="form-4", + rendered_content="Provide your name", + submission_token="token", + recipients=[], + submitted=False, + ) + + runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) + runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] + node = HumanInputNode( + id=config["id"], + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + runtime=runtime, + ) + + run_result = node._run() + pause_event = next(run_result) + + assert isinstance(pause_event, PauseRequestedEvent) + params = mock_repo.create_form.call_args.args[0] + assert params.display_in_ui is True def test_debugger_debug_mode_overrides_email_recipients(self): variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user-123", app_id="app", workflow_id="workflow", @@ -472,7 +647,7 @@ class TestHumanInputNodeVariableResolution: enabled=True, config=EmailDeliveryConfig( recipients=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="target@example.com")], ), subject="Subject", @@ -489,17 +664,19 @@ class TestHumanInputNodeVariableResolution: mock_repo.create_form.return_value = SimpleNamespace( id="form-3", rendered_content="Provide your name", - web_app_token="token", + submission_token="token", recipients=[], submitted=False, ) + runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) + runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] node = HumanInputNode( id=config["id"], config=config, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, - form_repository=mock_repo, + runtime=runtime, ) run_result = node._run() @@ -511,11 +688,11 @@ class TestHumanInputNodeVariableResolution: method = params.delivery_methods[0] assert isinstance(method, EmailDeliveryMethod) assert method.config.debug_mode is True - assert method.config.recipients.whole_workspace is False + assert method.config.recipients.include_bound_group is False assert len(method.config.recipients.items) == 1 recipient = method.config.recipients.items[0] assert isinstance(recipient, MemberRecipient) - assert recipient.user_id == "user-123" + assert recipient.reference_id == "user-123" class TestValidation: @@ -552,7 +729,7 @@ class TestHumanInputNodeRenderedContent: def test_replaces_outputs_placeholders_after_submission(self): variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -591,12 +768,14 @@ class TestHumanInputNodeRenderedContent: config = {"id": "human", "data": node_data.model_dump()} form_repository = InMemoryHumanInputFormRepository() + runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) + runtime._build_form_repository = MagicMock(return_value=form_repository) # type: ignore[attr-defined] node = HumanInputNode( id=config["id"], config=config, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, - form_repository=form_repository, + runtime=runtime, ) pause_gen = node._run() diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py index b0ed47158da..52802c7ce1e 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py @@ -1,18 +1,20 @@ import datetime from types import SimpleNamespace -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.graph_events import ( +from graphon.entities import GraphInitParams +from graphon.enums import BuiltinNodeTypes +from graphon.graph_events import ( NodeRunHumanInputFormFilledEvent, NodeRunHumanInputFormTimeoutEvent, NodeRunStartedEvent, ) -from dify_graph.nodes.human_input.enums import HumanInputFormStatus -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from graphon.nodes.human_input.enums import HumanInputFormStatus +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.runtime import GraphRuntimeState, VariablePool + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import default_system_variables from libs.datetime_utils import naive_utc_now @@ -25,7 +27,7 @@ class _FakeFormRepository: def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#}}") -> HumanInputNode: - system_variables = SystemVariable.default() + system_variables = default_system_variables() graph_runtime_state = GraphRuntimeState( variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]), start_at=0.0, @@ -85,11 +87,12 @@ def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name# graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) def _build_timeout_node() -> HumanInputNode: - system_variables = SystemVariable.default() + system_variables = default_system_variables() graph_runtime_state = GraphRuntimeState( variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]), start_at=0.0, @@ -149,6 +152,7 @@ def _build_timeout_node() -> HumanInputNode: graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py b/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py deleted file mode 100644 index 93c199514e6..00000000000 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py +++ /dev/null @@ -1,339 +0,0 @@ -from dify_graph.nodes.iteration.entities import ( - ErrorHandleMode, - IterationNodeData, - IterationStartNodeData, - IterationState, -) - - -class TestErrorHandleMode: - """Test suite for ErrorHandleMode enum.""" - - def test_terminated_value(self): - """Test TERMINATED enum value.""" - assert ErrorHandleMode.TERMINATED == "terminated" - assert ErrorHandleMode.TERMINATED.value == "terminated" - - def test_continue_on_error_value(self): - """Test CONTINUE_ON_ERROR enum value.""" - assert ErrorHandleMode.CONTINUE_ON_ERROR == "continue-on-error" - assert ErrorHandleMode.CONTINUE_ON_ERROR.value == "continue-on-error" - - def test_remove_abnormal_output_value(self): - """Test REMOVE_ABNORMAL_OUTPUT enum value.""" - assert ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT == "remove-abnormal-output" - assert ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT.value == "remove-abnormal-output" - - def test_error_handle_mode_is_str_enum(self): - """Test ErrorHandleMode is a string enum.""" - assert isinstance(ErrorHandleMode.TERMINATED, str) - assert isinstance(ErrorHandleMode.CONTINUE_ON_ERROR, str) - assert isinstance(ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT, str) - - def test_error_handle_mode_comparison(self): - """Test ErrorHandleMode can be compared with strings.""" - assert ErrorHandleMode.TERMINATED == "terminated" - assert ErrorHandleMode.CONTINUE_ON_ERROR == "continue-on-error" - - def test_all_error_handle_modes(self): - """Test all ErrorHandleMode values are accessible.""" - modes = list(ErrorHandleMode) - - assert len(modes) == 3 - assert ErrorHandleMode.TERMINATED in modes - assert ErrorHandleMode.CONTINUE_ON_ERROR in modes - assert ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT in modes - - -class TestIterationNodeData: - """Test suite for IterationNodeData model.""" - - def test_iteration_node_data_basic(self): - """Test IterationNodeData with basic configuration.""" - data = IterationNodeData( - title="Test Iteration", - iterator_selector=["node1", "output"], - output_selector=["iteration", "result"], - ) - - assert data.title == "Test Iteration" - assert data.iterator_selector == ["node1", "output"] - assert data.output_selector == ["iteration", "result"] - - def test_iteration_node_data_default_values(self): - """Test IterationNodeData default values.""" - data = IterationNodeData( - title="Default Test", - iterator_selector=["start", "items"], - output_selector=["iter", "out"], - ) - - assert data.parent_loop_id is None - assert data.is_parallel is False - assert data.parallel_nums == 10 - assert data.error_handle_mode == ErrorHandleMode.TERMINATED - assert data.flatten_output is True - - def test_iteration_node_data_parallel_mode(self): - """Test IterationNodeData with parallel mode enabled.""" - data = IterationNodeData( - title="Parallel Iteration", - iterator_selector=["node", "list"], - output_selector=["iter", "output"], - is_parallel=True, - parallel_nums=5, - ) - - assert data.is_parallel is True - assert data.parallel_nums == 5 - - def test_iteration_node_data_custom_parallel_nums(self): - """Test IterationNodeData with custom parallel numbers.""" - data = IterationNodeData( - title="Custom Parallel", - iterator_selector=["a", "b"], - output_selector=["c", "d"], - parallel_nums=20, - ) - - assert data.parallel_nums == 20 - - def test_iteration_node_data_continue_on_error(self): - """Test IterationNodeData with continue on error mode.""" - data = IterationNodeData( - title="Continue Error", - iterator_selector=["x", "y"], - output_selector=["z", "w"], - error_handle_mode=ErrorHandleMode.CONTINUE_ON_ERROR, - ) - - assert data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR - - def test_iteration_node_data_remove_abnormal_output(self): - """Test IterationNodeData with remove abnormal output mode.""" - data = IterationNodeData( - title="Remove Abnormal", - iterator_selector=["input", "array"], - output_selector=["output", "result"], - error_handle_mode=ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT, - ) - - assert data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT - - def test_iteration_node_data_flatten_output_disabled(self): - """Test IterationNodeData with flatten output disabled.""" - data = IterationNodeData( - title="No Flatten", - iterator_selector=["a"], - output_selector=["b"], - flatten_output=False, - ) - - assert data.flatten_output is False - - def test_iteration_node_data_with_parent_loop_id(self): - """Test IterationNodeData with parent loop ID.""" - data = IterationNodeData( - title="Nested Loop", - iterator_selector=["parent", "items"], - output_selector=["child", "output"], - parent_loop_id="parent_loop_123", - ) - - assert data.parent_loop_id == "parent_loop_123" - - def test_iteration_node_data_complex_selectors(self): - """Test IterationNodeData with complex selectors.""" - data = IterationNodeData( - title="Complex Selectors", - iterator_selector=["node1", "output", "data", "items"], - output_selector=["iteration", "result", "value"], - ) - - assert len(data.iterator_selector) == 4 - assert len(data.output_selector) == 3 - - def test_iteration_node_data_all_options(self): - """Test IterationNodeData with all options configured.""" - data = IterationNodeData( - title="Full Config", - iterator_selector=["start", "list"], - output_selector=["end", "result"], - parent_loop_id="outer_loop", - is_parallel=True, - parallel_nums=15, - error_handle_mode=ErrorHandleMode.CONTINUE_ON_ERROR, - flatten_output=False, - ) - - assert data.title == "Full Config" - assert data.parent_loop_id == "outer_loop" - assert data.is_parallel is True - assert data.parallel_nums == 15 - assert data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR - assert data.flatten_output is False - - -class TestIterationStartNodeData: - """Test suite for IterationStartNodeData model.""" - - def test_iteration_start_node_data_basic(self): - """Test IterationStartNodeData basic creation.""" - data = IterationStartNodeData(title="Iteration Start") - - assert data.title == "Iteration Start" - - def test_iteration_start_node_data_with_description(self): - """Test IterationStartNodeData with description.""" - data = IterationStartNodeData( - title="Start Node", - desc="This is the start of iteration", - ) - - assert data.title == "Start Node" - assert data.desc == "This is the start of iteration" - - -class TestIterationState: - """Test suite for IterationState model.""" - - def test_iteration_state_default_values(self): - """Test IterationState default values.""" - state = IterationState() - - assert state.outputs == [] - assert state.current_output is None - - def test_iteration_state_with_outputs(self): - """Test IterationState with outputs.""" - state = IterationState(outputs=["result1", "result2", "result3"]) - - assert len(state.outputs) == 3 - assert state.outputs[0] == "result1" - assert state.outputs[2] == "result3" - - def test_iteration_state_with_current_output(self): - """Test IterationState with current output.""" - state = IterationState(current_output="current_value") - - assert state.current_output == "current_value" - - def test_iteration_state_get_last_output_with_outputs(self): - """Test get_last_output with outputs present.""" - state = IterationState(outputs=["first", "second", "last"]) - - result = state.get_last_output() - - assert result == "last" - - def test_iteration_state_get_last_output_empty(self): - """Test get_last_output with empty outputs.""" - state = IterationState(outputs=[]) - - result = state.get_last_output() - - assert result is None - - def test_iteration_state_get_last_output_single(self): - """Test get_last_output with single output.""" - state = IterationState(outputs=["only_one"]) - - result = state.get_last_output() - - assert result == "only_one" - - def test_iteration_state_get_current_output(self): - """Test get_current_output method.""" - state = IterationState(current_output={"key": "value"}) - - result = state.get_current_output() - - assert result == {"key": "value"} - - def test_iteration_state_get_current_output_none(self): - """Test get_current_output when None.""" - state = IterationState() - - result = state.get_current_output() - - assert result is None - - def test_iteration_state_with_complex_outputs(self): - """Test IterationState with complex output types.""" - state = IterationState( - outputs=[ - {"id": 1, "name": "first"}, - {"id": 2, "name": "second"}, - [1, 2, 3], - "string_output", - ] - ) - - assert len(state.outputs) == 4 - assert state.outputs[0] == {"id": 1, "name": "first"} - assert state.outputs[2] == [1, 2, 3] - - def test_iteration_state_with_none_outputs(self): - """Test IterationState with None values in outputs.""" - state = IterationState(outputs=["value1", None, "value3"]) - - assert len(state.outputs) == 3 - assert state.outputs[1] is None - - def test_iteration_state_get_last_output_with_none(self): - """Test get_last_output when last output is None.""" - state = IterationState(outputs=["first", None]) - - result = state.get_last_output() - - assert result is None - - def test_iteration_state_metadata_class(self): - """Test IterationState.MetaData class.""" - metadata = IterationState.MetaData(iterator_length=10) - - assert metadata.iterator_length == 10 - - def test_iteration_state_metadata_different_lengths(self): - """Test IterationState.MetaData with different lengths.""" - metadata1 = IterationState.MetaData(iterator_length=0) - metadata2 = IterationState.MetaData(iterator_length=100) - metadata3 = IterationState.MetaData(iterator_length=1000000) - - assert metadata1.iterator_length == 0 - assert metadata2.iterator_length == 100 - assert metadata3.iterator_length == 1000000 - - def test_iteration_state_outputs_modification(self): - """Test modifying IterationState outputs.""" - state = IterationState(outputs=[]) - - state.outputs.append("new_output") - state.outputs.append("another_output") - - assert len(state.outputs) == 2 - assert state.get_last_output() == "another_output" - - def test_iteration_state_current_output_update(self): - """Test updating current_output.""" - state = IterationState() - - state.current_output = "first_value" - assert state.get_current_output() == "first_value" - - state.current_output = "updated_value" - assert state.get_current_output() == "updated_value" - - def test_iteration_state_with_numeric_outputs(self): - """Test IterationState with numeric outputs.""" - state = IterationState(outputs=[1, 2, 3, 4, 5]) - - assert state.get_last_output() == 5 - assert len(state.outputs) == 5 - - def test_iteration_state_with_boolean_outputs(self): - """Test IterationState with boolean outputs.""" - state = IterationState(outputs=[True, False, True]) - - assert state.get_last_output() is True - assert state.outputs[1] is False diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py deleted file mode 100644 index fdf5f4d1f80..00000000000 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py +++ /dev/null @@ -1,438 +0,0 @@ -from dify_graph.entities.graph_config import NodeConfigDictAdapter -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData -from dify_graph.nodes.iteration.exc import ( - InvalidIteratorValueError, - IterationGraphNotFoundError, - IterationIndexNotFoundError, - IterationNodeError, - IteratorVariableNotFoundError, - StartNodeIdNotFoundError, -) -from dify_graph.nodes.iteration.iteration_node import IterationNode - - -class TestIterationNodeExceptions: - """Test suite for iteration node exceptions.""" - - def test_iteration_node_error_is_value_error(self): - """Test IterationNodeError inherits from ValueError.""" - error = IterationNodeError("test error") - - assert isinstance(error, ValueError) - assert str(error) == "test error" - - def test_iterator_variable_not_found_error(self): - """Test IteratorVariableNotFoundError.""" - error = IteratorVariableNotFoundError("Iterator variable not found") - - assert isinstance(error, IterationNodeError) - assert isinstance(error, ValueError) - assert "Iterator variable not found" in str(error) - - def test_invalid_iterator_value_error(self): - """Test InvalidIteratorValueError.""" - error = InvalidIteratorValueError("Invalid iterator value") - - assert isinstance(error, IterationNodeError) - assert "Invalid iterator value" in str(error) - - def test_start_node_id_not_found_error(self): - """Test StartNodeIdNotFoundError.""" - error = StartNodeIdNotFoundError("Start node ID not found") - - assert isinstance(error, IterationNodeError) - assert "Start node ID not found" in str(error) - - def test_iteration_graph_not_found_error(self): - """Test IterationGraphNotFoundError.""" - error = IterationGraphNotFoundError("Iteration graph not found") - - assert isinstance(error, IterationNodeError) - assert "Iteration graph not found" in str(error) - - def test_iteration_index_not_found_error(self): - """Test IterationIndexNotFoundError.""" - error = IterationIndexNotFoundError("Iteration index not found") - - assert isinstance(error, IterationNodeError) - assert "Iteration index not found" in str(error) - - def test_exception_with_empty_message(self): - """Test exception with empty message.""" - error = IterationNodeError("") - - assert str(error) == "" - - def test_exception_with_detailed_message(self): - """Test exception with detailed message.""" - error = IteratorVariableNotFoundError("Variable 'items' not found in node 'start_node'") - - assert "items" in str(error) - assert "start_node" in str(error) - - def test_all_exceptions_inherit_from_base(self): - """Test all exceptions inherit from IterationNodeError.""" - exceptions = [ - IteratorVariableNotFoundError("test"), - InvalidIteratorValueError("test"), - StartNodeIdNotFoundError("test"), - IterationGraphNotFoundError("test"), - IterationIndexNotFoundError("test"), - ] - - for exc in exceptions: - assert isinstance(exc, IterationNodeError) - assert isinstance(exc, ValueError) - - -class TestIterationNodeClassAttributes: - """Test suite for IterationNode class attributes.""" - - def test_node_type(self): - """Test IterationNode node_type attribute.""" - assert IterationNode.node_type == BuiltinNodeTypes.ITERATION - - def test_version(self): - """Test IterationNode version method.""" - version = IterationNode.version() - - assert version == "1" - - -class TestIterationNodeDefaultConfig: - """Test suite for IterationNode get_default_config.""" - - def test_get_default_config_returns_dict(self): - """Test get_default_config returns a dictionary.""" - config = IterationNode.get_default_config() - - assert isinstance(config, dict) - - def test_get_default_config_type(self): - """Test get_default_config includes type.""" - config = IterationNode.get_default_config() - - assert config.get("type") == "iteration" - - def test_get_default_config_has_config_section(self): - """Test get_default_config has config section.""" - config = IterationNode.get_default_config() - - assert "config" in config - assert isinstance(config["config"], dict) - - def test_get_default_config_is_parallel_default(self): - """Test get_default_config is_parallel default value.""" - config = IterationNode.get_default_config() - - assert config["config"]["is_parallel"] is False - - def test_get_default_config_parallel_nums_default(self): - """Test get_default_config parallel_nums default value.""" - config = IterationNode.get_default_config() - - assert config["config"]["parallel_nums"] == 10 - - def test_get_default_config_error_handle_mode_default(self): - """Test get_default_config error_handle_mode default value.""" - config = IterationNode.get_default_config() - - assert config["config"]["error_handle_mode"] == ErrorHandleMode.TERMINATED - - def test_get_default_config_flatten_output_default(self): - """Test get_default_config flatten_output default value.""" - config = IterationNode.get_default_config() - - assert config["config"]["flatten_output"] is True - - def test_get_default_config_with_none_filters(self): - """Test get_default_config with None filters.""" - config = IterationNode.get_default_config(filters=None) - - assert config is not None - assert "type" in config - - def test_get_default_config_with_empty_filters(self): - """Test get_default_config with empty filters.""" - config = IterationNode.get_default_config(filters={}) - - assert config is not None - - -class TestIterationNodeInitialization: - """Test suite for IterationNode initialization.""" - - def test_init_node_data_basic(self): - """Test init_node_data with basic configuration.""" - node = IterationNode.__new__(IterationNode) - data = { - "title": "Test Iteration", - "iterator_selector": ["start", "items"], - "output_selector": ["iteration", "result"], - } - - node.init_node_data(data) - - assert node._node_data.title == "Test Iteration" - assert node._node_data.iterator_selector == ["start", "items"] - - def test_init_node_data_with_parallel(self): - """Test init_node_data with parallel configuration.""" - node = IterationNode.__new__(IterationNode) - data = { - "title": "Parallel Iteration", - "iterator_selector": ["node", "list"], - "output_selector": ["out", "result"], - "is_parallel": True, - "parallel_nums": 5, - } - - node.init_node_data(data) - - assert node._node_data.is_parallel is True - assert node._node_data.parallel_nums == 5 - - def test_init_node_data_with_error_handle_mode(self): - """Test init_node_data with error handle mode.""" - node = IterationNode.__new__(IterationNode) - data = { - "title": "Error Handle Test", - "iterator_selector": ["a", "b"], - "output_selector": ["c", "d"], - "error_handle_mode": "continue-on-error", - } - - node.init_node_data(data) - - assert node._node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR - - def test_get_title(self): - """Test _get_title method.""" - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="My Iteration", - iterator_selector=["x"], - output_selector=["y"], - ) - - assert node._get_title() == "My Iteration" - - def test_get_description_none(self): - """Test _get_description returns None when not set.""" - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="Test", - iterator_selector=["a"], - output_selector=["b"], - ) - - assert node._get_description() is None - - def test_get_description_with_value(self): - """Test _get_description with value.""" - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="Test", - desc="This is a description", - iterator_selector=["a"], - output_selector=["b"], - ) - - assert node._get_description() == "This is a description" - - def test_node_data_property(self): - """Test node_data property returns node data.""" - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="Base Test", - iterator_selector=["x"], - output_selector=["y"], - ) - - result = node.node_data - - assert result == node._node_data - - -class TestIterationNodeDataValidation: - """Test suite for IterationNodeData validation scenarios.""" - - def test_valid_iteration_node_data(self): - """Test valid IterationNodeData creation.""" - data = IterationNodeData( - title="Valid Iteration", - iterator_selector=["start", "items"], - output_selector=["end", "result"], - ) - - assert data.title == "Valid Iteration" - - def test_iteration_node_data_with_all_error_modes(self): - """Test IterationNodeData with all error handle modes.""" - modes = [ - ErrorHandleMode.TERMINATED, - ErrorHandleMode.CONTINUE_ON_ERROR, - ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT, - ] - - for mode in modes: - data = IterationNodeData( - title=f"Test {mode}", - iterator_selector=["a"], - output_selector=["b"], - error_handle_mode=mode, - ) - assert data.error_handle_mode == mode - - def test_iteration_node_data_parallel_configuration(self): - """Test IterationNodeData parallel configuration combinations.""" - configs = [ - (False, 10), - (True, 1), - (True, 5), - (True, 20), - (True, 100), - ] - - for is_parallel, parallel_nums in configs: - data = IterationNodeData( - title="Parallel Test", - iterator_selector=["x"], - output_selector=["y"], - is_parallel=is_parallel, - parallel_nums=parallel_nums, - ) - assert data.is_parallel == is_parallel - assert data.parallel_nums == parallel_nums - - def test_iteration_node_data_flatten_output_options(self): - """Test IterationNodeData flatten_output options.""" - data_flatten = IterationNodeData( - title="Flatten True", - iterator_selector=["a"], - output_selector=["b"], - flatten_output=True, - ) - - data_no_flatten = IterationNodeData( - title="Flatten False", - iterator_selector=["a"], - output_selector=["b"], - flatten_output=False, - ) - - assert data_flatten.flatten_output is True - assert data_no_flatten.flatten_output is False - - def test_iteration_node_data_complex_selectors(self): - """Test IterationNodeData with complex selectors.""" - data = IterationNodeData( - title="Complex", - iterator_selector=["node1", "output", "data", "items", "list"], - output_selector=["iteration", "result", "value", "final"], - ) - - assert len(data.iterator_selector) == 5 - assert len(data.output_selector) == 4 - - def test_iteration_node_data_single_element_selectors(self): - """Test IterationNodeData with single element selectors.""" - data = IterationNodeData( - title="Single", - iterator_selector=["items"], - output_selector=["result"], - ) - - assert len(data.iterator_selector) == 1 - assert len(data.output_selector) == 1 - - -class TestIterationNodeErrorStrategies: - """Test suite for IterationNode error strategies.""" - - def test_get_error_strategy_default(self): - """Test _get_error_strategy with default value.""" - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="Test", - iterator_selector=["a"], - output_selector=["b"], - ) - - result = node._get_error_strategy() - - assert result is None or result == node._node_data.error_strategy - - def test_get_retry_config(self): - """Test _get_retry_config method.""" - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="Test", - iterator_selector=["a"], - output_selector=["b"], - ) - - result = node._get_retry_config() - - assert result is not None - - def test_get_default_value_dict(self): - """Test _get_default_value_dict method.""" - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="Test", - iterator_selector=["a"], - output_selector=["b"], - ) - - result = node._get_default_value_dict() - - assert isinstance(result, dict) - - -def test_extract_variable_selector_to_variable_mapping_validates_child_node_configs(monkeypatch) -> None: - seen_configs: list[object] = [] - original_validate_python = NodeConfigDictAdapter.validate_python - - def record_validate_python(value: object): - seen_configs.append(value) - return original_validate_python(value) - - monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python) - - child_node_config = { - "id": "answer-node", - "data": { - "type": "answer", - "title": "Answer", - "answer": "", - "iteration_id": "iteration-node", - }, - } - - IterationNode._extract_variable_selector_to_variable_mapping( - graph_config={ - "nodes": [ - { - "id": "iteration-node", - "data": { - "type": "iteration", - "title": "Iteration", - "iterator_selector": ["start", "items"], - "output_selector": ["iteration", "result"], - }, - }, - child_node_config, - ], - "edges": [], - }, - node_id="iteration-node", - node_data=IterationNodeData( - title="Iteration", - iterator_selector=["start", "items"], - output_selector=["iteration", "result"], - ), - ) - - assert seen_configs == [child_node_config] diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py index 2eb4feef5f5..bbfe350f7e4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py @@ -1,18 +1,18 @@ -from collections.abc import Mapping, Sequence +from collections.abc import Mapping from typing import Any import pytest - -from dify_graph.entities import GraphInitParams -from dify_graph.nodes.iteration.exc import IterationGraphNotFoundError -from dify_graph.nodes.iteration.iteration_node import IterationNode -from dify_graph.runtime import ( +from graphon.entities import GraphInitParams +from graphon.nodes.iteration.exc import IterationGraphNotFoundError +from graphon.nodes.iteration.iteration_node import IterationNode +from graphon.runtime import ( ChildEngineBuilderNotConfiguredError, ChildGraphNotFoundError, GraphRuntimeState, VariablePool, ) -from dify_graph.system_variable import SystemVariable + +from core.workflow.system_variables import default_system_variables from tests.workflow_test_utils import build_test_graph_init_params @@ -22,17 +22,16 @@ class _MissingGraphBuilder: *, workflow_id: str, graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: Mapping[str, Any], + parent_graph_runtime_state: GraphRuntimeState, root_node_id: str, - layers: Sequence[object] = (), + variable_pool: VariablePool | None = None, ) -> object: raise ChildGraphNotFoundError(f"child graph root node '{root_node_id}' not found") def _build_runtime_state() -> GraphRuntimeState: return GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable.default(), user_inputs={}), + variable_pool=VariablePool(system_variables=default_system_variables(), user_inputs={}), start_at=0.0, ) @@ -69,8 +68,6 @@ def test_graph_runtime_state_raises_specific_error_when_child_builder_is_missing runtime_state.create_child_engine( workflow_id="workflow", graph_init_params=graph_init_params, - graph_runtime_state=_build_runtime_state(), - graph_config={}, root_node_id="root", ) diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py deleted file mode 100644 index 8660449032b..00000000000 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py +++ /dev/null @@ -1,63 +0,0 @@ -import time -from contextlib import nullcontext -from datetime import UTC, datetime - -import pytest - -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.graph_events import NodeRunSucceededEvent -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData -from dify_graph.nodes.iteration.iteration_node import IterationNode - - -def test_parallel_iteration_duration_map_uses_worker_measured_time() -> None: - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="Parallel Iteration", - iterator_selector=["start", "items"], - output_selector=["iteration", "output"], - is_parallel=True, - parallel_nums=2, - error_handle_mode=ErrorHandleMode.TERMINATED, - ) - node._capture_execution_context = lambda: nullcontext() - node._sync_conversation_variables_from_snapshot = lambda snapshot: None - node._merge_usage = lambda current, new: new if current.total_tokens == 0 else current.plus(new) - - def fake_execute_single_iteration_parallel(*, index: int, item: object, execution_context: object): - return ( - 0.1 + (index * 0.1), - [ - NodeRunSucceededEvent( - id=f"exec-{index}", - node_id=f"llm-{index}", - node_type=BuiltinNodeTypes.LLM, - start_at=datetime.now(UTC).replace(tzinfo=None), - ), - ], - f"output-{item}", - {}, - LLMUsage.empty_usage(), - ) - - node._execute_single_iteration_parallel = fake_execute_single_iteration_parallel - - outputs: list[object] = [] - iter_run_map: dict[str, float] = {} - usage_accumulator = [LLMUsage.empty_usage()] - - generator = node._execute_parallel_iterations( - iterator_list_value=["a", "b"], - outputs=outputs, - iter_run_map=iter_run_map, - usage_accumulator=usage_accumulator, - ) - - for _ in generator: - # Simulate a slow consumer replaying buffered events. - time.sleep(0.02) - - assert outputs == ["output-a", "output-b"] - assert iter_run_map["0"] == pytest.approx(0.1) - assert iter_run_map["1"] == pytest.approx(0.2) diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py index 33f7ace5ab3..f8802138b58 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py @@ -3,8 +3,12 @@ import uuid from unittest.mock import Mock import pytest +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables.segments import StringSegment from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.workflow.nodes.knowledge_index.entities import KnowledgeIndexNodeData from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError from core.workflow.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode @@ -14,10 +18,7 @@ from core.workflow.nodes.knowledge_index.protocols import ( PreviewItem, SummaryIndexServiceProtocol, ) -from dify_graph.enums import SystemVariableKey, WorkflowNodeExecutionStatus -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables.segments import StringSegment +from core.workflow.system_variables import SystemVariableKey, build_system_variables from tests.workflow_test_utils import build_test_graph_init_params @@ -40,7 +41,7 @@ def mock_graph_init_params(): def mock_graph_runtime_state(): """Create mock GraphRuntimeState.""" variable_pool = VariablePool( - system_variables=SystemVariable(user_id=str(uuid.uuid4()), files=[]), + system_variables=build_system_variables(user_id=str(uuid.uuid4()), files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -78,7 +79,7 @@ def sample_node_data(): type="knowledge-index", chunk_structure="general_structure", index_chunk_variable_selector=["start", "chunks"], - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, summary_index_setting=None, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py index 99997db6b29..ab64be59ad2 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py @@ -3,6 +3,10 @@ import uuid from unittest.mock import Mock import pytest +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables import StringSegment from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.nodes.knowledge_retrieval.entities import ( @@ -16,11 +20,7 @@ from core.workflow.nodes.knowledge_retrieval.entities import ( from core.workflow.nodes.knowledge_retrieval.exc import RateLimitExceededError from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode from core.workflow.nodes.knowledge_retrieval.retrieval import RAGRetrievalProtocol, Source -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables import StringSegment +from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params @@ -43,7 +43,7 @@ def mock_graph_init_params(): def mock_graph_runtime_state(): """Create mock GraphRuntimeState.""" variable_pool = VariablePool( - system_variables=SystemVariable(user_id=str(uuid.uuid4()), files=[]), + system_variables=build_system_variables(user_id=str(uuid.uuid4()), files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -157,7 +157,7 @@ class TestKnowledgeRetrievalNode: ): """Test _run with query variable in single mode.""" # Arrange - from dify_graph.nodes.llm.entities import ModelConfig + from graphon.nodes.llm.entities import ModelConfig query = "What is Python?" query_selector = ["start", "query"] @@ -441,7 +441,7 @@ class TestFetchDatasetRetriever: ): """Test _fetch_dataset_retriever in single mode.""" # Arrange - from dify_graph.nodes.llm.entities import ModelConfig + from graphon.nodes.llm.entities import ModelConfig query = "What is Python?" variables = {"query": query} diff --git a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py index d71e0921c1e..fdf1706765a 100644 --- a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py @@ -1,13 +1,13 @@ from unittest.mock import MagicMock import pytest +from graphon.entities import GraphInitParams +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.nodes.list_operator.node import ListOperatorNode +from graphon.runtime import GraphRuntimeState +from graphon.variables import ArrayNumberSegment, ArrayStringSegment -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.nodes.list_operator.node import ListOperatorNode -from dify_graph.runtime import GraphRuntimeState -from dify_graph.variables import ArrayNumberSegment, ArrayStringSegment +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY class TestListOperatorNode: diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py deleted file mode 100644 index b0f0fd428b6..00000000000 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py +++ /dev/null @@ -1,196 +0,0 @@ -import uuid -from typing import NamedTuple -from unittest import mock -from unittest.mock import MagicMock - -import httpx -import pytest - -from core.helper import ssrf_proxy -from core.tools import signature -from core.tools.tool_file_manager import ToolFileManager -from dify_graph.file import FileTransferMethod, FileType, models -from dify_graph.nodes.llm.file_saver import ( - FileSaverImpl, - _extract_content_type_and_extension, - _get_extension, - _validate_extension_override, -) -from models import ToolFile - -_PNG_DATA = b"\x89PNG\r\n\x1a\n" - - -def _gen_id(): - return str(uuid.uuid4()) - - -class TestFileSaverImpl: - def test_save_binary_string(self, monkeypatch: pytest.MonkeyPatch): - user_id = _gen_id() - tenant_id = _gen_id() - file_type = FileType.IMAGE - mime_type = "image/png" - mock_signed_url = "https://example.com/image.png" - mock_tool_file = ToolFile( - user_id=user_id, - tenant_id=tenant_id, - conversation_id=None, - file_key="test-file-key", - mimetype=mime_type, - original_url=None, - name=f"{_gen_id()}.png", - size=len(_PNG_DATA), - ) - mock_tool_file.id = _gen_id() - mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager) - - mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file - monkeypatch.setattr(FileSaverImpl, "_get_tool_file_manager", lambda _: mocked_tool_file_manager) - # Since `File.generate_url` used `ToolFileManager.sign_file` directly, we also need to patch it here. - mocked_sign_file = mock.MagicMock(spec=signature.sign_tool_file) - # Since `File.generate_url` used `signature.sign_tool_file` directly, we also need to patch it here. - monkeypatch.setattr(models, "sign_tool_file", mocked_sign_file) - mocked_sign_file.return_value = mock_signed_url - http_client = MagicMock() - - storage_file_manager = FileSaverImpl( - user_id=user_id, - tenant_id=tenant_id, - http_client=http_client, - ) - - file = storage_file_manager.save_binary_string(_PNG_DATA, mime_type, file_type) - assert file.tenant_id == tenant_id - assert file.type == file_type - assert file.transfer_method == FileTransferMethod.TOOL_FILE - assert file.extension == ".png" - assert file.mime_type == mime_type - assert file.size == len(_PNG_DATA) - assert file.related_id == mock_tool_file.id - - assert file.generate_url() == mock_signed_url - - mocked_tool_file_manager.create_file_by_raw.assert_called_once_with( - user_id=user_id, - tenant_id=tenant_id, - conversation_id=None, - file_binary=_PNG_DATA, - mimetype=mime_type, - ) - mocked_sign_file.assert_called_once_with(tool_file_id=mock_tool_file.id, extension=".png", for_external=True) - - def test_save_remote_url_request_failed(self, monkeypatch: pytest.MonkeyPatch): - _TEST_URL = "https://example.com/image.png" - mock_request = httpx.Request("GET", _TEST_URL) - mock_response = httpx.Response( - status_code=401, - request=mock_request, - ) - http_client = MagicMock() - http_client.get.return_value = mock_response - - file_saver = FileSaverImpl( - user_id=_gen_id(), - tenant_id=_gen_id(), - http_client=http_client, - ) - - with pytest.raises(httpx.HTTPStatusError) as exc: - file_saver.save_remote_url(_TEST_URL, FileType.IMAGE) - http_client.get.assert_called_once_with(_TEST_URL) - assert exc.value.response.status_code == 401 - - def test_save_remote_url_success(self, monkeypatch: pytest.MonkeyPatch): - _TEST_URL = "https://example.com/image.png" - mime_type = "image/png" - user_id = _gen_id() - tenant_id = _gen_id() - - mock_request = httpx.Request("GET", _TEST_URL) - mock_response = httpx.Response( - status_code=200, - content=b"test-data", - headers={"Content-Type": mime_type}, - request=mock_request, - ) - http_client = MagicMock() - http_client.get.return_value = mock_response - - file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id, http_client=http_client) - mock_tool_file = ToolFile( - user_id=user_id, - tenant_id=tenant_id, - conversation_id=None, - file_key="test-file-key", - mimetype=mime_type, - original_url=None, - name=f"{_gen_id()}.png", - size=len(_PNG_DATA), - ) - mock_tool_file.id = _gen_id() - mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response) - monkeypatch.setattr(ssrf_proxy, "get", mock_get) - mock_save_binary_string = mock.MagicMock(spec=file_saver.save_binary_string, return_value=mock_tool_file) - monkeypatch.setattr(file_saver, "save_binary_string", mock_save_binary_string) - - file = file_saver.save_remote_url(_TEST_URL, FileType.IMAGE) - mock_save_binary_string.assert_called_once_with( - mock_response.content, - mime_type, - FileType.IMAGE, - extension_override=".png", - ) - assert file == mock_tool_file - - -def test_validate_extension_override(): - class TestCase(NamedTuple): - extension_override: str | None - expected: str | None - - cases = [TestCase(None, None), TestCase("", ""), ".png", ".png", ".tar.gz", ".tar.gz"] - - for valid_ext_override in [None, "", ".png", ".tar.gz"]: - assert valid_ext_override == _validate_extension_override(valid_ext_override) - - for invalid_ext_override in ["png", "tar.gz"]: - with pytest.raises(ValueError) as exc: - _validate_extension_override(invalid_ext_override) - - -class TestExtractContentTypeAndExtension: - def test_with_both_content_type_and_extension(self): - content_type, extension = _extract_content_type_and_extension("https://example.com/image.jpg", "image/png") - assert content_type == "image/png" - assert extension == ".png" - - def test_url_with_file_extension(self): - for content_type in [None, ""]: - content_type, extension = _extract_content_type_and_extension("https://example.com/image.png", content_type) - assert content_type == "image/png" - assert extension == ".png" - - def test_response_with_content_type(self): - content_type, extension = _extract_content_type_and_extension("https://example.com/image", "image/png") - assert content_type == "image/png" - assert extension == ".png" - - def test_no_content_type_and_no_extension(self): - for content_type in [None, ""]: - content_type, extension = _extract_content_type_and_extension("https://example.com/image", content_type) - assert content_type == "application/octet-stream" - assert extension == ".bin" - - -class TestGetExtension: - def test_with_extension_override(self): - mime_type = "image/png" - for override in [".jpg", ""]: - extension = _get_extension(mime_type, override) - assert extension == override - - def test_without_extension_override(self): - mime_type = "image/png" - extension = _get_extension(mime_type) - assert extension == ".png" diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py index 618a4986593..c784f805c01 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py @@ -1,14 +1,95 @@ from unittest import mock import pytest +from graphon.file import File, FileTransferMethod, FileType +from graphon.model_runtime.entities import ( + ImagePromptMessageContent, + PromptMessageRole, + TextPromptMessageContent, +) +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, +) +from graphon.nodes.base.entities import VariableSelector +from graphon.nodes.llm import llm_utils +from graphon.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, MemoryConfig +from graphon.nodes.llm.exc import ( + InvalidVariableTypeError, + MemoryRolePrefixRequiredError, + NoPromptFoundError, + TemplateTypeNotSupportError, +) +from graphon.runtime import VariablePool +from graphon.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment from core.model_manager import ModelInstance -from dify_graph.model_runtime.entities import ImagePromptMessageContent, PromptMessageRole, TextPromptMessageContent -from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage -from dify_graph.nodes.llm import llm_utils -from dify_graph.nodes.llm.entities import LLMNodeChatModelMessage -from dify_graph.nodes.llm.exc import NoPromptFoundError -from dify_graph.runtime import VariablePool + + +def _build_model_schema( + *, + features: list[ModelFeature] | None = None, + model_properties: dict[ModelPropertyKey, object] | None = None, + parameter_rules: list[ParameterRule] | None = None, +) -> AIModelEntity: + return AIModelEntity( + model="gpt-3.5-turbo", + label={"en_US": "GPT-3.5 Turbo"}, + model_type=ModelType.LLM, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + features=features, + model_properties=model_properties or {}, + parameter_rules=parameter_rules or [], + ) + + +def _build_model_instance(*, model_schema: AIModelEntity | None = None) -> mock.MagicMock: + model_instance = mock.MagicMock(spec=ModelInstance) + model_instance.model_name = "gpt-3.5-turbo" + model_instance.parameters = {} + model_instance.get_model_schema.return_value = model_schema or _build_model_schema(features=[]) + model_instance.get_llm_num_tokens.return_value = 0 + return model_instance + + +def _build_image_file( + *, + file_id: str, + related_id: str, + remote_url: str, + extension: str = ".png", + mime_type: str = "image/png", +) -> File: + return File( + id=file_id, + type=FileType.IMAGE, + filename=f"{file_id}{extension}", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=remote_url, + related_id=related_id, + extension=extension, + mime_type=mime_type, + storage_key="", + ) + + +@pytest.fixture +def variable_pool() -> VariablePool: + pool = VariablePool.empty() + pool.add(["node1", "output"], "resolved_value") + pool.add(["node2", "text"], "hello world") + pool.add(["start", "user_input"], "dynamic_param") + return pool def _fetch_prompt_messages_with_mocked_content(content): @@ -24,15 +105,15 @@ def _fetch_prompt_messages_with_mocked_content(content): with ( mock.patch( - "dify_graph.nodes.llm.llm_utils.fetch_model_schema", + "graphon.nodes.llm.llm_utils.fetch_model_schema", return_value=mock.MagicMock(features=[]), ), mock.patch( - "dify_graph.nodes.llm.llm_utils.handle_list_messages", + "graphon.nodes.llm.llm_utils.handle_list_messages", return_value=[SystemPromptMessage(content=content)], ), mock.patch( - "dify_graph.nodes.llm.llm_utils.handle_memory_chat_mode", + "graphon.nodes.llm.llm_utils.handle_memory_chat_mode", return_value=[], ), ): @@ -53,6 +134,159 @@ def _fetch_prompt_messages_with_mocked_content(content): ) +class TestTypeCoercionViaResolve: + """Type coercion is tested through the public resolve_completion_params_variables API.""" + + def test_numeric_string_coerced_to_float(self): + pool = VariablePool.empty() + pool.add(["n", "v"], "0.7") + result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool) + assert result["p"] == 0.7 + + def test_integer_string_coerced_to_int(self): + pool = VariablePool.empty() + pool.add(["n", "v"], "1024") + result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool) + assert result["p"] == 1024 + + def test_boolean_string_coerced_to_bool(self): + pool = VariablePool.empty() + pool.add(["n", "v"], "true") + result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool) + assert result["p"] is True + + def test_plain_string_stays_string(self): + pool = VariablePool.empty() + pool.add(["n", "v"], "json_object") + result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool) + assert result["p"] == "json_object" + + def test_json_object_string_stays_string(self): + pool = VariablePool.empty() + pool.add(["n", "v"], '{"key": "val"}') + result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool) + assert result["p"] == '{"key": "val"}' + + def test_mixed_text_and_variable_stays_string(self): + pool = VariablePool.empty() + pool.add(["n", "v"], "0.7") + result = llm_utils.resolve_completion_params_variables({"p": "val={{#n.v#}}"}, pool) + assert result["p"] == "val=0.7" + + +class TestResolveCompletionParamsVariables: + def test_plain_string_values_unchanged(self, variable_pool: VariablePool): + params = {"response_format": "json", "custom_param": "static_value"} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"response_format": "json", "custom_param": "static_value"} + + def test_numeric_values_unchanged(self, variable_pool: VariablePool): + params = {"temperature": 0.7, "top_p": 0.9, "max_tokens": 1024} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"temperature": 0.7, "top_p": 0.9, "max_tokens": 1024} + + def test_boolean_values_unchanged(self, variable_pool: VariablePool): + params = {"stream": True, "echo": False} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"stream": True, "echo": False} + + def test_list_values_unchanged(self, variable_pool: VariablePool): + params = {"stop": ["Human:", "Assistant:"]} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"stop": ["Human:", "Assistant:"]} + + def test_single_variable_reference_resolved(self, variable_pool: VariablePool): + params = {"response_format": "{{#node1.output#}}"} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"response_format": "resolved_value"} + + def test_multiple_variable_references_resolved(self, variable_pool: VariablePool): + params = { + "param_a": "{{#node1.output#}}", + "param_b": "{{#node2.text#}}", + } + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"param_a": "resolved_value", "param_b": "hello world"} + + def test_mixed_text_and_variable_resolved(self, variable_pool: VariablePool): + params = {"prompt_prefix": "prefix_{{#node1.output#}}_suffix"} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"prompt_prefix": "prefix_resolved_value_suffix"} + + def test_mixed_params_types(self, variable_pool: VariablePool): + """Non-string params pass through; string params with variables get resolved.""" + params = { + "temperature": 0.7, + "response_format": "{{#node1.output#}}", + "custom_string": "no_vars_here", + "max_tokens": 512, + "stop": ["\n"], + } + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == { + "temperature": 0.7, + "response_format": "resolved_value", + "custom_string": "no_vars_here", + "max_tokens": 512, + "stop": ["\n"], + } + + def test_empty_params(self, variable_pool: VariablePool): + result = llm_utils.resolve_completion_params_variables({}, variable_pool) + + assert result == {} + + def test_unresolvable_variable_keeps_selector_text(self): + """When a referenced variable doesn't exist in the pool, convert_template + falls back to the raw selector path (e.g. 'nonexistent.var').""" + pool = VariablePool.empty() + params = {"format": "{{#nonexistent.var#}}"} + + result = llm_utils.resolve_completion_params_variables(params, pool) + + assert result["format"] == "nonexistent.var" + + def test_multiple_variables_in_single_value(self, variable_pool: VariablePool): + params = {"combined": "{{#node1.output#}} and {{#node2.text#}}"} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"combined": "resolved_value and hello world"} + + def test_original_params_not_mutated(self, variable_pool: VariablePool): + original = {"response_format": "{{#node1.output#}}", "temperature": 0.5} + original_copy = dict(original) + + _ = llm_utils.resolve_completion_params_variables(original, variable_pool) + + assert original == original_copy + + def test_long_value_truncated(self): + pool = VariablePool.empty() + pool.add(["node1", "big"], "x" * 2000) + params = {"param": "{{#node1.big#}}"} + + result = llm_utils.resolve_completion_params_variables(params, pool) + + assert len(result["param"]) == llm_utils.MAX_RESOLVED_VALUE_LENGTH + + def test_fetch_prompt_messages_skips_messages_when_all_contents_are_filtered_out(): with pytest.raises(NoPromptFoundError): _fetch_prompt_messages_with_mocked_content( @@ -104,3 +338,700 @@ def test_fetch_prompt_messages_keeps_list_content_when_multiple_supported_items_ ] ) ] + + +def test_fetch_model_schema_raises_when_model_schema_is_missing(): + model_instance = _build_model_instance() + model_instance.get_model_schema.return_value = None + + with pytest.raises(ValueError, match="Model schema not found for gpt-3.5-turbo"): + llm_utils.fetch_model_schema(model_instance=model_instance) + + +def test_fetch_files_supports_known_segments_and_rejects_invalid_types(): + file = _build_image_file(file_id="image", related_id="image-related", remote_url="https://example.com/image.png") + variable_pool = VariablePool.empty() + variable_pool.add(["input", "file"], file) + variable_pool.add(["input", "files"], ArrayFileSegment(value=[file])) + variable_pool.add(["input", "none"], NoneSegment()) + variable_pool.add(["input", "empty"], ArrayAnySegment(value=[])) + variable_pool.add(["input", "invalid"], {"a": 1}) + + assert llm_utils.fetch_files(variable_pool, ["input", "file"]) == [file] + assert llm_utils.fetch_files(variable_pool, ["input", "files"]) == [file] + assert llm_utils.fetch_files(variable_pool, ["input", "none"]) == [] + assert llm_utils.fetch_files(variable_pool, ["input", "empty"]) == [] + + with pytest.raises(InvalidVariableTypeError, match="Invalid variable type"): + llm_utils.fetch_files(variable_pool, ["input", "invalid"]) + + +def test_fetch_files_returns_empty_for_missing_variable(): + assert llm_utils.fetch_files(VariablePool.empty(), ["input", "missing"]) == [] + + +def test_convert_history_messages_to_text_skips_system_messages_and_formats_images(): + history_text = llm_utils.convert_history_messages_to_text( + history_messages=[ + SystemPromptMessage(content="skip"), + UserPromptMessage( + content=[ + TextPromptMessageContent(data="Question"), + ImagePromptMessageContent( + format="png", + url="https://example.com/image.png", + mime_type="image/png", + ), + ] + ), + AssistantPromptMessage(content="Answer"), + ], + human_prefix="Human", + ai_prefix="Assistant", + ) + + assert history_text == "Human: Question\n[image]\nAssistant: Answer" + + +def test_fetch_memory_text_uses_prompt_memory_interface(): + memory = mock.MagicMock() + memory.get_history_prompt_messages.return_value = [UserPromptMessage(content="Question")] + + memory_text = llm_utils.fetch_memory_text( + memory=memory, + max_token_limit=321, + message_limit=2, + human_prefix="Human", + ai_prefix="Assistant", + ) + + assert memory_text == "Human: Question" + memory.get_history_prompt_messages.assert_called_once_with(max_token_limit=321, message_limit=2) + + +def test_handle_list_messages_renders_jinja2_messages(): + variable_pool = VariablePool.empty() + variable_pool.add(["input", "name"], "Dify") + renderer = mock.MagicMock() + renderer.render_template.return_value = "Hello Dify" + + prompt_messages = llm_utils.handle_list_messages( + messages=[ + LLMNodeChatModelMessage( + text="ignored", + jinja2_text="Hello {{ name }}", + role=PromptMessageRole.SYSTEM, + edition_type="jinja2", + ) + ], + context="", + jinja2_variables=[VariableSelector(variable="name", value_selector=["input", "name"])], + variable_pool=variable_pool, + vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH, + template_renderer=renderer, + ) + + assert prompt_messages == [SystemPromptMessage(content=[TextPromptMessageContent(data="Hello Dify")])] + renderer.render_template.assert_called_once_with("Hello {{ name }}", {"name": "Dify"}) + + +def test_handle_list_messages_splits_text_and_file_content(): + variable_pool = VariablePool.empty() + image_file = _build_image_file( + file_id="image-file", + related_id="image-related", + remote_url="https://example.com/file.png", + ) + variable_pool.add(["input", "image"], image_file) + + with mock.patch( + "graphon.nodes.llm.llm_utils.file_manager.to_prompt_message_content", + return_value=ImagePromptMessageContent( + format="png", + url="https://example.com/file.png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ) as mock_to_prompt: + prompt_messages = llm_utils.handle_list_messages( + messages=[ + LLMNodeChatModelMessage( + text="Analyze {{#input.image#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + context="", + jinja2_variables=[], + variable_pool=variable_pool, + vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH, + ) + + assert prompt_messages == [ + UserPromptMessage(content=[TextPromptMessageContent(data="Analyze ")]), + UserPromptMessage( + content=[ + ImagePromptMessageContent( + format="png", + url="https://example.com/file.png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ) + ] + ), + ] + mock_to_prompt.assert_called_once() + + +def test_handle_list_messages_supports_array_file_segments(): + variable_pool = VariablePool.empty() + first_file = _build_image_file(file_id="first", related_id="first-related", remote_url="https://example.com/1.png") + second_file = _build_image_file( + file_id="second", + related_id="second-related", + remote_url="https://example.com/2.png", + ) + variable_pool.add(["input", "images"], ArrayFileSegment(value=[first_file, second_file])) + + first_prompt = ImagePromptMessageContent( + format="png", + url="https://example.com/1.png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ) + second_prompt = ImagePromptMessageContent( + format="png", + url="https://example.com/2.png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ) + + with mock.patch( + "graphon.nodes.llm.llm_utils.file_manager.to_prompt_message_content", + side_effect=[first_prompt, second_prompt], + ): + prompt_messages = llm_utils.handle_list_messages( + messages=[ + LLMNodeChatModelMessage( + text="{{#input.images#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + context="", + jinja2_variables=[], + variable_pool=variable_pool, + vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH, + ) + + assert prompt_messages == [UserPromptMessage(content=[first_prompt, second_prompt])] + + +def test_render_jinja2_message_handles_empty_template_success_and_missing_renderer(): + variable_pool = VariablePool.empty() + variable_pool.add(["input", "name"], "Dify") + variables = [VariableSelector(variable="name", value_selector=["input", "name"])] + + assert ( + llm_utils.render_jinja2_message( + template="", + jinja2_variables=variables, + variable_pool=variable_pool, + template_renderer=None, + ) + == "" + ) + + with pytest.raises(ValueError, match="template_renderer is required"): + llm_utils.render_jinja2_message( + template="Hello {{ name }}", + jinja2_variables=variables, + variable_pool=variable_pool, + template_renderer=None, + ) + + renderer = mock.MagicMock() + renderer.render_template.return_value = "Hello Dify" + assert ( + llm_utils.render_jinja2_message( + template="Hello {{ name }}", + jinja2_variables=variables, + variable_pool=variable_pool, + template_renderer=renderer, + ) + == "Hello Dify" + ) + + +def test_handle_completion_template_supports_basic_and_jinja2_templates(): + variable_pool = VariablePool.empty() + variable_pool.add(["input", "name"], "Dify") + renderer = mock.MagicMock() + renderer.render_template.return_value = "Hello Dify" + + basic_messages = llm_utils.handle_completion_template( + template=LLMNodeCompletionModelPromptTemplate( + text="Summarize {{#context#}}", + edition_type="basic", + ), + context="the docs", + jinja2_variables=[], + variable_pool=variable_pool, + ) + jinja_messages = llm_utils.handle_completion_template( + template=LLMNodeCompletionModelPromptTemplate( + text="ignored", + jinja2_text="Hello {{ name }}", + edition_type="jinja2", + ), + context="", + jinja2_variables=[VariableSelector(variable="name", value_selector=["input", "name"])], + variable_pool=variable_pool, + template_renderer=renderer, + ) + + assert basic_messages == [ + UserPromptMessage(content=[TextPromptMessageContent(data="Summarize the docs")]), + ] + assert jinja_messages == [ + UserPromptMessage(content=[TextPromptMessageContent(data="Hello Dify")]), + ] + + +def test_combine_message_content_with_role_handles_all_supported_roles(): + contents = [TextPromptMessageContent(data="hello")] + + assert llm_utils.combine_message_content_with_role(contents=contents, role=PromptMessageRole.USER) == ( + UserPromptMessage(content=contents) + ) + assert llm_utils.combine_message_content_with_role(contents=contents, role=PromptMessageRole.ASSISTANT) == ( + AssistantPromptMessage(content=contents) + ) + assert llm_utils.combine_message_content_with_role(contents=contents, role=PromptMessageRole.SYSTEM) == ( + SystemPromptMessage(content=contents) + ) + + with pytest.raises(NotImplementedError, match="Role custom is not supported"): + llm_utils.combine_message_content_with_role(contents=contents, role="custom") # type: ignore[arg-type] + + +def test_calculate_rest_token_uses_context_size_and_template_alias(): + model_instance = _build_model_instance( + model_schema=_build_model_schema( + model_properties={ModelPropertyKey.CONTEXT_SIZE: 4096}, + parameter_rules=[ + ParameterRule( + name="output_limit", + use_template="max_tokens", + label={"en_US": "Output Limit"}, + type=ParameterType.INT, + ) + ], + ) + ) + model_instance.parameters = {"max_tokens": 512} + model_instance.get_llm_num_tokens.return_value = 256 + + assert ( + llm_utils.calculate_rest_token( + prompt_messages=[UserPromptMessage(content="hello")], + model_instance=model_instance, + ) + == 3328 + ) + + +def test_handle_memory_chat_mode_returns_empty_without_memory_and_uses_window_when_present(): + model_instance = _build_model_instance() + memory = mock.MagicMock() + memory.get_history_prompt_messages.return_value = [UserPromptMessage(content="Question")] + + assert ( + llm_utils.handle_memory_chat_mode( + memory=None, + memory_config=None, + model_instance=model_instance, + ) + == [] + ) + + with mock.patch("graphon.nodes.llm.llm_utils.calculate_rest_token", return_value=123) as mock_rest: + messages = llm_utils.handle_memory_chat_mode( + memory=memory, + memory_config=MemoryConfig(window=MemoryConfig.WindowConfig(enabled=True, size=2)), + model_instance=model_instance, + ) + + assert messages == [UserPromptMessage(content="Question")] + mock_rest.assert_called_once() + memory.get_history_prompt_messages.assert_called_once_with(max_token_limit=123, message_limit=2) + + +def test_handle_memory_completion_mode_validates_role_prefix_and_formats_history(): + model_instance = _build_model_instance() + memory = mock.MagicMock() + memory.get_history_prompt_messages.return_value = [ + UserPromptMessage(content="Question"), + AssistantPromptMessage(content="Answer"), + ] + + assert ( + llm_utils.handle_memory_completion_mode( + memory=None, + memory_config=None, + model_instance=model_instance, + ) + == "" + ) + + with ( + mock.patch("graphon.nodes.llm.llm_utils.calculate_rest_token", return_value=456), + pytest.raises(MemoryRolePrefixRequiredError, match="Memory role prefix is required"), + ): + llm_utils.handle_memory_completion_mode( + memory=memory, + memory_config=MemoryConfig(window=MemoryConfig.WindowConfig(enabled=True, size=2)), + model_instance=model_instance, + ) + + with mock.patch("graphon.nodes.llm.llm_utils.calculate_rest_token", return_value=456): + history_text = llm_utils.handle_memory_completion_mode( + memory=memory, + memory_config=MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=False), + ), + model_instance=model_instance, + ) + + assert history_text == "Human: Question\nAssistant: Answer" + memory.get_history_prompt_messages.assert_called_with(max_token_limit=456, message_limit=None) + + +def test_append_file_prompts_merges_with_existing_user_content_or_appends_new_message(): + file = _build_image_file(file_id="image", related_id="image-related", remote_url="https://example.com/image.png") + file_prompt = ImagePromptMessageContent( + format="png", + url="https://example.com/image.png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ) + prompt_messages = [UserPromptMessage(content=[TextPromptMessageContent(data="Question")])] + + with mock.patch( + "graphon.nodes.llm.llm_utils.file_manager.to_prompt_message_content", + return_value=file_prompt, + ): + llm_utils._append_file_prompts( + prompt_messages=prompt_messages, + files=[file], + vision_enabled=True, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + ) + + assert prompt_messages == [ + UserPromptMessage(content=[file_prompt, TextPromptMessageContent(data="Question")]), + ] + + prompt_messages = [SystemPromptMessage(content="System prompt")] + with mock.patch( + "graphon.nodes.llm.llm_utils.file_manager.to_prompt_message_content", + return_value=file_prompt, + ): + llm_utils._append_file_prompts( + prompt_messages=prompt_messages, + files=[file], + vision_enabled=True, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + ) + + assert prompt_messages[-1] == UserPromptMessage(content=[file_prompt]) + + +def test_fetch_prompt_messages_chat_mode_includes_query_memory_and_supported_files(): + model_instance = _build_model_instance(model_schema=_build_model_schema(features=[ModelFeature.VISION])) + memory = mock.MagicMock() + memory.get_history_prompt_messages.return_value = [AssistantPromptMessage(content="history")] + sys_file = _build_image_file(file_id="sys", related_id="sys-related", remote_url="https://example.com/sys.png") + context_file = _build_image_file( + file_id="context", + related_id="context-related", + remote_url="https://example.com/context.png", + ) + file_prompts = [ + ImagePromptMessageContent( + format="png", + url="https://example.com/sys.png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ImagePromptMessageContent( + format="png", + url="https://example.com/context.png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ] + + with mock.patch( + "graphon.nodes.llm.llm_utils.file_manager.to_prompt_message_content", + side_effect=file_prompts, + ): + prompt_messages, stop = llm_utils.fetch_prompt_messages( + sys_query="current question", + sys_files=[sys_file], + context="", + memory=memory, + model_instance=model_instance, + prompt_template=[ + LLMNodeChatModelMessage( + text="Before query", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + stop=("STOP",), + memory_config=MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)), + vision_enabled=True, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + context_files=[context_file], + ) + + assert stop == ("STOP",) + assert prompt_messages[0] == UserPromptMessage(content="Before query") + assert prompt_messages[1] == AssistantPromptMessage(content="history") + assert prompt_messages[2] == UserPromptMessage( + content=[ + file_prompts[1], + file_prompts[0], + TextPromptMessageContent(data="current question"), + ] + ) + + +def test_fetch_prompt_messages_completion_mode_updates_list_content_with_histories_and_query(): + model_instance = _build_model_instance(model_schema=_build_model_schema(features=[])) + memory = mock.MagicMock() + memory.get_history_prompt_messages.return_value = [ + UserPromptMessage(content="previous question"), + AssistantPromptMessage(content="previous answer"), + ] + + prompt_messages, stop = llm_utils.fetch_prompt_messages( + sys_query="latest question", + sys_files=[], + context="", + memory=memory, + model_instance=model_instance, + prompt_template=LLMNodeCompletionModelPromptTemplate( + text="Prompt header\n#histories#", + edition_type="basic", + ), + stop=("HALT",), + memory_config=MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=True, size=2), + ), + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + assert stop == ("HALT",) + assert prompt_messages == [ + UserPromptMessage( + content="latest question\nPrompt header\nHuman: previous question\nAssistant: previous answer" + ) + ] + + memory.get_history_prompt_messages.return_value = [ + UserPromptMessage(content="another question"), + AssistantPromptMessage(content="another answer"), + ] + + prompt_messages, _ = llm_utils.fetch_prompt_messages( + sys_query="latest question", + sys_files=[], + context="", + memory=memory, + model_instance=model_instance, + prompt_template=LLMNodeCompletionModelPromptTemplate( + text="Prompt header", + edition_type="basic", + ), + stop=None, + memory_config=MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=True, size=2), + ), + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + assert prompt_messages == [ + UserPromptMessage(content="latest question\nHuman: another question\nAssistant: another answer\nPrompt header") + ] + + +def test_fetch_prompt_messages_filters_content_unsupported_by_model_features(): + model_instance = _build_model_instance(model_schema=_build_model_schema(features=[ModelFeature.DOCUMENT])) + prompt_template = [ + LLMNodeChatModelMessage( + text="You are a classifier.", + role=PromptMessageRole.SYSTEM, + edition_type="basic", + ) + ] + + with ( + mock.patch( + "graphon.nodes.llm.llm_utils.handle_list_messages", + return_value=[ + SystemPromptMessage( + content=[ + TextPromptMessageContent(data="You are a classifier."), + ImagePromptMessageContent( + format="png", + url="https://example.com/image.png", + mime_type="image/png", + ), + ] + ) + ], + ), + mock.patch("graphon.nodes.llm.llm_utils.handle_memory_chat_mode", return_value=[]), + ): + prompt_messages, stop = llm_utils.fetch_prompt_messages( + sys_query=None, + sys_files=[], + context="", + memory=None, + model_instance=model_instance, + prompt_template=prompt_template, + stop=("END",), + memory_config=None, + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + assert stop == ("END",) + assert prompt_messages == [SystemPromptMessage(content="You are a classifier.")] + + +def test_fetch_prompt_messages_completion_mode_supports_string_content_and_invalid_template_type(): + model_instance = _build_model_instance(model_schema=_build_model_schema(features=[])) + + with ( + mock.patch( + "graphon.nodes.llm.llm_utils.handle_completion_template", + return_value=[UserPromptMessage(content="Prefix #histories# and #sys.query#")], + ), + mock.patch( + "graphon.nodes.llm.llm_utils.handle_memory_completion_mode", + return_value="history text", + ), + ): + prompt_messages, stop = llm_utils.fetch_prompt_messages( + sys_query="latest question", + sys_files=[], + context="", + memory=None, + model_instance=model_instance, + prompt_template=LLMNodeCompletionModelPromptTemplate( + text="ignored", + edition_type="basic", + ), + stop=("HALT",), + memory_config=None, + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + assert stop == ("HALT",) + assert prompt_messages == [UserPromptMessage(content="Prefix history text and latest question")] + + with pytest.raises(TemplateTypeNotSupportError): + llm_utils.fetch_prompt_messages( + sys_query=None, + sys_files=[], + context="", + memory=None, + model_instance=model_instance, + prompt_template=object(), # type: ignore[arg-type] + stop=None, + memory_config=None, + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + invalid_prompt = mock.MagicMock() + invalid_prompt.content = object() + with ( + mock.patch( + "graphon.nodes.llm.llm_utils.handle_completion_template", + return_value=[invalid_prompt], + ), + mock.patch( + "graphon.nodes.llm.llm_utils.handle_memory_completion_mode", + return_value="history text", + ), + pytest.raises(ValueError, match="Invalid prompt content type"), + ): + llm_utils.fetch_prompt_messages( + sys_query="latest question", + sys_files=[], + context="", + memory=None, + model_instance=model_instance, + prompt_template=LLMNodeCompletionModelPromptTemplate( + text="ignored", + edition_type="basic", + ), + stop=None, + memory_config=None, + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + with ( + mock.patch( + "graphon.nodes.llm.llm_utils.handle_completion_template", + return_value=[UserPromptMessage(content="Prefix only")], + ), + mock.patch( + "graphon.nodes.llm.llm_utils.handle_memory_completion_mode", + return_value="history text", + ), + ): + prompt_messages, _ = llm_utils.fetch_prompt_messages( + sys_query=None, + sys_files=[], + context="", + memory=None, + model_instance=model_instance, + prompt_template=LLMNodeCompletionModelPromptTemplate( + text="ignored", + edition_type="basic", + ), + stop=None, + memory_config=None, + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + assert prompt_messages == [UserPromptMessage(content="history text\nPrefix only")] diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index fc96088af12..a215e9d350d 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -4,41 +4,81 @@ from collections.abc import Sequence from unittest import mock import pytest - -from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity, UserFrom -from core.app.llm.model_access import DifyCredentialsProvider, DifyModelFactory, fetch_model_config -from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle -from core.entities.provider_entities import CustomConfiguration, SystemConfiguration -from core.model_manager import ModelInstance -from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from dify_graph.entities import GraphInitParams -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.entities import GraphInitParams +from graphon.file import File, FileTransferMethod, FileType +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.llm_entities import ( + LLMResultChunk, + LLMResultChunkDelta, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, + LLMUsage, +) +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessage, PromptMessageRole, + SystemPromptMessage, TextPromptMessageContent, UserPromptMessage, ) -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -from dify_graph.nodes.llm import llm_utils -from dify_graph.nodes.llm.entities import ( +from graphon.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, +) +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from graphon.node_events import ModelInvokeCompletedEvent, RunRetrieverResourceEvent, StreamChunkEvent +from graphon.nodes.base.entities import VariableSelector +from graphon.nodes.llm import llm_utils +from graphon.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, + LLMNodeCompletionModelPromptTemplate, LLMNodeData, ModelConfig, + PromptConfig, VisionConfig, VisionConfigOptions, ) -from dify_graph.nodes.llm.file_saver import LLMFileSaver -from dify_graph.nodes.llm.node import LLMNode -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment +from graphon.nodes.llm.exc import ( + InvalidContextStructureError, + LLMNodeError, + NoPromptFoundError, + VariableNotFoundError, +) +from graphon.nodes.llm.file_saver import LLMFileSaver +from graphon.nodes.llm.node import ( + LLMNode, + _calculate_rest_token, + _handle_completion_template, + _handle_memory_chat_mode, + _handle_memory_completion_mode, + _render_jinja2_message, +) +from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory +from graphon.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.template_rendering import TemplateRenderError +from graphon.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment + +from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, ModelConfigWithCredentialsEntity, UserFrom +from core.app.llm.model_access import ( + DifyCredentialsProvider, + DifyModelFactory, + build_dify_model_access, + fetch_model_config, +) +from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle +from core.entities.provider_entities import CustomConfiguration, SystemConfiguration +from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime +from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.workflow.system_variables import default_system_variables from models.provider import ProviderType from tests.workflow_test_utils import build_test_graph_init_params @@ -55,6 +95,62 @@ class MockTokenBufferMemory: return self.history_messages +def _build_prepared_llm_mock() -> mock.MagicMock: + model_instance = mock.MagicMock() + model_instance.provider = "openai" + model_instance.model_name = "gpt-3.5-turbo" + model_instance.parameters = {} + model_instance.stop = () + model_instance.get_llm_num_tokens.return_value = 0 + model_instance.get_model_schema.return_value = AIModelEntity( + model="gpt-3.5-turbo", + label=I18nObject(en_US="GPT-3.5 Turbo"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={}, + ) + model_instance.is_structured_output_parse_error.return_value = False + return model_instance + + +def _build_model_schema( + *, + features: list[ModelFeature] | None = None, + model_properties: dict[ModelPropertyKey, object] | None = None, + parameter_rules: list[ParameterRule] | None = None, +) -> AIModelEntity: + return AIModelEntity( + model="gpt-3.5-turbo", + label=I18nObject(en_US="GPT-3.5 Turbo"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + features=features, + model_properties=model_properties or {}, + parameter_rules=parameter_rules or [], + ) + + +def _build_image_file( + *, + file_id: str, + related_id: str, + remote_url: str, + extension: str = ".png", + mime_type: str = "image/png", +) -> File: + return File( + id=file_id, + type=FileType.IMAGE, + filename=f"{file_id}{extension}", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=remote_url, + related_id=related_id, + extension=extension, + mime_type=mime_type, + storage_key="", + ) + + @pytest.fixture def llm_node_data() -> LLMNodeData: return LLMNodeData( @@ -91,7 +187,7 @@ def graph_init_params() -> GraphInitParams: @pytest.fixture def graph_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) return GraphRuntimeState( @@ -107,7 +203,7 @@ def llm_node( mock_file_saver = mock.MagicMock(spec=LLMFileSaver) mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) mock_model_factory = mock.MagicMock(spec=ModelFactory) - mock_template_renderer = mock.MagicMock(spec=TemplateRenderer) + mock_prompt_message_serializer = mock.MagicMock(spec=PromptMessageSerializerProtocol) node_config = { "id": "1", "data": llm_node_data.model_dump(), @@ -120,9 +216,9 @@ def llm_node( graph_runtime_state=graph_runtime_state, credentials_provider=mock_credentials_provider, model_factory=mock_model_factory, - model_instance=mock.MagicMock(spec=ModelInstance), + model_instance=_build_prepared_llm_mock(), llm_file_saver=mock_file_saver, - template_renderer=mock_template_renderer, + prompt_message_serializer=mock_prompt_message_serializer, http_client=http_client, ) return node @@ -132,28 +228,31 @@ def llm_node( def model_config(monkeypatch): from tests.integration_tests.model_runtime.__mock.plugin_model import MockModelClass - def mock_plugin_model_providers(_self): - providers = MockModelClass().fetch_model_providers("test") - for provider in providers: - provider.declaration.provider = f"{provider.plugin_id}/{provider.declaration.provider}" + def mock_model_providers(_self): + providers = [] + for provider in MockModelClass().fetch_model_providers("test"): + provider_schema = provider.declaration.model_copy(deep=True) + provider_schema.provider = f"{provider.plugin_id}/{provider.provider}" + provider_schema.provider_name = provider.provider + providers.append(provider_schema) return providers monkeypatch.setattr( ModelProviderFactory, - "get_plugin_model_providers", - mock_plugin_model_providers, + "get_model_providers", + mock_model_providers, ) # Create actual provider and model type instances - model_provider_factory = ModelProviderFactory(tenant_id="test") - provider_instance = model_provider_factory.get_plugin_model_provider("openai") + model_provider_factory = ModelProviderFactory(model_runtime=create_plugin_model_runtime(tenant_id="test")) + provider_instance = model_provider_factory.get_model_provider("openai") model_type_instance = model_provider_factory.get_model_type_instance("openai", ModelType.LLM) # Create a ProviderModelBundle provider_model_bundle = ProviderModelBundle( configuration=ProviderConfiguration( tenant_id="1", - provider=provider_instance.declaration, + provider=provider_instance, preferred_provider_type=ProviderType.CUSTOM, using_provider_type=ProviderType.CUSTOM, system_configuration=SystemConfiguration(enabled=False), @@ -181,13 +280,18 @@ def model_config(monkeypatch): ) -def test_fetch_model_config_uses_ports(model_config: ModelConfigWithCredentialsEntity): +def test_fetch_model_config_hydrates_model_instance_runtime_settings(model_config: ModelConfigWithCredentialsEntity): mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) - mock_model_factory = mock.MagicMock(spec=ModelFactory) + mock_model_factory = mock.MagicMock(spec=DifyModelFactory) provider_model_bundle = model_config.provider_model_bundle model_type_instance = provider_model_bundle.model_type_instance provider_model = mock.MagicMock() + completion_params = { + "temperature": 0.7, + "max_tokens": 256, + "stop": ["Observation:", "Human:"], + } model_instance = mock.MagicMock( model_type_instance=model_type_instance, @@ -208,12 +312,36 @@ def test_fetch_model_config_uses_ports(model_config: ModelConfigWithCredentialsE model_type_instance.__class__, "get_model_schema", return_value=model_config.model_schema, autospec=True ), ): - fetch_model_config( - node_data_model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), + hydrated_model_instance, model_config_with_credentials = fetch_model_config( + node_data_model=ModelConfig( + provider="openai", + name="gpt-3.5-turbo", + mode="chat", + completion_params=completion_params, + ), credentials_provider=mock_credentials_provider, model_factory=mock_model_factory, ) + assert hydrated_model_instance is model_instance + assert hydrated_model_instance.provider == "openai" + assert hydrated_model_instance.model_name == "gpt-3.5-turbo" + assert hydrated_model_instance.credentials == {"api_key": "test"} + assert hydrated_model_instance.parameters == { + "temperature": 0.7, + "max_tokens": 256, + } + assert hydrated_model_instance.stop == ("Observation:", "Human:") + assert model_config_with_credentials.parameters == { + "temperature": 0.7, + "max_tokens": 256, + } + assert model_config_with_credentials.stop == ["Observation:", "Human:"] + assert completion_params == { + "temperature": 0.7, + "max_tokens": 256, + "stop": ["Observation:", "Human:"], + } mock_credentials_provider.fetch.assert_called_once_with("openai", "gpt-3.5-turbo") mock_model_factory.init_model_instance.assert_called_once_with("openai", "gpt-3.5-turbo") provider_model.raise_for_status.assert_called_once() @@ -230,12 +358,20 @@ def test_dify_model_access_adapters_call_managers(): mock_provider_configuration.get_provider_model.return_value = mock_provider_model mock_provider_configuration.get_current_credentials.return_value = {"api_key": "test"} - credentials_provider = DifyCredentialsProvider( + run_context = DifyRunContext( tenant_id="tenant", + app_id="app", + user_id="user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + + credentials_provider = DifyCredentialsProvider( + run_context=run_context, provider_manager=mock_provider_manager, ) model_factory = DifyModelFactory( - tenant_id="tenant", + run_context=run_context, model_manager=mock_model_manager, ) @@ -255,18 +391,18 @@ def test_dify_model_access_adapters_call_managers(): model="gpt-3.5-turbo", ) mock_provider_model.raise_for_status.assert_called_once() - mock_model_manager.get_model_instance.assert_called_once_with( - tenant_id="tenant", - provider="openai", - model_type=ModelType.LLM, - model="gpt-3.5-turbo", - ) + mock_model_manager.get_model_instance.assert_called_once() + assert mock_model_manager.get_model_instance.call_args.kwargs == { + "tenant_id": "tenant", + "provider": "openai", + "model_type": ModelType.LLM, + "model": "gpt-3.5-turbo", + } def test_fetch_files_with_file_segment(): file = File( id="1", - tenant_id="test", type=FileType.IMAGE, filename="test.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, @@ -284,7 +420,6 @@ def test_fetch_files_with_array_file_segment(): files = [ File( id="1", - tenant_id="test", type=FileType.IMAGE, filename="test1.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, @@ -293,7 +428,6 @@ def test_fetch_files_with_array_file_segment(): ), File( id="2", - tenant_id="test", type=FileType.IMAGE, filename="test2.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, @@ -343,7 +477,6 @@ def test_fetch_files_with_non_existent_variable(): # files = [ # File( # id="1", -# tenant_id="test", # type=FileType.IMAGE, # filename="test1.jpg", # transfer_method=FileTransferMethod.REMOTE_URL, @@ -448,7 +581,6 @@ def test_fetch_files_with_non_existent_variable(): # sys_query=fake_query, # sys_files=[ # File( -# tenant_id="test", # type=FileType.IMAGE, # filename="test1.jpg", # transfer_method=FileTransferMethod.REMOTE_URL, @@ -524,7 +656,6 @@ def test_fetch_files_with_non_existent_variable(): # + [UserPromptMessage(content=fake_query)], # file_variables={ # "input.image": File( -# tenant_id="test", # type=FileType.IMAGE, # filename="test1.jpg", # transfer_method=FileTransferMethod.REMOTE_URL, @@ -569,7 +700,7 @@ def test_fetch_files_with_non_existent_variable(): def test_handle_list_messages_basic(llm_node): messages = [ LLMNodeChatModelMessage( - text="Hello, {#context#}", + text="Hello, {{#context#}}", role=PromptMessageRole.USER, edition_type="basic", ) @@ -592,32 +723,414 @@ def test_handle_list_messages_basic(llm_node): assert result[0].content == [TextPromptMessageContent(data="Hello, world")] -def test_handle_list_messages_jinja2_uses_template_renderer(llm_node): - llm_node._template_renderer.render_jinja2.return_value = "Hello, world" +def test_handle_list_messages_replaces_double_brace_context_placeholder(llm_node): messages = [ LLMNodeChatModelMessage( - text="", - jinja2_text="Hello, {{ name }}", - role=PromptMessageRole.USER, - edition_type="jinja2", + text="Answer user's question with the following context:\n\n{{#context#}}", + role=PromptMessageRole.SYSTEM, + edition_type="basic", ) ] + context = "## Overview\nSends a JSON request." result = llm_node.handle_list_messages( messages=messages, - context=None, + context=context, jinja2_variables=[], variable_pool=llm_node.graph_runtime_state.variable_pool, vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH, - template_renderer=llm_node._template_renderer, ) - assert result == [UserPromptMessage(content=[TextPromptMessageContent(data="Hello, world")])] - llm_node._template_renderer.render_jinja2.assert_called_once_with( - template="Hello, {{ name }}", - inputs={}, + assert len(result) == 1 + assert isinstance(result[0].content, list) + assert result[0].content == [ + TextPromptMessageContent( + data="Answer user's question with the following context:\n\n## Overview\nSends a JSON request." + ) + ] + + +def test_handle_list_messages_renders_jinja2_messages(llm_node): + llm_node.graph_runtime_state.variable_pool.add(["input", "name"], "Dify") + renderer = mock.MagicMock() + renderer.render_template.return_value = "Hello Dify" + + prompt_messages = llm_node.handle_list_messages( + messages=[ + LLMNodeChatModelMessage( + text="ignored", + jinja2_text="Hello {{ name }}", + role=PromptMessageRole.SYSTEM, + edition_type="jinja2", + ) + ], + context="", + jinja2_variables=[VariableSelector(variable="name", value_selector=["input", "name"])], + variable_pool=llm_node.graph_runtime_state.variable_pool, + vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH, + jinja2_template_renderer=renderer, ) + assert prompt_messages == [ + SystemPromptMessage(content=[TextPromptMessageContent(data="Hello Dify")]), + ] + renderer.render_template.assert_called_once_with("Hello {{ name }}", {"name": "Dify"}) + + +def test_transform_chat_messages_prefers_jinja2_text(llm_node): + completion_template = LLMNodeCompletionModelPromptTemplate( + text="ignored", + jinja2_text="completion prompt", + edition_type="jinja2", + ) + chat_messages = [ + LLMNodeChatModelMessage( + text="ignored", + jinja2_text="chat prompt", + role=PromptMessageRole.USER, + edition_type="jinja2", + ), + LLMNodeChatModelMessage( + text="keep original", + role=PromptMessageRole.SYSTEM, + edition_type="basic", + ), + ] + + transformed_completion = llm_node._transform_chat_messages(completion_template) + transformed_messages = llm_node._transform_chat_messages(chat_messages) + + assert transformed_completion.text == "completion prompt" + assert transformed_messages[0].text == "chat prompt" + assert transformed_messages[1].text == "keep original" + + +def test_fetch_jinja_inputs_serializes_supported_segment_types(llm_node): + llm_node.graph_runtime_state.variable_pool.add( + ["input", "items"], + ["alpha", {"metadata": {"_source": "knowledge"}, "content": "beta"}, 3], + ) + llm_node.graph_runtime_state.variable_pool.add( + ["input", "context_doc"], + {"metadata": {"_source": "knowledge"}, "content": "context body"}, + ) + llm_node.graph_runtime_state.variable_pool.add(["input", "payload"], {"a": 1}) + + node_data = llm_node.node_data.model_copy( + update={ + "prompt_config": PromptConfig( + jinja2_variables=[ + VariableSelector(variable="items", value_selector=["input", "items"]), + VariableSelector(variable="context_doc", value_selector=["input", "context_doc"]), + VariableSelector(variable="payload", value_selector=["input", "payload"]), + ] + ) + } + ) + + assert llm_node._fetch_jinja_inputs(node_data) == { + "items": "alpha\nbeta\n3", + "context_doc": "context body", + "payload": '{"a": 1}', + } + + +def test_fetch_jinja_inputs_raises_for_missing_variable(llm_node): + node_data = llm_node.node_data.model_copy( + update={ + "prompt_config": PromptConfig( + jinja2_variables=[VariableSelector(variable="missing", value_selector=["input", "missing"])] + ) + } + ) + + with pytest.raises(VariableNotFoundError, match="Variable missing not found"): + llm_node._fetch_jinja_inputs(node_data) + + +def test_fetch_inputs_collects_prompt_and_memory_variables(llm_node): + llm_node.graph_runtime_state.variable_pool.add(["input", "name"], "Dify") + llm_node.graph_runtime_state.variable_pool.add(["input", "payload"], {"active": True}) + + node_data = llm_node.node_data.model_copy( + update={ + "prompt_template": [ + LLMNodeChatModelMessage( + text="Hello {{#input.name#}} with {{#input.payload#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + "memory": MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=True, size=1), + query_prompt_template="Repeat {{#input.name#}}", + ), + } + ) + + assert llm_node._fetch_inputs(node_data) == { + "#input.name#": "Dify", + "#input.payload#": {"active": True}, + } + + +def test_fetch_context_emits_string_context_event(llm_node): + llm_node.graph_runtime_state.variable_pool.add(["context", "value"], "retrieved context") + node_data = llm_node.node_data.model_copy( + update={"context": ContextConfig(enabled=True, variable_selector=["context", "value"])} + ) + + events = list(llm_node._fetch_context(node_data)) + + assert events == [ + RunRetrieverResourceEvent(retriever_resources=[], context="retrieved context", context_files=[]), + ] + + +def test_fetch_context_collects_retriever_resources_and_attachments(llm_node): + attachment = _build_image_file( + file_id="attachment", + related_id="attachment-related", + remote_url="https://example.com/attachment.png", + ) + llm_node._retriever_attachment_loader = mock.MagicMock() + llm_node._retriever_attachment_loader.load.return_value = [attachment] + + llm_node.graph_runtime_state.variable_pool.add( + ["context", "value"], + [ + { + "content": "chunk body", + "summary": "chunk summary", + "files": [{"id": "file-1"}], + "metadata": { + "_source": "knowledge", + "dataset_id": "dataset-1", + "segment_id": "segment-1", + "segment_word_count": 12, + }, + }, + "tail text", + ], + ) + node_data = llm_node.node_data.model_copy( + update={"context": ContextConfig(enabled=True, variable_selector=["context", "value"])} + ) + + events = list(llm_node._fetch_context(node_data)) + + assert len(events) == 1 + event = events[0] + assert event.context == "chunk summary\nchunk body\ntail text" + assert event.context_files == [attachment] + assert event.retriever_resources == [ + { + "position": None, + "dataset_id": "dataset-1", + "dataset_name": None, + "document_id": None, + "document_name": None, + "data_source_type": None, + "segment_id": "segment-1", + "retriever_from": None, + "score": None, + "hit_count": None, + "word_count": 12, + "segment_position": None, + "index_node_hash": None, + "content": "chunk body", + "page": None, + "doc_metadata": None, + "files": [{"id": "file-1"}], + "summary": "chunk summary", + } + ] + llm_node._retriever_attachment_loader.load.assert_called_once_with(segment_id="segment-1") + + +def test_fetch_context_rejects_invalid_context_structure(llm_node): + llm_node.graph_runtime_state.variable_pool.add(["context", "value"], [{"summary": "missing content"}]) + node_data = llm_node.node_data.model_copy( + update={"context": ContextConfig(enabled=True, variable_selector=["context", "value"])} + ) + + with pytest.raises(InvalidContextStructureError, match="Invalid context structure"): + list(llm_node._fetch_context(node_data)) + + +def test_fetch_prompt_messages_chat_mode_appends_memory_query_and_files(): + model_instance = _build_prepared_llm_mock() + model_instance.get_model_schema.return_value = _build_model_schema(features=[ModelFeature.VISION]) + + memory = mock.MagicMock(spec=MockTokenBufferMemory) + memory.get_history_prompt_messages.return_value = [AssistantPromptMessage(content="history answer")] + + sys_file = _build_image_file(file_id="sys-file", related_id="sys-related", remote_url="https://example.com/sys.png") + context_file = _build_image_file( + file_id="context-file", + related_id="context-related", + remote_url="https://example.com/context.png", + ) + + prompt_content_side_effect = [ + ImagePromptMessageContent( + url="https://example.com/sys.png", + format="png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ImagePromptMessageContent( + url="https://example.com/context.png", + format="png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ] + + with mock.patch("graphon.nodes.llm.node.file_manager.to_prompt_message_content") as mock_to_prompt: + mock_to_prompt.side_effect = prompt_content_side_effect + prompt_messages, stop = LLMNode.fetch_prompt_messages( + sys_query="current question", + sys_files=[sys_file], + context="", + memory=memory, + model_instance=model_instance, + prompt_template=[ + LLMNodeChatModelMessage( + text="Before query", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + stop=("STOP",), + memory_config=MemoryConfig( + window=MemoryConfig.WindowConfig(enabled=False), + ), + vision_enabled=True, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + context_files=[context_file], + ) + + assert stop == ("STOP",) + assert prompt_messages[0] == UserPromptMessage(content="Before query") + assert prompt_messages[1] == AssistantPromptMessage(content="history answer") + assert isinstance(prompt_messages[2], UserPromptMessage) + assert isinstance(prompt_messages[2].content, list) + assert isinstance(prompt_messages[2].content[0], ImagePromptMessageContent) + assert isinstance(prompt_messages[2].content[1], ImagePromptMessageContent) + assert isinstance(prompt_messages[2].content[2], TextPromptMessageContent) + assert prompt_messages[2].content[0].url == "https://example.com/context.png" + assert prompt_messages[2].content[1].url == "https://example.com/sys.png" + assert prompt_messages[2].content[2].data == "current question" + memory.get_history_prompt_messages.assert_called_once_with(max_token_limit=2000, message_limit=None) + + +def test_fetch_prompt_messages_completion_mode_injects_histories_and_query(): + model_instance = _build_prepared_llm_mock() + model_instance.get_model_schema.return_value = _build_model_schema(features=[]) + + memory = mock.MagicMock(spec=MockTokenBufferMemory) + memory.get_history_prompt_messages.return_value = [ + UserPromptMessage(content="previous question"), + AssistantPromptMessage(content="previous answer"), + ] + + prompt_messages, stop = LLMNode.fetch_prompt_messages( + sys_query="latest question", + sys_files=[], + context="", + memory=memory, + model_instance=model_instance, + prompt_template=LLMNodeCompletionModelPromptTemplate( + text="Prompt header\n#histories#", + edition_type="basic", + ), + stop=("HALT",), + memory_config=MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=True, size=2), + ), + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + assert stop == ("HALT",) + assert prompt_messages == [ + UserPromptMessage( + content="latest question\nPrompt header\nHuman: previous question\nAssistant: previous answer" + ) + ] + + +def test_fetch_prompt_messages_raises_when_only_unsupported_content_remains(): + model_instance = _build_prepared_llm_mock() + model_instance.get_model_schema.return_value = _build_model_schema(features=[]) + + variable_pool = VariablePool.empty() + variable_pool.add( + ["input", "image"], + _build_image_file(file_id="image-file", related_id="image-related", remote_url="https://example.com/file.png"), + ) + + with ( + mock.patch( + "graphon.nodes.llm.node.file_manager.to_prompt_message_content", + return_value=ImagePromptMessageContent( + url="https://example.com/file.png", + format="png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ), + pytest.raises(NoPromptFoundError, match="No prompt found"), + ): + LLMNode.fetch_prompt_messages( + sys_query=None, + sys_files=[], + context="", + memory=None, + model_instance=model_instance, + prompt_template=[ + LLMNodeChatModelMessage( + text="{{#input.image#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + stop=None, + memory_config=None, + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=variable_pool, + jinja2_variables=[], + ) + + +def test_handle_completion_template_replaces_double_brace_context_placeholder(llm_node): + prompt_messages = _handle_completion_template( + template=LLMNodeCompletionModelPromptTemplate( + text="Summarize the following context:\n{{#context#}}", + edition_type="basic", + ), + context="## Overview\nSends a JSON request.", + jinja2_variables=[], + variable_pool=llm_node.graph_runtime_state.variable_pool, + jinja2_template_renderer=None, + ) + + assert prompt_messages == [ + UserPromptMessage( + content=[ + TextPromptMessageContent(data="Summarize the following context:\n## Overview\nSends a JSON request.") + ] + ) + ] + def test_handle_memory_completion_mode_uses_prompt_message_interface(): memory = mock.MagicMock(spec=MockTokenBufferMemory) @@ -635,15 +1148,15 @@ def test_handle_memory_completion_mode_uses_prompt_message_interface(): AssistantPromptMessage(content="first answer"), ] - model_instance = mock.MagicMock(spec=ModelInstance) + model_instance = _build_prepared_llm_mock() memory_config = MemoryConfig( role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), window=MemoryConfig.WindowConfig(enabled=True, size=3), ) - with mock.patch("dify_graph.nodes.llm.llm_utils.calculate_rest_token", return_value=2000) as mock_rest_token: - memory_text = llm_utils.handle_memory_completion_mode( + with mock.patch("graphon.nodes.llm.node._calculate_rest_token", return_value=2000) as mock_rest_token: + memory_text = _handle_memory_completion_mode( memory=memory, memory_config=memory_config, model_instance=model_instance, @@ -659,7 +1172,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver) mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) mock_model_factory = mock.MagicMock(spec=ModelFactory) - mock_template_renderer = mock.MagicMock(spec=TemplateRenderer) + mock_prompt_message_serializer = mock.MagicMock(spec=PromptMessageSerializerProtocol) node_config = { "id": "1", "data": llm_node_data.model_dump(), @@ -672,9 +1185,9 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat graph_runtime_state=graph_runtime_state, credentials_provider=mock_credentials_provider, model_factory=mock_model_factory, - model_instance=mock.MagicMock(spec=ModelInstance), + model_instance=_build_prepared_llm_mock(), llm_file_saver=mock_file_saver, - template_renderer=mock_template_renderer, + prompt_message_serializer=mock_prompt_message_serializer, http_client=http_client, ) return node, mock_file_saver @@ -690,7 +1203,6 @@ class TestLLMNodeSaveMultiModalImageOutput: ) mock_file = File( id=str(uuid.uuid4()), - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id=str(uuid.uuid4()), @@ -721,7 +1233,6 @@ class TestLLMNodeSaveMultiModalImageOutput: ) mock_file = File( id=str(uuid.uuid4()), - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id=str(uuid.uuid4()), @@ -776,7 +1287,6 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown: mock_saved_file = File( id=str(uuid.uuid4()), - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, filename="test.png", @@ -906,3 +1416,322 @@ class TestReasoningFormat: assert clean_text == text_with_think assert reasoning_content == "" + + +@pytest.mark.parametrize( + ("structured_output_enabled", "structured_output"), + [ + (False, None), + (True, {"schema": {"type": "object", "properties": {"answer": {"type": "string"}}}}), + ], +) +def test_invoke_llm_dispatches_to_expected_model_method(structured_output_enabled, structured_output): + model_instance = _build_prepared_llm_mock() + prompt_messages = [UserPromptMessage(content="hello")] + file_saver = mock.MagicMock(spec=LLMFileSaver) + + model_instance.invoke_llm.return_value = iter([]) + model_instance.invoke_llm_with_structured_output.return_value = iter([]) + + with ( + mock.patch.object(LLMNode, "handle_invoke_result", return_value=iter(["handled"])) as mock_handle, + mock.patch("graphon.nodes.llm.node.time.perf_counter", return_value=10.0), + ): + result = list( + LLMNode.invoke_llm( + model_instance=model_instance, + prompt_messages=prompt_messages, + stop=("STOP",), + structured_output_enabled=structured_output_enabled, + structured_output=structured_output, + file_saver=file_saver, + file_outputs=[], + node_id="node-1", + node_type=LLMNode.node_type, + reasoning_format="separated", + ) + ) + + assert result == ["handled"] + if structured_output_enabled: + model_instance.invoke_llm_with_structured_output.assert_called_once_with( + prompt_messages=prompt_messages, + json_schema={"type": "object", "properties": {"answer": {"type": "string"}}}, + model_parameters={}, + stop=("STOP",), + stream=True, + ) + model_instance.invoke_llm.assert_not_called() + else: + model_instance.invoke_llm.assert_called_once_with( + prompt_messages=prompt_messages, + model_parameters={}, + tools=None, + stop=("STOP",), + stream=True, + ) + model_instance.invoke_llm_with_structured_output.assert_not_called() + + assert mock_handle.call_args.kwargs["request_start_time"] == 10.0 + + +def test_handle_invoke_result_streaming_collects_text_metrics_and_structured_output(): + usage = LLMUsage.from_metadata({"prompt_tokens": 12, "completion_tokens": 4, "total_tokens": 16}) + first_chunk = LLMResultChunkWithStructuredOutput( + model="gpt-3.5-turbo", + prompt_messages=[], + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=[TextPromptMessageContent(data="plan")]), + ), + structured_output={"draft": True}, + ) + final_chunk = LLMResultChunk( + model="gpt-3.5-turbo", + prompt_messages=[], + delta=LLMResultChunkDelta( + index=1, + message=AssistantPromptMessage(content=[TextPromptMessageContent(data="answer")]), + usage=usage, + finish_reason="stop", + ), + ) + + with mock.patch("graphon.nodes.llm.node.time.perf_counter", side_effect=[2.0, 5.0]): + events = list( + LLMNode.handle_invoke_result( + invoke_result=iter([first_chunk, final_chunk]), + file_saver=mock.MagicMock(spec=LLMFileSaver), + file_outputs=[], + node_id="node-1", + node_type=LLMNode.node_type, + model_instance=_build_prepared_llm_mock(), + reasoning_format="separated", + request_start_time=1.0, + ) + ) + + assert events[0] == first_chunk + assert events[1] == StreamChunkEvent(selector=["node-1", "text"], chunk="plan", is_final=False) + assert events[2] == StreamChunkEvent(selector=["node-1", "text"], chunk="answer", is_final=False) + + completed = events[3] + assert isinstance(completed, ModelInvokeCompletedEvent) + assert completed.text == "answer" + assert completed.reasoning_content == "plan" + assert completed.structured_output == {"draft": True} + assert completed.finish_reason == "stop" + assert completed.usage.total_tokens == 16 + assert completed.usage.latency == 4.0 + assert completed.usage.time_to_first_token == 1.0 + assert completed.usage.time_to_generate == 3.0 + + +def test_handle_invoke_result_wraps_structured_output_parse_errors(): + model_instance = _build_prepared_llm_mock() + model_instance.is_structured_output_parse_error.return_value = True + + def broken_stream(): + raise ValueError("bad json") + yield + + with pytest.raises(LLMNodeError, match="Failed to parse structured output: bad json"): + list( + LLMNode.handle_invoke_result( + invoke_result=broken_stream(), + file_saver=mock.MagicMock(spec=LLMFileSaver), + file_outputs=[], + node_id="node-1", + node_type=LLMNode.node_type, + model_instance=model_instance, + ) + ) + + +def test_handle_blocking_result_extracts_reasoning_and_structured_output(): + invoke_result = LLMResultWithStructuredOutput( + model="gpt-3.5-turbo", + prompt_messages=[], + message=AssistantPromptMessage(content="reasoningfinal answer"), + usage=LLMUsage.empty_usage(), + structured_output={"answer": "final answer"}, + ) + + event = LLMNode.handle_blocking_result( + invoke_result=invoke_result, + saver=mock.MagicMock(spec=LLMFileSaver), + file_outputs=[], + reasoning_format="separated", + request_latency=1.2345, + ) + + assert event.text == "final answer" + assert event.reasoning_content == "reasoning" + assert event.structured_output == {"answer": "final answer"} + assert event.usage.latency == 1.234 + + +def test_fetch_structured_output_schema_validates_payload(): + assert LLMNode.fetch_structured_output_schema(structured_output={"schema": {"type": "object"}}) == { + "type": "object" + } + + with pytest.raises(LLMNodeError, match="Please provide a valid structured output schema"): + LLMNode.fetch_structured_output_schema(structured_output={}) + + with pytest.raises(LLMNodeError, match="structured_output_schema must be a JSON object"): + LLMNode.fetch_structured_output_schema(structured_output={"schema": ["not", "an", "object"]}) + + +def test_extract_variable_selector_to_variable_mapping_includes_runtime_selectors(): + node_data = LLMNodeData( + title="Test LLM", + model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), + prompt_template=[ + LLMNodeChatModelMessage( + text="Hello {{#input.name#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text="ignored", + jinja2_text="Hello {{ name }}", + role=PromptMessageRole.SYSTEM, + edition_type="jinja2", + ), + ], + prompt_config=PromptConfig( + jinja2_variables=[VariableSelector(variable="name", value_selector=["input", "name"])] + ), + memory=MemoryConfig( + window=MemoryConfig.WindowConfig(enabled=True, size=1), + query_prompt_template="Repeat {{#sys.query#}}", + ), + context=ContextConfig(enabled=True, variable_selector=["context", "value"]), + vision=VisionConfig(enabled=True), + ) + + mapping = LLMNode._extract_variable_selector_to_variable_mapping( + graph_config={}, + node_id="llm-1", + node_data=node_data, + ) + + assert mapping == { + "llm-1.#input.name#": ["input", "name"], + "llm-1.#sys.query#": ["sys", "query"], + "llm-1.#context#": ["context", "value"], + "llm-1.#files#": ["sys", "files"], + "llm-1.name": ["input", "name"], + } + + +def test_render_jinja2_message_requires_renderer_and_passes_inputs(): + variable_pool = VariablePool.empty() + variable_pool.add(["input", "name"], "Dify") + variables = [VariableSelector(variable="name", value_selector=["input", "name"])] + + with pytest.raises( + TemplateRenderError, + match="LLMNode requires an injected jinja2_template_renderer for jinja2 prompts", + ): + _render_jinja2_message( + template="Hello {{ name }}", + jinja2_variables=variables, + variable_pool=variable_pool, + jinja2_template_renderer=None, + ) + + renderer = mock.MagicMock() + renderer.render_template.return_value = "Hello Dify" + + assert ( + _render_jinja2_message( + template="Hello {{ name }}", + jinja2_variables=variables, + variable_pool=variable_pool, + jinja2_template_renderer=renderer, + ) + == "Hello Dify" + ) + renderer.render_template.assert_called_once_with("Hello {{ name }}", {"name": "Dify"}) + + +def test_calculate_rest_token_uses_context_size_and_max_tokens(): + model_instance = _build_prepared_llm_mock() + model_instance.parameters = {"max_tokens": 512} + model_instance.get_model_schema.return_value = _build_model_schema( + model_properties={ModelPropertyKey.CONTEXT_SIZE: 4096}, + parameter_rules=[ + ParameterRule( + name="max_tokens", + label=I18nObject(en_US="Max Tokens"), + type=ParameterType.INT, + ) + ], + ) + model_instance.get_llm_num_tokens.return_value = 1000 + + assert ( + _calculate_rest_token( + prompt_messages=[UserPromptMessage(content="hello")], + model_instance=model_instance, + ) + == 2584 + ) + + +def test_handle_memory_chat_mode_uses_calculated_token_budget(): + memory = mock.MagicMock(spec=MockTokenBufferMemory) + history = [UserPromptMessage(content="question")] + memory.get_history_prompt_messages.return_value = history + + with mock.patch("graphon.nodes.llm.node._calculate_rest_token", return_value=321) as mock_rest_token: + result = _handle_memory_chat_mode( + memory=memory, + memory_config=MemoryConfig(window=MemoryConfig.WindowConfig(enabled=True, size=2)), + model_instance=_build_prepared_llm_mock(), + ) + + assert result == history + mock_rest_token.assert_called_once() + memory.get_history_prompt_messages.assert_called_once_with(max_token_limit=321, message_limit=2) + + +def test_dify_model_access_adapters_skip_runtime_build_when_managers_are_injected(): + run_context = DifyRunContext( + tenant_id="tenant", + app_id="app", + user_id="user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + + with mock.patch("core.app.llm.model_access.create_plugin_provider_manager") as mock_provider_manager_factory: + DifyCredentialsProvider(run_context=run_context, provider_manager=mock.MagicMock()) + DifyModelFactory(run_context=run_context, model_manager=mock.MagicMock()) + + mock_provider_manager_factory.assert_not_called() + + +def test_build_dify_model_access_binds_run_context_user_id_once(): + run_context = DifyRunContext( + tenant_id="tenant", + app_id="app", + user_id="user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + + with mock.patch("core.app.llm.model_access.create_plugin_provider_manager") as mock_provider_manager: + build_dify_model_access(run_context) + + mock_provider_manager.assert_called_once_with(tenant_id="tenant", user_id="user") + + +def test_dify_model_access_requires_run_context_argument(): + with pytest.raises(TypeError): + DifyCredentialsProvider() + + with pytest.raises(TypeError): + DifyModelFactory() diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py deleted file mode 100644 index e40d565ef57..00000000000 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py +++ /dev/null @@ -1,25 +0,0 @@ -from collections.abc import Mapping, Sequence - -from pydantic import BaseModel, Field - -from dify_graph.file import File -from dify_graph.model_runtime.entities.message_entities import PromptMessage -from dify_graph.model_runtime.entities.model_entities import ModelFeature -from dify_graph.nodes.llm.entities import LLMNodeChatModelMessage - - -class LLMNodeTestScenario(BaseModel): - """Test scenario for LLM node testing.""" - - description: str = Field(..., description="Description of the test scenario") - sys_query: str = Field(..., description="User query input") - sys_files: Sequence[File] = Field(default_factory=list, description="List of user files") - vision_enabled: bool = Field(default=False, description="Whether vision is enabled") - vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled") - features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features") - window_size: int = Field(..., description="Window size for memory") - prompt_template: Sequence[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages") - file_variables: Mapping[str, File | Sequence[File]] = Field( - default_factory=dict, description="List of file variables" - ) - expected_messages: Sequence[PromptMessage] = Field(..., description="Expected messages after processing") diff --git a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py deleted file mode 100644 index fd48edc58c5..00000000000 --- a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py +++ /dev/null @@ -1,27 +0,0 @@ -from dify_graph.nodes.parameter_extractor.entities import ParameterConfig -from dify_graph.variables.types import SegmentType - - -class TestParameterConfig: - def test_select_type(self): - data = { - "name": "yes_or_no", - "type": "select", - "options": ["yes", "no"], - "description": "a simple select made of `yes` and `no`", - "required": True, - } - - pc = ParameterConfig.model_validate(data) - assert pc.type == SegmentType.STRING - assert pc.options == data["options"] - - def test_validate_bool_type(self): - data = { - "name": "boolean", - "type": "bool", - "description": "a simple boolean parameter", - "required": True, - } - pc = ParameterConfig.model_validate(data) - assert pc.type == SegmentType.BOOLEAN diff --git a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py index 7eca531b623..1c362a0a037 100644 --- a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py @@ -6,18 +6,18 @@ from dataclasses import dataclass from typing import Any import pytest - -from dify_graph.model_runtime.entities import LLMMode -from dify_graph.nodes.llm import ModelConfig, VisionConfig -from dify_graph.nodes.parameter_extractor.entities import ParameterConfig, ParameterExtractorNodeData -from dify_graph.nodes.parameter_extractor.exc import ( +from graphon.model_runtime.entities import LLMMode +from graphon.nodes.llm import ModelConfig, VisionConfig +from graphon.nodes.parameter_extractor.entities import ParameterConfig, ParameterExtractorNodeData +from graphon.nodes.parameter_extractor.exc import ( InvalidNumberOfParametersError, InvalidSelectValueError, InvalidValueTypeError, RequiredParameterMissingError, ) -from dify_graph.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode -from dify_graph.variables.types import SegmentType +from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from graphon.variables.types import SegmentType + from factories.variable_factory import build_segment_with_type diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py deleted file mode 100644 index e57ebbd83ee..00000000000 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py +++ /dev/null @@ -1,225 +0,0 @@ -import pytest -from pydantic import ValidationError - -from dify_graph.enums import ErrorStrategy -from dify_graph.nodes.template_transform.entities import TemplateTransformNodeData - - -class TestTemplateTransformNodeData: - """Test suite for TemplateTransformNodeData entity.""" - - def test_valid_template_transform_node_data(self): - """Test creating valid TemplateTransformNodeData.""" - data = { - "title": "Template Transform", - "desc": "Transform data using Jinja2 template", - "variables": [ - {"variable": "name", "value_selector": ["sys", "user_name"]}, - {"variable": "age", "value_selector": ["sys", "user_age"]}, - ], - "template": "Hello {{ name }}, you are {{ age }} years old!", - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert node_data.title == "Template Transform" - assert node_data.desc == "Transform data using Jinja2 template" - assert len(node_data.variables) == 2 - assert node_data.variables[0].variable == "name" - assert node_data.variables[0].value_selector == ["sys", "user_name"] - assert node_data.variables[1].variable == "age" - assert node_data.variables[1].value_selector == ["sys", "user_age"] - assert node_data.template == "Hello {{ name }}, you are {{ age }} years old!" - - def test_template_transform_node_data_with_empty_variables(self): - """Test TemplateTransformNodeData with no variables.""" - data = { - "title": "Static Template", - "variables": [], - "template": "This is a static template with no variables.", - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert node_data.title == "Static Template" - assert len(node_data.variables) == 0 - assert node_data.template == "This is a static template with no variables." - - def test_template_transform_node_data_with_complex_template(self): - """Test TemplateTransformNodeData with complex Jinja2 template.""" - data = { - "title": "Complex Template", - "variables": [ - {"variable": "items", "value_selector": ["sys", "item_list"]}, - {"variable": "total", "value_selector": ["sys", "total_count"]}, - ], - "template": ( - "{% for item in items %}{{ item }}{% if not loop.last %}, {% endif %}{% endfor %}. Total: {{ total }}" - ), - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert node_data.title == "Complex Template" - assert len(node_data.variables) == 2 - assert "{% for item in items %}" in node_data.template - assert "{{ total }}" in node_data.template - - def test_template_transform_node_data_with_error_strategy(self): - """Test TemplateTransformNodeData with error handling strategy.""" - data = { - "title": "Template with Error Handling", - "variables": [{"variable": "value", "value_selector": ["sys", "input"]}], - "template": "{{ value }}", - "error_strategy": "fail-branch", - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert node_data.error_strategy == ErrorStrategy.FAIL_BRANCH - - def test_template_transform_node_data_with_retry_config(self): - """Test TemplateTransformNodeData with retry configuration.""" - data = { - "title": "Template with Retry", - "variables": [{"variable": "data", "value_selector": ["sys", "data"]}], - "template": "{{ data }}", - "retry_config": {"enabled": True, "max_retries": 3, "retry_interval": 1000}, - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert node_data.retry_config.enabled is True - assert node_data.retry_config.max_retries == 3 - assert node_data.retry_config.retry_interval == 1000 - - def test_template_transform_node_data_missing_required_fields(self): - """Test that missing required fields raises ValidationError.""" - data = { - "title": "Incomplete Template", - # Missing 'variables' and 'template' - } - - with pytest.raises(ValidationError) as exc_info: - TemplateTransformNodeData.model_validate(data) - - errors = exc_info.value.errors() - assert len(errors) >= 2 - error_fields = {error["loc"][0] for error in errors} - assert "variables" in error_fields - assert "template" in error_fields - - def test_template_transform_node_data_invalid_variable_selector(self): - """Test that invalid variable selector format raises ValidationError.""" - data = { - "title": "Invalid Variable", - "variables": [ - {"variable": "name", "value_selector": "invalid_format"} # Should be list - ], - "template": "{{ name }}", - } - - with pytest.raises(ValidationError): - TemplateTransformNodeData.model_validate(data) - - def test_template_transform_node_data_with_default_value_dict(self): - """Test TemplateTransformNodeData with default value dictionary.""" - data = { - "title": "Template with Defaults", - "variables": [ - {"variable": "name", "value_selector": ["sys", "user_name"]}, - {"variable": "greeting", "value_selector": ["sys", "greeting"]}, - ], - "template": "{{ greeting }} {{ name }}!", - "default_value_dict": {"greeting": "Hello", "name": "Guest"}, - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert node_data.default_value_dict == {"greeting": "Hello", "name": "Guest"} - - def test_template_transform_node_data_with_nested_selectors(self): - """Test TemplateTransformNodeData with nested variable selectors.""" - data = { - "title": "Nested Selectors", - "variables": [ - {"variable": "user_info", "value_selector": ["sys", "user", "profile", "name"]}, - {"variable": "settings", "value_selector": ["sys", "config", "app", "theme"]}, - ], - "template": "User: {{ user_info }}, Theme: {{ settings }}", - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert len(node_data.variables) == 2 - assert node_data.variables[0].value_selector == ["sys", "user", "profile", "name"] - assert node_data.variables[1].value_selector == ["sys", "config", "app", "theme"] - - def test_template_transform_node_data_with_multiline_template(self): - """Test TemplateTransformNodeData with multiline template.""" - data = { - "title": "Multiline Template", - "variables": [ - {"variable": "title", "value_selector": ["sys", "title"]}, - {"variable": "content", "value_selector": ["sys", "content"]}, - ], - "template": """ -# {{ title }} - -{{ content }} - ---- -Generated by Template Transform Node - """, - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert "# {{ title }}" in node_data.template - assert "{{ content }}" in node_data.template - assert "Generated by Template Transform Node" in node_data.template - - def test_template_transform_node_data_serialization(self): - """Test that TemplateTransformNodeData can be serialized and deserialized.""" - original_data = { - "title": "Serialization Test", - "desc": "Test serialization", - "variables": [{"variable": "test", "value_selector": ["sys", "test"]}], - "template": "{{ test }}", - } - - node_data = TemplateTransformNodeData.model_validate(original_data) - serialized = node_data.model_dump() - deserialized = TemplateTransformNodeData.model_validate(serialized) - - assert deserialized.title == node_data.title - assert deserialized.desc == node_data.desc - assert len(deserialized.variables) == len(node_data.variables) - assert deserialized.template == node_data.template - - def test_template_transform_node_data_with_special_characters(self): - """Test TemplateTransformNodeData with special characters in template.""" - data = { - "title": "Special Characters", - "variables": [{"variable": "text", "value_selector": ["sys", "input"]}], - "template": "Special: {{ text }} | Symbols: @#$%^&*() | Unicode: 你好 🎉", - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert "@#$%^&*()" in node_data.template - assert "你好" in node_data.template - assert "🎉" in node_data.template - - def test_template_transform_node_data_empty_template(self): - """Test TemplateTransformNodeData with empty template string.""" - data = { - "title": "Empty Template", - "variables": [], - "template": "", - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert node_data.template == "" - assert len(node_data.variables) == 0 diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py index 332a8761f97..d86e0efe023 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py @@ -1,13 +1,15 @@ from unittest.mock import MagicMock import pytest +from graphon.enums import BuiltinNodeTypes, ErrorStrategy, WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.nodes.base.entities import VariableSelector +from graphon.nodes.template_transform.entities import TemplateTransformNodeData +from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode +from graphon.runtime import GraphRuntimeState +from graphon.template_rendering import TemplateRenderError from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.enums import BuiltinNodeTypes, ErrorStrategy, WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.nodes.template_transform.template_renderer import TemplateRenderError -from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode -from dify_graph.runtime import GraphRuntimeState from tests.workflow_test_utils import build_test_graph_init_params @@ -62,7 +64,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) assert node.node_type == BuiltinNodeTypes.TEMPLATE_TRANSFORM @@ -78,7 +80,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) assert node._get_title() == "Template Transform" @@ -91,7 +93,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) assert node._get_description() == "Transform data using template" @@ -111,7 +113,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) assert node._get_error_strategy() == ErrorStrategy.FAIL_BRANCH @@ -130,6 +132,26 @@ class TestTemplateTransformNode: """Test version class method.""" assert TemplateTransformNode.version() == "1" + @pytest.mark.parametrize("max_output_length", [0, -1]) + def test_node_initialization_rejects_non_positive_max_output_length( + self, + basic_node_data, + mock_graph_runtime_state, + graph_init_params, + max_output_length, + ): + mock_renderer = MagicMock() + + with pytest.raises(ValueError, match="max_output_length must be a positive integer"): + TemplateTransformNode( + id="test_node", + config={"id": "test_node", "data": basic_node_data}, + graph_init_params=graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + jinja2_template_renderer=mock_renderer, + max_output_length=max_output_length, + ) + def test_run_simple_template(self, basic_node_data, mock_graph_runtime_state, graph_init_params): """Test _run with simple template transformation using injected renderer.""" # Setup mock variable pool @@ -153,7 +175,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -181,7 +203,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -201,7 +223,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -221,7 +243,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, max_output_length=10, ) @@ -230,6 +252,28 @@ class TestTemplateTransformNode: assert result.status == WorkflowNodeExecutionStatus.FAILED assert "Output length exceeds" in result.error + def test_run_output_length_equal_to_limit_succeeds( + self, basic_node_data, mock_graph_runtime_state, graph_init_params + ): + mock_graph_runtime_state.variable_pool.get.return_value = MagicMock() + + mock_renderer = MagicMock() + mock_renderer.render_template.return_value = "1234567890" + + node = TemplateTransformNode( + id="test_node", + config={"id": "test_node", "data": basic_node_data}, + graph_init_params=graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + jinja2_template_renderer=mock_renderer, + max_output_length=10, + ) + + result = node._run() + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs["output"] == "1234567890" + def test_run_with_complex_jinja2_template(self, mock_graph_runtime_state, graph_init_params): """Test _run with complex Jinja2 template including loops and conditions.""" node_data = { @@ -263,7 +307,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -291,6 +335,69 @@ class TestTemplateTransformNode: assert mapping["node_123.var1"] == ["sys", "input1"] assert mapping["node_123.var2"] == ["sys", "input2"] + def test_extract_variable_selector_to_variable_mapping_accepts_validated_node_data(self): + node_data = TemplateTransformNodeData( + title="Test", + variables=[VariableSelector(variable="var1", value_selector=["sys", "input1"])], + template="{{ var1 }}", + ) + + mapping = TemplateTransformNode._extract_variable_selector_to_variable_mapping( + graph_config={}, node_id="node_123", node_data=node_data + ) + + assert mapping == {"node_123.var1": ["sys", "input1"]} + + def test_extract_variable_selector_to_variable_mapping_returns_empty_mapping_without_variables(self): + node_data = { + "title": "Test", + "template": "{{ missing }}", + } + + mapping = TemplateTransformNode._extract_variable_selector_to_variable_mapping( + graph_config={}, node_id="node_123", node_data=node_data + ) + + assert mapping == {} + + def test_extract_variable_selector_to_variable_mapping_accepts_sequence_value_selectors(self): + node_data = { + "title": "Test", + "variables": [ + {"variable": "var1", "value_selector": ("sys", "input1")}, + {"variable": "empty_selector", "value_selector": ()}, + ], + "template": "{{ var1 }}", + } + + mapping = TemplateTransformNode._extract_variable_selector_to_variable_mapping( + graph_config={}, node_id="node_123", node_data=node_data + ) + + assert mapping == { + "node_123.var1": ["sys", "input1"], + "node_123.empty_selector": [], + } + + def test_extract_variable_selector_to_variable_mapping_ignores_invalid_entries(self): + node_data = { + "title": "Test", + "variables": [ + {"variable": "var1", "value_selector": ["sys", "input1"]}, + {"variable": "missing_selector"}, + ["not", "a", "mapping"], + {"variable": 1, "value_selector": ["sys", "input2"]}, + {"variable": "invalid_selector", "value_selector": ["sys", 2]}, + ], + "template": "{{ var1 }}", + } + + mapping = TemplateTransformNode._extract_variable_selector_to_variable_mapping( + graph_config={}, node_id="node_123", node_data=node_data + ) + + assert mapping == {"node_123.var1": ["sys", "input1"]} + def test_run_with_empty_variables(self, mock_graph_runtime_state, graph_init_params): """Test _run with no variables (static template).""" node_data = { @@ -307,7 +414,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -346,7 +453,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -375,7 +482,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -405,7 +512,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py new file mode 100644 index 00000000000..bd22a8e318c --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py @@ -0,0 +1,74 @@ +from unittest.mock import MagicMock + +import pytest +from graphon.nodes.base.entities import VariableSelector +from graphon.nodes.template_transform.template_transform_node import ( + DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH, + TemplateTransformNode, +) +from graphon.runtime import GraphRuntimeState + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from tests.workflow_test_utils import build_test_graph_init_params + +from .template_transform_node_spec import TestTemplateTransformNode # noqa: F401 + + +@pytest.fixture +def graph_init_params(): + return build_test_graph_init_params( + workflow_id="test_workflow", + graph_config={}, + tenant_id="test_tenant", + app_id="test_app", + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + +@pytest.fixture +def mock_graph_runtime_state(): + mock_state = MagicMock(spec=GraphRuntimeState) + mock_state.variable_pool = MagicMock() + return mock_state + + +def test_node_uses_default_max_output_length_when_not_overridden(graph_init_params, mock_graph_runtime_state): + node = TemplateTransformNode( + id="test_node", + config={ + "id": "test_node", + "data": { + "title": "Template Transform", + "variables": [], + "template": "hello", + }, + }, + graph_init_params=graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + jinja2_template_renderer=MagicMock(), + ) + + assert node._max_output_length == DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH + + +def test_extract_variable_selector_to_variable_mapping_accepts_mixed_valid_entries(): + mapping = TemplateTransformNode._extract_variable_selector_to_variable_mapping( + graph_config={"ignored": True}, + node_id="node_123", + node_data={ + "variables": [ + VariableSelector(variable="validated", value_selector=["sys", "input1"]), + {"variable": "raw", "value_selector": ("sys", "input2")}, + {"variable": "invalid_selector", "value_selector": ["sys", 3]}, + ["not", "a", "mapping"], + ] + }, + ) + + assert mapping == { + "node_123.validated": ["sys", "input1"], + "node_123.raw": ["sys", "input2"], + } diff --git a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py index 2b0205fb7b7..e11ebf6eb8b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py @@ -1,15 +1,16 @@ from collections.abc import Mapping import pytest +from graphon.entities import GraphInitParams +from graphon.entities.base_node_data import BaseNodeData +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.base.node import Node +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.entities import GraphInitParams -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.nodes.base.node import Node -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.node_runtime import resolve_dify_run_context +from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params @@ -35,7 +36,7 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, invoke_from="debugger", ) runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}), + variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}), start_at=0.0, ) return init_params, runtime_state @@ -67,7 +68,7 @@ def test_node_hydrates_data_during_initialization(): assert node.node_data.foo == "bar" assert node.title == "Sample" - dify_ctx = node.require_dify_context() + dify_ctx = resolve_dify_run_context(node.run_context) assert dify_ctx.user_from == "account" assert dify_ctx.invoke_from == "debugger" @@ -80,7 +81,7 @@ def test_node_accepts_invoke_from_enum(): invoke_from=InvokeFrom.DEBUGGER, ) runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}), + variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}), start_at=0.0, ) @@ -91,7 +92,7 @@ def test_node_accepts_invoke_from_enum(): graph_runtime_state=runtime_state, ) - dify_ctx = node.require_dify_context() + dify_ctx = resolve_dify_run_context(node.run_context) assert dify_ctx.user_from == UserFrom.ACCOUNT assert dify_ctx.invoke_from == InvokeFrom.DEBUGGER assert node.get_run_context_value("missing") is None @@ -127,3 +128,29 @@ def test_base_node_data_keeps_dict_style_access_compatibility(): assert node_data["foo"] == "bar" assert node_data.get("foo") == "bar" assert node_data.get("missing", "fallback") == "fallback" + + +def test_node_hydration_preserves_compatibility_extra_fields(): + graph_config: dict[str, object] = {} + init_params, runtime_state = _build_context(graph_config) + node_config = NodeConfigDictAdapter.validate_python( + { + "id": "node-1", + "data": { + "type": BuiltinNodeTypes.ANSWER, + "title": "Sample", + "foo": "bar", + "compat_flag": True, + }, + } + ) + + node = _SampleNode( + id="node-1", + config=node_config, + graph_init_params=init_params, + graph_runtime_state=runtime_state, + ) + + assert node.node_data.foo == "bar" + assert node.node_data.get("compat_flag") is True diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index 40754974c13..555ff0c9452 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -4,23 +4,23 @@ from unittest.mock import Mock, patch import pandas as pd import pytest from docx.oxml.text.paragraph import CT_P - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.entities import GraphInitParams -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.file import File, FileTransferMethod -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.document_extractor import DocumentExtractorNode, DocumentExtractorNodeData -from dify_graph.nodes.document_extractor.node import ( +from graphon.entities import GraphInitParams +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod +from graphon.node_events import NodeRunResult +from graphon.nodes.document_extractor import DocumentExtractorNode, DocumentExtractorNodeData +from graphon.nodes.document_extractor.node import ( _extract_text_from_docx, _extract_text_from_excel, _extract_text_from_pdf, _extract_text_from_plain_text, _normalize_docx_zip, ) -from dify_graph.variables import ArrayFileSegment -from dify_graph.variables.segments import ArrayStringSegment -from dify_graph.variables.variables import StringVariable +from graphon.variables import ArrayFileSegment +from graphon.variables.segments import ArrayStringSegment +from graphon.variables.variables import StringVariable + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from tests.workflow_test_utils import build_test_graph_init_params @@ -183,14 +183,14 @@ def test_run_extract_text( mock_response.raise_for_status = Mock() document_extractor_node._http_client.get = Mock(return_value=mock_response) - monkeypatch.setattr("dify_graph.file.file_manager.download", mock_download) + monkeypatch.setattr("graphon.file.file_manager.download", mock_download) if mime_type == "application/pdf": mock_pdf_extract = Mock(return_value=expected_text[0]) - monkeypatch.setattr("dify_graph.nodes.document_extractor.node._extract_text_from_pdf", mock_pdf_extract) + monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_from_pdf", mock_pdf_extract) elif mime_type.startswith("application/vnd.openxmlformats"): mock_docx_extract = Mock(return_value=expected_text[0]) - monkeypatch.setattr("dify_graph.nodes.document_extractor.node._extract_text_from_docx", mock_docx_extract) + monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_from_docx", mock_docx_extract) result = document_extractor_node._run() diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index c746a945fed..1b14f0ab133 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -3,19 +3,18 @@ import uuid from unittest.mock import MagicMock, Mock import pytest +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType +from graphon.graph import Graph +from graphon.nodes.if_else.entities import IfElseNodeData +from graphon.nodes.if_else.if_else_node import IfElseNode +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.utils.condition.entities import Condition, SubCondition, SubVariableCondition +from graphon.variables import ArrayFileSegment -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.graph import Graph -from dify_graph.nodes.if_else.entities import IfElseNodeData -from dify_graph.nodes.if_else.if_else_node import IfElseNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.utils.condition.entities import Condition, SubCondition, SubVariableCondition -from dify_graph.variables import ArrayFileSegment +from core.workflow.system_variables import build_system_variables from extensions.ext_database import db from tests.workflow_test_utils import build_test_graph_init_params @@ -35,7 +34,7 @@ def test_execute_if_else_result_true(): ) # construct variable pool - pool = VariablePool(system_variables=SystemVariable(user_id="aaa", files=[]), user_inputs={}) + pool = VariablePool(system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}) pool.add(["start", "array_contains"], ["ab", "def"]) pool.add(["start", "array_not_contains"], ["ac", "def"]) pool.add(["start", "contains"], "cabcde") @@ -142,7 +141,7 @@ def test_execute_if_else_result_false(): # construct variable pool pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], ) @@ -253,7 +252,6 @@ def test_array_file_contains_file_name(): node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment( value=[ File( - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", @@ -316,7 +314,7 @@ def test_execute_if_else_boolean_conditions(condition: Condition): # construct variable pool with boolean values pool = VariablePool( - system_variables=SystemVariable(files=[], user_id="aaa"), + system_variables=build_system_variables(files=[], user_id="aaa"), ) pool.add(["start", "bool_true"], True) pool.add(["start", "bool_false"], False) @@ -371,7 +369,7 @@ def test_execute_if_else_boolean_false_conditions(): # construct variable pool with boolean values pool = VariablePool( - system_variables=SystemVariable(files=[], user_id="aaa"), + system_variables=build_system_variables(files=[], user_id="aaa"), ) pool.add(["start", "bool_true"], True) pool.add(["start", "bool_false"], False) @@ -440,7 +438,7 @@ def test_execute_if_else_boolean_cases_structure(): # construct variable pool with boolean values pool = VariablePool( - system_variables=SystemVariable(files=[], user_id="aaa"), + system_variables=build_system_variables(files=[], user_id="aaa"), ) pool.add(["start", "bool_true"], True) pool.add(["start", "bool_false"], False) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index 6ca72b64b2d..d28c3e01e5f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -1,12 +1,9 @@ from unittest.mock import MagicMock import pytest - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.nodes.list_operator.entities import ( +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType +from graphon.nodes.list_operator.entities import ( ExtractConfig, FilterBy, FilterCondition, @@ -15,9 +12,11 @@ from dify_graph.nodes.list_operator.entities import ( Order, OrderByConfig, ) -from dify_graph.nodes.list_operator.exc import InvalidKeyError -from dify_graph.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func -from dify_graph.variables import ArrayFileSegment +from graphon.nodes.list_operator.exc import InvalidKeyError +from graphon.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func +from graphon.variables import ArrayFileSegment + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom @pytest.fixture @@ -72,7 +71,6 @@ def test_filter_files_by_type(list_operator_node): File( filename="image1.jpg", type=FileType.IMAGE, - tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related1", storage_key="", @@ -80,7 +78,6 @@ def test_filter_files_by_type(list_operator_node): File( filename="document1.pdf", type=FileType.DOCUMENT, - tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related2", storage_key="", @@ -88,7 +85,6 @@ def test_filter_files_by_type(list_operator_node): File( filename="image2.png", type=FileType.IMAGE, - tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related3", storage_key="", @@ -96,7 +92,6 @@ def test_filter_files_by_type(list_operator_node): File( filename="audio1.mp3", type=FileType.AUDIO, - tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related4", storage_key="", @@ -120,14 +115,12 @@ def test_filter_files_by_type(list_operator_node): { "filename": "document1.pdf", "type": FileType.DOCUMENT, - "tenant_id": "tenant1", "transfer_method": FileTransferMethod.LOCAL_FILE, "related_id": "related2", }, { "filename": "image2.png", "type": FileType.IMAGE, - "tenant_id": "tenant1", "transfer_method": FileTransferMethod.LOCAL_FILE, "related_id": "related3", }, @@ -136,7 +129,6 @@ def test_filter_files_by_type(list_operator_node): for expected_file, result_file in zip(expected_files, result.outputs["result"].value): assert expected_file["filename"] == result_file.filename assert expected_file["type"] == result_file.type - assert expected_file["tenant_id"] == result_file.tenant_id assert expected_file["transfer_method"] == result_file.transfer_method assert expected_file["related_id"] == result_file.related_id @@ -144,7 +136,6 @@ def test_filter_files_by_type(list_operator_node): def test_get_file_extract_string_func(): # Create a File object file = File( - tenant_id="test_tenant", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, filename="test_file.txt", @@ -165,7 +156,6 @@ def test_get_file_extract_string_func(): # Test with empty values empty_file = File( - tenant_id="test_tenant", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, filename=None, diff --git a/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py b/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py deleted file mode 100644 index 63725838390..00000000000 --- a/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py +++ /dev/null @@ -1,52 +0,0 @@ -from dify_graph.entities.graph_config import NodeConfigDictAdapter -from dify_graph.nodes.loop.entities import LoopNodeData -from dify_graph.nodes.loop.loop_node import LoopNode - - -def test_extract_variable_selector_to_variable_mapping_validates_child_node_configs(monkeypatch) -> None: - seen_configs: list[object] = [] - original_validate_python = NodeConfigDictAdapter.validate_python - - def record_validate_python(value: object): - seen_configs.append(value) - return original_validate_python(value) - - monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python) - - child_node_config = { - "id": "answer-node", - "data": { - "type": "answer", - "title": "Answer", - "answer": "", - "loop_id": "loop-node", - }, - } - - LoopNode._extract_variable_selector_to_variable_mapping( - graph_config={ - "nodes": [ - { - "id": "loop-node", - "data": { - "type": "loop", - "title": "Loop", - "loop_count": 1, - "break_conditions": [], - "logical_operator": "and", - }, - }, - child_node_config, - ], - "edges": [], - }, - node_id="loop-node", - node_data=LoopNodeData( - title="Loop", - loop_count=1, - break_conditions=[], - logical_operator="and", - ), - ) - - assert seen_configs == [child_node_config] diff --git a/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py b/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py deleted file mode 100644 index c5a02e87e48..00000000000 --- a/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py +++ /dev/null @@ -1,125 +0,0 @@ -from types import SimpleNamespace -from unittest.mock import MagicMock - -from dify_graph.model_runtime.entities import ImagePromptMessageContent -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer -from dify_graph.nodes.protocols import HttpClientProtocol -from dify_graph.nodes.question_classifier import ( - QuestionClassifierNode, - QuestionClassifierNodeData, -) -from tests.workflow_test_utils import build_test_graph_init_params - - -def test_init_question_classifier_node_data(): - data = { - "title": "test classifier node", - "query_variable_selector": ["id", "name"], - "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}}, - "classes": [{"id": "1", "name": "class 1"}], - "instruction": "This is a test instruction", - "memory": { - "role_prefix": {"user": "Human:", "assistant": "AI:"}, - "window": {"enabled": True, "size": 5}, - "query_prompt_template": "Previous conversation:\n{history}\n\nHuman: {query}\nAI:", - }, - "vision": {"enabled": True, "configs": {"variable_selector": ["image"], "detail": "low"}}, - } - - node_data = QuestionClassifierNodeData.model_validate(data) - - assert node_data.query_variable_selector == ["id", "name"] - assert node_data.model.provider == "openai" - assert node_data.classes[0].id == "1" - assert node_data.instruction == "This is a test instruction" - assert node_data.memory is not None - assert node_data.memory.role_prefix is not None - assert node_data.memory.role_prefix.user == "Human:" - assert node_data.memory.role_prefix.assistant == "AI:" - assert node_data.memory.window.enabled == True - assert node_data.memory.window.size == 5 - assert node_data.memory.query_prompt_template == "Previous conversation:\n{history}\n\nHuman: {query}\nAI:" - assert node_data.vision.enabled == True - assert node_data.vision.configs.variable_selector == ["image"] - assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.LOW - - -def test_init_question_classifier_node_data_without_vision_config(): - data = { - "title": "test classifier node", - "query_variable_selector": ["id", "name"], - "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}}, - "classes": [{"id": "1", "name": "class 1"}], - "instruction": "This is a test instruction", - "memory": { - "role_prefix": {"user": "Human:", "assistant": "AI:"}, - "window": {"enabled": True, "size": 5}, - "query_prompt_template": "Previous conversation:\n{history}\n\nHuman: {query}\nAI:", - }, - } - - node_data = QuestionClassifierNodeData.model_validate(data) - - assert node_data.query_variable_selector == ["id", "name"] - assert node_data.model.provider == "openai" - assert node_data.classes[0].id == "1" - assert node_data.instruction == "This is a test instruction" - assert node_data.memory is not None - assert node_data.memory.role_prefix is not None - assert node_data.memory.role_prefix.user == "Human:" - assert node_data.memory.role_prefix.assistant == "AI:" - assert node_data.memory.window.enabled == True - assert node_data.memory.window.size == 5 - assert node_data.memory.query_prompt_template == "Previous conversation:\n{history}\n\nHuman: {query}\nAI:" - assert node_data.vision.enabled == False - assert node_data.vision.configs.variable_selector == ["sys", "files"] - assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.HIGH - - -def test_question_classifier_calculate_rest_token_uses_shared_prompt_builder(monkeypatch): - node_data = QuestionClassifierNodeData.model_validate( - { - "title": "test classifier node", - "query_variable_selector": ["id", "name"], - "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}}, - "classes": [{"id": "1", "name": "class 1"}], - "instruction": "This is a test instruction", - } - ) - template_renderer = MagicMock(spec=TemplateRenderer) - node = QuestionClassifierNode( - id="node-id", - config={"id": "node-id", "data": node_data.model_dump(mode="json")}, - graph_init_params=build_test_graph_init_params( - workflow_id="workflow-id", - graph_config={}, - tenant_id="tenant-id", - app_id="app-id", - user_id="user-id", - ), - graph_runtime_state=SimpleNamespace(variable_pool=MagicMock()), - credentials_provider=MagicMock(spec=CredentialsProvider), - model_factory=MagicMock(spec=ModelFactory), - model_instance=MagicMock(), - http_client=MagicMock(spec=HttpClientProtocol), - llm_file_saver=MagicMock(), - template_renderer=template_renderer, - ) - fetch_prompt_messages = MagicMock(return_value=([], None)) - monkeypatch.setattr( - "dify_graph.nodes.question_classifier.question_classifier_node.llm_utils.fetch_prompt_messages", - fetch_prompt_messages, - ) - monkeypatch.setattr( - "dify_graph.nodes.question_classifier.question_classifier_node.llm_utils.fetch_model_schema", - MagicMock(return_value=SimpleNamespace(model_properties={}, parameter_rules=[])), - ) - - node._calculate_rest_token( - node_data=node_data, - query="hello", - model_instance=MagicMock(stop=(), parameters={}), - context="", - ) - - assert fetch_prompt_messages.call_args.kwargs["template_renderer"] is template_renderer diff --git a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py index b8f0e25e91c..833c3030521 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py @@ -2,21 +2,24 @@ import json import time import pytest +from graphon.nodes.start.entities import StartNodeData +from graphon.nodes.start.start_node import StartNode +from graphon.runtime import GraphRuntimeState +from graphon.variables import build_segment, segment_to_variable +from graphon.variables.input_entities import VariableEntity, VariableEntityType +from graphon.variables.variables import Variable from pydantic import ValidationError as PydanticValidationError -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType -from tests.workflow_test_utils import build_test_graph_init_params +from core.workflow.system_variables import build_system_variables +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID +from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool def make_start_node(user_inputs, variables): - variable_pool = VariablePool( - system_variables=SystemVariable(), - user_inputs=user_inputs, - conversation_variables=[], + variable_pool = build_test_variable_pool( + variables=build_system_variables(), + node_id="start", + inputs=user_inputs, ) config = { @@ -232,3 +235,64 @@ def test_json_object_optional_variable_not_provided(): # Current implementation raises a validation error even when the variable is optional with pytest.raises(ValueError, match="profile is required in input form"): node._run() + + +def test_start_node_outputs_full_variable_pool_snapshot(): + variable_pool = build_test_variable_pool( + variables=[ + *build_system_variables(query="hello", workflow_run_id="run-123"), + _build_prefixed_variable(ENVIRONMENT_VARIABLE_NODE_ID, "API_KEY", "secret"), + _build_prefixed_variable(CONVERSATION_VARIABLE_NODE_ID, "session_id", "conversation-1"), + ], + node_id="start", + inputs={"profile": {"age": 20, "name": "Tom"}}, + ) + + config = { + "id": "start", + "data": StartNodeData( + title="Start", + variables=[ + VariableEntity( + variable="profile", + label="profile", + type=VariableEntityType.JSON_OBJECT, + required=True, + ) + ], + ).model_dump(), + } + + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + node = StartNode( + id="start", + config=config, + graph_init_params=build_test_graph_init_params( + workflow_id="wf", + graph_config={}, + tenant_id="tenant", + app_id="app", + user_id="u", + user_from="account", + invoke_from="debugger", + call_depth=0, + ), + graph_runtime_state=graph_runtime_state, + ) + + result = node._run() + + assert result.inputs == {"profile": {"age": 20, "name": "Tom"}} + assert result.outputs["profile"] == {"age": 20, "name": "Tom"} + assert result.outputs["sys.query"] == "hello" + assert result.outputs["sys.workflow_run_id"] == "run-123" + assert result.outputs["env.API_KEY"] == "secret" + assert result.outputs["conversation.session_id"] == "conversation-1" + + +def _build_prefixed_variable(node_id: str, name: str, value: object) -> Variable: + return segment_to_variable( + segment=build_segment(value), + selector=(node_id, name), + name=name, + ) diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index 3cbd96dfef0..15870148027 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -3,23 +3,55 @@ from __future__ import annotations import sys import types from collections.abc import Generator +from types import SimpleNamespace from typing import TYPE_CHECKING, Any -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest +from graphon.file import File, FileTransferMethod, FileType +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.node_events import StreamChunkEvent, StreamCompletedEvent +from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables.segments import ArrayFileSegment -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.utils.message_transformer import ToolFileMessageTransformer -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables.segments import ArrayFileSegment +from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params if TYPE_CHECKING: # pragma: no cover - imported for type checking only - from dify_graph.nodes.tool.tool_node import ToolNode + from graphon.nodes.tool.tool_node import ToolNode + + +class _StubToolRuntime: + def get_runtime(self, *, node_id: str, node_data: Any, variable_pool: Any) -> ToolRuntimeHandle: + raise NotImplementedError + + def get_runtime_parameters(self, *, tool_runtime: ToolRuntimeHandle) -> list[Any]: + return [] + + def invoke( + self, + *, + tool_runtime: ToolRuntimeHandle, + tool_parameters: dict[str, Any], + workflow_call_depth: int, + provider_name: str, + ) -> Generator[ToolRuntimeMessage, None, None]: + yield from () + + def get_usage(self, *, tool_runtime: ToolRuntimeHandle) -> LLMUsage: + return LLMUsage.empty_usage() + + def build_file_reference(self, *, mapping: dict[str, Any]) -> Any: + return mapping + + def resolve_provider_icons( + self, + *, + provider_name: str, + default_icon: str | None = None, + ) -> tuple[str | None, str | None]: + return default_icon, None @pytest.fixture @@ -31,8 +63,8 @@ def tool_node(monkeypatch) -> ToolNode: ops_stub.TraceTask = object # pragma: no cover - stub attribute monkeypatch.setitem(sys.modules, module_name, ops_stub) - from dify_graph.nodes.protocols import ToolFileManagerProtocol - from dify_graph.nodes.tool.tool_node import ToolNode + from graphon.nodes.protocols import ToolFileManagerProtocol + from graphon.nodes.tool.tool_node import ToolNode graph_config: dict[str, Any] = { "nodes": [ @@ -66,13 +98,14 @@ def tool_node(monkeypatch) -> ToolNode: call_depth=0, ) - variable_pool = VariablePool(system_variables=SystemVariable(user_id="user-id")) + variable_pool = VariablePool(system_variables=build_system_variables(user_id="user-id")) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) config = graph_config["nodes"][0] # Provide a stub ToolFileManager to satisfy the updated ToolNode constructor tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol) + runtime = _StubToolRuntime() node = ToolNode( id="node-instance", @@ -80,6 +113,7 @@ def tool_node(monkeypatch) -> ToolNode: graph_init_params=init_params, graph_runtime_state=graph_runtime_state, tool_file_manager_factory=tool_file_manager_factory, + runtime=runtime, ) return node @@ -93,29 +127,19 @@ def _collect_events(generator: Generator) -> tuple[list[Any], LLMUsage]: return events, stop.value -def _run_transform(tool_node: ToolNode, message: ToolInvokeMessage) -> tuple[list[Any], LLMUsage]: - def _identity_transform(messages, *_args, **_kwargs): - return messages - - tool_runtime = MagicMock() - with patch.object( - ToolFileMessageTransformer, "transform_tool_invoke_messages", side_effect=_identity_transform, autospec=True - ): - generator = tool_node._transform_message( - messages=iter([message]), - tool_info={"provider_type": "builtin", "provider_id": "provider"}, - parameters_for_log={}, - user_id="user-id", - tenant_id="tenant-id", - node_id=tool_node._node_id, - tool_runtime=tool_runtime, - ) - return _collect_events(generator) +def _run_transform(tool_node: ToolNode, message: ToolRuntimeMessage) -> tuple[list[Any], LLMUsage]: + generator = tool_node._transform_message( + messages=iter([message]), + tool_info={"provider_type": "builtin", "provider_id": "provider"}, + parameters_for_log={}, + node_id=tool_node._node_id, + tool_runtime=ToolRuntimeHandle(raw=object()), + ) + return _collect_events(generator) def test_link_messages_with_file_populate_files_output(tool_node: ToolNode): file_obj = File( - tenant_id="tenant-id", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.TOOL_FILE, related_id="file-id", @@ -125,9 +149,9 @@ def test_link_messages_with_file_populate_files_output(tool_node: ToolNode): size=123, storage_key="file-key", ) - message = ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.LINK, - message=ToolInvokeMessage.TextMessage(text="/files/tools/file-id.pdf"), + message = ToolRuntimeMessage( + type=ToolRuntimeMessage.MessageType.LINK, + message=ToolRuntimeMessage.TextMessage(text="/files/tools/file-id.pdf"), meta={"file": file_obj}, ) @@ -150,9 +174,9 @@ def test_link_messages_with_file_populate_files_output(tool_node: ToolNode): def test_plain_link_messages_remain_links(tool_node: ToolNode): - message = ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.LINK, - message=ToolInvokeMessage.TextMessage(text="https://dify.ai"), + message = ToolRuntimeMessage( + type=ToolRuntimeMessage.MessageType.LINK, + message=ToolRuntimeMessage.TextMessage(text="https://dify.ai"), meta=None, ) @@ -167,3 +191,35 @@ def test_plain_link_messages_remain_links(tool_node: ToolNode): files_segment = completed_events[0].node_run_result.outputs["files"] assert isinstance(files_segment, ArrayFileSegment) assert files_segment.value == [] + + +def test_image_link_messages_use_tool_file_id_metadata(tool_node: ToolNode): + file_obj = File( + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.TOOL_FILE, + related_id="file-id", + filename="demo.pdf", + extension=".pdf", + mime_type="application/pdf", + size=123, + storage_key="file-key", + ) + tool_node._tool_file_manager_factory.get_file_generator_by_tool_file_id.return_value = ( + None, + SimpleNamespace(mime_type="application/pdf"), + ) + tool_node._runtime.build_file_reference = MagicMock(return_value=file_obj) + message = ToolRuntimeMessage( + type=ToolRuntimeMessage.MessageType.IMAGE_LINK, + message=ToolRuntimeMessage.TextMessage(text="/files/tools/file-id.pdf"), + meta={"tool_file_id": "file-id"}, + ) + + events, _ = _run_transform(tool_node, message) + + tool_node._tool_file_manager_factory.get_file_generator_by_tool_file_id.assert_called_once_with("file-id") + completed_events = [event for event in events if isinstance(event, StreamCompletedEvent)] + assert len(completed_events) == 1 + files_segment = completed_events[0].node_run_result.outputs["files"] + assert isinstance(files_segment, ArrayFileSegment) + assert files_segment.value == [file_obj] diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py new file mode 100644 index 00000000000..c4dfc5a1792 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py @@ -0,0 +1,276 @@ +from __future__ import annotations + +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.nodes.tool.entities import ToolNodeData, ToolProviderType +from graphon.nodes.tool.exc import ToolRuntimeInvocationError +from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage +from graphon.runtime import VariablePool + +from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler +from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.entities.tool_entities import ToolProviderType as CoreToolProviderType +from core.tools.errors import ToolInvokeError +from core.tools.tool_engine import ToolEngine +from core.tools.tool_manager import ToolManager +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.workflow.node_runtime import DifyToolNodeRuntime +from core.workflow.system_variables import build_system_variables +from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool + + +@pytest.fixture +def runtime(monkeypatch) -> DifyToolNodeRuntime: + module_name = "core.ops.ops_trace_manager" + if module_name not in sys.modules: + ops_stub = types.ModuleType(module_name) + ops_stub.TraceQueueManager = object # pragma: no cover - stub attribute + ops_stub.TraceTask = object # pragma: no cover - stub attribute + monkeypatch.setitem(sys.modules, module_name, ops_stub) + + init_params = build_test_graph_init_params( + workflow_id="workflow-id", + graph_config={"nodes": [], "edges": []}, + tenant_id="tenant-id", + app_id="app-id", + user_id="user-id", + user_from="account", + invoke_from="debugger", + call_depth=0, + ) + return DifyToolNodeRuntime(init_params.run_context) + + +def _build_tool_node_data() -> ToolNodeData: + return ToolNodeData.model_validate( + { + "type": "tool", + "title": "Tool", + "provider_id": "provider", + "provider_type": ToolProviderType.BUILT_IN, + "provider_name": "provider", + "tool_name": "lookup", + "tool_label": "Lookup", + "tool_configurations": {}, + "tool_parameters": {}, + } + ) + + +def test_invoke_creates_callback_and_converts_messages(runtime: DifyToolNodeRuntime) -> None: + core_message = ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.LINK, + message=ToolInvokeMessage.TextMessage(text="https://dify.ai"), + meta=None, + ) + variable_pool: VariablePool = build_test_variable_pool( + variables=build_system_variables(conversation_id="conversation-id") + ) + workflow_tool = MagicMock() + + with ( + patch.object(ToolManager, "get_workflow_tool_runtime", return_value=workflow_tool), + patch.object(ToolEngine, "generic_invoke", return_value=iter([core_message])) as generic_invoke_mock, + patch.object( + ToolFileMessageTransformer, + "transform_tool_invoke_messages", + side_effect=lambda *, messages, **_: messages, + ) as transform_tool_messages, + ): + tool_runtime = runtime.get_runtime( + node_id="node-id", + node_data=_build_tool_node_data(), + variable_pool=variable_pool, + ) + messages = list( + runtime.invoke( + tool_runtime=tool_runtime, + tool_parameters={}, + workflow_call_depth=0, + provider_name="provider", + ) + ) + + assert not hasattr(tool_runtime, "conversation_id") + assert len(messages) == 1 + graph_message = messages[0] + assert graph_message.type == ToolRuntimeMessage.MessageType.LINK + assert isinstance(graph_message.message, ToolRuntimeMessage.TextMessage) + assert graph_message.message.text == "https://dify.ai" + + callback = generic_invoke_mock.call_args.kwargs["workflow_tool_callback"] + assert isinstance(callback, DifyWorkflowCallbackHandler) + assert generic_invoke_mock.call_args.kwargs["conversation_id"] == "conversation-id" + + transform_kwargs = transform_tool_messages.call_args.kwargs + assert transform_kwargs["conversation_id"] == "conversation-id" + + +def test_invoke_maps_plugin_errors_to_graph_errors(runtime: DifyToolNodeRuntime) -> None: + invoke_error = PluginInvokeError('{"error_type":"RateLimit","message":"too many"}') + + with patch.object(ToolEngine, "generic_invoke", side_effect=invoke_error): + with pytest.raises(ToolRuntimeInvocationError, match="An error occurred in the provider"): + runtime.invoke( + tool_runtime=ToolRuntimeHandle(raw=MagicMock()), + tool_parameters={}, + workflow_call_depth=0, + provider_name="provider", + ) + + +def test_get_usage_normalizes_dict_payload(runtime: DifyToolNodeRuntime) -> None: + usage_payload = LLMUsage.empty_usage().model_dump() + usage_payload["total_tokens"] = 42 + + usage = runtime.get_usage( + tool_runtime=ToolRuntimeHandle(raw=SimpleNamespace(latest_usage=usage_payload)), + ) + + assert usage.total_tokens == 42 + + +def test_get_runtime_converts_graph_provider_type_for_tool_manager(runtime: DifyToolNodeRuntime) -> None: + node_data = _build_tool_node_data() + + with patch.object(ToolManager, "get_workflow_tool_runtime", return_value=MagicMock()) as runtime_mock: + tool_runtime = runtime.get_runtime(node_id="node-id", node_data=node_data, variable_pool=None) + + assert not hasattr(tool_runtime, "conversation_id") + workflow_tool = runtime_mock.call_args.args[3] + assert workflow_tool.provider_type == CoreToolProviderType.BUILT_IN + + +def test_get_runtime_parameters_reads_required_flags(runtime: DifyToolNodeRuntime) -> None: + tool_runtime = ToolRuntimeHandle( + raw=SimpleNamespace( + get_merged_runtime_parameters=MagicMock( + return_value=[ + SimpleNamespace(name="city", required=True), + SimpleNamespace(name="country", required=False), + ] + ) + ) + ) + + parameters = runtime.get_runtime_parameters(tool_runtime=tool_runtime) + + assert [(parameter.name, parameter.required) for parameter in parameters] == [ + ("city", True), + ("country", False), + ] + + +def test_get_usage_returns_empty_usage_when_tool_has_no_usage(runtime: DifyToolNodeRuntime) -> None: + usage = runtime.get_usage(tool_runtime=ToolRuntimeHandle(raw=SimpleNamespace(latest_usage=None))) + + assert usage == LLMUsage.empty_usage() + + +@pytest.mark.parametrize( + ("payload", "expected_type"), + [ + (ToolInvokeMessage.JsonMessage(json_object={"ok": True}, suppress_output=True), ToolRuntimeMessage.JsonMessage), + (ToolInvokeMessage.BlobMessage(blob=b"bytes"), ToolRuntimeMessage.BlobMessage), + ( + ToolInvokeMessage.BlobChunkMessage( + id="blob-id", + sequence=1, + total_length=5, + blob=b"hello", + end=True, + ), + ToolRuntimeMessage.BlobChunkMessage, + ), + (ToolInvokeMessage.FileMessage(file_marker="marker"), ToolRuntimeMessage.FileMessage), + ( + ToolInvokeMessage.VariableMessage(variable_name="city", variable_value="Tokyo", stream=True), + ToolRuntimeMessage.VariableMessage, + ), + ( + ToolInvokeMessage.LogMessage( + id="log-id", + label="lookup", + status=ToolInvokeMessage.LogMessage.LogStatus.SUCCESS, + data={"count": 1}, + metadata={"source": "tool"}, + ), + ToolRuntimeMessage.LogMessage, + ), + ], +) +def test_convert_message_payload_supports_runtime_message_types( + runtime: DifyToolNodeRuntime, + payload: object, + expected_type: type[object], +) -> None: + message = runtime._convert_message_payload(payload) + + assert isinstance(message, expected_type) + + +def test_convert_message_payload_rejects_unknown_types(runtime: DifyToolNodeRuntime) -> None: + with pytest.raises(TypeError, match="unsupported tool message payload"): + runtime._convert_message_payload(object()) + + +def test_resolve_provider_icons_prefers_builtin_tool_icons(runtime: DifyToolNodeRuntime) -> None: + plugin = SimpleNamespace( + plugin_id="langgenius/tools", + name="search", + declaration=SimpleNamespace(icon={"plugin": "icon"}), + ) + builtin_tool = SimpleNamespace( + name="langgenius/tools/search", + icon={"builtin": "icon"}, + icon_dark={"builtin": "dark"}, + ) + + with ( + patch("core.workflow.node_runtime.PluginInstaller") as installer_cls, + patch("core.workflow.node_runtime.BuiltinToolManageService.list_builtin_tools", return_value=[builtin_tool]), + ): + installer_cls.return_value.list_plugins.return_value = [plugin] + + icon, icon_dark = runtime.resolve_provider_icons(provider_name="langgenius/tools/search") + + assert icon == {"builtin": "icon"} + assert icon_dark == {"builtin": "dark"} + + +def test_resolve_provider_icons_returns_default_when_provider_is_unknown(runtime: DifyToolNodeRuntime) -> None: + with ( + patch("core.workflow.node_runtime.PluginInstaller") as installer_cls, + patch("core.workflow.node_runtime.BuiltinToolManageService.list_builtin_tools", return_value=[]), + ): + installer_cls.return_value.list_plugins.return_value = [] + + icon, icon_dark = runtime.resolve_provider_icons(provider_name="unknown", default_icon="fallback") + + assert icon == "fallback" + assert icon_dark is None + + +@pytest.mark.parametrize( + ("exc", "message"), + [ + (PluginDaemonClientSideError("bad request"), "Failed to invoke tool, error: bad request"), + (ToolInvokeError("broken"), "Failed to invoke tool provider: broken"), + (RuntimeError("unexpected"), "unexpected"), + ], +) +def test_map_invocation_exception_normalizes_runtime_errors( + runtime: DifyToolNodeRuntime, + exc: Exception, + message: str, +) -> None: + error = runtime._map_invocation_exception(exc, provider_name="provider") + + assert isinstance(error, ToolRuntimeInvocationError) + assert str(error) == message diff --git a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py index 9aeab0409e4..952e798430f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py @@ -1,13 +1,14 @@ from collections.abc import Mapping +from graphon.entities import GraphInitParams +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState + from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from tests.workflow_test_utils import build_test_graph_init_params +from core.workflow.system_variables import build_system_variables +from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, GraphRuntimeState]: @@ -17,9 +18,10 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, invoke_from="debugger", ) runtime_state = GraphRuntimeState( - variable_pool=VariablePool( - system_variables=SystemVariable(user_id="user", files=[]), - user_inputs={"payload": "value"}, + variable_pool=build_test_variable_pool( + variables=build_system_variables(user_id="user", files=[]), + node_id="node-1", + inputs={"payload": "value"}, ), start_at=0.0, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py deleted file mode 100644 index e69c05dc0bb..00000000000 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py +++ /dev/null @@ -1,308 +0,0 @@ -import time -import uuid -from uuid import uuid4 - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.graph import Graph -from dify_graph.graph_events.node import NodeRunSucceededEvent -from dify_graph.nodes.variable_assigner.common import helpers as common_helpers -from dify_graph.nodes.variable_assigner.v1 import VariableAssignerNode -from dify_graph.nodes.variable_assigner.v1.node_data import WriteMode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables import ArrayStringVariable, StringVariable - -DEFAULT_NODE_ID = "node_id" - - -def test_overwrite_string_variable(): - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": { - "type": "assigner", - "title": "Variable Assigner", - "assigned_variable_selector": ["conversation", "test_conversation_variable"], - "write_mode": "over-write", - "input_variable_selector": ["node_id", "test_string_variable"], - }, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - conversation_variable = StringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value="the first value", - ) - - input_variable = StringVariable( - id=str(uuid4()), - name="test_string_variable", - value="the second value", - ) - conversation_id = str(uuid.uuid4()) - - # construct variable pool - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id=conversation_id), - user_inputs={}, - environment_variables=[], - conversation_variables=[conversation_variable], - ) - - variable_pool.add( - [DEFAULT_NODE_ID, input_variable.name], - input_variable, - ) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node_config = { - "id": "node_id", - "data": { - "title": "test", - "assigned_variable_selector": ["conversation", conversation_variable.name], - "write_mode": WriteMode.OVER_WRITE, - "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], - }, - } - - node = VariableAssignerNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - events = list(node.run()) - succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) - updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) - assert updated_variables is not None - assert updated_variables[0].name == conversation_variable.name - assert updated_variables[0].new_value == input_variable.value - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.value == "the second value" - assert got.to_object() == "the second value" - - -def test_append_variable_to_array(): - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": { - "type": "assigner", - "title": "Variable Assigner", - "assigned_variable_selector": ["conversation", "test_conversation_variable"], - "write_mode": "append", - "input_variable_selector": ["node_id", "test_string_variable"], - }, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - conversation_variable = ArrayStringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value=["the first value"], - ) - - input_variable = StringVariable( - id=str(uuid4()), - name="test_string_variable", - value="the second value", - ) - conversation_id = str(uuid.uuid4()) - - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id=conversation_id), - user_inputs={}, - environment_variables=[], - conversation_variables=[conversation_variable], - ) - variable_pool.add( - [DEFAULT_NODE_ID, input_variable.name], - input_variable, - ) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node_config = { - "id": "node_id", - "data": { - "title": "test", - "assigned_variable_selector": ["conversation", conversation_variable.name], - "write_mode": WriteMode.APPEND, - "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], - }, - } - - node = VariableAssignerNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - events = list(node.run()) - succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) - updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) - assert updated_variables is not None - assert updated_variables[0].name == conversation_variable.name - assert updated_variables[0].new_value == ["the first value", "the second value"] - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == ["the first value", "the second value"] - - -def test_clear_array(): - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": { - "type": "assigner", - "title": "Variable Assigner", - "assigned_variable_selector": ["conversation", "test_conversation_variable"], - "write_mode": "clear", - "input_variable_selector": [], - }, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - conversation_variable = ArrayStringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value=["the first value"], - ) - - conversation_id = str(uuid.uuid4()) - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id=conversation_id), - user_inputs={}, - environment_variables=[], - conversation_variables=[conversation_variable], - ) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node_config = { - "id": "node_id", - "data": { - "title": "test", - "assigned_variable_selector": ["conversation", conversation_variable.name], - "write_mode": WriteMode.CLEAR, - "input_variable_selector": [], - }, - } - - node = VariableAssignerNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - events = list(node.run()) - succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) - updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) - assert updated_variables is not None - assert updated_variables[0].name == conversation_variable.name - assert updated_variables[0].new_value == [] - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == [] diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py deleted file mode 100644 index a7673c5a148..00000000000 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py +++ /dev/null @@ -1,22 +0,0 @@ -from dify_graph.nodes.variable_assigner.v2.enums import Operation -from dify_graph.nodes.variable_assigner.v2.helpers import is_input_value_valid -from dify_graph.variables import SegmentType - - -def test_is_input_value_valid_overwrite_array_string(): - # Valid cases - assert is_input_value_valid( - variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=["hello", "world"] - ) - assert is_input_value_valid(variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=[]) - - # Invalid cases - assert not is_input_value_valid( - variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value="not an array" - ) - assert not is_input_value_valid( - variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=[1, 2, 3] - ) - assert not is_input_value_valid( - variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=["valid", 123, "invalid"] - ) diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py deleted file mode 100644 index 6874f3fef13..00000000000 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py +++ /dev/null @@ -1,451 +0,0 @@ -import time -import uuid -from uuid import uuid4 - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.graph import Graph -from dify_graph.nodes.variable_assigner.v2 import VariableAssignerNode -from dify_graph.nodes.variable_assigner.v2.enums import InputType, Operation -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables import ArrayStringVariable - -DEFAULT_NODE_ID = "node_id" - - -def test_handle_item_directly(): - """Test the _handle_item method directly for remove operations.""" - # Create variables - variable1 = ArrayStringVariable( - id=str(uuid4()), - name="test_variable1", - value=["first", "second", "third"], - ) - - variable2 = ArrayStringVariable( - id=str(uuid4()), - name="test_variable2", - value=["first", "second", "third"], - ) - - # Create a mock class with just the _handle_item method - class MockNode: - def _handle_item(self, *, variable, operation, value): - match operation: - case Operation.REMOVE_FIRST: - if not variable.value: - return variable.value - return variable.value[1:] - case Operation.REMOVE_LAST: - if not variable.value: - return variable.value - return variable.value[:-1] - - node = MockNode() - - # Test remove-first - result1 = node._handle_item( - variable=variable1, - operation=Operation.REMOVE_FIRST, - value=None, - ) - - # Test remove-last - result2 = node._handle_item( - variable=variable2, - operation=Operation.REMOVE_LAST, - value=None, - ) - - # Check the results - assert result1 == ["second", "third"] - assert result2 == ["first", "second"] - - -def test_remove_first_from_array(): - """Test removing the first element from an array.""" - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []}, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - conversation_variable = ArrayStringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value=["first", "second", "third"], - selector=["conversation", "test_conversation_variable"], - ) - - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id="conversation_id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[conversation_variable], - ) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node_config = { - "id": "node_id", - "data": { - "title": "test", - "version": "2", - "items": [ - { - "variable_selector": ["conversation", conversation_variable.name], - "input_type": InputType.VARIABLE, - "operation": Operation.REMOVE_FIRST, - "value": None, - } - ], - }, - } - - node = VariableAssignerNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - # Run the node - result = list(node.run()) - - # Completed run - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == ["second", "third"] - - -def test_remove_last_from_array(): - """Test removing the last element from an array.""" - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []}, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - conversation_variable = ArrayStringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value=["first", "second", "third"], - selector=["conversation", "test_conversation_variable"], - ) - - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id="conversation_id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[conversation_variable], - ) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node_config = { - "id": "node_id", - "data": { - "title": "test", - "version": "2", - "items": [ - { - "variable_selector": ["conversation", conversation_variable.name], - "input_type": InputType.VARIABLE, - "operation": Operation.REMOVE_LAST, - "value": None, - } - ], - }, - } - - node = VariableAssignerNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - list(node.run()) - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == ["first", "second"] - - -def test_remove_first_from_empty_array(): - """Test removing the first element from an empty array (should do nothing).""" - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []}, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - conversation_variable = ArrayStringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value=[], - selector=["conversation", "test_conversation_variable"], - ) - - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id="conversation_id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[conversation_variable], - ) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node_config = { - "id": "node_id", - "data": { - "title": "test", - "version": "2", - "items": [ - { - "variable_selector": ["conversation", conversation_variable.name], - "input_type": InputType.VARIABLE, - "operation": Operation.REMOVE_FIRST, - "value": None, - } - ], - }, - } - - node = VariableAssignerNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - list(node.run()) - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == [] - - -def test_remove_last_from_empty_array(): - """Test removing the last element from an empty array (should do nothing).""" - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []}, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - conversation_variable = ArrayStringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value=[], - selector=["conversation", "test_conversation_variable"], - ) - - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id="conversation_id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[conversation_variable], - ) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node_config = { - "id": "node_id", - "data": { - "title": "test", - "version": "2", - "items": [ - { - "variable_selector": ["conversation", conversation_variable.name], - "input_type": InputType.VARIABLE, - "operation": Operation.REMOVE_LAST, - "value": None, - } - ], - }, - } - - node = VariableAssignerNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - list(node.run()) - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == [] - - -def test_node_factory_creates_variable_assigner_node(): - graph_config = { - "edges": [], - "nodes": [ - { - "data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []}, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id="conversation_id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[], - ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - - node = node_factory.create_node(graph_config["nodes"][0]) - - assert isinstance(node, VariableAssignerNode) diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py index 6be5bb23e86..be18391b2c2 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py @@ -324,7 +324,7 @@ def test_webhook_body_parameter_edge_cases(): def test_webhook_data_inheritance(): """Test WebhookData inherits from BaseNodeData correctly.""" - from dify_graph.entities.base_node_data import BaseNodeData + from graphon.entities.base_node_data import BaseNodeData # Test that WebhookData is a subclass of BaseNodeData assert issubclass(WebhookData, BaseNodeData) diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py index ddf1af5a59f..f1132af02b5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py @@ -1,4 +1,5 @@ import pytest +from graphon.entities.exc import BaseNodeError from core.workflow.nodes.trigger_webhook.exc import ( WebhookConfigError, @@ -6,7 +7,6 @@ from core.workflow.nodes.trigger_webhook.exc import ( WebhookNotFoundError, WebhookTimeoutError, ) -from dify_graph.entities.exc import BaseNodeError def test_webhook_node_error_inheritance(): diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py index 78dd7ce0f3e..cccd3fb6767 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py @@ -8,7 +8,11 @@ when passing files to downstream LLM nodes. from unittest.mock import Mock, patch -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from graphon.entities import GraphInitParams +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.workflow.nodes.trigger_webhook.entities import ( ContentType, Method, @@ -16,11 +20,8 @@ from core.workflow.nodes.trigger_webhook.entities import ( WebhookData, ) from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.runtime.graph_runtime_state import GraphRuntimeState -from dify_graph.runtime.variable_pool import VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import default_system_variables +from tests.workflow_test_utils import build_test_variable_pool def create_webhook_node( @@ -96,6 +97,18 @@ def create_test_file_dict( } +def build_webhook_variable_pool(inputs: dict) -> VariablePool: + return build_test_variable_pool( + variables=default_system_variables(), + node_id="webhook-node-1", + inputs=inputs, + ) + + +def expected_factory_mapping(file_dict: dict) -> dict: + return {**file_dict, "upload_file_id": file_dict["related_id"]} + + def test_webhook_node_file_conversion_to_file_variable(): """Test that webhook node converts file dictionaries to FileVariable objects.""" # Create test file dictionary (as it comes from webhook service) @@ -111,9 +124,8 @@ def test_webhook_node_file_conversion_to_file_variable(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -122,14 +134,14 @@ def test_webhook_node_file_conversion_to_file_variable(): "image_upload": file_dict, }, } - }, + } ) node = create_webhook_node(data, variable_pool) - # Mock the file factory and variable factory + # Mock the file reference boundary and variable factory with ( - patch("factories.file_factory.build_from_mapping") as mock_file_factory, + patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory, patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, ): @@ -153,8 +165,7 @@ def test_webhook_node_file_conversion_to_file_variable(): # Verify file factory was called with correct parameters mock_file_factory.assert_called_once_with( - mapping=file_dict, - tenant_id="test-tenant", + mapping=expected_factory_mapping(file_dict), ) # Verify segment factory was called to create FileSegment @@ -184,16 +195,15 @@ def test_webhook_node_file_conversion_with_missing_files(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, "body": {}, "files": {}, # No files } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -219,9 +229,8 @@ def test_webhook_node_file_conversion_with_none_file(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -230,7 +239,7 @@ def test_webhook_node_file_conversion_with_none_file(): "file": None, }, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -256,9 +265,8 @@ def test_webhook_node_file_conversion_with_non_dict_file(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -267,7 +275,7 @@ def test_webhook_node_file_conversion_with_non_dict_file(): "file": "not_a_dict", # Wrapped to match node expectation }, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -300,9 +308,8 @@ def test_webhook_node_file_conversion_mixed_parameters(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -315,13 +322,13 @@ def test_webhook_node_file_conversion_mixed_parameters(): "file_param": file_dict, }, } - }, + } ) node = create_webhook_node(data, variable_pool) with ( - patch("factories.file_factory.build_from_mapping") as mock_file_factory, + patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory, patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, ): @@ -350,8 +357,7 @@ def test_webhook_node_file_conversion_mixed_parameters(): # Verify file conversion was called mock_file_factory.assert_called_once_with( - mapping=file_dict, - tenant_id="test-tenant", + mapping=expected_factory_mapping(file_dict), ) @@ -370,9 +376,8 @@ def test_webhook_node_different_file_types(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -383,13 +388,13 @@ def test_webhook_node_different_file_types(): "video": create_test_file_dict("video.mp4", "video"), }, } - }, + } ) node = create_webhook_node(data, variable_pool) with ( - patch("factories.file_factory.build_from_mapping") as mock_file_factory, + patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory, patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, ): @@ -430,9 +435,8 @@ def test_webhook_node_file_conversion_with_non_dict_wrapper(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -441,7 +445,7 @@ def test_webhook_node_file_conversion_with_non_dict_wrapper(): "file": "just a string", }, } - }, + } ) node = create_webhook_node(data, variable_pool) diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py index 139f65d6c3f..34c66a4f9f3 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py @@ -1,8 +1,13 @@ from unittest.mock import patch import pytest +from graphon.entities import GraphInitParams +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables import FileVariable, StringVariable -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE from core.workflow.nodes.trigger_webhook.entities import ( ContentType, @@ -12,13 +17,8 @@ from core.workflow.nodes.trigger_webhook.entities import ( WebhookParameter, ) from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.runtime.graph_runtime_state import GraphRuntimeState -from dify_graph.runtime.variable_pool import VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables import FileVariable, StringVariable +from core.workflow.system_variables import default_system_variables +from tests.workflow_test_utils import build_test_variable_pool def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) -> TriggerWebhookNode: @@ -62,6 +62,14 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) return node +def build_webhook_variable_pool(inputs: dict) -> VariablePool: + return build_test_variable_pool( + variables=default_system_variables(), + node_id="1", + inputs=inputs, + ) + + def test_webhook_node_basic_initialization(): """Test basic webhook node initialization and configuration.""" data = WebhookData( @@ -76,10 +84,7 @@ def test_webhook_node_basic_initialization(): timeout=30, ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={}, - ) + variable_pool = build_webhook_variable_pool({}) node = create_webhook_node(data, variable_pool) @@ -119,9 +124,8 @@ def test_webhook_node_run_with_headers(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": { "Authorization": "Bearer token123", @@ -132,7 +136,7 @@ def test_webhook_node_run_with_headers(): "body": {}, "files": {}, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -155,9 +159,8 @@ def test_webhook_node_run_with_query_params(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": { @@ -167,7 +170,7 @@ def test_webhook_node_run_with_query_params(): "body": {}, "files": {}, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -191,9 +194,8 @@ def test_webhook_node_run_with_body_params(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -205,7 +207,7 @@ def test_webhook_node_run_with_body_params(): }, "files": {}, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -222,7 +224,6 @@ def test_webhook_node_run_with_file_params(): """Test webhook node execution with file parameter extraction.""" # Create mock file objects file1 = File( - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file1", @@ -232,7 +233,6 @@ def test_webhook_node_run_with_file_params(): ) file2 = File( - tenant_id="1", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file2", @@ -250,9 +250,8 @@ def test_webhook_node_run_with_file_params(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -262,14 +261,14 @@ def test_webhook_node_run_with_file_params(): "document": file2.to_dict(), }, } - }, + } ) node = create_webhook_node(data, variable_pool) - # Mock the file factory to avoid DB-dependent validation on upload_file_id - with patch("factories.file_factory.build_from_mapping") as mock_file_factory: + # Mock the node's file reference boundary to avoid DB-dependent validation on upload_file_id + with patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory: - def _to_file(mapping, tenant_id, config=None, strict_type_validation=False): + def _to_file(*, mapping): return File.model_validate(mapping) mock_file_factory.side_effect = _to_file @@ -284,7 +283,6 @@ def test_webhook_node_run_with_file_params(): def test_webhook_node_run_mixed_parameters(): """Test webhook node execution with mixed parameter types.""" file_obj = File( - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file1", @@ -303,23 +301,22 @@ def test_webhook_node_run_mixed_parameters(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {"Authorization": "Bearer token"}, "query_params": {"version": "v1"}, "body": {"message": "Test message"}, "files": {"upload": file_obj.to_dict()}, } - }, + } ) node = create_webhook_node(data, variable_pool) - # Mock the file factory to avoid DB-dependent validation on upload_file_id - with patch("factories.file_factory.build_from_mapping") as mock_file_factory: + # Mock the node's file reference boundary to avoid DB-dependent validation on upload_file_id + with patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory: - def _to_file(mapping, tenant_id, config=None, strict_type_validation=False): + def _to_file(*, mapping): return File.model_validate(mapping) mock_file_factory.side_effect = _to_file @@ -343,10 +340,7 @@ def test_webhook_node_run_empty_webhook_data(): body=[WebhookBodyParameter(name="message", type="string", required=False)], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={}, # No webhook_data - ) + variable_pool = build_webhook_variable_pool({}) # No webhook_data node = create_webhook_node(data, variable_pool) result = node._run() @@ -369,9 +363,8 @@ def test_webhook_node_run_case_insensitive_headers(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": { "content-type": "application/json", # lowercase @@ -382,7 +375,7 @@ def test_webhook_node_run_case_insensitive_headers(): "body": {}, "files": {}, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -399,12 +392,11 @@ def test_webhook_node_variable_pool_user_inputs(): data = WebhookData(title="Test Webhook") # Add some additional variables to the pool - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": {"headers": {}, "query_params": {}, "body": {}, "files": {}}, "other_var": "should_be_included", - }, + } ) variable_pool.add(["node1", "extra"], StringVariable(name="extra", value="extra_value")) @@ -430,16 +422,15 @@ def test_webhook_node_different_methods(method): method=method, ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, "body": {}, "files": {}, } - }, + } ) node = create_webhook_node(data, variable_pool) diff --git a/api/tests/unit_tests/core/workflow/test_enums.py b/api/tests/unit_tests/core/workflow/test_enums.py deleted file mode 100644 index e8ce6f60f7d..00000000000 --- a/api/tests/unit_tests/core/workflow/test_enums.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Tests for workflow pause related enums and constants.""" - -from dify_graph.enums import ( - WorkflowExecutionStatus, -) - - -class TestWorkflowExecutionStatus: - """Test WorkflowExecutionStatus enum.""" - - def test_is_ended_method(self): - """Test is_ended method for different statuses.""" - # Test ended statuses - ended_statuses = [ - WorkflowExecutionStatus.SUCCEEDED, - WorkflowExecutionStatus.FAILED, - WorkflowExecutionStatus.PARTIAL_SUCCEEDED, - WorkflowExecutionStatus.STOPPED, - ] - - for status in ended_statuses: - assert status.is_ended(), f"{status} should be considered ended" - - # Test non-ended statuses - non_ended_statuses = [ - WorkflowExecutionStatus.SCHEDULED, - WorkflowExecutionStatus.RUNNING, - WorkflowExecutionStatus.PAUSED, - ] - - for status in non_ended_statuses: - assert not status.is_ended(), f"{status} should not be considered ended" - - def test_ended_values(self): - """Test ended_values returns the expected status values.""" - assert set(WorkflowExecutionStatus.ended_values()) == { - WorkflowExecutionStatus.SUCCEEDED.value, - WorkflowExecutionStatus.FAILED.value, - WorkflowExecutionStatus.PARTIAL_SUCCEEDED.value, - WorkflowExecutionStatus.STOPPED.value, - } diff --git a/api/tests/unit_tests/core/workflow/test_human_input_compat.py b/api/tests/unit_tests/core/workflow/test_human_input_compat.py new file mode 100644 index 00000000000..cd41c43e4ad --- /dev/null +++ b/api/tests/unit_tests/core/workflow/test_human_input_compat.py @@ -0,0 +1,184 @@ +from types import SimpleNamespace + +from graphon.enums import BuiltinNodeTypes +from pydantic import BaseModel + +from core.workflow.human_input_compat import ( + DeliveryMethodType, + EmailDeliveryConfig, + EmailDeliveryMethod, + EmailRecipients, + WebAppDeliveryMethod, + _WebAppDeliveryConfig, + is_human_input_webapp_enabled, + normalize_human_input_node_data_for_graph, + normalize_node_config_for_graph, + normalize_node_data_for_graph, + parse_human_input_delivery_methods, +) + + +def test_email_delivery_config_helpers_render_and_sanitize_text() -> None: + variable_pool = SimpleNamespace( + convert_template=lambda body: SimpleNamespace(text=body.replace("{{#node.value#}}", "42")) + ) + + rendered = EmailDeliveryConfig.render_body_template( + body="Open {{#url#}} and use {{#node.value#}}", + url="https://example.com", + variable_pool=variable_pool, + ) + sanitized = EmailDeliveryConfig.sanitize_subject("Hello\r\n Team") + html = EmailDeliveryConfig.render_markdown_body( + "**Hello** [mail](mailto:test@example.com)" + ) + + assert rendered == "Open https://example.com and use 42" + assert sanitized == "Hello alert(1) Team" + assert "Hello" in html + assert "", - "'; DROP TABLE users; --", - "../../../etc/passwd", - "\\x00\\x00", # null bytes - "A" * 10000, # very long input - ], - ) - def test_validate_api_key_auth_args_malicious_input(self, malicious_input): - """Test API key auth args validation - malicious input""" - args = self.mock_args.copy() - args["category"] = malicious_input - - # Verify parameter validator doesn't crash on malicious input - # Should validate normally rather than raising security-related exceptions - ApiKeyAuthService.validate_api_key_auth_args(args) - - @patch("services.auth.api_key_auth_service.db.session") - @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") - @patch("services.auth.api_key_auth_service.encrypter") - def test_create_provider_auth_database_error_handling(self, mock_encrypter, mock_factory, mock_session): - """Test create provider auth - database error handling""" - # Mock successful auth validation - mock_auth_instance = Mock() - mock_auth_instance.validate_credentials.return_value = True - mock_factory.return_value = mock_auth_instance - - # Mock encryption - mock_encrypter.encrypt_token.return_value = "encrypted_key" - - # Mock database error - mock_session.commit.side_effect = Exception("Database error") - - with pytest.raises(Exception, match="Database error"): - ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args) - - @patch("services.auth.api_key_auth_service.db.session") - def test_get_auth_credentials_invalid_json(self, mock_session): - """Test get auth credentials - invalid JSON""" - # Mock database returning invalid JSON - mock_binding = Mock() - mock_binding.credentials = "invalid json content" - mock_session.query.return_value.where.return_value.first.return_value = mock_binding - - with pytest.raises(json.JSONDecodeError): - ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider) - - @patch("services.auth.api_key_auth_service.db.session") - @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") - def test_create_provider_auth_factory_exception(self, mock_factory, mock_session): - """Test create provider auth - factory exception""" - # Mock factory raising exception - mock_factory.side_effect = Exception("Factory error") - - with pytest.raises(Exception, match="Factory error"): - ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args) - - @patch("services.auth.api_key_auth_service.db.session") - @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") - @patch("services.auth.api_key_auth_service.encrypter") - def test_create_provider_auth_encryption_exception(self, mock_encrypter, mock_factory, mock_session): - """Test create provider auth - encryption exception""" - # Mock successful auth validation - mock_auth_instance = Mock() - mock_auth_instance.validate_credentials.return_value = True - mock_factory.return_value = mock_auth_instance - - # Mock encryption exception - mock_encrypter.encrypt_token.side_effect = Exception("Encryption error") - - with pytest.raises(Exception, match="Encryption error"): - ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args) - - def test_validate_api_key_auth_args_none_input(self): - """Test API key auth args validation - None input""" - with pytest.raises(TypeError): - ApiKeyAuthService.validate_api_key_auth_args(None) - - def test_validate_api_key_auth_args_dict_credentials_with_list_auth_type(self): - """Test API key auth args validation - dict credentials with list auth_type""" - args = self.mock_args.copy() - args["credentials"]["auth_type"] = ["api_key"] - - # Current implementation checks if auth_type exists and is truthy, list ["api_key"] is truthy - # So this should not raise exception, this test should pass - ApiKeyAuthService.validate_api_key_auth_args(args) diff --git a/api/tests/unit_tests/services/auth/test_auth_integration.py b/api/tests/unit_tests/services/auth/test_auth_integration.py deleted file mode 100644 index 3832a0b8b20..00000000000 --- a/api/tests/unit_tests/services/auth/test_auth_integration.py +++ /dev/null @@ -1,231 +0,0 @@ -""" -API Key Authentication System Integration Tests -""" - -import json -from concurrent.futures import ThreadPoolExecutor -from unittest.mock import Mock, patch - -import httpx -import pytest - -from services.auth.api_key_auth_factory import ApiKeyAuthFactory -from services.auth.api_key_auth_service import ApiKeyAuthService -from services.auth.auth_type import AuthType - - -class TestAuthIntegration: - def setup_method(self): - self.tenant_id_1 = "tenant_123" - self.tenant_id_2 = "tenant_456" # For multi-tenant isolation testing - self.category = "search" - - # Realistic authentication configurations - self.firecrawl_credentials = {"auth_type": "bearer", "config": {"api_key": "fc_test_key_123"}} - self.jina_credentials = {"auth_type": "bearer", "config": {"api_key": "jina_test_key_456"}} - self.watercrawl_credentials = {"auth_type": "x-api-key", "config": {"api_key": "wc_test_key_789"}} - - @patch("services.auth.api_key_auth_service.db.session") - @patch("services.auth.firecrawl.firecrawl.httpx.post") - @patch("services.auth.api_key_auth_service.encrypter.encrypt_token") - def test_end_to_end_auth_flow(self, mock_encrypt, mock_http, mock_session): - """Test complete authentication flow: request → validation → encryption → storage""" - mock_http.return_value = self._create_success_response() - mock_encrypt.return_value = "encrypted_fc_test_key_123" - mock_session.add = Mock() - mock_session.commit = Mock() - - args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials} - ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args) - - mock_http.assert_called_once() - call_args = mock_http.call_args - assert "https://api.firecrawl.dev/v1/crawl" in call_args[0][0] - assert call_args[1]["headers"]["Authorization"] == "Bearer fc_test_key_123" - - mock_encrypt.assert_called_once_with(self.tenant_id_1, "fc_test_key_123") - mock_session.add.assert_called_once() - mock_session.commit.assert_called_once() - - @patch("services.auth.firecrawl.firecrawl.httpx.post") - def test_cross_component_integration(self, mock_http): - """Test factory → provider → HTTP call integration""" - mock_http.return_value = self._create_success_response() - factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, self.firecrawl_credentials) - result = factory.validate_credentials() - - assert result is True - mock_http.assert_called_once() - - @patch("services.auth.api_key_auth_service.db.session") - def test_multi_tenant_isolation(self, mock_session): - """Ensure complete tenant data isolation""" - tenant1_binding = self._create_mock_binding(self.tenant_id_1, AuthType.FIRECRAWL, self.firecrawl_credentials) - tenant2_binding = self._create_mock_binding(self.tenant_id_2, AuthType.JINA, self.jina_credentials) - - mock_session.scalars.return_value.all.return_value = [tenant1_binding] - result1 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_1) - - mock_session.scalars.return_value.all.return_value = [tenant2_binding] - result2 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_2) - - assert len(result1) == 1 - assert result1[0].tenant_id == self.tenant_id_1 - assert len(result2) == 1 - assert result2[0].tenant_id == self.tenant_id_2 - - @patch("services.auth.api_key_auth_service.db.session") - def test_cross_tenant_access_prevention(self, mock_session): - """Test prevention of cross-tenant credential access""" - mock_session.query.return_value.where.return_value.first.return_value = None - - result = ApiKeyAuthService.get_auth_credentials(self.tenant_id_2, self.category, AuthType.FIRECRAWL) - - assert result is None - - def test_sensitive_data_protection(self): - """Ensure API keys don't leak to logs""" - credentials_with_secrets = { - "auth_type": "bearer", - "config": {"api_key": "super_secret_key_do_not_log", "secret": "another_secret"}, - } - - factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, credentials_with_secrets) - factory_str = str(factory) - - assert "super_secret_key_do_not_log" not in factory_str - assert "another_secret" not in factory_str - - @patch("services.auth.api_key_auth_service.db.session") - @patch("services.auth.firecrawl.firecrawl.httpx.post") - @patch("services.auth.api_key_auth_service.encrypter.encrypt_token") - def test_concurrent_creation_safety(self, mock_encrypt, mock_http, mock_session): - """Test concurrent authentication creation safety""" - mock_http.return_value = self._create_success_response() - mock_encrypt.return_value = "encrypted_key" - mock_session.add = Mock() - mock_session.commit = Mock() - - args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials} - - results = [] - exceptions = [] - - def create_auth(): - try: - ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args) - results.append("success") - except Exception as e: - exceptions.append(e) - - with ThreadPoolExecutor(max_workers=5) as executor: - futures = [executor.submit(create_auth) for _ in range(5)] - for future in futures: - future.result() - - assert len(results) == 5 - assert len(exceptions) == 0 - assert mock_session.add.call_count == 5 - assert mock_session.commit.call_count == 5 - - @pytest.mark.parametrize( - "invalid_input", - [ - None, # Null input - {}, # Empty dictionary - missing required fields - {"auth_type": "bearer"}, # Missing config section - {"auth_type": "bearer", "config": {}}, # Missing api_key - ], - ) - def test_invalid_input_boundary(self, invalid_input): - """Test boundary handling for invalid inputs""" - with pytest.raises((ValueError, KeyError, TypeError, AttributeError)): - ApiKeyAuthFactory(AuthType.FIRECRAWL, invalid_input) - - @patch("services.auth.firecrawl.firecrawl.httpx.post") - def test_http_error_handling(self, mock_http): - """Test proper HTTP error handling""" - mock_response = Mock() - mock_response.status_code = 401 - mock_response.text = '{"error": "Unauthorized"}' - mock_response.raise_for_status.side_effect = httpx.HTTPError("Unauthorized") - mock_http.return_value = mock_response - - # PT012: Split into single statement for pytest.raises - factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, self.firecrawl_credentials) - with pytest.raises((httpx.HTTPError, Exception)): - factory.validate_credentials() - - @patch("services.auth.api_key_auth_service.db.session") - @patch("services.auth.firecrawl.firecrawl.httpx.post") - def test_network_failure_recovery(self, mock_http, mock_session): - """Test system recovery from network failures""" - mock_http.side_effect = httpx.RequestError("Network timeout") - mock_session.add = Mock() - mock_session.commit = Mock() - - args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials} - - with pytest.raises(httpx.RequestError): - ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args) - - mock_session.commit.assert_not_called() - - @pytest.mark.parametrize( - ("provider", "credentials"), - [ - (AuthType.FIRECRAWL, {"auth_type": "bearer", "config": {"api_key": "fc_key"}}), - (AuthType.JINA, {"auth_type": "bearer", "config": {"api_key": "jina_key"}}), - (AuthType.WATERCRAWL, {"auth_type": "x-api-key", "config": {"api_key": "wc_key"}}), - ], - ) - def test_all_providers_factory_creation(self, provider, credentials): - """Test factory creation for all supported providers""" - auth_class = ApiKeyAuthFactory.get_apikey_auth_factory(provider) - assert auth_class is not None - - factory = ApiKeyAuthFactory(provider, credentials) - assert factory.auth is not None - - def _create_success_response(self, status_code=200): - """Create successful HTTP response mock""" - mock_response = Mock() - mock_response.status_code = status_code - mock_response.json.return_value = {"status": "success"} - mock_response.raise_for_status.return_value = None - return mock_response - - def _create_mock_binding(self, tenant_id: str, provider: str, credentials: dict) -> Mock: - """Create realistic database binding mock""" - mock_binding = Mock() - mock_binding.id = f"binding_{provider}_{tenant_id}" - mock_binding.tenant_id = tenant_id - mock_binding.category = self.category - mock_binding.provider = provider - mock_binding.credentials = json.dumps(credentials, ensure_ascii=False) - mock_binding.disabled = False - - mock_binding.created_at = Mock() - mock_binding.created_at.timestamp.return_value = 1640995200 - mock_binding.updated_at = Mock() - mock_binding.updated_at.timestamp.return_value = 1640995200 - - return mock_binding - - def test_integration_coverage_validation(self): - """Validate integration test coverage meets quality standards""" - core_scenarios = { - "business_logic": ["end_to_end_auth_flow", "cross_component_integration"], - "security": ["multi_tenant_isolation", "cross_tenant_access_prevention", "sensitive_data_protection"], - "reliability": ["concurrent_creation_safety", "network_failure_recovery"], - "compatibility": ["all_providers_factory_creation"], - "boundaries": ["invalid_input_boundary", "http_error_handling"], - } - - total_scenarios = sum(len(scenarios) for scenarios in core_scenarios.values()) - assert total_scenarios >= 10 - - security_tests = core_scenarios["security"] - assert "multi_tenant_isolation" in security_tests - assert "sensitive_data_protection" in security_tests - assert True diff --git a/api/tests/unit_tests/services/dataset_service_test_helpers.py b/api/tests/unit_tests/services/dataset_service_test_helpers.py new file mode 100644 index 00000000000..ef73bc0e01b --- /dev/null +++ b/api/tests/unit_tests/services/dataset_service_test_helpers.py @@ -0,0 +1,455 @@ +"""Shared helpers for dataset_service unit tests. + +These factories and lightweight builders are reused across the dataset, +document, and segment service test modules that exercise +``api/services/dataset_service.py``. +""" + +import json +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, create_autospec, patch + +import pytest +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType +from werkzeug.exceptions import Forbidden, NotFound + +from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError +from core.rag.index_processor.constant.built_in_field import BuiltInField +from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.retrieval.retrieval_methods import RetrievalMethod +from enums.cloud_plan import CloudPlan +from models import Account, TenantAccountRole +from models.dataset import ( + ChildChunk, + Dataset, + DatasetPermissionEnum, + DatasetProcessRule, + Document, + DocumentSegment, +) +from models.model import UploadFile +from services.dataset_service import ( + DatasetCollectionBindingService, + DatasetPermissionService, + DatasetService, + DocumentService, + SegmentService, +) +from services.entities.knowledge_entities.knowledge_entities import ( + ChildChunkUpdateArgs, + DataSource, + FileInfo, + InfoList, + KnowledgeConfig, + NotionIcon, + NotionInfo, + NotionPage, + PreProcessingRule, + ProcessRule, + RerankingModel, + RetrievalModel, + Rule, + Segmentation, + SegmentUpdateArgs, + WebsiteInfo, +) +from services.entities.knowledge_entities.rag_pipeline_entities import ( + IconInfo as PipelineIconInfo, +) +from services.entities.knowledge_entities.rag_pipeline_entities import ( + KnowledgeConfiguration, + RagPipelineDatasetCreateEntity, +) +from services.entities.knowledge_entities.rag_pipeline_entities import ( + RerankingModelConfig as RagPipelineRerankingModelConfig, +) +from services.entities.knowledge_entities.rag_pipeline_entities import ( + RetrievalSetting as RagPipelineRetrievalSetting, +) +from services.errors.account import NoPermissionError +from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError +from services.errors.dataset import DatasetNameDuplicateError +from services.errors.document import DocumentIndexingError +from services.errors.file import FileNotExistsError + +__all__ = [ + "Account", + "BuiltInField", + "ChildChunk", + "ChildChunkDeleteIndexError", + "ChildChunkIndexingError", + "ChildChunkUpdateArgs", + "CloudPlan", + "DataSource", + "Dataset", + "DatasetCollectionBindingService", + "DatasetNameDuplicateError", + "DatasetPermissionEnum", + "DatasetPermissionService", + "DatasetProcessRule", + "DatasetService", + "DatasetServiceUnitDataFactory", + "Document", + "DocumentIndexingError", + "DocumentSegment", + "DocumentService", + "FileInfo", + "FileNotExistsError", + "Forbidden", + "IndexStructureType", + "InfoList", + "KnowledgeConfig", + "KnowledgeConfiguration", + "LLMBadRequestError", + "MagicMock", + "Mock", + "ModelFeature", + "ModelType", + "NoPermissionError", + "NotFound", + "NotionIcon", + "NotionInfo", + "NotionPage", + "PipelineIconInfo", + "PreProcessingRule", + "ProcessRule", + "ProviderTokenNotInitError", + "RagPipelineDatasetCreateEntity", + "RagPipelineRerankingModelConfig", + "RagPipelineRetrievalSetting", + "RerankingModel", + "RetrievalMethod", + "RetrievalModel", + "Rule", + "SegmentService", + "SegmentUpdateArgs", + "Segmentation", + "SimpleNamespace", + "TenantAccountRole", + "WebsiteInfo", + "_make_child_chunk", + "_make_dataset", + "_make_document", + "_make_features", + "_make_knowledge_configuration", + "_make_lock_context", + "_make_retrieval_model", + "_make_segment", + "_make_session_context", + "_make_upload_knowledge_config", + "create_autospec", + "json", + "patch", + "pytest", +] + + +def _make_session_context(session: MagicMock) -> MagicMock: + """Wrap a mocked session in a context manager.""" + context_manager = MagicMock() + context_manager.__enter__.return_value = session + context_manager.__exit__.return_value = False + return context_manager + + +class DatasetServiceUnitDataFactory: + """Factory for lightweight doubles used across dataset service tests.""" + + @staticmethod + def create_dataset_mock( + dataset_id: str = "dataset-123", + tenant_id: str = "tenant-123", + *, + permission: str = DatasetPermissionEnum.ALL_TEAM, + created_by: str = "user-123", + indexing_technique: str = "economy", + embedding_model_provider: str = "provider", + embedding_model: str = "model", + built_in_field_enabled: bool = False, + doc_form: str | None = "text_model", + enable_api: bool = False, + summary_index_setting: dict | None = None, + **kwargs, + ) -> Mock: + dataset = Mock(spec=Dataset) + dataset.id = dataset_id + dataset.tenant_id = tenant_id + dataset.permission = permission + dataset.created_by = created_by + dataset.indexing_technique = indexing_technique + dataset.embedding_model_provider = embedding_model_provider + dataset.embedding_model = embedding_model + dataset.built_in_field_enabled = built_in_field_enabled + dataset.doc_form = doc_form + dataset.enable_api = enable_api + dataset.updated_by = None + dataset.updated_at = None + dataset.summary_index_setting = summary_index_setting + for key, value in kwargs.items(): + setattr(dataset, key, value) + return dataset + + @staticmethod + def create_user_mock( + user_id: str = "user-123", + tenant_id: str = "tenant-123", + role: str = TenantAccountRole.OWNER, + **kwargs, + ) -> SimpleNamespace: + user = SimpleNamespace( + id=user_id, + current_tenant_id=tenant_id, + current_role=role, + ) + for key, value in kwargs.items(): + setattr(user, key, value) + return user + + @staticmethod + def create_document_mock( + document_id: str = "doc-123", + dataset_id: str = "dataset-123", + tenant_id: str = "tenant-123", + *, + indexing_status: str = "completed", + is_paused: bool = False, + archived: bool = False, + enabled: bool = True, + data_source_type: str = "upload_file", + data_source_info_dict: dict | None = None, + data_source_info: str | None = None, + doc_form: str = "text_model", + need_summary: bool = True, + position: int = 0, + doc_metadata: dict | None = None, + name: str = "Document", + **kwargs, + ) -> Mock: + document = Mock(spec=Document) + document.id = document_id + document.dataset_id = dataset_id + document.tenant_id = tenant_id + document.indexing_status = indexing_status + document.is_paused = is_paused + document.paused_by = None + document.paused_at = None + document.archived = archived + document.enabled = enabled + document.data_source_type = data_source_type + document.data_source_info_dict = data_source_info_dict or {} + document.data_source_info = data_source_info + document.doc_form = doc_form + document.need_summary = need_summary + document.position = position + document.doc_metadata = doc_metadata + document.name = name + for key, value in kwargs.items(): + setattr(document, key, value) + return document + + @staticmethod + def create_upload_file_mock(file_id: str = "file-123", name: str = "upload.txt") -> Mock: + upload_file = Mock(spec=UploadFile) + upload_file.id = file_id + upload_file.name = name + return upload_file + + +_UNSET = object() + + +def _make_lock_context() -> MagicMock: + context_manager = MagicMock() + context_manager.__enter__.return_value = None + context_manager.__exit__.return_value = False + return context_manager + + +def _make_features(*, enabled: bool, plan: str = CloudPlan.PROFESSIONAL) -> SimpleNamespace: + return SimpleNamespace( + billing=SimpleNamespace( + enabled=enabled, + subscription=SimpleNamespace(plan=plan), + ), + documents_upload_quota=SimpleNamespace(limit=1000, size=0), + ) + + +def _make_dataset( + *, + dataset_id: str = "dataset-1", + tenant_id: str = "tenant-1", + data_source_type: str | None = None, + indexing_technique: str | None = "economy", + latest_process_rule=None, +) -> Mock: + dataset = Mock(spec=Dataset) + dataset.id = dataset_id + dataset.tenant_id = tenant_id + dataset.data_source_type = data_source_type + dataset.indexing_technique = indexing_technique + dataset.latest_process_rule = latest_process_rule + dataset.embedding_model_provider = "provider" + dataset.embedding_model = "embedding-model" + dataset.summary_index_setting = None + dataset.retrieval_model = None + dataset.collection_binding_id = None + return dataset + + +def _make_document( + *, + document_id: str = "doc-1", + dataset_id: str = "dataset-1", + tenant_id: str = "tenant-1", + batch: str = "batch-1", + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, + word_count: int = 0, + name: str = "Document 1", + enabled: bool = True, + archived: bool = False, + indexing_status: str = "completed", + display_status: str = "available", +) -> Mock: + document = Mock(spec=Document) + document.id = document_id + document.dataset_id = dataset_id + document.tenant_id = tenant_id + document.batch = batch + document.doc_form = doc_form + document.word_count = word_count + document.name = name + document.enabled = enabled + document.archived = archived + document.indexing_status = indexing_status + document.display_status = display_status + document.data_source_type = "upload_file" + document.data_source_info = "{}" + document.completed_at = SimpleNamespace() + document.processing_started_at = "started" + document.parsing_completed_at = "parsed" + document.cleaning_completed_at = "cleaned" + document.splitting_completed_at = "split" + document.updated_at = None + document.created_from = None + document.dataset_process_rule_id = "process-rule-1" + return document + + +def _make_segment( + *, + segment_id: str = "segment-1", + content: str = "segment content", + word_count: int = 15, + enabled: bool = True, + keywords: list[str] | None = None, + index_node_id: str = "node-1", + dataset_id: str = "dataset-1", + document_id: str = "doc-1", +) -> Mock: + segment = Mock(spec=DocumentSegment) + segment.id = segment_id + segment.dataset_id = dataset_id + segment.document_id = document_id + segment.content = content + segment.word_count = word_count + segment.enabled = enabled + segment.keywords = keywords or [] + segment.answer = None + segment.index_node_id = index_node_id + segment.disabled_at = None + segment.disabled_by = None + segment.status = "completed" + segment.error = None + return segment + + +def _make_child_chunk() -> ChildChunk: + return ChildChunk( + id="child-a", + tenant_id="tenant-1", + dataset_id="dataset-1", + document_id="doc-1", + segment_id="segment-1", + position=1, + content="old content", + word_count=11, + created_by="user-1", + ) + + +def _make_upload_knowledge_config( + *, + original_document_id: str | None = None, + file_ids: list[str] | None = None, + process_rule: ProcessRule | None = None, + data_source: DataSource | object | None = _UNSET, +) -> KnowledgeConfig: + if data_source is _UNSET: + info_list = InfoList( + data_source_type="upload_file", + file_info_list=FileInfo(file_ids=file_ids) if file_ids is not None else None, + ) + data_source = DataSource(info_list=info_list) + + return KnowledgeConfig( + original_document_id=original_document_id, + indexing_technique="economy", + data_source=data_source, + process_rule=process_rule, + doc_form=IndexStructureType.PARAGRAPH_INDEX, + doc_language="English", + ) + + +def _make_retrieval_model( + *, + reranking_provider_name: str = "rerank-provider", + reranking_model_name: str = "rerank-model", +) -> RetrievalModel: + return RetrievalModel( + search_method=RetrievalMethod.SEMANTIC_SEARCH, + reranking_enable=True, + reranking_model=RerankingModel( + reranking_provider_name=reranking_provider_name, + reranking_model_name=reranking_model_name, + ), + reranking_mode="reranking_model", + top_k=4, + score_threshold_enabled=False, + ) + + +def _make_rag_pipeline_retrieval_setting() -> RagPipelineRetrievalSetting: + return RagPipelineRetrievalSetting( + search_method=RetrievalMethod.SEMANTIC_SEARCH, + top_k=4, + score_threshold=0.5, + score_threshold_enabled=True, + reranking_mode="reranking_model", + reranking_enable=True, + reranking_model=RagPipelineRerankingModelConfig( + reranking_provider_name="rerank-provider", + reranking_model_name="rerank-model", + ), + ) + + +def _make_knowledge_configuration( + *, + chunk_structure: str = "paragraph", + indexing_technique: str = "high_quality", + embedding_model_provider: str = "provider", + embedding_model: str = "embedding-model", + keyword_number: int = 8, + summary_index_setting: dict | None = None, +) -> KnowledgeConfiguration: + return KnowledgeConfiguration( + chunk_structure=chunk_structure, + indexing_technique=indexing_technique, + embedding_model_provider=embedding_model_provider, + embedding_model=embedding_model, + keyword_number=keyword_number, + retrieval_model=_make_rag_pipeline_retrieval_setting(), + summary_index_setting=summary_index_setting, + ) diff --git a/api/tests/unit_tests/services/dataset_service_update_delete.py b/api/tests/unit_tests/services/dataset_service_update_delete.py index c805dd98e29..62c39f96d35 100644 --- a/api/tests/unit_tests/services/dataset_service_update_delete.py +++ b/api/tests/unit_tests/services/dataset_service_update_delete.py @@ -97,6 +97,7 @@ from unittest.mock import Mock, create_autospec, patch import pytest from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models import Account, TenantAccountRole from models.dataset import ( AppDatasetJoin, @@ -149,7 +150,7 @@ class DatasetUpdateDeleteTestDataFactory: name: str = "Test Dataset", description: str = "Test description", tenant_id: str = "tenant-123", - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, embedding_model_provider: str | None = "openai", embedding_model: str | None = "text-embedding-ada-002", collection_binding_id: str | None = "binding-123", @@ -237,7 +238,7 @@ class DatasetUpdateDeleteTestDataFactory: @staticmethod def create_knowledge_configuration_mock( chunk_structure: str = "tree", - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, embedding_model_provider: str = "openai", embedding_model: str = "text-embedding-ada-002", keyword_number: int = 10, @@ -591,7 +592,7 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings: patch( "services.dataset_service.current_user", create_autospec(Account, instance=True) ) as mock_current_user, - patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager, patch( "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" ) as mock_get_binding, @@ -630,12 +631,12 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings: dataset_id="dataset-123", runtime_mode="rag_pipeline", chunk_structure="tree", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock( chunk_structure="list", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", ) @@ -671,7 +672,7 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings: # Assert assert dataset.chunk_structure == "list" - assert dataset.indexing_technique == "high_quality" + assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert dataset.embedding_model == "text-embedding-ada-002" assert dataset.embedding_model_provider == "openai" assert dataset.collection_binding_id == "binding-123" @@ -698,12 +699,12 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings: dataset_id="dataset-123", runtime_mode="rag_pipeline", chunk_structure="tree", # Existing structure - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock( chunk_structure="list", # Different structure - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) mock_session.merge.return_value = dataset @@ -735,11 +736,11 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings: dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock( dataset_id="dataset-123", runtime_mode="rag_pipeline", - indexing_technique="high_quality", # Current technique + indexing_technique=IndexTechniqueType.HIGH_QUALITY, # Current technique ) knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock( - indexing_technique="economy", # Trying to change to economy + indexing_technique=IndexTechniqueType.ECONOMY, # Trying to change to economy ) mock_session.merge.return_value = dataset diff --git a/api/tests/unit_tests/services/document_service_validation.py b/api/tests/unit_tests/services/document_service_validation.py index 68296915078..7c36e9d9602 100644 --- a/api/tests/unit_tests/services/document_service_validation.py +++ b/api/tests/unit_tests/services/document_service_validation.py @@ -109,9 +109,10 @@ This test suite follows a comprehensive testing strategy that covers: from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.entities.model_entities import ModelType from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError -from dify_graph.model_runtime.entities.model_entities import ModelType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.dataset import Dataset, DatasetProcessRule, Document from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import ( @@ -153,7 +154,7 @@ class DocumentValidationTestDataFactory: dataset_id: str = "dataset-123", tenant_id: str = "tenant-123", doc_form: str | None = None, - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, embedding_model_provider: str = "openai", embedding_model: str = "text-embedding-ada-002", **kwargs, @@ -188,8 +189,8 @@ class DocumentValidationTestDataFactory: def create_knowledge_config_mock( data_source: DataSource | None = None, process_rule: ProcessRule | None = None, - doc_form: str = "text_model", - indexing_technique: str = "high_quality", + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, **kwargs, ) -> Mock: """ @@ -326,8 +327,8 @@ class TestDatasetServiceCheckDocForm: - Validation logic works correctly """ # Arrange - dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form="text_model") - doc_form = "text_model" + dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form=IndexStructureType.PARAGRAPH_INDEX) + doc_form = IndexStructureType.PARAGRAPH_INDEX # Act (should not raise) DatasetService.check_doc_form(dataset, doc_form) @@ -349,7 +350,7 @@ class TestDatasetServiceCheckDocForm: """ # Arrange dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form=None) - doc_form = "text_model" + doc_form = IndexStructureType.PARAGRAPH_INDEX # Act (should not raise) DatasetService.check_doc_form(dataset, doc_form) @@ -370,8 +371,8 @@ class TestDatasetServiceCheckDocForm: - Error type is correct """ # Arrange - dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form="text_model") - doc_form = "table_model" # Different form + dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form=IndexStructureType.PARAGRAPH_INDEX) + doc_form = IndexStructureType.PARENT_CHILD_INDEX # Different form # Act & Assert with pytest.raises(ValueError, match="doc_form is different from the dataset doc_form"): @@ -390,7 +391,7 @@ class TestDatasetServiceCheckDocForm: """ # Arrange dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form="knowledge_card") - doc_form = "text_model" # Different form + doc_form = IndexStructureType.PARAGRAPH_INDEX # Different form # Act & Assert with pytest.raises(ValueError, match="doc_form is different from the dataset doc_form"): @@ -430,7 +431,7 @@ class TestDatasetServiceCheckDatasetModelSetting: Provides a mocked ModelManager that can be used to verify model instance retrieval and error handling. """ - with patch("services.dataset_service.ModelManager") as mock_manager: + with patch("services.dataset_service.ModelManager.for_tenant") as mock_manager: yield mock_manager def test_check_dataset_model_setting_high_quality_success(self, mock_model_manager): @@ -447,7 +448,7 @@ class TestDatasetServiceCheckDatasetModelSetting: """ # Arrange dataset = DocumentValidationTestDataFactory.create_dataset_mock( - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", ) @@ -480,7 +481,7 @@ class TestDatasetServiceCheckDatasetModelSetting: - No errors are raised """ # Arrange - dataset = DocumentValidationTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = DocumentValidationTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) # Act (should not raise) DatasetService.check_dataset_model_setting(dataset) @@ -502,7 +503,7 @@ class TestDatasetServiceCheckDatasetModelSetting: """ # Arrange dataset = DocumentValidationTestDataFactory.create_dataset_mock( - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="invalid-model", ) @@ -532,7 +533,7 @@ class TestDatasetServiceCheckDatasetModelSetting: """ # Arrange dataset = DocumentValidationTestDataFactory.create_dataset_mock( - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", ) @@ -579,7 +580,7 @@ class TestDatasetServiceCheckEmbeddingModelSetting: Provides a mocked ModelManager that can be used to verify model instance retrieval and error handling. """ - with patch("services.dataset_service.ModelManager") as mock_manager: + with patch("services.dataset_service.ModelManager.for_tenant") as mock_manager: yield mock_manager def test_check_embedding_model_setting_success(self, mock_model_manager): @@ -701,7 +702,7 @@ class TestDatasetServiceCheckRerankingModelSetting: Provides a mocked ModelManager that can be used to verify model instance retrieval and error handling. """ - with patch("services.dataset_service.ModelManager") as mock_manager: + with patch("services.dataset_service.ModelManager.for_tenant") as mock_manager: yield mock_manager def test_check_reranking_model_setting_success(self, mock_model_manager): diff --git a/api/tests/unit_tests/services/enterprise/test_account_deletion_sync.py b/api/tests/unit_tests/services/enterprise/test_account_deletion_sync.py deleted file mode 100644 index b66111902c1..00000000000 --- a/api/tests/unit_tests/services/enterprise/test_account_deletion_sync.py +++ /dev/null @@ -1,276 +0,0 @@ -"""Unit tests for account deletion synchronization. - -This test module verifies the enterprise account deletion sync functionality, -including Redis queuing, error handling, and community vs enterprise behavior. -""" - -from unittest.mock import MagicMock, patch - -import pytest -from redis import RedisError - -from services.enterprise.account_deletion_sync import ( - _queue_task, - sync_account_deletion, - sync_workspace_member_removal, -) - - -class TestQueueTask: - """Unit tests for the _queue_task helper function.""" - - @pytest.fixture - def mock_redis_client(self): - """Mock redis_client for testing.""" - with patch("services.enterprise.account_deletion_sync.redis_client") as mock_redis: - yield mock_redis - - @pytest.fixture - def mock_uuid(self): - """Mock UUID generation for predictable task IDs.""" - with patch("services.enterprise.account_deletion_sync.uuid.uuid4") as mock_uuid_gen: - mock_uuid_gen.return_value = MagicMock(hex="test-task-id-1234") - yield mock_uuid_gen - - def test_queue_task_success(self, mock_redis_client, mock_uuid): - """Test successful task queueing to Redis.""" - # Arrange - workspace_id = "ws-123" - member_id = "member-456" - source = "test_source" - - # Act - result = _queue_task(workspace_id=workspace_id, member_id=member_id, source=source) - - # Assert - assert result is True - mock_redis_client.lpush.assert_called_once() - - # Verify the task payload structure - call_args = mock_redis_client.lpush.call_args[0] - assert call_args[0] == "enterprise:member:sync:queue" - - import json - - task_data = json.loads(call_args[1]) - assert task_data["workspace_id"] == workspace_id - assert task_data["member_id"] == member_id - assert task_data["source"] == source - assert task_data["type"] == "sync_member_deletion_from_workspace" - assert task_data["retry_count"] == 0 - assert "task_id" in task_data - assert "created_at" in task_data - - def test_queue_task_redis_error(self, mock_redis_client, caplog): - """Test handling of Redis connection errors.""" - # Arrange - mock_redis_client.lpush.side_effect = RedisError("Connection failed") - - # Act - result = _queue_task(workspace_id="ws-123", member_id="member-456", source="test_source") - - # Assert - assert result is False - assert "Failed to queue account deletion sync" in caplog.text - - def test_queue_task_type_error(self, mock_redis_client, caplog): - """Test handling of JSON serialization errors.""" - # Arrange - mock_redis_client.lpush.side_effect = TypeError("Cannot serialize") - - # Act - result = _queue_task(workspace_id="ws-123", member_id="member-456", source="test_source") - - # Assert - assert result is False - assert "Failed to queue account deletion sync" in caplog.text - - -class TestSyncWorkspaceMemberRemoval: - """Unit tests for sync_workspace_member_removal function.""" - - @pytest.fixture - def mock_queue_task(self): - """Mock _queue_task for testing.""" - with patch("services.enterprise.account_deletion_sync._queue_task") as mock_queue: - mock_queue.return_value = True - yield mock_queue - - def test_sync_workspace_member_removal_enterprise_enabled(self, mock_queue_task): - """Test sync when ENTERPRISE_ENABLED is True.""" - # Arrange - workspace_id = "ws-123" - member_id = "member-456" - source = "workspace_member_removed" - - with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: - mock_config.ENTERPRISE_ENABLED = True - - # Act - result = sync_workspace_member_removal(workspace_id=workspace_id, member_id=member_id, source=source) - - # Assert - assert result is True - mock_queue_task.assert_called_once_with(workspace_id=workspace_id, member_id=member_id, source=source) - - def test_sync_workspace_member_removal_enterprise_disabled(self, mock_queue_task): - """Test sync when ENTERPRISE_ENABLED is False (community edition).""" - # Arrange - with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: - mock_config.ENTERPRISE_ENABLED = False - - # Act - result = sync_workspace_member_removal(workspace_id="ws-123", member_id="member-456", source="test_source") - - # Assert - assert result is True - mock_queue_task.assert_not_called() - - def test_sync_workspace_member_removal_queue_failure(self, mock_queue_task): - """Test handling of queue task failures.""" - # Arrange - mock_queue_task.return_value = False - - with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: - mock_config.ENTERPRISE_ENABLED = True - - # Act - result = sync_workspace_member_removal(workspace_id="ws-123", member_id="member-456", source="test_source") - - # Assert - assert result is False - - -class TestSyncAccountDeletion: - """Unit tests for sync_account_deletion function.""" - - @pytest.fixture - def mock_db_session(self): - """Mock database session for testing.""" - with patch("services.enterprise.account_deletion_sync.db.session") as mock_session: - yield mock_session - - @pytest.fixture - def mock_queue_task(self): - """Mock _queue_task for testing.""" - with patch("services.enterprise.account_deletion_sync._queue_task") as mock_queue: - mock_queue.return_value = True - yield mock_queue - - def test_sync_account_deletion_enterprise_disabled(self, mock_db_session, mock_queue_task): - """Test sync when ENTERPRISE_ENABLED is False (community edition).""" - # Arrange - with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: - mock_config.ENTERPRISE_ENABLED = False - - # Act - result = sync_account_deletion(account_id="acc-123", source="account_deleted") - - # Assert - assert result is True - mock_db_session.query.assert_not_called() - mock_queue_task.assert_not_called() - - def test_sync_account_deletion_multiple_workspaces(self, mock_db_session, mock_queue_task): - """Test sync for account with multiple workspace memberships.""" - # Arrange - account_id = "acc-123" - - # Mock workspace joins - mock_join1 = MagicMock() - mock_join1.tenant_id = "tenant-1" - mock_join2 = MagicMock() - mock_join2.tenant_id = "tenant-2" - mock_join3 = MagicMock() - mock_join3.tenant_id = "tenant-3" - - mock_query = MagicMock() - mock_query.filter_by.return_value.all.return_value = [mock_join1, mock_join2, mock_join3] - mock_db_session.query.return_value = mock_query - - with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: - mock_config.ENTERPRISE_ENABLED = True - - # Act - result = sync_account_deletion(account_id=account_id, source="account_deleted") - - # Assert - assert result is True - assert mock_queue_task.call_count == 3 - - # Verify each workspace was queued - mock_queue_task.assert_any_call(workspace_id="tenant-1", member_id=account_id, source="account_deleted") - mock_queue_task.assert_any_call(workspace_id="tenant-2", member_id=account_id, source="account_deleted") - mock_queue_task.assert_any_call(workspace_id="tenant-3", member_id=account_id, source="account_deleted") - - def test_sync_account_deletion_no_workspaces(self, mock_db_session, mock_queue_task): - """Test sync for account with no workspace memberships.""" - # Arrange - mock_query = MagicMock() - mock_query.filter_by.return_value.all.return_value = [] - mock_db_session.query.return_value = mock_query - - with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: - mock_config.ENTERPRISE_ENABLED = True - - # Act - result = sync_account_deletion(account_id="acc-123", source="account_deleted") - - # Assert - assert result is True - mock_queue_task.assert_not_called() - - def test_sync_account_deletion_partial_failure(self, mock_db_session, mock_queue_task): - """Test sync when some tasks fail to queue.""" - # Arrange - account_id = "acc-123" - - # Mock workspace joins - mock_join1 = MagicMock() - mock_join1.tenant_id = "tenant-1" - mock_join2 = MagicMock() - mock_join2.tenant_id = "tenant-2" - mock_join3 = MagicMock() - mock_join3.tenant_id = "tenant-3" - - mock_query = MagicMock() - mock_query.filter_by.return_value.all.return_value = [mock_join1, mock_join2, mock_join3] - mock_db_session.query.return_value = mock_query - - # Mock queue_task to fail for second workspace - def queue_side_effect(workspace_id, member_id, source): - return workspace_id != "tenant-2" - - mock_queue_task.side_effect = queue_side_effect - - with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: - mock_config.ENTERPRISE_ENABLED = True - - # Act - result = sync_account_deletion(account_id=account_id, source="account_deleted") - - # Assert - assert result is False # Should return False if any task fails - assert mock_queue_task.call_count == 3 - - def test_sync_account_deletion_all_failures(self, mock_db_session, mock_queue_task): - """Test sync when all tasks fail to queue.""" - # Arrange - mock_join = MagicMock() - mock_join.tenant_id = "tenant-1" - - mock_query = MagicMock() - mock_query.filter_by.return_value.all.return_value = [mock_join] - mock_db_session.query.return_value = mock_query - - mock_queue_task.return_value = False - - with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: - mock_config.ENTERPRISE_ENABLED = True - - # Act - result = sync_account_deletion(account_id="acc-123", source="account_deleted") - - # Assert - assert result is False - mock_queue_task.assert_called_once() diff --git a/api/tests/unit_tests/services/external_dataset_service.py b/api/tests/unit_tests/services/external_dataset_service.py index afc3b29fcae..a8ef35a0d0b 100644 --- a/api/tests/unit_tests/services/external_dataset_service.py +++ b/api/tests/unit_tests/services/external_dataset_service.py @@ -545,7 +545,7 @@ class TestExternalDatasetServiceProcessExternalApi: params={}, ) - from dify_graph.nodes.http_request.exc import InvalidHttpMethodError + from graphon.nodes.http_request.exc import InvalidHttpMethodError with pytest.raises(InvalidHttpMethodError): ExternalDatasetService.process_external_api(settings, files=None) diff --git a/api/tests/unit_tests/services/plugin/test_oauth_service.py b/api/tests/unit_tests/services/plugin/test_oauth_service.py index 27df4556bc1..eee65b3a18f 100644 --- a/api/tests/unit_tests/services/plugin/test_oauth_service.py +++ b/api/tests/unit_tests/services/plugin/test_oauth_service.py @@ -13,6 +13,10 @@ import pytest from services.plugin.oauth_service import OAuthProxyService +def _oauth_proxy_setex_calls(redis_client) -> list: + return [call for call in redis_client.setex.call_args_list if call.args[0].startswith("oauth_proxy_context:")] + + class TestCreateProxyContext: def test_stores_context_in_redis_with_ttl(self): context_id = OAuthProxyService.create_proxy_context( @@ -22,8 +26,9 @@ class TestCreateProxyContext: assert context_id # non-empty UUID string from extensions.ext_redis import redis_client - redis_client.setex.assert_called_once() - call_args = redis_client.setex.call_args + oauth_calls = _oauth_proxy_setex_calls(redis_client) + assert len(oauth_calls) == 1 + call_args = oauth_calls[0] key = call_args[0][0] ttl = call_args[0][1] stored_data = json.loads(call_args[0][2]) @@ -88,3 +93,20 @@ class TestUseProxyContext: assert result == stored expected_key = "oauth_proxy_context:valid-id" redis_client.delete.assert_called_once_with(expected_key) + + def test_returns_context_with_credential_id(self): + from extensions.ext_redis import redis_client + + stored = { + "user_id": "u1", + "tenant_id": "t1", + "plugin_id": "p1", + "provider": "github", + "credential_id": "cred-42", + } + redis_client.get.return_value = json.dumps(stored).encode() + + result = OAuthProxyService.use_proxy_context("ctx-with-cred") + + assert result["credential_id"] == "cred-42" + assert result["tenant_id"] == "t1" diff --git a/api/tests/unit_tests/services/recommend_app/test_database_retrieval.py b/api/tests/unit_tests/services/recommend_app/test_database_retrieval.py deleted file mode 100644 index 5d21665f75a..00000000000 --- a/api/tests/unit_tests/services/recommend_app/test_database_retrieval.py +++ /dev/null @@ -1,145 +0,0 @@ -from types import SimpleNamespace -from unittest.mock import MagicMock, patch - -from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval -from services.recommend_app.recommend_app_type import RecommendAppType - - -class TestDatabaseRecommendAppRetrieval: - def test_get_type(self): - assert DatabaseRecommendAppRetrieval().get_type() == RecommendAppType.DATABASE - - def test_get_recommended_apps_delegates(self): - with patch.object( - DatabaseRecommendAppRetrieval, - "fetch_recommended_apps_from_db", - return_value={"recommended_apps": [], "categories": []}, - ) as mock_fetch: - result = DatabaseRecommendAppRetrieval().get_recommended_apps_and_categories("en-US") - mock_fetch.assert_called_once_with("en-US") - assert result == {"recommended_apps": [], "categories": []} - - def test_get_recommend_app_detail_delegates(self): - with patch.object( - DatabaseRecommendAppRetrieval, - "fetch_recommended_app_detail_from_db", - return_value={"id": "app-1"}, - ) as mock_fetch: - result = DatabaseRecommendAppRetrieval().get_recommend_app_detail("app-1") - mock_fetch.assert_called_once_with("app-1") - assert result == {"id": "app-1"} - - -class TestFetchRecommendedAppsFromDb: - def _make_recommended_app(self, app_id, category, is_public=True, has_site=True): - site = ( - SimpleNamespace( - description="desc", - copyright="copy", - privacy_policy="pp", - custom_disclaimer="cd", - ) - if has_site - else None - ) - app = ( - SimpleNamespace(is_public=is_public, site=site) - if is_public - else SimpleNamespace(is_public=False, site=site) - ) - return SimpleNamespace( - id=f"rec-{app_id}", - app=app, - app_id=app_id, - category=category, - position=1, - is_listed=True, - ) - - @patch("services.recommend_app.database.database_retrieval.db") - def test_returns_apps_and_sorted_categories(self, mock_db): - rec1 = self._make_recommended_app("a1", "writing") - rec2 = self._make_recommended_app("a2", "assistant") - mock_db.session.scalars.return_value.all.return_value = [rec1, rec2] - - result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US") - - assert len(result["recommended_apps"]) == 2 - assert result["categories"] == ["assistant", "writing"] - - @patch("services.recommend_app.database.database_retrieval.db") - def test_falls_back_to_default_language_when_empty(self, mock_db): - mock_db.session.scalars.return_value.all.side_effect = [ - [], - [self._make_recommended_app("a1", "chat")], - ] - - result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("fr-FR") - - assert len(result["recommended_apps"]) == 1 - assert mock_db.session.scalars.call_count == 2 - - @patch("services.recommend_app.database.database_retrieval.db") - def test_skips_non_public_apps(self, mock_db): - rec = self._make_recommended_app("a1", "chat", is_public=False) - mock_db.session.scalars.return_value.all.return_value = [rec] - - result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US") - - assert result["recommended_apps"] == [] - - @patch("services.recommend_app.database.database_retrieval.db") - def test_skips_apps_without_site(self, mock_db): - rec = self._make_recommended_app("a1", "chat", has_site=False) - mock_db.session.scalars.return_value.all.return_value = [rec] - - result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US") - - assert result["recommended_apps"] == [] - - -class TestFetchRecommendedAppDetailFromDb: - @patch("services.recommend_app.database.database_retrieval.db") - def test_returns_none_when_not_listed(self, mock_db): - mock_db.session.query.return_value.where.return_value.first.return_value = None - - result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db("app-1") - - assert result is None - - @patch("services.recommend_app.database.database_retrieval.AppDslService") - @patch("services.recommend_app.database.database_retrieval.db") - def test_returns_none_when_app_not_public(self, mock_db, mock_dsl): - rec_chain = MagicMock() - rec_chain.where.return_value.first.return_value = SimpleNamespace(app_id="app-1") - app_chain = MagicMock() - app_chain.where.return_value.first.return_value = SimpleNamespace(id="app-1", is_public=False) - mock_db.session.query.side_effect = [rec_chain, app_chain] - - result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db("app-1") - - assert result is None - - @patch("services.recommend_app.database.database_retrieval.AppDslService") - @patch("services.recommend_app.database.database_retrieval.db") - def test_returns_detail_on_success(self, mock_db, mock_dsl): - app_model = SimpleNamespace( - id="app-1", - name="My App", - icon="icon.png", - icon_background="#fff", - mode="chat", - is_public=True, - ) - rec_chain = MagicMock() - rec_chain.where.return_value.first.return_value = SimpleNamespace(app_id="app-1") - app_chain = MagicMock() - app_chain.where.return_value.first.return_value = app_model - mock_db.session.query.side_effect = [rec_chain, app_chain] - mock_dsl.export_dsl.return_value = "exported_yaml" - - result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db("app-1") - - assert result["id"] == "app-1" - assert result["name"] == "My App" - assert result["export_data"] == "exported_yaml" diff --git a/api/tests/unit_tests/services/retention/conversation/test_messages_clean_service.py b/api/tests/unit_tests/services/retention/conversation/test_messages_clean_service.py deleted file mode 100644 index f9d901fca24..00000000000 --- a/api/tests/unit_tests/services/retention/conversation/test_messages_clean_service.py +++ /dev/null @@ -1,311 +0,0 @@ -import datetime -from unittest.mock import MagicMock, patch - -import pytest - -from services.retention.conversation.messages_clean_policy import ( - BillingDisabledPolicy, -) -from services.retention.conversation.messages_clean_service import MessagesCleanService - - -class TestMessagesCleanService: - @pytest.fixture(autouse=True) - def mock_db_engine(self): - with patch("services.retention.conversation.messages_clean_service.db") as mock_db: - mock_db.engine = MagicMock() - yield mock_db.engine - - @pytest.fixture - def mock_db_session(self, mock_db_engine): - with patch("services.retention.conversation.messages_clean_service.Session") as mock_session_cls: - mock_session = MagicMock() - mock_session_cls.return_value.__enter__.return_value = mock_session - yield mock_session - - @pytest.fixture - def mock_policy(self): - policy = MagicMock(spec=BillingDisabledPolicy) - return policy - - def test_run_calls_clean_messages(self, mock_policy): - service = MessagesCleanService( - policy=mock_policy, - end_before=datetime.datetime.now(), - batch_size=10, - ) - with patch.object(service, "_clean_messages_by_time_range") as mock_clean: - mock_clean.return_value = {"total_deleted": 5} - result = service.run() - assert result == {"total_deleted": 5} - mock_clean.assert_called_once() - - def test_clean_messages_by_time_range_basic(self, mock_db_session, mock_policy): - # Arrange - end_before = datetime.datetime(2024, 1, 1, 12, 0, 0) - service = MessagesCleanService( - policy=mock_policy, - end_before=end_before, - batch_size=10, - ) - - mock_db_session.execute.side_effect = [ - MagicMock(all=lambda: [("msg1", "app1", datetime.datetime(2024, 1, 1, 10, 0, 0))]), # messages - MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # apps - MagicMock( - rowcount=1 - ), # delete relations (this is wrong, relations delete doesn't use rowcount here, but execute) - MagicMock(rowcount=1), # delete relations - MagicMock(rowcount=1), # delete relations - MagicMock(rowcount=1), # delete relations - MagicMock(rowcount=1), # delete relations - MagicMock(rowcount=1), # delete relations - MagicMock(rowcount=1), # delete relations - MagicMock(rowcount=1), # delete relations - MagicMock(rowcount=1), # delete messages - MagicMock(all=lambda: []), # next batch empty - ] - - # Reset side_effect to be more robust - # The service calls session.execute for: - # 1. Fetch messages - # 2. Fetch apps - # 3. Batch delete relations (8 calls if IDs exist) - # 4. Delete messages - - mock_returns = [ - MagicMock(all=lambda: [("msg1", "app1", datetime.datetime(2024, 1, 1, 10, 0, 0))]), # fetch messages - MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # fetch apps - ] - # 8 deletes for relations - mock_returns.extend([MagicMock() for _ in range(8)]) - # 1 delete for messages - mock_returns.append(MagicMock(rowcount=1)) - # Final fetch messages (empty) - mock_returns.append(MagicMock(all=lambda: [])) - - mock_db_session.execute.side_effect = mock_returns - mock_policy.filter_message_ids.return_value = ["msg1"] - - # Act - with patch("services.retention.conversation.messages_clean_service.time.sleep"): - stats = service.run() - - # Assert - assert stats["total_messages"] == 1 - assert stats["total_deleted"] == 1 - assert stats["batches"] == 2 - - def test_clean_messages_by_time_range_with_start_from(self, mock_db_session, mock_policy): - start_from = datetime.datetime(2024, 1, 1, 0, 0, 0) - end_before = datetime.datetime(2024, 1, 1, 12, 0, 0) - service = MessagesCleanService( - policy=mock_policy, - start_from=start_from, - end_before=end_before, - batch_size=10, - ) - - mock_db_session.execute.side_effect = [ - MagicMock(all=lambda: []), # No messages - ] - - stats = service.run() - assert stats["total_messages"] == 0 - - def test_clean_messages_by_time_range_with_cursor(self, mock_db_session, mock_policy): - # Test pagination with cursor - end_before = datetime.datetime(2024, 1, 1, 12, 0, 0) - service = MessagesCleanService( - policy=mock_policy, - end_before=end_before, - batch_size=1, - ) - - msg1_time = datetime.datetime(2024, 1, 1, 10, 0, 0) - msg2_time = datetime.datetime(2024, 1, 1, 11, 0, 0) - - mock_returns = [] - # Batch 1 - mock_returns.append(MagicMock(all=lambda: [("msg1", "app1", msg1_time)])) - mock_returns.append(MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")])) - mock_returns.extend([MagicMock() for _ in range(8)]) # relations - mock_returns.append(MagicMock(rowcount=1)) # messages - - # Batch 2 - mock_returns.append(MagicMock(all=lambda: [("msg2", "app1", msg2_time)])) - mock_returns.append(MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")])) - mock_returns.extend([MagicMock() for _ in range(8)]) # relations - mock_returns.append(MagicMock(rowcount=1)) # messages - - # Batch 3 - mock_returns.append(MagicMock(all=lambda: [])) - - mock_db_session.execute.side_effect = mock_returns - mock_policy.filter_message_ids.return_value = ["msg1"] # Simplified - - with patch("services.retention.conversation.messages_clean_service.time.sleep"): - stats = service.run() - - assert stats["batches"] == 3 - assert stats["total_messages"] == 2 - - def test_clean_messages_by_time_range_dry_run(self, mock_db_session, mock_policy): - service = MessagesCleanService( - policy=mock_policy, - end_before=datetime.datetime.now(), - batch_size=10, - dry_run=True, - ) - - mock_db_session.execute.side_effect = [ - MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages - MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # apps - MagicMock(all=lambda: []), # next batch empty - ] - mock_policy.filter_message_ids.return_value = ["msg1"] - - with patch("services.retention.conversation.messages_clean_service.random.sample") as mock_sample: - mock_sample.return_value = ["msg1"] - stats = service.run() - assert stats["filtered_messages"] == 1 - assert stats["total_deleted"] == 0 # Dry run - mock_sample.assert_called() - - def test_clean_messages_by_time_range_no_apps_found(self, mock_db_session, mock_policy): - service = MessagesCleanService( - policy=mock_policy, - end_before=datetime.datetime.now(), - batch_size=10, - ) - - mock_db_session.execute.side_effect = [ - MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages - MagicMock(all=lambda: []), # apps NOT found - MagicMock(all=lambda: []), # next batch empty - ] - - stats = service.run() - assert stats["total_messages"] == 1 - assert stats["total_deleted"] == 0 - - def test_clean_messages_by_time_range_no_app_ids(self, mock_db_session, mock_policy): - service = MessagesCleanService( - policy=mock_policy, - end_before=datetime.datetime.now(), - batch_size=10, - ) - - mock_db_session.execute.side_effect = [ - MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages - MagicMock(all=lambda: []), # next batch empty - ] - - # We need to successfully execute line 228 and 229, then return empty at 251. - # line 228: raw_messages = list(session.execute(msg_stmt).all()) - # line 251: app_ids = list({msg.app_id for msg in messages}) - - calls = [] - - def list_side_effect(arg): - calls.append(arg) - if len(calls) == 2: # This is the second call to list() in the loop - return [] - return list(arg) - - with patch("services.retention.conversation.messages_clean_service.list", side_effect=list_side_effect): - stats = service.run() - assert stats["batches"] == 2 - assert stats["total_messages"] == 1 - - def test_from_time_range_validation(self, mock_policy): - now = datetime.datetime.now() - # Test start_from >= end_before - with pytest.raises(ValueError, match="start_from .* must be less than end_before"): - MessagesCleanService.from_time_range(mock_policy, now, now) - - # Test batch_size <= 0 - with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): - MessagesCleanService.from_time_range(mock_policy, now - datetime.timedelta(days=1), now, batch_size=0) - - def test_from_time_range_success(self, mock_policy): - start = datetime.datetime(2024, 1, 1) - end = datetime.datetime(2024, 2, 1) - # Mock logger to avoid actual logging if needed, though it's fine - service = MessagesCleanService.from_time_range(mock_policy, start, end) - assert service._start_from == start - assert service._end_before == end - - def test_from_days_validation(self, mock_policy): - # Test days < 0 - with pytest.raises(ValueError, match="days .* must be greater than or equal to 0"): - MessagesCleanService.from_days(mock_policy, days=-1) - - # Test batch_size <= 0 - with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): - MessagesCleanService.from_days(mock_policy, days=30, batch_size=0) - - def test_from_days_success(self, mock_policy): - with patch("services.retention.conversation.messages_clean_service.naive_utc_now") as mock_now: - fixed_now = datetime.datetime(2024, 6, 1) - mock_now.return_value = fixed_now - - service = MessagesCleanService.from_days(mock_policy, days=10) - assert service._start_from is None - assert service._end_before == fixed_now - datetime.timedelta(days=10) - - def test_clean_messages_by_time_range_no_messages_to_delete(self, mock_db_session, mock_policy): - service = MessagesCleanService( - policy=mock_policy, - end_before=datetime.datetime.now(), - batch_size=10, - ) - - mock_db_session.execute.side_effect = [ - MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages - MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # apps - MagicMock(all=lambda: []), # next batch empty - ] - mock_policy.filter_message_ids.return_value = [] # Policy says NO - - stats = service.run() - assert stats["total_messages"] == 1 - assert stats["filtered_messages"] == 0 - assert stats["total_deleted"] == 0 - - def test_batch_delete_message_relations_empty(self, mock_db_session): - MessagesCleanService._batch_delete_message_relations(mock_db_session, []) - mock_db_session.execute.assert_not_called() - - def test_batch_delete_message_relations_with_ids(self, mock_db_session): - MessagesCleanService._batch_delete_message_relations(mock_db_session, ["msg1", "msg2"]) - assert mock_db_session.execute.call_count == 8 # 8 tables to clean up - - def test_clean_messages_interval_from_env(self, mock_db_session, mock_policy): - service = MessagesCleanService( - policy=mock_policy, - end_before=datetime.datetime.now(), - batch_size=10, - ) - - mock_returns = [ - MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages - MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # apps - ] - mock_returns.extend([MagicMock() for _ in range(8)]) # relations - mock_returns.append(MagicMock(rowcount=1)) # messages - mock_returns.append(MagicMock(all=lambda: [])) # next batch empty - - mock_db_session.execute.side_effect = mock_returns - mock_policy.filter_message_ids.return_value = ["msg1"] - - with patch( - "services.retention.conversation.messages_clean_service.dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL", - 500, - ): - with patch("services.retention.conversation.messages_clean_service.time.sleep") as mock_sleep: - with patch("services.retention.conversation.messages_clean_service.random.uniform") as mock_uniform: - mock_uniform.return_value = 300.0 - service.run() - mock_uniform.assert_called_with(0, 500) - mock_sleep.assert_called_with(0.3) diff --git a/api/tests/unit_tests/services/retention/workflow_run/test_delete_archived_workflow_run.py b/api/tests/unit_tests/services/retention/workflow_run/test_delete_archived_workflow_run.py deleted file mode 100644 index 9fe153c1534..00000000000 --- a/api/tests/unit_tests/services/retention/workflow_run/test_delete_archived_workflow_run.py +++ /dev/null @@ -1,216 +0,0 @@ -from datetime import datetime -from unittest.mock import MagicMock, patch - -import pytest -from sqlalchemy.orm import Session - -from models.workflow import WorkflowRun -from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion, DeleteResult - - -class TestArchivedWorkflowRunDeletion: - @pytest.fixture - def mock_db(self): - with patch("services.retention.workflow_run.delete_archived_workflow_run.db") as mock_db: - mock_db.engine = MagicMock() - yield mock_db - - @pytest.fixture - def mock_sessionmaker(self): - with patch("services.retention.workflow_run.delete_archived_workflow_run.sessionmaker") as mock_sm: - mock_session = MagicMock(spec=Session) - mock_sm.return_value.return_value.__enter__.return_value = mock_session - yield mock_sm, mock_session - - @pytest.fixture - def mock_workflow_run_repo(self): - with patch( - "services.retention.workflow_run.delete_archived_workflow_run.APIWorkflowRunRepository" - ) as mock_repo_cls: - mock_repo = MagicMock() - yield mock_repo - - def test_delete_by_run_id_success(self, mock_db, mock_sessionmaker): - mock_sm, mock_session = mock_sessionmaker - run_id = "run-123" - tenant_id = "tenant-456" - - mock_run = MagicMock(spec=WorkflowRun) - mock_run.id = run_id - mock_run.tenant_id = tenant_id - mock_session.get.return_value = mock_run - - deletion = ArchivedWorkflowRunDeletion() - - with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo: - mock_repo = MagicMock() - mock_get_repo.return_value = mock_repo - mock_repo.get_archived_run_ids.return_value = [run_id] - - with patch.object(deletion, "_delete_run") as mock_delete_run: - expected_result = DeleteResult(run_id=run_id, tenant_id=tenant_id, success=True) - mock_delete_run.return_value = expected_result - - result = deletion.delete_by_run_id(run_id) - - assert result == expected_result - mock_session.get.assert_called_once_with(WorkflowRun, run_id) - mock_repo.get_archived_run_ids.assert_called_once() - mock_delete_run.assert_called_once_with(mock_run) - - def test_delete_by_run_id_not_found(self, mock_db, mock_sessionmaker): - mock_sm, mock_session = mock_sessionmaker - run_id = "run-123" - mock_session.get.return_value = None - - deletion = ArchivedWorkflowRunDeletion() - with patch.object(deletion, "_get_workflow_run_repo"): - result = deletion.delete_by_run_id(run_id) - - assert result.success is False - assert "not found" in result.error - assert result.run_id == run_id - - def test_delete_by_run_id_not_archived(self, mock_db, mock_sessionmaker): - mock_sm, mock_session = mock_sessionmaker - run_id = "run-123" - - mock_run = MagicMock(spec=WorkflowRun) - mock_run.id = run_id - mock_session.get.return_value = mock_run - - deletion = ArchivedWorkflowRunDeletion() - with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo: - mock_repo = MagicMock() - mock_get_repo.return_value = mock_repo - mock_repo.get_archived_run_ids.return_value = [] - - result = deletion.delete_by_run_id(run_id) - - assert result.success is False - assert "is not archived" in result.error - - def test_delete_batch(self, mock_db, mock_sessionmaker): - mock_sm, mock_session = mock_sessionmaker - deletion = ArchivedWorkflowRunDeletion() - - mock_run1 = MagicMock(spec=WorkflowRun) - mock_run1.id = "run-1" - mock_run2 = MagicMock(spec=WorkflowRun) - mock_run2.id = "run-2" - - with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo: - mock_repo = MagicMock() - mock_get_repo.return_value = mock_repo - mock_repo.get_archived_runs_by_time_range.return_value = [mock_run1, mock_run2] - - with patch.object(deletion, "_delete_run") as mock_delete_run: - mock_delete_run.side_effect = [ - DeleteResult(run_id="run-1", tenant_id="t1", success=True), - DeleteResult(run_id="run-2", tenant_id="t1", success=True), - ] - - results = deletion.delete_batch(tenant_ids=["t1"], start_date=datetime.now(), end_date=datetime.now()) - - assert len(results) == 2 - assert results[0].run_id == "run-1" - assert results[1].run_id == "run-2" - assert mock_delete_run.call_count == 2 - - def test_delete_run_dry_run(self): - deletion = ArchivedWorkflowRunDeletion(dry_run=True) - mock_run = MagicMock(spec=WorkflowRun) - mock_run.id = "run-123" - mock_run.tenant_id = "tenant-456" - - result = deletion._delete_run(mock_run) - - assert result.success is True - assert result.run_id == "run-123" - - def test_delete_run_success(self): - deletion = ArchivedWorkflowRunDeletion(dry_run=False) - mock_run = MagicMock(spec=WorkflowRun) - mock_run.id = "run-123" - mock_run.tenant_id = "tenant-456" - - with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo: - mock_repo = MagicMock() - mock_get_repo.return_value = mock_repo - mock_repo.delete_runs_with_related.return_value = {"workflow_runs": 1} - - result = deletion._delete_run(mock_run) - - assert result.success is True - assert result.deleted_counts == {"workflow_runs": 1} - - def test_delete_run_exception(self): - deletion = ArchivedWorkflowRunDeletion(dry_run=False) - mock_run = MagicMock(spec=WorkflowRun) - mock_run.id = "run-123" - - with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo: - mock_repo = MagicMock() - mock_get_repo.return_value = mock_repo - mock_repo.delete_runs_with_related.side_effect = Exception("Database error") - - result = deletion._delete_run(mock_run) - - assert result.success is False - assert result.error == "Database error" - - def test_delete_trigger_logs(self): - mock_session = MagicMock(spec=Session) - run_ids = ["run-1", "run-2"] - - with patch( - "services.retention.workflow_run.delete_archived_workflow_run.SQLAlchemyWorkflowTriggerLogRepository" - ) as mock_repo_cls: - mock_repo = MagicMock() - mock_repo_cls.return_value = mock_repo - mock_repo.delete_by_run_ids.return_value = 5 - - count = ArchivedWorkflowRunDeletion._delete_trigger_logs(mock_session, run_ids) - - assert count == 5 - mock_repo_cls.assert_called_once_with(mock_session) - mock_repo.delete_by_run_ids.assert_called_once_with(run_ids) - - def test_delete_node_executions(self): - mock_session = MagicMock(spec=Session) - mock_run = MagicMock(spec=WorkflowRun) - mock_run.id = "run-1" - runs = [mock_run] - - with patch( - "repositories.factory.DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository" - ) as mock_create_repo: - mock_repo = MagicMock() - mock_create_repo.return_value = mock_repo - mock_repo.delete_by_runs.return_value = (1, 2) - - with patch("services.retention.workflow_run.delete_archived_workflow_run.sessionmaker") as mock_sm: - result = ArchivedWorkflowRunDeletion._delete_node_executions(mock_session, runs) - - assert result == (1, 2) - mock_create_repo.assert_called_once() - mock_repo.delete_by_runs.assert_called_once_with(mock_session, ["run-1"]) - - def test_get_workflow_run_repo(self, mock_db): - deletion = ArchivedWorkflowRunDeletion() - - with patch( - "repositories.factory.DifyAPIRepositoryFactory.create_api_workflow_run_repository" - ) as mock_create_repo: - mock_repo = MagicMock() - mock_create_repo.return_value = mock_repo - - # First call - repo1 = deletion._get_workflow_run_repo() - assert repo1 == mock_repo - assert deletion.workflow_run_repo == mock_repo - - # Second call (should return cached) - repo2 = deletion._get_workflow_run_repo() - assert repo2 == mock_repo - mock_create_repo.assert_called_once() diff --git a/api/tests/unit_tests/services/retention/workflow_run/test_restore_archived_workflow_run.py b/api/tests/unit_tests/services/retention/workflow_run/test_restore_archived_workflow_run.py index 4bfdba87a04..628e4e594dd 100644 --- a/api/tests/unit_tests/services/retention/workflow_run/test_restore_archived_workflow_run.py +++ b/api/tests/unit_tests/services/retention/workflow_run/test_restore_archived_workflow_run.py @@ -13,6 +13,7 @@ from datetime import datetime from unittest.mock import Mock, create_autospec, patch import pytest +from pydantic import ValidationError from sqlalchemy import Column, Integer, MetaData, String, Table from libs.archive_storage import ArchiveStorageNotConfiguredError @@ -292,7 +293,7 @@ class TestLoadManifestFromZip: zip_buffer.seek(0) with zipfile.ZipFile(zip_buffer, "r") as archive: - with pytest.raises(json.JSONDecodeError): + with pytest.raises(ValidationError): WorkflowRunRestore._load_manifest_from_zip(archive) diff --git a/api/tests/unit_tests/services/segment_service.py b/api/tests/unit_tests/services/segment_service.py index affbc8d0b51..f0a66a00d44 100644 --- a/api/tests/unit_tests/services/segment_service.py +++ b/api/tests/unit_tests/services/segment_service.py @@ -2,8 +2,10 @@ from unittest.mock import MagicMock, Mock, patch import pytest +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.account import Account from models.dataset import ChildChunk, Dataset, Document, DocumentSegment +from models.enums import SegmentType from services.dataset_service import SegmentService from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError @@ -77,7 +79,7 @@ class SegmentTestDataFactory: chunk.word_count = word_count chunk.index_node_id = f"node-{chunk_id}" chunk.index_node_hash = "hash-123" - chunk.type = "automatic" + chunk.type = SegmentType.AUTOMATIC chunk.created_by = "user-123" chunk.updated_by = None chunk.updated_at = None @@ -90,7 +92,7 @@ class SegmentTestDataFactory: document_id: str = "doc-123", dataset_id: str = "dataset-123", tenant_id: str = "tenant-123", - doc_form: str = "text_model", + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, word_count: int = 100, **kwargs, ) -> Mock: @@ -109,7 +111,7 @@ class SegmentTestDataFactory: def create_dataset_mock( dataset_id: str = "dataset-123", tenant_id: str = "tenant-123", - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, embedding_model: str = "text-embedding-ada-002", embedding_model_provider: str = "openai", **kwargs, @@ -161,7 +163,7 @@ class TestSegmentServiceCreateSegment: """Test successful creation of a segment.""" # Arrange document = SegmentTestDataFactory.create_document_mock(word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) args = {"content": "New segment content", "keywords": ["test", "segment"]} mock_query = MagicMock() @@ -209,8 +211,8 @@ class TestSegmentServiceCreateSegment: def test_create_segment_with_qa_model(self, mock_db_session, mock_current_user): """Test creation of segment with QA model (requires answer).""" # Arrange - document = SegmentTestDataFactory.create_document_mock(doc_form="qa_model", word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy") + document = SegmentTestDataFactory.create_document_mock(doc_form=IndexStructureType.QA_INDEX, word_count=100) + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) args = {"content": "What is AI?", "answer": "AI is Artificial Intelligence", "keywords": ["ai"]} mock_query = MagicMock() @@ -245,7 +247,7 @@ class TestSegmentServiceCreateSegment: """Test creation of segment with high quality indexing technique.""" # Arrange document = SegmentTestDataFactory.create_document_mock(word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) args = {"content": "New segment content", "keywords": ["test"]} mock_query = MagicMock() @@ -265,7 +267,7 @@ class TestSegmentServiceCreateSegment: patch( "services.dataset_service.VectorService.create_segments_vector", autospec=True ) as mock_vector_service, - patch("services.dataset_service.ModelManager", autospec=True) as mock_model_manager_class, + patch("services.dataset_service.ModelManager.for_tenant", autospec=True) as mock_model_manager_class, patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash, patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now, ): @@ -287,7 +289,7 @@ class TestSegmentServiceCreateSegment: """Test segment creation when vector indexing fails.""" # Arrange document = SegmentTestDataFactory.create_document_mock(word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) args = {"content": "New segment content", "keywords": ["test"]} mock_query = MagicMock() @@ -340,7 +342,7 @@ class TestSegmentServiceUpdateSegment: # Arrange segment = SegmentTestDataFactory.create_segment_mock(enabled=True, word_count=10) document = SegmentTestDataFactory.create_document_mock(word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) args = SegmentUpdateArgs(content="Updated content", keywords=["updated"]) mock_db_session.query.return_value.where.return_value.first.return_value = segment @@ -428,8 +430,8 @@ class TestSegmentServiceUpdateSegment: """Test update segment with QA model (includes answer).""" # Arrange segment = SegmentTestDataFactory.create_segment_mock(enabled=True, word_count=10) - document = SegmentTestDataFactory.create_document_mock(doc_form="qa_model", word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy") + document = SegmentTestDataFactory.create_document_mock(doc_form=IndexStructureType.QA_INDEX, word_count=100) + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) args = SegmentUpdateArgs(content="Updated question", answer="Updated answer", keywords=["qa"]) mock_db_session.query.return_value.where.return_value.first.return_value = segment diff --git a/api/tests/unit_tests/services/test_advanced_prompt_template_service.py b/api/tests/unit_tests/services/test_advanced_prompt_template_service.py deleted file mode 100644 index a6bc79e82b1..00000000000 --- a/api/tests/unit_tests/services/test_advanced_prompt_template_service.py +++ /dev/null @@ -1,214 +0,0 @@ -""" -Unit tests for services.advanced_prompt_template_service -""" - -import copy - -from core.prompt.prompt_templates.advanced_prompt_templates import ( - BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, - BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, - BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, - BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, - BAICHUAN_CONTEXT, - CHAT_APP_CHAT_PROMPT_CONFIG, - CHAT_APP_COMPLETION_PROMPT_CONFIG, - COMPLETION_APP_CHAT_PROMPT_CONFIG, - COMPLETION_APP_COMPLETION_PROMPT_CONFIG, - CONTEXT, -) -from models.model import AppMode -from services.advanced_prompt_template_service import AdvancedPromptTemplateService - - -class TestAdvancedPromptTemplateService: - """Test suite for AdvancedPromptTemplateService.""" - - def test_get_prompt_should_use_baichuan_prompt_when_model_name_contains_baichuan(self) -> None: - """Test baichuan model names use baichuan context prompt.""" - # Arrange - args = { - "app_mode": AppMode.CHAT, - "model_mode": "chat", - "model_name": "Baichuan2-13B", - "has_context": "true", - } - - # Act - result = AdvancedPromptTemplateService.get_prompt(args) - - # Assert - assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(BAICHUAN_CONTEXT) - - def test_get_prompt_should_use_common_prompt_when_model_name_not_baichuan(self) -> None: - """Test non-baichuan model names use common prompt.""" - # Arrange - args = { - "app_mode": AppMode.CHAT, - "model_mode": "completion", - "model_name": "gpt-4", - "has_context": "false", - } - original_config = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_prompt(args) - - # Assert - assert result == original_config - assert original_config == CHAT_APP_COMPLETION_PROMPT_CONFIG - - def test_get_common_prompt_should_return_empty_dict_when_app_mode_invalid(self) -> None: - """Test invalid app mode returns empty dict.""" - # Arrange - app_mode = "invalid" - model_mode = "chat" - - # Act - result = AdvancedPromptTemplateService.get_common_prompt(app_mode, model_mode, "true") - - # Assert - assert result == {} - - def test_get_common_prompt_should_prepend_context_for_completion_prompt(self) -> None: - """Test context is prepended for completion prompt when has_context is true.""" - # Arrange - original_config = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "completion", "true") - - # Assert - assert result["completion_prompt_config"]["prompt"]["text"].startswith(CONTEXT) - assert original_config == CHAT_APP_COMPLETION_PROMPT_CONFIG - - def test_get_common_prompt_should_prepend_context_for_chat_prompt(self) -> None: - """Test context is prepended for chat prompt when has_context is true.""" - # Arrange - original_config = copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "chat", "true") - - # Assert - assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(CONTEXT) - assert original_config == COMPLETION_APP_CHAT_PROMPT_CONFIG - - def test_get_common_prompt_should_return_chat_prompt_without_context_when_has_context_false(self) -> None: - """Test chat prompt remains unchanged when has_context is false.""" - # Arrange - original_config = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "chat", "false") - - # Assert - assert result == original_config - assert original_config == CHAT_APP_CHAT_PROMPT_CONFIG - - def test_get_common_prompt_should_return_completion_prompt_for_completion_app_mode(self) -> None: - """Test completion app mode with completion model returns completion prompt.""" - # Arrange - original_config = copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "completion", "false") - - # Assert - assert result == original_config - assert original_config == COMPLETION_APP_COMPLETION_PROMPT_CONFIG - - def test_get_common_prompt_should_return_empty_dict_when_model_mode_invalid(self) -> None: - """Test invalid model mode returns empty dict.""" - # Arrange - app_mode = AppMode.CHAT - model_mode = "invalid" - - # Act - result = AdvancedPromptTemplateService.get_common_prompt(app_mode, model_mode, "false") - - # Assert - assert result == {} - - def test_get_completion_prompt_should_not_prepend_context_when_has_context_false(self) -> None: - """Test helper keeps completion prompt unchanged when context is disabled.""" - # Arrange - prompt_template = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG) - original_text = prompt_template["completion_prompt_config"]["prompt"]["text"] - - # Act - result = AdvancedPromptTemplateService.get_completion_prompt(prompt_template, "false", CONTEXT) - - # Assert - assert result["completion_prompt_config"]["prompt"]["text"] == original_text - - def test_get_chat_prompt_should_not_prepend_context_when_has_context_false(self) -> None: - """Test helper keeps chat prompt unchanged when context is disabled.""" - # Arrange - prompt_template = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG) - original_text = prompt_template["chat_prompt_config"]["prompt"][0]["text"] - - # Act - result = AdvancedPromptTemplateService.get_chat_prompt(prompt_template, "false", CONTEXT) - - # Assert - assert result["chat_prompt_config"]["prompt"][0]["text"] == original_text - - def test_get_baichuan_prompt_should_return_chat_completion_config_when_chat_completion(self) -> None: - """Test baichuan chat/completion returns the expected config.""" - # Arrange - original_config = copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "completion", "false") - - # Assert - assert result == original_config - assert original_config == BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG - - def test_get_baichuan_prompt_should_return_completion_chat_config_when_completion_chat(self) -> None: - """Test baichuan completion/chat returns the expected config.""" - # Arrange - original_config = copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "chat", "false") - - # Assert - assert result == original_config - assert original_config == BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG - - def test_get_baichuan_prompt_should_return_completion_completion_config_when_enabled_context(self) -> None: - """Test baichuan completion/completion prepends baichuan context when enabled.""" - # Arrange - original_config = copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "completion", "true") - - # Assert - assert result["completion_prompt_config"]["prompt"]["text"].startswith(BAICHUAN_CONTEXT) - assert original_config == BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG - - def test_get_baichuan_prompt_should_return_chat_chat_config_when_enabled_context(self) -> None: - """Test baichuan chat/chat prepends baichuan context when enabled.""" - # Arrange - original_config = copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "chat", "true") - - # Assert - assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(BAICHUAN_CONTEXT) - assert original_config == BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG - - def test_get_baichuan_prompt_should_return_empty_dict_when_invalid_inputs(self) -> None: - """Test invalid baichuan mode combinations return empty dict.""" - # Arrange - app_mode = "invalid" - model_mode = "invalid" - - # Act - result = AdvancedPromptTemplateService.get_baichuan_prompt(app_mode, model_mode, "true") - - # Assert - assert result == {} diff --git a/api/tests/unit_tests/services/test_agent_service.py b/api/tests/unit_tests/services/test_agent_service.py deleted file mode 100644 index 7ce3d7ef7bf..00000000000 --- a/api/tests/unit_tests/services/test_agent_service.py +++ /dev/null @@ -1,346 +0,0 @@ -""" -Unit tests for services.agent_service -""" - -from collections.abc import Callable -from datetime import datetime -from unittest.mock import MagicMock, patch - -import pytest -import pytz - -from core.plugin.impl.exc import PluginDaemonClientSideError -from models import Account -from models.model import App, Conversation, EndUser, Message, MessageAgentThought -from services.agent_service import AgentService - - -def _make_current_user_account(timezone: str = "UTC") -> Account: - account = Account(name="Test User", email="test@example.com") - account.timezone = timezone - return account - - -def _make_app_model(app_model_config: MagicMock | None) -> MagicMock: - app_model = MagicMock(spec=App) - app_model.id = "app-123" - app_model.tenant_id = "tenant-123" - app_model.app_model_config = app_model_config - return app_model - - -def _make_conversation(from_end_user_id: str | None, from_account_id: str | None) -> MagicMock: - conversation = MagicMock(spec=Conversation) - conversation.id = "conv-123" - conversation.app_id = "app-123" - conversation.from_end_user_id = from_end_user_id - conversation.from_account_id = from_account_id - return conversation - - -def _make_message(agent_thoughts: list[MessageAgentThought]) -> MagicMock: - message = MagicMock(spec=Message) - message.id = "msg-123" - message.conversation_id = "conv-123" - message.created_at = datetime(2024, 1, 1, tzinfo=pytz.UTC) - message.provider_response_latency = 1.23 - message.answer_tokens = 4 - message.message_tokens = 6 - message.agent_thoughts = agent_thoughts - message.message_files = ["file-a.txt"] - return message - - -def _make_agent_thought() -> MagicMock: - agent_thought = MagicMock(spec=MessageAgentThought) - agent_thought.tokens = 3 - agent_thought.tool_input = "raw-input" - agent_thought.observation = "raw-output" - agent_thought.thought = "thinking" - agent_thought.created_at = datetime(2024, 1, 1, tzinfo=pytz.UTC) - agent_thought.files = [] - agent_thought.tools = ["tool_a", "dataset_tool"] - agent_thought.tool_labels = {"tool_a": "Tool A"} - agent_thought.tool_meta = { - "tool_a": { - "tool_config": { - "tool_provider_type": "custom", - "tool_provider": "provider-1", - }, - "tool_parameters": {"param": "value"}, - "time_cost": 2.5, - }, - "dataset_tool": { - "tool_config": { - "tool_provider_type": "dataset-retrieval", - "tool_provider": "dataset-provider", - } - }, - } - agent_thought.tool_inputs_dict = {"tool_a": {"q": "hello"}, "dataset_tool": {"k": "v"}} - agent_thought.tool_outputs_dict = {"tool_a": {"result": "ok"}} - return agent_thought - - -def _build_query_side_effect( - conversation: Conversation | None, - message: Message | None, - executor: EndUser | Account | None, -) -> Callable[..., MagicMock]: - def _query_side_effect(*args: object, **kwargs: object) -> MagicMock: - query = MagicMock() - query.where.return_value = query - if any(arg is Conversation for arg in args): - query.first.return_value = conversation - elif any(arg is Message for arg in args): - query.first.return_value = message - elif any(arg is EndUser for arg in args) or any(arg is Account for arg in args): - query.first.return_value = executor - return query - - return _query_side_effect - - -class TestAgentServiceGetAgentLogs: - """Test suite for AgentService.get_agent_logs.""" - - def test_get_agent_logs_should_raise_when_conversation_missing(self) -> None: - """Test missing conversation raises ValueError.""" - # Arrange - app_model = _make_app_model(MagicMock()) - with patch("services.agent_service.db") as mock_db: - query = MagicMock() - query.where.return_value = query - query.first.return_value = None - mock_db.session.query.return_value = query - - # Act & Assert - with pytest.raises(ValueError): - AgentService.get_agent_logs(app_model, "missing-conv", "msg-1") - - def test_get_agent_logs_should_raise_when_message_missing(self) -> None: - """Test missing message raises ValueError.""" - # Arrange - app_model = _make_app_model(MagicMock()) - conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None) - with patch("services.agent_service.db") as mock_db: - conversation_query = MagicMock() - conversation_query.where.return_value = conversation_query - conversation_query.first.return_value = conversation - - message_query = MagicMock() - message_query.where.return_value = message_query - message_query.first.return_value = None - - mock_db.session.query.side_effect = [conversation_query, message_query] - - # Act & Assert - with pytest.raises(ValueError): - AgentService.get_agent_logs(app_model, conversation.id, "missing-msg") - - def test_get_agent_logs_should_raise_when_app_model_config_missing(self) -> None: - """Test missing app model config raises ValueError.""" - # Arrange - app_model = _make_app_model(None) - conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None) - message = _make_message([]) - current_user = _make_current_user_account() - - with patch("services.agent_service.db") as mock_db, patch("services.agent_service.current_user", current_user): - mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, MagicMock()) - - # Act & Assert - with pytest.raises(ValueError): - AgentService.get_agent_logs(app_model, conversation.id, message.id) - - def test_get_agent_logs_should_raise_when_agent_config_missing(self) -> None: - """Test missing agent config raises ValueError.""" - # Arrange - app_model_config = MagicMock() - app_model_config.agent_mode_dict = {"strategy": "react"} - app_model_config.to_dict.return_value = {"tools": []} - app_model = _make_app_model(app_model_config) - conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None) - message = _make_message([]) - current_user = _make_current_user_account() - - with ( - patch("services.agent_service.db") as mock_db, - patch("services.agent_service.AgentConfigManager.convert", return_value=None), - patch("services.agent_service.current_user", current_user), - ): - mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, MagicMock()) - - # Act & Assert - with pytest.raises(ValueError): - AgentService.get_agent_logs(app_model, conversation.id, message.id) - - def test_get_agent_logs_should_return_logs_for_end_user_executor(self) -> None: - """Test agent logs returned for end-user executor with tool icons.""" - # Arrange - agent_thought = _make_agent_thought() - message = _make_message([agent_thought]) - conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None) - executor = MagicMock(spec=EndUser) - executor.name = "End User" - app_model_config = MagicMock() - app_model_config.agent_mode_dict = {"strategy": "react"} - app_model_config.to_dict.return_value = {"tools": []} - app_model = _make_app_model(app_model_config) - current_user = _make_current_user_account() - agent_tool = MagicMock() - agent_tool.tool_name = "tool_a" - agent_tool.provider_type = "custom" - agent_tool.provider_id = "provider-2" - agent_config = MagicMock() - agent_config.tools = [agent_tool] - - with ( - patch("services.agent_service.db") as mock_db, - patch("services.agent_service.AgentConfigManager.convert", return_value=agent_config) as mock_convert, - patch("services.agent_service.ToolManager.get_tool_icon") as mock_get_icon, - patch("services.agent_service.current_user", current_user), - ): - mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, executor) - mock_get_icon.side_effect = [None, "icon-a"] - - # Act - result = AgentService.get_agent_logs(app_model, conversation.id, message.id) - - # Assert - assert result["meta"]["status"] == "success" - assert result["meta"]["executor"] == "End User" - assert result["meta"]["total_tokens"] == 10 - assert result["meta"]["agent_mode"] == "react" - assert result["meta"]["iterations"] == 1 - assert result["files"] == ["file-a.txt"] - assert len(result["iterations"]) == 1 - tool_calls = result["iterations"][0]["tool_calls"] - assert tool_calls[0]["tool_name"] == "tool_a" - assert tool_calls[0]["tool_icon"] == "icon-a" - assert tool_calls[1]["tool_name"] == "dataset_tool" - assert tool_calls[1]["tool_icon"] == "" - mock_convert.assert_called_once() - - def test_get_agent_logs_should_return_account_executor_when_no_end_user(self) -> None: - """Test agent logs fall back to account executor when end user is missing.""" - # Arrange - agent_thought = _make_agent_thought() - message = _make_message([agent_thought]) - conversation = _make_conversation(from_end_user_id=None, from_account_id="account-1") - executor = MagicMock(spec=Account) - executor.name = "Account User" - app_model_config = MagicMock() - app_model_config.agent_mode_dict = {"strategy": "react"} - app_model_config.to_dict.return_value = {"tools": []} - app_model = _make_app_model(app_model_config) - current_user = _make_current_user_account() - agent_config = MagicMock() - agent_config.tools = [] - - with ( - patch("services.agent_service.db") as mock_db, - patch("services.agent_service.AgentConfigManager.convert", return_value=agent_config), - patch("services.agent_service.ToolManager.get_tool_icon", return_value=""), - patch("services.agent_service.current_user", current_user), - ): - mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, executor) - - # Act - result = AgentService.get_agent_logs(app_model, conversation.id, message.id) - - # Assert - assert result["meta"]["executor"] == "Account User" - - def test_get_agent_logs_should_use_defaults_when_executor_and_tool_data_missing(self) -> None: - """Test unknown executor and missing tool details fall back to defaults.""" - # Arrange - agent_thought = _make_agent_thought() - agent_thought.tool_labels = {} - agent_thought.tool_inputs_dict = {} - agent_thought.tool_outputs_dict = None - agent_thought.tool_meta = {"tool_a": {"error": "failed"}} - agent_thought.tools = ["tool_a"] - - message = _make_message([agent_thought]) - conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None) - app_model_config = MagicMock() - app_model_config.agent_mode_dict = {} - app_model_config.to_dict.return_value = {"tools": []} - app_model = _make_app_model(app_model_config) - current_user = _make_current_user_account() - agent_config = MagicMock() - agent_config.tools = [] - - with ( - patch("services.agent_service.db") as mock_db, - patch("services.agent_service.AgentConfigManager.convert", return_value=agent_config), - patch("services.agent_service.ToolManager.get_tool_icon", return_value=None), - patch("services.agent_service.current_user", current_user), - ): - mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, None) - - # Act - result = AgentService.get_agent_logs(app_model, conversation.id, message.id) - - # Assert - assert result["meta"]["executor"] == "Unknown" - assert result["meta"]["agent_mode"] == "react" - tool_call = result["iterations"][0]["tool_calls"][0] - assert tool_call["status"] == "error" - assert tool_call["error"] == "failed" - assert tool_call["tool_label"] == "tool_a" - assert tool_call["tool_input"] == {} - assert tool_call["tool_output"] == {} - assert tool_call["time_cost"] == 0 - assert tool_call["tool_parameters"] == {} - assert tool_call["tool_icon"] is None - - -class TestAgentServiceProviders: - """Test suite for AgentService provider methods.""" - - def test_list_agent_providers_should_delegate_to_plugin_client(self) -> None: - """Test list_agent_providers delegates to PluginAgentClient.""" - # Arrange - tenant_id = "tenant-1" - expected = [{"name": "provider"}] - with patch("services.agent_service.PluginAgentClient") as mock_client: - mock_client.return_value.fetch_agent_strategy_providers.return_value = expected - - # Act - result = AgentService.list_agent_providers("user-1", tenant_id) - - # Assert - assert result == expected - mock_client.return_value.fetch_agent_strategy_providers.assert_called_once_with(tenant_id) - - def test_get_agent_provider_should_return_provider_when_successful(self) -> None: - """Test get_agent_provider returns provider when successful.""" - # Arrange - tenant_id = "tenant-1" - provider_name = "provider-a" - expected = {"name": provider_name} - with patch("services.agent_service.PluginAgentClient") as mock_client: - mock_client.return_value.fetch_agent_strategy_provider.return_value = expected - - # Act - result = AgentService.get_agent_provider("user-1", tenant_id, provider_name) - - # Assert - assert result == expected - mock_client.return_value.fetch_agent_strategy_provider.assert_called_once_with(tenant_id, provider_name) - - def test_get_agent_provider_should_raise_value_error_on_plugin_error(self) -> None: - """Test get_agent_provider wraps PluginDaemonClientSideError into ValueError.""" - # Arrange - tenant_id = "tenant-1" - provider_name = "provider-a" - with patch("services.agent_service.PluginAgentClient") as mock_client: - mock_client.return_value.fetch_agent_strategy_provider.side_effect = PluginDaemonClientSideError( - "plugin error" - ) - - # Act & Assert - with pytest.raises(ValueError): - AgentService.get_agent_provider("user-1", tenant_id, provider_name) diff --git a/api/tests/unit_tests/services/test_api_based_extension_service.py b/api/tests/unit_tests/services/test_api_based_extension_service.py deleted file mode 100644 index 7f4b5fdaa37..00000000000 --- a/api/tests/unit_tests/services/test_api_based_extension_service.py +++ /dev/null @@ -1,421 +0,0 @@ -""" -Comprehensive unit tests for services/api_based_extension_service.py - -Covers: -- APIBasedExtensionService.get_all_by_tenant_id -- APIBasedExtensionService.save -- APIBasedExtensionService.delete -- APIBasedExtensionService.get_with_tenant_id -- APIBasedExtensionService._validation (new record & existing record branches) -- APIBasedExtensionService._ping_connection (pong success, wrong response, exception) -""" - -from unittest.mock import MagicMock, patch - -import pytest - -from services.api_based_extension_service import APIBasedExtensionService - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _make_extension( - *, - id_: str | None = None, - tenant_id: str = "tenant-001", - name: str = "my-ext", - api_endpoint: str = "https://example.com/hook", - api_key: str = "secret-key-123", -) -> MagicMock: - """Return a lightweight mock that mimics APIBasedExtension.""" - ext = MagicMock() - ext.id = id_ - ext.tenant_id = tenant_id - ext.name = name - ext.api_endpoint = api_endpoint - ext.api_key = api_key - return ext - - -# --------------------------------------------------------------------------- -# Tests: get_all_by_tenant_id -# --------------------------------------------------------------------------- - - -class TestGetAllByTenantId: - """Tests for APIBasedExtensionService.get_all_by_tenant_id.""" - - @patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key") - @patch("services.api_based_extension_service.db") - def test_returns_extensions_with_decrypted_keys(self, mock_db, mock_decrypt): - """Each api_key is decrypted and the list is returned.""" - ext1 = _make_extension(id_="id-1", api_key="enc-key-1") - ext2 = _make_extension(id_="id-2", api_key="enc-key-2") - - mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [ - ext1, - ext2, - ] - - result = APIBasedExtensionService.get_all_by_tenant_id("tenant-001") - - assert result == [ext1, ext2] - assert ext1.api_key == "decrypted-key" - assert ext2.api_key == "decrypted-key" - assert mock_decrypt.call_count == 2 - - @patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key") - @patch("services.api_based_extension_service.db") - def test_returns_empty_list_when_no_extensions(self, mock_db, mock_decrypt): - """Returns an empty list gracefully when no records exist.""" - mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [] - - result = APIBasedExtensionService.get_all_by_tenant_id("tenant-001") - - assert result == [] - mock_decrypt.assert_not_called() - - @patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key") - @patch("services.api_based_extension_service.db") - def test_calls_query_with_correct_tenant_id(self, mock_db, mock_decrypt): - """Verifies the DB is queried with the supplied tenant_id.""" - mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [] - - APIBasedExtensionService.get_all_by_tenant_id("tenant-xyz") - - mock_db.session.query.return_value.filter_by.assert_called_once_with(tenant_id="tenant-xyz") - - -# --------------------------------------------------------------------------- -# Tests: save -# --------------------------------------------------------------------------- - - -class TestSave: - """Tests for APIBasedExtensionService.save.""" - - @patch("services.api_based_extension_service.encrypt_token", return_value="encrypted-key") - @patch("services.api_based_extension_service.db") - @patch.object(APIBasedExtensionService, "_validation") - def test_save_new_record_encrypts_key_and_commits(self, mock_validation, mock_db, mock_encrypt): - """Happy path: validation passes, key is encrypted, record is added and committed.""" - ext = _make_extension(id_=None, api_key="plain-key-123") - - result = APIBasedExtensionService.save(ext) - - mock_validation.assert_called_once_with(ext) - mock_encrypt.assert_called_once_with(ext.tenant_id, "plain-key-123") - assert ext.api_key == "encrypted-key" - mock_db.session.add.assert_called_once_with(ext) - mock_db.session.commit.assert_called_once() - assert result is ext - - @patch("services.api_based_extension_service.encrypt_token", return_value="encrypted-key") - @patch("services.api_based_extension_service.db") - @patch.object(APIBasedExtensionService, "_validation", side_effect=ValueError("name must not be empty")) - def test_save_raises_when_validation_fails(self, mock_validation, mock_db, mock_encrypt): - """If _validation raises, save should propagate the error without touching the DB.""" - ext = _make_extension(name="") - - with pytest.raises(ValueError, match="name must not be empty"): - APIBasedExtensionService.save(ext) - - mock_db.session.add.assert_not_called() - mock_db.session.commit.assert_not_called() - - -# --------------------------------------------------------------------------- -# Tests: delete -# --------------------------------------------------------------------------- - - -class TestDelete: - """Tests for APIBasedExtensionService.delete.""" - - @patch("services.api_based_extension_service.db") - def test_delete_removes_record_and_commits(self, mock_db): - """delete() must call session.delete with the extension and then commit.""" - ext = _make_extension(id_="delete-me") - - APIBasedExtensionService.delete(ext) - - mock_db.session.delete.assert_called_once_with(ext) - mock_db.session.commit.assert_called_once() - - -# --------------------------------------------------------------------------- -# Tests: get_with_tenant_id -# --------------------------------------------------------------------------- - - -class TestGetWithTenantId: - """Tests for APIBasedExtensionService.get_with_tenant_id.""" - - @patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key") - @patch("services.api_based_extension_service.db") - def test_returns_extension_with_decrypted_key(self, mock_db, mock_decrypt): - """Found extension has its api_key decrypted before being returned.""" - ext = _make_extension(id_="ext-123", api_key="enc-key") - - (mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value) = ext - - result = APIBasedExtensionService.get_with_tenant_id("tenant-001", "ext-123") - - assert result is ext - assert ext.api_key == "decrypted-key" - mock_decrypt.assert_called_once_with(ext.tenant_id, "enc-key") - - @patch("services.api_based_extension_service.db") - def test_raises_value_error_when_not_found(self, mock_db): - """Raises ValueError when no matching extension exists.""" - (mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value) = None - - with pytest.raises(ValueError, match="API based extension is not found"): - APIBasedExtensionService.get_with_tenant_id("tenant-001", "non-existent") - - @patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key") - @patch("services.api_based_extension_service.db") - def test_queries_with_correct_tenant_and_extension_id(self, mock_db, mock_decrypt): - """Verifies both tenant_id and extension id are used in the query.""" - ext = _make_extension(id_="ext-abc") - chain = mock_db.session.query.return_value - chain.filter_by.return_value.filter_by.return_value.first.return_value = ext - - APIBasedExtensionService.get_with_tenant_id("tenant-002", "ext-abc") - - # First filter_by call uses tenant_id - chain.filter_by.assert_called_once_with(tenant_id="tenant-002") - # Second filter_by call uses id - chain.filter_by.return_value.filter_by.assert_called_once_with(id="ext-abc") - - -# --------------------------------------------------------------------------- -# Tests: _validation (new record — id is falsy) -# --------------------------------------------------------------------------- - - -class TestValidationNewRecord: - """Tests for _validation() with a brand-new record (no id).""" - - def _build_mock_db(self, name_exists: bool = False): - mock_db = MagicMock() - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = ( - MagicMock() if name_exists else None - ) - return mock_db - - @patch.object(APIBasedExtensionService, "_ping_connection") - @patch("services.api_based_extension_service.db") - def test_valid_new_extension_passes(self, mock_db, mock_ping): - """A new record with all valid fields should pass without exceptions.""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None - ext = _make_extension(id_=None, name="valid-ext", api_key="longenoughkey") - - # Should not raise - APIBasedExtensionService._validation(ext) - mock_ping.assert_called_once_with(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_name_is_empty(self, mock_db): - """Empty name raises ValueError.""" - ext = _make_extension(id_=None, name="") - with pytest.raises(ValueError, match="name must not be empty"): - APIBasedExtensionService._validation(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_name_is_none(self, mock_db): - """None name raises ValueError.""" - ext = _make_extension(id_=None, name=None) - with pytest.raises(ValueError, match="name must not be empty"): - APIBasedExtensionService._validation(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_name_already_exists_for_new_record(self, mock_db): - """A new record whose name already exists raises ValueError.""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = ( - MagicMock() - ) - ext = _make_extension(id_=None, name="duplicate-name") - - with pytest.raises(ValueError, match="name must be unique, it is already existed"): - APIBasedExtensionService._validation(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_api_endpoint_is_empty(self, mock_db): - """Empty api_endpoint raises ValueError.""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None - ext = _make_extension(id_=None, api_endpoint="") - - with pytest.raises(ValueError, match="api_endpoint must not be empty"): - APIBasedExtensionService._validation(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_api_endpoint_is_none(self, mock_db): - """None api_endpoint raises ValueError.""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None - ext = _make_extension(id_=None, api_endpoint=None) - - with pytest.raises(ValueError, match="api_endpoint must not be empty"): - APIBasedExtensionService._validation(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_api_key_is_empty(self, mock_db): - """Empty api_key raises ValueError.""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None - ext = _make_extension(id_=None, api_key="") - - with pytest.raises(ValueError, match="api_key must not be empty"): - APIBasedExtensionService._validation(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_api_key_is_none(self, mock_db): - """None api_key raises ValueError.""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None - ext = _make_extension(id_=None, api_key=None) - - with pytest.raises(ValueError, match="api_key must not be empty"): - APIBasedExtensionService._validation(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_api_key_too_short(self, mock_db): - """api_key shorter than 5 characters raises ValueError.""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None - ext = _make_extension(id_=None, api_key="abc") - - with pytest.raises(ValueError, match="api_key must be at least 5 characters"): - APIBasedExtensionService._validation(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_api_key_exactly_four_chars(self, mock_db): - """api_key with exactly 4 characters raises ValueError (boundary condition).""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None - ext = _make_extension(id_=None, api_key="1234") - - with pytest.raises(ValueError, match="api_key must be at least 5 characters"): - APIBasedExtensionService._validation(ext) - - @patch.object(APIBasedExtensionService, "_ping_connection") - @patch("services.api_based_extension_service.db") - def test_api_key_exactly_five_chars_is_accepted(self, mock_db, mock_ping): - """api_key with exactly 5 characters should pass (boundary condition).""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None - ext = _make_extension(id_=None, api_key="12345") - - # Should not raise - APIBasedExtensionService._validation(ext) - - -# --------------------------------------------------------------------------- -# Tests: _validation (existing record — id is truthy) -# --------------------------------------------------------------------------- - - -class TestValidationExistingRecord: - """Tests for _validation() with an existing record (id is set).""" - - @patch.object(APIBasedExtensionService, "_ping_connection") - @patch("services.api_based_extension_service.db") - def test_valid_existing_extension_passes(self, mock_db, mock_ping): - """An existing record whose name is unique (excluding self) should pass.""" - # .where(...).first() → None means no *other* record has that name - ( - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.where.return_value.first.return_value - ) = None - ext = _make_extension(id_="existing-id", name="unique-name", api_key="longenoughkey") - - # Should not raise - APIBasedExtensionService._validation(ext) - mock_ping.assert_called_once_with(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_existing_record_name_conflicts_with_another(self, mock_db): - """Existing record cannot use a name already owned by a different record.""" - ( - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.where.return_value.first.return_value - ) = MagicMock() - ext = _make_extension(id_="existing-id", name="taken-name") - - with pytest.raises(ValueError, match="name must be unique, it is already existed"): - APIBasedExtensionService._validation(ext) - - -# --------------------------------------------------------------------------- -# Tests: _ping_connection -# --------------------------------------------------------------------------- - - -class TestPingConnection: - """Tests for APIBasedExtensionService._ping_connection.""" - - @patch("services.api_based_extension_service.APIBasedExtensionRequestor") - def test_successful_ping_returns_pong(self, mock_requestor_class): - """When the endpoint returns {"result": "pong"}, no exception is raised.""" - mock_client = MagicMock() - mock_client.request.return_value = {"result": "pong"} - mock_requestor_class.return_value = mock_client - - ext = _make_extension(api_endpoint="https://ok.example.com", api_key="secret-key") - # Should not raise - APIBasedExtensionService._ping_connection(ext) - - mock_requestor_class.assert_called_once_with(ext.api_endpoint, ext.api_key) - - @patch("services.api_based_extension_service.APIBasedExtensionRequestor") - def test_wrong_ping_response_raises_value_error(self, mock_requestor_class): - """When the response is not {"result": "pong"}, a ValueError is raised.""" - mock_client = MagicMock() - mock_client.request.return_value = {"result": "error"} - mock_requestor_class.return_value = mock_client - - ext = _make_extension() - with pytest.raises(ValueError, match="connection error"): - APIBasedExtensionService._ping_connection(ext) - - @patch("services.api_based_extension_service.APIBasedExtensionRequestor") - def test_network_exception_wraps_in_value_error(self, mock_requestor_class): - """Any exception raised during request is wrapped in a ValueError.""" - mock_client = MagicMock() - mock_client.request.side_effect = ConnectionError("network failure") - mock_requestor_class.return_value = mock_client - - ext = _make_extension() - with pytest.raises(ValueError, match="connection error: network failure"): - APIBasedExtensionService._ping_connection(ext) - - @patch("services.api_based_extension_service.APIBasedExtensionRequestor") - def test_requestor_constructor_exception_wraps_in_value_error(self, mock_requestor_class): - """Exception raised by the requestor constructor itself is wrapped.""" - mock_requestor_class.side_effect = RuntimeError("bad config") - - ext = _make_extension() - with pytest.raises(ValueError, match="connection error: bad config"): - APIBasedExtensionService._ping_connection(ext) - - @patch("services.api_based_extension_service.APIBasedExtensionRequestor") - def test_missing_result_key_raises_value_error(self, mock_requestor_class): - """A response dict without a 'result' key does not equal 'pong' → raises.""" - mock_client = MagicMock() - mock_client.request.return_value = {} # no 'result' key - mock_requestor_class.return_value = mock_client - - ext = _make_extension() - with pytest.raises(ValueError, match="connection error"): - APIBasedExtensionService._ping_connection(ext) - - @patch("services.api_based_extension_service.APIBasedExtensionRequestor") - def test_uses_ping_extension_point(self, mock_requestor_class): - """The PING extension point is passed to the client.request call.""" - from models.api_based_extension import APIBasedExtensionPoint - - mock_client = MagicMock() - mock_client.request.return_value = {"result": "pong"} - mock_requestor_class.return_value = mock_client - - ext = _make_extension() - APIBasedExtensionService._ping_connection(ext) - - call_kwargs = mock_client.request.call_args - assert call_kwargs.kwargs["point"] == APIBasedExtensionPoint.PING - assert call_kwargs.kwargs["params"] == {} diff --git a/api/tests/unit_tests/services/test_app_dsl_service.py b/api/tests/unit_tests/services/test_app_dsl_service.py index 4f7d1840460..179518a5fad 100644 --- a/api/tests/unit_tests/services/test_app_dsl_service.py +++ b/api/tests/unit_tests/services/test_app_dsl_service.py @@ -4,13 +4,13 @@ from unittest.mock import MagicMock import pytest import yaml +from graphon.enums import BuiltinNodeTypes from core.trigger.constants import ( TRIGGER_PLUGIN_NODE_TYPE, TRIGGER_SCHEDULE_NODE_TYPE, TRIGGER_WEBHOOK_NODE_TYPE, ) -from dify_graph.enums import BuiltinNodeTypes from models import Account, AppMode from models.model import IconType from services import app_dsl_service @@ -211,6 +211,7 @@ def test_import_app_overwrite_only_allows_workflow_and_advanced_chat(monkeypatch def test_import_app_pending_stores_import_info_in_redis(): service = AppDslService(MagicMock()) + app_dsl_service.redis_client.setex.reset_mock() result = service.import_app( account=_account_mock(), import_mode=ImportMode.YAML_CONTENT, @@ -375,10 +376,13 @@ def test_confirm_import_success_deletes_redis_key(monkeypatch): created_app = SimpleNamespace(id="confirmed-app", mode=AppMode.WORKFLOW.value, tenant_id="tenant-1") monkeypatch.setattr(AppDslService, "_create_or_update_app", lambda *_args, **_kwargs: created_app) + app_dsl_service.redis_client.delete.reset_mock() result = service.confirm_import(import_id="import-1", account=_account_mock()) assert result.status == ImportStatus.COMPLETED assert result.app_id == "confirmed-app" - app_dsl_service.redis_client.delete.assert_called_once() + app_dsl_service.redis_client.delete.assert_called_once_with( + f"{app_dsl_service.IMPORT_INFO_REDIS_KEY_PREFIX}import-1" + ) def test_confirm_import_exception_returns_failed(monkeypatch): diff --git a/api/tests/unit_tests/services/test_app_service.py b/api/tests/unit_tests/services/test_app_service.py deleted file mode 100644 index bff8dc92c6e..00000000000 --- a/api/tests/unit_tests/services/test_app_service.py +++ /dev/null @@ -1,609 +0,0 @@ -"""Unit tests for services.app_service.""" - -import json -from types import SimpleNamespace -from typing import cast -from unittest.mock import MagicMock, patch - -import pytest - -from core.errors.error import ProviderTokenNotInitError -from models import Account, Tenant -from models.model import App, AppMode -from services.app_service import AppService - - -@pytest.fixture -def service() -> AppService: - """Provide AppService instance.""" - return AppService() - - -@pytest.fixture -def account() -> Account: - """Create account object for create_app tests.""" - tenant = Tenant(name="Tenant") - tenant.id = "tenant-1" - result = Account(name="Account User", email="account@example.com") - result.id = "acc-1" - result._current_tenant = tenant - return result - - -@pytest.fixture -def default_args() -> dict: - """Create default create_app args.""" - return { - "name": "Test App", - "mode": AppMode.CHAT.value, - "icon": "🤖", - "icon_background": "#FFFFFF", - } - - -@pytest.fixture -def app_template() -> dict: - """Create basic app template for create_app tests.""" - return { - AppMode.CHAT: { - "app": {}, - "model_config": { - "model": { - "provider": "provider-a", - "name": "model-a", - "mode": "chat", - "completion_params": {}, - } - }, - } - } - - -def _make_current_user() -> Account: - user = Account(name="Tester", email="tester@example.com") - user.id = "user-1" - tenant = Tenant(name="Tenant") - tenant.id = "tenant-1" - user._current_tenant = tenant - return user - - -class TestAppServicePagination: - """Test suite for get_paginate_apps.""" - - def test_get_paginate_apps_should_return_none_when_tag_filter_empty(self, service: AppService) -> None: - """Test pagination returns None when tag filter has no targets.""" - # Arrange - args = {"mode": "chat", "page": 1, "limit": 20, "tag_ids": ["tag-1"]} - - with patch("services.app_service.TagService.get_target_ids_by_tag_ids", return_value=[]): - # Act - result = service.get_paginate_apps("user-1", "tenant-1", args) - - # Assert - assert result is None - - def test_get_paginate_apps_should_delegate_to_db_paginate(self, service: AppService) -> None: - """Test pagination delegates to db.paginate when filters are valid.""" - # Arrange - args = { - "mode": "workflow", - "is_created_by_me": True, - "name": "My_App%", - "tag_ids": ["tag-1"], - "page": 2, - "limit": 10, - } - expected_pagination = MagicMock() - - with ( - patch("services.app_service.TagService.get_target_ids_by_tag_ids", return_value=["app-1"]), - patch("libs.helper.escape_like_pattern", return_value="escaped"), - patch("services.app_service.db") as mock_db, - ): - mock_db.paginate.return_value = expected_pagination - - # Act - result = service.get_paginate_apps("user-1", "tenant-1", args) - - # Assert - assert result is expected_pagination - mock_db.paginate.assert_called_once() - - -class TestAppServiceCreate: - """Test suite for create_app.""" - - def test_create_app_should_create_with_matching_default_model( - self, - service: AppService, - account: Account, - default_args: dict, - app_template: dict, - ) -> None: - """Test create_app uses matching default model and persists app config.""" - # Arrange - app_instance = SimpleNamespace(id="app-1", tenant_id="tenant-1") - app_model_config = SimpleNamespace(id="cfg-1") - model_instance = SimpleNamespace( - model_name="model-a", - provider="provider-a", - model_type_instance=MagicMock(), - credentials={"k": "v"}, - ) - - with ( - patch("services.app_service.default_app_templates", app_template), - patch("services.app_service.App", return_value=app_instance), - patch("services.app_service.AppModelConfig", return_value=app_model_config), - patch("services.app_service.ModelManager") as mock_model_manager, - patch("services.app_service.db") as mock_db, - patch("services.app_service.app_was_created") as mock_event, - patch("services.app_service.FeatureService.get_system_features") as mock_features, - patch("services.app_service.BillingService") as mock_billing, - patch("services.app_service.dify_config") as mock_config, - ): - manager = mock_model_manager.return_value - manager.get_default_model_instance.return_value = model_instance - mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) - mock_config.BILLING_ENABLED = True - - # Act - result = service.create_app("tenant-1", default_args, account) - - # Assert - assert result is app_instance - assert app_instance.app_model_config_id == "cfg-1" - mock_db.session.add.assert_any_call(app_instance) - mock_db.session.add.assert_any_call(app_model_config) - assert mock_db.session.flush.call_count == 2 - mock_db.session.commit.assert_called_once() - mock_event.send.assert_called_once_with(app_instance, account=account) - mock_billing.clean_billing_info_cache.assert_called_once_with("tenant-1") - - def test_create_app_should_raise_when_model_schema_missing( - self, - service: AppService, - account: Account, - default_args: dict, - app_template: dict, - ) -> None: - """Test create_app raises ValueError when non-matching model has no schema.""" - # Arrange - app_instance = SimpleNamespace(id="app-1") - model_instance = SimpleNamespace( - model_name="model-b", - provider="provider-b", - model_type_instance=MagicMock(), - credentials={"k": "v"}, - ) - model_instance.model_type_instance.get_model_schema.return_value = None - - with ( - patch("services.app_service.default_app_templates", app_template), - patch("services.app_service.App", return_value=app_instance), - patch("services.app_service.ModelManager") as mock_model_manager, - patch("services.app_service.db") as mock_db, - ): - manager = mock_model_manager.return_value - manager.get_default_model_instance.return_value = model_instance - - # Act & Assert - with pytest.raises(ValueError, match="model schema not found"): - service.create_app("tenant-1", default_args, account) - mock_db.session.commit.assert_not_called() - - def test_create_app_should_fallback_to_default_provider_when_model_missing( - self, - service: AppService, - account: Account, - default_args: dict, - app_template: dict, - ) -> None: - """Test create_app falls back to provider/model name when no default model instance is available.""" - # Arrange - app_instance = SimpleNamespace(id="app-1", tenant_id="tenant-1") - app_model_config = SimpleNamespace(id="cfg-1") - - with ( - patch("services.app_service.default_app_templates", app_template), - patch("services.app_service.App", return_value=app_instance), - patch("services.app_service.AppModelConfig", return_value=app_model_config), - patch("services.app_service.ModelManager") as mock_model_manager, - patch("services.app_service.db") as mock_db, - patch("services.app_service.app_was_created") as mock_event, - patch("services.app_service.FeatureService.get_system_features") as mock_features, - patch("services.app_service.EnterpriseService") as mock_enterprise, - patch("services.app_service.dify_config") as mock_config, - ): - manager = mock_model_manager.return_value - manager.get_default_model_instance.side_effect = ProviderTokenNotInitError("not ready") - manager.get_default_provider_model_name.return_value = ("fallback-provider", "fallback-model") - mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True)) - mock_config.BILLING_ENABLED = False - - # Act - result = service.create_app("tenant-1", default_args, account) - - # Assert - assert result is app_instance - mock_event.send.assert_called_once_with(app_instance, account=account) - mock_db.session.commit.assert_called_once() - mock_enterprise.WebAppAuth.update_app_access_mode.assert_called_once_with("app-1", "private") - - def test_create_app_should_log_and_fallback_on_unexpected_model_error( - self, - service: AppService, - account: Account, - default_args: dict, - app_template: dict, - ) -> None: - """Test unexpected model manager errors are logged and fallback provider is used.""" - # Arrange - app_instance = SimpleNamespace(id="app-1", tenant_id="tenant-1") - app_model_config = SimpleNamespace(id="cfg-1") - - with ( - patch("services.app_service.default_app_templates", app_template), - patch("services.app_service.App", return_value=app_instance), - patch("services.app_service.AppModelConfig", return_value=app_model_config), - patch("services.app_service.ModelManager") as mock_model_manager, - patch("services.app_service.db"), - patch("services.app_service.app_was_created"), - patch( - "services.app_service.FeatureService.get_system_features", - return_value=SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)), - ), - patch("services.app_service.dify_config", new=SimpleNamespace(BILLING_ENABLED=False)), - patch("services.app_service.logger") as mock_logger, - ): - manager = mock_model_manager.return_value - manager.get_default_model_instance.side_effect = RuntimeError("boom") - manager.get_default_provider_model_name.return_value = ("fallback-provider", "fallback-model") - - # Act - result = service.create_app("tenant-1", default_args, account) - - # Assert - assert result is app_instance - mock_logger.exception.assert_called_once() - - -class TestAppServiceGetAndUpdate: - """Test suite for app retrieval and update methods.""" - - def test_get_app_should_return_original_when_not_agent_app(self, service: AppService) -> None: - """Test get_app returns original app for non-agent modes.""" - # Arrange - app = MagicMock() - app.mode = AppMode.CHAT - app.is_agent = False - - with patch("services.app_service.current_user", _make_current_user()): - # Act - result = service.get_app(app) - - # Assert - assert result is app - - def test_get_app_should_return_original_when_model_config_missing(self, service: AppService) -> None: - """Test get_app returns app when agent mode has no model config.""" - # Arrange - app = MagicMock() - app.id = "app-1" - app.mode = AppMode.AGENT_CHAT - app.is_agent = False - app.app_model_config = None - - with patch("services.app_service.current_user", _make_current_user()): - # Act - result = service.get_app(app) - - # Assert - assert result is app - - def test_get_app_should_mask_tool_parameters_for_agent_tools(self, service: AppService) -> None: - """Test get_app decrypts and masks secret tool parameters.""" - # Arrange - tool = { - "provider_type": "builtin", - "provider_id": "provider-1", - "tool_name": "tool-a", - "tool_parameters": {"secret": "encrypted"}, - "extra": True, - } - model_config = MagicMock() - model_config.agent_mode_dict = {"tools": [tool, {"skip": True}]} - - app = MagicMock() - app.id = "app-1" - app.mode = AppMode.AGENT_CHAT - app.is_agent = False - app.app_model_config = model_config - - manager = MagicMock() - manager.decrypt_tool_parameters.return_value = {"secret": "decrypted"} - manager.mask_tool_parameters.return_value = {"secret": "***"} - - with ( - patch("services.app_service.current_user", _make_current_user()), - patch("services.app_service.ToolManager.get_agent_tool_runtime", return_value=MagicMock()), - patch("services.app_service.ToolParameterConfigurationManager", return_value=manager), - ): - # Act - result = service.get_app(app) - - # Assert - assert result.app_model_config is model_config - assert tool["tool_parameters"] == {"secret": "***"} - assert json.loads(model_config.agent_mode)["tools"][0]["tool_parameters"] == {"secret": "***"} - - def test_get_app_should_continue_when_tool_parameter_masking_fails(self, service: AppService) -> None: - """Test get_app logs and continues when masking fails.""" - # Arrange - tool = { - "provider_type": "builtin", - "provider_id": "provider-1", - "tool_name": "tool-a", - "tool_parameters": {"secret": "encrypted"}, - "extra": True, - } - model_config = MagicMock() - model_config.agent_mode_dict = {"tools": [tool]} - - app = MagicMock() - app.id = "app-1" - app.mode = AppMode.AGENT_CHAT - app.is_agent = False - app.app_model_config = model_config - - with ( - patch("services.app_service.current_user", _make_current_user()), - patch("services.app_service.ToolManager.get_agent_tool_runtime", side_effect=RuntimeError("mask-failed")), - patch("services.app_service.logger") as mock_logger, - ): - # Act - result = service.get_app(app) - - # Assert - assert result.app_model_config is model_config - mock_logger.exception.assert_called_once() - - def test_update_methods_should_mutate_app_and_commit(self, service: AppService) -> None: - """Test update methods set fields and commit changes.""" - # Arrange - app = cast( - App, - SimpleNamespace( - name="old", - description="old", - icon_type="emoji", - icon="a", - icon_background="#111", - enable_site=True, - enable_api=True, - ), - ) - args = { - "name": "new", - "description": "new-desc", - "icon_type": "image", - "icon": "new-icon", - "icon_background": "#222", - "use_icon_as_answer_icon": True, - "max_active_requests": 5, - } - user = SimpleNamespace(id="user-1") - - with ( - patch("services.app_service.current_user", user), - patch("services.app_service.db") as mock_db, - patch("services.app_service.naive_utc_now", return_value="now"), - ): - # Act - updated = service.update_app(app, args) - renamed = service.update_app_name(app, "rename") - iconed = service.update_app_icon(app, "icon-2", "#333") - site_same = service.update_app_site_status(app, app.enable_site) - api_same = service.update_app_api_status(app, app.enable_api) - site_changed = service.update_app_site_status(app, False) - api_changed = service.update_app_api_status(app, False) - - # Assert - assert updated is app - assert renamed is app - assert iconed is app - assert site_same is app - assert api_same is app - assert site_changed is app - assert api_changed is app - assert mock_db.session.commit.call_count >= 5 - - -class TestAppServiceDeleteAndMeta: - """Test suite for delete and metadata methods.""" - - def test_delete_app_should_cleanup_and_enqueue_task(self, service: AppService) -> None: - """Test delete_app removes app, runs cleanup, and triggers async deletion task.""" - # Arrange - app = cast(App, SimpleNamespace(id="app-1", tenant_id="tenant-1")) - - with ( - patch("services.app_service.db") as mock_db, - patch( - "services.app_service.FeatureService.get_system_features", - return_value=SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True)), - ), - patch("services.app_service.EnterpriseService") as mock_enterprise, - patch( - "services.app_service.dify_config", - new=SimpleNamespace(BILLING_ENABLED=True, CONSOLE_API_URL="https://console.example"), - ), - patch("services.app_service.BillingService") as mock_billing, - patch("services.app_service.remove_app_and_related_data_task") as mock_task, - ): - # Act - service.delete_app(app) - - # Assert - mock_db.session.delete.assert_called_once_with(app) - mock_db.session.commit.assert_called_once() - mock_enterprise.WebAppAuth.cleanup_webapp.assert_called_once_with("app-1") - mock_billing.clean_billing_info_cache.assert_called_once_with("tenant-1") - mock_task.delay.assert_called_once_with(tenant_id="tenant-1", app_id="app-1") - - def test_get_app_meta_should_handle_workflow_and_tool_provider_icons(self, service: AppService) -> None: - """Test get_app_meta extracts builtin and API tool icons from workflow graph.""" - # Arrange - workflow = SimpleNamespace( - graph_dict={ - "nodes": [ - { - "data": { - "type": "tool", - "provider_type": "builtin", - "provider_id": "builtin-provider", - "tool_name": "tool_builtin", - } - }, - { - "data": { - "type": "tool", - "provider_type": "api", - "provider_id": "api-provider-id", - "tool_name": "tool_api", - } - }, - ] - } - ) - app = cast( - App, - SimpleNamespace( - mode=AppMode.WORKFLOW.value, - workflow=workflow, - app_model_config=None, - tenant_id="tenant-1", - icon_type="emoji", - icon_background="#fff", - ), - ) - - provider = SimpleNamespace(icon=json.dumps({"background": "#000", "content": "A"})) - - with ( - patch("services.app_service.dify_config", new=SimpleNamespace(CONSOLE_API_URL="https://console.example")), - patch("services.app_service.db") as mock_db, - ): - query = MagicMock() - query.where.return_value = query - query.first.return_value = provider - mock_db.session.query.return_value = query - - # Act - meta = service.get_app_meta(app) - - # Assert - assert meta["tool_icons"]["tool_builtin"].endswith("/builtin-provider/icon") - assert meta["tool_icons"]["tool_api"] == {"background": "#000", "content": "A"} - - def test_get_app_meta_should_use_default_api_icon_on_lookup_error(self, service: AppService) -> None: - """Test get_app_meta falls back to default icon when API provider lookup fails.""" - # Arrange - app_model_config = SimpleNamespace( - agent_mode_dict={ - "tools": [{"provider_type": "api", "provider_id": "x", "tool_name": "t", "tool_parameters": {}}] - } - ) - app = cast(App, SimpleNamespace(mode=AppMode.CHAT.value, app_model_config=app_model_config, workflow=None)) - - with ( - patch("services.app_service.dify_config", new=SimpleNamespace(CONSOLE_API_URL="https://console.example")), - patch("services.app_service.db") as mock_db, - ): - query = MagicMock() - query.where.return_value = query - query.first.return_value = None - mock_db.session.query.return_value = query - - # Act - meta = service.get_app_meta(app) - - # Assert - assert meta["tool_icons"]["t"] == {"background": "#252525", "content": "\ud83d\ude01"} - - def test_get_app_meta_should_return_empty_when_required_data_missing(self, service: AppService) -> None: - """Test get_app_meta returns empty metadata when workflow/model config is absent.""" - # Arrange - workflow_app = cast(App, SimpleNamespace(mode=AppMode.WORKFLOW.value, workflow=None)) - chat_app = cast(App, SimpleNamespace(mode=AppMode.CHAT.value, app_model_config=None)) - - # Act - workflow_meta = service.get_app_meta(workflow_app) - chat_meta = service.get_app_meta(chat_app) - - # Assert - assert workflow_meta == {"tool_icons": {}} - assert chat_meta == {"tool_icons": {}} - - -class TestAppServiceCodeLookup: - """Test suite for app code lookup methods.""" - - def test_get_app_code_by_id_should_raise_when_site_missing(self) -> None: - """Test get_app_code_by_id raises when site is missing.""" - # Arrange - with patch("services.app_service.db") as mock_db: - query = MagicMock() - query.where.return_value = query - query.first.return_value = None - mock_db.session.query.return_value = query - - # Act & Assert - with pytest.raises(ValueError, match="not found"): - AppService.get_app_code_by_id("app-1") - - def test_get_app_code_by_id_should_return_code(self) -> None: - """Test get_app_code_by_id returns site code.""" - # Arrange - site = SimpleNamespace(code="code-1") - with patch("services.app_service.db") as mock_db: - query = MagicMock() - query.where.return_value = query - query.first.return_value = site - mock_db.session.query.return_value = query - - # Act - result = AppService.get_app_code_by_id("app-1") - - # Assert - assert result == "code-1" - - def test_get_app_id_by_code_should_raise_when_site_missing(self) -> None: - """Test get_app_id_by_code raises when code does not exist.""" - # Arrange - with patch("services.app_service.db") as mock_db: - query = MagicMock() - query.where.return_value = query - query.first.return_value = None - mock_db.session.query.return_value = query - - # Act & Assert - with pytest.raises(ValueError, match="not found"): - AppService.get_app_id_by_code("missing") - - def test_get_app_id_by_code_should_return_app_id(self) -> None: - """Test get_app_id_by_code returns linked app id.""" - # Arrange - site = SimpleNamespace(app_id="app-1") - with patch("services.app_service.db") as mock_db: - query = MagicMock() - query.where.return_value = query - query.first.return_value = site - mock_db.session.query.return_value = query - - # Act - result = AppService.get_app_id_by_code("code-1") - - # Assert - assert result == "app-1" diff --git a/api/tests/unit_tests/services/test_attachment_service.py b/api/tests/unit_tests/services/test_attachment_service.py deleted file mode 100644 index 88be20bc41e..00000000000 --- a/api/tests/unit_tests/services/test_attachment_service.py +++ /dev/null @@ -1,73 +0,0 @@ -import base64 -from unittest.mock import MagicMock, patch - -import pytest -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker -from werkzeug.exceptions import NotFound - -import services.attachment_service as attachment_service_module -from models.model import UploadFile -from services.attachment_service import AttachmentService - - -class TestAttachmentService: - def test_should_initialize_with_sessionmaker_when_sessionmaker_is_provided(self): - """Test that AttachmentService keeps the provided sessionmaker instance.""" - session_factory = sessionmaker() - - service = AttachmentService(session_factory=session_factory) - - assert service._session_maker is session_factory - - def test_should_initialize_with_bound_sessionmaker_when_engine_is_provided(self): - """Test that AttachmentService builds a sessionmaker bound to the provided engine.""" - engine = create_engine("sqlite:///:memory:") - - service = AttachmentService(session_factory=engine) - session = service._session_maker() - try: - assert session.bind == engine - finally: - session.close() - engine.dispose() - - @pytest.mark.parametrize("invalid_session_factory", [None, "not-a-session-factory", 1]) - def test_should_raise_assertion_error_when_session_factory_type_is_invalid(self, invalid_session_factory): - """Test that invalid session_factory types are rejected.""" - with pytest.raises(AssertionError, match="must be a sessionmaker or an Engine."): - AttachmentService(session_factory=invalid_session_factory) - - def test_should_return_base64_encoded_blob_when_file_exists(self): - """Test that existing files are loaded from storage and returned as base64.""" - service = AttachmentService(session_factory=sessionmaker()) - upload_file = MagicMock(spec=UploadFile) - upload_file.key = "upload-file-key" - - session = MagicMock() - session.query.return_value.where.return_value.first.return_value = upload_file - service._session_maker = MagicMock(return_value=session) - - with patch.object(attachment_service_module.storage, "load_once", return_value=b"binary-content") as mock_load: - result = service.get_file_base64("file-123") - - assert result == base64.b64encode(b"binary-content").decode() - service._session_maker.assert_called_once_with(expire_on_commit=False) - session.query.assert_called_once_with(UploadFile) - mock_load.assert_called_once_with("upload-file-key") - - def test_should_raise_not_found_when_file_does_not_exist(self): - """Test that missing files raise NotFound and never call storage.""" - service = AttachmentService(session_factory=sessionmaker()) - - session = MagicMock() - session.query.return_value.where.return_value.first.return_value = None - service._session_maker = MagicMock(return_value=session) - - with patch.object(attachment_service_module.storage, "load_once") as mock_load: - with pytest.raises(NotFound, match="File not found"): - service.get_file_base64("missing-file") - - service._session_maker.assert_called_once_with(expire_on_commit=False) - session.query.assert_called_once_with(UploadFile) - mock_load.assert_not_called() diff --git a/api/tests/unit_tests/services/test_audio_service.py b/api/tests/unit_tests/services/test_audio_service.py index 5d674691057..175fd3ee016 100644 --- a/api/tests/unit_tests/services/test_audio_service.py +++ b/api/tests/unit_tests/services/test_audio_service.py @@ -214,7 +214,7 @@ def factory(): class TestAudioServiceASR: """Test speech-to-text (ASR) operations.""" - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_asr_success_chat_mode(self, mock_model_manager_class, factory): """Test successful ASR transcription in CHAT mode.""" # Arrange @@ -237,10 +237,9 @@ class TestAudioServiceASR: # Assert assert result == {"text": "Transcribed text"} mock_model_instance.invoke_speech2text.assert_called_once() - call_args = mock_model_instance.invoke_speech2text.call_args - assert call_args.kwargs["user"] == "user-123" + mock_model_manager_class.assert_called_once_with(tenant_id=app.tenant_id, user_id="user-123") - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_asr_success_advanced_chat_mode(self, mock_model_manager_class, factory): """Test successful ASR transcription in ADVANCED_CHAT mode.""" # Arrange @@ -347,7 +346,7 @@ class TestAudioServiceASR: with pytest.raises(AudioTooLargeServiceError, match="Audio size larger than 30 mb"): AudioService.transcript_asr(app_model=app, file=file) - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_asr_raises_error_when_no_model_instance(self, mock_model_manager_class, factory): """Test that ASR raises error when no model instance is available.""" # Arrange @@ -370,7 +369,7 @@ class TestAudioServiceASR: class TestAudioServiceTTS: """Test text-to-speech (TTS) operations.""" - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_with_text_success(self, mock_model_manager_class, factory): """Test successful TTS with text input.""" # Arrange @@ -398,15 +397,14 @@ class TestAudioServiceTTS: # Assert assert result == b"audio data" + mock_model_manager_class.assert_called_once_with(tenant_id=app.tenant_id, user_id="user-123") mock_model_instance.invoke_tts.assert_called_once_with( content_text="Hello world", - user="user-123", - tenant_id=app.tenant_id, voice="en-US-Neural", ) @patch("services.audio_service.db.session", autospec=True) - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_with_message_id_success(self, mock_model_manager_class, mock_db_session, factory): """Test successful TTS with message ID.""" # Arrange @@ -445,7 +443,7 @@ class TestAudioServiceTTS: assert result == b"audio from message" mock_model_instance.invoke_tts.assert_called_once() - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_with_default_voice(self, mock_model_manager_class, factory): """Test TTS uses default voice when none specified.""" # Arrange @@ -475,7 +473,7 @@ class TestAudioServiceTTS: call_args = mock_model_instance.invoke_tts.call_args assert call_args.kwargs["voice"] == "default-voice" - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_gets_first_available_voice_when_none_configured(self, mock_model_manager_class, factory): """Test TTS gets first available voice when none is configured.""" # Arrange @@ -506,7 +504,7 @@ class TestAudioServiceTTS: assert call_args.kwargs["voice"] == "auto-voice" @patch("services.audio_service.WorkflowService", autospec=True) - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_workflow_mode_with_draft( self, mock_model_manager_class, mock_workflow_service_class, factory ): @@ -549,7 +547,7 @@ class TestAudioServiceTTS: with pytest.raises(ValueError, match="Text is required"): AudioService.transcript_tts(app_model=app, text=None) - @patch("services.audio_service.db.session", autospec=True) + @patch("services.audio_service.db.session") def test_transcript_tts_returns_none_for_invalid_message_id(self, mock_db_session, factory): """Test that TTS returns None for invalid message ID format.""" # Arrange @@ -564,7 +562,7 @@ class TestAudioServiceTTS: # Assert assert result is None - @patch("services.audio_service.db.session", autospec=True) + @patch("services.audio_service.db.session") def test_transcript_tts_returns_none_for_nonexistent_message(self, mock_db_session, factory): """Test that TTS returns None when message doesn't exist.""" # Arrange @@ -585,7 +583,7 @@ class TestAudioServiceTTS: # Assert assert result is None - @patch("services.audio_service.db.session", autospec=True) + @patch("services.audio_service.db.session") def test_transcript_tts_returns_none_for_empty_message_answer(self, mock_db_session, factory): """Test that TTS returns None when message answer is empty.""" # Arrange @@ -611,7 +609,7 @@ class TestAudioServiceTTS: # Assert assert result is None - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_raises_error_when_no_voices_available(self, mock_model_manager_class, factory): """Test that TTS raises error when no voices are available.""" # Arrange @@ -637,7 +635,7 @@ class TestAudioServiceTTS: class TestAudioServiceTTSVoices: """Test TTS voice listing operations.""" - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_voices_success(self, mock_model_manager_class, factory): """Test successful retrieval of TTS voices.""" # Arrange @@ -662,7 +660,7 @@ class TestAudioServiceTTSVoices: assert result == expected_voices mock_model_instance.get_tts_voices.assert_called_once_with(language) - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_voices_raises_error_when_no_model_instance(self, mock_model_manager_class, factory): """Test that TTS voices raises error when no model instance is available.""" # Arrange @@ -677,7 +675,7 @@ class TestAudioServiceTTSVoices: with pytest.raises(ProviderNotSupportTextToSpeechServiceError): AudioService.transcript_tts_voices(tenant_id=tenant_id, language=language) - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_voices_propagates_exceptions(self, mock_model_manager_class, factory): """Test that TTS voices propagates exceptions from model instance.""" # Arrange diff --git a/api/tests/unit_tests/services/test_billing_service.py b/api/tests/unit_tests/services/test_billing_service.py index eecb3c7672d..316381f0ca1 100644 --- a/api/tests/unit_tests/services/test_billing_service.py +++ b/api/tests/unit_tests/services/test_billing_service.py @@ -1303,6 +1303,24 @@ class TestBillingServiceSubscriptionOperations: # Assert assert result == {} + def test_get_plan_bulk_converts_string_expiration_date_to_int(self, mock_send_request): + """Test bulk plan retrieval converts string expiration_date to int.""" + # Arrange + tenant_ids = ["tenant-1"] + mock_send_request.return_value = { + "data": { + "tenant-1": {"plan": "sandbox", "expiration_date": "1735689600"}, + } + } + + # Act + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert + assert "tenant-1" in result + assert isinstance(result["tenant-1"]["expiration_date"], int) + assert result["tenant-1"]["expiration_date"] == 1735689600 + def test_get_plan_bulk_with_invalid_tenant_plan_skipped(self, mock_send_request): """Test bulk plan retrieval when one tenant has invalid plan data (should skip that tenant).""" # Arrange diff --git a/api/tests/unit_tests/services/test_conversation_service.py b/api/tests/unit_tests/services/test_conversation_service.py index 75551531a23..1bf4c0e1721 100644 --- a/api/tests/unit_tests/services/test_conversation_service.py +++ b/api/tests/unit_tests/services/test_conversation_service.py @@ -6,15 +6,17 @@ Tests are organized by functionality and include edge cases, error handling, and both positive and negative test scenarios. """ -from datetime import datetime, timedelta +from datetime import timedelta from unittest.mock import MagicMock, Mock, create_autospec, patch import pytest from sqlalchemy import asc, desc from core.app.entities.app_invoke_entities import InvokeFrom +from libs.datetime_utils import naive_utc_now from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account, ConversationVariable +from models.enums import ConversationFromSource from models.model import App, Conversation, EndUser, Message from services.conversation_service import ConversationService from services.errors.conversation import ( @@ -121,8 +123,8 @@ class ConversationServiceTestDataFactory: conversation.is_deleted = kwargs.get("is_deleted", False) conversation.name = kwargs.get("name", "Test Conversation") conversation.status = kwargs.get("status", "normal") - conversation.created_at = kwargs.get("created_at", datetime.utcnow()) - conversation.updated_at = kwargs.get("updated_at", datetime.utcnow()) + conversation.created_at = kwargs.get("created_at", naive_utc_now()) + conversation.updated_at = kwargs.get("updated_at", naive_utc_now()) for key, value in kwargs.items(): setattr(conversation, key, value) return conversation @@ -151,7 +153,7 @@ class ConversationServiceTestDataFactory: message.conversation_id = conversation_id message.app_id = app_id message.query = kwargs.get("query", "Test message content") - message.created_at = kwargs.get("created_at", datetime.utcnow()) + message.created_at = kwargs.get("created_at", naive_utc_now()) for key, value in kwargs.items(): setattr(message, key, value) return message @@ -180,8 +182,8 @@ class ConversationServiceTestDataFactory: variable.conversation_id = conversation_id variable.app_id = app_id variable.data = {"name": kwargs.get("name", "test_var"), "value": kwargs.get("value", "test_value")} - variable.created_at = kwargs.get("created_at", datetime.utcnow()) - variable.updated_at = kwargs.get("updated_at", datetime.utcnow()) + variable.created_at = kwargs.get("created_at", naive_utc_now()) + variable.updated_at = kwargs.get("updated_at", naive_utc_now()) # Mock to_variable method mock_variable = Mock() @@ -301,7 +303,7 @@ class TestConversationServiceHelpers: """ # Arrange mock_conversation = ConversationServiceTestDataFactory.create_conversation_mock() - mock_conversation.updated_at = datetime.utcnow() + mock_conversation.updated_at = naive_utc_now() # Act condition = ConversationService._build_filter_condition( @@ -322,7 +324,7 @@ class TestConversationServiceHelpers: """ # Arrange mock_conversation = ConversationServiceTestDataFactory.create_conversation_mock() - mock_conversation.created_at = datetime.utcnow() + mock_conversation.created_at = naive_utc_now() # Act condition = ConversationService._build_filter_condition( @@ -350,7 +352,7 @@ class TestConversationServiceGetConversation: app_model = ConversationServiceTestDataFactory.create_app_mock() user = ConversationServiceTestDataFactory.create_account_mock() conversation = ConversationServiceTestDataFactory.create_conversation_mock( - from_account_id=user.id, from_source="console" + from_account_id=user.id, from_source=ConversationFromSource.CONSOLE ) mock_query = mock_db_session.query.return_value @@ -374,7 +376,7 @@ class TestConversationServiceGetConversation: app_model = ConversationServiceTestDataFactory.create_app_mock() user = ConversationServiceTestDataFactory.create_end_user_mock() conversation = ConversationServiceTestDataFactory.create_conversation_mock( - from_end_user_id=user.id, from_source="api" + from_end_user_id=user.id, from_source=ConversationFromSource.API ) mock_query = mock_db_session.query.return_value @@ -667,9 +669,9 @@ class TestConversationServiceConversationalVariable: mock_session_factory.create_session.return_value.__enter__.return_value = mock_session last_variable = ConversationServiceTestDataFactory.create_conversation_variable_mock( - created_at=datetime.utcnow() - timedelta(hours=1) + created_at=naive_utc_now() - timedelta(hours=1) ) - variable = ConversationServiceTestDataFactory.create_conversation_variable_mock(created_at=datetime.utcnow()) + variable = ConversationServiceTestDataFactory.create_conversation_variable_mock(created_at=naive_utc_now()) mock_session.scalar.return_value = last_variable mock_session.scalars.return_value.all.return_value = [variable] @@ -1111,7 +1113,7 @@ class TestConversationServiceEdgeCases: mock_session_factory.create_session.return_value.__enter__.return_value = mock_session conversation = ConversationServiceTestDataFactory.create_conversation_mock( - from_source="api", from_end_user_id="user-123" + from_source=ConversationFromSource.API, from_end_user_id="user-123" ) mock_session.scalars.return_value.all.return_value = [conversation] @@ -1143,7 +1145,7 @@ class TestConversationServiceEdgeCases: mock_session_factory.create_session.return_value.__enter__.return_value = mock_session conversation = ConversationServiceTestDataFactory.create_conversation_mock( - from_source="console", from_account_id="account-123" + from_source=ConversationFromSource.CONSOLE, from_account_id="account-123" ) mock_session.scalars.return_value.all.return_value = [conversation] diff --git a/api/tests/unit_tests/services/test_conversation_variable_updater.py b/api/tests/unit_tests/services/test_conversation_variable_updater.py deleted file mode 100644 index 20f7caa78e9..00000000000 --- a/api/tests/unit_tests/services/test_conversation_variable_updater.py +++ /dev/null @@ -1,75 +0,0 @@ -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest - -from dify_graph.variables import StringVariable -from services.conversation_variable_updater import ConversationVariableNotFoundError, ConversationVariableUpdater - - -class TestConversationVariableUpdater: - def test_should_update_conversation_variable_data_and_commit(self): - """Test update persists serialized variable data when the row exists.""" - conversation_id = "conv-123" - variable = StringVariable( - id="var-123", - name="topic", - value="new value", - ) - expected_json = variable.model_dump_json() - - row = SimpleNamespace(data="old value") - session = MagicMock() - session.scalar.return_value = row - - session_context = MagicMock() - session_context.__enter__.return_value = session - session_context.__exit__.return_value = None - - session_maker = MagicMock(return_value=session_context) - updater = ConversationVariableUpdater(session_maker) - - updater.update(conversation_id=conversation_id, variable=variable) - - session_maker.assert_called_once_with() - session.scalar.assert_called_once() - stmt = session.scalar.call_args.args[0] - compiled_params = stmt.compile().params - assert variable.id in compiled_params.values() - assert conversation_id in compiled_params.values() - assert row.data == expected_json - session.commit.assert_called_once() - - def test_should_raise_not_found_error_when_conversation_variable_missing(self): - """Test update raises ConversationVariableNotFoundError when no matching row exists.""" - conversation_id = "conv-404" - variable = StringVariable( - id="var-404", - name="topic", - value="value", - ) - - session = MagicMock() - session.scalar.return_value = None - - session_context = MagicMock() - session_context.__enter__.return_value = session - session_context.__exit__.return_value = None - - session_maker = MagicMock(return_value=session_context) - updater = ConversationVariableUpdater(session_maker) - - with pytest.raises(ConversationVariableNotFoundError, match="conversation variable not found in the database"): - updater.update(conversation_id=conversation_id, variable=variable) - - session.commit.assert_not_called() - - def test_should_do_nothing_when_flush_is_called(self): - """Test flush currently behaves as a no-op and returns None.""" - session_maker = MagicMock() - updater = ConversationVariableUpdater(session_maker) - - result = updater.flush() - - assert result is None - session_maker.assert_not_called() diff --git a/api/tests/unit_tests/services/test_credit_pool_service.py b/api/tests/unit_tests/services/test_credit_pool_service.py deleted file mode 100644 index 9ef314cb9ed..00000000000 --- a/api/tests/unit_tests/services/test_credit_pool_service.py +++ /dev/null @@ -1,157 +0,0 @@ -from types import SimpleNamespace -from unittest.mock import MagicMock, patch - -import pytest - -import services.credit_pool_service as credit_pool_service_module -from core.errors.error import QuotaExceededError -from models import TenantCreditPool -from services.credit_pool_service import CreditPoolService - - -@pytest.fixture -def mock_credit_deduction_setup(): - """Fixture providing common setup for credit deduction tests.""" - pool = SimpleNamespace(remaining_credits=50) - fake_engine = MagicMock() - session = MagicMock() - session_context = MagicMock() - session_context.__enter__.return_value = session - session_context.__exit__.return_value = None - - mock_get_pool = patch.object(CreditPoolService, "get_pool", return_value=pool) - mock_db = patch.object(credit_pool_service_module, "db", new=SimpleNamespace(engine=fake_engine)) - mock_session = patch.object(credit_pool_service_module, "Session", return_value=session_context) - - return { - "pool": pool, - "fake_engine": fake_engine, - "session": session, - "session_context": session_context, - "patches": (mock_get_pool, mock_db, mock_session), - } - - -class TestCreditPoolService: - def test_should_create_default_pool_with_trial_type_and_configured_quota(self): - """Test create_default_pool persists a trial pool using configured hosted credits.""" - tenant_id = "tenant-123" - hosted_pool_credits = 5000 - - with ( - patch.object(credit_pool_service_module.dify_config, "HOSTED_POOL_CREDITS", hosted_pool_credits), - patch.object(credit_pool_service_module, "db") as mock_db, - ): - pool = CreditPoolService.create_default_pool(tenant_id) - - assert isinstance(pool, TenantCreditPool) - assert pool.tenant_id == tenant_id - assert pool.pool_type == "trial" - assert pool.quota_limit == hosted_pool_credits - assert pool.quota_used == 0 - mock_db.session.add.assert_called_once_with(pool) - mock_db.session.commit.assert_called_once() - - def test_should_return_first_pool_from_query_when_get_pool_called(self): - """Test get_pool queries by tenant and pool_type and returns first result.""" - tenant_id = "tenant-123" - pool_type = "enterprise" - expected_pool = MagicMock(spec=TenantCreditPool) - - with patch.object(credit_pool_service_module, "db") as mock_db: - query = mock_db.session.query.return_value - filtered_query = query.filter_by.return_value - filtered_query.first.return_value = expected_pool - - result = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type=pool_type) - - assert result == expected_pool - mock_db.session.query.assert_called_once_with(TenantCreditPool) - query.filter_by.assert_called_once_with(tenant_id=tenant_id, pool_type=pool_type) - filtered_query.first.assert_called_once() - - def test_should_return_false_when_pool_not_found_in_check_credits_available(self): - """Test check_credits_available returns False when tenant has no pool.""" - with patch.object(CreditPoolService, "get_pool", return_value=None) as mock_get_pool: - result = CreditPoolService.check_credits_available(tenant_id="tenant-123", credits_required=10) - - assert result is False - mock_get_pool.assert_called_once_with("tenant-123", "trial") - - def test_should_return_true_when_remaining_credits_cover_required_amount(self): - """Test check_credits_available returns True when remaining credits are sufficient.""" - pool = SimpleNamespace(remaining_credits=100) - - with patch.object(CreditPoolService, "get_pool", return_value=pool) as mock_get_pool: - result = CreditPoolService.check_credits_available(tenant_id="tenant-123", credits_required=60) - - assert result is True - mock_get_pool.assert_called_once_with("tenant-123", "trial") - - def test_should_return_false_when_remaining_credits_are_insufficient(self): - """Test check_credits_available returns False when required credits exceed remaining credits.""" - pool = SimpleNamespace(remaining_credits=30) - - with patch.object(CreditPoolService, "get_pool", return_value=pool): - result = CreditPoolService.check_credits_available(tenant_id="tenant-123", credits_required=60) - - assert result is False - - def test_should_raise_quota_exceeded_when_pool_not_found_in_check_and_deduct(self): - """Test check_and_deduct_credits raises when tenant credit pool does not exist.""" - with patch.object(CreditPoolService, "get_pool", return_value=None): - with pytest.raises(QuotaExceededError, match="Credit pool not found"): - CreditPoolService.check_and_deduct_credits(tenant_id="tenant-123", credits_required=10) - - def test_should_raise_quota_exceeded_when_pool_has_no_remaining_credits(self): - """Test check_and_deduct_credits raises when remaining credits are zero or negative.""" - pool = SimpleNamespace(remaining_credits=0) - - with patch.object(CreditPoolService, "get_pool", return_value=pool): - with pytest.raises(QuotaExceededError, match="No credits remaining"): - CreditPoolService.check_and_deduct_credits(tenant_id="tenant-123", credits_required=10) - - def test_should_deduct_minimum_of_required_and_remaining_credits(self, mock_credit_deduction_setup): - """Test check_and_deduct_credits updates quota_used by the actual deducted amount.""" - tenant_id = "tenant-123" - pool_type = "trial" - credits_required = 200 - remaining_credits = 120 - expected_deducted_credits = 120 - - mock_credit_deduction_setup["pool"].remaining_credits = remaining_credits - patches = mock_credit_deduction_setup["patches"] - session = mock_credit_deduction_setup["session"] - - with patches[0], patches[1], patches[2]: - result = CreditPoolService.check_and_deduct_credits( - tenant_id=tenant_id, - credits_required=credits_required, - pool_type=pool_type, - ) - - assert result == expected_deducted_credits - session.execute.assert_called_once() - session.commit.assert_called_once() - - stmt = session.execute.call_args.args[0] - compiled_params = stmt.compile().params - assert tenant_id in compiled_params.values() - assert pool_type in compiled_params.values() - assert expected_deducted_credits in compiled_params.values() - - def test_should_raise_quota_exceeded_when_deduction_update_fails(self, mock_credit_deduction_setup): - """Test check_and_deduct_credits translates DB update failures to QuotaExceededError.""" - mock_credit_deduction_setup["pool"].remaining_credits = 50 - mock_credit_deduction_setup["session"].execute.side_effect = Exception("db failure") - session = mock_credit_deduction_setup["session"] - - patches = mock_credit_deduction_setup["patches"] - mock_logger = patch.object(credit_pool_service_module, "logger") - - with patches[0], patches[1], patches[2], mock_logger as mock_logger_obj: - with pytest.raises(QuotaExceededError, match="Failed to deduct credits"): - CreditPoolService.check_and_deduct_credits(tenant_id="tenant-123", credits_required=10) - - session.commit.assert_not_called() - mock_logger_obj.exception.assert_called_once() diff --git a/api/tests/unit_tests/services/test_dataset_permission.py b/api/tests/unit_tests/services/test_dataset_permission.py deleted file mode 100644 index 4974d6c1ef5..00000000000 --- a/api/tests/unit_tests/services/test_dataset_permission.py +++ /dev/null @@ -1,305 +0,0 @@ -from unittest.mock import Mock, patch - -import pytest - -from models.account import Account, TenantAccountRole -from models.dataset import Dataset, DatasetPermission, DatasetPermissionEnum -from services.dataset_service import DatasetService -from services.errors.account import NoPermissionError - - -class DatasetPermissionTestDataFactory: - """Factory class for creating test data and mock objects for dataset permission tests.""" - - @staticmethod - def create_dataset_mock( - dataset_id: str = "dataset-123", - tenant_id: str = "test-tenant-123", - created_by: str = "creator-456", - permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME, - **kwargs, - ) -> Mock: - """Create a mock dataset with specified attributes.""" - dataset = Mock(spec=Dataset) - dataset.id = dataset_id - dataset.tenant_id = tenant_id - dataset.created_by = created_by - dataset.permission = permission - for key, value in kwargs.items(): - setattr(dataset, key, value) - return dataset - - @staticmethod - def create_user_mock( - user_id: str = "user-789", - tenant_id: str = "test-tenant-123", - role: TenantAccountRole = TenantAccountRole.NORMAL, - **kwargs, - ) -> Mock: - """Create a mock user with specified attributes.""" - user = Mock(spec=Account) - user.id = user_id - user.current_tenant_id = tenant_id - user.current_role = role - for key, value in kwargs.items(): - setattr(user, key, value) - return user - - @staticmethod - def create_dataset_permission_mock( - dataset_id: str = "dataset-123", - account_id: str = "user-789", - **kwargs, - ) -> Mock: - """Create a mock dataset permission record.""" - permission = Mock(spec=DatasetPermission) - permission.dataset_id = dataset_id - permission.account_id = account_id - for key, value in kwargs.items(): - setattr(permission, key, value) - return permission - - -class TestDatasetPermissionService: - """ - Comprehensive unit tests for DatasetService.check_dataset_permission method. - - This test suite covers all permission scenarios including: - - Cross-tenant access restrictions - - Owner privilege checks - - Different permission levels (ONLY_ME, ALL_TEAM, PARTIAL_TEAM) - - Explicit permission checks for PARTIAL_TEAM - - Error conditions and logging - """ - - @pytest.fixture - def mock_dataset_service_dependencies(self): - """Common mock setup for dataset service dependencies.""" - with patch("services.dataset_service.db.session") as mock_session: - yield { - "db_session": mock_session, - } - - @pytest.fixture - def mock_logging_dependencies(self): - """Mock setup for logging tests.""" - with patch("services.dataset_service.logger") as mock_logging: - yield { - "logging": mock_logging, - } - - def _assert_permission_check_passes(self, dataset: Mock, user: Mock): - """Helper method to verify that permission check passes without raising exceptions.""" - # Should not raise any exception - DatasetService.check_dataset_permission(dataset, user) - - def _assert_permission_check_fails( - self, dataset: Mock, user: Mock, expected_message: str = "You do not have permission to access this dataset." - ): - """Helper method to verify that permission check fails with expected error.""" - with pytest.raises(NoPermissionError, match=expected_message): - DatasetService.check_dataset_permission(dataset, user) - - def _assert_database_query_called(self, mock_session: Mock, dataset_id: str, account_id: str): - """Helper method to verify database query calls for permission checks.""" - mock_session.query().filter_by.assert_called_with(dataset_id=dataset_id, account_id=account_id) - - def _assert_database_query_not_called(self, mock_session: Mock): - """Helper method to verify that database query was not called.""" - mock_session.query.assert_not_called() - - # ==================== Cross-Tenant Access Tests ==================== - - def test_permission_check_different_tenant_should_fail(self): - """Test that users from different tenants cannot access dataset regardless of other permissions.""" - # Create dataset and user from different tenants - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - tenant_id="tenant-123", permission=DatasetPermissionEnum.ALL_TEAM - ) - user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="user-789", tenant_id="different-tenant-456", role=TenantAccountRole.EDITOR - ) - - # Should fail due to different tenant - self._assert_permission_check_fails(dataset, user) - - # ==================== Owner Privilege Tests ==================== - - def test_owner_can_access_any_dataset(self): - """Test that tenant owners can access any dataset regardless of permission level.""" - # Create dataset with restrictive permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.ONLY_ME) - - # Create owner user - owner_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="owner-999", role=TenantAccountRole.OWNER - ) - - # Owner should have access regardless of dataset permission - self._assert_permission_check_passes(dataset, owner_user) - - # ==================== ONLY_ME Permission Tests ==================== - - def test_only_me_permission_creator_can_access(self): - """Test ONLY_ME permission allows only the dataset creator to access.""" - # Create dataset with ONLY_ME permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - created_by="creator-456", permission=DatasetPermissionEnum.ONLY_ME - ) - - # Create creator user - creator_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="creator-456", role=TenantAccountRole.EDITOR - ) - - # Creator should be able to access - self._assert_permission_check_passes(dataset, creator_user) - - def test_only_me_permission_others_cannot_access(self): - """Test ONLY_ME permission denies access to non-creators.""" - # Create dataset with ONLY_ME permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - created_by="creator-456", permission=DatasetPermissionEnum.ONLY_ME - ) - - # Create normal user (not the creator) - normal_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="normal-789", role=TenantAccountRole.NORMAL - ) - - # Non-creator should be denied access - self._assert_permission_check_fails(dataset, normal_user) - - # ==================== ALL_TEAM Permission Tests ==================== - - def test_all_team_permission_allows_access(self): - """Test ALL_TEAM permission allows any team member to access the dataset.""" - # Create dataset with ALL_TEAM permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.ALL_TEAM) - - # Create different types of team members - normal_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="normal-789", role=TenantAccountRole.NORMAL - ) - editor_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="editor-456", role=TenantAccountRole.EDITOR - ) - - # All team members should have access - self._assert_permission_check_passes(dataset, normal_user) - self._assert_permission_check_passes(dataset, editor_user) - - # ==================== PARTIAL_TEAM Permission Tests ==================== - - def test_partial_team_permission_creator_can_access(self, mock_dataset_service_dependencies): - """Test PARTIAL_TEAM permission allows creator to access without database query.""" - # Create dataset with PARTIAL_TEAM permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM - ) - - # Create creator user - creator_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="creator-456", role=TenantAccountRole.EDITOR - ) - - # Creator should have access without database query - self._assert_permission_check_passes(dataset, creator_user) - self._assert_database_query_not_called(mock_dataset_service_dependencies["db_session"]) - - def test_partial_team_permission_with_explicit_permission(self, mock_dataset_service_dependencies): - """Test PARTIAL_TEAM permission allows users with explicit permission records.""" - # Create dataset with PARTIAL_TEAM permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM) - - # Create normal user (not the creator) - normal_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="normal-789", role=TenantAccountRole.NORMAL - ) - - # Mock database query to return a permission record - mock_permission = DatasetPermissionTestDataFactory.create_dataset_permission_mock( - dataset_id=dataset.id, account_id=normal_user.id - ) - mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = mock_permission - - # User with explicit permission should have access - self._assert_permission_check_passes(dataset, normal_user) - self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, normal_user.id) - - def test_partial_team_permission_without_explicit_permission(self, mock_dataset_service_dependencies): - """Test PARTIAL_TEAM permission denies users without explicit permission records.""" - # Create dataset with PARTIAL_TEAM permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM) - - # Create normal user (not the creator) - normal_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="normal-789", role=TenantAccountRole.NORMAL - ) - - # Mock database query to return None (no permission record) - mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None - - # User without explicit permission should be denied access - self._assert_permission_check_fails(dataset, normal_user) - self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, normal_user.id) - - def test_partial_team_permission_non_creator_without_permission_fails(self, mock_dataset_service_dependencies): - """Test that non-creators without explicit permission are denied access to PARTIAL_TEAM datasets.""" - # Create dataset with PARTIAL_TEAM permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM - ) - - # Create a different user (not the creator) - other_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="other-user-123", role=TenantAccountRole.NORMAL - ) - - # Mock database query to return None (no permission record) - mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None - - # Non-creator without explicit permission should be denied access - self._assert_permission_check_fails(dataset, other_user) - self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, other_user.id) - - # ==================== Enum Usage Tests ==================== - - def test_partial_team_permission_uses_correct_enum(self): - """Test that the method correctly uses DatasetPermissionEnum.PARTIAL_TEAM instead of string literals.""" - # Create dataset with PARTIAL_TEAM permission using enum - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM - ) - - # Create creator user - creator_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="creator-456", role=TenantAccountRole.EDITOR - ) - - # Creator should always have access regardless of permission level - self._assert_permission_check_passes(dataset, creator_user) - - # ==================== Logging Tests ==================== - - def test_permission_denied_logs_debug_message(self, mock_dataset_service_dependencies, mock_logging_dependencies): - """Test that permission denied events are properly logged for debugging purposes.""" - # Create dataset with PARTIAL_TEAM permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM) - - # Create normal user (not the creator) - normal_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="normal-789", role=TenantAccountRole.NORMAL - ) - - # Mock database query to return None (no permission record) - mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None - - # Attempt permission check (should fail) - with pytest.raises(NoPermissionError): - DatasetService.check_dataset_permission(dataset, normal_user) - - # Verify debug message was logged with correct user and dataset information - mock_logging_dependencies["logging"].debug.assert_called_with( - "User %s does not have permission to access dataset %s", normal_user.id, dataset.id - ) diff --git a/api/tests/unit_tests/services/test_dataset_service.py b/api/tests/unit_tests/services/test_dataset_service.py deleted file mode 100644 index a1d2f6410ce..00000000000 --- a/api/tests/unit_tests/services/test_dataset_service.py +++ /dev/null @@ -1,129 +0,0 @@ -"""Unit tests for non-SQL DocumentService orchestration behaviors. - -This file intentionally keeps only collaborator-oriented document indexing -orchestration tests. SQL-backed dataset lifecycle cases are covered by -integration tests under testcontainers. -""" - -from unittest.mock import Mock, patch - -import pytest - -from models.dataset import Document -from services.errors.document import DocumentIndexingError - - -class DatasetServiceUnitDataFactory: - """Factory for creating lightweight document doubles used in unit tests.""" - - @staticmethod - def create_document_mock( - document_id: str = "doc-123", - dataset_id: str = "dataset-123", - indexing_status: str = "completed", - is_paused: bool = False, - ) -> Mock: - """Create a document-shaped mock for DocumentService orchestration tests.""" - document = Mock(spec=Document) - document.id = document_id - document.dataset_id = dataset_id - document.indexing_status = indexing_status - document.is_paused = is_paused - document.paused_by = None - document.paused_at = None - return document - - -class TestDatasetServiceDocumentIndexing: - """Unit tests for pause/recover/retry orchestration without SQL assertions.""" - - @pytest.fixture - def mock_document_service_dependencies(self): - """Patch non-SQL collaborators used by DocumentService methods.""" - with ( - patch("services.dataset_service.redis_client") as mock_redis, - patch("services.dataset_service.db.session") as mock_db, - patch("services.dataset_service.current_user") as mock_current_user, - ): - mock_current_user.id = "user-123" - yield { - "redis_client": mock_redis, - "db_session": mock_db, - "current_user": mock_current_user, - } - - def test_pause_document_success(self, mock_document_service_dependencies): - """Pause a document that is currently in an indexable status.""" - # Arrange - document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="indexing") - - # Act - from services.dataset_service import DocumentService - - DocumentService.pause_document(document) - - # Assert - assert document.is_paused is True - assert document.paused_by == "user-123" - mock_document_service_dependencies["db_session"].add.assert_called_once_with(document) - mock_document_service_dependencies["db_session"].commit.assert_called_once() - mock_document_service_dependencies["redis_client"].setnx.assert_called_once_with( - f"document_{document.id}_is_paused", - "True", - ) - - def test_pause_document_invalid_status_error(self, mock_document_service_dependencies): - """Raise DocumentIndexingError when pausing a completed document.""" - # Arrange - document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="completed") - - # Act / Assert - from services.dataset_service import DocumentService - - with pytest.raises(DocumentIndexingError): - DocumentService.pause_document(document) - - def test_recover_document_success(self, mock_document_service_dependencies): - """Recover a paused document and dispatch the recover indexing task.""" - # Arrange - document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="indexing", is_paused=True) - - # Act - with patch("services.dataset_service.recover_document_indexing_task") as recover_task: - from services.dataset_service import DocumentService - - DocumentService.recover_document(document) - - # Assert - assert document.is_paused is False - assert document.paused_by is None - assert document.paused_at is None - mock_document_service_dependencies["db_session"].add.assert_called_once_with(document) - mock_document_service_dependencies["db_session"].commit.assert_called_once() - mock_document_service_dependencies["redis_client"].delete.assert_called_once_with( - f"document_{document.id}_is_paused" - ) - recover_task.delay.assert_called_once_with(document.dataset_id, document.id) - - def test_retry_document_indexing_success(self, mock_document_service_dependencies): - """Reset documents to waiting state and dispatch retry indexing task.""" - # Arrange - dataset_id = "dataset-123" - documents = [ - DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-1", indexing_status="error"), - DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-2", indexing_status="error"), - ] - mock_document_service_dependencies["redis_client"].get.return_value = None - - # Act - with patch("services.dataset_service.retry_document_indexing_task") as retry_task: - from services.dataset_service import DocumentService - - DocumentService.retry_document(dataset_id, documents) - - # Assert - assert all(document.indexing_status == "waiting" for document in documents) - assert mock_document_service_dependencies["db_session"].add.call_count == 2 - assert mock_document_service_dependencies["db_session"].commit.call_count == 2 - assert mock_document_service_dependencies["redis_client"].setex.call_count == 2 - retry_task.delay.assert_called_once_with(dataset_id, ["doc-1", "doc-2"], "user-123") diff --git a/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py b/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py deleted file mode 100644 index abff48347e8..00000000000 --- a/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py +++ /dev/null @@ -1,100 +0,0 @@ -import datetime -from unittest.mock import Mock, patch - -import pytest - -from models.dataset import Dataset, Document -from services.dataset_service import DocumentService -from tests.unit_tests.conftest import redis_mock - - -class DocumentBatchUpdateTestDataFactory: - """Factory class for creating test data and mock objects for document batch update tests.""" - - @staticmethod - def create_dataset_mock(dataset_id: str = "dataset-123", tenant_id: str = "tenant-456") -> Mock: - """Create a mock dataset with specified attributes.""" - dataset = Mock(spec=Dataset) - dataset.id = dataset_id - dataset.tenant_id = tenant_id - return dataset - - @staticmethod - def create_user_mock(user_id: str = "user-789") -> Mock: - """Create a mock user.""" - user = Mock() - user.id = user_id - return user - - @staticmethod - def create_document_mock( - document_id: str = "doc-1", - name: str = "test_document.pdf", - enabled: bool = True, - archived: bool = False, - indexing_status: str = "completed", - completed_at: datetime.datetime | None = None, - **kwargs, - ) -> Mock: - """Create a mock document with specified attributes.""" - document = Mock(spec=Document) - document.id = document_id - document.name = name - document.enabled = enabled - document.archived = archived - document.indexing_status = indexing_status - document.completed_at = completed_at or datetime.datetime.now() - - document.disabled_at = None - document.disabled_by = None - document.archived_at = None - document.archived_by = None - document.updated_at = None - - for key, value in kwargs.items(): - setattr(document, key, value) - return document - - -class TestDatasetServiceBatchUpdateDocumentStatus: - """Unit tests for non-SQL path in DocumentService.batch_update_document_status.""" - - @pytest.fixture - def mock_document_service_dependencies(self): - """Common mock setup for document service dependencies.""" - with ( - patch("services.dataset_service.DocumentService.get_document") as mock_get_doc, - patch("extensions.ext_database.db.session") as mock_db, - patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now, - ): - current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) - mock_naive_utc_now.return_value = current_time - - yield { - "get_document": mock_get_doc, - "db_session": mock_db, - "naive_utc_now": mock_naive_utc_now, - "current_time": current_time, - } - - def test_batch_update_invalid_action_error(self, mock_document_service_dependencies): - """Test that ValueError is raised when an invalid action is provided.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True) - mock_document_service_dependencies["get_document"].return_value = doc - - redis_mock.reset_mock() - redis_mock.get.return_value = None - - invalid_action = "invalid_action" - with pytest.raises(ValueError) as exc_info: - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1"], action=invalid_action, user=user - ) - - assert invalid_action in str(exc_info.value) - assert "Invalid action" in str(exc_info.value) - - redis_mock.setex.assert_not_called() diff --git a/api/tests/unit_tests/services/test_dataset_service_create_dataset.py b/api/tests/unit_tests/services/test_dataset_service_create_dataset.py deleted file mode 100644 index f8c52706561..00000000000 --- a/api/tests/unit_tests/services/test_dataset_service_create_dataset.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Unit tests for non-SQL validation paths in DatasetService dataset creation.""" - -from unittest.mock import Mock, patch -from uuid import uuid4 - -import pytest - -from services.dataset_service import DatasetService -from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity - - -class TestDatasetServiceCreateRagPipelineDatasetNonSQL: - """Unit coverage for non-SQL validation in create_empty_rag_pipeline_dataset.""" - - @pytest.fixture - def mock_rag_pipeline_dependencies(self): - """Patch database session and current_user for validation-only unit coverage.""" - with ( - patch("services.dataset_service.db.session") as mock_db, - patch("services.dataset_service.current_user") as mock_current_user, - ): - yield { - "db_session": mock_db, - "current_user_mock": mock_current_user, - } - - def test_create_rag_pipeline_dataset_missing_current_user_error(self, mock_rag_pipeline_dependencies): - """Raise ValueError when current_user.id is unavailable before SQL persistence.""" - # Arrange - tenant_id = str(uuid4()) - mock_rag_pipeline_dependencies["current_user_mock"].id = None - - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query - - icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") - entity = RagPipelineDatasetCreateEntity( - name="Test Dataset", - description="", - icon_info=icon_info, - permission="only_me", - ) - - # Act / Assert - with pytest.raises(ValueError, match="Current user or current user id not found"): - DatasetService.create_empty_rag_pipeline_dataset( - tenant_id=tenant_id, - rag_pipeline_dataset_create_entity=entity, - ) diff --git a/api/tests/unit_tests/services/test_dataset_service_dataset.py b/api/tests/unit_tests/services/test_dataset_service_dataset.py new file mode 100644 index 00000000000..92aed7c30a8 --- /dev/null +++ b/api/tests/unit_tests/services/test_dataset_service_dataset.py @@ -0,0 +1,1760 @@ +"""Unit tests for DatasetService and dataset-related collaborators.""" + +from .dataset_service_test_helpers import ( + CloudPlan, + Dataset, + DatasetCollectionBindingService, + DatasetNameDuplicateError, + DatasetPermissionEnum, + DatasetPermissionService, + DatasetProcessRule, + DatasetService, + DatasetServiceUnitDataFactory, + DocumentIndexingError, + DocumentService, + LLMBadRequestError, + MagicMock, + Mock, + ModelFeature, + ModelType, + NoPermissionError, + NotFound, + PipelineIconInfo, + ProviderTokenNotInitError, + RagPipelineDatasetCreateEntity, + SimpleNamespace, + TenantAccountRole, + _make_knowledge_configuration, + _make_retrieval_model, + _make_session_context, + json, + patch, + pytest, +) + + +class TestDatasetServiceQueries: + """Unit tests for DatasetService query composition and fallback branches.""" + + @pytest.fixture + def mock_dataset_query_dependencies(self): + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.helper.escape_like_pattern", return_value="escaped-search") as escape_like, + patch("services.dataset_service.TagService.get_target_ids_by_tag_ids") as get_target_ids, + ): + mock_db.paginate.return_value = SimpleNamespace(items=["dataset"], total=1) + yield { + "db": mock_db, + "escape_like_pattern": escape_like, + "get_target_ids": get_target_ids, + } + + def test_get_datasets_returns_paginated_results_for_public_view(self, mock_dataset_query_dependencies): + items, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id="tenant-1") + + assert items == ["dataset"] + assert total == 1 + mock_dataset_query_dependencies["db"].paginate.assert_called_once() + mock_dataset_query_dependencies["escape_like_pattern"].assert_not_called() + + def test_get_datasets_short_circuits_for_dataset_operator_without_permissions( + self, mock_dataset_query_dependencies + ): + user = DatasetServiceUnitDataFactory.create_user_mock(role=TenantAccountRole.DATASET_OPERATOR) + mock_dataset_query_dependencies["db"].session.query.return_value.filter_by.return_value.all.return_value = [] + + items, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id="tenant-1", user=user) + + assert items == [] + assert total == 0 + mock_dataset_query_dependencies["db"].paginate.assert_not_called() + + def test_get_datasets_short_circuits_when_tag_lookup_returns_no_target_ids(self, mock_dataset_query_dependencies): + mock_dataset_query_dependencies["get_target_ids"].return_value = [] + + items, total = DatasetService.get_datasets( + page=1, + per_page=20, + tenant_id="tenant-1", + tag_ids=["tag-1"], + ) + + assert items == [] + assert total == 0 + mock_dataset_query_dependencies["get_target_ids"].assert_called_once_with("knowledge", "tenant-1", ["tag-1"]) + mock_dataset_query_dependencies["db"].paginate.assert_not_called() + + def test_get_datasets_search_and_tag_filters_call_collaborators(self, mock_dataset_query_dependencies): + mock_dataset_query_dependencies["get_target_ids"].return_value = ["dataset-1"] + + items, total = DatasetService.get_datasets( + page=2, + per_page=10, + tenant_id="tenant-1", + search="report", + tag_ids=["tag-1"], + ) + + assert items == ["dataset"] + assert total == 1 + mock_dataset_query_dependencies["escape_like_pattern"].assert_called_once_with("report") + mock_dataset_query_dependencies["get_target_ids"].assert_called_once_with("knowledge", "tenant-1", ["tag-1"]) + mock_dataset_query_dependencies["db"].paginate.assert_called_once() + + def test_get_process_rules_returns_latest_rule_when_present(self): + dataset_process_rule = Mock(spec=DatasetProcessRule) + dataset_process_rule.mode = "automatic" + dataset_process_rule.rules_dict = {"delimiter": "\n"} + + with patch("services.dataset_service.db") as mock_db: + ( + mock_db.session.query.return_value.where.return_value.order_by.return_value.limit.return_value.one_or_none.return_value + ) = dataset_process_rule + + result = DatasetService.get_process_rules("dataset-1") + + assert result == {"mode": "automatic", "rules": {"delimiter": "\n"}} + + def test_get_process_rules_falls_back_to_default_rules_when_missing(self): + with patch("services.dataset_service.db") as mock_db: + ( + mock_db.session.query.return_value.where.return_value.order_by.return_value.limit.return_value.one_or_none.return_value + ) = None + + result = DatasetService.get_process_rules("dataset-1") + + assert result == { + "mode": DocumentService.DEFAULT_RULES["mode"], + "rules": DocumentService.DEFAULT_RULES["rules"], + } + + def test_get_datasets_by_ids_returns_empty_for_missing_ids(self): + with patch("services.dataset_service.db") as mock_db: + items, total = DatasetService.get_datasets_by_ids([], "tenant-1") + + assert items == [] + assert total == 0 + mock_db.paginate.assert_not_called() + + def test_get_datasets_by_ids_uses_paginate_for_non_empty_input(self): + with patch("services.dataset_service.db") as mock_db: + mock_db.paginate.return_value = SimpleNamespace(items=["dataset-1"], total=1) + + items, total = DatasetService.get_datasets_by_ids(["dataset-1"], "tenant-1") + + assert items == ["dataset-1"] + assert total == 1 + mock_db.paginate.assert_called_once() + + def test_get_dataset_returns_first_match(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock() + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.filter_by.return_value.first.return_value = dataset + + result = DatasetService.get_dataset(dataset.id) + + assert result is dataset + + +class TestDatasetServiceValidation: + """Unit tests for DatasetService validation helpers.""" + + @pytest.mark.parametrize( + ("dataset_doc_form", "incoming_doc_form"), + [(None, "text_model"), ("text_model", "text_model")], + ) + def test_check_doc_form_allows_matching_or_missing_dataset_doc_form(self, dataset_doc_form, incoming_doc_form): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(doc_form=dataset_doc_form) + + DatasetService.check_doc_form(dataset, incoming_doc_form) + + def test_check_doc_form_rejects_mismatched_doc_form(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(doc_form="qa_model") + + with pytest.raises(ValueError, match="doc_form is different"): + DatasetService.check_doc_form(dataset, "text_model") + + def test_check_dataset_model_setting_skips_non_high_quality_datasets(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="economy") + + with patch("services.dataset_service.ModelManager") as model_manager_cls: + DatasetService.check_dataset_model_setting(dataset) + + model_manager_cls.assert_not_called() + + def test_check_dataset_model_setting_validates_embedding_model_for_high_quality_dataset(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="high_quality") + + with patch("services.dataset_service.ModelManager") as model_manager_cls: + DatasetService.check_dataset_model_setting(dataset) + + model_manager_cls.for_tenant.return_value.get_model_instance.assert_called_once_with( + tenant_id=dataset.tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + + def test_check_dataset_model_setting_wraps_llm_bad_request_error(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="high_quality") + + with patch("services.dataset_service.ModelManager") as model_manager_cls: + model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = LLMBadRequestError() + + with pytest.raises(ValueError, match="No Embedding Model available"): + DatasetService.check_dataset_model_setting(dataset) + + def test_check_dataset_model_setting_wraps_provider_token_error(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="high_quality") + + with patch("services.dataset_service.ModelManager") as model_manager_cls: + model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError( + "token missing" + ) + + with pytest.raises(ValueError, match="The dataset is unavailable, due to: token missing"): + DatasetService.check_dataset_model_setting(dataset) + + def test_check_embedding_model_setting_wraps_provider_token_error_description(self): + with patch("services.dataset_service.ModelManager") as model_manager_cls: + model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError( + "provider setup" + ) + + with pytest.raises(ValueError, match="provider setup"): + DatasetService.check_embedding_model_setting("tenant-1", "provider", "embedding-model") + + def test_check_reranking_model_setting_uses_rerank_model_type(self): + with patch("services.dataset_service.ModelManager") as model_manager_cls: + DatasetService.check_reranking_model_setting("tenant-1", "provider", "reranker") + + model_manager_cls.for_tenant.return_value.get_model_instance.assert_called_once_with( + tenant_id="tenant-1", + provider="provider", + model_type=ModelType.RERANK, + model="reranker", + ) + + def test_check_reranking_model_setting_wraps_bad_request(self): + with patch("services.dataset_service.ModelManager") as model_manager_cls: + model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = LLMBadRequestError() + + with pytest.raises(ValueError, match="No Rerank Model available"): + DatasetService.check_reranking_model_setting("tenant-1", "provider", "reranker") + + def test_check_is_multimodal_model_returns_true_when_model_supports_vision(self): + model_schema = SimpleNamespace(features=[ModelFeature.VISION]) + model_type_instance = MagicMock() + model_type_instance.get_model_schema.return_value = model_schema + model_instance = SimpleNamespace( + model_type_instance=model_type_instance, + model_name="embedding-model", + credentials={"api_key": "secret"}, + ) + + with patch("services.dataset_service.ModelManager") as model_manager_cls: + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = model_instance + + result = DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model") + + assert result is True + + def test_check_is_multimodal_model_returns_false_when_vision_feature_is_absent(self): + model_schema = SimpleNamespace(features=[]) + model_type_instance = MagicMock() + model_type_instance.get_model_schema.return_value = model_schema + model_instance = SimpleNamespace( + model_type_instance=model_type_instance, + model_name="embedding-model", + credentials={"api_key": "secret"}, + ) + + with patch("services.dataset_service.ModelManager") as model_manager_cls: + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = model_instance + + result = DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model") + + assert result is False + + def test_check_is_multimodal_model_raises_when_schema_is_missing(self): + model_type_instance = MagicMock() + model_type_instance.get_model_schema.return_value = None + model_instance = SimpleNamespace( + model_type_instance=model_type_instance, + model_name="embedding-model", + credentials={"api_key": "secret"}, + ) + + with patch("services.dataset_service.ModelManager") as model_manager_cls: + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = model_instance + + with pytest.raises(ValueError, match="Model schema not found"): + DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model") + + def test_check_is_multimodal_model_wraps_bad_request_error(self): + with patch("services.dataset_service.ModelManager") as model_manager_cls: + model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = LLMBadRequestError() + + with pytest.raises(ValueError, match="No Model available"): + DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model") + + +class TestDatasetServiceCreationAndUpdate: + """Unit tests for dataset creation and update helpers.""" + + def test_create_empty_dataset_raises_when_name_already_exists(self): + account = SimpleNamespace(id="user-1") + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.filter_by.return_value.first.return_value = object() + + with pytest.raises(DatasetNameDuplicateError, match="Dataset with name Dataset already exists"): + DatasetService.create_empty_dataset("tenant-1", "Dataset", None, "economy", account) + + def test_create_empty_dataset_uses_default_embedding_model_for_high_quality_dataset(self): + account = SimpleNamespace(id="user-1") + default_embedding_model = SimpleNamespace(provider="provider", model_name="default-embedding") + + with ( + patch("services.dataset_service.db") as mock_db, + patch( + "services.dataset_service.Dataset", + side_effect=lambda **kwargs: SimpleNamespace(id="dataset-1", **kwargs), + ), + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch.object(DatasetService, "check_embedding_model_setting") as check_embedding, + ): + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + model_manager_cls.for_tenant.return_value.get_default_model_instance.return_value = default_embedding_model + + dataset = DatasetService.create_empty_dataset( + tenant_id="tenant-1", + name="Dataset", + description="Description", + indexing_technique="high_quality", + account=account, + ) + + assert dataset.embedding_model_provider == "provider" + assert dataset.embedding_model == "default-embedding" + assert dataset.permission == DatasetPermissionEnum.ONLY_ME + assert dataset.provider == "vendor" + model_manager_cls.for_tenant.return_value.get_default_model_instance.assert_called_once_with( + tenant_id="tenant-1", + model_type=ModelType.TEXT_EMBEDDING, + ) + check_embedding.assert_not_called() + mock_db.session.commit.assert_called_once() + + def test_create_empty_dataset_creates_external_binding_for_high_quality_dataset(self): + account = SimpleNamespace(id="user-1") + retrieval_model = _make_retrieval_model() + embedding_model = SimpleNamespace(provider="provider", model_name="embedding-model") + + with ( + patch("services.dataset_service.db") as mock_db, + patch( + "services.dataset_service.Dataset", + side_effect=lambda **kwargs: SimpleNamespace(id="dataset-1", **kwargs), + ), + patch( + "services.dataset_service.ExternalKnowledgeBindings", + side_effect=lambda **kwargs: SimpleNamespace(**kwargs), + ) as binding_cls, + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch("services.dataset_service.ExternalDatasetService.get_external_knowledge_api", return_value=object()), + patch.object(DatasetService, "check_embedding_model_setting") as check_embedding, + patch.object(DatasetService, "check_reranking_model_setting") as check_reranking, + ): + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model + + dataset = DatasetService.create_empty_dataset( + tenant_id="tenant-1", + name="External Dataset", + description="Description", + indexing_technique="high_quality", + account=account, + permission=DatasetPermissionEnum.ALL_TEAM, + provider="external", + external_knowledge_api_id="api-1", + external_knowledge_id="knowledge-1", + embedding_model_provider="provider", + embedding_model_name="embedding-model", + retrieval_model=retrieval_model, + summary_index_setting={"enable": True}, + ) + + assert dataset.embedding_model_provider == "provider" + assert dataset.embedding_model == "embedding-model" + assert dataset.retrieval_model == retrieval_model.model_dump() + assert dataset.summary_index_setting == {"enable": True} + check_embedding.assert_called_once_with("tenant-1", "provider", "embedding-model") + check_reranking.assert_called_once_with("tenant-1", "rerank-provider", "rerank-model") + binding_cls.assert_called_once_with( + tenant_id="tenant-1", + dataset_id="dataset-1", + external_knowledge_api_id="api-1", + external_knowledge_id="knowledge-1", + created_by="user-1", + ) + assert mock_db.session.add.call_count == 2 + mock_db.session.commit.assert_called_once() + + def test_create_empty_rag_pipeline_dataset_raises_for_duplicate_name(self): + entity = RagPipelineDatasetCreateEntity( + name="Existing Dataset", + description="Description", + icon_info=PipelineIconInfo(icon="book", icon_background="#fff"), + permission=DatasetPermissionEnum.ALL_TEAM, + ) + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.filter_by.return_value.first.return_value = object() + + with pytest.raises(DatasetNameDuplicateError, match="Existing Dataset already exists"): + DatasetService.create_empty_rag_pipeline_dataset("tenant-1", entity) + + def test_create_empty_rag_pipeline_dataset_generates_name_and_creates_dataset(self): + entity = RagPipelineDatasetCreateEntity( + name="", + description="Description", + icon_info=PipelineIconInfo(icon="book", icon_background="#fff"), + permission=DatasetPermissionEnum.ALL_TEAM, + ) + pipeline = SimpleNamespace(id="pipeline-1") + + def pipeline_factory(**kwargs): + pipeline.__dict__.update(kwargs) + return pipeline + + def dataset_factory(**kwargs): + return SimpleNamespace(id="dataset-1", **kwargs) + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.current_user", SimpleNamespace(id="user-1")), + patch("services.dataset_service.generate_incremental_name", return_value="Untitled 2") as generate_name, + patch("services.dataset_service.Pipeline", side_effect=pipeline_factory), + patch("services.dataset_service.Dataset", side_effect=dataset_factory), + ): + mock_db.session.query.return_value.filter_by.return_value.all.return_value = [ + SimpleNamespace(name="Untitled"), + SimpleNamespace(name="Untitled 1"), + ] + + dataset = DatasetService.create_empty_rag_pipeline_dataset("tenant-1", entity) + + assert entity.name == "Untitled 2" + assert dataset.pipeline_id == "pipeline-1" + assert dataset.runtime_mode == "rag_pipeline" + generate_name.assert_called_once_with(["Untitled", "Untitled 1"], "Untitled") + mock_db.session.commit.assert_called_once() + + def test_create_empty_rag_pipeline_dataset_requires_current_user_id(self): + entity = RagPipelineDatasetCreateEntity( + name="Dataset", + description="Description", + icon_info=PipelineIconInfo(icon="book", icon_background="#fff"), + permission=DatasetPermissionEnum.ALL_TEAM, + ) + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.current_user", SimpleNamespace(id=None)), + ): + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + + with pytest.raises(ValueError, match="Current user or current user id not found"): + DatasetService.create_empty_rag_pipeline_dataset("tenant-1", entity) + + def test_update_dataset_raises_when_dataset_is_missing(self): + with patch.object(DatasetService, "get_dataset", return_value=None): + with pytest.raises(ValueError, match="Dataset not found"): + DatasetService.update_dataset("dataset-1", {}, SimpleNamespace(id="user-1")) + + def test_update_dataset_raises_when_new_name_conflicts(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1", tenant_id="tenant-1") + dataset.name = "Old Dataset" + + with ( + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch.object(DatasetService, "_has_dataset_same_name", return_value=True), + ): + with pytest.raises(ValueError, match="Dataset name already exists"): + DatasetService.update_dataset("dataset-1", {"name": "New Dataset"}, SimpleNamespace(id="user-1")) + + def test_update_dataset_routes_external_datasets_to_external_helper(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1", tenant_id="tenant-1") + dataset.provider = "external" + user = DatasetServiceUnitDataFactory.create_user_mock() + + with ( + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch.object(DatasetService, "check_dataset_permission") as check_permission, + patch.object(DatasetService, "_update_external_dataset", return_value="updated") as update_external, + ): + result = DatasetService.update_dataset("dataset-1", {"name": dataset.name}, user) + + assert result == "updated" + check_permission.assert_called_once_with(dataset, user) + update_external.assert_called_once_with(dataset, {"name": dataset.name}, user) + + def test_update_dataset_routes_internal_datasets_to_internal_helper(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1", tenant_id="tenant-1") + dataset.provider = "vendor" + user = DatasetServiceUnitDataFactory.create_user_mock() + + with ( + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch.object(DatasetService, "check_dataset_permission") as check_permission, + patch.object(DatasetService, "_update_internal_dataset", return_value="updated") as update_internal, + ): + result = DatasetService.update_dataset("dataset-1", {"name": dataset.name}, user) + + assert result == "updated" + check_permission.assert_called_once_with(dataset, user) + update_internal.assert_called_once_with(dataset, {"name": dataset.name}, user) + + def test_has_dataset_same_name_returns_true_when_query_matches(self): + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.first.return_value = object() + + result = DatasetService._has_dataset_same_name("tenant-1", "dataset-1", "Dataset") + + assert result is True + + def test_update_external_dataset_updates_dataset_and_binding(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + user = SimpleNamespace(id="user-1") + now = object() + + with ( + patch.object(DatasetService, "_update_external_knowledge_binding") as update_binding, + patch("services.dataset_service.naive_utc_now", return_value=now), + patch("services.dataset_service.db") as mock_db, + ): + result = DatasetService._update_external_dataset( + dataset, + { + "external_retrieval_model": {"top_k": 3}, + "summary_index_setting": {"enable": True}, + "name": "Updated Dataset", + "description": "Updated description", + "permission": DatasetPermissionEnum.PARTIAL_TEAM, + "external_knowledge_id": "knowledge-1", + "external_knowledge_api_id": "api-1", + }, + user, + ) + + assert result is dataset + assert dataset.retrieval_model == {"top_k": 3} + assert dataset.summary_index_setting == {"enable": True} + assert dataset.name == "Updated Dataset" + assert dataset.description == "Updated description" + assert dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM + assert dataset.updated_by == "user-1" + assert dataset.updated_at is now + update_binding.assert_called_once_with("dataset-1", "knowledge-1", "api-1") + mock_db.session.add.assert_called_once_with(dataset) + mock_db.session.commit.assert_called_once() + + @pytest.mark.parametrize( + ("payload", "message"), + [ + ({"external_knowledge_api_id": "api-1"}, "External knowledge id is required"), + ({"external_knowledge_id": "knowledge-1"}, "External knowledge api id is required"), + ], + ) + def test_update_external_dataset_requires_external_binding_fields(self, payload, message): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + + with pytest.raises(ValueError, match=message): + DatasetService._update_external_dataset(dataset, payload, SimpleNamespace(id="user-1")) + + def test_update_external_knowledge_binding_updates_changed_binding_values(self): + binding = SimpleNamespace(external_knowledge_id="old-knowledge", external_knowledge_api_id="old-api") + session = MagicMock() + session.query.return_value.filter_by.return_value.first.return_value = binding + session_context = _make_session_context(session) + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.Session", return_value=session_context), + ): + DatasetService._update_external_knowledge_binding("dataset-1", "new-knowledge", "new-api") + + assert binding.external_knowledge_id == "new-knowledge" + assert binding.external_knowledge_api_id == "new-api" + mock_db.session.add.assert_called_once_with(binding) + + def test_update_external_knowledge_binding_raises_for_missing_binding(self): + session = MagicMock() + session.query.return_value.filter_by.return_value.first.return_value = None + session_context = _make_session_context(session) + + with ( + patch("services.dataset_service.db"), + patch("services.dataset_service.Session", return_value=session_context), + ): + with pytest.raises(ValueError, match="External knowledge binding not found"): + DatasetService._update_external_knowledge_binding("dataset-1", "knowledge-1", "api-1") + + def test_update_internal_dataset_updates_fields_and_dispatches_regeneration_tasks(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + user = SimpleNamespace(id="user-1") + now = object() + update_payload = { + "name": "Updated Dataset", + "description": None, + "partial_member_list": [{"user_id": "member-1"}], + "external_knowledge_api_id": "api-1", + "external_knowledge_id": "knowledge-1", + "external_retrieval_model": {"top_k": 2}, + "retrieval_model": {"top_k": 4}, + "summary_index_setting": {"enable": True}, + "icon_info": {"icon": "book"}, + } + + with ( + patch.object(DatasetService, "_handle_indexing_technique_change", return_value="update"), + patch.object(DatasetService, "_update_pipeline_knowledge_base_node_data") as update_pipeline, + patch("services.dataset_service.naive_utc_now", return_value=now), + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.deal_dataset_vector_index_task") as vector_task, + patch("services.dataset_service.regenerate_summary_index_task") as regenerate_task, + ): + result = DatasetService._update_internal_dataset(dataset, update_payload.copy(), user) + + assert result is dataset + updated_values = mock_db.session.query.return_value.filter_by.return_value.update.call_args.args[0] + assert updated_values["name"] == "Updated Dataset" + assert updated_values["description"] is None + assert updated_values["retrieval_model"] == {"top_k": 4} + assert updated_values["summary_index_setting"] == {"enable": True} + assert updated_values["icon_info"] == {"icon": "book"} + assert updated_values["updated_by"] == "user-1" + assert updated_values["updated_at"] is now + assert "partial_member_list" not in updated_values + assert "external_knowledge_api_id" not in updated_values + assert "external_knowledge_id" not in updated_values + assert "external_retrieval_model" not in updated_values + mock_db.session.commit.assert_called_once() + mock_db.session.refresh.assert_called_once_with(dataset) + update_pipeline.assert_called_once_with(dataset, "user-1") + vector_task.delay.assert_called_once_with("dataset-1", "update") + regenerate_task.delay.assert_called_once_with( + "dataset-1", + regenerate_reason="embedding_model_changed", + regenerate_vectors_only=True, + ) + + def test_update_pipeline_knowledge_base_node_data_returns_early_for_non_pipeline_dataset(self): + dataset = SimpleNamespace(runtime_mode="workflow", pipeline_id="pipeline-1") + + with patch("services.dataset_service.db") as mock_db: + DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1") + + mock_db.session.query.assert_not_called() + + def test_update_pipeline_knowledge_base_node_data_returns_when_pipeline_is_missing(self): + dataset = SimpleNamespace(runtime_mode="rag_pipeline", pipeline_id="pipeline-1") + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + + DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1") + + mock_db.session.commit.assert_not_called() + + def test_update_pipeline_knowledge_base_node_data_updates_published_and_draft_workflows(self): + dataset = SimpleNamespace( + id="dataset-1", + runtime_mode="rag_pipeline", + pipeline_id="pipeline-1", + embedding_model="embedding-model", + embedding_model_provider="provider", + retrieval_model={"top_k": 5}, + chunk_structure="paragraph", + indexing_technique="high_quality", + keyword_number=8, + summary_index_setting={"enable": True}, + ) + pipeline = SimpleNamespace(id="pipeline-1", tenant_id="tenant-1") + published_workflow = SimpleNamespace( + graph=json.dumps({"nodes": [{"data": {"type": "knowledge-index"}}, {"data": {"type": "start"}}]}), + type="chat", + features={"feature": True}, + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + draft_workflow = SimpleNamespace(graph=json.dumps({"nodes": [{"data": {"type": "knowledge-index"}}]})) + new_workflow = SimpleNamespace(id="workflow-1") + rag_pipeline_service = MagicMock() + rag_pipeline_service.get_published_workflow.return_value = published_workflow + rag_pipeline_service.get_draft_workflow.return_value = draft_workflow + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.RagPipelineService", return_value=rag_pipeline_service), + patch("services.dataset_service.Workflow.new", return_value=new_workflow) as workflow_new, + ): + mock_db.session.query.return_value.filter_by.return_value.first.return_value = pipeline + + DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1") + + published_graph = json.loads(workflow_new.call_args.kwargs["graph"]) + assert published_graph["nodes"][0]["data"]["embedding_model"] == "embedding-model" + assert published_graph["nodes"][0]["data"]["summary_index_setting"] == {"enable": True} + assert json.loads(draft_workflow.graph)["nodes"][0]["data"]["embedding_model_provider"] == "provider" + mock_db.session.add.assert_any_call(new_workflow) + mock_db.session.add.assert_any_call(draft_workflow) + mock_db.session.commit.assert_called_once() + + def test_update_pipeline_knowledge_base_node_data_rolls_back_when_update_fails(self): + dataset = SimpleNamespace(runtime_mode="rag_pipeline", pipeline_id="pipeline-1") + pipeline = SimpleNamespace(id="pipeline-1", tenant_id="tenant-1") + rag_pipeline_service = MagicMock() + rag_pipeline_service.get_published_workflow.side_effect = RuntimeError("boom") + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.RagPipelineService", return_value=rag_pipeline_service), + ): + mock_db.session.query.return_value.filter_by.return_value.first.return_value = pipeline + + with pytest.raises(RuntimeError, match="boom"): + DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1") + + mock_db.session.rollback.assert_called_once() + + def test_handle_indexing_technique_change_returns_none_without_indexing_technique(self): + filtered_data: dict[str, object] = {} + dataset = SimpleNamespace(indexing_technique="economy") + + result = DatasetService._handle_indexing_technique_change(dataset, {}, filtered_data) + + assert result is None + assert filtered_data == {} + + def test_handle_indexing_technique_change_switches_to_economy(self): + filtered_data: dict[str, object] = {} + dataset = SimpleNamespace(indexing_technique="high_quality") + + result = DatasetService._handle_indexing_technique_change( + dataset, + {"indexing_technique": "economy"}, + filtered_data, + ) + + assert result == "remove" + assert filtered_data == { + "embedding_model": None, + "embedding_model_provider": None, + "collection_binding_id": None, + } + + def test_handle_indexing_technique_change_switches_to_high_quality(self): + filtered_data: dict[str, object] = {} + dataset = SimpleNamespace(indexing_technique="economy") + + with patch.object(DatasetService, "_configure_embedding_model_for_high_quality") as configure_embedding: + result = DatasetService._handle_indexing_technique_change( + dataset, + {"indexing_technique": "high_quality"}, + filtered_data, + ) + + assert result == "add" + configure_embedding.assert_called_once_with({"indexing_technique": "high_quality"}, filtered_data) + + def test_handle_indexing_technique_change_delegates_when_technique_is_unchanged(self): + filtered_data: dict[str, object] = {} + dataset = SimpleNamespace(indexing_technique="high_quality") + + with patch.object( + DatasetService, + "_handle_embedding_model_update_when_technique_unchanged", + return_value="update", + ) as update_embedding: + result = DatasetService._handle_indexing_technique_change( + dataset, + {"indexing_technique": "high_quality"}, + filtered_data, + ) + + assert result == "update" + update_embedding.assert_called_once_with(dataset, {"indexing_technique": "high_quality"}, filtered_data) + + def test_configure_embedding_model_for_high_quality_updates_filtered_data(self): + class FakeAccount: + pass + + current_user = FakeAccount() + current_user.current_tenant_id = "tenant-1" + embedding_model = SimpleNamespace(provider="provider", model_name="embedding-model") + filtered_data: dict[str, object] = {} + + with ( + patch("services.dataset_service.Account", FakeAccount), + patch("services.dataset_service.current_user", current_user), + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch( + "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding", + return_value=SimpleNamespace(id="binding-1"), + ), + ): + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model + + DatasetService._configure_embedding_model_for_high_quality( + {"embedding_model_provider": "provider", "embedding_model": "embedding-model"}, + filtered_data, + ) + + assert filtered_data == { + "embedding_model": "embedding-model", + "embedding_model_provider": "provider", + "collection_binding_id": "binding-1", + } + + @pytest.mark.parametrize( + ("error", "message"), + [ + (LLMBadRequestError(), "No Embedding Model available"), + (ProviderTokenNotInitError("provider setup"), "provider setup"), + ], + ) + def test_configure_embedding_model_for_high_quality_wraps_model_errors(self, error, message): + class FakeAccount: + pass + + current_user = FakeAccount() + current_user.current_tenant_id = "tenant-1" + + with ( + patch("services.dataset_service.Account", FakeAccount), + patch("services.dataset_service.current_user", current_user), + patch("services.dataset_service.ModelManager") as model_manager_cls, + ): + model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = error + + with pytest.raises(ValueError, match=message): + DatasetService._configure_embedding_model_for_high_quality( + {"embedding_model_provider": "provider", "embedding_model": "embedding-model"}, + {}, + ) + + def test_handle_embedding_model_update_when_technique_unchanged_preserves_existing_settings(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + embedding_model_provider="provider", + embedding_model="embedding-model", + ) + filtered_data: dict[str, object] = {} + + with patch.object(DatasetService, "_preserve_existing_embedding_settings") as preserve_settings: + result = DatasetService._handle_embedding_model_update_when_technique_unchanged( + dataset, + {}, + filtered_data, + ) + + assert result is None + preserve_settings.assert_called_once_with(dataset, filtered_data) + + def test_handle_embedding_model_update_when_technique_unchanged_updates_when_model_is_provided(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + embedding_model_provider="provider", + embedding_model="embedding-model", + ) + + with patch.object(DatasetService, "_update_embedding_model_settings", return_value="update") as update_settings: + result = DatasetService._handle_embedding_model_update_when_technique_unchanged( + dataset, + {"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"}, + {}, + ) + + assert result == "update" + update_settings.assert_called_once() + + def test_preserve_existing_embedding_settings_keeps_current_binding(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + embedding_model_provider="provider", + embedding_model="embedding-model", + collection_binding_id="binding-1", + ) + filtered_data = {"embedding_model_provider": "", "embedding_model": ""} + + DatasetService._preserve_existing_embedding_settings(dataset, filtered_data) + + assert filtered_data == { + "embedding_model_provider": "provider", + "embedding_model": "embedding-model", + "collection_binding_id": "binding-1", + } + + def test_preserve_existing_embedding_settings_removes_empty_placeholders_without_existing_values(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + embedding_model_provider=None, + embedding_model=None, + collection_binding_id=None, + ) + filtered_data = {"embedding_model_provider": "", "embedding_model": ""} + + DatasetService._preserve_existing_embedding_settings(dataset, filtered_data) + + assert filtered_data == {} + + def test_update_embedding_model_settings_returns_update_for_changed_values(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + embedding_model_provider="provider", + embedding_model="embedding-model", + ) + + with patch.object(DatasetService, "_apply_new_embedding_settings") as apply_settings: + result = DatasetService._update_embedding_model_settings( + dataset, + {"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"}, + {}, + ) + + assert result == "update" + apply_settings.assert_called_once() + + def test_update_embedding_model_settings_returns_none_for_unchanged_values(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + embedding_model_provider="provider", + embedding_model="embedding-model", + ) + + result = DatasetService._update_embedding_model_settings( + dataset, + {"embedding_model_provider": "provider", "embedding_model": "embedding-model"}, + {}, + ) + + assert result is None + + def test_update_embedding_model_settings_wraps_bad_request_errors(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + embedding_model_provider="provider", + embedding_model="embedding-model", + ) + + with patch.object(DatasetService, "_apply_new_embedding_settings", side_effect=LLMBadRequestError()): + with pytest.raises(ValueError, match="No Embedding Model available"): + DatasetService._update_embedding_model_settings( + dataset, + {"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"}, + {}, + ) + + def test_apply_new_embedding_settings_updates_binding_for_new_model(self): + class FakeAccount: + pass + + current_user = FakeAccount() + current_user.current_tenant_id = "tenant-1" + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(collection_binding_id="binding-1") + filtered_data: dict[str, object] = {} + + with ( + patch("services.dataset_service.Account", FakeAccount), + patch("services.dataset_service.current_user", current_user), + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch( + "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding", + return_value=SimpleNamespace(id="binding-2"), + ), + ): + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = SimpleNamespace( + provider="provider-two", + model_name="embedding-model-two", + ) + + DatasetService._apply_new_embedding_settings( + dataset, + {"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"}, + filtered_data, + ) + + assert filtered_data == { + "embedding_model": "embedding-model-two", + "embedding_model_provider": "provider-two", + "collection_binding_id": "binding-2", + } + + def test_apply_new_embedding_settings_preserves_existing_values_when_provider_token_is_missing(self): + class FakeAccount: + pass + + current_user = FakeAccount() + current_user.current_tenant_id = "tenant-1" + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + embedding_model_provider="provider", + embedding_model="embedding-model", + collection_binding_id="binding-1", + ) + filtered_data: dict[str, object] = {} + + with ( + patch("services.dataset_service.Account", FakeAccount), + patch("services.dataset_service.current_user", current_user), + patch("services.dataset_service.ModelManager") as model_manager_cls, + ): + model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError( + "token missing" + ) + + DatasetService._apply_new_embedding_settings( + dataset, + {"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"}, + filtered_data, + ) + + assert filtered_data == { + "embedding_model_provider": "provider", + "embedding_model": "embedding-model", + "collection_binding_id": "binding-1", + } + + @pytest.mark.parametrize( + ("summary_index_setting", "expected"), + [ + (None, False), + ({"enable": False}, False), + ({"enable": True, "model_name": "old-model", "model_provider_name": "provider"}, False), + ({"enable": True, "model_name": "new-model", "model_provider_name": "provider-two"}, True), + ], + ) + def test_check_summary_index_setting_model_changed(self, summary_index_setting, expected): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + dataset_id="dataset-1", + summary_index_setting={"enable": True, "model_name": "old-model", "model_provider_name": "provider"}, + ) + + result = DatasetService._check_summary_index_setting_model_changed( + dataset, + {"summary_index_setting": summary_index_setting} if summary_index_setting is not None else {}, + ) + + assert result is expected + + +class TestDatasetServiceRagPipelineSettings: + """Unit tests for rag-pipeline dataset setting updates.""" + + def test_update_rag_pipeline_dataset_settings_requires_current_tenant(self): + session = MagicMock() + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + knowledge_configuration = _make_knowledge_configuration() + + with patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id=None)): + with pytest.raises(ValueError, match="Current user or current tenant not found"): + DatasetService.update_rag_pipeline_dataset_settings(session, dataset, knowledge_configuration) + + def test_update_rag_pipeline_dataset_settings_without_published_high_quality_updates_embedding_settings(self): + session = MagicMock() + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + session.merge.return_value = dataset + knowledge_configuration = _make_knowledge_configuration(summary_index_setting={"enable": True}) + embedding_model = SimpleNamespace(provider="provider", model_name="embedding-model") + + with ( + patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")), + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch.object(DatasetService, "check_is_multimodal_model", return_value=True) as check_multimodal, + patch( + "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding", + return_value=SimpleNamespace(id="binding-1"), + ), + ): + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model + + DatasetService.update_rag_pipeline_dataset_settings(session, dataset, knowledge_configuration) + + assert dataset.chunk_structure == "paragraph" + assert dataset.indexing_technique == "high_quality" + assert dataset.embedding_model == "embedding-model" + assert dataset.embedding_model_provider == "provider" + assert dataset.collection_binding_id == "binding-1" + assert dataset.is_multimodal is True + assert dataset.retrieval_model == knowledge_configuration.retrieval_model.model_dump() + assert dataset.summary_index_setting == {"enable": True} + check_multimodal.assert_called_once_with("tenant-1", "provider", "embedding-model") + session.add.assert_called_once_with(dataset) + session.commit.assert_not_called() + + def test_update_rag_pipeline_dataset_settings_without_published_economy_updates_keyword_number(self): + session = MagicMock() + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + session.merge.return_value = dataset + knowledge_configuration = _make_knowledge_configuration( + indexing_technique="economy", + embedding_model_provider="", + embedding_model="", + keyword_number=12, + ) + + with patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")): + DatasetService.update_rag_pipeline_dataset_settings(session, dataset, knowledge_configuration) + + assert dataset.indexing_technique == "economy" + assert dataset.keyword_number == 12 + assert dataset.retrieval_model == knowledge_configuration.retrieval_model.model_dump() + session.add.assert_called_once_with(dataset) + + def test_update_rag_pipeline_dataset_settings_with_published_rejects_chunk_structure_changes(self): + session = MagicMock() + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + dataset.chunk_structure = "paragraph" + session.merge.return_value = dataset + knowledge_configuration = _make_knowledge_configuration(chunk_structure="sentence") + + with patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")): + with pytest.raises(ValueError, match="Chunk structure is not allowed to be updated"): + DatasetService.update_rag_pipeline_dataset_settings( + session, + dataset, + knowledge_configuration, + has_published=True, + ) + + def test_update_rag_pipeline_dataset_settings_with_published_rejects_switch_to_economy(self): + session = MagicMock() + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + dataset.chunk_structure = "paragraph" + dataset.indexing_technique = "high_quality" + session.merge.return_value = dataset + knowledge_configuration = _make_knowledge_configuration( + indexing_technique="economy", + embedding_model_provider="", + embedding_model="", + ) + + with patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")): + with pytest.raises( + ValueError, + match="Knowledge base indexing technique is not allowed to be updated to economy", + ): + DatasetService.update_rag_pipeline_dataset_settings( + session, + dataset, + knowledge_configuration, + has_published=True, + ) + + def test_update_rag_pipeline_dataset_settings_with_published_adds_high_quality_index(self): + session = MagicMock() + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + dataset.chunk_structure = "paragraph" + dataset.indexing_technique = "economy" + session.merge.return_value = dataset + knowledge_configuration = _make_knowledge_configuration() + embedding_model = SimpleNamespace(provider="provider", model_name="embedding-model") + + with ( + patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")), + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch.object(DatasetService, "check_is_multimodal_model", return_value=False), + patch( + "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding", + return_value=SimpleNamespace(id="binding-1"), + ), + patch("services.dataset_service.deal_dataset_index_update_task") as update_task, + ): + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model + + DatasetService.update_rag_pipeline_dataset_settings( + session, + dataset, + knowledge_configuration, + has_published=True, + ) + + assert dataset.indexing_technique == "high_quality" + assert dataset.embedding_model == "embedding-model" + assert dataset.embedding_model_provider == "provider" + assert dataset.collection_binding_id == "binding-1" + assert dataset.is_multimodal is False + assert dataset.retrieval_model == knowledge_configuration.retrieval_model.model_dump() + session.add.assert_called_once_with(dataset) + session.commit.assert_called_once() + update_task.delay.assert_called_once_with("dataset-1", "add") + + def test_update_rag_pipeline_dataset_settings_with_published_updates_changed_embedding_model(self): + session = MagicMock() + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + dataset.chunk_structure = "paragraph" + dataset.indexing_technique = "high_quality" + dataset.embedding_model_provider = "provider" + dataset.embedding_model = "embedding-model" + session.merge.return_value = dataset + knowledge_configuration = _make_knowledge_configuration( + embedding_model_provider="provider-two", + embedding_model="embedding-model-two", + summary_index_setting={"enable": True}, + ) + + with ( + patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")), + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch.object(DatasetService, "check_is_multimodal_model", return_value=True), + patch( + "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding", + return_value=SimpleNamespace(id="binding-2"), + ), + patch("services.dataset_service.deal_dataset_index_update_task") as update_task, + ): + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = SimpleNamespace( + provider="provider-two", + model_name="embedding-model-two", + ) + + DatasetService.update_rag_pipeline_dataset_settings( + session, + dataset, + knowledge_configuration, + has_published=True, + ) + + assert dataset.embedding_model_provider == "provider-two" + assert dataset.embedding_model == "embedding-model-two" + assert dataset.collection_binding_id == "binding-2" + assert dataset.is_multimodal is True + assert dataset.summary_index_setting == {"enable": True} + session.add.assert_called_once_with(dataset) + session.commit.assert_called_once() + update_task.delay.assert_called_once_with("dataset-1", "update") + + def test_update_rag_pipeline_dataset_settings_with_published_skips_embedding_update_when_token_is_missing(self): + session = MagicMock() + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + dataset.chunk_structure = "paragraph" + dataset.indexing_technique = "high_quality" + dataset.embedding_model_provider = "provider" + dataset.embedding_model = "embedding-model" + session.merge.return_value = dataset + knowledge_configuration = _make_knowledge_configuration( + embedding_model_provider="provider-two", + embedding_model="embedding-model-two", + ) + + with ( + patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")), + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch("services.dataset_service.deal_dataset_index_update_task") as update_task, + ): + model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError( + "token missing" + ) + + DatasetService.update_rag_pipeline_dataset_settings( + session, + dataset, + knowledge_configuration, + has_published=True, + ) + + assert dataset.embedding_model_provider == "provider" + assert dataset.embedding_model == "embedding-model" + assert dataset.retrieval_model == knowledge_configuration.retrieval_model.model_dump() + session.add.assert_called_once_with(dataset) + session.commit.assert_called_once() + update_task.delay.assert_called_once_with("dataset-1", "update") + + def test_update_rag_pipeline_dataset_settings_with_published_updates_economy_keyword_number(self): + session = MagicMock() + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + dataset.chunk_structure = "paragraph" + dataset.indexing_technique = "economy" + dataset.keyword_number = 5 + session.merge.return_value = dataset + knowledge_configuration = _make_knowledge_configuration( + indexing_technique="economy", + embedding_model_provider="", + embedding_model="", + keyword_number=9, + ) + + with ( + patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")), + patch("services.dataset_service.deal_dataset_index_update_task") as update_task, + ): + DatasetService.update_rag_pipeline_dataset_settings( + session, + dataset, + knowledge_configuration, + has_published=True, + ) + + assert dataset.keyword_number == 9 + assert dataset.retrieval_model == knowledge_configuration.retrieval_model.model_dump() + session.add.assert_called_once_with(dataset) + session.commit.assert_called_once() + update_task.delay.assert_not_called() + + +class TestDatasetServicePermissionsAndLifecycle: + """Unit tests for dataset permissions, deletion, and metadata helpers.""" + + def test_delete_dataset_returns_false_when_dataset_is_missing(self): + with patch.object(DatasetService, "get_dataset", return_value=None): + result = DatasetService.delete_dataset("dataset-1", user=SimpleNamespace(id="user-1")) + + assert result is False + + def test_delete_dataset_checks_permission_and_deletes_dataset(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock() + + with ( + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch.object(DatasetService, "check_dataset_permission") as check_permission, + patch("services.dataset_service.dataset_was_deleted.send") as send_deleted_signal, + patch("services.dataset_service.db") as mock_db, + ): + result = DatasetService.delete_dataset(dataset.id, user=SimpleNamespace(id="user-1")) + + assert result is True + check_permission.assert_called_once_with(dataset, SimpleNamespace(id="user-1")) + send_deleted_signal.assert_called_once_with(dataset) + mock_db.session.delete.assert_called_once_with(dataset) + mock_db.session.commit.assert_called_once() + + def test_dataset_use_check_returns_scalar_result(self): + with patch("services.dataset_service.db") as mock_db: + mock_db.session.execute.return_value.scalar_one.return_value = True + + result = DatasetService.dataset_use_check("dataset-1") + + assert result is True + + def test_check_dataset_permission_rejects_cross_tenant_access(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(tenant_id="tenant-a") + user = DatasetServiceUnitDataFactory.create_user_mock(tenant_id="tenant-b") + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_permission(dataset, user) + + def test_check_dataset_permission_rejects_only_me_dataset_for_non_creator(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + permission=DatasetPermissionEnum.ONLY_ME, + created_by="owner-1", + ) + user = DatasetServiceUnitDataFactory.create_user_mock( + user_id="member-1", + role=TenantAccountRole.EDITOR, + ) + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_permission(dataset, user) + + def test_check_dataset_permission_rejects_partial_team_user_without_binding(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + permission=DatasetPermissionEnum.PARTIAL_TEAM, + created_by="owner-1", + ) + user = DatasetServiceUnitDataFactory.create_user_mock( + user_id="member-1", + role=TenantAccountRole.EDITOR, + ) + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_permission(dataset, user) + + def test_check_dataset_permission_allows_partial_team_creator_without_lookup(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + permission=DatasetPermissionEnum.PARTIAL_TEAM, + created_by="creator-1", + ) + user = DatasetServiceUnitDataFactory.create_user_mock( + user_id="creator-1", + role=TenantAccountRole.EDITOR, + ) + + with patch("services.dataset_service.db") as mock_db: + DatasetService.check_dataset_permission(dataset, user) + + mock_db.session.query.assert_not_called() + + def test_check_dataset_permission_allows_partial_team_member_with_binding(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + permission=DatasetPermissionEnum.PARTIAL_TEAM, + created_by="owner-1", + ) + user = DatasetServiceUnitDataFactory.create_user_mock( + user_id="member-1", + role=TenantAccountRole.EDITOR, + ) + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.filter_by.return_value.first.return_value = object() + + DatasetService.check_dataset_permission(dataset, user) + + def test_check_dataset_operator_permission_validates_required_arguments(self): + with pytest.raises(ValueError, match="Dataset not found"): + DatasetService.check_dataset_operator_permission(user=SimpleNamespace(id="user-1"), dataset=None) + + with pytest.raises(ValueError, match="User not found"): + DatasetService.check_dataset_operator_permission(user=None, dataset=SimpleNamespace(id="dataset-1")) + + def test_check_dataset_operator_permission_rejects_only_me_for_non_creator(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + permission=DatasetPermissionEnum.ONLY_ME, + created_by="owner-1", + ) + user = DatasetServiceUnitDataFactory.create_user_mock( + user_id="member-1", + role=TenantAccountRole.EDITOR, + ) + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_operator_permission(user=user, dataset=dataset) + + def test_check_dataset_operator_permission_rejects_partial_team_without_binding(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM) + user = DatasetServiceUnitDataFactory.create_user_mock( + user_id="member-1", + role=TenantAccountRole.EDITOR, + ) + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.filter_by.return_value.all.return_value = [] + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_operator_permission(user=user, dataset=dataset) + + def test_get_dataset_queries_delegates_to_paginate(self): + with patch("services.dataset_service.db") as mock_db: + mock_db.desc.side_effect = lambda column: column + mock_db.paginate.return_value = SimpleNamespace(items=["query"], total=1) + + items, total = DatasetService.get_dataset_queries("dataset-1", page=1, per_page=20) + + assert items == ["query"] + assert total == 1 + mock_db.paginate.assert_called_once() + + def test_get_related_apps_returns_ordered_query_results(self): + with patch("services.dataset_service.db") as mock_db: + mock_db.desc.side_effect = lambda column: column + mock_db.session.query.return_value.where.return_value.order_by.return_value.all.return_value = [ + "relation-1" + ] + + result = DatasetService.get_related_apps("dataset-1") + + assert result == ["relation-1"] + + def test_update_dataset_api_status_raises_not_found_for_missing_dataset(self): + with patch.object(DatasetService, "get_dataset", return_value=None): + with pytest.raises(NotFound, match="Dataset not found"): + DatasetService.update_dataset_api_status("dataset-1", True) + + def test_update_dataset_api_status_requires_current_user_id(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(enable_api=False) + + with ( + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch("services.dataset_service.current_user", SimpleNamespace(id=None)), + ): + with pytest.raises(ValueError, match="Current user or current user id not found"): + DatasetService.update_dataset_api_status(dataset.id, True) + + def test_update_dataset_api_status_updates_fields_and_commits(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(enable_api=False) + now = object() + + with ( + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch("services.dataset_service.current_user", SimpleNamespace(id="user-1")), + patch("services.dataset_service.naive_utc_now", return_value=now), + patch("services.dataset_service.db") as mock_db, + ): + DatasetService.update_dataset_api_status(dataset.id, True) + + assert dataset.enable_api is True + assert dataset.updated_by == "user-1" + assert dataset.updated_at is now + mock_db.session.commit.assert_called_once() + + def test_get_dataset_auto_disable_logs_returns_empty_when_billing_is_disabled(self): + class FakeAccount: + pass + + current_user = FakeAccount() + current_user.current_tenant_id = "tenant-1" + + features = SimpleNamespace( + billing=SimpleNamespace(enabled=False, subscription=SimpleNamespace(plan=CloudPlan.PROFESSIONAL)) + ) + + with ( + patch("services.dataset_service.Account", FakeAccount), + patch("services.dataset_service.current_user", current_user), + patch("services.dataset_service.FeatureService.get_features", return_value=features), + patch("services.dataset_service.db") as mock_db, + ): + result = DatasetService.get_dataset_auto_disable_logs("dataset-1") + + assert result == {"document_ids": [], "count": 0} + mock_db.session.scalars.assert_not_called() + + def test_get_dataset_auto_disable_logs_returns_recent_document_ids(self): + class FakeAccount: + pass + + current_user = FakeAccount() + current_user.current_tenant_id = "tenant-1" + logs = [SimpleNamespace(document_id="doc-1"), SimpleNamespace(document_id="doc-2")] + features = SimpleNamespace( + billing=SimpleNamespace(enabled=True, subscription=SimpleNamespace(plan=CloudPlan.PROFESSIONAL)) + ) + + with ( + patch("services.dataset_service.Account", FakeAccount), + patch("services.dataset_service.current_user", current_user), + patch("services.dataset_service.FeatureService.get_features", return_value=features), + patch("services.dataset_service.db") as mock_db, + ): + mock_db.session.scalars.return_value.all.return_value = logs + + result = DatasetService.get_dataset_auto_disable_logs("dataset-1") + + assert result == {"document_ids": ["doc-1", "doc-2"], "count": 2} + + +class TestDatasetServiceDocumentIndexing: + """Unit tests for pause/recover/retry orchestration without SQL assertions.""" + + @pytest.fixture + def mock_document_service_dependencies(self): + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db.session") as mock_db_session, + patch("services.dataset_service.current_user") as mock_current_user, + ): + mock_current_user.id = "user-123" + yield { + "redis_client": mock_redis, + "db_session": mock_db_session, + "current_user": mock_current_user, + } + + def test_pause_document_success(self, mock_document_service_dependencies): + document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="indexing") + + DocumentService.pause_document(document) + + assert document.is_paused is True + assert document.paused_by == "user-123" + mock_document_service_dependencies["db_session"].add.assert_called_once_with(document) + mock_document_service_dependencies["db_session"].commit.assert_called_once() + mock_document_service_dependencies["redis_client"].setnx.assert_called_once_with( + f"document_{document.id}_is_paused", + "True", + ) + + def test_pause_document_invalid_status_error(self, mock_document_service_dependencies): + document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="completed") + + with pytest.raises(DocumentIndexingError): + DocumentService.pause_document(document) + + def test_recover_document_success(self, mock_document_service_dependencies): + document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="indexing", is_paused=True) + + with patch("services.dataset_service.recover_document_indexing_task") as recover_task: + DocumentService.recover_document(document) + + assert document.is_paused is False + assert document.paused_by is None + assert document.paused_at is None + mock_document_service_dependencies["db_session"].add.assert_called_once_with(document) + mock_document_service_dependencies["db_session"].commit.assert_called_once() + mock_document_service_dependencies["redis_client"].delete.assert_called_once_with( + f"document_{document.id}_is_paused" + ) + recover_task.delay.assert_called_once_with(document.dataset_id, document.id) + + def test_retry_document_indexing_success(self, mock_document_service_dependencies): + dataset_id = "dataset-123" + documents = [ + DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-1", indexing_status="error"), + DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-2", indexing_status="error"), + ] + mock_document_service_dependencies["redis_client"].get.return_value = None + + with patch("services.dataset_service.retry_document_indexing_task") as retry_task: + DocumentService.retry_document(dataset_id, documents) + + assert all(document.indexing_status == "waiting" for document in documents) + assert mock_document_service_dependencies["db_session"].add.call_count == 2 + assert mock_document_service_dependencies["db_session"].commit.call_count == 2 + assert mock_document_service_dependencies["redis_client"].setex.call_count == 2 + retry_task.delay.assert_called_once_with(dataset_id, ["doc-1", "doc-2"], "user-123") + + +class TestDatasetCollectionBindingService: + """Unit tests for dataset collection binding lookups and creation.""" + + def test_get_dataset_collection_binding_returns_existing_binding(self): + binding = SimpleNamespace(id="binding-1") + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = binding + + result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "model") + + assert result is binding + mock_db.session.add.assert_not_called() + + def test_get_dataset_collection_binding_creates_binding_when_missing(self): + created_binding = SimpleNamespace(id="binding-2") + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.DatasetCollectionBinding", return_value=created_binding) as binding_cls, + patch.object(Dataset, "gen_collection_name_by_id", return_value="generated-collection"), + ): + mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = None + + result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "model", "dataset") + + assert result is created_binding + binding_cls.assert_called_once_with( + provider_name="provider", + model_name="model", + collection_name="generated-collection", + type="dataset", + ) + mock_db.session.add.assert_called_once_with(created_binding) + mock_db.session.commit.assert_called_once() + + def test_get_dataset_collection_binding_by_id_and_type_raises_when_missing(self): + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = None + + with pytest.raises(ValueError, match="Dataset collection binding not found"): + DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type("binding-1") + + def test_get_dataset_collection_binding_by_id_and_type_returns_binding(self): + binding = SimpleNamespace(id="binding-1") + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = binding + + result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type("binding-1") + + assert result is binding + + +class TestDatasetPermissionService: + """Unit tests for dataset partial-member management helpers.""" + + def test_get_dataset_partial_member_list_returns_scalar_results(self): + with patch("services.dataset_service.db") as mock_db: + mock_db.session.scalars.return_value.all.return_value = ["user-1", "user-2"] + + result = DatasetPermissionService.get_dataset_partial_member_list("dataset-1") + + assert result == ["user-1", "user-2"] + + def test_update_partial_member_list_replaces_permissions_and_commits(self): + with patch("services.dataset_service.db") as mock_db: + DatasetPermissionService.update_partial_member_list( + "tenant-1", + "dataset-1", + [{"user_id": "user-1"}, {"user_id": "user-2"}], + ) + + mock_db.session.query.return_value.where.return_value.delete.assert_called_once() + mock_db.session.add_all.assert_called_once() + mock_db.session.commit.assert_called_once() + + def test_update_partial_member_list_rolls_back_on_exception(self): + with patch("services.dataset_service.db") as mock_db: + mock_db.session.add_all.side_effect = RuntimeError("boom") + + with pytest.raises(RuntimeError, match="boom"): + DatasetPermissionService.update_partial_member_list( + "tenant-1", + "dataset-1", + [{"user_id": "user-1"}], + ) + + mock_db.session.rollback.assert_called_once() + + def test_check_permission_requires_dataset_editor(self): + user = SimpleNamespace(is_dataset_editor=False, is_dataset_operator=False) + dataset = DatasetServiceUnitDataFactory.create_dataset_mock() + + with pytest.raises(NoPermissionError, match="does not have permission"): + DatasetPermissionService.check_permission(user, dataset, "all_team", []) + + def test_check_permission_prevents_dataset_operator_from_changing_permission_mode(self): + user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(permission="all_team") + + with pytest.raises(NoPermissionError, match="cannot change the dataset permissions"): + DatasetPermissionService.check_permission(user, dataset, "only_me", []) + + def test_check_permission_requires_partial_member_list_for_partial_members_mode(self): + user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(permission="partial_members") + + with pytest.raises(ValueError, match="Partial member list is required"): + DatasetPermissionService.check_permission(user, dataset, "partial_members", []) + + def test_check_permission_rejects_dataset_operator_member_list_changes(self): + user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + dataset_id="dataset-1", permission="partial_members" + ) + + with patch.object(DatasetPermissionService, "get_dataset_partial_member_list", return_value=["user-1"]): + with pytest.raises(ValueError, match="cannot change the dataset permissions"): + DatasetPermissionService.check_permission( + user, + dataset, + "partial_members", + [{"user_id": "user-2"}], + ) + + def test_check_permission_allows_dataset_operator_when_member_list_is_unchanged(self): + user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + dataset_id="dataset-1", permission="partial_members" + ) + + with patch.object(DatasetPermissionService, "get_dataset_partial_member_list", return_value=["user-1"]): + DatasetPermissionService.check_permission( + user, + dataset, + "partial_members", + [{"user_id": "user-1"}], + ) + + def test_clear_partial_member_list_deletes_permissions_and_commits(self): + with patch("services.dataset_service.db") as mock_db: + DatasetPermissionService.clear_partial_member_list("dataset-1") + + mock_db.session.query.return_value.where.return_value.delete.assert_called_once() + mock_db.session.commit.assert_called_once() + + def test_clear_partial_member_list_rolls_back_on_exception(self): + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.delete.side_effect = RuntimeError("boom") + + with pytest.raises(RuntimeError, match="boom"): + DatasetPermissionService.clear_partial_member_list("dataset-1") + + mock_db.session.rollback.assert_called_once() diff --git a/api/tests/unit_tests/services/test_dataset_service_document.py b/api/tests/unit_tests/services/test_dataset_service_document.py new file mode 100644 index 00000000000..c8036487ab4 --- /dev/null +++ b/api/tests/unit_tests/services/test_dataset_service_document.py @@ -0,0 +1,2078 @@ +"""Unit tests for DocumentService behaviors in dataset_service.""" + +from .dataset_service_test_helpers import ( + Account, + BuiltInField, + CloudPlan, + DatasetProcessRule, + DatasetService, + DatasetServiceUnitDataFactory, + DataSource, + DocumentIndexingError, + DocumentService, + FileInfo, + FileNotExistsError, + Forbidden, + IndexStructureType, + InfoList, + KnowledgeConfig, + MagicMock, + NoPermissionError, + NotFound, + NotionIcon, + NotionInfo, + NotionPage, + PreProcessingRule, + ProcessRule, + RerankingModel, + RetrievalMethod, + RetrievalModel, + Rule, + Segmentation, + SimpleNamespace, + WebsiteInfo, + _make_dataset, + _make_document, + _make_features, + _make_lock_context, + _make_session_context, + _make_upload_knowledge_config, + create_autospec, + json, + patch, + pytest, +) + + +class TestDocumentServiceDisplayStatus: + """Unit tests for DocumentService display-status helpers.""" + + @pytest.mark.parametrize( + ("raw_status", "expected"), + [ + ("enabled", "available"), + ("AVAILABLE", "available"), + ("paused", "paused"), + ("unknown", None), + (None, None), + ], + ) + def test_normalize_display_status(self, raw_status, expected): + assert DocumentService.normalize_display_status(raw_status) == expected + + def test_build_display_status_filters_returns_empty_tuple_for_unknown_status(self): + assert DocumentService.build_display_status_filters("missing") == () + + def test_apply_display_status_filter_returns_original_query_for_unknown_status(self): + query = MagicMock() + + result = DocumentService.apply_display_status_filter(query, "missing") + + assert result is query + query.where.assert_not_called() + + def test_apply_display_status_filter_applies_where_for_known_status(self): + query = MagicMock() + filtered_query = MagicMock() + query.where.return_value = filtered_query + + result = DocumentService.apply_display_status_filter(query, "enabled") + + assert result is filtered_query + query.where.assert_called_once() + + +class TestDocumentServiceQueryAndDownloadHelpers: + """Unit tests for DocumentService query helpers and download flows.""" + + def test_get_document_returns_none_when_document_id_is_missing(self): + with patch("services.dataset_service.db") as mock_db: + result = DocumentService.get_document("dataset-1", None) + + assert result is None + mock_db.session.query.assert_not_called() + + def test_get_document_queries_by_dataset_and_document_id(self): + document = DatasetServiceUnitDataFactory.create_document_mock() + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.first.return_value = document + + result = DocumentService.get_document("dataset-1", "doc-1") + + assert result is document + + def test_get_documents_by_ids_returns_empty_for_empty_input(self): + with patch("services.dataset_service.db") as mock_db: + result = DocumentService.get_documents_by_ids("dataset-1", []) + + assert result == [] + mock_db.session.scalars.assert_not_called() + + def test_get_documents_by_ids_uses_single_batch_query(self): + document = DatasetServiceUnitDataFactory.create_document_mock() + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.scalars.return_value.all.return_value = [document] + + result = DocumentService.get_documents_by_ids("dataset-1", ["doc-1"]) + + assert result == [document] + mock_db.session.scalars.assert_called_once() + + def test_update_documents_need_summary_returns_zero_for_empty_input(self): + with patch("services.dataset_service.session_factory") as session_factory_mock: + result = DocumentService.update_documents_need_summary("dataset-1", []) + + assert result == 0 + session_factory_mock.create_session.assert_not_called() + + def test_update_documents_need_summary_updates_matching_documents_and_commits(self): + session = MagicMock() + session.query.return_value.filter.return_value.update.return_value = 2 + + with patch("services.dataset_service.session_factory") as session_factory_mock: + session_factory_mock.create_session.return_value = _make_session_context(session) + + result = DocumentService.update_documents_need_summary( + "dataset-1", + ["doc-1", "doc-2"], + need_summary=False, + ) + + assert result == 2 + session.commit.assert_called_once() + + def test_get_document_download_url_uses_upload_file_lookup_and_signed_url_helper(self): + upload_file = DatasetServiceUnitDataFactory.create_upload_file_mock(file_id="file-1") + document = DatasetServiceUnitDataFactory.create_document_mock() + + with ( + patch.object(DocumentService, "_get_upload_file_for_upload_file_document", return_value=upload_file), + patch("services.dataset_service.file_helpers.get_signed_file_url", return_value="signed-url") as get_url, + ): + result = DocumentService.get_document_download_url(document) + + assert result == "signed-url" + get_url.assert_called_once_with(upload_file_id="file-1", as_attachment=True) + + def test_get_upload_file_id_for_upload_file_document_rejects_invalid_source_type(self): + document = DatasetServiceUnitDataFactory.create_document_mock(data_source_type="not-upload-file") + + with pytest.raises(NotFound, match="invalid source"): + DocumentService._get_upload_file_id_for_upload_file_document( + document, + invalid_source_message="invalid source", + missing_file_message="missing file", + ) + + def test_get_upload_file_id_for_upload_file_document_rejects_missing_upload_file_id(self): + document = DatasetServiceUnitDataFactory.create_document_mock(data_source_info_dict={}) + + with pytest.raises(NotFound, match="missing file"): + DocumentService._get_upload_file_id_for_upload_file_document( + document, + invalid_source_message="invalid source", + missing_file_message="missing file", + ) + + def test_get_upload_file_id_for_upload_file_document_returns_string_id(self): + document = DatasetServiceUnitDataFactory.create_document_mock(data_source_info_dict={"upload_file_id": 99}) + + result = DocumentService._get_upload_file_id_for_upload_file_document( + document, + invalid_source_message="invalid source", + missing_file_message="missing file", + ) + + assert result == "99" + + def test_get_upload_file_for_upload_file_document_raises_when_file_service_returns_nothing(self): + document = DatasetServiceUnitDataFactory.create_document_mock( + tenant_id="tenant-1", + data_source_info_dict={"upload_file_id": "file-1"}, + ) + + with patch("services.dataset_service.FileService.get_upload_files_by_ids", return_value={}): + with pytest.raises(NotFound, match="Uploaded file not found"): + DocumentService._get_upload_file_for_upload_file_document(document) + + def test_get_upload_file_for_upload_file_document_returns_upload_file(self): + document = DatasetServiceUnitDataFactory.create_document_mock( + tenant_id="tenant-1", + data_source_info_dict={"upload_file_id": "file-1"}, + ) + upload_file = DatasetServiceUnitDataFactory.create_upload_file_mock(file_id="file-1") + + with patch( + "services.dataset_service.FileService.get_upload_files_by_ids", return_value={"file-1": upload_file} + ): + result = DocumentService._get_upload_file_for_upload_file_document(document) + + assert result is upload_file + + def test_enrich_documents_with_summary_index_status_skips_lookup_when_summary_is_disabled(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(summary_index_setting={"enable": False}) + documents = [ + DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-1", need_summary=True), + DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-2", need_summary=False), + ] + + DocumentService.enrich_documents_with_summary_index_status(documents, dataset, tenant_id="tenant-1") + + assert documents[0].summary_index_status is None + assert documents[1].summary_index_status is None + + def test_enrich_documents_with_summary_index_status_applies_summary_status_map(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + dataset_id="dataset-1", + summary_index_setting={"enable": True}, + ) + documents = [ + DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-1", need_summary=True), + DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-2", need_summary=True), + DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-3", need_summary=False), + ] + + with patch( + "services.summary_index_service.SummaryIndexService.get_documents_summary_index_status", + return_value={"doc-1": "completed", "doc-2": None}, + ) as get_status_map: + DocumentService.enrich_documents_with_summary_index_status(documents, dataset, tenant_id="tenant-1") + + get_status_map.assert_called_once_with( + document_ids=["doc-1", "doc-2"], + dataset_id="dataset-1", + tenant_id="tenant-1", + ) + assert documents[0].summary_index_status == "completed" + assert documents[1].summary_index_status is None + assert documents[2].summary_index_status is None + + def test_generate_document_batch_download_zip_filename_uses_zip_extension(self): + fake_uuid = SimpleNamespace(hex="archive-id") + + with patch("services.dataset_service.uuid.uuid4", return_value=fake_uuid): + result = DocumentService._generate_document_batch_download_zip_filename() + + assert result == "archive-id.zip" + + def test_get_upload_files_by_document_id_for_zip_download_raises_for_missing_documents(self): + with patch.object(DocumentService, "get_documents_by_ids", return_value=[]): + with pytest.raises(NotFound, match="Document not found"): + DocumentService._get_upload_files_by_document_id_for_zip_download( + dataset_id="dataset-1", + document_ids=["doc-1"], + tenant_id="tenant-1", + ) + + def test_get_upload_files_by_document_id_for_zip_download_rejects_cross_tenant_access(self): + document = DatasetServiceUnitDataFactory.create_document_mock( + document_id="doc-1", + tenant_id="tenant-other", + data_source_info_dict={"upload_file_id": "file-1"}, + ) + + with patch.object(DocumentService, "get_documents_by_ids", return_value=[document]): + with pytest.raises(Forbidden, match="No permission"): + DocumentService._get_upload_files_by_document_id_for_zip_download( + dataset_id="dataset-1", + document_ids=["doc-1"], + tenant_id="tenant-1", + ) + + def test_get_upload_files_by_document_id_for_zip_download_rejects_missing_upload_files(self): + document = DatasetServiceUnitDataFactory.create_document_mock( + document_id="doc-1", + tenant_id="tenant-1", + data_source_info_dict={"upload_file_id": "file-1"}, + ) + + with ( + patch.object(DocumentService, "get_documents_by_ids", return_value=[document]), + patch("services.dataset_service.FileService.get_upload_files_by_ids", return_value={}), + ): + with pytest.raises(NotFound, match="Only uploaded-file documents can be downloaded as ZIP"): + DocumentService._get_upload_files_by_document_id_for_zip_download( + dataset_id="dataset-1", + document_ids=["doc-1"], + tenant_id="tenant-1", + ) + + def test_get_upload_files_by_document_id_for_zip_download_returns_document_keyed_mapping(self): + document_a = DatasetServiceUnitDataFactory.create_document_mock( + document_id="doc-1", + tenant_id="tenant-1", + data_source_info_dict={"upload_file_id": "file-1"}, + ) + document_b = DatasetServiceUnitDataFactory.create_document_mock( + document_id="doc-2", + tenant_id="tenant-1", + data_source_info_dict={"upload_file_id": "file-2"}, + ) + upload_file_a = DatasetServiceUnitDataFactory.create_upload_file_mock(file_id="file-1") + upload_file_b = DatasetServiceUnitDataFactory.create_upload_file_mock(file_id="file-2") + + with ( + patch.object(DocumentService, "get_documents_by_ids", return_value=[document_a, document_b]), + patch( + "services.dataset_service.FileService.get_upload_files_by_ids", + return_value={"file-1": upload_file_a, "file-2": upload_file_b}, + ), + ): + result = DocumentService._get_upload_files_by_document_id_for_zip_download( + dataset_id="dataset-1", + document_ids=["doc-1", "doc-2"], + tenant_id="tenant-1", + ) + + assert result == {"doc-1": upload_file_a, "doc-2": upload_file_b} + + def test_prepare_document_batch_download_zip_raises_not_found_for_missing_dataset(self): + user = DatasetServiceUnitDataFactory.create_user_mock() + + with patch.object(DatasetService, "get_dataset", return_value=None): + with pytest.raises(NotFound, match="Dataset not found"): + DocumentService.prepare_document_batch_download_zip( + dataset_id="dataset-1", + document_ids=["doc-1"], + tenant_id="tenant-1", + current_user=user, + ) + + def test_prepare_document_batch_download_zip_translates_permission_error_to_forbidden(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock() + user = DatasetServiceUnitDataFactory.create_user_mock() + + with ( + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch.object(DatasetService, "check_dataset_permission", side_effect=NoPermissionError("blocked")), + ): + with pytest.raises(Forbidden, match="blocked"): + DocumentService.prepare_document_batch_download_zip( + dataset_id=dataset.id, + document_ids=["doc-1"], + tenant_id="tenant-1", + current_user=user, + ) + + def test_prepare_document_batch_download_zip_returns_upload_files_in_requested_order(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock() + user = DatasetServiceUnitDataFactory.create_user_mock() + upload_file_a = DatasetServiceUnitDataFactory.create_upload_file_mock(file_id="file-a") + upload_file_b = DatasetServiceUnitDataFactory.create_upload_file_mock(file_id="file-b") + + with ( + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch.object(DatasetService, "check_dataset_permission"), + patch.object( + DocumentService, + "_get_upload_files_by_document_id_for_zip_download", + return_value={"doc-1": upload_file_a, "doc-2": upload_file_b}, + ), + patch.object(DocumentService, "_generate_document_batch_download_zip_filename", return_value="archive.zip"), + ): + upload_files, download_name = DocumentService.prepare_document_batch_download_zip( + dataset_id=dataset.id, + document_ids=["doc-2", "doc-1"], + tenant_id="tenant-1", + current_user=user, + ) + + assert upload_files == [upload_file_b, upload_file_a] + assert download_name == "archive.zip" + + def test_get_document_by_dataset_id_returns_enabled_documents(self): + document = DatasetServiceUnitDataFactory.create_document_mock(enabled=True) + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.scalars.return_value.all.return_value = [document] + + result = DocumentService.get_document_by_dataset_id("dataset-1") + + assert result == [document] + + def test_get_working_documents_by_dataset_id_returns_scalars_result(self): + document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="completed", archived=False) + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.scalars.return_value.all.return_value = [document] + + result = DocumentService.get_working_documents_by_dataset_id("dataset-1") + + assert result == [document] + + def test_get_error_documents_by_dataset_id_returns_scalars_result(self): + document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="error") + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.scalars.return_value.all.return_value = [document] + + result = DocumentService.get_error_documents_by_dataset_id("dataset-1") + + assert result == [document] + + def test_get_batch_documents_filters_by_current_user_tenant(self): + class FakeAccount: + pass + + current_user = FakeAccount() + current_user.current_tenant_id = "tenant-1" + document = DatasetServiceUnitDataFactory.create_document_mock() + + with ( + patch("services.dataset_service.Account", FakeAccount), + patch("services.dataset_service.current_user", current_user), + patch("services.dataset_service.db") as mock_db, + ): + mock_db.session.scalars.return_value.all.return_value = [document] + + result = DocumentService.get_batch_documents("dataset-1", "batch-1") + + assert result == [document] + + def test_get_document_file_detail_returns_one_or_none(self): + upload_file = DatasetServiceUnitDataFactory.create_upload_file_mock() + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.one_or_none.return_value = upload_file + + result = DocumentService.get_document_file_detail(upload_file.id) + + assert result is upload_file + + +class TestDocumentServiceMutations: + """Unit tests for DocumentService mutation and orchestration helpers.""" + + @pytest.fixture + def rename_account_context(self): + class FakeAccount: + pass + + current_user = FakeAccount() + current_user.id = "user-123" + current_user.current_tenant_id = "tenant-123" + + with ( + patch("services.dataset_service.Account", FakeAccount), + patch("services.dataset_service.current_user", current_user), + ): + yield current_user + + @pytest.mark.parametrize(("archived", "expected"), [(True, True), (False, False)]) + def test_check_archived_returns_boolean_status(self, archived, expected): + document = DatasetServiceUnitDataFactory.create_document_mock(archived=archived) + + assert DocumentService.check_archived(document) is expected + + def test_delete_document_emits_signal_and_commits(self): + document = DatasetServiceUnitDataFactory.create_document_mock( + data_source_type="upload_file", + data_source_info='{"upload_file_id": "file-1"}', + data_source_info_dict={"upload_file_id": "file-1"}, + ) + + with ( + patch("services.dataset_service.document_was_deleted.send") as send_deleted_signal, + patch("services.dataset_service.db") as mock_db, + ): + DocumentService.delete_document(document) + + send_deleted_signal.assert_called_once_with( + document.id, + dataset_id=document.dataset_id, + doc_form=document.doc_form, + file_id="file-1", + ) + mock_db.session.delete.assert_called_once_with(document) + mock_db.session.commit.assert_called_once() + + def test_delete_documents_ignores_empty_input(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock() + + with patch("services.dataset_service.db") as mock_db: + DocumentService.delete_documents(dataset, []) + + mock_db.session.scalars.assert_not_called() + + def test_delete_documents_deletes_rows_and_dispatches_cleanup_task(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(doc_form="text_model") + document_a = DatasetServiceUnitDataFactory.create_document_mock( + document_id="doc-1", + data_source_type="upload_file", + data_source_info_dict={"upload_file_id": "file-1"}, + ) + document_b = DatasetServiceUnitDataFactory.create_document_mock( + document_id="doc-2", + data_source_type="upload_file", + data_source_info_dict={"upload_file_id": "file-2"}, + ) + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.batch_clean_document_task") as clean_task, + ): + mock_db.session.scalars.return_value.all.return_value = [document_a, document_b] + + DocumentService.delete_documents(dataset, ["doc-1", "doc-2"]) + + assert mock_db.session.delete.call_count == 2 + mock_db.session.commit.assert_called_once() + clean_task.delay.assert_called_once_with(["doc-1", "doc-2"], dataset.id, dataset.doc_form, ["file-1", "file-2"]) + + def test_rename_document_raises_when_dataset_is_missing(self, rename_account_context): + with patch.object(DatasetService, "get_dataset", return_value=None): + with pytest.raises(ValueError, match="Dataset not found"): + DocumentService.rename_document("dataset-1", "doc-1", "New Name") + + def test_rename_document_raises_when_document_is_missing(self, rename_account_context): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock() + + with ( + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch.object(DocumentService, "get_document", return_value=None), + ): + with pytest.raises(ValueError, match="Document not found"): + DocumentService.rename_document(dataset.id, "doc-1", "New Name") + + def test_rename_document_rejects_cross_tenant_access(self, rename_account_context): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock() + document = DatasetServiceUnitDataFactory.create_document_mock(tenant_id="tenant-other") + + with ( + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch.object(DocumentService, "get_document", return_value=document), + ): + with pytest.raises(ValueError, match="No permission"): + DocumentService.rename_document(dataset.id, document.id, "New Name") + + def test_rename_document_updates_document_metadata_and_upload_file_name(self, rename_account_context): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + built_in_field_enabled=True, + tenant_id="tenant-1", + ) + document = DatasetServiceUnitDataFactory.create_document_mock( + tenant_id="tenant-1", + doc_metadata={"title": "Old"}, + data_source_info_dict={"upload_file_id": "file-1"}, + ) + rename_account_context.current_tenant_id = "tenant-1" + + with ( + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch.object(DocumentService, "get_document", return_value=document), + patch("services.dataset_service.db") as mock_db, + ): + result = DocumentService.rename_document(dataset.id, document.id, "New Name") + + assert result is document + assert document.name == "New Name" + assert document.doc_metadata[BuiltInField.document_name] == "New Name" + mock_db.session.add.assert_called_once_with(document) + mock_db.session.query.return_value.where.return_value.update.assert_called_once() + mock_db.session.commit.assert_called_once() + + def test_recover_document_raises_when_document_is_not_paused(self): + document = DatasetServiceUnitDataFactory.create_document_mock(is_paused=False) + + with pytest.raises(DocumentIndexingError): + DocumentService.recover_document(document) + + def test_retry_document_raises_when_retry_flag_is_already_set(self): + document = DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-1") + + with patch("services.dataset_service.redis_client") as mock_redis: + mock_redis.get.return_value = "1" + + with pytest.raises(ValueError, match="being retried"): + DocumentService.retry_document("dataset-1", [document]) + + def test_sync_website_document_raises_when_sync_flag_exists(self): + document = DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-1") + + with patch("services.dataset_service.redis_client") as mock_redis: + mock_redis.get.return_value = "1" + + with pytest.raises(ValueError, match="being synced"): + DocumentService.sync_website_document("dataset-1", document) + + def test_sync_website_document_updates_status_sets_cache_and_dispatches_task(self): + document = DatasetServiceUnitDataFactory.create_document_mock( + document_id="doc-1", + data_source_info_dict={"mode": "crawl"}, + ) + document.data_source_info = "{}" + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.sync_website_document_indexing_task") as sync_task, + ): + mock_redis.get.return_value = None + + DocumentService.sync_website_document("dataset-1", document) + + assert document.indexing_status == "waiting" + assert '"mode": "scrape"' in document.data_source_info + mock_db.session.add.assert_called_once_with(document) + mock_db.session.commit.assert_called_once() + mock_redis.setex.assert_called_once_with("document_doc-1_is_sync", 600, 1) + sync_task.delay.assert_called_once_with("dataset-1", "doc-1") + + def test_get_documents_position_returns_next_position_when_documents_exist(self): + document = DatasetServiceUnitDataFactory.create_document_mock(position=7) + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.first.return_value = ( + document + ) + + result = DocumentService.get_documents_position("dataset-1") + + assert result == 8 + + def test_get_documents_position_defaults_to_one_when_dataset_is_empty(self): + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.first.return_value = None + + result = DocumentService.get_documents_position("dataset-1") + + assert result == 1 + + +class TestDocumentServiceSaveDocumentWithoutDatasetId: + """Unit tests for dataset creation around save_document_without_dataset_id.""" + + @pytest.fixture + def account_context(self): + account = create_autospec(Account, instance=True) + account.id = "user-1" + account.current_tenant_id = "tenant-1" + + with patch("services.dataset_service.current_user", account): + yield account + + def test_save_document_without_dataset_id_creates_high_quality_dataset_with_default_retrieval_model( + self, account_context + ): + knowledge_config = KnowledgeConfig( + indexing_technique="high_quality", + data_source=DataSource( + info_list=InfoList( + data_source_type="upload_file", + file_info_list=FileInfo(file_ids=["file-1"]), + ) + ), + embedding_model="embedding-model", + embedding_model_provider="provider", + summary_index_setting={"enable": True}, + is_multimodal=True, + ) + created_dataset = SimpleNamespace( + id="dataset-1", + tenant_id="tenant-1", + name="", + description=None, + ) + first_document = SimpleNamespace(name="VeryLongDocumentNameForDataset.txt") + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch( + "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding", + return_value=SimpleNamespace(id="binding-1"), + ), + patch( + "services.dataset_service.Dataset", + side_effect=lambda **kwargs: created_dataset.__dict__.update(kwargs) or created_dataset, + ) as dataset_cls, + patch.object( + DocumentService, "save_document_with_dataset_id", return_value=([first_document], "batch-1") + ) as save_document, + patch("services.dataset_service.db") as mock_db, + ): + dataset, documents, batch = DocumentService.save_document_without_dataset_id( + tenant_id="tenant-1", + knowledge_config=knowledge_config, + account=account_context, + ) + + assert dataset is created_dataset + assert documents == [first_document] + assert batch == "batch-1" + assert created_dataset.collection_binding_id == "binding-1" + assert created_dataset.retrieval_model["search_method"] == RetrievalMethod.SEMANTIC_SEARCH + assert created_dataset.retrieval_model["top_k"] == 4 + assert created_dataset.summary_index_setting == {"enable": True} + assert created_dataset.is_multimodal is True + assert created_dataset.name == first_document.name[:18] + "..." + assert ( + created_dataset.description + == "useful for when you want to answer queries about the VeryLongDocumentNameForDataset.txt" + ) + dataset_cls.assert_called_once() + save_document.assert_called_once_with(created_dataset, knowledge_config, account_context) + assert mock_db.session.commit.call_count == 1 + + def test_save_document_without_dataset_id_uses_provided_retrieval_model(self, account_context): + retrieval_model = RetrievalModel( + search_method=RetrievalMethod.SEMANTIC_SEARCH, + reranking_enable=True, + reranking_model=RerankingModel( + reranking_provider_name="rerank-provider", + reranking_model_name="rerank-model", + ), + top_k=9, + score_threshold_enabled=True, + score_threshold=0.6, + ) + knowledge_config = KnowledgeConfig( + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="upload_file", + file_info_list=FileInfo(file_ids=["file-1"]), + ) + ), + retrieval_model=retrieval_model, + ) + created_dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1", name="", description=None) + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch( + "services.dataset_service.Dataset", + side_effect=lambda **kwargs: created_dataset.__dict__.update(kwargs) or created_dataset, + ), + patch.object( + DocumentService, + "save_document_with_dataset_id", + return_value=([SimpleNamespace(name="Doc")], "batch-1"), + ), + patch("services.dataset_service.db"), + ): + DocumentService.save_document_without_dataset_id("tenant-1", knowledge_config, account_context) + + assert created_dataset.retrieval_model == retrieval_model.model_dump() + assert created_dataset.collection_binding_id is None + + def test_save_document_without_dataset_id_rejects_sandbox_batch_upload(self, account_context): + knowledge_config = KnowledgeConfig( + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="upload_file", + file_info_list=FileInfo(file_ids=["file-1", "file-2"]), + ) + ), + ) + + with ( + patch( + "services.dataset_service.FeatureService.get_features", + return_value=_make_features(enabled=True, plan=CloudPlan.SANDBOX), + ), + patch.object(DocumentService, "check_documents_upload_quota") as check_quota, + ): + with pytest.raises(ValueError, match="does not support batch upload"): + DocumentService.save_document_without_dataset_id("tenant-1", knowledge_config, account_context) + + check_quota.assert_not_called() + + +class TestDocumentServiceUpdateDocumentWithDatasetId: + """Unit tests for the document-update orchestration path.""" + + @pytest.fixture + def account_context(self): + account = create_autospec(Account, instance=True) + account.id = "user-1" + account.current_tenant_id = "tenant-1" + + with patch("services.dataset_service.current_user", account): + yield account + + def test_update_document_with_dataset_id_raises_when_document_is_missing(self, account_context): + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") + document_data = KnowledgeConfig( + original_document_id="doc-1", + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="upload_file", + file_info_list=FileInfo(file_ids=["file-1"]), + ) + ), + ) + + with ( + patch.object(DocumentService, "get_document", return_value=None), + patch.object(DatasetService, "check_dataset_model_setting") as check_model_setting, + ): + with pytest.raises(NotFound, match="Document not found"): + DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + + check_model_setting.assert_called_once_with(dataset) + + def test_update_document_with_dataset_id_rejects_non_available_documents(self, account_context): + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") + document = SimpleNamespace(display_status="indexing") + document_data = KnowledgeConfig( + original_document_id="doc-1", + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="upload_file", + file_info_list=FileInfo(file_ids=["file-1"]), + ) + ), + ) + + with ( + patch.object(DocumentService, "get_document", return_value=document), + patch.object(DatasetService, "check_dataset_model_setting"), + ): + with pytest.raises(ValueError, match="Document is not available"): + DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + + def test_update_document_with_dataset_id_upload_file_process_rule_and_name_override(self, account_context): + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") + document = _make_document() + document.dataset_process_rule_id = "old-rule" + document_data = KnowledgeConfig( + original_document_id="doc-1", + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="upload_file", + file_info_list=FileInfo(file_ids=["file-1"]), + ) + ), + process_rule=ProcessRule( + mode="custom", + rules=Rule( + pre_processing_rules=[PreProcessingRule(id="remove_stopwords", enabled=True)], + segmentation=Segmentation(separator="\n", max_tokens=128), + ), + ), + name="Renamed document", + doc_form=IndexStructureType.QA_INDEX, + ) + created_process_rule = SimpleNamespace(id="rule-2") + + with ( + patch.object(DocumentService, "get_document", return_value=document), + patch.object(DatasetService, "check_dataset_model_setting"), + patch("services.dataset_service.DatasetProcessRule", return_value=created_process_rule), + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.document_indexing_update_task") as update_task, + ): + upload_query = MagicMock() + upload_query.where.return_value.first.return_value = SimpleNamespace(id="file-1", name="upload.txt") + segment_query = MagicMock() + segment_query.filter_by.return_value.update.return_value = 3 + mock_db.session.query.side_effect = [upload_query, segment_query] + + result = DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + + assert result is document + assert document.dataset_process_rule_id == "rule-2" + assert document.data_source_type == "upload_file" + assert document.data_source_info == '{"upload_file_id": "file-1"}' + assert document.name == "Renamed document" + assert document.indexing_status == "waiting" + assert document.completed_at is None + assert document.processing_started_at is None + assert document.parsing_completed_at is None + assert document.cleaning_completed_at is None + assert document.splitting_completed_at is None + assert document.updated_at == "now" + assert document.created_from == "web" + assert document.doc_form == IndexStructureType.QA_INDEX + assert mock_db.session.commit.call_count == 3 + segment_query.filter_by.return_value.update.assert_called_once() + update_task.delay.assert_called_once_with(document.dataset_id, document.id) + + def test_update_document_with_dataset_id_notion_import_requires_binding(self, account_context): + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") + document = SimpleNamespace(display_status="available", id="doc-1", dataset_id="dataset-1") + document_data = KnowledgeConfig( + original_document_id="doc-1", + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="notion_import", + notion_info_list=[ + NotionInfo( + credential_id="credential-1", + workspace_id="workspace-1", + pages=[NotionPage(page_id="page-1", page_name="Page 1", page_icon=None, type="page")], + ) + ], + ) + ), + ) + + with ( + patch.object(DocumentService, "get_document", return_value=document), + patch.object(DatasetService, "check_dataset_model_setting"), + patch("services.dataset_service.db") as mock_db, + ): + binding_query = MagicMock() + binding_query.where.return_value.first.return_value = None + mock_db.session.query.return_value = binding_query + + with pytest.raises(ValueError, match="Data source binding not found"): + DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + + def test_update_document_with_dataset_id_website_crawl_updates_segments_and_dispatches_task(self, account_context): + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") + document = _make_document() + document_data = KnowledgeConfig( + original_document_id="doc-1", + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="website_crawl", + website_info_list=WebsiteInfo( + provider="firecrawl", + job_id="job-1", + urls=["https://example.com"], + only_main_content=False, + ), + ) + ), + doc_form=IndexStructureType.PARENT_CHILD_INDEX, + ) + + with ( + patch.object(DocumentService, "get_document", return_value=document), + patch.object(DatasetService, "check_dataset_model_setting"), + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.document_indexing_update_task") as update_task, + ): + segment_query = MagicMock() + segment_query.filter_by.return_value.update.return_value = 2 + mock_db.session.query.return_value = segment_query + + result = DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + + assert result is document + assert document.data_source_type == "website_crawl" + assert document.data_source_info == ( + '{"url": "https://example.com", "provider": "firecrawl", "job_id": "job-1", ' + '"only_main_content": false, "mode": "crawl"}' + ) + assert document.name == "" + assert document.doc_form == IndexStructureType.PARENT_CHILD_INDEX + segment_query.filter_by.return_value.update.assert_called_once() + update_task.delay.assert_called_once_with("dataset-1", "doc-1") + + +class TestDocumentServiceCreateValidation: + """Unit tests for document creation validation helpers.""" + + def test_document_create_args_validate_requires_data_source_or_process_rule(self): + knowledge_config = SimpleNamespace(data_source=None, process_rule=None) + + with pytest.raises(ValueError, match="Data source or Process rule is required"): + DocumentService.document_create_args_validate(knowledge_config) + + def test_document_create_args_validate_delegates_to_sub_validators(self): + knowledge_config = SimpleNamespace(data_source=object(), process_rule=object()) + + with ( + patch.object(DocumentService, "data_source_args_validate") as validate_data_source, + patch.object(DocumentService, "process_rule_args_validate") as validate_process_rule, + ): + DocumentService.document_create_args_validate(knowledge_config) + + validate_data_source.assert_called_once_with(knowledge_config) + validate_process_rule.assert_called_once_with(knowledge_config) + + def test_data_source_args_validate_rejects_invalid_type(self): + knowledge_config = SimpleNamespace( + data_source=SimpleNamespace( + info_list=SimpleNamespace( + data_source_type="bad-source", + file_info_list=None, + notion_info_list=None, + website_info_list=None, + ) + ) + ) + + with pytest.raises(ValueError, match="Data source type is invalid"): + DocumentService.data_source_args_validate(knowledge_config) + + @pytest.mark.parametrize( + ("data_source_type", "field_name", "message"), + [ + ("upload_file", "file_info_list", "File source info is required"), + ("notion_import", "notion_info_list", "Notion source info is required"), + ("website_crawl", "website_info_list", "Website source info is required"), + ], + ) + def test_data_source_args_validate_requires_source_specific_info(self, data_source_type, field_name, message): + info_list = SimpleNamespace( + data_source_type=data_source_type, + file_info_list=object(), + notion_info_list=object(), + website_info_list=object(), + ) + setattr(info_list, field_name, None) + knowledge_config = SimpleNamespace(data_source=SimpleNamespace(info_list=info_list)) + + with pytest.raises(ValueError, match=message): + DocumentService.data_source_args_validate(knowledge_config) + + def test_process_rule_args_validate_clears_rules_for_automatic_mode(self): + knowledge_config = KnowledgeConfig( + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="upload_file", + file_info_list=FileInfo(file_ids=["file-1"]), + ) + ), + process_rule=ProcessRule( + mode="automatic", + rules=Rule( + pre_processing_rules=[PreProcessingRule(id="remove_stopwords", enabled=True)], + segmentation=Segmentation(separator="\n", max_tokens=128), + ), + ), + ) + + DocumentService.process_rule_args_validate(knowledge_config) + + assert knowledge_config.process_rule is not None + assert knowledge_config.process_rule.rules is None + + def test_process_rule_args_validate_deduplicates_rules_and_skips_max_tokens_for_full_doc_hierarchical(self): + knowledge_config = KnowledgeConfig( + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="upload_file", + file_info_list=FileInfo(file_ids=["file-1"]), + ) + ), + process_rule=ProcessRule( + mode="hierarchical", + rules=Rule( + pre_processing_rules=[ + PreProcessingRule(id="remove_stopwords", enabled=True), + PreProcessingRule(id="remove_stopwords", enabled=False), + ], + segmentation=Segmentation(separator="\n", max_tokens=0), + parent_mode="full-doc", + ), + ), + ) + + DocumentService.process_rule_args_validate(knowledge_config) + + assert knowledge_config.process_rule is not None + assert knowledge_config.process_rule.rules is not None + assert len(knowledge_config.process_rule.rules.pre_processing_rules) == 1 + assert knowledge_config.process_rule.rules.pre_processing_rules[0].enabled is False + + +class TestDocumentServiceSaveDocumentWithDatasetId: + """Unit tests for non-SQL validation branches in save_document_with_dataset_id.""" + + @pytest.fixture + def account_context(self): + account = create_autospec(Account, instance=True) + account.id = "user-1" + account.current_tenant_id = "tenant-1" + + with ( + patch("services.dataset_service.current_user", account), + patch.object(DatasetService, "check_doc_form"), + ): + yield account + + def test_save_document_with_dataset_id_requires_file_info_for_upload_source(self, account_context): + dataset = _make_dataset() + knowledge_config = _make_upload_knowledge_config(file_ids=None) + + with patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=True)): + with pytest.raises(ValueError, match="File source info is required"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + def test_save_document_with_dataset_id_blocks_batch_upload_for_sandbox_plan(self, account_context): + dataset = _make_dataset() + knowledge_config = _make_upload_knowledge_config(file_ids=["file-1", "file-2"]) + + with ( + patch( + "services.dataset_service.FeatureService.get_features", + return_value=_make_features(enabled=True, plan=CloudPlan.SANDBOX), + ), + patch.object(DocumentService, "check_documents_upload_quota") as check_quota, + ): + with pytest.raises(ValueError, match="does not support batch upload"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + check_quota.assert_not_called() + + def test_save_document_with_dataset_id_enforces_batch_upload_limit(self, account_context): + dataset = _make_dataset() + knowledge_config = _make_upload_knowledge_config(file_ids=["file-1", "file-2"]) + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=True)), + patch("services.dataset_service.dify_config.BATCH_UPLOAD_LIMIT", 1), + patch.object(DocumentService, "check_documents_upload_quota") as check_quota, + ): + with pytest.raises(ValueError, match="batch upload limit of 1"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + check_quota.assert_not_called() + + def test_save_document_with_dataset_id_updates_existing_document_and_data_source_type(self, account_context): + dataset = _make_dataset(data_source_type=None) + knowledge_config = _make_upload_knowledge_config(original_document_id="doc-1", file_ids=["file-1"]) + updated_document = _make_document(document_id="doc-1", batch="batch-existing") + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch.object( + DocumentService, "update_document_with_dataset_id", return_value=updated_document + ) as update_document, + ): + documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + assert dataset.data_source_type == "upload_file" + assert documents == [updated_document] + assert batch == "batch-existing" + update_document.assert_called_once_with(dataset, knowledge_config, account_context) + + def test_save_document_with_dataset_id_requires_data_source_for_new_documents(self, account_context): + dataset = _make_dataset() + knowledge_config = _make_upload_knowledge_config(data_source=None) + + with patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)): + with pytest.raises(ValueError, match="Data source is required when creating new documents"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + def test_save_document_with_dataset_id_requires_existing_process_rule_for_custom_mode(self, account_context): + dataset = _make_dataset(latest_process_rule=None) + knowledge_config = _make_upload_knowledge_config( + file_ids=["file-1"], + process_rule=ProcessRule(mode="custom"), + ) + + with patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)): + with pytest.raises(ValueError, match="No process rule found"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + def test_save_document_with_dataset_id_rejects_invalid_indexing_technique(self, account_context): + dataset = _make_dataset(indexing_technique=None) + knowledge_config = SimpleNamespace( + doc_form=IndexStructureType.PARAGRAPH_INDEX, + original_document_id=None, + data_source=None, + indexing_technique="broken-technique", + ) + + with patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)): + with pytest.raises(ValueError, match="Indexing technique is invalid"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + def test_save_document_with_dataset_id_returns_empty_for_invalid_process_rule_mode(self, account_context): + dataset = _make_dataset() + knowledge_config = _make_upload_knowledge_config(file_ids=["file-1"]) + knowledge_config.process_rule = SimpleNamespace(mode="unsupported-mode", rules=None) + + with patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)): + documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + assert documents == [] + assert batch == "" + + def test_save_document_with_dataset_id_upload_file_creates_and_reindexes_documents(self, account_context): + dataset = _make_dataset() + dataset_process_rule = SimpleNamespace(id="rule-1") + knowledge_config = _make_upload_knowledge_config(file_ids=["file-1", "file-2"]) + duplicate_document = _make_document(document_id="doc-duplicate", name="existing.txt") + created_document = _make_document(document_id="doc-created", name="new.txt") + upload_file_a = SimpleNamespace(id="file-1", name="existing.txt") + upload_file_b = SimpleNamespace(id="file-2", name="new.txt") + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch.object(DocumentService, "get_documents_position", return_value=4), + patch.object(DocumentService, "build_document", return_value=created_document) as build_document, + patch("services.dataset_service.DocumentIndexingTaskProxy") as document_proxy_cls, + patch("services.dataset_service.DuplicateDocumentIndexingTaskProxy") as duplicate_proxy_cls, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.time.strftime", return_value="20260101010101"), + patch("services.dataset_service.secrets.randbelow", return_value=23), + ): + mock_redis.lock.return_value = _make_lock_context() + upload_query = MagicMock() + upload_query.where.return_value.all.return_value = [upload_file_a, upload_file_b] + existing_documents_query = MagicMock() + existing_documents_query.where.return_value.all.return_value = [duplicate_document] + mock_db.session.query.side_effect = [upload_query, existing_documents_query] + + documents, batch = DocumentService.save_document_with_dataset_id( + dataset, + knowledge_config, + account_context, + dataset_process_rule=dataset_process_rule, + ) + + assert documents == [duplicate_document, created_document] + assert batch == "20260101010101100023" + assert duplicate_document.dataset_process_rule_id == "rule-1" + assert duplicate_document.updated_at == "now" + assert duplicate_document.batch == batch + assert duplicate_document.indexing_status == "waiting" + build_document.assert_called_once_with( + dataset, + "rule-1", + "upload_file", + IndexStructureType.PARAGRAPH_INDEX, + "English", + {"upload_file_id": "file-2"}, + "web", + 4, + account_context, + "new.txt", + batch, + ) + document_proxy_cls.assert_called_once_with(dataset.tenant_id, dataset.id, ["doc-created"]) + document_proxy_cls.return_value.delay.assert_called_once() + duplicate_proxy_cls.assert_called_once_with(dataset.tenant_id, dataset.id, ["doc-duplicate"]) + duplicate_proxy_cls.return_value.delay.assert_called_once() + + def test_save_document_with_dataset_id_notion_import_truncates_names_and_cleans_removed_pages( + self, account_context + ): + dataset = _make_dataset() + dataset_process_rule = SimpleNamespace(id="rule-1") + notion_page_name = "a" * 300 + knowledge_config = KnowledgeConfig( + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="notion_import", + notion_info_list=[ + NotionInfo( + credential_id="credential-1", + workspace_id="workspace-1", + pages=[ + NotionPage(page_id="page-keep", page_name="Keep page", type="page"), + NotionPage( + page_id="page-new", + page_name=notion_page_name, + page_icon=NotionIcon(type="emoji", emoji="page"), + type="page", + ), + ], + ) + ], + ) + ), + doc_form=IndexStructureType.PARAGRAPH_INDEX, + doc_language="English", + ) + existing_keep = _make_document(document_id="doc-keep") + existing_keep.data_source_info = json.dumps({"notion_page_id": "page-keep"}) + existing_remove = _make_document(document_id="doc-remove") + existing_remove.data_source_info = json.dumps({"notion_page_id": "page-remove"}) + created_document = _make_document(document_id="doc-new") + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch.object(DocumentService, "get_documents_position", return_value=1), + patch.object(DocumentService, "build_document", return_value=created_document) as build_document, + patch("services.dataset_service.clean_notion_document_task") as clean_task, + patch("services.dataset_service.DocumentIndexingTaskProxy") as document_proxy_cls, + ): + mock_redis.lock.return_value = _make_lock_context() + notion_documents_query = MagicMock() + notion_documents_query.filter_by.return_value.all.return_value = [existing_keep, existing_remove] + mock_db.session.query.return_value = notion_documents_query + + documents, _ = DocumentService.save_document_with_dataset_id( + dataset, + knowledge_config, + account_context, + dataset_process_rule=dataset_process_rule, + ) + + assert created_document in documents + assert len(build_document.call_args.args[9]) == 255 + clean_task.delay.assert_called_once_with(["doc-remove"], dataset.id) + document_proxy_cls.assert_called_once_with(dataset.tenant_id, dataset.id, ["doc-new"]) + document_proxy_cls.return_value.delay.assert_called_once() + + def test_save_document_with_dataset_id_website_crawl_truncates_long_urls(self, account_context): + dataset = _make_dataset() + dataset_process_rule = SimpleNamespace(id="rule-1") + long_url = "https://example.com/" + ("a" * 260) + short_url = "https://example.com/short" + knowledge_config = KnowledgeConfig( + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="website_crawl", + website_info_list=WebsiteInfo( + provider="firecrawl", + job_id="job-1", + urls=[long_url, short_url], + only_main_content=True, + ), + ) + ), + doc_form=IndexStructureType.PARAGRAPH_INDEX, + doc_language="English", + ) + first_document = _make_document(document_id="doc-1") + second_document = _make_document(document_id="doc-2") + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch.object(DocumentService, "get_documents_position", return_value=2), + patch.object( + DocumentService, + "build_document", + side_effect=[first_document, second_document], + ) as build_document, + patch("services.dataset_service.DocumentIndexingTaskProxy") as document_proxy_cls, + ): + mock_redis.lock.return_value = _make_lock_context() + + documents, _ = DocumentService.save_document_with_dataset_id( + dataset, + knowledge_config, + account_context, + dataset_process_rule=dataset_process_rule, + ) + + assert documents == [first_document, second_document] + assert build_document.call_args_list[0].args[9] == long_url[:200] + "..." + assert build_document.call_args_list[1].args[9] == short_url + document_proxy_cls.assert_called_once_with(dataset.tenant_id, dataset.id, ["doc-1", "doc-2"]) + document_proxy_cls.return_value.delay.assert_called_once() + + +class TestDocumentServiceBatchUpdateStatus: + """Unit tests for batch_update_document_status orchestration and helper branches.""" + + def test_prepare_disable_update_requires_completed_document(self): + document = _make_document(indexing_status="waiting") + document.completed_at = None + + with pytest.raises(DocumentIndexingError, match="is not completed"): + DocumentService._prepare_disable_update(document, user=SimpleNamespace(id="user-1"), now="now") + + def test_prepare_archive_update_sets_async_task_for_enabled_document(self): + document = _make_document(enabled=True, archived=False) + + result = DocumentService._prepare_archive_update(document, user=SimpleNamespace(id="user-1"), now="now") + + assert result is not None + assert result["updates"]["archived"] is True + assert result["set_cache"] is True + assert result["async_task"]["args"] == [document.id] + + def test_prepare_unarchive_update_sets_async_task_for_enabled_document(self): + document = _make_document(enabled=True, archived=True) + + result = DocumentService._prepare_unarchive_update(document, now="now") + + assert result is not None + assert result["updates"]["archived"] is False + assert result["set_cache"] is True + assert result["async_task"]["args"] == [document.id] + + def test_batch_update_document_status_rejects_indexing_documents(self): + dataset = _make_dataset() + document = _make_document(name="Busy document") + + with ( + patch.object(DocumentService, "get_document", return_value=document), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + ): + mock_redis.get.return_value = "1" + + with pytest.raises(DocumentIndexingError, match="Busy document is being indexed"): + DocumentService.batch_update_document_status( + dataset, [document.id], "archive", SimpleNamespace(id="user-1") + ) + + mock_db.session.commit.assert_not_called() + + def test_batch_update_document_status_rolls_back_when_commit_fails(self): + dataset = _make_dataset() + document = _make_document(enabled=False) + + with ( + patch.object(DocumentService, "get_document", return_value=document), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + ): + mock_redis.get.return_value = None + mock_db.session.commit.side_effect = RuntimeError("commit failed") + + with pytest.raises(RuntimeError, match="commit failed"): + DocumentService.batch_update_document_status( + dataset, [document.id], "enable", SimpleNamespace(id="user-1") + ) + + mock_db.session.rollback.assert_called_once() + + def test_batch_update_document_status_raises_async_task_error_after_commit(self): + dataset = _make_dataset() + document = _make_document(enabled=False) + + with ( + patch.object(DocumentService, "get_document", return_value=document), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.add_document_to_index_task") as add_task, + ): + mock_redis.get.return_value = None + add_task.delay.side_effect = RuntimeError("task failed") + + with pytest.raises(RuntimeError, match="task failed"): + DocumentService.batch_update_document_status( + dataset, [document.id], "enable", SimpleNamespace(id="user-1") + ) + + mock_db.session.commit.assert_called_once() + mock_redis.setex.assert_called_once_with(f"document_{document.id}_indexing", 600, 1) + + +class TestDocumentServiceTenantAndUpdateEdges: + """Unit tests for tenant-count and update edge cases.""" + + @pytest.fixture + def account_context(self): + account = create_autospec(Account, instance=True) + account.id = "user-1" + account.current_tenant_id = "tenant-1" + + with patch("services.dataset_service.current_user", account): + yield account + + def test_get_tenant_documents_count_returns_query_count(self, account_context): + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.count.return_value = 12 + + result = DocumentService.get_tenant_documents_count() + + assert result == 12 + mock_db.session.query.return_value.where.return_value.count.assert_called_once() + + def test_update_document_with_dataset_id_uses_automatic_process_rule_payload(self, account_context): + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") + document = _make_document() + document_data = KnowledgeConfig( + original_document_id="doc-1", + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="upload_file", + file_info_list=FileInfo(file_ids=["file-1"]), + ) + ), + process_rule=ProcessRule( + mode="automatic", + rules=Rule( + pre_processing_rules=[PreProcessingRule(id="remove_stopwords", enabled=True)], + segmentation=Segmentation(separator="\n", max_tokens=128), + ), + ), + doc_form=IndexStructureType.PARAGRAPH_INDEX, + ) + created_process_rule = SimpleNamespace(id="rule-2") + + with ( + patch.object(DocumentService, "get_document", return_value=document), + patch("services.dataset_service.DatasetProcessRule") as process_rule_cls, + patch.object(DatasetService, "check_dataset_model_setting"), + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.document_indexing_update_task") as update_task, + ): + process_rule_cls.AUTOMATIC_RULES = DatasetProcessRule.AUTOMATIC_RULES + process_rule_cls.return_value = created_process_rule + upload_query = MagicMock() + upload_query.where.return_value.first.return_value = SimpleNamespace(id="file-1", name="upload.txt") + segment_query = MagicMock() + segment_query.filter_by.return_value.update.return_value = 1 + mock_db.session.query.side_effect = [upload_query, segment_query] + + result = DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + + assert result is document + assert document.dataset_process_rule_id == "rule-2" + assert document.name == "upload.txt" + assert process_rule_cls.call_args.kwargs == { + "dataset_id": "dataset-1", + "mode": "automatic", + "rules": json.dumps(DatasetProcessRule.AUTOMATIC_RULES), + "created_by": "user-1", + } + assert mock_db.session.commit.call_count == 3 + update_task.delay.assert_called_once_with("dataset-1", "doc-1") + + def test_update_document_with_dataset_id_requires_upload_file_info(self, account_context): + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") + document_data = KnowledgeConfig( + original_document_id="doc-1", + indexing_technique="economy", + data_source=DataSource(info_list=InfoList(data_source_type="upload_file")), + ) + + with ( + patch.object(DocumentService, "get_document", return_value=_make_document()), + patch.object(DatasetService, "check_dataset_model_setting"), + ): + with pytest.raises(ValueError, match="No file info list found"): + DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + + def test_update_document_with_dataset_id_raises_when_upload_file_is_missing(self, account_context): + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") + document_data = KnowledgeConfig( + original_document_id="doc-1", + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="upload_file", + file_info_list=FileInfo(file_ids=["file-1"]), + ) + ), + ) + + with ( + patch.object(DocumentService, "get_document", return_value=_make_document()), + patch.object(DatasetService, "check_dataset_model_setting"), + patch("services.dataset_service.db") as mock_db, + ): + mock_db.session.query.return_value.where.return_value.first.return_value = None + + with pytest.raises(FileNotExistsError): + DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + + def test_update_document_with_dataset_id_requires_notion_info_list(self, account_context): + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") + document_data = KnowledgeConfig( + original_document_id="doc-1", + indexing_technique="economy", + data_source=DataSource(info_list=InfoList(data_source_type="notion_import")), + ) + + with ( + patch.object(DocumentService, "get_document", return_value=_make_document()), + patch.object(DatasetService, "check_dataset_model_setting"), + ): + with pytest.raises(ValueError, match="No notion info list found"): + DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + + def test_update_document_with_dataset_id_notion_import_updates_page_info(self, account_context): + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") + document = _make_document() + document_data = KnowledgeConfig( + original_document_id="doc-1", + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="notion_import", + notion_info_list=[ + NotionInfo( + credential_id="credential-1", + workspace_id="workspace-1", + pages=[ + NotionPage(page_id="page-1", page_name="Page 1", page_icon=None, type="page"), + NotionPage(page_id="page-2", page_name="Page 2", page_icon=None, type="database"), + ], + ) + ], + ) + ), + doc_form=IndexStructureType.PARAGRAPH_INDEX, + ) + + with ( + patch.object(DocumentService, "get_document", return_value=document), + patch.object(DatasetService, "check_dataset_model_setting"), + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.document_indexing_update_task") as update_task, + ): + binding_query = MagicMock() + binding_query.where.return_value.first.return_value = SimpleNamespace(id="binding-1") + segment_query = MagicMock() + segment_query.filter_by.return_value.update.return_value = 1 + mock_db.session.query.side_effect = [binding_query, segment_query] + + result = DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + + assert result is document + assert document.data_source_type == "notion_import" + assert document.name == "" + assert document.data_source_info == json.dumps( + { + "credential_id": "credential-1", + "notion_workspace_id": "workspace-1", + "notion_page_id": "page-2", + "notion_page_icon": None, + "type": "database", + } + ) + update_task.delay.assert_called_once_with("dataset-1", "doc-1") + + +class TestDocumentServiceSaveWithoutDatasetBilling: + """Unit tests for batch-count and quota branches in save_document_without_dataset_id.""" + + @pytest.fixture + def account_context(self): + account = create_autospec(Account, instance=True) + account.id = "user-1" + account.current_tenant_id = "tenant-1" + + with patch("services.dataset_service.current_user", account): + yield account + + def test_save_document_without_dataset_id_counts_notion_pages_for_quota(self, account_context): + knowledge_config = KnowledgeConfig( + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="notion_import", + notion_info_list=[ + NotionInfo( + credential_id="credential-1", + workspace_id="workspace-1", + pages=[ + NotionPage(page_id="page-1", page_name="Page 1", page_icon=None, type="page"), + NotionPage(page_id="page-2", page_name="Page 2", page_icon=None, type="page"), + ], + ), + NotionInfo( + credential_id="credential-2", + workspace_id="workspace-2", + pages=[NotionPage(page_id="page-3", page_name="Page 3", page_icon=None, type="page")], + ), + ], + ) + ), + ) + created_dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1", name="", description=None) + features = _make_features(enabled=True) + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=features), + patch("services.dataset_service.dify_config.BATCH_UPLOAD_LIMIT", "10"), + patch.object(DocumentService, "check_documents_upload_quota") as check_quota, + patch( + "services.dataset_service.Dataset", + side_effect=lambda **kwargs: created_dataset.__dict__.update(kwargs) or created_dataset, + ), + patch.object( + DocumentService, + "save_document_with_dataset_id", + return_value=([SimpleNamespace(name="Doc")], "batch-1"), + ), + patch("services.dataset_service.db"), + ): + DocumentService.save_document_without_dataset_id("tenant-1", knowledge_config, account_context) + + check_quota.assert_called_once_with(3, features) + + def test_save_document_without_dataset_id_enforces_batch_limit_for_website_urls(self, account_context): + knowledge_config = KnowledgeConfig( + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="website_crawl", + website_info_list=WebsiteInfo( + provider="firecrawl", + job_id="job-1", + urls=["https://example.com/a", "https://example.com/b"], + only_main_content=True, + ), + ) + ), + ) + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=True)), + patch("services.dataset_service.dify_config.BATCH_UPLOAD_LIMIT", "1"), + patch.object(DocumentService, "check_documents_upload_quota") as check_quota, + ): + with pytest.raises(ValueError, match="batch upload limit of 1"): + DocumentService.save_document_without_dataset_id("tenant-1", knowledge_config, account_context) + + check_quota.assert_not_called() + + +class TestDocumentServiceEstimateValidation: + """Unit tests for estimate_args_validate branches.""" + + def test_estimate_args_validate_rejects_missing_info_list(self): + with pytest.raises(ValueError, match="Data source info is required"): + DocumentService.estimate_args_validate({}) + + def test_estimate_args_validate_sets_empty_rules_for_automatic_mode(self): + args = { + "info_list": {"data_source_type": "upload_file"}, + "process_rule": {"mode": "automatic", "rules": {"ignored": True}}, + } + + DocumentService.estimate_args_validate(args) + + assert args["process_rule"]["rules"] == {} + + def test_estimate_args_validate_rejects_unknown_pre_processing_rule_id(self): + args = { + "info_list": {"data_source_type": "upload_file"}, + "process_rule": { + "mode": "custom", + "rules": { + "pre_processing_rules": [{"id": "unknown", "enabled": True}], + "segmentation": {"separator": "\n", "max_tokens": 128}, + }, + }, + } + + with pytest.raises(ValueError, match="pre_processing_rules id is invalid"): + DocumentService.estimate_args_validate(args) + + def test_estimate_args_validate_deduplicates_rules_for_custom_mode(self): + args = { + "info_list": {"data_source_type": "upload_file"}, + "process_rule": { + "mode": "custom", + "rules": { + "pre_processing_rules": [ + {"id": "remove_stopwords", "enabled": True}, + {"id": "remove_stopwords", "enabled": False}, + ], + "segmentation": {"separator": "\n", "max_tokens": 128}, + }, + }, + } + + DocumentService.estimate_args_validate(args) + + assert args["process_rule"]["rules"]["pre_processing_rules"] == [{"id": "remove_stopwords", "enabled": False}] + + def test_estimate_args_validate_requires_summary_index_provider_name(self): + args = { + "info_list": {"data_source_type": "upload_file"}, + "process_rule": { + "mode": "custom", + "rules": { + "pre_processing_rules": [{"id": "remove_stopwords", "enabled": True}], + "segmentation": {"separator": "\n", "max_tokens": 128}, + }, + "summary_index_setting": {"enable": True, "model_name": "summary-model"}, + }, + } + + with pytest.raises(ValueError, match="Summary index model provider name is required"): + DocumentService.estimate_args_validate(args) + + +class TestDocumentServiceSaveDocumentAdditionalBranches: + """Additional unit tests for dataset bootstrap and process-rule branches.""" + + @pytest.fixture + def account_context(self): + account = create_autospec(Account, instance=True) + account.id = "user-1" + account.current_tenant_id = "tenant-1" + + with ( + patch("services.dataset_service.current_user", account), + patch.object(DatasetService, "check_doc_form"), + ): + yield account + + def test_save_document_with_dataset_id_initializes_high_quality_dataset_from_default_embedding_model( + self, account_context + ): + dataset = _make_dataset(data_source_type=None, indexing_technique=None) + knowledge_config = _make_upload_knowledge_config(original_document_id="doc-1", file_ids=["file-1"]) + knowledge_config.indexing_technique = "high_quality" + knowledge_config.embedding_model = None + knowledge_config.embedding_model_provider = None + updated_document = _make_document(batch="batch-existing") + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch( + "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding", + return_value=SimpleNamespace(id="binding-1"), + ) as get_binding, + patch.object(DocumentService, "update_document_with_dataset_id", return_value=updated_document), + ): + model_manager_cls.for_tenant.return_value.get_default_model_instance.return_value = SimpleNamespace( + model_name="default-embedding", + provider="default-provider", + ) + + documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + assert documents == [updated_document] + assert batch == "batch-existing" + assert dataset.data_source_type == "upload_file" + assert dataset.indexing_technique == "high_quality" + assert dataset.embedding_model == "default-embedding" + assert dataset.embedding_model_provider == "default-provider" + assert dataset.collection_binding_id == "binding-1" + assert dataset.retrieval_model == { + "search_method": "semantic_search", + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 4, + "score_threshold_enabled": False, + } + get_binding.assert_called_once_with("default-provider", "default-embedding") + + def test_save_document_with_dataset_id_uses_explicit_embedding_and_retrieval_model(self, account_context): + dataset = _make_dataset(indexing_technique=None) + knowledge_config = _make_upload_knowledge_config(original_document_id="doc-1", file_ids=["file-1"]) + knowledge_config.indexing_technique = "high_quality" + knowledge_config.embedding_model = "explicit-model" + knowledge_config.embedding_model_provider = "explicit-provider" + knowledge_config.retrieval_model = RetrievalModel( + search_method="semantic_search", + reranking_enable=True, + reranking_model=RerankingModel( + reranking_provider_name="rerank-provider", + reranking_model_name="rerank-model", + ), + top_k=7, + score_threshold_enabled=True, + score_threshold=0.3, + ) + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch( + "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding", + return_value=SimpleNamespace(id="binding-2"), + ) as get_binding, + patch.object(DocumentService, "update_document_with_dataset_id", return_value=_make_document()), + ): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + model_manager_cls.for_tenant.return_value.get_default_model_instance.assert_not_called() + get_binding.assert_called_once_with("explicit-provider", "explicit-model") + assert dataset.embedding_model == "explicit-model" + assert dataset.embedding_model_provider == "explicit-provider" + assert dataset.retrieval_model == knowledge_config.retrieval_model.model_dump() + + def test_save_document_with_dataset_id_creates_custom_process_rule_for_new_upload_document(self, account_context): + dataset = _make_dataset() + knowledge_config = _make_upload_knowledge_config( + file_ids=["file-1"], + process_rule=ProcessRule( + mode="custom", + rules=Rule( + pre_processing_rules=[PreProcessingRule(id="remove_stopwords", enabled=True)], + segmentation=Segmentation(separator="\n", max_tokens=128), + ), + ), + ) + created_process_rule = SimpleNamespace(id="rule-custom") + created_document = _make_document(document_id="doc-created", name="file.txt") + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.DatasetProcessRule") as process_rule_cls, + patch.object(DocumentService, "get_documents_position", return_value=3), + patch.object(DocumentService, "build_document", return_value=created_document), + patch("services.dataset_service.DocumentIndexingTaskProxy") as document_proxy_cls, + patch("services.dataset_service.time.strftime", return_value="20260101010101"), + patch("services.dataset_service.secrets.randbelow", return_value=23), + ): + mock_redis.lock.return_value = _make_lock_context() + process_rule_cls.return_value = created_process_rule + upload_query = MagicMock() + upload_query.where.return_value.all.return_value = [SimpleNamespace(id="file-1", name="file.txt")] + existing_documents_query = MagicMock() + existing_documents_query.where.return_value.all.return_value = [] + mock_db.session.query.side_effect = [upload_query, existing_documents_query] + + documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + assert documents == [created_document] + assert batch == "20260101010101100023" + assert process_rule_cls.call_args.kwargs == { + "dataset_id": "dataset-1", + "mode": "custom", + "rules": knowledge_config.process_rule.rules.model_dump_json(), + "created_by": "user-1", + } + document_proxy_cls.assert_called_once_with("tenant-1", "dataset-1", ["doc-created"]) + document_proxy_cls.return_value.delay.assert_called_once() + + def test_save_document_with_dataset_id_creates_automatic_process_rule_for_new_upload_document( + self, account_context + ): + dataset = _make_dataset() + knowledge_config = _make_upload_knowledge_config( + file_ids=["file-1"], + process_rule=ProcessRule(mode="automatic"), + ) + created_process_rule = SimpleNamespace(id="rule-auto") + created_document = _make_document(document_id="doc-created", name="file.txt") + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.DatasetProcessRule") as process_rule_cls, + patch.object(DocumentService, "get_documents_position", return_value=1), + patch.object(DocumentService, "build_document", return_value=created_document), + patch("services.dataset_service.DocumentIndexingTaskProxy"), + patch("services.dataset_service.time.strftime", return_value="20260101010101"), + patch("services.dataset_service.secrets.randbelow", return_value=23), + ): + mock_redis.lock.return_value = _make_lock_context() + process_rule_cls.AUTOMATIC_RULES = DatasetProcessRule.AUTOMATIC_RULES + process_rule_cls.return_value = created_process_rule + upload_query = MagicMock() + upload_query.where.return_value.all.return_value = [SimpleNamespace(id="file-1", name="file.txt")] + existing_documents_query = MagicMock() + existing_documents_query.where.return_value.all.return_value = [] + mock_db.session.query.side_effect = [upload_query, existing_documents_query] + + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + assert process_rule_cls.call_args.kwargs == { + "dataset_id": "dataset-1", + "mode": "automatic", + "rules": json.dumps(DatasetProcessRule.AUTOMATIC_RULES), + "created_by": "user-1", + } + assert mock_db.session.flush.call_count >= 2 + + def test_save_document_with_dataset_id_creates_fallback_automatic_process_rule_when_latest_is_missing( + self, account_context + ): + dataset = _make_dataset(latest_process_rule=None) + knowledge_config = _make_upload_knowledge_config(file_ids=["file-1"], process_rule=None) + created_process_rule = SimpleNamespace(id="rule-fallback") + created_document = _make_document(document_id="doc-created", name="file.txt") + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.DatasetProcessRule") as process_rule_cls, + patch.object(DocumentService, "get_documents_position", return_value=1), + patch.object(DocumentService, "build_document", return_value=created_document), + patch("services.dataset_service.DocumentIndexingTaskProxy"), + patch("services.dataset_service.time.strftime", return_value="20260101010101"), + patch("services.dataset_service.secrets.randbelow", return_value=23), + ): + mock_redis.lock.return_value = _make_lock_context() + process_rule_cls.AUTOMATIC_RULES = DatasetProcessRule.AUTOMATIC_RULES + process_rule_cls.return_value = created_process_rule + upload_query = MagicMock() + upload_query.where.return_value.all.return_value = [SimpleNamespace(id="file-1", name="file.txt")] + existing_documents_query = MagicMock() + existing_documents_query.where.return_value.all.return_value = [] + mock_db.session.query.side_effect = [upload_query, existing_documents_query] + + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + assert process_rule_cls.call_args.kwargs == { + "dataset_id": "dataset-1", + "mode": "automatic", + "rules": json.dumps(DatasetProcessRule.AUTOMATIC_RULES), + "created_by": "user-1", + } + + def test_save_document_with_dataset_id_raises_when_upload_file_lookup_is_incomplete(self, account_context): + dataset = _make_dataset() + knowledge_config = _make_upload_knowledge_config(file_ids=["file-1", "file-2"]) + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch.object(DocumentService, "get_documents_position", return_value=1), + patch("services.dataset_service.time.strftime", return_value="20260101010101"), + patch("services.dataset_service.secrets.randbelow", return_value=23), + ): + mock_redis.lock.return_value = _make_lock_context() + upload_query = MagicMock() + upload_query.where.return_value.all.return_value = [SimpleNamespace(id="file-1", name="file.txt")] + mock_db.session.query.return_value = upload_query + + with pytest.raises(FileNotExistsError, match="One or more files not found"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + def test_save_document_with_dataset_id_requires_notion_info_list_for_notion_import(self, account_context): + dataset = _make_dataset() + knowledge_config = KnowledgeConfig( + indexing_technique="economy", + data_source=DataSource(info_list=InfoList(data_source_type="notion_import")), + doc_form=IndexStructureType.PARAGRAPH_INDEX, + doc_language="English", + ) + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch("services.dataset_service.redis_client") as mock_redis, + patch.object(DocumentService, "get_documents_position", return_value=1), + ): + mock_redis.lock.return_value = _make_lock_context() + with pytest.raises(ValueError, match="No notion info list found"): + DocumentService.save_document_with_dataset_id( + dataset, + knowledge_config, + account_context, + dataset_process_rule=SimpleNamespace(id="rule-1"), + ) + + def test_save_document_with_dataset_id_requires_website_info_list_for_website_crawl(self, account_context): + dataset = _make_dataset() + knowledge_config = KnowledgeConfig( + indexing_technique="economy", + data_source=DataSource(info_list=InfoList(data_source_type="website_crawl")), + doc_form=IndexStructureType.PARAGRAPH_INDEX, + doc_language="English", + ) + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch("services.dataset_service.redis_client") as mock_redis, + patch.object(DocumentService, "get_documents_position", return_value=1), + ): + mock_redis.lock.return_value = _make_lock_context() + with pytest.raises(ValueError, match="No website info list found"): + DocumentService.save_document_with_dataset_id( + dataset, + knowledge_config, + account_context, + dataset_process_rule=SimpleNamespace(id="rule-1"), + ) diff --git a/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py b/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py index bd226f7536d..9a513c3fe64 100644 --- a/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py +++ b/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py @@ -4,6 +4,7 @@ from unittest.mock import Mock, create_autospec import pytest from redis.exceptions import LockNotOwnedError +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.account import Account from models.dataset import Dataset, Document from services.dataset_service import DocumentService, SegmentService @@ -70,16 +71,16 @@ def test_save_document_with_dataset_id_ignores_lock_not_owned( dataset.id = "ds-1" dataset.tenant_id = fake_current_user.current_tenant_id dataset.data_source_type = "upload_file" - dataset.indexing_technique = "high_quality" # so we skip re-initialization branch + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY # so we skip re-initialization branch # Minimal knowledge_config stub that satisfies pre-lock code info_list = types.SimpleNamespace(data_source_type="upload_file") data_source = types.SimpleNamespace(info_list=info_list) knowledge_config = types.SimpleNamespace( - doc_form="qa_model", + doc_form=IndexStructureType.QA_INDEX, original_document_id=None, # go into "new document" branch data_source=data_source, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model=None, embedding_model_provider=None, retrieval_model=None, @@ -125,13 +126,13 @@ def test_add_segment_ignores_lock_not_owned( dataset = create_autospec(Dataset, instance=True) dataset.id = "ds-1" dataset.tenant_id = fake_current_user.current_tenant_id - dataset.indexing_technique = "economy" # skip embedding/token calculation branch + dataset.indexing_technique = IndexTechniqueType.ECONOMY # skip embedding/token calculation branch document = create_autospec(Document, instance=True) document.id = "doc-1" document.dataset_id = dataset.id document.word_count = 0 - document.doc_form = "qa_model" + document.doc_form = IndexStructureType.QA_INDEX # Minimal args required by add_segment args = { @@ -168,10 +169,10 @@ def test_multi_create_segment_ignores_lock_not_owned( dataset = create_autospec(Dataset, instance=True) dataset.id = "ds-1" dataset.tenant_id = fake_current_user.current_tenant_id - dataset.indexing_technique = "economy" # again, skip high_quality path + dataset.indexing_technique = IndexTechniqueType.ECONOMY # again, skip high_quality path document = create_autospec(Document, instance=True) document.id = "doc-1" document.dataset_id = dataset.id document.word_count = 0 - document.doc_form = "qa_model" + document.doc_form = IndexStructureType.QA_INDEX diff --git a/api/tests/unit_tests/services/test_dataset_service_segment.py b/api/tests/unit_tests/services/test_dataset_service_segment.py new file mode 100644 index 00000000000..2f8ae14a8e3 --- /dev/null +++ b/api/tests/unit_tests/services/test_dataset_service_segment.py @@ -0,0 +1,1017 @@ +"""Unit tests for SegmentService behaviors in dataset_service.""" + +from .dataset_service_test_helpers import ( + Account, + ChildChunk, + ChildChunkDeleteIndexError, + ChildChunkIndexingError, + ChildChunkUpdateArgs, + DocumentSegment, + IndexStructureType, + MagicMock, + ModelType, + SegmentService, + SegmentUpdateArgs, + SimpleNamespace, + _make_child_chunk, + _make_dataset, + _make_document, + _make_lock_context, + _make_segment, + create_autospec, + patch, + pytest, +) + + +class TestSegmentServiceChildChunks: + """Unit tests for child-chunk CRUD helpers.""" + + @pytest.fixture + def account_context(self): + account = create_autospec(Account, instance=True) + account.id = "user-1" + account.current_tenant_id = "tenant-1" + + with patch("services.dataset_service.current_user", account): + yield account + + def test_create_child_chunk_assigns_next_position_and_commits(self, account_context): + dataset = SimpleNamespace(id="dataset-1") + document = _make_document() + segment = _make_segment() + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.uuid.uuid4", return_value="node-1"), + patch("services.dataset_service.helper.generate_text_hash", return_value="hash-1"), + patch("services.dataset_service.VectorService") as vector_service, + ): + mock_redis.lock.return_value = _make_lock_context() + mock_db.session.query.return_value.where.return_value.scalar.return_value = 2 + + child_chunk = SegmentService.create_child_chunk("child content", segment, document, dataset) + + assert isinstance(child_chunk, ChildChunk) + assert child_chunk.position == 3 + assert child_chunk.index_node_id == "node-1" + assert child_chunk.index_node_hash == "hash-1" + assert child_chunk.word_count == len("child content") + mock_db.session.add.assert_called_once_with(child_chunk) + vector_service.create_child_chunk_vector.assert_called_once_with(child_chunk, dataset) + mock_db.session.commit.assert_called_once() + + def test_create_child_chunk_rolls_back_and_raises_on_vector_failure(self, account_context): + dataset = SimpleNamespace(id="dataset-1") + document = _make_document() + segment = _make_segment() + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.uuid.uuid4", return_value="node-1"), + patch("services.dataset_service.helper.generate_text_hash", return_value="hash-1"), + patch("services.dataset_service.VectorService") as vector_service, + ): + mock_redis.lock.return_value = _make_lock_context() + mock_db.session.query.return_value.where.return_value.scalar.return_value = None + vector_service.create_child_chunk_vector.side_effect = RuntimeError("vector failed") + + with pytest.raises(ChildChunkIndexingError, match="vector failed"): + SegmentService.create_child_chunk("child content", segment, document, dataset) + + mock_db.session.rollback.assert_called_once() + mock_db.session.commit.assert_not_called() + + def test_update_child_chunks_updates_deletes_and_creates_records(self, account_context): + dataset = SimpleNamespace(id="dataset-1") + document = _make_document() + segment = _make_segment() + existing_a = ChildChunk( + id="child-a", + tenant_id="tenant-1", + dataset_id="dataset-1", + document_id="doc-1", + segment_id="segment-1", + position=1, + content="old content", + word_count=11, + created_by="user-1", + ) + existing_b = ChildChunk( + id="child-b", + tenant_id="tenant-1", + dataset_id="dataset-1", + document_id="doc-1", + segment_id="segment-1", + position=2, + content="remove me", + word_count=9, + created_by="user-1", + ) + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.uuid.uuid4", return_value="node-new"), + patch("services.dataset_service.helper.generate_text_hash", return_value="hash-new"), + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.VectorService") as vector_service, + ): + mock_db.session.scalars.return_value.all.return_value = [existing_a, existing_b] + + result = SegmentService.update_child_chunks( + [ + ChildChunkUpdateArgs(id="child-a", content="updated content"), + ChildChunkUpdateArgs(content="brand new"), + ], + segment, + document, + dataset, + ) + + assert [chunk.position for chunk in result] == [1, 3] + assert existing_a.content == "updated content" + assert existing_a.updated_by == account_context.id + assert existing_a.updated_at == "now" + mock_db.session.bulk_save_objects.assert_called_once_with([existing_a]) + mock_db.session.delete.assert_called_once_with(existing_b) + new_chunk = result[1] + assert isinstance(new_chunk, ChildChunk) + assert new_chunk.position == 3 + assert new_chunk.index_node_id == "node-new" + vector_service.update_child_chunk_vector.assert_called_once_with( + [new_chunk], [existing_a], [existing_b], dataset + ) + mock_db.session.commit.assert_called_once() + + def test_update_child_chunks_rolls_back_on_vector_failure(self, account_context): + dataset = SimpleNamespace(id="dataset-1") + document = _make_document() + segment = _make_segment() + existing_chunk = _make_child_chunk() + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.VectorService") as vector_service, + ): + mock_db.session.scalars.return_value.all.return_value = [existing_chunk] + vector_service.update_child_chunk_vector.side_effect = RuntimeError("vector failed") + + with pytest.raises(ChildChunkIndexingError, match="vector failed"): + SegmentService.update_child_chunks( + [ChildChunkUpdateArgs(id="child-a", content="updated content")], + segment, + document, + dataset, + ) + + mock_db.session.rollback.assert_called_once() + + def test_update_child_chunk_updates_vector_and_commits(self, account_context): + dataset = SimpleNamespace(id="dataset-1") + child_chunk = _make_child_chunk() + + with ( + patch("services.dataset_service.current_user", SimpleNamespace(id="user-1")), + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.VectorService") as vector_service, + ): + result = SegmentService.update_child_chunk( + "new content", child_chunk, _make_segment(), _make_document(), dataset + ) + + assert result is child_chunk + assert child_chunk.content == "new content" + assert child_chunk.word_count == len("new content") + assert child_chunk.updated_by == "user-1" + assert child_chunk.updated_at == "now" + mock_db.session.add.assert_called_once_with(child_chunk) + vector_service.update_child_chunk_vector.assert_called_once_with([], [child_chunk], [], dataset) + mock_db.session.commit.assert_called_once() + + def test_delete_child_chunk_raises_delete_index_error_on_vector_failure(self): + dataset = SimpleNamespace(id="dataset-1") + child_chunk = _make_child_chunk() + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.VectorService") as vector_service, + ): + vector_service.delete_child_chunk_vector.side_effect = RuntimeError("delete failed") + + with pytest.raises(ChildChunkDeleteIndexError, match="delete failed"): + SegmentService.delete_child_chunk(child_chunk, dataset) + + mock_db.session.delete.assert_called_once_with(child_chunk) + mock_db.session.rollback.assert_called_once() + + +class TestSegmentServiceQueries: + """Unit tests for child-chunk and segment query helpers.""" + + @pytest.fixture + def account_context(self): + account = create_autospec(Account, instance=True) + account.id = "user-1" + account.current_tenant_id = "tenant-1" + + with patch("services.dataset_service.current_user", account): + yield account + + def test_get_child_chunks_applies_keyword_filter_and_paginate(self, account_context): + paginated = SimpleNamespace(items=["chunk"], total=1) + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.helper.escape_like_pattern", return_value="escaped") as escape_like, + ): + mock_db.paginate.return_value = paginated + + result = SegmentService.get_child_chunks( + segment_id="segment-1", + document_id="doc-1", + dataset_id="dataset-1", + page=2, + limit=10, + keyword="needle", + ) + + assert result is paginated + escape_like.assert_called_once_with("needle") + mock_db.paginate.assert_called_once() + + def test_get_child_chunk_by_id_returns_only_child_chunk_instances(self): + child_chunk = _make_child_chunk() + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.first.return_value = child_chunk + result = SegmentService.get_child_chunk_by_id("child-a", "tenant-1") + + assert result is child_chunk + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace() + result = SegmentService.get_child_chunk_by_id("child-a", "tenant-1") + + assert result is None + + def test_get_segments_uses_status_and_keyword_filters(self): + paginated = SimpleNamespace(items=["segment"], total=1) + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.helper.escape_like_pattern", return_value="escaped") as escape_like, + ): + mock_db.paginate.return_value = paginated + + items, total = SegmentService.get_segments( + document_id="doc-1", + tenant_id="tenant-1", + status_list=["completed"], + keyword="needle", + page=1, + limit=20, + ) + + assert items == ["segment"] + assert total == 1 + escape_like.assert_called_once_with("needle") + mock_db.paginate.assert_called_once() + + def test_get_segment_by_id_returns_only_document_segment_instances(self): + segment = DocumentSegment( + id="segment-1", + tenant_id="tenant-1", + dataset_id="dataset-1", + document_id="doc-1", + position=1, + content="segment", + word_count=7, + tokens=2, + created_by="user-1", + ) + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.first.return_value = segment + result = SegmentService.get_segment_by_id("segment-1", "tenant-1") + + assert result is segment + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace() + result = SegmentService.get_segment_by_id("segment-1", "tenant-1") + + assert result is None + + def test_get_segments_by_document_and_dataset_returns_scalars_result(self): + segment = DocumentSegment( + id="segment-1", + tenant_id="tenant-1", + dataset_id="dataset-1", + document_id="doc-1", + position=1, + content="segment", + word_count=7, + tokens=2, + created_by="user-1", + ) + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.scalars.return_value.all.return_value = [segment] + + result = SegmentService.get_segments_by_document_and_dataset( + document_id="doc-1", + dataset_id="dataset-1", + status="completed", + enabled=True, + ) + + assert result == [segment] + mock_db.session.scalars.assert_called_once() + + +class TestSegmentServiceValidation: + """Unit tests for segment-create argument validation.""" + + def test_segment_create_args_validate_requires_answer_for_qa_model(self): + document = _make_document(doc_form=IndexStructureType.QA_INDEX) + + with pytest.raises(ValueError, match="Answer is required"): + SegmentService.segment_create_args_validate({"content": "question"}, document) + + def test_segment_create_args_validate_requires_non_empty_content(self): + document = _make_document(doc_form=IndexStructureType.PARAGRAPH_INDEX) + + with pytest.raises(ValueError, match="Content is empty"): + SegmentService.segment_create_args_validate({"content": " "}, document) + + def test_segment_create_args_validate_enforces_attachment_limit(self): + document = _make_document(doc_form=IndexStructureType.PARAGRAPH_INDEX) + args = {"content": "hello", "attachment_ids": ["a-1", "a-2"]} + + with patch("services.dataset_service.dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT", 1): + with pytest.raises(ValueError, match="Exceeded maximum attachment limit of 1"): + SegmentService.segment_create_args_validate(args, document) + + def test_segment_create_args_validate_requires_attachment_ids_list(self): + document = _make_document(doc_form=IndexStructureType.PARAGRAPH_INDEX) + + with pytest.raises(ValueError, match="Attachment IDs is invalid"): + SegmentService.segment_create_args_validate({"content": "hello", "attachment_ids": "bad-type"}, document) + + +class TestSegmentServiceMutations: + """Unit tests for segment create, update, delete, and bulk status flows.""" + + @pytest.fixture + def account_context(self): + account = create_autospec(Account, instance=True) + account.id = "user-1" + account.current_tenant_id = "tenant-1" + + with patch("services.dataset_service.current_user", account): + yield account + + def test_create_segment_creates_bindings_and_marks_segment_error_on_vector_failure(self, account_context): + dataset = _make_dataset(indexing_technique="economy") + document = _make_document( + dataset_id=dataset.id, + tenant_id=dataset.tenant_id, + doc_form=IndexStructureType.QA_INDEX, + word_count=0, + ) + refreshed_segment = SimpleNamespace(id="segment-1") + args = { + "content": "question", + "answer": "answer", + "keywords": ["kw-1"], + "attachment_ids": ["att-1", "att-2"], + } + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.VectorService") as vector_service, + patch("services.dataset_service.helper.generate_text_hash", return_value="hash-1"), + patch("services.dataset_service.uuid.uuid4", return_value="node-1"), + patch("services.dataset_service.naive_utc_now", return_value="now"), + ): + mock_redis.lock.return_value = _make_lock_context() + + max_position_query = MagicMock() + max_position_query.where.return_value.scalar.return_value = 2 + refresh_query = MagicMock() + refresh_query.where.return_value.first.return_value = refreshed_segment + mock_db.session.query.side_effect = [max_position_query, refresh_query] + + def add_side_effect(obj): + if obj.__class__.__name__ == "DocumentSegment" and getattr(obj, "id", None) is None: + obj.id = "segment-1" + + mock_db.session.add.side_effect = add_side_effect + vector_service.create_segments_vector.side_effect = RuntimeError("vector failed") + + result = SegmentService.create_segment(args=args, document=document, dataset=dataset) + + created_segment = vector_service.create_segments_vector.call_args.args[1][0] + attachment_bindings = [ + call.args[0] + for call in mock_db.session.add.call_args_list + if call.args and call.args[0].__class__.__name__ == "SegmentAttachmentBinding" + ] + + assert result is refreshed_segment + assert created_segment.position == 3 + assert created_segment.answer == "answer" + assert created_segment.word_count == len("question") + len("answer") + assert created_segment.status == "error" + assert created_segment.enabled is False + assert created_segment.error == "vector failed" + assert document.word_count == len("question") + len("answer") + assert len(attachment_bindings) == 2 + assert {binding.attachment_id for binding in attachment_bindings} == {"att-1", "att-2"} + assert mock_db.session.commit.call_count == 3 + + def test_multi_create_segment_high_quality_marks_segments_error_when_vector_creation_fails(self, account_context): + dataset = _make_dataset(indexing_technique="high_quality") + document = _make_document( + dataset_id=dataset.id, + tenant_id=dataset.tenant_id, + doc_form=IndexStructureType.QA_INDEX, + word_count=5, + ) + segments = [ + {"content": "question-1", "answer": "answer-1", "keywords": ["k1"]}, + {"content": "question-2", "answer": "answer-2"}, + ] + embedding_model = MagicMock() + embedding_model.get_text_embedding_num_tokens.side_effect = [[11], [13]] + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch("services.dataset_service.VectorService") as vector_service, + patch("services.dataset_service.helper.generate_text_hash", side_effect=["hash-1", "hash-2"]), + patch("services.dataset_service.uuid.uuid4", side_effect=["node-1", "node-2"]), + patch("services.dataset_service.naive_utc_now", return_value="now"), + ): + mock_redis.lock.return_value = _make_lock_context() + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model + mock_db.session.query.return_value.where.return_value.scalar.return_value = 1 + vector_service.create_segments_vector.side_effect = RuntimeError("vector failed") + + result = SegmentService.multi_create_segment(segments, document, dataset) + + assert len(result) == 2 + assert [segment.position for segment in result] == [2, 3] + assert [segment.tokens for segment in result] == [11, 13] + assert all(segment.status == "error" for segment in result) + assert all(segment.enabled is False for segment in result) + assert all(segment.error == "vector failed" for segment in result) + assert document.word_count == 5 + sum(len(item["content"]) + len(item["answer"]) for item in segments) + vector_service.create_segments_vector.assert_called_once_with( + [["k1"], None], result, dataset, document.doc_form + ) + mock_db.session.commit.assert_called_once() + + def test_update_segment_disables_enabled_segment_and_dispatches_index_cleanup(self, account_context): + segment = _make_segment(enabled=True) + document = _make_document() + dataset = _make_dataset() + args = SegmentUpdateArgs(enabled=False) + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.disable_segment_from_index_task") as disable_task, + ): + mock_redis.get.return_value = None + + result = SegmentService.update_segment(args, segment, document, dataset) + + assert result is segment + assert segment.enabled is False + assert segment.disabled_at == "now" + assert segment.disabled_by == account_context.id + mock_db.session.add.assert_called_once_with(segment) + mock_db.session.commit.assert_called_once() + mock_redis.setex.assert_called_once_with(f"segment_{segment.id}_indexing", 600, 1) + disable_task.delay.assert_called_once_with(segment.id) + + def test_update_segment_rejects_updates_for_disabled_segment(self, account_context): + segment = _make_segment(enabled=False) + document = _make_document() + dataset = _make_dataset() + + with patch("services.dataset_service.redis_client") as mock_redis: + mock_redis.get.return_value = None + + with pytest.raises(ValueError, match="Can't update disabled segment"): + SegmentService.update_segment(SegmentUpdateArgs(content="new content"), segment, document, dataset) + + def test_update_segment_rejects_when_indexing_cache_exists(self, account_context): + segment = _make_segment(enabled=True) + document = _make_document() + dataset = _make_dataset() + + with patch("services.dataset_service.redis_client") as mock_redis: + mock_redis.get.return_value = "1" + + with pytest.raises(ValueError, match="Segment is indexing"): + SegmentService.update_segment(SegmentUpdateArgs(content="new content"), segment, document, dataset) + + def test_update_segment_updates_keywords_for_same_content_segment(self, account_context): + segment = _make_segment(content="same content", keywords=["old"]) + document = _make_document(doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=20) + dataset = _make_dataset() + refreshed_segment = SimpleNamespace(id=segment.id) + args = SegmentUpdateArgs(content="same content", keywords=["new"]) + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.VectorService") as vector_service, + ): + mock_redis.get.return_value = None + mock_db.session.query.return_value.where.return_value.first.return_value = refreshed_segment + + result = SegmentService.update_segment(args, segment, document, dataset) + + assert result is refreshed_segment + assert segment.keywords == ["new"] + vector_service.update_segment_vector.assert_called_once_with(["new"], segment, dataset) + vector_service.update_multimodel_vector.assert_called_once_with(segment, [], dataset) + + def test_update_segment_regenerates_child_chunks_and_updates_manual_summary(self, account_context): + segment = _make_segment(content="same content", word_count=len("same content")) + document = _make_document( + doc_form=IndexStructureType.PARENT_CHILD_INDEX, + word_count=20, + ) + dataset = _make_dataset(indexing_technique="high_quality") + refreshed_segment = SimpleNamespace(id=segment.id) + processing_rule = SimpleNamespace(id=document.dataset_process_rule_id) + existing_summary = SimpleNamespace(summary_content="old summary") + embedding_model_instance = object() + args = SegmentUpdateArgs( + content="same content", + regenerate_child_chunks=True, + summary="new summary", + ) + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch("services.dataset_service.VectorService") as vector_service, + patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary, + ): + mock_redis.get.return_value = None + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model_instance + + processing_rule_query = MagicMock() + processing_rule_query.where.return_value.first.return_value = processing_rule + summary_query = MagicMock() + summary_query.where.return_value.first.return_value = existing_summary + refreshed_query = MagicMock() + refreshed_query.where.return_value.first.return_value = refreshed_segment + mock_db.session.query.side_effect = [processing_rule_query, summary_query, refreshed_query] + + result = SegmentService.update_segment(args, segment, document, dataset) + + assert result is refreshed_segment + vector_service.generate_child_chunks.assert_called_once_with( + segment, + document, + dataset, + embedding_model_instance, + processing_rule, + True, + ) + update_summary.assert_called_once_with(segment, dataset, "new summary") + vector_service.update_multimodel_vector.assert_called_once_with(segment, [], dataset) + + def test_update_segment_auto_regenerates_summary_after_content_change(self, account_context): + segment = _make_segment(content="old", word_count=3) + document = _make_document(doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=10) + dataset = _make_dataset(indexing_technique="high_quality") + dataset.summary_index_setting = {"enable": True} + refreshed_segment = SimpleNamespace(id=segment.id) + existing_summary = SimpleNamespace(summary_content="old summary") + embedding_model = MagicMock() + embedding_model.get_text_embedding_num_tokens.return_value = [9] + args = SegmentUpdateArgs(content="new content", keywords=["kw-1"]) + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch("services.dataset_service.VectorService") as vector_service, + patch("services.dataset_service.helper.generate_text_hash", return_value="hash-1"), + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch( + "services.summary_index_service.SummaryIndexService.generate_and_vectorize_summary" + ) as generate_summary, + ): + mock_redis.get.return_value = None + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model + + summary_query = MagicMock() + summary_query.where.return_value.first.return_value = existing_summary + refreshed_query = MagicMock() + refreshed_query.where.return_value.first.return_value = refreshed_segment + mock_db.session.query.side_effect = [summary_query, refreshed_query] + + result = SegmentService.update_segment(args, segment, document, dataset) + + assert result is refreshed_segment + assert segment.content == "new content" + assert segment.index_node_hash == "hash-1" + assert segment.tokens == 9 + assert document.word_count == 18 + vector_service.update_segment_vector.assert_called_once_with(["kw-1"], segment, dataset) + generate_summary.assert_called_once_with(segment, dataset, {"enable": True}) + vector_service.update_multimodel_vector.assert_called_once_with(segment, [], dataset) + + def test_update_segment_regenerates_summary_when_manual_summary_is_unchanged(self, account_context): + segment = _make_segment(content="old", word_count=3) + document = _make_document(doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=10) + dataset = _make_dataset(indexing_technique="high_quality") + dataset.summary_index_setting = {"enable": True} + refreshed_segment = SimpleNamespace(id=segment.id) + existing_summary = SimpleNamespace(summary_content="same summary") + embedding_model = MagicMock() + embedding_model.get_text_embedding_num_tokens.return_value = [7] + args = SegmentUpdateArgs(content="new text", summary="same summary") + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch("services.dataset_service.VectorService") as vector_service, + patch("services.dataset_service.helper.generate_text_hash", return_value="hash-2"), + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch( + "services.summary_index_service.SummaryIndexService.generate_and_vectorize_summary" + ) as generate_summary, + patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary, + ): + mock_redis.get.return_value = None + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model + + summary_query = MagicMock() + summary_query.where.return_value.first.return_value = existing_summary + refreshed_query = MagicMock() + refreshed_query.where.return_value.first.return_value = refreshed_segment + mock_db.session.query.side_effect = [summary_query, refreshed_query] + + result = SegmentService.update_segment(args, segment, document, dataset) + + assert result is refreshed_segment + generate_summary.assert_called_once_with(segment, dataset, {"enable": True}) + update_summary.assert_not_called() + vector_service.update_multimodel_vector.assert_called_once_with(segment, [], dataset) + + def test_delete_segment_removes_index_and_updates_document_word_count(self): + segment = _make_segment(word_count=4, index_node_id="parent-node") + document = _make_document(word_count=10) + dataset = _make_dataset() + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.delete_segment_from_index_task") as delete_task, + ): + mock_redis.get.return_value = None + mock_db.session.query.return_value.where.return_value.all.return_value = [("child-1",), ("child-2",)] + + SegmentService.delete_segment(segment, document, dataset) + + assert document.word_count == 6 + mock_redis.setex.assert_called_once_with(f"segment_{segment.id}_delete_indexing", 600, 1) + delete_task.delay.assert_called_once_with( + ["parent-node"], + dataset.id, + document.id, + [segment.id], + ["child-1", "child-2"], + ) + mock_db.session.delete.assert_called_once_with(segment) + mock_db.session.add.assert_called_once_with(document) + mock_db.session.commit.assert_called_once() + + def test_delete_segment_rejects_when_delete_is_already_in_progress(self): + segment = _make_segment() + document = _make_document() + dataset = _make_dataset() + + with patch("services.dataset_service.redis_client") as mock_redis: + mock_redis.get.return_value = "1" + + with pytest.raises(ValueError, match="Segment is deleting"): + SegmentService.delete_segment(segment, document, dataset) + + def test_delete_segments_removes_records_and_clamps_document_word_count(self): + dataset = _make_dataset() + document = _make_document(word_count=3) + current_user = SimpleNamespace(current_tenant_id="tenant-1") + + with ( + patch("services.dataset_service.current_user", current_user), + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.delete_segment_from_index_task") as delete_task, + ): + segments_query = MagicMock() + segments_query.with_entities.return_value.where.return_value.all.return_value = [ + ("node-1", "segment-1", 2), + ("node-2", "segment-2", 5), + ] + child_query = MagicMock() + child_query.where.return_value.all.return_value = [("child-1",)] + delete_query = MagicMock() + delete_query.where.return_value.delete.return_value = 2 + mock_db.session.query.side_effect = [segments_query, child_query, delete_query] + + SegmentService.delete_segments(["segment-1", "segment-2"], document, dataset) + + assert document.word_count == 0 + mock_db.session.add.assert_called_once_with(document) + delete_task.delay.assert_called_once_with( + ["node-1", "node-2"], + dataset.id, + document.id, + ["segment-1", "segment-2"], + ["child-1"], + ) + delete_query.where.return_value.delete.assert_called_once() + mock_db.session.commit.assert_called_once() + + def test_update_segments_status_enables_only_segments_without_indexing_cache(self): + dataset = _make_dataset() + document = _make_document() + segment_a = _make_segment(segment_id="segment-a", enabled=False) + segment_b = _make_segment(segment_id="segment-b", enabled=False) + current_user = SimpleNamespace(id="user-1", current_tenant_id="tenant-1") + + with ( + patch("services.dataset_service.current_user", current_user), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.enable_segments_to_index_task") as enable_task, + ): + mock_db.session.scalars.return_value.all.return_value = [segment_a, segment_b] + mock_redis.get.side_effect = [None, "1"] + + SegmentService.update_segments_status(["segment-a", "segment-b"], "enable", dataset, document) + + assert segment_a.enabled is True + assert segment_a.disabled_at is None + assert segment_a.disabled_by is None + assert segment_b.enabled is False + mock_db.session.add.assert_called_once_with(segment_a) + mock_db.session.commit.assert_called_once() + enable_task.delay.assert_called_once_with(["segment-a"], dataset.id, document.id) + + def test_update_segments_status_disables_only_segments_without_indexing_cache(self): + dataset = _make_dataset() + document = _make_document() + segment_a = _make_segment(segment_id="segment-a", enabled=True) + segment_b = _make_segment(segment_id="segment-b", enabled=True) + current_user = SimpleNamespace(id="user-1", current_tenant_id="tenant-1") + + with ( + patch("services.dataset_service.current_user", current_user), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.disable_segments_from_index_task") as disable_task, + ): + mock_db.session.scalars.return_value.all.return_value = [segment_a, segment_b] + mock_redis.get.side_effect = [None, "1"] + + SegmentService.update_segments_status(["segment-a", "segment-b"], "disable", dataset, document) + + assert segment_a.enabled is False + assert segment_a.disabled_at == "now" + assert segment_a.disabled_by == current_user.id + assert segment_b.enabled is True + mock_db.session.add.assert_called_once_with(segment_a) + mock_db.session.commit.assert_called_once() + disable_task.delay.assert_called_once_with(["segment-a"], dataset.id, document.id) + + +class TestSegmentServiceChildChunkTailHelpers: + """Unit tests for the remaining child-chunk helper branches.""" + + def test_update_child_chunk_rolls_back_on_vector_failure(self): + dataset = SimpleNamespace(id="dataset-1") + child_chunk = _make_child_chunk() + + with ( + patch("services.dataset_service.current_user", SimpleNamespace(id="user-1")), + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.VectorService") as vector_service, + ): + vector_service.update_child_chunk_vector.side_effect = RuntimeError("vector failed") + + with pytest.raises(ChildChunkIndexingError, match="vector failed"): + SegmentService.update_child_chunk( + "new content", child_chunk, SimpleNamespace(), SimpleNamespace(), dataset + ) + + mock_db.session.rollback.assert_called_once() + mock_db.session.commit.assert_not_called() + + def test_delete_child_chunk_commits_after_successful_vector_delete(self): + dataset = SimpleNamespace(id="dataset-1") + child_chunk = _make_child_chunk() + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.VectorService") as vector_service, + ): + SegmentService.delete_child_chunk(child_chunk, dataset) + + mock_db.session.delete.assert_called_once_with(child_chunk) + vector_service.delete_child_chunk_vector.assert_called_once_with(child_chunk, dataset) + mock_db.session.commit.assert_called_once() + + +class TestSegmentServiceAdditionalRegenerationBranches: + """Additional unit tests for segment update and regeneration edge cases.""" + + @pytest.fixture + def account_context(self): + account = create_autospec(Account, instance=True) + account.id = "user-1" + account.current_tenant_id = "tenant-1" + + with patch("services.dataset_service.current_user", account): + yield account + + def test_update_segment_same_content_updates_answer_and_document_word_count_for_qa_segments(self, account_context): + segment = _make_segment(content="question", word_count=8) + document = _make_document(doc_form=IndexStructureType.QA_INDEX, word_count=20) + dataset = _make_dataset() + refreshed_segment = SimpleNamespace(id=segment.id) + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.VectorService") as vector_service, + ): + mock_redis.get.return_value = None + mock_db.session.query.return_value.where.return_value.first.return_value = refreshed_segment + + result = SegmentService.update_segment( + SegmentUpdateArgs(content="question", answer="new answer"), + segment, + document, + dataset, + ) + + assert result is refreshed_segment + assert segment.answer == "new answer" + assert segment.word_count == len("question") + len("new answer") + assert document.word_count == 20 + (len("question") + len("new answer") - 8) + vector_service.update_segment_vector.assert_not_called() + vector_service.update_multimodel_vector.assert_called_once_with(segment, [], dataset) + + def test_update_segment_content_change_uses_answer_when_counting_tokens_for_qa_segments(self, account_context): + segment = _make_segment(content="old", word_count=3) + document = _make_document(doc_form=IndexStructureType.QA_INDEX, word_count=10) + dataset = _make_dataset(indexing_technique="high_quality") + refreshed_segment = SimpleNamespace(id=segment.id) + embedding_model = MagicMock() + embedding_model.get_text_embedding_num_tokens.return_value = [21] + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch("services.dataset_service.VectorService") as vector_service, + patch("services.dataset_service.helper.generate_text_hash", return_value="hash-qa"), + patch("services.dataset_service.naive_utc_now", return_value="now"), + ): + mock_redis.get.return_value = None + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model + summary_query = MagicMock() + summary_query.where.return_value.first.return_value = None + refreshed_query = MagicMock() + refreshed_query.where.return_value.first.return_value = refreshed_segment + mock_db.session.query.side_effect = [summary_query, refreshed_query] + + result = SegmentService.update_segment( + SegmentUpdateArgs(content="new question", answer="new answer", keywords=["kw-1"]), + segment, + document, + dataset, + ) + + assert result is refreshed_segment + embedding_model.get_text_embedding_num_tokens.assert_called_once_with(texts=["new questionnew answer"]) + assert segment.answer == "new answer" + assert segment.tokens == 21 + assert segment.word_count == len("new question") + len("new answer") + vector_service.update_segment_vector.assert_called_once_with(["kw-1"], segment, dataset) + vector_service.update_multimodel_vector.assert_called_once_with(segment, [], dataset) + + def test_update_segment_content_change_parent_child_uses_default_embedding_and_ignores_summary_failures( + self, account_context + ): + segment = _make_segment(content="old", word_count=3) + document = _make_document( + doc_form=IndexStructureType.PARENT_CHILD_INDEX, + word_count=10, + ) + dataset = _make_dataset(indexing_technique="high_quality") + dataset.embedding_model_provider = None + refreshed_segment = SimpleNamespace(id=segment.id) + processing_rule = SimpleNamespace(id=document.dataset_process_rule_id) + existing_summary = SimpleNamespace(summary_content="old summary") + embedding_model_instance = object() + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch("services.dataset_service.VectorService") as vector_service, + patch("services.dataset_service.helper.generate_text_hash", return_value="hash-parent"), + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary, + ): + mock_redis.get.return_value = None + model_manager_cls.for_tenant.return_value.get_default_model_instance.return_value = embedding_model_instance + update_summary.side_effect = RuntimeError("summary failed") + + processing_rule_query = MagicMock() + processing_rule_query.where.return_value.first.return_value = processing_rule + summary_query = MagicMock() + summary_query.where.return_value.first.return_value = existing_summary + refreshed_query = MagicMock() + refreshed_query.where.return_value.first.return_value = refreshed_segment + mock_db.session.query.side_effect = [processing_rule_query, summary_query, refreshed_query] + + result = SegmentService.update_segment( + SegmentUpdateArgs(content="new parent content", regenerate_child_chunks=True, summary="new summary"), + segment, + document, + dataset, + ) + + assert result is refreshed_segment + model_manager_cls.for_tenant.return_value.get_default_model_instance.assert_called_once_with( + tenant_id="tenant-1", + model_type=ModelType.TEXT_EMBEDDING, + ) + vector_service.generate_child_chunks.assert_called_once_with( + segment, + document, + dataset, + embedding_model_instance, + processing_rule, + True, + ) + update_summary.assert_called_once_with(segment, dataset, "new summary") + vector_service.update_multimodel_vector.assert_called_once_with(segment, [], dataset) + + def test_update_segment_same_content_parent_child_marks_segment_error_for_non_high_quality_dataset( + self, account_context + ): + segment = _make_segment(content="same content", word_count=len("same content")) + document = _make_document( + doc_form=IndexStructureType.PARENT_CHILD_INDEX, + word_count=20, + ) + dataset = _make_dataset(indexing_technique="economy") + refreshed_segment = SimpleNamespace(id=segment.id) + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.VectorService") as vector_service, + ): + mock_redis.get.return_value = None + mock_db.session.query.return_value.where.return_value.first.return_value = refreshed_segment + + result = SegmentService.update_segment( + SegmentUpdateArgs(content="same content", regenerate_child_chunks=True), + segment, + document, + dataset, + ) + + assert result is refreshed_segment + assert segment.enabled is False + assert segment.disabled_at == "now" + assert segment.status == "error" + assert segment.error == "The knowledge base index technique is not high quality!" + vector_service.update_multimodel_vector.assert_not_called() diff --git a/api/tests/unit_tests/services/test_datasource_provider_service.py b/api/tests/unit_tests/services/test_datasource_provider_service.py index 105ef7ba482..3df7d500cf2 100644 --- a/api/tests/unit_tests/services/test_datasource_provider_service.py +++ b/api/tests/unit_tests/services/test_datasource_provider_service.py @@ -1,10 +1,10 @@ from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.entities.provider_entities import FormType from sqlalchemy.orm import Session from core.plugin.entities.plugin_daemon import CredentialType -from dify_graph.model_runtime.entities.provider_entities import FormType from models.account import Account from models.model import EndUser from models.oauth import DatasourceProvider diff --git a/api/tests/unit_tests/services/test_delete_archived_workflow_run.py b/api/tests/unit_tests/services/test_delete_archived_workflow_run.py deleted file mode 100644 index a7e1a011f67..00000000000 --- a/api/tests/unit_tests/services/test_delete_archived_workflow_run.py +++ /dev/null @@ -1,57 +0,0 @@ -""" -Unit tests for archived workflow run deletion service. -""" - -from unittest.mock import MagicMock, patch - - -class TestArchivedWorkflowRunDeletion: - def test_delete_by_run_id_calls_delete_run(self): - from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion - - deleter = ArchivedWorkflowRunDeletion() - repo = MagicMock() - repo.get_archived_run_ids.return_value = {"run-1"} - run = MagicMock() - run.id = "run-1" - run.tenant_id = "tenant-1" - - session = MagicMock() - session.get.return_value = run - - session_maker = MagicMock() - session_maker.return_value.__enter__.return_value = session - session_maker.return_value.__exit__.return_value = None - mock_db = MagicMock() - mock_db.engine = MagicMock() - - with ( - patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db), - patch( - "services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", - return_value=session_maker, - autospec=True, - ), - patch.object(deleter, "_get_workflow_run_repo", return_value=repo, autospec=True), - patch.object( - deleter, "_delete_run", return_value=MagicMock(success=True), autospec=True - ) as mock_delete_run, - ): - result = deleter.delete_by_run_id("run-1") - - assert result.success is True - mock_delete_run.assert_called_once_with(run) - - def test_delete_run_dry_run(self): - from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion - - deleter = ArchivedWorkflowRunDeletion(dry_run=True) - run = MagicMock() - run.id = "run-1" - run.tenant_id = "tenant-1" - - with patch.object(deleter, "_get_workflow_run_repo", autospec=True) as mock_get_repo: - result = deleter._delete_run(run) - - assert result.success is True - mock_get_repo.assert_not_called() diff --git a/api/tests/unit_tests/services/test_document_service_display_status.py b/api/tests/unit_tests/services/test_document_service_display_status.py deleted file mode 100644 index cb2e2940c88..00000000000 --- a/api/tests/unit_tests/services/test_document_service_display_status.py +++ /dev/null @@ -1,8 +0,0 @@ -from services.dataset_service import DocumentService - - -def test_normalize_display_status_alias_mapping(): - assert DocumentService.normalize_display_status("ACTIVE") == "available" - assert DocumentService.normalize_display_status("enabled") == "available" - assert DocumentService.normalize_display_status("archived") == "archived" - assert DocumentService.normalize_display_status("unknown") is None diff --git a/api/tests/unit_tests/services/test_end_user_service.py b/api/tests/unit_tests/services/test_end_user_service.py deleted file mode 100644 index a3b1f46436c..00000000000 --- a/api/tests/unit_tests/services/test_end_user_service.py +++ /dev/null @@ -1,841 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest - -from core.app.entities.app_invoke_entities import InvokeFrom -from models.model import App, DefaultEndUserSessionID, EndUser -from services.end_user_service import EndUserService - - -class TestEndUserServiceFactory: - """Factory class for creating test data and mock objects for end user service tests.""" - - @staticmethod - def create_app_mock( - app_id: str = "app-123", - tenant_id: str = "tenant-456", - name: str = "Test App", - ) -> MagicMock: - """Create a mock App object.""" - app = MagicMock(spec=App) - app.id = app_id - app.tenant_id = tenant_id - app.name = name - return app - - @staticmethod - def create_end_user_mock( - user_id: str = "user-789", - tenant_id: str = "tenant-456", - app_id: str = "app-123", - session_id: str = "session-001", - type: InvokeFrom = InvokeFrom.SERVICE_API, - is_anonymous: bool = False, - ) -> MagicMock: - """Create a mock EndUser object.""" - end_user = MagicMock(spec=EndUser) - end_user.id = user_id - end_user.tenant_id = tenant_id - end_user.app_id = app_id - end_user.session_id = session_id - end_user.type = type - end_user.is_anonymous = is_anonymous - end_user.external_user_id = session_id - return end_user - - -class TestEndUserServiceGetEndUserById: - """Unit tests for EndUserService.get_end_user_by_id method.""" - - @pytest.fixture - def factory(self): - """Provide test data factory.""" - return TestEndUserServiceFactory() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_get_end_user_by_id_success(self, mock_db, mock_session_class, factory): - """Test successful retrieval of end user by ID.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - end_user_id = "user-789" - - mock_end_user = factory.create_end_user_mock(user_id=end_user_id, tenant_id=tenant_id, app_id=app_id) - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = mock_end_user - - # Act - result = EndUserService.get_end_user_by_id(tenant_id=tenant_id, app_id=app_id, end_user_id=end_user_id) - - # Assert - assert result == mock_end_user - mock_session.query.assert_called_once_with(EndUser) - mock_query.where.assert_called_once() - mock_query.first.assert_called_once() - mock_context.__enter__.assert_called_once() - mock_context.__exit__.assert_called_once() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_get_end_user_by_id_not_found(self, mock_db, mock_session_class): - """Test retrieval of non-existent end user returns None.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - end_user_id = "user-789" - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None - - # Act - result = EndUserService.get_end_user_by_id(tenant_id=tenant_id, app_id=app_id, end_user_id=end_user_id) - - # Assert - assert result is None - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_get_end_user_by_id_query_parameters(self, mock_db, mock_session_class): - """Test that query parameters are correctly applied.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - end_user_id = "user-789" - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None - - # Act - EndUserService.get_end_user_by_id(tenant_id=tenant_id, app_id=app_id, end_user_id=end_user_id) - - # Assert - # Verify the where clause was called with the correct conditions - call_args = mock_query.where.call_args[0] - assert len(call_args) == 3 - # Check that the conditions match the expected filters - # (We can't easily test the exact conditions without importing SQLAlchemy) - - -class TestEndUserServiceGetOrCreateEndUser: - """Unit tests for EndUserService.get_or_create_end_user method.""" - - @pytest.fixture - def factory(self): - """Provide test data factory.""" - return TestEndUserServiceFactory() - - @patch("services.end_user_service.EndUserService.get_or_create_end_user_by_type") - def test_get_or_create_end_user_with_user_id(self, mock_get_or_create_by_type, factory): - """Test get_or_create_end_user with specific user_id.""" - # Arrange - app_mock = factory.create_app_mock() - user_id = "user-123" - expected_end_user = factory.create_end_user_mock() - mock_get_or_create_by_type.return_value = expected_end_user - - # Act - result = EndUserService.get_or_create_end_user(app_mock, user_id) - - # Assert - assert result == expected_end_user - mock_get_or_create_by_type.assert_called_once_with( - InvokeFrom.SERVICE_API, app_mock.tenant_id, app_mock.id, user_id - ) - - @patch("services.end_user_service.EndUserService.get_or_create_end_user_by_type") - def test_get_or_create_end_user_without_user_id(self, mock_get_or_create_by_type, factory): - """Test get_or_create_end_user without user_id (None).""" - # Arrange - app_mock = factory.create_app_mock() - expected_end_user = factory.create_end_user_mock() - mock_get_or_create_by_type.return_value = expected_end_user - - # Act - result = EndUserService.get_or_create_end_user(app_mock, None) - - # Assert - assert result == expected_end_user - mock_get_or_create_by_type.assert_called_once_with( - InvokeFrom.SERVICE_API, app_mock.tenant_id, app_mock.id, None - ) - - -class TestEndUserServiceGetOrCreateEndUserByType: - """ - Unit tests for EndUserService.get_or_create_end_user_by_type method. - - This test suite covers: - - Creating end users with different InvokeFrom types - - Type migration for legacy users - - Query ordering and prioritization - - Session management - """ - - @pytest.fixture - def factory(self): - """Provide test data factory.""" - return TestEndUserServiceFactory() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_new_end_user_with_user_id(self, mock_db, mock_session_class, factory): - """Test creating a new end user with specific user_id.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.first.return_value = None # No existing user - - # Act - result = EndUserService.get_or_create_end_user_by_type( - type=type_enum, tenant_id=tenant_id, app_id=app_id, user_id=user_id - ) - - # Assert - # Verify new EndUser was created with correct parameters - mock_session.add.assert_called_once() - mock_session.commit.assert_called_once() - added_user = mock_session.add.call_args[0][0] - assert added_user.tenant_id == tenant_id - assert added_user.app_id == app_id - assert added_user.type == type_enum - assert added_user.session_id == user_id - assert added_user.external_user_id == user_id - assert added_user._is_anonymous is False - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_new_end_user_default_session(self, mock_db, mock_session_class, factory): - """Test creating a new end user with default session ID.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - user_id = None - type_enum = InvokeFrom.WEB_APP - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.first.return_value = None # No existing user - - # Act - result = EndUserService.get_or_create_end_user_by_type( - type=type_enum, tenant_id=tenant_id, app_id=app_id, user_id=user_id - ) - - # Assert - added_user = mock_session.add.call_args[0][0] - assert added_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID - assert added_user.external_user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID - assert added_user._is_anonymous is True - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - @patch("services.end_user_service.logger") - def test_existing_user_same_type(self, mock_logger, mock_db, mock_session_class, factory): - """Test retrieving existing user with same type.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - existing_user = factory.create_end_user_mock( - tenant_id=tenant_id, app_id=app_id, session_id=user_id, type=type_enum - ) - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.first.return_value = existing_user - - # Act - result = EndUserService.get_or_create_end_user_by_type( - type=type_enum, tenant_id=tenant_id, app_id=app_id, user_id=user_id - ) - - # Assert - assert result == existing_user - mock_session.add.assert_not_called() - mock_session.commit.assert_not_called() - mock_logger.info.assert_not_called() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - @patch("services.end_user_service.logger") - def test_existing_user_different_type_upgrade(self, mock_logger, mock_db, mock_session_class, factory): - """Test upgrading existing user with different type.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - user_id = "user-789" - old_type = InvokeFrom.WEB_APP - new_type = InvokeFrom.SERVICE_API - - existing_user = factory.create_end_user_mock( - tenant_id=tenant_id, app_id=app_id, session_id=user_id, type=old_type - ) - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.first.return_value = existing_user - - # Act - result = EndUserService.get_or_create_end_user_by_type( - type=new_type, tenant_id=tenant_id, app_id=app_id, user_id=user_id - ) - - # Assert - assert result == existing_user - assert existing_user.type == new_type - mock_session.commit.assert_called_once() - mock_logger.info.assert_called_once() - logger_call_args = mock_logger.info.call_args[0] - assert "Upgrading legacy EndUser" in logger_call_args[0] - # The old and new types are passed as separate arguments - assert mock_logger.info.call_args[0][1] == existing_user.id - assert mock_logger.info.call_args[0][2] == old_type - assert mock_logger.info.call_args[0][3] == new_type - assert mock_logger.info.call_args[0][4] == user_id - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_query_ordering_prioritizes_exact_type_match(self, mock_db, mock_session_class, factory): - """Test that query ordering prioritizes exact type matches.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - user_id = "user-789" - target_type = InvokeFrom.SERVICE_API - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.first.return_value = None - - # Act - EndUserService.get_or_create_end_user_by_type( - type=target_type, tenant_id=tenant_id, app_id=app_id, user_id=user_id - ) - - # Assert - mock_query.order_by.assert_called_once() - # Verify that case statement is used for ordering - order_by_call = mock_query.order_by.call_args[0][0] - # The exact structure depends on SQLAlchemy's case implementation - # but we can verify it was called - - # Test 10: Session context manager properly closes - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_session_context_manager_closes(self, mock_db, mock_session_class, factory): - """Test that Session context manager is properly used.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - user_id = "user-789" - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.first.return_value = None - - # Act - EndUserService.get_or_create_end_user_by_type( - type=InvokeFrom.SERVICE_API, - tenant_id=tenant_id, - app_id=app_id, - user_id=user_id, - ) - - # Assert - # Verify context manager was entered and exited - mock_context.__enter__.assert_called_once() - mock_context.__exit__.assert_called_once() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_all_invokefrom_types_supported(self, mock_db, mock_session_class): - """Test that all InvokeFrom enum values are supported.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - user_id = "user-789" - - for invoke_type in InvokeFrom: - with patch("services.end_user_service.Session") as mock_session_class: - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.first.return_value = None - - # Act - result = EndUserService.get_or_create_end_user_by_type( - type=invoke_type, tenant_id=tenant_id, app_id=app_id, user_id=user_id - ) - - # Assert - added_user = mock_session.add.call_args[0][0] - assert added_user.type == invoke_type - - -class TestEndUserServiceCreateEndUserBatch: - """Unit tests for EndUserService.create_end_user_batch method.""" - - @pytest.fixture - def factory(self): - """Provide test data factory.""" - return TestEndUserServiceFactory() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_empty_app_ids(self, mock_db, mock_session_class): - """Test batch creation with empty app_ids list.""" - # Arrange - tenant_id = "tenant-123" - app_ids: list[str] = [] - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - # Act - result = EndUserService.create_end_user_batch( - type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id - ) - - # Assert - assert result == {} - mock_session_class.assert_not_called() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_default_session_id(self, mock_db, mock_session_class): - """Test batch creation with empty user_id (uses default session).""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456", "app-789"] - user_id = "" - type_enum = InvokeFrom.SERVICE_API - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [] # No existing users - - # Act - result = EndUserService.create_end_user_batch( - type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id - ) - - # Assert - assert len(result) == 2 - for app_id, end_user in result.items(): - assert end_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID - assert end_user.external_user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID - assert end_user._is_anonymous is True - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_deduplicate_app_ids(self, mock_db, mock_session_class): - """Test that duplicate app_ids are deduplicated while preserving order.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456", "app-789", "app-456", "app-123", "app-789"] - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [] # No existing users - - # Act - result = EndUserService.create_end_user_batch( - type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id - ) - - # Assert - # Should have 3 unique app_ids in original order - assert len(result) == 3 - assert "app-456" in result - assert "app-789" in result - assert "app-123" in result - - # Verify the order is preserved - added_users = mock_session.add_all.call_args[0][0] - assert len(added_users) == 3 - assert added_users[0].app_id == "app-456" - assert added_users[1].app_id == "app-789" - assert added_users[2].app_id == "app-123" - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_all_existing_users(self, mock_db, mock_session_class, factory): - """Test batch creation when all users already exist.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456", "app-789"] - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - existing_user1 = factory.create_end_user_mock( - tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum - ) - existing_user2 = factory.create_end_user_mock( - tenant_id=tenant_id, app_id="app-789", session_id=user_id, type=type_enum - ) - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [existing_user1, existing_user2] - - # Act - result = EndUserService.create_end_user_batch( - type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id - ) - - # Assert - assert len(result) == 2 - assert result["app-456"] == existing_user1 - assert result["app-789"] == existing_user2 - mock_session.add_all.assert_not_called() - mock_session.commit.assert_not_called() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_partial_existing_users(self, mock_db, mock_session_class, factory): - """Test batch creation with some existing and some new users.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456", "app-789", "app-123"] - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - existing_user1 = factory.create_end_user_mock( - tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum - ) - # app-789 and app-123 don't exist - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [existing_user1] - - # Act - result = EndUserService.create_end_user_batch( - type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id - ) - - # Assert - assert len(result) == 3 - assert result["app-456"] == existing_user1 - assert "app-789" in result - assert "app-123" in result - - # Should create 2 new users - mock_session.add_all.assert_called_once() - added_users = mock_session.add_all.call_args[0][0] - assert len(added_users) == 2 - - mock_session.commit.assert_called_once() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_handles_duplicates_in_existing(self, mock_db, mock_session_class, factory): - """Test batch creation handles duplicates in existing users gracefully.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456"] - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - # Simulate duplicate records in database - existing_user1 = factory.create_end_user_mock( - user_id="user-1", tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum - ) - existing_user2 = factory.create_end_user_mock( - user_id="user-2", tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum - ) - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [existing_user1, existing_user2] - - # Act - result = EndUserService.create_end_user_batch( - type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id - ) - - # Assert - assert len(result) == 1 - # Should prefer the first one found - assert result["app-456"] == existing_user1 - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_all_invokefrom_types(self, mock_db, mock_session_class): - """Test batch creation with all InvokeFrom types.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456"] - user_id = "user-789" - - for invoke_type in InvokeFrom: - with patch("services.end_user_service.Session") as mock_session_class: - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [] # No existing users - - # Act - result = EndUserService.create_end_user_batch( - type=invoke_type, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id - ) - - # Assert - added_user = mock_session.add_all.call_args[0][0][0] - assert added_user.type == invoke_type - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_single_app_id(self, mock_db, mock_session_class, factory): - """Test batch creation with single app_id.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456"] - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [] # No existing users - - # Act - result = EndUserService.create_end_user_batch( - type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id - ) - - # Assert - assert len(result) == 1 - assert "app-456" in result - mock_session.add_all.assert_called_once() - added_users = mock_session.add_all.call_args[0][0] - assert len(added_users) == 1 - assert added_users[0].app_id == "app-456" - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_anonymous_vs_authenticated(self, mock_db, mock_session_class): - """Test batch creation correctly sets anonymous flag.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456", "app-789"] - - # Test with regular user ID - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [] # No existing users - - # Act - authenticated user - result = EndUserService.create_end_user_batch( - type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id="user-789" - ) - - # Assert - added_users = mock_session.add_all.call_args[0][0] - for user in added_users: - assert user._is_anonymous is False - - # Test with default session ID - mock_session.reset_mock() - mock_query.reset_mock() - mock_query.all.return_value = [] - - # Act - anonymous user - result = EndUserService.create_end_user_batch( - type=InvokeFrom.SERVICE_API, - tenant_id=tenant_id, - app_ids=app_ids, - user_id=DefaultEndUserSessionID.DEFAULT_SESSION_ID, - ) - - # Assert - added_users = mock_session.add_all.call_args[0][0] - for user in added_users: - assert user._is_anonymous is True - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_efficient_single_query(self, mock_db, mock_session_class): - """Test that batch creation uses efficient single query for existing users.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456", "app-789", "app-123"] - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [] # No existing users - - # Act - EndUserService.create_end_user_batch(type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id) - - # Assert - # Should make exactly one query to check for existing users - mock_session.query.assert_called_once_with(EndUser) - mock_query.where.assert_called_once() - mock_query.all.assert_called_once() - - # Verify the where clause uses .in_() for app_ids - where_call = mock_query.where.call_args[0] - # The exact structure depends on SQLAlchemy implementation - # but we can verify it was called with the right parameters - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_session_context_manager(self, mock_db, mock_session_class): - """Test that batch creation properly uses session context manager.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456"] - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [] # No existing users - - # Act - EndUserService.create_end_user_batch(type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id) - - # Assert - mock_context.__enter__.assert_called_once() - mock_context.__exit__.assert_called_once() - mock_session.commit.assert_called_once() diff --git a/api/tests/unit_tests/services/test_feedback_service.py b/api/tests/unit_tests/services/test_feedback_service.py deleted file mode 100644 index 1f70839ee22..00000000000 --- a/api/tests/unit_tests/services/test_feedback_service.py +++ /dev/null @@ -1,626 +0,0 @@ -import csv -import io -import json -from datetime import datetime -from unittest.mock import MagicMock, patch - -import pytest - -from services.feedback_service import FeedbackService - - -class TestFeedbackServiceFactory: - """Factory class for creating test data and mock objects for feedback service tests.""" - - @staticmethod - def create_feedback_mock( - feedback_id: str = "feedback-123", - app_id: str = "app-456", - conversation_id: str = "conv-789", - message_id: str = "msg-001", - rating: str = "like", - content: str | None = "Great response!", - from_source: str = "user", - from_account_id: str | None = None, - from_end_user_id: str | None = "end-user-001", - created_at: datetime | None = None, - ) -> MagicMock: - """Create a mock MessageFeedback object.""" - feedback = MagicMock() - feedback.id = feedback_id - feedback.app_id = app_id - feedback.conversation_id = conversation_id - feedback.message_id = message_id - feedback.rating = rating - feedback.content = content - feedback.from_source = from_source - feedback.from_account_id = from_account_id - feedback.from_end_user_id = from_end_user_id - feedback.created_at = created_at or datetime.now() - return feedback - - @staticmethod - def create_message_mock( - message_id: str = "msg-001", - query: str = "What is AI?", - answer: str = "AI stands for Artificial Intelligence.", - inputs: dict | None = None, - created_at: datetime | None = None, - ): - """Create a mock Message object.""" - - # Create a simple object with instance attributes - # Using a class with __init__ ensures attributes are instance attributes - class Message: - def __init__(self): - self.id = message_id - self.query = query - self.answer = answer - self.inputs = inputs - self.created_at = created_at or datetime.now() - - return Message() - - @staticmethod - def create_conversation_mock( - conversation_id: str = "conv-789", - name: str | None = "Test Conversation", - ) -> MagicMock: - """Create a mock Conversation object.""" - conversation = MagicMock() - conversation.id = conversation_id - conversation.name = name - return conversation - - @staticmethod - def create_app_mock( - app_id: str = "app-456", - name: str = "Test App", - ) -> MagicMock: - """Create a mock App object.""" - app = MagicMock() - app.id = app_id - app.name = name - return app - - @staticmethod - def create_account_mock( - account_id: str = "account-123", - name: str = "Test Admin", - ) -> MagicMock: - """Create a mock Account object.""" - account = MagicMock() - account.id = account_id - account.name = name - return account - - -class TestFeedbackService: - """ - Comprehensive unit tests for FeedbackService. - - This test suite covers: - - CSV and JSON export formats - - All filter combinations - - Edge cases and error handling - - Response validation - """ - - @pytest.fixture - def factory(self): - """Provide test data factory.""" - return TestFeedbackServiceFactory() - - @pytest.fixture - def sample_feedback_data(self, factory): - """Create sample feedback data for testing.""" - feedback = factory.create_feedback_mock( - rating="like", - content="Excellent answer!", - from_source="user", - ) - message = factory.create_message_mock( - query="What is Python?", - answer="Python is a programming language.", - ) - conversation = factory.create_conversation_mock(name="Python Discussion") - app = factory.create_app_mock(name="AI Assistant") - account = factory.create_account_mock(name="Admin User") - - return [(feedback, message, conversation, app, account)] - - # Test 01: CSV Export - Basic Functionality - @patch("services.feedback_service.db") - def test_export_feedbacks_csv_basic(self, mock_db, factory, sample_feedback_data): - """Test basic CSV export with single feedback record.""" - # Arrange - mock_query = MagicMock() - # Configure the mock to return itself for all chaining methods - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = sample_feedback_data - - # Set up the session.query to return our mock - mock_db.session.query.return_value = mock_query - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="csv") - - # Assert - assert response.mimetype == "text/csv" - assert "charset=utf-8-sig" in response.content_type - assert "attachment" in response.headers["Content-Disposition"] - assert "dify_feedback_export_app-456" in response.headers["Content-Disposition"] - - # Verify CSV content - csv_content = response.get_data(as_text=True) - reader = csv.DictReader(io.StringIO(csv_content)) - rows = list(reader) - - assert len(rows) == 1 - assert rows[0]["feedback_rating"] == "👍" - assert rows[0]["feedback_rating_raw"] == "like" - assert rows[0]["feedback_comment"] == "Excellent answer!" - assert rows[0]["user_query"] == "What is Python?" - assert rows[0]["ai_response"] == "Python is a programming language." - - # Test 02: JSON Export - Basic Functionality - @patch("services.feedback_service.db") - def test_export_feedbacks_json_basic(self, mock_db, factory, sample_feedback_data): - """Test basic JSON export with metadata structure.""" - # Arrange - mock_query = MagicMock() - # Configure the mock to return itself for all chaining methods - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = sample_feedback_data - - # Set up the session.query to return our mock - mock_db.session.query.return_value = mock_query - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json") - - # Assert - assert response.mimetype == "application/json" - assert "charset=utf-8" in response.content_type - assert "attachment" in response.headers["Content-Disposition"] - - # Verify JSON structure - json_content = json.loads(response.get_data(as_text=True)) - assert "export_info" in json_content - assert "feedback_data" in json_content - assert json_content["export_info"]["app_id"] == "app-456" - assert json_content["export_info"]["total_records"] == 1 - assert len(json_content["feedback_data"]) == 1 - - # Test 03: Filter by from_source - @patch("services.feedback_service.db") - def test_export_feedbacks_filter_from_source(self, mock_db, factory): - """Test filtering by feedback source (user/admin).""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Act - FeedbackService.export_feedbacks(app_id="app-456", from_source="admin") - - # Assert - mock_query.filter.assert_called() - - # Test 04: Filter by rating - @patch("services.feedback_service.db") - def test_export_feedbacks_filter_rating(self, mock_db, factory): - """Test filtering by rating (like/dislike).""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Act - FeedbackService.export_feedbacks(app_id="app-456", rating="dislike") - - # Assert - mock_query.filter.assert_called() - - # Test 05: Filter by has_comment (True) - @patch("services.feedback_service.db") - def test_export_feedbacks_filter_has_comment_true(self, mock_db, factory): - """Test filtering for feedback with comments.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Act - FeedbackService.export_feedbacks(app_id="app-456", has_comment=True) - - # Assert - mock_query.filter.assert_called() - - # Test 06: Filter by has_comment (False) - @patch("services.feedback_service.db") - def test_export_feedbacks_filter_has_comment_false(self, mock_db, factory): - """Test filtering for feedback without comments.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Act - FeedbackService.export_feedbacks(app_id="app-456", has_comment=False) - - # Assert - mock_query.filter.assert_called() - - # Test 07: Filter by date range - @patch("services.feedback_service.db") - def test_export_feedbacks_filter_date_range(self, mock_db, factory): - """Test filtering by start and end dates.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Act - FeedbackService.export_feedbacks( - app_id="app-456", - start_date="2024-01-01", - end_date="2024-12-31", - ) - - # Assert - assert mock_query.filter.call_count >= 2 # Called for both start and end dates - - # Test 08: Invalid date format - start_date - @patch("services.feedback_service.db") - def test_export_feedbacks_invalid_start_date(self, mock_db): - """Test error handling for invalid start_date format.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - - # Act & Assert - with pytest.raises(ValueError, match="Invalid start_date format"): - FeedbackService.export_feedbacks(app_id="app-456", start_date="invalid-date") - - # Test 09: Invalid date format - end_date - @patch("services.feedback_service.db") - def test_export_feedbacks_invalid_end_date(self, mock_db): - """Test error handling for invalid end_date format.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - - # Act & Assert - with pytest.raises(ValueError, match="Invalid end_date format"): - FeedbackService.export_feedbacks(app_id="app-456", end_date="2024-13-45") - - # Test 10: Unsupported format - def test_export_feedbacks_unsupported_format(self): - """Test error handling for unsupported export format.""" - # Act & Assert - with pytest.raises(ValueError, match="Unsupported format"): - FeedbackService.export_feedbacks(app_id="app-456", format_type="xml") - - # Test 11: Empty result set - CSV - @patch("services.feedback_service.db") - def test_export_feedbacks_empty_results_csv(self, mock_db): - """Test CSV export with no feedback records.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="csv") - - # Assert - csv_content = response.get_data(as_text=True) - reader = csv.DictReader(io.StringIO(csv_content)) - rows = list(reader) - assert len(rows) == 0 - # But headers should still be present - assert reader.fieldnames is not None - - # Test 12: Empty result set - JSON - @patch("services.feedback_service.db") - def test_export_feedbacks_empty_results_json(self, mock_db): - """Test JSON export with no feedback records.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json") - - # Assert - json_content = json.loads(response.get_data(as_text=True)) - assert json_content["export_info"]["total_records"] == 0 - assert len(json_content["feedback_data"]) == 0 - - # Test 13: Long response truncation - @patch("services.feedback_service.db") - def test_export_feedbacks_long_response_truncation(self, mock_db, factory): - """Test that long AI responses are truncated to 500 characters.""" - # Arrange - long_answer = "A" * 600 # 600 characters - feedback = factory.create_feedback_mock() - message = factory.create_message_mock(answer=long_answer) - conversation = factory.create_conversation_mock() - app = factory.create_app_mock() - account = factory.create_account_mock() - - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [(feedback, message, conversation, app, account)] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json") - - # Assert - json_content = json.loads(response.get_data(as_text=True)) - ai_response = json_content["feedback_data"][0]["ai_response"] - assert len(ai_response) == 503 # 500 + "..." - assert ai_response.endswith("...") - - # Test 14: Null account (end user feedback) - @patch("services.feedback_service.db") - def test_export_feedbacks_null_account(self, mock_db, factory): - """Test handling of feedback from end users (no account).""" - # Arrange - feedback = factory.create_feedback_mock(from_account_id=None) - message = factory.create_message_mock() - conversation = factory.create_conversation_mock() - app = factory.create_app_mock() - account = None # No account for end user - - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [(feedback, message, conversation, app, account)] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json") - - # Assert - json_content = json.loads(response.get_data(as_text=True)) - assert json_content["feedback_data"][0]["from_account_name"] == "" - - # Test 15: Null conversation name - @patch("services.feedback_service.db") - def test_export_feedbacks_null_conversation_name(self, mock_db, factory): - """Test handling of conversations without names.""" - # Arrange - feedback = factory.create_feedback_mock() - message = factory.create_message_mock() - conversation = factory.create_conversation_mock(name=None) - app = factory.create_app_mock() - account = factory.create_account_mock() - - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [(feedback, message, conversation, app, account)] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json") - - # Assert - json_content = json.loads(response.get_data(as_text=True)) - assert json_content["feedback_data"][0]["conversation_name"] == "" - - # Test 16: Dislike rating emoji - @patch("services.feedback_service.db") - def test_export_feedbacks_dislike_rating(self, mock_db, factory): - """Test that dislike rating shows thumbs down emoji.""" - # Arrange - feedback = factory.create_feedback_mock(rating="dislike") - message = factory.create_message_mock() - conversation = factory.create_conversation_mock() - app = factory.create_app_mock() - account = factory.create_account_mock() - - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [(feedback, message, conversation, app, account)] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json") - - # Assert - json_content = json.loads(response.get_data(as_text=True)) - assert json_content["feedback_data"][0]["feedback_rating"] == "👎" - assert json_content["feedback_data"][0]["feedback_rating_raw"] == "dislike" - - # Test 17: Combined filters - @patch("services.feedback_service.db") - def test_export_feedbacks_combined_filters(self, mock_db, factory): - """Test applying multiple filters simultaneously.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Act - FeedbackService.export_feedbacks( - app_id="app-456", - from_source="admin", - rating="like", - has_comment=True, - start_date="2024-01-01", - end_date="2024-12-31", - ) - - # Assert - # Should have called filter multiple times for each condition - assert mock_query.filter.call_count >= 4 - - # Test 18: Message query fallback to inputs - @patch("services.feedback_service.db") - def test_export_feedbacks_message_query_from_inputs(self, mock_db, factory): - """Test fallback to inputs.query when message.query is None.""" - # Arrange - feedback = factory.create_feedback_mock() - message = factory.create_message_mock(query=None, inputs={"query": "Query from inputs"}) - conversation = factory.create_conversation_mock() - app = factory.create_app_mock() - account = factory.create_account_mock() - - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [(feedback, message, conversation, app, account)] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json") - - # Assert - json_content = json.loads(response.get_data(as_text=True)) - assert json_content["feedback_data"][0]["user_query"] == "Query from inputs" - - # Test 19: Empty feedback content - @patch("services.feedback_service.db") - def test_export_feedbacks_empty_feedback_content(self, mock_db, factory): - """Test handling of feedback with empty/null content.""" - # Arrange - feedback = factory.create_feedback_mock(content=None) - message = factory.create_message_mock() - conversation = factory.create_conversation_mock() - app = factory.create_app_mock() - account = factory.create_account_mock() - - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [(feedback, message, conversation, app, account)] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json") - - # Assert - json_content = json.loads(response.get_data(as_text=True)) - assert json_content["feedback_data"][0]["feedback_comment"] == "" - assert json_content["feedback_data"][0]["has_comment"] == "No" - - # Test 20: CSV headers validation - @patch("services.feedback_service.db") - def test_export_feedbacks_csv_headers(self, mock_db, factory, sample_feedback_data): - """Test that CSV contains all expected headers.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = sample_feedback_data - - expected_headers = [ - "feedback_id", - "app_name", - "app_id", - "conversation_id", - "conversation_name", - "message_id", - "user_query", - "ai_response", - "feedback_rating", - "feedback_rating_raw", - "feedback_comment", - "feedback_source", - "feedback_date", - "message_date", - "from_account_name", - "from_end_user_id", - "has_comment", - ] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="csv") - - # Assert - csv_content = response.get_data(as_text=True) - reader = csv.DictReader(io.StringIO(csv_content)) - assert list(reader.fieldnames) == expected_headers diff --git a/api/tests/unit_tests/services/test_file_service_zip_and_lookup.py b/api/tests/unit_tests/services/test_file_service_zip_and_lookup.py deleted file mode 100644 index 7b4d349e33e..00000000000 --- a/api/tests/unit_tests/services/test_file_service_zip_and_lookup.py +++ /dev/null @@ -1,99 +0,0 @@ -""" -Unit tests for `services.file_service.FileService` helpers. - -We keep these tests focused on: -- ZIP tempfile building (sanitization + deduplication + content writes) -- tenant-scoped batch lookup behavior (`get_upload_files_by_ids`) -""" - -from __future__ import annotations - -from types import SimpleNamespace -from typing import Any -from zipfile import ZipFile - -import pytest - -import services.file_service as file_service_module -from services.file_service import FileService - - -def test_build_upload_files_zip_tempfile_sanitizes_and_dedupes_names(monkeypatch: pytest.MonkeyPatch) -> None: - """Ensure ZIP entry names are safe and unique while preserving extensions.""" - - # Arrange: three upload files that all sanitize down to the same basename ("b.txt"). - upload_files: list[Any] = [ - SimpleNamespace(name="a/b.txt", key="k1"), - SimpleNamespace(name="c/b.txt", key="k2"), - SimpleNamespace(name="../b.txt", key="k3"), - ] - - # Stream distinct bytes per key so we can verify content is written to the right entry. - data_by_key: dict[str, list[bytes]] = {"k1": [b"one"], "k2": [b"two"], "k3": [b"three"]} - - def _load(key: str, stream: bool = True) -> list[bytes]: - # Return the corresponding chunks for this key (the production code iterates chunks). - assert stream is True - return data_by_key[key] - - monkeypatch.setattr(file_service_module.storage, "load", _load) - - # Act: build zip in a tempfile. - with FileService.build_upload_files_zip_tempfile(upload_files=upload_files) as tmp: - with ZipFile(tmp, mode="r") as zf: - # Assert: names are sanitized (no directory components) and deduped with suffixes. - assert zf.namelist() == ["b.txt", "b (1).txt", "b (2).txt"] - - # Assert: each entry contains the correct bytes from storage. - assert zf.read("b.txt") == b"one" - assert zf.read("b (1).txt") == b"two" - assert zf.read("b (2).txt") == b"three" - - -def test_get_upload_files_by_ids_returns_empty_when_no_ids(monkeypatch: pytest.MonkeyPatch) -> None: - """Ensure empty input returns an empty mapping without hitting the database.""" - - class _Session: - def scalars(self, _stmt): # type: ignore[no-untyped-def] - raise AssertionError("db.session.scalars should not be called for empty id lists") - - monkeypatch.setattr(file_service_module, "db", SimpleNamespace(session=_Session())) - - assert FileService.get_upload_files_by_ids("tenant-1", []) == {} - - -def test_get_upload_files_by_ids_returns_id_keyed_mapping(monkeypatch: pytest.MonkeyPatch) -> None: - """Ensure batch lookup returns a dict keyed by stringified UploadFile ids.""" - - upload_files: list[Any] = [ - SimpleNamespace(id="file-1", tenant_id="tenant-1"), - SimpleNamespace(id="file-2", tenant_id="tenant-1"), - ] - - class _ScalarResult: - def __init__(self, items: list[Any]) -> None: - self._items = items - - def all(self) -> list[Any]: - return self._items - - class _Session: - def __init__(self, items: list[Any]) -> None: - self._items = items - self.calls: list[object] = [] - - def scalars(self, stmt): # type: ignore[no-untyped-def] - # Capture the statement so we can at least assert the query path is taken. - self.calls.append(stmt) - return _ScalarResult(self._items) - - session = _Session(upload_files) - monkeypatch.setattr(file_service_module, "db", SimpleNamespace(session=session)) - - # Provide duplicates to ensure callers can safely pass repeated ids. - result = FileService.get_upload_files_by_ids("tenant-1", ["file-1", "file-1", "file-2"]) - - assert set(result.keys()) == {"file-1", "file-2"} - assert result["file-1"].id == "file-1" - assert result["file-2"].id == "file-2" - assert len(session.calls) == 1 diff --git a/api/tests/unit_tests/services/test_human_input_service.py b/api/tests/unit_tests/services/test_human_input_service.py index 375e47d7fcf..9be475d043f 100644 --- a/api/tests/unit_tests/services/test_human_input_service.py +++ b/api/tests/unit_tests/services/test_human_input_service.py @@ -3,18 +3,19 @@ from datetime import datetime, timedelta from unittest.mock import MagicMock import pytest +from graphon.nodes.human_input.entities import ( + FormDefinition, + FormInput, + UserAction, +) +from graphon.nodes.human_input.enums import FormInputType, HumanInputFormKind, HumanInputFormStatus import services.human_input_service as human_input_service_module from core.repositories.human_input_repository import ( HumanInputFormRecord, HumanInputFormSubmissionRepository, ) -from dify_graph.nodes.human_input.entities import ( - FormDefinition, - FormInput, - UserAction, -) -from dify_graph.nodes.human_input.enums import FormInputType, HumanInputFormKind, HumanInputFormStatus +from libs.datetime_utils import naive_utc_now from models.human_input import RecipientType from services.human_input_service import ( Form, @@ -51,11 +52,11 @@ def sample_form_record(): inputs=[], user_actions=[UserAction(id="submit", title="Submit")], rendered_content="

hello

", - expiration_time=datetime.utcnow() + timedelta(hours=1), + expiration_time=naive_utc_now() + timedelta(hours=1), ), rendered_content="

hello

", - created_at=datetime.utcnow(), - expiration_time=datetime.utcnow() + timedelta(hours=1), + created_at=naive_utc_now(), + expiration_time=naive_utc_now() + timedelta(hours=1), status=HumanInputFormStatus.WAITING, selected_action_id=None, submitted_data=None, @@ -101,8 +102,8 @@ def test_ensure_form_active_respects_global_timeout(monkeypatch, sample_form_rec service = HumanInputService(session_factory) expired_record = dataclasses.replace( sample_form_record, - created_at=datetime.utcnow() - timedelta(hours=2), - expiration_time=datetime.utcnow() + timedelta(hours=2), + created_at=naive_utc_now() - timedelta(hours=2), + expiration_time=naive_utc_now() + timedelta(hours=2), ) monkeypatch.setattr(human_input_service_module.dify_config, "HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS", 3600) @@ -391,7 +392,7 @@ def test_ensure_form_active_errors(sample_form_record, mock_session_factory): service = HumanInputService(session_factory) # Submitted - submitted_record = dataclasses.replace(sample_form_record, submitted_at=datetime.utcnow()) + submitted_record = dataclasses.replace(sample_form_record, submitted_at=naive_utc_now()) with pytest.raises(human_input_service_module.FormSubmittedError): service.ensure_form_active(Form(submitted_record)) @@ -402,7 +403,7 @@ def test_ensure_form_active_errors(sample_form_record, mock_session_factory): # Expired time expired_time_record = dataclasses.replace( - sample_form_record, expiration_time=datetime.utcnow() - timedelta(minutes=1) + sample_form_record, expiration_time=naive_utc_now() - timedelta(minutes=1) ) with pytest.raises(FormExpiredError): service.ensure_form_active(Form(expired_time_record)) @@ -411,7 +412,7 @@ def test_ensure_form_active_errors(sample_form_record, mock_session_factory): def test_ensure_not_submitted_raises(sample_form_record, mock_session_factory): session_factory, _ = mock_session_factory service = HumanInputService(session_factory) - submitted_record = dataclasses.replace(sample_form_record, submitted_at=datetime.utcnow()) + submitted_record = dataclasses.replace(sample_form_record, submitted_at=naive_utc_now()) with pytest.raises(human_input_service_module.FormSubmittedError): service._ensure_not_submitted(Form(submitted_record)) diff --git a/api/tests/unit_tests/services/test_knowledge_service.py b/api/tests/unit_tests/services/test_knowledge_service.py index bc0caee0717..53c243ad716 100644 --- a/api/tests/unit_tests/services/test_knowledge_service.py +++ b/api/tests/unit_tests/services/test_knowledge_service.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch import pytest -from services.knowledge_service import ExternalDatasetTestService +from services.knowledge_service import BedrockRetrievalSetting, ExternalDatasetTestService class TestKnowledgeService: @@ -24,7 +24,7 @@ class TestKnowledgeService: mock_client = MagicMock() mock_boto_client.return_value = mock_client - retrieval_setting = {"top_k": 4, "score_threshold": 0.5} + retrieval_setting = BedrockRetrievalSetting(top_k=4, score_threshold=0.5) query = "test query" knowledge_id = "kb-123" @@ -87,7 +87,10 @@ class TestKnowledgeService: mock_client.retrieve.return_value = {"ResponseMetadata": {"HTTPStatusCode": 200}, "retrievalResults": []} # Act - result = cast(dict[str, Any], ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb")) + result = cast( + dict[str, Any], + ExternalDatasetTestService.knowledge_retrieval(BedrockRetrievalSetting(top_k=1), "query", "kb"), + ) # Assert assert result["records"] == [] @@ -104,7 +107,10 @@ class TestKnowledgeService: mock_client.retrieve.return_value = {"ResponseMetadata": {"HTTPStatusCode": 500}} # Act - result = cast(dict[str, Any], ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb")) + result = cast( + dict[str, Any], + ExternalDatasetTestService.knowledge_retrieval(BedrockRetrievalSetting(top_k=1), "query", "kb"), + ) # Assert assert result["records"] == [] @@ -114,7 +120,7 @@ class TestKnowledgeService: with patch("services.knowledge_service.boto3.client") as mock_boto: mock_boto.side_effect = Exception("client init failed") with pytest.raises(Exception) as exc_info: - ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb") + ExternalDatasetTestService.knowledge_retrieval(BedrockRetrievalSetting(top_k=1), "query", "kb") assert "client init failed" in str(exc_info.value) # ===== Edge Cases ===== @@ -139,7 +145,10 @@ class TestKnowledgeService: # Act # retrieval_setting missing "score_threshold" - result = cast(dict[str, Any], ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb")) + result = cast( + dict[str, Any], + ExternalDatasetTestService.knowledge_retrieval(BedrockRetrievalSetting(top_k=1), "query", "kb"), + ) # Assert assert len(result["records"]) == 1 diff --git a/api/tests/unit_tests/services/test_message_service.py b/api/tests/unit_tests/services/test_message_service.py index e7740ef93a4..101b9bff24d 100644 --- a/api/tests/unit_tests/services/test_message_service.py +++ b/api/tests/unit_tests/services/test_message_service.py @@ -933,7 +933,7 @@ class TestMessageServiceSuggestedQuestions: ) # Test 28: get_suggested_questions_after_answer - Advanced Chat success - @patch("services.message_service.ModelManager") + @patch("services.message_service.ModelManager.for_tenant") @patch("services.message_service.WorkflowService") @patch("services.message_service.AdvancedChatAppConfigManager") @patch("services.message_service.TokenBufferMemory") @@ -983,7 +983,7 @@ class TestMessageServiceSuggestedQuestions: # Test 29: get_suggested_questions_after_answer - Chat app success (no override) @patch("services.message_service.db") - @patch("services.message_service.ModelManager") + @patch("services.message_service.ModelManager.for_tenant") @patch("services.message_service.TokenBufferMemory") @patch("services.message_service.LLMGenerator") @patch("services.message_service.TraceQueueManager") diff --git a/api/tests/unit_tests/services/test_metadata_partial_update.py b/api/tests/unit_tests/services/test_metadata_partial_update.py deleted file mode 100644 index 60252784bc5..00000000000 --- a/api/tests/unit_tests/services/test_metadata_partial_update.py +++ /dev/null @@ -1,187 +0,0 @@ -import unittest -from unittest.mock import MagicMock, patch - -import pytest - -from models.dataset import Dataset, Document -from services.entities.knowledge_entities.knowledge_entities import ( - DocumentMetadataOperation, - MetadataDetail, - MetadataOperationData, -) -from services.metadata_service import MetadataService - - -class TestMetadataPartialUpdate(unittest.TestCase): - def setUp(self): - self.dataset = MagicMock(spec=Dataset) - self.dataset.id = "dataset_id" - self.dataset.built_in_field_enabled = False - - self.document = MagicMock(spec=Document) - self.document.id = "doc_id" - self.document.doc_metadata = {"existing_key": "existing_value"} - self.document.data_source_type = "upload_file" - - @patch("services.metadata_service.db") - @patch("services.metadata_service.DocumentService") - @patch("services.metadata_service.current_account_with_tenant") - @patch("services.metadata_service.redis_client") - def test_partial_update_merges_metadata(self, mock_redis, mock_current_account, mock_document_service, mock_db): - # Setup mocks - mock_redis.get.return_value = None - mock_document_service.get_document.return_value = self.document - mock_current_account.return_value = (MagicMock(id="user_id"), "tenant_id") - - # Mock DB query for existing bindings - - # No existing binding for new key - mock_db.session.query.return_value.filter_by.return_value.first.return_value = None - - # Input data - operation = DocumentMetadataOperation( - document_id="doc_id", - metadata_list=[MetadataDetail(id="new_meta_id", name="new_key", value="new_value")], - partial_update=True, - ) - metadata_args = MetadataOperationData(operation_data=[operation]) - - # Execute - MetadataService.update_documents_metadata(self.dataset, metadata_args) - - # Verify - # 1. Check that doc_metadata contains BOTH existing and new keys - expected_metadata = {"existing_key": "existing_value", "new_key": "new_value"} - assert self.document.doc_metadata == expected_metadata - - # 2. Check that existing bindings were NOT deleted - # The delete call in the original code: db.session.query(...).filter_by(...).delete() - # In partial update, this should NOT be called. - mock_db.session.query.return_value.filter_by.return_value.delete.assert_not_called() - - @patch("services.metadata_service.db") - @patch("services.metadata_service.DocumentService") - @patch("services.metadata_service.current_account_with_tenant") - @patch("services.metadata_service.redis_client") - def test_full_update_replaces_metadata(self, mock_redis, mock_current_account, mock_document_service, mock_db): - # Setup mocks - mock_redis.get.return_value = None - mock_document_service.get_document.return_value = self.document - mock_current_account.return_value = (MagicMock(id="user_id"), "tenant_id") - - # Input data (partial_update=False by default) - operation = DocumentMetadataOperation( - document_id="doc_id", - metadata_list=[MetadataDetail(id="new_meta_id", name="new_key", value="new_value")], - partial_update=False, - ) - metadata_args = MetadataOperationData(operation_data=[operation]) - - # Execute - MetadataService.update_documents_metadata(self.dataset, metadata_args) - - # Verify - # 1. Check that doc_metadata contains ONLY the new key - expected_metadata = {"new_key": "new_value"} - assert self.document.doc_metadata == expected_metadata - - # 2. Check that existing bindings WERE deleted - # In full update (default), we expect the existing bindings to be cleared. - mock_db.session.query.return_value.filter_by.return_value.delete.assert_called() - - @patch("services.metadata_service.db") - @patch("services.metadata_service.DocumentService") - @patch("services.metadata_service.current_account_with_tenant") - @patch("services.metadata_service.redis_client") - def test_partial_update_skips_existing_binding( - self, mock_redis, mock_current_account, mock_document_service, mock_db - ): - # Setup mocks - mock_redis.get.return_value = None - mock_document_service.get_document.return_value = self.document - mock_current_account.return_value = (MagicMock(id="user_id"), "tenant_id") - - # Mock DB query to return an existing binding - # This simulates that the document ALREADY has the metadata we are trying to add - mock_existing_binding = MagicMock() - mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_existing_binding - - # Input data - operation = DocumentMetadataOperation( - document_id="doc_id", - metadata_list=[MetadataDetail(id="existing_meta_id", name="existing_key", value="existing_value")], - partial_update=True, - ) - metadata_args = MetadataOperationData(operation_data=[operation]) - - # Execute - MetadataService.update_documents_metadata(self.dataset, metadata_args) - - # Verify - # We verify that db.session.add was NOT called for DatasetMetadataBinding - # Since we can't easily check "not called with specific type" on the generic add method without complex logic, - # we can check if the number of add calls is 1 (only for the document update) instead of 2 (document + binding) - - # Expected calls: - # 1. db.session.add(document) - # 2. NO db.session.add(binding) because it exists - - # Note: In the code, db.session.add is called for document. - # Then loop over metadata_list. - # If existing_binding found, continue. - # So binding add should be skipped. - - # Let's filter the calls to add to see what was added - add_calls = mock_db.session.add.call_args_list - added_objects = [call.args[0] for call in add_calls] - - # Check that no DatasetMetadataBinding was added - from models.dataset import DatasetMetadataBinding - - has_binding_add = any( - isinstance(obj, DatasetMetadataBinding) - or (isinstance(obj, MagicMock) and getattr(obj, "__class__", None) == DatasetMetadataBinding) - for obj in added_objects - ) - - # Since we mock everything, checking isinstance might be tricky if DatasetMetadataBinding - # is not the exact class used in the service (imports match). - # But we can check the count. - # If it were added, there would be 2 calls. If skipped, 1 call. - assert mock_db.session.add.call_count == 1 - - @patch("services.metadata_service.db") - @patch("services.metadata_service.DocumentService") - @patch("services.metadata_service.current_account_with_tenant") - @patch("services.metadata_service.redis_client") - def test_rollback_called_on_commit_failure(self, mock_redis, mock_current_account, mock_document_service, mock_db): - """When db.session.commit() raises, rollback must be called and the exception must propagate.""" - # Setup mocks - mock_redis.get.return_value = None - mock_document_service.get_document.return_value = self.document - mock_current_account.return_value = (MagicMock(id="user_id"), "tenant_id") - mock_db.session.query.return_value.filter_by.return_value.first.return_value = None - - # Make commit raise an exception - mock_db.session.commit.side_effect = RuntimeError("database connection lost") - - operation = DocumentMetadataOperation( - document_id="doc_id", - metadata_list=[MetadataDetail(id="meta_id", name="key", value="value")], - partial_update=True, - ) - metadata_args = MetadataOperationData(operation_data=[operation]) - - # Act & Assert: the exception must propagate - with pytest.raises(RuntimeError, match="database connection lost"): - MetadataService.update_documents_metadata(self.dataset, metadata_args) - - # Verify rollback was called - mock_db.session.rollback.assert_called_once() - - # Verify the lock key was cleaned up despite the failure - mock_redis.delete.assert_called_with("document_metadata_lock_doc_id") - - -if __name__ == "__main__": - unittest.main() diff --git a/api/tests/unit_tests/services/test_model_load_balancing_service.py b/api/tests/unit_tests/services/test_model_load_balancing_service.py new file mode 100644 index 00000000000..b43e79dff50 --- /dev/null +++ b/api/tests/unit_tests/services/test_model_load_balancing_service.py @@ -0,0 +1,815 @@ +from __future__ import annotations + +import json +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import MagicMock + +import pytest +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + CredentialFormSchema, + FieldModelSchema, + FormType, + ModelCredentialSchema, + ProviderCredentialSchema, +) +from pytest_mock import MockerFixture + +from constants import HIDDEN_VALUE +from models.provider import LoadBalancingModelConfig +from services.model_load_balancing_service import ModelLoadBalancingService + + +def _build_provider_credential_schema() -> ProviderCredentialSchema: + return ProviderCredentialSchema( + credential_form_schemas=[ + CredentialFormSchema(variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.SECRET_INPUT) + ] + ) + + +def _build_model_credential_schema() -> ModelCredentialSchema: + return ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model")), + credential_form_schemas=[ + CredentialFormSchema(variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.SECRET_INPUT) + ], + ) + + +def _build_provider_configuration( + *, + custom_provider: bool = False, + load_balancing_enabled: bool | None = None, + model_schema: ModelCredentialSchema | None = None, + provider_schema: ProviderCredentialSchema | None = None, +) -> MagicMock: + provider_configuration = MagicMock() + provider_configuration.provider = SimpleNamespace( + provider="openai", + model_credential_schema=model_schema, + provider_credential_schema=provider_schema, + ) + provider_configuration.custom_configuration = SimpleNamespace(provider=custom_provider) + provider_configuration.extract_secret_variables.return_value = ["api_key"] + provider_configuration.obfuscated_credentials.side_effect = lambda credentials, credential_form_schemas: credentials + provider_configuration.get_provider_model_setting.return_value = ( + None if load_balancing_enabled is None else SimpleNamespace(load_balancing_enabled=load_balancing_enabled) + ) + return provider_configuration + + +def _load_balancing_model_config(**kwargs: Any) -> LoadBalancingModelConfig: + return cast(LoadBalancingModelConfig, SimpleNamespace(**kwargs)) + + +@pytest.fixture +def service(mocker: MockerFixture) -> ModelLoadBalancingService: + # Arrange + provider_manager = MagicMock() + mocker.patch("services.model_load_balancing_service.create_plugin_provider_manager", return_value=provider_manager) + model_assembly = SimpleNamespace(provider_manager=provider_manager, model_provider_factory=MagicMock()) + mocker.patch("services.model_load_balancing_service.create_plugin_model_assembly", return_value=model_assembly) + svc = ModelLoadBalancingService() + svc.provider_manager = provider_manager + svc.model_assembly = model_assembly + svc._get_provider_manager = lambda _tenant_id: provider_manager # type: ignore[method-assign] + return svc + + +@pytest.fixture +def mock_db(mocker: MockerFixture) -> MagicMock: + # Arrange + mocked_db = mocker.patch("services.model_load_balancing_service.db") + mocked_db.session = MagicMock() + return mocked_db + + +@pytest.mark.parametrize( + ("method_name", "expected_provider_method"), + [ + ("enable_model_load_balancing", "enable_model_load_balancing"), + ("disable_model_load_balancing", "disable_model_load_balancing"), + ], +) +def test_enable_disable_model_load_balancing_should_call_provider_configuration_method_when_provider_exists( + method_name: str, + expected_provider_method: str, + service: ModelLoadBalancingService, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + + # Act + getattr(service, method_name)("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value) + + # Assert + getattr(provider_configuration, expected_provider_method).assert_called_once_with( + model="gpt-4o-mini", model_type=ModelType.LLM + ) + + +@pytest.mark.parametrize( + "method_name", + ["enable_model_load_balancing", "disable_model_load_balancing"], +) +def test_enable_disable_model_load_balancing_should_raise_value_error_when_provider_missing( + method_name: str, + service: ModelLoadBalancingService, +) -> None: + # Arrange + service.provider_manager.get_configurations.return_value = {} + + # Act + Assert + with pytest.raises(ValueError, match="Provider openai does not exist"): + getattr(service, method_name)("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value) + + +def test_get_load_balancing_configs_should_raise_value_error_when_provider_missing( + service: ModelLoadBalancingService, +) -> None: + # Arrange + service.provider_manager.get_configurations.return_value = {} + + # Act + Assert + with pytest.raises(ValueError, match="Provider openai does not exist"): + service.get_load_balancing_configs("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value) + + +def test_get_load_balancing_configs_should_insert_inherit_config_when_missing_for_custom_provider( + service: ModelLoadBalancingService, + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration( + custom_provider=True, + load_balancing_enabled=True, + provider_schema=_build_provider_credential_schema(), + ) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + config = SimpleNamespace( + id="cfg-1", + name="primary", + encrypted_config=json.dumps({"api_key": "encrypted-key"}), + credential_id="cred-1", + enabled=True, + ) + mock_db.session.query.return_value.where.return_value.order_by.return_value.all.return_value = [config] + mocker.patch( + "services.model_load_balancing_service.encrypter.get_decrypt_decoding", + return_value=("rsa", "cipher"), + ) + mocker.patch( + "services.model_load_balancing_service.encrypter.decrypt_token_with_decoding", + return_value="plain-key", + ) + mocker.patch( + "services.model_load_balancing_service.LBModelManager.get_config_in_cooldown_and_ttl", + return_value=(False, 0), + ) + + # Act + is_enabled, configs = service.get_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + ) + + # Assert + assert is_enabled is True + assert len(configs) == 2 + assert configs[0]["name"] == "__inherit__" + assert configs[1]["name"] == "primary" + assert configs[1]["credentials"] == {"api_key": "plain-key"} + assert mock_db.session.add.call_count == 1 + assert mock_db.session.commit.call_count == 1 + + +def test_get_load_balancing_configs_should_reorder_existing_inherit_and_tolerate_json_or_decrypt_errors( + service: ModelLoadBalancingService, + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration( + custom_provider=True, + load_balancing_enabled=None, + provider_schema=_build_provider_credential_schema(), + ) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + normal_config = SimpleNamespace( + id="cfg-1", + name="normal", + encrypted_config=json.dumps({"api_key": "bad-encrypted"}), + credential_id="cred-1", + enabled=True, + ) + inherit_config = SimpleNamespace( + id="cfg-2", + name="__inherit__", + encrypted_config="not-json", + credential_id=None, + enabled=False, + ) + mock_db.session.query.return_value.where.return_value.order_by.return_value.all.return_value = [ + normal_config, + inherit_config, + ] + mocker.patch( + "services.model_load_balancing_service.encrypter.get_decrypt_decoding", + return_value=("rsa", "cipher"), + ) + mocker.patch( + "services.model_load_balancing_service.encrypter.decrypt_token_with_decoding", + side_effect=ValueError("cannot decrypt"), + ) + mocker.patch( + "services.model_load_balancing_service.LBModelManager.get_config_in_cooldown_and_ttl", + return_value=(True, 15), + ) + + # Act + is_enabled, configs = service.get_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + config_from="predefined-model", + ) + + # Assert + assert is_enabled is False + assert configs[0]["name"] == "__inherit__" + assert configs[0]["credentials"] == {} + assert configs[1]["credentials"] == {"api_key": "bad-encrypted"} + assert configs[1]["in_cooldown"] is True + assert configs[1]["ttl"] == 15 + + +def test_get_load_balancing_config_should_raise_value_error_when_provider_missing( + service: ModelLoadBalancingService, +) -> None: + # Arrange + service.provider_manager.get_configurations.return_value = {} + + # Act + Assert + with pytest.raises(ValueError, match="Provider openai does not exist"): + service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1") + + +def test_get_load_balancing_config_should_return_none_when_config_not_found( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + mock_db.session.query.return_value.where.return_value.first.return_value = None + + # Act + result = service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1") + + # Assert + assert result is None + + +def test_get_load_balancing_config_should_return_obfuscated_payload_when_config_exists( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + provider_configuration.obfuscated_credentials.side_effect = lambda credentials, credential_form_schemas: { + "masked": credentials.get("api_key", "") + } + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + config = SimpleNamespace(id="cfg-1", name="primary", encrypted_config="not-json", enabled=True) + mock_db.session.query.return_value.where.return_value.first.return_value = config + + # Act + result = service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1") + + # Assert + assert result == { + "id": "cfg-1", + "name": "primary", + "credentials": {"masked": ""}, + "enabled": True, + } + + +def test_init_inherit_config_should_create_and_persist_inherit_configuration( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + model_type = ModelType.LLM + + # Act + inherit_config = service._init_inherit_config("tenant-1", "openai", "gpt-4o-mini", model_type) + + # Assert + assert inherit_config.tenant_id == "tenant-1" + assert inherit_config.provider_name == "openai" + assert inherit_config.model_name == "gpt-4o-mini" + assert inherit_config.model_type == "text-generation" + assert inherit_config.name == "__inherit__" + mock_db.session.add.assert_called_once_with(inherit_config) + mock_db.session.commit.assert_called_once() + + +def test_update_load_balancing_configs_should_raise_value_error_when_provider_missing( + service: ModelLoadBalancingService, +) -> None: + # Arrange + service.provider_manager.get_configurations.return_value = {} + + # Act + Assert + with pytest.raises(ValueError, match="Provider openai does not exist"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [], + "custom-model", + ) + + +def test_update_load_balancing_configs_should_raise_value_error_when_configs_is_not_list( + service: ModelLoadBalancingService, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + + # Act + Assert + with pytest.raises(ValueError, match="Invalid load balancing configs"): + service.update_load_balancing_configs( # type: ignore[arg-type] + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + cast(list[dict[str, object]], "invalid-configs"), + "custom-model", + ) + + +def test_update_load_balancing_configs_should_raise_value_error_when_config_item_is_not_dict( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + mock_db.session.scalars.return_value.all.return_value = [] + + # Act + Assert + with pytest.raises(ValueError, match="Invalid load balancing config"): + service.update_load_balancing_configs( # type: ignore[list-item] + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + cast(list[dict[str, object]], ["bad-item"]), + "custom-model", + ) + + +def test_update_load_balancing_configs_should_raise_value_error_when_credential_id_not_found( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + mock_db.session.scalars.return_value.all.return_value = [] + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="Provider credential with id cred-1 not found"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"credential_id": "cred-1", "enabled": True}], + "predefined-model", + ) + + +def test_update_load_balancing_configs_should_raise_value_error_when_name_or_enabled_is_invalid( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + mock_db.session.scalars.return_value.all.return_value = [] + + # Act + Assert + with pytest.raises(ValueError, match="Invalid load balancing config name"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"enabled": True}], + "custom-model", + ) + + with pytest.raises(ValueError, match="Invalid load balancing config enabled"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"name": "cfg-without-enabled"}], + "custom-model", + ) + + +def test_update_load_balancing_configs_should_raise_value_error_when_existing_config_id_is_invalid( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + current_config = SimpleNamespace(id="cfg-1") + mock_db.session.scalars.return_value.all.return_value = [current_config] + + # Act + Assert + with pytest.raises(ValueError, match="Invalid load balancing config id: cfg-2"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"id": "cfg-2", "name": "invalid", "enabled": True}], + "custom-model", + ) + + +def test_update_load_balancing_configs_should_raise_value_error_when_credentials_are_invalid_for_update_or_create( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + existing_config = SimpleNamespace(id="cfg-1", name="old", enabled=True, encrypted_config=None, updated_at=None) + mock_db.session.scalars.return_value.all.return_value = [existing_config] + + # Act + Assert + with pytest.raises(ValueError, match="Invalid load balancing config credentials"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"id": "cfg-1", "name": "new", "enabled": True, "credentials": "bad"}], + "custom-model", + ) + + with pytest.raises(ValueError, match="Invalid load balancing config credentials"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"name": "new-config", "enabled": True, "credentials": "bad"}], + "custom-model", + ) + + +def test_update_load_balancing_configs_should_update_existing_create_new_and_delete_removed_configs( + service: ModelLoadBalancingService, + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + existing_config_1 = SimpleNamespace( + id="cfg-1", + name="existing-one", + enabled=True, + encrypted_config=json.dumps({"api_key": "old"}), + updated_at=None, + ) + existing_config_2 = SimpleNamespace( + id="cfg-2", + name="existing-two", + enabled=True, + encrypted_config=None, + updated_at=None, + ) + mock_db.session.scalars.return_value.all.return_value = [existing_config_1, existing_config_2] + mocker.patch.object(service, "_custom_credentials_validate", return_value={"api_key": "encrypted"}) + mock_clear_cache = mocker.patch.object(service, "_clear_credentials_cache") + + # Act + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [ + {"id": "cfg-1", "name": "updated-name", "enabled": False, "credentials": {"api_key": "plain"}}, + {"name": "new-config", "enabled": True, "credentials": {"api_key": "plain"}}, + ], + "custom-model", + ) + + # Assert + assert existing_config_1.name == "updated-name" + assert existing_config_1.enabled is False + assert json.loads(existing_config_1.encrypted_config) == {"api_key": "encrypted"} + assert mock_db.session.add.call_count == 1 + mock_db.session.delete.assert_called_once_with(existing_config_2) + assert mock_db.session.commit.call_count >= 3 + mock_clear_cache.assert_any_call("tenant-1", "cfg-1") + mock_clear_cache.assert_any_call("tenant-1", "cfg-2") + + +def test_update_load_balancing_configs_should_raise_value_error_for_invalid_new_config_name_or_missing_credentials( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + mock_db.session.scalars.return_value.all.return_value = [] + + # Act + Assert + with pytest.raises(ValueError, match="Invalid load balancing config name"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"name": "__inherit__", "enabled": True, "credentials": {"api_key": "x"}}], + "custom-model", + ) + + with pytest.raises(ValueError, match="Invalid load balancing config credentials"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"name": "new", "enabled": True}], + "custom-model", + ) + + +def test_update_load_balancing_configs_should_create_from_existing_provider_credential_when_credential_id_provided( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + mock_db.session.scalars.return_value.all.return_value = [] + credential_record = SimpleNamespace(credential_name="Main Credential", encrypted_config='{"api_key":"enc"}') + mock_db.session.query.return_value.filter_by.return_value.first.return_value = credential_record + + # Act + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"credential_id": "cred-1", "enabled": True}], + "predefined-model", + ) + + # Assert + created_config = mock_db.session.add.call_args.args[0] + assert created_config.name == "Main Credential" + assert created_config.credential_id == "cred-1" + assert created_config.credential_source_type == "provider" + assert created_config.encrypted_config == '{"api_key":"enc"}' + mock_db.session.commit.assert_called() + + +def test_validate_load_balancing_credentials_should_raise_value_error_when_provider_missing( + service: ModelLoadBalancingService, +) -> None: + # Arrange + service.provider_manager.get_configurations.return_value = {} + + # Act + Assert + with pytest.raises(ValueError, match="Provider openai does not exist"): + service.validate_load_balancing_credentials( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + {"api_key": "plain"}, + ) + + +def test_validate_load_balancing_credentials_should_raise_value_error_when_config_id_is_invalid( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + mock_db.session.query.return_value.where.return_value.first.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="Load balancing config cfg-1 does not exist"): + service.validate_load_balancing_credentials( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + {"api_key": "plain"}, + config_id="cfg-1", + ) + + +def test_validate_load_balancing_credentials_should_delegate_to_custom_validate_with_or_without_config( + service: ModelLoadBalancingService, + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + existing_config = SimpleNamespace(id="cfg-1") + mock_db.session.query.return_value.where.return_value.first.return_value = existing_config + mock_validate = mocker.patch.object(service, "_custom_credentials_validate") + + # Act + service.validate_load_balancing_credentials( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + {"api_key": "plain"}, + config_id="cfg-1", + ) + service.validate_load_balancing_credentials( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + {"api_key": "plain"}, + ) + + # Assert + assert mock_validate.call_count == 2 + assert mock_validate.call_args_list[0].kwargs["load_balancing_model_config"] is existing_config + assert mock_validate.call_args_list[1].kwargs["load_balancing_model_config"] is None + shared_model_provider_factory = service.model_assembly.model_provider_factory + assert mock_validate.call_args_list[0].kwargs["model_provider_factory"] is shared_model_provider_factory + assert mock_validate.call_args_list[1].kwargs["model_provider_factory"] is shared_model_provider_factory + + +def test_custom_credentials_validate_should_replace_hidden_secret_with_original_value_and_encrypt( + service: ModelLoadBalancingService, + mocker: MockerFixture, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + load_balancing_model_config = _load_balancing_model_config( + encrypted_config=json.dumps({"api_key": "old-encrypted-token"}) + ) + mocker.patch("services.model_load_balancing_service.encrypter.decrypt_token", return_value="old-plain-value") + mock_encrypt = mocker.patch( + "services.model_load_balancing_service.encrypter.encrypt_token", + side_effect=lambda tenant_id, value: f"enc:{value}", + ) + + # Act + result = service._custom_credentials_validate( + tenant_id="tenant-1", + provider_configuration=provider_configuration, + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": HIDDEN_VALUE, "region": "us"}, + load_balancing_model_config=load_balancing_model_config, + validate=False, + ) + + # Assert + assert result == {"api_key": "enc:old-plain-value", "region": "us"} + mock_encrypt.assert_called_once_with("tenant-1", "old-plain-value") + + +def test_custom_credentials_validate_should_handle_invalid_original_json_and_validate_with_model_schema( + service: ModelLoadBalancingService, + mocker: MockerFixture, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(model_schema=_build_model_credential_schema()) + load_balancing_model_config = _load_balancing_model_config(encrypted_config="not-json") + mock_factory = MagicMock() + mock_factory.model_credentials_validate.return_value = {"api_key": "validated"} + mock_encrypt = mocker.patch( + "services.model_load_balancing_service.encrypter.encrypt_token", + side_effect=lambda tenant_id, value: f"enc:{value}", + ) + + # Act + result = service._custom_credentials_validate( + tenant_id="tenant-1", + provider_configuration=provider_configuration, + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "plain"}, + load_balancing_model_config=load_balancing_model_config, + model_provider_factory=mock_factory, + validate=True, + ) + + # Assert + assert result == {"api_key": "enc:validated"} + mock_factory.model_credentials_validate.assert_called_once() + mock_factory.provider_credentials_validate.assert_not_called() + mock_encrypt.assert_called_once_with("tenant-1", "validated") + + +def test_custom_credentials_validate_should_validate_with_provider_schema_when_model_schema_absent( + service: ModelLoadBalancingService, + mocker: MockerFixture, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + mock_factory = MagicMock() + mock_factory.provider_credentials_validate.return_value = {"api_key": "provider-validated"} + mocker.patch( + "services.model_load_balancing_service.encrypter.encrypt_token", + side_effect=lambda tenant_id, value: f"enc:{value}", + ) + + # Act + result = service._custom_credentials_validate( + tenant_id="tenant-1", + provider_configuration=provider_configuration, + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "plain"}, + model_provider_factory=mock_factory, + validate=True, + ) + + # Assert + assert result == {"api_key": "enc:provider-validated"} + mock_factory.provider_credentials_validate.assert_called_once() + mock_factory.model_credentials_validate.assert_not_called() + + +def test_get_credential_schema_should_return_model_schema_or_provider_schema_or_raise( + service: ModelLoadBalancingService, +) -> None: + # Arrange + model_schema = _build_model_credential_schema() + provider_schema = _build_provider_credential_schema() + provider_configuration_with_model = _build_provider_configuration(model_schema=model_schema) + provider_configuration_with_provider = _build_provider_configuration(provider_schema=provider_schema) + provider_configuration_without_schema = _build_provider_configuration() + + # Act + schema_from_model = service._get_credential_schema(provider_configuration_with_model) + schema_from_provider = service._get_credential_schema(provider_configuration_with_provider) + + # Assert + assert schema_from_model is model_schema + assert schema_from_provider is provider_schema + with pytest.raises(ValueError, match="No credential schema found"): + service._get_credential_schema(provider_configuration_without_schema) + + +def test_clear_credentials_cache_should_delete_load_balancing_cache_entry( + service: ModelLoadBalancingService, + mocker: MockerFixture, +) -> None: + # Arrange + mock_cache_instance = MagicMock() + mock_cache_cls = mocker.patch( + "services.model_load_balancing_service.ProviderCredentialsCache", + return_value=mock_cache_instance, + ) + + # Act + service._clear_credentials_cache("tenant-1", "cfg-1") + + # Assert + mock_cache_cls.assert_called_once() + assert mock_cache_cls.call_args.kwargs == { + "tenant_id": "tenant-1", + "identity_id": "cfg-1", + "cache_type": mocker.ANY, + } + assert mock_cache_cls.call_args.kwargs["cache_type"].name == "LOAD_BALANCING_MODEL" + mock_cache_instance.delete.assert_called_once() diff --git a/api/tests/unit_tests/services/test_model_provider_service_sanitization.py b/api/tests/unit_tests/services/test_model_provider_service_sanitization.py index 6a6b63f0037..1bd979b9ec2 100644 --- a/api/tests/unit_tests/services/test_model_provider_service_sanitization.py +++ b/api/tests/unit_tests/services/test_model_provider_service_sanitization.py @@ -1,11 +1,11 @@ import types import pytest +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ConfigurateMethod from core.entities.provider_entities import CredentialConfiguration, CustomModelConfiguration -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod from models.provider import ProviderType from services.model_provider_service import ModelProviderService @@ -71,7 +71,7 @@ def service_with_fake_configurations(): return _FakeConfigurations(fake_provider_configuration) svc = ModelProviderService() - svc.provider_manager = _FakeProviderManager() + svc._get_provider_manager = lambda tenant_id: _FakeProviderManager() return svc diff --git a/api/tests/unit_tests/services/test_restore_archived_workflow_run.py b/api/tests/unit_tests/services/test_restore_archived_workflow_run.py deleted file mode 100644 index a214ecf7284..00000000000 --- a/api/tests/unit_tests/services/test_restore_archived_workflow_run.py +++ /dev/null @@ -1,37 +0,0 @@ -""" -Unit tests for workflow run restore functionality. -""" - -from datetime import datetime - - -class TestWorkflowRunRestore: - """Tests for the WorkflowRunRestore class.""" - - def test_restore_initialization(self): - """Restore service should respect dry_run flag.""" - from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore - - restore = WorkflowRunRestore(dry_run=True) - - assert restore.dry_run is True - - def test_convert_datetime_fields(self): - """ISO datetime strings should be converted to datetime objects.""" - from models.workflow import WorkflowRun - from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore - - record = { - "id": "test-id", - "created_at": "2024-01-01T12:00:00", - "finished_at": "2024-01-01T12:05:00", - "name": "test", - } - - restore = WorkflowRunRestore() - result = restore._convert_datetime_fields(record, WorkflowRun) - - assert isinstance(result["created_at"], datetime) - assert result["created_at"].year == 2024 - assert result["created_at"].month == 1 - assert result["name"] == "test" diff --git a/api/tests/unit_tests/services/test_saved_message_service.py b/api/tests/unit_tests/services/test_saved_message_service.py deleted file mode 100644 index 87b946fe461..00000000000 --- a/api/tests/unit_tests/services/test_saved_message_service.py +++ /dev/null @@ -1,626 +0,0 @@ -""" -Comprehensive unit tests for SavedMessageService. - -This test suite provides complete coverage of saved message operations in Dify, -following TDD principles with the Arrange-Act-Assert pattern. - -## Test Coverage - -### 1. Pagination (TestSavedMessageServicePagination) -Tests saved message listing and pagination: -- Pagination with valid user (Account and EndUser) -- Pagination without user raises ValueError -- Pagination with last_id parameter -- Empty results when no saved messages exist -- Integration with MessageService pagination - -### 2. Save Operations (TestSavedMessageServiceSave) -Tests saving messages: -- Save message for Account user -- Save message for EndUser -- Save without user (no-op) -- Prevent duplicate saves (idempotent) -- Message validation through MessageService - -### 3. Delete Operations (TestSavedMessageServiceDelete) -Tests deleting saved messages: -- Delete saved message for Account user -- Delete saved message for EndUser -- Delete without user (no-op) -- Delete non-existent saved message (no-op) -- Proper database cleanup - -## Testing Approach - -- **Mocking Strategy**: All external dependencies (database, MessageService) are mocked - for fast, isolated unit tests -- **Factory Pattern**: SavedMessageServiceTestDataFactory provides consistent test data -- **Fixtures**: Mock objects are configured per test method -- **Assertions**: Each test verifies return values and side effects - (database operations, method calls) - -## Key Concepts - -**User Types:** -- Account: Workspace members (console users) -- EndUser: API users (end users) - -**Saved Messages:** -- Users can save messages for later reference -- Each user has their own saved message list -- Saving is idempotent (duplicate saves ignored) -- Deletion is safe (non-existent deletes ignored) -""" - -from datetime import UTC, datetime -from unittest.mock import MagicMock, Mock, create_autospec, patch - -import pytest - -from libs.infinite_scroll_pagination import InfiniteScrollPagination -from models import Account -from models.model import App, EndUser, Message -from models.web import SavedMessage -from services.saved_message_service import SavedMessageService - - -class SavedMessageServiceTestDataFactory: - """ - Factory for creating test data and mock objects. - - Provides reusable methods to create consistent mock objects for testing - saved message operations. - """ - - @staticmethod - def create_account_mock(account_id: str = "account-123", **kwargs) -> Mock: - """ - Create a mock Account object. - - Args: - account_id: Unique identifier for the account - **kwargs: Additional attributes to set on the mock - - Returns: - Mock Account object with specified attributes - """ - account = create_autospec(Account, instance=True) - account.id = account_id - for key, value in kwargs.items(): - setattr(account, key, value) - return account - - @staticmethod - def create_end_user_mock(user_id: str = "user-123", **kwargs) -> Mock: - """ - Create a mock EndUser object. - - Args: - user_id: Unique identifier for the end user - **kwargs: Additional attributes to set on the mock - - Returns: - Mock EndUser object with specified attributes - """ - user = create_autospec(EndUser, instance=True) - user.id = user_id - for key, value in kwargs.items(): - setattr(user, key, value) - return user - - @staticmethod - def create_app_mock(app_id: str = "app-123", tenant_id: str = "tenant-123", **kwargs) -> Mock: - """ - Create a mock App object. - - Args: - app_id: Unique identifier for the app - tenant_id: Tenant/workspace identifier - **kwargs: Additional attributes to set on the mock - - Returns: - Mock App object with specified attributes - """ - app = create_autospec(App, instance=True) - app.id = app_id - app.tenant_id = tenant_id - app.name = kwargs.get("name", "Test App") - app.mode = kwargs.get("mode", "chat") - for key, value in kwargs.items(): - setattr(app, key, value) - return app - - @staticmethod - def create_message_mock( - message_id: str = "msg-123", - app_id: str = "app-123", - **kwargs, - ) -> Mock: - """ - Create a mock Message object. - - Args: - message_id: Unique identifier for the message - app_id: Associated app identifier - **kwargs: Additional attributes to set on the mock - - Returns: - Mock Message object with specified attributes - """ - message = create_autospec(Message, instance=True) - message.id = message_id - message.app_id = app_id - message.query = kwargs.get("query", "Test query") - message.answer = kwargs.get("answer", "Test answer") - message.created_at = kwargs.get("created_at", datetime.now(UTC)) - for key, value in kwargs.items(): - setattr(message, key, value) - return message - - @staticmethod - def create_saved_message_mock( - saved_message_id: str = "saved-123", - app_id: str = "app-123", - message_id: str = "msg-123", - created_by: str = "user-123", - created_by_role: str = "account", - **kwargs, - ) -> Mock: - """ - Create a mock SavedMessage object. - - Args: - saved_message_id: Unique identifier for the saved message - app_id: Associated app identifier - message_id: Associated message identifier - created_by: User who saved the message - created_by_role: Role of the user ('account' or 'end_user') - **kwargs: Additional attributes to set on the mock - - Returns: - Mock SavedMessage object with specified attributes - """ - saved_message = create_autospec(SavedMessage, instance=True) - saved_message.id = saved_message_id - saved_message.app_id = app_id - saved_message.message_id = message_id - saved_message.created_by = created_by - saved_message.created_by_role = created_by_role - saved_message.created_at = kwargs.get("created_at", datetime.now(UTC)) - for key, value in kwargs.items(): - setattr(saved_message, key, value) - return saved_message - - -@pytest.fixture -def factory(): - """Provide the test data factory to all tests.""" - return SavedMessageServiceTestDataFactory - - -class TestSavedMessageServicePagination: - """Test saved message pagination operations.""" - - @patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True) - @patch("services.saved_message_service.db.session", autospec=True) - def test_pagination_with_account_user(self, mock_db_session, mock_message_pagination, factory): - """Test pagination with an Account user.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_account_mock() - - # Create saved messages for this user - saved_messages = [ - factory.create_saved_message_mock( - saved_message_id=f"saved-{i}", - app_id=app.id, - message_id=f"msg-{i}", - created_by=user.id, - created_by_role="account", - ) - for i in range(3) - ] - - # Mock database query - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = saved_messages - - # Mock MessageService pagination response - expected_pagination = InfiniteScrollPagination(data=[], limit=20, has_more=False) - mock_message_pagination.return_value = expected_pagination - - # Act - result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=None, limit=20) - - # Assert - assert result == expected_pagination - mock_db_session.query.assert_called_once_with(SavedMessage) - # Verify MessageService was called with correct message IDs - mock_message_pagination.assert_called_once_with( - app_model=app, - user=user, - last_id=None, - limit=20, - include_ids=["msg-0", "msg-1", "msg-2"], - ) - - @patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True) - @patch("services.saved_message_service.db.session", autospec=True) - def test_pagination_with_end_user(self, mock_db_session, mock_message_pagination, factory): - """Test pagination with an EndUser.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_end_user_mock() - - # Create saved messages for this end user - saved_messages = [ - factory.create_saved_message_mock( - saved_message_id=f"saved-{i}", - app_id=app.id, - message_id=f"msg-{i}", - created_by=user.id, - created_by_role="end_user", - ) - for i in range(2) - ] - - # Mock database query - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = saved_messages - - # Mock MessageService pagination response - expected_pagination = InfiniteScrollPagination(data=[], limit=10, has_more=False) - mock_message_pagination.return_value = expected_pagination - - # Act - result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=None, limit=10) - - # Assert - assert result == expected_pagination - # Verify correct role was used in query - mock_message_pagination.assert_called_once_with( - app_model=app, - user=user, - last_id=None, - limit=10, - include_ids=["msg-0", "msg-1"], - ) - - def test_pagination_without_user_raises_error(self, factory): - """Test that pagination without user raises ValueError.""" - # Arrange - app = factory.create_app_mock() - - # Act & Assert - with pytest.raises(ValueError, match="User is required"): - SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=20) - - @patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True) - @patch("services.saved_message_service.db.session", autospec=True) - def test_pagination_with_last_id(self, mock_db_session, mock_message_pagination, factory): - """Test pagination with last_id parameter.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_account_mock() - last_id = "msg-last" - - saved_messages = [ - factory.create_saved_message_mock( - message_id=f"msg-{i}", - app_id=app.id, - created_by=user.id, - ) - for i in range(5) - ] - - # Mock database query - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = saved_messages - - # Mock MessageService pagination response - expected_pagination = InfiniteScrollPagination(data=[], limit=10, has_more=True) - mock_message_pagination.return_value = expected_pagination - - # Act - result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=last_id, limit=10) - - # Assert - assert result == expected_pagination - # Verify last_id was passed to MessageService - mock_message_pagination.assert_called_once() - call_args = mock_message_pagination.call_args - assert call_args.kwargs["last_id"] == last_id - - @patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True) - @patch("services.saved_message_service.db.session", autospec=True) - def test_pagination_with_empty_saved_messages(self, mock_db_session, mock_message_pagination, factory): - """Test pagination when user has no saved messages.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_account_mock() - - # Mock database query returning empty list - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Mock MessageService pagination response - expected_pagination = InfiniteScrollPagination(data=[], limit=20, has_more=False) - mock_message_pagination.return_value = expected_pagination - - # Act - result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=None, limit=20) - - # Assert - assert result == expected_pagination - # Verify MessageService was called with empty include_ids - mock_message_pagination.assert_called_once_with( - app_model=app, - user=user, - last_id=None, - limit=20, - include_ids=[], - ) - - -class TestSavedMessageServiceSave: - """Test save message operations.""" - - @patch("services.saved_message_service.MessageService.get_message", autospec=True) - @patch("services.saved_message_service.db.session", autospec=True) - def test_save_message_for_account(self, mock_db_session, mock_get_message, factory): - """Test saving a message for an Account user.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_account_mock() - message = factory.create_message_mock(message_id="msg-123", app_id=app.id) - - # Mock database query - no existing saved message - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None - - # Mock MessageService.get_message - mock_get_message.return_value = message - - # Act - SavedMessageService.save(app_model=app, user=user, message_id=message.id) - - # Assert - mock_db_session.add.assert_called_once() - saved_message = mock_db_session.add.call_args[0][0] - assert saved_message.app_id == app.id - assert saved_message.message_id == message.id - assert saved_message.created_by == user.id - assert saved_message.created_by_role == "account" - mock_db_session.commit.assert_called_once() - - @patch("services.saved_message_service.MessageService.get_message", autospec=True) - @patch("services.saved_message_service.db.session", autospec=True) - def test_save_message_for_end_user(self, mock_db_session, mock_get_message, factory): - """Test saving a message for an EndUser.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_end_user_mock() - message = factory.create_message_mock(message_id="msg-456", app_id=app.id) - - # Mock database query - no existing saved message - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None - - # Mock MessageService.get_message - mock_get_message.return_value = message - - # Act - SavedMessageService.save(app_model=app, user=user, message_id=message.id) - - # Assert - mock_db_session.add.assert_called_once() - saved_message = mock_db_session.add.call_args[0][0] - assert saved_message.app_id == app.id - assert saved_message.message_id == message.id - assert saved_message.created_by == user.id - assert saved_message.created_by_role == "end_user" - mock_db_session.commit.assert_called_once() - - @patch("services.saved_message_service.db.session", autospec=True) - def test_save_without_user_does_nothing(self, mock_db_session, factory): - """Test that saving without user is a no-op.""" - # Arrange - app = factory.create_app_mock() - - # Act - SavedMessageService.save(app_model=app, user=None, message_id="msg-123") - - # Assert - mock_db_session.query.assert_not_called() - mock_db_session.add.assert_not_called() - mock_db_session.commit.assert_not_called() - - @patch("services.saved_message_service.MessageService.get_message", autospec=True) - @patch("services.saved_message_service.db.session", autospec=True) - def test_save_duplicate_message_is_idempotent(self, mock_db_session, mock_get_message, factory): - """Test that saving an already saved message is idempotent.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_account_mock() - message_id = "msg-789" - - # Mock database query - existing saved message found - existing_saved = factory.create_saved_message_mock( - app_id=app.id, - message_id=message_id, - created_by=user.id, - created_by_role="account", - ) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = existing_saved - - # Act - SavedMessageService.save(app_model=app, user=user, message_id=message_id) - - # Assert - no new saved message created - mock_db_session.add.assert_not_called() - mock_db_session.commit.assert_not_called() - mock_get_message.assert_not_called() - - @patch("services.saved_message_service.MessageService.get_message", autospec=True) - @patch("services.saved_message_service.db.session", autospec=True) - def test_save_validates_message_exists(self, mock_db_session, mock_get_message, factory): - """Test that save validates message exists through MessageService.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_account_mock() - message = factory.create_message_mock() - - # Mock database query - no existing saved message - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None - - # Mock MessageService.get_message - mock_get_message.return_value = message - - # Act - SavedMessageService.save(app_model=app, user=user, message_id=message.id) - - # Assert - MessageService.get_message was called for validation - mock_get_message.assert_called_once_with(app_model=app, user=user, message_id=message.id) - - -class TestSavedMessageServiceDelete: - """Test delete saved message operations.""" - - @patch("services.saved_message_service.db.session", autospec=True) - def test_delete_saved_message_for_account(self, mock_db_session, factory): - """Test deleting a saved message for an Account user.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_account_mock() - message_id = "msg-123" - - # Mock database query - existing saved message found - saved_message = factory.create_saved_message_mock( - app_id=app.id, - message_id=message_id, - created_by=user.id, - created_by_role="account", - ) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = saved_message - - # Act - SavedMessageService.delete(app_model=app, user=user, message_id=message_id) - - # Assert - mock_db_session.delete.assert_called_once_with(saved_message) - mock_db_session.commit.assert_called_once() - - @patch("services.saved_message_service.db.session", autospec=True) - def test_delete_saved_message_for_end_user(self, mock_db_session, factory): - """Test deleting a saved message for an EndUser.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_end_user_mock() - message_id = "msg-456" - - # Mock database query - existing saved message found - saved_message = factory.create_saved_message_mock( - app_id=app.id, - message_id=message_id, - created_by=user.id, - created_by_role="end_user", - ) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = saved_message - - # Act - SavedMessageService.delete(app_model=app, user=user, message_id=message_id) - - # Assert - mock_db_session.delete.assert_called_once_with(saved_message) - mock_db_session.commit.assert_called_once() - - @patch("services.saved_message_service.db.session", autospec=True) - def test_delete_without_user_does_nothing(self, mock_db_session, factory): - """Test that deleting without user is a no-op.""" - # Arrange - app = factory.create_app_mock() - - # Act - SavedMessageService.delete(app_model=app, user=None, message_id="msg-123") - - # Assert - mock_db_session.query.assert_not_called() - mock_db_session.delete.assert_not_called() - mock_db_session.commit.assert_not_called() - - @patch("services.saved_message_service.db.session", autospec=True) - def test_delete_non_existent_saved_message_does_nothing(self, mock_db_session, factory): - """Test that deleting a non-existent saved message is a no-op.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_account_mock() - message_id = "msg-nonexistent" - - # Mock database query - no saved message found - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None - - # Act - SavedMessageService.delete(app_model=app, user=user, message_id=message_id) - - # Assert - no deletion occurred - mock_db_session.delete.assert_not_called() - mock_db_session.commit.assert_not_called() - - @patch("services.saved_message_service.db.session", autospec=True) - def test_delete_only_affects_user_own_saved_messages(self, mock_db_session, factory): - """Test that delete only removes the user's own saved message.""" - # Arrange - app = factory.create_app_mock() - user1 = factory.create_account_mock(account_id="user-1") - message_id = "msg-shared" - - # Mock database query - finds user1's saved message - saved_message = factory.create_saved_message_mock( - app_id=app.id, - message_id=message_id, - created_by=user1.id, - created_by_role="account", - ) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = saved_message - - # Act - SavedMessageService.delete(app_model=app, user=user1, message_id=message_id) - - # Assert - only user1's saved message is deleted - mock_db_session.delete.assert_called_once_with(saved_message) - # Verify the query filters by user - assert mock_query.where.called diff --git a/api/tests/unit_tests/services/test_summary_index_service.py b/api/tests/unit_tests/services/test_summary_index_service.py index be64e431ba7..cbf3e121d8c 100644 --- a/api/tests/unit_tests/services/test_summary_index_service.py +++ b/api/tests/unit_tests/services/test_summary_index_service.py @@ -11,6 +11,7 @@ from unittest.mock import MagicMock import pytest import services.summary_index_service as summary_module +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.enums import SegmentStatus, SummaryStatus from services.summary_index_service import SummaryIndexService @@ -26,7 +27,7 @@ class _SessionContext: return None -def _dataset(*, indexing_technique: str = "high_quality") -> MagicMock: +def _dataset(*, indexing_technique: str = IndexTechniqueType.HIGH_QUALITY) -> MagicMock: dataset = MagicMock(name="dataset") dataset.id = "dataset-1" dataset.tenant_id = "tenant-1" @@ -48,7 +49,7 @@ def _segment(*, has_document: bool = True) -> MagicMock: if has_document: doc = MagicMock(name="document") doc.doc_language = "en" - doc.doc_form = "text_model" + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX segment.document = doc else: segment.document = None @@ -168,7 +169,8 @@ def test_create_summary_record_creates_new(monkeypatch: pytest.MonkeyPatch) -> N def test_vectorize_summary_skips_non_high_quality(monkeypatch: pytest.MonkeyPatch) -> None: vector_cls = MagicMock() monkeypatch.setattr(summary_module, "Vector", vector_cls) - SummaryIndexService.vectorize_summary(_summary_record(), _segment(), _dataset(indexing_technique="economy")) + dataset = _dataset(indexing_technique=IndexTechniqueType.ECONOMY) + SummaryIndexService.vectorize_summary(_summary_record(), _segment(), dataset) vector_cls.assert_not_called() @@ -189,7 +191,7 @@ def test_vectorize_summary_retries_connection_errors_then_succeeds(monkeypatch: embedding_model.get_text_embedding_num_tokens.return_value = [5] model_manager = MagicMock() model_manager.get_model_instance.return_value = embedding_model - monkeypatch.setattr(summary_module, "ModelManager", MagicMock(return_value=model_manager)) + monkeypatch.setattr(summary_module.ModelManager, "for_tenant", MagicMock(return_value=model_manager)) vector_instance = MagicMock() vector_instance.add_texts.side_effect = [RuntimeError("connection timeout"), None] @@ -228,7 +230,7 @@ def test_vectorize_summary_without_session_creates_record_when_missing(monkeypat model_manager = MagicMock() model_manager.get_model_instance.side_effect = RuntimeError("no model") - monkeypatch.setattr(summary_module, "ModelManager", MagicMock(return_value=model_manager)) + monkeypatch.setattr(summary_module.ModelManager, "for_tenant", MagicMock(return_value=model_manager)) # New session used after vectorization succeeds (record not found by id nor chunk_id). session = MagicMock(name="session") @@ -405,8 +407,8 @@ def test_vectorize_summary_updates_existing_record_found_by_chunk_id(monkeypatch vector_instance.add_texts.return_value = None monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) monkeypatch.setattr( - summary_module, - "ModelManager", + summary_module.ModelManager, + "for_tenant", MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))), ) @@ -439,8 +441,8 @@ def test_vectorize_summary_updates_existing_record_found_by_id(monkeypatch: pyte summary_module, "Vector", MagicMock(return_value=MagicMock(add_texts=MagicMock(return_value=None))) ) monkeypatch.setattr( - summary_module, - "ModelManager", + summary_module.ModelManager, + "for_tenant", MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))), ) @@ -472,8 +474,8 @@ def test_vectorize_summary_session_enter_returns_none_triggers_runtime_error(mon summary_module, "Vector", MagicMock(return_value=MagicMock(add_texts=MagicMock(return_value=None))) ) monkeypatch.setattr( - summary_module, - "ModelManager", + summary_module.ModelManager, + "for_tenant", MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))), ) @@ -508,8 +510,8 @@ def test_vectorize_summary_created_record_becomes_none_triggers_guard(monkeypatc summary_module, "Vector", MagicMock(return_value=MagicMock(add_texts=MagicMock(return_value=None))) ) monkeypatch.setattr( - summary_module, - "ModelManager", + summary_module.ModelManager, + "for_tenant", MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))), ) @@ -620,16 +622,16 @@ def test_generate_and_vectorize_summary_creates_missing_record_and_logs_usage(mo def test_generate_summaries_for_document_skip_conditions(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _dataset(indexing_technique="economy") + dataset = _dataset(indexing_technique=IndexTechniqueType.ECONOMY) document = MagicMock(spec=summary_module.DatasetDocument) document.id = "doc-1" - document.doc_form = "text_model" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": True}) == [] dataset = _dataset() assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": False}) == [] - document.doc_form = "qa_model" + document.doc_form = IndexStructureType.QA_INDEX assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": True}) == [] @@ -637,7 +639,7 @@ def test_generate_summaries_for_document_runs_and_handles_errors(monkeypatch: py dataset = _dataset() document = MagicMock(spec=summary_module.DatasetDocument) document.id = "doc-1" - document.doc_form = "text_model" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX seg1 = _segment() seg2 = _segment() @@ -673,7 +675,7 @@ def test_generate_summaries_for_document_no_segments_returns_empty(monkeypatch: dataset = _dataset() document = MagicMock(spec=summary_module.DatasetDocument) document.id = "doc-1" - document.doc_form = "text_model" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX session = MagicMock() query = MagicMock() @@ -696,7 +698,7 @@ def test_generate_summaries_for_document_applies_segment_ids_and_only_parent_chu dataset = _dataset() document = MagicMock(spec=summary_module.DatasetDocument) document.id = "doc-1" - document.doc_form = "text_model" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX seg = _segment() session = MagicMock() @@ -777,7 +779,7 @@ def test_disable_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.Mo def test_enable_summaries_for_segments_skips_non_high_quality() -> None: - SummaryIndexService.enable_summaries_for_segments(_dataset(indexing_technique="economy")) + SummaryIndexService.enable_summaries_for_segments(_dataset(indexing_technique=IndexTechniqueType.ECONOMY)) def test_enable_summaries_for_segments_revectorizes_and_enables(monkeypatch: pytest.MonkeyPatch) -> None: @@ -931,11 +933,10 @@ def test_delete_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.Mon def test_update_summary_for_segment_skip_conditions() -> None: - assert ( - SummaryIndexService.update_summary_for_segment(_segment(), _dataset(indexing_technique="economy"), "x") is None - ) + economy_dataset = _dataset(indexing_technique=IndexTechniqueType.ECONOMY) + assert SummaryIndexService.update_summary_for_segment(_segment(), economy_dataset, "x") is None seg = _segment(has_document=True) - seg.document.doc_form = "qa_model" + seg.document.doc_form = IndexStructureType.QA_INDEX assert SummaryIndexService.update_summary_for_segment(seg, _dataset(), "x") is None diff --git a/api/tests/unit_tests/services/test_tag_service.py b/api/tests/unit_tests/services/test_tag_service.py deleted file mode 100644 index 264eac4d77f..00000000000 --- a/api/tests/unit_tests/services/test_tag_service.py +++ /dev/null @@ -1,1335 +0,0 @@ -""" -Comprehensive unit tests for TagService. - -This test suite provides complete coverage of tag management operations in Dify, -following TDD principles with the Arrange-Act-Assert pattern. - -The TagService is responsible for managing tags that can be associated with -datasets (knowledge bases) and applications. Tags enable users to organize, -categorize, and filter their content effectively. - -## Test Coverage - -### 1. Tag Retrieval (TestTagServiceRetrieval) -Tests tag listing and filtering: -- Get tags with binding counts -- Filter tags by keyword (case-insensitive) -- Get tags by target ID (apps/datasets) -- Get tags by tag name -- Get target IDs by tag IDs -- Empty results handling - -### 2. Tag CRUD Operations (TestTagServiceCRUD) -Tests tag creation, update, and deletion: -- Create new tags -- Prevent duplicate tag names -- Update tag names -- Update with duplicate name validation -- Delete tags and cascade delete bindings -- Get tag binding counts -- NotFound error handling - -### 3. Tag Binding Operations (TestTagServiceBindings) -Tests tag-to-resource associations: -- Save tag bindings (apps/datasets) -- Prevent duplicate bindings (idempotent) -- Delete tag bindings -- Check target exists validation -- Batch binding operations - -## Testing Approach - -- **Mocking Strategy**: All external dependencies (database, current_user) are mocked - for fast, isolated unit tests -- **Factory Pattern**: TagServiceTestDataFactory provides consistent test data -- **Fixtures**: Mock objects are configured per test method -- **Assertions**: Each test verifies return values and side effects - (database operations, method calls) - -## Key Concepts - -**Tag Types:** -- knowledge: Tags for datasets/knowledge bases -- app: Tags for applications - -**Tag Bindings:** -- Many-to-many relationship between tags and resources -- Each binding links a tag to a specific app or dataset -- Bindings are tenant-scoped for multi-tenancy - -**Validation:** -- Tag names must be unique within tenant and type -- Target resources must exist before binding -- Cascade deletion of bindings when tag is deleted -""" - - -# ============================================================================ -# IMPORTS -# ============================================================================ - -from datetime import UTC, datetime -from unittest.mock import MagicMock, Mock, create_autospec, patch - -import pytest -from werkzeug.exceptions import NotFound - -from models.dataset import Dataset -from models.model import App, Tag, TagBinding -from services.tag_service import TagService - -# ============================================================================ -# TEST DATA FACTORY -# ============================================================================ - - -class TagServiceTestDataFactory: - """ - Factory for creating test data and mock objects. - - Provides reusable methods to create consistent mock objects for testing - tag-related operations. This factory ensures all test data follows the - same structure and reduces code duplication across tests. - - The factory pattern is used here to: - - Ensure consistent test data creation - - Reduce boilerplate code in individual tests - - Make tests more maintainable and readable - - Centralize mock object configuration - """ - - @staticmethod - def create_tag_mock( - tag_id: str = "tag-123", - name: str = "Test Tag", - tag_type: str = "app", - tenant_id: str = "tenant-123", - **kwargs, - ) -> Mock: - """ - Create a mock Tag object. - - This method creates a mock Tag instance with all required attributes - set to sensible defaults. Additional attributes can be passed via - kwargs to customize the mock for specific test scenarios. - - Args: - tag_id: Unique identifier for the tag - name: Tag name (e.g., "Frontend", "Backend", "Data Science") - tag_type: Type of tag ('app' or 'knowledge') - tenant_id: Tenant identifier for multi-tenancy isolation - **kwargs: Additional attributes to set on the mock - (e.g., created_by, created_at, etc.) - - Returns: - Mock Tag object with specified attributes - - Example: - >>> tag = factory.create_tag_mock( - ... tag_id="tag-456", - ... name="Machine Learning", - ... tag_type="knowledge" - ... ) - """ - # Create a mock that matches the Tag model interface - tag = create_autospec(Tag, instance=True) - - # Set core attributes - tag.id = tag_id - tag.name = name - tag.type = tag_type - tag.tenant_id = tenant_id - - # Set default optional attributes - tag.created_by = kwargs.pop("created_by", "user-123") - tag.created_at = kwargs.pop("created_at", datetime(2023, 1, 1, 0, 0, 0, tzinfo=UTC)) - - # Apply any additional attributes from kwargs - for key, value in kwargs.items(): - setattr(tag, key, value) - - return tag - - @staticmethod - def create_tag_binding_mock( - binding_id: str = "binding-123", - tag_id: str = "tag-123", - target_id: str = "target-123", - tenant_id: str = "tenant-123", - **kwargs, - ) -> Mock: - """ - Create a mock TagBinding object. - - TagBindings represent the many-to-many relationship between tags - and resources (datasets or apps). This method creates a mock - binding with the necessary attributes. - - Args: - binding_id: Unique identifier for the binding - tag_id: Associated tag identifier - target_id: Associated target (app/dataset) identifier - tenant_id: Tenant identifier for multi-tenancy isolation - **kwargs: Additional attributes to set on the mock - (e.g., created_by, etc.) - - Returns: - Mock TagBinding object with specified attributes - - Example: - >>> binding = factory.create_tag_binding_mock( - ... tag_id="tag-456", - ... target_id="dataset-789", - ... tenant_id="tenant-123" - ... ) - """ - # Create a mock that matches the TagBinding model interface - binding = create_autospec(TagBinding, instance=True) - - # Set core attributes - binding.id = binding_id - binding.tag_id = tag_id - binding.target_id = target_id - binding.tenant_id = tenant_id - - # Set default optional attributes - binding.created_by = kwargs.pop("created_by", "user-123") - - # Apply any additional attributes from kwargs - for key, value in kwargs.items(): - setattr(binding, key, value) - - return binding - - @staticmethod - def create_app_mock(app_id: str = "app-123", tenant_id: str = "tenant-123", **kwargs) -> Mock: - """ - Create a mock App object. - - This method creates a mock App instance for testing tag bindings - to applications. Apps are one of the two target types that tags - can be bound to (the other being datasets/knowledge bases). - - Args: - app_id: Unique identifier for the app - tenant_id: Tenant identifier for multi-tenancy isolation - **kwargs: Additional attributes to set on the mock - - Returns: - Mock App object with specified attributes - - Example: - >>> app = factory.create_app_mock( - ... app_id="app-456", - ... name="My Chat App" - ... ) - """ - # Create a mock that matches the App model interface - app = create_autospec(App, instance=True) - - # Set core attributes - app.id = app_id - app.tenant_id = tenant_id - app.name = kwargs.get("name", "Test App") - - # Apply any additional attributes from kwargs - for key, value in kwargs.items(): - setattr(app, key, value) - - return app - - @staticmethod - def create_dataset_mock(dataset_id: str = "dataset-123", tenant_id: str = "tenant-123", **kwargs) -> Mock: - """ - Create a mock Dataset object. - - This method creates a mock Dataset instance for testing tag bindings - to knowledge bases. Datasets (knowledge bases) are one of the two - target types that tags can be bound to (the other being apps). - - Args: - dataset_id: Unique identifier for the dataset - tenant_id: Tenant identifier for multi-tenancy isolation - **kwargs: Additional attributes to set on the mock - - Returns: - Mock Dataset object with specified attributes - - Example: - >>> dataset = factory.create_dataset_mock( - ... dataset_id="dataset-456", - ... name="My Knowledge Base" - ... ) - """ - # Create a mock that matches the Dataset model interface - dataset = create_autospec(Dataset, instance=True) - - # Set core attributes - dataset.id = dataset_id - dataset.tenant_id = tenant_id - dataset.name = kwargs.pop("name", "Test Dataset") - - # Apply any additional attributes from kwargs - for key, value in kwargs.items(): - setattr(dataset, key, value) - - return dataset - - -# ============================================================================ -# PYTEST FIXTURES -# ============================================================================ - - -@pytest.fixture -def factory(): - """ - Provide the test data factory to all tests. - - This fixture makes the TagServiceTestDataFactory available to all test - methods, allowing them to create consistent mock objects easily. - - Returns: - TagServiceTestDataFactory class - """ - return TagServiceTestDataFactory - - -# ============================================================================ -# TAG RETRIEVAL TESTS -# ============================================================================ - - -class TestTagServiceRetrieval: - """ - Test tag retrieval operations. - - This test class covers all methods related to retrieving and querying - tags from the system. These operations are read-only and do not modify - the database state. - - Methods tested: - - get_tags: Retrieve tags with optional keyword filtering - - get_target_ids_by_tag_ids: Get target IDs (datasets/apps) by tag IDs - - get_tag_by_tag_name: Find tags by exact name match - - get_tags_by_target_id: Get all tags bound to a specific target - """ - - @patch("services.tag_service.db.session", autospec=True) - def test_get_tags_with_binding_counts(self, mock_db_session, factory): - """ - Test retrieving tags with their binding counts. - - This test verifies that the get_tags method correctly retrieves - a list of tags along with the count of how many resources - (datasets/apps) are bound to each tag. - - The method should: - - Query tags filtered by type and tenant - - Include binding counts via a LEFT OUTER JOIN - - Return results ordered by creation date (newest first) - - Expected behavior: - - Returns a list of tuples containing (id, type, name, binding_count) - - Each tag includes its binding count - - Results are ordered by creation date descending - """ - # Arrange - # Set up test parameters - tenant_id = "tenant-123" - tag_type = "app" - - # Mock query results: tuples of (tag_id, type, name, binding_count) - # This simulates the SQL query result with aggregated binding counts - mock_results = [ - ("tag-1", "app", "Frontend", 5), # Frontend tag with 5 bindings - ("tag-2", "app", "Backend", 3), # Backend tag with 3 bindings - ("tag-3", "app", "API", 0), # API tag with no bindings - ] - - # Configure mock database session and query chain - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.outerjoin.return_value = mock_query # LEFT OUTER JOIN with TagBinding - mock_query.where.return_value = mock_query # WHERE clause for filtering - mock_query.group_by.return_value = mock_query # GROUP BY for aggregation - mock_query.order_by.return_value = mock_query # ORDER BY for sorting - mock_query.all.return_value = mock_results # Final result - - # Act - # Execute the method under test - results = TagService.get_tags(tag_type=tag_type, current_tenant_id=tenant_id) - - # Assert - # Verify the results match expectations - assert len(results) == 3, "Should return 3 tags" - - # Verify each tag's data structure - assert results[0] == ("tag-1", "app", "Frontend", 5), "First tag should match" - assert results[1] == ("tag-2", "app", "Backend", 3), "Second tag should match" - assert results[2] == ("tag-3", "app", "API", 0), "Third tag should match" - - # Verify database query was called - mock_db_session.query.assert_called_once() - - @patch("services.tag_service.db.session", autospec=True) - def test_get_tags_with_keyword_filter(self, mock_db_session, factory): - """ - Test retrieving tags filtered by keyword (case-insensitive). - - This test verifies that the get_tags method correctly filters tags - by keyword when a keyword parameter is provided. The filtering - should be case-insensitive and support partial matches. - - The method should: - - Apply an additional WHERE clause when keyword is provided - - Use ILIKE for case-insensitive pattern matching - - Support partial matches (e.g., "data" matches "Database" and "Data Science") - - Expected behavior: - - Returns only tags whose names contain the keyword - - Matching is case-insensitive - - Partial matches are supported - """ - # Arrange - # Set up test parameters - tenant_id = "tenant-123" - tag_type = "knowledge" - keyword = "data" - - # Mock query results filtered by keyword - mock_results = [ - ("tag-1", "knowledge", "Database", 2), - ("tag-2", "knowledge", "Data Science", 4), - ] - - # Configure mock database session and query chain - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.group_by.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = mock_results - - # Act - # Execute the method with keyword filter - results = TagService.get_tags(tag_type=tag_type, current_tenant_id=tenant_id, keyword=keyword) - - # Assert - # Verify filtered results - assert len(results) == 2, "Should return 2 matching tags" - - # Verify keyword filter was applied - # The where() method should be called at least twice: - # 1. Initial WHERE clause for type and tenant - # 2. Additional WHERE clause for keyword filtering - assert mock_query.where.call_count >= 2, "Keyword filter should add WHERE clause" - - @patch("services.tag_service.db.session", autospec=True) - def test_get_target_ids_by_tag_ids(self, mock_db_session, factory): - """ - Test retrieving target IDs by tag IDs. - - This test verifies that the get_target_ids_by_tag_ids method correctly - retrieves all target IDs (dataset/app IDs) that are bound to the - specified tags. This is useful for filtering datasets or apps by tags. - - The method should: - - First validate and filter tags by type and tenant - - Then find all bindings for those tags - - Return the target IDs from those bindings - - Expected behavior: - - Returns a list of target IDs (strings) - - Only includes targets bound to valid tags - - Respects tenant and type filtering - """ - # Arrange - # Set up test parameters - tenant_id = "tenant-123" - tag_type = "app" - tag_ids = ["tag-1", "tag-2"] - - # Create mock tag objects - tags = [ - factory.create_tag_mock(tag_id="tag-1", tenant_id=tenant_id, tag_type=tag_type), - factory.create_tag_mock(tag_id="tag-2", tenant_id=tenant_id, tag_type=tag_type), - ] - - # Mock target IDs that are bound to these tags - target_ids = ["app-1", "app-2", "app-3"] - - # Mock tag query (first scalars call) - mock_scalars_tags = MagicMock() - mock_scalars_tags.all.return_value = tags - - # Mock binding query (second scalars call) - mock_scalars_bindings = MagicMock() - mock_scalars_bindings.all.return_value = target_ids - - # Configure side_effect to return different mocks for each scalars() call - mock_db_session.scalars.side_effect = [mock_scalars_tags, mock_scalars_bindings] - - # Act - # Execute the method under test - results = TagService.get_target_ids_by_tag_ids(tag_type=tag_type, current_tenant_id=tenant_id, tag_ids=tag_ids) - - # Assert - # Verify results match expected target IDs - assert results == target_ids, "Should return all target IDs bound to tags" - - # Verify both queries were executed - assert mock_db_session.scalars.call_count == 2, "Should execute tag query and binding query" - - @patch("services.tag_service.db.session", autospec=True) - def test_get_target_ids_with_empty_tag_ids(self, mock_db_session, factory): - """ - Test that empty tag_ids returns empty list. - - This test verifies the edge case handling when an empty list of - tag IDs is provided. The method should return early without - executing any database queries. - - Expected behavior: - - Returns empty list immediately - - Does not execute any database queries - - Handles empty input gracefully - """ - # Arrange - # Set up test parameters with empty tag IDs - tenant_id = "tenant-123" - tag_type = "app" - - # Act - # Execute the method with empty tag IDs list - results = TagService.get_target_ids_by_tag_ids(tag_type=tag_type, current_tenant_id=tenant_id, tag_ids=[]) - - # Assert - # Verify empty result and no database queries - assert results == [], "Should return empty list for empty input" - mock_db_session.scalars.assert_not_called(), "Should not query database for empty input" - - @patch("services.tag_service.db.session", autospec=True) - def test_get_tag_by_tag_name(self, mock_db_session, factory): - """ - Test retrieving tags by name. - - This test verifies that the get_tag_by_tag_name method correctly - finds tags by their exact name. This is used for duplicate name - checking and tag lookup operations. - - The method should: - - Perform exact name matching (case-sensitive) - - Filter by type and tenant - - Return a list of matching tags (usually 0 or 1) - - Expected behavior: - - Returns list of tags with matching name - - Respects type and tenant filtering - - Returns empty list if no matches found - """ - # Arrange - # Set up test parameters - tenant_id = "tenant-123" - tag_type = "app" - tag_name = "Production" - - # Create mock tag with matching name - tags = [factory.create_tag_mock(name=tag_name, tag_type=tag_type, tenant_id=tenant_id)] - - # Configure mock database session - mock_scalars = MagicMock() - mock_scalars.all.return_value = tags - mock_db_session.scalars.return_value = mock_scalars - - # Act - # Execute the method under test - results = TagService.get_tag_by_tag_name(tag_type=tag_type, current_tenant_id=tenant_id, tag_name=tag_name) - - # Assert - # Verify tag was found - assert len(results) == 1, "Should find exactly one tag" - assert results[0].name == tag_name, "Tag name should match" - - @patch("services.tag_service.db.session", autospec=True) - def test_get_tag_by_tag_name_returns_empty_for_missing_params(self, mock_db_session, factory): - """ - Test that missing tag_type or tag_name returns empty list. - - This test verifies the input validation for the get_tag_by_tag_name - method. When either tag_type or tag_name is empty or missing, - the method should return early without querying the database. - - Expected behavior: - - Returns empty list for empty tag_type - - Returns empty list for empty tag_name - - Does not execute database queries for invalid input - """ - # Arrange - # Set up test parameters - tenant_id = "tenant-123" - - # Act & Assert - # Test with empty tag_type - assert TagService.get_tag_by_tag_name("", tenant_id, "name") == [], "Should return empty for empty type" - - # Test with empty tag_name - assert TagService.get_tag_by_tag_name("app", tenant_id, "") == [], "Should return empty for empty name" - - # Verify no database queries were executed - mock_db_session.scalars.assert_not_called(), "Should not query database for invalid input" - - @patch("services.tag_service.db.session", autospec=True) - def test_get_tags_by_target_id(self, mock_db_session, factory): - """ - Test retrieving tags associated with a specific target. - - This test verifies that the get_tags_by_target_id method correctly - retrieves all tags that are bound to a specific target (dataset or app). - This is useful for displaying tags associated with a resource. - - The method should: - - Join Tag and TagBinding tables - - Filter by target_id, tenant, and type - - Return all tags bound to the target - - Expected behavior: - - Returns list of Tag objects bound to the target - - Respects tenant and type filtering - - Returns empty list if no tags are bound - """ - # Arrange - # Set up test parameters - tenant_id = "tenant-123" - tag_type = "app" - target_id = "app-123" - - # Create mock tags that are bound to the target - tags = [ - factory.create_tag_mock(tag_id="tag-1", name="Frontend"), - factory.create_tag_mock(tag_id="tag-2", name="Production"), - ] - - # Configure mock database session and query chain - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.join.return_value = mock_query # JOIN with TagBinding - mock_query.where.return_value = mock_query # WHERE clause for filtering - mock_query.all.return_value = tags # Final result - - # Act - # Execute the method under test - results = TagService.get_tags_by_target_id(tag_type=tag_type, current_tenant_id=tenant_id, target_id=target_id) - - # Assert - # Verify tags were retrieved - assert len(results) == 2, "Should return 2 tags bound to target" - - # Verify tag names - assert results[0].name == "Frontend", "First tag name should match" - assert results[1].name == "Production", "Second tag name should match" - - -# ============================================================================ -# TAG CRUD OPERATIONS TESTS -# ============================================================================ - - -class TestTagServiceCRUD: - """ - Test tag CRUD operations. - - This test class covers all Create, Read, Update, and Delete operations - for tags. These operations modify the database state and require proper - transaction handling and validation. - - Methods tested: - - save_tags: Create new tags - - update_tags: Update existing tag names - - delete_tag: Delete tags and cascade delete bindings - - get_tag_binding_count: Get count of bindings for a tag - """ - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True) - @patch("services.tag_service.db.session", autospec=True) - @patch("services.tag_service.uuid.uuid4", autospec=True) - def test_save_tags(self, mock_uuid, mock_db_session, mock_get_tag_by_name, mock_current_user, factory): - """ - Test creating a new tag. - - This test verifies that the save_tags method correctly creates a new - tag in the database with all required attributes. The method should - validate uniqueness, generate a UUID, and persist the tag. - - The method should: - - Check for duplicate tag names (via get_tag_by_tag_name) - - Generate a unique UUID for the tag ID - - Set user and tenant information from current_user - - Persist the tag to the database - - Commit the transaction - - Expected behavior: - - Creates tag with correct attributes - - Assigns UUID to tag ID - - Sets created_by from current_user - - Sets tenant_id from current_user - - Commits to database - """ - # Arrange - # Configure mock current_user - mock_current_user.id = "user-123" - mock_current_user.current_tenant_id = "tenant-123" - - # Mock UUID generation - mock_uuid.return_value = "new-tag-id" - - # Mock no existing tag (duplicate check passes) - mock_get_tag_by_name.return_value = [] - - # Prepare tag creation arguments - args = {"name": "New Tag", "type": "app"} - - # Act - # Execute the method under test - result = TagService.save_tags(args) - - # Assert - # Verify tag was added to database session - mock_db_session.add.assert_called_once(), "Should add tag to session" - - # Verify transaction was committed - mock_db_session.commit.assert_called_once(), "Should commit transaction" - - # Verify tag attributes - added_tag = mock_db_session.add.call_args[0][0] - assert added_tag.name == "New Tag", "Tag name should match" - assert added_tag.type == "app", "Tag type should match" - assert added_tag.created_by == "user-123", "Created by should match current user" - assert added_tag.tenant_id == "tenant-123", "Tenant ID should match current tenant" - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True) - def test_save_tags_raises_error_for_duplicate_name(self, mock_get_tag_by_name, mock_current_user, factory): - """ - Test that creating a tag with duplicate name raises ValueError. - - This test verifies that the save_tags method correctly prevents - duplicate tag names within the same tenant and type. Tag names - must be unique per tenant and type combination. - - Expected behavior: - - Raises ValueError when duplicate name is detected - - Error message indicates "Tag name already exists" - - Does not create the tag - """ - # Arrange - # Configure mock current_user - mock_current_user.current_tenant_id = "tenant-123" - - # Mock existing tag with same name (duplicate detected) - existing_tag = factory.create_tag_mock(name="Existing Tag") - mock_get_tag_by_name.return_value = [existing_tag] - - # Prepare tag creation arguments with duplicate name - args = {"name": "Existing Tag", "type": "app"} - - # Act & Assert - # Verify ValueError is raised for duplicate name - with pytest.raises(ValueError, match="Tag name already exists"): - TagService.save_tags(args) - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True) - @patch("services.tag_service.db.session", autospec=True) - def test_update_tags(self, mock_db_session, mock_get_tag_by_name, mock_current_user, factory): - """ - Test updating a tag name. - - This test verifies that the update_tags method correctly updates - an existing tag's name while preserving other attributes. The method - should validate uniqueness of the new name and ensure the tag exists. - - The method should: - - Check for duplicate tag names (excluding the current tag) - - Find the tag by ID - - Update the tag name - - Commit the transaction - - Expected behavior: - - Updates tag name successfully - - Preserves other tag attributes - - Commits to database - """ - # Arrange - # Configure mock current_user - mock_current_user.current_tenant_id = "tenant-123" - - # Mock no duplicate name (update check passes) - mock_get_tag_by_name.return_value = [] - - # Create mock tag to be updated - tag = factory.create_tag_mock(tag_id="tag-123", name="Old Name") - - # Configure mock database session to return the tag - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = tag - - # Prepare update arguments - args = {"name": "New Name", "type": "app"} - - # Act - # Execute the method under test - result = TagService.update_tags(args, tag_id="tag-123") - - # Assert - # Verify tag name was updated - assert tag.name == "New Name", "Tag name should be updated" - - # Verify transaction was committed - mock_db_session.commit.assert_called_once(), "Should commit transaction" - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True) - @patch("services.tag_service.db.session", autospec=True) - def test_update_tags_raises_error_for_duplicate_name( - self, mock_db_session, mock_get_tag_by_name, mock_current_user, factory - ): - """ - Test that updating to a duplicate name raises ValueError. - - This test verifies that the update_tags method correctly prevents - updating a tag to a name that already exists for another tag - within the same tenant and type. - - Expected behavior: - - Raises ValueError when duplicate name is detected - - Error message indicates "Tag name already exists" - - Does not update the tag - """ - # Arrange - # Configure mock current_user - mock_current_user.current_tenant_id = "tenant-123" - - # Mock existing tag with the duplicate name - existing_tag = factory.create_tag_mock(name="Duplicate Name") - mock_get_tag_by_name.return_value = [existing_tag] - - # Prepare update arguments with duplicate name - args = {"name": "Duplicate Name", "type": "app"} - - # Act & Assert - # Verify ValueError is raised for duplicate name - with pytest.raises(ValueError, match="Tag name already exists"): - TagService.update_tags(args, tag_id="tag-123") - - @patch("services.tag_service.db.session", autospec=True) - def test_update_tags_raises_not_found_for_missing_tag(self, mock_db_session, factory): - """ - Test that updating a non-existent tag raises NotFound. - - This test verifies that the update_tags method correctly handles - the case when attempting to update a tag that does not exist. - This prevents silent failures and provides clear error feedback. - - Expected behavior: - - Raises NotFound exception - - Error message indicates "Tag not found" - - Does not attempt to update or commit - """ - # Arrange - # Configure mock database session to return None (tag not found) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None - - # Mock duplicate check and current_user - with patch("services.tag_service.TagService.get_tag_by_tag_name", return_value=[], autospec=True): - with patch("services.tag_service.current_user", autospec=True) as mock_user: - mock_user.current_tenant_id = "tenant-123" - args = {"name": "New Name", "type": "app"} - - # Act & Assert - # Verify NotFound is raised for non-existent tag - with pytest.raises(NotFound, match="Tag not found"): - TagService.update_tags(args, tag_id="nonexistent") - - @patch("services.tag_service.db.session", autospec=True) - def test_get_tag_binding_count(self, mock_db_session, factory): - """ - Test getting the count of bindings for a tag. - - This test verifies that the get_tag_binding_count method correctly - counts how many resources (datasets/apps) are bound to a specific tag. - This is useful for displaying tag usage statistics. - - The method should: - - Query TagBinding table filtered by tag_id - - Return the count of matching bindings - - Expected behavior: - - Returns integer count of bindings - - Returns 0 for tags with no bindings - """ - # Arrange - # Set up test parameters - tag_id = "tag-123" - expected_count = 5 - - # Configure mock database session - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.count.return_value = expected_count - - # Act - # Execute the method under test - result = TagService.get_tag_binding_count(tag_id) - - # Assert - # Verify count matches expectation - assert result == expected_count, "Binding count should match" - - @patch("services.tag_service.db.session", autospec=True) - def test_delete_tag(self, mock_db_session, factory): - """ - Test deleting a tag and its bindings. - - This test verifies that the delete_tag method correctly deletes - a tag along with all its associated bindings (cascade delete). - This ensures data integrity and prevents orphaned bindings. - - The method should: - - Find the tag by ID - - Delete the tag - - Find all bindings for the tag - - Delete all bindings (cascade delete) - - Commit the transaction - - Expected behavior: - - Deletes tag from database - - Deletes all associated bindings - - Commits transaction - """ - # Arrange - # Set up test parameters - tag_id = "tag-123" - - # Create mock tag to be deleted - tag = factory.create_tag_mock(tag_id=tag_id) - - # Create mock bindings that will be cascade deleted - bindings = [factory.create_tag_binding_mock(binding_id=f"binding-{i}", tag_id=tag_id) for i in range(3)] - - # Configure mock database session for tag query - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = tag - - # Configure mock database session for bindings query - mock_scalars = MagicMock() - mock_scalars.all.return_value = bindings - mock_db_session.scalars.return_value = mock_scalars - - # Act - # Execute the method under test - TagService.delete_tag(tag_id) - - # Assert - # Verify tag and bindings were deleted - mock_db_session.delete.assert_called(), "Should call delete method" - - # Verify delete was called 4 times (1 tag + 3 bindings) - assert mock_db_session.delete.call_count == 4, "Should delete tag and all bindings" - - # Verify transaction was committed - mock_db_session.commit.assert_called_once(), "Should commit transaction" - - @patch("services.tag_service.db.session", autospec=True) - def test_delete_tag_raises_not_found(self, mock_db_session, factory): - """ - Test that deleting a non-existent tag raises NotFound. - - This test verifies that the delete_tag method correctly handles - the case when attempting to delete a tag that does not exist. - This prevents silent failures and provides clear error feedback. - - Expected behavior: - - Raises NotFound exception - - Error message indicates "Tag not found" - - Does not attempt to delete or commit - """ - # Arrange - # Configure mock database session to return None (tag not found) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None - - # Act & Assert - # Verify NotFound is raised for non-existent tag - with pytest.raises(NotFound, match="Tag not found"): - TagService.delete_tag("nonexistent") - - -# ============================================================================ -# TAG BINDING OPERATIONS TESTS -# ============================================================================ - - -class TestTagServiceBindings: - """ - Test tag binding operations. - - This test class covers all operations related to binding tags to - resources (datasets and apps). Tag bindings create the many-to-many - relationship between tags and resources. - - Methods tested: - - save_tag_binding: Create bindings between tags and targets - - delete_tag_binding: Remove bindings between tags and targets - - check_target_exists: Validate target (dataset/app) existence - """ - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.TagService.check_target_exists", autospec=True) - @patch("services.tag_service.db.session", autospec=True) - def test_save_tag_binding(self, mock_db_session, mock_check_target, mock_current_user, factory): - """ - Test creating tag bindings. - - This test verifies that the save_tag_binding method correctly - creates bindings between tags and a target resource (dataset or app). - The method supports batch binding of multiple tags to a single target. - - The method should: - - Validate target exists (via check_target_exists) - - Check for existing bindings to avoid duplicates - - Create new bindings for tags that aren't already bound - - Commit the transaction - - Expected behavior: - - Validates target exists - - Creates bindings for each tag in tag_ids - - Skips tags that are already bound (idempotent) - - Commits transaction - """ - # Arrange - # Configure mock current_user - mock_current_user.id = "user-123" - mock_current_user.current_tenant_id = "tenant-123" - - # Configure mock database session (no existing bindings) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None # No existing bindings - - # Prepare binding arguments (batch binding) - args = {"type": "app", "target_id": "app-123", "tag_ids": ["tag-1", "tag-2"]} - - # Act - # Execute the method under test - TagService.save_tag_binding(args) - - # Assert - # Verify target existence was checked - mock_check_target.assert_called_once_with("app", "app-123"), "Should validate target exists" - - # Verify bindings were created (2 bindings for 2 tags) - assert mock_db_session.add.call_count == 2, "Should create 2 bindings" - - # Verify transaction was committed - mock_db_session.commit.assert_called_once(), "Should commit transaction" - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.TagService.check_target_exists", autospec=True) - @patch("services.tag_service.db.session", autospec=True) - def test_save_tag_binding_is_idempotent(self, mock_db_session, mock_check_target, mock_current_user, factory): - """ - Test that saving duplicate bindings is idempotent. - - This test verifies that the save_tag_binding method correctly handles - the case when attempting to create a binding that already exists. - The method should skip existing bindings and not create duplicates, - making the operation idempotent. - - Expected behavior: - - Checks for existing bindings - - Skips tags that are already bound - - Does not create duplicate bindings - - Still commits transaction - """ - # Arrange - # Configure mock current_user - mock_current_user.id = "user-123" - mock_current_user.current_tenant_id = "tenant-123" - - # Mock existing binding (duplicate detected) - existing_binding = factory.create_tag_binding_mock() - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = existing_binding # Binding already exists - - # Prepare binding arguments - args = {"type": "app", "target_id": "app-123", "tag_ids": ["tag-1"]} - - # Act - # Execute the method under test - TagService.save_tag_binding(args) - - # Assert - # Verify no new binding was added (idempotent) - mock_db_session.add.assert_not_called(), "Should not create duplicate binding" - - @patch("services.tag_service.TagService.check_target_exists", autospec=True) - @patch("services.tag_service.db.session", autospec=True) - def test_delete_tag_binding(self, mock_db_session, mock_check_target, factory): - """ - Test deleting a tag binding. - - This test verifies that the delete_tag_binding method correctly - removes a binding between a tag and a target resource. This - operation should be safe even if the binding doesn't exist. - - The method should: - - Validate target exists (via check_target_exists) - - Find the binding by tag_id and target_id - - Delete the binding if it exists - - Commit the transaction - - Expected behavior: - - Validates target exists - - Deletes the binding - - Commits transaction - """ - # Arrange - # Create mock binding to be deleted - binding = factory.create_tag_binding_mock() - - # Configure mock database session - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = binding - - # Prepare delete arguments - args = {"type": "app", "target_id": "app-123", "tag_id": "tag-1"} - - # Act - # Execute the method under test - TagService.delete_tag_binding(args) - - # Assert - # Verify target existence was checked - mock_check_target.assert_called_once_with("app", "app-123"), "Should validate target exists" - - # Verify binding was deleted - mock_db_session.delete.assert_called_once_with(binding), "Should delete the binding" - - # Verify transaction was committed - mock_db_session.commit.assert_called_once(), "Should commit transaction" - - @patch("services.tag_service.TagService.check_target_exists", autospec=True) - @patch("services.tag_service.db.session", autospec=True) - def test_delete_tag_binding_does_nothing_if_not_exists(self, mock_db_session, mock_check_target, factory): - """ - Test that deleting a non-existent binding is a no-op. - - This test verifies that the delete_tag_binding method correctly - handles the case when attempting to delete a binding that doesn't - exist. The method should not raise an error and should not commit - if there's nothing to delete. - - Expected behavior: - - Validates target exists - - Does not raise error for non-existent binding - - Does not call delete or commit if binding doesn't exist - """ - # Arrange - # Configure mock database session (binding not found) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None # Binding doesn't exist - - # Prepare delete arguments - args = {"type": "app", "target_id": "app-123", "tag_id": "tag-1"} - - # Act - # Execute the method under test - TagService.delete_tag_binding(args) - - # Assert - # Verify no delete operation was attempted - mock_db_session.delete.assert_not_called(), "Should not delete if binding doesn't exist" - - # Verify no commit was made (nothing changed) - mock_db_session.commit.assert_not_called(), "Should not commit if nothing to delete" - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.db.session", autospec=True) - def test_check_target_exists_for_dataset(self, mock_db_session, mock_current_user, factory): - """ - Test validating that a dataset target exists. - - This test verifies that the check_target_exists method correctly - validates the existence of a dataset (knowledge base) when the - target type is "knowledge". This validation ensures bindings - are only created for valid resources. - - The method should: - - Query Dataset table filtered by tenant and ID - - Raise NotFound if dataset doesn't exist - - Return normally if dataset exists - - Expected behavior: - - No exception raised when dataset exists - - Database query is executed - """ - # Arrange - # Configure mock current_user - mock_current_user.current_tenant_id = "tenant-123" - - # Create mock dataset - dataset = factory.create_dataset_mock() - - # Configure mock database session - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = dataset # Dataset exists - - # Act - # Execute the method under test - TagService.check_target_exists("knowledge", "dataset-123") - - # Assert - # Verify no exception was raised and query was executed - mock_db_session.query.assert_called_once(), "Should query database for dataset" - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.db.session", autospec=True) - def test_check_target_exists_for_app(self, mock_db_session, mock_current_user, factory): - """ - Test validating that an app target exists. - - This test verifies that the check_target_exists method correctly - validates the existence of an application when the target type is - "app". This validation ensures bindings are only created for valid - resources. - - The method should: - - Query App table filtered by tenant and ID - - Raise NotFound if app doesn't exist - - Return normally if app exists - - Expected behavior: - - No exception raised when app exists - - Database query is executed - """ - # Arrange - # Configure mock current_user - mock_current_user.current_tenant_id = "tenant-123" - - # Create mock app - app = factory.create_app_mock() - - # Configure mock database session - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = app # App exists - - # Act - # Execute the method under test - TagService.check_target_exists("app", "app-123") - - # Assert - # Verify no exception was raised and query was executed - mock_db_session.query.assert_called_once(), "Should query database for app" - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.db.session", autospec=True) - def test_check_target_exists_raises_not_found_for_missing_dataset( - self, mock_db_session, mock_current_user, factory - ): - """ - Test that missing dataset raises NotFound. - - This test verifies that the check_target_exists method correctly - raises a NotFound exception when attempting to validate a dataset - that doesn't exist. This prevents creating bindings for invalid - resources. - - Expected behavior: - - Raises NotFound exception - - Error message indicates "Dataset not found" - """ - # Arrange - # Configure mock current_user - mock_current_user.current_tenant_id = "tenant-123" - - # Configure mock database session (dataset not found) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None # Dataset doesn't exist - - # Act & Assert - # Verify NotFound is raised for non-existent dataset - with pytest.raises(NotFound, match="Dataset not found"): - TagService.check_target_exists("knowledge", "nonexistent") - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.db.session", autospec=True) - def test_check_target_exists_raises_not_found_for_missing_app(self, mock_db_session, mock_current_user, factory): - """ - Test that missing app raises NotFound. - - This test verifies that the check_target_exists method correctly - raises a NotFound exception when attempting to validate an app - that doesn't exist. This prevents creating bindings for invalid - resources. - - Expected behavior: - - Raises NotFound exception - - Error message indicates "App not found" - """ - # Arrange - # Configure mock current_user - mock_current_user.current_tenant_id = "tenant-123" - - # Configure mock database session (app not found) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None # App doesn't exist - - # Act & Assert - # Verify NotFound is raised for non-existent app - with pytest.raises(NotFound, match="App not found"): - TagService.check_target_exists("app", "nonexistent") - - def test_check_target_exists_raises_not_found_for_invalid_type(self, factory): - """ - Test that invalid binding type raises NotFound. - - This test verifies that the check_target_exists method correctly - raises a NotFound exception when an invalid target type is provided. - Only "knowledge" (for datasets) and "app" are valid target types. - - Expected behavior: - - Raises NotFound exception - - Error message indicates "Invalid binding type" - """ - # Act & Assert - # Verify NotFound is raised for invalid target type - with pytest.raises(NotFound, match="Invalid binding type"): - TagService.check_target_exists("invalid_type", "target-123") diff --git a/api/tests/unit_tests/services/test_trigger_provider_service.py b/api/tests/unit_tests/services/test_trigger_provider_service.py new file mode 100644 index 00000000000..81a3b181fdb --- /dev/null +++ b/api/tests/unit_tests/services/test_trigger_provider_service.py @@ -0,0 +1,1249 @@ +from __future__ import annotations + +import contextlib +import json +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from constants import HIDDEN_VALUE +from core.plugin.entities.plugin_daemon import CredentialType +from models.provider_ids import TriggerProviderID +from services.trigger.trigger_provider_service import TriggerProviderService + + +def _patch_redis_lock(mocker: MockerFixture) -> None: + mock_redis = mocker.patch("services.trigger.trigger_provider_service.redis_client") + mock_redis.lock.return_value = contextlib.nullcontext() + + +def _mock_get_trigger_provider(mocker: MockerFixture, provider: object | None) -> None: + mocker.patch( + "services.trigger.trigger_provider_service.TriggerManager.get_trigger_provider", + return_value=provider, + ) + + +def _encrypter_mock( + *, + decrypted: dict | None = None, + encrypted: dict | None = None, + masked: dict | None = None, +) -> MagicMock: + enc = MagicMock() + enc.decrypt.return_value = decrypted or {} + enc.encrypt.return_value = encrypted or {} + enc.mask_credentials.return_value = masked or {} + enc.mask_plugin_credentials.return_value = masked or {} + return enc + + +@pytest.fixture +def provider_id() -> TriggerProviderID: + # Arrange + return TriggerProviderID("langgenius/github/github") + + +@pytest.fixture(autouse=True) +def mock_db_engine(mocker: MockerFixture) -> SimpleNamespace: + # Arrange + mocked_db = SimpleNamespace(engine=object()) + mocker.patch("services.trigger.trigger_provider_service.db", mocked_db) + return mocked_db + + +@pytest.fixture +def mock_session(mocker: MockerFixture) -> MagicMock: + """Mocks the database session context manager used by TriggerProviderService.""" + # Arrange + mock_session_instance = MagicMock() + mock_session_cm = MagicMock() + mock_session_cm.__enter__.return_value = mock_session_instance + mock_session_cm.__exit__.return_value = False + mocker.patch("services.trigger.trigger_provider_service.Session", return_value=mock_session_cm) + return mock_session_instance + + +@pytest.fixture +def provider_controller() -> MagicMock: + # Arrange + controller = MagicMock() + controller.get_credential_schema_config.return_value = [] + controller.get_properties_schema.return_value = [] + controller.get_oauth_client_schema.return_value = [] + controller.plugin_unique_identifier = "langgenius/github:0.0.1" + return controller + + +def test_get_trigger_provider_should_return_api_entity_from_manager( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, +) -> None: + # Arrange + provider = MagicMock() + provider.to_api_entity.return_value = {"provider": "ok"} + _mock_get_trigger_provider(mocker, provider) + + # Act + result = TriggerProviderService.get_trigger_provider("tenant-1", provider_id) + + # Assert + assert result == {"provider": "ok"} + + +def test_list_trigger_providers_should_return_api_entities_from_manager(mocker: MockerFixture) -> None: + # Arrange + provider_a = MagicMock() + provider_b = MagicMock() + provider_a.to_api_entity.return_value = {"id": "a"} + provider_b.to_api_entity.return_value = {"id": "b"} + mocker.patch( + "services.trigger.trigger_provider_service.TriggerManager.list_all_trigger_providers", + return_value=[provider_a, provider_b], + ) + + # Act + result = TriggerProviderService.list_trigger_providers("tenant-1") + + # Assert + assert result == [{"id": "a"}, {"id": "b"}] + + +def test_list_trigger_provider_subscriptions_should_return_empty_list_when_no_subscriptions( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, +) -> None: + # Arrange + query = MagicMock() + query.filter_by.return_value.order_by.return_value.all.return_value = [] + mock_session.query.return_value = query + + # Act + result = TriggerProviderService.list_trigger_provider_subscriptions("tenant-1", provider_id) + + # Assert + assert result == [] + + +def test_list_trigger_provider_subscriptions_should_mask_fields_and_attach_workflow_counts( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + api_sub = SimpleNamespace( + id="sub-1", + credentials={"token": "enc"}, + properties={"hook": "enc"}, + parameters={"event": "push"}, + workflows_in_use=0, + ) + db_sub = SimpleNamespace(to_api_entity=lambda: api_sub) + usage_row = SimpleNamespace(subscription_id="sub-1", app_count=2) + + query_subs = MagicMock() + query_subs.filter_by.return_value.order_by.return_value.all.return_value = [db_sub] + query_usage = MagicMock() + query_usage.filter.return_value.group_by.return_value.all.return_value = [usage_row] + mock_session.query.side_effect = [query_subs, query_usage] + + _mock_get_trigger_provider(mocker, provider_controller) + cred_enc = _encrypter_mock(decrypted={"token": "plain"}, masked={"token": "****"}) + prop_enc = _encrypter_mock(decrypted={"hook": "plain"}, masked={"hook": "****"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_subscription", + return_value=(cred_enc, MagicMock()), + ) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_properties", + return_value=(prop_enc, MagicMock()), + ) + + # Act + result = TriggerProviderService.list_trigger_provider_subscriptions("tenant-1", provider_id) + + # Assert + assert len(result) == 1 + assert result[0].credentials == {"token": "****"} + assert result[0].properties == {"hook": "****"} + assert result[0].workflows_in_use == 2 + + +def test_add_trigger_subscription_should_create_subscription_successfully_for_api_key( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + _patch_redis_lock(mocker) + query_count = MagicMock() + query_count.filter_by.return_value.count.return_value = 0 + query_existing = MagicMock() + query_existing.filter_by.return_value.first.return_value = None + mock_session.query.side_effect = [query_count, query_existing] + + _mock_get_trigger_provider(mocker, provider_controller) + cred_enc = _encrypter_mock(encrypted={"api_key": "enc"}) + prop_enc = _encrypter_mock(encrypted={"project": "enc"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_provider_encrypter", + side_effect=[(cred_enc, MagicMock()), (prop_enc, MagicMock())], + ) + + # Act + result = TriggerProviderService.add_trigger_subscription( + tenant_id="tenant-1", + user_id="user-1", + name="main", + provider_id=provider_id, + endpoint_id="endpoint-1", + credential_type=CredentialType.API_KEY, + parameters={"event": "push"}, + properties={"project": "demo"}, + credentials={"api_key": "plain"}, + ) + + # Assert + assert result["result"] == "success" + mock_session.add.assert_called_once() + mock_session.commit.assert_called_once() + + +def test_add_trigger_subscription_should_store_empty_credentials_for_unauthorized_type( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + _patch_redis_lock(mocker) + query_count = MagicMock() + query_count.filter_by.return_value.count.return_value = 0 + query_existing = MagicMock() + query_existing.filter_by.return_value.first.return_value = None + mock_session.query.side_effect = [query_count, query_existing] + + _mock_get_trigger_provider(mocker, provider_controller) + prop_enc = _encrypter_mock(encrypted={"p": "enc"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_provider_encrypter", + return_value=(prop_enc, MagicMock()), + ) + + # Act + result = TriggerProviderService.add_trigger_subscription( + tenant_id="tenant-1", + user_id="user-1", + name="main", + provider_id=provider_id, + endpoint_id="endpoint-1", + credential_type=CredentialType.UNAUTHORIZED, + parameters={}, + properties={"p": "v"}, + credentials={}, + subscription_id="sub-fixed", + ) + + # Assert + assert result == {"result": "success", "id": "sub-fixed"} + + +def test_add_trigger_subscription_should_raise_error_when_provider_limit_reached( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + _patch_redis_lock(mocker) + query_count = MagicMock() + query_count.filter_by.return_value.count.return_value = TriggerProviderService.__MAX_TRIGGER_PROVIDER_COUNT__ + mock_session.query.return_value = query_count + _mock_get_trigger_provider(mocker, provider_controller) + mock_logger = mocker.patch("services.trigger.trigger_provider_service.logger") + + # Act + Assert + with pytest.raises(ValueError, match="Maximum number of providers"): + TriggerProviderService.add_trigger_subscription( + tenant_id="tenant-1", + user_id="user-1", + name="main", + provider_id=provider_id, + endpoint_id="endpoint-1", + credential_type=CredentialType.API_KEY, + parameters={}, + properties={}, + credentials={}, + ) + mock_logger.exception.assert_called_once() + + +def test_add_trigger_subscription_should_raise_error_when_name_exists( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + _patch_redis_lock(mocker) + query_count = MagicMock() + query_count.filter_by.return_value.count.return_value = 0 + query_existing = MagicMock() + query_existing.filter_by.return_value.first.return_value = object() + mock_session.query.side_effect = [query_count, query_existing] + _mock_get_trigger_provider(mocker, provider_controller) + + # Act + Assert + with pytest.raises(ValueError, match="Credential name 'main' already exists"): + TriggerProviderService.add_trigger_subscription( + tenant_id="tenant-1", + user_id="user-1", + name="main", + provider_id=provider_id, + endpoint_id="endpoint-1", + credential_type=CredentialType.API_KEY, + parameters={}, + properties={}, + credentials={}, + ) + + +def test_update_trigger_subscription_should_raise_error_when_subscription_not_found( + mocker: MockerFixture, + mock_session: MagicMock, +) -> None: + # Arrange + _patch_redis_lock(mocker) + query_sub = MagicMock() + query_sub.filter_by.return_value.first.return_value = None + mock_session.query.return_value = query_sub + + # Act + Assert + with pytest.raises(ValueError, match="not found"): + TriggerProviderService.update_trigger_subscription("tenant-1", "sub-1") + + +def test_update_trigger_subscription_should_raise_error_when_name_conflicts( + mocker: MockerFixture, + mock_session: MagicMock, + provider_controller: MagicMock, +) -> None: + # Arrange + _patch_redis_lock(mocker) + subscription = SimpleNamespace( + id="sub-1", + name="old", + provider_id="langgenius/github/github", + credential_type=CredentialType.API_KEY.value, + ) + query_sub = MagicMock() + query_sub.filter_by.return_value.first.return_value = subscription + query_existing = MagicMock() + query_existing.filter_by.return_value.first.return_value = object() + mock_session.query.side_effect = [query_sub, query_existing] + _mock_get_trigger_provider(mocker, provider_controller) + + # Act + Assert + with pytest.raises(ValueError, match="already exists"): + TriggerProviderService.update_trigger_subscription("tenant-1", "sub-1", name="new-name") + + +def test_update_trigger_subscription_should_update_fields_and_clear_cache( + mocker: MockerFixture, + mock_session: MagicMock, + provider_controller: MagicMock, +) -> None: + # Arrange + _patch_redis_lock(mocker) + subscription = SimpleNamespace( + id="sub-1", + name="old", + tenant_id="tenant-1", + provider_id="langgenius/github/github", + properties={"project": "enc-old"}, + parameters={"event": "old"}, + credentials={"api_key": "enc-old"}, + credential_type=CredentialType.API_KEY.value, + credential_expires_at=0, + expires_at=0, + ) + query_sub = MagicMock() + query_sub.filter_by.return_value.first.return_value = subscription + query_existing = MagicMock() + query_existing.filter_by.return_value.first.return_value = None + mock_session.query.side_effect = [query_sub, query_existing] + + _mock_get_trigger_provider(mocker, provider_controller) + prop_enc = _encrypter_mock(decrypted={"project": "old-value"}, encrypted={"project": "new-value"}) + cred_enc = _encrypter_mock(encrypted={"api_key": "new-key"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_provider_encrypter", + side_effect=[(prop_enc, MagicMock()), (cred_enc, MagicMock())], + ) + mock_delete_cache = mocker.patch("services.trigger.trigger_provider_service.delete_cache_for_subscription") + + # Act + TriggerProviderService.update_trigger_subscription( + tenant_id="tenant-1", + subscription_id="sub-1", + name="new", + properties={"project": HIDDEN_VALUE, "region": "us"}, + parameters={"event": "new"}, + credentials={"api_key": "plain-key"}, + credential_expires_at=100, + expires_at=200, + ) + + # Assert + assert subscription.name == "new" + assert subscription.parameters == {"event": "new"} + assert subscription.credentials == {"api_key": "new-key"} + assert subscription.credential_expires_at == 100 + assert subscription.expires_at == 200 + mock_session.commit.assert_called_once() + mock_delete_cache.assert_called_once() + + +def test_get_subscription_by_id_should_return_none_when_missing(mocker: MockerFixture, mock_session: MagicMock) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + result = TriggerProviderService.get_subscription_by_id("tenant-1", "sub-1") + + # Assert + assert result is None + + +def test_get_subscription_by_id_should_decrypt_credentials_and_properties( + mocker: MockerFixture, + mock_session: MagicMock, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace( + id="sub-1", + tenant_id="tenant-1", + provider_id="langgenius/github/github", + credentials={"token": "enc"}, + properties={"project": "enc"}, + ) + mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + _mock_get_trigger_provider(mocker, provider_controller) + cred_enc = _encrypter_mock(decrypted={"token": "plain"}) + prop_enc = _encrypter_mock(decrypted={"project": "plain"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_subscription", + return_value=(cred_enc, MagicMock()), + ) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_properties", + return_value=(prop_enc, MagicMock()), + ) + + # Act + result = TriggerProviderService.get_subscription_by_id("tenant-1", "sub-1") + + # Assert + assert result is subscription + assert subscription.credentials == {"token": "plain"} + assert subscription.properties == {"project": "plain"} + + +def test_delete_trigger_provider_should_raise_error_when_subscription_missing( + mocker: MockerFixture, + mock_session: MagicMock, +) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="not found"): + TriggerProviderService.delete_trigger_provider(mock_session, "tenant-1", "sub-1") + + +def test_delete_trigger_provider_should_delete_and_clear_cache_even_if_unsubscribe_fails( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace( + id="sub-1", + user_id="user-1", + provider_id=str(provider_id), + credential_type=CredentialType.OAUTH2.value, + credentials={"token": "enc"}, + to_entity=lambda: SimpleNamespace(id="sub-1"), + ) + mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + _mock_get_trigger_provider(mocker, provider_controller) + cred_enc = _encrypter_mock(decrypted={"token": "plain"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_subscription", + return_value=(cred_enc, MagicMock()), + ) + mocker.patch( + "services.trigger.trigger_provider_service.TriggerManager.unsubscribe_trigger", + side_effect=RuntimeError("remote fail"), + ) + mock_delete_cache = mocker.patch("services.trigger.trigger_provider_service.delete_cache_for_subscription") + + # Act + TriggerProviderService.delete_trigger_provider(mock_session, "tenant-1", "sub-1") + + # Assert + mock_session.delete.assert_called_once_with(subscription) + mock_delete_cache.assert_called_once() + + +def test_delete_trigger_provider_should_skip_unsubscribe_for_unauthorized( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace( + id="sub-2", + user_id="user-1", + provider_id=str(provider_id), + credential_type=CredentialType.UNAUTHORIZED.value, + credentials={}, + to_entity=lambda: SimpleNamespace(id="sub-2"), + ) + mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + _mock_get_trigger_provider(mocker, provider_controller) + mock_unsubscribe = mocker.patch("services.trigger.trigger_provider_service.TriggerManager.unsubscribe_trigger") + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_subscription", + return_value=(_encrypter_mock(decrypted={}), MagicMock()), + ) + + # Act + TriggerProviderService.delete_trigger_provider(mock_session, "tenant-1", "sub-2") + + # Assert + mock_unsubscribe.assert_not_called() + mock_session.delete.assert_called_once_with(subscription) + + +def test_refresh_oauth_token_should_raise_error_when_subscription_missing( + mocker: MockerFixture, mock_session: MagicMock +) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="not found"): + TriggerProviderService.refresh_oauth_token("tenant-1", "sub-1") + + +def test_refresh_oauth_token_should_raise_error_for_non_oauth_credentials( + mocker: MockerFixture, mock_session: MagicMock +) -> None: + # Arrange + subscription = SimpleNamespace(credential_type=CredentialType.API_KEY.value) + mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + + # Act + Assert + with pytest.raises(ValueError, match="Only OAuth credentials can be refreshed"): + TriggerProviderService.refresh_oauth_token("tenant-1", "sub-1") + + +def test_refresh_oauth_token_should_refresh_and_persist_new_credentials( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace( + provider_id=str(provider_id), + user_id="user-1", + credential_type=CredentialType.OAUTH2.value, + credentials={"access_token": "enc"}, + credential_expires_at=0, + ) + mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + _mock_get_trigger_provider(mocker, provider_controller) + cache = MagicMock() + cred_enc = _encrypter_mock(decrypted={"access_token": "old"}, encrypted={"access_token": "new"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_provider_encrypter", + return_value=(cred_enc, cache), + ) + mocker.patch.object(TriggerProviderService, "get_oauth_client", return_value={"client_id": "id"}) + refreshed = SimpleNamespace(credentials={"access_token": "new"}, expires_at=12345) + oauth_handler = MagicMock() + oauth_handler.refresh_credentials.return_value = refreshed + mocker.patch("services.trigger.trigger_provider_service.OAuthHandler", return_value=oauth_handler) + + # Act + result = TriggerProviderService.refresh_oauth_token("tenant-1", "sub-1") + + # Assert + assert result == {"result": "success", "expires_at": 12345} + assert subscription.credentials == {"access_token": "new"} + assert subscription.credential_expires_at == 12345 + mock_session.commit.assert_called_once() + cache.delete.assert_called_once() + + +def test_refresh_subscription_should_raise_error_when_subscription_missing( + mocker: MockerFixture, mock_session: MagicMock +) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="not found"): + TriggerProviderService.refresh_subscription("tenant-1", "sub-1", now=100) + + +def test_refresh_subscription_should_skip_when_not_due(mocker: MockerFixture, mock_session: MagicMock) -> None: + # Arrange + subscription = SimpleNamespace(expires_at=200) + mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + + # Act + result = TriggerProviderService.refresh_subscription("tenant-1", "sub-1", now=100) + + # Assert + assert result == {"result": "skipped", "expires_at": 200} + + +def test_refresh_subscription_should_refresh_and_persist_properties( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace( + id="sub-1", + tenant_id="tenant-1", + endpoint_id="endpoint-1", + expires_at=50, + provider_id=str(provider_id), + parameters={"event": "push"}, + properties={"p": "enc"}, + credentials={"c": "enc"}, + credential_type=CredentialType.API_KEY.value, + ) + mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + _mock_get_trigger_provider(mocker, provider_controller) + cred_enc = _encrypter_mock(decrypted={"c": "plain"}) + prop_cache = MagicMock() + prop_enc = _encrypter_mock(decrypted={"p": "plain"}, encrypted={"p": "new-enc"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_subscription", + return_value=(cred_enc, MagicMock()), + ) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_properties", + return_value=(prop_enc, prop_cache), + ) + mocker.patch( + "services.trigger.trigger_provider_service.generate_plugin_trigger_endpoint_url", + return_value="https://endpoint", + ) + provider_controller.refresh_trigger.return_value = SimpleNamespace(properties={"p": "new"}, expires_at=999) + + # Act + result = TriggerProviderService.refresh_subscription("tenant-1", "sub-1", now=100) + + # Assert + assert result == {"result": "success", "expires_at": 999} + assert subscription.properties == {"p": "new-enc"} + assert subscription.expires_at == 999 + mock_session.commit.assert_called_once() + prop_cache.delete.assert_called_once() + + +def test_get_oauth_client_should_return_tenant_client_when_available( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + tenant_client = SimpleNamespace(oauth_params={"client_id": "enc"}) + system_client = None + query_tenant = MagicMock() + query_tenant.filter_by.return_value.first.return_value = tenant_client + mock_session.query.return_value = query_tenant + _mock_get_trigger_provider(mocker, provider_controller) + enc = _encrypter_mock(decrypted={"client_id": "plain"}) + mocker.patch("services.trigger.trigger_provider_service.create_provider_encrypter", return_value=(enc, MagicMock())) + + # Act + result = TriggerProviderService.get_oauth_client("tenant-1", provider_id) + + # Assert + assert result == {"client_id": "plain"} + + +def test_get_oauth_client_should_return_none_when_plugin_not_verified( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + query_tenant = MagicMock() + query_tenant.filter_by.return_value.first.return_value = None + query_system = MagicMock() + query_system.filter_by.return_value.first.return_value = None + mock_session.query.side_effect = [query_tenant, query_system] + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=False) + + # Act + result = TriggerProviderService.get_oauth_client("tenant-1", provider_id) + + # Assert + assert result is None + + +def test_get_oauth_client_should_return_decrypted_system_client_when_verified( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + query_tenant = MagicMock() + query_tenant.filter_by.return_value.first.return_value = None + query_system = MagicMock() + query_system.filter_by.return_value.first.return_value = SimpleNamespace(encrypted_oauth_params="enc") + mock_session.query.side_effect = [query_tenant, query_system] + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True) + mocker.patch( + "services.trigger.trigger_provider_service.decrypt_system_oauth_params", + return_value={"client_id": "system"}, + ) + + # Act + result = TriggerProviderService.get_oauth_client("tenant-1", provider_id) + + # Assert + assert result == {"client_id": "system"} + + +def test_get_oauth_client_should_raise_error_when_system_decryption_fails( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + query_tenant = MagicMock() + query_tenant.filter_by.return_value.first.return_value = None + query_system = MagicMock() + query_system.filter_by.return_value.first.return_value = SimpleNamespace(encrypted_oauth_params="enc") + mock_session.query.side_effect = [query_tenant, query_system] + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True) + mocker.patch( + "services.trigger.trigger_provider_service.decrypt_system_oauth_params", + side_effect=RuntimeError("bad data"), + ) + + # Act + Assert + with pytest.raises(ValueError, match="Error decrypting system oauth params"): + TriggerProviderService.get_oauth_client("tenant-1", provider_id) + + +def test_is_oauth_system_client_exists_should_return_false_when_unverified( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=False) + + # Act + result = TriggerProviderService.is_oauth_system_client_exists("tenant-1", provider_id) + + # Assert + assert result is False + + +@pytest.mark.parametrize("has_client", [True, False]) +def test_is_oauth_system_client_exists_should_reflect_database_record( + has_client: bool, + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.first.return_value = object() if has_client else None + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True) + + # Act + result = TriggerProviderService.is_oauth_system_client_exists("tenant-1", provider_id) + + # Assert + assert result is has_client + + +def test_save_custom_oauth_client_params_should_return_success_when_nothing_to_update( + provider_id: TriggerProviderID, +) -> None: + # Arrange + # Act + result = TriggerProviderService.save_custom_oauth_client_params("tenant-1", provider_id, None, None) + + # Assert + assert result == {"result": "success"} + + +def test_save_custom_oauth_client_params_should_create_record_and_clear_params_when_client_params_none( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + query = MagicMock() + query.filter_by.return_value.first.return_value = None + mock_session.query.return_value = query + _mock_get_trigger_provider(mocker, provider_controller) + fake_model = SimpleNamespace(encrypted_oauth_params="", enabled=False, oauth_params={}) + mocker.patch("services.trigger.trigger_provider_service.TriggerOAuthTenantClient", return_value=fake_model) + + # Act + result = TriggerProviderService.save_custom_oauth_client_params( + tenant_id="tenant-1", + provider_id=provider_id, + client_params=None, + enabled=True, + ) + + # Assert + assert result == {"result": "success"} + assert fake_model.encrypted_oauth_params == "{}" + assert fake_model.enabled is True + mock_session.add.assert_called_once_with(fake_model) + mock_session.commit.assert_called_once() + + +def test_save_custom_oauth_client_params_should_merge_hidden_values_and_delete_cache( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + custom_client = SimpleNamespace(oauth_params={"client_id": "enc-old"}, enabled=False) + mock_session.query.return_value.filter_by.return_value.first.return_value = custom_client + _mock_get_trigger_provider(mocker, provider_controller) + cache = MagicMock() + enc = _encrypter_mock(decrypted={"client_id": "old-id"}, encrypted={"client_id": "new-id"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_provider_encrypter", + return_value=(enc, cache), + ) + + # Act + result = TriggerProviderService.save_custom_oauth_client_params( + tenant_id="tenant-1", + provider_id=provider_id, + client_params={"client_id": HIDDEN_VALUE, "client_secret": "new"}, + enabled=None, + ) + + # Assert + assert result == {"result": "success"} + assert json.loads(custom_client.encrypted_oauth_params) == {"client_id": "new-id"} + cache.delete.assert_called_once() + mock_session.commit.assert_called_once() + + +def test_get_custom_oauth_client_params_should_return_empty_when_record_missing( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, +) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + result = TriggerProviderService.get_custom_oauth_client_params("tenant-1", provider_id) + + # Assert + assert result == {} + + +def test_get_custom_oauth_client_params_should_return_masked_decrypted_values( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + custom_client = SimpleNamespace(oauth_params={"client_id": "enc"}) + mock_session.query.return_value.filter_by.return_value.first.return_value = custom_client + _mock_get_trigger_provider(mocker, provider_controller) + enc = _encrypter_mock(decrypted={"client_id": "plain"}, masked={"client_id": "pl***id"}) + mocker.patch("services.trigger.trigger_provider_service.create_provider_encrypter", return_value=(enc, MagicMock())) + + # Act + result = TriggerProviderService.get_custom_oauth_client_params("tenant-1", provider_id) + + # Assert + assert result == {"client_id": "pl***id"} + + +def test_delete_custom_oauth_client_params_should_delete_record_and_commit( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, +) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.delete.return_value = 1 + + # Act + result = TriggerProviderService.delete_custom_oauth_client_params("tenant-1", provider_id) + + # Assert + assert result == {"result": "success"} + mock_session.commit.assert_called_once() + + +@pytest.mark.parametrize("exists", [True, False]) +def test_is_oauth_custom_client_enabled_should_return_expected_boolean( + exists: bool, + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, +) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.first.return_value = object() if exists else None + + # Act + result = TriggerProviderService.is_oauth_custom_client_enabled("tenant-1", provider_id) + + # Assert + assert result is exists + + +def test_get_subscription_by_endpoint_should_return_none_when_not_found( + mocker: MockerFixture, mock_session: MagicMock +) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + result = TriggerProviderService.get_subscription_by_endpoint("endpoint-1") + + # Assert + assert result is None + + +def test_get_subscription_by_endpoint_should_decrypt_credentials_and_properties( + mocker: MockerFixture, + mock_session: MagicMock, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace( + tenant_id="tenant-1", + provider_id="langgenius/github/github", + credentials={"token": "enc"}, + properties={"hook": "enc"}, + ) + mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_subscription", + return_value=(_encrypter_mock(decrypted={"token": "plain"}), MagicMock()), + ) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_properties", + return_value=(_encrypter_mock(decrypted={"hook": "plain"}), MagicMock()), + ) + + # Act + result = TriggerProviderService.get_subscription_by_endpoint("endpoint-1") + + # Assert + assert result is subscription + assert subscription.credentials == {"token": "plain"} + assert subscription.properties == {"hook": "plain"} + + +def test_verify_subscription_credentials_should_raise_when_provider_not_found( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, +) -> None: + # Arrange + _mock_get_trigger_provider(mocker, None) + + # Act + Assert + with pytest.raises(ValueError, match="Provider .* not found"): + TriggerProviderService.verify_subscription_credentials( + tenant_id="tenant-1", + user_id="user-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={}, + ) + + +def test_verify_subscription_credentials_should_raise_when_subscription_not_found( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch.object(TriggerProviderService, "get_subscription_by_id", return_value=None) + + # Act + Assert + with pytest.raises(ValueError, match="Subscription sub-1 not found"): + TriggerProviderService.verify_subscription_credentials( + tenant_id="tenant-1", + user_id="user-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={}, + ) + + +def test_verify_subscription_credentials_should_raise_when_api_key_validation_fails( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace(credential_type=CredentialType.API_KEY.value, credentials={"api_key": "old"}) + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch.object(TriggerProviderService, "get_subscription_by_id", return_value=subscription) + provider_controller.validate_credentials.side_effect = RuntimeError("bad credentials") + + # Act + Assert + with pytest.raises(ValueError, match="Invalid credentials: bad credentials"): + TriggerProviderService.verify_subscription_credentials( + tenant_id="tenant-1", + user_id="user-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={"api_key": HIDDEN_VALUE}, + ) + + +def test_verify_subscription_credentials_should_return_verified_when_api_key_validation_succeeds( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace(credential_type=CredentialType.API_KEY.value, credentials={"api_key": "old"}) + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch.object(TriggerProviderService, "get_subscription_by_id", return_value=subscription) + + # Act + result = TriggerProviderService.verify_subscription_credentials( + tenant_id="tenant-1", + user_id="user-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={"api_key": HIDDEN_VALUE}, + ) + + # Assert + assert result == {"verified": True} + + +def test_verify_subscription_credentials_should_return_verified_for_non_api_key_credentials( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace(credential_type=CredentialType.OAUTH2.value, credentials={}) + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch.object(TriggerProviderService, "get_subscription_by_id", return_value=subscription) + + # Act + result = TriggerProviderService.verify_subscription_credentials( + tenant_id="tenant-1", + user_id="user-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={}, + ) + + # Assert + assert result == {"verified": True} + + +def test_rebuild_trigger_subscription_should_raise_when_provider_not_found( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, +) -> None: + # Arrange + _mock_get_trigger_provider(mocker, None) + + # Act + Assert + with pytest.raises(ValueError, match="Provider .* not found"): + TriggerProviderService.rebuild_trigger_subscription( + tenant_id="tenant-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={}, + parameters={}, + ) + + +def test_rebuild_trigger_subscription_should_raise_when_subscription_not_found( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch.object(TriggerProviderService, "get_subscription_by_id", return_value=None) + + # Act + Assert + with pytest.raises(ValueError, match="Subscription sub-1 not found"): + TriggerProviderService.rebuild_trigger_subscription( + tenant_id="tenant-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={}, + parameters={}, + ) + + +def test_rebuild_trigger_subscription_should_raise_for_unsupported_credential_type( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace(credential_type=CredentialType.UNAUTHORIZED.value) + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch.object(TriggerProviderService, "get_subscription_by_id", return_value=subscription) + + # Act + Assert + with pytest.raises(ValueError, match="not supported for auto creation"): + TriggerProviderService.rebuild_trigger_subscription( + tenant_id="tenant-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={}, + parameters={}, + ) + + +def test_rebuild_trigger_subscription_should_raise_when_unsubscribe_fails( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace( + id="sub-1", + user_id="user-1", + endpoint_id="endpoint-1", + credential_type=CredentialType.API_KEY.value, + credentials={"api_key": "old"}, + to_entity=lambda: SimpleNamespace(id="sub-1"), + ) + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch.object(TriggerProviderService, "get_subscription_by_id", return_value=subscription) + mocker.patch( + "services.trigger.trigger_provider_service.TriggerManager.unsubscribe_trigger", + return_value=SimpleNamespace(success=False, message="remote error"), + ) + + # Act + Assert + with pytest.raises(ValueError, match="Failed to delete previous subscription"): + TriggerProviderService.rebuild_trigger_subscription( + tenant_id="tenant-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={}, + parameters={}, + ) + + +def test_rebuild_trigger_subscription_should_resubscribe_and_update_existing_subscription( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace( + id="sub-1", + user_id="user-1", + endpoint_id="endpoint-1", + credential_type=CredentialType.API_KEY.value, + credentials={"api_key": "old-key"}, + to_entity=lambda: SimpleNamespace(id="sub-1"), + ) + new_subscription = SimpleNamespace(properties={"project": "new"}, expires_at=888) + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch.object(TriggerProviderService, "get_subscription_by_id", return_value=subscription) + mocker.patch( + "services.trigger.trigger_provider_service.TriggerManager.unsubscribe_trigger", + return_value=SimpleNamespace(success=True, message="ok"), + ) + mock_subscribe = mocker.patch( + "services.trigger.trigger_provider_service.TriggerManager.subscribe_trigger", + return_value=new_subscription, + ) + mocker.patch( + "services.trigger.trigger_provider_service.generate_plugin_trigger_endpoint_url", + return_value="https://endpoint", + ) + mock_update = mocker.patch.object(TriggerProviderService, "update_trigger_subscription") + + # Act + TriggerProviderService.rebuild_trigger_subscription( + tenant_id="tenant-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={"api_key": HIDDEN_VALUE, "region": "us"}, + parameters={"event": "push"}, + name="updated", + ) + + # Assert + call_kwargs = mock_subscribe.call_args.kwargs + assert call_kwargs["credentials"]["api_key"] == "old-key" + assert call_kwargs["credentials"]["region"] == "us" + mock_update.assert_called_once_with( + tenant_id="tenant-1", + subscription_id="sub-1", + name="updated", + parameters={"event": "push"}, + credentials={"api_key": "old-key", "region": "us"}, + properties={"project": "new"}, + expires_at=888, + ) diff --git a/api/tests/unit_tests/services/test_variable_truncator.py b/api/tests/unit_tests/services/test_variable_truncator.py index c703ab64d06..9c231352256 100644 --- a/api/tests/unit_tests/services/test_variable_truncator.py +++ b/api/tests/unit_tests/services/test_variable_truncator.py @@ -16,10 +16,8 @@ from typing import Any from uuid import uuid4 import pytest - -from dify_graph.file.enums import FileTransferMethod, FileType -from dify_graph.file.models import File -from dify_graph.variables.segments import ( +from graphon.file import File, FileTransferMethod, FileType +from graphon.variables.segments import ( ArrayFileSegment, ArrayNumberSegment, ArraySegment, @@ -30,6 +28,7 @@ from dify_graph.variables.segments import ( ObjectSegment, StringSegment, ) + from services.variable_truncator import ( DummyVariableTruncator, MaxDepthExceededError, diff --git a/api/tests/unit_tests/services/test_vector_service.py b/api/tests/unit_tests/services/test_vector_service.py index 7b0103a2a14..598ff3fc3a4 100644 --- a/api/tests/unit_tests/services/test_vector_service.py +++ b/api/tests/unit_tests/services/test_vector_service.py @@ -9,6 +9,7 @@ from unittest.mock import MagicMock import pytest import services.vector_service as vector_service_module +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from services.vector_service import VectorService @@ -31,8 +32,8 @@ class _ParentDocStub: def _make_dataset( *, - indexing_technique: str = "high_quality", - doc_form: str = "text_model", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, tenant_id: str = "tenant-1", dataset_id: str = "dataset-1", is_multimodal: bool = False, @@ -106,7 +107,7 @@ def test_create_segments_vector_regular_indexing_loads_documents_and_keywords(mo factory_instance.init_index_processor.return_value = index_processor monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) - VectorService.create_segments_vector([["k1"]], [segment], dataset, "text_model") + VectorService.create_segments_vector([["k1"]], [segment], dataset, IndexStructureType.PARAGRAPH_INDEX) index_processor.load.assert_called_once() args, kwargs = index_processor.load.call_args @@ -131,7 +132,7 @@ def test_create_segments_vector_regular_indexing_loads_multimodal_documents(monk factory_instance.init_index_processor.return_value = index_processor monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) - VectorService.create_segments_vector([["k1"]], [segment], dataset, "text_model") + VectorService.create_segments_vector([["k1"]], [segment], dataset, IndexStructureType.PARAGRAPH_INDEX) assert index_processor.load.call_count == 2 first_args, first_kwargs = index_processor.load.call_args_list[0] @@ -153,7 +154,7 @@ def test_create_segments_vector_with_no_segments_does_not_load(monkeypatch: pyte factory_instance.init_index_processor.return_value = index_processor monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) - VectorService.create_segments_vector(None, [], dataset, "text_model") + VectorService.create_segments_vector(None, [], dataset, IndexStructureType.PARAGRAPH_INDEX) index_processor.load.assert_not_called() @@ -191,7 +192,7 @@ def test_create_segments_vector_parent_child_calls_generate_child_chunks_with_ex dataset = _make_dataset( doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX, embedding_model_provider="openai", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) segment = _make_segment() @@ -213,7 +214,9 @@ def test_create_segments_vector_parent_child_calls_generate_child_chunks_with_ex embedding_model_instance = MagicMock(name="embedding_model_instance") model_manager_instance = MagicMock(name="model_manager_instance") model_manager_instance.get_model_instance.return_value = embedding_model_instance - monkeypatch.setattr(vector_service_module, "ModelManager", MagicMock(return_value=model_manager_instance)) + monkeypatch.setattr( + vector_service_module.ModelManager, "for_tenant", MagicMock(return_value=model_manager_instance) + ) generate_child_chunks_mock = MagicMock() monkeypatch.setattr(VectorService, "generate_child_chunks", generate_child_chunks_mock) @@ -240,7 +243,7 @@ def test_create_segments_vector_parent_child_uses_default_embedding_model_when_p dataset = _make_dataset( doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX, embedding_model_provider=None, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) segment = _make_segment() @@ -261,7 +264,9 @@ def test_create_segments_vector_parent_child_uses_default_embedding_model_when_p embedding_model_instance = MagicMock() model_manager_instance = MagicMock() model_manager_instance.get_default_model_instance.return_value = embedding_model_instance - monkeypatch.setattr(vector_service_module, "ModelManager", MagicMock(return_value=model_manager_instance)) + monkeypatch.setattr( + vector_service_module.ModelManager, "for_tenant", MagicMock(return_value=model_manager_instance) + ) generate_child_chunks_mock = MagicMock() monkeypatch.setattr(VectorService, "generate_child_chunks", generate_child_chunks_mock) @@ -328,7 +333,7 @@ def test_create_segments_vector_parent_child_missing_processing_rule_raises(monk def test_create_segments_vector_parent_child_non_high_quality_raises(monkeypatch: pytest.MonkeyPatch) -> None: dataset = _make_dataset( doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX, - indexing_technique="economy", + indexing_technique=IndexTechniqueType.ECONOMY, ) segment = _make_segment() dataset_document = MagicMock() @@ -347,7 +352,7 @@ def test_create_segments_vector_parent_child_non_high_quality_raises(monkeypatch def test_update_segment_vector_high_quality_uses_vector(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="high_quality") + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY) segment = _make_segment() vector_instance = MagicMock() @@ -363,7 +368,7 @@ def test_update_segment_vector_high_quality_uses_vector(monkeypatch: pytest.Monk def test_update_segment_vector_economy_uses_keyword_with_keywords_list(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="economy") + dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY) segment = _make_segment() keyword_instance = MagicMock() @@ -379,7 +384,7 @@ def test_update_segment_vector_economy_uses_keyword_with_keywords_list(monkeypat def test_update_segment_vector_economy_uses_keyword_without_keywords_list(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="economy") + dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY) segment = _make_segment() keyword_instance = MagicMock() @@ -392,7 +397,7 @@ def test_update_segment_vector_economy_uses_keyword_without_keywords_list(monkey def test_generate_child_chunks_regenerate_cleans_then_saves_children(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(doc_form="text_model", tenant_id="tenant-1", dataset_id="dataset-1") + dataset = _make_dataset(doc_form=IndexStructureType.PARAGRAPH_INDEX, tenant_id="tenant-1", dataset_id="dataset-1") segment = _make_segment(segment_id="seg-1") dataset_document = MagicMock() @@ -439,7 +444,7 @@ def test_generate_child_chunks_regenerate_cleans_then_saves_children(monkeypatch def test_generate_child_chunks_commits_even_when_no_children(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(doc_form="text_model") + dataset = _make_dataset(doc_form=IndexStructureType.PARAGRAPH_INDEX) segment = _make_segment() dataset_document = MagicMock() dataset_document.doc_language = "en" @@ -472,7 +477,7 @@ def test_generate_child_chunks_commits_even_when_no_children(monkeypatch: pytest def test_create_child_chunk_vector_high_quality_adds_texts(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="high_quality") + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY) child_chunk = MagicMock() child_chunk.content = "child" child_chunk.index_node_id = "id" @@ -488,7 +493,7 @@ def test_create_child_chunk_vector_high_quality_adds_texts(monkeypatch: pytest.M def test_create_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="economy") + dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY) vector_cls = MagicMock() monkeypatch.setattr(vector_service_module, "Vector", vector_cls) @@ -504,7 +509,7 @@ def test_create_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch) def test_update_child_chunk_vector_high_quality_updates_vector(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="high_quality") + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY) new_chunk = MagicMock() new_chunk.content = "n" @@ -535,7 +540,7 @@ def test_update_child_chunk_vector_high_quality_updates_vector(monkeypatch: pyte def test_update_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="economy") + dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY) vector_cls = MagicMock() monkeypatch.setattr(vector_service_module, "Vector", vector_cls) VectorService.update_child_chunk_vector([], [], [], dataset) @@ -560,7 +565,7 @@ def test_delete_child_chunk_vector_deletes_by_id(monkeypatch: pytest.MonkeyPatch def test_update_multimodel_vector_returns_when_not_high_quality(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="economy", is_multimodal=True) + dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY, is_multimodal=True) segment = _make_segment(tenant_id="t", attachments=[{"id": "a"}]) vector_cls = MagicMock() @@ -574,7 +579,7 @@ def test_update_multimodel_vector_returns_when_not_high_quality(monkeypatch: pyt def test_update_multimodel_vector_returns_when_no_actual_change(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True) segment = _make_segment(tenant_id="t", attachments=[{"id": "a"}, {"id": "b"}]) vector_cls = MagicMock() @@ -590,7 +595,7 @@ def test_update_multimodel_vector_returns_when_no_actual_change(monkeypatch: pyt def test_update_multimodel_vector_deletes_bindings_and_commits_on_empty_new_ids( monkeypatch: pytest.MonkeyPatch, ) -> None: - dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True) segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}, {"id": "old-2"}]) vector_instance = MagicMock(name="vector_instance") @@ -611,7 +616,7 @@ def test_update_multimodel_vector_deletes_bindings_and_commits_on_empty_new_ids( def test_update_multimodel_vector_commits_when_no_upload_files_found(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True) segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}]) vector_instance = MagicMock() @@ -629,7 +634,7 @@ def test_update_multimodel_vector_commits_when_no_upload_files_found(monkeypatch def test_update_multimodel_vector_adds_bindings_and_vectors_and_skips_missing_upload_files( monkeypatch: pytest.MonkeyPatch, ) -> None: - dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True) segment = _make_segment(segment_id="seg-1", tenant_id="tenant-1", attachments=[{"id": "old-1"}]) vector_instance = MagicMock() @@ -662,7 +667,7 @@ def test_update_multimodel_vector_adds_bindings_and_vectors_and_skips_missing_up def test_update_multimodel_vector_updates_bindings_without_multimodal_vector_ops( monkeypatch: pytest.MonkeyPatch, ) -> None: - dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=False) + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=False) segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}]) vector_instance = MagicMock() @@ -682,7 +687,7 @@ def test_update_multimodel_vector_updates_bindings_without_multimodal_vector_ops def test_update_multimodel_vector_rolls_back_and_reraises_on_error(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True) segment = _make_segment(segment_id="seg-1", tenant_id="tenant-1", attachments=[{"id": "old-1"}]) vector_instance = MagicMock() diff --git a/api/tests/unit_tests/services/test_workflow_run_service_pause.py b/api/tests/unit_tests/services/test_workflow_run_service_pause.py index 27664c7e294..a62c9f45556 100644 --- a/api/tests/unit_tests/services/test_workflow_run_service_pause.py +++ b/api/tests/unit_tests/services/test_workflow_run_service_pause.py @@ -13,10 +13,10 @@ from datetime import datetime from unittest.mock import MagicMock, create_autospec, patch import pytest +from graphon.enums import WorkflowExecutionStatus from sqlalchemy import Engine from sqlalchemy.orm import Session, sessionmaker -from dify_graph.enums import WorkflowExecutionStatus from models.workflow import WorkflowPause from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity diff --git a/api/tests/unit_tests/services/test_workflow_service.py b/api/tests/unit_tests/services/test_workflow_service.py index 57c0464dc62..cd71981bcf1 100644 --- a/api/tests/unit_tests/services/test_workflow_service.py +++ b/api/tests/unit_tests/services/test_workflow_service.py @@ -10,18 +10,37 @@ This test suite covers: """ import json -from unittest.mock import MagicMock, patch +import uuid +from typing import Any, cast +from unittest.mock import ANY, MagicMock, patch import pytest +from graphon.entities import WorkflowNodeExecution +from graphon.enums import ( + BuiltinNodeTypes, + ErrorStrategy, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.errors import WorkflowNodeRunFailedError +from graphon.graph_events import NodeRunFailedEvent, NodeRunSucceededEvent +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.node_events import NodeRunResult +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig +from graphon.variables.input_entities import VariableEntityType -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig from libs.datetime_utils import naive_utc_now +from models.human_input import RecipientType from models.model import App, AppMode from models.workflow import Workflow, WorkflowType from services.errors.app import IsDraftWorkflowError, TriggerNodeLimitExceededError, WorkflowHashNotEqualError from services.errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError -from services.workflow_service import WorkflowService +from services.workflow_service import ( + WorkflowService, + _rebuild_file_for_user_inputs_in_start_node, + _rebuild_single_file, + _setup_variable_pool, +) class TestWorkflowAssociatedDataFactory: @@ -544,6 +563,89 @@ class TestWorkflowService: conversation_variables=[], ) + def test_restore_published_workflow_to_draft_keeps_source_features_unmodified( + self, workflow_service, mock_db_session + ): + app = TestWorkflowAssociatedDataFactory.create_app_mock() + account = TestWorkflowAssociatedDataFactory.create_account_mock() + legacy_features = { + "file_upload": { + "image": { + "enabled": True, + "number_limits": 6, + "transfer_methods": ["remote_url", "local_file"], + } + }, + "opening_statement": "", + "retriever_resource": {"enabled": True}, + "sensitive_word_avoidance": {"enabled": False}, + "speech_to_text": {"enabled": False}, + "suggested_questions": [], + "suggested_questions_after_answer": {"enabled": False}, + "text_to_speech": {"enabled": False, "language": "", "voice": ""}, + } + normalized_features = { + "file_upload": { + "enabled": True, + "allowed_file_types": ["image"], + "allowed_file_extensions": [], + "allowed_file_upload_methods": ["remote_url", "local_file"], + "number_limits": 6, + }, + "opening_statement": "", + "retriever_resource": {"enabled": True}, + "sensitive_word_avoidance": {"enabled": False}, + "speech_to_text": {"enabled": False}, + "suggested_questions": [], + "suggested_questions_after_answer": {"enabled": False}, + "text_to_speech": {"enabled": False, "language": "", "voice": ""}, + } + source_workflow = Workflow( + id="published-workflow-id", + tenant_id=app.tenant_id, + app_id=app.id, + type=WorkflowType.WORKFLOW.value, + version="2026-03-19T00:00:00", + graph=json.dumps(TestWorkflowAssociatedDataFactory.create_valid_workflow_graph()), + features=json.dumps(legacy_features), + created_by=account.id, + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + draft_workflow = Workflow( + id="draft-workflow-id", + tenant_id=app.tenant_id, + app_id=app.id, + type=WorkflowType.WORKFLOW.value, + version=Workflow.VERSION_DRAFT, + graph=json.dumps({"nodes": [], "edges": []}), + features=json.dumps({}), + created_by=account.id, + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + + with ( + patch.object(workflow_service, "get_published_workflow_by_id", return_value=source_workflow), + patch.object(workflow_service, "get_draft_workflow", return_value=draft_workflow), + patch.object(workflow_service, "validate_graph_structure"), + patch.object(workflow_service, "validate_features_structure") as mock_validate_features, + patch("services.workflow_service.app_draft_workflow_was_synced"), + ): + result = workflow_service.restore_published_workflow_to_draft( + app_model=app, + workflow_id=source_workflow.id, + account=account, + ) + + mock_validate_features.assert_called_once_with(app_model=app, features=normalized_features) + assert result is draft_workflow + assert source_workflow.serialized_features == json.dumps(legacy_features) + assert draft_workflow.serialized_features == json.dumps(legacy_features) + mock_db_session.session.commit.assert_called_once() + # ==================== Workflow Validation Tests ==================== # These tests verify graph structure and feature configuration validation @@ -1226,3 +1328,1460 @@ class TestWorkflowService: with pytest.raises(ValueError, match="not supported convert to workflow"): workflow_service.convert_to_workflow(app, account, args) + + +# =========================================================================== +# TestWorkflowServiceCredentialValidation +# Tests for _validate_workflow_credentials and related private helpers +# =========================================================================== + + +class TestWorkflowServiceCredentialValidation: + """ + Tests for the private credential-validation helpers on WorkflowService. + + These helpers gate `publish_workflow` when `PluginManager` is enabled. + Each test focuses on a distinct branch inside `_validate_workflow_credentials`, + `_validate_llm_model_config`, `_check_default_tool_credential`, and the + load-balancing path. + """ + + @pytest.fixture + def service(self) -> WorkflowService: + with patch("services.workflow_service.db"): + return WorkflowService() + + @staticmethod + def _make_workflow(nodes: list[dict]) -> MagicMock: + wf = MagicMock(spec=Workflow) + wf.tenant_id = "tenant-1" + wf.app_id = "app-1" + wf.graph_dict = {"nodes": nodes} + return wf + + # --- _validate_workflow_credentials: tool node (with credential_id) --- + + def test_validate_workflow_credentials_should_check_tool_credential_when_credential_id_present( + self, service: WorkflowService + ) -> None: + # Arrange + nodes = [ + { + "id": "tool-node", + "data": { + "type": "tool", + "provider_id": "my-provider", + "credential_id": "cred-123", + }, + } + ] + workflow = self._make_workflow(nodes) + + # Act + Assert + with patch("core.helper.credential_utils.check_credential_policy_compliance") as mock_check: + # Should not raise; mock allows the call + service._validate_workflow_credentials(workflow) + mock_check.assert_called_once() + + def test_validate_workflow_credentials_should_check_default_credential_when_no_credential_id( + self, service: WorkflowService + ) -> None: + # Arrange + nodes = [ + { + "id": "tool-node", + "data": { + "type": "tool", + "provider_id": "my-provider", + # No credential_id — should fall back to default + }, + } + ] + workflow = self._make_workflow(nodes) + + # Act + with patch.object(service, "_check_default_tool_credential") as mock_default: + service._validate_workflow_credentials(workflow) + + # Assert + mock_default.assert_called_once_with("tenant-1", "my-provider") + + def test_validate_workflow_credentials_should_skip_tool_node_without_provider( + self, service: WorkflowService + ) -> None: + """Tool nodes without a provider_id should be silently skipped.""" + # Arrange + nodes = [{"id": "tool-node", "data": {"type": "tool"}}] + workflow = self._make_workflow(nodes) + + # Act + Assert (no error raised) + with patch.object(service, "_check_default_tool_credential") as mock_default: + service._validate_workflow_credentials(workflow) + mock_default.assert_not_called() + + def test_validate_workflow_credentials_should_validate_llm_node_with_model_config( + self, service: WorkflowService + ) -> None: + # Arrange + nodes = [ + { + "id": "llm-node", + "data": { + "type": "llm", + "model": {"provider": "openai", "name": "gpt-4"}, + }, + } + ] + workflow = self._make_workflow(nodes) + + # Act + with ( + patch.object(service, "_validate_llm_model_config") as mock_llm, + patch.object(service, "_validate_load_balancing_credentials"), + ): + service._validate_workflow_credentials(workflow) + + # Assert + mock_llm.assert_called_once_with("tenant-1", "openai", "gpt-4") + + def test_validate_workflow_credentials_should_raise_for_llm_node_missing_model( + self, service: WorkflowService + ) -> None: + """LLM nodes without provider AND name should raise ValueError.""" + # Arrange + nodes = [ + { + "id": "llm-node", + "data": {"type": "llm", "model": {"provider": "openai"}}, # name missing + } + ] + workflow = self._make_workflow(nodes) + + # Act + Assert + with pytest.raises(ValueError, match="Missing provider or model configuration"): + service._validate_workflow_credentials(workflow) + + def test_validate_workflow_credentials_should_wrap_unexpected_exception_in_value_error( + self, service: WorkflowService + ) -> None: + """Non-ValueError exceptions from validation must be re-raised as ValueError.""" + # Arrange + nodes = [ + { + "id": "llm-node", + "data": { + "type": "llm", + "model": {"provider": "openai", "name": "gpt-4"}, + }, + } + ] + workflow = self._make_workflow(nodes) + + # Act + Assert + with patch.object(service, "_validate_llm_model_config", side_effect=RuntimeError("boom")): + with pytest.raises(ValueError, match="boom"): + service._validate_workflow_credentials(workflow) + + def test_validate_workflow_credentials_should_validate_agent_node_model(self, service: WorkflowService) -> None: + # Arrange + nodes = [ + { + "id": "agent-node", + "data": { + "type": "agent", + "agent_parameters": { + "model": {"value": {"provider": "openai", "model": "gpt-4"}}, + "tools": {"value": []}, + }, + }, + } + ] + workflow = self._make_workflow(nodes) + + # Act + with ( + patch.object(service, "_validate_llm_model_config") as mock_llm, + patch.object(service, "_validate_load_balancing_credentials"), + ): + service._validate_workflow_credentials(workflow) + + # Assert + mock_llm.assert_called_once_with("tenant-1", "openai", "gpt-4") + + def test_validate_workflow_credentials_should_validate_agent_tools(self, service: WorkflowService) -> None: + """Each agent tool with a provider should be checked for credential compliance.""" + # Arrange + nodes = [ + { + "id": "agent-node", + "data": { + "type": "agent", + "agent_parameters": { + "model": {"value": {}}, # no model config + "tools": { + "value": [ + {"provider_name": "provider-a", "credential_id": "cred-a"}, + {"provider_name": "provider-b"}, # uses default + ] + }, + }, + }, + } + ] + workflow = self._make_workflow(nodes) + + # Act + with ( + patch("core.helper.credential_utils.check_credential_policy_compliance") as mock_check, + patch.object(service, "_check_default_tool_credential") as mock_default, + ): + service._validate_workflow_credentials(workflow) + + # Assert + mock_check.assert_called_once() # provider-a has credential_id + mock_default.assert_called_once_with("tenant-1", "provider-b") + + # --- _validate_llm_model_config --- + + def test_validate_llm_model_config_should_raise_value_error_on_failure(self, service: WorkflowService) -> None: + """If ModelManager raises any exception it must be wrapped into ValueError.""" + # Arrange + assembly = MagicMock() + assembly.model_manager.get_model_instance.side_effect = RuntimeError("no key") + + with patch("services.workflow_service.create_plugin_model_assembly", return_value=assembly): + # Act + Assert + with pytest.raises(ValueError, match="Failed to validate LLM model configuration"): + service._validate_llm_model_config("tenant-1", "openai", "gpt-4") + + def test_validate_llm_model_config_success(self, service: WorkflowService) -> None: + """Test success path with ProviderManager and Model entities.""" + mock_model = MagicMock() + mock_model.model = "gpt-4" + mock_model.provider.provider = "openai" + + mock_configs = MagicMock() + mock_configs.get_models.return_value = [mock_model] + assembly = MagicMock() + assembly.provider_manager.get_configurations.return_value = mock_configs + + with patch("services.workflow_service.create_plugin_model_assembly", return_value=assembly): + # Act + service._validate_llm_model_config("tenant-1", "openai", "gpt-4") + + # Assert + mock_model.raise_for_status.assert_called_once() + assembly.model_manager.get_model_instance.assert_called_once_with( + tenant_id="tenant-1", + provider="openai", + model_type=ModelType.LLM, + model="gpt-4", + ) + + def test_validate_llm_model_config_model_not_found(self, service: WorkflowService) -> None: + """Test ValueError when model is not found in provider configurations.""" + mock_configs = MagicMock() + mock_configs.get_models.return_value = [] # No models + assembly = MagicMock() + assembly.provider_manager.get_configurations.return_value = mock_configs + + with patch("services.workflow_service.create_plugin_model_assembly", return_value=assembly): + # Act + Assert + with pytest.raises(ValueError, match="Model gpt-4 not found for provider openai"): + service._validate_llm_model_config("tenant-1", "openai", "gpt-4") + + # --- _check_default_tool_credential --- + + def test_check_default_tool_credential_should_silently_pass_when_no_provider_found( + self, service: WorkflowService + ) -> None: + """Missing BuiltinToolProvider → plugin requires no credentials → no error.""" + # Arrange + with patch("services.workflow_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = None + + # Act + Assert (should NOT raise) + service._check_default_tool_credential("tenant-1", "some-provider") + + def test_check_default_tool_credential_should_raise_when_compliance_fails(self, service: WorkflowService) -> None: + # Arrange + mock_provider = MagicMock() + mock_provider.id = "builtin-cred-id" + with ( + patch("services.workflow_service.db") as mock_db, + patch("core.helper.credential_utils.check_credential_policy_compliance", side_effect=Exception("denied")), + ): + mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = ( + mock_provider + ) + + # Act + Assert + with pytest.raises(ValueError, match="Failed to validate default credential"): + service._check_default_tool_credential("tenant-1", "some-provider") + + # --- _is_load_balancing_enabled --- + + def test_is_load_balancing_enabled_should_return_false_when_provider_not_found( + self, service: WorkflowService + ) -> None: + # Arrange + with patch("services.workflow_service.db"): + service_instance = WorkflowService() + + with patch("core.provider_manager.ProviderManager.get_configurations") as mock_get_configs: + mock_configs = MagicMock() + mock_configs.get.return_value = None # provider not found + mock_get_configs.return_value = mock_configs + + # Act + result = service_instance._is_load_balancing_enabled("tenant-1", "openai", "gpt-4") + + # Assert + assert result is False + + def test_is_load_balancing_enabled_should_return_true_when_setting_enabled(self, service: WorkflowService) -> None: + # Arrange + with patch("core.provider_manager.ProviderManager.get_configurations") as mock_get_configs: + mock_provider_config = MagicMock() + mock_provider_model_setting = MagicMock() + mock_provider_model_setting.load_balancing_enabled = True + mock_provider_config.get_provider_model_setting.return_value = mock_provider_model_setting + + mock_configs = MagicMock() + mock_configs.get.return_value = mock_provider_config + mock_get_configs.return_value = mock_configs + + # Act + result = service._is_load_balancing_enabled("tenant-1", "openai", "gpt-4") + + # Assert + assert result is True + + def test_is_load_balancing_enabled_should_return_false_on_exception(self, service: WorkflowService) -> None: + """Any exception should be swallowed and return False.""" + # Arrange + with patch("core.provider_manager.ProviderManager.get_configurations", side_effect=RuntimeError("db down")): + # Act + result = service._is_load_balancing_enabled("tenant-1", "openai", "gpt-4") + + # Assert + assert result is False + + # --- _get_load_balancing_configs --- + + def test_get_load_balancing_configs_should_return_empty_list_on_exception(self, service: WorkflowService) -> None: + """Any exception during LB config retrieval should return an empty list.""" + # Arrange + with patch( + "services.model_load_balancing_service.ModelLoadBalancingService.get_load_balancing_configs", + side_effect=RuntimeError("fail"), + ): + # Act + result = service._get_load_balancing_configs("tenant-1", "openai", "gpt-4") + + # Assert + assert result == [] + + def test_get_load_balancing_configs_should_merge_predefined_and_custom(self, service: WorkflowService) -> None: + # Arrange + predefined = [{"credential_id": "cred-a"}, {"credential_id": None}] + custom = [{"credential_id": "cred-b"}] + with patch( + "services.model_load_balancing_service.ModelLoadBalancingService.get_load_balancing_configs", + side_effect=[ + (None, predefined), # first call: predefined-model + (None, custom), # second call: custom-model + ], + ): + # Act + result = service._get_load_balancing_configs("tenant-1", "openai", "gpt-4") + + # Assert — only entries with a credential_id should be returned + assert len(result) == 2 + assert all(c["credential_id"] for c in result) + + # --- _validate_load_balancing_credentials --- + + def test_validate_load_balancing_credentials_should_skip_when_no_model_config( + self, service: WorkflowService + ) -> None: + """Missing provider or model in node_data should be a no-op.""" + # Arrange + workflow = self._make_workflow([]) + node_data: dict = {} # no model key + + # Act + Assert (no error expected) + service._validate_load_balancing_credentials(workflow, node_data, "node-1") + + def test_validate_load_balancing_credentials_should_skip_when_lb_not_enabled( + self, service: WorkflowService + ) -> None: + # Arrange + workflow = self._make_workflow([]) + node_data = {"model": {"provider": "openai", "name": "gpt-4"}} + + # Act + Assert (no error expected) + with patch.object(service, "_is_load_balancing_enabled", return_value=False): + service._validate_load_balancing_credentials(workflow, node_data, "node-1") + + def test_validate_load_balancing_credentials_should_raise_when_compliance_fails( + self, service: WorkflowService + ) -> None: + # Arrange + workflow = self._make_workflow([]) + node_data = {"model": {"provider": "openai", "name": "gpt-4"}} + lb_configs = [{"credential_id": "cred-lb-1"}] + + # Act + Assert + with ( + patch.object(service, "_is_load_balancing_enabled", return_value=True), + patch.object(service, "_get_load_balancing_configs", return_value=lb_configs), + patch( + "core.helper.credential_utils.check_credential_policy_compliance", + side_effect=Exception("policy violation"), + ), + ): + with pytest.raises(ValueError, match="Invalid load balancing credentials"): + service._validate_load_balancing_credentials(workflow, node_data, "node-1") + + +# =========================================================================== +# TestWorkflowServiceExecutionHelpers +# Tests for _apply_error_strategy, _populate_execution_result, _execute_node_safely +# =========================================================================== + + +class TestWorkflowServiceExecutionHelpers: + """ + Tests for the private execution-result handling methods: + _apply_error_strategy, _populate_execution_result, _execute_node_safely. + """ + + @pytest.fixture + def service(self) -> WorkflowService: + with patch("services.workflow_service.db"): + return WorkflowService() + + # --- _apply_error_strategy --- + + def test_apply_error_strategy_should_return_exception_status_noderunresult(self, service: WorkflowService) -> None: + # Arrange + node = MagicMock() + node.error_strategy = ErrorStrategy.FAIL_BRANCH + node.default_value_dict = {} + original = NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error="something went wrong", + error_type="SomeError", + inputs={"x": 1}, + outputs={}, + ) + + # Act + result = service._apply_error_strategy(node, original) + + # Assert + assert result.status == WorkflowNodeExecutionStatus.EXCEPTION + assert result.error == "something went wrong" + assert result.metadata[WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY] == ErrorStrategy.FAIL_BRANCH + + def test_apply_error_strategy_should_include_default_values_for_default_value_strategy( + self, service: WorkflowService + ) -> None: + # Arrange + node = MagicMock() + node.error_strategy = ErrorStrategy.DEFAULT_VALUE + node.default_value_dict = {"output_key": "fallback"} + original = NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error="err", + ) + + # Act + result = service._apply_error_strategy(node, original) + + # Assert + assert result.outputs.get("output_key") == "fallback" + assert result.status == WorkflowNodeExecutionStatus.EXCEPTION + + # --- _populate_execution_result --- + + def test_populate_execution_result_should_set_succeeded_fields_when_run_succeeded( + self, service: WorkflowService + ) -> None: + # Arrange + node_execution = MagicMock(error=None) + node_run_result = NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"q": "hello"}, + process_data={"steps": 3}, + outputs={"answer": "hi"}, + metadata={WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 10}, + ) + + # Act + with patch("services.workflow_service.WorkflowEntry.handle_special_values", side_effect=lambda x: x): + service._populate_execution_result(node_execution, node_run_result, True, None) + + # Assert + assert node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert node_execution.outputs == {"answer": "hi"} + assert node_execution.error is None # SUCCEEDED status doesn't set error + + def test_populate_execution_result_should_set_failed_status_and_error_when_not_succeeded( + self, service: WorkflowService + ) -> None: + # Arrange + node_execution = MagicMock(error=None) + + # Act + service._populate_execution_result(node_execution, None, False, "catastrophic failure") + + # Assert + assert node_execution.status == WorkflowNodeExecutionStatus.FAILED + assert node_execution.error == "catastrophic failure" + + def test_populate_execution_result_should_set_error_field_for_exception_status( + self, service: WorkflowService + ) -> None: + """A succeeded=True result with EXCEPTION status should still populate the error field.""" + # Arrange + node_execution = MagicMock() + node_run_result = NodeRunResult( + status=WorkflowNodeExecutionStatus.EXCEPTION, + error="constraint violated", + ) + + # Act + with patch("services.workflow_service.WorkflowEntry.handle_special_values", side_effect=lambda x: x): + service._populate_execution_result(node_execution, node_run_result, True, None) + + # Assert + assert node_execution.status == WorkflowNodeExecutionStatus.EXCEPTION + assert node_execution.error == "constraint violated" + + # --- _execute_node_safely --- + + def test_execute_node_safely_should_return_succeeded_result_on_happy_path(self, service: WorkflowService) -> None: + # Arrange + node = MagicMock() + node.error_strategy = None + node_run_result = MagicMock() + node_run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED + node_run_result.error = None + + succeeded_event = MagicMock(spec=NodeRunSucceededEvent) + succeeded_event.node_run_result = node_run_result + + def invoke_fn(): + def _gen(): + yield succeeded_event + + return node, _gen() + + # Act + out_node, out_result, run_succeeded, error = service._execute_node_safely(invoke_fn) + + # Assert + assert out_node is node + assert run_succeeded is True + assert error is None + + def test_execute_node_safely_should_return_failed_result_on_failed_event(self, service: WorkflowService) -> None: + # Arrange + node = MagicMock() + node.error_strategy = None + node_run_result = MagicMock() + node_run_result.status = WorkflowNodeExecutionStatus.FAILED + node_run_result.error = "node exploded" + + failed_event = MagicMock(spec=NodeRunFailedEvent) + failed_event.node_run_result = node_run_result + + def invoke_fn(): + def _gen(): + yield failed_event + + return node, _gen() + + # Act + _, _, run_succeeded, error = service._execute_node_safely(invoke_fn) + + # Assert + assert run_succeeded is False + assert error == "node exploded" + + def test_execute_node_safely_should_handle_workflow_node_run_failed_error(self, service: WorkflowService) -> None: + # Arrange + node = MagicMock() + exc = WorkflowNodeRunFailedError(node, "runtime failure") + + def invoke_fn(): + raise exc + + # Act + out_node, out_result, run_succeeded, error = service._execute_node_safely(invoke_fn) + + # Assert + assert out_node is node + assert out_result is None + assert run_succeeded is False + assert error == "runtime failure" + + def test_execute_node_safely_should_raise_when_no_result_event(self, service: WorkflowService) -> None: + """If the generator produces no NodeRunSucceededEvent/NodeRunFailedEvent, ValueError is expected.""" + # Arrange + node = MagicMock() + node.error_strategy = None + + def invoke_fn(): + def _gen(): + yield from [] + + return node, _gen() + + # Act + Assert + with pytest.raises(ValueError, match="no result returned"): + service._execute_node_safely(invoke_fn) + + # --- _apply_error_strategy with FAIL_BRANCH strategy --- + + def test_execute_node_safely_should_apply_error_strategy_on_failed_status(self, service: WorkflowService) -> None: + # Arrange + node = MagicMock() + node.error_strategy = ErrorStrategy.FAIL_BRANCH + node.default_value_dict = {} + + original_result = MagicMock() + original_result.status = WorkflowNodeExecutionStatus.FAILED + original_result.error = "oops" + original_result.error_type = "ValueError" + original_result.inputs = {} + + failed_event = MagicMock(spec=NodeRunFailedEvent) + failed_event.node_run_result = original_result + + def invoke_fn(): + def _gen(): + yield failed_event + + return node, _gen() + + # Act + _, result, run_succeeded, _ = service._execute_node_safely(invoke_fn) + + # Assert — after applying error strategy status becomes EXCEPTION + assert result is not None + assert result.status == WorkflowNodeExecutionStatus.EXCEPTION + # run_succeeded should be True because EXCEPTION is in the succeeded set + assert run_succeeded is True + + +# =========================================================================== +# TestWorkflowServiceGetNodeLastRun +# Tests for get_node_last_run delegation to repository +# =========================================================================== + + +class TestWorkflowServiceGetNodeLastRun: + @pytest.fixture + def service(self) -> WorkflowService: + with patch("services.workflow_service.db"): + return WorkflowService() + + def test_get_node_last_run_should_delegate_to_repository(self, service: WorkflowService) -> None: + # Arrange + app = MagicMock(spec=App) + app.tenant_id = "tenant-1" + app.id = "app-1" + workflow = MagicMock(spec=Workflow) + workflow.id = "wf-1" + expected = MagicMock() + + service._node_execution_service_repo = MagicMock() + service._node_execution_service_repo.get_node_last_execution.return_value = expected + + # Act + result = service.get_node_last_run(app, workflow, "node-42") + + # Assert + assert result is expected + service._node_execution_service_repo.get_node_last_execution.assert_called_once_with( + tenant_id="tenant-1", + app_id="app-1", + workflow_id="wf-1", + node_id="node-42", + ) + + def test_get_node_last_run_should_return_none_when_repository_returns_none(self, service: WorkflowService) -> None: + # Arrange + app = MagicMock(spec=App) + app.tenant_id = "t" + app.id = "a" + workflow = MagicMock(spec=Workflow) + workflow.id = "w" + service._node_execution_service_repo = MagicMock() + service._node_execution_service_repo.get_node_last_execution.return_value = None + + # Act + result = service.get_node_last_run(app, workflow, "node-x") + + # Assert + assert result is None + + +# =========================================================================== +# TestWorkflowServiceModuleLevelHelpers +# Tests for module-level helper functions exported from workflow_service +# =========================================================================== + + +class TestSetupVariablePool: + """ + Tests for the module-level `_setup_variable_pool` function. + This helper initialises the VariablePool used for single-step workflow execution. + """ + + def _make_workflow(self, workflow_type: str = WorkflowType.WORKFLOW.value) -> MagicMock: + wf = MagicMock(spec=Workflow) + wf.app_id = "app-1" + wf.id = "wf-1" + wf.type = workflow_type + wf.environment_variables = [] + return wf + + def test_setup_variable_pool_should_use_full_system_variables_for_start_node( + self, + ) -> None: + # Arrange + workflow = self._make_workflow() + + # Act + with ( + patch("services.workflow_service.VariablePool") as MockPool, + patch("services.workflow_service.build_system_variables") as mock_build_system_variables, + patch("services.workflow_service.build_bootstrap_variables") as mock_build_bootstrap_variables, + patch("services.workflow_service.add_variables_to_pool") as mock_add_variables_to_pool, + patch("services.workflow_service.add_node_inputs_to_pool") as mock_add_node_inputs_to_pool, + ): + _setup_variable_pool( + query="hello", + files=[], + user_id="u-1", + user_inputs={"k": "v"}, + workflow=workflow, + node_id="start-node", + node_type=BuiltinNodeTypes.START, + conversation_id="conv-1", + conversation_variables=[], + ) + + # Assert — start nodes should build bootstrap variables and attach node inputs. + MockPool.assert_called_once_with() + mock_build_system_variables.assert_called_once() + mock_add_variables_to_pool.assert_called_once_with( + MockPool.return_value, + mock_build_bootstrap_variables.return_value, + ) + mock_add_node_inputs_to_pool.assert_called_once_with( + MockPool.return_value, + node_id="start-node", + inputs={"k": "v"}, + ) + + def test_setup_variable_pool_should_use_default_system_variables_for_non_start_node( + self, + ) -> None: + # Arrange + workflow = self._make_workflow() + + # Act + with ( + patch("services.workflow_service.VariablePool") as MockPool, + patch("services.workflow_service.default_system_variables") as mock_default_system_variables, + patch("services.workflow_service.build_bootstrap_variables") as mock_build_bootstrap_variables, + patch("services.workflow_service.add_variables_to_pool") as mock_add_variables_to_pool, + patch("services.workflow_service.add_node_inputs_to_pool") as mock_add_node_inputs_to_pool, + ): + _setup_variable_pool( + query="", + files=[], + user_id="u-1", + user_inputs={}, + workflow=workflow, + node_id="llm-node", + node_type=BuiltinNodeTypes.LLM, # not a start/trigger node + conversation_id="conv-1", + conversation_variables=[], + ) + + # Assert — default system variables should be used and node inputs should not be added. + mock_default_system_variables.assert_called_once() + MockPool.assert_called_once_with() + mock_add_variables_to_pool.assert_called_once_with( + MockPool.return_value, + mock_build_bootstrap_variables.return_value, + ) + mock_add_node_inputs_to_pool.assert_not_called() + + def test_setup_variable_pool_should_set_chatflow_specifics_for_non_workflow_type( + self, + ) -> None: + """For ADVANCED_CHAT workflows on a START node, query/conversation_id/dialogue_count should be set.""" + from models.workflow import WorkflowType + + # Arrange + workflow = self._make_workflow(workflow_type=WorkflowType.CHAT.value) + + # Act + with ( + patch("services.workflow_service.VariablePool") as MockPool, + patch("services.workflow_service.build_system_variables") as mock_build_system_variables, + patch("services.workflow_service.build_bootstrap_variables"), + patch("services.workflow_service.add_variables_to_pool"), + patch("services.workflow_service.add_node_inputs_to_pool"), + ): + _setup_variable_pool( + query="what is AI?", + files=[], + user_id="u-1", + user_inputs={}, + workflow=workflow, + node_id="start-node", + node_type=BuiltinNodeTypes.START, + conversation_id="conv-abc", + conversation_variables=[], + ) + + # Assert — chatflow system variables should include query, conversation_id and dialogue_count. + MockPool.assert_called_once_with() + system_variable_values = mock_build_system_variables.call_args.args[0] + assert system_variable_values["query"] == "what is AI?" + assert system_variable_values["conversation_id"] == "conv-abc" + assert system_variable_values["dialogue_count"] == 1 + + +class TestRebuildSingleFile: + """ + Tests for the module-level `_rebuild_single_file` function. + Ensures correct delegation to `build_from_mapping` / `build_from_mappings`. + """ + + def test_rebuild_single_file_should_call_build_from_mapping_for_file_type( + self, + ) -> None: + # Arrange + tenant_id = "tenant-1" + value = {"url": "https://example.com/file.pdf", "type": "document"} + mock_file = MagicMock() + + # Act + with patch("services.workflow_service.build_from_mapping", return_value=mock_file) as mock_build: + result = _rebuild_single_file(tenant_id, value, VariableEntityType.FILE) + + # Assert + assert result is mock_file + mock_build.assert_called_once_with(mapping=value, tenant_id=tenant_id, access_controller=ANY) + + def test_rebuild_single_file_should_raise_when_file_value_not_dict( + self, + ) -> None: + # Arrange + Act + Assert + with pytest.raises(ValueError, match="expected dict for file object"): + _rebuild_single_file("tenant-1", "not-a-dict", VariableEntityType.FILE) + + def test_rebuild_single_file_should_call_build_from_mappings_for_file_list( + self, + ) -> None: + # Arrange + tenant_id = "tenant-1" + value = [{"url": "https://example.com/a.pdf"}, {"url": "https://example.com/b.pdf"}] + mock_files = [MagicMock(), MagicMock()] + + # Act + with patch("services.workflow_service.build_from_mappings", return_value=mock_files) as mock_build: + result = _rebuild_single_file(tenant_id, value, VariableEntityType.FILE_LIST) + + # Assert + assert result is mock_files + mock_build.assert_called_once_with(mappings=value, tenant_id=tenant_id, access_controller=ANY) + + def test_rebuild_single_file_should_raise_when_file_list_value_not_list( + self, + ) -> None: + # Arrange + Act + Assert + with pytest.raises(ValueError, match="expected list for file list object"): + _rebuild_single_file("tenant-1", "not-a-list", VariableEntityType.FILE_LIST) + + def test_rebuild_single_file_should_return_empty_list_for_empty_file_list( + self, + ) -> None: + # Arrange + Act + result = _rebuild_single_file("tenant-1", [], VariableEntityType.FILE_LIST) + + # Assert + assert result == [] + + def test_rebuild_single_file_should_raise_when_first_element_not_dict( + self, + ) -> None: + # Arrange + Act + Assert + with pytest.raises(ValueError, match="expected dict for first element"): + _rebuild_single_file("tenant-1", ["not-a-dict"], VariableEntityType.FILE_LIST) + + +class TestRebuildFileForUserInputsInStartNode: + """ + Tests for the module-level `_rebuild_file_for_user_inputs_in_start_node` function. + """ + + def _make_start_node_data(self, variables: list) -> MagicMock: + start_data = MagicMock() + start_data.variables = variables + return start_data + + def _make_variable(self, name: str, var_type: VariableEntityType) -> MagicMock: + var = MagicMock() + var.variable = name + var.type = var_type + return var + + def test_rebuild_should_pass_through_non_file_variables( + self, + ) -> None: + # Arrange + text_var = self._make_variable("query", VariableEntityType.TEXT_INPUT) + start_data = self._make_start_node_data([text_var]) + user_inputs = {"query": "hello world"} + + # Act + result = _rebuild_file_for_user_inputs_in_start_node( + tenant_id="tenant-1", + start_node_data=start_data, + user_inputs=user_inputs, + ) + + # Assert — non-file inputs are untouched + assert result["query"] == "hello world" + + def test_rebuild_should_rebuild_file_variable( + self, + ) -> None: + # Arrange + file_var = self._make_variable("attachment", VariableEntityType.FILE) + start_data = self._make_start_node_data([file_var]) + file_value = {"url": "https://example.com/file.pdf"} + user_inputs = {"attachment": file_value} + mock_file = MagicMock() + + # Act + with patch("services.workflow_service.build_from_mapping", return_value=mock_file): + result = _rebuild_file_for_user_inputs_in_start_node( + tenant_id="tenant-1", + start_node_data=start_data, + user_inputs=user_inputs, + ) + + # Assert — the dict value should be replaced by the rebuilt File object + assert result["attachment"] is mock_file + + def test_rebuild_should_skip_variable_not_in_inputs( + self, + ) -> None: + # Arrange + file_var = self._make_variable("attachment", VariableEntityType.FILE) + start_data = self._make_start_node_data([file_var]) + user_inputs: dict = {} # attachment not provided + + # Act + result = _rebuild_file_for_user_inputs_in_start_node( + tenant_id="tenant-1", + start_node_data=start_data, + user_inputs=user_inputs, + ) + + # Assert — no key should be added for missing inputs + assert "attachment" not in result + + +class TestWorkflowServiceResolveDeliveryMethod: + """ + Tests for the static helper `_resolve_human_input_delivery_method`. + """ + + def _make_method(self, method_id) -> MagicMock: + m = MagicMock() + m.id = method_id + return m + + def test_resolve_delivery_method_should_return_method_when_id_matches(self) -> None: + # Arrange + method_a = self._make_method("method-1") + method_b = self._make_method("method-2") + + # Act + with patch("services.workflow_service.parse_human_input_delivery_methods", return_value=[method_a, method_b]): + result = WorkflowService._resolve_human_input_delivery_method( + node_data=MagicMock(), delivery_method_id="method-2" + ) + + # Assert + assert result is method_b + + def test_resolve_delivery_method_should_return_none_when_no_match(self) -> None: + # Arrange + method_a = self._make_method("method-1") + + # Act + with patch("services.workflow_service.parse_human_input_delivery_methods", return_value=[method_a]): + result = WorkflowService._resolve_human_input_delivery_method( + node_data=MagicMock(), delivery_method_id="does-not-exist" + ) + + # Assert + assert result is None + + def test_resolve_delivery_method_should_return_none_for_empty_methods(self) -> None: + # Act + with patch("services.workflow_service.parse_human_input_delivery_methods", return_value=[]): + result = WorkflowService._resolve_human_input_delivery_method( + node_data=MagicMock(), delivery_method_id="method-1" + ) + + # Assert + assert result is None + + +# =========================================================================== +# TestWorkflowServiceDraftExecution +# Tests for run_draft_workflow_node +# =========================================================================== + + +class TestWorkflowServiceDraftExecution: + @pytest.fixture + def service(self) -> WorkflowService: + with patch("services.workflow_service.db"): + return WorkflowService() + + def test_run_draft_workflow_node_should_execute_start_node_successfully(self, service: WorkflowService) -> None: + # Arrange + app = MagicMock(spec=App) + app.id = "app-1" + app.tenant_id = "tenant-1" + account = MagicMock() + account.id = "user-1" + + draft_workflow = MagicMock(spec=Workflow) + draft_workflow.id = "wf-1" + draft_workflow.tenant_id = "tenant-1" + draft_workflow.app_id = "app-1" + draft_workflow.graph_dict = {"nodes": []} + + node_id = "start-node" + node_config = {"id": node_id, "data": MagicMock(type=BuiltinNodeTypes.START)} + draft_workflow.get_node_config_by_id.return_value = node_config + draft_workflow.get_enclosing_node_type_and_id.return_value = None + + service.get_draft_workflow = MagicMock(return_value=draft_workflow) + + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "exec-1" + node_execution.process_data = {} + + # Mocking complex dependencies + with ( + patch("services.workflow_service.db"), + patch("services.workflow_service.Session"), + patch("services.workflow_service.WorkflowDraftVariableService"), + patch("services.workflow_service.StartNodeData") as mock_start_data, + patch( + "services.workflow_service._rebuild_file_for_user_inputs_in_start_node", + side_effect=lambda **kwargs: kwargs["user_inputs"], + ), + patch("services.workflow_service._setup_variable_pool"), + patch("services.workflow_service.DraftVarLoader"), + patch("services.workflow_service.WorkflowEntry.single_step_run") as mock_run, + patch("services.workflow_service.DifyCoreRepositoryFactory") as mock_repo_factory, + patch("services.workflow_service.DraftVariableSaver") as mock_saver_cls, + patch("services.workflow_service.storage"), + ): + mock_node = MagicMock() + mock_node.node_type = BuiltinNodeTypes.START + mock_node.title = "Start Node" + mock_run_result = NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={}, outputs={"result": "ok"} + ) + mock_event = NodeRunSucceededEvent( + id=str(uuid.uuid4()), + node_id="start-node", + node_type=BuiltinNodeTypes.START, + node_run_result=mock_run_result, + start_at=naive_utc_now(), + ) + mock_run.return_value = (mock_node, [mock_event]) + + mock_repo = MagicMock() + mock_repo_factory.create_workflow_node_execution_repository.return_value = mock_repo + + service._node_execution_service_repo = MagicMock() + mock_execution_record = MagicMock() + mock_execution_record.node_type = "start" + mock_execution_record.node_id = "start-node" + mock_execution_record.load_full_outputs.return_value = {} + service._node_execution_service_repo.get_execution_by_id.return_value = mock_execution_record + + # Act + result = service.run_draft_workflow_node( + app_model=app, + draft_workflow=draft_workflow, + account=account, + node_id=node_id, + user_inputs={"key": "val"}, + query="hi", + files=[], + ) + + # Assert + assert result is not None + mock_run.assert_called_once() + mock_repo.save.assert_called_once() + mock_saver_cls.return_value.save.assert_called_once() + + def test_run_draft_workflow_node_should_execute_non_start_node_successfully(self, service: WorkflowService) -> None: + # Arrange + app = MagicMock(spec=App) + account = MagicMock() + draft_workflow = MagicMock(spec=Workflow) + draft_workflow.graph_dict = {"nodes": []} + node_id = "llm-node" + node_config = {"id": node_id, "data": MagicMock(type=BuiltinNodeTypes.LLM)} + draft_workflow.get_node_config_by_id.return_value = node_config + draft_workflow.get_enclosing_node_type_and_id.return_value = None + service.get_draft_workflow = MagicMock(return_value=draft_workflow) + + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "exec-1" + node_execution.process_data = {} + + with ( + patch("services.workflow_service.db"), + patch("services.workflow_service.Session"), + patch("services.workflow_service.WorkflowDraftVariableService"), + patch("services.workflow_service.VariablePool") as mock_pool_cls, + patch("services.workflow_service.default_system_variables") as mock_default_system_variables, + patch("services.workflow_service.build_bootstrap_variables") as mock_build_bootstrap_variables, + patch("services.workflow_service.add_variables_to_pool") as mock_add_variables_to_pool, + patch("services.workflow_service.DraftVarLoader"), + patch("services.workflow_service.WorkflowEntry.single_step_run") as mock_run, + patch("services.workflow_service.DifyCoreRepositoryFactory"), + patch("services.workflow_service.DraftVariableSaver"), + patch("services.workflow_service.storage"), + ): + mock_node = MagicMock() + mock_node.node_type = BuiltinNodeTypes.LLM + mock_node.title = "LLM Node" + mock_run_result = NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={}, outputs={"result": "ok"} + ) + mock_event = NodeRunSucceededEvent( + id=str(uuid.uuid4()), + node_id="llm-node", + node_type=BuiltinNodeTypes.LLM, + node_run_result=mock_run_result, + start_at=naive_utc_now(), + ) + mock_run.return_value = (mock_node, [mock_event]) + + service._node_execution_service_repo = MagicMock() + mock_execution_record = MagicMock() + mock_execution_record.node_type = "llm" + mock_execution_record.node_id = "llm-node" + mock_execution_record.load_full_outputs.return_value = {"answer": "hello"} + service._node_execution_service_repo.get_execution_by_id.return_value = mock_execution_record + + # Act + service.run_draft_workflow_node( + app_model=app, + draft_workflow=draft_workflow, + account=account, + node_id=node_id, + user_inputs={}, + query="", + files=None, + ) + + # Assert + # For non-start nodes, bootstrap variables should be loaded into an empty pool. + mock_pool_cls.assert_called_once_with() + mock_default_system_variables.assert_called_once() + mock_build_bootstrap_variables.assert_called_once_with( + system_variables=mock_default_system_variables.return_value, + environment_variables=draft_workflow.environment_variables, + ) + mock_add_variables_to_pool.assert_called_once_with( + mock_pool_cls.return_value, mock_build_bootstrap_variables.return_value + ) + + +# =========================================================================== +# TestWorkflowServiceHumanInputOperations +# Tests for Human Input related methods +# =========================================================================== + + +class TestWorkflowServiceHumanInputOperations: + @pytest.fixture + def service(self) -> WorkflowService: + with patch("services.workflow_service.db"): + return WorkflowService() + + def test_get_human_input_form_preview_should_raise_if_workflow_not_init(self, service: WorkflowService) -> None: + service.get_draft_workflow = MagicMock(return_value=None) + with pytest.raises(ValueError, match="Workflow not initialized"): + service.get_human_input_form_preview(app_model=MagicMock(), account=MagicMock(), node_id="node-1") + + def test_get_human_input_form_preview_should_raise_if_wrong_node_type(self, service: WorkflowService) -> None: + draft = MagicMock() + draft.get_node_config_by_id.return_value = {"data": {"type": "llm"}} + service.get_draft_workflow = MagicMock(return_value=draft) + with patch("models.workflow.Workflow.get_node_type_from_node_config", return_value=BuiltinNodeTypes.LLM): + with pytest.raises(ValueError, match="Node type must be human-input"): + service.get_human_input_form_preview(app_model=MagicMock(), account=MagicMock(), node_id="node-1") + + def test_get_human_input_form_preview_success(self, service: WorkflowService) -> None: + app_model = MagicMock(spec=App) + app_model.id = "app-1" + app_model.tenant_id = "tenant-1" + + account = MagicMock() + account.id = "user-1" + + draft = MagicMock() + draft.id = "wf-1" + draft.tenant_id = "tenant-1" + draft.app_id = "app-1" + draft.graph_dict = {"nodes": []} + draft.get_node_config_by_id.return_value = { + "id": "node-1", + "data": MagicMock(type=BuiltinNodeTypes.HUMAN_INPUT), + } + service.get_draft_workflow = MagicMock(return_value=draft) + + mock_node = MagicMock() + mock_node.render_form_content_before_submission.return_value = "rendered" + mock_node.resolve_default_values.return_value = {"def": 1} + mock_node.title = "Form Title" + mock_node.node_data = MagicMock() + + with ( + patch("services.workflow_service.db"), + patch("services.workflow_service.WorkflowDraftVariableService"), + patch("models.workflow.Workflow.get_node_type_from_node_config", return_value=BuiltinNodeTypes.HUMAN_INPUT), + patch.object(service, "_build_human_input_variable_pool"), + patch("services.workflow_service.HumanInputNode", return_value=mock_node), + patch("services.workflow_service.HumanInputRequired") as mock_required_cls, + ): + service.get_human_input_form_preview(app_model=app_model, account=account, node_id="node-1") + mock_node.render_form_content_before_submission.assert_called_once() + mock_required_cls.return_value.model_dump.assert_called_once() + + def test_submit_human_input_form_preview_success(self, service: WorkflowService) -> None: + app_model = MagicMock(spec=App) + app_model.id = "app-1" + app_model.tenant_id = "tenant-1" + + account = MagicMock() + account.id = "user-1" + + draft = MagicMock() + draft.id = "wf-1" + draft.tenant_id = "tenant-1" + draft.app_id = "app-1" + draft.graph_dict = {"nodes": []} + draft.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": "human-input"}} + service.get_draft_workflow = MagicMock(return_value=draft) + + mock_node = MagicMock() + mock_node.node_data = MagicMock() + mock_node.node_data.outputs_field_names.return_value = ["field1"] + + with ( + patch("services.workflow_service.db"), + patch("services.workflow_service.WorkflowDraftVariableService"), + patch("models.workflow.Workflow.get_node_type_from_node_config", return_value=BuiltinNodeTypes.HUMAN_INPUT), + patch.object(service, "_build_human_input_variable_pool"), + patch("services.workflow_service.HumanInputNode", return_value=mock_node), + patch("services.workflow_service.validate_human_input_submission"), + patch("services.workflow_service.Session"), + patch("services.workflow_service.DraftVariableSaver") as mock_saver_cls, + ): + result = service.submit_human_input_form_preview( + app_model=app_model, account=account, node_id="node-1", form_inputs={"field1": "val1"}, action="submit" + ) + assert result["__action_id"] == "submit" + mock_saver_cls.return_value.save.assert_called_once() + + def test_test_human_input_delivery_success(self, service: WorkflowService) -> None: + draft = MagicMock() + draft.get_node_config_by_id.return_value = {"data": {"type": "human-input"}} + service.get_draft_workflow = MagicMock(return_value=draft) + + with ( + patch("models.workflow.Workflow.get_node_type_from_node_config", return_value=BuiltinNodeTypes.HUMAN_INPUT), + patch("services.workflow_service.HumanInputNodeData.model_validate"), + patch.object(service, "_resolve_human_input_delivery_method") as mock_resolve, + patch("services.workflow_service.apply_dify_debug_email_recipient"), + patch.object(service, "_build_human_input_variable_pool"), + patch.object(service, "_build_human_input_node"), + patch.object(service, "_create_human_input_delivery_test_form", return_value=("form-1", [])), + patch("services.workflow_service.HumanInputDeliveryTestService") as mock_test_srv, + ): + mock_resolve.return_value = MagicMock() + service.test_human_input_delivery( + app_model=MagicMock(), account=MagicMock(), node_id="node-1", delivery_method_id="method-1" + ) + mock_test_srv.return_value.send_test.assert_called_once() + + def test_test_human_input_delivery_failure_cases(self, service: WorkflowService) -> None: + draft = MagicMock() + draft.get_node_config_by_id.return_value = {"data": {"type": "human-input"}} + service.get_draft_workflow = MagicMock(return_value=draft) + + with ( + patch("models.workflow.Workflow.get_node_type_from_node_config", return_value=BuiltinNodeTypes.HUMAN_INPUT), + patch("services.workflow_service.HumanInputNodeData.model_validate"), + patch.object(service, "_resolve_human_input_delivery_method", return_value=None), + ): + with pytest.raises(ValueError, match="Delivery method not found"): + service.test_human_input_delivery( + app_model=MagicMock(), account=MagicMock(), node_id="node-1", delivery_method_id="none" + ) + + def test_load_email_recipients_parsing_failure(self, service: WorkflowService) -> None: + # Arrange + mock_recipient = MagicMock() + mock_recipient.recipient_payload = "invalid json" + mock_recipient.recipient_type = RecipientType.EMAIL_MEMBER + + with ( + patch("services.workflow_service.db"), + patch("services.workflow_service.WorkflowDraftVariableService"), + patch("services.workflow_service.Session") as mock_session_cls, + patch("services.workflow_service.select"), + patch("services.workflow_service.json.loads", side_effect=ValueError("bad json")), + ): + mock_session = mock_session_cls.return_value.__enter__.return_value + # sqlalchemy assertions check for .bind + mock_session.bind = MagicMock() # removed spec=Engine to avoid import issues for now + mock_session.scalars.return_value.all.return_value = [mock_recipient] + + # Act + # _load_email_recipients(form_id: str) is a static method + result = WorkflowService._load_email_recipients("form-1") + + # Assert + assert result == [] # Should fall back to empty list on parsing error + + def test_build_human_input_variable_pool(self, service: WorkflowService) -> None: + workflow = MagicMock() + workflow.environment_variables = [] + workflow.graph_dict = {} + + with ( + patch("services.workflow_service.db"), + patch("services.workflow_service.Session"), + patch("services.workflow_service.WorkflowDraftVariableService"), + patch("services.workflow_service.VariablePool") as mock_pool_cls, + patch("services.workflow_service.DraftVarLoader"), + patch("services.workflow_service.HumanInputNode.extract_variable_selector_to_variable_mapping"), + patch("services.workflow_service.load_into_variable_pool"), + patch("services.workflow_service.WorkflowEntry.mapping_user_inputs_to_variable_pool"), + ): + service._build_human_input_variable_pool( + app_model=MagicMock(), workflow=workflow, node_config={}, manual_inputs={}, user_id="user-1" + ) + mock_pool_cls.assert_called_once() + + +# =========================================================================== +# TestWorkflowServiceFreeNodeExecution +# Tests for run_free_workflow_node and handle_single_step_result +# =========================================================================== + + +class TestWorkflowServiceFreeNodeExecution: + @pytest.fixture + def service(self) -> WorkflowService: + with patch("services.workflow_service.db"): + return WorkflowService() + + def test_run_free_workflow_node_success(self, service: WorkflowService) -> None: + node_execution = MagicMock() + with ( + patch.object(service, "_handle_single_step_result", return_value=node_execution), + patch("services.workflow_service.WorkflowEntry.run_free_node"), + ): + result = service.run_free_workflow_node({}, "tenant-1", "user-1", "node-1", {}) + assert result == node_execution + + def test_validate_graph_structure_coexist_error(self, service: WorkflowService) -> None: + graph = { + "nodes": [ + {"data": {"type": "start"}}, + {"data": {"type": "trigger-webhook"}}, # is_trigger_node=True + ] + } + with pytest.raises(ValueError, match="Start node and trigger nodes cannot coexist"): + service.validate_graph_structure(graph) + + def test_validate_features_structure_success(self, service: WorkflowService) -> None: + app = MagicMock() + app.mode = "workflow" + features = {} + with patch("services.workflow_service.WorkflowAppConfigManager.config_validate") as mock_val: + service.validate_features_structure(app, features) + mock_val.assert_called_once() + + def test_validate_features_structure_invalid_mode(self, service: WorkflowService) -> None: + app = MagicMock() + app.mode = "invalid" + with pytest.raises(ValueError, match="Invalid app mode"): + service.validate_features_structure(app, {}) + + def test_validate_human_input_node_data_error(self, service: WorkflowService) -> None: + with patch( + "graphon.nodes.human_input.entities.HumanInputNodeData.model_validate", side_effect=Exception("error") + ): + with pytest.raises(ValueError, match="Invalid HumanInput node data"): + service._validate_human_input_node_data({}) + + def test_rebuild_single_file_unreachable(self) -> None: + # Test line 1523 (unreachable) + with pytest.raises(Exception, match="unreachable"): + _rebuild_single_file("tenant-1", {}, cast(Any, "invalid_type")) + + def test_build_human_input_node(self, service: WorkflowService) -> None: + """Cover _build_human_input_node (lines 1065-1088).""" + workflow = MagicMock() + workflow.id = "wf-1" + workflow.tenant_id = "t-1" + workflow.app_id = "app-1" + account = MagicMock() + account.id = "u-1" + node_config = {"id": "n-1"} + variable_pool = MagicMock() + + with ( + patch("services.workflow_service.GraphInitParams") as mock_graph_init_params, + patch("services.workflow_service.GraphRuntimeState"), + patch("services.workflow_service.build_dify_run_context"), + patch("services.workflow_service.DifyHumanInputNodeRuntime") as mock_runtime_cls, + patch("services.workflow_service.HumanInputNode") as mock_node_cls, + ): + node = service._build_human_input_node( + workflow=workflow, account=account, node_config=node_config, variable_pool=variable_pool + ) + assert node == mock_node_cls.return_value + mock_node_cls.assert_called_once() + mock_runtime_cls.assert_called_once_with(mock_graph_init_params.return_value.run_context) diff --git a/api/tests/unit_tests/services/tools/test_tools_transform_service.py b/api/tests/unit_tests/services/tools/test_tools_transform_service.py deleted file mode 100644 index 9616d2f1023..00000000000 --- a/api/tests/unit_tests/services/tools/test_tools_transform_service.py +++ /dev/null @@ -1,452 +0,0 @@ -from unittest.mock import Mock - -from core.tools.__base.tool import Tool -from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity -from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolParameter, ToolProviderType -from services.tools.tools_transform_service import ToolTransformService - - -class TestToolTransformService: - """Test cases for ToolTransformService.convert_tool_entity_to_api_entity method""" - - def test_convert_tool_with_parameter_override(self): - """Test that runtime parameters correctly override base parameters""" - # Create mock base parameters - base_param1 = Mock(spec=ToolParameter) - base_param1.name = "param1" - base_param1.form = ToolParameter.ToolParameterForm.FORM - base_param1.type = "string" - base_param1.label = "Base Param 1" - - base_param2 = Mock(spec=ToolParameter) - base_param2.name = "param2" - base_param2.form = ToolParameter.ToolParameterForm.FORM - base_param2.type = "string" - base_param2.label = "Base Param 2" - - # Create mock runtime parameters that override base parameters - runtime_param1 = Mock(spec=ToolParameter) - runtime_param1.name = "param1" - runtime_param1.form = ToolParameter.ToolParameterForm.FORM - runtime_param1.type = "string" - runtime_param1.label = "Runtime Param 1" # Different label to verify override - - # Create mock tool - mock_tool = Mock(spec=Tool) - mock_tool.entity = Mock() - mock_tool.entity.parameters = [base_param1, base_param2] - mock_tool.entity.identity = Mock() - mock_tool.entity.identity.author = "test_author" - mock_tool.entity.identity.name = "test_tool" - mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") - mock_tool.entity.description = Mock() - mock_tool.entity.description.human = I18nObject(en_US="Test description") - mock_tool.entity.output_schema = {} - mock_tool.get_runtime_parameters.return_value = [runtime_param1] - - # Mock fork_tool_runtime to return the same tool - mock_tool.fork_tool_runtime.return_value = mock_tool - - # Call the method - result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) - - # Verify the result - assert isinstance(result, ToolApiEntity) - assert result.author == "test_author" - assert result.name == "test_tool" - assert result.parameters is not None - assert len(result.parameters) == 2 - - # Find the overridden parameter - overridden_param = next((p for p in result.parameters if p.name == "param1"), None) - assert overridden_param is not None - assert overridden_param.label == "Runtime Param 1" # Should be runtime version - - # Find the non-overridden parameter - original_param = next((p for p in result.parameters if p.name == "param2"), None) - assert original_param is not None - assert original_param.label == "Base Param 2" # Should be base version - - def test_convert_tool_with_additional_runtime_parameters(self): - """Test that additional runtime parameters are added to the final list""" - # Create mock base parameters - base_param1 = Mock(spec=ToolParameter) - base_param1.name = "param1" - base_param1.form = ToolParameter.ToolParameterForm.FORM - base_param1.type = "string" - base_param1.label = "Base Param 1" - - # Create mock runtime parameters - one that overrides and one that's new - runtime_param1 = Mock(spec=ToolParameter) - runtime_param1.name = "param1" - runtime_param1.form = ToolParameter.ToolParameterForm.FORM - runtime_param1.type = "string" - runtime_param1.label = "Runtime Param 1" - - runtime_param2 = Mock(spec=ToolParameter) - runtime_param2.name = "runtime_only" - runtime_param2.form = ToolParameter.ToolParameterForm.FORM - runtime_param2.type = "string" - runtime_param2.label = "Runtime Only Param" - - # Create mock tool - mock_tool = Mock(spec=Tool) - mock_tool.entity = Mock() - mock_tool.entity.parameters = [base_param1] - mock_tool.entity.identity = Mock() - mock_tool.entity.identity.author = "test_author" - mock_tool.entity.identity.name = "test_tool" - mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") - mock_tool.entity.description = Mock() - mock_tool.entity.description.human = I18nObject(en_US="Test description") - mock_tool.entity.output_schema = {} - mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2] - - # Mock fork_tool_runtime to return the same tool - mock_tool.fork_tool_runtime.return_value = mock_tool - - # Call the method - result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) - - # Verify the result - assert isinstance(result, ToolApiEntity) - assert result.parameters is not None - assert len(result.parameters) == 2 - - # Check that both parameters are present - param_names = [p.name for p in result.parameters] - assert "param1" in param_names - assert "runtime_only" in param_names - - # Verify the overridden parameter has runtime version - overridden_param = next((p for p in result.parameters if p.name == "param1"), None) - assert overridden_param is not None - assert overridden_param.label == "Runtime Param 1" - - # Verify the new runtime parameter is included - new_param = next((p for p in result.parameters if p.name == "runtime_only"), None) - assert new_param is not None - assert new_param.label == "Runtime Only Param" - - def test_convert_tool_with_non_form_runtime_parameters(self): - """Test that non-FORM runtime parameters are not added as new parameters""" - # Create mock base parameters - base_param1 = Mock(spec=ToolParameter) - base_param1.name = "param1" - base_param1.form = ToolParameter.ToolParameterForm.FORM - base_param1.type = "string" - base_param1.label = "Base Param 1" - - # Create mock runtime parameters with different forms - runtime_param1 = Mock(spec=ToolParameter) - runtime_param1.name = "param1" - runtime_param1.form = ToolParameter.ToolParameterForm.FORM - runtime_param1.type = "string" - runtime_param1.label = "Runtime Param 1" - - runtime_param2 = Mock(spec=ToolParameter) - runtime_param2.name = "llm_param" - runtime_param2.form = ToolParameter.ToolParameterForm.LLM - runtime_param2.type = "string" - runtime_param2.label = "LLM Param" - - # Create mock tool - mock_tool = Mock(spec=Tool) - mock_tool.entity = Mock() - mock_tool.entity.parameters = [base_param1] - mock_tool.entity.identity = Mock() - mock_tool.entity.identity.author = "test_author" - mock_tool.entity.identity.name = "test_tool" - mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") - mock_tool.entity.description = Mock() - mock_tool.entity.description.human = I18nObject(en_US="Test description") - mock_tool.entity.output_schema = {} - mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2] - - # Mock fork_tool_runtime to return the same tool - mock_tool.fork_tool_runtime.return_value = mock_tool - - # Call the method - result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) - - # Verify the result - assert isinstance(result, ToolApiEntity) - assert result.parameters is not None - assert len(result.parameters) == 1 # Only the FORM parameter should be present - - # Check that only the FORM parameter is present - param_names = [p.name for p in result.parameters] - assert "param1" in param_names - assert "llm_param" not in param_names - - def test_convert_tool_with_empty_parameters(self): - """Test conversion with empty base and runtime parameters""" - # Create mock tool with no parameters - mock_tool = Mock(spec=Tool) - mock_tool.entity = Mock() - mock_tool.entity.parameters = [] - mock_tool.entity.identity = Mock() - mock_tool.entity.identity.author = "test_author" - mock_tool.entity.identity.name = "test_tool" - mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") - mock_tool.entity.description = Mock() - mock_tool.entity.description.human = I18nObject(en_US="Test description") - mock_tool.entity.output_schema = {} - mock_tool.get_runtime_parameters.return_value = [] - - # Mock fork_tool_runtime to return the same tool - mock_tool.fork_tool_runtime.return_value = mock_tool - - # Call the method - result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) - - # Verify the result - assert isinstance(result, ToolApiEntity) - assert result.parameters is not None - assert len(result.parameters) == 0 - - def test_convert_tool_with_none_parameters(self): - """Test conversion when base parameters is None""" - # Create mock tool with None parameters - mock_tool = Mock(spec=Tool) - mock_tool.entity = Mock() - mock_tool.entity.parameters = None - mock_tool.entity.identity = Mock() - mock_tool.entity.identity.author = "test_author" - mock_tool.entity.identity.name = "test_tool" - mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") - mock_tool.entity.description = Mock() - mock_tool.entity.description.human = I18nObject(en_US="Test description") - mock_tool.entity.output_schema = {} - mock_tool.get_runtime_parameters.return_value = [] - - # Mock fork_tool_runtime to return the same tool - mock_tool.fork_tool_runtime.return_value = mock_tool - - # Call the method - result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) - - # Verify the result - assert isinstance(result, ToolApiEntity) - assert result.parameters is not None - assert len(result.parameters) == 0 - - def test_convert_tool_parameter_order_preserved(self): - """Test that parameter order is preserved correctly""" - # Create mock base parameters in specific order - base_param1 = Mock(spec=ToolParameter) - base_param1.name = "param1" - base_param1.form = ToolParameter.ToolParameterForm.FORM - base_param1.type = "string" - base_param1.label = "Base Param 1" - - base_param2 = Mock(spec=ToolParameter) - base_param2.name = "param2" - base_param2.form = ToolParameter.ToolParameterForm.FORM - base_param2.type = "string" - base_param2.label = "Base Param 2" - - base_param3 = Mock(spec=ToolParameter) - base_param3.name = "param3" - base_param3.form = ToolParameter.ToolParameterForm.FORM - base_param3.type = "string" - base_param3.label = "Base Param 3" - - # Create runtime parameter that overrides middle parameter - runtime_param2 = Mock(spec=ToolParameter) - runtime_param2.name = "param2" - runtime_param2.form = ToolParameter.ToolParameterForm.FORM - runtime_param2.type = "string" - runtime_param2.label = "Runtime Param 2" - - # Create new runtime parameter - runtime_param4 = Mock(spec=ToolParameter) - runtime_param4.name = "param4" - runtime_param4.form = ToolParameter.ToolParameterForm.FORM - runtime_param4.type = "string" - runtime_param4.label = "Runtime Param 4" - - # Create mock tool - mock_tool = Mock(spec=Tool) - mock_tool.entity = Mock() - mock_tool.entity.parameters = [base_param1, base_param2, base_param3] - mock_tool.entity.identity = Mock() - mock_tool.entity.identity.author = "test_author" - mock_tool.entity.identity.name = "test_tool" - mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") - mock_tool.entity.description = Mock() - mock_tool.entity.description.human = I18nObject(en_US="Test description") - mock_tool.entity.output_schema = {} - mock_tool.get_runtime_parameters.return_value = [runtime_param2, runtime_param4] - - # Mock fork_tool_runtime to return the same tool - mock_tool.fork_tool_runtime.return_value = mock_tool - - # Call the method - result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) - - # Verify the result - assert isinstance(result, ToolApiEntity) - assert result.parameters is not None - assert len(result.parameters) == 4 - - # Check that order is maintained: base parameters first, then new runtime parameters - param_names = [p.name for p in result.parameters] - assert param_names == ["param1", "param2", "param3", "param4"] - - # Verify that param2 was overridden with runtime version - param2 = result.parameters[1] - assert param2.name == "param2" - assert param2.label == "Runtime Param 2" - - -class TestWorkflowProviderToUserProvider: - """Test cases for ToolTransformService.workflow_provider_to_user_provider method""" - - def test_workflow_provider_to_user_provider_with_workflow_app_id(self): - """Test that workflow_provider_to_user_provider correctly sets workflow_app_id.""" - from core.tools.workflow_as_tool.provider import WorkflowToolProviderController - - # Create mock workflow tool provider controller - workflow_app_id = "app_123" - provider_id = "provider_123" - mock_controller = Mock(spec=WorkflowToolProviderController) - mock_controller.provider_id = provider_id - mock_controller.entity = Mock() - mock_controller.entity.identity = Mock() - mock_controller.entity.identity.author = "test_author" - mock_controller.entity.identity.name = "test_workflow_tool" - mock_controller.entity.identity.description = I18nObject(en_US="Test description") - mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"} - mock_controller.entity.identity.icon_dark = None - mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool") - - # Call the method - result = ToolTransformService.workflow_provider_to_user_provider( - provider_controller=mock_controller, - labels=["label1", "label2"], - workflow_app_id=workflow_app_id, - ) - - # Verify the result - assert isinstance(result, ToolProviderApiEntity) - assert result.id == provider_id - assert result.author == "test_author" - assert result.name == "test_workflow_tool" - assert result.type == ToolProviderType.WORKFLOW - assert result.workflow_app_id == workflow_app_id - assert result.labels == ["label1", "label2"] - assert result.is_team_authorization is True - assert result.plugin_id is None - assert result.plugin_unique_identifier is None - assert result.tools == [] - - def test_workflow_provider_to_user_provider_without_workflow_app_id(self): - """Test that workflow_provider_to_user_provider works when workflow_app_id is not provided.""" - from core.tools.workflow_as_tool.provider import WorkflowToolProviderController - - # Create mock workflow tool provider controller - provider_id = "provider_123" - mock_controller = Mock(spec=WorkflowToolProviderController) - mock_controller.provider_id = provider_id - mock_controller.entity = Mock() - mock_controller.entity.identity = Mock() - mock_controller.entity.identity.author = "test_author" - mock_controller.entity.identity.name = "test_workflow_tool" - mock_controller.entity.identity.description = I18nObject(en_US="Test description") - mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"} - mock_controller.entity.identity.icon_dark = None - mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool") - - # Call the method without workflow_app_id - result = ToolTransformService.workflow_provider_to_user_provider( - provider_controller=mock_controller, - labels=["label1"], - ) - - # Verify the result - assert isinstance(result, ToolProviderApiEntity) - assert result.id == provider_id - assert result.workflow_app_id is None - assert result.labels == ["label1"] - - def test_workflow_provider_to_user_provider_workflow_app_id_none(self): - """Test that workflow_provider_to_user_provider handles None workflow_app_id explicitly.""" - from core.tools.workflow_as_tool.provider import WorkflowToolProviderController - - # Create mock workflow tool provider controller - provider_id = "provider_123" - mock_controller = Mock(spec=WorkflowToolProviderController) - mock_controller.provider_id = provider_id - mock_controller.entity = Mock() - mock_controller.entity.identity = Mock() - mock_controller.entity.identity.author = "test_author" - mock_controller.entity.identity.name = "test_workflow_tool" - mock_controller.entity.identity.description = I18nObject(en_US="Test description") - mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"} - mock_controller.entity.identity.icon_dark = None - mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool") - - # Call the method with explicit None values - result = ToolTransformService.workflow_provider_to_user_provider( - provider_controller=mock_controller, - labels=None, - workflow_app_id=None, - ) - - # Verify the result - assert isinstance(result, ToolProviderApiEntity) - assert result.id == provider_id - assert result.workflow_app_id is None - assert result.labels == [] - - def test_workflow_provider_to_user_provider_preserves_other_fields(self): - """Test that workflow_provider_to_user_provider preserves all other entity fields.""" - from core.tools.workflow_as_tool.provider import WorkflowToolProviderController - - # Create mock workflow tool provider controller with various fields - workflow_app_id = "app_456" - provider_id = "provider_456" - mock_controller = Mock(spec=WorkflowToolProviderController) - mock_controller.provider_id = provider_id - mock_controller.entity = Mock() - mock_controller.entity.identity = Mock() - mock_controller.entity.identity.author = "another_author" - mock_controller.entity.identity.name = "another_workflow_tool" - mock_controller.entity.identity.description = I18nObject( - en_US="Another description", zh_Hans="Another description" - ) - mock_controller.entity.identity.icon = {"type": "emoji", "content": "⚙️"} - mock_controller.entity.identity.icon_dark = {"type": "emoji", "content": "🔧"} - mock_controller.entity.identity.label = I18nObject( - en_US="Another Workflow Tool", zh_Hans="Another Workflow Tool" - ) - - # Call the method - result = ToolTransformService.workflow_provider_to_user_provider( - provider_controller=mock_controller, - labels=["automation", "workflow"], - workflow_app_id=workflow_app_id, - ) - - # Verify all fields are preserved correctly - assert isinstance(result, ToolProviderApiEntity) - assert result.id == provider_id - assert result.author == "another_author" - assert result.name == "another_workflow_tool" - assert result.description.en_US == "Another description" - assert result.description.zh_Hans == "Another description" - assert result.icon == {"type": "emoji", "content": "⚙️"} - assert result.icon_dark == {"type": "emoji", "content": "🔧"} - assert result.label.en_US == "Another Workflow Tool" - assert result.label.zh_Hans == "Another Workflow Tool" - assert result.type == ToolProviderType.WORKFLOW - assert result.workflow_app_id == workflow_app_id - assert result.labels == ["automation", "workflow"] - assert result.masked_credentials == {} - assert result.is_team_authorization is True - assert result.allow_delete is True - assert result.plugin_id is None - assert result.plugin_unique_identifier is None - assert result.tools == [] diff --git a/api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py deleted file mode 100644 index ae59da0a3da..00000000000 --- a/api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py +++ /dev/null @@ -1,162 +0,0 @@ -import json -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest - -from core.tools.entities.tool_entities import ToolParameter, WorkflowToolParameterConfiguration -from core.tools.errors import WorkflowToolHumanInputNotSupportedError -from models.model import App -from models.tools import WorkflowToolProvider -from services.tools import workflow_tools_manage_service - - -class DummyWorkflow: - def __init__(self, graph_dict: dict, version: str = "1.0.0") -> None: - self._graph_dict = graph_dict - self.version = version - - @property - def graph_dict(self) -> dict: - return self._graph_dict - - -class FakeQuery: - def __init__(self, result): - self._result = result - - def where(self, *args, **kwargs): - return self - - def first(self): - return self._result - - -class DummySession: - def __init__(self) -> None: - self.added: list[object] = [] - - def __enter__(self) -> "DummySession": - return self - - def __exit__(self, exc_type, exc, tb) -> bool: - return False - - def add(self, obj) -> None: - self.added.append(obj) - - def begin(self): - return DummyBegin(self) - - -class DummyBegin: - def __init__(self, session: DummySession) -> None: - self._session = session - - def __enter__(self) -> DummySession: - return self._session - - def __exit__(self, exc_type, exc, tb) -> bool: - return False - - -class DummySessionContext: - def __init__(self, session: DummySession) -> None: - self._session = session - - def __enter__(self) -> DummySession: - return self._session - - def __exit__(self, exc_type, exc, tb) -> bool: - return False - - -class DummySessionFactory: - def __init__(self, session: DummySession) -> None: - self._session = session - - def create_session(self) -> DummySessionContext: - return DummySessionContext(self._session) - - -def _build_fake_session(app) -> SimpleNamespace: - def query(model): - if model is WorkflowToolProvider: - return FakeQuery(None) - if model is App: - return FakeQuery(app) - return FakeQuery(None) - - return SimpleNamespace(query=query) - - -def _build_parameters() -> list[WorkflowToolParameterConfiguration]: - return [ - WorkflowToolParameterConfiguration(name="input", description="input", form=ToolParameter.ToolParameterForm.LLM), - ] - - -def test_create_workflow_tool_rejects_human_input_nodes(monkeypatch): - workflow = DummyWorkflow(graph_dict={"nodes": [{"id": "node_1", "data": {"type": "human-input"}}]}) - app = SimpleNamespace(workflow=workflow) - - fake_session = _build_fake_session(app) - monkeypatch.setattr(workflow_tools_manage_service.db, "session", fake_session) - - mock_from_db = MagicMock() - monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", mock_from_db) - mock_invalidate = MagicMock() - - with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info: - workflow_tools_manage_service.WorkflowToolManageService.create_workflow_tool( - user_id="user-id", - tenant_id="tenant-id", - workflow_app_id="app-id", - name="tool_name", - label="Tool", - icon={"type": "emoji", "emoji": "tool"}, - description="desc", - parameters=_build_parameters(), - ) - - assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" - mock_from_db.assert_not_called() - mock_invalidate.assert_not_called() - - -def test_create_workflow_tool_success(monkeypatch): - workflow = DummyWorkflow(graph_dict={"nodes": [{"id": "node_1", "data": {"type": "start"}}]}) - app = SimpleNamespace(workflow=workflow) - - fake_db = MagicMock() - fake_session = _build_fake_session(app) - fake_db.session = fake_session - monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) - - dummy_session = DummySession() - monkeypatch.setattr(workflow_tools_manage_service, "Session", lambda *_, **__: dummy_session) - - mock_from_db = MagicMock() - monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", mock_from_db) - - icon = {"type": "emoji", "emoji": "tool"} - - result = workflow_tools_manage_service.WorkflowToolManageService.create_workflow_tool( - user_id="user-id", - tenant_id="tenant-id", - workflow_app_id="app-id", - name="tool_name", - label="Tool", - icon=icon, - description="desc", - parameters=_build_parameters(), - ) - - assert result == {"result": "success"} - assert len(dummy_session.added) == 1 - created_provider = dummy_session.added[0] - assert created_provider.name == "tool_name" - assert created_provider.label == "Tool" - assert created_provider.icon == json.dumps(icon) - assert created_provider.version == workflow.version - mock_from_db.assert_called_once() diff --git a/api/tests/unit_tests/services/vector_service.py b/api/tests/unit_tests/services/vector_service.py index c99275c6b27..ee9ba1c6d68 100644 --- a/api/tests/unit_tests/services/vector_service.py +++ b/api/tests/unit_tests/services/vector_service.py @@ -121,6 +121,7 @@ import pytest from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.models.document import Document from models.dataset import ChildChunk, Dataset, DatasetDocument, DatasetProcessRule, DocumentSegment from services.vector_service import VectorService @@ -151,8 +152,8 @@ class VectorServiceTestDataFactory: def create_dataset_mock( dataset_id: str = "dataset-123", tenant_id: str = "tenant-123", - doc_form: str = "text_model", - indexing_technique: str = "high_quality", + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, embedding_model_provider: str = "openai", embedding_model: str = "text-embedding-ada-002", index_struct_dict: dict | None = None, @@ -493,7 +494,7 @@ class TestVectorService: """ # Arrange dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form="text_model", indexing_technique="high_quality" + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_technique=IndexTechniqueType.HIGH_QUALITY ) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -505,7 +506,7 @@ class TestVectorService: mock_index_processor_factory.return_value.init_index_processor.return_value = mock_index_processor # Act - VectorService.create_segments_vector(keywords_list, [segment], dataset, "text_model") + VectorService.create_segments_vector(keywords_list, [segment], dataset, IndexStructureType.PARAGRAPH_INDEX) # Assert mock_index_processor.load.assert_called_once() @@ -521,7 +522,7 @@ class TestVectorService: assert call_args[1]["keywords_list"] == keywords_list @patch("services.vector_service.VectorService.generate_child_chunks") - @patch("services.vector_service.ModelManager") + @patch("services.vector_service.ModelManager.for_tenant") @patch("services.vector_service.db") def test_create_segments_vector_parent_child_indexing( self, mock_db, mock_model_manager, mock_generate_child_chunks @@ -534,7 +535,7 @@ class TestVectorService: """ # Arrange dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form="parent_child_model", indexing_technique="high_quality" + doc_form="parent_child_model", indexing_technique=IndexTechniqueType.HIGH_QUALITY ) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -567,7 +568,7 @@ class TestVectorService: """ # Arrange dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form="parent_child_model", indexing_technique="high_quality" + doc_form="parent_child_model", indexing_technique=IndexTechniqueType.HIGH_QUALITY ) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -590,7 +591,7 @@ class TestVectorService: """ # Arrange dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form="parent_child_model", indexing_technique="high_quality" + doc_form="parent_child_model", indexing_technique=IndexTechniqueType.HIGH_QUALITY ) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -615,7 +616,7 @@ class TestVectorService: """ # Arrange dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form="parent_child_model", indexing_technique="economy" + doc_form="parent_child_model", indexing_technique=IndexTechniqueType.ECONOMY ) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -649,7 +650,7 @@ class TestVectorService: mock_index_processor_factory.return_value.init_index_processor.return_value = mock_index_processor # Act - VectorService.create_segments_vector(None, [], dataset, "text_model") + VectorService.create_segments_vector(None, [], dataset, IndexStructureType.PARAGRAPH_INDEX) # Assert mock_index_processor.load.assert_not_called() @@ -668,7 +669,7 @@ class TestVectorService: store when using high_quality indexing. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -694,7 +695,7 @@ class TestVectorService: index when using economy indexing with keywords. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -730,7 +731,7 @@ class TestVectorService: index when using economy indexing without keywords. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -894,7 +895,7 @@ class TestVectorService: when using high_quality indexing. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() @@ -922,7 +923,7 @@ class TestVectorService: using economy indexing. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() @@ -950,7 +951,7 @@ class TestVectorService: when there are new chunks, updated chunks, and deleted chunks. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) new_chunk = VectorServiceTestDataFactory.create_child_chunk_mock(chunk_id="new-chunk-1") @@ -992,7 +993,7 @@ class TestVectorService: add_texts is called, not delete_by_ids. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) new_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() @@ -1018,7 +1019,7 @@ class TestVectorService: delete_by_ids is called, not add_texts. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) delete_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() @@ -1044,7 +1045,7 @@ class TestVectorService: using economy indexing. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) new_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() @@ -1074,7 +1075,7 @@ class TestVectorService: when using high_quality indexing. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() @@ -1098,7 +1099,7 @@ class TestVectorService: using economy indexing. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() @@ -1754,7 +1755,7 @@ class TestVector: # ======================================================================== @patch("core.rag.datasource.vdb.vector_factory.CacheEmbedding") - @patch("core.rag.datasource.vdb.vector_factory.ModelManager") + @patch("core.rag.datasource.vdb.vector_factory.ModelManager.for_tenant") @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") def test_vector_get_embeddings(self, mock_init_vector, mock_model_manager, mock_cache_embedding): """ diff --git a/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py b/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py index f3391d6380c..8525672da8e 100644 --- a/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py +++ b/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py @@ -4,10 +4,12 @@ import json from unittest.mock import Mock, patch import pytest +from graphon.file import File, FileTransferMethod, FileType +from graphon.variables.segments import ObjectSegment, StringSegment +from graphon.variables.types import SegmentType from sqlalchemy import Engine -from dify_graph.variables.segments import ObjectSegment, StringSegment -from dify_graph.variables.types import SegmentType +from core.workflow.file_reference import build_file_reference from models.model import UploadFile from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile from services.workflow_draft_variable_service import DraftVarLoader @@ -54,25 +56,18 @@ class TestDraftVarLoaderSimple: with patch("services.workflow_draft_variable_service.storage") as mock_storage: mock_storage.load.return_value = test_content.encode() - with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable: - mock_variable = Mock() - mock_variable.id = "draft-var-id" - mock_variable.name = "test_variable" - mock_variable.value = StringSegment(value=test_content) - mock_segment_to_variable.return_value = mock_variable + # Execute the method + selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var) - # Execute the method - selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var) + # Verify results + assert selector_tuple == ("test-node-id", "test_variable") + assert variable.id == "draft-var-id" + assert variable.name == "test_variable" + assert variable.description == "test description" + assert variable.value == test_content - # Verify results - assert selector_tuple == ("test-node-id", "test_variable") - assert variable.id == "draft-var-id" - assert variable.name == "test_variable" - assert variable.description == "test description" - assert variable.value == test_content - - # Verify storage was called correctly - mock_storage.load.assert_called_once_with("storage/key/test.txt") + # Verify storage was called correctly + mock_storage.load.assert_called_once_with("storage/key/test.txt") def test_load_offloaded_variable_object_type_unit(self, draft_var_loader): """Test _load_offloaded_variable with object type - isolated unit test.""" @@ -97,31 +92,22 @@ class TestDraftVarLoaderSimple: with patch("services.workflow_draft_variable_service.storage") as mock_storage: mock_storage.load.return_value = test_json_content.encode() + mock_segment = ObjectSegment(value=test_object) + draft_var.build_segment_from_serialized_value.return_value = mock_segment - with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment: - mock_segment = ObjectSegment(value=test_object) - mock_build_segment.return_value = mock_segment + # Execute the method + selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var) - with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable: - mock_variable = Mock() - mock_variable.id = "draft-var-id" - mock_variable.name = "test_object" - mock_variable.value = mock_segment - mock_segment_to_variable.return_value = mock_variable + # Verify results + assert selector_tuple == ("test-node-id", "test_object") + assert variable.id == "draft-var-id" + assert variable.name == "test_object" + assert variable.description == "test description" + assert variable.value == test_object - # Execute the method - selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var) - - # Verify results - assert selector_tuple == ("test-node-id", "test_object") - assert variable.id == "draft-var-id" - assert variable.name == "test_object" - assert variable.description == "test description" - assert variable.value == test_object - - # Verify method calls - mock_storage.load.assert_called_once_with("storage/key/test.json") - mock_build_segment.assert_called_once_with(SegmentType.OBJECT, test_object) + # Verify method calls + mock_storage.load.assert_called_once_with("storage/key/test.json") + draft_var.build_segment_from_serialized_value.assert_called_once_with(SegmentType.OBJECT, test_object) def test_load_offloaded_variable_missing_variable_file_unit(self, draft_var_loader): """Test that assertion error is raised when variable_file is None.""" @@ -176,32 +162,23 @@ class TestDraftVarLoaderSimple: with patch("services.workflow_draft_variable_service.storage") as mock_storage: mock_storage.load.return_value = test_json_content.encode() + from graphon.variables.segments import FloatSegment - with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment: - from dify_graph.variables.segments import FloatSegment + mock_segment = FloatSegment(value=test_number) + draft_var.build_segment_from_serialized_value.return_value = mock_segment - mock_segment = FloatSegment(value=test_number) - mock_build_segment.return_value = mock_segment + # Execute the method + selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var) - with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable: - mock_variable = Mock() - mock_variable.id = "draft-var-id" - mock_variable.name = "test_number" - mock_variable.value = mock_segment - mock_segment_to_variable.return_value = mock_variable + # Verify results + assert selector_tuple == ("test-node-id", "test_number") + assert variable.id == "draft-var-id" + assert variable.name == "test_number" + assert variable.description == "test number description" - # Execute the method - selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var) - - # Verify results - assert selector_tuple == ("test-node-id", "test_number") - assert variable.id == "draft-var-id" - assert variable.name == "test_number" - assert variable.description == "test number description" - - # Verify method calls - mock_storage.load.assert_called_once_with("storage/key/test_number.json") - mock_build_segment.assert_called_once_with(SegmentType.NUMBER, test_number) + # Verify method calls + mock_storage.load.assert_called_once_with("storage/key/test_number.json") + draft_var.build_segment_from_serialized_value.assert_called_once_with(SegmentType.NUMBER, test_number) def test_load_offloaded_variable_array_type_unit(self, draft_var_loader): """Test _load_offloaded_variable with array type - isolated unit test.""" @@ -226,32 +203,83 @@ class TestDraftVarLoaderSimple: with patch("services.workflow_draft_variable_service.storage") as mock_storage: mock_storage.load.return_value = test_json_content.encode() + from graphon.variables.segments import ArrayAnySegment - with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment: - from dify_graph.variables.segments import ArrayAnySegment + mock_segment = ArrayAnySegment(value=test_array) + draft_var.build_segment_from_serialized_value.return_value = mock_segment - mock_segment = ArrayAnySegment(value=test_array) - mock_build_segment.return_value = mock_segment + # Execute the method + selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var) - with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable: - mock_variable = Mock() - mock_variable.id = "draft-var-id" - mock_variable.name = "test_array" - mock_variable.value = mock_segment - mock_segment_to_variable.return_value = mock_variable + # Verify results + assert selector_tuple == ("test-node-id", "test_array") + assert variable.id == "draft-var-id" + assert variable.name == "test_array" + assert variable.description == "test array description" - # Execute the method - selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var) + # Verify method calls + mock_storage.load.assert_called_once_with("storage/key/test_array.json") + draft_var.build_segment_from_serialized_value.assert_called_once_with(SegmentType.ARRAY_ANY, test_array) - # Verify results - assert selector_tuple == ("test-node-id", "test_array") - assert variable.id == "draft-var-id" - assert variable.name == "test_array" - assert variable.description == "test array description" + def test_load_offloaded_variable_file_type_rebuilds_storage_backed_payload(self, draft_var_loader): + upload_file = Mock(spec=UploadFile) + upload_file.key = "storage/key/test_file.json" - # Verify method calls - mock_storage.load.assert_called_once_with("storage/key/test_array.json") - mock_build_segment.assert_called_once_with(SegmentType.ARRAY_ANY, test_array) + variable_file = Mock(spec=WorkflowDraftVariableFile) + variable_file.value_type = SegmentType.FILE + variable_file.upload_file = upload_file + + draft_var = WorkflowDraftVariable() + draft_var.id = "draft-var-id" + draft_var.app_id = "app-1" + draft_var.node_id = "test-node-id" + draft_var.name = "test_file" + draft_var.description = "test file description" + draft_var._set_selector(["test-node-id", "test_file"]) + draft_var.variable_file = variable_file + + persisted_file = File( + id="file-1", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.LOCAL_FILE, + reference=build_file_reference(record_id="upload-1", storage_key="legacy-storage-key"), + filename="test.txt", + extension=".txt", + mime_type="text/plain", + size=12, + ) + rebuilt_file = File( + id="file-1", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.LOCAL_FILE, + reference=build_file_reference(record_id="upload-1"), + filename="test.txt", + extension=".txt", + mime_type="text/plain", + size=12, + storage_key="canonical-storage-key", + ) + + raw_file = { + **persisted_file.model_dump(mode="json"), + "tenant_id": "legacy-tenant", + } + + with ( + patch("services.workflow_draft_variable_service.storage") as mock_storage, + patch("models.workflow._resolve_workflow_app_tenant_id", return_value="tenant-1"), + patch("models.workflow.build_file_from_stored_mapping", return_value=rebuilt_file) as rebuild_file, + ): + mock_storage.load.return_value = json.dumps(raw_file).encode() + + selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var) + + assert selector_tuple == ("test-node-id", "test_file") + assert variable.id == "draft-var-id" + assert variable.name == "test_file" + assert variable.description == "test file description" + assert variable.value == rebuilt_file + rebuild_file.assert_called_once_with(file_mapping=raw_file, tenant_id="tenant-1") def test_load_variables_with_offloaded_variables_unit(self, draft_var_loader): """Test load_variables method with mix of regular and offloaded variables.""" diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter.py b/api/tests/unit_tests/services/workflow/test_workflow_converter.py deleted file mode 100644 index a847c2b4d1a..00000000000 --- a/api/tests/unit_tests/services/workflow/test_workflow_converter.py +++ /dev/null @@ -1,431 +0,0 @@ -# test for api/services/workflow/workflow_converter.py -import json -from unittest.mock import MagicMock - -import pytest - -from core.app.app_config.entities import ( - AdvancedChatMessageEntity, - AdvancedChatPromptTemplateEntity, - AdvancedCompletionPromptTemplateEntity, - DatasetEntity, - DatasetRetrieveConfigEntity, - ExternalDataVariableEntity, - ModelConfigEntity, - PromptTemplateEntity, -) -from core.helper import encrypter -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType -from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint -from models.model import AppMode -from services.workflow.workflow_converter import WorkflowConverter - - -@pytest.fixture -def default_variables(): - value = [ - VariableEntity( - variable="text_input", - label="text-input", - type=VariableEntityType.TEXT_INPUT, - ), - VariableEntity( - variable="paragraph", - label="paragraph", - type=VariableEntityType.PARAGRAPH, - ), - VariableEntity( - variable="select", - label="select", - type=VariableEntityType.SELECT, - ), - ] - return value - - -def test__convert_to_start_node(default_variables): - # act - result = WorkflowConverter()._convert_to_start_node(default_variables) - - # assert - assert isinstance(result["data"]["variables"][0]["type"], str) - assert result["data"]["variables"][0]["type"] == "text-input" - assert result["data"]["variables"][0]["variable"] == "text_input" - assert result["data"]["variables"][1]["variable"] == "paragraph" - assert result["data"]["variables"][2]["variable"] == "select" - - -def test__convert_to_http_request_node_for_chatbot(default_variables): - """ - Test convert to http request nodes for chatbot - :return: - """ - app_model = MagicMock() - app_model.id = "app_id" - app_model.tenant_id = "tenant_id" - app_model.mode = AppMode.CHAT - - api_based_extension_id = "api_based_extension_id" - mock_api_based_extension = APIBasedExtension( - tenant_id="tenant_id", - name="api-1", - api_key="encrypted_api_key", - api_endpoint="https://dify.ai", - ) - - mock_api_based_extension.id = api_based_extension_id - workflow_converter = WorkflowConverter() - workflow_converter._get_api_based_extension = MagicMock(return_value=mock_api_based_extension) - - encrypter.decrypt_token = MagicMock(return_value="api_key") - - external_data_variables = [ - ExternalDataVariableEntity( - variable="external_variable", type="api", config={"api_based_extension_id": api_based_extension_id} - ) - ] - - nodes, _ = workflow_converter._convert_to_http_request_node( - app_model=app_model, variables=default_variables, external_data_variables=external_data_variables - ) - - assert len(nodes) == 2 - assert nodes[0]["data"]["type"] == "http-request" - - http_request_node = nodes[0] - - assert http_request_node["data"]["method"] == "post" - assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint - assert http_request_node["data"]["authorization"]["type"] == "api-key" - assert http_request_node["data"]["authorization"]["config"] == {"type": "bearer", "api_key": "api_key"} - assert http_request_node["data"]["body"]["type"] == "json" - - body_data = http_request_node["data"]["body"]["data"] - - assert body_data - - body_data_json = json.loads(body_data) - assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY - - body_params = body_data_json["params"] - assert body_params["app_id"] == app_model.id - assert body_params["tool_variable"] == external_data_variables[0].variable - assert len(body_params["inputs"]) == 3 - assert body_params["query"] == "{{#sys.query#}}" # for chatbot - - code_node = nodes[1] - assert code_node["data"]["type"] == "code" - - -def test__convert_to_http_request_node_for_workflow_app(default_variables): - """ - Test convert to http request nodes for workflow app - :return: - """ - app_model = MagicMock() - app_model.id = "app_id" - app_model.tenant_id = "tenant_id" - app_model.mode = AppMode.WORKFLOW - - api_based_extension_id = "api_based_extension_id" - mock_api_based_extension = APIBasedExtension( - tenant_id="tenant_id", - name="api-1", - api_key="encrypted_api_key", - api_endpoint="https://dify.ai", - ) - mock_api_based_extension.id = api_based_extension_id - - workflow_converter = WorkflowConverter() - workflow_converter._get_api_based_extension = MagicMock(return_value=mock_api_based_extension) - - encrypter.decrypt_token = MagicMock(return_value="api_key") - - external_data_variables = [ - ExternalDataVariableEntity( - variable="external_variable", type="api", config={"api_based_extension_id": api_based_extension_id} - ) - ] - - nodes, _ = workflow_converter._convert_to_http_request_node( - app_model=app_model, variables=default_variables, external_data_variables=external_data_variables - ) - - assert len(nodes) == 2 - assert nodes[0]["data"]["type"] == "http-request" - - http_request_node = nodes[0] - - assert http_request_node["data"]["method"] == "post" - assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint - assert http_request_node["data"]["authorization"]["type"] == "api-key" - assert http_request_node["data"]["authorization"]["config"] == {"type": "bearer", "api_key": "api_key"} - assert http_request_node["data"]["body"]["type"] == "json" - - body_data = http_request_node["data"]["body"]["data"] - - assert body_data - - body_data_json = json.loads(body_data) - assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY - - body_params = body_data_json["params"] - assert body_params["app_id"] == app_model.id - assert body_params["tool_variable"] == external_data_variables[0].variable - assert len(body_params["inputs"]) == 3 - assert body_params["query"] == "" - - code_node = nodes[1] - assert code_node["data"]["type"] == "code" - - -def test__convert_to_knowledge_retrieval_node_for_chatbot(): - new_app_mode = AppMode.ADVANCED_CHAT - - dataset_config = DatasetEntity( - dataset_ids=["dataset_id_1", "dataset_id_2"], - retrieve_config=DatasetRetrieveConfigEntity( - retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, - top_k=5, - score_threshold=0.8, - reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"}, - reranking_enabled=True, - ), - ) - - model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[]) - - node = WorkflowConverter()._convert_to_knowledge_retrieval_node( - new_app_mode=new_app_mode, dataset_config=dataset_config, model_config=model_config - ) - assert node is not None - - assert node["data"]["type"] == "knowledge-retrieval" - assert node["data"]["query_variable_selector"] == ["sys", "query"] - assert node["data"]["dataset_ids"] == dataset_config.dataset_ids - assert node["data"]["retrieval_mode"] == dataset_config.retrieve_config.retrieve_strategy.value - assert node["data"]["multiple_retrieval_config"] == { - "top_k": dataset_config.retrieve_config.top_k, - "score_threshold": dataset_config.retrieve_config.score_threshold, - "reranking_model": dataset_config.retrieve_config.reranking_model, - } - - -def test__convert_to_knowledge_retrieval_node_for_workflow_app(): - new_app_mode = AppMode.WORKFLOW - - dataset_config = DatasetEntity( - dataset_ids=["dataset_id_1", "dataset_id_2"], - retrieve_config=DatasetRetrieveConfigEntity( - query_variable="query", - retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, - top_k=5, - score_threshold=0.8, - reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"}, - reranking_enabled=True, - ), - ) - - model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[]) - - node = WorkflowConverter()._convert_to_knowledge_retrieval_node( - new_app_mode=new_app_mode, dataset_config=dataset_config, model_config=model_config - ) - assert node is not None - - assert node["data"]["type"] == "knowledge-retrieval" - assert node["data"]["query_variable_selector"] == ["start", dataset_config.retrieve_config.query_variable] - assert node["data"]["dataset_ids"] == dataset_config.dataset_ids - assert node["data"]["retrieval_mode"] == dataset_config.retrieve_config.retrieve_strategy.value - assert node["data"]["multiple_retrieval_config"] == { - "top_k": dataset_config.retrieve_config.top_k, - "score_threshold": dataset_config.retrieve_config.score_threshold, - "reranking_model": dataset_config.retrieve_config.reranking_model, - } - - -def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables): - new_app_mode = AppMode.ADVANCED_CHAT - model = "gpt-4" - model_mode = LLMMode.CHAT - - workflow_converter = WorkflowConverter() - start_node = workflow_converter._convert_to_start_node(default_variables) - graph = { - "nodes": [start_node], - "edges": [], # no need - } - - model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = "openai" - model_config_mock.model = model - model_config_mock.mode = model_mode.value - model_config_mock.parameters = {} - model_config_mock.stop = [] - - prompt_template = PromptTemplateEntity( - prompt_type=PromptTemplateEntity.PromptType.SIMPLE, - simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}.", - ) - - llm_node = workflow_converter._convert_to_llm_node( - original_app_mode=AppMode.CHAT, - new_app_mode=new_app_mode, - model_config=model_config_mock, - graph=graph, - prompt_template=prompt_template, - ) - - assert llm_node["data"]["type"] == "llm" - assert llm_node["data"]["model"]["name"] == model - assert llm_node["data"]["model"]["mode"] == model_mode.value - template = prompt_template.simple_prompt_template - assert template is not None - for v in default_variables: - template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}") - assert llm_node["data"]["prompt_template"][0]["text"] == template + "\n" - assert llm_node["data"]["context"]["enabled"] is False - - -def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variables): - new_app_mode = AppMode.ADVANCED_CHAT - model = "gpt-3.5-turbo-instruct" - model_mode = LLMMode.COMPLETION - - workflow_converter = WorkflowConverter() - start_node = workflow_converter._convert_to_start_node(default_variables) - graph = { - "nodes": [start_node], - "edges": [], # no need - } - - model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = "openai" - model_config_mock.model = model - model_config_mock.mode = model_mode.value - model_config_mock.parameters = {} - model_config_mock.stop = [] - - prompt_template = PromptTemplateEntity( - prompt_type=PromptTemplateEntity.PromptType.SIMPLE, - simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}.", - ) - - llm_node = workflow_converter._convert_to_llm_node( - original_app_mode=AppMode.CHAT, - new_app_mode=new_app_mode, - model_config=model_config_mock, - graph=graph, - prompt_template=prompt_template, - ) - - assert llm_node["data"]["type"] == "llm" - assert llm_node["data"]["model"]["name"] == model - assert llm_node["data"]["model"]["mode"] == model_mode.value - template = prompt_template.simple_prompt_template - assert template is not None - for v in default_variables: - template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}") - assert llm_node["data"]["prompt_template"]["text"] == template + "\n" - assert llm_node["data"]["context"]["enabled"] is False - - -def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables): - new_app_mode = AppMode.ADVANCED_CHAT - model = "gpt-4" - model_mode = LLMMode.CHAT - - workflow_converter = WorkflowConverter() - start_node = workflow_converter._convert_to_start_node(default_variables) - graph = { - "nodes": [start_node], - "edges": [], # no need - } - - model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = "openai" - model_config_mock.model = model - model_config_mock.mode = model_mode.value - model_config_mock.parameters = {} - model_config_mock.stop = [] - - prompt_template = PromptTemplateEntity( - prompt_type=PromptTemplateEntity.PromptType.ADVANCED, - advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity( - messages=[ - AdvancedChatMessageEntity( - text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", - role=PromptMessageRole.SYSTEM, - ), - AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER), - AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT), - ] - ), - ) - - llm_node = workflow_converter._convert_to_llm_node( - original_app_mode=AppMode.CHAT, - new_app_mode=new_app_mode, - model_config=model_config_mock, - graph=graph, - prompt_template=prompt_template, - ) - - assert llm_node["data"]["type"] == "llm" - assert llm_node["data"]["model"]["name"] == model - assert llm_node["data"]["model"]["mode"] == model_mode.value - assert isinstance(llm_node["data"]["prompt_template"], list) - assert prompt_template.advanced_chat_prompt_template is not None - assert len(llm_node["data"]["prompt_template"]) == len(prompt_template.advanced_chat_prompt_template.messages) - template = prompt_template.advanced_chat_prompt_template.messages[0].text - for v in default_variables: - template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}") - assert llm_node["data"]["prompt_template"][0]["text"] == template - - -def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_variables): - new_app_mode = AppMode.ADVANCED_CHAT - model = "gpt-3.5-turbo-instruct" - model_mode = LLMMode.COMPLETION - - workflow_converter = WorkflowConverter() - start_node = workflow_converter._convert_to_start_node(default_variables) - graph = { - "nodes": [start_node], - "edges": [], # no need - } - - model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = "openai" - model_config_mock.model = model - model_config_mock.mode = model_mode.value - model_config_mock.parameters = {} - model_config_mock.stop = [] - - prompt_template = PromptTemplateEntity( - prompt_type=PromptTemplateEntity.PromptType.ADVANCED, - advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity( - prompt="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}\n\nHuman: hi\nAssistant: ", - role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(user="Human", assistant="Assistant"), - ), - ) - - llm_node = workflow_converter._convert_to_llm_node( - original_app_mode=AppMode.CHAT, - new_app_mode=new_app_mode, - model_config=model_config_mock, - graph=graph, - prompt_template=prompt_template, - ) - - assert llm_node["data"]["type"] == "llm" - assert llm_node["data"]["model"]["name"] == model - assert llm_node["data"]["model"]["mode"] == model_mode.value - assert isinstance(llm_node["data"]["prompt_template"], dict) - assert prompt_template.advanced_completion_prompt_template is not None - template = prompt_template.advanced_completion_prompt_template.prompt - for v in default_variables: - template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}") - assert llm_node["data"]["prompt_template"]["text"] == template diff --git a/api/tests/unit_tests/services/workflow/test_workflow_deletion.py b/api/tests/unit_tests/services/workflow/test_workflow_deletion.py deleted file mode 100644 index dfe325648d2..00000000000 --- a/api/tests/unit_tests/services/workflow/test_workflow_deletion.py +++ /dev/null @@ -1,127 +0,0 @@ -from unittest.mock import MagicMock - -import pytest -from sqlalchemy.orm import Session - -from models.model import App -from models.workflow import Workflow -from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService - - -@pytest.fixture -def workflow_setup(): - mock_session_maker = MagicMock() - workflow_service = WorkflowService(mock_session_maker) - session = MagicMock(spec=Session) - tenant_id = "test-tenant-id" - workflow_id = "test-workflow-id" - - # Mock workflow - workflow = MagicMock(spec=Workflow) - workflow.id = workflow_id - workflow.tenant_id = tenant_id - workflow.version = "1.0" # Not a draft - workflow.tool_published = False # Not published as a tool by default - - # Mock app - app = MagicMock(spec=App) - app.id = "test-app-id" - app.name = "Test App" - app.workflow_id = None # Not used by an app by default - - return { - "workflow_service": workflow_service, - "session": session, - "tenant_id": tenant_id, - "workflow_id": workflow_id, - "workflow": workflow, - "app": app, - } - - -def test_delete_workflow_success(workflow_setup): - # Setup mocks - - # Mock the tool provider query to return None (not published as a tool) - workflow_setup["session"].query.return_value.where.return_value.first.return_value = None - - workflow_setup["session"].scalar = MagicMock( - side_effect=[workflow_setup["workflow"], None] - ) # Return workflow first, then None for app - - # Call the method - result = workflow_setup["workflow_service"].delete_workflow( - session=workflow_setup["session"], - workflow_id=workflow_setup["workflow_id"], - tenant_id=workflow_setup["tenant_id"], - ) - - # Verify - assert result is True - workflow_setup["session"].delete.assert_called_once_with(workflow_setup["workflow"]) - - -def test_delete_workflow_draft_error(workflow_setup): - # Setup mocks - workflow_setup["workflow"].version = "draft" - workflow_setup["session"].scalar = MagicMock(return_value=workflow_setup["workflow"]) - - # Call the method and verify exception - with pytest.raises(DraftWorkflowDeletionError): - workflow_setup["workflow_service"].delete_workflow( - session=workflow_setup["session"], - workflow_id=workflow_setup["workflow_id"], - tenant_id=workflow_setup["tenant_id"], - ) - - # Verify - workflow_setup["session"].delete.assert_not_called() - - -def test_delete_workflow_in_use_by_app_error(workflow_setup): - # Setup mocks - workflow_setup["app"].workflow_id = workflow_setup["workflow_id"] - workflow_setup["session"].scalar = MagicMock( - side_effect=[workflow_setup["workflow"], workflow_setup["app"]] - ) # Return workflow first, then app - - # Call the method and verify exception - with pytest.raises(WorkflowInUseError) as excinfo: - workflow_setup["workflow_service"].delete_workflow( - session=workflow_setup["session"], - workflow_id=workflow_setup["workflow_id"], - tenant_id=workflow_setup["tenant_id"], - ) - - # Verify error message contains app name - assert "Cannot delete workflow that is currently in use by app" in str(excinfo.value) - - # Verify - workflow_setup["session"].delete.assert_not_called() - - -def test_delete_workflow_published_as_tool_error(workflow_setup): - # Setup mocks - from models.tools import WorkflowToolProvider - - # Mock the tool provider query - mock_tool_provider = MagicMock(spec=WorkflowToolProvider) - workflow_setup["session"].query.return_value.where.return_value.first.return_value = mock_tool_provider - - workflow_setup["session"].scalar = MagicMock( - side_effect=[workflow_setup["workflow"], None] - ) # Return workflow first, then None for app - - # Call the method and verify exception - with pytest.raises(WorkflowInUseError) as excinfo: - workflow_setup["workflow_service"].delete_workflow( - session=workflow_setup["session"], - workflow_id=workflow_setup["workflow_id"], - tenant_id=workflow_setup["tenant_id"], - ) - - # Verify error message - assert "Cannot delete workflow that is published as a tool" in str(excinfo.value) - - # Verify - workflow_setup["session"].delete.assert_not_called() diff --git a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py index 0c2be9c79fd..e7e72793a32 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py @@ -4,13 +4,19 @@ import uuid from unittest.mock import MagicMock, Mock, patch import pytest +from graphon.enums import BuiltinNodeTypes +from graphon.file import File, FileTransferMethod, FileType +from graphon.variables.segments import StringSegment +from graphon.variables.types import SegmentType from sqlalchemy import Engine from sqlalchemy.orm import Session -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID -from dify_graph.enums import BuiltinNodeTypes, SystemVariableKey -from dify_graph.variables.segments import StringSegment -from dify_graph.variables.types import SegmentType +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, + SYSTEM_VARIABLE_NODE_ID, +) from libs.uuid_utils import uuidv7 from models.account import Account from models.enums import DraftVariableType @@ -86,6 +92,20 @@ class TestDraftVariableSaver: expected_node_id=_NODE_ID, expected_name="start_input", ), + TestCase( + name="name with `env.` prefix should return the environment node_id", + input_node_id=_NODE_ID, + input_name="env.API_KEY", + expected_node_id=ENVIRONMENT_VARIABLE_NODE_ID, + expected_name="API_KEY", + ), + TestCase( + name="name with `conversation.` prefix should return the conversation node_id", + input_node_id=_NODE_ID, + input_name="conversation.session_id", + expected_node_id=CONVERSATION_VARIABLE_NODE_ID, + expected_name="session_id", + ), TestCase( name="dummy_variable should return the original input node_id", input_node_id=_NODE_ID, @@ -112,6 +132,47 @@ class TestDraftVariableSaver: assert node_id == c.expected_node_id, fail_msg assert name == c.expected_name, fail_msg + def test_build_variables_from_start_mapping_rebuilds_system_files(self): + mock_session = MagicMock(spec=Session) + mock_user = MagicMock(spec=Account) + mock_user.id = str(uuid.uuid4()) + saver = DraftVariableSaver( + session=mock_session, + app_id=self._get_test_app_id(), + node_id="start", + node_type=BuiltinNodeTypes.START, + node_execution_id="exec-1", + user=mock_user, + ) + rebuilt_file = File( + id="file-1", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.LOCAL_FILE, + reference="upload-1", + filename="test.txt", + extension=".txt", + mime_type="text/plain", + size=12, + storage_key="canonical-storage-key", + ) + raw_file = { + **rebuilt_file.model_dump(mode="json"), + "tenant_id": "legacy-tenant", + } + + with ( + patch.object(saver, "_resolve_app_tenant_id", return_value="tenant-1"), + patch( + "services.workflow_draft_variable_service.build_file_from_stored_mapping", + return_value=rebuilt_file, + ) as rebuild_file, + ): + draft_vars = saver._build_variables_from_start_mapping({"sys.files": [raw_file]}) + + sys_var = draft_vars[0] + assert sys_var.get_value().value[0] == rebuilt_file + rebuild_file.assert_called_once_with(file_mapping=raw_file, tenant_id="tenant-1") + @pytest.fixture def mock_session(self): """Mock SQLAlchemy session.""" @@ -218,6 +279,46 @@ class TestDraftVariableSaver: str(SystemVariableKey.WORKFLOW_EXECUTION_ID), } + @patch("services.workflow_draft_variable_service._batch_upsert_draft_variable", autospec=True) + def test_start_node_save_normalizes_reserved_prefix_outputs(self, mock_batch_upsert): + mock_session = MagicMock(spec=Session) + mock_user = MagicMock(spec=Account) + mock_user.id = "test-user-id" + mock_user.tenant_id = "test-tenant-id" + + saver = DraftVariableSaver( + session=mock_session, + app_id="test-app-id", + node_id="start-node-id", + node_type=BuiltinNodeTypes.START, + node_execution_id="exec-id", + user=mock_user, + ) + + saver.save( + outputs={ + "env.API_KEY": "secret", + "conversation.session_id": "conversation-1", + "sys.workflow_run_id": "run-id-123", + } + ) + + mock_batch_upsert.assert_called_once() + draft_vars = mock_batch_upsert.call_args[0][1] + + assert len(draft_vars) == 3 + + env_var = next(v for v in draft_vars if v.node_id == ENVIRONMENT_VARIABLE_NODE_ID) + assert env_var.name == "API_KEY" + assert env_var.editable is False + + conversation_var = next(v for v in draft_vars if v.node_id == CONVERSATION_VARIABLE_NODE_ID) + assert conversation_var.name == "session_id" + assert conversation_var.node_execution_id is None + + sys_var = next(v for v in draft_vars if v.node_id == SYSTEM_VARIABLE_NODE_ID) + assert sys_var.name == str(SystemVariableKey.WORKFLOW_EXECUTION_ID) + class TestWorkflowDraftVariableService: def _get_test_app_id(self): diff --git a/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py b/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py index 6c1adba2b8b..077a7c27a2b 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py @@ -8,13 +8,13 @@ from datetime import UTC, datetime from threading import Event import pytest +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus -from dify_graph.runtime import GraphRuntimeState, VariablePool from models.enums import CreatorUserRole from models.model import AppMode from models.workflow import WorkflowRun diff --git a/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py b/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py index c890ab6a65b..98d057e41fe 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py @@ -3,16 +3,16 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.human_input.entities import HumanInputNodeData from sqlalchemy.orm import sessionmaker -from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - HumanInputNodeData, MemberRecipient, ) from services import workflow_service as workflow_service_module @@ -23,7 +23,7 @@ def _make_service() -> WorkflowService: return WorkflowService(session_maker=sessionmaker()) -def _build_node_config(delivery_methods: list[EmailDeliveryMethod]) -> NodeConfigDict: +def _build_node_config(delivery_methods: list[EmailDeliveryMethod], *, legacy: bool = False) -> NodeConfigDict: node_data = HumanInputNodeData( title="Human Input", delivery_methods=delivery_methods, @@ -31,6 +31,14 @@ def _build_node_config(delivery_methods: list[EmailDeliveryMethod]) -> NodeConfi inputs=[], user_actions=[], ).model_dump(mode="json") + if legacy: + for delivery_method in node_data["delivery_methods"]: + recipients = delivery_method.get("config", {}).get("recipients", {}) + if "include_bound_group" in recipients: + recipients["whole_workspace"] = recipients.pop("include_bound_group") + for recipient in recipients.get("items", []): + if "reference_id" in recipient: + recipient["user_id"] = recipient.pop("reference_id") node_data["type"] = BuiltinNodeTypes.HUMAN_INPUT return NodeConfigDictAdapter.validate_python({"id": "node-1", "data": node_data}) @@ -41,7 +49,7 @@ def _make_email_method(enabled: bool = True, debug_mode: bool = False) -> EmailD enabled=enabled, config=EmailDeliveryConfig( recipients=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ExternalRecipient(email="tester@example.com")], ), subject="Test subject", @@ -69,7 +77,7 @@ def test_human_input_delivery_requires_draft_workflow(): def test_human_input_delivery_allows_disabled_method(monkeypatch: pytest.MonkeyPatch): service = _make_service() delivery_method = _make_email_method(enabled=False) - node_config = _build_node_config([delivery_method]) + node_config = _build_node_config([delivery_method], legacy=True) workflow = MagicMock() workflow.get_node_config_by_id.return_value = node_config service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] @@ -105,7 +113,7 @@ def test_human_input_delivery_allows_disabled_method(monkeypatch: pytest.MonkeyP def test_human_input_delivery_dispatches_to_test_service(monkeypatch: pytest.MonkeyPatch): service = _make_service() delivery_method = _make_email_method(enabled=True) - node_config = _build_node_config([delivery_method]) + node_config = _build_node_config([delivery_method], legacy=True) workflow = MagicMock() workflow.get_node_config_by_id.return_value = node_config service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] @@ -144,7 +152,7 @@ def test_human_input_delivery_dispatches_to_test_service(monkeypatch: pytest.Mon def test_human_input_delivery_debug_mode_overrides_recipients(monkeypatch: pytest.MonkeyPatch): service = _make_service() delivery_method = _make_email_method(enabled=True, debug_mode=True) - node_config = _build_node_config([delivery_method]) + node_config = _build_node_config([delivery_method], legacy=True) workflow = MagicMock() workflow.get_node_config_by_id.return_value = node_config service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] @@ -178,8 +186,8 @@ def test_human_input_delivery_debug_mode_overrides_recipients(monkeypatch: pytes sent_method = test_service_instance.send_test.call_args.kwargs["method"] assert isinstance(sent_method, EmailDeliveryMethod) assert sent_method.config.debug_mode is True - assert sent_method.config.recipients.whole_workspace is False + assert sent_method.config.recipients.include_bound_group is False assert len(sent_method.config.recipients.items) == 1 recipient = sent_method.config.recipients.items[0] assert isinstance(recipient, MemberRecipient) - assert recipient.user_id == account.id + assert recipient.reference_id == account.id diff --git a/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py b/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py deleted file mode 100644 index 79bf5e94c28..00000000000 --- a/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py +++ /dev/null @@ -1,30 +0,0 @@ -from unittest.mock import MagicMock - -import pytest - -from repositories.sqlalchemy_api_workflow_node_execution_repository import ( - DifyAPISQLAlchemyWorkflowNodeExecutionRepository, -) - - -class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: - @pytest.fixture - def repository(self): - mock_session_maker = MagicMock() - return DifyAPISQLAlchemyWorkflowNodeExecutionRepository(session_maker=mock_session_maker) - - def test_repository_implements_protocol(self, repository): - """Test that the repository implements the required protocol methods.""" - # Verify all protocol methods are implemented - assert hasattr(repository, "get_node_last_execution") - assert hasattr(repository, "get_executions_by_workflow_run") - assert hasattr(repository, "get_execution_by_id") - - # Verify methods are callable - assert callable(repository.get_node_last_execution) - assert callable(repository.get_executions_by_workflow_run) - assert callable(repository.get_execution_by_id) - assert callable(repository.delete_expired_executions) - assert callable(repository.delete_executions_by_app) - assert callable(repository.get_expired_executions_batch) - assert callable(repository.delete_executions_by_ids) diff --git a/api/tests/unit_tests/services/workflow/test_workflow_restore.py b/api/tests/unit_tests/services/workflow/test_workflow_restore.py new file mode 100644 index 00000000000..179361de452 --- /dev/null +++ b/api/tests/unit_tests/services/workflow/test_workflow_restore.py @@ -0,0 +1,77 @@ +import json +from types import SimpleNamespace + +from models.workflow import Workflow +from services.workflow_restore import apply_published_workflow_snapshot_to_draft + +LEGACY_FEATURES = { + "file_upload": { + "image": { + "enabled": True, + "number_limits": 6, + "transfer_methods": ["remote_url", "local_file"], + } + }, + "opening_statement": "", + "retriever_resource": {"enabled": True}, + "sensitive_word_avoidance": {"enabled": False}, + "speech_to_text": {"enabled": False}, + "suggested_questions": [], + "suggested_questions_after_answer": {"enabled": False}, + "text_to_speech": {"enabled": False, "language": "", "voice": ""}, +} + +NORMALIZED_FEATURES = { + "file_upload": { + "enabled": True, + "allowed_file_types": ["image"], + "allowed_file_extensions": [], + "allowed_file_upload_methods": ["remote_url", "local_file"], + "number_limits": 6, + }, + "opening_statement": "", + "retriever_resource": {"enabled": True}, + "sensitive_word_avoidance": {"enabled": False}, + "speech_to_text": {"enabled": False}, + "suggested_questions": [], + "suggested_questions_after_answer": {"enabled": False}, + "text_to_speech": {"enabled": False, "language": "", "voice": ""}, +} + + +def _create_workflow(*, workflow_id: str, version: str, features: dict[str, object]) -> Workflow: + return Workflow( + id=workflow_id, + tenant_id="tenant-id", + app_id="app-id", + type="workflow", + version=version, + graph=json.dumps({"nodes": [], "edges": []}), + features=json.dumps(features), + created_by="account-id", + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + + +def test_apply_published_workflow_snapshot_to_draft_copies_serialized_features_without_mutating_source() -> None: + source_workflow = _create_workflow( + workflow_id="published-workflow-id", + version="2026-03-19T00:00:00", + features=LEGACY_FEATURES, + ) + + draft_workflow, is_new_draft = apply_published_workflow_snapshot_to_draft( + tenant_id="tenant-id", + app_id="app-id", + source_workflow=source_workflow, + draft_workflow=None, + account=SimpleNamespace(id="account-id"), + updated_at_factory=lambda: source_workflow.updated_at, + ) + + assert is_new_draft is True + assert source_workflow.serialized_features == json.dumps(LEGACY_FEATURES) + assert source_workflow.normalized_features_dict == NORMALIZED_FEATURES + assert draft_workflow.serialized_features == json.dumps(LEGACY_FEATURES) diff --git a/api/tests/unit_tests/services/workflow/test_workflow_service.py b/api/tests/unit_tests/services/workflow/test_workflow_service.py deleted file mode 100644 index 538c1b3595c..00000000000 --- a/api/tests/unit_tests/services/workflow/test_workflow_service.py +++ /dev/null @@ -1,415 +0,0 @@ -from contextlib import nullcontext -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest - -from dify_graph.entities.graph_config import NodeConfigDictAdapter -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.nodes.human_input.entities import FormInput, HumanInputNodeData, UserAction -from dify_graph.nodes.human_input.enums import FormInputType -from models.model import App -from models.workflow import Workflow -from services import workflow_service as workflow_service_module -from services.workflow_service import WorkflowService - - -class TestWorkflowService: - @pytest.fixture - def workflow_service(self): - mock_session_maker = MagicMock() - return WorkflowService(mock_session_maker) - - @pytest.fixture - def mock_app(self): - app = MagicMock(spec=App) - app.id = "app-id-1" - app.workflow_id = "workflow-id-1" - app.tenant_id = "tenant-id-1" - return app - - @pytest.fixture - def mock_workflows(self): - workflows = [] - for i in range(5): - workflow = MagicMock(spec=Workflow) - workflow.id = f"workflow-id-{i}" - workflow.app_id = "app-id-1" - workflow.created_at = f"2023-01-0{5 - i}" # Descending date order - workflow.created_by = "user-id-1" if i % 2 == 0 else "user-id-2" - workflow.marked_name = f"Workflow {i}" if i % 2 == 0 else "" - workflows.append(workflow) - return workflows - - @pytest.fixture - def dummy_session_cls(self): - class DummySession: - def __init__(self, *args, **kwargs): - self.commit = MagicMock() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def begin(self): - return nullcontext() - - return DummySession - - def test_get_all_published_workflow_no_workflow_id(self, workflow_service, mock_app): - mock_app.workflow_id = None - mock_session = MagicMock() - - workflows, has_more = workflow_service.get_all_published_workflow( - session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None - ) - - assert workflows == [] - assert has_more is False - mock_session.scalars.assert_not_called() - - def test_get_all_published_workflow_basic(self, workflow_service, mock_app, mock_workflows): - mock_session = MagicMock() - mock_scalar_result = MagicMock() - mock_scalar_result.all.return_value = mock_workflows[:3] - mock_session.scalars.return_value = mock_scalar_result - - workflows, has_more = workflow_service.get_all_published_workflow( - session=mock_session, app_model=mock_app, page=1, limit=3, user_id=None - ) - - assert workflows == mock_workflows[:3] - assert has_more is False - mock_session.scalars.assert_called_once() - - def test_get_all_published_workflow_pagination(self, workflow_service, mock_app, mock_workflows): - mock_session = MagicMock() - mock_scalar_result = MagicMock() - # Return 4 items when limit is 3, which should indicate has_more=True - mock_scalar_result.all.return_value = mock_workflows[:4] - mock_session.scalars.return_value = mock_scalar_result - - workflows, has_more = workflow_service.get_all_published_workflow( - session=mock_session, app_model=mock_app, page=1, limit=3, user_id=None - ) - - # Should return only the first 3 items - assert len(workflows) == 3 - assert workflows == mock_workflows[:3] - assert has_more is True - - # Test page 2 - mock_scalar_result.all.return_value = mock_workflows[3:] - mock_session.scalars.return_value = mock_scalar_result - - workflows, has_more = workflow_service.get_all_published_workflow( - session=mock_session, app_model=mock_app, page=2, limit=3, user_id=None - ) - - assert len(workflows) == 2 - assert has_more is False - - def test_get_all_published_workflow_user_filter(self, workflow_service, mock_app, mock_workflows): - mock_session = MagicMock() - mock_scalar_result = MagicMock() - # Filter workflows for user-id-1 - filtered_workflows = [w for w in mock_workflows if w.created_by == "user-id-1"] - mock_scalar_result.all.return_value = filtered_workflows - mock_session.scalars.return_value = mock_scalar_result - - workflows, has_more = workflow_service.get_all_published_workflow( - session=mock_session, app_model=mock_app, page=1, limit=10, user_id="user-id-1" - ) - - assert workflows == filtered_workflows - assert has_more is False - mock_session.scalars.assert_called_once() - - # Verify that the select contains a user filter clause - args = mock_session.scalars.call_args[0][0] - assert "created_by" in str(args) - - def test_get_all_published_workflow_named_only(self, workflow_service, mock_app, mock_workflows): - mock_session = MagicMock() - mock_scalar_result = MagicMock() - # Filter workflows that have a marked_name - named_workflows = [w for w in mock_workflows if w.marked_name] - mock_scalar_result.all.return_value = named_workflows - mock_session.scalars.return_value = mock_scalar_result - - workflows, has_more = workflow_service.get_all_published_workflow( - session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None, named_only=True - ) - - assert workflows == named_workflows - assert has_more is False - mock_session.scalars.assert_called_once() - - # Verify that the select contains a named_only filter clause - args = mock_session.scalars.call_args[0][0] - assert "marked_name !=" in str(args) - - def test_get_all_published_workflow_combined_filters(self, workflow_service, mock_app, mock_workflows): - mock_session = MagicMock() - mock_scalar_result = MagicMock() - # Combined filter: user-id-1 and has marked_name - filtered_workflows = [w for w in mock_workflows if w.created_by == "user-id-1" and w.marked_name] - mock_scalar_result.all.return_value = filtered_workflows - mock_session.scalars.return_value = mock_scalar_result - - workflows, has_more = workflow_service.get_all_published_workflow( - session=mock_session, app_model=mock_app, page=1, limit=10, user_id="user-id-1", named_only=True - ) - - assert workflows == filtered_workflows - assert has_more is False - mock_session.scalars.assert_called_once() - - # Verify that both filters are applied - args = mock_session.scalars.call_args[0][0] - assert "created_by" in str(args) - assert "marked_name !=" in str(args) - - def test_get_all_published_workflow_empty_result(self, workflow_service, mock_app): - mock_session = MagicMock() - mock_scalar_result = MagicMock() - mock_scalar_result.all.return_value = [] - mock_session.scalars.return_value = mock_scalar_result - - workflows, has_more = workflow_service.get_all_published_workflow( - session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None - ) - - assert workflows == [] - assert has_more is False - mock_session.scalars.assert_called_once() - - def test_submit_human_input_form_preview_uses_rendered_content( - self, - workflow_service: WorkflowService, - monkeypatch: pytest.MonkeyPatch, - dummy_session_cls, - ) -> None: - service = workflow_service - node_data = HumanInputNodeData( - title="Human Input", - form_content="

{{#$output.name#}}

", - inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")], - user_actions=[UserAction(id="approve", title="Approve")], - ) - node = MagicMock() - node.node_data = node_data - node.render_form_content_before_submission.return_value = "

preview

" - node.render_form_content_with_outputs.return_value = "

rendered

" - - service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] - service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign] - - workflow = MagicMock() - node_config = NodeConfigDictAdapter.validate_python( - {"id": "node-1", "data": {"type": BuiltinNodeTypes.HUMAN_INPUT}} - ) - workflow.get_node_config_by_id.return_value = node_config - workflow.get_enclosing_node_type_and_id.return_value = None - service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] - - saved_outputs: dict[str, object] = {} - - class DummySaver: - def __init__(self, *args, **kwargs): - pass - - def save(self, outputs, process_data): - saved_outputs.update(outputs) - - monkeypatch.setattr(workflow_service_module, "Session", dummy_session_cls) - monkeypatch.setattr(workflow_service_module, "DraftVariableSaver", DummySaver) - monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock())) - - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") - account = SimpleNamespace(id="account-1") - - result = service.submit_human_input_form_preview( - app_model=app_model, - account=account, - node_id="node-1", - form_inputs={"name": "Ada", "extra": "ignored"}, - inputs={"#node-0.result#": "LLM output"}, - action="approve", - ) - - service._build_human_input_variable_pool.assert_called_once_with( - app_model=app_model, - workflow=workflow, - node_config=node_config, - manual_inputs={"#node-0.result#": "LLM output"}, - user_id="account-1", - ) - - node.render_form_content_with_outputs.assert_called_once() - called_args = node.render_form_content_with_outputs.call_args.args - assert called_args[0] == "

preview

" - assert called_args[2] == node_data.outputs_field_names() - rendered_outputs = called_args[1] - assert rendered_outputs["name"] == "Ada" - assert rendered_outputs["extra"] == "ignored" - assert "extra" in saved_outputs - assert "extra" in result - assert saved_outputs["name"] == "Ada" - assert result["name"] == "Ada" - assert result["__action_id"] == "approve" - assert "__rendered_content" in result - - def test_submit_human_input_form_preview_missing_inputs_message(self, workflow_service: WorkflowService) -> None: - service = workflow_service - node_data = HumanInputNodeData( - title="Human Input", - form_content="

{{#$output.name#}}

", - inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")], - user_actions=[UserAction(id="approve", title="Approve")], - ) - node = MagicMock() - node.node_data = node_data - node._render_form_content_before_submission.return_value = "

preview

" - node._render_form_content_with_outputs.return_value = "

rendered

" - - service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] - service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign] - - workflow = MagicMock() - workflow.get_node_config_by_id.return_value = NodeConfigDictAdapter.validate_python( - {"id": "node-1", "data": {"type": BuiltinNodeTypes.HUMAN_INPUT}} - ) - service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] - - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") - account = SimpleNamespace(id="account-1") - with pytest.raises(ValueError) as exc_info: - service.submit_human_input_form_preview( - app_model=app_model, - account=account, - node_id="node-1", - form_inputs={}, - inputs={}, - action="approve", - ) - - assert "Missing required inputs" in str(exc_info.value) - - def test_run_draft_workflow_node_successful_behavior( - self, workflow_service, mock_app, monkeypatch, dummy_session_cls - ): - """Behavior: When a basic workflow node runs, it correctly sets up context, - executes the node, and saves outputs.""" - service = workflow_service - account = SimpleNamespace(id="account-1") - mock_workflow = MagicMock() - mock_workflow.id = "wf-1" - mock_workflow.tenant_id = "tenant-1" - mock_workflow.environment_variables = [] - mock_workflow.conversation_variables = [] - - # Mock node config - mock_workflow.get_node_config_by_id.return_value = NodeConfigDictAdapter.validate_python( - {"id": "node-1", "data": {"type": BuiltinNodeTypes.LLM}} - ) - mock_workflow.get_enclosing_node_type_and_id.return_value = None - - # Mock class methods - monkeypatch.setattr(workflow_service_module, "WorkflowDraftVariableService", MagicMock()) - monkeypatch.setattr(workflow_service_module, "DraftVarLoader", MagicMock()) - - # Mock workflow entry execution - mock_node_exec = MagicMock() - mock_node_exec.id = "exec-1" - mock_node_exec.process_data = {} - mock_run = MagicMock() - monkeypatch.setattr(workflow_service_module.WorkflowEntry, "single_step_run", mock_run) - - # Mock execution handling - service._handle_single_step_result = MagicMock(return_value=mock_node_exec) - - # Mock repository - mock_repo = MagicMock() - mock_repo.get_execution_by_id.return_value = mock_node_exec - mock_repo_factory = MagicMock(return_value=mock_repo) - monkeypatch.setattr( - workflow_service_module.DifyCoreRepositoryFactory, - "create_workflow_node_execution_repository", - mock_repo_factory, - ) - service._node_execution_service_repo = mock_repo - - # Set up node execution service repo mock to return our exec node - mock_node_exec.load_full_outputs.return_value = {"output_var": "result_value"} - mock_node_exec.node_id = "node-1" - mock_node_exec.node_type = "llm" - - # Mock draft variable saver - mock_saver = MagicMock() - monkeypatch.setattr(workflow_service_module, "DraftVariableSaver", MagicMock(return_value=mock_saver)) - - # Mock DB - monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock())) - - monkeypatch.setattr(workflow_service_module, "Session", dummy_session_cls) - - # Act - result = service.run_draft_workflow_node( - app_model=mock_app, - draft_workflow=mock_workflow, - node_id="node-1", - user_inputs={"input_val": "test"}, - account=account, - ) - - # Assert - assert result == mock_node_exec - service._handle_single_step_result.assert_called_once() - mock_repo.save.assert_called_once_with(mock_node_exec) - mock_saver.save.assert_called_once_with(process_data={}, outputs={"output_var": "result_value"}) - - def test_run_draft_workflow_node_failure_behavior(self, workflow_service, mock_app, monkeypatch, dummy_session_cls): - """Behavior: If retrieving the saved execution fails, an appropriate error bubble matches expectations.""" - service = workflow_service - account = SimpleNamespace(id="account-1") - mock_workflow = MagicMock() - mock_workflow.tenant_id = "tenant-1" - mock_workflow.environment_variables = [] - mock_workflow.conversation_variables = [] - mock_workflow.get_node_config_by_id.return_value = NodeConfigDictAdapter.validate_python( - {"id": "node-1", "data": {"type": BuiltinNodeTypes.LLM}} - ) - mock_workflow.get_enclosing_node_type_and_id.return_value = None - - monkeypatch.setattr(workflow_service_module, "WorkflowDraftVariableService", MagicMock()) - monkeypatch.setattr(workflow_service_module, "DraftVarLoader", MagicMock()) - monkeypatch.setattr(workflow_service_module.WorkflowEntry, "single_step_run", MagicMock()) - - mock_node_exec = MagicMock() - mock_node_exec.id = "exec-invalid" - service._handle_single_step_result = MagicMock(return_value=mock_node_exec) - - mock_repo = MagicMock() - mock_repo_factory = MagicMock(return_value=mock_repo) - monkeypatch.setattr( - workflow_service_module.DifyCoreRepositoryFactory, - "create_workflow_node_execution_repository", - mock_repo_factory, - ) - service._node_execution_service_repo = mock_repo - - # Simulate failure to retrieve the saved execution - mock_repo.get_execution_by_id.return_value = None - - monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock())) - - monkeypatch.setattr(workflow_service_module, "Session", dummy_session_cls) - - # Act & Assert - with pytest.raises(ValueError, match="WorkflowNodeExecution with id exec-invalid not found after saving"): - service.run_draft_workflow_node( - app_model=mock_app, draft_workflow=mock_workflow, node_id="node-1", user_inputs={}, account=account - ) diff --git a/api/tests/unit_tests/tasks/test_clean_dataset_task.py b/api/tests/unit_tests/tasks/test_clean_dataset_task.py index 74ba7f9c340..936a10d6c5d 100644 --- a/api/tests/unit_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/unit_tests/tasks/test_clean_dataset_task.py @@ -16,6 +16,7 @@ from unittest.mock import MagicMock, patch import pytest +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.enums import DataSourceType from tasks.clean_dataset_task import clean_dataset_task @@ -183,10 +184,10 @@ class TestErrorHandling: clean_dataset_task( dataset_id=dataset_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) # Assert @@ -228,10 +229,10 @@ class TestPipelineAndWorkflowDeletion: clean_dataset_task( dataset_id=dataset_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, pipeline_id=pipeline_id, ) @@ -264,10 +265,10 @@ class TestPipelineAndWorkflowDeletion: clean_dataset_task( dataset_id=dataset_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, pipeline_id=None, ) @@ -320,10 +321,10 @@ class TestSegmentAttachmentCleanup: clean_dataset_task( dataset_id=dataset_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) # Assert @@ -365,10 +366,10 @@ class TestSegmentAttachmentCleanup: clean_dataset_task( dataset_id=dataset_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) # Assert - storage delete was attempted @@ -407,10 +408,10 @@ class TestEdgeCases: clean_dataset_task( dataset_id=dataset_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) # Assert @@ -444,7 +445,7 @@ class TestIndexProcessorParameters: - Dataset object with correct attributes is passed """ # Arrange - indexing_technique = "high_quality" + indexing_technique = IndexTechniqueType.HIGH_QUALITY index_struct = '{"type": "paragraph"}' # Act @@ -454,7 +455,7 @@ class TestIndexProcessorParameters: indexing_technique=indexing_technique, index_struct=index_struct, collection_binding_id=collection_binding_id, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) # Assert diff --git a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py index 8a721124d64..0b189ebae29 100644 --- a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py @@ -15,6 +15,7 @@ from unittest.mock import MagicMock, Mock, patch import pytest from core.indexing_runner import DocumentIsPausedError +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from extensions.ext_redis import redis_client @@ -58,6 +59,11 @@ def mock_redis(): # Redis is already mocked globally in conftest.py # Reset it for each test redis_client.reset_mock() + redis_client.get.reset_mock() + redis_client.setex.reset_mock() + redis_client.delete.reset_mock() + redis_client.lpush.reset_mock() + redis_client.rpop.reset_mock() redis_client.get.return_value = None redis_client.setex.return_value = True redis_client.delete.return_value = True @@ -203,7 +209,7 @@ def mock_dataset(dataset_id, tenant_id): dataset = Mock(spec=Dataset) dataset.id = dataset_id dataset.tenant_id = tenant_id - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.embedding_model_provider = "openai" dataset.embedding_model = "text-embedding-ada-002" return dataset @@ -222,7 +228,7 @@ def mock_documents(document_ids, dataset_id): doc.stopped_at = None doc.processing_started_at = None # optional attribute used in some code paths - doc.doc_form = "text_model" + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX documents.append(doc) return documents diff --git a/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py index 3668416e36b..f49f4535af7 100644 --- a/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py @@ -11,6 +11,7 @@ from unittest.mock import MagicMock, Mock, patch import pytest +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document from tasks.document_indexing_sync_task import document_indexing_sync_task @@ -62,7 +63,7 @@ def mock_document(document_id, dataset_id, notion_workspace_id, notion_page_id, document.tenant_id = str(uuid.uuid4()) document.data_source_type = "notion_import" document.indexing_status = "completed" - document.doc_form = "text_model" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX document.data_source_info_dict = { "notion_workspace_id": notion_workspace_id, "notion_page_id": notion_page_id, diff --git a/api/tests/unit_tests/tasks/test_enterprise_telemetry_task.py b/api/tests/unit_tests/tasks/test_enterprise_telemetry_task.py new file mode 100644 index 00000000000..b48c69a146b --- /dev/null +++ b/api/tests/unit_tests/tasks/test_enterprise_telemetry_task.py @@ -0,0 +1,69 @@ +"""Unit tests for enterprise telemetry Celery task.""" + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope +from tasks.enterprise_telemetry_task import process_enterprise_telemetry + + +@pytest.fixture +def sample_envelope_json(): + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_CREATED, + tenant_id="test-tenant", + event_id="test-event-123", + payload={"app_id": "app-123"}, + ) + return envelope.model_dump_json() + + +def test_process_enterprise_telemetry_success(sample_envelope_json): + with patch("tasks.enterprise_telemetry_task.EnterpriseMetricHandler") as mock_handler_class: + mock_handler = MagicMock() + mock_handler_class.return_value = mock_handler + + process_enterprise_telemetry(sample_envelope_json) + + mock_handler.handle.assert_called_once() + call_args = mock_handler.handle.call_args[0][0] + assert isinstance(call_args, TelemetryEnvelope) + assert call_args.case == TelemetryCase.APP_CREATED + assert call_args.tenant_id == "test-tenant" + assert call_args.event_id == "test-event-123" + + +def test_process_enterprise_telemetry_invalid_json(caplog): + invalid_json = "not valid json" + + process_enterprise_telemetry(invalid_json) + + assert "Failed to process enterprise telemetry envelope" in caplog.text + + +def test_process_enterprise_telemetry_handler_exception(sample_envelope_json, caplog): + with patch("tasks.enterprise_telemetry_task.EnterpriseMetricHandler") as mock_handler_class: + mock_handler = MagicMock() + mock_handler.handle.side_effect = Exception("Handler error") + mock_handler_class.return_value = mock_handler + + process_enterprise_telemetry(sample_envelope_json) + + assert "Failed to process enterprise telemetry envelope" in caplog.text + + +def test_process_enterprise_telemetry_validation_error(caplog): + invalid_envelope = json.dumps( + { + "case": "INVALID_CASE", + "tenant_id": "test-tenant", + "event_id": "test-event", + "payload": {}, + } + ) + + process_enterprise_telemetry(invalid_envelope) + + assert "Failed to process enterprise telemetry envelope" in caplog.text diff --git a/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py b/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py index bd0182a4029..7119217e94e 100644 --- a/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py +++ b/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py @@ -5,8 +5,8 @@ from types import SimpleNamespace from typing import Any import pytest +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus -from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from tasks import human_input_timeout_tasks as task_module diff --git a/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py b/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py index a223f0119e8..8cac696d98f 100644 --- a/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py +++ b/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py @@ -11,11 +11,11 @@ # import pytest -# from dify_graph.entities.workflow_node_execution import ( +# from graphon.entities.workflow_node_execution import ( # WorkflowNodeExecution, # WorkflowNodeExecutionStatus, # ) -# from dify_graph.enums import BuiltinNodeTypes +# from graphon.enums import BuiltinNodeTypes # from libs.datetime_utils import naive_utc_now # from models import WorkflowNodeExecutionModel # from models.enums import ExecutionOffLoadType diff --git a/api/tests/unit_tests/tools/test_mcp_tool.py b/api/tests/unit_tests/tools/test_mcp_tool.py index fa9c6af2874..68359ba078d 100644 --- a/api/tests/unit_tests/tools/test_mcp_tool.py +++ b/api/tests/unit_tests/tools/test_mcp_tool.py @@ -3,6 +3,7 @@ from decimal import Decimal from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.entities.llm_entities import LLMUsage from core.mcp.types import ( AudioContent, @@ -17,7 +18,6 @@ from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage from core.tools.mcp_tool.tool import MCPTool -from dify_graph.model_runtime.entities.llm_entities import LLMUsage def _make_mcp_tool(output_schema: dict | None = None) -> MCPTool: diff --git a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py index 7ec1343f980..ffa6833524d 100644 --- a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py +++ b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py @@ -2,10 +2,7 @@ from decimal import Decimal from unittest.mock import MagicMock, patch import pytest - -from core.llm_generator.output_parser.errors import OutputParserError -from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output -from dify_graph.model_runtime.entities.llm_entities import ( +from graphon.model_runtime.entities.llm_entities import ( LLMResult, LLMResultChunk, LLMResultChunkDelta, @@ -13,13 +10,16 @@ from dify_graph.model_runtime.entities.llm_entities import ( LLMResultWithStructuredOutput, LLMUsage, ) -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, SystemPromptMessage, TextPromptMessageContent, UserPromptMessage, ) -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType + +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output def create_mock_usage(prompt_tokens: int = 10, completion_tokens: int = 5) -> LLMUsage: @@ -336,7 +336,6 @@ def test_structured_output_parser(): json_schema=case["json_schema"], stream=case["stream"], model_parameters={"temperature": 0.7, "max_tokens": 100}, - user="test_user", ) if case["expected_result_type"] == "generator": @@ -367,7 +366,7 @@ def test_structured_output_parser(): call_args = model_instance.invoke_llm.call_args assert call_args.kwargs["stream"] == case["stream"] - assert call_args.kwargs["user"] == "test_user" + assert "user" not in call_args.kwargs assert "temperature" in call_args.kwargs["model_parameters"] assert "max_tokens" in call_args.kwargs["model_parameters"] diff --git a/api/tests/workflow_test_utils.py b/api/tests/workflow_test_utils.py index 1f0bf8ef376..d33ac2c7108 100644 --- a/api/tests/workflow_test_utils.py +++ b/api/tests/workflow_test_utils.py @@ -1,8 +1,12 @@ from collections.abc import Mapping from typing import Any +from graphon.entities import GraphInitParams +from graphon.runtime import VariablePool +from graphon.variables.variables import Variable + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context -from dify_graph.entities.graph_init_params import GraphInitParams +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool def build_test_run_context( @@ -51,3 +55,16 @@ def build_test_graph_init_params( ), call_depth=call_depth, ) + + +def build_test_variable_pool( + *, + variables: list[Variable] | tuple[Variable, ...] = (), + node_id: str | None = None, + inputs: Mapping[str, Any] | None = None, +) -> VariablePool: + variable_pool = VariablePool() + add_variables_to_pool(variable_pool, variables) + if node_id is not None and inputs is not None: + add_node_inputs_to_pool(variable_pool, node_id=node_id, inputs=inputs) + return variable_pool diff --git a/api/uv.lock b/api/uv.lock index ebfc6678fe4..3e8d794866a 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -169,12 +169,6 @@ version = "1.0.0" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/a0/87/1d7019d23891897cb076b2f7e3c81ab3c2ba91de3bb067196f675d60d34c/alibabacloud-credentials-api-1.0.0.tar.gz", hash = "sha256:8c340038d904f0218d7214a8f4088c31912bfcf279af2cbc7d9be4897a97dd2f", size = 2330, upload-time = "2025-01-13T05:53:04.931Z" } -[[package]] -name = "alibabacloud-endpoint-util" -version = "0.0.4" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/92/7d/8cc92a95c920e344835b005af6ea45a0db98763ad6ad19299d26892e6c8d/alibabacloud_endpoint_util-0.0.4.tar.gz", hash = "sha256:a593eb8ddd8168d5dc2216cd33111b144f9189fcd6e9ca20e48f358a739bbf90", size = 2813, upload-time = "2025-06-12T07:20:52.572Z" } - [[package]] name = "alibabacloud-gateway-spi" version = "0.0.3" @@ -186,69 +180,17 @@ sdist = { url = "https://files.pythonhosted.org/packages/ab/98/d7111245f17935bf7 [[package]] name = "alibabacloud-gpdb20160503" -version = "3.8.3" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "alibabacloud-endpoint-util" }, - { name = "alibabacloud-openapi-util" }, - { name = "alibabacloud-openplatform20191219" }, - { name = "alibabacloud-oss-sdk" }, - { name = "alibabacloud-oss-util" }, - { name = "alibabacloud-tea-fileform" }, - { name = "alibabacloud-tea-openapi" }, - { name = "alibabacloud-tea-util" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/15/6a/cc72e744e95c8f37fa6a84e66ae0b9b57a13ee97a0ef03d94c7127c31d75/alibabacloud_gpdb20160503-3.8.3.tar.gz", hash = "sha256:4dfcc0d9cff5a921d529d76f4bf97e2ceb9dc2fa53f00ab055f08509423d8e30", size = 155092, upload-time = "2024-07-18T17:09:42.438Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ab/36/bce41704b3bf59d607590ec73a42a254c5dea27c0f707aee11d20512a200/alibabacloud_gpdb20160503-3.8.3-py3-none-any.whl", hash = "sha256:06e1c46ce5e4e9d1bcae76e76e51034196c625799d06b2efec8d46a7df323fe8", size = 156097, upload-time = "2024-07-18T17:09:40.414Z" }, -] - -[[package]] -name = "alibabacloud-openapi-util" -version = "0.2.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "alibabacloud-tea-util" }, - { name = "cryptography" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/f6/50/5f41ab550d7874c623f6e992758429802c4b52a6804db437017e5387de33/alibabacloud_openapi_util-0.2.2.tar.gz", hash = "sha256:ebbc3906f554cb4bf8f513e43e8a33e8b6a3d4a0ef13617a0e14c3dda8ef52a8", size = 7201, upload-time = "2023-10-23T07:44:18.523Z" } - -[[package]] -name = "alibabacloud-openplatform20191219" -version = "2.0.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "alibabacloud-endpoint-util" }, - { name = "alibabacloud-openapi-util" }, - { name = "alibabacloud-tea-openapi" }, - { name = "alibabacloud-tea-util" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/4f/bf/f7fa2f3657ed352870f442434cb2f27b7f70dcd52a544a1f3998eeaf6d71/alibabacloud_openplatform20191219-2.0.0.tar.gz", hash = "sha256:e67f4c337b7542538746592c6a474bd4ae3a9edccdf62e11a32ca61fad3c9020", size = 5038, upload-time = "2022-09-21T06:16:10.683Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/94/e5/18c75213551eeca9db1f6b41ddcc0bd87b5b6508c75a67f05cd8671847b4/alibabacloud_openplatform20191219-2.0.0-py3-none-any.whl", hash = "sha256:873821c45bca72a6c6ec7a906c9cb21554c122e88893bbac3986934dab30dd36", size = 5204, upload-time = "2022-09-21T06:16:07.844Z" }, -] - -[[package]] -name = "alibabacloud-oss-sdk" -version = "0.1.1" +version = "5.1.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "alibabacloud-credentials" }, - { name = "alibabacloud-oss-util" }, - { name = "alibabacloud-tea-fileform" }, - { name = "alibabacloud-tea-util" }, - { name = "alibabacloud-tea-xml" }, + { name = "alibabacloud-tea-openapi" }, + { name = "darabonba-core" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/7e/d1/f442dd026908fcf55340ca694bb1d027aa91e119e76ae2fbea62f2bde4f4/alibabacloud_oss_sdk-0.1.1.tar.gz", hash = "sha256:f51a368020d0964fcc0978f96736006f49f5ab6a4a4bf4f0b8549e2c659e7358", size = 46434, upload-time = "2025-04-22T12:40:41.717Z" } - -[[package]] -name = "alibabacloud-oss-util" -version = "0.0.6" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "alibabacloud-tea" }, +sdist = { url = "https://files.pythonhosted.org/packages/b3/36/69333c7fb7fb5267f338371b14fdd8dbdd503717c97bbc7a6419d155ab4c/alibabacloud_gpdb20160503-5.1.0.tar.gz", hash = "sha256:086ec6d5e39b64f54d0e44bb3fd4fde1a4822a53eb9f6ff7464dff7d19b07b63", size = 295641, upload-time = "2026-03-19T10:09:02.444Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/68/7f/a91a2f9ad97c92fa9a6981587ea0ff789240cea05b17b17b7c244e5bac64/alibabacloud_gpdb20160503-5.1.0-py3-none-any.whl", hash = "sha256:580e4579285a54c7f04570782e0f60423a1997568684187fe88e4110acfb640e", size = 848784, upload-time = "2026-03-19T10:09:00.72Z" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/02/7c/d7e812b9968247a302573daebcfef95d0f9a718f7b4bfcca8d3d83e266be/alibabacloud_oss_util-0.0.6.tar.gz", hash = "sha256:d3ecec36632434bd509a113e8cf327dc23e830ac8d9dd6949926f4e334c8b5d6", size = 10008, upload-time = "2021-04-28T09:25:04.056Z" } [[package]] name = "alibabacloud-tea" @@ -260,18 +202,9 @@ dependencies = [ ] sdist = { url = "https://files.pythonhosted.org/packages/9a/7d/b22cb9a0d4f396ee0f3f9d7f26b76b9ed93d4101add7867a2c87ed2534f5/alibabacloud-tea-0.4.3.tar.gz", hash = "sha256:ec8053d0aa8d43ebe1deb632d5c5404339b39ec9a18a0707d57765838418504a", size = 8785, upload-time = "2025-03-24T07:34:42.958Z" } -[[package]] -name = "alibabacloud-tea-fileform" -version = "0.0.5" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "alibabacloud-tea" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/22/8a/ef8ddf5ee0350984cad2749414b420369fe943e15e6d96b79be45367630e/alibabacloud_tea_fileform-0.0.5.tar.gz", hash = "sha256:fd00a8c9d85e785a7655059e9651f9e91784678881831f60589172387b968ee8", size = 3961, upload-time = "2021-04-28T09:22:54.56Z" } - [[package]] name = "alibabacloud-tea-openapi" -version = "0.4.3" +version = "0.4.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "alibabacloud-credentials" }, @@ -280,9 +213,9 @@ dependencies = [ { name = "cryptography" }, { name = "darabonba-core" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/91/4f/b5288eea8f4d4b032c9a8f2cd1d926d5017977d10b874956f31e5343f299/alibabacloud_tea_openapi-0.4.3.tar.gz", hash = "sha256:12aef036ed993637b6f141abbd1de9d6199d5516f4a901588bb65d6a3768d41b", size = 21864, upload-time = "2026-01-15T07:55:16.744Z" } +sdist = { url = "https://files.pythonhosted.org/packages/30/93/138bcdc8fc596add73e37cf2073798f285284d1240bda9ee02f9384fc6be/alibabacloud_tea_openapi-0.4.4.tar.gz", hash = "sha256:1b0917bc03cd49417da64945e92731716d53e2eb8707b235f54e45b7473221ce", size = 21960, upload-time = "2026-03-26T10:16:16.792Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a5/37/48ee5468ecad19c6d44cf3b9629d77078e836ee3ec760f0366247f307b7c/alibabacloud_tea_openapi-0.4.3-py3-none-any.whl", hash = "sha256:d0b3a373b760ef6278b25fc128c73284301e07888977bf97519e7636d47bdf0a", size = 26159, upload-time = "2026-01-15T07:55:15.72Z" }, + { url = "https://files.pythonhosted.org/packages/f5/5a/6bfc4506438c1809c486f66217ad11eab78157192b3d5707b4e2f4212f6c/alibabacloud_tea_openapi-0.4.4-py3-none-any.whl", hash = "sha256:cea6bc1fe35b0319a8752cb99eb0ecb0dab7ca1a71b99c12970ba0867410995f", size = 26236, upload-time = "2026-03-26T10:16:15.861Z" }, ] [[package]] @@ -297,15 +230,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/72/9e/c394b4e2104766fb28a1e44e3ed36e4c7773b4d05c868e482be99d5635c9/alibabacloud_tea_util-0.3.14-py3-none-any.whl", hash = "sha256:10d3e5c340d8f7ec69dd27345eb2fc5a1dab07875742525edf07bbe86db93bfe", size = 6697, upload-time = "2025-11-19T06:01:07.355Z" }, ] -[[package]] -name = "alibabacloud-tea-xml" -version = "0.0.3" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "alibabacloud-tea" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/32/eb/5e82e419c3061823f3feae9b5681588762929dc4da0176667297c2784c1a/alibabacloud_tea_xml-0.0.3.tar.gz", hash = "sha256:979cb51fadf43de77f41c69fc69c12529728919f849723eb0cd24eb7b048a90c", size = 3466, upload-time = "2025-07-01T08:04:55.144Z" } - [[package]] name = "aliyun-log-python-sdk" version = "0.9.37" @@ -570,28 +494,29 @@ wheels = [ [[package]] name = "basedpyright" -version = "1.38.2" +version = "1.38.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nodejs-wheel-binaries" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e4/a3/20aa7c4e83f2f614e0036300f3c352775dede0655c66814da16c37b661a9/basedpyright-1.38.2.tar.gz", hash = "sha256:b433b2b8ba745ed7520cdc79a29a03682f3fb00346d272ece5944e9e5e5daa92", size = 25277019, upload-time = "2026-02-26T11:18:43.594Z" } +sdist = { url = "https://files.pythonhosted.org/packages/08/b4/26cb812eaf8ab56909c792c005fe1690706aef6f21d61107639e46e9c54c/basedpyright-1.38.4.tar.gz", hash = "sha256:8e7d4f37ffb6106621e06b9355025009cdf5b48f71c592432dd2dd304bf55e70", size = 25354730, upload-time = "2026-03-25T13:50:44.353Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ac/12/736cab83626fea3fe65cdafb3ef3d2ee9480c56723f2fd33921537289a5e/basedpyright-1.38.2-py3-none-any.whl", hash = "sha256:153481d37fd19f9e3adedc8629d1d071b10c5f5e49321fb026b74444b7c70e24", size = 12312475, upload-time = "2026-02-26T11:18:40.373Z" }, + { url = "https://files.pythonhosted.org/packages/62/0b/3f95fd47def42479e61077523d3752086d5c12009192a7f1c9fd5507e687/basedpyright-1.38.4-py3-none-any.whl", hash = "sha256:90aa067cf3e8a3c17ad5836a72b9e1f046bc72a4ad57d928473d9368c9cd07a2", size = 12352258, upload-time = "2026-03-25T13:50:41.059Z" }, ] [[package]] name = "bce-python-sdk" -version = "0.9.63" +version = "0.9.67" source = { registry = "https://pypi.org/simple" } dependencies = [ + { name = "crc32c" }, { name = "future" }, { name = "pycryptodome" }, { name = "six" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/8e/ab/4c2927b01a97562af6a296b722eee79658335795f341a395a12742d5e1a3/bce_python_sdk-0.9.63.tar.gz", hash = "sha256:0c80bc3ac128a0a144bae3b8dff1f397f42c30b36f7677e3a39d8df8e77b1088", size = 284419, upload-time = "2026-03-06T14:54:06.592Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b5/b9/5140cc02832fe3a7394c52949796d43f8c1f635aa016100f857f504e0348/bce_python_sdk-0.9.67.tar.gz", hash = "sha256:2c673d757c5c8952f1be6611da4ab77a63ecabaa3ff22b11531f46845ac99e58", size = 295251, upload-time = "2026-03-24T14:10:07.086Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/67/a4/501e978776c7060aa8ba77e68536597e754d938bcdbe1826618acebfbddf/bce_python_sdk-0.9.63-py3-none-any.whl", hash = "sha256:ec66eee8807c6aa4036412592da7e8c9e2cd7fdec494190986288ac2195d8276", size = 400305, upload-time = "2026-03-06T14:53:52.887Z" }, + { url = "https://files.pythonhosted.org/packages/d4/a9/a58a63e2756e5d01901595af58c673f68de7621f28d71007479e00f45a6c/bce_python_sdk-0.9.67-py3-none-any.whl", hash = "sha256:3054879d098a92ceeb4b9ac1e64d2c658120a5a10e8e630f22410564b2170bf0", size = 410854, upload-time = "2026-03-24T14:09:54.29Z" }, ] [[package]] @@ -660,14 +585,14 @@ wheels = [ [[package]] name = "bleach" -version = "6.2.0" +version = "6.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "webencodings" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/76/9a/0e33f5054c54d349ea62c277191c020c2d6ef1d65ab2cb1993f91ec846d1/bleach-6.2.0.tar.gz", hash = "sha256:123e894118b8a599fd80d3ec1a6d4cc7ce4e5882b1317a7e1ba69b56e95f991f", size = 203083, upload-time = "2024-10-29T18:30:40.477Z" } +sdist = { url = "https://files.pythonhosted.org/packages/07/18/3c8523962314be6bf4c8989c79ad9531c825210dd13a8669f6b84336e8bd/bleach-6.3.0.tar.gz", hash = "sha256:6f3b91b1c0a02bb9a78b5a454c92506aa0fdf197e1d5e114d2e00c6f64306d22", size = 203533, upload-time = "2025-10-27T17:57:39.211Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/fc/55/96142937f66150805c25c4d0f31ee4132fd33497753400734f9dfdcbdc66/bleach-6.2.0-py3-none-any.whl", hash = "sha256:117d9c6097a7c3d22fd578fcd8d35ff1e125df6736f554da4e432fdd63f31e5e", size = 163406, upload-time = "2024-10-29T18:30:38.186Z" }, + { url = "https://files.pythonhosted.org/packages/cd/3a/577b549de0cc09d95f11087ee63c739bba856cd3952697eec4c4bb91350a/bleach-6.3.0-py3-none-any.whl", hash = "sha256:fe10ec77c93ddf3d13a73b035abaac7a9f5e436513864ccdad516693213c65d6", size = 164437, upload-time = "2025-10-27T17:57:37.538Z" }, ] [[package]] @@ -706,30 +631,30 @@ wheels = [ [[package]] name = "boto3" -version = "1.42.68" +version = "1.42.78" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "botocore" }, { name = "jmespath" }, { name = "s3transfer" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/06/ae/60c642aa5413e560b671da825329f510b29a77274ed0f580bde77562294d/boto3-1.42.68.tar.gz", hash = "sha256:3f349f967ab38c23425626d130962bcb363e75f042734fe856ea8c5a00eef03c", size = 112761, upload-time = "2026-03-13T19:32:17.137Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a8/2b/ebdad075934cf6bb78bf81fe31d83339bcd804ad6c856f7341376cbc88b6/boto3-1.42.78.tar.gz", hash = "sha256:cef2ebdb9be5c0e96822f8d3941ac4b816c90a5737a7ffb901d664c808964b63", size = 112789, upload-time = "2026-03-27T19:28:07.58Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/fb/f6/dc6e993479dbb597d68223fbf61cb026511737696b15bd7d2a33e9b2c24f/boto3-1.42.68-py3-none-any.whl", hash = "sha256:dbff353eb7dc93cbddd7926ed24793e0174c04adbe88860dfa639568442e4962", size = 140556, upload-time = "2026-03-13T19:32:14.951Z" }, + { url = "https://files.pythonhosted.org/packages/57/bb/1f6dade1f1e86858bef7bd332bc8106c445f2dbabec7b32ab5d7d118c9b6/boto3-1.42.78-py3-none-any.whl", hash = "sha256:480a34a077484a5ca60124dfd150ba3ea6517fc89963a679e45b30c6db614d26", size = 140556, upload-time = "2026-03-27T19:28:06.125Z" }, ] [[package]] name = "boto3-stubs" -version = "1.42.68" +version = "1.42.78" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "botocore-stubs" }, { name = "types-s3transfer" }, { name = "typing-extensions", marker = "python_full_version < '3.12'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/4c/8c/dd4b0c95ff008bed5a35ab411452ece121b355539d2a0b6dcd62a0c47be5/boto3_stubs-1.42.68.tar.gz", hash = "sha256:96ad1020735619483fb9b4da7a5e694b460bf2e18f84a34d5d175d0ffe8c4653", size = 101372, upload-time = "2026-03-13T19:49:54.867Z" } +sdist = { url = "https://files.pythonhosted.org/packages/03/16/4bdb3c1f69bf7b97dd8b22fe5b007e9da67ba3f00ed10e47146f5fd9d0ff/boto3_stubs-1.42.78.tar.gz", hash = "sha256:423335b8ce9a935e404054978589cdb98d9fa1d4bd46073d6821bf1c3fad8ca7", size = 101602, upload-time = "2026-03-27T19:35:51.149Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/68/15/3ca5848917214a168134512a5b45f856a56e913659888947a052e02031b5/boto3_stubs-1.42.68-py3-none-any.whl", hash = "sha256:ed7f98334ef7b2377fa8532190e63dc2c6d1dc895e3d7cb3d6d1c83771b81bf6", size = 70011, upload-time = "2026-03-13T19:49:42.801Z" }, + { url = "https://files.pythonhosted.org/packages/22/d5/bdedd4951c795899ac5a1f0b88d81b9e2c6333cb87457f2edd11ef3b7b7b/boto3_stubs-1.42.78-py3-none-any.whl", hash = "sha256:6ed07e734174751da8d01031d9ede8d81a88e4338d9e6b00ce7a6bc870075372", size = 70161, upload-time = "2026-03-27T19:35:46.336Z" }, ] [package.optional-dependencies] @@ -739,16 +664,16 @@ bedrock-runtime = [ [[package]] name = "botocore" -version = "1.42.68" +version = "1.42.78" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jmespath" }, { name = "python-dateutil" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/3f/22/87502d5fbbfa8189406a617b30b1e2a3dc0ab2669f7268e91b385c1c1c7a/botocore-1.42.68.tar.gz", hash = "sha256:3951c69e12ac871dda245f48dac5c7dd88ea1bfdd74a8879ec356cf2874b806a", size = 14994514, upload-time = "2026-03-13T19:32:03.577Z" } +sdist = { url = "https://files.pythonhosted.org/packages/67/8e/cdb34c8ca71216d214e049ada2148ee08bcda12b1ac72af3a720dea300ff/botocore-1.42.78.tar.gz", hash = "sha256:61cbd49728e23f68cfd945406ab40044d49abed143362f7ffa4a4f4bd4311791", size = 15023592, upload-time = "2026-03-27T19:27:57.122Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/3c/2a/1428f6594799780fe6ee845d8e6aeffafe026cd16a70c878684e2dcbbfc8/botocore-1.42.68-py3-none-any.whl", hash = "sha256:9df7da26374601f890e2f115bfa573d65bf15b25fe136bb3aac809f6145f52ab", size = 14668816, upload-time = "2026-03-13T19:31:58.572Z" }, + { url = "https://files.pythonhosted.org/packages/54/72/94bba1a375d45c685b00e051b56142359547837086a83861d76f6aec26f4/botocore-1.42.78-py3-none-any.whl", hash = "sha256:038ab63c7f898e8b5db58cb6a45e4da56c31dd984e7e995839a3540c735564ea", size = 14701729, upload-time = "2026-03-27T19:27:54.05Z" }, ] [[package]] @@ -1138,7 +1063,7 @@ wheels = [ [[package]] name = "clickhouse-connect" -version = "0.14.1" +version = "0.15.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "certifi" }, @@ -1147,24 +1072,24 @@ dependencies = [ { name = "urllib3" }, { name = "zstandard" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f5/0e/96958db88b6ce6e9d96dc7a836f12c7644934b3a436b04843f19eb8da2db/clickhouse_connect-0.14.1.tar.gz", hash = "sha256:dc107ae9ab7b86409049ae8abe21817543284b438291796d3dd639ad5496a1ab", size = 120093, upload-time = "2026-03-12T15:51:03.606Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ec/59/c0b0a2c2e4c204e5baeca4917a95cc95add651da3cec86ec464a8e54cfa0/clickhouse_connect-0.15.0.tar.gz", hash = "sha256:529fcf072df335d18ae16339d99389190f4bd543067dcdc174541c7a9c622ef5", size = 126344, upload-time = "2026-03-26T18:34:52.316Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/66/b0/04bc82ca70d4dcc35987c83e4ef04f6dec3c29d3cce4cda3523ebf4498dc/clickhouse_connect-0.14.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f2b1d1acb8f64c3cd9d922d9e8c0b6328238c4a38e084598c86cc95a0edbd8bd", size = 278797, upload-time = "2026-03-12T15:49:34.728Z" }, - { url = "https://files.pythonhosted.org/packages/97/03/f8434ed43946dcab2d8b4ccf8e90b1c6d69abea0fa8b8aaddb1dc9931657/clickhouse_connect-0.14.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:573f3e5a6b49135b711c086050f46510d4738cc09e5a354cc18ef26f8de5cd98", size = 271849, upload-time = "2026-03-12T15:49:35.881Z" }, - { url = "https://files.pythonhosted.org/packages/a0/db/b3665f4d855c780be8d00638d874fc0d62613d1f1c06ffcad7c11a333f06/clickhouse_connect-0.14.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:86b28932faab182a312779e5c3cf341abe19d31028a399bda9d8b06b3b9adab4", size = 1090975, upload-time = "2026-03-12T15:49:37.064Z" }, - { url = "https://files.pythonhosted.org/packages/ea/a2/7ba2d9669c5771734573397b034169653cdf3348dc4cc66bd66d8ab18910/clickhouse_connect-0.14.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bfc9650906ff96452c2b5676a7e68e8a77a5642504596f8482e0f3c0ccdffbf1", size = 1095899, upload-time = "2026-03-12T15:49:38.36Z" }, - { url = "https://files.pythonhosted.org/packages/e2/f4/0394af37b491ca832610f2ca7a129e85d8d857d40c94a42f2c2e6d3d9481/clickhouse_connect-0.14.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:b379749a962599f9d6ec81e773a3b907ac58b001f4a977e4ac397f6a76fedff2", size = 1077567, upload-time = "2026-03-12T15:49:40.027Z" }, - { url = "https://files.pythonhosted.org/packages/9a/b8/9279a88afac94c262b55cc75aadc6a3e83f7fa1641e618f9060d9d38415f/clickhouse_connect-0.14.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:43ccb5debd13d41b97af81940c0cac01e92d39f17131d984591bedee13439a5d", size = 1100264, upload-time = "2026-03-12T15:49:41.414Z" }, - { url = "https://files.pythonhosted.org/packages/19/36/20e19ab392c211b83c967e275eb46f663853e0b8ce4da89056fda8a35fc6/clickhouse_connect-0.14.1-cp311-cp311-win32.whl", hash = "sha256:13cbe46c04be8e49da4f6aed698f2570a5295d15f498dd5511b4f761d1ef0edc", size = 250488, upload-time = "2026-03-12T15:49:42.649Z" }, - { url = "https://files.pythonhosted.org/packages/9d/3b/74a07e692a21cad4692e72595cdefbd709bd74a9f778c7334d57a98ee548/clickhouse_connect-0.14.1-cp311-cp311-win_amd64.whl", hash = "sha256:7038cf547c542a17a465e062cd837659f46f99c991efcb010a9ea08ce70960ab", size = 268730, upload-time = "2026-03-12T15:49:44.225Z" }, - { url = "https://files.pythonhosted.org/packages/58/9e/d84a14241967b3aa1e657bbbee83e2eee02d3d6df1ebe8edd4ed72cd8643/clickhouse_connect-0.14.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:97665169090889a8bc4dbae4a5fc758b91a23e49a8f8ddc1ae993f18f6d71e02", size = 280679, upload-time = "2026-03-12T15:49:45.497Z" }, - { url = "https://files.pythonhosted.org/packages/d8/29/80835a980be6298a7a2ae42d5a14aab0c9c066ecafe1763bc1958a6f6f0f/clickhouse_connect-0.14.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3ee6b513ca7d83e0f7b46d87bc2e48260316431cb466680e3540400379bcd1db", size = 271570, upload-time = "2026-03-12T15:49:46.721Z" }, - { url = "https://files.pythonhosted.org/packages/8b/bf/25c17cb91d72143742d2b060c6954e8000a7753c1fd21f7bf8b49ef2bd89/clickhouse_connect-0.14.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2a0e8a3f46aba99f1c574927d196e12f1ee689e31c41bf0caec86ad3e181abf3", size = 1115637, upload-time = "2026-03-12T15:49:47.921Z" }, - { url = "https://files.pythonhosted.org/packages/2d/5f/5d5df3585d98889aedc55c9eeb2ea90dba27ec4329eee392101619daf0c0/clickhouse_connect-0.14.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:25698cddcdd6c2e4ea12dc5c56d6035d77fc99c5d75e96a54123826c36fdd8ae", size = 1131995, upload-time = "2026-03-12T15:49:49.791Z" }, - { url = "https://files.pythonhosted.org/packages/ad/50/acc9f4c6a1d712f2ed11626f8451eff222e841cf0809655362f0e90454b6/clickhouse_connect-0.14.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:29ab49e5cac44b830b58de73d17a7d895f6c362bf67a50134ff405b428774f44", size = 1095380, upload-time = "2026-03-12T15:49:51.388Z" }, - { url = "https://files.pythonhosted.org/packages/08/18/1ef01beee93d243ec9d9c37f0ce62b3083478a5dd7f59cc13279600cd3a5/clickhouse_connect-0.14.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3cbf7d7a134692bacd68dd5f8661e87f5db94af60db9f3a74bd732596794910a", size = 1127217, upload-time = "2026-03-12T15:49:53.016Z" }, - { url = "https://files.pythonhosted.org/packages/18/e2/b4daee8287dc49eb9918c77b1e57f5644e47008f719b77281bf5fca63f6e/clickhouse_connect-0.14.1-cp312-cp312-win32.whl", hash = "sha256:6f295b66f3e2ed931dd0d3bb80e00ee94c6f4a584b2dc6d998872b2e0ceaa706", size = 250775, upload-time = "2026-03-12T15:49:54.639Z" }, - { url = "https://files.pythonhosted.org/packages/01/c7/7b55d346952fcd8f0f491faca4449f607a04764fd23cada846dc93facb9e/clickhouse_connect-0.14.1-cp312-cp312-win_amd64.whl", hash = "sha256:c6bb2cce37041c90f8a3b1b380665acbaf252f125e401c13ce8f8df105378f69", size = 269353, upload-time = "2026-03-12T15:49:55.854Z" }, + { url = "https://files.pythonhosted.org/packages/83/b0/bf4a169a1b4e5e19f5e884596937ce13855146a3f4b3225228a87701fd18/clickhouse_connect-0.15.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4f0928fdfb408d314c0e5151caf30b1c3bd56c2812ffdbc8d262fb60c0e7ab28", size = 284805, upload-time = "2026-03-26T18:33:18.659Z" }, + { url = "https://files.pythonhosted.org/packages/ec/d5/63dd572db91bd5e1231d7b7dc63591c52ffbbf653a57f9b8449681815976/clickhouse_connect-0.15.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6486b02825ac87f57811710e5a9a2da8531bb3c88bcb154fd5c7378742a33d66", size = 277846, upload-time = "2026-03-26T18:33:20.171Z" }, + { url = "https://files.pythonhosted.org/packages/e4/d6/192130a807de130945cc451e17c89ac6183625b8028026e5a4a7fc46fa59/clickhouse_connect-0.15.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9f2df9c2fd97b40c6493232e0cbf516d8ba268165c6161851ef15f4f1fd0456e", size = 1096969, upload-time = "2026-03-26T18:33:21.728Z" }, + { url = "https://files.pythonhosted.org/packages/32/46/f2895cc4240ef45a2a274d4323f6858c0860034efe6c9a1c7168f1d8cecd/clickhouse_connect-0.15.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a5a349d19c63abb49c884afe0a0387823045831f005451e85c09c032f953f1c1", size = 1101890, upload-time = "2026-03-26T18:33:23.038Z" }, + { url = "https://files.pythonhosted.org/packages/e8/69/dcecbca254b45525ad3fd8294441ac9cf8a8a8bd1fa8fd6b93e241b377a3/clickhouse_connect-0.15.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:4d80205cbdbface6d2f35fbd65a6f85caf2b59ec65f2e9dd190f11e335fe7316", size = 1083561, upload-time = "2026-03-26T18:33:24.64Z" }, + { url = "https://files.pythonhosted.org/packages/69/10/21f0cb98453d9710aaeb92f9a9e156e909c1ac72e57210a48b0f615916a7/clickhouse_connect-0.15.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c3c84dfebf49ec7a2cd9ac31c46986f7a81b43ea781d23ef7d607907fcc6de5d", size = 1106257, upload-time = "2026-03-26T18:33:26.257Z" }, + { url = "https://files.pythonhosted.org/packages/70/91/ae0f5c8df5dc650f1ab327d4b40cde7e18bf9e8b3507764dce320c328092/clickhouse_connect-0.15.0-cp311-cp311-win32.whl", hash = "sha256:d2bbdccf9cd838b990576d3f7d1e6a0ab5c3a5c8eb830394258b7b225531fe74", size = 256591, upload-time = "2026-03-26T18:33:27.869Z" }, + { url = "https://files.pythonhosted.org/packages/e6/7f/85673ff522554ef76e17b5d267816c199a731fde836ef957b0960655f251/clickhouse_connect-0.15.0-cp311-cp311-win_amd64.whl", hash = "sha256:1c4223d557bc0a3919cb7ce0d749d9091123b6e61341e028ffc09b7f9c847ac2", size = 274778, upload-time = "2026-03-26T18:33:29.02Z" }, + { url = "https://files.pythonhosted.org/packages/f5/be/86e149c60822caed29e4435acac4fc73e20fddfb0b56ea6452bc7a08ab10/clickhouse_connect-0.15.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d51f49694e9007564bfd8dac51a1f9e60b94d6c93a07eb4027113a2e62bbb384", size = 286680, upload-time = "2026-03-26T18:33:30.219Z" }, + { url = "https://files.pythonhosted.org/packages/aa/65/c38cc5028afa2ccd9e8ff65611434063c0c5c1b6edadc507dbbc80a09bfd/clickhouse_connect-0.15.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6a48fbad9ebc2b6d1cd01d1f9b5d6740081f1c84f1aacc9f91651be949f6b6ed", size = 277579, upload-time = "2026-03-26T18:33:31.474Z" }, + { url = "https://files.pythonhosted.org/packages/0a/ef/c8b2ef597fefd04e8b7c017c991552162cb89b7cb73bfdd6225b1c79e2fe/clickhouse_connect-0.15.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:36e1ae470b94cc56d270461c8626c8fd4dac16e6c1ffa8477f21c012462e22cf", size = 1121630, upload-time = "2026-03-26T18:33:32.983Z" }, + { url = "https://files.pythonhosted.org/packages/de/f7/1b71819e825d44582c014a489618170b03ccdac3c9b710dfd56445f1c017/clickhouse_connect-0.15.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fa97f0ae8eb069a451d8577342dffeef5dc308a0eac7dba1809008c761e720c7", size = 1137988, upload-time = "2026-03-26T18:33:34.585Z" }, + { url = "https://files.pythonhosted.org/packages/7f/1f/41002b8d5ff146dc2835dc6b6f690bc361bd9a94b6195872abcb922f3788/clickhouse_connect-0.15.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:b5b3baf70009174a4df9c8356c96d03e1c2dbf0d8b29f1b3270a641a59399b61", size = 1101376, upload-time = "2026-03-26T18:33:36.258Z" }, + { url = "https://files.pythonhosted.org/packages/2c/8a/bd090dab73fc9c47efcaaeb152a77610b9d233cd88ea73cf4535f9bac2a6/clickhouse_connect-0.15.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:af3fba93fd2efa8f856f3a88a6a710e06005fa48b6b6b0f116d462a4021957e2", size = 1133211, upload-time = "2026-03-26T18:33:38.003Z" }, + { url = "https://files.pythonhosted.org/packages/f1/8d/cf4eee7225bdee85a9b8a88c5bfff42ce48f37ee9277930ac8bc76f47126/clickhouse_connect-0.15.0-cp312-cp312-win32.whl", hash = "sha256:86ca76f8acaf7f3f6530e3e4139e174d54c4674910c69f4277d1b9cdf7c1cc98", size = 256767, upload-time = "2026-03-26T18:33:39.55Z" }, + { url = "https://files.pythonhosted.org/packages/26/6e/f5a2cb1e4624dfd77c1e226239360a9e3690db8056a0027bda2ab87d0085/clickhouse_connect-0.15.0-cp312-cp312-win_amd64.whl", hash = "sha256:5a471d9a9cf06f0a4e90784547b6a2acb066b0d8642dfea9866960c4bdde6959", size = 275404, upload-time = "2026-03-26T18:33:40.885Z" }, ] [[package]] @@ -1290,41 +1215,41 @@ wheels = [ [[package]] name = "coverage" -version = "7.13.4" +version = "7.13.5" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/24/56/95b7e30fa389756cb56630faa728da46a27b8c6eb46f9d557c68fff12b65/coverage-7.13.4.tar.gz", hash = "sha256:e5c8f6ed1e61a8b2dcdf31eb0b9bbf0130750ca79c1c49eb898e2ad86f5ccc91", size = 827239, upload-time = "2026-02-09T12:59:03.86Z" } +sdist = { url = "https://files.pythonhosted.org/packages/9d/e0/70553e3000e345daff267cec284ce4cbf3fc141b6da229ac52775b5428f1/coverage-7.13.5.tar.gz", hash = "sha256:c81f6515c4c40141f83f502b07bbfa5c240ba25bbe73da7b33f1e5b6120ff179", size = 915967, upload-time = "2026-03-17T10:33:18.341Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b4/ad/b59e5b451cf7172b8d1043dc0fa718f23aab379bc1521ee13d4bd9bfa960/coverage-7.13.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d490ba50c3f35dd7c17953c68f3270e7ccd1c6642e2d2afe2d8e720b98f5a053", size = 219278, upload-time = "2026-02-09T12:56:31.673Z" }, - { url = "https://files.pythonhosted.org/packages/f1/17/0cb7ca3de72e5f4ef2ec2fa0089beafbcaaaead1844e8b8a63d35173d77d/coverage-7.13.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:19bc3c88078789f8ef36acb014d7241961dbf883fd2533d18cb1e7a5b4e28b11", size = 219783, upload-time = "2026-02-09T12:56:33.104Z" }, - { url = "https://files.pythonhosted.org/packages/ab/63/325d8e5b11e0eaf6d0f6a44fad444ae58820929a9b0de943fa377fe73e85/coverage-7.13.4-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:3998e5a32e62fdf410c0dbd3115df86297995d6e3429af80b8798aad894ca7aa", size = 250200, upload-time = "2026-02-09T12:56:34.474Z" }, - { url = "https://files.pythonhosted.org/packages/76/53/c16972708cbb79f2942922571a687c52bd109a7bd51175aeb7558dff2236/coverage-7.13.4-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:8e264226ec98e01a8e1054314af91ee6cde0eacac4f465cc93b03dbe0bce2fd7", size = 252114, upload-time = "2026-02-09T12:56:35.749Z" }, - { url = "https://files.pythonhosted.org/packages/eb/c2/7ab36d8b8cc412bec9ea2d07c83c48930eb4ba649634ba00cb7e4e0f9017/coverage-7.13.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a3aa4e7b9e416774b21797365b358a6e827ffadaaca81b69ee02946852449f00", size = 254220, upload-time = "2026-02-09T12:56:37.796Z" }, - { url = "https://files.pythonhosted.org/packages/d6/4d/cf52c9a3322c89a0e6febdfbc83bb45c0ed3c64ad14081b9503adee702e7/coverage-7.13.4-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:71ca20079dd8f27fcf808817e281e90220475cd75115162218d0e27549f95fef", size = 256164, upload-time = "2026-02-09T12:56:39.016Z" }, - { url = "https://files.pythonhosted.org/packages/78/e9/eb1dd17bd6de8289df3580e967e78294f352a5df8a57ff4671ee5fc3dcd0/coverage-7.13.4-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:e2f25215f1a359ab17320b47bcdaca3e6e6356652e8256f2441e4ef972052903", size = 250325, upload-time = "2026-02-09T12:56:40.668Z" }, - { url = "https://files.pythonhosted.org/packages/71/07/8c1542aa873728f72267c07278c5cc0ec91356daf974df21335ccdb46368/coverage-7.13.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d65b2d373032411e86960604dc4edac91fdfb5dca539461cf2cbe78327d1e64f", size = 251913, upload-time = "2026-02-09T12:56:41.97Z" }, - { url = "https://files.pythonhosted.org/packages/74/d7/c62e2c5e4483a748e27868e4c32ad3daa9bdddbba58e1bc7a15e252baa74/coverage-7.13.4-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:94eb63f9b363180aff17de3e7c8760c3ba94664ea2695c52f10111244d16a299", size = 249974, upload-time = "2026-02-09T12:56:43.323Z" }, - { url = "https://files.pythonhosted.org/packages/98/9f/4c5c015a6e98ced54efd0f5cf8d31b88e5504ecb6857585fc0161bb1e600/coverage-7.13.4-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:e856bf6616714c3a9fbc270ab54103f4e685ba236fa98c054e8f87f266c93505", size = 253741, upload-time = "2026-02-09T12:56:45.155Z" }, - { url = "https://files.pythonhosted.org/packages/bd/59/0f4eef89b9f0fcd9633b5d350016f54126ab49426a70ff4c4e87446cabdc/coverage-7.13.4-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:65dfcbe305c3dfe658492df2d85259e0d79ead4177f9ae724b6fb245198f55d6", size = 249695, upload-time = "2026-02-09T12:56:46.636Z" }, - { url = "https://files.pythonhosted.org/packages/b5/2c/b7476f938deb07166f3eb281a385c262675d688ff4659ad56c6c6b8e2e70/coverage-7.13.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b507778ae8a4c915436ed5c2e05b4a6cecfa70f734e19c22a005152a11c7b6a9", size = 250599, upload-time = "2026-02-09T12:56:48.13Z" }, - { url = "https://files.pythonhosted.org/packages/b8/34/c3420709d9846ee3785b9f2831b4d94f276f38884032dca1457fa83f7476/coverage-7.13.4-cp311-cp311-win32.whl", hash = "sha256:784fc3cf8be001197b652d51d3fd259b1e2262888693a4636e18879f613a62a9", size = 221780, upload-time = "2026-02-09T12:56:50.479Z" }, - { url = "https://files.pythonhosted.org/packages/61/08/3d9c8613079d2b11c185b865de9a4c1a68850cfda2b357fae365cf609f29/coverage-7.13.4-cp311-cp311-win_amd64.whl", hash = "sha256:2421d591f8ca05b308cf0092807308b2facbefe54af7c02ac22548b88b95c98f", size = 222715, upload-time = "2026-02-09T12:56:51.815Z" }, - { url = "https://files.pythonhosted.org/packages/18/1a/54c3c80b2f056164cc0a6cdcb040733760c7c4be9d780fe655f356f433e4/coverage-7.13.4-cp311-cp311-win_arm64.whl", hash = "sha256:79e73a76b854d9c6088fe5d8b2ebe745f8681c55f7397c3c0a016192d681045f", size = 221385, upload-time = "2026-02-09T12:56:53.194Z" }, - { url = "https://files.pythonhosted.org/packages/d1/81/4ce2fdd909c5a0ed1f6dedb88aa57ab79b6d1fbd9b588c1ac7ef45659566/coverage-7.13.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:02231499b08dabbe2b96612993e5fc34217cdae907a51b906ac7fca8027a4459", size = 219449, upload-time = "2026-02-09T12:56:54.889Z" }, - { url = "https://files.pythonhosted.org/packages/5d/96/5238b1efc5922ddbdc9b0db9243152c09777804fb7c02ad1741eb18a11c0/coverage-7.13.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40aa8808140e55dc022b15d8aa7f651b6b3d68b365ea0398f1441e0b04d859c3", size = 219810, upload-time = "2026-02-09T12:56:56.33Z" }, - { url = "https://files.pythonhosted.org/packages/78/72/2f372b726d433c9c35e56377cf1d513b4c16fe51841060d826b95caacec1/coverage-7.13.4-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:5b856a8ccf749480024ff3bd7310adaef57bf31fd17e1bfc404b7940b6986634", size = 251308, upload-time = "2026-02-09T12:56:57.858Z" }, - { url = "https://files.pythonhosted.org/packages/5d/a0/2ea570925524ef4e00bb6c82649f5682a77fac5ab910a65c9284de422600/coverage-7.13.4-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:2c048ea43875fbf8b45d476ad79f179809c590ec7b79e2035c662e7afa3192e3", size = 254052, upload-time = "2026-02-09T12:56:59.754Z" }, - { url = "https://files.pythonhosted.org/packages/e8/ac/45dc2e19a1939098d783c846e130b8f862fbb50d09e0af663988f2f21973/coverage-7.13.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b7b38448866e83176e28086674fe7368ab8590e4610fb662b44e345b86d63ffa", size = 255165, upload-time = "2026-02-09T12:57:01.287Z" }, - { url = "https://files.pythonhosted.org/packages/2d/4d/26d236ff35abc3b5e63540d3386e4c3b192168c1d96da5cb2f43c640970f/coverage-7.13.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:de6defc1c9badbf8b9e67ae90fd00519186d6ab64e5cc5f3d21359c2a9b2c1d3", size = 257432, upload-time = "2026-02-09T12:57:02.637Z" }, - { url = "https://files.pythonhosted.org/packages/ec/55/14a966c757d1348b2e19caf699415a2a4c4f7feaa4bbc6326a51f5c7dd1b/coverage-7.13.4-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:7eda778067ad7ffccd23ecffce537dface96212576a07924cbf0d8799d2ded5a", size = 251716, upload-time = "2026-02-09T12:57:04.056Z" }, - { url = "https://files.pythonhosted.org/packages/77/33/50116647905837c66d28b2af1321b845d5f5d19be9655cb84d4a0ea806b4/coverage-7.13.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e87f6c587c3f34356c3759f0420693e35e7eb0e2e41e4c011cb6ec6ecbbf1db7", size = 253089, upload-time = "2026-02-09T12:57:05.503Z" }, - { url = "https://files.pythonhosted.org/packages/c2/b4/8efb11a46e3665d92635a56e4f2d4529de6d33f2cb38afd47d779d15fc99/coverage-7.13.4-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:8248977c2e33aecb2ced42fef99f2d319e9904a36e55a8a68b69207fb7e43edc", size = 251232, upload-time = "2026-02-09T12:57:06.879Z" }, - { url = "https://files.pythonhosted.org/packages/51/24/8cd73dd399b812cc76bb0ac260e671c4163093441847ffe058ac9fda1e32/coverage-7.13.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:25381386e80ae727608e662474db537d4df1ecd42379b5ba33c84633a2b36d47", size = 255299, upload-time = "2026-02-09T12:57:08.245Z" }, - { url = "https://files.pythonhosted.org/packages/03/94/0a4b12f1d0e029ce1ccc1c800944a9984cbe7d678e470bb6d3c6bc38a0da/coverage-7.13.4-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:ee756f00726693e5ba94d6df2bdfd64d4852d23b09bb0bc700e3b30e6f333985", size = 250796, upload-time = "2026-02-09T12:57:10.142Z" }, - { url = "https://files.pythonhosted.org/packages/73/44/6002fbf88f6698ca034360ce474c406be6d5a985b3fdb3401128031eef6b/coverage-7.13.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fdfc1e28e7c7cdce44985b3043bc13bbd9c747520f94a4d7164af8260b3d91f0", size = 252673, upload-time = "2026-02-09T12:57:12.197Z" }, - { url = "https://files.pythonhosted.org/packages/de/c6/a0279f7c00e786be75a749a5674e6fa267bcbd8209cd10c9a450c655dfa7/coverage-7.13.4-cp312-cp312-win32.whl", hash = "sha256:01d4cbc3c283a17fc1e42d614a119f7f438eabb593391283adca8dc86eff1246", size = 221990, upload-time = "2026-02-09T12:57:14.085Z" }, - { url = "https://files.pythonhosted.org/packages/77/4e/c0a25a425fcf5557d9abd18419c95b63922e897bc86c1f327f155ef234a9/coverage-7.13.4-cp312-cp312-win_amd64.whl", hash = "sha256:9401ebc7ef522f01d01d45532c68c5ac40fb27113019b6b7d8b208f6e9baa126", size = 222800, upload-time = "2026-02-09T12:57:15.944Z" }, - { url = "https://files.pythonhosted.org/packages/47/ac/92da44ad9a6f4e3a7debd178949d6f3769bedca33830ce9b1dcdab589a37/coverage-7.13.4-cp312-cp312-win_arm64.whl", hash = "sha256:b1ec7b6b6e93255f952e27ab58fbc68dcc468844b16ecbee881aeb29b6ab4d8d", size = 221415, upload-time = "2026-02-09T12:57:17.497Z" }, - { url = "https://files.pythonhosted.org/packages/0d/4a/331fe2caf6799d591109bb9c08083080f6de90a823695d412a935622abb2/coverage-7.13.4-py3-none-any.whl", hash = "sha256:1af1641e57cf7ba1bd67d677c9abdbcd6cc2ab7da3bca7fa1e2b7e50e65f2ad0", size = 211242, upload-time = "2026-02-09T12:59:02.032Z" }, + { url = "https://files.pythonhosted.org/packages/4b/37/d24c8f8220ff07b839b2c043ea4903a33b0f455abe673ae3c03bbdb7f212/coverage-7.13.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:66a80c616f80181f4d643b0f9e709d97bcea413ecd9631e1dedc7401c8e6695d", size = 219381, upload-time = "2026-03-17T10:30:14.68Z" }, + { url = "https://files.pythonhosted.org/packages/35/8b/cd129b0ca4afe886a6ce9d183c44d8301acbd4ef248622e7c49a23145605/coverage-7.13.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:145ede53ccbafb297c1c9287f788d1bc3efd6c900da23bf6931b09eafc931587", size = 219880, upload-time = "2026-03-17T10:30:16.231Z" }, + { url = "https://files.pythonhosted.org/packages/55/2f/e0e5b237bffdb5d6c530ce87cc1d413a5b7d7dfd60fb067ad6d254c35c76/coverage-7.13.5-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:0672854dc733c342fa3e957e0605256d2bf5934feeac328da9e0b5449634a642", size = 250303, upload-time = "2026-03-17T10:30:17.748Z" }, + { url = "https://files.pythonhosted.org/packages/92/be/b1afb692be85b947f3401375851484496134c5554e67e822c35f28bf2fbc/coverage-7.13.5-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:ec10e2a42b41c923c2209b846126c6582db5e43a33157e9870ba9fb70dc7854b", size = 252218, upload-time = "2026-03-17T10:30:19.804Z" }, + { url = "https://files.pythonhosted.org/packages/da/69/2f47bb6fa1b8d1e3e5d0c4be8ccb4313c63d742476a619418f85740d597b/coverage-7.13.5-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:be3d4bbad9d4b037791794ddeedd7d64a56f5933a2c1373e18e9e568b9141686", size = 254326, upload-time = "2026-03-17T10:30:21.321Z" }, + { url = "https://files.pythonhosted.org/packages/d5/d0/79db81da58965bd29dabc8f4ad2a2af70611a57cba9d1ec006f072f30a54/coverage-7.13.5-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:4d2afbc5cc54d286bfb54541aa50b64cdb07a718227168c87b9e2fb8f25e1743", size = 256267, upload-time = "2026-03-17T10:30:23.094Z" }, + { url = "https://files.pythonhosted.org/packages/e5/32/d0d7cc8168f91ddab44c0ce4806b969df5f5fdfdbb568eaca2dbc2a04936/coverage-7.13.5-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:3ad050321264c49c2fa67bb599100456fc51d004b82534f379d16445da40fb75", size = 250430, upload-time = "2026-03-17T10:30:25.311Z" }, + { url = "https://files.pythonhosted.org/packages/4d/06/a055311d891ddbe231cd69fdd20ea4be6e3603ffebddf8704b8ca8e10a3c/coverage-7.13.5-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:7300c8a6d13335b29bb76d7651c66af6bd8658517c43499f110ddc6717bfc209", size = 252017, upload-time = "2026-03-17T10:30:27.284Z" }, + { url = "https://files.pythonhosted.org/packages/d6/f6/d0fd2d21e29a657b5f77a2fe7082e1568158340dceb941954f776dce1b7b/coverage-7.13.5-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:eb07647a5738b89baab047f14edd18ded523de60f3b30e75c2acc826f79c839a", size = 250080, upload-time = "2026-03-17T10:30:29.481Z" }, + { url = "https://files.pythonhosted.org/packages/4e/ab/0d7fb2efc2e9a5eb7ddcc6e722f834a69b454b7e6e5888c3a8567ecffb31/coverage-7.13.5-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:9adb6688e3b53adffefd4a52d72cbd8b02602bfb8f74dcd862337182fd4d1a4e", size = 253843, upload-time = "2026-03-17T10:30:31.301Z" }, + { url = "https://files.pythonhosted.org/packages/ba/6f/7467b917bbf5408610178f62a49c0ed4377bb16c1657f689cc61470da8ce/coverage-7.13.5-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:7c8d4bc913dd70b93488d6c496c77f3aff5ea99a07e36a18f865bca55adef8bd", size = 249802, upload-time = "2026-03-17T10:30:33.358Z" }, + { url = "https://files.pythonhosted.org/packages/75/2c/1172fb689df92135f5bfbbd69fc83017a76d24ea2e2f3a1154007e2fb9f8/coverage-7.13.5-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0e3c426ffc4cd952f54ee9ffbdd10345709ecc78a3ecfd796a57236bfad0b9b8", size = 250707, upload-time = "2026-03-17T10:30:35.2Z" }, + { url = "https://files.pythonhosted.org/packages/67/21/9ac389377380a07884e3b48ba7a620fcd9dbfaf1d40565facdc6b36ec9ef/coverage-7.13.5-cp311-cp311-win32.whl", hash = "sha256:259b69bb83ad9894c4b25be2528139eecba9a82646ebdda2d9db1ba28424a6bf", size = 221880, upload-time = "2026-03-17T10:30:36.775Z" }, + { url = "https://files.pythonhosted.org/packages/af/7f/4cd8a92531253f9d7c1bbecd9fa1b472907fb54446ca768c59b531248dc5/coverage-7.13.5-cp311-cp311-win_amd64.whl", hash = "sha256:258354455f4e86e3e9d0d17571d522e13b4e1e19bf0f8596bcf9476d61e7d8a9", size = 222816, upload-time = "2026-03-17T10:30:38.891Z" }, + { url = "https://files.pythonhosted.org/packages/12/a6/1d3f6155fb0010ca68eba7fe48ca6c9da7385058b77a95848710ecf189b1/coverage-7.13.5-cp311-cp311-win_arm64.whl", hash = "sha256:bff95879c33ec8da99fc9b6fe345ddb5be6414b41d6d1ad1c8f188d26f36e028", size = 221483, upload-time = "2026-03-17T10:30:40.463Z" }, + { url = "https://files.pythonhosted.org/packages/a0/c3/a396306ba7db865bf96fc1fb3b7fd29bcbf3d829df642e77b13555163cd6/coverage-7.13.5-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:460cf0114c5016fa841214ff5564aa4864f11948da9440bc97e21ad1f4ba1e01", size = 219554, upload-time = "2026-03-17T10:30:42.208Z" }, + { url = "https://files.pythonhosted.org/packages/a6/16/a68a19e5384e93f811dccc51034b1fd0b865841c390e3c931dcc4699e035/coverage-7.13.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0e223ce4b4ed47f065bfb123687686512e37629be25cc63728557ae7db261422", size = 219908, upload-time = "2026-03-17T10:30:43.906Z" }, + { url = "https://files.pythonhosted.org/packages/29/72/20b917c6793af3a5ceb7fb9c50033f3ec7865f2911a1416b34a7cfa0813b/coverage-7.13.5-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:6e3370441f4513c6252bf042b9c36d22491142385049243253c7e48398a15a9f", size = 251419, upload-time = "2026-03-17T10:30:45.545Z" }, + { url = "https://files.pythonhosted.org/packages/8c/49/cd14b789536ac6a4778c453c6a2338bc0a2fb60c5a5a41b4008328b9acc1/coverage-7.13.5-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:03ccc709a17a1de074fb1d11f217342fb0d2b1582ed544f554fc9fc3f07e95f5", size = 254159, upload-time = "2026-03-17T10:30:47.204Z" }, + { url = "https://files.pythonhosted.org/packages/9d/00/7b0edcfe64e2ed4c0340dac14a52ad0f4c9bd0b8b5e531af7d55b703db7c/coverage-7.13.5-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3f4818d065964db3c1c66dc0fbdac5ac692ecbc875555e13374fdbe7eedb4376", size = 255270, upload-time = "2026-03-17T10:30:48.812Z" }, + { url = "https://files.pythonhosted.org/packages/93/89/7ffc4ba0f5d0a55c1e84ea7cee39c9fc06af7b170513d83fbf3bbefce280/coverage-7.13.5-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:012d5319e66e9d5a218834642d6c35d265515a62f01157a45bcc036ecf947256", size = 257538, upload-time = "2026-03-17T10:30:50.77Z" }, + { url = "https://files.pythonhosted.org/packages/81/bd/73ddf85f93f7e6fa83e77ccecb6162d9415c79007b4bc124008a4995e4a7/coverage-7.13.5-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:8dd02af98971bdb956363e4827d34425cb3df19ee550ef92855b0acb9c7ce51c", size = 251821, upload-time = "2026-03-17T10:30:52.5Z" }, + { url = "https://files.pythonhosted.org/packages/a0/81/278aff4e8dec4926a0bcb9486320752811f543a3ce5b602cc7a29978d073/coverage-7.13.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:f08fd75c50a760c7eb068ae823777268daaf16a80b918fa58eea888f8e3919f5", size = 253191, upload-time = "2026-03-17T10:30:54.543Z" }, + { url = "https://files.pythonhosted.org/packages/70/ee/fe1621488e2e0a58d7e94c4800f0d96f79671553488d401a612bebae324b/coverage-7.13.5-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:843ea8643cf967d1ac7e8ecd4bb00c99135adf4816c0c0593fdcc47b597fcf09", size = 251337, upload-time = "2026-03-17T10:30:56.663Z" }, + { url = "https://files.pythonhosted.org/packages/37/a6/f79fb37aa104b562207cc23cb5711ab6793608e246cae1e93f26b2236ed9/coverage-7.13.5-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:9d44d7aa963820b1b971dbecd90bfe5fe8f81cff79787eb6cca15750bd2f79b9", size = 255404, upload-time = "2026-03-17T10:30:58.427Z" }, + { url = "https://files.pythonhosted.org/packages/75/f0/ed15262a58ec81ce457ceb717b7f78752a1713556b19081b76e90896e8d4/coverage-7.13.5-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:7132bed4bd7b836200c591410ae7d97bf7ae8be6fc87d160b2bd881df929e7bf", size = 250903, upload-time = "2026-03-17T10:31:00.093Z" }, + { url = "https://files.pythonhosted.org/packages/0f/e9/9129958f20e7e9d4d56d51d42ccf708d15cac355ff4ac6e736e97a9393d2/coverage-7.13.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:a698e363641b98843c517817db75373c83254781426e94ada3197cabbc2c919c", size = 252780, upload-time = "2026-03-17T10:31:01.916Z" }, + { url = "https://files.pythonhosted.org/packages/a4/d7/0ad9b15812d81272db94379fe4c6df8fd17781cc7671fdfa30c76ba5ff7b/coverage-7.13.5-cp312-cp312-win32.whl", hash = "sha256:bdba0a6b8812e8c7df002d908a9a2ea3c36e92611b5708633c50869e6d922fdf", size = 222093, upload-time = "2026-03-17T10:31:03.642Z" }, + { url = "https://files.pythonhosted.org/packages/29/3d/821a9a5799fac2556bcf0bd37a70d1d11fa9e49784b6d22e92e8b2f85f18/coverage-7.13.5-cp312-cp312-win_amd64.whl", hash = "sha256:d2c87e0c473a10bffe991502eac389220533024c8082ec1ce849f4218dded810", size = 222900, upload-time = "2026-03-17T10:31:05.651Z" }, + { url = "https://files.pythonhosted.org/packages/d4/fa/2238c2ad08e35cf4f020ea721f717e09ec3152aea75d191a7faf3ef009a8/coverage-7.13.5-cp312-cp312-win_arm64.whl", hash = "sha256:bf69236a9a81bdca3bff53796237aab096cdbf8d78a66ad61e992d9dac7eb2de", size = 221515, upload-time = "2026-03-17T10:31:07.293Z" }, + { url = "https://files.pythonhosted.org/packages/9e/ee/a4cf96b8ce1e566ed238f0659ac2d3f007ed1d14b181bcb684e19561a69a/coverage-7.13.5-py3-none-any.whl", hash = "sha256:34b02417cf070e173989b3db962f7ed56d2f644307b2cf9d5a0f258e13084a61", size = 211346, upload-time = "2026-03-17T10:33:15.691Z" }, ] [package.optional-dependencies] @@ -1384,43 +1309,47 @@ wheels = [ [[package]] name = "cryptography" -version = "44.0.3" +version = "46.0.6" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cffi", marker = "platform_python_implementation != 'PyPy'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/53/d6/1411ab4d6108ab167d06254c5be517681f1e331f90edf1379895bcb87020/cryptography-44.0.3.tar.gz", hash = "sha256:fe19d8bc5536a91a24a8133328880a41831b6c5df54599a8417b62fe015d3053", size = 711096, upload-time = "2025-05-02T19:36:04.667Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a4/ba/04b1bd4218cbc58dc90ce967106d51582371b898690f3ae0402876cc4f34/cryptography-46.0.6.tar.gz", hash = "sha256:27550628a518c5c6c903d84f637fbecf287f6cb9ced3804838a1295dc1fd0759", size = 750542, upload-time = "2026-03-25T23:34:53.396Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/08/53/c776d80e9d26441bb3868457909b4e74dd9ccabd182e10b2b0ae7a07e265/cryptography-44.0.3-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:962bc30480a08d133e631e8dfd4783ab71cc9e33d5d7c1e192f0b7c06397bb88", size = 6670281, upload-time = "2025-05-02T19:34:50.665Z" }, - { url = "https://files.pythonhosted.org/packages/6a/06/af2cf8d56ef87c77319e9086601bef621bedf40f6f59069e1b6d1ec498c5/cryptography-44.0.3-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ffc61e8f3bf5b60346d89cd3d37231019c17a081208dfbbd6e1605ba03fa137", size = 3959305, upload-time = "2025-05-02T19:34:53.042Z" }, - { url = "https://files.pythonhosted.org/packages/ae/01/80de3bec64627207d030f47bf3536889efee8913cd363e78ca9a09b13c8e/cryptography-44.0.3-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58968d331425a6f9eedcee087f77fd3c927c88f55368f43ff7e0a19891f2642c", size = 4171040, upload-time = "2025-05-02T19:34:54.675Z" }, - { url = "https://files.pythonhosted.org/packages/bd/48/bb16b7541d207a19d9ae8b541c70037a05e473ddc72ccb1386524d4f023c/cryptography-44.0.3-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:e28d62e59a4dbd1d22e747f57d4f00c459af22181f0b2f787ea83f5a876d7c76", size = 3963411, upload-time = "2025-05-02T19:34:56.61Z" }, - { url = "https://files.pythonhosted.org/packages/42/b2/7d31f2af5591d217d71d37d044ef5412945a8a8e98d5a2a8ae4fd9cd4489/cryptography-44.0.3-cp37-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:af653022a0c25ef2e3ffb2c673a50e5a0d02fecc41608f4954176f1933b12359", size = 3689263, upload-time = "2025-05-02T19:34:58.591Z" }, - { url = "https://files.pythonhosted.org/packages/25/50/c0dfb9d87ae88ccc01aad8eb93e23cfbcea6a6a106a9b63a7b14c1f93c75/cryptography-44.0.3-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:157f1f3b8d941c2bd8f3ffee0af9b049c9665c39d3da9db2dc338feca5e98a43", size = 4196198, upload-time = "2025-05-02T19:35:00.988Z" }, - { url = "https://files.pythonhosted.org/packages/66/c9/55c6b8794a74da652690c898cb43906310a3e4e4f6ee0b5f8b3b3e70c441/cryptography-44.0.3-cp37-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:c6cd67722619e4d55fdb42ead64ed8843d64638e9c07f4011163e46bc512cf01", size = 3966502, upload-time = "2025-05-02T19:35:03.091Z" }, - { url = "https://files.pythonhosted.org/packages/b6/f7/7cb5488c682ca59a02a32ec5f975074084db4c983f849d47b7b67cc8697a/cryptography-44.0.3-cp37-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:b424563394c369a804ecbee9b06dfb34997f19d00b3518e39f83a5642618397d", size = 4196173, upload-time = "2025-05-02T19:35:05.018Z" }, - { url = "https://files.pythonhosted.org/packages/d2/0b/2f789a8403ae089b0b121f8f54f4a3e5228df756e2146efdf4a09a3d5083/cryptography-44.0.3-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:c91fc8e8fd78af553f98bc7f2a1d8db977334e4eea302a4bfd75b9461c2d8904", size = 4087713, upload-time = "2025-05-02T19:35:07.187Z" }, - { url = "https://files.pythonhosted.org/packages/1d/aa/330c13655f1af398fc154089295cf259252f0ba5df93b4bc9d9c7d7f843e/cryptography-44.0.3-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:25cd194c39fa5a0aa4169125ee27d1172097857b27109a45fadc59653ec06f44", size = 4299064, upload-time = "2025-05-02T19:35:08.879Z" }, - { url = "https://files.pythonhosted.org/packages/10/a8/8c540a421b44fd267a7d58a1fd5f072a552d72204a3f08194f98889de76d/cryptography-44.0.3-cp37-abi3-win32.whl", hash = "sha256:3be3f649d91cb182c3a6bd336de8b61a0a71965bd13d1a04a0e15b39c3d5809d", size = 2773887, upload-time = "2025-05-02T19:35:10.41Z" }, - { url = "https://files.pythonhosted.org/packages/b9/0d/c4b1657c39ead18d76bbd122da86bd95bdc4095413460d09544000a17d56/cryptography-44.0.3-cp37-abi3-win_amd64.whl", hash = "sha256:3883076d5c4cc56dbef0b898a74eb6992fdac29a7b9013870b34efe4ddb39a0d", size = 3209737, upload-time = "2025-05-02T19:35:12.12Z" }, - { url = "https://files.pythonhosted.org/packages/34/a3/ad08e0bcc34ad436013458d7528e83ac29910943cea42ad7dd4141a27bbb/cryptography-44.0.3-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:5639c2b16764c6f76eedf722dbad9a0914960d3489c0cc38694ddf9464f1bb2f", size = 6673501, upload-time = "2025-05-02T19:35:13.775Z" }, - { url = "https://files.pythonhosted.org/packages/b1/f0/7491d44bba8d28b464a5bc8cc709f25a51e3eac54c0a4444cf2473a57c37/cryptography-44.0.3-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3ffef566ac88f75967d7abd852ed5f182da252d23fac11b4766da3957766759", size = 3960307, upload-time = "2025-05-02T19:35:15.917Z" }, - { url = "https://files.pythonhosted.org/packages/f7/c8/e5c5d0e1364d3346a5747cdcd7ecbb23ca87e6dea4f942a44e88be349f06/cryptography-44.0.3-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:192ed30fac1728f7587c6f4613c29c584abdc565d7417c13904708db10206645", size = 4170876, upload-time = "2025-05-02T19:35:18.138Z" }, - { url = "https://files.pythonhosted.org/packages/73/96/025cb26fc351d8c7d3a1c44e20cf9a01e9f7cf740353c9c7a17072e4b264/cryptography-44.0.3-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:7d5fe7195c27c32a64955740b949070f21cba664604291c298518d2e255931d2", size = 3964127, upload-time = "2025-05-02T19:35:19.864Z" }, - { url = "https://files.pythonhosted.org/packages/01/44/eb6522db7d9f84e8833ba3bf63313f8e257729cf3a8917379473fcfd6601/cryptography-44.0.3-cp39-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:3f07943aa4d7dad689e3bb1638ddc4944cc5e0921e3c227486daae0e31a05e54", size = 3689164, upload-time = "2025-05-02T19:35:21.449Z" }, - { url = "https://files.pythonhosted.org/packages/68/fb/d61a4defd0d6cee20b1b8a1ea8f5e25007e26aeb413ca53835f0cae2bcd1/cryptography-44.0.3-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:cb90f60e03d563ca2445099edf605c16ed1d5b15182d21831f58460c48bffb93", size = 4198081, upload-time = "2025-05-02T19:35:23.187Z" }, - { url = "https://files.pythonhosted.org/packages/1b/50/457f6911d36432a8811c3ab8bd5a6090e8d18ce655c22820994913dd06ea/cryptography-44.0.3-cp39-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:ab0b005721cc0039e885ac3503825661bd9810b15d4f374e473f8c89b7d5460c", size = 3967716, upload-time = "2025-05-02T19:35:25.426Z" }, - { url = "https://files.pythonhosted.org/packages/35/6e/dca39d553075980ccb631955c47b93d87d27f3596da8d48b1ae81463d915/cryptography-44.0.3-cp39-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:3bb0847e6363c037df8f6ede57d88eaf3410ca2267fb12275370a76f85786a6f", size = 4197398, upload-time = "2025-05-02T19:35:27.678Z" }, - { url = "https://files.pythonhosted.org/packages/9b/9d/d1f2fe681eabc682067c66a74addd46c887ebacf39038ba01f8860338d3d/cryptography-44.0.3-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:b0cc66c74c797e1db750aaa842ad5b8b78e14805a9b5d1348dc603612d3e3ff5", size = 4087900, upload-time = "2025-05-02T19:35:29.312Z" }, - { url = "https://files.pythonhosted.org/packages/c4/f5/3599e48c5464580b73b236aafb20973b953cd2e7b44c7c2533de1d888446/cryptography-44.0.3-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:6866df152b581f9429020320e5eb9794c8780e90f7ccb021940d7f50ee00ae0b", size = 4301067, upload-time = "2025-05-02T19:35:31.547Z" }, - { url = "https://files.pythonhosted.org/packages/a7/6c/d2c48c8137eb39d0c193274db5c04a75dab20d2f7c3f81a7dcc3a8897701/cryptography-44.0.3-cp39-abi3-win32.whl", hash = "sha256:c138abae3a12a94c75c10499f1cbae81294a6f983b3af066390adee73f433028", size = 2775467, upload-time = "2025-05-02T19:35:33.805Z" }, - { url = "https://files.pythonhosted.org/packages/c9/ad/51f212198681ea7b0deaaf8846ee10af99fba4e894f67b353524eab2bbe5/cryptography-44.0.3-cp39-abi3-win_amd64.whl", hash = "sha256:5d186f32e52e66994dce4f766884bcb9c68b8da62d61d9d215bfe5fb56d21334", size = 3210375, upload-time = "2025-05-02T19:35:35.369Z" }, - { url = "https://files.pythonhosted.org/packages/8d/4b/c11ad0b6c061902de5223892d680e89c06c7c4d606305eb8de56c5427ae6/cryptography-44.0.3-pp311-pypy311_pp73-macosx_10_9_x86_64.whl", hash = "sha256:896530bc9107b226f265effa7ef3f21270f18a2026bc09fed1ebd7b66ddf6375", size = 3390230, upload-time = "2025-05-02T19:35:49.062Z" }, - { url = "https://files.pythonhosted.org/packages/58/11/0a6bf45d53b9b2290ea3cec30e78b78e6ca29dc101e2e296872a0ffe1335/cryptography-44.0.3-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:9b4d4a5dbee05a2c390bf212e78b99434efec37b17a4bff42f50285c5c8c9647", size = 3895216, upload-time = "2025-05-02T19:35:51.351Z" }, - { url = "https://files.pythonhosted.org/packages/0a/27/b28cdeb7270e957f0077a2c2bfad1b38f72f1f6d699679f97b816ca33642/cryptography-44.0.3-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:02f55fb4f8b79c1221b0961488eaae21015b69b210e18c386b69de182ebb1259", size = 4115044, upload-time = "2025-05-02T19:35:53.044Z" }, - { url = "https://files.pythonhosted.org/packages/35/b0/ec4082d3793f03cb248881fecefc26015813199b88f33e3e990a43f79835/cryptography-44.0.3-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:dd3db61b8fe5be220eee484a17233287d0be6932d056cf5738225b9c05ef4fff", size = 3898034, upload-time = "2025-05-02T19:35:54.72Z" }, - { url = "https://files.pythonhosted.org/packages/0b/7f/adf62e0b8e8d04d50c9a91282a57628c00c54d4ae75e2b02a223bd1f2613/cryptography-44.0.3-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:978631ec51a6bbc0b7e58f23b68a8ce9e5f09721940933e9c217068388789fe5", size = 4114449, upload-time = "2025-05-02T19:35:57.139Z" }, - { url = "https://files.pythonhosted.org/packages/87/62/d69eb4a8ee231f4bf733a92caf9da13f1c81a44e874b1d4080c25ecbb723/cryptography-44.0.3-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:5d20cc348cca3a8aa7312f42ab953a56e15323800ca3ab0706b8cd452a3a056c", size = 3134369, upload-time = "2025-05-02T19:35:58.907Z" }, + { url = "https://files.pythonhosted.org/packages/47/23/9285e15e3bc57325b0a72e592921983a701efc1ee8f91c06c5f0235d86d9/cryptography-46.0.6-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:64235194bad039a10bb6d2d930ab3323baaec67e2ce36215fd0952fad0930ca8", size = 7176401, upload-time = "2026-03-25T23:33:22.096Z" }, + { url = "https://files.pythonhosted.org/packages/60/f8/e61f8f13950ab6195b31913b42d39f0f9afc7d93f76710f299b5ec286ae6/cryptography-46.0.6-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:26031f1e5ca62fcb9d1fcb34b2b60b390d1aacaa15dc8b895a9ed00968b97b30", size = 4275275, upload-time = "2026-03-25T23:33:23.844Z" }, + { url = "https://files.pythonhosted.org/packages/19/69/732a736d12c2631e140be2348b4ad3d226302df63ef64d30dfdb8db7ad1c/cryptography-46.0.6-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9a693028b9cbe51b5a1136232ee8f2bc242e4e19d456ded3fa7c86e43c713b4a", size = 4425320, upload-time = "2026-03-25T23:33:25.703Z" }, + { url = "https://files.pythonhosted.org/packages/d4/12/123be7292674abf76b21ac1fc0e1af50661f0e5b8f0ec8285faac18eb99e/cryptography-46.0.6-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:67177e8a9f421aa2d3a170c3e56eca4e0128883cf52a071a7cbf53297f18b175", size = 4278082, upload-time = "2026-03-25T23:33:27.423Z" }, + { url = "https://files.pythonhosted.org/packages/5b/ba/d5e27f8d68c24951b0a484924a84c7cdaed7502bac9f18601cd357f8b1d2/cryptography-46.0.6-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:d9528b535a6c4f8ff37847144b8986a9a143585f0540fbcb1a98115b543aa463", size = 4926514, upload-time = "2026-03-25T23:33:29.206Z" }, + { url = "https://files.pythonhosted.org/packages/34/71/1ea5a7352ae516d5512d17babe7e1b87d9db5150b21f794b1377eac1edc0/cryptography-46.0.6-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:22259338084d6ae497a19bae5d4c66b7ca1387d3264d1c2c0e72d9e9b6a77b97", size = 4457766, upload-time = "2026-03-25T23:33:30.834Z" }, + { url = "https://files.pythonhosted.org/packages/01/59/562be1e653accee4fdad92c7a2e88fced26b3fdfce144047519bbebc299e/cryptography-46.0.6-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:760997a4b950ff00d418398ad73fbc91aa2894b5c1db7ccb45b4f68b42a63b3c", size = 3986535, upload-time = "2026-03-25T23:33:33.02Z" }, + { url = "https://files.pythonhosted.org/packages/d6/8b/b1ebfeb788bf4624d36e45ed2662b8bd43a05ff62157093c1539c1288a18/cryptography-46.0.6-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:3dfa6567f2e9e4c5dceb8ccb5a708158a2a871052fa75c8b78cb0977063f1507", size = 4277618, upload-time = "2026-03-25T23:33:34.567Z" }, + { url = "https://files.pythonhosted.org/packages/dd/52/a005f8eabdb28df57c20f84c44d397a755782d6ff6d455f05baa2785bd91/cryptography-46.0.6-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:cdcd3edcbc5d55757e5f5f3d330dd00007ae463a7e7aa5bf132d1f22a4b62b19", size = 4890802, upload-time = "2026-03-25T23:33:37.034Z" }, + { url = "https://files.pythonhosted.org/packages/ec/4d/8e7d7245c79c617d08724e2efa397737715ca0ec830ecb3c91e547302555/cryptography-46.0.6-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:d4e4aadb7fc1f88687f47ca20bb7227981b03afaae69287029da08096853b738", size = 4457425, upload-time = "2026-03-25T23:33:38.904Z" }, + { url = "https://files.pythonhosted.org/packages/1d/5c/f6c3596a1430cec6f949085f0e1a970638d76f81c3ea56d93d564d04c340/cryptography-46.0.6-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:2b417edbe8877cda9022dde3a008e2deb50be9c407eef034aeeb3a8b11d9db3c", size = 4405530, upload-time = "2026-03-25T23:33:40.842Z" }, + { url = "https://files.pythonhosted.org/packages/7e/c9/9f9cea13ee2dbde070424e0c4f621c091a91ffcc504ffea5e74f0e1daeff/cryptography-46.0.6-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:380343e0653b1c9d7e1f55b52aaa2dbb2fdf2730088d48c43ca1c7c0abb7cc2f", size = 4667896, upload-time = "2026-03-25T23:33:42.781Z" }, + { url = "https://files.pythonhosted.org/packages/ad/b5/1895bc0821226f129bc74d00eccfc6a5969e2028f8617c09790bf89c185e/cryptography-46.0.6-cp311-abi3-win32.whl", hash = "sha256:bcb87663e1f7b075e48c3be3ecb5f0b46c8fc50b50a97cf264e7f60242dca3f2", size = 3026348, upload-time = "2026-03-25T23:33:45.021Z" }, + { url = "https://files.pythonhosted.org/packages/c3/f8/c9bcbf0d3e6ad288b9d9aa0b1dee04b063d19e8c4f871855a03ab3a297ab/cryptography-46.0.6-cp311-abi3-win_amd64.whl", hash = "sha256:6739d56300662c468fddb0e5e291f9b4d084bead381667b9e654c7dd81705124", size = 3483896, upload-time = "2026-03-25T23:33:46.649Z" }, + { url = "https://files.pythonhosted.org/packages/c4/cc/f330e982852403da79008552de9906804568ae9230da8432f7496ce02b71/cryptography-46.0.6-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:12cae594e9473bca1a7aceb90536060643128bb274fcea0fc459ab90f7d1ae7a", size = 7162776, upload-time = "2026-03-25T23:34:13.308Z" }, + { url = "https://files.pythonhosted.org/packages/49/b3/dc27efd8dcc4bff583b3f01d4a3943cd8b5821777a58b3a6a5f054d61b79/cryptography-46.0.6-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:639301950939d844a9e1c4464d7e07f902fe9a7f6b215bb0d4f28584729935d8", size = 4270529, upload-time = "2026-03-25T23:34:15.019Z" }, + { url = "https://files.pythonhosted.org/packages/e6/05/e8d0e6eb4f0d83365b3cb0e00eb3c484f7348db0266652ccd84632a3d58d/cryptography-46.0.6-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ed3775295fb91f70b4027aeba878d79b3e55c0b3e97eaa4de71f8f23a9f2eb77", size = 4414827, upload-time = "2026-03-25T23:34:16.604Z" }, + { url = "https://files.pythonhosted.org/packages/2f/97/daba0f5d2dc6d855e2dcb70733c812558a7977a55dd4a6722756628c44d1/cryptography-46.0.6-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:8927ccfbe967c7df312ade694f987e7e9e22b2425976ddbf28271d7e58845290", size = 4271265, upload-time = "2026-03-25T23:34:18.586Z" }, + { url = "https://files.pythonhosted.org/packages/89/06/fe1fce39a37ac452e58d04b43b0855261dac320a2ebf8f5260dd55b201a9/cryptography-46.0.6-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:b12c6b1e1651e42ab5de8b1e00dc3b6354fdfd778e7fa60541ddacc27cd21410", size = 4916800, upload-time = "2026-03-25T23:34:20.561Z" }, + { url = "https://files.pythonhosted.org/packages/ff/8a/b14f3101fe9c3592603339eb5d94046c3ce5f7fc76d6512a2d40efd9724e/cryptography-46.0.6-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:063b67749f338ca9c5a0b7fe438a52c25f9526b851e24e6c9310e7195aad3b4d", size = 4448771, upload-time = "2026-03-25T23:34:22.406Z" }, + { url = "https://files.pythonhosted.org/packages/01/b3/0796998056a66d1973fd52ee89dc1bb3b6581960a91ad4ac705f182d398f/cryptography-46.0.6-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:02fad249cb0e090b574e30b276a3da6a149e04ee2f049725b1f69e7b8351ec70", size = 3978333, upload-time = "2026-03-25T23:34:24.281Z" }, + { url = "https://files.pythonhosted.org/packages/c5/3d/db200af5a4ffd08918cd55c08399dc6c9c50b0bc72c00a3246e099d3a849/cryptography-46.0.6-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:7e6142674f2a9291463e5e150090b95a8519b2fb6e6aaec8917dd8d094ce750d", size = 4271069, upload-time = "2026-03-25T23:34:25.895Z" }, + { url = "https://files.pythonhosted.org/packages/d7/18/61acfd5b414309d74ee838be321c636fe71815436f53c9f0334bf19064fa/cryptography-46.0.6-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:456b3215172aeefb9284550b162801d62f5f264a081049a3e94307fe20792cfa", size = 4878358, upload-time = "2026-03-25T23:34:27.67Z" }, + { url = "https://files.pythonhosted.org/packages/8b/65/5bf43286d566f8171917cae23ac6add941654ccf085d739195a4eacf1674/cryptography-46.0.6-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:341359d6c9e68834e204ceaf25936dffeafea3829ab80e9503860dcc4f4dac58", size = 4448061, upload-time = "2026-03-25T23:34:29.375Z" }, + { url = "https://files.pythonhosted.org/packages/e0/25/7e49c0fa7205cf3597e525d156a6bce5b5c9de1fd7e8cb01120e459f205a/cryptography-46.0.6-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9a9c42a2723999a710445bc0d974e345c32adfd8d2fac6d8a251fa829ad31cfb", size = 4399103, upload-time = "2026-03-25T23:34:32.036Z" }, + { url = "https://files.pythonhosted.org/packages/44/46/466269e833f1c4718d6cd496ffe20c56c9c8d013486ff66b4f69c302a68d/cryptography-46.0.6-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:6617f67b1606dfd9fe4dbfa354a9508d4a6d37afe30306fe6c101b7ce3274b72", size = 4659255, upload-time = "2026-03-25T23:34:33.679Z" }, + { url = "https://files.pythonhosted.org/packages/0a/09/ddc5f630cc32287d2c953fc5d32705e63ec73e37308e5120955316f53827/cryptography-46.0.6-cp38-abi3-win32.whl", hash = "sha256:7f6690b6c55e9c5332c0b59b9c8a3fb232ebf059094c17f9019a51e9827df91c", size = 3010660, upload-time = "2026-03-25T23:34:35.418Z" }, + { url = "https://files.pythonhosted.org/packages/1b/82/ca4893968aeb2709aacfb57a30dec6fa2ab25b10fa9f064b8882ce33f599/cryptography-46.0.6-cp38-abi3-win_amd64.whl", hash = "sha256:79e865c642cfc5c0b3eb12af83c35c5aeff4fa5c672dc28c43721c2c9fdd2f0f", size = 3471160, upload-time = "2026-03-25T23:34:37.191Z" }, + { url = "https://files.pythonhosted.org/packages/2e/84/7ccff00ced5bac74b775ce0beb7d1be4e8637536b522b5df9b73ada42da2/cryptography-46.0.6-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:2ea0f37e9a9cf0df2952893ad145fd9627d326a59daec9b0802480fa3bcd2ead", size = 3475444, upload-time = "2026-03-25T23:34:38.944Z" }, + { url = "https://files.pythonhosted.org/packages/bc/1f/4c926f50df7749f000f20eede0c896769509895e2648db5da0ed55db711d/cryptography-46.0.6-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:a3e84d5ec9ba01f8fd03802b2147ba77f0c8f2617b2aff254cedd551844209c8", size = 4218227, upload-time = "2026-03-25T23:34:40.871Z" }, + { url = "https://files.pythonhosted.org/packages/c6/65/707be3ffbd5f786028665c3223e86e11c4cda86023adbc56bd72b1b6bab5/cryptography-46.0.6-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:12f0fa16cc247b13c43d56d7b35287ff1569b5b1f4c5e87e92cc4fcc00cd10c0", size = 4381399, upload-time = "2026-03-25T23:34:42.609Z" }, + { url = "https://files.pythonhosted.org/packages/f3/6d/73557ed0ef7d73d04d9aba745d2c8e95218213687ee5e76b7d236a5030fc/cryptography-46.0.6-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:50575a76e2951fe7dbd1f56d181f8c5ceeeb075e9ff88e7ad997d2f42af06e7b", size = 4217595, upload-time = "2026-03-25T23:34:44.205Z" }, + { url = "https://files.pythonhosted.org/packages/9e/c5/e1594c4eec66a567c3ac4400008108a415808be2ce13dcb9a9045c92f1a0/cryptography-46.0.6-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:90e5f0a7b3be5f40c3a0a0eafb32c681d8d2c181fc2a1bdabe9b3f611d9f6b1a", size = 4380912, upload-time = "2026-03-25T23:34:46.328Z" }, + { url = "https://files.pythonhosted.org/packages/1a/89/843b53614b47f97fe1abc13f9a86efa5ec9e275292c457af1d4a60dc80e0/cryptography-46.0.6-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:6728c49e3b2c180ef26f8e9f0a883a2c585638db64cf265b49c9ba10652d430e", size = 3409955, upload-time = "2026-03-25T23:34:48.465Z" }, ] [[package]] @@ -1533,7 +1462,7 @@ wheels = [ [[package]] name = "dify-api" -version = "1.13.2" +version = "1.13.3" source = { virtual = "." } dependencies = [ { name = "aliyun-log-python-sdk" }, @@ -1565,12 +1494,12 @@ dependencies = [ { name = "google-auth-httplib2" }, { name = "google-cloud-aiplatform" }, { name = "googleapis-common-protos" }, + { name = "graphon" }, { name = "gunicorn" }, { name = "httpx", extra = ["socks"] }, { name = "httpx-sse" }, { name = "jieba" }, { name = "json-repair" }, - { name = "jsonschema" }, { name = "langfuse" }, { name = "langsmith" }, { name = "litellm" }, @@ -1602,9 +1531,9 @@ dependencies = [ { name = "psycopg2-binary" }, { name = "pycryptodome" }, { name = "pydantic" }, - { name = "pydantic-extra-types" }, { name = "pydantic-settings" }, { name = "pyjwt" }, + { name = "pypandoc" }, { name = "pypdfium2" }, { name = "python-docx" }, { name = "python-dotenv" }, @@ -1622,7 +1551,6 @@ dependencies = [ { name = "unstructured", extra = ["docx", "epub", "md", "ppt", "pptx"] }, { name = "weave" }, { name = "weaviate-client" }, - { name = "webvtt-py" }, { name = "yarl" }, ] @@ -1665,7 +1593,6 @@ dev = [ { name = "types-greenlet" }, { name = "types-html5lib" }, { name = "types-jmespath" }, - { name = "types-jsonschema" }, { name = "types-markdown" }, { name = "types-oauthlib" }, { name = "types-objgraph" }, @@ -1743,8 +1670,8 @@ requires-dist = [ { name = "arize-phoenix-otel", specifier = "~=0.15.0" }, { name = "azure-identity", specifier = "==1.25.3" }, { name = "beautifulsoup4", specifier = "==4.14.3" }, - { name = "bleach", specifier = "~=6.2.0" }, - { name = "boto3", specifier = "==1.42.68" }, + { name = "bleach", specifier = "~=6.3.0" }, + { name = "boto3", specifier = "==1.42.78" }, { name = "bs4", specifier = "~=0.0.1" }, { name = "cachetools", specifier = "~=5.3.0" }, { name = "celery", specifier = "~=5.6.2" }, @@ -1762,41 +1689,41 @@ requires-dist = [ { name = "gevent", specifier = "~=25.9.1" }, { name = "gmpy2", specifier = "~=2.3.0" }, { name = "google-api-core", specifier = ">=2.19.1" }, - { name = "google-api-python-client", specifier = "==2.192.0" }, + { name = "google-api-python-client", specifier = "==2.193.0" }, { name = "google-auth", specifier = ">=2.47.0" }, { name = "google-auth-httplib2", specifier = "==0.3.0" }, { name = "google-cloud-aiplatform", specifier = ">=1.123.0" }, { name = "googleapis-common-protos", specifier = ">=1.65.0" }, - { name = "gunicorn", specifier = "~=25.1.0" }, + { name = "graphon", specifier = ">=0.1.2" }, + { name = "gunicorn", specifier = "~=25.3.0" }, { name = "httpx", extras = ["socks"], specifier = "~=0.28.0" }, { name = "httpx-sse", specifier = "~=0.4.0" }, { name = "jieba", specifier = "==0.42.1" }, { name = "json-repair", specifier = ">=0.55.1" }, - { name = "jsonschema", specifier = ">=4.25.1" }, - { name = "langfuse", specifier = "~=2.51.3" }, + { name = "langfuse", specifier = ">=3.0.0,<5.0.0" }, { name = "langsmith", specifier = "~=0.7.16" }, - { name = "litellm", specifier = "==1.82.2" }, + { name = "litellm", specifier = "==1.82.6" }, { name = "markdown", specifier = "~=3.10.2" }, { name = "mlflow-skinny", specifier = ">=3.0.0" }, { name = "numpy", specifier = "~=1.26.4" }, { name = "openpyxl", specifier = "~=3.1.5" }, - { name = "opentelemetry-api", specifier = "==1.28.0" }, - { name = "opentelemetry-distro", specifier = "==0.49b0" }, - { name = "opentelemetry-exporter-otlp", specifier = "==1.28.0" }, - { name = "opentelemetry-exporter-otlp-proto-common", specifier = "==1.28.0" }, - { name = "opentelemetry-exporter-otlp-proto-grpc", specifier = "==1.28.0" }, - { name = "opentelemetry-exporter-otlp-proto-http", specifier = "==1.28.0" }, - { name = "opentelemetry-instrumentation", specifier = "==0.49b0" }, - { name = "opentelemetry-instrumentation-celery", specifier = "==0.49b0" }, - { name = "opentelemetry-instrumentation-flask", specifier = "==0.49b0" }, - { name = "opentelemetry-instrumentation-httpx", specifier = "==0.49b0" }, - { name = "opentelemetry-instrumentation-redis", specifier = "==0.49b0" }, - { name = "opentelemetry-instrumentation-sqlalchemy", specifier = "==0.49b0" }, + { name = "opentelemetry-api", specifier = "==1.40.0" }, + { name = "opentelemetry-distro", specifier = "==0.61b0" }, + { name = "opentelemetry-exporter-otlp", specifier = "==1.40.0" }, + { name = "opentelemetry-exporter-otlp-proto-common", specifier = "==1.40.0" }, + { name = "opentelemetry-exporter-otlp-proto-grpc", specifier = "==1.40.0" }, + { name = "opentelemetry-exporter-otlp-proto-http", specifier = "==1.40.0" }, + { name = "opentelemetry-instrumentation", specifier = "==0.61b0" }, + { name = "opentelemetry-instrumentation-celery", specifier = "==0.61b0" }, + { name = "opentelemetry-instrumentation-flask", specifier = "==0.61b0" }, + { name = "opentelemetry-instrumentation-httpx", specifier = "==0.61b0" }, + { name = "opentelemetry-instrumentation-redis", specifier = "==0.61b0" }, + { name = "opentelemetry-instrumentation-sqlalchemy", specifier = "==0.61b0" }, { name = "opentelemetry-propagator-b3", specifier = "==1.40.0" }, - { name = "opentelemetry-proto", specifier = "==1.28.0" }, - { name = "opentelemetry-sdk", specifier = "==1.28.0" }, - { name = "opentelemetry-semantic-conventions", specifier = "==0.49b0" }, - { name = "opentelemetry-util-http", specifier = "==0.49b0" }, + { name = "opentelemetry-proto", specifier = "==1.40.0" }, + { name = "opentelemetry-sdk", specifier = "==1.40.0" }, + { name = "opentelemetry-semantic-conventions", specifier = "==0.61b0" }, + { name = "opentelemetry-util-http", specifier = "==0.61b0" }, { name = "opik", specifier = "~=1.10.37" }, { name = "packaging", specifier = "~=23.2" }, { name = "pandas", extras = ["excel", "output-formatting", "performance"], specifier = "~=3.0.1" }, @@ -1804,27 +1731,26 @@ requires-dist = [ { name = "psycopg2-binary", specifier = "~=2.9.6" }, { name = "pycryptodome", specifier = "==3.23.0" }, { name = "pydantic", specifier = "~=2.12.5" }, - { name = "pydantic-extra-types", specifier = "~=2.11.0" }, { name = "pydantic-settings", specifier = "~=2.13.1" }, { name = "pyjwt", specifier = "~=2.12.0" }, + { name = "pypandoc", specifier = "~=1.13" }, { name = "pypdfium2", specifier = "==5.6.0" }, { name = "python-docx", specifier = "~=1.2.0" }, { name = "python-dotenv", specifier = "==1.2.2" }, { name = "pyyaml", specifier = "~=6.0.1" }, { name = "readabilipy", specifier = "~=0.3.0" }, - { name = "redis", extras = ["hiredis"], specifier = "~=7.3.0" }, - { name = "resend", specifier = "~=2.23.0" }, + { name = "redis", extras = ["hiredis"], specifier = "~=7.4.0" }, + { name = "resend", specifier = "~=2.26.0" }, { name = "sendgrid", specifier = "~=6.12.3" }, - { name = "sentry-sdk", extras = ["flask"], specifier = "~=2.54.0" }, + { name = "sentry-sdk", extras = ["flask"], specifier = "~=2.55.0" }, { name = "sqlalchemy", specifier = "~=2.0.29" }, { name = "sseclient-py", specifier = "~=1.9.0" }, - { name = "starlette", specifier = "==0.52.1" }, + { name = "starlette", specifier = "==1.0.0" }, { name = "tiktoken", specifier = "~=0.12.0" }, { name = "transformers", specifier = "~=5.3.0" }, { name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.21.5" }, { name = "weave", specifier = ">=0.52.16" }, { name = "weaviate-client", specifier = "==4.20.4" }, - { name = "webvtt-py", specifier = "~=0.5.1" }, { name = "yarl", specifier = "~=1.23.0" }, ] @@ -1841,10 +1767,10 @@ dev = [ { name = "lxml-stubs", specifier = "~=0.5.1" }, { name = "mypy", specifier = "~=1.19.1" }, { name = "pandas-stubs", specifier = "~=3.0.0" }, - { name = "pyrefly", specifier = ">=0.55.0" }, + { name = "pyrefly", specifier = ">=0.57.1" }, { name = "pytest", specifier = "~=9.0.2" }, { name = "pytest-benchmark", specifier = "~=5.2.3" }, - { name = "pytest-cov", specifier = "~=7.0.0" }, + { name = "pytest-cov", specifier = "~=7.1.0" }, { name = "pytest-env", specifier = "~=1.6.0" }, { name = "pytest-mock", specifier = "~=3.15.1" }, { name = "pytest-timeout", specifier = ">=2.4.0" }, @@ -1867,7 +1793,6 @@ dev = [ { name = "types-greenlet", specifier = "~=3.3.0" }, { name = "types-html5lib", specifier = "~=1.1.11" }, { name = "types-jmespath", specifier = ">=1.0.2.20240106" }, - { name = "types-jsonschema", specifier = "~=4.26.0" }, { name = "types-markdown", specifier = "~=3.10.2" }, { name = "types-oauthlib", specifier = "~=3.3.0" }, { name = "types-objgraph", specifier = "~=3.6.0" }, @@ -1885,7 +1810,7 @@ dev = [ { name = "types-pywin32", specifier = "~=311.0.0" }, { name = "types-pyyaml", specifier = "~=6.0.12" }, { name = "types-redis", specifier = ">=4.6.0.20241004" }, - { name = "types-regex", specifier = "~=2026.2.28" }, + { name = "types-regex", specifier = "~=2026.3.32" }, { name = "types-setuptools", specifier = ">=80.9.0" }, { name = "types-shapely", specifier = "~=2.1.0" }, { name = "types-simplejson", specifier = ">=3.20.0" }, @@ -1910,10 +1835,10 @@ tools = [ { name = "nltk", specifier = "~=3.9.1" }, ] vdb = [ - { name = "alibabacloud-gpdb20160503", specifier = "~=3.8.0" }, + { name = "alibabacloud-gpdb20160503", specifier = "~=5.1.0" }, { name = "alibabacloud-tea-openapi", specifier = "~=0.4.3" }, { name = "chromadb", specifier = "==0.5.20" }, - { name = "clickhouse-connect", specifier = "~=0.14.1" }, + { name = "clickhouse-connect", specifier = "~=0.15.0" }, { name = "clickzetta-connector-python", specifier = ">=0.8.102" }, { name = "couchbase", specifier = "~=4.5.0" }, { name = "elasticsearch", specifier = "==8.14.0" }, @@ -1929,13 +1854,13 @@ vdb = [ { name = "pymochow", specifier = "==2.3.6" }, { name = "pyobvector", specifier = "~=0.2.17" }, { name = "qdrant-client", specifier = "==1.9.0" }, - { name = "tablestore", specifier = "==6.4.1" }, - { name = "tcvectordb", specifier = "~=2.0.0" }, + { name = "tablestore", specifier = "==6.4.2" }, + { name = "tcvectordb", specifier = "~=2.1.0" }, { name = "tidb-vector", specifier = "==0.0.15" }, { name = "upstash-vector", specifier = "==0.8.0" }, { name = "volcengine-compat", specifier = "~=1.0.0" }, { name = "weaviate-client", specifier = "==4.20.4" }, - { name = "xinference-client", specifier = "~=2.3.1" }, + { name = "xinference-client", specifier = "~=2.4.0" }, ] [[package]] @@ -2098,14 +2023,14 @@ wheels = [ [[package]] name = "faker" -version = "40.11.0" +version = "40.11.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "tzdata", marker = "sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/94/dc/b68e5378e5a7db0ab776efcdd53b6fe374b29d703e156fd5bb4c5437069e/faker-40.11.0.tar.gz", hash = "sha256:7c419299103b13126bd02ec14bd2b47b946edb5a5eedf305e66a193b25f9a734", size = 1957570, upload-time = "2026-03-13T14:36:11.844Z" } +sdist = { url = "https://files.pythonhosted.org/packages/fa/e5/b16bf568a2f20fe7423282db4a4059dbcadef70e9029c1c106836f8edd84/faker-40.11.1.tar.gz", hash = "sha256:61965046e79e8cfde4337d243eac04c0d31481a7c010033141103b43f603100c", size = 1957415, upload-time = "2026-03-23T14:05:50.233Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b1/fa/a86c6ba66f0308c95b9288b1e3eaccd934b545646f63494a86f1ec2f8c8e/faker-40.11.0-py3-none-any.whl", hash = "sha256:0e9816c950528d2a37d74863f3ef389ea9a3a936cbcde0b11b8499942e25bf90", size = 1989457, upload-time = "2026-03-13T14:36:09.792Z" }, + { url = "https://files.pythonhosted.org/packages/fc/ec/3c4b78eb0d2f6a81fb8cc9286745845bff661e6815741eff7a6ac5fcc9ea/faker-40.11.1-py3-none-any.whl", hash = "sha256:3af3a213ba8fb33ce6ba2af7aef2ac91363dae35d0cec0b2b0337d189e5bee2a", size = 1989484, upload-time = "2026-03-23T14:05:48.793Z" }, ] [[package]] @@ -2499,7 +2424,7 @@ grpc = [ [[package]] name = "google-api-python-client" -version = "2.192.0" +version = "2.193.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "google-api-core" }, @@ -2508,9 +2433,9 @@ dependencies = [ { name = "httplib2" }, { name = "uritemplate" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/85/d8/489052a40935e45b9b5b3d6accc14b041360c1507bdc659c2e1a19aaa3ff/google_api_python_client-2.192.0.tar.gz", hash = "sha256:d48cfa6078fadea788425481b007af33fe0ab6537b78f37da914fb6fc112eb27", size = 14209505, upload-time = "2026-03-05T15:17:01.598Z" } +sdist = { url = "https://files.pythonhosted.org/packages/90/f4/e14b6815d3b1885328dd209676a3a4c704882743ac94e18ef0093894f5c8/google_api_python_client-2.193.0.tar.gz", hash = "sha256:8f88d16e89d11341e0a8b199cafde0fb7e6b44260dffb88d451577cbd1bb5d33", size = 14281006, upload-time = "2026-03-17T18:25:29.415Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e0/76/ec4128f00fefb9011635ae2abc67d7dacd05c8559378f8f05f0c907c38d8/google_api_python_client-2.192.0-py3-none-any.whl", hash = "sha256:63a57d4457cd97df1d63eb89c5fda03c5a50588dcbc32c0115dd1433c08f4b62", size = 14783267, upload-time = "2026-03-05T15:16:58.804Z" }, + { url = "https://files.pythonhosted.org/packages/f0/6d/fe75167797790a56d17799b75e1129bb93f7ff061efc7b36e9731bd4be2b/google_api_python_client-2.193.0-py3-none-any.whl", hash = "sha256:c42aa324b822109901cfecab5dc4fc3915d35a7b376835233c916c70610322db", size = 14856490, upload-time = "2026-03-17T18:25:26.608Z" }, ] [[package]] @@ -2546,7 +2471,7 @@ wheels = [ [[package]] name = "google-cloud-aiplatform" -version = "1.141.0" +version = "1.143.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "docstring-parser" }, @@ -2562,9 +2487,9 @@ dependencies = [ { name = "pydantic" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ac/dc/1209c7aab43bd7233cf631165a3b1b4284d22fc7fe7387c66228d07868ab/google_cloud_aiplatform-1.141.0.tar.gz", hash = "sha256:e3b1cdb28865dd862aac9c685dfc5ac076488705aba0a5354016efadcddd59c6", size = 10152688, upload-time = "2026-03-10T22:20:08.692Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a7/08/939fb05870fdf155410a927e22f5b053d49f18e215618e102fba1d8bb147/google_cloud_aiplatform-1.143.0.tar.gz", hash = "sha256:1f0124a89795a6b473deb28724dd37d95334205df3a9c9c48d0b8d7a3d5d5cc4", size = 10215389, upload-time = "2026-03-25T18:30:15.444Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/6a/fc/428af69a69ff2e477e7f5e12d227b31fe5790f1a8234aacd54297f49c836/google_cloud_aiplatform-1.141.0-py2.py3-none-any.whl", hash = "sha256:6bd25b4d514c40b8181ca703e1b313ad6d0454ab8006fc9907fb3e9f672f31d1", size = 8358409, upload-time = "2026-03-10T22:20:04.871Z" }, + { url = "https://files.pythonhosted.org/packages/90/14/16323e604e79dc63b528268f97a841c2c29dd8eb16395de6bf530c1a5ebe/google_cloud_aiplatform-1.143.0-py2.py3-none-any.whl", hash = "sha256:78df97d044859f743a9cc48b89a260d33579b0d548b1589bb3ae9f4c2afc0c5a", size = 8392705, upload-time = "2026-03-25T18:30:11.496Z" }, ] [[package]] @@ -2617,7 +2542,7 @@ wheels = [ [[package]] name = "google-cloud-storage" -version = "3.9.0" +version = "3.10.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "google-api-core" }, @@ -2627,9 +2552,9 @@ dependencies = [ { name = "google-resumable-media" }, { name = "requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f7/b1/4f0798e88285b50dfc60ed3a7de071def538b358db2da468c2e0deecbb40/google_cloud_storage-3.9.0.tar.gz", hash = "sha256:f2d8ca7db2f652be757e92573b2196e10fbc09649b5c016f8b422ad593c641cc", size = 17298544, upload-time = "2026-02-02T13:36:34.119Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4c/47/205eb8e9a1739b5345843e5a425775cbdc472cc38e7eda082ba5b8d02450/google_cloud_storage-3.10.1.tar.gz", hash = "sha256:97db9aa4460727982040edd2bd13ff3d5e2260b5331ad22895802da1fc2a5286", size = 17309950, upload-time = "2026-03-23T09:35:23.409Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/46/0b/816a6ae3c9fd096937d2e5f9670558908811d57d59ddf69dd4b83b326fd1/google_cloud_storage-3.9.0-py3-none-any.whl", hash = "sha256:2dce75a9e8b3387078cbbdad44757d410ecdb916101f8ba308abf202b6968066", size = 321324, upload-time = "2026-02-02T13:36:32.271Z" }, + { url = "https://files.pythonhosted.org/packages/ad/ff/ca9ab2417fa913d75aae38bf40bf856bb2749a604b2e0f701b37cfcd23cc/google_cloud_storage-3.10.1-py3-none-any.whl", hash = "sha256:a72f656759b7b99bda700f901adcb3425a828d4a29f911bc26b3ea79c5b1217f", size = 324453, upload-time = "2026-03-23T09:35:21.368Z" }, ] [[package]] @@ -2687,14 +2612,14 @@ wheels = [ [[package]] name = "googleapis-common-protos" -version = "1.73.0" +version = "1.73.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "protobuf" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/99/96/a0205167fa0154f4a542fd6925bdc63d039d88dab3588b875078107e6f06/googleapis_common_protos-1.73.0.tar.gz", hash = "sha256:778d07cd4fbeff84c6f7c72102f0daf98fa2bfd3fa8bea426edc545588da0b5a", size = 147323, upload-time = "2026-03-06T21:53:09.727Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a1/c0/4a54c386282c13449eca8bbe2ddb518181dc113e78d240458a68856b4d69/googleapis_common_protos-1.73.1.tar.gz", hash = "sha256:13114f0e9d2391756a0194c3a8131974ed7bffb06086569ba193364af59163b6", size = 147506, upload-time = "2026-03-26T22:17:38.451Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/69/28/23eea8acd65972bbfe295ce3666b28ac510dfcb115fac089d3edb0feb00a/googleapis_common_protos-1.73.0-py3-none-any.whl", hash = "sha256:dfdaaa2e860f242046be561e6d6cb5c5f1541ae02cfbcb034371aadb2942b4e8", size = 297578, upload-time = "2026-03-06T21:52:33.933Z" }, + { url = "https://files.pythonhosted.org/packages/dc/82/fcb6520612bec0c39b973a6c0954b6a0d948aadfe8f7e9487f60ceb8bfa6/googleapis_common_protos-1.73.1-py3-none-any.whl", hash = "sha256:e51f09eb0a43a8602f5a915870972e6b4a394088415c79d79605a46d8e826ee8", size = 297556, upload-time = "2026-03-26T22:15:58.455Z" }, ] [package.optional-dependencies] @@ -2726,6 +2651,34 @@ requests = [ { name = "requests-toolbelt" }, ] +[[package]] +name = "graphon" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "charset-normalizer" }, + { name = "httpx" }, + { name = "json-repair" }, + { name = "jsonschema" }, + { name = "orjson" }, + { name = "pandas", extra = ["excel"] }, + { name = "pydantic" }, + { name = "pydantic-extra-types" }, + { name = "pypandoc" }, + { name = "pypdfium2" }, + { name = "python-docx" }, + { name = "pyyaml" }, + { name = "tiktoken" }, + { name = "transformers" }, + { name = "typing-extensions" }, + { name = "unstructured", extra = ["docx", "epub", "md", "ppt", "pptx"] }, + { name = "webvtt-py" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/33/fc/0a5342a1c29bc367c2254c170ef130a84a60d8cd1c9cc84a7a85e96c1042/graphon-0.1.2.tar.gz", hash = "sha256:a2210629f93258ad2e7cbe85b5d4c6826814f6c679aa2a23ca100511363b9240", size = 214744, upload-time = "2026-03-27T20:09:53.802Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d5/46/65b5e366ec2d7017b6d6448e2635b3772d86840a6f7297277471b1bfbfbd/graphon-0.1.2-py3-none-any.whl", hash = "sha256:79f0c7796de7b8642d070730bb8bdaf1c68ccdfcecac38e0b2282e0543f0a6db", size = 314398, upload-time = "2026-03-27T20:09:52.524Z" }, +] + [[package]] name = "graphql-core" version = "3.2.7" @@ -2917,14 +2870,14 @@ wheels = [ [[package]] name = "gunicorn" -version = "25.1.0" +version = "25.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "packaging" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/66/13/ef67f59f6a7896fdc2c1d62b5665c5219d6b0a9a1784938eb9a28e55e128/gunicorn-25.1.0.tar.gz", hash = "sha256:1426611d959fa77e7de89f8c0f32eed6aa03ee735f98c01efba3e281b1c47616", size = 594377, upload-time = "2026-02-13T11:09:58.989Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c4/f4/e78fa054248fab913e2eab0332c6c2cb07421fca1ce56d8fe43b6aef57a4/gunicorn-25.3.0.tar.gz", hash = "sha256:f74e1b2f9f76f6cd1ca01198968bd2dd65830edc24b6e8e4d78de8320e2fe889", size = 634883, upload-time = "2026-03-27T00:00:26.092Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/da/73/4ad5b1f6a2e21cf1e85afdaad2b7b1a933985e2f5d679147a1953aaa192c/gunicorn-25.1.0-py3-none-any.whl", hash = "sha256:d0b1236ccf27f72cfe14bce7caadf467186f19e865094ca84221424e839b8b8b", size = 197067, upload-time = "2026-02-13T11:09:57.146Z" }, + { url = "https://files.pythonhosted.org/packages/43/c8/8aaf447698c4d59aa853fd318eed300b5c9e44459f242ab8ead6c9c09792/gunicorn-25.3.0-py3-none-any.whl", hash = "sha256:cacea387dab08cd6776501621c295a904fe8e3b7aae9a1a3cbb26f4e7ed54660", size = 208403, upload-time = "2026-03-27T00:00:27.386Z" }, ] [[package]] @@ -3157,14 +3110,14 @@ wheels = [ [[package]] name = "hypothesis" -version = "6.151.9" +version = "6.151.10" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "sortedcontainers" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/19/e1/ef365ff480903b929d28e057f57b76cae51a30375943e33374ec9a165d9c/hypothesis-6.151.9.tar.gz", hash = "sha256:2f284428dda6c3c48c580de0e18470ff9c7f5ef628a647ee8002f38c3f9097ca", size = 463534, upload-time = "2026-02-16T22:59:23.09Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f5/dd/633e2cd62377333b7681628aee2ec1d88166f5bdf916b08c98b1e8288ad3/hypothesis-6.151.10.tar.gz", hash = "sha256:6c9565af8b4aa3a080b508f66ce9c2a77dd613c7e9073e27fc7e4ef9f45f8a27", size = 463762, upload-time = "2026-03-29T01:06:22.19Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c4/f7/5cc291d701094754a1d327b44d80a44971e13962881d9a400235726171da/hypothesis-6.151.9-py3-none-any.whl", hash = "sha256:7b7220585c67759b1b1ef839b1e6e9e3d82ed468cfc1ece43c67184848d7edd9", size = 529307, upload-time = "2026-02-16T22:59:20.443Z" }, + { url = "https://files.pythonhosted.org/packages/40/da/439bb2e451979f5e88c13bbebc3e9e17754429cfb528c93677b2bd81783b/hypothesis-6.151.10-py3-none-any.whl", hash = "sha256:b0d7728f0c8c2be009f89fcdd6066f70c5439aa0f94adbb06e98261d05f49b05", size = 529493, upload-time = "2026-03-29T01:06:19.161Z" }, ] [[package]] @@ -3440,25 +3393,27 @@ sdist = { url = "https://files.pythonhosted.org/packages/0e/72/a3add0e4eec4eb9e2 [[package]] name = "langfuse" -version = "2.51.5" +version = "4.0.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "anyio" }, { name = "backoff" }, { name = "httpx" }, - { name = "idna" }, + { name = "openai" }, + { name = "opentelemetry-api" }, + { name = "opentelemetry-exporter-otlp-proto-http" }, + { name = "opentelemetry-sdk" }, { name = "packaging" }, { name = "pydantic" }, { name = "wrapt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/3c/e9/22c9c05d877ab85da6d9008aaa7360f2a9ad58787a8e36e00b1b5be9a990/langfuse-2.51.5.tar.gz", hash = "sha256:55bc37b5c5d3ae133c1a95db09117cfb3117add110ba02ebbf2ce45ac4395c5b", size = 117574, upload-time = "2024-10-09T00:59:15.016Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c9/94/ab00e21fa5977d6b9c68fb3a95de2aa1a1e586964ff2af3e37405bf65d9f/langfuse-4.0.1.tar.gz", hash = "sha256:40a6daf3ab505945c314246d5b577d48fcfde0a47e8c05267ea6bd494ae9608e", size = 272749, upload-time = "2026-03-19T14:03:34.508Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/03/f7/242a13ca094c78464b7d4df77dfe7d4c44ed77b15fed3d2e3486afa5d2e1/langfuse-2.51.5-py3-none-any.whl", hash = "sha256:b95401ca710ef94b521afa6541933b6f93d7cfd4a97523c8fc75bca4d6d219fb", size = 214281, upload-time = "2024-10-09T00:59:12.596Z" }, + { url = "https://files.pythonhosted.org/packages/27/8f/3145ef00940f9c29d7e0200fd040f35616eac21c6ab4610a1ba14f3a04c1/langfuse-4.0.1-py3-none-any.whl", hash = "sha256:e22f49ea31304f97fc31a97c014ba63baa8802d9568295d54f06b00b43c30524", size = 465049, upload-time = "2026-03-19T14:03:32.527Z" }, ] [[package]] name = "langsmith" -version = "0.7.17" +version = "0.7.22" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "httpx" }, @@ -3471,9 +3426,9 @@ dependencies = [ { name = "xxhash" }, { name = "zstandard" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/71/79/81041dde07a974e728db7def23c1c7255950b8874102925cc77093bc847d/langsmith-0.7.17.tar.gz", hash = "sha256:6c1b0c2863cdd6636d2a58b8d5b1b80060703d98cac2593f4233e09ac25b5a9d", size = 1132228, upload-time = "2026-03-12T20:41:10.808Z" } +sdist = { url = "https://files.pythonhosted.org/packages/be/2a/2d5e6c67396fd228670af278c4da7bd6db2b8d11deaf6f108490b6d3f561/langsmith-0.7.22.tar.gz", hash = "sha256:35bfe795d648b069958280760564632fd28ebc9921c04f3e209c0db6a6c7dc04", size = 1134923, upload-time = "2026-03-19T22:45:23.492Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/34/31/62689d57f4d25792bd6a3c05c868771899481be2f3e31f9e71d31e1ac4ab/langsmith-0.7.17-py3-none-any.whl", hash = "sha256:cbec10460cb6c6ecc94c18c807be88a9984838144ae6c4693c9f859f378d7d02", size = 359147, upload-time = "2026-03-12T20:41:08.758Z" }, + { url = "https://files.pythonhosted.org/packages/1a/94/1f5d72655ab6534129540843776c40eff757387b88e798d8b3bf7e313fd4/langsmith-0.7.22-py3-none-any.whl", hash = "sha256:6e9d5148314d74e86748cb9d3898632cad0320c9323d95f70f969e5bc078eee4", size = 359927, upload-time = "2026-03-19T22:45:21.603Z" }, ] [[package]] @@ -3521,7 +3476,7 @@ wheels = [ [[package]] name = "litellm" -version = "1.82.2" +version = "1.82.6" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, @@ -3537,9 +3492,9 @@ dependencies = [ { name = "tiktoken" }, { name = "tokenizers" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/60/12/010a86643f12ac0b004032d5927c260094299a84ed38b5ed20a8f8c7e3c4/litellm-1.82.2.tar.gz", hash = "sha256:f5f4c4049f344a88bf80b2e421bb927807687c99624515d7ff4152d533ec9dcb", size = 17353218, upload-time = "2026-03-13T21:24:24.5Z" } +sdist = { url = "https://files.pythonhosted.org/packages/29/75/1c537aa458426a9127a92bc2273787b2f987f4e5044e21f01f2eed5244fd/litellm-1.82.6.tar.gz", hash = "sha256:2aa1c2da21fe940c33613aa447119674a3ad4d2ad5eb064e4d5ce5ee42420136", size = 17414147, upload-time = "2026-03-22T06:36:00.452Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/96/e4/87e3ca82a8bf6e6bfffb42a539a1350dd6ced1b7169397bd439ba56fde10/litellm-1.82.2-py3-none-any.whl", hash = "sha256:641ed024774fa3d5b4dd9347f0efb1e31fa422fba2a6500aabedee085d1194cb", size = 15524224, upload-time = "2026-03-13T21:24:21.288Z" }, + { url = "https://files.pythonhosted.org/packages/02/6c/5327667e6dbe9e98cbfbd4261c8e91386a52e38f41419575854248bbab6a/litellm-1.82.6-py3-none-any.whl", hash = "sha256:164a3ef3e19f309e3cabc199bef3d2045212712fefdfa25fc7f75884a5b5b205", size = 15591595, upload-time = "2026-03-22T06:35:56.795Z" }, ] [[package]] @@ -3979,7 +3934,7 @@ wheels = [ [[package]] name = "nltk" -version = "3.9.3" +version = "3.9.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, @@ -3987,9 +3942,9 @@ dependencies = [ { name = "regex" }, { name = "tqdm" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e1/8f/915e1c12df07c70ed779d18ab83d065718a926e70d3ea33eb0cd66ffb7c0/nltk-3.9.3.tar.gz", hash = "sha256:cb5945d6424a98d694c2b9a0264519fab4363711065a46aa0ae7a2195b92e71f", size = 2923673, upload-time = "2026-02-24T12:05:53.833Z" } +sdist = { url = "https://files.pythonhosted.org/packages/74/a1/b3b4adf15585a5bc4c357adde150c01ebeeb642173ded4d871e89468767c/nltk-3.9.4.tar.gz", hash = "sha256:ed03bc098a40481310320808b2db712d95d13ca65b27372f8a403949c8b523d0", size = 2946864, upload-time = "2026-03-24T06:13:40.641Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c2/7e/9af5a710a1236e4772de8dfcc6af942a561327bb9f42b5b4a24d0cf100fd/nltk-3.9.3-py3-none-any.whl", hash = "sha256:60b3db6e9995b3dd976b1f0fa7dec22069b2677e759c28eb69b62ddd44870522", size = 1525385, upload-time = "2026-02-24T12:05:46.54Z" }, + { url = "https://files.pythonhosted.org/packages/9d/91/04e965f8e717ba0ab4bdca5c112deeab11c9e750d94c4d4602f050295d39/nltk-3.9.4-py3-none-any.whl", hash = "sha256:f2fa301c3a12718ce4a0e9305c5675299da5ad9e26068218b69d692fda84828f", size = 1552087, upload-time = "2026-03-24T06:13:38.47Z" }, ] [[package]] @@ -4247,95 +4202,95 @@ wheels = [ [[package]] name = "opentelemetry-api" -version = "1.28.0" +version = "1.40.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "deprecated" }, { name = "importlib-metadata" }, + { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/79/36/260eaea0f74fdd0c0d8f22ed3a3031109ea1c85531f94f4fde266c29e29a/opentelemetry_api-1.28.0.tar.gz", hash = "sha256:578610bcb8aa5cdcb11169d136cc752958548fb6ccffb0969c1036b0ee9e5353", size = 62803, upload-time = "2024-11-05T19:14:45.497Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2c/1d/4049a9e8698361cc1a1aa03a6c59e4fa4c71e0c0f94a30f988a6876a2ae6/opentelemetry_api-1.40.0.tar.gz", hash = "sha256:159be641c0b04d11e9ecd576906462773eb97ae1b657730f0ecf64d32071569f", size = 70851, upload-time = "2026-03-04T14:17:21.555Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/22/e4/3b25d8b856791c04d8a62b1257b5fc09dc41a057800db06885af8ddcdce1/opentelemetry_api-1.28.0-py3-none-any.whl", hash = "sha256:8457cd2c59ea1bd0988560f021656cecd254ad7ef6be4ba09dbefeca2409ce52", size = 64314, upload-time = "2024-11-05T19:14:21.659Z" }, + { url = "https://files.pythonhosted.org/packages/5f/bf/93795954016c522008da367da292adceed71cca6ee1717e1d64c83089099/opentelemetry_api-1.40.0-py3-none-any.whl", hash = "sha256:82dd69331ae74b06f6a874704be0cfaa49a1650e1537d4a813b86ecef7d0ecf9", size = 68676, upload-time = "2026-03-04T14:17:01.24Z" }, ] [[package]] name = "opentelemetry-distro" -version = "0.49b0" +version = "0.61b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, { name = "opentelemetry-instrumentation" }, { name = "opentelemetry-sdk" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/4d/75/7cb7c33899e66bb366d40a889111a78c22df0951038b6699f1663e715a9f/opentelemetry_distro-0.49b0.tar.gz", hash = "sha256:1bafa274f9e83baa0d2a5d47ed02caffcf9bcca60107b389b145400d82b07513", size = 2560, upload-time = "2024-11-05T19:21:39.379Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f5/00/1f8acc51326956a596fefaf67751380001af36029132a7a07d4debce3c06/opentelemetry_distro-0.61b0.tar.gz", hash = "sha256:975b845f50181ad53753becf4fd4b123b54fa04df5a9d78812264436d6518981", size = 2590, upload-time = "2026-03-04T14:20:12.453Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/4c/db/806172b6a4933966eee518db814b375e620602f7fe776b74ef795690f135/opentelemetry_distro-0.49b0-py3-none-any.whl", hash = "sha256:1af4074702f605ea210753dd41947dc2fd61b39724f23cdcf15d5654867cd3c2", size = 3318, upload-time = "2024-11-05T19:20:34.065Z" }, + { url = "https://files.pythonhosted.org/packages/56/2c/efcc995cd7484e6e55b1d26bd7fa6c55ca96bd415ff94310b52c19f330b0/opentelemetry_distro-0.61b0-py3-none-any.whl", hash = "sha256:f21d1ac0627549795d75e332006dd068877f00e461b1b2e8fe4568d6eb7b9590", size = 3349, upload-time = "2026-03-04T14:18:57.788Z" }, ] [[package]] name = "opentelemetry-exporter-otlp" -version = "1.28.0" +version = "1.40.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-exporter-otlp-proto-grpc" }, { name = "opentelemetry-exporter-otlp-proto-http" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/eb/16/14e3fc163930ea68f0980a4cdd4ae5796e60aeb898965990e13263d64baf/opentelemetry_exporter_otlp-1.28.0.tar.gz", hash = "sha256:31ae7495831681dd3da34ac457f6970f147465ae4b9aae3a888d7a581c7cd868", size = 6170, upload-time = "2024-11-05T19:14:47.349Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d0/37/b6708e0eff5c5fb9aba2e0ea09f7f3bcbfd12a592d2a780241b5f6014df7/opentelemetry_exporter_otlp-1.40.0.tar.gz", hash = "sha256:7caa0870b95e2fcb59d64e16e2b639ecffb07771b6cd0000b5d12e5e4fef765a", size = 6152, upload-time = "2026-03-04T14:17:23.235Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c2/82/3f521b3c1f2a411ed60a24a8c9f486c1beeaf8c6c55337c87d3ae1642151/opentelemetry_exporter_otlp-1.28.0-py3-none-any.whl", hash = "sha256:1fd02d70f2c1b7ac5579c81e78de4594b188d3317c8ceb69e8b53900fb7b40fd", size = 7024, upload-time = "2024-11-05T19:14:24.534Z" }, + { url = "https://files.pythonhosted.org/packages/2d/fc/aea77c28d9f3ffef2fdafdc3f4a235aee4091d262ddabd25882f47ce5c5f/opentelemetry_exporter_otlp-1.40.0-py3-none-any.whl", hash = "sha256:48c87e539ec9afb30dc443775a1334cc5487de2f72a770a4c00b1610bf6c697d", size = 7023, upload-time = "2026-03-04T14:17:03.612Z" }, ] [[package]] name = "opentelemetry-exporter-otlp-proto-common" -version = "1.28.0" +version = "1.40.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-proto" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c2/8d/5d411084ac441052f4c9bae03a1aec65ae5d16b439fea7b9c5ac3842c013/opentelemetry_exporter_otlp_proto_common-1.28.0.tar.gz", hash = "sha256:5fa0419b0c8e291180b0fc8430a20dd44a3f3236f8e0827992145914f273ec4f", size = 18505, upload-time = "2024-11-05T19:14:48.204Z" } +sdist = { url = "https://files.pythonhosted.org/packages/51/bc/1559d46557fe6eca0b46c88d4c2676285f1f3be2e8d06bb5d15fbffc814a/opentelemetry_exporter_otlp_proto_common-1.40.0.tar.gz", hash = "sha256:1cbee86a4064790b362a86601ee7934f368b81cd4cc2f2e163902a6e7818a0fa", size = 20416, upload-time = "2026-03-04T14:17:23.801Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e1/72/3c44aabc74db325aaba09361b6a0d80f6d601f0ff86ecea8ee655c9538fc/opentelemetry_exporter_otlp_proto_common-1.28.0-py3-none-any.whl", hash = "sha256:467e6437d24e020156dffecece8c0a4471a8a60f6a34afeda7386df31a092410", size = 18403, upload-time = "2024-11-05T19:14:25.798Z" }, + { url = "https://files.pythonhosted.org/packages/8b/ca/8f122055c97a932311a3f640273f084e738008933503d0c2563cd5d591fc/opentelemetry_exporter_otlp_proto_common-1.40.0-py3-none-any.whl", hash = "sha256:7081ff453835a82417bf38dccf122c827c3cbc94f2079b03bba02a3165f25149", size = 18369, upload-time = "2026-03-04T14:17:04.796Z" }, ] [[package]] name = "opentelemetry-exporter-otlp-proto-grpc" -version = "1.28.0" +version = "1.40.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "deprecated" }, { name = "googleapis-common-protos" }, { name = "grpcio" }, { name = "opentelemetry-api" }, { name = "opentelemetry-exporter-otlp-proto-common" }, { name = "opentelemetry-proto" }, { name = "opentelemetry-sdk" }, + { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/43/4d/f215162e58041afb4bdf5dbd0d8faf0b7fc9bf7b3d3fc0e44e06f9e7e869/opentelemetry_exporter_otlp_proto_grpc-1.28.0.tar.gz", hash = "sha256:47a11c19dc7f4289e220108e113b7de90d59791cb4c37fc29f69a6a56f2c3735", size = 26237, upload-time = "2024-11-05T19:14:49.026Z" } +sdist = { url = "https://files.pythonhosted.org/packages/8f/7f/b9e60435cfcc7590fa87436edad6822240dddbc184643a2a005301cc31f4/opentelemetry_exporter_otlp_proto_grpc-1.40.0.tar.gz", hash = "sha256:bd4015183e40b635b3dab8da528b27161ba83bf4ef545776b196f0fb4ec47740", size = 25759, upload-time = "2026-03-04T14:17:24.4Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/1d/b5/afabc8106abc0f9cfeecf5b3e682622b3e04bba1d9b967dbfcd91b9c4ebe/opentelemetry_exporter_otlp_proto_grpc-1.28.0-py3-none-any.whl", hash = "sha256:edbdc53e7783f88d4535db5807cb91bd7b1ec9e9b9cdbfee14cd378f29a3b328", size = 18532, upload-time = "2024-11-05T19:14:26.853Z" }, + { url = "https://files.pythonhosted.org/packages/96/6f/7ee0980afcbdcd2d40362da16f7f9796bd083bf7f0b8e038abfbc0300f5d/opentelemetry_exporter_otlp_proto_grpc-1.40.0-py3-none-any.whl", hash = "sha256:2aa0ca53483fe0cf6405087a7491472b70335bc5c7944378a0a8e72e86995c52", size = 20304, upload-time = "2026-03-04T14:17:05.942Z" }, ] [[package]] name = "opentelemetry-exporter-otlp-proto-http" -version = "1.28.0" +version = "1.40.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "deprecated" }, { name = "googleapis-common-protos" }, { name = "opentelemetry-api" }, { name = "opentelemetry-exporter-otlp-proto-common" }, { name = "opentelemetry-proto" }, { name = "opentelemetry-sdk" }, { name = "requests" }, + { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f1/2a/555f2845928086cd51aa6941c7a546470805b68ed631ec139ce7d841763d/opentelemetry_exporter_otlp_proto_http-1.28.0.tar.gz", hash = "sha256:d83a9a03a8367ead577f02a64127d827c79567de91560029688dd5cfd0152a8e", size = 15051, upload-time = "2024-11-05T19:14:49.813Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2e/fa/73d50e2c15c56be4d000c98e24221d494674b0cc95524e2a8cb3856d95a4/opentelemetry_exporter_otlp_proto_http-1.40.0.tar.gz", hash = "sha256:db48f5e0f33217588bbc00274a31517ba830da576e59503507c839b38fa0869c", size = 17772, upload-time = "2026-03-04T14:17:25.324Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b2/ce/80d5adabbf7ab4a0ca7b5e0f4039b24d273be370c3ba85fc05b13794411c/opentelemetry_exporter_otlp_proto_http-1.28.0-py3-none-any.whl", hash = "sha256:e8f3f7961b747edb6b44d51de4901a61e9c01d50debd747b120a08c4996c7e7b", size = 17228, upload-time = "2024-11-05T19:14:28.613Z" }, + { url = "https://files.pythonhosted.org/packages/a0/3a/8865d6754e61c9fb170cdd530a124a53769ee5f740236064816eb0ca7301/opentelemetry_exporter_otlp_proto_http-1.40.0-py3-none-any.whl", hash = "sha256:a8d1dab28f504c5d96577d6509f80a8150e44e8f45f82cdbe0e34c99ab040069", size = 19960, upload-time = "2026-03-04T14:17:07.153Z" }, ] [[package]] name = "opentelemetry-instrumentation" -version = "0.49b0" +version = "0.61b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, @@ -4343,14 +4298,14 @@ dependencies = [ { name = "packaging" }, { name = "wrapt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/de/6b/6c25b15063c92a011cf3f68375971e2c58a9c764690847edc97df2d94eeb/opentelemetry_instrumentation-0.49b0.tar.gz", hash = "sha256:398a93e0b9dc2d11cc8627e1761665c506fe08c6b2df252a2ab3ade53d751c46", size = 26478, upload-time = "2024-11-05T19:21:41.402Z" } +sdist = { url = "https://files.pythonhosted.org/packages/da/37/6bf8e66bfcee5d3c6515b79cb2ee9ad05fe573c20f7ceb288d0e7eeec28c/opentelemetry_instrumentation-0.61b0.tar.gz", hash = "sha256:cb21b48db738c9de196eba6b805b4ff9de3b7f187e4bbf9a466fa170514f1fc7", size = 32606, upload-time = "2026-03-04T14:20:16.825Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/93/61/e0d21e958d6072ce25c4f5e26a1d22835fc86f80836660adf6badb6038ce/opentelemetry_instrumentation-0.49b0-py3-none-any.whl", hash = "sha256:68364d73a1ff40894574cbc6138c5f98674790cae1f3b0865e21cf702f24dcb3", size = 30694, upload-time = "2024-11-05T19:20:38.584Z" }, + { url = "https://files.pythonhosted.org/packages/d8/3e/f6f10f178b6316de67f0dfdbbb699a24fbe8917cf1743c1595fb9dcdd461/opentelemetry_instrumentation-0.61b0-py3-none-any.whl", hash = "sha256:92a93a280e69788e8f88391247cc530fd81f16f2b011979d4d6398f805cfbc63", size = 33448, upload-time = "2026-03-04T14:19:02.447Z" }, ] [[package]] name = "opentelemetry-instrumentation-asgi" -version = "0.49b0" +version = "0.61b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "asgiref" }, @@ -4359,28 +4314,28 @@ dependencies = [ { name = "opentelemetry-semantic-conventions" }, { name = "opentelemetry-util-http" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e8/55/693c3d0938ba5fead5c3aa4ac7022a992b4ff99a8e9979800d0feb843ff4/opentelemetry_instrumentation_asgi-0.49b0.tar.gz", hash = "sha256:959fd9b1345c92f20c6ef1d42f92ef6a76b3c3083fbc4104d59da6859b15b083", size = 24117, upload-time = "2024-11-05T19:21:46.769Z" } +sdist = { url = "https://files.pythonhosted.org/packages/00/3e/143cf5c034e58037307e6a24f06e0dd64b2c49ae60a965fc580027581931/opentelemetry_instrumentation_asgi-0.61b0.tar.gz", hash = "sha256:9d08e127244361dc33976d39dd4ca8f128b5aa5a7ae425208400a80a095019b5", size = 26691, upload-time = "2026-03-04T14:20:21.038Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2c/0b/7900c782a1dfaa584588d724bc3bbdf8405a32497537dd96b3fcbf8461b9/opentelemetry_instrumentation_asgi-0.49b0-py3-none-any.whl", hash = "sha256:722a90856457c81956c88f35a6db606cc7db3231046b708aae2ddde065723dbe", size = 16326, upload-time = "2024-11-05T19:20:46.176Z" }, + { url = "https://files.pythonhosted.org/packages/19/78/154470cf9d741a7487fbb5067357b87386475bbb77948a6707cae982e158/opentelemetry_instrumentation_asgi-0.61b0-py3-none-any.whl", hash = "sha256:e4b3ce6b66074e525e717efff20745434e5efd5d9df6557710856fba356da7a4", size = 16980, upload-time = "2026-03-04T14:19:10.894Z" }, ] [[package]] name = "opentelemetry-instrumentation-celery" -version = "0.49b0" +version = "0.61b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, { name = "opentelemetry-instrumentation" }, { name = "opentelemetry-semantic-conventions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/4c/8b/9b8a9dda3ed53354c6f707a45cdb7a4730e1c109b50fc1b413525493f811/opentelemetry_instrumentation_celery-0.49b0.tar.gz", hash = "sha256:afbaee97cc9c75f29bcc9784f16f8e37c415d4fe9b334748c5b90a3d30d12473", size = 14702, upload-time = "2024-11-05T19:21:53.672Z" } +sdist = { url = "https://files.pythonhosted.org/packages/8d/43/e79108a804d16b1dc8ff28edd0e94ac393cf6359a5adcd7cdd2ec4be85f4/opentelemetry_instrumentation_celery-0.61b0.tar.gz", hash = "sha256:0e352a567dc89ed8bc083fc635035ce3c5b96bbbd92831ffd676e93b87f8e94f", size = 14780, upload-time = "2026-03-04T14:20:27.776Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/21/8c/d7d4adb36abbc0e517a69f7a069f32742122ae22d6017202f64570d9f4c5/opentelemetry_instrumentation_celery-0.49b0-py3-none-any.whl", hash = "sha256:38d4a78c78f33020032ef77ef0ead756bdf7838bcfb603de10f5925d39f14929", size = 13749, upload-time = "2024-11-05T19:20:54.98Z" }, + { url = "https://files.pythonhosted.org/packages/a2/ed/c05f3c84b455654eb6c047474ffde61ed92efc24030f64213c98bca9d44b/opentelemetry_instrumentation_celery-0.61b0-py3-none-any.whl", hash = "sha256:01235733ff0cdf571cb03b270645abb14b9c8d830313dc5842097ec90146320b", size = 13856, upload-time = "2026-03-04T14:19:20.98Z" }, ] [[package]] name = "opentelemetry-instrumentation-fastapi" -version = "0.49b0" +version = "0.61b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, @@ -4389,14 +4344,14 @@ dependencies = [ { name = "opentelemetry-semantic-conventions" }, { name = "opentelemetry-util-http" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/fe/bf/8e6d2a4807360f2203192017eb4845f5628dbeaf0597adf3d141cc5c24e1/opentelemetry_instrumentation_fastapi-0.49b0.tar.gz", hash = "sha256:6d14935c41fd3e49328188b6a59dd4c37bd17a66b01c15b0c64afa9714a1f905", size = 19230, upload-time = "2024-11-05T19:21:59.361Z" } +sdist = { url = "https://files.pythonhosted.org/packages/37/35/aa727bb6e6ef930dcdc96a617b83748fece57b43c47d83ba8d83fbeca657/opentelemetry_instrumentation_fastapi-0.61b0.tar.gz", hash = "sha256:3a24f35b07c557ae1bbc483bf8412221f25d79a405f8b047de8b670722e2fa9f", size = 24800, upload-time = "2026-03-04T14:20:32.759Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b1/f4/0895b9410c10abf987c90dee1b7688a8f2214a284fe15e575648f6a1473a/opentelemetry_instrumentation_fastapi-0.49b0-py3-none-any.whl", hash = "sha256:646e1b18523cbe6860ae9711eb2c7b9c85466c3c7697cd6b8fb5180d85d3fe6e", size = 12101, upload-time = "2024-11-05T19:21:01.805Z" }, + { url = "https://files.pythonhosted.org/packages/91/05/acfeb2cccd434242a0a7d0ea29afaf077e04b42b35b485d89aee4e0d9340/opentelemetry_instrumentation_fastapi-0.61b0-py3-none-any.whl", hash = "sha256:a1a844d846540d687d377516b2ff698b51d87c781b59f47c214359c4a241047c", size = 13485, upload-time = "2026-03-04T14:19:30.351Z" }, ] [[package]] name = "opentelemetry-instrumentation-flask" -version = "0.49b0" +version = "0.61b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, @@ -4406,14 +4361,14 @@ dependencies = [ { name = "opentelemetry-util-http" }, { name = "packaging" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/17/12/dc72873fb1e35699941d8eb6a53ef25e8c5843dea37665dad33bd720f047/opentelemetry_instrumentation_flask-0.49b0.tar.gz", hash = "sha256:f7c5ab67753c4781a2e21c8f43dc5fc02ece74fdd819466c75d025db80aa7576", size = 19176, upload-time = "2024-11-05T19:22:00.816Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d9/33/d6852d8f2c3eef86f2f8c858d6f5315983c7063e07e595519e96d4c31c06/opentelemetry_instrumentation_flask-0.61b0.tar.gz", hash = "sha256:e9faf58dfd9860a1868442d180142645abdafc1a652dd73d469a5efd106a7d49", size = 24071, upload-time = "2026-03-04T14:20:33.437Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a2/fc/354da8f33ef0daebfc8e4eac995d342ae13a35097bbad512cfe0d2f3c61a/opentelemetry_instrumentation_flask-0.49b0-py3-none-any.whl", hash = "sha256:f3ef330c3cee3e2c161f27f1e7017c8800b9bfb6f9204f2f7bfb0b274874be0e", size = 14582, upload-time = "2024-11-05T19:21:02.793Z" }, + { url = "https://files.pythonhosted.org/packages/3e/41/619f3530324a58491f2d20f216a10dd7393629b29db4610dda642a27f4ed/opentelemetry_instrumentation_flask-0.61b0-py3-none-any.whl", hash = "sha256:e8ce474d7ce543bfbbb3e93f8a6f8263348af9d7b45502f387420cf3afa71253", size = 15996, upload-time = "2026-03-04T14:19:31.304Z" }, ] [[package]] name = "opentelemetry-instrumentation-httpx" -version = "0.49b0" +version = "0.61b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, @@ -4422,14 +4377,14 @@ dependencies = [ { name = "opentelemetry-util-http" }, { name = "wrapt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a0/53/8b5e05e55a513d846ead5afb0509bec37a34a1c3e82f30b13d14156334b1/opentelemetry_instrumentation_httpx-0.49b0.tar.gz", hash = "sha256:07165b624f3e58638cee47ecf1c81939a8c2beb7e42ce9f69e25a9f21dc3f4cf", size = 17750, upload-time = "2024-11-05T19:22:02.911Z" } +sdist = { url = "https://files.pythonhosted.org/packages/cd/2a/e2becd55e33c29d1d9ef76e2579040ed1951cb33bacba259f6aff2fdd2a6/opentelemetry_instrumentation_httpx-0.61b0.tar.gz", hash = "sha256:6569ec097946c5551c2a4252f74c98666addd1bf047c1dde6b4ef426719ff8dd", size = 24104, upload-time = "2026-03-04T14:20:34.752Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/3b/9f/843391c6d645cd4f6914b27bc807fc1ff52b97f84cbe3ca675641976b23f/opentelemetry_instrumentation_httpx-0.49b0-py3-none-any.whl", hash = "sha256:e59e0d2fda5ef841630c68da1d78ff9192f63590a9099f12f0eab614abdf239a", size = 14110, upload-time = "2024-11-05T19:21:04.698Z" }, + { url = "https://files.pythonhosted.org/packages/af/88/dde310dce56e2d85cf1a09507f5888544955309edc4b8d22971d6d3d1417/opentelemetry_instrumentation_httpx-0.61b0-py3-none-any.whl", hash = "sha256:dee05c93a6593a5dc3ae5d9d5c01df8b4e2c5d02e49275e5558534ee46343d5e", size = 17198, upload-time = "2026-03-04T14:19:33.585Z" }, ] [[package]] name = "opentelemetry-instrumentation-redis" -version = "0.49b0" +version = "0.61b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, @@ -4437,14 +4392,14 @@ dependencies = [ { name = "opentelemetry-semantic-conventions" }, { name = "wrapt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/19/5b/1398eb2f92fd76787ccec28d24dc4c7dfaaf97a7557e7729e2f7c2c05d84/opentelemetry_instrumentation_redis-0.49b0.tar.gz", hash = "sha256:922542c3bd192ad4ba74e2c7e0a253c7c58a5cefbd6f89da2aba4d193a974703", size = 11353, upload-time = "2024-11-05T19:22:12.822Z" } +sdist = { url = "https://files.pythonhosted.org/packages/cf/21/26205f89358a5f2be3ee5512d3d3bce16b622977f64aeaa9d3fa8887dd39/opentelemetry_instrumentation_redis-0.61b0.tar.gz", hash = "sha256:ae0fbb56be9a641e621d55b02a7d62977a2c77c5ee760addd79b9b266e46e523", size = 14781, upload-time = "2026-03-04T14:20:45.694Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/24/e4/4f258fef0759629f2e8a0210d5533cfef3ecad69ff35be044637a3e2783e/opentelemetry_instrumentation_redis-0.49b0-py3-none-any.whl", hash = "sha256:b7d8f758bac53e77b7e7ca98ce80f91230577502dacb619ebe8e8b6058042067", size = 12453, upload-time = "2024-11-05T19:21:18.534Z" }, + { url = "https://files.pythonhosted.org/packages/a5/e1/8f4c8e4194291dbe828aeabe779050a8497b379ad90040a5a0a7074b1d08/opentelemetry_instrumentation_redis-0.61b0-py3-none-any.whl", hash = "sha256:8d4e850bbb5f8eeafa44c0eac3a007990c7125de187bc9c3659e29ff7e091172", size = 15506, upload-time = "2026-03-04T14:19:48.588Z" }, ] [[package]] name = "opentelemetry-instrumentation-sqlalchemy" -version = "0.49b0" +version = "0.61b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, @@ -4453,14 +4408,14 @@ dependencies = [ { name = "packaging" }, { name = "wrapt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a0/a7/24f6cce3808ae1802dd1b60d752fbab877db5655198929cf4ee8ea416923/opentelemetry_instrumentation_sqlalchemy-0.49b0.tar.gz", hash = "sha256:32658e520fc8b35823c722f5d8831d3a410b76dd2724adb2887befc041ddef04", size = 13194, upload-time = "2024-11-05T19:22:14.92Z" } +sdist = { url = "https://files.pythonhosted.org/packages/9e/4f/3a325b180944610697a0a926d49d782b41a86120050d44fefb2715b630ac/opentelemetry_instrumentation_sqlalchemy-0.61b0.tar.gz", hash = "sha256:13a3a159a2043a52f0180b3757fbaa26741b0e08abb50deddce4394c118956e6", size = 15343, upload-time = "2026-03-04T14:20:47.648Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ec/6b/a1a3685fed593282999cdc374ece15efbd56f8d774bd368bf7ff2cf5923c/opentelemetry_instrumentation_sqlalchemy-0.49b0-py3-none-any.whl", hash = "sha256:d854052d2b02cd0562e5628a514c8153fceada7f585137e173165dfd0a46ef6a", size = 13358, upload-time = "2024-11-05T19:21:23.654Z" }, + { url = "https://files.pythonhosted.org/packages/1f/97/b906a930c6a1a20c53ecc8b58cabc2cdd0ce560a2b5d44259084ffe4333e/opentelemetry_instrumentation_sqlalchemy-0.61b0-py3-none-any.whl", hash = "sha256:f115e0be54116ba4c327b8d7b68db4045ee18d44439d888ab8130a549c50d1c1", size = 14547, upload-time = "2026-03-04T14:19:53.088Z" }, ] [[package]] name = "opentelemetry-instrumentation-wsgi" -version = "0.49b0" +version = "0.61b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, @@ -4468,9 +4423,9 @@ dependencies = [ { name = "opentelemetry-semantic-conventions" }, { name = "opentelemetry-util-http" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/17/2b/91b022b004ac9e9ab0eefd10bc4257975291f88adc81b4ef2c601ddb1adf/opentelemetry_instrumentation_wsgi-0.49b0.tar.gz", hash = "sha256:0812a02e132f8fc3d5c897bba84e530c37b85c315b199bb97ca6508279e7eb23", size = 17733, upload-time = "2024-11-05T19:22:24.3Z" } +sdist = { url = "https://files.pythonhosted.org/packages/89/e5/189f2845362cfe78e356ba127eab21456309def411c6874aa4800c3de816/opentelemetry_instrumentation_wsgi-0.61b0.tar.gz", hash = "sha256:380f2ae61714e5303275a80b2e14c58571573cd1fddf496d8c39fb9551c5e532", size = 19898, upload-time = "2026-03-04T14:20:54.068Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/02/1d/59979665778ed8c85bc31c92b75571cd7afb8e3322fb513c87fe1bad6d78/opentelemetry_instrumentation_wsgi-0.49b0-py3-none-any.whl", hash = "sha256:8869ccf96611827e4448417718920e9eec6d25bffb5bf72c7952c7346ec33fbc", size = 13699, upload-time = "2024-11-05T19:21:35.039Z" }, + { url = "https://files.pythonhosted.org/packages/96/75/d6b42ba26f3c921be6d01b16561b7bb863f843bad7ac3a5011f62617bcab/opentelemetry_instrumentation_wsgi-0.61b0-py3-none-any.whl", hash = "sha256:bd33b0824166f24134a3400648805e8d2e6a7951f070241294e8b8866611d7fa", size = 14628, upload-time = "2026-03-04T14:20:03.934Z" }, ] [[package]] @@ -4488,55 +4443,55 @@ wheels = [ [[package]] name = "opentelemetry-proto" -version = "1.28.0" +version = "1.40.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "protobuf" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c9/63/ac4cef4d30ea0ca1d2153ad2fc62d91d1cf3b89b0e4e5cbd61a8c567885f/opentelemetry_proto-1.28.0.tar.gz", hash = "sha256:4a45728dfefa33f7908b828b9b7c9f2c6de42a05d5ec7b285662ddae71c4c870", size = 34331, upload-time = "2024-11-05T19:14:59.503Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4c/77/dd38991db037fdfce45849491cb61de5ab000f49824a00230afb112a4392/opentelemetry_proto-1.40.0.tar.gz", hash = "sha256:03f639ca129ba513f5819810f5b1f42bcb371391405d99c168fe6937c62febcd", size = 45667, upload-time = "2026-03-04T14:17:31.194Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/86/94/c0b43d16e1d96ee1e699373aa59f14a3aa2e7126af3f11d6adc5dcc531cd/opentelemetry_proto-1.28.0-py3-none-any.whl", hash = "sha256:d5ad31b997846543b8e15504657d9a8cf1ad3c71dcbbb6c4799b1ab29e38f7f9", size = 55832, upload-time = "2024-11-05T19:14:40.446Z" }, + { url = "https://files.pythonhosted.org/packages/b9/b2/189b2577dde745b15625b3214302605b1353436219d42b7912e77fa8dc24/opentelemetry_proto-1.40.0-py3-none-any.whl", hash = "sha256:266c4385d88923a23d63e353e9761af0f47a6ed0d486979777fe4de59dc9b25f", size = 72073, upload-time = "2026-03-04T14:17:16.673Z" }, ] [[package]] name = "opentelemetry-sdk" -version = "1.28.0" +version = "1.40.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, { name = "opentelemetry-semantic-conventions" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0c/5b/a509ccab93eacc6044591d5ec437d8266e76f893d0389bbf7e5592c7da32/opentelemetry_sdk-1.28.0.tar.gz", hash = "sha256:41d5420b2e3fb7716ff4981b510d551eff1fc60eb5a95cf7335b31166812a893", size = 156155, upload-time = "2024-11-05T19:15:00.451Z" } +sdist = { url = "https://files.pythonhosted.org/packages/58/fd/3c3125b20ba18ce2155ba9ea74acb0ae5d25f8cd39cfd37455601b7955cc/opentelemetry_sdk-1.40.0.tar.gz", hash = "sha256:18e9f5ec20d859d268c7cb3c5198c8d105d073714db3de50b593b8c1345a48f2", size = 184252, upload-time = "2026-03-04T14:17:31.87Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c3/fe/c8decbebb5660529f1d6ba65e50a45b1294022dfcba2968fc9c8697c42b2/opentelemetry_sdk-1.28.0-py3-none-any.whl", hash = "sha256:4b37da81d7fad67f6683c4420288c97f4ed0d988845d5886435f428ec4b8429a", size = 118692, upload-time = "2024-11-05T19:14:41.669Z" }, + { url = "https://files.pythonhosted.org/packages/2c/c5/6a852903d8bfac758c6dc6e9a68b015d3c33f2f1be5e9591e0f4b69c7e0a/opentelemetry_sdk-1.40.0-py3-none-any.whl", hash = "sha256:787d2154a71f4b3d81f20524a8ce061b7db667d24e46753f32a7bc48f1c1f3f1", size = 141951, upload-time = "2026-03-04T14:17:17.961Z" }, ] [[package]] name = "opentelemetry-semantic-conventions" -version = "0.49b0" +version = "0.61b0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "deprecated" }, { name = "opentelemetry-api" }, + { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ee/c8/433b0e54143f8c9369f5c4a7a83e73eec7eb2ee7d0b7e81a9243e78c8e80/opentelemetry_semantic_conventions-0.49b0.tar.gz", hash = "sha256:dbc7b28339e5390b6b28e022835f9bac4e134a80ebf640848306d3c5192557e8", size = 95227, upload-time = "2024-11-05T19:15:01.443Z" } +sdist = { url = "https://files.pythonhosted.org/packages/6d/c0/4ae7973f3c2cfd2b6e321f1675626f0dab0a97027cc7a297474c9c8f3d04/opentelemetry_semantic_conventions-0.61b0.tar.gz", hash = "sha256:072f65473c5d7c6dc0355b27d6c9d1a679d63b6d4b4b16a9773062cb7e31192a", size = 145755, upload-time = "2026-03-04T14:17:32.664Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/25/05/20104df4ef07d3bf5c3fd6bcc796ef70ab4ea4309378a9ba57bc4b4d01fa/opentelemetry_semantic_conventions-0.49b0-py3-none-any.whl", hash = "sha256:0458117f6ead0b12e3221813e3e511d85698c31901cac84682052adb9c17c7cd", size = 159214, upload-time = "2024-11-05T19:14:43.047Z" }, + { url = "https://files.pythonhosted.org/packages/b2/37/cc6a55e448deaa9b27377d087da8615a3416d8ad523d5960b78dbeadd02a/opentelemetry_semantic_conventions-0.61b0-py3-none-any.whl", hash = "sha256:fa530a96be229795f8cef353739b618148b0fe2b4b3f005e60e262926c4d38e2", size = 231621, upload-time = "2026-03-04T14:17:19.33Z" }, ] [[package]] name = "opentelemetry-util-http" -version = "0.49b0" +version = "0.61b0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a3/99/377ef446928808211b127b9ab31c348bc465c8da4514ebeec6e4a3de3d21/opentelemetry_util_http-0.49b0.tar.gz", hash = "sha256:02928496afcffd58a7c15baf99d2cedae9b8325a8ac52b0d0877b2e8f936dd1b", size = 7863, upload-time = "2024-11-05T19:22:26.973Z" } +sdist = { url = "https://files.pythonhosted.org/packages/57/3c/f0196223efc5c4ca19f8fad3d5462b171ac6333013335ce540c01af419e9/opentelemetry_util_http-0.61b0.tar.gz", hash = "sha256:1039cb891334ad2731affdf034d8fb8b48c239af9b6dd295e5fabd07f1c95572", size = 11361, upload-time = "2026-03-04T14:20:57.01Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/66/0e/ab0a89b315d0bacdd355a345bb69b20c50fc1f0804b52b56fe1c35a60e68/opentelemetry_util_http-0.49b0-py3-none-any.whl", hash = "sha256:8661bbd6aea1839badc44de067ec9c15c05eab05f729f496c856c50a1203caf1", size = 6945, upload-time = "2024-11-05T19:21:37.81Z" }, + { url = "https://files.pythonhosted.org/packages/0d/e5/c08aaaf2f64288d2b6ef65741d2de5454e64af3e050f34285fb1907492fe/opentelemetry_util_http-0.61b0-py3-none-any.whl", hash = "sha256:8e715e848233e9527ea47e275659ea60a57a75edf5206a3b937e236a6da5fc33", size = 9281, upload-time = "2026-03-04T14:20:08.364Z" }, ] [[package]] name = "opik" -version = "1.10.39" +version = "1.10.54" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "boto3-stubs", extra = ["bedrock-runtime"] }, @@ -4555,9 +4510,9 @@ dependencies = [ { name = "tqdm" }, { name = "uuid6" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/b5/0f/b1e00a18cac16b4f36bf6cecc2de962fda810a9416d1159c48f46b81f5ec/opik-1.10.39.tar.gz", hash = "sha256:4d808eb2137070fc5d92a3bed3c3100d9cccfb35f4f0b71ea9990733f293dbb2", size = 780312, upload-time = "2026-03-12T14:08:25.746Z" } +sdist = { url = "https://files.pythonhosted.org/packages/fa/c9/ecc68c5ae32bf5b1074bdc713cb1543b8e2a46c58c814bf150fecf50f272/opik-1.10.54.tar.gz", hash = "sha256:46e29abf4656bd80b9cb339659d24ecf97b61f37c3fde594de75e5f59953e9d3", size = 812757, upload-time = "2026-03-27T11:23:06.109Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e1/24/0f4404907a98b4aec4508504570a78a61a3a8b5e451c67326632695ba8e6/opik-1.10.39-py3-none-any.whl", hash = "sha256:a72d735b9afac62e5262294b2f704aca89ec31f5c9beda17504815f7423870c3", size = 1317833, upload-time = "2026-03-12T14:08:23.954Z" }, + { url = "https://files.pythonhosted.org/packages/58/91/1ae4e8a349da0620a6f0a4fc51cd00c3e75176939d022e8684379aee2928/opik-1.10.54-py3-none-any.whl", hash = "sha256:5f8ddabe5283ebe08d455e81b188d6e09ce1d1efa989f8b05567ef70f1e9aeda", size = 1379008, upload-time = "2026-03-27T11:23:04.582Z" }, ] [[package]] @@ -5273,15 +5228,15 @@ wheels = [ [[package]] name = "pydantic-extra-types" -version = "2.11.0" +version = "2.11.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pydantic" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/fd/35/2fee58b1316a73e025728583d3b1447218a97e621933fc776fb8c0f2ebdd/pydantic_extra_types-2.11.0.tar.gz", hash = "sha256:4e9991959d045b75feb775683437a97991d02c138e00b59176571db9ce634f0e", size = 157226, upload-time = "2025-12-31T16:18:27.944Z" } +sdist = { url = "https://files.pythonhosted.org/packages/66/71/dba38ee2651f84f7842206adbd2233d8bbdb59fb85e9fa14232486a8c471/pydantic_extra_types-2.11.1.tar.gz", hash = "sha256:46792d2307383859e923d8fcefa82108b1a141f8a9c0198982b3832ab5ef1049", size = 172002, upload-time = "2026-03-16T08:08:03.92Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/fe/17/fabd56da47096d240dd45ba627bead0333b0cf0ee8ada9bec579287dadf3/pydantic_extra_types-2.11.0-py3-none-any.whl", hash = "sha256:84b864d250a0fc62535b7ec591e36f2c5b4d1325fa0017eb8cda9aeb63b374a6", size = 74296, upload-time = "2025-12-31T16:18:26.38Z" }, + { url = "https://files.pythonhosted.org/packages/17/c1/3226e6d7f5a4f736f38ac11a6fbb262d701889802595cdb0f53a885ac2e0/pydantic_extra_types-2.11.1-py3-none-any.whl", hash = "sha256:1722ea2bddae5628ace25f2aa685b69978ef533123e5638cfbddb999e0100ec1", size = 79526, upload-time = "2026-03-16T08:08:02.533Z" }, ] [[package]] @@ -5323,7 +5278,7 @@ crypto = [ [[package]] name = "pymilvus" -version = "2.6.10" +version = "2.6.11" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cachetools" }, @@ -5335,9 +5290,9 @@ dependencies = [ { name = "requests" }, { name = "setuptools" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/9e/85/90362066ccda5ff6fec693a55693cde659fdcd36d08f1bd7012ae958248d/pymilvus-2.6.10.tar.gz", hash = "sha256:58a44ee0f1dddd7727ae830ef25325872d8946f029d801a37105164e6699f1b8", size = 1561042, upload-time = "2026-03-13T09:54:22.441Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c9/e6/0adc3b374f5c5d1eebd4f551b455c6865c449b170b17545001b208e2b153/pymilvus-2.6.11.tar.gz", hash = "sha256:a40c10322cde25184a8c3d84993a14dfb67ad2bdcfc5dff7e68b11a79ff8f6d8", size = 1583634, upload-time = "2026-03-27T06:25:46.023Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/88/10/fe7fbb6795aa20038afd55e9c653991e7c69fb24c741ebb39ba3b0aa5c13/pymilvus-2.6.10-py3-none-any.whl", hash = "sha256:a048b6f3ebad93742bca559beabf44fe578f0983555a109c4436b5fb2c1dbd40", size = 312797, upload-time = "2026-03-13T09:54:21.081Z" }, + { url = "https://files.pythonhosted.org/packages/9c/1c/bccb331d71f824738f80f11e9b8b4da47973c903826355526ae4fa2b762f/pymilvus-2.6.11-py3-none-any.whl", hash = "sha256:a11e1718b15045361c71ca671b959900cb7e2faae863c896f6b7e87bf2e4d10a", size = 315252, upload-time = "2026-03-27T06:25:44.215Z" }, ] [[package]] @@ -5380,6 +5335,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/7d/037401cecb34728d1c28ea05e196ea3c9d50a1ce0f2172e586e075ff55d8/pyobvector-0.2.25-py3-none-any.whl", hash = "sha256:ae0153f99bd0222783ed7e3951efc31a0d2b462d926b6f86ebd2033409aede8f", size = 64663, upload-time = "2026-03-10T07:18:29.789Z" }, ] +[[package]] +name = "pypandoc" +version = "1.17" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ea/d6/410615fc433e5d1eacc00db2044ae2a9c82302df0d35366fe2bd15de024d/pypandoc-1.17.tar.gz", hash = "sha256:51179abfd6e582a25ed03477541b48836b5bba5a4c3b282a547630793934d799", size = 69071, upload-time = "2026-03-14T22:39:07.21Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0c/86/e2ffa604eacfbec3f430b1d850e7e04c4101eca1a5828f9ae54bf51dfba4/pypandoc-1.17-py3-none-any.whl", hash = "sha256:01fdbffa61edb9f8e82e8faad6954efcb7b6f8f0634aead4d89e322a00225a67", size = 23554, upload-time = "2026-03-14T22:38:46.007Z" }, +] + [[package]] name = "pypandoc-binary" version = "1.17" @@ -5405,11 +5369,11 @@ wheels = [ [[package]] name = "pypdf" -version = "6.9.1" +version = "6.9.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f9/fb/dc2e8cb006e80b0020ed20d8649106fe4274e82d8e756ad3e24ade19c0df/pypdf-6.9.1.tar.gz", hash = "sha256:ae052407d33d34de0c86c5c729be6d51010bf36e03035a8f23ab449bca52377d", size = 5311551, upload-time = "2026-03-17T10:46:07.876Z" } +sdist = { url = "https://files.pythonhosted.org/packages/31/83/691bdb309306232362503083cb15777491045dd54f45393a317dc7d8082f/pypdf-6.9.2.tar.gz", hash = "sha256:7f850faf2b0d4ab936582c05da32c52214c2b089d61a316627b5bfb5b0dab46c", size = 5311837, upload-time = "2026-03-23T14:53:27.983Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f9/f4/75543fa802b86e72f87e9395440fe1a89a6d149887e3e55745715c3352ac/pypdf-6.9.1-py3-none-any.whl", hash = "sha256:f35a6a022348fae47e092a908339a8f3dc993510c026bb39a96718fc7185e89f", size = 333661, upload-time = "2026-03-17T10:46:06.286Z" }, + { url = "https://files.pythonhosted.org/packages/a5/7e/c85f41243086a8fe5d1baeba527cb26a1918158a565932b41e0f7c0b32e9/pypdf-6.9.2-py3-none-any.whl", hash = "sha256:662cf29bcb419a36a1365232449624ab40b7c2d0cfc28e54f42eeecd1fd7e844", size = 333744, upload-time = "2026-03-23T14:53:26.573Z" }, ] [[package]] @@ -5467,18 +5431,18 @@ wheels = [ [[package]] name = "pyrefly" -version = "0.55.0" +version = "0.57.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/bf/c4/76e0797215e62d007f81f86c9c4fb5d6202685a3f5e70810f3fd94294f92/pyrefly-0.55.0.tar.gz", hash = "sha256:434c3282532dd4525c4840f2040ed0eb79b0ec8224fe18d957956b15471f2441", size = 5135682, upload-time = "2026-03-03T00:46:38.122Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c9/c1/c17211e5bbd2b90a24447484713da7cc2cee4e9455e57b87016ffc69d426/pyrefly-0.57.1.tar.gz", hash = "sha256:b05f6f5ee3a6a5d502ca19d84cb9ab62d67f05083819964a48c1510f2993efc6", size = 5310800, upload-time = "2026-03-18T18:42:35.614Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/39/b0/16e50cf716784513648e23e726a24f71f9544aa4f86103032dcaa5ff71a2/pyrefly-0.55.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:49aafcefe5e2dd4256147db93e5b0ada42bff7d9a60db70e03d1f7055338eec9", size = 12210073, upload-time = "2026-03-03T00:46:15.51Z" }, - { url = "https://files.pythonhosted.org/packages/3a/ad/89500c01bac3083383011600370289fbc67700c5be46e781787392628a3a/pyrefly-0.55.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:2827426e6b28397c13badb93c0ede0fb0f48046a7a89e3d774cda04e8e2067cd", size = 11767474, upload-time = "2026-03-03T00:46:18.003Z" }, - { url = "https://files.pythonhosted.org/packages/78/68/4c66b260f817f304ead11176ff13985625f7c269e653304b4bdb546551af/pyrefly-0.55.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7346b2d64dc575bd61aa3bca854fbf8b5a19a471cbdb45e0ca1e09861b63488c", size = 33260395, upload-time = "2026-03-03T00:46:20.509Z" }, - { url = "https://files.pythonhosted.org/packages/47/09/10bd48c9f860064f29f412954126a827d60f6451512224912c265e26bbe6/pyrefly-0.55.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:233b861b4cff008b1aff62f4f941577ed752e4d0060834229eb9b6826e6973c9", size = 35848269, upload-time = "2026-03-03T00:46:23.418Z" }, - { url = "https://files.pythonhosted.org/packages/a9/39/bc65cdd5243eb2dfea25dd1321f9a5a93e8d9c3a308501c4c6c05d011585/pyrefly-0.55.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5aa85657d76da1d25d081a49f0e33c8fc3ec91c1a0f185a8ed393a5a3d9e178", size = 38449820, upload-time = "2026-03-03T00:46:26.309Z" }, - { url = "https://files.pythonhosted.org/packages/e5/64/58b38963b011af91209e87f868cc85cfc762ec49a4568ce610c45e7a5f40/pyrefly-0.55.0-py3-none-win32.whl", hash = "sha256:23f786a78536a56fed331b245b7d10ec8945bebee7b723491c8d66fdbc155fe6", size = 11259415, upload-time = "2026-03-03T00:46:30.875Z" }, - { url = "https://files.pythonhosted.org/packages/7a/0b/a4aa519ff632a1ea69eec942566951670b870b99b5c08407e1387b85b6a4/pyrefly-0.55.0-py3-none-win_amd64.whl", hash = "sha256:d465b49e999b50eeb069ad23f0f5710651cad2576f9452a82991bef557df91ee", size = 12043581, upload-time = "2026-03-03T00:46:33.674Z" }, - { url = "https://files.pythonhosted.org/packages/f1/51/89017636fbe1ffd166ad478990c6052df615b926182fa6d3c0842b407e89/pyrefly-0.55.0-py3-none-win_arm64.whl", hash = "sha256:732ff490e0e863b296e7c0b2471e08f8ba7952f9fa6e9de09d8347fd67dde77f", size = 11548076, upload-time = "2026-03-03T00:46:36.193Z" }, + { url = "https://files.pythonhosted.org/packages/b7/58/8af37856c8d45b365ece635a6728a14b0356b08d1ff1ac601d7120def1e0/pyrefly-0.57.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:91974bfbe951eebf5a7bc959c1f3921f0371c789cad84761511d695e9ab2265f", size = 12681847, upload-time = "2026-03-18T18:42:10.963Z" }, + { url = "https://files.pythonhosted.org/packages/5f/d7/fae6dd9d0355fc5b8df7793f1423b7433ca8e10b698ea934c35f0e4e6522/pyrefly-0.57.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:808087298537c70f5e7cdccb5bbaad482e7e056e947c0adf00fb612cbace9fdc", size = 12219634, upload-time = "2026-03-18T18:42:13.469Z" }, + { url = "https://files.pythonhosted.org/packages/29/8f/9511ae460f0690e837b9ba0f7e5e192079e16ff9a9ba8a272450e81f11f8/pyrefly-0.57.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b01f454fa5539e070c0cba17ddec46b3d2107d571d519bd8eca8f3142ba02a6", size = 34947757, upload-time = "2026-03-18T18:42:17.152Z" }, + { url = "https://files.pythonhosted.org/packages/07/43/f053bf9c65218f70e6a49561e9942c7233f8c3e4da8d42e5fe2aae50b3d2/pyrefly-0.57.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:02ad59ea722191f51635f23e37574662116b82ca9d814529f7cb5528f041f381", size = 37621018, upload-time = "2026-03-18T18:42:20.79Z" }, + { url = "https://files.pythonhosted.org/packages/0e/76/9cea46de01665bbc125e4f215340c9365c8d56cda6198ff238a563ea8e75/pyrefly-0.57.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:54bc0afe56776145e37733ff763e7e9679ee8a76c467b617dc3f227d4124a9e2", size = 40203649, upload-time = "2026-03-18T18:42:24.519Z" }, + { url = "https://files.pythonhosted.org/packages/fd/8b/2fb4a96d75e2a57df698a43e2970e441ba2704e3906cdc0386a055daa05a/pyrefly-0.57.1-py3-none-win32.whl", hash = "sha256:468e5839144b25bb0dce839bfc5fd879c9f38e68ebf5de561f30bed9ae19d8ca", size = 11732953, upload-time = "2026-03-18T18:42:27.379Z" }, + { url = "https://files.pythonhosted.org/packages/13/5a/4a197910fe2e9b102b15ae5e7687c45b7b5981275a11a564b41e185dd907/pyrefly-0.57.1-py3-none-win_amd64.whl", hash = "sha256:46db9c97093673c4fb7fab96d610e74d140661d54688a92d8e75ad885a56c141", size = 12537319, upload-time = "2026-03-18T18:42:30.196Z" }, + { url = "https://files.pythonhosted.org/packages/b5/c6/bc442874be1d9b63da1f9debb4f04b7d0c590a8dc4091921f3c288207242/pyrefly-0.57.1-py3-none-win_arm64.whl", hash = "sha256:feb1bbe3b0d8d5a70121dcdf1476e6a99cc056a26a49379a156f040729244dcb", size = 12013455, upload-time = "2026-03-18T18:42:32.928Z" }, ] [[package]] @@ -5512,16 +5476,16 @@ wheels = [ [[package]] name = "pytest-cov" -version = "7.0.0" +version = "7.1.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "coverage", extra = ["toml"] }, { name = "pluggy" }, { name = "pytest" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/5e/f7/c933acc76f5208b3b00089573cf6a2bc26dc80a8aece8f52bb7d6b1855ca/pytest_cov-7.0.0.tar.gz", hash = "sha256:33c97eda2e049a0c5298e91f519302a1334c26ac65c1a483d6206fd458361af1", size = 54328, upload-time = "2025-09-09T10:57:02.113Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b1/51/a849f96e117386044471c8ec2bd6cfebacda285da9525c9106aeb28da671/pytest_cov-7.1.0.tar.gz", hash = "sha256:30674f2b5f6351aa09702a9c8c364f6a01c27aae0c1366ae8016160d1efc56b2", size = 55592, upload-time = "2026-03-21T20:11:16.284Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ee/49/1377b49de7d0c1ce41292161ea0f721913fa8722c19fb9c1e3aa0367eecb/pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861", size = 22424, upload-time = "2025-09-09T10:57:00.695Z" }, + { url = "https://files.pythonhosted.org/packages/9d/7a/d968e294073affff457b041c2be9868a40c1c71f4a35fcc1e45e5493067b/pytest_cov-7.1.0-py3-none-any.whl", hash = "sha256:a0461110b7865f9a271aa1b51e516c9a95de9d696734a2f71e3e78f46e1d4678", size = 22876, upload-time = "2026-03-21T20:11:14.438Z" }, ] [[package]] @@ -5850,14 +5814,14 @@ wheels = [ [[package]] name = "redis" -version = "7.3.0" +version = "7.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "async-timeout", marker = "python_full_version < '3.11.3'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/da/82/4d1a5279f6c1251d3d2a603a798a1137c657de9b12cfc1fba4858232c4d2/redis-7.3.0.tar.gz", hash = "sha256:4d1b768aafcf41b01022410b3cc4f15a07d9b3d6fe0c66fc967da2c88e551034", size = 4928081, upload-time = "2026-03-06T18:18:16.287Z" } +sdist = { url = "https://files.pythonhosted.org/packages/7b/7f/3759b1d0d72b7c92f0d70ffd9dc962b7b7b5ee74e135f9d7d8ab06b8a318/redis-7.4.0.tar.gz", hash = "sha256:64a6ea7bf567ad43c964d2c30d82853f8df927c5c9017766c55a1d1ed95d18ad", size = 4943913, upload-time = "2026-03-24T09:14:37.53Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f0/28/84e57fce7819e81ec5aa1bd31c42b89607241f4fb1a3ea5b0d2dbeaea26c/redis-7.3.0-py3-none-any.whl", hash = "sha256:9d4fcb002a12a5e3c3fbe005d59c48a2cc231f87fbb2f6b70c2d89bb64fec364", size = 404379, upload-time = "2026-03-06T18:18:14.583Z" }, + { url = "https://files.pythonhosted.org/packages/74/3a/95deec7db1eb53979973ebd156f3369a72732208d1391cd2e5d127062a32/redis-7.4.0-py3-none-any.whl", hash = "sha256:a9c74a5c893a5ef8455a5adb793a31bb70feb821c86eccb62eebef5a19c429ec", size = 409772, upload-time = "2026-03-24T09:14:35.968Z" }, ] [package.optional-dependencies] @@ -5917,7 +5881,7 @@ wheels = [ [[package]] name = "requests" -version = "2.32.5" +version = "2.33.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "certifi" }, @@ -5925,9 +5889,9 @@ dependencies = [ { name = "idna" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c9/74/b3ff8e6c8446842c3f5c837e9c3dfcfe2018ea6ecef224c710c85ef728f4/requests-2.32.5.tar.gz", hash = "sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf", size = 134517, upload-time = "2025-08-18T20:46:02.573Z" } +sdist = { url = "https://files.pythonhosted.org/packages/34/64/8860370b167a9721e8956ae116825caff829224fbca0ca6e7bf8ddef8430/requests-2.33.0.tar.gz", hash = "sha256:c7ebc5e8b0f21837386ad0e1c8fe8b829fa5f544d8df3b2253bff14ef29d7652", size = 134232, upload-time = "2026-03-25T15:10:41.586Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" }, + { url = "https://files.pythonhosted.org/packages/56/5d/c814546c2333ceea4ba42262d8c4d55763003e767fa169adc693bd524478/requests-2.33.0-py3-none-any.whl", hash = "sha256:3324635456fa185245e24865e810cecec7b4caf933d7eb133dcde67d48cee69b", size = 65017, upload-time = "2026-03-25T15:10:40.382Z" }, ] [[package]] @@ -5957,15 +5921,15 @@ wheels = [ [[package]] name = "resend" -version = "2.23.0" +version = "2.26.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "requests" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/96/a3/20003e7d14604fef778bd30c69604df3560a657a95a5c29a9688610759b6/resend-2.23.0.tar.gz", hash = "sha256:df613827dcc40eb1c9de2e5ff600cd4081b89b206537dec8067af1a5016d23c7", size = 31416, upload-time = "2026-02-23T19:01:57.603Z" } +sdist = { url = "https://files.pythonhosted.org/packages/07/ff/6a4e5e758fc2145c6a7d8563934d8ee24bf96a0212d7ec7d1af1f155bb74/resend-2.26.0.tar.gz", hash = "sha256:957a6a59dc597ce27fbd6d5383220dd9cc497fab99d4f3d775c8a42a449a569e", size = 36238, upload-time = "2026-03-20T22:49:09.728Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e3/35/64df775b8cd95e89798fd7b1b7fcafa975b6b09f559c10c0650e65b33580/resend-2.23.0-py2.py3-none-any.whl", hash = "sha256:eca6d28a1ffd36c1fc489fa83cb6b511f384792c9f07465f7c92d96c8b4d5636", size = 52599, upload-time = "2026-02-23T19:01:55.962Z" }, + { url = "https://files.pythonhosted.org/packages/16/c2/f88d3299d97aa1d36a923d0846fe185fcf5355ca898c954b2e5a79f090b5/resend-2.26.0-py2.py3-none-any.whl", hash = "sha256:5e25a804a84a68df504f2ade5369ac37e0139e37788a1f20b66c88696595b4bc", size = 57699, upload-time = "2026-03-20T22:49:08.354Z" }, ] [[package]] @@ -6046,27 +6010,27 @@ wheels = [ [[package]] name = "ruff" -version = "0.15.6" +version = "0.15.8" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/51/df/f8629c19c5318601d3121e230f74cbee7a3732339c52b21daa2b82ef9c7d/ruff-0.15.6.tar.gz", hash = "sha256:8394c7bb153a4e3811a4ecdacd4a8e6a4fa8097028119160dffecdcdf9b56ae4", size = 4597916, upload-time = "2026-03-12T23:05:47.51Z" } +sdist = { url = "https://files.pythonhosted.org/packages/14/b0/73cf7550861e2b4824950b8b52eebdcc5adc792a00c514406556c5b80817/ruff-0.15.8.tar.gz", hash = "sha256:995f11f63597ee362130d1d5a327a87cb6f3f5eae3094c620bcc632329a4d26e", size = 4610921, upload-time = "2026-03-26T18:39:38.675Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/9e/2f/4e03a7e5ce99b517e98d3b4951f411de2b0fa8348d39cf446671adcce9a2/ruff-0.15.6-py3-none-linux_armv6l.whl", hash = "sha256:7c98c3b16407b2cf3d0f2b80c80187384bc92c6774d85fefa913ecd941256fff", size = 10508953, upload-time = "2026-03-12T23:05:17.246Z" }, - { url = "https://files.pythonhosted.org/packages/70/60/55bcdc3e9f80bcf39edf0cd272da6fa511a3d94d5a0dd9e0adf76ceebdb4/ruff-0.15.6-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:ee7dcfaad8b282a284df4aa6ddc2741b3f4a18b0555d626805555a820ea181c3", size = 10942257, upload-time = "2026-03-12T23:05:23.076Z" }, - { url = "https://files.pythonhosted.org/packages/e7/f9/005c29bd1726c0f492bfa215e95154cf480574140cb5f867c797c18c790b/ruff-0.15.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:3bd9967851a25f038fc8b9ae88a7fbd1b609f30349231dffaa37b6804923c4bb", size = 10322683, upload-time = "2026-03-12T23:05:33.738Z" }, - { url = "https://files.pythonhosted.org/packages/5f/74/2f861f5fd7cbb2146bddb5501450300ce41562da36d21868c69b7a828169/ruff-0.15.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:13f4594b04e42cd24a41da653886b04d2ff87adbf57497ed4f728b0e8a4866f8", size = 10660986, upload-time = "2026-03-12T23:05:53.245Z" }, - { url = "https://files.pythonhosted.org/packages/c1/a1/309f2364a424eccb763cdafc49df843c282609f47fe53aa83f38272389e0/ruff-0.15.6-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e2ed8aea2f3fe57886d3f00ea5b8aae5bf68d5e195f487f037a955ff9fbaac9e", size = 10332177, upload-time = "2026-03-12T23:05:56.145Z" }, - { url = "https://files.pythonhosted.org/packages/30/41/7ebf1d32658b4bab20f8ac80972fb19cd4e2c6b78552be263a680edc55ac/ruff-0.15.6-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:70789d3e7830b848b548aae96766431c0dc01a6c78c13381f423bf7076c66d15", size = 11170783, upload-time = "2026-03-12T23:06:01.742Z" }, - { url = "https://files.pythonhosted.org/packages/76/be/6d488f6adca047df82cd62c304638bcb00821c36bd4881cfca221561fdfc/ruff-0.15.6-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:542aaf1de3154cea088ced5a819ce872611256ffe2498e750bbae5247a8114e9", size = 12044201, upload-time = "2026-03-12T23:05:28.697Z" }, - { url = "https://files.pythonhosted.org/packages/71/68/e6f125df4af7e6d0b498f8d373274794bc5156b324e8ab4bf5c1b4fc0ec7/ruff-0.15.6-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1c22e6f02c16cfac3888aa636e9eba857254d15bbacc9906c9689fdecb1953ab", size = 11421561, upload-time = "2026-03-12T23:05:31.236Z" }, - { url = "https://files.pythonhosted.org/packages/f1/9f/f85ef5fd01a52e0b472b26dc1b4bd228b8f6f0435975442ffa4741278703/ruff-0.15.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98893c4c0aadc8e448cfa315bd0cc343a5323d740fe5f28ef8a3f9e21b381f7e", size = 11310928, upload-time = "2026-03-12T23:05:45.288Z" }, - { url = "https://files.pythonhosted.org/packages/8c/26/b75f8c421f5654304b89471ed384ae8c7f42b4dff58fa6ce1626d7f2b59a/ruff-0.15.6-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:70d263770d234912374493e8cc1e7385c5d49376e41dfa51c5c3453169dc581c", size = 11235186, upload-time = "2026-03-12T23:05:50.677Z" }, - { url = "https://files.pythonhosted.org/packages/fc/d4/d5a6d065962ff7a68a86c9b4f5500f7d101a0792078de636526c0edd40da/ruff-0.15.6-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:55a1ad63c5a6e54b1f21b7514dfadc0c7fb40093fa22e95143cf3f64ebdcd512", size = 10635231, upload-time = "2026-03-12T23:05:37.044Z" }, - { url = "https://files.pythonhosted.org/packages/d6/56/7c3acf3d50910375349016cf33de24be021532042afbed87942858992491/ruff-0.15.6-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:8dc473ba093c5ec238bb1e7429ee676dca24643c471e11fbaa8a857925b061c0", size = 10340357, upload-time = "2026-03-12T23:06:04.748Z" }, - { url = "https://files.pythonhosted.org/packages/06/54/6faa39e9c1033ff6a3b6e76b5df536931cd30caf64988e112bbf91ef5ce5/ruff-0.15.6-py3-none-musllinux_1_2_i686.whl", hash = "sha256:85b042377c2a5561131767974617006f99f7e13c63c111b998f29fc1e58a4cfb", size = 10860583, upload-time = "2026-03-12T23:05:58.978Z" }, - { url = "https://files.pythonhosted.org/packages/cb/1e/509a201b843b4dfb0b32acdedf68d951d3377988cae43949ba4c4133a96a/ruff-0.15.6-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:cef49e30bc5a86a6a92098a7fbf6e467a234d90b63305d6f3ec01225a9d092e0", size = 11410976, upload-time = "2026-03-12T23:05:39.955Z" }, - { url = "https://files.pythonhosted.org/packages/6c/25/3fc9114abf979a41673ce877c08016f8e660ad6cf508c3957f537d2e9fa9/ruff-0.15.6-py3-none-win32.whl", hash = "sha256:bbf67d39832404812a2d23020dda68fee7f18ce15654e96fb1d3ad21a5fe436c", size = 10616872, upload-time = "2026-03-12T23:05:42.451Z" }, - { url = "https://files.pythonhosted.org/packages/89/7a/09ece68445ceac348df06e08bf75db72d0e8427765b96c9c0ffabc1be1d9/ruff-0.15.6-py3-none-win_amd64.whl", hash = "sha256:aee25bc84c2f1007ecb5037dff75cef00414fdf17c23f07dc13e577883dca406", size = 11787271, upload-time = "2026-03-12T23:05:20.168Z" }, - { url = "https://files.pythonhosted.org/packages/7f/d0/578c47dd68152ddddddf31cd7fc67dc30b7cdf639a86275fda821b0d9d98/ruff-0.15.6-py3-none-win_arm64.whl", hash = "sha256:c34de3dd0b0ba203be50ae70f5910b17188556630e2178fd7d79fc030eb0d837", size = 11060497, upload-time = "2026-03-12T23:05:25.968Z" }, + { url = "https://files.pythonhosted.org/packages/4a/92/c445b0cd6da6e7ae51e954939cb69f97e008dbe750cfca89b8cedc081be7/ruff-0.15.8-py3-none-linux_armv6l.whl", hash = "sha256:cbe05adeba76d58162762d6b239c9056f1a15a55bd4b346cfd21e26cd6ad7bc7", size = 10527394, upload-time = "2026-03-26T18:39:41.566Z" }, + { url = "https://files.pythonhosted.org/packages/eb/92/f1c662784d149ad1414cae450b082cf736430c12ca78367f20f5ed569d65/ruff-0.15.8-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:d3e3d0b6ba8dca1b7ef9ab80a28e840a20070c4b62e56d675c24f366ef330570", size = 10905693, upload-time = "2026-03-26T18:39:30.364Z" }, + { url = "https://files.pythonhosted.org/packages/ca/f2/7a631a8af6d88bcef997eb1bf87cc3da158294c57044aafd3e17030613de/ruff-0.15.8-py3-none-macosx_11_0_arm64.whl", hash = "sha256:6ee3ae5c65a42f273f126686353f2e08ff29927b7b7e203b711514370d500de3", size = 10323044, upload-time = "2026-03-26T18:39:33.37Z" }, + { url = "https://files.pythonhosted.org/packages/67/18/1bf38e20914a05e72ef3b9569b1d5c70a7ef26cd188d69e9ca8ef588d5bf/ruff-0.15.8-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fdce027ada77baa448077ccc6ebb2fa9c3c62fd110d8659d601cf2f475858d94", size = 10629135, upload-time = "2026-03-26T18:39:44.142Z" }, + { url = "https://files.pythonhosted.org/packages/d2/e9/138c150ff9af60556121623d41aba18b7b57d95ac032e177b6a53789d279/ruff-0.15.8-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:12e617fc01a95e5821648a6df341d80456bd627bfab8a829f7cfc26a14a4b4a3", size = 10348041, upload-time = "2026-03-26T18:39:52.178Z" }, + { url = "https://files.pythonhosted.org/packages/02/f1/5bfb9298d9c323f842c5ddeb85f1f10ef51516ac7a34ba446c9347d898df/ruff-0.15.8-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:432701303b26416d22ba696c39f2c6f12499b89093b61360abc34bcc9bf07762", size = 11121987, upload-time = "2026-03-26T18:39:55.195Z" }, + { url = "https://files.pythonhosted.org/packages/10/11/6da2e538704e753c04e8d86b1fc55712fdbdcc266af1a1ece7a51fff0d10/ruff-0.15.8-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d910ae974b7a06a33a057cb87d2a10792a3b2b3b35e33d2699fdf63ec8f6b17a", size = 11951057, upload-time = "2026-03-26T18:39:19.18Z" }, + { url = "https://files.pythonhosted.org/packages/83/f0/c9208c5fd5101bf87002fed774ff25a96eea313d305f1e5d5744698dc314/ruff-0.15.8-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2033f963c43949d51e6fdccd3946633c6b37c484f5f98c3035f49c27395a8ab8", size = 11464613, upload-time = "2026-03-26T18:40:06.301Z" }, + { url = "https://files.pythonhosted.org/packages/f8/22/d7f2fabdba4fae9f3b570e5605d5eb4500dcb7b770d3217dca4428484b17/ruff-0.15.8-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f29b989a55572fb885b77464cf24af05500806ab4edf9a0fd8977f9759d85b1", size = 11257557, upload-time = "2026-03-26T18:39:57.972Z" }, + { url = "https://files.pythonhosted.org/packages/71/8c/382a9620038cf6906446b23ce8632ab8c0811b8f9d3e764f58bedd0c9a6f/ruff-0.15.8-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:ac51d486bf457cdc985a412fb1801b2dfd1bd8838372fc55de64b1510eff4bec", size = 11169440, upload-time = "2026-03-26T18:39:22.205Z" }, + { url = "https://files.pythonhosted.org/packages/4d/0d/0994c802a7eaaf99380085e4e40c845f8e32a562e20a38ec06174b52ef24/ruff-0.15.8-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:c9861eb959edab053c10ad62c278835ee69ca527b6dcd72b47d5c1e5648964f6", size = 10605963, upload-time = "2026-03-26T18:39:46.682Z" }, + { url = "https://files.pythonhosted.org/packages/19/aa/d624b86f5b0aad7cef6bbf9cd47a6a02dfdc4f72c92a337d724e39c9d14b/ruff-0.15.8-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:8d9a5b8ea13f26ae90838afc33f91b547e61b794865374f114f349e9036835fb", size = 10357484, upload-time = "2026-03-26T18:39:49.176Z" }, + { url = "https://files.pythonhosted.org/packages/35/c3/e0b7835d23001f7d999f3895c6b569927c4d39912286897f625736e1fd04/ruff-0.15.8-py3-none-musllinux_1_2_i686.whl", hash = "sha256:c2a33a529fb3cbc23a7124b5c6ff121e4d6228029cba374777bd7649cc8598b8", size = 10830426, upload-time = "2026-03-26T18:40:03.702Z" }, + { url = "https://files.pythonhosted.org/packages/f0/51/ab20b322f637b369383adc341d761eaaa0f0203d6b9a7421cd6e783d81b9/ruff-0.15.8-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:75e5cd06b1cf3f47a3996cfc999226b19aa92e7cce682dcd62f80d7035f98f49", size = 11345125, upload-time = "2026-03-26T18:39:27.799Z" }, + { url = "https://files.pythonhosted.org/packages/37/e6/90b2b33419f59d0f2c4c8a48a4b74b460709a557e8e0064cf33ad894f983/ruff-0.15.8-py3-none-win32.whl", hash = "sha256:bc1f0a51254ba21767bfa9a8b5013ca8149dcf38092e6a9eb704d876de94dc34", size = 10571959, upload-time = "2026-03-26T18:39:36.117Z" }, + { url = "https://files.pythonhosted.org/packages/1f/a2/ef467cb77099062317154c63f234b8a7baf7cb690b99af760c5b68b9ee7f/ruff-0.15.8-py3-none-win_amd64.whl", hash = "sha256:04f79eff02a72db209d47d665ba7ebcad609d8918a134f86cb13dd132159fc89", size = 11743893, upload-time = "2026-03-26T18:39:25.01Z" }, + { url = "https://files.pythonhosted.org/packages/15/e2/77be4fff062fa78d9b2a4dea85d14785dac5f1d0c1fb58ed52331f0ebe28/ruff-0.15.8-py3-none-win_arm64.whl", hash = "sha256:cf891fa8e3bb430c0e7fac93851a5978fc99c8fa2c053b57b118972866f8e5f2", size = 11048175, upload-time = "2026-03-26T18:40:01.06Z" }, ] [[package]] @@ -6105,14 +6069,14 @@ wheels = [ [[package]] name = "scipy-stubs" -version = "1.17.1.2" +version = "1.17.1.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "optype", extra = ["numpy"] }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c7/ab/43f681ffba42f363b7ed6b767fd215d1e26006578214ff8330586a11bf95/scipy_stubs-1.17.1.2.tar.gz", hash = "sha256:2ecadc8c87a3b61aaf7379d6d6b10f1038a829c53b9efe5b174fb97fc8b52237", size = 388354, upload-time = "2026-03-15T22:33:20.449Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a7/59/59c6cc3f9970154b9ed6b1aff42a0185cdd60cef54adc0404b9e77972221/scipy_stubs-1.17.1.3.tar.gz", hash = "sha256:5eb87a8d23d726706259b012ebe76a4a96a9ae9e141fc59bf55fc8eac2ed9e0f", size = 392185, upload-time = "2026-03-22T22:11:58.34Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/8c/0b/ec4fe720c1202d9df729a3e9d9b7e4d2da9f6e7f28bd2877b7d0769f4f75/scipy_stubs-1.17.1.2-py3-none-any.whl", hash = "sha256:f19e8f5273dbe3b7ee6a9554678c3973b9695fa66b91f29206d00830a1536c06", size = 594377, upload-time = "2026-03-15T22:33:18.684Z" }, + { url = "https://files.pythonhosted.org/packages/2c/d4/94304532c0a75a55526119043dd44a9bd1541a21e14483cbb54261c527d2/scipy_stubs-1.17.1.3-py3-none-any.whl", hash = "sha256:7b91d3f05aa47da06fbca14eb6c5bb4c28994e9245fd250cc847e375bab31297", size = 597933, upload-time = "2026-03-22T22:11:56.525Z" }, ] [[package]] @@ -6131,15 +6095,15 @@ wheels = [ [[package]] name = "sentry-sdk" -version = "2.54.0" +version = "2.55.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "certifi" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c8/e9/2e3a46c304e7fa21eaa70612f60354e32699c7102eb961f67448e222ad7c/sentry_sdk-2.54.0.tar.gz", hash = "sha256:2620c2575128d009b11b20f7feb81e4e4e8ae08ec1d36cbc845705060b45cc1b", size = 413813, upload-time = "2026-03-02T15:12:41.355Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e9/b8/285293dc60fc198fffc3fcdbc7c6d4e646e0f74e61461c355d40faa64ceb/sentry_sdk-2.55.0.tar.gz", hash = "sha256:3774c4d8820720ca4101548131b9c162f4c9426eb7f4d24aca453012a7470f69", size = 424505, upload-time = "2026-03-17T14:15:51.707Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/53/39/be412cc86bc6247b8f69e9383d7950711bd86f8d0a4a4b0fe8fad685bc21/sentry_sdk-2.54.0-py2.py3-none-any.whl", hash = "sha256:fd74e0e281dcda63afff095d23ebcd6e97006102cdc8e78a29f19ecdf796a0de", size = 439198, upload-time = "2026-03-02T15:12:39.546Z" }, + { url = "https://files.pythonhosted.org/packages/9a/66/20465097782d7e1e742d846407ea7262d338c6e876ddddad38ca8907b38f/sentry_sdk-2.55.0-py2.py3-none-any.whl", hash = "sha256:97026981cb15699394474a196b88503a393cbc58d182ece0d3abe12b9bd978d4", size = 449284, upload-time = "2026-03-17T14:15:49.604Z" }, ] [package.optional-dependencies] @@ -6375,15 +6339,15 @@ wheels = [ [[package]] name = "starlette" -version = "0.52.1" +version = "1.0.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c4/68/79977123bb7be889ad680d79a40f339082c1978b5cfcf62c2d8d196873ac/starlette-0.52.1.tar.gz", hash = "sha256:834edd1b0a23167694292e94f597773bc3f89f362be6effee198165a35d62933", size = 2653702, upload-time = "2026-01-18T13:34:11.062Z" } +sdist = { url = "https://files.pythonhosted.org/packages/81/69/17425771797c36cded50b7fe44e850315d039f28b15901ab44839e70b593/starlette-1.0.0.tar.gz", hash = "sha256:6a4beaf1f81bb472fd19ea9b918b50dc3a77a6f2e190a12954b25e6ed5eea149", size = 2655289, upload-time = "2026-03-22T18:29:46.779Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/81/0d/13d1d239a25cbfb19e740db83143e95c772a1fe10202dda4b76792b114dd/starlette-0.52.1-py3-none-any.whl", hash = "sha256:0029d43eb3d273bc4f83a08720b4912ea4b071087a3b48db01b7c839f7954d74", size = 74272, upload-time = "2026-01-18T13:34:09.188Z" }, + { url = "https://files.pythonhosted.org/packages/0b/c9/584bc9651441b4ba60cc4d557d8a547b5aff901af35bda3a4ee30c819b82/starlette-1.0.0-py3-none-any.whl", hash = "sha256:d3ec55e0bb321692d275455ddfd3df75fff145d009685eb40dc91fc66b03d38b", size = 72651, upload-time = "2026-03-22T18:29:45.111Z" }, ] [[package]] @@ -6467,7 +6431,7 @@ wheels = [ [[package]] name = "tablestore" -version = "6.4.1" +version = "6.4.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, @@ -6480,9 +6444,9 @@ dependencies = [ { name = "six" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/62/00/53f8eeb0016e7ad518f92b085de8855891d10581b42f86d15d1df7a56d33/tablestore-6.4.1.tar.gz", hash = "sha256:005c6939832f2ecd403e01220b7045de45f2e53f1ffaf0c2efc435810885fffb", size = 120319, upload-time = "2026-02-13T06:58:37.267Z" } +sdist = { url = "https://files.pythonhosted.org/packages/09/07/afa1d18521bab13bb813066892b73589937fcf68aea63a54b0b14dae17b5/tablestore-6.4.2.tar.gz", hash = "sha256:5251e14b7c7ebf3d49d37dde957b49c7dba04ee8715c2650109cc02f3b89cc77", size = 5071435, upload-time = "2026-03-26T15:39:06.498Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/cc/96/a132bdecb753dc9dc34124a53019da29672baaa34485c8c504895897ea96/tablestore-6.4.1-py3-none-any.whl", hash = "sha256:616898d294dfe22f0d427463c241c6788374cdb2ace9aaf85673ce2c2a18d7e0", size = 141556, upload-time = "2026-02-13T06:58:35.579Z" }, + { url = "https://files.pythonhosted.org/packages/c7/3f/5fb3e8e5de36934fe38986b4e861657cebb3a6dfd97d32224cd40fc66359/tablestore-6.4.2-py3-none-any.whl", hash = "sha256:98c4cffa5eace4a3ea6fc2425263e733093c2baa43537f25dbaaf02e2b7882d8", size = 5114987, upload-time = "2026-03-26T15:39:04.074Z" }, ] [[package]] @@ -6508,7 +6472,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/20/81/be13f417065200182 [[package]] name = "tcvectordb" -version = "2.0.0" +version = "2.1.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cachetools" }, @@ -6521,9 +6485,9 @@ dependencies = [ { name = "ujson" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/16/21/3bcd466df20ac69408c0228b1c5e793cf3283085238d3ef5d352c556b6ad/tcvectordb-2.0.0.tar.gz", hash = "sha256:38c6ed17931b9bd702138941ca6cfe10b2b60301424ffa36b64a3c2686318941", size = 82209, upload-time = "2025-12-27T07:55:27.376Z" } +sdist = { url = "https://files.pythonhosted.org/packages/34/4c/3510489c20823c045a4f84c3f656b1af00b3fbbfa36efc494cf01492521f/tcvectordb-2.1.0.tar.gz", hash = "sha256:382615573f2b6d3e21535b686feac8895169b8eb56078fc73abb020676a1622f", size = 85691, upload-time = "2026-03-25T12:55:27.509Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/af/10/e807b273348edef3b321194bc13b67d2cd4df64e22f0404b9e39082415c7/tcvectordb-2.0.0-py3-none-any.whl", hash = "sha256:1731d9c6c0d17a4199872747ddfb1dd3feb26f14ffe7a657f8a5ac3af4ddcdd1", size = 96256, upload-time = "2025-12-27T07:55:24.362Z" }, + { url = "https://files.pythonhosted.org/packages/99/cf/7f340b4dc30ed0d2758915d1c2a4b2e9f0c90ce4f322b7cf17e571c80a45/tcvectordb-2.1.0-py3-none-any.whl", hash = "sha256:afbfc5f82bda70480921b2308148cbd0c51c8b45b3eef6cea64ddd003c7577e9", size = 99615, upload-time = "2026-03-25T12:55:26.004Z" }, ] [[package]] @@ -6792,11 +6756,11 @@ wheels = [ [[package]] name = "types-cachetools" -version = "6.2.0.20251022" +version = "6.2.0.20260317" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/3b/a8/f9bcc7f1be63af43ef0170a773e2d88817bcc7c9d8769f2228c802826efe/types_cachetools-6.2.0.20251022.tar.gz", hash = "sha256:f1d3c736f0f741e89ec10f0e1b0138625023e21eb33603a930c149e0318c0cef", size = 9608, upload-time = "2025-10-22T03:03:58.16Z" } +sdist = { url = "https://files.pythonhosted.org/packages/8b/7f/16a4d8344c28193a5a74358028c2d2f753f0d9658dd98b9e1967c50045a2/types_cachetools-6.2.0.20260317.tar.gz", hash = "sha256:6d91855bcc944665897c125e720aa3c80aace929b77a64e796343701df4f61c6", size = 9812, upload-time = "2026-03-17T04:06:32.007Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/98/2d/8d821ed80f6c2c5b427f650bf4dc25b80676ed63d03388e4b637d2557107/types_cachetools-6.2.0.20251022-py3-none-any.whl", hash = "sha256:698eb17b8f16b661b90624708b6915f33dbac2d185db499ed57e4997e7962cad", size = 9341, upload-time = "2025-10-22T03:03:57.036Z" }, + { url = "https://files.pythonhosted.org/packages/17/9a/b00b23054934c4d569c19f7278c4fb32746cd36a64a175a216d3073a4713/types_cachetools-6.2.0.20260317-py3-none-any.whl", hash = "sha256:92fa9bc50e4629e31fca67ceb3fb1de71791e314fa16c0a0d2728724dc222c8b", size = 9346, upload-time = "2026-03-17T04:06:31.184Z" }, ] [[package]] @@ -6840,11 +6804,11 @@ wheels = [ [[package]] name = "types-docutils" -version = "0.22.3.20260316" +version = "0.22.3.20260322" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/9f/27/a7f16b3a2fad0a4ddd85a668319f9a1d0311c4bd9578894f6471c7e6c788/types_docutils-0.22.3.20260316.tar.gz", hash = "sha256:8ef27d565b9831ff094fe2eac75337a74151013e2d21ecabd445c2955f891564", size = 57263, upload-time = "2026-03-16T04:29:12.211Z" } +sdist = { url = "https://files.pythonhosted.org/packages/44/bb/243a87fc1605a4a94c2c343d6dbddbf0d7ef7c0b9550f360b8cda8e82c39/types_docutils-0.22.3.20260322.tar.gz", hash = "sha256:e2450bb997283c3141ec5db3e436b91f0aa26efe35eb9165178ca976ccb4930b", size = 57311, upload-time = "2026-03-22T04:08:44.064Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/70/60/c1f22b7cfc4837d5419e5a2d8702c7d65f03343f866364b71cccd8a73b79/types_docutils-0.22.3.20260316-py3-none-any.whl", hash = "sha256:083c7091b8072c242998ec51da1bf1492f0332387da81c3b085efbf5ca754c7d", size = 91968, upload-time = "2026-03-16T04:29:11.114Z" }, + { url = "https://files.pythonhosted.org/packages/c6/4a/22c090cd4615a16917dff817cbe7c5956da376c961e024c241cd962d2c3d/types_docutils-0.22.3.20260322-py3-none-any.whl", hash = "sha256:681d4510ce9b80a0c6a593f0f9843d81f8caa786db7b39ba04d9fd5480ac4442", size = 91978, upload-time = "2026-03-22T04:08:43.117Z" }, ] [[package]] @@ -6874,15 +6838,15 @@ wheels = [ [[package]] name = "types-gevent" -version = "25.9.0.20251228" +version = "25.9.0.20260322" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "types-greenlet" }, { name = "types-psutil" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/06/85/c5043c4472f82c8ee3d9e0673eb4093c7d16770a26541a137a53a1d096f6/types_gevent-25.9.0.20251228.tar.gz", hash = "sha256:423ef9891d25c5a3af236c3e9aace4c444c86ff773fe13ef22731bc61d59abef", size = 38063, upload-time = "2025-12-28T03:28:28.651Z" } +sdist = { url = "https://files.pythonhosted.org/packages/34/f0/14a99ddcaa69b559fa7cec8c9de880b792bebb0b848ae865d94ea9058533/types_gevent-25.9.0.20260322.tar.gz", hash = "sha256:91257920845762f09753c08aa20fad1743ac13d2de8bcf23f4b8fe967d803732", size = 38241, upload-time = "2026-03-22T04:08:55.213Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c8/b7/a2d6b652ab5a26318b68cafd58c46fafb9b15c5313d2d76a70b838febb4b/types_gevent-25.9.0.20251228-py3-none-any.whl", hash = "sha256:e2e225af4fface9241c16044983eb2fc3993f2d13d801f55c2932848649b7f2f", size = 55486, upload-time = "2025-12-28T03:28:27.382Z" }, + { url = "https://files.pythonhosted.org/packages/89/0f/964440b57eb4ddb4aca03479a4093852e1ce79010d1c5967234e6f5d6bd9/types_gevent-25.9.0.20260322-py3-none-any.whl", hash = "sha256:21b3c269b3a20ecb0e4668289c63b97d21694d84a004ab059c1e32ab970eacc2", size = 55500, upload-time = "2026-03-22T04:08:54.103Z" }, ] [[package]] @@ -6915,18 +6879,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/61/91/915c4a6e6e9bd2bca3ec0c21c1771b175c59e204b85e57f3f572370fe753/types_jmespath-1.1.0.20260124-py3-none-any.whl", hash = "sha256:ec387666d446b15624215aa9cbd2867ffd885b6c74246d357c65e830c7a138b3", size = 11509, upload-time = "2026-01-24T03:18:45.536Z" }, ] -[[package]] -name = "types-jsonschema" -version = "4.26.0.20260202" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "referencing" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/a1/07/68f63e715eb327ed2f5292e29e8be99785db0f72c7664d2c63bd4dbdc29d/types_jsonschema-4.26.0.20260202.tar.gz", hash = "sha256:29831baa4308865a9aec547a61797a06fc152b0dac8dddd531e002f32265cb07", size = 16168, upload-time = "2026-02-02T04:11:22.585Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c1/06/962d4f364f779d7389cd31a1bb581907b057f52f0ace2c119a8dd8409db6/types_jsonschema-4.26.0.20260202-py3-none-any.whl", hash = "sha256:41c95343abc4de9264e333a55e95dfb4d401e463856d0164eec9cb182e8746da", size = 15914, upload-time = "2026-02-02T04:11:21.61Z" }, -] - [[package]] name = "types-markdown" version = "3.10.2.20260211" @@ -6938,11 +6890,11 @@ wheels = [ [[package]] name = "types-oauthlib" -version = "3.3.0.20250822" +version = "3.3.0.20260324" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/6a/6e/d08033f562053c459322333c46baa8cf8d2d8c18f30d46dd898c8fd8df77/types_oauthlib-3.3.0.20250822.tar.gz", hash = "sha256:2cd41587dd80c199e4230e3f086777e9ae525e89579c64afe5e0039ab09be9de", size = 25700, upload-time = "2025-08-22T03:02:41.378Z" } +sdist = { url = "https://files.pythonhosted.org/packages/91/38/543938f86d81bd6a78b8c355fe81bb8da0a26e4c28addfe3443e38a683d2/types_oauthlib-3.3.0.20260324.tar.gz", hash = "sha256:3c4cc07fa33886f881682237c1e445c5f1778b44efea118f4c1e4ede82cb52f2", size = 26030, upload-time = "2026-03-24T04:06:30.898Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/18/4b/00593b8b5d055550e1fcb9af2c42fa11b0a90bf16a94759a77bc1c3c0c72/types_oauthlib-3.3.0.20250822-py3-none-any.whl", hash = "sha256:b7f4c9b9eed0e020f454e0af800b10e93dd2efd196da65744b76910cce7e70d6", size = 48800, upload-time = "2025-08-22T03:02:40.427Z" }, + { url = "https://files.pythonhosted.org/packages/0e/60/26f0ddade4b2bb17b3d8f3ebaac436e5487caec28831da3d7ea309fe93b9/types_oauthlib-3.3.0.20260324-py3-none-any.whl", hash = "sha256:d24662033b04f4d50a2f1fed04c1b43ff2554aa037c1dafa0424f87100a46ccd", size = 48984, upload-time = "2026-03-24T04:06:29.696Z" }, ] [[package]] @@ -6965,11 +6917,11 @@ wheels = [ [[package]] name = "types-openpyxl" -version = "3.1.5.20260316" +version = "3.1.5.20260322" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a1/38/32f8ee633dd66ca6d52b8853b9fd45dc3869490195a6ed435d5c868b9c2d/types_openpyxl-3.1.5.20260316.tar.gz", hash = "sha256:081dda9427ea1141e5649e3dcf630e7013a4cf254a5862a7e0a3f53c123b7ceb", size = 101318, upload-time = "2026-03-16T04:29:05.004Z" } +sdist = { url = "https://files.pythonhosted.org/packages/77/bf/15240de4d68192d2a1f385ef2f6f1ecb29b85d2f3791dd2e2d5b980be30f/types_openpyxl-3.1.5.20260322.tar.gz", hash = "sha256:a61d66ebe1e49697853c6db8e0929e1cda2c96755e71fb676ed7fc48dfdcf697", size = 101325, upload-time = "2026-03-22T04:08:40.426Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d5/df/b87ae6226ed7cc84b9e43119c489c7f053a9a25e209e0ebb5d84bc36fa37/types_openpyxl-3.1.5.20260316-py3-none-any.whl", hash = "sha256:38e7e125df520fb7eb72cb1129c9f024eb99ef9564aad2c27f68f080c26bcf2d", size = 166084, upload-time = "2026-03-16T04:29:03.657Z" }, + { url = "https://files.pythonhosted.org/packages/bf/b4/c14191b30bcb266365b124b2bb4e67ecd68425a78ba77ee026f33667daa9/types_openpyxl-3.1.5.20260322-py3-none-any.whl", hash = "sha256:2f515f0b0bbfb04bfb587de34f7522d90b5151a8da7bbbd11ecec4ca40f64238", size = 166102, upload-time = "2026-03-22T04:08:39.174Z" }, ] [[package]] @@ -7044,11 +6996,11 @@ wheels = [ [[package]] name = "types-python-dateutil" -version = "2.9.0.20260305" +version = "2.9.0.20260323" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/1d/c7/025c624f347e10476b439a6619a95f1d200250ea88e7ccea6e09e48a7544/types_python_dateutil-2.9.0.20260305.tar.gz", hash = "sha256:389717c9f64d8f769f36d55a01873915b37e97e52ce21928198d210fbd393c8b", size = 16885, upload-time = "2026-03-05T04:00:47.409Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e9/02/f72df9ef5ffc4f959b83cb80c8aa03eb8718a43e563ecd99ccffe265fa89/types_python_dateutil-2.9.0.20260323.tar.gz", hash = "sha256:a107aef5841db41ace381dbbbd7e4945220fc940f7a72172a0be5a92d9ab7164", size = 16897, upload-time = "2026-03-23T04:15:14.829Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/0a/77/8c0d1ec97f0d9707ad3d8fa270ab8964e7b31b076d2f641c94987395cc75/types_python_dateutil-2.9.0.20260305-py3-none-any.whl", hash = "sha256:a3be9ca444d38cadabd756cfbb29780d8b338ae2a3020e73c266a83cc3025dd7", size = 18419, upload-time = "2026-03-05T04:00:46.392Z" }, + { url = "https://files.pythonhosted.org/packages/92/c1/b661838b97453e699a215451f2e22cee750eaaf4ea4619b34bdaf01221a4/types_python_dateutil-2.9.0.20260323-py3-none-any.whl", hash = "sha256:a23a50a07f6eb87e729d4cb0c2eb511c81761eeb3f505db2c1413be94aae8335", size = 18433, upload-time = "2026-03-23T04:15:13.683Z" }, ] [[package]] @@ -7062,11 +7014,11 @@ wheels = [ [[package]] name = "types-pywin32" -version = "311.0.0.20260316" +version = "311.0.0.20260323" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/17/a8/b4652002a854fcfe5d272872a0ae2d5df0e9dc482e1a6dfb5e97b905b76f/types_pywin32-311.0.0.20260316.tar.gz", hash = "sha256:c136fa489fe6279a13bca167b750414e18d657169b7cf398025856dc363004e8", size = 329956, upload-time = "2026-03-16T04:28:57.366Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b5/cc/f03ddb7412ac2fc2238358b617c2d5919ba96812dff8d3081f3b2754bb83/types_pywin32-311.0.0.20260323.tar.gz", hash = "sha256:2e8dc6a59fedccbc51b241651ce1e8aa58488934f517debf23a9c6d0ff329b4b", size = 332263, upload-time = "2026-03-23T04:15:20.004Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f0/83/704698d93788cf1c2f5e236eae2b37f1b2152ef84dc66b4b83f6c7487b76/types_pywin32-311.0.0.20260316-py3-none-any.whl", hash = "sha256:abb643d50012386d697af49384cc0e6e475eab76b0ca2a7f93d480d0862b3692", size = 392959, upload-time = "2026-03-16T04:28:56.104Z" }, + { url = "https://files.pythonhosted.org/packages/dc/82/d786d5d8b846e3cbe1ee52da8945560b111c789b42c3771b2129b312ab94/types_pywin32-311.0.0.20260323-py3-none-any.whl", hash = "sha256:2f2b03fc72ae77ccbb0ee258da0f181c3a38bd8602f6e332e42587b3b0d5f095", size = 395435, upload-time = "2026-03-23T04:15:18.76Z" }, ] [[package]] @@ -7093,11 +7045,11 @@ wheels = [ [[package]] name = "types-regex" -version = "2026.2.28.20260301" +version = "2026.3.32.20260329" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/3a/ed/106958cb686316113b748ed4209fa363fd92b15759d5409c3930fed36606/types_regex-2026.2.28.20260301.tar.gz", hash = "sha256:644c231db3f368908320170c14905731a7ae5fabdac0f60f5d6d12ecdd3bc8dd", size = 13157, upload-time = "2026-03-01T04:11:13.559Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f2/d8/a3aca5775c573e56d201bbd76a827b84d851a4bce28e189e5acb9c7a0d15/types_regex-2026.3.32.20260329.tar.gz", hash = "sha256:12653e44694cb3e3ccdc39bab3d433d2a83fec1c01220e6871fd6f3cf434675c", size = 13111, upload-time = "2026-03-29T04:27:04.759Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c7/bb/9bc26fcf5155bd25efeca35f8ba6bffb8b3c9da2baac8bf40067606418f3/types_regex-2026.2.28.20260301-py3-none-any.whl", hash = "sha256:7da7a1fe67528238176a5844fd435ca90617cf605341308686afbc579fdea5c0", size = 11130, upload-time = "2026-03-01T04:11:11.454Z" }, + { url = "https://files.pythonhosted.org/packages/89/f4/a1db307e56753c49fb15fc88d70fadeb3f38897b28cab645cddd18054c79/types_regex-2026.3.32.20260329-py3-none-any.whl", hash = "sha256:861d0893bcfe08a57eb7486a502014e29dc2721d46dd5130798fbccafdb31cc0", size = 11128, upload-time = "2026-03-29T04:27:03.854Z" }, ] [[package]] @@ -7162,16 +7114,16 @@ wheels = [ [[package]] name = "types-tensorflow" -version = "2.18.0.20260224" +version = "2.18.0.20260322" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, { name = "types-protobuf" }, { name = "types-requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/af/cb/4914c2fbc1cf8a8d1ef2a7c727bb6f694879be85edeee880a0c88e696af8/types_tensorflow-2.18.0.20260224.tar.gz", hash = "sha256:9b0ccc91c79c88791e43d3f80d6c879748fa0361409c5ff23c7ffe3709be00f2", size = 258786, upload-time = "2026-02-24T04:06:45.613Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/cb/81dfaa2680031a6e087bcdfaf1c0556371098e229aee541e21c81a381065/types_tensorflow-2.18.0.20260322.tar.gz", hash = "sha256:135dc6ca06cc647a002e1bca5c5c99516fde51efd08e46c48a9b1916fc5df07f", size = 259030, upload-time = "2026-03-22T04:09:14.069Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d4/1d/a1c3c60f0eb1a204500dbdc66e3d18aafabc86ad07a8eca71ea05bc8c5a8/types_tensorflow-2.18.0.20260224-py3-none-any.whl", hash = "sha256:6a25f5f41f3e06f28c1f65c6e09f484d4ba0031d6d8df83a39df9d890245eefc", size = 329746, upload-time = "2026-02-24T04:06:44.4Z" }, + { url = "https://files.pythonhosted.org/packages/5b/0c/a178061450b640e53577e2c423ad22bf5d3f692f6bfeeb12156d02b531ef/types_tensorflow-2.18.0.20260322-py3-none-any.whl", hash = "sha256:d8776b6daacdb279e64f105f9dcbc0b8e3544b9a2f2eb71ec6ea5955081f65e6", size = 329771, upload-time = "2026-03-22T04:09:12.844Z" }, ] [[package]] @@ -7757,7 +7709,7 @@ wheels = [ [[package]] name = "xinference-client" -version = "2.3.1" +version = "2.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, @@ -7765,9 +7717,9 @@ dependencies = [ { name = "requests" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/bc/7a/33aeef9cffdc331de0046c25412622c5a16226d1b4e0cca9ed512ad00b9a/xinference_client-2.3.1.tar.gz", hash = "sha256:23ae225f47ff9adf4c6f7718c54993d1be8c704d727509f6e5cb670de3e02c4d", size = 58414, upload-time = "2026-03-15T05:53:23.994Z" } +sdist = { url = "https://files.pythonhosted.org/packages/0e/f2/7640528fd4f816df19afe91d52332a658ad2d2bacb13471b0a27dbd0cf46/xinference_client-2.4.0.tar.gz", hash = "sha256:59de6d58f89126c8ff05136818e0756108e534858255d7c4c0673b804fd2d01d", size = 58386, upload-time = "2026-03-29T05:10:58.533Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/74/8d/d9ab0a457718050a279b9bb6515b7245d114118dc5e275f190ef2628dd16/xinference_client-2.3.1-py3-none-any.whl", hash = "sha256:f7c4f0b56635b46be9cfd9b2affa8e15275491597ac9b958e14b13da5745133e", size = 40012, upload-time = "2026-03-15T05:53:22.797Z" }, + { url = "https://files.pythonhosted.org/packages/73/cf/9d27e0095cc28691c73ff186b33556790c7b87f046ca2ecd517c80272592/xinference_client-2.4.0-py3-none-any.whl", hash = "sha256:2f9478b00fe15643f281fe4c0643e74479c8b7837d377000ff120702cda81efc", size = 40012, upload-time = "2026-03-29T05:10:57.279Z" }, ] [[package]] diff --git a/dev/pytest/pytest_unit_tests.sh b/dev/pytest/pytest_unit_tests.sh index a034083304e..1d4ff4d86f5 100755 --- a/dev/pytest/pytest_unit_tests.sh +++ b/dev/pytest/pytest_unit_tests.sh @@ -13,4 +13,4 @@ PYTEST_XDIST_ARGS="${PYTEST_XDIST_ARGS:--n auto}" pytest --timeout "${PYTEST_TIMEOUT}" ${PYTEST_XDIST_ARGS} api/tests/unit_tests --ignore=api/tests/unit_tests/controllers # Run controller tests sequentially to avoid import race conditions -pytest --timeout "${PYTEST_TIMEOUT}" api/tests/unit_tests/controllers +pytest --timeout "${PYTEST_TIMEOUT}" --cov-append api/tests/unit_tests/controllers diff --git a/dev/setup b/dev/setup index 399c8f28a58..4236ff7fa74 100755 --- a/dev/setup +++ b/dev/setup @@ -24,5 +24,4 @@ cp "$MIDDLEWARE_ENV_EXAMPLE" "$MIDDLEWARE_ENV" cd "$ROOT/api" uv sync --group dev -cd "$ROOT/web" -pnpm install +pnpm --dir "$ROOT" install diff --git a/dev/start-web b/dev/start-web index f853f4a895b..baf008274b6 100755 --- a/dev/start-web +++ b/dev/start-web @@ -3,6 +3,6 @@ set -x SCRIPT_DIR="$(dirname "$(realpath "$0")")" -cd "$SCRIPT_DIR/../web" +ROOT_DIR="$(dirname "$SCRIPT_DIR")" -pnpm install && pnpm dev:inspect +pnpm --dir "$ROOT_DIR" install && pnpm --dir "$ROOT_DIR/web" dev:inspect diff --git a/docker/.env.example b/docker/.env.example index 9d6cd653187..9fbf9a9e72a 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -488,7 +488,8 @@ ALIYUN_OSS_REGION=ap-southeast-1 ALIYUN_OSS_AUTH_VERSION=v4 # Don't start with '/'. OSS doesn't support leading slash in object names. ALIYUN_OSS_PATH=your-path -ALIYUN_CLOUDBOX_ID=your-cloudbox-id +# Optional CloudBox ID for Aliyun OSS, DO NOT enable it if you are not using CloudBox. +#ALIYUN_CLOUDBOX_ID=your-cloudbox-id # Tencent COS Configuration # @@ -771,6 +772,9 @@ BAIDU_VECTOR_DB_SHARD=1 BAIDU_VECTOR_DB_REPLICAS=3 BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER=DEFAULT_ANALYZER BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE=COARSE_MODE +BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT=500 +BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO=0.05 +BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS=300 # VikingDB configurations, only available when VECTOR_STORE is `vikingdb` VIKINGDB_ACCESS_KEY=your-ak diff --git a/docker/dify-env-sync.py b/docker/dify-env-sync.py new file mode 100755 index 00000000000..d7c762748c5 --- /dev/null +++ b/docker/dify-env-sync.py @@ -0,0 +1,440 @@ +#!/usr/bin/env python3 + +# ================================================================ +# Dify Environment Variables Synchronization Script +# +# Features: +# - Synchronize latest settings from .env.example to .env +# - Preserve custom settings in existing .env +# - Add new environment variables +# - Detect removed environment variables +# - Create backup files +# ================================================================ + +import argparse +import re +import shutil +import sys +from datetime import datetime +from pathlib import Path + +# ANSI color codes +RED = "\033[0;31m" +GREEN = "\033[0;32m" +YELLOW = "\033[1;33m" +BLUE = "\033[0;34m" +NC = "\033[0m" # No Color + + +def supports_color() -> bool: + """Return True if the terminal supports ANSI color codes.""" + return hasattr(sys.stdout, "isatty") and sys.stdout.isatty() + + +def log_info(message: str) -> None: + """Print an informational message in blue.""" + if supports_color(): + print(f"{BLUE}[INFO]{NC} {message}") + else: + print(f"[INFO] {message}") + + +def log_success(message: str) -> None: + """Print a success message in green.""" + if supports_color(): + print(f"{GREEN}[SUCCESS]{NC} {message}") + else: + print(f"[SUCCESS] {message}") + + +def log_warning(message: str) -> None: + """Print a warning message in yellow to stderr.""" + if supports_color(): + print(f"{YELLOW}[WARNING]{NC} {message}", file=sys.stderr) + else: + print(f"[WARNING] {message}", file=sys.stderr) + + +def log_error(message: str) -> None: + """Print an error message in red to stderr.""" + if supports_color(): + print(f"{RED}[ERROR]{NC} {message}", file=sys.stderr) + else: + print(f"[ERROR] {message}", file=sys.stderr) + + +def parse_env_file(path: Path) -> dict[str, str]: + """Parse an .env-style file and return a mapping of key to raw value. + + Lines that are blank or start with '#' (after optional whitespace) are + skipped. Only lines containing '=' are considered variable definitions. + + Args: + path: Path to the .env file to parse. + + Returns: + Ordered dict mapping variable name to its value string. + """ + variables: dict[str, str] = {} + with path.open(encoding="utf-8") as fh: + for line in fh: + line = line.rstrip("\n") + # Skip blank lines and comment lines + stripped = line.strip() + if not stripped or stripped.startswith("#"): + continue + if "=" not in line: + continue + key, _, value = line.partition("=") + key = key.strip() + if key: + variables[key] = value.strip() + return variables + + +def check_files(work_dir: Path) -> None: + """Verify required files exist; create .env from .env.example if absent. + + Args: + work_dir: Directory that must contain .env.example (and optionally .env). + + Raises: + SystemExit: If .env.example does not exist. + """ + log_info("Checking required files...") + + example_file = work_dir / ".env.example" + env_file = work_dir / ".env" + + if not example_file.exists(): + log_error(".env.example file not found") + sys.exit(1) + + if not env_file.exists(): + log_warning(".env file does not exist. Creating from .env.example.") + shutil.copy2(example_file, env_file) + log_success(".env file created") + + log_success("Required files verified") + + +def create_backup(work_dir: Path) -> None: + """Create a timestamped backup of the current .env file. + + Backups are placed in ``/env-backup/`` with the filename + ``.env.backup_``. + + Args: + work_dir: Directory containing the .env file to back up. + """ + env_file = work_dir / ".env" + if not env_file.exists(): + return + + backup_dir = work_dir / "env-backup" + if not backup_dir.exists(): + backup_dir.mkdir(parents=True) + log_info(f"Created backup directory: {backup_dir}") + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + backup_file = backup_dir / f".env.backup_{timestamp}" + shutil.copy2(env_file, backup_file) + log_success(f"Backed up existing .env to {backup_file}") + + +def analyze_value_change(current: str, recommended: str) -> str | None: + """Analyse what kind of change occurred between two env values. + + Args: + current: Value currently set in .env. + recommended: Value present in .env.example. + + Returns: + A human-readable description string, or None when no analysis applies. + """ + use_colors = supports_color() + + def colorize(color: str, text: str) -> str: + return f"{color}{text}{NC}" if use_colors else text + + if not current and recommended: + return colorize(RED, " -> Setting from empty to recommended value") + if current and not recommended: + return colorize(RED, " -> Recommended value changed to empty") + + # Numeric comparison + if re.fullmatch(r"\d+", current) and re.fullmatch(r"\d+", recommended): + cur_int, rec_int = int(current), int(recommended) + if cur_int < rec_int: + return colorize(BLUE, f" -> Numeric increase ({current} < {recommended})") + if cur_int > rec_int: + return colorize(YELLOW, f" -> Numeric decrease ({current} > {recommended})") + return None + + # Boolean comparison + if current.lower() in {"true", "false"} and recommended.lower() in {"true", "false"}: + if current.lower() != recommended.lower(): + return colorize(BLUE, f" -> Boolean value change ({current} -> {recommended})") + return None + + # URL / endpoint + if current.startswith(("http://", "https://")) or recommended.startswith(("http://", "https://")): + return colorize(BLUE, " -> URL/endpoint change") + + # File path + if current.startswith("/") or recommended.startswith("/"): + return colorize(BLUE, " -> File path change") + + # String length + if len(current) != len(recommended): + return colorize(YELLOW, f" -> String length change ({len(current)} -> {len(recommended)} characters)") + + return None + + +def detect_differences(env_vars: dict[str, str], example_vars: dict[str, str]) -> dict[str, tuple[str, str]]: + """Find variables whose values differ between .env and .env.example. + + Only variables present in *both* files are compared; new or removed + variables are handled by separate functions. + + Args: + env_vars: Parsed key/value pairs from .env. + example_vars: Parsed key/value pairs from .env.example. + + Returns: + Mapping of key -> (env_value, example_value) for every key whose + values differ. + """ + log_info("Detecting differences between .env and .env.example...") + + diffs: dict[str, tuple[str, str]] = {} + for key, example_value in example_vars.items(): + if key in env_vars and env_vars[key] != example_value: + diffs[key] = (env_vars[key], example_value) + + if diffs: + log_success(f"Detected differences in {len(diffs)} environment variables") + show_differences_detail(diffs) + else: + log_info("No differences detected") + + return diffs + + +def show_differences_detail(diffs: dict[str, tuple[str, str]]) -> None: + """Print a formatted table of differing environment variables. + + Args: + diffs: Mapping of key -> (current_value, recommended_value). + """ + use_colors = supports_color() + + log_info("") + log_info("=== Environment Variable Differences ===") + + if not diffs: + log_info("No differences to display") + return + + for count, (key, (env_value, example_value)) in enumerate(diffs.items(), start=1): + print() + if use_colors: + print(f"{YELLOW}[{count}] {key}{NC}") + print(f" {GREEN}.env (current){NC} : {env_value}") + print(f" {BLUE}.env.example (recommended){NC} : {example_value}") + else: + print(f"[{count}] {key}") + print(f" .env (current) : {env_value}") + print(f" .env.example (recommended) : {example_value}") + + analysis = analyze_value_change(env_value, example_value) + if analysis: + print(analysis) + + print() + log_info("=== Difference Analysis Complete ===") + log_info("Note: Consider changing to the recommended values above.") + log_info("Current implementation preserves .env values.") + print() + + +def detect_removed_variables(env_vars: dict[str, str], example_vars: dict[str, str]) -> list[str]: + """Identify variables present in .env but absent from .env.example. + + Args: + env_vars: Parsed key/value pairs from .env. + example_vars: Parsed key/value pairs from .env.example. + + Returns: + Sorted list of variable names that no longer appear in .env.example. + """ + log_info("Detecting removed environment variables...") + + removed = sorted(set(env_vars) - set(example_vars)) + + if removed: + log_warning("The following environment variables have been removed from .env.example:") + for var in removed: + log_warning(f" - {var}") + log_warning("Consider manually removing these variables from .env") + else: + log_success("No removed environment variables found") + + return removed + + +def sync_env_file(work_dir: Path, env_vars: dict[str, str], diffs: dict[str, tuple[str, str]]) -> None: + """Rewrite .env based on .env.example while preserving custom values. + + The output file follows the exact line structure of .env.example + (preserving comments, blank lines, and ordering). For every variable + that exists in .env with a different value from the example, the + current .env value is kept. Variables that are new in .env.example + (not present in .env at all) are added with the example's default. + + Args: + work_dir: Directory containing .env and .env.example. + env_vars: Parsed key/value pairs from the original .env. + diffs: Keys whose .env values differ from .env.example (to preserve). + """ + log_info("Starting partial synchronization of .env file...") + + example_file = work_dir / ".env.example" + new_env_file = work_dir / ".env.new" + + # Keys whose current .env value should override the example default + preserved_keys: set[str] = set(diffs.keys()) + + preserved_count = 0 + updated_count = 0 + + env_var_pattern = re.compile(r"^([A-Za-z_][A-Za-z0-9_]*)\s*=") + + with example_file.open(encoding="utf-8") as src, new_env_file.open("w", encoding="utf-8") as dst: + for line in src: + raw_line = line.rstrip("\n") + match = env_var_pattern.match(raw_line) + if match: + key = match.group(1) + if key in preserved_keys: + # Write the preserved value from .env + dst.write(f"{key}={env_vars[key]}\n") + log_info(f" Preserved: {key} (.env value)") + preserved_count += 1 + else: + # Use the example value (covers new vars and unchanged ones) + dst.write(line if line.endswith("\n") else raw_line + "\n") + updated_count += 1 + else: + # Blank line, comment, or non-variable line — keep as-is + dst.write(line if line.endswith("\n") else raw_line + "\n") + + # Atomically replace the original .env + try: + new_env_file.replace(work_dir / ".env") + except OSError as exc: + log_error(f"Failed to replace .env file: {exc}") + new_env_file.unlink(missing_ok=True) + sys.exit(1) + + log_success("Successfully created new .env file") + log_success("Partial synchronization of .env file completed") + log_info(f" Preserved .env values: {preserved_count}") + log_info(f" Updated to .env.example values: {updated_count}") + + +def show_statistics(work_dir: Path) -> None: + """Print a summary of variable counts from both env files. + + Args: + work_dir: Directory containing .env and .env.example. + """ + log_info("Synchronization statistics:") + + example_file = work_dir / ".env.example" + env_file = work_dir / ".env" + + example_count = len(parse_env_file(example_file)) if example_file.exists() else 0 + env_count = len(parse_env_file(env_file)) if env_file.exists() else 0 + + log_info(f" .env.example environment variables: {example_count}") + log_info(f" .env environment variables: {env_count}") + + +def build_arg_parser() -> argparse.ArgumentParser: + """Build and return the CLI argument parser. + + Returns: + Configured ArgumentParser instance. + """ + parser = argparse.ArgumentParser( + prog="dify-env-sync", + description=( + "Synchronize .env with .env.example: add new variables, " + "preserve custom values, and report removed variables." + ), + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=( + "Examples:\n" + " # Run from the docker/ directory (default)\n" + " python dify-env-sync.py\n\n" + " # Specify a custom working directory\n" + " python dify-env-sync.py --dir /path/to/docker\n" + ), + ) + parser.add_argument( + "--dir", + metavar="DIRECTORY", + default=".", + help="Working directory containing .env and .env.example (default: current directory)", + ) + parser.add_argument( + "--no-backup", + action="store_true", + default=False, + help="Skip creating a timestamped backup of the existing .env file", + ) + return parser + + +def main() -> None: + """Orchestrate the complete environment variable synchronization process.""" + parser = build_arg_parser() + args = parser.parse_args() + + work_dir = Path(args.dir).resolve() + + log_info("=== Dify Environment Variables Synchronization Script ===") + log_info(f"Execution started: {datetime.now()}") + log_info(f"Working directory: {work_dir}") + + # 1. Verify prerequisites + check_files(work_dir) + + # 2. Backup existing .env + if not args.no_backup: + create_backup(work_dir) + + # 3. Parse both files + env_vars = parse_env_file(work_dir / ".env") + example_vars = parse_env_file(work_dir / ".env.example") + + # 4. Report differences (values that changed in the example) + diffs = detect_differences(env_vars, example_vars) + + # 5. Report variables removed from the example + detect_removed_variables(env_vars, example_vars) + + # 6. Rewrite .env + sync_env_file(work_dir, env_vars, diffs) + + # 7. Print summary statistics + show_statistics(work_dir) + + log_success("=== Synchronization process completed successfully ===") + log_info(f"Execution finished: {datetime.now()}") + + +if __name__ == "__main__": + main() diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 04bd2858ff4..e55cf942c32 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -21,7 +21,7 @@ services: # API service api: - image: langgenius/dify-api:1.13.2 + image: langgenius/dify-api:1.13.3 restart: always environment: # Use the shared environment variables. @@ -63,7 +63,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.13.2 + image: langgenius/dify-api:1.13.3 restart: always environment: # Use the shared environment variables. @@ -102,7 +102,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.13.2 + image: langgenius/dify-api:1.13.3 restart: always environment: # Use the shared environment variables. @@ -132,7 +132,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.13.2 + image: langgenius/dify-web:1.13.3 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -245,7 +245,7 @@ services: # The DifySandbox sandbox: - image: langgenius/dify-sandbox:0.2.12 + image: langgenius/dify-sandbox:0.2.14 restart: always environment: # The DifySandbox configurations @@ -269,12 +269,13 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.5.4-local + image: langgenius/dify-plugin-daemon:0.5.3-local restart: always environment: # Use the shared environment variables. <<: *shared-api-worker-env DB_DATABASE: ${DB_PLUGIN_DATABASE:-dify_plugin} + DB_SSL_MODE: ${DB_SSL_MODE:-disable} SERVER_PORT: ${PLUGIN_DAEMON_PORT:-5002} SERVER_KEY: ${PLUGIN_DAEMON_KEY:-lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi} MAX_PLUGIN_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800} diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index 2dca5819039..911da70a737 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -97,7 +97,7 @@ services: # The DifySandbox sandbox: - image: langgenius/dify-sandbox:0.2.12 + image: langgenius/dify-sandbox:0.2.14 restart: always env_file: - ./middleware.env @@ -123,10 +123,12 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.5.4-local + image: langgenius/dify-plugin-daemon:0.5.3-local restart: always env_file: - ./middleware.env + extra_hosts: + - "host.docker.internal:host-gateway" environment: # Use the shared environment variables. LOG_OUTPUT_FORMAT: ${LOG_OUTPUT_FORMAT:-text} diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index bf72a0f623e..737a62020ca 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -146,7 +146,6 @@ x-shared-env: &shared-api-worker-env ALIYUN_OSS_REGION: ${ALIYUN_OSS_REGION:-ap-southeast-1} ALIYUN_OSS_AUTH_VERSION: ${ALIYUN_OSS_AUTH_VERSION:-v4} ALIYUN_OSS_PATH: ${ALIYUN_OSS_PATH:-your-path} - ALIYUN_CLOUDBOX_ID: ${ALIYUN_CLOUDBOX_ID:-your-cloudbox-id} TENCENT_COS_BUCKET_NAME: ${TENCENT_COS_BUCKET_NAME:-your-bucket-name} TENCENT_COS_SECRET_KEY: ${TENCENT_COS_SECRET_KEY:-your-secret-key} TENCENT_COS_SECRET_ID: ${TENCENT_COS_SECRET_ID:-your-secret-id} @@ -345,6 +344,9 @@ x-shared-env: &shared-api-worker-env BAIDU_VECTOR_DB_REPLICAS: ${BAIDU_VECTOR_DB_REPLICAS:-3} BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER: ${BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER:-DEFAULT_ANALYZER} BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE: ${BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE:-COARSE_MODE} + BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT: ${BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT:-500} + BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO: ${BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO:-0.05} + BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS: ${BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS:-300} VIKINGDB_ACCESS_KEY: ${VIKINGDB_ACCESS_KEY:-your-ak} VIKINGDB_SECRET_KEY: ${VIKINGDB_SECRET_KEY:-your-sk} VIKINGDB_REGION: ${VIKINGDB_REGION:-cn-shanghai} @@ -728,7 +730,7 @@ services: # API service api: - image: langgenius/dify-api:1.13.2 + image: langgenius/dify-api:1.13.3 restart: always environment: # Use the shared environment variables. @@ -770,7 +772,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.13.2 + image: langgenius/dify-api:1.13.3 restart: always environment: # Use the shared environment variables. @@ -809,7 +811,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.13.2 + image: langgenius/dify-api:1.13.3 restart: always environment: # Use the shared environment variables. @@ -839,7 +841,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.13.2 + image: langgenius/dify-web:1.13.3 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -952,7 +954,7 @@ services: # The DifySandbox sandbox: - image: langgenius/dify-sandbox:0.2.12 + image: langgenius/dify-sandbox:0.2.14 restart: always environment: # The DifySandbox configurations @@ -976,12 +978,13 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.5.4-local + image: langgenius/dify-plugin-daemon:0.5.3-local restart: always environment: # Use the shared environment variables. <<: *shared-api-worker-env DB_DATABASE: ${DB_PLUGIN_DATABASE:-dify_plugin} + DB_SSL_MODE: ${DB_SSL_MODE:-disable} SERVER_PORT: ${PLUGIN_DAEMON_PORT:-5002} SERVER_KEY: ${PLUGIN_DAEMON_KEY:-lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi} MAX_PLUGIN_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800} diff --git a/docker/ssrf_proxy/squid.conf.template b/docker/ssrf_proxy/squid.conf.template index 256e669c8db..fbe9ebc448b 100644 --- a/docker/ssrf_proxy/squid.conf.template +++ b/docker/ssrf_proxy/squid.conf.template @@ -28,6 +28,7 @@ http_access deny manager http_access allow localhost include /etc/squid/conf.d/*.conf http_access deny all +tcp_outgoing_address 0.0.0.0 ################################## Proxy Server ################################ http_port ${HTTP_PORT} diff --git a/docker/volumes/sandbox/conf/config.yaml b/docker/volumes/sandbox/conf/config.yaml index 8c1a1deb54e..3b4a6b84396 100644 --- a/docker/volumes/sandbox/conf/config.yaml +++ b/docker/volumes/sandbox/conf/config.yaml @@ -5,7 +5,8 @@ app: max_workers: 4 max_requests: 50 worker_timeout: 5 -python_path: /usr/local/bin/python3 +python_path: /opt/python/bin/python3 +nodejs_path: /usr/local/bin/node enable_network: True # please make sure there is no network risk in your environment allowed_syscalls: # please leave it empty if you have no idea how seccomp works proxy: diff --git a/docker/volumes/sandbox/conf/config.yaml.example b/docker/volumes/sandbox/conf/config.yaml.example index f92c19e51a2..365089cb9eb 100644 --- a/docker/volumes/sandbox/conf/config.yaml.example +++ b/docker/volumes/sandbox/conf/config.yaml.example @@ -5,7 +5,7 @@ app: max_workers: 4 max_requests: 50 worker_timeout: 5 -python_path: /usr/local/bin/python3 +python_path: /opt/python/bin/python3 python_lib_path: - /usr/local/lib/python3.10 - /usr/lib/python3.10 diff --git a/docs/ar-SA/README.md b/docs/ar-SA/README.md index 99e3e3567e2..af5a9bfdc66 100644 --- a/docs/ar-SA/README.md +++ b/docs/ar-SA/README.md @@ -53,7 +53,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

diff --git a/docs/bn-BD/README.md b/docs/bn-BD/README.md index f3fa68b4668..5dceacb1875 100644 --- a/docs/bn-BD/README.md +++ b/docs/bn-BD/README.md @@ -57,7 +57,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

ডিফাই একটি ওপেন-সোর্স LLM অ্যাপ ডেভেলপমেন্ট প্ল্যাটফর্ম। এটি ইন্টুইটিভ ইন্টারফেস, এজেন্টিক AI ওয়ার্কফ্লো, RAG পাইপলাইন, এজেন্ট ক্যাপাবিলিটি, মডেল ম্যানেজমেন্ট, মনিটরিং সুবিধা এবং আরও অনেক কিছু একত্রিত করে, যা দ্রুত প্রোটোটাইপ থেকে প্রোডাকশন পর্যন্ত নিয়ে যেতে সহায়তা করে। diff --git a/docs/de-DE/README.md b/docs/de-DE/README.md index c71a0bfccfa..1eab517a6d9 100644 --- a/docs/de-DE/README.md +++ b/docs/de-DE/README.md @@ -57,7 +57,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

Dify ist eine Open-Source-Plattform zur Entwicklung von LLM-Anwendungen. Ihre intuitive Benutzeroberfläche vereint agentenbasierte KI-Workflows, RAG-Pipelines, Agentenfunktionen, Modellverwaltung, Überwachungsfunktionen und mehr, sodass Sie schnell von einem Prototyp in die Produktion übergehen können. diff --git a/docs/es-ES/README.md b/docs/es-ES/README.md index da81b51d6a4..f4c60e3d8f9 100644 --- a/docs/es-ES/README.md +++ b/docs/es-ES/README.md @@ -53,7 +53,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

# diff --git a/docs/eu-ai-act-compliance.md b/docs/eu-ai-act-compliance.md new file mode 100644 index 00000000000..5fa29eed3f3 --- /dev/null +++ b/docs/eu-ai-act-compliance.md @@ -0,0 +1,186 @@ +# EU AI Act Compliance Guide for Dify Deployers + +Dify is an LLMOps platform for building RAG pipelines, agents, and AI workflows. If you deploy Dify in the EU — whether self-hosted or using a cloud provider — the EU AI Act applies to your deployment. This guide covers what the regulation requires and how Dify's architecture maps to those requirements. + +## Is your system in scope? + +The detailed obligations in Articles 12, 13, and 14 only apply to **high-risk AI systems** as defined in Annex III of the EU AI Act. A Dify application is high-risk if it is used for: + +- **Recruitment and HR** — screening candidates, evaluating employee performance, allocating tasks +- **Credit scoring and insurance** — assessing creditworthiness or setting premiums +- **Law enforcement** — profiling, criminal risk assessment, border control +- **Critical infrastructure** — managing energy, water, transport, or telecommunications systems +- **Education assessment** — grading students, determining admissions +- **Essential public services** — evaluating eligibility for benefits, housing, or emergency services + +Most Dify deployments (customer-facing chatbots, internal knowledge bases, content generation workflows) are **not** high-risk. If your Dify application does not fall into one of the categories above: + +- **Article 50** (end-user transparency) still applies if users interact with your application directly. See the [Article 50 section](#article-50-end-user-transparency) below. +- **GDPR** still applies if you process personal data. See the [GDPR section](#gdpr-considerations) below. +- The high-risk obligations (Articles 9-15) are less likely to apply, but risk classification is context-dependent. **Do not self-classify without legal review.** Focus on Article 50 (transparency) and GDPR (data protection) as your baseline obligations. + +If you are unsure whether your use case qualifies as high-risk, consult a qualified legal professional before proceeding. + +## Self-hosted vs cloud: different compliance profiles + +| Deployment | Your role | Dify's role | Who handles compliance? | +|-----------|----------|-------------|------------------------| +| **Self-hosted** | Provider and deployer | Framework provider — obligations under Article 25 apply only if Dify is placed on the market or put into service as part of a complete AI system bearing its name or trademark | You | +| **Dify Cloud** | Deployer | Provider and processor | Shared — Dify handles SOC 2 and GDPR for the platform; you handle AI Act obligations for your specific use case | + +Dify Cloud already has SOC 2 Type II and GDPR compliance for the platform itself. But the EU AI Act adds obligations specific to AI systems that SOC 2 does not cover: risk classification, technical documentation, transparency, and human oversight. + +## Supported providers and services + +Dify integrates with a broad range of AI providers and data stores. The following are the key ones relevant to compliance: + +- **AI providers:** HuggingFace (core), plus integrations with OpenAI, Anthropic, Google, and 100+ models via provider plugins +- **Model identifiers include:** gpt-4o, gpt-3.5-turbo, claude-3-opus, gemini-2.5-flash, whisper-1, and others +- **Vector database connections:** Extensive RAG infrastructure supporting numerous vector stores + +Dify's plugin architecture means actual provider usage depends on your configuration. Document which providers and models are active in your deployment. + +## Data flow diagram + +A typical Dify RAG deployment: + +```mermaid +graph LR + USER((User)) -->|query| DIFY[Dify Platform] + DIFY -->|prompts| LLM([LLM Provider]) + LLM -->|responses| DIFY + DIFY -->|documents| EMBED([Embedding Model]) + EMBED -->|vectors| DIFY + DIFY -->|store/retrieve| VS[(Vector Store)] + DIFY -->|knowledge| KB[(Knowledge Base)] + DIFY -->|response| USER + + classDef processor fill:#60a5fa,stroke:#1e40af,color:#000 + classDef controller fill:#4ade80,stroke:#166534,color:#000 + classDef app fill:#a78bfa,stroke:#5b21b6,color:#000 + classDef user fill:#f472b6,stroke:#be185d,color:#000 + + class USER user + class DIFY app + class LLM processor + class EMBED processor + class VS controller + class KB controller +``` + +**GDPR roles** (providers are typically processors for customer-submitted data, but the exact role depends on each provider's terms of service and processing purpose; deployers should review each provider's DPA): +- **Cloud LLM providers (OpenAI, Anthropic, Google)** typically act as processors — requires DPA. +- **Cloud embedding services** typically act as processors — requires DPA. +- **Self-hosted vector stores (Weaviate, Qdrant, pgvector):** Your organization remains the controller — no third-party transfer. +- **Cloud vector stores (Pinecone, Zilliz Cloud)** typically act as processors — requires DPA. +- **Knowledge base documents:** Your organization is the controller — stored in your infrastructure. + +## Article 11: Technical documentation + +High-risk systems need Annex IV documentation. For Dify deployments, key sections include: + +| Section | What Dify provides | What you must document | +|---------|-------------------|----------------------| +| General description | Platform capabilities, supported models | Your specific use case, intended users, deployment context | +| Development process | Dify's architecture, plugin system | Your RAG pipeline design, prompt engineering, knowledge base curation | +| Monitoring | Dify's built-in logging and analytics | Your monitoring plan, alert thresholds, incident response | +| Performance metrics | Dify's evaluation features | Your accuracy benchmarks, quality thresholds, bias testing | +| Risk management | — | Risk assessment for your specific use case | + +Some sections can be derived from Dify's architecture and your deployment configuration, as shown in the table above. The remaining sections require your input. + +## Article 12: Record-keeping + +Dify's built-in logging covers several Article 12 requirements: + +| Requirement | Dify Feature | Status | +|------------|-------------|--------| +| Conversation logs | Full conversation history with timestamps | **Covered** | +| Model tracking | Model name recorded per interaction | **Covered** | +| Token usage | Token counts per message | **Covered** | +| Cost tracking | Cost per conversation (if provider reports it) | **Partial** | +| Document retrieval | RAG source documents logged | **Covered** | +| User identification | User session tracking | **Covered** | +| Error logging | Failed generation logs | **Covered** | +| Data retention | Configurable | **Your responsibility** | + +**Retention periods:** The required retention period depends on your role under the Act. Article 18 requires **providers** of high-risk systems to retain logs and technical documentation for **10 years** after market placement. Article 26(6) requires **deployers** to retain logs for at least **6 months**. If you self-host Dify and have substantially modified the system, you may be classified as a provider rather than a deployer. Confirm the applicable retention period with legal counsel. + +## Article 13: Transparency to deployers + +Article 13 requires providers of high-risk AI systems to supply deployers with the information needed to understand and operate the system correctly. This is a **documentation obligation**, not a logging obligation. For Dify deployments, this means the upstream LLM and embedding providers must give you: + +- Instructions for use, including intended purpose and known limitations +- Accuracy metrics and performance benchmarks +- Known or foreseeable risks and residual risks after mitigation +- Technical specifications: input/output formats, training data characteristics, model architecture details + +As a deployer, collect model cards, system documentation, and accuracy reports from each AI provider your Dify application uses. Maintain these as part of your Annex IV technical documentation. + +Dify's platform features provide **supporting evidence** that can inform Article 13 documentation, but they do not satisfy Article 13 on their own: +- **Source attribution** — Dify's RAG citation feature shows which documents informed the response, supporting deployer-side auditing +- **Model identification** — Dify logs which LLM model generates responses, providing evidence for system documentation +- **Conversation logs** — execution history helps compile performance and behavior evidence + +You must independently produce system documentation covering how your specific Dify deployment uses AI, its intended purpose, performance characteristics, and residual risks. + +## Article 50: End-user transparency + +Article 50 requires deployers to inform end users that they are interacting with an AI system. This is a separate obligation from Article 13 and applies even to limited-risk systems. + +For Dify applications serving end users: + +1. **Disclose AI involvement** — tell users they are interacting with an AI system +2. **AI-generated content labeling** — identify AI-generated content as such (e.g., clear labeling in the UI) + +Dify's "citation" feature also supports end-user transparency by showing users which knowledge base documents informed the answer. + +> **Note:** Article 50 applies to chatbots and systems interacting directly with natural persons. It has a separate scope from the high-risk designation under Annex III — it applies even to limited-risk systems. + +## Article 14: Human oversight + +Article 14 requires that high-risk AI systems be designed so that natural persons can effectively oversee them. Dify provides **automated technical safeguards** that support human oversight, but they are not a substitute for it: + +| Dify Feature | What It Does | Oversight Role | +|-------------|-------------|----------------| +| Annotation/feedback system | Human review of AI outputs | **Direct oversight** — humans evaluate and correct AI responses | +| Content moderation | Built-in filtering before responses reach users | **Automated safeguard** — reduces harmful outputs but does not replace human judgment on edge cases | +| Rate limiting | Controls on API usage | **Automated safeguard** — bounds system behavior, supports overseer's ability to maintain control | +| Workflow control | Insert human review steps between AI generation and output | **Oversight enabler** — allows building approval gates into the pipeline | + +These automated controls are necessary building blocks, but Article 14 compliance requires **human oversight procedures** on top of them: +- **Escalation procedures** — define what happens when moderation triggers or edge cases arise (who is notified, what action is taken) +- **Human review pipeline** — for high-stakes decisions, route AI outputs to a qualified person before they take effect +- **Override mechanism** — a human must be able to halt AI responses or override the system's output +- **Competence requirements** — the human overseer must understand the system's capabilities, limitations, and the context of its outputs + +### Recommended pattern + +For high-risk use cases (HR, legal, medical), configure your Dify workflow to require human approval before the AI response is delivered to the end user or acted upon. + +## Knowledge base compliance + +Dify's knowledge base feature has specific compliance implications: + +1. **Data provenance:** Document where your knowledge base documents come from. Article 10 requires data governance for training data; knowledge bases are analogous. +2. **Update tracking:** When you add, remove, or update documents in the knowledge base, log the change. The AI system's behavior changes with its knowledge base. +3. **PII in documents:** If knowledge base documents contain personal data, GDPR applies to the entire RAG pipeline. Implement access controls and consider PII redaction before indexing. +4. **Copyright:** Ensure you have the right to use the documents in your knowledge base for AI-assisted generation. + +## GDPR considerations + +1. **Legal basis** (Article 6): Document why AI processing of user queries is necessary +2. **Data Processing Agreements** (Article 28): Required for each cloud LLM and embedding provider +3. **Data minimization:** Only include necessary context in prompts; avoid sending entire documents when a relevant excerpt suffices +4. **Right to erasure:** If a user requests deletion, ensure their conversations are removed from Dify's logs AND any vector store entries derived from their data +5. **Cross-border transfers:** Providers based outside the EEA — including US-based providers (OpenAI, Anthropic), and any other non-EEA providers you route to — require Standard Contractual Clauses (SCCs) or equivalent safeguards under Chapter V of the GDPR. Review each provider's transfer mechanism individually. + +## Resources + +- [EU AI Act full text](https://artificialintelligenceact.eu/) +- [Dify documentation](https://docs.dify.ai/) +- [Dify SOC 2 compliance](https://dify.ai/trust) + +--- + +*This is not legal advice. Consult a qualified professional for compliance decisions.* diff --git a/docs/fr-FR/README.md b/docs/fr-FR/README.md index 291c8dab40c..db8730b36bd 100644 --- a/docs/fr-FR/README.md +++ b/docs/fr-FR/README.md @@ -53,7 +53,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

# diff --git a/docs/hi-IN/README.md b/docs/hi-IN/README.md index bedeaa6246d..ad712046b54 100644 --- a/docs/hi-IN/README.md +++ b/docs/hi-IN/README.md @@ -58,6 +58,8 @@ README Tiếng Việt README in Deutsch README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা README in हिन्दी

diff --git a/docs/it-IT/README.md b/docs/it-IT/README.md index 2e96335d3e0..bca560b574c 100644 --- a/docs/it-IT/README.md +++ b/docs/it-IT/README.md @@ -58,7 +58,10 @@ README Tiếng Việt README in Deutsch README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

Dify è una piattaforma open-source per lo sviluppo di applicazioni LLM. La sua interfaccia intuitiva combina flussi di lavoro AI basati su agenti, pipeline RAG, funzionalità di agenti, gestione dei modelli, funzionalità di monitoraggio e altro ancora, permettendovi di passare rapidamente da un prototipo alla produzione. diff --git a/docs/ja-JP/README.md b/docs/ja-JP/README.md index 659ffbda515..298dcb95aa6 100644 --- a/docs/ja-JP/README.md +++ b/docs/ja-JP/README.md @@ -53,7 +53,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

# diff --git a/docs/ko-KR/README.md b/docs/ko-KR/README.md index 2f6c526ef27..2dcacaae8b5 100644 --- a/docs/ko-KR/README.md +++ b/docs/ko-KR/README.md @@ -53,7 +53,11 @@ Türkçe README README Tiếng Việt README in Deutsch - README in বাংলা + README in Italiano + README em Português do Brasil + README Slovenščina + README in বাংলা + README in हिन्दी

Dify는 오픈 소스 LLM 앱 개발 플랫폼입니다. 직관적인 인터페이스를 통해 AI 워크플로우, RAG 파이프라인, 에이전트 기능, 모델 관리, 관찰 기능 등을 결합하여 프로토타입에서 프로덕션까지 빠르게 전환할 수 있습니다. 주요 기능 목록은 다음과 같습니다:

diff --git a/docs/pt-BR/README.md b/docs/pt-BR/README.md index ed29ec02945..f7cc37a20f4 100644 --- a/docs/pt-BR/README.md +++ b/docs/pt-BR/README.md @@ -58,7 +58,10 @@ README em Vietnamita README em Português - BR README in Deutsch + README in Italiano + README Slovenščina README in বাংলা + README in हिन्दी

Dify é uma plataforma de desenvolvimento de aplicativos LLM de código aberto. Sua interface intuitiva combina workflow de IA, pipeline RAG, capacidades de agente, gerenciamento de modelos, recursos de observabilidade e muito mais, permitindo que você vá rapidamente do protótipo à produção. Aqui está uma lista das principais funcionalidades: diff --git a/docs/sl-SI/README.md b/docs/sl-SI/README.md index caef2c303c1..7b3fe76b5d8 100644 --- a/docs/sl-SI/README.md +++ b/docs/sl-SI/README.md @@ -53,9 +53,12 @@ README بالعربية Türkçe README README Tiếng Việt - README Slovenščina README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

Dify je odprtokodna platforma za razvoj aplikacij LLM. Njegov intuitivni vmesnik združuje agentski potek dela z umetno inteligenco, cevovod RAG, zmogljivosti agentov, upravljanje modelov, funkcije opazovanja in več, kar vam omogoča hiter prehod od prototipa do proizvodnje. diff --git a/docs/tlh/README.md b/docs/tlh/README.md index e2acd7734cb..a8e63026c80 100644 --- a/docs/tlh/README.md +++ b/docs/tlh/README.md @@ -53,7 +53,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

# diff --git a/docs/tr-TR/README.md b/docs/tr-TR/README.md index 6361ca5dd93..cecc1b189c3 100644 --- a/docs/tr-TR/README.md +++ b/docs/tr-TR/README.md @@ -53,7 +53,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

Dify, açık kaynaklı bir LLM uygulama geliştirme platformudur. Sezgisel arayüzü, AI iş akışı, RAG pipeline'ı, ajan yetenekleri, model yönetimi, gözlemlenebilirlik özellikleri ve daha fazlasını birleştirerek, prototipten üretime hızlıca geçmenizi sağlar. İşte temel özelliklerin bir listesi: diff --git a/docs/vi-VN/README.md b/docs/vi-VN/README.md index 3042a98d95a..5230d531101 100644 --- a/docs/vi-VN/README.md +++ b/docs/vi-VN/README.md @@ -53,7 +53,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

Dify là một nền tảng phát triển ứng dụng LLM mã nguồn mở. Giao diện trực quan kết hợp quy trình làm việc AI, mô hình RAG, khả năng tác nhân, quản lý mô hình, tính năng quan sát và hơn thế nữa, cho phép bạn nhanh chóng chuyển từ nguyên mẫu sang sản phẩm. Đây là danh sách các tính năng cốt lõi: diff --git a/docs/zh-CN/README.md b/docs/zh-CN/README.md index 15bb447ad8a..8ba8009959e 100644 --- a/docs/zh-CN/README.md +++ b/docs/zh-CN/README.md @@ -53,7 +53,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी
# diff --git a/docs/zh-TW/README.md b/docs/zh-TW/README.md index 14b343ba29d..de5bab8679c 100644 --- a/docs/zh-TW/README.md +++ b/docs/zh-TW/README.md @@ -57,6 +57,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina + README in বাংলা + README in हिन्दी

Dify 是一個開源的 LLM 應用程式開發平台。其直觀的界面結合了智能代理工作流程、RAG 管道、代理功能、模型管理、可觀察性功能等,讓您能夠快速從原型進展到生產環境。 diff --git a/e2e/.gitignore b/e2e/.gitignore new file mode 100644 index 00000000000..96c1e0f3a18 --- /dev/null +++ b/e2e/.gitignore @@ -0,0 +1,6 @@ +node_modules/ +.auth/ +playwright-report/ +test-results/ +cucumber-report/ +.logs/ diff --git a/e2e/AGENTS.md b/e2e/AGENTS.md new file mode 100644 index 00000000000..ae642768f5f --- /dev/null +++ b/e2e/AGENTS.md @@ -0,0 +1,167 @@ +# E2E + +This package contains the repository-level end-to-end tests for Dify. + +This file is the canonical package guide for `e2e/`. Keep detailed workflow, architecture, debugging, and reporting documentation here. Keep `README.md` as a minimal pointer to this file so the two documents do not drift. + +The suite uses Cucumber for scenario definitions and Playwright as the browser execution layer. + +It tests: + +- backend API started from source +- frontend served from the production artifact +- middleware services started from Docker + +## Prerequisites + +- Node.js `^22.22.1` +- `pnpm` +- `uv` +- Docker + +Run the following commands from the repository root. + +Install Playwright browsers once: + +```bash +pnpm install +pnpm -C e2e e2e:install +pnpm -C e2e check +``` + +`pnpm install` is resolved through the repository workspace and uses the shared root lockfile plus `pnpm-workspace.yaml`. + +Use `pnpm check` as the default local verification step after editing E2E TypeScript, Cucumber support code, or feature glue. It runs formatting, linting, and type checks for this package. + +Common commands: + +```bash +# authenticated-only regression (default excludes @fresh) +# expects backend API, frontend artifact, and middleware stack to already be running +pnpm -C e2e e2e + +# full reset + fresh install + authenticated scenarios +# starts required middleware/dependencies for you +pnpm -C e2e e2e:full + +# run a tagged subset +pnpm -C e2e e2e -- --tags @smoke + +# headed browser +pnpm -C e2e e2e:headed -- --tags @smoke + +# slow down browser actions for local debugging +E2E_SLOW_MO=500 pnpm -C e2e e2e:headed -- --tags @smoke +``` + +Frontend artifact behavior: + +- if `web/.next/BUILD_ID` exists, E2E reuses the existing build by default +- if you set `E2E_FORCE_WEB_BUILD=1`, E2E rebuilds the frontend before starting it + +## Lifecycle + +```mermaid +flowchart TD + A["Start E2E run"] --> B["run-cucumber.ts orchestrates setup/API/frontend"] + B --> C["support/web-server.ts starts or reuses frontend directly"] + C --> D["Cucumber loads config, steps, and support modules"] + D --> E["BeforeAll bootstraps shared auth state via /install"] + E --> F{"Which command is running?"} + F -->|`pnpm e2e`| G["Run config default tags: not @fresh and not @skip"] + F -->|`pnpm e2e:full*`| H["Override tags to not @skip"] + G --> I["Per-scenario BrowserContext from shared browser"] + H --> I + I --> J["Failure artifacts written to cucumber-report/artifacts"] +``` + +Ownership is split like this: + +- `scripts/setup.ts` is the single environment entrypoint for reset, middleware, backend, and frontend startup +- `run-cucumber.ts` orchestrates the E2E run and Cucumber invocation +- `support/web-server.ts` manages frontend reuse, startup, readiness, and shutdown +- `features/support/hooks.ts` manages auth bootstrap, scenario lifecycle, and diagnostics +- `features/support/world.ts` owns per-scenario typed context +- `features/step-definitions/` holds domain-oriented glue so the official VS Code Cucumber plugin works with default conventions when `e2e/` is opened as the workspace root + +Package layout: + +- `features/`: Gherkin scenarios grouped by capability +- `features/step-definitions/`: domain-oriented step definitions +- `features/support/hooks.ts`: suite lifecycle, auth-state bootstrap, diagnostics +- `features/support/world.ts`: shared scenario context +- `support/web-server.ts`: typed frontend startup/reuse logic +- `scripts/setup.ts`: reset and service lifecycle commands +- `scripts/run-cucumber.ts`: Cucumber orchestration entrypoint + +Behavior depends on instance state: + +- uninitialized instance: completes install and stores authenticated state +- initialized instance: signs in and reuses authenticated state + +Because of that, the `@fresh` install scenario only runs in the `pnpm e2e:full*` flows. The default `pnpm e2e*` flows exclude `@fresh` via Cucumber config tags so they can be re-run against an already initialized instance. + +Reset all persisted E2E state: + +```bash +pnpm -C e2e e2e:reset +``` + +This removes: + +- `docker/volumes/db/data` +- `docker/volumes/redis/data` +- `docker/volumes/weaviate` +- `docker/volumes/plugin_daemon` +- `e2e/.auth` +- `e2e/.logs` +- `e2e/cucumber-report` + +Start the full middleware stack: + +```bash +pnpm -C e2e e2e:middleware:up +``` + +Stop the full middleware stack: + +```bash +pnpm e2e:middleware:down +``` + +The middleware stack includes: + +- PostgreSQL +- Redis +- Weaviate +- Sandbox +- SSRF proxy +- Plugin daemon + +Fresh install verification: + +```bash +pnpm e2e:full +``` + +Run the Cucumber suite against an already running middleware stack: + +```bash +pnpm e2e:middleware:up +pnpm e2e +pnpm e2e:middleware:down +``` + +Artifacts and diagnostics: + +- `cucumber-report/report.html`: HTML report +- `cucumber-report/report.json`: JSON report +- `cucumber-report/artifacts/`: failure screenshots and HTML captures +- `.logs/cucumber-api.log`: backend startup log +- `.logs/cucumber-web.log`: frontend startup log + +Open the HTML report locally with: + +```bash +open cucumber-report/report.html +``` diff --git a/e2e/README.md b/e2e/README.md new file mode 100644 index 00000000000..9b4046eaff7 --- /dev/null +++ b/e2e/README.md @@ -0,0 +1,3 @@ +# E2E + +Canonical documentation for this package lives in [AGENTS.md](./AGENTS.md). diff --git a/e2e/cucumber.config.ts b/e2e/cucumber.config.ts new file mode 100644 index 00000000000..c162a6562e9 --- /dev/null +++ b/e2e/cucumber.config.ts @@ -0,0 +1,19 @@ +import type { IConfiguration } from '@cucumber/cucumber' + +const config = { + format: [ + 'progress-bar', + 'summary', + 'html:./cucumber-report/report.html', + 'json:./cucumber-report/report.json', + ], + import: ['features/**/*.ts'], + parallel: 1, + paths: ['features/**/*.feature'], + tags: process.env.E2E_CUCUMBER_TAGS || 'not @fresh and not @skip', + timeout: 60_000, +} satisfies Partial & { + timeout: number +} + +export default config diff --git a/e2e/features/apps/create-app.feature b/e2e/features/apps/create-app.feature new file mode 100644 index 00000000000..c0ca8ea4e00 --- /dev/null +++ b/e2e/features/apps/create-app.feature @@ -0,0 +1,10 @@ +@apps @authenticated +Feature: Create app + Scenario: Create a new blank app and redirect to the editor + Given I am signed in as the default E2E admin + When I open the apps console + And I start creating a blank app + And I enter a unique E2E app name + And I confirm app creation + Then I should land on the app editor + And I should see the "Orchestrate" text diff --git a/e2e/features/smoke/authenticated-entry.feature b/e2e/features/smoke/authenticated-entry.feature new file mode 100644 index 00000000000..3c1191a330d --- /dev/null +++ b/e2e/features/smoke/authenticated-entry.feature @@ -0,0 +1,8 @@ +@smoke @authenticated +Feature: Authenticated app console + Scenario: Open the apps console with the shared authenticated state + Given I am signed in as the default E2E admin + When I open the apps console + Then I should stay on the apps console + And I should see the "Create from Blank" button + And I should not see the "Sign in" button diff --git a/e2e/features/smoke/install.feature b/e2e/features/smoke/install.feature new file mode 100644 index 00000000000..39fc1f996b9 --- /dev/null +++ b/e2e/features/smoke/install.feature @@ -0,0 +1,7 @@ +@smoke @fresh +Feature: Fresh installation bootstrap + Scenario: Complete the initial installation bootstrap on a fresh instance + Given the last authentication bootstrap came from a fresh install + When I open the apps console + Then I should stay on the apps console + And I should see the "Create from Blank" button diff --git a/e2e/features/step-definitions/apps/create-app.steps.ts b/e2e/features/step-definitions/apps/create-app.steps.ts new file mode 100644 index 00000000000..b8e76c6f064 --- /dev/null +++ b/e2e/features/step-definitions/apps/create-app.steps.ts @@ -0,0 +1,29 @@ +import { Then, When } from '@cucumber/cucumber' +import { expect } from '@playwright/test' +import type { DifyWorld } from '../../support/world' + +When('I start creating a blank app', async function (this: DifyWorld) { + const page = this.getPage() + + await expect(page.getByRole('button', { name: 'Create from Blank' })).toBeVisible() + await page.getByRole('button', { name: 'Create from Blank' }).click() +}) + +When('I enter a unique E2E app name', async function (this: DifyWorld) { + const appName = `E2E App ${Date.now()}` + + await this.getPage().getByPlaceholder('Give your app a name').fill(appName) +}) + +When('I confirm app creation', async function (this: DifyWorld) { + const createButton = this.getPage() + .getByRole('button', { name: /^Create(?:\s|$)/ }) + .last() + + await expect(createButton).toBeEnabled() + await createButton.click() +}) + +Then('I should land on the app editor', async function (this: DifyWorld) { + await expect(this.getPage()).toHaveURL(/\/app\/[^/]+\/(workflow|configuration)(?:\?.*)?$/) +}) diff --git a/e2e/features/step-definitions/common/auth.steps.ts b/e2e/features/step-definitions/common/auth.steps.ts new file mode 100644 index 00000000000..bf03c2d8f43 --- /dev/null +++ b/e2e/features/step-definitions/common/auth.steps.ts @@ -0,0 +1,11 @@ +import { Given } from '@cucumber/cucumber' +import type { DifyWorld } from '../../support/world' + +Given('I am signed in as the default E2E admin', async function (this: DifyWorld) { + const session = await this.getAuthSession() + + this.attach( + `Authenticated as ${session.adminEmail} using ${session.mode} flow at ${session.baseURL}.`, + 'text/plain', + ) +}) diff --git a/e2e/features/step-definitions/common/navigation.steps.ts b/e2e/features/step-definitions/common/navigation.steps.ts new file mode 100644 index 00000000000..b18ff035fa6 --- /dev/null +++ b/e2e/features/step-definitions/common/navigation.steps.ts @@ -0,0 +1,23 @@ +import { Then, When } from '@cucumber/cucumber' +import { expect } from '@playwright/test' +import type { DifyWorld } from '../../support/world' + +When('I open the apps console', async function (this: DifyWorld) { + await this.getPage().goto('/apps') +}) + +Then('I should stay on the apps console', async function (this: DifyWorld) { + await expect(this.getPage()).toHaveURL(/\/apps(?:\?.*)?$/) +}) + +Then('I should see the {string} button', async function (this: DifyWorld, label: string) { + await expect(this.getPage().getByRole('button', { name: label })).toBeVisible() +}) + +Then('I should not see the {string} button', async function (this: DifyWorld, label: string) { + await expect(this.getPage().getByRole('button', { name: label })).not.toBeVisible() +}) + +Then('I should see the {string} text', async function (this: DifyWorld, text: string) { + await expect(this.getPage().getByText(text)).toBeVisible({ timeout: 30_000 }) +}) diff --git a/e2e/features/step-definitions/smoke/install.steps.ts b/e2e/features/step-definitions/smoke/install.steps.ts new file mode 100644 index 00000000000..857e01a9717 --- /dev/null +++ b/e2e/features/step-definitions/smoke/install.steps.ts @@ -0,0 +1,12 @@ +import { Given } from '@cucumber/cucumber' +import { expect } from '@playwright/test' +import type { DifyWorld } from '../../support/world' + +Given( + 'the last authentication bootstrap came from a fresh install', + async function (this: DifyWorld) { + const session = await this.getAuthSession() + + expect(session.mode).toBe('install') + }, +) diff --git a/e2e/features/support/hooks.ts b/e2e/features/support/hooks.ts new file mode 100644 index 00000000000..a6862d79f54 --- /dev/null +++ b/e2e/features/support/hooks.ts @@ -0,0 +1,90 @@ +import { After, AfterAll, Before, BeforeAll, Status, setDefaultTimeout } from '@cucumber/cucumber' +import { chromium, type Browser } from '@playwright/test' +import { mkdir, writeFile } from 'node:fs/promises' +import path from 'node:path' +import { fileURLToPath } from 'node:url' +import { ensureAuthenticatedState } from '../../fixtures/auth' +import { baseURL, cucumberHeadless, cucumberSlowMo } from '../../test-env' +import type { DifyWorld } from './world' + +const e2eRoot = fileURLToPath(new URL('../..', import.meta.url)) +const artifactsDir = path.join(e2eRoot, 'cucumber-report', 'artifacts') + +let browser: Browser | undefined + +setDefaultTimeout(60_000) + +const sanitizeForPath = (value: string) => + value.replaceAll(/[^a-zA-Z0-9_-]+/g, '-').replaceAll(/^-+|-+$/g, '') + +const writeArtifact = async ( + scenarioName: string, + extension: 'html' | 'png', + contents: Buffer | string, +) => { + const artifactPath = path.join( + artifactsDir, + `${Date.now()}-${sanitizeForPath(scenarioName || 'scenario')}.${extension}`, + ) + await writeFile(artifactPath, contents) + + return artifactPath +} + +BeforeAll(async () => { + await mkdir(artifactsDir, { recursive: true }) + + browser = await chromium.launch({ + headless: cucumberHeadless, + slowMo: cucumberSlowMo, + }) + + console.log(`[e2e] session cache bootstrap against ${baseURL}`) + await ensureAuthenticatedState(browser, baseURL) +}) + +Before(async function (this: DifyWorld, { pickle }) { + if (!browser) throw new Error('Shared Playwright browser is not available.') + + await this.startAuthenticatedSession(browser) + this.scenarioStartedAt = Date.now() + + const tags = pickle.tags.map((tag) => tag.name).join(' ') + console.log(`[e2e] start ${pickle.name}${tags ? ` ${tags}` : ''}`) +}) + +After(async function (this: DifyWorld, { pickle, result }) { + const elapsedMs = this.scenarioStartedAt ? Date.now() - this.scenarioStartedAt : undefined + + if (result?.status !== Status.PASSED && this.page) { + const screenshot = await this.page.screenshot({ + fullPage: true, + }) + const screenshotPath = await writeArtifact(pickle.name, 'png', screenshot) + this.attach(screenshot, 'image/png') + + const html = await this.page.content() + const htmlPath = await writeArtifact(pickle.name, 'html', html) + this.attach(html, 'text/html') + + if (this.consoleErrors.length > 0) + this.attach(`Console Errors:\n${this.consoleErrors.join('\n')}`, 'text/plain') + + if (this.pageErrors.length > 0) + this.attach(`Page Errors:\n${this.pageErrors.join('\n')}`, 'text/plain') + + this.attach(`Artifacts:\n${[screenshotPath, htmlPath].join('\n')}`, 'text/plain') + } + + const status = result?.status || 'UNKNOWN' + console.log( + `[e2e] end ${pickle.name} status=${status}${elapsedMs ? ` durationMs=${elapsedMs}` : ''}`, + ) + + await this.closeSession() +}) + +AfterAll(async () => { + await browser?.close() + browser = undefined +}) diff --git a/e2e/features/support/world.ts b/e2e/features/support/world.ts new file mode 100644 index 00000000000..15ab8daf168 --- /dev/null +++ b/e2e/features/support/world.ts @@ -0,0 +1,68 @@ +import { type IWorldOptions, World, setWorldConstructor } from '@cucumber/cucumber' +import type { Browser, BrowserContext, ConsoleMessage, Page } from '@playwright/test' +import { + authStatePath, + readAuthSessionMetadata, + type AuthSessionMetadata, +} from '../../fixtures/auth' +import { baseURL, defaultLocale } from '../../test-env' + +export class DifyWorld extends World { + context: BrowserContext | undefined + page: Page | undefined + consoleErrors: string[] = [] + pageErrors: string[] = [] + scenarioStartedAt: number | undefined + session: AuthSessionMetadata | undefined + + constructor(options: IWorldOptions) { + super(options) + this.resetScenarioState() + } + + resetScenarioState() { + this.consoleErrors = [] + this.pageErrors = [] + } + + async startAuthenticatedSession(browser: Browser) { + this.resetScenarioState() + this.context = await browser.newContext({ + baseURL, + locale: defaultLocale, + storageState: authStatePath, + }) + this.context.setDefaultTimeout(30_000) + this.page = await this.context.newPage() + this.page.setDefaultTimeout(30_000) + + this.page.on('console', (message: ConsoleMessage) => { + if (message.type() === 'error') this.consoleErrors.push(message.text()) + }) + this.page.on('pageerror', (error) => { + this.pageErrors.push(error.message) + }) + } + + getPage() { + if (!this.page) throw new Error('Playwright page has not been initialized for this scenario.') + + return this.page + } + + async getAuthSession() { + this.session ??= await readAuthSessionMetadata() + return this.session + } + + async closeSession() { + await this.context?.close() + this.context = undefined + this.page = undefined + this.session = undefined + this.scenarioStartedAt = undefined + this.resetScenarioState() + } +} + +setWorldConstructor(DifyWorld) diff --git a/e2e/fixtures/auth.ts b/e2e/fixtures/auth.ts new file mode 100644 index 00000000000..853bfff5ed8 --- /dev/null +++ b/e2e/fixtures/auth.ts @@ -0,0 +1,148 @@ +import type { Browser, Page } from '@playwright/test' +import { expect } from '@playwright/test' +import { mkdir, readFile, writeFile } from 'node:fs/promises' +import path from 'node:path' +import { fileURLToPath } from 'node:url' +import { defaultBaseURL, defaultLocale } from '../test-env' + +export type AuthSessionMetadata = { + adminEmail: string + baseURL: string + mode: 'install' | 'login' + usedInitPassword: boolean +} + +const WAIT_TIMEOUT_MS = 120_000 +const e2eRoot = fileURLToPath(new URL('..', import.meta.url)) + +export const authDir = path.join(e2eRoot, '.auth') +export const authStatePath = path.join(authDir, 'admin.json') +export const authMetadataPath = path.join(authDir, 'session.json') + +export const adminCredentials = { + email: process.env.E2E_ADMIN_EMAIL || 'e2e-admin@example.com', + name: process.env.E2E_ADMIN_NAME || 'E2E Admin', + password: process.env.E2E_ADMIN_PASSWORD || 'E2eAdmin12345', +} + +const initPassword = process.env.E2E_INIT_PASSWORD || 'E2eInit12345' + +export const resolveBaseURL = (configuredBaseURL?: string) => + configuredBaseURL || process.env.E2E_BASE_URL || defaultBaseURL + +export const readAuthSessionMetadata = async () => { + const content = await readFile(authMetadataPath, 'utf8') + return JSON.parse(content) as AuthSessionMetadata +} + +const escapeRegex = (value: string) => value.replaceAll(/[.*+?^${}()|[\]\\]/g, '\\$&') + +const appURL = (baseURL: string, pathname: string) => new URL(pathname, baseURL).toString() + +const waitForPageState = async (page: Page) => { + const installHeading = page.getByRole('heading', { name: 'Setting up an admin account' }) + const signInButton = page.getByRole('button', { name: 'Sign in' }) + const initPasswordField = page.getByLabel('Admin initialization password') + + const deadline = Date.now() + WAIT_TIMEOUT_MS + + while (Date.now() < deadline) { + if (await installHeading.isVisible().catch(() => false)) return 'install' as const + if (await signInButton.isVisible().catch(() => false)) return 'login' as const + if (await initPasswordField.isVisible().catch(() => false)) return 'init' as const + + await page.waitForTimeout(1_000) + } + + throw new Error(`Unable to determine auth page state for ${page.url()}`) +} + +const completeInitPasswordIfNeeded = async (page: Page) => { + const initPasswordField = page.getByLabel('Admin initialization password') + if (!(await initPasswordField.isVisible({ timeout: 3_000 }).catch(() => false))) return false + + await initPasswordField.fill(initPassword) + await page.getByRole('button', { name: 'Validate' }).click() + await expect(page.getByRole('heading', { name: 'Setting up an admin account' })).toBeVisible({ + timeout: WAIT_TIMEOUT_MS, + }) + + return true +} + +const completeInstall = async (page: Page, baseURL: string) => { + await expect(page.getByRole('heading', { name: 'Setting up an admin account' })).toBeVisible({ + timeout: WAIT_TIMEOUT_MS, + }) + + await page.getByLabel('Email address').fill(adminCredentials.email) + await page.getByLabel('Username').fill(adminCredentials.name) + await page.getByLabel('Password').fill(adminCredentials.password) + await page.getByRole('button', { name: 'Set up' }).click() + + await expect(page).toHaveURL(new RegExp(`^${escapeRegex(baseURL)}/apps(?:\\?.*)?$`), { + timeout: WAIT_TIMEOUT_MS, + }) +} + +const completeLogin = async (page: Page, baseURL: string) => { + await expect(page.getByRole('button', { name: 'Sign in' })).toBeVisible({ + timeout: WAIT_TIMEOUT_MS, + }) + + await page.getByLabel('Email address').fill(adminCredentials.email) + await page.getByLabel('Password').fill(adminCredentials.password) + await page.getByRole('button', { name: 'Sign in' }).click() + + await expect(page).toHaveURL(new RegExp(`^${escapeRegex(baseURL)}/apps(?:\\?.*)?$`), { + timeout: WAIT_TIMEOUT_MS, + }) +} + +export const ensureAuthenticatedState = async (browser: Browser, configuredBaseURL?: string) => { + const baseURL = resolveBaseURL(configuredBaseURL) + + await mkdir(authDir, { recursive: true }) + + const context = await browser.newContext({ + baseURL, + locale: defaultLocale, + }) + const page = await context.newPage() + + try { + await page.goto(appURL(baseURL, '/install'), { waitUntil: 'networkidle' }) + + let usedInitPassword = await completeInitPasswordIfNeeded(page) + let pageState = await waitForPageState(page) + + while (pageState === 'init') { + const completedInitPassword = await completeInitPasswordIfNeeded(page) + if (!completedInitPassword) + throw new Error(`Unable to validate initialization password for ${page.url()}`) + + usedInitPassword = true + pageState = await waitForPageState(page) + } + + if (pageState === 'install') await completeInstall(page, baseURL) + else await completeLogin(page, baseURL) + + await expect(page.getByRole('button', { name: 'Create from Blank' })).toBeVisible({ + timeout: WAIT_TIMEOUT_MS, + }) + + await context.storageState({ path: authStatePath }) + + const metadata: AuthSessionMetadata = { + adminEmail: adminCredentials.email, + baseURL, + mode: pageState, + usedInitPassword, + } + + await writeFile(authMetadataPath, `${JSON.stringify(metadata, null, 2)}\n`, 'utf8') + } finally { + await context.close() + } +} diff --git a/e2e/package.json b/e2e/package.json new file mode 100644 index 00000000000..0ee2afff7fb --- /dev/null +++ b/e2e/package.json @@ -0,0 +1,24 @@ +{ + "name": "dify-e2e", + "private": true, + "type": "module", + "scripts": { + "check": "vp check --fix", + "e2e": "tsx ./scripts/run-cucumber.ts", + "e2e:full": "tsx ./scripts/run-cucumber.ts --full", + "e2e:full:headed": "tsx ./scripts/run-cucumber.ts --full --headed", + "e2e:headed": "tsx ./scripts/run-cucumber.ts --headed", + "e2e:install": "playwright install --with-deps chromium", + "e2e:middleware:down": "tsx ./scripts/setup.ts middleware-down", + "e2e:middleware:up": "tsx ./scripts/setup.ts middleware-up", + "e2e:reset": "tsx ./scripts/setup.ts reset" + }, + "devDependencies": { + "@cucumber/cucumber": "catalog:", + "@playwright/test": "catalog:", + "@types/node": "catalog:", + "tsx": "catalog:", + "typescript": "catalog:", + "vite-plus": "catalog:" + } +} diff --git a/e2e/scripts/common.ts b/e2e/scripts/common.ts new file mode 100644 index 00000000000..bb82121079d --- /dev/null +++ b/e2e/scripts/common.ts @@ -0,0 +1,242 @@ +import { spawn, type ChildProcess } from 'node:child_process' +import { access, copyFile, readFile, writeFile } from 'node:fs/promises' +import net from 'node:net' +import path from 'node:path' +import { fileURLToPath, pathToFileURL } from 'node:url' +import { sleep } from '../support/process' + +type RunCommandOptions = { + command: string + args: string[] + cwd: string + env?: NodeJS.ProcessEnv + stdio?: 'inherit' | 'pipe' +} + +type RunCommandResult = { + exitCode: number + stdout: string + stderr: string +} + +type ForegroundProcessOptions = { + command: string + args: string[] + cwd: string + env?: NodeJS.ProcessEnv +} + +export const rootDir = fileURLToPath(new URL('../..', import.meta.url)) +export const e2eDir = path.join(rootDir, 'e2e') +export const apiDir = path.join(rootDir, 'api') +export const dockerDir = path.join(rootDir, 'docker') +export const webDir = path.join(rootDir, 'web') + +export const middlewareComposeFile = path.join(dockerDir, 'docker-compose.middleware.yaml') +export const middlewareEnvFile = path.join(dockerDir, 'middleware.env') +export const middlewareEnvExampleFile = path.join(dockerDir, 'middleware.env.example') +export const webEnvLocalFile = path.join(webDir, '.env.local') +export const webEnvExampleFile = path.join(webDir, '.env.example') +export const apiEnvExampleFile = path.join(apiDir, 'tests', 'integration_tests', '.env.example') + +const formatCommand = (command: string, args: string[]) => [command, ...args].join(' ') + +export const isMainModule = (metaUrl: string) => { + const entrypoint = process.argv[1] + if (!entrypoint) return false + + return pathToFileURL(entrypoint).href === metaUrl +} + +export const runCommand = async ({ + command, + args, + cwd, + env, + stdio = 'inherit', +}: RunCommandOptions): Promise => { + const childProcess = spawn(command, args, { + cwd, + env: { + ...process.env, + ...env, + }, + stdio: stdio === 'inherit' ? 'inherit' : 'pipe', + }) + + let stdout = '' + let stderr = '' + + if (stdio === 'pipe') { + childProcess.stdout?.on('data', (chunk: Buffer | string) => { + stdout += chunk.toString() + }) + childProcess.stderr?.on('data', (chunk: Buffer | string) => { + stderr += chunk.toString() + }) + } + + return await new Promise((resolve, reject) => { + childProcess.once('error', reject) + childProcess.once('exit', (code) => { + resolve({ + exitCode: code ?? 1, + stdout, + stderr, + }) + }) + }) +} + +export const runCommandOrThrow = async (options: RunCommandOptions) => { + const result = await runCommand(options) + + if (result.exitCode !== 0) { + throw new Error( + `Command failed (${result.exitCode}): ${formatCommand(options.command, options.args)}`, + ) + } + + return result +} + +const forwardSignalsToChild = (childProcess: ChildProcess) => { + const handleSignal = (signal: NodeJS.Signals) => { + if (childProcess.exitCode === null) childProcess.kill(signal) + } + + const onSigint = () => handleSignal('SIGINT') + const onSigterm = () => handleSignal('SIGTERM') + + process.on('SIGINT', onSigint) + process.on('SIGTERM', onSigterm) + + return () => { + process.off('SIGINT', onSigint) + process.off('SIGTERM', onSigterm) + } +} + +export const runForegroundProcess = async ({ + command, + args, + cwd, + env, +}: ForegroundProcessOptions) => { + const childProcess = spawn(command, args, { + cwd, + env: { + ...process.env, + ...env, + }, + stdio: 'inherit', + }) + + const cleanupSignals = forwardSignalsToChild(childProcess) + const exitCode = await new Promise((resolve, reject) => { + childProcess.once('error', reject) + childProcess.once('exit', (code) => { + resolve(code ?? 1) + }) + }) + + cleanupSignals() + process.exit(exitCode) +} + +export const ensureFileExists = async (filePath: string, exampleFilePath: string) => { + try { + await access(filePath) + } catch { + await copyFile(exampleFilePath, filePath) + } +} + +export const ensureLineInFile = async (filePath: string, line: string) => { + const fileContent = await readFile(filePath, 'utf8') + const lines = fileContent.split(/\r?\n/) + const assignmentPrefix = line.includes('=') ? `${line.slice(0, line.indexOf('='))}=` : null + + if (lines.includes(line)) return + + if (assignmentPrefix && lines.some((existingLine) => existingLine.startsWith(assignmentPrefix))) + return + + const normalizedContent = fileContent.endsWith('\n') ? fileContent : `${fileContent}\n` + await writeFile(filePath, `${normalizedContent}${line}\n`, 'utf8') +} + +export const ensureWebEnvLocal = async () => { + await ensureFileExists(webEnvLocalFile, webEnvExampleFile) + + const fileContent = await readFile(webEnvLocalFile, 'utf8') + const nextContent = fileContent.replaceAll('http://localhost:5001', 'http://127.0.0.1:5001') + + if (nextContent !== fileContent) await writeFile(webEnvLocalFile, nextContent, 'utf8') +} + +export const readSimpleDotenv = async (filePath: string) => { + const fileContent = await readFile(filePath, 'utf8') + const entries = fileContent + .split(/\r?\n/) + .map((line) => line.trim()) + .filter((line) => line && !line.startsWith('#')) + .map<[string, string]>((line) => { + const separatorIndex = line.indexOf('=') + const key = separatorIndex === -1 ? line : line.slice(0, separatorIndex).trim() + const rawValue = separatorIndex === -1 ? '' : line.slice(separatorIndex + 1).trim() + + if ( + (rawValue.startsWith('"') && rawValue.endsWith('"')) || + (rawValue.startsWith("'") && rawValue.endsWith("'")) + ) { + return [key, rawValue.slice(1, -1)] + } + + return [key, rawValue] + }) + + return Object.fromEntries(entries) +} + +export const waitForCondition = async ({ + check, + description, + intervalMs, + timeoutMs, +}: { + check: () => Promise | boolean + description: string + intervalMs: number + timeoutMs: number +}) => { + const deadline = Date.now() + timeoutMs + + while (Date.now() < deadline) { + if (await check()) return + + await sleep(intervalMs) + } + + throw new Error(`Timed out waiting for ${description} after ${timeoutMs}ms.`) +} + +export const isTcpPortReachable = async (host: string, port: number, timeoutMs = 1_000) => { + return await new Promise((resolve) => { + const socket = net.createConnection({ + host, + port, + }) + + const finish = (result: boolean) => { + socket.removeAllListeners() + socket.destroy() + resolve(result) + } + + socket.setTimeout(timeoutMs) + socket.once('connect', () => finish(true)) + socket.once('timeout', () => finish(false)) + socket.once('error', () => finish(false)) + }) +} diff --git a/e2e/scripts/run-cucumber.ts b/e2e/scripts/run-cucumber.ts new file mode 100644 index 00000000000..39e91579164 --- /dev/null +++ b/e2e/scripts/run-cucumber.ts @@ -0,0 +1,154 @@ +import { mkdir, rm } from 'node:fs/promises' +import path from 'node:path' +import { startWebServer, stopWebServer } from '../support/web-server' +import { waitForUrl, startLoggedProcess, stopManagedProcess } from '../support/process' +import { apiURL, baseURL, reuseExistingWebServer } from '../test-env' +import { e2eDir, isMainModule, runCommand } from './common' +import { resetState, startMiddleware, stopMiddleware } from './setup' + +type RunOptions = { + forwardArgs: string[] + full: boolean + headed: boolean +} + +const parseArgs = (argv: string[]): RunOptions => { + let full = false + let headed = false + const forwardArgs: string[] = [] + + for (let index = 0; index < argv.length; index += 1) { + const arg = argv[index] + + if (arg === '--') { + forwardArgs.push(...argv.slice(index + 1)) + break + } + + if (arg === '--full') { + full = true + continue + } + + if (arg === '--headed') { + headed = true + continue + } + + forwardArgs.push(arg) + } + + return { + forwardArgs, + full, + headed, + } +} + +const hasCustomTags = (forwardArgs: string[]) => + forwardArgs.some((arg) => arg === '--tags' || arg.startsWith('--tags=')) + +const main = async () => { + const { forwardArgs, full, headed } = parseArgs(process.argv.slice(2)) + const startMiddlewareForRun = full + const resetStateForRun = full + + if (resetStateForRun) await resetState() + + if (startMiddlewareForRun) await startMiddleware() + + const cucumberReportDir = path.join(e2eDir, 'cucumber-report') + const logDir = path.join(e2eDir, '.logs') + + await rm(cucumberReportDir, { force: true, recursive: true }) + await mkdir(logDir, { recursive: true }) + + const apiProcess = await startLoggedProcess({ + command: 'npx', + args: ['tsx', './scripts/setup.ts', 'api'], + cwd: e2eDir, + label: 'api server', + logFilePath: path.join(logDir, 'cucumber-api.log'), + }) + + let cleanupPromise: Promise | undefined + const cleanup = async () => { + if (!cleanupPromise) { + cleanupPromise = (async () => { + await stopWebServer() + await stopManagedProcess(apiProcess) + + if (startMiddlewareForRun) { + try { + await stopMiddleware() + } catch { + // Cleanup should continue even if middleware shutdown fails. + } + } + })() + } + + await cleanupPromise + } + + const onTerminate = () => { + void cleanup().finally(() => { + process.exit(1) + }) + } + + process.once('SIGINT', onTerminate) + process.once('SIGTERM', onTerminate) + + try { + try { + await waitForUrl(`${apiURL}/health`, 180_000, 1_000) + } catch { + throw new Error(`API did not become ready at ${apiURL}/health.`) + } + + await startWebServer({ + baseURL, + command: 'npx', + args: ['tsx', './scripts/setup.ts', 'web'], + cwd: e2eDir, + logFilePath: path.join(logDir, 'cucumber-web.log'), + reuseExistingServer: reuseExistingWebServer, + timeoutMs: 300_000, + }) + + const cucumberEnv: NodeJS.ProcessEnv = { + ...process.env, + CUCUMBER_HEADLESS: headed ? '0' : '1', + } + + if (startMiddlewareForRun && !hasCustomTags(forwardArgs)) + cucumberEnv.E2E_CUCUMBER_TAGS = 'not @skip' + + const result = await runCommand({ + command: 'npx', + args: [ + 'tsx', + './node_modules/@cucumber/cucumber/bin/cucumber.js', + '--config', + './cucumber.config.ts', + ...forwardArgs, + ], + cwd: e2eDir, + env: cucumberEnv, + }) + + process.exitCode = result.exitCode + } finally { + process.off('SIGINT', onTerminate) + process.off('SIGTERM', onTerminate) + await cleanup() + } +} + +if (isMainModule(import.meta.url)) { + void main().catch((error) => { + console.error(error instanceof Error ? error.message : String(error)) + process.exit(1) + }) +} diff --git a/e2e/scripts/setup.ts b/e2e/scripts/setup.ts new file mode 100644 index 00000000000..6f38598df4e --- /dev/null +++ b/e2e/scripts/setup.ts @@ -0,0 +1,306 @@ +import { access, mkdir, rm } from 'node:fs/promises' +import path from 'node:path' +import { waitForUrl } from '../support/process' +import { + apiDir, + apiEnvExampleFile, + dockerDir, + e2eDir, + ensureFileExists, + ensureLineInFile, + ensureWebEnvLocal, + isMainModule, + isTcpPortReachable, + middlewareComposeFile, + middlewareEnvExampleFile, + middlewareEnvFile, + readSimpleDotenv, + runCommand, + runCommandOrThrow, + runForegroundProcess, + waitForCondition, + webDir, +} from './common' + +const buildIdPath = path.join(webDir, '.next', 'BUILD_ID') + +const middlewareDataPaths = [ + path.join(dockerDir, 'volumes', 'db', 'data'), + path.join(dockerDir, 'volumes', 'plugin_daemon'), + path.join(dockerDir, 'volumes', 'redis', 'data'), + path.join(dockerDir, 'volumes', 'weaviate'), +] + +const e2eStatePaths = [ + path.join(e2eDir, '.auth'), + path.join(e2eDir, 'cucumber-report'), + path.join(e2eDir, '.logs'), + path.join(e2eDir, 'playwright-report'), + path.join(e2eDir, 'test-results'), +] + +const composeArgs = [ + 'compose', + '-f', + middlewareComposeFile, + '--profile', + 'postgresql', + '--profile', + 'weaviate', +] + +const getApiEnvironment = async () => { + const envFromExample = await readSimpleDotenv(apiEnvExampleFile) + + return { + ...envFromExample, + FLASK_APP: 'app.py', + } +} + +const getServiceContainerId = async (service: string) => { + const result = await runCommandOrThrow({ + command: 'docker', + args: ['compose', '-f', middlewareComposeFile, 'ps', '-q', service], + cwd: dockerDir, + stdio: 'pipe', + }) + + return result.stdout.trim() +} + +const getContainerHealth = async (containerId: string) => { + const result = await runCommand({ + command: 'docker', + args: ['inspect', '-f', '{{.State.Health.Status}}', containerId], + cwd: dockerDir, + stdio: 'pipe', + }) + + if (result.exitCode !== 0) return '' + + return result.stdout.trim() +} + +const printComposeLogs = async (services: string[]) => { + await runCommand({ + command: 'docker', + args: ['compose', '-f', middlewareComposeFile, 'logs', ...services], + cwd: dockerDir, + }) +} + +const waitForDependency = async ({ + description, + services, + wait, +}: { + description: string + services: string[] + wait: () => Promise +}) => { + console.log(`Waiting for ${description}...`) + + try { + await wait() + } catch (error) { + await printComposeLogs(services) + throw error + } +} + +export const ensureWebBuild = async () => { + await ensureWebEnvLocal() + + if (process.env.E2E_FORCE_WEB_BUILD === '1') { + await runCommandOrThrow({ + command: 'pnpm', + args: ['run', 'build'], + cwd: webDir, + }) + return + } + + try { + await access(buildIdPath) + console.log('Reusing existing web build artifact.') + } catch { + await runCommandOrThrow({ + command: 'pnpm', + args: ['run', 'build'], + cwd: webDir, + }) + } +} + +export const startWeb = async () => { + await ensureWebBuild() + + await runForegroundProcess({ + command: 'pnpm', + args: ['run', 'start'], + cwd: webDir, + env: { + HOSTNAME: '127.0.0.1', + PORT: '3000', + }, + }) +} + +export const startApi = async () => { + const env = await getApiEnvironment() + + await runCommandOrThrow({ + command: 'uv', + args: ['run', '--project', '.', 'flask', 'upgrade-db'], + cwd: apiDir, + env, + }) + + await runForegroundProcess({ + command: 'uv', + args: ['run', '--project', '.', 'flask', 'run', '--host', '127.0.0.1', '--port', '5001'], + cwd: apiDir, + env, + }) +} + +export const stopMiddleware = async () => { + await runCommandOrThrow({ + command: 'docker', + args: [...composeArgs, 'down', '--remove-orphans'], + cwd: dockerDir, + }) +} + +export const resetState = async () => { + console.log('Stopping middleware services...') + try { + await stopMiddleware() + } catch { + // Reset should continue even if middleware is already stopped. + } + + console.log('Removing persisted middleware data...') + await Promise.all( + middlewareDataPaths.map(async (targetPath) => { + await rm(targetPath, { force: true, recursive: true }) + await mkdir(targetPath, { recursive: true }) + }), + ) + + console.log('Removing E2E local state...') + await Promise.all( + e2eStatePaths.map((targetPath) => rm(targetPath, { force: true, recursive: true })), + ) + + console.log('E2E state reset complete.') +} + +export const startMiddleware = async () => { + await ensureFileExists(middlewareEnvFile, middlewareEnvExampleFile) + await ensureLineInFile(middlewareEnvFile, 'COMPOSE_PROFILES=postgresql,weaviate') + + console.log('Starting middleware services...') + await runCommandOrThrow({ + command: 'docker', + args: [ + ...composeArgs, + 'up', + '-d', + 'db_postgres', + 'redis', + 'weaviate', + 'sandbox', + 'ssrf_proxy', + 'plugin_daemon', + ], + cwd: dockerDir, + }) + + const [postgresContainerId, redisContainerId] = await Promise.all([ + getServiceContainerId('db_postgres'), + getServiceContainerId('redis'), + ]) + + await waitForDependency({ + description: 'PostgreSQL and Redis health checks', + services: ['db_postgres', 'redis'], + wait: () => + waitForCondition({ + check: async () => { + const [postgresStatus, redisStatus] = await Promise.all([ + getContainerHealth(postgresContainerId), + getContainerHealth(redisContainerId), + ]) + + return postgresStatus === 'healthy' && redisStatus === 'healthy' + }, + description: 'PostgreSQL and Redis health checks', + intervalMs: 2_000, + timeoutMs: 240_000, + }), + }) + + await waitForDependency({ + description: 'Weaviate readiness', + services: ['weaviate'], + wait: () => waitForUrl('http://127.0.0.1:8080/v1/.well-known/ready', 120_000, 2_000), + }) + + await waitForDependency({ + description: 'sandbox health', + services: ['sandbox', 'ssrf_proxy'], + wait: () => waitForUrl('http://127.0.0.1:8194/health', 120_000, 2_000), + }) + + await waitForDependency({ + description: 'plugin daemon port', + services: ['plugin_daemon'], + wait: () => + waitForCondition({ + check: async () => isTcpPortReachable('127.0.0.1', 5002), + description: 'plugin daemon port', + intervalMs: 2_000, + timeoutMs: 120_000, + }), + }) + + console.log('Full middleware stack is ready.') +} + +const printUsage = () => { + console.log('Usage: tsx ./scripts/setup.ts ') +} + +const main = async () => { + const command = process.argv[2] + + switch (command) { + case 'api': + await startApi() + return + case 'middleware-down': + await stopMiddleware() + return + case 'middleware-up': + await startMiddleware() + return + case 'reset': + await resetState() + return + case 'web': + await startWeb() + return + default: + printUsage() + process.exitCode = 1 + } +} + +if (isMainModule(import.meta.url)) { + void main().catch((error) => { + console.error(error instanceof Error ? error.message : String(error)) + process.exit(1) + }) +} diff --git a/e2e/support/process.ts b/e2e/support/process.ts new file mode 100644 index 00000000000..96273ef9312 --- /dev/null +++ b/e2e/support/process.ts @@ -0,0 +1,178 @@ +import type { ChildProcess } from 'node:child_process' +import { spawn } from 'node:child_process' +import { createWriteStream, type WriteStream } from 'node:fs' +import { mkdir } from 'node:fs/promises' +import net from 'node:net' +import { dirname } from 'node:path' + +type ManagedProcessOptions = { + command: string + args?: string[] + cwd: string + env?: NodeJS.ProcessEnv + label: string + logFilePath: string +} + +export type ManagedProcess = { + childProcess: ChildProcess + label: string + logFilePath: string + logStream: WriteStream +} + +export const sleep = (ms: number) => + new Promise((resolve) => { + setTimeout(resolve, ms) + }) + +export const isPortReachable = async (host: string, port: number, timeoutMs = 1_000) => { + return await new Promise((resolve) => { + const socket = net.createConnection({ + host, + port, + }) + + const finish = (result: boolean) => { + socket.removeAllListeners() + socket.destroy() + resolve(result) + } + + socket.setTimeout(timeoutMs) + socket.once('connect', () => finish(true)) + socket.once('timeout', () => finish(false)) + socket.once('error', () => finish(false)) + }) +} + +export const waitForUrl = async ( + url: string, + timeoutMs: number, + intervalMs = 1_000, + requestTimeoutMs = Math.max(intervalMs, 1_000), +) => { + const deadline = Date.now() + timeoutMs + + while (Date.now() < deadline) { + try { + const controller = new AbortController() + const timeout = setTimeout(() => controller.abort(), requestTimeoutMs) + + try { + const response = await fetch(url, { + signal: controller.signal, + }) + if (response.ok) return + } finally { + clearTimeout(timeout) + } + } catch { + // Keep polling until timeout. + } + + await sleep(intervalMs) + } + + throw new Error(`Timed out waiting for ${url} after ${timeoutMs}ms.`) +} + +export const startLoggedProcess = async ({ + command, + args = [], + cwd, + env, + label, + logFilePath, +}: ManagedProcessOptions): Promise => { + await mkdir(dirname(logFilePath), { recursive: true }) + + const logStream = createWriteStream(logFilePath, { flags: 'a' }) + const childProcess = spawn(command, args, { + cwd, + env: { + ...process.env, + ...env, + }, + detached: process.platform !== 'win32', + stdio: ['ignore', 'pipe', 'pipe'], + }) + + const formattedCommand = [command, ...args].join(' ') + logStream.write(`[${new Date().toISOString()}] Starting ${label}: ${formattedCommand}\n`) + childProcess.stdout?.pipe(logStream, { end: false }) + childProcess.stderr?.pipe(logStream, { end: false }) + + return { + childProcess, + label, + logFilePath, + logStream, + } +} + +const waitForProcessExit = (childProcess: ChildProcess, timeoutMs: number) => + new Promise((resolve) => { + if (childProcess.exitCode !== null) { + resolve() + return + } + + const timeout = setTimeout(() => { + cleanup() + resolve() + }, timeoutMs) + + const onExit = () => { + cleanup() + resolve() + } + + const cleanup = () => { + clearTimeout(timeout) + childProcess.off('exit', onExit) + } + + childProcess.once('exit', onExit) + }) + +const signalManagedProcess = (childProcess: ChildProcess, signal: NodeJS.Signals) => { + const { pid } = childProcess + if (!pid) return + + try { + if (process.platform !== 'win32') { + process.kill(-pid, signal) + return + } + + childProcess.kill(signal) + } catch { + // Best-effort shutdown. Cleanup continues even when the process is already gone. + } +} + +export const stopManagedProcess = async (managedProcess?: ManagedProcess) => { + if (!managedProcess) return + + const { childProcess, logStream } = managedProcess + + if (childProcess.exitCode === null) { + signalManagedProcess(childProcess, 'SIGTERM') + await waitForProcessExit(childProcess, 5_000) + } + + if (childProcess.exitCode === null) { + signalManagedProcess(childProcess, 'SIGKILL') + await waitForProcessExit(childProcess, 5_000) + } + + childProcess.stdout?.unpipe(logStream) + childProcess.stderr?.unpipe(logStream) + childProcess.stdout?.destroy() + childProcess.stderr?.destroy() + + await new Promise((resolve) => { + logStream.end(() => resolve()) + }) +} diff --git a/e2e/support/web-server.ts b/e2e/support/web-server.ts new file mode 100644 index 00000000000..ad5d5d916a1 --- /dev/null +++ b/e2e/support/web-server.ts @@ -0,0 +1,83 @@ +import type { ManagedProcess } from './process' +import { isPortReachable, startLoggedProcess, stopManagedProcess, waitForUrl } from './process' + +type WebServerStartOptions = { + baseURL: string + command: string + args?: string[] + cwd: string + logFilePath: string + reuseExistingServer: boolean + timeoutMs: number +} + +let activeProcess: ManagedProcess | undefined + +const getUrlHostAndPort = (url: string) => { + const parsedUrl = new URL(url) + const isHttps = parsedUrl.protocol === 'https:' + + return { + host: parsedUrl.hostname, + port: parsedUrl.port ? Number(parsedUrl.port) : isHttps ? 443 : 80, + } +} + +export const startWebServer = async ({ + baseURL, + command, + args = [], + cwd, + logFilePath, + reuseExistingServer, + timeoutMs, +}: WebServerStartOptions) => { + const { host, port } = getUrlHostAndPort(baseURL) + + if (reuseExistingServer && (await isPortReachable(host, port))) return + + activeProcess = await startLoggedProcess({ + command, + args, + cwd, + label: 'web server', + logFilePath, + }) + + let startupError: Error | undefined + activeProcess.childProcess.once('error', (error) => { + startupError = error + }) + activeProcess.childProcess.once('exit', (code, signal) => { + if (startupError) return + + startupError = new Error( + `Web server exited before readiness (code: ${code ?? 'unknown'}, signal: ${signal ?? 'none'}).`, + ) + }) + + const deadline = Date.now() + timeoutMs + while (Date.now() < deadline) { + if (startupError) { + await stopManagedProcess(activeProcess) + activeProcess = undefined + throw startupError + } + + try { + await waitForUrl(baseURL, 1_000, 250, 1_000) + return + } catch { + // Continue polling until timeout or child exit. + } + } + + await stopManagedProcess(activeProcess) + activeProcess = undefined + throw new Error(`Timed out waiting for web server readiness at ${baseURL} after ${timeoutMs}ms.`) +} + +export const stopWebServer = async () => { + await stopManagedProcess(activeProcess) + activeProcess = undefined +} diff --git a/e2e/test-env.ts b/e2e/test-env.ts new file mode 100644 index 00000000000..c0afc2a8c1f --- /dev/null +++ b/e2e/test-env.ts @@ -0,0 +1,12 @@ +export const defaultBaseURL = 'http://127.0.0.1:3000' +export const defaultApiURL = 'http://127.0.0.1:5001' +export const defaultLocale = 'en-US' + +export const baseURL = process.env.E2E_BASE_URL || defaultBaseURL +export const apiURL = process.env.E2E_API_URL || defaultApiURL + +export const cucumberHeadless = process.env.CUCUMBER_HEADLESS !== '0' +export const cucumberSlowMo = Number(process.env.E2E_SLOW_MO || 0) +export const reuseExistingWebServer = process.env.E2E_REUSE_WEB_SERVER + ? process.env.E2E_REUSE_WEB_SERVER !== '0' + : !process.env.CI diff --git a/e2e/tsconfig.json b/e2e/tsconfig.json new file mode 100644 index 00000000000..3976c12b667 --- /dev/null +++ b/e2e/tsconfig.json @@ -0,0 +1,25 @@ +{ + "compilerOptions": { + "target": "ES2023", + "lib": ["ES2023", "DOM"], + "module": "ESNext", + "moduleResolution": "Bundler", + "allowJs": false, + "resolveJsonModule": true, + "noEmit": true, + "strict": true, + "skipLibCheck": true, + "types": ["node", "@playwright/test", "@cucumber/cucumber"], + "isolatedModules": true, + "verbatimModuleSyntax": true + }, + "include": ["./**/*.ts"], + "exclude": [ + "./node_modules", + "./playwright-report", + "./test-results", + "./.auth", + "./cucumber-report", + "./.logs" + ] +} diff --git a/e2e/vite.config.ts b/e2e/vite.config.ts new file mode 100644 index 00000000000..98400d5b9b6 --- /dev/null +++ b/e2e/vite.config.ts @@ -0,0 +1,15 @@ +import { defineConfig } from 'vite-plus' + +export default defineConfig({ + lint: { + options: { + typeAware: true, + typeCheck: true, + denyWarnings: true, + }, + }, + fmt: { + singleQuote: true, + semi: false, + }, +}) diff --git a/package.json b/package.json new file mode 100644 index 00000000000..07f1e16153f --- /dev/null +++ b/package.json @@ -0,0 +1,11 @@ +{ + "name": "dify", + "private": true, + "engines": { + "node": "^22.22.1" + }, + "packageManager": "pnpm@10.33.0", + "devDependencies": { + "taze": "catalog:" + } +} diff --git a/web/pnpm-lock.yaml b/pnpm-lock.yaml similarity index 71% rename from web/pnpm-lock.yaml rename to pnpm-lock.yaml index 5c4ccfc5c82..01a96c5585c 100644 --- a/web/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -4,6 +4,564 @@ settings: autoInstallPeers: true excludeLinksFromLockfile: false +catalogs: + default: + '@amplitude/analytics-browser': + specifier: 2.38.0 + version: 2.38.0 + '@amplitude/plugin-session-replay-browser': + specifier: 1.27.5 + version: 1.27.5 + '@antfu/eslint-config': + specifier: 7.7.3 + version: 7.7.3 + '@base-ui/react': + specifier: 1.3.0 + version: 1.3.0 + '@chromatic-com/storybook': + specifier: 5.1.1 + version: 5.1.1 + '@cucumber/cucumber': + specifier: 12.7.0 + version: 12.7.0 + '@egoist/tailwindcss-icons': + specifier: 1.9.2 + version: 1.9.2 + '@emoji-mart/data': + specifier: 1.2.1 + version: 1.2.1 + '@eslint-react/eslint-plugin': + specifier: 3.0.0 + version: 3.0.0 + '@eslint/js': + specifier: ^10.0.1 + version: 10.0.1 + '@floating-ui/react': + specifier: 0.27.19 + version: 0.27.19 + '@formatjs/intl-localematcher': + specifier: 0.8.2 + version: 0.8.2 + '@headlessui/react': + specifier: 2.2.9 + version: 2.2.9 + '@heroicons/react': + specifier: 2.2.0 + version: 2.2.0 + '@hono/node-server': + specifier: 1.19.11 + version: 1.19.11 + '@iconify-json/heroicons': + specifier: 1.2.3 + version: 1.2.3 + '@iconify-json/ri': + specifier: 1.2.10 + version: 1.2.10 + '@lexical/link': + specifier: 0.42.0 + version: 0.42.0 + '@lexical/list': + specifier: 0.42.0 + version: 0.42.0 + '@lexical/react': + specifier: 0.42.0 + version: 0.42.0 + '@lexical/selection': + specifier: 0.42.0 + version: 0.42.0 + '@lexical/text': + specifier: 0.42.0 + version: 0.42.0 + '@lexical/utils': + specifier: 0.42.0 + version: 0.42.0 + '@mdx-js/loader': + specifier: 3.1.1 + version: 3.1.1 + '@mdx-js/react': + specifier: 3.1.1 + version: 3.1.1 + '@mdx-js/rollup': + specifier: 3.1.1 + version: 3.1.1 + '@monaco-editor/react': + specifier: 4.7.0 + version: 4.7.0 + '@next/eslint-plugin-next': + specifier: 16.2.1 + version: 16.2.1 + '@next/mdx': + specifier: 16.2.1 + version: 16.2.1 + '@orpc/client': + specifier: 1.13.13 + version: 1.13.13 + '@orpc/contract': + specifier: 1.13.13 + version: 1.13.13 + '@orpc/openapi-client': + specifier: 1.13.13 + version: 1.13.13 + '@orpc/tanstack-query': + specifier: 1.13.13 + version: 1.13.13 + '@playwright/test': + specifier: 1.58.2 + version: 1.58.2 + '@remixicon/react': + specifier: 4.9.0 + version: 4.9.0 + '@rgrove/parse-xml': + specifier: 4.2.0 + version: 4.2.0 + '@sentry/react': + specifier: 10.46.0 + version: 10.46.0 + '@storybook/addon-docs': + specifier: 10.3.3 + version: 10.3.3 + '@storybook/addon-links': + specifier: 10.3.3 + version: 10.3.3 + '@storybook/addon-onboarding': + specifier: 10.3.3 + version: 10.3.3 + '@storybook/addon-themes': + specifier: 10.3.3 + version: 10.3.3 + '@storybook/nextjs-vite': + specifier: 10.3.3 + version: 10.3.3 + '@storybook/react': + specifier: 10.3.3 + version: 10.3.3 + '@streamdown/math': + specifier: 1.0.2 + version: 1.0.2 + '@svgdotjs/svg.js': + specifier: 3.2.5 + version: 3.2.5 + '@t3-oss/env-nextjs': + specifier: 0.13.11 + version: 0.13.11 + '@tailwindcss/typography': + specifier: 0.5.19 + version: 0.5.19 + '@tanstack/eslint-plugin-query': + specifier: 5.95.2 + version: 5.95.2 + '@tanstack/react-devtools': + specifier: 0.10.0 + version: 0.10.0 + '@tanstack/react-form': + specifier: 1.28.5 + version: 1.28.5 + '@tanstack/react-form-devtools': + specifier: 0.2.19 + version: 0.2.19 + '@tanstack/react-query': + specifier: 5.95.2 + version: 5.95.2 + '@tanstack/react-query-devtools': + specifier: 5.95.2 + version: 5.95.2 + '@testing-library/dom': + specifier: 10.4.1 + version: 10.4.1 + '@testing-library/jest-dom': + specifier: 6.9.1 + version: 6.9.1 + '@testing-library/react': + specifier: 16.3.2 + version: 16.3.2 + '@testing-library/user-event': + specifier: 14.6.1 + version: 14.6.1 + '@tsslint/cli': + specifier: 3.0.2 + version: 3.0.2 + '@tsslint/compat-eslint': + specifier: 3.0.2 + version: 3.0.2 + '@tsslint/config': + specifier: 3.0.2 + version: 3.0.2 + '@types/js-cookie': + specifier: 3.0.6 + version: 3.0.6 + '@types/js-yaml': + specifier: 4.0.9 + version: 4.0.9 + '@types/negotiator': + specifier: 0.6.4 + version: 0.6.4 + '@types/node': + specifier: 25.5.0 + version: 25.5.0 + '@types/postcss-js': + specifier: 4.1.0 + version: 4.1.0 + '@types/qs': + specifier: 6.15.0 + version: 6.15.0 + '@types/react': + specifier: 19.2.14 + version: 19.2.14 + '@types/react-dom': + specifier: 19.2.3 + version: 19.2.3 + '@types/react-syntax-highlighter': + specifier: 15.5.13 + version: 15.5.13 + '@types/react-window': + specifier: 1.8.8 + version: 1.8.8 + '@types/sortablejs': + specifier: 1.15.9 + version: 1.15.9 + '@typescript-eslint/eslint-plugin': + specifier: ^8.57.2 + version: 8.57.2 + '@typescript-eslint/parser': + specifier: 8.57.2 + version: 8.57.2 + '@typescript/native-preview': + specifier: 7.0.0-dev.20260329.1 + version: 7.0.0-dev.20260329.1 + '@vitejs/plugin-react': + specifier: 6.0.1 + version: 6.0.1 + '@vitejs/plugin-rsc': + specifier: 0.5.21 + version: 0.5.21 + '@vitest/coverage-v8': + specifier: 4.1.2 + version: 4.1.2 + abcjs: + specifier: 6.6.2 + version: 6.6.2 + agentation: + specifier: 3.0.2 + version: 3.0.2 + ahooks: + specifier: 3.9.7 + version: 3.9.7 + autoprefixer: + specifier: 10.4.27 + version: 10.4.27 + axios: + specifier: ^1.14.0 + version: 1.14.0 + class-variance-authority: + specifier: 0.7.1 + version: 0.7.1 + clsx: + specifier: 2.1.1 + version: 2.1.1 + cmdk: + specifier: 1.1.1 + version: 1.1.1 + code-inspector-plugin: + specifier: 1.4.5 + version: 1.4.5 + copy-to-clipboard: + specifier: 3.3.3 + version: 3.3.3 + cron-parser: + specifier: 5.5.0 + version: 5.5.0 + dayjs: + specifier: 1.11.20 + version: 1.11.20 + decimal.js: + specifier: 10.6.0 + version: 10.6.0 + dompurify: + specifier: 3.3.3 + version: 3.3.3 + echarts: + specifier: 6.0.0 + version: 6.0.0 + echarts-for-react: + specifier: 3.0.6 + version: 3.0.6 + elkjs: + specifier: 0.11.1 + version: 0.11.1 + embla-carousel-autoplay: + specifier: 8.6.0 + version: 8.6.0 + embla-carousel-react: + specifier: 8.6.0 + version: 8.6.0 + emoji-mart: + specifier: 5.6.0 + version: 5.6.0 + es-toolkit: + specifier: 1.45.1 + version: 1.45.1 + eslint: + specifier: 10.1.0 + version: 10.1.0 + eslint-markdown: + specifier: 0.6.0 + version: 0.6.0 + eslint-plugin-better-tailwindcss: + specifier: 4.3.2 + version: 4.3.2 + eslint-plugin-hyoban: + specifier: 0.14.1 + version: 0.14.1 + eslint-plugin-markdown-preferences: + specifier: 0.40.3 + version: 0.40.3 + eslint-plugin-no-barrel-files: + specifier: 1.2.2 + version: 1.2.2 + eslint-plugin-react-hooks: + specifier: 7.0.1 + version: 7.0.1 + eslint-plugin-react-refresh: + specifier: 0.5.2 + version: 0.5.2 + eslint-plugin-sonarjs: + specifier: 4.0.2 + version: 4.0.2 + eslint-plugin-storybook: + specifier: 10.3.3 + version: 10.3.3 + fast-deep-equal: + specifier: 3.1.3 + version: 3.1.3 + foxact: + specifier: 0.3.0 + version: 0.3.0 + happy-dom: + specifier: 20.8.9 + version: 20.8.9 + hono: + specifier: 4.12.9 + version: 4.12.9 + html-entities: + specifier: 2.6.0 + version: 2.6.0 + html-to-image: + specifier: 1.11.13 + version: 1.11.13 + husky: + specifier: 9.1.7 + version: 9.1.7 + i18next: + specifier: 25.10.10 + version: 25.10.10 + i18next-resources-to-backend: + specifier: 1.2.1 + version: 1.2.1 + iconify-import-svg: + specifier: 0.1.2 + version: 0.1.2 + immer: + specifier: 11.1.4 + version: 11.1.4 + jotai: + specifier: 2.19.0 + version: 2.19.0 + js-audio-recorder: + specifier: 1.0.7 + version: 1.0.7 + js-cookie: + specifier: 3.0.5 + version: 3.0.5 + js-yaml: + specifier: 4.1.1 + version: 4.1.1 + jsonschema: + specifier: 1.5.0 + version: 1.5.0 + katex: + specifier: 0.16.44 + version: 0.16.44 + knip: + specifier: 6.1.0 + version: 6.1.0 + ky: + specifier: 1.14.3 + version: 1.14.3 + lamejs: + specifier: 1.2.1 + version: 1.2.1 + lexical: + specifier: 0.42.0 + version: 0.42.0 + lint-staged: + specifier: 16.4.0 + version: 16.4.0 + mermaid: + specifier: 11.13.0 + version: 11.13.0 + mime: + specifier: 4.1.0 + version: 4.1.0 + mitt: + specifier: 3.0.1 + version: 3.0.1 + negotiator: + specifier: 1.0.0 + version: 1.0.0 + next: + specifier: 16.2.1 + version: 16.2.1 + next-themes: + specifier: 0.4.6 + version: 0.4.6 + nuqs: + specifier: 2.8.9 + version: 2.8.9 + pinyin-pro: + specifier: 3.28.0 + version: 3.28.0 + postcss: + specifier: 8.5.8 + version: 8.5.8 + postcss-js: + specifier: 5.1.0 + version: 5.1.0 + qrcode.react: + specifier: 4.2.0 + version: 4.2.0 + qs: + specifier: 6.15.0 + version: 6.15.0 + react: + specifier: 19.2.4 + version: 19.2.4 + react-18-input-autosize: + specifier: 3.0.0 + version: 3.0.0 + react-dom: + specifier: 19.2.4 + version: 19.2.4 + react-easy-crop: + specifier: 5.5.7 + version: 5.5.7 + react-hotkeys-hook: + specifier: 5.2.4 + version: 5.2.4 + react-i18next: + specifier: 16.6.6 + version: 16.6.6 + react-multi-email: + specifier: 1.0.25 + version: 1.0.25 + react-papaparse: + specifier: 4.4.0 + version: 4.4.0 + react-pdf-highlighter: + specifier: 8.0.0-rc.0 + version: 8.0.0-rc.0 + react-server-dom-webpack: + specifier: 19.2.4 + version: 19.2.4 + react-sortablejs: + specifier: 6.1.4 + version: 6.1.4 + react-syntax-highlighter: + specifier: 15.6.6 + version: 15.6.6 + react-textarea-autosize: + specifier: 8.5.9 + version: 8.5.9 + react-window: + specifier: 1.8.11 + version: 1.8.11 + reactflow: + specifier: 11.11.4 + version: 11.11.4 + remark-breaks: + specifier: 4.0.0 + version: 4.0.0 + remark-directive: + specifier: 4.0.0 + version: 4.0.0 + sass: + specifier: 1.98.0 + version: 1.98.0 + scheduler: + specifier: 0.27.0 + version: 0.27.0 + sharp: + specifier: 0.34.5 + version: 0.34.5 + sortablejs: + specifier: 1.15.7 + version: 1.15.7 + std-semver: + specifier: 1.0.8 + version: 1.0.8 + storybook: + specifier: 10.3.3 + version: 10.3.3 + streamdown: + specifier: 2.5.0 + version: 2.5.0 + string-ts: + specifier: 2.3.1 + version: 2.3.1 + tailwind-merge: + specifier: 2.6.1 + version: 2.6.1 + tailwindcss: + specifier: 3.4.19 + version: 3.4.19 + taze: + specifier: 19.10.0 + version: 19.10.0 + tldts: + specifier: 7.0.27 + version: 7.0.27 + tsup: + specifier: ^8.5.1 + version: 8.5.1 + tsx: + specifier: 4.21.0 + version: 4.21.0 + typescript: + specifier: 5.9.3 + version: 5.9.3 + uglify-js: + specifier: 3.19.3 + version: 3.19.3 + unist-util-visit: + specifier: 5.1.0 + version: 5.1.0 + use-context-selector: + specifier: 2.0.0 + version: 2.0.0 + uuid: + specifier: 13.0.0 + version: 13.0.0 + vinext: + specifier: 0.0.38 + version: 0.0.38 + vite-plugin-inspect: + specifier: 12.0.0-beta.1 + version: 12.0.0-beta.1 + vite-plus: + specifier: 0.1.14 + version: 0.1.14 + vitest-canvas-mock: + specifier: 1.1.4 + version: 1.1.4 + zod: + specifier: 4.3.6 + version: 4.3.6 + zundo: + specifier: 2.3.0 + version: 2.3.0 + zustand: + specifier: 5.0.12 + version: 5.0.12 + overrides: '@lexical/code': npm:lexical-code-no-prism@0.41.0 '@monaco-editor/loader': 1.7.0 @@ -16,11 +574,12 @@ overrides: array.prototype.tosorted: npm:@nolyfill/array.prototype.tosorted@^1.0.44 assert: npm:@nolyfill/assert@^1.0.26 brace-expansion@<2.0.2: 2.0.2 - canvas: ^3.2.1 + canvas: ^3.2.2 devalue@<5.3.2: 5.3.2 dompurify@>=3.1.3 <=3.3.1: 3.3.2 es-iterator-helpers: npm:@nolyfill/es-iterator-helpers@^1.0.21 esbuild@<0.27.2: 0.27.2 + flatted@<=3.4.1: 3.4.2 glob@>=10.2.0 <10.5.0: 11.1.0 hasown: npm:@nolyfill/hasown@^1.0.44 is-arguments: npm:@nolyfill/is-arguments@^1.0.44 @@ -35,6 +594,8 @@ overrides: object.values: npm:@nolyfill/object.values@^1.0.44 pbkdf2: ~3.1.5 pbkdf2@<3.1.3: 3.1.3 + picomatch@<2.3.2: 2.3.2 + picomatch@>=4.0.0 <4.0.4: 4.0.4 prismjs: ~1.30 prismjs@<1.30.0: 1.30.0 rollup@>=4.0.0 <4.59.0: 4.59.0 @@ -42,6 +603,7 @@ overrides: safe-regex-test: npm:@nolyfill/safe-regex-test@^1.0.44 safer-buffer: npm:@nolyfill/safer-buffer@^1.0.44 side-channel: npm:@nolyfill/side-channel@^1.0.44 + smol-toml@<1.6.1: 1.6.1 solid-js: 1.9.11 string-width: ~8.2.0 string.prototype.includes: npm:@nolyfill/string.prototype.includes@^1.0.44 @@ -52,571 +614,621 @@ overrides: tar@<=7.5.10: 7.5.11 typed-array-buffer: npm:@nolyfill/typed-array-buffer@^1.0.44 undici@>=7.0.0 <7.24.0: 7.24.0 - vite: npm:@voidzero-dev/vite-plus-core@0.1.12 - vitest: npm:@voidzero-dev/vite-plus-test@0.1.12 + vite: npm:@voidzero-dev/vite-plus-core@0.1.14 + vitest: npm:@voidzero-dev/vite-plus-test@0.1.14 which-typed-array: npm:@nolyfill/which-typed-array@^1.0.44 + yaml@>=2.0.0 <2.8.3: 2.8.3 yauzl@<3.2.1: 3.2.1 importers: .: + devDependencies: + taze: + specifier: 'catalog:' + version: 19.10.0 + + e2e: + devDependencies: + '@cucumber/cucumber': + specifier: 'catalog:' + version: 12.7.0 + '@playwright/test': + specifier: 'catalog:' + version: 1.58.2 + '@types/node': + specifier: 'catalog:' + version: 25.5.0 + tsx: + specifier: 'catalog:' + version: 4.21.0 + typescript: + specifier: 'catalog:' + version: 5.9.3 + vite-plus: + specifier: 'catalog:' + version: 0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(vite@8.0.3(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1)(@types/node@25.5.0)(esbuild@0.27.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(yaml@2.8.3))(yaml@2.8.3) + + sdks/nodejs-client: + dependencies: + axios: + specifier: 'catalog:' + version: 1.14.0 + devDependencies: + '@eslint/js': + specifier: 'catalog:' + version: 10.0.1(eslint@10.1.0(jiti@2.6.1)) + '@types/node': + specifier: 'catalog:' + version: 25.5.0 + '@typescript-eslint/eslint-plugin': + specifier: 'catalog:' + version: 8.57.2(@typescript-eslint/parser@8.57.2(eslint@10.1.0(jiti@2.6.1))(typescript@5.9.3))(eslint@10.1.0(jiti@2.6.1))(typescript@5.9.3) + '@typescript-eslint/parser': + specifier: 'catalog:' + version: 8.57.2(eslint@10.1.0(jiti@2.6.1))(typescript@5.9.3) + '@vitest/coverage-v8': + specifier: 'catalog:' + version: 4.1.2(@voidzero-dev/vite-plus-test@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(vite@8.0.3(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1)(@types/node@25.5.0)(esbuild@0.27.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(yaml@2.8.3))(yaml@2.8.3)) + eslint: + specifier: 'catalog:' + version: 10.1.0(jiti@2.6.1) + tsup: + specifier: 'catalog:' + version: 8.5.1(jiti@2.6.1)(postcss@8.5.8)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3) + typescript: + specifier: 'catalog:' + version: 5.9.3 + vitest: + specifier: npm:@voidzero-dev/vite-plus-test@0.1.14 + version: '@voidzero-dev/vite-plus-test@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(vite@8.0.3(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1)(@types/node@25.5.0)(esbuild@0.27.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(yaml@2.8.3))(yaml@2.8.3)' + + web: dependencies: '@amplitude/analytics-browser': - specifier: 2.36.7 - version: 2.36.7 + specifier: 'catalog:' + version: 2.38.0 '@amplitude/plugin-session-replay-browser': - specifier: 1.26.4 - version: 1.26.4(@amplitude/rrweb@2.0.0-alpha.35)(rollup@4.59.0) + specifier: 'catalog:' + version: 1.27.5(@amplitude/rrweb@2.0.0-alpha.37)(rollup@4.59.0) '@base-ui/react': - specifier: 1.3.0 + specifier: 'catalog:' version: 1.3.0(@types/react@19.2.14)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) '@emoji-mart/data': - specifier: 1.2.1 + specifier: 'catalog:' version: 1.2.1 '@floating-ui/react': - specifier: 0.27.19 + specifier: 'catalog:' version: 0.27.19(react-dom@19.2.4(react@19.2.4))(react@19.2.4) '@formatjs/intl-localematcher': - specifier: 0.8.2 + specifier: 'catalog:' version: 0.8.2 '@headlessui/react': - specifier: 2.2.9 + specifier: 'catalog:' version: 2.2.9(react-dom@19.2.4(react@19.2.4))(react@19.2.4) '@heroicons/react': - specifier: 2.2.0 + specifier: 'catalog:' version: 2.2.0(react@19.2.4) - '@hono/node-server': - specifier: 1.19.11 - version: 1.19.11(hono@4.12.8) '@lexical/code': specifier: npm:lexical-code-no-prism@0.41.0 - version: lexical-code-no-prism@0.41.0(@lexical/utils@0.41.0)(lexical@0.41.0) + version: lexical-code-no-prism@0.41.0(@lexical/utils@0.42.0)(lexical@0.42.0) '@lexical/link': - specifier: 0.41.0 - version: 0.41.0 + specifier: 'catalog:' + version: 0.42.0 '@lexical/list': - specifier: 0.41.0 - version: 0.41.0 + specifier: 'catalog:' + version: 0.42.0 '@lexical/react': - specifier: 0.41.0 - version: 0.41.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(yjs@13.6.30) + specifier: 'catalog:' + version: 0.42.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(yjs@13.6.30) '@lexical/selection': - specifier: 0.41.0 - version: 0.41.0 + specifier: 'catalog:' + version: 0.42.0 '@lexical/text': - specifier: 0.41.0 - version: 0.41.0 + specifier: 'catalog:' + version: 0.42.0 '@lexical/utils': - specifier: 0.41.0 - version: 0.41.0 + specifier: 'catalog:' + version: 0.42.0 '@monaco-editor/react': - specifier: 4.7.0 + specifier: 'catalog:' version: 4.7.0(monaco-editor@0.55.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) - '@octokit/core': - specifier: 7.0.6 - version: 7.0.6 - '@octokit/request-error': - specifier: 7.1.0 - version: 7.1.0 '@orpc/client': - specifier: 1.13.8 - version: 1.13.8 + specifier: 'catalog:' + version: 1.13.13 '@orpc/contract': - specifier: 1.13.8 - version: 1.13.8 + specifier: 'catalog:' + version: 1.13.13 '@orpc/openapi-client': - specifier: 1.13.8 - version: 1.13.8 + specifier: 'catalog:' + version: 1.13.13 '@orpc/tanstack-query': - specifier: 1.13.8 - version: 1.13.8(@orpc/client@1.13.8)(@tanstack/query-core@5.91.0) + specifier: 'catalog:' + version: 1.13.13(@orpc/client@1.13.13)(@tanstack/query-core@5.95.2) '@remixicon/react': - specifier: 4.9.0 + specifier: 'catalog:' version: 4.9.0(react@19.2.4) '@sentry/react': - specifier: 10.44.0 - version: 10.44.0(react@19.2.4) + specifier: 'catalog:' + version: 10.46.0(react@19.2.4) '@streamdown/math': - specifier: 1.0.2 + specifier: 'catalog:' version: 1.0.2(react@19.2.4) '@svgdotjs/svg.js': - specifier: 3.2.5 + specifier: 'catalog:' version: 3.2.5 '@t3-oss/env-nextjs': - specifier: 0.13.10 - version: 0.13.10(typescript@5.9.3)(valibot@1.3.0(typescript@5.9.3))(zod@4.3.6) + specifier: 'catalog:' + version: 0.13.11(typescript@5.9.3)(valibot@1.3.1(typescript@5.9.3))(zod@4.3.6) '@tailwindcss/typography': - specifier: 0.5.19 - version: 0.5.19(tailwindcss@3.4.19(tsx@4.21.0)(yaml@2.8.2)) + specifier: 'catalog:' + version: 0.5.19(tailwindcss@3.4.19(tsx@4.21.0)(yaml@2.8.3)) '@tanstack/react-form': - specifier: 1.28.5 + specifier: 'catalog:' version: 1.28.5(react-dom@19.2.4(react@19.2.4))(react@19.2.4) '@tanstack/react-query': - specifier: 5.91.0 - version: 5.91.0(react@19.2.4) + specifier: 'catalog:' + version: 5.95.2(react@19.2.4) abcjs: - specifier: 6.6.2 + specifier: 'catalog:' version: 6.6.2 ahooks: - specifier: 3.9.6 - version: 3.9.6(react-dom@19.2.4(react@19.2.4))(react@19.2.4) + specifier: 'catalog:' + version: 3.9.7(react-dom@19.2.4(react@19.2.4))(react@19.2.4) class-variance-authority: - specifier: 0.7.1 + specifier: 'catalog:' version: 0.7.1 clsx: - specifier: 2.1.1 + specifier: 'catalog:' version: 2.1.1 cmdk: - specifier: 1.1.1 + specifier: 'catalog:' version: 1.1.1(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) copy-to-clipboard: - specifier: 3.3.3 + specifier: 'catalog:' version: 3.3.3 cron-parser: - specifier: 5.5.0 + specifier: 'catalog:' version: 5.5.0 dayjs: - specifier: 1.11.20 + specifier: 'catalog:' version: 1.11.20 decimal.js: - specifier: 10.6.0 + specifier: 'catalog:' version: 10.6.0 dompurify: - specifier: 3.3.3 + specifier: 'catalog:' version: 3.3.3 echarts: - specifier: 6.0.0 + specifier: 'catalog:' version: 6.0.0 echarts-for-react: - specifier: 3.0.6 + specifier: 'catalog:' version: 3.0.6(echarts@6.0.0)(react@19.2.4) elkjs: - specifier: 0.11.1 + specifier: 'catalog:' version: 0.11.1 embla-carousel-autoplay: - specifier: 8.6.0 + specifier: 'catalog:' version: 8.6.0(embla-carousel@8.6.0) embla-carousel-react: - specifier: 8.6.0 + specifier: 'catalog:' version: 8.6.0(react@19.2.4) emoji-mart: - specifier: 5.6.0 + specifier: 'catalog:' version: 5.6.0 es-toolkit: - specifier: 1.45.1 + specifier: 'catalog:' version: 1.45.1 fast-deep-equal: - specifier: 3.1.3 + specifier: 'catalog:' version: 3.1.3 foxact: - specifier: 0.3.0 + specifier: 'catalog:' version: 0.3.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4) - hono: - specifier: 4.12.8 - version: 4.12.8 html-entities: - specifier: 2.6.0 + specifier: 'catalog:' version: 2.6.0 html-to-image: - specifier: 1.11.13 + specifier: 'catalog:' version: 1.11.13 i18next: - specifier: 25.8.18 - version: 25.8.18(typescript@5.9.3) + specifier: 'catalog:' + version: 25.10.10(typescript@5.9.3) i18next-resources-to-backend: - specifier: 1.2.1 + specifier: 'catalog:' version: 1.2.1 immer: - specifier: 11.1.4 + specifier: 'catalog:' version: 11.1.4 jotai: - specifier: 2.18.1 - version: 2.18.1(@babel/core@7.29.0)(@babel/template@7.28.6)(@types/react@19.2.14)(react@19.2.4) + specifier: 'catalog:' + version: 2.19.0(@babel/core@7.29.0)(@babel/template@7.28.6)(@types/react@19.2.14)(react@19.2.4) js-audio-recorder: - specifier: 1.0.7 + specifier: 'catalog:' version: 1.0.7 js-cookie: - specifier: 3.0.5 + specifier: 'catalog:' version: 3.0.5 js-yaml: - specifier: 4.1.1 + specifier: 'catalog:' version: 4.1.1 jsonschema: - specifier: 1.5.0 + specifier: 'catalog:' version: 1.5.0 katex: - specifier: 0.16.38 - version: 0.16.38 + specifier: 'catalog:' + version: 0.16.44 ky: - specifier: 1.14.3 + specifier: 'catalog:' version: 1.14.3 lamejs: - specifier: 1.2.1 + specifier: 'catalog:' version: 1.2.1 lexical: - specifier: 0.41.0 - version: 0.41.0 + specifier: 'catalog:' + version: 0.42.0 mermaid: - specifier: 11.13.0 + specifier: 'catalog:' version: 11.13.0 mime: - specifier: 4.1.0 + specifier: 'catalog:' version: 4.1.0 mitt: - specifier: 3.0.1 + specifier: 'catalog:' version: 3.0.1 negotiator: - specifier: 1.0.0 + specifier: 'catalog:' version: 1.0.0 next: - specifier: 16.2.0 - version: 16.2.0(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0) + specifier: 'catalog:' + version: 16.2.1(@babel/core@7.29.0)(@playwright/test@1.58.2)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0) next-themes: - specifier: 0.4.6 + specifier: 'catalog:' version: 0.4.6(react-dom@19.2.4(react@19.2.4))(react@19.2.4) nuqs: - specifier: 2.8.9 - version: 2.8.9(next@16.2.0(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0))(react@19.2.4) + specifier: 'catalog:' + version: 2.8.9(next@16.2.1(@babel/core@7.29.0)(@playwright/test@1.58.2)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0))(react@19.2.4) pinyin-pro: - specifier: 3.28.0 + specifier: 'catalog:' version: 3.28.0 qrcode.react: - specifier: 4.2.0 + specifier: 'catalog:' version: 4.2.0(react@19.2.4) qs: - specifier: 6.15.0 + specifier: 'catalog:' version: 6.15.0 react: - specifier: 19.2.4 + specifier: 'catalog:' version: 19.2.4 react-18-input-autosize: - specifier: 3.0.0 + specifier: 'catalog:' version: 3.0.0(react@19.2.4) react-dom: - specifier: 19.2.4 + specifier: 'catalog:' version: 19.2.4(react@19.2.4) react-easy-crop: - specifier: 5.5.6 - version: 5.5.6(react-dom@19.2.4(react@19.2.4))(react@19.2.4) + specifier: 'catalog:' + version: 5.5.7(react-dom@19.2.4(react@19.2.4))(react@19.2.4) react-hotkeys-hook: - specifier: 5.2.4 + specifier: 'catalog:' version: 5.2.4(react-dom@19.2.4(react@19.2.4))(react@19.2.4) react-i18next: - specifier: 16.5.8 - version: 16.5.8(i18next@25.8.18(typescript@5.9.3))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3) + specifier: 'catalog:' + version: 16.6.6(i18next@25.10.10(typescript@5.9.3))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3) react-multi-email: - specifier: 1.0.25 + specifier: 'catalog:' version: 1.0.25(react-dom@19.2.4(react@19.2.4))(react@19.2.4) react-papaparse: - specifier: 4.4.0 + specifier: 'catalog:' version: 4.4.0 react-pdf-highlighter: - specifier: 8.0.0-rc.0 + specifier: 'catalog:' version: 8.0.0-rc.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4) - react-slider: - specifier: 2.0.6 - version: 2.0.6(react@19.2.4) react-sortablejs: - specifier: 6.1.4 + specifier: 'catalog:' version: 6.1.4(@types/sortablejs@1.15.9)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sortablejs@1.15.7) react-syntax-highlighter: - specifier: 15.6.6 + specifier: 'catalog:' version: 15.6.6(react@19.2.4) react-textarea-autosize: - specifier: 8.5.9 + specifier: 'catalog:' version: 8.5.9(@types/react@19.2.14)(react@19.2.4) react-window: - specifier: 1.8.11 + specifier: 'catalog:' version: 1.8.11(react-dom@19.2.4(react@19.2.4))(react@19.2.4) reactflow: - specifier: 11.11.4 + specifier: 'catalog:' version: 11.11.4(@types/react@19.2.14)(immer@11.1.4)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) remark-breaks: - specifier: 4.0.0 + specifier: 'catalog:' version: 4.0.0 remark-directive: - specifier: 4.0.0 + specifier: 'catalog:' version: 4.0.0 scheduler: - specifier: 0.27.0 + specifier: 'catalog:' version: 0.27.0 sharp: - specifier: 0.34.5 + specifier: 'catalog:' version: 0.34.5 sortablejs: - specifier: 1.15.7 + specifier: 'catalog:' version: 1.15.7 std-semver: - specifier: 1.0.8 + specifier: 'catalog:' version: 1.0.8 streamdown: - specifier: 2.5.0 + specifier: 'catalog:' version: 2.5.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4) string-ts: - specifier: 2.3.1 + specifier: 'catalog:' version: 2.3.1 tailwind-merge: - specifier: 2.6.1 + specifier: 'catalog:' version: 2.6.1 tldts: - specifier: 7.0.26 - version: 7.0.26 + specifier: 'catalog:' + version: 7.0.27 unist-util-visit: - specifier: 5.1.0 + specifier: 'catalog:' version: 5.1.0 use-context-selector: - specifier: 2.0.0 + specifier: 'catalog:' version: 2.0.0(react@19.2.4)(scheduler@0.27.0) uuid: - specifier: 13.0.0 + specifier: 'catalog:' version: 13.0.0 zod: - specifier: 4.3.6 + specifier: 'catalog:' version: 4.3.6 zundo: - specifier: 2.3.0 + specifier: 'catalog:' version: 2.3.0(zustand@5.0.12(@types/react@19.2.14)(immer@11.1.4)(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4))) zustand: - specifier: 5.0.12 + specifier: 'catalog:' version: 5.0.12(@types/react@19.2.14)(immer@11.1.4)(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)) devDependencies: '@antfu/eslint-config': - specifier: 7.7.3 - version: 7.7.3(@eslint-react/eslint-plugin@2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3))(@next/eslint-plugin-next@16.2.0)(@typescript-eslint/rule-tester@8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3))(@typescript-eslint/typescript-estree@8.57.1(typescript@5.9.3))(@typescript-eslint/utils@8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3))(@voidzero-dev/vite-plus-test@0.1.12(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.0(canvas@3.2.1))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(@vue/compiler-sfc@3.5.30)(eslint-plugin-react-hooks@7.0.1(eslint@10.0.3(jiti@1.21.7)))(eslint-plugin-react-refresh@0.5.2(eslint@10.0.3(jiti@1.21.7)))(eslint@10.0.3(jiti@1.21.7))(oxlint@1.55.0(oxlint-tsgolint@0.17.0))(typescript@5.9.3) + specifier: 'catalog:' + version: 7.7.3(@eslint-react/eslint-plugin@3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(@next/eslint-plugin-next@16.2.1)(@typescript-eslint/rule-tester@8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(@typescript-eslint/typescript-estree@8.57.2(typescript@5.9.3))(@typescript-eslint/utils@8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(@voidzero-dev/vite-plus-test@0.1.14(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(@vue/compiler-sfc@3.5.31)(eslint-plugin-react-hooks@7.0.1(eslint@10.1.0(jiti@1.21.7)))(eslint-plugin-react-refresh@0.5.2(eslint@10.1.0(jiti@1.21.7)))(eslint@10.1.0(jiti@1.21.7))(oxlint@1.57.0(oxlint-tsgolint@0.17.3))(typescript@5.9.3) '@chromatic-com/storybook': - specifier: 5.0.1 - version: 5.0.1(storybook@10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)) + specifier: 'catalog:' + version: 5.1.1(storybook@10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)) '@egoist/tailwindcss-icons': - specifier: 1.9.2 - version: 1.9.2(tailwindcss@3.4.19(tsx@4.21.0)(yaml@2.8.2)) + specifier: 'catalog:' + version: 1.9.2(tailwindcss@3.4.19(tsx@4.21.0)(yaml@2.8.3)) '@eslint-react/eslint-plugin': - specifier: 2.13.0 - version: 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) + specifier: 'catalog:' + version: 3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@hono/node-server': + specifier: 'catalog:' + version: 1.19.11(hono@4.12.9) '@iconify-json/heroicons': - specifier: 1.2.3 + specifier: 'catalog:' version: 1.2.3 '@iconify-json/ri': - specifier: 1.2.10 + specifier: 'catalog:' version: 1.2.10 '@mdx-js/loader': - specifier: 3.1.1 + specifier: 'catalog:' version: 3.1.1(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) '@mdx-js/react': - specifier: 3.1.1 + specifier: 'catalog:' version: 3.1.1(@types/react@19.2.14)(react@19.2.4) '@mdx-js/rollup': - specifier: 3.1.1 + specifier: 'catalog:' version: 3.1.1(rollup@4.59.0) '@next/eslint-plugin-next': - specifier: 16.2.0 - version: 16.2.0 + specifier: 'catalog:' + version: 16.2.1 '@next/mdx': - specifier: 16.2.0 - version: 16.2.0(@mdx-js/loader@3.1.1(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)))(@mdx-js/react@3.1.1(@types/react@19.2.14)(react@19.2.4)) + specifier: 'catalog:' + version: 16.2.1(@mdx-js/loader@3.1.1(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)))(@mdx-js/react@3.1.1(@types/react@19.2.14)(react@19.2.4)) '@rgrove/parse-xml': - specifier: 4.2.0 + specifier: 'catalog:' version: 4.2.0 '@storybook/addon-docs': - specifier: 10.3.0 - version: 10.3.0(@types/react@19.2.14)(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(rollup@4.59.0)(storybook@10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) + specifier: 'catalog:' + version: 10.3.3(@types/react@19.2.14)(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(rollup@4.59.0)(storybook@10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) '@storybook/addon-links': - specifier: 10.3.0 - version: 10.3.0(react@19.2.4)(storybook@10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)) + specifier: 'catalog:' + version: 10.3.3(react@19.2.4)(storybook@10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)) '@storybook/addon-onboarding': - specifier: 10.3.0 - version: 10.3.0(storybook@10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)) + specifier: 'catalog:' + version: 10.3.3(storybook@10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)) '@storybook/addon-themes': - specifier: 10.3.0 - version: 10.3.0(storybook@10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)) + specifier: 'catalog:' + version: 10.3.3(storybook@10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)) '@storybook/nextjs-vite': - specifier: 10.3.0 - version: 10.3.0(@babel/core@7.29.0)(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(next@16.2.0(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(rollup@4.59.0)(storybook@10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) + specifier: 'catalog:' + version: 10.3.3(@babel/core@7.29.0)(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(next@16.2.1(@babel/core@7.29.0)(@playwright/test@1.58.2)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(rollup@4.59.0)(storybook@10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) '@storybook/react': - specifier: 10.3.0 - version: 10.3.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(storybook@10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3) + specifier: 'catalog:' + version: 10.3.3(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(storybook@10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3) '@tanstack/eslint-plugin-query': - specifier: 5.91.5 - version: 5.91.5(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) + specifier: 'catalog:' + version: 5.95.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) '@tanstack/react-devtools': - specifier: 0.10.0 + specifier: 'catalog:' version: 0.10.0(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(csstype@3.2.3)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(solid-js@1.9.11) '@tanstack/react-form-devtools': - specifier: 0.2.19 + specifier: 'catalog:' version: 0.2.19(@types/react@19.2.14)(csstype@3.2.3)(react@19.2.4)(solid-js@1.9.11) '@tanstack/react-query-devtools': - specifier: 5.91.3 - version: 5.91.3(@tanstack/react-query@5.91.0(react@19.2.4))(react@19.2.4) + specifier: 'catalog:' + version: 5.95.2(@tanstack/react-query@5.95.2(react@19.2.4))(react@19.2.4) '@testing-library/dom': - specifier: 10.4.1 + specifier: 'catalog:' version: 10.4.1 '@testing-library/jest-dom': - specifier: 6.9.1 + specifier: 'catalog:' version: 6.9.1 '@testing-library/react': - specifier: 16.3.2 + specifier: 'catalog:' version: 16.3.2(@testing-library/dom@10.4.1)(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) '@testing-library/user-event': - specifier: 14.6.1 + specifier: 'catalog:' version: 14.6.1(@testing-library/dom@10.4.1) '@tsslint/cli': - specifier: 3.0.2 + specifier: 'catalog:' version: 3.0.2(@tsslint/compat-eslint@3.0.2(jiti@1.21.7)(typescript@5.9.3))(typescript@5.9.3) '@tsslint/compat-eslint': - specifier: 3.0.2 + specifier: 'catalog:' version: 3.0.2(jiti@1.21.7)(typescript@5.9.3) '@tsslint/config': - specifier: 3.0.2 + specifier: 'catalog:' version: 3.0.2(@tsslint/compat-eslint@3.0.2(jiti@1.21.7)(typescript@5.9.3))(typescript@5.9.3) '@types/js-cookie': - specifier: 3.0.6 + specifier: 'catalog:' version: 3.0.6 '@types/js-yaml': - specifier: 4.0.9 + specifier: 'catalog:' version: 4.0.9 '@types/negotiator': - specifier: 0.6.4 + specifier: 'catalog:' version: 0.6.4 '@types/node': - specifier: 25.5.0 + specifier: 'catalog:' version: 25.5.0 '@types/postcss-js': - specifier: 4.1.0 + specifier: 'catalog:' version: 4.1.0 '@types/qs': - specifier: 6.15.0 + specifier: 'catalog:' version: 6.15.0 '@types/react': - specifier: 19.2.14 + specifier: 'catalog:' version: 19.2.14 '@types/react-dom': - specifier: 19.2.3 + specifier: 'catalog:' version: 19.2.3(@types/react@19.2.14) - '@types/react-slider': - specifier: 1.3.6 - version: 1.3.6 '@types/react-syntax-highlighter': - specifier: 15.5.13 + specifier: 'catalog:' version: 15.5.13 '@types/react-window': - specifier: 1.8.8 + specifier: 'catalog:' version: 1.8.8 '@types/sortablejs': - specifier: 1.15.9 + specifier: 'catalog:' version: 1.15.9 '@typescript-eslint/parser': - specifier: 8.57.1 - version: 8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) + specifier: 'catalog:' + version: 8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) '@typescript/native-preview': - specifier: 7.0.0-dev.20260318.1 - version: 7.0.0-dev.20260318.1 + specifier: 'catalog:' + version: 7.0.0-dev.20260329.1 '@vitejs/plugin-react': - specifier: 6.0.1 - version: 6.0.1(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)) + specifier: 'catalog:' + version: 6.0.1(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)) '@vitejs/plugin-rsc': - specifier: 0.5.21 - version: 0.5.21(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(react-dom@19.2.4(react@19.2.4))(react-server-dom-webpack@19.2.4(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)))(react@19.2.4) + specifier: 'catalog:' + version: 0.5.21(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(react-dom@19.2.4(react@19.2.4))(react-server-dom-webpack@19.2.4(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)))(react@19.2.4) '@vitest/coverage-v8': - specifier: 4.1.0 - version: 4.1.0(@voidzero-dev/vite-plus-test@0.1.12(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.0(canvas@3.2.1))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)) + specifier: 'catalog:' + version: 4.1.2(@voidzero-dev/vite-plus-test@0.1.14(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)) agentation: - specifier: 2.3.3 - version: 2.3.3(react-dom@19.2.4(react@19.2.4))(react@19.2.4) + specifier: 'catalog:' + version: 3.0.2(react-dom@19.2.4(react@19.2.4))(react@19.2.4) autoprefixer: - specifier: 10.4.27 + specifier: 'catalog:' version: 10.4.27(postcss@8.5.8) code-inspector-plugin: - specifier: 1.4.4 - version: 1.4.4 + specifier: 'catalog:' + version: 1.4.5 eslint: - specifier: 10.0.3 - version: 10.0.3(jiti@1.21.7) + specifier: 'catalog:' + version: 10.1.0(jiti@1.21.7) + eslint-markdown: + specifier: 'catalog:' + version: 0.6.0(eslint@10.1.0(jiti@1.21.7)) eslint-plugin-better-tailwindcss: - specifier: 4.3.2 - version: 4.3.2(eslint@10.0.3(jiti@1.21.7))(oxlint@1.55.0(oxlint-tsgolint@0.17.0))(tailwindcss@3.4.19(tsx@4.21.0)(yaml@2.8.2))(typescript@5.9.3) + specifier: 'catalog:' + version: 4.3.2(eslint@10.1.0(jiti@1.21.7))(oxlint@1.57.0(oxlint-tsgolint@0.17.3))(tailwindcss@3.4.19(tsx@4.21.0)(yaml@2.8.3))(typescript@5.9.3) eslint-plugin-hyoban: - specifier: 0.14.1 - version: 0.14.1(eslint@10.0.3(jiti@1.21.7)) + specifier: 'catalog:' + version: 0.14.1(eslint@10.1.0(jiti@1.21.7)) + eslint-plugin-markdown-preferences: + specifier: 'catalog:' + version: 0.40.3(@eslint/markdown@7.5.1)(eslint@10.1.0(jiti@1.21.7)) + eslint-plugin-no-barrel-files: + specifier: 'catalog:' + version: 1.2.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) eslint-plugin-react-hooks: - specifier: 7.0.1 - version: 7.0.1(eslint@10.0.3(jiti@1.21.7)) + specifier: 'catalog:' + version: 7.0.1(eslint@10.1.0(jiti@1.21.7)) eslint-plugin-react-refresh: - specifier: 0.5.2 - version: 0.5.2(eslint@10.0.3(jiti@1.21.7)) + specifier: 'catalog:' + version: 0.5.2(eslint@10.1.0(jiti@1.21.7)) eslint-plugin-sonarjs: - specifier: 4.0.2 - version: 4.0.2(eslint@10.0.3(jiti@1.21.7)) + specifier: 'catalog:' + version: 4.0.2(eslint@10.1.0(jiti@1.21.7)) eslint-plugin-storybook: - specifier: 10.3.0 - version: 10.3.0(eslint@10.0.3(jiti@1.21.7))(storybook@10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3) + specifier: 'catalog:' + version: 10.3.3(eslint@10.1.0(jiti@1.21.7))(storybook@10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3) + happy-dom: + specifier: 'catalog:' + version: 20.8.9 + hono: + specifier: 'catalog:' + version: 4.12.9 husky: - specifier: 9.1.7 + specifier: 'catalog:' version: 9.1.7 iconify-import-svg: - specifier: 0.1.2 + specifier: 'catalog:' version: 0.1.2 - jsdom: - specifier: 29.0.0 - version: 29.0.0(canvas@3.2.1) - jsdom-testing-mocks: - specifier: 1.16.0 - version: 1.16.0 knip: - specifier: 5.88.0 - version: 5.88.0(@types/node@25.5.0)(typescript@5.9.3) + specifier: 'catalog:' + version: 6.1.0(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1) lint-staged: - specifier: 16.4.0 + specifier: 'catalog:' version: 16.4.0 - nock: - specifier: 14.0.11 - version: 14.0.11 postcss: - specifier: 8.5.8 + specifier: 'catalog:' version: 8.5.8 postcss-js: - specifier: 5.1.0 + specifier: 'catalog:' version: 5.1.0(postcss@8.5.8) react-server-dom-webpack: - specifier: 19.2.4 + specifier: 'catalog:' version: 19.2.4(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) sass: - specifier: 1.98.0 + specifier: 'catalog:' version: 1.98.0 storybook: - specifier: 10.3.0 - version: 10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) + specifier: 'catalog:' + version: 10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) tailwindcss: - specifier: 3.4.19 - version: 3.4.19(tsx@4.21.0)(yaml@2.8.2) - taze: - specifier: 19.10.0 - version: 19.10.0 + specifier: 'catalog:' + version: 3.4.19(tsx@4.21.0)(yaml@2.8.3) tsx: - specifier: 4.21.0 + specifier: 'catalog:' version: 4.21.0 typescript: - specifier: 5.9.3 + specifier: 'catalog:' version: 5.9.3 uglify-js: - specifier: 3.19.3 + specifier: 'catalog:' version: 3.19.3 vinext: - specifier: 0.0.31 - version: 0.0.31(d43efe4756ad5ea698dcdb002ea787ea) + specifier: 'catalog:' + version: 0.0.38(f5786d681f520e26604259e094ebaa46) vite: - specifier: npm:@voidzero-dev/vite-plus-core@0.1.12 - version: '@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' + specifier: npm:@voidzero-dev/vite-plus-core@0.1.14 + version: '@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' vite-plugin-inspect: - specifier: 11.3.3 - version: 11.3.3(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)) + specifier: 'catalog:' + version: 12.0.0-beta.1(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(typescript@5.9.3)(ws@8.20.0) vite-plus: - specifier: 0.1.12 - version: 0.1.12(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.0(canvas@3.2.1))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2) + specifier: 'catalog:' + version: 0.1.14(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3) vitest: - specifier: npm:@voidzero-dev/vite-plus-test@0.1.12 - version: '@voidzero-dev/vite-plus-test@0.1.12(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.0(canvas@3.2.1))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' + specifier: npm:@voidzero-dev/vite-plus-test@0.1.14 + version: '@voidzero-dev/vite-plus-test@0.1.14(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' vitest-canvas-mock: - specifier: 1.1.3 - version: 1.1.3(@voidzero-dev/vite-plus-test@0.1.12(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.0(canvas@3.2.1))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)) + specifier: 'catalog:' + version: 1.1.4(@voidzero-dev/vite-plus-test@0.1.14(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)) packages: @@ -627,17 +1239,17 @@ packages: resolution: {integrity: sha512-UrcABB+4bUrFABwbluTIBErXwvbsU/V7TZWfmbgJfbkwiBuziS9gxdODUyuiecfdGQ85jglMW6juS3+z5TsKLw==} engines: {node: '>=10'} - '@amplitude/analytics-browser@2.36.7': - resolution: {integrity: sha512-aqEakThBQI+nEV/ytMqyUhHvSisjqKv9g2hpMA8sQBy3MYWzATWCr63gVyml7U56QwoDvCIomt80HENBSsYVqg==} + '@amplitude/analytics-browser@2.38.0': + resolution: {integrity: sha512-MhqyEkr1gGAR4s4GSSflDhFVheIx9Nv3FfElQu9NlNrXB2Hh3BEOyVgdK7hgfi6NJwFyfw30+t5lym+njtA8hA==} - '@amplitude/analytics-client-common@2.4.37': - resolution: {integrity: sha512-mTJY7LXPdOPjUe3wTgSq9J/RX9+gsNpcKWQ3VUDSOgCSgSe5NW/2WopmHBbt8FLZN29OnrApz+WGyEwvMQt/NQ==} + '@amplitude/analytics-client-common@2.4.41': + resolution: {integrity: sha512-+GbvtvhsUROotPBwfAxbrqovKePhC0oQKXtxjbeNQleOHjBjsAs5jEOCHpJenCKtaRpucg/FuK3NVOS09MfW7Q==} '@amplitude/analytics-connector@1.6.4': resolution: {integrity: sha512-SpIv0IQMNIq6SH3UqFGiaZyGSc7PBZwRdq7lvP0pBxW8i4Ny+8zwI0pV+VMfMHQwWY3wdIbWw5WQphNjpdq1/Q==} - '@amplitude/analytics-core@2.41.7': - resolution: {integrity: sha512-6vb7kX/k64A9GzHtxLvm/PJf1kDgYRFxKOqSbKXi0z2N2OVfrrbPD6uFve8lLdT0iVuSGo3HVGG7V1bQ8rIfiQ==} + '@amplitude/analytics-core@2.44.0': + resolution: {integrity: sha512-z9QuTxLqEQ8KIeAT6Vmy6K48rP9TUmjnb4GwUMYoV/fxu3B9ClTaN18zqXQMmDw9HwUiIreHiVbwTb7OQRN5aA==} '@amplitude/analytics-types@2.11.1': resolution: {integrity: sha512-wFEgb0t99ly2uJKm5oZ28Lti0Kh5RecR5XBkwfUpDzn84IoCIZ8GJTsMw/nThu8FZFc7xFDA4UAt76zhZKrs9A==} @@ -645,52 +1257,61 @@ packages: '@amplitude/experiment-core@0.7.2': resolution: {integrity: sha512-Wc2NWvgQ+bLJLeF0A9wBSPIaw0XuqqgkPKsoNFQrmS7r5Djd56um75In05tqmVntPJZRvGKU46pAp8o5tdf4mA==} - '@amplitude/plugin-autocapture-browser@1.23.7': - resolution: {integrity: sha512-dTUpJEUNbHy9pXpBm/UmNk4wWBcCd14MwpZYFLsJNZzaDT5Iyo9MhE946TEoQG9LwE/wAQHiVUk3n0bIGaCzEQ==} + '@amplitude/plugin-autocapture-browser@1.25.0': + resolution: {integrity: sha512-YuWsz8XmJuKu3NlMxkvlhLey/5tGCeOwwfsROHficR0yDWO9gNG0WtHl7A0Pw1PUc9iaXjqfG2AjYumAtiq16Q==} - '@amplitude/plugin-network-capture-browser@1.9.7': - resolution: {integrity: sha512-HLHVlb2G9p7HABJvmJRigCO/h06oD6F9AqB47v0it671dstuytOtdSYP5ZBfFCEEshyd4gZzT4Qk5dd3foxtqw==} + '@amplitude/plugin-custom-enrichment-browser@0.1.2': + resolution: {integrity: sha512-ZX9BKqs1E1OI7l7QCGu9JnB/1kqLN+zqIePgM2tuEhZNFQJaw4NhAMUaMRqvNnaCkHlmpVRISzSj/4D3tWMRtA==} - '@amplitude/plugin-page-url-enrichment-browser@0.6.11': - resolution: {integrity: sha512-aRQb2GkW4g4X+Yyb4R5DXaBTUGq0NKIBOEQHL0ywsWMoNY/k3S1SHN0iqIjDLlIlOC23ZyhrBMPnriGiUPGZpA==} + '@amplitude/plugin-network-capture-browser@1.9.11': + resolution: {integrity: sha512-49o3zYnKUmRdrxgAEcr1iHnXR1um40e1icO0hzugSq04k19hs27zcl3zpEk9geO+nNKwO744ryE1q93gqVbHrQ==} - '@amplitude/plugin-page-view-tracking-browser@2.8.7': - resolution: {integrity: sha512-imsBOuSdeYu+CMy/RJl3uVL3NzJGf8IORecaCkZoCleeQWt7il8cAmtL5xO0EVPZaOWifZ/juVL+DUdBxpnJrw==} + '@amplitude/plugin-page-url-enrichment-browser@0.7.3': + resolution: {integrity: sha512-3UZq/zKg4lcsRgziWAPSEeaUsNsbyjjxmsAE9kSDi/hIj5RaWnwWhY6TGhv45UAReugTA4vVZyFRg9btf3c/Fg==} - '@amplitude/plugin-session-replay-browser@1.26.4': - resolution: {integrity: sha512-eJ783UPWvZtf2ThWs0pONZaHY/KtMPjMWIF48YQjkI2Z8e30qJ1kdE0++bNXJ0jchrOUs2UCl2wsWiPlaR0tAQ==} + '@amplitude/plugin-page-view-tracking-browser@2.9.4': + resolution: {integrity: sha512-J16zmEadnzNpkHSmzpTiQN2q9pGJ/4SkHONA9O8KxUsMU/MYTDgof3rAYY/w5B5rmvdxfMRCjqWtvnkizzgZ6w==} - '@amplitude/plugin-web-vitals-browser@1.1.22': - resolution: {integrity: sha512-DjjkWvxUYfR/axvxCcXJQOUXpSfd3nF6kg+a63nJo2pf6EGkSHhWkXz9p7O0/IqWi5P+i6pDWO2m8031+OnG+A==} + '@amplitude/plugin-session-replay-browser@1.27.5': + resolution: {integrity: sha512-tf0Ty1nNF8OJ5QQ5scEqdGfzdgIaqkRf2MSzQfHbGcTIoYuVmAKuCgn3yMLk62MKnwgG3IsTIugMdRRv7l85PA==} - '@amplitude/rrdom@2.0.0-alpha.35': - resolution: {integrity: sha512-W9ImCKtgFB8oBKd7td0TH7JKkQ/3iwu5bfLXcOvzxLj7+RSD1k1gfDyncooyobwBV8j4FMiTyj2N53tJ6rFgaw==} + '@amplitude/plugin-web-vitals-browser@1.1.26': + resolution: {integrity: sha512-wiD4vy+f2fepr+8Lnn26TYYjDEnWsmlGhJog99x+xfbZ/D+stGdaCIOz5AOjU1TpTRvxvamEu2XuOh+8EZOCSA==} - '@amplitude/rrweb-packer@2.0.0-alpha.35': - resolution: {integrity: sha512-A6BlcBuiAI8pHJ51mcQWu2Uddnddxj9MaYZMNjIzFm1FK+qYAyYafO1xcoVPXoMUHE/qqITUgAn9tUVWj8N8NQ==} + '@amplitude/rrdom@2.0.0-alpha.37': + resolution: {integrity: sha512-u4dSnBtlbJ8oU5P/Ywl2RLqvjqWbkl4ScMUbvQA7in4pWcx+0NRN+VVjLZXQcd8Fn7E/rcxjeUh7e7HfwvdasQ==} - '@amplitude/rrweb-plugin-console-record@2.0.0-alpha.35': - resolution: {integrity: sha512-8hstBoMHMSEA3FGoQ0LKidhpQypKchyT2sjEDdwTC77xZSg+3LwtjElOSMVdgjrEfxvN4V1g72v+Pwy7LBGUDA==} + '@amplitude/rrweb-packer@2.0.0-alpha.36': + resolution: {integrity: sha512-kqKg6OGoxHZvG4jwyO4kIjLdf8MkL6JcY5iLB09PQNP7O36ysnrH+ecJfa4V1Rld99kX25Pefkw4bzKmmFAqcg==} + + '@amplitude/rrweb-plugin-console-record@2.0.0-alpha.36': + resolution: {integrity: sha512-7VbXu36PpJA8dSOFxpfpMaoDTuPK5uy1C8mN+Wfdm0X4ROdmrvcTdlQj+jGzhLGeK+xbTixHEy23itCNUau7hQ==} peerDependencies: - '@amplitude/rrweb': ^2.0.0-alpha.35 + '@amplitude/rrweb': ^2.0.0-alpha.36 - '@amplitude/rrweb-record@2.0.0-alpha.35': - resolution: {integrity: sha512-C8lr6LLMXLDINWE3SaebDrc4sj1pSFKm9s+zlW5e8CkAuAv8XfA5Wjx5cevxG3LMkIwXdugvrrjYKmEVCODI1g==} + '@amplitude/rrweb-record@2.0.0-alpha.36': + resolution: {integrity: sha512-zSHvmG5NUG4jNgWNVM7Oj3+rJPagv+TiHlnSiJ1X0WWLIg1GbUnOoTqpincZS5QupqTxQchNQaUg9MNu0MM3sQ==} - '@amplitude/rrweb-snapshot@2.0.0-alpha.35': - resolution: {integrity: sha512-n55AdmlRNZ7XuOlCRmSjH2kyyHS1oe5haUS+buxqjfQcamUtam+dSnP+6N1E8dLxIDjynJnbrCOC+8xvenpl1A==} + '@amplitude/rrweb-snapshot@2.0.0-alpha.37': + resolution: {integrity: sha512-OPW2r8ESAguq+1R+z+WxGyzZzkMtojZ49Lpp6NrataNFyjdKaNXehDuLoNlEQkkUZGyDBiA7RSYvUw+JPSmmSQ==} - '@amplitude/rrweb-types@2.0.0-alpha.35': - resolution: {integrity: sha512-cR/xlN5fu7Cw6Zh9O6iEgNleqT92wJ3HO2mV19yQE6SRqLGKXXeDeTrUBd5FKCZnXvRsv3JtK+VR4u9vmZze3g==} + '@amplitude/rrweb-types@2.0.0-alpha.36': + resolution: {integrity: sha512-Bd2r3Bs0XIJt5fgPRWVl8bhvA9FCjJn8vQlDTO8ffPxilGPIzUXLQ06+xoLYkK9v+PDKJnCapOTL4A2LilDmgA==} - '@amplitude/rrweb-utils@2.0.0-alpha.35': - resolution: {integrity: sha512-/OpyKKHYGwoy2fvWDg5jiH1LzWag4wlFTQjd2DUgndxlXccQF1+yxYljCDdM+J1GBeZ7DaLZa9qe2JUUtoNOOw==} + '@amplitude/rrweb-types@2.0.0-alpha.37': + resolution: {integrity: sha512-LW9wQ85umaAW/qlemTrUC408WVoBx99hvFCjsNRnxAyUmRemWyYY7+o8xPyeUexoWGqizPMkkNnPEO8t1NFjtw==} - '@amplitude/rrweb@2.0.0-alpha.35': - resolution: {integrity: sha512-qFaZDNMkjolZUVv1OxrWngGl38FH0iF0jtybd/vhuOzvwohJjyKL9Tgoulj8osj21/4BUpGEhWweGeJygjoJJw==} + '@amplitude/rrweb-utils@2.0.0-alpha.36': + resolution: {integrity: sha512-w5RGROLU1Kyrq9j+trxcvvfkTp05MEKJ70Ig+YvHyZsE0nElh1PCF8PHAjV0/kji68+KqB03c0hoyaV99CDaDw==} - '@amplitude/session-replay-browser@1.33.1': - resolution: {integrity: sha512-5Mjd5rWq9VxVvDewH+l7m22fJhnBsTHYGXI0GxZTMZZGKP7PJm77O9oDaar+2cCCr8ckk7RnIfelgWRS1lmbDA==} + '@amplitude/rrweb-utils@2.0.0-alpha.37': + resolution: {integrity: sha512-40YvPj24ietFQ3BTLfvFRPriRqdNOp3DzGiPU+WDOZkI3KjInQrEsibaqNBSXzJ+kMWrm8/eEwcQ0FkLk7Achw==} + + '@amplitude/rrweb@2.0.0-alpha.37': + resolution: {integrity: sha512-jJkSpPYiVgOZB422pb2jOJJn3pvb5E5f9vKK8CEmUlk2mVAl6kPQzW98mb05M65OJFj5nn9tSe9h5r5+Cl93ag==} + + '@amplitude/session-replay-browser@1.35.0': + resolution: {integrity: sha512-aGqu807oC8UIMmP+g1jBYsgN+/VeR/ThtK6fpxuZCugEogx7EZ9sXDEeudUmyvkQQfWmD+nLmrhYPX8FpROT5w==} '@amplitude/targeting@0.2.0': resolution: {integrity: sha512-/50ywTrC4hfcfJVBbh5DFbqMPPfaIOivZeb5Gb+OGM03QrA+lsUqdvtnKLNuWtceD4H6QQ2KFzPJ5aAJLyzVDA==} @@ -770,17 +1391,6 @@ packages: '@antfu/utils@8.1.1': resolution: {integrity: sha512-Mex9nXf9vR6AhcXmMrlz/HVgYYZpVGJ6YlPgwl7UnaFpnshXs6EK/oa5Gpf3CzENMjkvEx2tQtntGnb7UtSTOQ==} - '@asamuzakjp/css-color@5.0.1': - resolution: {integrity: sha512-2SZFvqMyvboVV1d15lMf7XiI3m7SDqXUuKaTymJYLN6dSGadqp+fVojqJlVoMlbZnlTmu3S0TLwLTJpvBMO1Aw==} - engines: {node: ^20.19.0 || ^22.12.0 || >=24.0.0} - - '@asamuzakjp/dom-selector@7.0.3': - resolution: {integrity: sha512-Q6mU0Z6bfj6YvnX2k9n0JxiIwrCFN59x/nWmYQnAqP000ruX/yV+5bp/GRcF5T8ncvfwJQ7fgfP74DlpKExILA==} - engines: {node: ^20.19.0 || ^22.12.0 || >=24.0.0} - - '@asamuzakjp/nwsapi@2.3.9': - resolution: {integrity: sha512-n8GuYSrI9bF7FFZ/SjhwevlHc8xaVlb/7HmHelnc/PZXBD2ZR49NnN9sMMuDdEGPeeRQ5d0hqlSlEpgCX3Wl0Q==} - '@babel/code-frame@7.29.0': resolution: {integrity: sha512-9NhCeYjq9+3uxgdtp20LSiJXJvN0FeCtNGpJxuMFZ1Kv3cWUNb6DOhJwUvcVCzKGR66cw4njwM6hrJLqgOwbcw==} engines: {node: '>=6.9.0'} @@ -880,10 +1490,6 @@ packages: '@braintree/sanitize-url@7.1.2': resolution: {integrity: sha512-jigsZK+sMF/cuiB7sERuo9V7N9jx+dhmHHnQyDSVdpZwVutaBu7WvNYqMDLSgFgfB30n452TP3vjDAvFC973mA==} - '@bramus/specificity@2.4.2': - resolution: {integrity: sha512-ctxtJ/eA+t+6q2++vj5j7FYX3nRu311q1wfYH3xjlLOsczhlhxAg2FWNUXhpGvAw3BWo1xBcvOV6/YLc2r5FJw==} - hasBin: true - '@chevrotain/cst-dts-gen@11.1.2': resolution: {integrity: sha512-XTsjvDVB5nDZBQB8o0o/0ozNelQtn2KrUVteIHSlPd2VAV2utEb6JzyCJaJ8tGxACR4RiBNWy5uYUHX2eji88Q==} @@ -899,11 +1505,11 @@ packages: '@chevrotain/utils@11.1.2': resolution: {integrity: sha512-4mudFAQ6H+MqBTfqLmU7G1ZwRzCLfJEooL/fsF6rCX5eePMbGhoy5n4g+G4vlh2muDcsCTJtL+uKbOzWxs5LHA==} - '@chromatic-com/storybook@5.0.1': - resolution: {integrity: sha512-v80QBwVd8W6acH5NtDgFlUevIBaMZAh1pYpBiB40tuNzS242NTHeQHBDGYwIAbWKDnt1qfjJpcpL6pj5kAr4LA==} + '@chromatic-com/storybook@5.1.1': + resolution: {integrity: sha512-BPoAXHM71XgeCK2u0jKr9i8apeQMm/Z9IWGyndA2FMijfQG9m8ox45DdWh/pxFkK5ClhGgirv5QwMhFIeHmThg==} engines: {node: '>=20.0.0', yarn: '>=1.22.18'} peerDependencies: - storybook: ^0.0.0-0 || ^10.1.0 || ^10.1.0-0 || ^10.2.0-0 || ^10.3.0-0 + storybook: ^0.0.0-0 || ^10.1.0 || ^10.1.0-0 || ^10.2.0-0 || ^10.3.0-0 || ^10.4.0-0 '@clack/core@0.3.5': resolution: {integrity: sha512-5cfhQNH+1VQ2xLQlmzXMqUoiaH0lRBq9/CLW9lTyMbuKLC3+xEK01tHVvyut++mLOn5urSHmkm6I0Lg9MaJSTQ==} @@ -917,59 +1523,85 @@ packages: '@clack/prompts@1.1.0': resolution: {integrity: sha512-pkqbPGtohJAvm4Dphs2M8xE29ggupihHdy1x84HNojZuMtFsHiUlRvqD24tM2+XmI+61LlfNceM3Wr7U5QES5g==} - '@code-inspector/core@1.4.4': - resolution: {integrity: sha512-bQNcbiiTodOiVuJ9JQ/AgyArfc5rH9qexzDya3ugasIbUMfUNBPKCwoq6He4Y6/bwUx6mUqwTODwPtu13BR75Q==} + '@code-inspector/core@1.4.5': + resolution: {integrity: sha512-wskkSRX13TAqJG65d5sq0bRZ4kYktas/iE70xqXMOeqW/A6n2Zqhw5QRHANmEmlBvB9bP/bse+9iBkNN3Q2Skw==} - '@code-inspector/esbuild@1.4.4': - resolution: {integrity: sha512-quGKHsPiFRIPMGOhtHhSQhqDAdvC5aGvKKk4EAhvNvZG1TGxt0nXu99+O0shHdl6TQhlq1NgmPyTWqGyVM5s6g==} + '@code-inspector/esbuild@1.4.5': + resolution: {integrity: sha512-KBwq7waqZ3L1CW7N9ff7aS0HxzamrslR08i5ovkLQe1p6tH9Axe9zzCrBnvgmB0UZsT2r/5wKLOWyEpq5+VYKw==} - '@code-inspector/mako@1.4.4': - resolution: {integrity: sha512-SSs9oo3THS7vAFceAcICvVbbmaU9z6omwiXbCjIGhCxMvm7T6s/au4VHuOyU8Z3+floz+lDg/6W72VdBxWwVSg==} + '@code-inspector/mako@1.4.5': + resolution: {integrity: sha512-yrHgE5+b4ZL29Xt+y0H/9xrXSbRskq7dFhmE9GYFWCcgdWNCMD25hZd7xZVije94++H65Vw6Bu/abfqEx0peog==} - '@code-inspector/turbopack@1.4.4': - resolution: {integrity: sha512-ZK/sHPB4A+qcHXg+sR+0qCSFA2CYTfuPXaHC9GdnwwNdz6lhO3bkG7Ju0csKVxEp3LR8UVfMsKsRYbGSs8Ly8w==} + '@code-inspector/turbopack@1.4.5': + resolution: {integrity: sha512-IG39ikmQthdx/oAxhpV7zsIQZ3Jpycl88JzH+UXHq0ZpfHwa1KdNc/9erP3kFMY4+ANmkmerqBk57knmRTGMRQ==} - '@code-inspector/vite@1.4.4': - resolution: {integrity: sha512-UWnkaRTHwUDezKp1vXUrjr8Q93s91iYHbsyhfjOJGIiqBvmcaa3nqBlEAt7rzEi5hdaQVVeFdh+9q+4cVpK26A==} + '@code-inspector/vite@1.4.5': + resolution: {integrity: sha512-vBtH91afwYL7JV4zWcJJTFd65LJ4SZz5E9AwGgCF30/L1mdDx7U29D+M+JpaxSgsMB6monKSZh+ubbqYe0ixpQ==} - '@code-inspector/webpack@1.4.4': - resolution: {integrity: sha512-icYvkENomjUhlBXhYwkDFMtk62BPEWJCNsfYyHnQlGNJWW8SKuLU3AAbJQJMvA6Nmp++r9D/8xj1OJ2K1Y+/Dg==} + '@code-inspector/webpack@1.4.5': + resolution: {integrity: sha512-lwUv+X1FNSUWz+FKcUsE2dT2pg6VFRRXKt16hg/m+Lwtdet2adfi6BFLZmNz3OPIEGbRB5Kjx6bfaghZhbDCCg==} - '@csstools/color-helpers@6.0.2': - resolution: {integrity: sha512-LMGQLS9EuADloEFkcTBR3BwV/CGHV7zyDxVRtVDTwdI2Ca4it0CCVTT9wCkxSgokjE5Ho41hEPgb8OEUwoXr6Q==} - engines: {node: '>=20.19.0'} + '@colors/colors@1.5.0': + resolution: {integrity: sha512-ooWCrlZP11i8GImSjTHYHLkvFDP48nS4+204nGb1RiX/WXYHmJA2III9/e2DWVabCESdW7hBAEzHRqUn9OUVvQ==} + engines: {node: '>=0.1.90'} - '@csstools/css-calc@3.1.1': - resolution: {integrity: sha512-HJ26Z/vmsZQqs/o3a6bgKslXGFAungXGbinULZO3eMsOyNJHeBBZfup5FiZInOghgoM4Hwnmw+OgbJCNg1wwUQ==} - engines: {node: '>=20.19.0'} + '@cucumber/ci-environment@13.0.0': + resolution: {integrity: sha512-cs+3NzfNkGbcmHPddjEv4TKFiBpZRQ6WJEEufB9mw+ExS22V/4R/zpDSEG+fsJ/iSNCd6A2sATdY8PFOyY3YnA==} + + '@cucumber/cucumber-expressions@19.0.0': + resolution: {integrity: sha512-4FKoOQh2Uf6F6/Ln+1OxuK8LkTg6PyAqekhf2Ix8zqV2M54sH+m7XNJNLhOFOAW/t9nxzRbw2CcvXbCLjcvHZg==} + + '@cucumber/cucumber@12.7.0': + resolution: {integrity: sha512-7A/9CJpJDxv1SQ7hAZU0zPn2yRxx6XMR+LO4T94Enm3cYNWsEEj+RGX38NLX4INT+H6w5raX3Csb/qs4vUBsOA==} + engines: {node: 20 || 22 || >=24} + hasBin: true + + '@cucumber/gherkin-streams@6.0.0': + resolution: {integrity: sha512-HLSHMmdDH0vCr7vsVEURcDA4WwnRLdjkhqr6a4HQ3i4RFK1wiDGPjBGVdGJLyuXuRdJpJbFc6QxHvT8pU4t6jw==} + hasBin: true peerDependencies: - '@csstools/css-parser-algorithms': ^4.0.0 - '@csstools/css-tokenizer': ^4.0.0 + '@cucumber/gherkin': '>=22.0.0' + '@cucumber/message-streams': '>=4.0.0' + '@cucumber/messages': '>=17.1.1' - '@csstools/css-color-parser@4.0.2': - resolution: {integrity: sha512-0GEfbBLmTFf0dJlpsNU7zwxRIH0/BGEMuXLTCvFYxuL1tNhqzTbtnFICyJLTNK4a+RechKP75e7w42ClXSnJQw==} - engines: {node: '>=20.19.0'} + '@cucumber/gherkin-utils@11.0.0': + resolution: {integrity: sha512-LJ+s4+TepHTgdKWDR4zbPyT7rQjmYIcukTwNbwNwgqr6i8Gjcmzf6NmtbYDA19m1ZFg6kWbFsmHnj37ZuX+kZA==} + hasBin: true + + '@cucumber/gherkin@38.0.0': + resolution: {integrity: sha512-duEXK+KDfQUzu3vsSzXjkxQ2tirF5PRsc1Xrts6THKHJO6mjw4RjM8RV+vliuDasmhhrmdLcOcM7d9nurNTJKw==} + + '@cucumber/html-formatter@23.0.0': + resolution: {integrity: sha512-WwcRzdM8Ixy4e53j+Frm3fKM5rNuIyWUfy4HajEN+Xk/YcjA6yW0ACGTFDReB++VDZz/iUtwYdTlPRY36NbqJg==} peerDependencies: - '@csstools/css-parser-algorithms': ^4.0.0 - '@csstools/css-tokenizer': ^4.0.0 + '@cucumber/messages': '>=18' - '@csstools/css-parser-algorithms@4.0.0': - resolution: {integrity: sha512-+B87qS7fIG3L5h3qwJ/IFbjoVoOe/bpOdh9hAjXbvx0o8ImEmUsGXN0inFOnk2ChCFgqkkGFQ+TpM5rbhkKe4w==} - engines: {node: '>=20.19.0'} + '@cucumber/junit-xml-formatter@0.9.0': + resolution: {integrity: sha512-WF+A7pBaXpKMD1i7K59Nk5519zj4extxY4+4nSgv5XLsGXHDf1gJnb84BkLUzevNtp2o2QzMG0vWLwSm8V5blw==} peerDependencies: - '@csstools/css-tokenizer': ^4.0.0 + '@cucumber/messages': '*' - '@csstools/css-syntax-patches-for-csstree@1.1.1': - resolution: {integrity: sha512-BvqN0AMWNAnLk9G8jnUT77D+mUbY/H2b3uDTvg2isJkHaOufUE2R3AOwxWo7VBQKT1lOdwdvorddo2B/lk64+w==} + '@cucumber/message-streams@4.0.1': + resolution: {integrity: sha512-Kxap9uP5jD8tHUZVjTWgzxemi/0uOsbGjd4LBOSxcJoOCRbESFwemUzilJuzNTB8pcTQUh8D5oudUyxfkJOKmA==} peerDependencies: - css-tree: ^3.2.1 - peerDependenciesMeta: - css-tree: - optional: true + '@cucumber/messages': '>=17.1.1' - '@csstools/css-tokenizer@4.0.0': - resolution: {integrity: sha512-QxULHAm7cNu72w97JUNCBFODFaXpbDg+dP8b/oWFAZ2MTRppA3U00Y2L1HqaS4J6yBqxwa/Y3nMBaxVKbB/NsA==} - engines: {node: '>=20.19.0'} + '@cucumber/messages@32.0.1': + resolution: {integrity: sha512-1OSoW+GQvFUNAl6tdP2CTBexTXMNJF0094goVUcvugtQeXtJ0K8sCP0xbq7GGoiezs/eJAAOD03+zAPT64orHQ==} + + '@cucumber/pretty-formatter@1.0.1': + resolution: {integrity: sha512-A1lU4VVP0aUWdOTmpdzvXOyEYuPtBDI0xYwYJnmoMDplzxMdhcHk86lyyvYDoMoPzzq6OkOE3isuosvUU4X7IQ==} + peerDependencies: + '@cucumber/cucumber': '>=7.0.0' + '@cucumber/messages': '*' + + '@cucumber/query@14.7.0': + resolution: {integrity: sha512-fiqZ4gMEgYjmbuWproF/YeCdD5y+gD2BqgBIGbpihOsx6UlNsyzoDSfO+Tny0q65DxfK+pHo2UkPyEl7dO7wmQ==} + peerDependencies: + '@cucumber/messages': '*' + + '@cucumber/tag-expressions@9.1.0': + resolution: {integrity: sha512-bvHjcRFZ+J1TqIa9eFNO1wGHqwx4V9ZKV3hYgkuK/VahHx73uiP4rKV3JVrvWSMrwrFvJG6C8aEwnCWSvbyFdQ==} '@e18e/eslint-plugin@0.2.0': resolution: {integrity: sha512-mXgODVwhuDjTJ+UT+XSvmMmCidtGKfrV5nMIv1UtpWex2pYLsIM3RSpT8HWIMAebS9qANbXPKlSX4BE7ZvuCgA==} @@ -987,11 +1619,11 @@ packages: peerDependencies: tailwindcss: '*' - '@emnapi/core@1.9.0': - resolution: {integrity: sha512-0DQ98G9ZQZOxfUcQn1waV2yS8aWdZ6kJMbYCJB3oUBecjWYO1fqJ+a1DRfPF3O5JEkwqwP1A9QEN/9mYm2Yd0w==} + '@emnapi/core@1.9.1': + resolution: {integrity: sha512-mukuNALVsoix/w1BJwFzwXBN/dHeejQtuVzcDsfOEsdpCumXb/E9j8w11h5S54tT1xhifGfbbSm/ICrObRb3KA==} - '@emnapi/runtime@1.9.0': - resolution: {integrity: sha512-QN75eB0IH2ywSpRpNddCRfQIhmJYBCJ1x5Lb3IscKAL8bMnVAKnRg8dCoXbHzVLLH7P38N2Z3mtulB7W0J0FKw==} + '@emnapi/runtime@1.9.1': + resolution: {integrity: sha512-VYi5+ZVLhpgK4hQ0TAjiQiZ6ol0oe4mBx7mVv7IflsiEp0OWoVsp/+f9Vc1hOhE0TtkORVrI1GvzyreqpgWtkA==} '@emnapi/wasi-threads@1.2.0': resolution: {integrity: sha512-N10dEJNSsUx41Z6pZsXU8FjPjpBEplgH24sfkmITrBED1/U2Esum9F3lfLrMjKHHjmi557zQn7kR9R+XWXu5Rg==} @@ -1179,44 +1811,40 @@ packages: resolution: {integrity: sha512-EriSTlt5OC9/7SXkRSCAhfSxxoSUgBm33OH+IkwbdpgoqsSsUg7y3uh+IICI/Qg4BBWr3U2i39RpmycbxMq4ew==} engines: {node: ^12.0.0 || ^14.0.0 || >=16.0.0} - '@eslint-react/ast@2.13.0': - resolution: {integrity: sha512-43+5gmqV3MpatTzKnu/V2i/jXjmepvwhrb9MaGQvnXHQgq9J7/C7VVCCcwp6Rvp2QHAFquAAdvQDSL8IueTpeA==} - engines: {node: '>=20.19.0'} + '@eslint-react/ast@3.0.0': + resolution: {integrity: sha512-qBasEJqMhcof/pbxhKSgp52rW9TMUMVIYqv3SOgSzvDG3bed+saWFXOQ+YFMj/o5gr/e6Dsi3mAHqErPzJHelA==} + engines: {node: '>=22.0.0'} peerDependencies: - eslint: ^8.57.0 || ^9.0.0 || ^10.0.0 - typescript: '>=4.8.4 <6.0.0' + eslint: ^10.0.0 + typescript: '*' - '@eslint-react/core@2.13.0': - resolution: {integrity: sha512-m62XDzkf1hpzW4sBc7uh7CT+8rBG2xz/itSADuEntlsg4YA7Jhb8hjU6VHf3wRFDwyfx5VnbV209sbJ7Azey0Q==} - engines: {node: '>=20.19.0'} + '@eslint-react/core@3.0.0': + resolution: {integrity: sha512-PKa13GrqUAilcvcONJMN8BukuVg3dHuaTxjNBdKOHGxkMexCxDF9hjNHBILErJhFs1kGaJPBK9QUYQci8PV/TA==} + engines: {node: '>=22.0.0'} peerDependencies: - eslint: ^8.57.0 || ^9.0.0 || ^10.0.0 - typescript: '>=4.8.4 <6.0.0' + eslint: ^10.0.0 + typescript: '*' - '@eslint-react/eff@2.13.0': - resolution: {integrity: sha512-rEH2R8FQnUAblUW+v3ZHDU1wEhatbL1+U2B1WVuBXwSKqzF7BGaLqCPIU7o9vofumz5MerVfaCtJgI8jYe2Btg==} - engines: {node: '>=20.19.0'} - - '@eslint-react/eslint-plugin@2.13.0': - resolution: {integrity: sha512-iaMXpqnJCTW7317hg8L4wx7u5aIiPzZ+d1p59X8wXFgMHzFX4hNu4IfV8oygyjmWKdLsjKE9sEpv/UYWczlb+A==} - engines: {node: '>=20.19.0'} + '@eslint-react/eslint-plugin@3.0.0': + resolution: {integrity: sha512-OK8rBrsM/bUr0L918hQ1tWAufz22+m0L6gpSrW3Z/7NSg/imy17IiZHO8UVT99sgcx9euKYAT+QIx45sZUYf1g==} + engines: {node: '>=22.0.0'} peerDependencies: - eslint: ^8.57.0 || ^9.0.0 || ^10.0.0 - typescript: '>=4.8.4 <6.0.0' + eslint: ^10.0.0 + typescript: '*' - '@eslint-react/shared@2.13.0': - resolution: {integrity: sha512-IOloCqrZ7gGBT4lFf9+0/wn7TfzU7JBRjYwTSyb9SDngsbeRrtW95ZpgUpS8/jen1wUEm6F08duAooTZ2FtsWA==} - engines: {node: '>=20.19.0'} + '@eslint-react/shared@3.0.0': + resolution: {integrity: sha512-oHELwh3FghrMc5UX+4qVEdY7ZLZsO4bgKDVv5i6yk8+/997xe6LAY2wailbeljbIJxppcJSl6eXcRl2yv6ffig==} + engines: {node: '>=22.0.0'} peerDependencies: - eslint: ^8.57.0 || ^9.0.0 || ^10.0.0 - typescript: '>=4.8.4 <6.0.0' + eslint: ^10.0.0 + typescript: '*' - '@eslint-react/var@2.13.0': - resolution: {integrity: sha512-dM+QaeiHR16qPQoJYg205MkdHYSWVa2B7ore5OFpOPlSwqDV3tLW7I+475WjbK7potq5QNPTxRa7VLp9FGeQqA==} - engines: {node: '>=20.19.0'} + '@eslint-react/var@3.0.0': + resolution: {integrity: sha512-Af/7GEZfXtc9jV1i/Uqfko40Gr256YXDZR9CG6mxROOUOMRYIaBPf3K7oLCnwiKVZXiFJ5qYGLEs6HoG8Ifrjw==} + engines: {node: '>=22.0.0'} peerDependencies: - eslint: ^8.57.0 || ^9.0.0 || ^10.0.0 - typescript: '>=4.8.4 <6.0.0' + eslint: ^10.0.0 + typescript: '*' '@eslint/compat@2.0.3': resolution: {integrity: sha512-SjIJhGigp8hmd1YGIBwh7Ovri7Kisl42GYFjrOyHhtfYGGoLW6teYi/5p8W50KSsawUPpuLOSmsq1bD0NGQLBw==} @@ -1267,6 +1895,15 @@ packages: resolution: {integrity: sha512-4IlJx0X0qftVsN5E+/vGujTRIFtwuLbNsVUe7TO6zYPDR1O6nFwvwhIKEKSrl6dZchmYBITazxKoUYOjdtjlRg==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + '@eslint/js@10.0.1': + resolution: {integrity: sha512-zeR9k5pd4gxjZ0abRoIaxdc7I3nDktoXZk2qOv9gCNWx3mVwEn32VRhyLaRsDiJjTs0xq/T8mfPtyuXu7GWBcA==} + engines: {node: ^20.19.0 || ^22.13.0 || >=24} + peerDependencies: + eslint: ^10.0.0 + peerDependenciesMeta: + eslint: + optional: true + '@eslint/js@9.27.0': resolution: {integrity: sha512-G5JD9Tu5HJEu4z2Uo4aHY2sLV64B7CDMXxFzqzjl3NKd6RVzSXNoE80jk7Y0lJkTTkjiIhBAqmlYwjuBY3tvpA==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} @@ -1295,15 +1932,6 @@ packages: resolution: {integrity: sha512-iH1B076HoAshH1mLpHMgwdGeTs0CYwL0SPMkGuSebZrwBp16v415e9NZXg2jtrqPVQjf6IANe2Vtlr5KswtcZQ==} engines: {node: ^20.19.0 || ^22.13.0 || >=24} - '@exodus/bytes@1.15.0': - resolution: {integrity: sha512-UY0nlA+feH81UGSHv92sLEPLCeZFjXOuHhrIo0HQydScuQc8s0A7kL/UdgwgDq8g8ilksmuoF35YVTNphV2aBQ==} - engines: {node: ^20.19.0 || ^22.12.0 || >=24.0.0} - peerDependencies: - '@noble/hashes': ^1.8.0 || ^2.0.0 - peerDependenciesMeta: - '@noble/hashes': - optional: true - '@floating-ui/core@1.7.5': resolution: {integrity: sha512-1Ih4WTWyw0+lKyFMcBHGbb5U5FtuHJuujoyyr5zTaWS5EYMeT6Jb2AuDeftsCsEuchO+mM2ij5+q9crhydzLhQ==} @@ -1577,74 +2205,77 @@ packages: '@jridgewell/trace-mapping@0.3.31': resolution: {integrity: sha512-zzNR+SdQSDJzc8joaeP8QQoCQr8NuYx2dIIytl1QeBEZHJ9uW6hebsrYgbz8hJwUQao3TWCMtmfV8Nu1twOLAw==} - '@lexical/clipboard@0.41.0': - resolution: {integrity: sha512-Ex5lPkb4NBBX1DCPzOAIeHBJFH1bJcmATjREaqpnTfxCbuOeQkt44wchezUA0oDl+iAxNZ3+pLLWiUju9icoSA==} + '@lexical/clipboard@0.42.0': + resolution: {integrity: sha512-D3K2ID0zew/+CKpwxnUTTh/N46yU4IK8bFWV9Htz+g1vFhgUF9UnDOQCmqpJbdP7z+9U1F8rk3fzf9OmP2Fm2w==} - '@lexical/devtools-core@0.41.0': - resolution: {integrity: sha512-FzJtluBhBc8bKS11TUZe72KoZN/hnzIyiiM0SPJAsPwGpoXuM01jqpXQGybWf/1bWB+bmmhOae7O4Nywi/Csuw==} + '@lexical/code-core@0.42.0': + resolution: {integrity: sha512-vrZTUPWDJkHjAAvuV2+Qte4vYE80s7hIO7wxipiJmWojGx6lcmQjO+UqJ8AIrqI4Wjy8kXrK74kisApWmwxuCw==} + + '@lexical/devtools-core@0.42.0': + resolution: {integrity: sha512-8nP8eE9i8JImgSrvInkWFfMCmXVKp3w3VaOvbJysdlK/Zal6xd8EWJEi6elj0mUW5T/oycfipPs2Sfl7Z+n14A==} peerDependencies: react: '>=17.x' react-dom: '>=17.x' - '@lexical/dragon@0.41.0': - resolution: {integrity: sha512-gBEqkk8Q6ZPruvDaRcOdF1EK9suCVBODzOCcR+EnoJTaTjfDkCM7pkPAm4w90Wa1wCZEtFHvCfas+jU9MDSumg==} + '@lexical/dragon@0.42.0': + resolution: {integrity: sha512-/TQzP+7PLJMqq9+MlgQWiJsxS9GOOa8Gp0svCD8vNIOciYmXfd28TR1Go+ZnBWwr7k/2W++3XUYVQU2KUcQsDQ==} - '@lexical/extension@0.41.0': - resolution: {integrity: sha512-sF4SPiP72yXvIGchmmIZ7Yg2XZTxNLOpFEIIzdqG7X/1fa1Ham9P/T7VbrblWpF6Ei5LJtK9JgNVB0hb4l3o1g==} + '@lexical/extension@0.42.0': + resolution: {integrity: sha512-rkZq/h8d1BenKRqU4t/zQUVfY/RinMX1Tz7t+Ee3ss0sk+kzP4W+URXNAxpn7r39Vn6wrFBqmCziah3dLAIqPw==} - '@lexical/hashtag@0.41.0': - resolution: {integrity: sha512-tFWM74RW4KU0E/sj2aowfWl26vmLUTp331CgVESnhQKcZBfT40KJYd57HEqBDTfQKn4MUhylQCCA0hbpw6EeFQ==} + '@lexical/hashtag@0.42.0': + resolution: {integrity: sha512-WOg5nFOfhabNBXzEIutdWDj+TUHtJEezj6w8jyYDGqZ31gu0cgrXSeV8UIynz/1oj+rpzEeEB7P6ODnwgjt7qA==} - '@lexical/history@0.41.0': - resolution: {integrity: sha512-kGoVWsiOn62+RMjRolRa+NXZl8jFwxav6GNDiHH8yzivtoaH8n1SwUfLJELXCzeqzs81HySqD4q30VLJVTGoDg==} + '@lexical/history@0.42.0': + resolution: {integrity: sha512-YfCZ1ICUt6BCg2ncJWFMuS4yftnB7FEHFRf3qqTSTf6oGZ4IZfzabMNEy47xybUuf7FXBbdaCKJrc/zOM+wGxw==} - '@lexical/html@0.41.0': - resolution: {integrity: sha512-3RyZy+H/IDKz2D66rNN/NqYx87xVFrngfEbyu1OWtbY963RUFnopiVHCQvsge/8kT04QSZ7U/DzjVFqeNS6clg==} + '@lexical/html@0.42.0': + resolution: {integrity: sha512-KgBUDLXehufCsXW3w0XsuoI2xecIhouOishnaNOH4zIA7dAtnNAfdPN/kWrWs0s83gz44OrnqccP+Bprw3UDEQ==} - '@lexical/link@0.41.0': - resolution: {integrity: sha512-Rjtx5cGWAkKcnacncbVsZ1TqRnUB2Wm4eEVKpaAEG41+kHgqghzM2P+UGT15yROroxJu8KvAC9ISiYFiU4XE1w==} + '@lexical/link@0.42.0': + resolution: {integrity: sha512-cdeM/+f+kn7aGwW/3FIi6USjl1gBNdEEwg0/ZS+KlYcsy8gxx2e4cyVjsomBu/WU17Qxa0NC0paSr7qEJ/1Fig==} - '@lexical/list@0.41.0': - resolution: {integrity: sha512-RXvB+xcbzVoQLGRDOBRCacztG7V+bI95tdoTwl8pz5xvgPtAaRnkZWMDP+yMNzMJZsqEChdtpxbf0NgtMkun6g==} + '@lexical/list@0.42.0': + resolution: {integrity: sha512-TIezILnmIVuvfqEEbcMnsT4xQRlswI6ysHISqsvKL6l5EBhs1gqmNYjHa/Yrfzaq5y52TM1PAtxbFts+G7N6kg==} - '@lexical/mark@0.41.0': - resolution: {integrity: sha512-UO5WVs9uJAYIKHSlYh4Z1gHrBBchTOi21UCYBIZ7eAs4suK84hPzD+3/LAX5CB7ZltL6ke5Sly3FOwNXv/wfpA==} + '@lexical/mark@0.42.0': + resolution: {integrity: sha512-H1aGjbMEcL4B8GT7bm/ePHm7j3Wema+wIRNPmxMtXGMz5gpVN3gZlvg2UcUHHJb00SrBA95OUVT5I2nu/KP06w==} - '@lexical/markdown@0.41.0': - resolution: {integrity: sha512-bzI73JMXpjGFhqUWNV6KqfjWcgAWzwFT+J3RHtbCF5rysC8HLldBYojOgAAtPfXqfxyv2mDzsY7SoJ75s9uHZA==} + '@lexical/markdown@0.42.0': + resolution: {integrity: sha512-+mOxgBiumlgVX8Acna+9HjJfSOw1jywufGcAQq3/8S11wZ4gE0u13AaR8LMmU8ydVeOQg09y8PNzGNQ/avZJbg==} - '@lexical/offset@0.41.0': - resolution: {integrity: sha512-2RHBXZqC8gm3X9C0AyRb0M8w7zJu5dKiasrif+jSKzsxPjAUeF1m95OtIOsWs1XLNUgASOSUqGovDZxKJslZfA==} + '@lexical/offset@0.42.0': + resolution: {integrity: sha512-V+4af1KmTOnBZrR+kU3e6eD33W/g3QqMPPp3cpFwyXk/dKRc4K8HfyDsSDrjop1mPd9pl3lKSiEmX6uQG8K9XQ==} - '@lexical/overflow@0.41.0': - resolution: {integrity: sha512-Iy6ZiJip8X14EBYt1zKPOrXyQ4eG9JLBEoPoSVBTiSbVd+lYicdUvaOThT0k0/qeVTN9nqTaEltBjm56IrVKCQ==} + '@lexical/overflow@0.42.0': + resolution: {integrity: sha512-wlrHaM27rODJP5m+CTgfZGLg3qWlQ0ptGodcqoGdq6HSbV8nGFY6TvcLMaMtYQ1lm4v9G7Xe9LwjooR6xS3Gug==} - '@lexical/plain-text@0.41.0': - resolution: {integrity: sha512-HIsGgmFUYRUNNyvckun33UQfU7LRzDlxymHUq67+Bxd5bXqdZOrStEKJXuDX+LuLh/GXZbaWNbDLqwLBObfbQg==} + '@lexical/plain-text@0.42.0': + resolution: {integrity: sha512-YWvBwIxLltrIaZDcv0rK4s44P6Yt17yhOb0E+g3+tjF8GGPrrocox+Pglu0m2RHR+G7zULN3isolmWIm/HhWiw==} - '@lexical/react@0.41.0': - resolution: {integrity: sha512-7+GUdZUm6sofWm+zdsWAs6cFBwKNsvsHezZTrf6k8jrZxL461ZQmbz/16b4DvjCGL9r5P1fR7md9/LCmk8TiCg==} + '@lexical/react@0.42.0': + resolution: {integrity: sha512-ujWJXhvlFVVTpwDcnSgEYWRuqUbreZaMB+4bjIDT5r7hkAplUHQndlkeuFHKFiJBasSAreleV7zhXrLL5xa9eA==} peerDependencies: react: '>=17.x' react-dom: '>=17.x' - '@lexical/rich-text@0.41.0': - resolution: {integrity: sha512-yUcr7ZaaVTZNi8bow4CK1M8jy2qyyls1Vr+5dVjwBclVShOL/F/nFyzBOSb6RtXXRbd3Ahuk9fEleppX/RNIdw==} + '@lexical/rich-text@0.42.0': + resolution: {integrity: sha512-v4YgiM3oK3FZcRrfB+LetvLbQ5aee9MRO9tHf0EFweXg19XnSjHV0cfPAW7TyPxRELzB69+K0Q3AybRlTMjG4Q==} - '@lexical/selection@0.41.0': - resolution: {integrity: sha512-1s7/kNyRzcv5uaTwsUL28NpiisqTf5xZ1zNukLsCN1xY+TWbv9RE9OxIv+748wMm4pxNczQe/UbIBODkbeknLw==} + '@lexical/selection@0.42.0': + resolution: {integrity: sha512-iWTjLA5BSEuUnvWe9Xwu9FSdZFl3Yi0NqalabXKI+7KgCIlIVXE74y4NvWPUSLkSCB/Z1RPKiHmZqZ1vyu/yGQ==} - '@lexical/table@0.41.0': - resolution: {integrity: sha512-d3SPThBAr+oZ8O74TXU0iXM3rLbrAVC7/HcOnSAq7/AhWQW8yMutT51JQGN+0fMLP9kqoWSAojNtkdvzXfU/+A==} + '@lexical/table@0.42.0': + resolution: {integrity: sha512-GKiZyjQsHDXRckq5VBrOowyvds51WoVRECfDgcl8pqLMnKyEdCa58E7fkSJrr5LS80Scod+Cjn6SBRzOcdsrKg==} - '@lexical/text@0.41.0': - resolution: {integrity: sha512-gGA+Anc7ck110EXo4KVKtq6Ui3M7Vz3OpGJ4QE6zJHWW8nV5h273koUGSutAMeoZgRVb6t01Izh3ORoFt/j1CA==} + '@lexical/text@0.42.0': + resolution: {integrity: sha512-hT3EYVtBmONXyXe4TFVgtFcG1tf6JhLEuAf95+cOjgFGFSgvkZ/64BPbKLNTj2/9n6cU7EGPUNNwVigCSECJ2g==} - '@lexical/utils@0.41.0': - resolution: {integrity: sha512-Wlsokr5NQCq83D+7kxZ9qs5yQ3dU3Qaf2M+uXxLRoPoDaXqW8xTWZq1+ZFoEzsHzx06QoPa4Vu/40BZR91uQPg==} + '@lexical/utils@0.42.0': + resolution: {integrity: sha512-wGNdCW3QWEyVdFiSTLZfFPtiASPyYLcekIiYYZmoRVxVimT/jY+QPfnkO4JYgkO7Z70g/dsg9OhqyQSChQfvkQ==} - '@lexical/yjs@0.41.0': - resolution: {integrity: sha512-PaKTxSbVC4fpqUjQ7vUL9RkNF1PjL8TFl5jRe03PqoPYpE33buf3VXX6+cOUEfv9+uknSqLCPHoBS/4jN3a97w==} + '@lexical/yjs@0.42.0': + resolution: {integrity: sha512-DplzWnYhfFceGPR+UyDFpZdB287wF/vNOHFuDsBF/nGDdTezvr0Gf60opzyBEF3oXym6p3xTmGygxvO97LZ+vw==} peerDependencies: yjs: '>=13.5.22' @@ -1683,12 +2314,11 @@ packages: react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 - '@mswjs/interceptors@0.41.3': - resolution: {integrity: sha512-cXu86tF4VQVfwz8W1SPbhoRyHJkti6mjH/XJIxp40jhO4j2k1m4KYrEykxqWPkFF3vrK4rgQppBh//AwyGSXPA==} - engines: {node: '>=18'} - - '@napi-rs/wasm-runtime@1.1.1': - resolution: {integrity: sha512-p64ah1M1ld8xjWv3qbvFwHiFVWrq1yFvV4f7w+mzaqiR4IlSgkqhcRdHwsGgomwzBH51sRY4NEowLxnaBjcW/A==} + '@napi-rs/wasm-runtime@1.1.2': + resolution: {integrity: sha512-sNXv5oLJ7ob93xkZ1XnxisYhGYXfaG9f65/ZgYuAu3qt7b3NadcOEhLvx28hv31PgX8SZJRYrAIPQilQmFpLVw==} + peerDependencies: + '@emnapi/core': ^1.7.1 + '@emnapi/runtime': ^1.7.1 '@neoconfetti/react@1.0.0': resolution: {integrity: sha512-klcSooChXXOzIm+SE5IISIAn3bYzYfPjbX7D7HoqZL84oAfgREeSg5vSIaSFH+DaGzzvImTyWe1OyrJ67vik4A==} @@ -1696,14 +2326,14 @@ packages: '@next/env@16.0.0': resolution: {integrity: sha512-s5j2iFGp38QsG1LWRQaE2iUY3h1jc014/melHFfLdrsMJPqxqDQwWNwyQTcNoUSGZlCVZuM7t7JDMmSyRilsnA==} - '@next/env@16.2.0': - resolution: {integrity: sha512-OZIbODWWAi0epQRCRjNe1VO45LOFBzgiyqmTLzIqWq6u1wrxKnAyz1HH6tgY/Mc81YzIjRPoYsPAEr4QV4l9TA==} + '@next/env@16.2.1': + resolution: {integrity: sha512-n8P/HCkIWW+gVal2Z8XqXJ6aB3J0tuM29OcHpCsobWlChH/SITBs1DFBk/HajgrwDkqqBXPbuUuzgDvUekREPg==} - '@next/eslint-plugin-next@16.2.0': - resolution: {integrity: sha512-3D3pEMcGKfENC9Pzlkr67GOm+205+5hRdYPZvHuNIy5sr9k0ybSU8g+sxOO/R/RLEh/gWZ3UlY+5LmEyZ1xgXQ==} + '@next/eslint-plugin-next@16.2.1': + resolution: {integrity: sha512-r0epZGo24eT4g08jJlg2OEryBphXqO8aL18oajoTKLzHJ6jVr6P6FI58DLMug04MwD3j8Fj0YK0slyzneKVyzA==} - '@next/mdx@16.2.0': - resolution: {integrity: sha512-I+qgh34a9tNfZpz0TdMT8c6CjUEjatFx7njvQXKi3gbQtuRc5MyHYyyP7+GBtOpmtSUocnI+I+SaVQK/8UFIIw==} + '@next/mdx@16.2.1': + resolution: {integrity: sha512-w0YOkOc+WEnsTJ8uxzBOvpe3R+9BnJOxWCE7qcI/62CzJiUEd8JKtF25e3R8cW5BGsKyRW8p4zE2JLyXKa8xdw==} peerDependencies: '@mdx-js/loader': '>=0.15.0' '@mdx-js/react': '>=0.15.0' @@ -1713,54 +2343,54 @@ packages: '@mdx-js/react': optional: true - '@next/swc-darwin-arm64@16.2.0': - resolution: {integrity: sha512-/JZsqKzKt01IFoiLLAzlNqys7qk2F3JkcUhj50zuRhKDQkZNOz9E5N6wAQWprXdsvjRP4lTFj+/+36NSv5AwhQ==} + '@next/swc-darwin-arm64@16.2.1': + resolution: {integrity: sha512-BwZ8w8YTaSEr2HIuXLMLxIdElNMPvY9fLqb20LX9A9OMGtJilhHLbCL3ggyd0TwjmMcTxi0XXt+ur1vWUoxj2Q==} engines: {node: '>= 10'} cpu: [arm64] os: [darwin] - '@next/swc-darwin-x64@16.2.0': - resolution: {integrity: sha512-/hV8erWq4SNlVgglUiW5UmQ5Hwy5EW/AbbXlJCn6zkfKxTy/E/U3V8U1Ocm2YCTUoFgQdoMxRyRMOW5jYy4ygg==} + '@next/swc-darwin-x64@16.2.1': + resolution: {integrity: sha512-/vrcE6iQSJq3uL3VGVHiXeaKbn8Es10DGTGRJnRZlkNQQk3kaNtAJg8Y6xuAlrx/6INKVjkfi5rY0iEXorZ6uA==} engines: {node: '>= 10'} cpu: [x64] os: [darwin] - '@next/swc-linux-arm64-gnu@16.2.0': - resolution: {integrity: sha512-GkjL/Q7MWOwqWR9zoxu1TIHzkOI2l2BHCf7FzeQG87zPgs+6WDh+oC9Sw9ARuuL/FUk6JNCgKRkA6rEQYadUaw==} + '@next/swc-linux-arm64-gnu@16.2.1': + resolution: {integrity: sha512-uLn+0BK+C31LTVbQ/QU+UaVrV0rRSJQ8RfniQAHPghDdgE+SlroYqcmFnO5iNjNfVWCyKZHYrs3Nl0mUzWxbBw==} engines: {node: '>= 10'} cpu: [arm64] os: [linux] libc: [glibc] - '@next/swc-linux-arm64-musl@16.2.0': - resolution: {integrity: sha512-1ffhC6KY5qWLg5miMlKJp3dZbXelEfjuXt1qcp5WzSCQy36CV3y+JT7OC1WSFKizGQCDOcQbfkH/IjZP3cdRNA==} + '@next/swc-linux-arm64-musl@16.2.1': + resolution: {integrity: sha512-ssKq6iMRnHdnycGp9hCuGnXJZ0YPr4/wNwrfE5DbmvEcgl9+yv97/Kq3TPVDfYome1SW5geciLB9aiEqKXQjlQ==} engines: {node: '>= 10'} cpu: [arm64] os: [linux] libc: [musl] - '@next/swc-linux-x64-gnu@16.2.0': - resolution: {integrity: sha512-FmbDcZQ8yJRq93EJSL6xaE0KK/Rslraf8fj1uViGxg7K4CKBCRYSubILJPEhjSgZurpcPQq12QNOJQ0DRJl6Hg==} + '@next/swc-linux-x64-gnu@16.2.1': + resolution: {integrity: sha512-HQm7SrHRELJ30T1TSmT706IWovFFSRGxfgUkyWJZF/RKBMdbdRWJuFrcpDdE5vy9UXjFOx6L3mRdqH04Mmx0hg==} engines: {node: '>= 10'} cpu: [x64] os: [linux] libc: [glibc] - '@next/swc-linux-x64-musl@16.2.0': - resolution: {integrity: sha512-HzjIHVkmGAwRbh/vzvoBWWEbb8BBZPxBvVbDQDvzHSf3D8RP/4vjw7MNLDXFF9Q1WEzeQyEj2zdxBtVAHu5Oyw==} + '@next/swc-linux-x64-musl@16.2.1': + resolution: {integrity: sha512-aV2iUaC/5HGEpbBkE+4B8aHIudoOy5DYekAKOMSHoIYQ66y/wIVeaRx8MS2ZMdxe/HIXlMho4ubdZs/J8441Tg==} engines: {node: '>= 10'} cpu: [x64] os: [linux] libc: [musl] - '@next/swc-win32-arm64-msvc@16.2.0': - resolution: {integrity: sha512-UMiFNQf5H7+1ZsZPxEsA064WEuFbRNq/kEXyepbCnSErp4f5iut75dBA8UeerFIG3vDaQNOfCpevnERPp2V+nA==} + '@next/swc-win32-arm64-msvc@16.2.1': + resolution: {integrity: sha512-IXdNgiDHaSk0ZUJ+xp0OQTdTgnpx1RCfRTalhn3cjOP+IddTMINwA7DXZrwTmGDO8SUr5q2hdP/du4DcrB1GxA==} engines: {node: '>= 10'} cpu: [arm64] os: [win32] - '@next/swc-win32-x64-msvc@16.2.0': - resolution: {integrity: sha512-DRrNJKW+/eimrZgdhVN1uvkN1OI4j6Lpefwr44jKQ0YQzztlmOBUUzHuV5GxOMPK3nmodAYElUVCY8ZXo/IWeA==} + '@next/swc-win32-x64-msvc@16.2.1': + resolution: {integrity: sha512-qvU+3a39Hay+ieIztkGSbF7+mccbbg1Tk25hc4JDylf8IHjYmY/Zm64Qq1602yPyQqvie+vf5T/uPwNxDNIoeg==} engines: {node: '>= 10'} cpu: [x64] os: [win32] @@ -1777,6 +2407,10 @@ packages: resolution: {integrity: sha512-oGB+UxlgWcgQkgwo8GcEGwemoTFt3FIO9ababBmaGwXIoBKZ+GTy0pP185beGg7Llih/NSHSV2XAs1lnznocSg==} engines: {node: '>= 8'} + '@nolyfill/hasown@1.0.44': + resolution: {integrity: sha512-GA/21lkTr2PAQuT6jGnhLuBD5IFd/AEhBXJ/tf33+/bVxPxg+5ejKx9jGQGnyV/P0eSmdup5E+s8b2HL6lOrwQ==} + engines: {node: '>=12.4.0'} + '@nolyfill/is-core-module@1.0.39': resolution: {integrity: sha512-nn5ozdjYQpUCZlWGuxcJY/KpxkWQs4DcbMCmKojjyrYDEAGy4Ce19NN4v5MduafTwJlbKc99UA8YhSVqq9yPZA==} engines: {node: '>=12.4.0'} @@ -1789,87 +2423,178 @@ packages: resolution: {integrity: sha512-y3SvzjuY1ygnzWA4Krwx/WaJAsTMP11DN+e21A8Fa8PW1oDtVB5NSRW7LWurAiS2oKRkuCgcjTYMkBuBkcPCRg==} engines: {node: '>=12.4.0'} - '@octokit/auth-token@6.0.0': - resolution: {integrity: sha512-P4YJBPdPSpWTQ1NU4XYdvHvXJJDxM6YwpS0FZHRgP7YFkdVxsWcpWGy/NVqlAA7PcPCnMacXlRm1y2PFZRWL/w==} - engines: {node: '>= 20'} + '@orpc/client@1.13.13': + resolution: {integrity: sha512-jagx/Sa+9K4HEC5lBrUlMSrmR/06hvZctWh93/sKZc8GBk4zM0+71oT1kXQVw1oRYFV2XAq3xy3m6NdM6gfKYA==} - '@octokit/core@7.0.6': - resolution: {integrity: sha512-DhGl4xMVFGVIyMwswXeyzdL4uXD5OGILGX5N8Y+f6W7LhC1Ze2poSNrkF/fedpVDHEEZ+PHFW0vL14I+mm8K3Q==} - engines: {node: '>= 20'} + '@orpc/contract@1.13.13': + resolution: {integrity: sha512-md6iyrYkePBSJNs1VnVEEnAUORMDPHIf3JGRSHxyssIcNakev/iOjP0HvpH0Sx0MlTBhihAJo6uFL8Vpth58Nw==} - '@octokit/endpoint@11.0.3': - resolution: {integrity: sha512-FWFlNxghg4HrXkD3ifYbS/IdL/mDHjh9QcsNyhQjN8dplUoZbejsdpmuqdA76nxj2xoWPs7p8uX2SNr9rYu0Ag==} - engines: {node: '>= 20'} + '@orpc/openapi-client@1.13.13': + resolution: {integrity: sha512-k8od+bD7MqysKPPybAkxgfaNIaNseFPXtbidWkZAdCZ5w34SnDc7QPZJ0PQbyt9n9B+jOXSADNwQSTWSuGpjyA==} - '@octokit/graphql@9.0.3': - resolution: {integrity: sha512-grAEuupr/C1rALFnXTv6ZQhFuL1D8G5y8CN04RgrO4FIPMrtm+mcZzFG7dcBm+nq+1ppNixu+Jd78aeJOYxlGA==} - engines: {node: '>= 20'} - - '@octokit/openapi-types@27.0.0': - resolution: {integrity: sha512-whrdktVs1h6gtR+09+QsNk2+FO+49j6ga1c55YZudfEG+oKJVvJLQi3zkOm5JjiUXAagWK2tI2kTGKJ2Ys7MGA==} - - '@octokit/request-error@7.1.0': - resolution: {integrity: sha512-KMQIfq5sOPpkQYajXHwnhjCC0slzCNScLHs9JafXc4RAJI+9f+jNDlBNaIMTvazOPLgb4BnlhGJOTbnN0wIjPw==} - engines: {node: '>= 20'} - - '@octokit/request@10.0.8': - resolution: {integrity: sha512-SJZNwY9pur9Agf7l87ywFi14W+Hd9Jg6Ifivsd33+/bGUQIjNujdFiXII2/qSlN2ybqUHfp5xpekMEjIBTjlSw==} - engines: {node: '>= 20'} - - '@octokit/types@16.0.0': - resolution: {integrity: sha512-sKq+9r1Mm4efXW1FCk7hFSeJo4QKreL/tTbR0rz/qx/r1Oa2VV83LTA/H/MuCOX7uCIJmQVRKBcbmWoySjAnSg==} - - '@open-draft/deferred-promise@2.2.0': - resolution: {integrity: sha512-CecwLWx3rhxVQF6V4bAgPS5t+So2sTbPgAzafKkVizyi7tlwpcFpdFqq+wqF2OwNBmqFuu6tOyouTuxgpMfzmA==} - - '@open-draft/logger@0.3.0': - resolution: {integrity: sha512-X2g45fzhxH238HKO4xbSr7+wBS8Fvw6ixhTDuvLd5mqh6bJJCFAPwU9mPDxbcrRtfxv4u5IHCEH77BmxvXmmxQ==} - - '@open-draft/until@2.1.0': - resolution: {integrity: sha512-U69T3ItWHvLwGg5eJ0n3I62nWuE6ilHlmz7zM0npLBRvPRd7e6NYmg54vvRtP5mZG7kZqZCFVdsTWo7BPtBujg==} - - '@orpc/client@1.13.8': - resolution: {integrity: sha512-7B8NDjBjP17Mrrgc/YeZl9b0YBu2Sk9/lKyVeG3755tyrAPLiezWuwQEaP9T45S2/+g8LTzFmV2R504Wn5R5MQ==} - - '@orpc/contract@1.13.8': - resolution: {integrity: sha512-W8hjVYDnsHI63TgQUGB4bb+ldCqR5hdxL1o2b7ytkFEkXTft6HOrHHvv+ncmgK1c1XapD1ScsCj11zzxf5NUGQ==} - - '@orpc/openapi-client@1.13.8': - resolution: {integrity: sha512-Cg7oDhbiO9bPpseRaFeWIhZFoA1bCF2pPxAJZj6/YtHkh+VSDI8W1xzbzoKNp2YHnhhJfgpIuVsHD42tX73+Mw==} - - '@orpc/shared@1.13.8': - resolution: {integrity: sha512-d7bZW2F8/ov6JFuGEMeh7XYZtW4+zgjxW5DKBv5tNkWmZEC5JJQz8l6Ym9ZRe2VyRzQgo5JarJGsVQlmqVVvhw==} + '@orpc/shared@1.13.13': + resolution: {integrity: sha512-kNpYOBjHvmgKHla6munWOaEeA0utEfAvoiZpXjiRjjt1RxTibdwQvVHgxRIBNMXfQsb+ON3Q/wDkoaUhvvSnIw==} peerDependencies: '@opentelemetry/api': '>=1.9.0' peerDependenciesMeta: '@opentelemetry/api': optional: true - '@orpc/standard-server-fetch@1.13.8': - resolution: {integrity: sha512-g26Loo7GFTCF/S5QsM3Z6Xd9ZYs90K7jtRtEqbJh03YNrjecvZdpUKd/lTf/9kpJTBTQbhFxC9WCAJH4+8leFA==} + '@orpc/standard-server-fetch@1.13.13': + resolution: {integrity: sha512-Lffy26+WtCQkwOUacsrdyeJF1GNzrhm75O3LXKVFXqmSdyVVdyI6zuqLn/YKGODU2L9IqGxZ2CwsV2tE298SSA==} - '@orpc/standard-server-peer@1.13.8': - resolution: {integrity: sha512-ZyzWT6zZnLJkX15r04ecSDAJmkQ46PXTovORmK7RzOV47qIB7IryiRGR60U4WygBX0VDzZU8cgcXidZTx4v7oA==} + '@orpc/standard-server-peer@1.13.13': + resolution: {integrity: sha512-FeWAbXfnZDPYQRajM0hD6GJvHeC3DZILngAjdcLHy5zt3riu6nL2lLPSWDv5yNWWscmYU+CfKmXWd0Z01BOeWA==} - '@orpc/standard-server@1.13.8': - resolution: {integrity: sha512-/v72eRSPFzWt6SoHDC04cjZfwdW94z3aib7dMBat32aK3eXwfRZmwPPmfVBQO/ZlJYlq+5rSdPoMKkSoirG/5Q==} + '@orpc/standard-server@1.13.13': + resolution: {integrity: sha512-9pgS8XvauuRQElkyuD8F3om+nN0KBEnTkhblDHCBzkZERjWkmfirJmshQrWHoFaDTk+nnXHIaY6d7TBTxXdPRw==} - '@orpc/tanstack-query@1.13.8': - resolution: {integrity: sha512-ZUwwkAqoGPOCs8gBG7w6vVNxUOAJyTBVUuclmZoyTdbb5xgMVtUGCvyjiwaWOSoL4+N2urZBbvNdTbEMsuoqLQ==} + '@orpc/tanstack-query@1.13.13': + resolution: {integrity: sha512-6+Cheaiu+RDPdszdeRKoBINrF8MQp64zSeZB+L3gqgF43zlYDhLOgELZMzYa6U3U6bLk4rmIeubpk+i1kACfRg==} peerDependencies: - '@orpc/client': 1.13.8 + '@orpc/client': 1.13.13 '@tanstack/query-core': '>=5.80.2' '@ota-meshi/ast-token-store@0.3.0': resolution: {integrity: sha512-XRO0zi2NIUKq2lUk3T1ecFSld1fMWRKE6naRFGkgkdeosx7IslyUKNv5Dcb5PJTja9tHJoFu0v/7yEpAkrkrTg==} engines: {node: ^20.19.0 || ^22.13.0 || >=24} - '@oxc-project/runtime@0.115.0': - resolution: {integrity: sha512-Rg8Wlt5dCbXhQnsXPrkOjL1DTSvXLgb2R/KYfnf1/K+R0k6UMLEmbQXPM+kwrWqSmWA2t0B1EtHy2/3zikQpvQ==} + '@oxc-parser/binding-android-arm-eabi@0.121.0': + resolution: {integrity: sha512-n07FQcySwOlzap424/PLMtOkbS7xOu8nsJduKL8P3COGHKgKoDYXwoAHCbChfgFpHnviehrLWIPX0lKGtbEk/A==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm] + os: [android] + + '@oxc-parser/binding-android-arm64@0.121.0': + resolution: {integrity: sha512-/Dd1xIXboYAicw+twT2utxPD7bL8qh7d3ej0qvaYIMj3/EgIrGR+tSnjCUkiCT6g6uTC0neSS4JY8LxhdSU/sA==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [android] + + '@oxc-parser/binding-darwin-arm64@0.121.0': + resolution: {integrity: sha512-A0jNEvv7QMtCO1yk205t3DWU9sWUjQ2KNF0hSVO5W9R9r/R1BIvzG01UQAfmtC0dQm7sCrs5puixurKSfr2bRQ==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [darwin] + + '@oxc-parser/binding-darwin-x64@0.121.0': + resolution: {integrity: sha512-SsHzipdxTKUs3I9EOAPmnIimEeJOemqRlRDOp9LIj+96wtxZejF51gNibmoGq8KoqbT1ssAI5po/E3J+vEtXGA==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [darwin] + + '@oxc-parser/binding-freebsd-x64@0.121.0': + resolution: {integrity: sha512-v1APOTkCp+RWOIDAHRoaeW/UoaHF15a60E8eUL6kUQXh+i4K7PBwq2Wi7jm8p0ymID5/m/oC1w3W31Z/+r7HQw==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [freebsd] + + '@oxc-parser/binding-linux-arm-gnueabihf@0.121.0': + resolution: {integrity: sha512-PmqPQuqHZyFVWA4ycr0eu4VnTMmq9laOHZd+8R359w6kzuNZPvmmunmNJ8ybkm769A0nCoVp3TJ6dUz7B3FYIQ==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm] + os: [linux] + + '@oxc-parser/binding-linux-arm-musleabihf@0.121.0': + resolution: {integrity: sha512-vF24htj+MOH+Q7y9A8NuC6pUZu8t/C2Fr/kDOi2OcNf28oogr2xadBPXAbml802E8wRAVfbta6YLDQTearz+jw==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm] + os: [linux] + + '@oxc-parser/binding-linux-arm64-gnu@0.121.0': + resolution: {integrity: sha512-wjH8cIG2Lu/3d64iZpbYr73hREMgKAfu7fqpXjgM2S16y2zhTfDIp8EQjxO8vlDtKP5Rc7waZW72lh8nZtWrpA==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [linux] + libc: [glibc] + + '@oxc-parser/binding-linux-arm64-musl@0.121.0': + resolution: {integrity: sha512-qT663J/W8yQFw3dtscbEi9LKJevr20V7uWs2MPGTnvNZ3rm8anhhE16gXGpxDOHeg9raySaSHKhd4IGa3YZvuw==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [linux] + libc: [musl] + + '@oxc-parser/binding-linux-ppc64-gnu@0.121.0': + resolution: {integrity: sha512-mYNe4NhVvDBbPkAP8JaVS8lC1dsoJZWH5WCjpw5E+sjhk1R08wt3NnXYUzum7tIiWPfgQxbCMcoxgeemFASbRw==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [ppc64] + os: [linux] + libc: [glibc] + + '@oxc-parser/binding-linux-riscv64-gnu@0.121.0': + resolution: {integrity: sha512-+QiFoGxhAbaI/amqX567784cDyyuZIpinBrJNxUzb+/L2aBRX67mN6Jv40pqduHf15yYByI+K5gUEygCuv0z9w==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [riscv64] + os: [linux] + libc: [glibc] + + '@oxc-parser/binding-linux-riscv64-musl@0.121.0': + resolution: {integrity: sha512-9ykEgyTa5JD/Uhv2sttbKnCfl2PieUfOjyxJC/oDL2UO0qtXOtjPLl7H8Kaj5G7p3hIvFgu3YWvAxvE0sqY+hQ==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [riscv64] + os: [linux] + libc: [musl] + + '@oxc-parser/binding-linux-s390x-gnu@0.121.0': + resolution: {integrity: sha512-DB1EW5VHZdc1lIRjOI3bW/wV6R6y0xlfvdVrqj6kKi7Ayu2U3UqUBdq9KviVkcUGd5Oq+dROqvUEEFRXGAM7EQ==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [s390x] + os: [linux] + libc: [glibc] + + '@oxc-parser/binding-linux-x64-gnu@0.121.0': + resolution: {integrity: sha512-s4lfobX9p4kPTclvMiH3gcQUd88VlnkMTF6n2MTMDAyX5FPNRhhRSFZK05Ykhf8Zy5NibV4PbGR6DnK7FGNN6A==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [linux] + libc: [glibc] + + '@oxc-parser/binding-linux-x64-musl@0.121.0': + resolution: {integrity: sha512-P9KlyTpuBuMi3NRGpJO8MicuGZfOoqZVRP1WjOecwx8yk4L/+mrCRNc5egSi0byhuReblBF2oVoDSMgV9Bj4Hw==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [linux] + libc: [musl] + + '@oxc-parser/binding-openharmony-arm64@0.121.0': + resolution: {integrity: sha512-R+4jrWOfF2OAPPhj3Eb3U5CaKNAH9/btMveMULIrcNW/hjfysFQlF8wE0GaVBr81dWz8JLgQlsxwctoL78JwXw==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [openharmony] + + '@oxc-parser/binding-wasm32-wasi@0.121.0': + resolution: {integrity: sha512-5TFISkPTymKvsmIlKasPVTPuWxzCcrT8pM+p77+mtQbIZDd1UC8zww4CJcRI46kolmgrEX6QpKO8AvWMVZ+ifw==} + engines: {node: '>=14.0.0'} + cpu: [wasm32] + + '@oxc-parser/binding-win32-arm64-msvc@0.121.0': + resolution: {integrity: sha512-V0pxh4mql4XTt3aiEtRNUeBAUFOw5jzZNxPABLaOKAWrVzSr9+XUaB095lY7jqMf5t8vkfh8NManGB28zanYKw==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [win32] + + '@oxc-parser/binding-win32-ia32-msvc@0.121.0': + resolution: {integrity: sha512-4Ob1qvYMPnlF2N9rdmKdkQFdrq16QVcQwBsO8yiPZXof0fHKFF+LmQV501XFbi7lHyrKm8rlJRfQ/M8bZZPVLw==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [ia32] + os: [win32] + + '@oxc-parser/binding-win32-x64-msvc@0.121.0': + resolution: {integrity: sha512-BOp1KCzdboB1tPqoCPXgntgFs0jjeSyOXHzgxVFR7B/qfr3F8r4YDacHkTOUNXtDgM8YwKnkf3rE5gwALYX7NA==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [win32] + + '@oxc-project/runtime@0.121.0': + resolution: {integrity: sha512-p0bQukD8OEHxzY4T9OlANBbEFGnOnjo1CYi50HES7OD36UO2yPh6T+uOJKLtlg06eclxroipRCpQGMpeH8EJ/g==} engines: {node: ^20.19.0 || >=22.12.0} - '@oxc-project/types@0.115.0': - resolution: {integrity: sha512-4n91DKnebUS4yjUHl2g3/b2T+IUdCfmoZGhmwsovZCDaJSs+QkVAM+0AqqTxHSsHfeiMuueT75cZaZcT/m0pSw==} + '@oxc-project/types@0.121.0': + resolution: {integrity: sha512-CGtOARQb9tyv7ECgdAlFxi0Fv7lmzvmlm2rpD/RdijOO9rfk/JvB1CjT8EnoD+tjna/IYgKKw3IV7objRb+aYw==} + + '@oxc-project/types@0.122.0': + resolution: {integrity: sha512-oLAl5kBpV4w69UtFZ9xqcmTi+GENWOcPF7FCrczTiBbmC0ibXxCwyvZGbO39rCVEuLGAZM84DH0pUIyyv/YJzA==} '@oxc-resolver/binding-android-arm-eabi@11.19.1': resolution: {integrity: sha512-aUs47y+xyXHUKlbhqHUjBABjvycq6YSD7bpxSW7vplUmdzAlJ93yXY6ZR0c1o1x5A/QKbENCvs3+NlY8IpIVzg==} @@ -1979,276 +2704,276 @@ packages: cpu: [x64] os: [win32] - '@oxfmt/binding-android-arm-eabi@0.40.0': - resolution: {integrity: sha512-S6zd5r1w/HmqR8t0CTnGjFTBLDq2QKORPwriCHxo4xFNuhmOTABGjPaNvCJJVnrKBLsohOeiDX3YqQfJPF+FXw==} + '@oxfmt/binding-android-arm-eabi@0.42.0': + resolution: {integrity: sha512-dsqPTYsozeokRjlrt/b4E7Pj0z3eS3Eg74TWQuuKbjY4VttBmA88rB7d50Xrd+TZ986qdXCNeZRPEzZHAe+jow==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [arm] os: [android] - '@oxfmt/binding-android-arm64@0.40.0': - resolution: {integrity: sha512-/mbS9UUP/5Vbl2D6osIdcYiP0oie63LKMoTyGj5hyMCK/SFkl3EhtyRAfdjPvuvHC0SXdW6ePaTKkBSq1SNcIw==} + '@oxfmt/binding-android-arm64@0.42.0': + resolution: {integrity: sha512-t+aAjHxcr5eOBphFHdg1ouQU9qmZZoRxnX7UOJSaTwSoKsb6TYezNKO0YbWytGXCECObRqNcUxPoPr0KaraAIg==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [arm64] os: [android] - '@oxfmt/binding-darwin-arm64@0.40.0': - resolution: {integrity: sha512-wRt8fRdfLiEhnRMBonlIbKrJWixoEmn6KCjKE9PElnrSDSXETGZfPb8ee+nQNTobXkCVvVLytp2o0obAsxl78Q==} + '@oxfmt/binding-darwin-arm64@0.42.0': + resolution: {integrity: sha512-ulpSEYMKg61C5bRMZinFHrKJYRoKGVbvMEXA5zM1puX3O9T6Q4XXDbft20yrDijpYWeuG59z3Nabt+npeTsM1A==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [arm64] os: [darwin] - '@oxfmt/binding-darwin-x64@0.40.0': - resolution: {integrity: sha512-fzowhqbOE/NRy+AE5ob0+Y4X243WbWzDb00W+pKwD7d9tOqsAFbtWUwIyqqCoCLxj791m2xXIEeLH/3uz7zCCg==} + '@oxfmt/binding-darwin-x64@0.42.0': + resolution: {integrity: sha512-ttxLKhQYPdFiM8I/Ri37cvqChE4Xa562nNOsZFcv1CKTVLeEozXjKuYClNvxkXmNlcF55nzM80P+CQkdFBu+uQ==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [x64] os: [darwin] - '@oxfmt/binding-freebsd-x64@0.40.0': - resolution: {integrity: sha512-agZ9ITaqdBjcerRRFEHB8s0OyVcQW8F9ZxsszjxzeSthQ4fcN2MuOtQFWec1ed8/lDa50jSLHVE2/xPmTgtCfQ==} + '@oxfmt/binding-freebsd-x64@0.42.0': + resolution: {integrity: sha512-Og7QS3yI3tdIKYZ58SXik0rADxIk2jmd+/YvuHRyKULWpG4V2fR5V4hvKm624Mc0cQET35waPXiCQWvjQEjwYQ==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [x64] os: [freebsd] - '@oxfmt/binding-linux-arm-gnueabihf@0.40.0': - resolution: {integrity: sha512-ZM2oQ47p28TP1DVIp7HL1QoMUgqlBFHey0ksHct7tMXoU5BqjNvPWw7888azzMt25lnyPODVuye1wvNbvVUFOA==} + '@oxfmt/binding-linux-arm-gnueabihf@0.42.0': + resolution: {integrity: sha512-jwLOw/3CW4H6Vxcry4/buQHk7zm9Ne2YsidzTL1kpiMe4qqrRCwev3dkyWe2YkFmP+iZCQ7zku4KwjcLRoh8ew==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [arm] os: [linux] - '@oxfmt/binding-linux-arm-musleabihf@0.40.0': - resolution: {integrity: sha512-RBFPAxRAIsMisKM47Oe6Lwdv6agZYLz02CUhVCD1sOv5ajAcRMrnwCFBPWwGXpazToW2mjnZxFos8TuFjTU15A==} + '@oxfmt/binding-linux-arm-musleabihf@0.42.0': + resolution: {integrity: sha512-XwXu2vkMtiq2h7tfvN+WA/9/5/1IoGAVCFPiiQUvcAuG3efR97KNcRGM8BetmbYouFotQ2bDal3yyjUx6IPsTg==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [arm] os: [linux] - '@oxfmt/binding-linux-arm64-gnu@0.40.0': - resolution: {integrity: sha512-Nb2XbQ+wV3W2jSIihXdPj7k83eOxeSgYP3N/SRXvQ6ZYPIk6Q86qEh5Gl/7OitX3bQoQrESqm1yMLvZV8/J7dA==} + '@oxfmt/binding-linux-arm64-gnu@0.42.0': + resolution: {integrity: sha512-ea7s/XUJoT7ENAtUQDudFe3nkSM3e3Qpz4nJFRdzO2wbgXEcjnchKLEsV3+t4ev3r8nWxIYr9NRjPWtnyIFJVA==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [arm64] os: [linux] libc: [glibc] - '@oxfmt/binding-linux-arm64-musl@0.40.0': - resolution: {integrity: sha512-tGmWhLD/0YMotCdfezlT6tC/MJG/wKpo4vnQ3Cq+4eBk/BwNv7EmkD0VkD5F/dYkT3b8FNU01X2e8vvJuWoM1w==} + '@oxfmt/binding-linux-arm64-musl@0.42.0': + resolution: {integrity: sha512-+JA0YMlSdDqmacygGi2REp57c3fN+tzARD8nwsukx9pkCHK+6DkbAA9ojS4lNKsiBjIW8WWa0pBrBWhdZEqfuw==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [arm64] os: [linux] libc: [musl] - '@oxfmt/binding-linux-ppc64-gnu@0.40.0': - resolution: {integrity: sha512-rVbFyM3e7YhkVnp0IVYjaSHfrBWcTRWb60LEcdNAJcE2mbhTpbqKufx0FrhWfoxOrW/+7UJonAOShoFFLigDqQ==} + '@oxfmt/binding-linux-ppc64-gnu@0.42.0': + resolution: {integrity: sha512-VfnET0j4Y5mdfCzh5gBt0NK28lgn5DKx+8WgSMLYYeSooHhohdbzwAStLki9pNuGy51y4I7IoW8bqwAaCMiJQg==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [ppc64] os: [linux] libc: [glibc] - '@oxfmt/binding-linux-riscv64-gnu@0.40.0': - resolution: {integrity: sha512-3ZqBw14JtWeEoLiioJcXSJz8RQyPE+3jLARnYM1HdPzZG4vk+Ua8CUupt2+d+vSAvMyaQBTN2dZK+kbBS/j5mA==} + '@oxfmt/binding-linux-riscv64-gnu@0.42.0': + resolution: {integrity: sha512-gVlCbmBkB0fxBWbhBj9rcxezPydsQHf4MFKeHoTSPicOQ+8oGeTQgQ8EeesSybWeiFPVRx3bgdt4IJnH6nOjAA==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [riscv64] os: [linux] libc: [glibc] - '@oxfmt/binding-linux-riscv64-musl@0.40.0': - resolution: {integrity: sha512-JJ4PPSdcbGBjPvb+O7xYm2FmAsKCyuEMYhqatBAHMp/6TA6rVlf9Z/sYPa4/3Bommb+8nndm15SPFRHEPU5qFA==} + '@oxfmt/binding-linux-riscv64-musl@0.42.0': + resolution: {integrity: sha512-zN5OfstL0avgt/IgvRu0zjQzVh/EPkcLzs33E9LMAzpqlLWiPWeMDZyMGFlSRGOdDjuNmlZBCgj0pFnK5u32TQ==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [riscv64] os: [linux] libc: [musl] - '@oxfmt/binding-linux-s390x-gnu@0.40.0': - resolution: {integrity: sha512-Kp0zNJoX9Ik77wUya2tpBY3W9f40VUoMQLWVaob5SgCrblH/t2xr/9B2bWHfs0WCefuGmqXcB+t0Lq77sbBmZw==} + '@oxfmt/binding-linux-s390x-gnu@0.42.0': + resolution: {integrity: sha512-9X6+H2L0qMc2sCAgO9HS03bkGLMKvOFjmEdchaFlany3vNZOjnVui//D8k/xZAtQv2vaCs1reD5KAgPoIU4msA==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [s390x] os: [linux] libc: [glibc] - '@oxfmt/binding-linux-x64-gnu@0.40.0': - resolution: {integrity: sha512-7YTCNzleWTaQTqNGUNQ66qVjpoV6DjbCOea+RnpMBly2bpzrI/uu7Rr+2zcgRfNxyjXaFTVQKaRKjqVdeUfeVA==} + '@oxfmt/binding-linux-x64-gnu@0.42.0': + resolution: {integrity: sha512-BajxJ6KQvMMdpXGPWhBGyjb2Jvx4uec0w+wi6TJZ6Tv7+MzPwe0pO8g5h1U0jyFgoaF7mDl6yKPW3ykWcbUJRw==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [x64] os: [linux] libc: [glibc] - '@oxfmt/binding-linux-x64-musl@0.40.0': - resolution: {integrity: sha512-hWnSzJ0oegeOwfOEeejYXfBqmnRGHusgtHfCPzmvJvHTwy1s3Neo59UKc1CmpE3zxvrCzJoVHos0rr97GHMNPw==} + '@oxfmt/binding-linux-x64-musl@0.42.0': + resolution: {integrity: sha512-0wV284I6vc5f0AqAhgAbHU2935B4bVpncPoe5n/WzVZY/KnHgqxC8iSFGeSyLWEgstFboIcWkOPck7tqbdHkzA==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [x64] os: [linux] libc: [musl] - '@oxfmt/binding-openharmony-arm64@0.40.0': - resolution: {integrity: sha512-28sJC1lR4qtBJGzSRRbPnSW3GxU2+4YyQFE6rCmsUYqZ5XYH8jg0/w+CvEzQ8TuAQz5zLkcA25nFQGwoU0PT3Q==} + '@oxfmt/binding-openharmony-arm64@0.42.0': + resolution: {integrity: sha512-p4BG6HpGnhfgHk1rzZfyR6zcWkE7iLrWxyehHfXUy4Qa5j3e0roglFOdP/Nj5cJJ58MA3isQ5dlfkW2nNEpolw==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [arm64] os: [openharmony] - '@oxfmt/binding-win32-arm64-msvc@0.40.0': - resolution: {integrity: sha512-cDkRnyT0dqwF5oIX1Cv59HKCeZQFbWWdUpXa3uvnHFT2iwYSSZspkhgjXjU6iDp5pFPaAEAe9FIbMoTgkTmKPg==} + '@oxfmt/binding-win32-arm64-msvc@0.42.0': + resolution: {integrity: sha512-mn//WV60A+IetORDxYieYGAoQso4KnVRRjORDewMcod4irlRe0OSC7YPhhwaexYNPQz/GCFk+v9iUcZ2W22yxQ==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [arm64] os: [win32] - '@oxfmt/binding-win32-ia32-msvc@0.40.0': - resolution: {integrity: sha512-7rPemBJjqm5Gkv6ZRCPvK8lE6AqQ/2z31DRdWazyx2ZvaSgL7QGofHXHNouRpPvNsT9yxRNQJgigsWkc+0qg4w==} + '@oxfmt/binding-win32-ia32-msvc@0.42.0': + resolution: {integrity: sha512-3gWltUrvuz4LPJXWivoAxZ28Of2O4N7OGuM5/X3ubPXCEV8hmgECLZzjz7UYvSDUS3grfdccQwmjynm+51EFpw==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [ia32] os: [win32] - '@oxfmt/binding-win32-x64-msvc@0.40.0': - resolution: {integrity: sha512-/Zmj0yTYSvmha6TG1QnoLqVT7ZMRDqXvFXXBQpIjteEwx9qvUYMBH2xbiOFhDeMUJkGwC3D6fdKsFtaqUvkwNA==} + '@oxfmt/binding-win32-x64-msvc@0.42.0': + resolution: {integrity: sha512-Wg4TMAfQRL9J9AZevJ/ZNy3uyyDztDYQtGr4P8UyyzIhLhFrdSmz1J/9JT+rv0fiCDLaFOBQnj3f3K3+a5PzDQ==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [x64] os: [win32] - '@oxlint-tsgolint/darwin-arm64@0.17.0': - resolution: {integrity: sha512-z3XwCDuOAKgk7bO4y5tyH8Zogwr51G56R0XGKC3tlAbrAq8DecoxAd3qhRZqWBMG2Gzl5bWU3Ghu7lrxuLPzYw==} + '@oxlint-tsgolint/darwin-arm64@0.17.3': + resolution: {integrity: sha512-5aDl4mxXWs+Bj02pNrX6YY6v9KMZjLIytXoqolLEo0dfBNVeZUonZgJAa/w0aUmijwIRrBhxEzb42oLuUtfkGw==} cpu: [arm64] os: [darwin] - '@oxlint-tsgolint/darwin-x64@0.17.0': - resolution: {integrity: sha512-TZgVXy0MtI8nt0MYiceuZhHPwHcwlIZ/YwzFTAKrgdHiTvVzFbqHVdXi5wbZfT/o1nHGw9fbGWPlb6qKZ4uZ9Q==} + '@oxlint-tsgolint/darwin-x64@0.17.3': + resolution: {integrity: sha512-gPBy4DS5ueCgXzko20XsNZzDe/Cxde056B+QuPLGvz05CGEAtmRfpImwnyY2lAXXjPL+SmnC/OYexu8zI12yHQ==} cpu: [x64] os: [darwin] - '@oxlint-tsgolint/linux-arm64@0.17.0': - resolution: {integrity: sha512-IDfhFl/Y8bjidCvAP6QAxVyBsl78TmfCHlfjtEv2XtJXgYmIwzv6muO18XMp74SZ2qAyD4y2n2dUedrmghGHeA==} + '@oxlint-tsgolint/linux-arm64@0.17.3': + resolution: {integrity: sha512-+pkunvCfB6pB0G9qHVVXUao3nqzXQPo4O3DReIi+5nGa+bOU3J3Srgy+Zb8VyOL+WDsSMJ+U7+r09cKHWhz3hg==} cpu: [arm64] os: [linux] - '@oxlint-tsgolint/linux-x64@0.17.0': - resolution: {integrity: sha512-Bgdgqx/m8EnfjmmlRLEeYy9Yhdt1GdFrMr5mTu/NyLRGkB1C9VLAikdxB7U9QambAGTAmjMbHNFDFk8Vx69Huw==} + '@oxlint-tsgolint/linux-x64@0.17.3': + resolution: {integrity: sha512-/kW5oXtBThu4FjmgIBthdmMjWLzT3M1TEDQhxDu7hQU5xDeTd60CDXb2SSwKCbue9xu7MbiFoJu83LN0Z/d38g==} cpu: [x64] os: [linux] - '@oxlint-tsgolint/win32-arm64@0.17.0': - resolution: {integrity: sha512-dO6wyKMDqFWh1vwr+zNZS7/ovlfGgl4S3P1LDy4CKjP6V6NGtdmEwWkWax8j/I8RzGZdfXKnoUfb/qhVg5bx0w==} + '@oxlint-tsgolint/win32-arm64@0.17.3': + resolution: {integrity: sha512-NMELRvbz4Ed4dxg8WiqZxtu3k4OJEp2B9KInZW+BMfqEqbwZdEJY83tbqz2hD1EjKO2akrqBQ0GpRUJEkd8kKw==} cpu: [arm64] os: [win32] - '@oxlint-tsgolint/win32-x64@0.17.0': - resolution: {integrity: sha512-lPGYFp3yX2nh6hLTpIuMnJbZnt3Df42VkoA/fSkMYi2a/LXdDytQGpgZOrb5j47TICARd34RauKm0P3OA4Oxbw==} + '@oxlint-tsgolint/win32-x64@0.17.3': + resolution: {integrity: sha512-+pJ7r8J3SLPws5uoidVplZc8R/lpKyKPE6LoPGv9BME00Y1VjT6jWGx/dtUN8PWvcu3iTC6k+8u3ojFSJNmWTg==} cpu: [x64] os: [win32] - '@oxlint/binding-android-arm-eabi@1.55.0': - resolution: {integrity: sha512-NhvgAhncTSOhRahQSCnkK/4YIGPjTmhPurQQ2dwt2IvwCMTvZRW5vF2K10UBOxFve4GZDMw6LtXZdC2qeuYIVQ==} + '@oxlint/binding-android-arm-eabi@1.57.0': + resolution: {integrity: sha512-C7EiyfAJG4B70496eV543nKiq5cH0o/xIh/ufbjQz3SIvHhlDDsyn+mRFh+aW8KskTyUpyH2LGWL8p2oN6bl1A==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [arm] os: [android] - '@oxlint/binding-android-arm64@1.55.0': - resolution: {integrity: sha512-P9iWRh+Ugqhg+D7rkc7boHX8o3H2h7YPcZHQIgvVBgnua5tk4LR2L+IBlreZs58/95cd2x3/004p5VsQM9z4SA==} + '@oxlint/binding-android-arm64@1.57.0': + resolution: {integrity: sha512-9i80AresjZ/FZf5xK8tKFbhQnijD4s1eOZw6/FHUwD59HEZbVLRc2C88ADYJfLZrF5XofWDiRX/Ja9KefCLy7w==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [arm64] os: [android] - '@oxlint/binding-darwin-arm64@1.55.0': - resolution: {integrity: sha512-esakkJIt7WFAhT30P/Qzn96ehFpzdZ1mNuzpOb8SCW7lI4oB8VsyQnkSHREM671jfpuBb/o2ppzBCx5l0jpgMA==} + '@oxlint/binding-darwin-arm64@1.57.0': + resolution: {integrity: sha512-0eUfhRz5L2yKa9I8k3qpyl37XK3oBS5BvrgdVIx599WZK63P8sMbg+0s4IuxmIiZuBK68Ek+Z+gcKgeYf0otsg==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [arm64] os: [darwin] - '@oxlint/binding-darwin-x64@1.55.0': - resolution: {integrity: sha512-xDMFRCCAEK9fOH6As2z8ELsC+VDGSFRHwIKVSilw+xhgLwTDFu37rtmRbmUlx8rRGS6cWKQPTc47AVxAZEVVPQ==} + '@oxlint/binding-darwin-x64@1.57.0': + resolution: {integrity: sha512-UvrSuzBaYOue+QMAcuDITe0k/Vhj6KZGjfnI6x+NkxBTke/VoM7ZisaxgNY0LWuBkTnd1OmeQfEQdQ48fRjkQg==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [x64] os: [darwin] - '@oxlint/binding-freebsd-x64@1.55.0': - resolution: {integrity: sha512-mYZqnwUD7ALCRxGenyLd1uuG+rHCL+OTT6S8FcAbVm/ZT2AZMGjvibp3F6k1SKOb2aeqFATmwRykrE41Q0GWVw==} + '@oxlint/binding-freebsd-x64@1.57.0': + resolution: {integrity: sha512-wtQq0dCoiw4bUwlsNVDJJ3pxJA218fOezpgtLKrbQqUtQJcM9yP8z+I9fu14aHg0uyAxIY+99toL6uBa2r7nxA==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [x64] os: [freebsd] - '@oxlint/binding-linux-arm-gnueabihf@1.55.0': - resolution: {integrity: sha512-LcX6RYcF9vL9ESGwJW3yyIZ/d/ouzdOKXxCdey1q0XJOW1asrHsIg5MmyKdEBR4plQx+shvYeQne7AzW5f3T1w==} + '@oxlint/binding-linux-arm-gnueabihf@1.57.0': + resolution: {integrity: sha512-qxFWl2BBBFcT4djKa+OtMdnLgoHEJXpqjyGwz8OhW35ImoCwR5qtAGqApNYce5260FQqoAHW8S8eZTjiX67Tsg==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [arm] os: [linux] - '@oxlint/binding-linux-arm-musleabihf@1.55.0': - resolution: {integrity: sha512-C+8GS1rPtK+dI7mJFkqoRBkDuqbrNihnyYQsJPS9ez+8zF9JzfvU19lawqt4l/Y23o5uQswE/DORa8aiXUih3w==} + '@oxlint/binding-linux-arm-musleabihf@1.57.0': + resolution: {integrity: sha512-SQoIsBU7J0bDW15/f0/RvxHfY3Y0+eB/caKBQtNFbuerTiA6JCYx9P1MrrFTwY2dTm/lMgTSgskvCEYk2AtG/Q==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [arm] os: [linux] - '@oxlint/binding-linux-arm64-gnu@1.55.0': - resolution: {integrity: sha512-ErLE4XbmcCopA4/CIDiH6J1IAaDOMnf/KSx/aFObs4/OjAAM3sFKWGZ57pNOMxhhyBdcmcXwYymph9GwcpcqgQ==} + '@oxlint/binding-linux-arm64-gnu@1.57.0': + resolution: {integrity: sha512-jqxYd1W6WMeozsCmqe9Rzbu3SRrGTyGDAipRlRggetyYbUksJqJKvUNTQtZR/KFoJPb+grnSm5SHhdWrywv3RQ==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [arm64] os: [linux] libc: [glibc] - '@oxlint/binding-linux-arm64-musl@1.55.0': - resolution: {integrity: sha512-/kp65avi6zZfqEng56TTuhiy3P/3pgklKIdf38yvYeJ9/PgEeRA2A2AqKAKbZBNAqUzrzHhz9jF6j/PZvhJzTQ==} + '@oxlint/binding-linux-arm64-musl@1.57.0': + resolution: {integrity: sha512-i66WyEPVEvq9bxRUCJ/MP5EBfnTDN3nhwEdFZFTO5MmLLvzngfWEG3NSdXQzTT3vk5B9i6C2XSIYBh+aG6uqyg==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [arm64] os: [linux] libc: [musl] - '@oxlint/binding-linux-ppc64-gnu@1.55.0': - resolution: {integrity: sha512-A6pTdXwcEEwL/nmz0eUJ6WxmxcoIS+97GbH96gikAyre3s5deC7sts38ZVVowjS2QQFuSWkpA4ZmQC0jZSNvJQ==} + '@oxlint/binding-linux-ppc64-gnu@1.57.0': + resolution: {integrity: sha512-oMZDCwz4NobclZU3pH+V1/upVlJZiZvne4jQP+zhJwt+lmio4XXr4qG47CehvrW1Lx2YZiIHuxM2D4YpkG3KVA==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [ppc64] os: [linux] libc: [glibc] - '@oxlint/binding-linux-riscv64-gnu@1.55.0': - resolution: {integrity: sha512-clj0lnIN+V52G9tdtZl0LbdTSurnZ1NZj92Je5X4lC7gP5jiCSW+Y/oiDiSauBAD4wrHt2S7nN3pA0zfKYK/6Q==} + '@oxlint/binding-linux-riscv64-gnu@1.57.0': + resolution: {integrity: sha512-uoBnjJ3MMEBbfnWC1jSFr7/nSCkcQYa72NYoNtLl1imshDnWSolYCjzb8LVCwYCCfLJXD+0gBLD7fyC14c0+0g==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [riscv64] os: [linux] libc: [glibc] - '@oxlint/binding-linux-riscv64-musl@1.55.0': - resolution: {integrity: sha512-NNu08pllN5x/O94/sgR3DA8lbrGBnTHsINZZR0hcav1sj79ksTiKKm1mRzvZvacwQ0hUnGinFo+JO75ok2PxYg==} + '@oxlint/binding-linux-riscv64-musl@1.57.0': + resolution: {integrity: sha512-BdrwD7haPZ8a9KrZhKJRSj6jwCor+Z8tHFZ3PT89Y3Jq5v3LfMfEePeAmD0LOTWpiTmzSzdmyw9ijneapiVHKQ==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [riscv64] os: [linux] libc: [musl] - '@oxlint/binding-linux-s390x-gnu@1.55.0': - resolution: {integrity: sha512-BvfQz3PRlWZRoEZ17dZCqgQsMRdpzGZomJkVATwCIGhHVVeHJMQdmdXPSjcT1DCNUrOjXnVyj1RGDj5+/Je2+Q==} + '@oxlint/binding-linux-s390x-gnu@1.57.0': + resolution: {integrity: sha512-BNs+7ZNsRstVg2tpNxAXfMX/Iv5oZh204dVyb8Z37+/gCh+yZqNTlg6YwCLIMPSk5wLWIGOaQjT0GUOahKYImw==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [s390x] os: [linux] libc: [glibc] - '@oxlint/binding-linux-x64-gnu@1.55.0': - resolution: {integrity: sha512-ngSOoFCSBMKVQd24H8zkbcBNc7EHhjnF1sv3mC9NNXQ/4rRjI/4Dj9+9XoDZeFEkF1SX1COSBXF1b2Pr9rqdEw==} + '@oxlint/binding-linux-x64-gnu@1.57.0': + resolution: {integrity: sha512-AghS18w+XcENcAX0+BQGLiqjpqpaxKJa4cWWP0OWNLacs27vHBxu7TYkv9LUSGe5w8lOJHeMxcYfZNOAPqw2bg==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [x64] os: [linux] libc: [glibc] - '@oxlint/binding-linux-x64-musl@1.55.0': - resolution: {integrity: sha512-BDpP7W8GlaG7BR6QjGZAleYzxoyKc/D24spZIF2mB3XsfALQJJT/OBmP8YpeTb1rveFSBHzl8T7l0aqwkWNdGA==} + '@oxlint/binding-linux-x64-musl@1.57.0': + resolution: {integrity: sha512-E/FV3GB8phu/Rpkhz5T96hAiJlGzn91qX5yj5gU754P5cmVGXY1Jw/VSjDSlZBCY3VHjsVLdzgdkJaomEmcNOg==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [x64] os: [linux] libc: [musl] - '@oxlint/binding-openharmony-arm64@1.55.0': - resolution: {integrity: sha512-PS6GFvmde/pc3fCA2Srt51glr8Lcxhpf6WIBFfLphndjRrD34NEcses4TSxQrEcxYo6qVywGfylM0ZhSCF2gGA==} + '@oxlint/binding-openharmony-arm64@1.57.0': + resolution: {integrity: sha512-xvZ2yZt0nUVfU14iuGv3V25jpr9pov5N0Wr28RXnHFxHCRxNDMtYPHV61gGLhN9IlXM96gI4pyYpLSJC5ClLCQ==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [arm64] os: [openharmony] - '@oxlint/binding-win32-arm64-msvc@1.55.0': - resolution: {integrity: sha512-P6JcLJGs/q1UOvDLzN8otd9JsH4tsuuPDv+p7aHqHM3PrKmYdmUvkNj4K327PTd35AYcznOCN+l4ZOaq76QzSw==} + '@oxlint/binding-win32-arm64-msvc@1.57.0': + resolution: {integrity: sha512-Z4D8Pd0AyHBKeazhdIXeUUy5sIS3Mo0veOlzlDECg6PhRRKgEsBJCCV1n+keUZtQ04OP+i7+itS3kOykUyNhDg==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [arm64] os: [win32] - '@oxlint/binding-win32-ia32-msvc@1.55.0': - resolution: {integrity: sha512-gzkk4zE2zsE+WmRxFOiAZHpCpUNDFytEakqNXoNHW+PnYEOTPKDdW6nrzgSeTbGKVPXNAKQnRnMgrh7+n3Xueg==} + '@oxlint/binding-win32-ia32-msvc@1.57.0': + resolution: {integrity: sha512-StOZ9nFMVKvevicbQfql6Pouu9pgbeQnu60Fvhz2S6yfMaii+wnueLnqQ5I1JPgNF0Syew4voBlAaHD13wH6tw==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [ia32] os: [win32] - '@oxlint/binding-win32-x64-msvc@1.55.0': - resolution: {integrity: sha512-ZFALNow2/og75gvYzNP7qe+rREQ5xunktwA+lgykoozHZ6hw9bqg4fn5j2UvG4gIn1FXqrZHkOAXuPf5+GOYTQ==} + '@oxlint/binding-win32-x64-msvc@1.57.0': + resolution: {integrity: sha512-6PuxhYgth8TuW0+ABPOIkGdBYw+qYGxgIdXPHSVpiCDm+hqTTWCmC739St1Xni0DJBt8HnSHTG67i1y6gr8qrA==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [x64] os: [win32] @@ -2345,6 +3070,11 @@ packages: resolution: {integrity: sha512-QNqXyfVS2wm9hweSYD2O7F0G06uurj9kZ96TRQE5Y9hU7+tgdZwIkbAKc5Ocy1HxEY2kuDQa6cQ1WRs/O5LFKA==} engines: {node: ^12.20.0 || ^14.18.0 || >=16.0.0} + '@playwright/test@1.58.2': + resolution: {integrity: sha512-akea+6bHYBBfA9uQqSYmlJXn61cTa+jbO87xVLCWbTqbWadRVmhxlXATaOjOgcBaWU4ePo0wB41KMFv3o35IXA==} + engines: {node: '>=18'} + hasBin: true + '@polka/url@1.0.0-next.29': resolution: {integrity: sha512-wwQAWhWSuHaag8c4q/KN/vCoeOJYshAIvMQwD4GpSb3OiZklFfvAgmj0VCBBImRpuF/aFgIRzllXlVX93Jevww==} @@ -2633,6 +3363,104 @@ packages: resolution: {integrity: sha512-UuBOt7BOsKVOkFXRe4Ypd/lADuNIfqJXv8GvHqtXaTYXPPKkj2nS2zPllVsrtRjcomDhIJVBnZwfmlI222WH8g==} engines: {node: '>=14.0.0'} + '@rolldown/binding-android-arm64@1.0.0-rc.12': + resolution: {integrity: sha512-pv1y2Fv0JybcykuiiD3qBOBdz6RteYojRFY1d+b95WVuzx211CRh+ytI/+9iVyWQ6koTh5dawe4S/yRfOFjgaA==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [android] + + '@rolldown/binding-darwin-arm64@1.0.0-rc.12': + resolution: {integrity: sha512-cFYr6zTG/3PXXF3pUO+umXxt1wkRK/0AYT8lDwuqvRC+LuKYWSAQAQZjCWDQpAH172ZV6ieYrNnFzVVcnSflAg==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [darwin] + + '@rolldown/binding-darwin-x64@1.0.0-rc.12': + resolution: {integrity: sha512-ZCsYknnHzeXYps0lGBz8JrF37GpE9bFVefrlmDrAQhOEi4IOIlcoU1+FwHEtyXGx2VkYAvhu7dyBf75EJQffBw==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [darwin] + + '@rolldown/binding-freebsd-x64@1.0.0-rc.12': + resolution: {integrity: sha512-dMLeprcVsyJsKolRXyoTH3NL6qtsT0Y2xeuEA8WQJquWFXkEC4bcu1rLZZSnZRMtAqwtrF/Ib9Ddtpa/Gkge9Q==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [freebsd] + + '@rolldown/binding-linux-arm-gnueabihf@1.0.0-rc.12': + resolution: {integrity: sha512-YqWjAgGC/9M1lz3GR1r1rP79nMgo3mQiiA+Hfo+pvKFK1fAJ1bCi0ZQVh8noOqNacuY1qIcfyVfP6HoyBRZ85Q==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm] + os: [linux] + + '@rolldown/binding-linux-arm64-gnu@1.0.0-rc.12': + resolution: {integrity: sha512-/I5AS4cIroLpslsmzXfwbe5OmWvSsrFuEw3mwvbQ1kDxJ822hFHIx+vsN/TAzNVyepI/j/GSzrtCIwQPeKCLIg==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [linux] + libc: [glibc] + + '@rolldown/binding-linux-arm64-musl@1.0.0-rc.12': + resolution: {integrity: sha512-V6/wZztnBqlx5hJQqNWwFdxIKN0m38p8Jas+VoSfgH54HSj9tKTt1dZvG6JRHcjh6D7TvrJPWFGaY9UBVOaWPw==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [linux] + libc: [musl] + + '@rolldown/binding-linux-ppc64-gnu@1.0.0-rc.12': + resolution: {integrity: sha512-AP3E9BpcUYliZCxa3w5Kwj9OtEVDYK6sVoUzy4vTOJsjPOgdaJZKFmN4oOlX0Wp0RPV2ETfmIra9x1xuayFB7g==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [ppc64] + os: [linux] + libc: [glibc] + + '@rolldown/binding-linux-s390x-gnu@1.0.0-rc.12': + resolution: {integrity: sha512-nWwpvUSPkoFmZo0kQazZYOrT7J5DGOJ/+QHHzjvNlooDZED8oH82Yg67HvehPPLAg5fUff7TfWFHQS8IV1n3og==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [s390x] + os: [linux] + libc: [glibc] + + '@rolldown/binding-linux-x64-gnu@1.0.0-rc.12': + resolution: {integrity: sha512-RNrafz5bcwRy+O9e6P8Z/OCAJW/A+qtBczIqVYwTs14pf4iV1/+eKEjdOUta93q2TsT/FI0XYDP3TCky38LMAg==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [linux] + libc: [glibc] + + '@rolldown/binding-linux-x64-musl@1.0.0-rc.12': + resolution: {integrity: sha512-Jpw/0iwoKWx3LJ2rc1yjFrj+T7iHZn2JDg1Yny1ma0luviFS4mhAIcd1LFNxK3EYu3DHWCps0ydXQ5i/rrJ2ig==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [linux] + libc: [musl] + + '@rolldown/binding-openharmony-arm64@1.0.0-rc.12': + resolution: {integrity: sha512-vRugONE4yMfVn0+7lUKdKvN4D5YusEiPilaoO2sgUWpCvrncvWgPMzK00ZFFJuiPgLwgFNP5eSiUlv2tfc+lpA==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [openharmony] + + '@rolldown/binding-wasm32-wasi@1.0.0-rc.12': + resolution: {integrity: sha512-ykGiLr/6kkiHc0XnBfmFJuCjr5ZYKKofkx+chJWDjitX+KsJuAmrzWhwyOMSHzPhzOHOy7u9HlFoa5MoAOJ/Zg==} + engines: {node: '>=14.0.0'} + cpu: [wasm32] + + '@rolldown/binding-win32-arm64-msvc@1.0.0-rc.12': + resolution: {integrity: sha512-5eOND4duWkwx1AzCxadcOrNeighiLwMInEADT0YM7xeEOOFcovWZCq8dadXgcRHSf3Ulh1kFo/qvzoFiCLOL1Q==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [win32] + + '@rolldown/binding-win32-x64-msvc@1.0.0-rc.12': + resolution: {integrity: sha512-PyqoipaswDLAZtot351MLhrlrh6lcZPo2LSYE+VDxbVk24LVKAGOuE4hb8xZQmrPAuEtTZW8E6D2zc5EUZX4Lw==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [win32] + + '@rolldown/pluginutils@1.0.0-rc.12': + resolution: {integrity: sha512-HHMwmarRKvoFsJorqYlFeFRzXZqCt2ETQlEDOb9aqssrnVBB1/+xgTGtuTrIk5vzLNX1MjMtTf7W9z3tsSbrxw==} + '@rolldown/pluginutils@1.0.0-rc.5': resolution: {integrity: sha512-RxlLX/DPoarZ9PtxVrQgZhPoor987YtKQqCo5zkjX+0S0yLJ7Vv515Wk6+xtTL67VONKJKxETWZwuZjss2idYw==} @@ -2795,32 +3623,32 @@ packages: cpu: [x64] os: [win32] - '@sentry-internal/browser-utils@10.44.0': - resolution: {integrity: sha512-z9xz3T/v+MnfHY6kdUCmOZI8CiAl3LlKYtGH2p3rAsrxhwX+BTnUp01VhMVnEZIDgUXNt3AhJac+4kcDIPu1Hg==} + '@sentry-internal/browser-utils@10.46.0': + resolution: {integrity: sha512-WB1gBT9G13V02ekZ6NpUhoI1aGHV2eNfjEPthkU2bGBvFpQKnstwzjg7waIRGR7cu+YSW2Q6UI6aQLgBeOPD1g==} engines: {node: '>=18'} - '@sentry-internal/feedback@10.44.0': - resolution: {integrity: sha512-yNS2EGK1bNm8YUI+Orzpa7yr05Da+b1VEe/9x7dl7gTjw/+tfutoXlG6Y+iFZBB3gQ9QU+nxZAhU+KcxiPEURw==} + '@sentry-internal/feedback@10.46.0': + resolution: {integrity: sha512-c4pI/z9nZCQXe9GYEw/hE/YTY9AxGBp8/wgKI+T8zylrN35SGHaXv63szzE1WbI8lacBY8lBF7rstq9bQVCaHw==} engines: {node: '>=18'} - '@sentry-internal/replay-canvas@10.44.0': - resolution: {integrity: sha512-RA7XgYZWHY7M+vaHvuMxDFT51wCs4puS2smElM5oh+j3YqbFXY7P16fOCwIAGoyI4gVsj8aTeBgVqUmrmzhAXQ==} + '@sentry-internal/replay-canvas@10.46.0': + resolution: {integrity: sha512-ub314MWUsekVCuoH0/HJbbimlI24SkV745UW2pj9xRbxOAEf1wjkmIzxKrMDbTgJGuEunug02XZVdJFJUzOcDw==} engines: {node: '>=18'} - '@sentry-internal/replay@10.44.0': - resolution: {integrity: sha512-KDmoqBsRmkaoc+eKLR2CbScd2eBmLcw+1+D441lLttAO3WWhvYyCaYdu/HIGGUoybuSgt+IcpCJdi7hFuCvYqw==} + '@sentry-internal/replay@10.46.0': + resolution: {integrity: sha512-JBsWeXG6bRbxBFK8GzWymWGOB9QE7Kl57BeF3jzgdHTuHSWZ2mRnAmb1K05T4LU+gVygk6yW0KmdC8Py9Qzg9A==} engines: {node: '>=18'} - '@sentry/browser@10.44.0': - resolution: {integrity: sha512-UpMx5forbVKieNULma3gT2SsLYqsYT4nLXa6s1io/Y8BFej9sH2dD5ExA8TrkQThQwAWFI3qKsQzYnF+EX/Bfg==} + '@sentry/browser@10.46.0': + resolution: {integrity: sha512-80DmGlTk5Z2/OxVOzLNxwolMyouuAYKqG8KUcoyintZqHbF6kO1RulI610HmyUt3OagKeBCqt9S7w0VIfCRL+Q==} engines: {node: '>=18'} - '@sentry/core@10.44.0': - resolution: {integrity: sha512-aa7CiDaNFZvHpqd97LJhuskolfJ/4IH5xyuVVLnv7l6B0v9KTwskPUxb0tH1ej3FxuzfH+i8iTiTFuqpfHS3QA==} + '@sentry/core@10.46.0': + resolution: {integrity: sha512-N3fj4zqBQOhXliS1Ne9euqIKuciHCGOJfPGQLwBoW9DNz03jF+NB8+dUKtrJ79YLoftjVgf8nbgwtADK7NR+2Q==} engines: {node: '>=18'} - '@sentry/react@10.44.0': - resolution: {integrity: sha512-blaYoLk/UgFZXj9ieKZeY1JIiqzeL2VegQt22S9IQk8gHpunDZux5XC4CdcPdavcVusddaB/SmHAmhy2RCBdPQ==} + '@sentry/react@10.46.0': + resolution: {integrity: sha512-Rb1S+9OuUPVwsz7GWnQ6Kgf3azbsseUymIegg3JZHNcW/fM1nPpaljzTBnuineia113DH0pgMBcdrrZDLaosFQ==} engines: {node: '>=18'} peerDependencies: react: ^16.14.0 || 17.x || 18.x || 19.x @@ -2870,42 +3698,42 @@ packages: '@standard-schema/spec@1.1.0': resolution: {integrity: sha512-l2aFy5jALhniG5HgqrD6jXLi/rUWrKvqN/qJx6yoJsgKhblVd+iqqU4RCXavm/jPityDo5TCvKMnpjKnOriy0w==} - '@storybook/addon-docs@10.3.0': - resolution: {integrity: sha512-g9bc4YDiy4g/peLsUDmVcy2q/QXI3eHCQtHrVp2sHWef2SYjwUJ2+TOtJHScO8LuKhGnU3h2UeE59tPWTF2quw==} + '@storybook/addon-docs@10.3.3': + resolution: {integrity: sha512-trJQTpOtuOEuNv1Rn8X2Sopp5hSPpb0u0soEJ71BZAbxe4d2Y1d/1MYcxBdRKwncum6sCTsnxTpqQ/qvSJKlTQ==} peerDependencies: - storybook: ^10.3.0 + storybook: ^10.3.3 - '@storybook/addon-links@10.3.0': - resolution: {integrity: sha512-F0/UPO3HysoJoAFrBSqWkRP3lK2owHSAgQNEFB9mNihsAQbHHg9xer22VROL012saprs98+V/hNUZs4zPy9zlg==} + '@storybook/addon-links@10.3.3': + resolution: {integrity: sha512-tazBHlB+YbU62bde5DWsq0lnxZjcAsPB3YRUpN2hSMfAySsudRingyWrgu5KeOxXhJvKJj0ohjQvGcMx/wgQUA==} peerDependencies: react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 - storybook: ^10.3.0 + storybook: ^10.3.3 peerDependenciesMeta: react: optional: true - '@storybook/addon-onboarding@10.3.0': - resolution: {integrity: sha512-zhSmxO1VDntnAxSCvw1R9h2+KvAnY0PeDdhyrr9hQdVL1j3SEXxegc3dm/YJRhtBk6S2KPLgPU5+UQuFF0p2nA==} + '@storybook/addon-onboarding@10.3.3': + resolution: {integrity: sha512-HZiHfXdcLc29WkYFW+1VAMtJCeAZOOLRYPvs97woJUcZqW8yfWEJ9MWH+j++736SFAv2aqZWNmP47OdBJ/kMkw==} peerDependencies: - storybook: ^10.3.0 + storybook: ^10.3.3 - '@storybook/addon-themes@10.3.0': - resolution: {integrity: sha512-tMNRnEXv91u2lYgyUUAPhWiPD2XTLw2prj6r9/e9wmKYqJ5a2q0gQ7MiGzbgNYWmqq+DZ7g4vvGt8MXt2GmSHQ==} + '@storybook/addon-themes@10.3.3': + resolution: {integrity: sha512-6PgH1o7yNnWRVj4lAT1DNcX/eZXKgzjhfmzgWh3oFpPfDDvUzpFxx+MClM5f/ZieIbyQscxEuq8li7+e/F5VEQ==} peerDependencies: - storybook: ^10.3.0 + storybook: ^10.3.3 - '@storybook/builder-vite@10.3.0': - resolution: {integrity: sha512-T7LfZPE31j94Jkk66bnsxMibBnbLYmebLIDgPSYzeN3ZkjPfoFhhi2+8Zxneth5cQCGRkCAhRTV0tYmFp1+H6g==} + '@storybook/builder-vite@10.3.3': + resolution: {integrity: sha512-awspKCTZvXyeV3KabL0id62mFbxR5u/5yyGQultwCiSb2/yVgBfip2MAqLyS850pvTiB6QFVM9deOyd2/G/bEA==} peerDependencies: - storybook: ^10.3.0 + storybook: ^10.3.3 vite: ^5.0.0 || ^6.0.0 || ^7.0.0 || ^8.0.0 - '@storybook/csf-plugin@10.3.0': - resolution: {integrity: sha512-zlBnNpv0wtmICdQPDoY91HNzn6BNqnS2hur580J+qJtcP+5ZOYU7+gNyU+vfAnQuLEWbPz34rx8b1cTzXZQCDg==} + '@storybook/csf-plugin@10.3.3': + resolution: {integrity: sha512-Utlh7zubm+4iOzBBfzLW4F4vD99UBtl2Do4edlzK2F7krQIcFvR2ontjAE8S1FQVLZAC3WHalCOS+Ch8zf3knA==} peerDependencies: esbuild: 0.27.2 rollup: 4.59.0 - storybook: ^10.3.0 + storybook: ^10.3.3 vite: '*' webpack: '*' peerDependenciesMeta: @@ -2927,40 +3755,40 @@ packages: react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 - '@storybook/nextjs-vite@10.3.0': - resolution: {integrity: sha512-PQSQiUVxiR3eO3lmGbSyuPAbVwNJpOQDzkiC337IqWHhzZZQFVRgGU9j39hsUiP/d23BVuXPOWZtmTPASXDVMQ==} + '@storybook/nextjs-vite@10.3.3': + resolution: {integrity: sha512-/OzOo0dSd0eFIAF9ft+ptwaXHa5Xj01cw3NXEtmPdODZXl0eiPmTvWYIJeP26UEPzI2FFSm4fK64ZZJluKpGOA==} peerDependencies: next: ^14.1.0 || ^15.0.0 || ^16.0.0 react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 - storybook: ^10.3.0 + storybook: ^10.3.3 typescript: '*' vite: ^5.0.0 || ^6.0.0 || ^7.0.0 || ^8.0.0 peerDependenciesMeta: typescript: optional: true - '@storybook/react-dom-shim@10.3.0': - resolution: {integrity: sha512-dmAnIjkMmUYZCdg3FUL83Lavybin3bYKRNRXFZq1okCH8SINa2J+zKEzJhPlqixAKkbd7x1PFDgXnxxM/Nisig==} + '@storybook/react-dom-shim@10.3.3': + resolution: {integrity: sha512-lkhuh4G3UTreU9M3Iz5Dt32c6U+l/4XuvqLtbe1sDHENZH6aPj7y0b5FwnfHyvuTvYRhtbo29xZrF5Bp9kCC0w==} peerDependencies: react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 - storybook: ^10.3.0 + storybook: ^10.3.3 - '@storybook/react-vite@10.3.0': - resolution: {integrity: sha512-34t+30j+gglcRchPuZx4S4uusD746cvPeUPli7iJRWd3+vpnHSct03uGFAlsVJo6DZvVgH5s7vP4QU66C76K8A==} + '@storybook/react-vite@10.3.3': + resolution: {integrity: sha512-qHdlBe1hjqFAGXa8JL7bWTLbP/gDqXbWDm+SYCB646NHh5yvVDkZLwigP5Y+UL7M2ASfqFtosnroUK9tcCM2dw==} peerDependencies: react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 - storybook: ^10.3.0 + storybook: ^10.3.3 vite: ^5.0.0 || ^6.0.0 || ^7.0.0 || ^8.0.0 - '@storybook/react@10.3.0': - resolution: {integrity: sha512-pN++HZYVwjyJWeNg+6cewjOPkWlSho+BaUxCq/2e6yYUCr1J6MkBCYN/l1F7/ex9pDTKv9AW0da0o1aRXm3ivg==} + '@storybook/react@10.3.3': + resolution: {integrity: sha512-cGG5TbR8Tdx9zwlpsWyBEfWrejm5iWdYF26EwIhwuKq9GFUTAVrQzo0Rs7Tqc3ZyVhRS/YfsRiWSEH+zmq2JiQ==} peerDependencies: react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 - storybook: ^10.3.0 + storybook: ^10.3.3 typescript: '>= 4.9.x' peerDependenciesMeta: typescript: @@ -2983,11 +3811,11 @@ packages: '@swc/helpers@0.5.15': resolution: {integrity: sha512-JQ5TuMi45Owi4/BIMAJBoSQoOJu12oOk/gADqlcUL9JEdHB8vyjUSsxqeNXnmXHjYKMi2WcYtezGEEhqUI/E2g==} - '@swc/helpers@0.5.19': - resolution: {integrity: sha512-QamiFeIK3txNjgUTNppE6MiG3p7TdninpZu0E0PbqVh1a9FNLT2FRhisaa4NcaX52XVhA5l7Pk58Ft7Sqi/2sA==} + '@swc/helpers@0.5.20': + resolution: {integrity: sha512-2egEBHUMasdypIzrprsu8g+OEVd7Vp2MM3a2eVlM/cyFYto0nGz5BX5BTgh/ShZZI9ed+ozEq+Ngt+rgmUs8tw==} - '@t3-oss/env-core@0.13.10': - resolution: {integrity: sha512-NNFfdlJ+HmPHkLi2HKy7nwuat9SIYOxei9K10lO2YlcSObDILY7mHZNSHsieIM3A0/5OOzw/P/b+yLvPdaG52g==} + '@t3-oss/env-core@0.13.11': + resolution: {integrity: sha512-sM7GYY+KL7H/Hl0BE0inWfk3nRHZOLhmVn7sHGxaZt9FAR6KqREXAE+6TqKfiavfXmpRxO/OZ2QgKRd+oiBYRQ==} peerDependencies: arktype: ^2.1.0 typescript: '>=5.0.0' @@ -3003,8 +3831,8 @@ packages: zod: optional: true - '@t3-oss/env-nextjs@0.13.10': - resolution: {integrity: sha512-JfSA2WXOnvcc/uMdp31paMsfbYhhdvLLRxlwvrnlPE9bwM/n0Z+Qb9xRv48nPpvfMhOrkrTYw1I5Yc06WIKBJQ==} + '@t3-oss/env-nextjs@0.13.11': + resolution: {integrity: sha512-NC+3j7YWgpzdFu1t5y/8wqibTK0lm5RS4bjXA1n8uwik3wIR4iZM4Fa+U2BaMa5k3Qk8RZiYhoAIX0WogmGkzg==} peerDependencies: arktype: ^2.1.0 typescript: '>=5.0.0' @@ -3073,8 +3901,8 @@ packages: peerDependencies: solid-js: 1.9.11 - '@tanstack/eslint-plugin-query@5.91.5': - resolution: {integrity: sha512-4pqgoT5J+ntkyOoBtnxJu8LYRj3CurfNe92fghJw66mI7pZijKmOulM32Wa48cyVzGtgiuQ2o5KWC9LJVXYcBQ==} + '@tanstack/eslint-plugin-query@5.95.2': + resolution: {integrity: sha512-EYUFRaqjBep4EHMPpZR12sXP7Kr5qv9iDIlq93NfbhHwhITaW6Txu3ROO6dLFz5r84T8p+oZXBG77pa2Wuok7A==} peerDependencies: eslint: ^8.57.0 || ^9.0.0 typescript: ^5.4.0 @@ -3094,11 +3922,11 @@ packages: resolution: {integrity: sha512-y/xtNPNt/YeyoVxE/JCx+T7yjEzpezmbb+toK8DDD1P4m7Kzs5YR956+7OKexG3f8aXgC3rLZl7b1V+yNUSy5w==} engines: {node: '>=18'} - '@tanstack/query-core@5.91.0': - resolution: {integrity: sha512-FYXN8Kk9Q5VKuV6AIVaNwMThSi0nvAtR4X7HQoigf6ePOtFcavJYVIzgFhOVdtbBQtCJE3KimDIMMJM2DR1hjw==} + '@tanstack/query-core@5.95.2': + resolution: {integrity: sha512-o4T8vZHZET4Bib3jZ/tCW9/7080urD4c+0/AUaYVpIqOsr7y0reBc1oX3ttNaSW5mYyvZHctiQ/UOP2PfdmFEQ==} - '@tanstack/query-devtools@5.93.0': - resolution: {integrity: sha512-+kpsx1NQnOFTZsw6HAFCW3HkKg0+2cepGtAWXjiiSOJJ1CtQpt72EE2nyZb+AjAbLRPoeRmPJ8MtQd8r8gsPdg==} + '@tanstack/query-devtools@5.95.2': + resolution: {integrity: sha512-QfaoqBn9uAZ+ICkA8brd1EHj+qBF6glCFgt94U8XP5BT6ppSsDBI8IJ00BU+cAGjQzp6wcKJL2EmRYvxy0TWIg==} '@tanstack/react-devtools@0.10.0': resolution: {integrity: sha512-cUMzOQb1IHmkb8MsD0TrxHT8EL92Rx3G0Huq+IFkWeoaZPGlIiaIcGTpS5VvQDeI4BVUT+ZGt6CQTpx8oSTECg==} @@ -3123,19 +3951,19 @@ packages: '@tanstack/react-start': optional: true - '@tanstack/react-query-devtools@5.91.3': - resolution: {integrity: sha512-nlahjMtd/J1h7IzOOfqeyDh5LNfG0eULwlltPEonYy0QL+nqrBB+nyzJfULV+moL7sZyxc2sHdNJki+vLA9BSA==} + '@tanstack/react-query-devtools@5.95.2': + resolution: {integrity: sha512-AFQFmbznVkbtfpx8VJ2DylW17wWagQel/qLstVLkYmNRo2CmJt3SNej5hvl6EnEeljJIdC3BTB+W7HZtpsH+3g==} peerDependencies: - '@tanstack/react-query': ^5.90.20 + '@tanstack/react-query': ^5.95.2 react: ^18 || ^19 - '@tanstack/react-query@5.91.0': - resolution: {integrity: sha512-S8FODsDTNv0Ym+o/JVBvA6EWiWVhg6K2Q4qFehZyFKk6uW4H9OPbXl4kyiN9hAly0uHJ/1GEbR6kAI4MZWfjEA==} + '@tanstack/react-query@5.95.2': + resolution: {integrity: sha512-/wGkvLj/st5Ud1Q76KF1uFxScV7WeqN1slQx5280ycwAyYkIPGaRZAEgHxe3bjirSd5Zpwkj6zNcR4cqYni/ZA==} peerDependencies: react: ^18 || ^19 - '@tanstack/react-store@0.9.2': - resolution: {integrity: sha512-Vt5usJE5sHG/cMechQfmwvwne6ktGCELe89Lmvoxe3LKRoFrhPa8OCKWs0NliG8HTJElEIj7PLtaBQIcux5pAQ==} + '@tanstack/react-store@0.9.3': + resolution: {integrity: sha512-y2iHd/N9OkoQbFJLUX1T9vbc2O9tjH0pQRgTcx1/Nz4IlwLvkgpuglXUx+mXt0g5ZDFrEeDnONPqkbfxXJKwRg==} peerDependencies: react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 @@ -3146,12 +3974,16 @@ packages: react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 - '@tanstack/store@0.9.2': - resolution: {integrity: sha512-K013lUJEFJK2ofFQ/hZKJUmCnpcV00ebLyOyFOWQvyQHUOZp/iYO84BM6aOGiV81JzwbX0APTVmW8YI7yiG5oA==} + '@tanstack/store@0.9.3': + resolution: {integrity: sha512-8reSzl/qGWGGVKhBoxXPMWzATSbZLZFWhwBAFO9NAyp0TxzfBP0mIrGb8CP8KrQTmvzXlR/vFPPUrHTLBGyFyw==} '@tanstack/virtual-core@3.13.23': resolution: {integrity: sha512-zSz2Z2HNyLjCplANTDyl3BcdQJc2k1+yyFoKhNRmCr7V7dY8o8q5m8uFTI1/Pg1kL+Hgrz6u3Xo6eFUB7l66cg==} + '@teppeis/multimaps@3.0.0': + resolution: {integrity: sha512-ID7fosbc50TbT0MK0EG12O+gAP3W3Aa/Pz4DaTtQtEvlc9Odaqi0de+xuZ7Li2GtK4HzEX7IuRWS/JmZLksR3Q==} + engines: {node: '>=14'} + '@testing-library/dom@10.4.1': resolution: {integrity: sha512-o4PXJQidqJl82ckFaXUeoAW+XysPLauYI43Abki5hABd853iMhitooc6znOnczgbTYmEP6U6/y1ZyKAIsvMKGg==} engines: {node: '>=18'} @@ -3328,8 +4160,8 @@ packages: '@types/d3@7.4.3': resolution: {integrity: sha512-lZXZ9ckh5R8uiFVt8ogUNf+pIrK4EsWrx2Np75WvF/eTpJ0FMHNhjXk8CKEx/+gpHbNQyJWehbFaTvqmHWB3ww==} - '@types/debug@4.1.12': - resolution: {integrity: sha512-vIChWdVG3LG1SMxEvI/AK+FWJthlrqlTu7fbrlywTkkaONwk/UAGaULXRlf8vkzFBLVm0zkMdCquhL5aOjhXPQ==} + '@types/debug@4.1.13': + resolution: {integrity: sha512-KSVgmQmzMwPlmtljOomayoR89W4FynCAi3E8PPs7vmDVPe84hT+vGPKkJfThkmXs0x0jAaa9U8uW8bbfyS2fWw==} '@types/deep-eql@4.0.2': resolution: {integrity: sha512-c9h9dVVMigMPc4bwTvC5dxqtqJZwQPePsWjPlpSOnojbor6pGqdk541lfA7AqFQr5pB1BRdq0juY9db81BwyFw==} @@ -3388,6 +4220,9 @@ packages: '@types/node@25.5.0': resolution: {integrity: sha512-jp2P3tQMSxWugkCUKLRPVUpGaL5MVFwF8RDuSRztfwgN1wmqJeMSbKlnEtQqU8UrhTmzEmZdu2I6v2dpp7XIxw==} + '@types/normalize-package-data@2.4.4': + resolution: {integrity: sha512-37i+OaWTh9qeK4LSHPsyRC7NahnGotNuZvjLSgcPzblpHB3rrCJxAOgI5gCdKm7coonsaX1Of0ILiTcnZjbfxA==} + '@types/papaparse@5.5.2': resolution: {integrity: sha512-gFnFp/JMzLHCwRf7tQHrNnfhN4eYBVYYI897CGX4MY1tzY9l2aLkVyx2IlKZ/SAqDbB3I1AOZW5gTMGGsqWliA==} @@ -3402,9 +4237,6 @@ packages: peerDependencies: '@types/react': ^19.2.0 - '@types/react-slider@1.3.6': - resolution: {integrity: sha512-RS8XN5O159YQ6tu3tGZIQz1/9StMLTg/FCIPxwqh2gwVixJnlfIodtVx+fpXVMZHe7A58lAX1Q4XTgAGOQaCQg==} - '@types/react-syntax-highlighter@15.5.13': resolution: {integrity: sha512-uLGJ87j6Sz8UaBAooU0T6lWJ0dBmjZgN1PZTrj05TNql2/XpC6+4HhMT5syIdFUUt+FASfCeLLv4kBygNU+8qA==} @@ -3429,114 +4261,120 @@ packages: '@types/unist@3.0.3': resolution: {integrity: sha512-ko/gIFJRv177XgZsZcBwnqJN5x/Gien8qNOn0D5bQU/zAzVf9Zt3BlcUiLqhV9y4ARk0GbT3tnUiPNgnTXzc/Q==} + '@types/whatwg-mimetype@3.0.2': + resolution: {integrity: sha512-c2AKvDT8ToxLIOUlN51gTiHXflsfIFisS4pO7pDPoKouJCESkhZnEy623gwP9laCy5lnLDAw1vAzu2vM2YLOrA==} + + '@types/ws@8.18.1': + resolution: {integrity: sha512-ThVF6DCVhA8kUGy+aazFQ4kXQ7E1Ty7A3ypFOe0IcJV8O/M511G99AW24irKrW56Wt44yG9+ij8FaqoBGkuBXg==} + '@types/yauzl@2.10.3': resolution: {integrity: sha512-oJoftv0LSuaDZE3Le4DbKX+KS9G36NzOeSap90UIK0yMA/NhKJhqlSGtNDORNRaIbQfzjXDrQa0ytJ6mNRGz/Q==} '@types/zen-observable@0.8.3': resolution: {integrity: sha512-fbF6oTd4sGGy0xjHPKAt+eS2CrxJ3+6gQ3FGcBoIJR2TLAyCkCyI8JqZNy+FeON0AhVgNJoUumVoZQjBFUqHkw==} - '@typescript-eslint/eslint-plugin@8.57.1': - resolution: {integrity: sha512-Gn3aqnvNl4NGc6x3/Bqk1AOn0thyTU9bqDRhiRnUWezgvr2OnhYCWCgC8zXXRVqBsIL1pSDt7T9nJUe0oM0kDQ==} + '@typescript-eslint/eslint-plugin@8.57.2': + resolution: {integrity: sha512-NZZgp0Fm2IkD+La5PR81sd+g+8oS6JwJje+aRWsDocxHkjyRw0J5L5ZTlN3LI1LlOcGL7ph3eaIUmTXMIjLk0w==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} peerDependencies: - '@typescript-eslint/parser': ^8.57.1 + '@typescript-eslint/parser': ^8.57.2 eslint: ^8.57.0 || ^9.0.0 || ^10.0.0 typescript: '>=4.8.4 <6.0.0' - '@typescript-eslint/parser@8.57.1': - resolution: {integrity: sha512-k4eNDan0EIMTT/dUKc/g+rsJ6wcHYhNPdY19VoX/EOtaAG8DLtKCykhrUnuHPYvinn5jhAPgD2Qw9hXBwrahsw==} + '@typescript-eslint/parser@8.57.2': + resolution: {integrity: sha512-30ScMRHIAD33JJQkgfGW1t8CURZtjc2JpTrq5n2HFhOefbAhb7ucc7xJwdWcrEtqUIYJ73Nybpsggii6GtAHjA==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} peerDependencies: eslint: ^8.57.0 || ^9.0.0 || ^10.0.0 typescript: '>=4.8.4 <6.0.0' - '@typescript-eslint/project-service@8.57.1': - resolution: {integrity: sha512-vx1F37BRO1OftsYlmG9xay1TqnjNVlqALymwWVuYTdo18XuKxtBpCj1QlzNIEHlvlB27osvXFWptYiEWsVdYsg==} + '@typescript-eslint/project-service@8.57.2': + resolution: {integrity: sha512-FuH0wipFywXRTHf+bTTjNyuNQQsQC3qh/dYzaM4I4W0jrCqjCVuUh99+xd9KamUfmCGPvbO8NDngo/vsnNVqgw==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} peerDependencies: typescript: '>=4.8.4 <6.0.0' - '@typescript-eslint/rule-tester@8.57.1': - resolution: {integrity: sha512-gk0q0rLa7a1uEB0iD2t1GZELK1z6HfudiKYeSVhjQ5gW5FdL0OcZ+8f09Lg7NbmHSBF3V+S9BDuw0qoCFkHR+w==} + '@typescript-eslint/rule-tester@8.57.2': + resolution: {integrity: sha512-cb5m0irr1449waTuYzGi4KD3SGUH3khL4ta/o9lzShvT7gnIwR5qVhU0VM0p966kCrtFId8hwmkvz1fOElsxTg==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} peerDependencies: eslint: ^8.57.0 || ^9.0.0 || ^10.0.0 - '@typescript-eslint/scope-manager@8.57.1': - resolution: {integrity: sha512-hs/QcpCwlwT2L5S+3fT6gp0PabyGk4Q0Rv2doJXA0435/OpnSR3VRgvrp8Xdoc3UAYSg9cyUjTeFXZEPg/3OKg==} + '@typescript-eslint/scope-manager@8.57.2': + resolution: {integrity: sha512-snZKH+W4WbWkrBqj4gUNRIGb/jipDW3qMqVJ4C9rzdFc+wLwruxk+2a5D+uoFcKPAqyqEnSb4l2ULuZf95eSkw==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} - '@typescript-eslint/tsconfig-utils@8.57.1': - resolution: {integrity: sha512-0lgOZB8cl19fHO4eI46YUx2EceQqhgkPSuCGLlGi79L2jwYY1cxeYc1Nae8Aw1xjgW3PKVDLlr3YJ6Bxx8HkWg==} + '@typescript-eslint/tsconfig-utils@8.57.2': + resolution: {integrity: sha512-3Lm5DSM+DCowsUOJC+YqHHnKEfFh5CoGkj5Z31NQSNF4l5wdOwqGn99wmwN/LImhfY3KJnmordBq/4+VDe2eKw==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} peerDependencies: typescript: '>=4.8.4 <6.0.0' - '@typescript-eslint/type-utils@8.57.1': - resolution: {integrity: sha512-+Bwwm0ScukFdyoJsh2u6pp4S9ktegF98pYUU0hkphOOqdMB+1sNQhIz8y5E9+4pOioZijrkfNO/HUJVAFFfPKA==} + '@typescript-eslint/type-utils@8.57.2': + resolution: {integrity: sha512-Co6ZCShm6kIbAM/s+oYVpKFfW7LBc6FXoPXjTRQ449PPNBY8U0KZXuevz5IFuuUj2H9ss40atTaf9dlGLzbWZg==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} peerDependencies: eslint: ^8.57.0 || ^9.0.0 || ^10.0.0 typescript: '>=4.8.4 <6.0.0' - '@typescript-eslint/types@8.57.1': - resolution: {integrity: sha512-S29BOBPJSFUiblEl6RzPPjJt6w25A6XsBqRVDt53tA/tlL8q7ceQNZHTjPeONt/3S7KRI4quk+yP9jK2WjBiPQ==} + '@typescript-eslint/types@8.57.2': + resolution: {integrity: sha512-/iZM6FnM4tnx9csuTxspMW4BOSegshwX5oBDznJ7S4WggL7Vczz5d2W11ecc4vRrQMQHXRSxzrCsyG5EsPPTbA==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} - '@typescript-eslint/typescript-estree@8.57.1': - resolution: {integrity: sha512-ybe2hS9G6pXpqGtPli9Gx9quNV0TWLOmh58ADlmZe9DguLq0tiAKVjirSbtM1szG6+QH6rVXyU6GTLQbWnMY+g==} + '@typescript-eslint/typescript-estree@8.57.2': + resolution: {integrity: sha512-2MKM+I6g8tJxfSmFKOnHv2t8Sk3T6rF20A1Puk0svLK+uVapDZB/4pfAeB7nE83uAZrU6OxW+HmOd5wHVdXwXA==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} peerDependencies: typescript: '>=4.8.4 <6.0.0' - '@typescript-eslint/utils@8.57.1': - resolution: {integrity: sha512-XUNSJ/lEVFttPMMoDVA2r2bwrl8/oPx8cURtczkSEswY5T3AeLmCy+EKWQNdL4u0MmAHOjcWrqJp2cdvgjn8dQ==} + '@typescript-eslint/utils@8.57.2': + resolution: {integrity: sha512-krRIbvPK1ju1WBKIefiX+bngPs+odIQUtR7kymzPfo1POVw3jlF+nLkmexdSSd4UCbDcQn+wMBATOOmpBbqgKg==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} peerDependencies: eslint: ^8.57.0 || ^9.0.0 || ^10.0.0 typescript: '>=4.8.4 <6.0.0' - '@typescript-eslint/visitor-keys@8.57.1': - resolution: {integrity: sha512-YWnmJkXbofiz9KbnbbwuA2rpGkFPLbAIetcCNO6mJ8gdhdZ/v7WDXsoGFAJuM6ikUFKTlSQnjWnVO4ux+UzS6A==} + '@typescript-eslint/visitor-keys@8.57.2': + resolution: {integrity: sha512-zhahknjobV2FiD6Ee9iLbS7OV9zi10rG26odsQdfBO/hjSzUQbkIYgda+iNKK1zNiW2ey+Lf8MU5btN17V3dUw==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} - '@typescript/native-preview-darwin-arm64@7.0.0-dev.20260318.1': - resolution: {integrity: sha512-hsXZC0M5N2F/KdX/wjRywZPovdGBgWw9ARy0GWCw1dAynqdfDcuceKbUw+QwMSdvvsFbUjSomTlyFdT09p1mcA==} + '@typescript/native-preview-darwin-arm64@7.0.0-dev.20260329.1': + resolution: {integrity: sha512-zS1thDk7luD82nXVwvMd97F7FgxAE6jGtSmnHeXdaQ+6hJQcQLOVkfUdaehhdodqKDapWA2jEURxQAYjDGvv3g==} cpu: [arm64] os: [darwin] - '@typescript/native-preview-darwin-x64@7.0.0-dev.20260318.1': - resolution: {integrity: sha512-lQl7DQkROqPZrx4C1MpFP0WNxdqv+9r4lErhd+57M2Kmxx1BmX3K5VMLJT9FZQFRtgntnYbwQAQ774Z17fv8rA==} + '@typescript/native-preview-darwin-x64@7.0.0-dev.20260329.1': + resolution: {integrity: sha512-3IJ2qmpjQ1OXpZNUhJRjF1+SbDuqGC14Ug8DjWJlPBp06isi1fcJph90f5qW//FxEsNnJPYRcNwpP0A2RbTASg==} cpu: [x64] os: [darwin] - '@typescript/native-preview-linux-arm64@7.0.0-dev.20260318.1': - resolution: {integrity: sha512-1wv0qpJW4okKadShemVi4s7zGuiIRI7zTInRYDV/FfyQVyKrkTOzMtZXB6CF3Reus1HmRpGp5ADyc4MI7CCeJg==} + '@typescript/native-preview-linux-arm64@7.0.0-dev.20260329.1': + resolution: {integrity: sha512-gQb6SjB5JlUKDaDuz6mv/m+/OBWVDlcjHINFOykBZZYZtgxBx6nEDjLrT8TiJRjmHEG6hSbv+yisUL9IThWycA==} cpu: [arm64] os: [linux] - '@typescript/native-preview-linux-arm@7.0.0-dev.20260318.1': - resolution: {integrity: sha512-tE7uN00Po/oBg5VYaYM0C/QXroo6gdIRmFVZl543o46ihl0YKEZBMnyStRKKgPCI9oeYXyCNT6WR4MxSMz6ndA==} + '@typescript/native-preview-linux-arm@7.0.0-dev.20260329.1': + resolution: {integrity: sha512-WKSSJrH611DFFAg6YCkgbnkdy0a4RRpzvDpNXtPzLTbMYC5oJdq3Dpvncx5nrJvGh4J4yvzXoMxraGPyygqGLw==} cpu: [arm] os: [linux] - '@typescript/native-preview-linux-x64@7.0.0-dev.20260318.1': - resolution: {integrity: sha512-aSE7xAKYTOrxsFrIgmcaHjgXSSOnWrZ6ozNBeNxpGzd/gl2Ho3FCIwQb0NCXrDwF9AhpFRtHMWPpAPaJk24+rg==} + '@typescript/native-preview-linux-x64@7.0.0-dev.20260329.1': + resolution: {integrity: sha512-kg4r+ssxoEWruBynUg9bFMdcMpo5NupzAPqNBlV8uWbmYGZjaPLonFWAi9ZZMiVJY/x5ZQ9GBl6xskwLdd3PJQ==} cpu: [x64] os: [linux] - '@typescript/native-preview-win32-arm64@7.0.0-dev.20260318.1': - resolution: {integrity: sha512-TV/Tn8cgWamb+6mvY45X2wF0vrTkQmRFCiN1pRRehEwxslDkqLVlpGAFpZndLaPlMb/wzwVpz1e/926xdAoO1w==} + '@typescript/native-preview-win32-arm64@7.0.0-dev.20260329.1': + resolution: {integrity: sha512-Qi4lddVxl5MG7Tk67gYhCFnoqqLGd4TvaI8RN4qHFjt3GV8s6c+0cQGsJXJnVgMx27qbyDTdsyAa2pvb42rYcQ==} cpu: [arm64] os: [win32] - '@typescript/native-preview-win32-x64@7.0.0-dev.20260318.1': - resolution: {integrity: sha512-AgOZODSYeTlQWVTioRG3AxHzIBSLbZZhyK19WPzjHW0LtxCcFi59G/Gn1uIshVL3sp1ESRg9SZ5mSiFdgvfK4g==} + '@typescript/native-preview-win32-x64@7.0.0-dev.20260329.1': + resolution: {integrity: sha512-+k5+usuB8HZ6Xc+enLdb95ZJd25bQqsnI1zXxfRCHP+RS9mxs70Mi9ezQz3lKOLZFFXShSH7iW9iulm8KwVzCQ==} cpu: [x64] os: [win32] - '@typescript/native-preview@7.0.0-dev.20260318.1': - resolution: {integrity: sha512-/7LF/2x29K++k147445omxNixPANTmwJl9p/IIzK8NbOeqVOFv1Gj1GQyOQqRdT4j/X6YDwO/p400/JKE+cBOw==} + '@typescript/native-preview@7.0.0-dev.20260329.1': + resolution: {integrity: sha512-v5lJ0TgSt2m9yVk2xoj9+NH/gTDeWTLaWGPx6MJsUKOYd6bmCJhHbMcWmb8d/zlfhE9ffpixUKYj62CdYfriqA==} hasBin: true '@ungap/structured-clone@1.3.0': @@ -3567,6 +4405,19 @@ packages: resolution: {integrity: sha512-hBcWIOppZV14bi+eAmCZj8Elj8hVSUZJTpf1lgGBhVD85pervzQ1poM/qYfFUlPraYSZYP+ASg6To5BwYmUSGQ==} engines: {node: '>=16'} + '@vitejs/devtools-kit@0.1.11': + resolution: {integrity: sha512-ZmBr54Nk8IwdbNCBNtOkQ3WcskWcL55ndfiB0UM8eTZ0ZoNwzPTCHiHgk/RnbhviXiB0kTowyTTYp4RfqGEWUQ==} + peerDependencies: + vite: '*' + + '@vitejs/devtools-rpc@0.1.11': + resolution: {integrity: sha512-APo34qbV05bNJB//Jmn4QLDrCU1CQuFvYbQdqvvyCKjxwWuoHhGobqzgoRS5V23tn8Sbliz7/Fyhfh+7C0LtKA==} + peerDependencies: + ws: '*' + peerDependenciesMeta: + ws: + optional: true + '@vitejs/plugin-react@6.0.1': resolution: {integrity: sha512-l9X/E3cDb+xY3SWzlG1MOGt2usfEHGMNIaegaUGFsLkb3RCn/k8/TOXBcab+OndDI4TBtktT8/9BwwW8Vi9KUQ==} engines: {node: ^20.19.0 || >=22.12.0} @@ -3591,23 +4442,26 @@ packages: react-server-dom-webpack: optional: true - '@vitest/coverage-v8@4.1.0': - resolution: {integrity: sha512-nDWulKeik2bL2Va/Wl4x7DLuTKAXa906iRFooIRPR+huHkcvp9QDkPQ2RJdmjOFrqOqvNfoSQLF68deE3xC3CQ==} + '@vitest/coverage-v8@4.1.2': + resolution: {integrity: sha512-sPK//PHO+kAkScb8XITeB1bf7fsk85Km7+rt4eeuRR3VS1/crD47cmV5wicisJmjNdfeokTZwjMk4Mj2d58Mgg==} peerDependencies: - '@vitest/browser': 4.1.0 - vitest: 4.1.0 + '@vitest/browser': 4.1.2 + vitest: 4.1.2 peerDependenciesMeta: '@vitest/browser': optional: true - '@vitest/eslint-plugin@1.6.12': - resolution: {integrity: sha512-4kI47BJNFE+EQ5bmPbHzBF+ibNzx2Fj0Jo9xhWsTPxMddlHwIWl6YAxagefh461hrwx/W0QwBZpxGS404kBXyg==} + '@vitest/eslint-plugin@1.6.13': + resolution: {integrity: sha512-ui7JGWBoQpS5NKKW0FDb1eTuFEZ5EupEv2Psemuyfba7DfA5K52SeDLelt6P4pQJJ/4UGkker/BgMk/KrjH3WQ==} engines: {node: '>=18'} peerDependencies: + '@typescript-eslint/eslint-plugin': '*' eslint: '>=8.57.0' typescript: '>=5.0.0' vitest: '*' peerDependenciesMeta: + '@typescript-eslint/eslint-plugin': + optional: true typescript: optional: true vitest: @@ -3619,8 +4473,8 @@ packages: '@vitest/pretty-format@3.2.4': resolution: {integrity: sha512-IVNZik8IVRJRTr9fxlitMKeJeXFFFN0JaB9PHPGQ8NKQbGpfjlTx9zO4RefN8gp7eqjNy8nyK3NZmBzOPeIxtA==} - '@vitest/pretty-format@4.1.0': - resolution: {integrity: sha512-3RZLZlh88Ib0J7NQTRATfc/3ZPOnSUn2uDBUoGNn5T36+bALixmzphN26OUD3LRXWkJu4H0s5vvUeqBiw+kS0A==} + '@vitest/pretty-format@4.1.2': + resolution: {integrity: sha512-dwQga8aejqeuB+TvXCMzSQemvV9hNEtDDpgUKDzOmNQayl2OG241PSWeJwKRH3CiC+sESrmoFd49rfnq7T4RnA==} '@vitest/spy@3.2.4': resolution: {integrity: sha512-vAfasCOe6AIK70iP5UD11Ac4siNUNJ9i/9PZ3NKx07sG6sUxeag1LWdNrMWeKKYBLlzuK+Gn65Yd5nyL6ds+nw==} @@ -3628,18 +4482,18 @@ packages: '@vitest/utils@3.2.4': resolution: {integrity: sha512-fB2V0JFrQSMsCo9HiSq3Ezpdv4iYaXRG1Sx8edX3MwxfyNn83mKiGzOcH+Fkxt4MHxr3y42fQi1oeAInqgX2QA==} - '@vitest/utils@4.1.0': - resolution: {integrity: sha512-XfPXT6a8TZY3dcGY8EdwsBulFCIw+BeeX0RZn2x/BtiY/75YGh8FeWGG8QISN/WhaqSrE2OrlDgtF8q5uhOTmw==} + '@vitest/utils@4.1.2': + resolution: {integrity: sha512-xw2/TiX82lQHA06cgbqRKFb5lCAy3axQ4H4SoUFhUsg+wztiet+co86IAMDtF6Vm1hc7J6j09oh/rgDn+JdKIQ==} - '@voidzero-dev/vite-plus-core@0.1.12': - resolution: {integrity: sha512-j8YNe7A+8JcSoddztf5whvom/yJ7OKUO3Y5a3UoLIUmOL8YEKVv5nPANrxJ7eaFfHJoMnBEwzBpq1YVZ+H3uPA==} + '@voidzero-dev/vite-plus-core@0.1.14': + resolution: {integrity: sha512-CCWzdkfW0fo0cQNlIsYp5fOuH2IwKuPZEb2UY2Z8gXcp5pG74A82H2Pthj0heAuvYTAnfT7kEC6zM+RbiBgQbg==} engines: {node: ^20.19.0 || >=22.12.0} peerDependencies: '@arethetypeswrong/core': ^0.18.1 - '@tsdown/css': 0.21.3 - '@tsdown/exe': 0.21.3 + '@tsdown/css': 0.21.4 + '@tsdown/exe': 0.21.4 '@types/node': ^20.19.0 || >=22.12.0 - '@vitejs/devtools': ^0.0.0-alpha.31 + '@vitejs/devtools': ^0.1.0 esbuild: 0.27.2 jiti: '>=1.21.0' less: ^4.0.0 @@ -3652,7 +4506,7 @@ packages: tsx: ^4.8.1 typescript: ^5.0.0 unplugin-unused: ^0.5.0 - yaml: ^2.4.2 + yaml: 2.8.3 peerDependenciesMeta: '@arethetypeswrong/core': optional: true @@ -3691,43 +4545,57 @@ packages: yaml: optional: true - '@voidzero-dev/vite-plus-darwin-arm64@0.1.12': - resolution: {integrity: sha512-tYQrfmcLxIqqr/de00oN7ayu+rYobEOjyR9AxoeJoNUqRyNQCdT0A5vg78kJNPaQCyL6ctgRRvpEKr0WHVmduQ==} + '@voidzero-dev/vite-plus-darwin-arm64@0.1.14': + resolution: {integrity: sha512-q2ESUSbapwsxVRe/KevKATahNRraoX5nti3HT9S3266OHT5sMroBY14jaxTv74ekjQc9E6EPhyLGQWuWQuuBRw==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [arm64] os: [darwin] - '@voidzero-dev/vite-plus-darwin-x64@0.1.12': - resolution: {integrity: sha512-852hO/Onx9Z5u0tOYOVEUVzYJUmWdlHeqYnNT6pj0IClgVp0+KSabxr7A2paTWEFWp6XbKWvqw5Y5cVwUV3A6Q==} + '@voidzero-dev/vite-plus-darwin-x64@0.1.14': + resolution: {integrity: sha512-UpcDZc9G99E/4HDRoobvYHxMvFOG5uv3RwEcq0HF70u4DsnEMl1z8RaJLeWV7a09LGwj9Q+YWC3Z4INWnTLs8g==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [x64] os: [darwin] - '@voidzero-dev/vite-plus-linux-arm64-gnu@0.1.12': - resolution: {integrity: sha512-/gTh4tGyJKCNBn9SZUs3sq9QVRUmyuyseZefBgS223QRxdwFaxc7tIKaw91X59WXXYOzUYZOD5zsTcaIF4hc9A==} + '@voidzero-dev/vite-plus-linux-arm64-gnu@0.1.14': + resolution: {integrity: sha512-GIjn35RABUEDB9gHD26nRq7T72Te+Qy2+NIzogwEaUE728PvPkatF5gMCeF4sigCoc8c4qxDwsG+A2A2LYGnDg==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [arm64] os: [linux] libc: [glibc] - '@voidzero-dev/vite-plus-linux-x64-gnu@0.1.12': - resolution: {integrity: sha512-9oN9ITjK/Xq9Werx+6G6jnI3+F1S3g9lB36J1VAHyRlAEtuiCDV0E3YMoW2O7KzM/PlodZIZ8LStVkH7aA5ZCw==} + '@voidzero-dev/vite-plus-linux-arm64-musl@0.1.14': + resolution: {integrity: sha512-qo2RToGirG0XCcxZ2AEOuonLM256z6dNbJzDDIo5gWYA+cIKigFQJbkPyr25zsT1tsP2aY0OTxt2038XbVlRkQ==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [linux] + libc: [musl] + + '@voidzero-dev/vite-plus-linux-x64-gnu@0.1.14': + resolution: {integrity: sha512-BsMWKZfdfGcYLxxLyaePpg6NW54xqzzcfq8sFUwKfwby0kgOKQ4WymUXyBvO9nnBb0ZPsJQrV0sx+Onac/LTaw==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [x64] os: [linux] libc: [glibc] - '@voidzero-dev/vite-plus-test@0.1.12': - resolution: {integrity: sha512-EE8Y2vQvqS4c/1qSa7qlhUY9koAG6wYev0NFAtDZsijQCHUqE7nYXGJYnyUInAE6GX4zlQDGg7tf2DAl+CISYw==} + '@voidzero-dev/vite-plus-linux-x64-musl@0.1.14': + resolution: {integrity: sha512-mOrEpj7ntW9RopGbcOYG/L0pOs0qHzUG4Vz7NXbuf4dbOSlY4JjyoMOIWxjKQORQht02Hzuf8YrMGNwa6AjVSQ==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [linux] + libc: [musl] + + '@voidzero-dev/vite-plus-test@0.1.14': + resolution: {integrity: sha512-rjF+qpYD+5+THOJZ3gbE3+cxsk5sW7nJ0ODK7y6ZKeS4amREUMedEDYykzKBwR7OZDC/WwE90A0iLWCr6qAXhA==} engines: {node: ^20.0.0 || ^22.0.0 || >=24.0.0} peerDependencies: '@edge-runtime/vm': '*' '@opentelemetry/api': ^1.9.0 '@types/node': ^20.0.0 || ^22.0.0 || >=24.0.0 - '@vitest/ui': 4.1.0 + '@vitest/ui': 4.1.1 happy-dom: '*' jsdom: '*' - vite: ^6.0.0 || ^7.0.0 || ^8.0.0-0 + vite: ^6.0.0 || ^7.0.0 || ^8.0.0 peerDependenciesMeta: '@edge-runtime/vm': optional: true @@ -3742,14 +4610,14 @@ packages: jsdom: optional: true - '@voidzero-dev/vite-plus-win32-arm64-msvc@0.1.12': - resolution: {integrity: sha512-JanAb6Y+6BmPhKNLvpZB/syeyY99bt7EPJCaLlbaCt3V0Y2Iw7c7dWBM4Sg4GZ7szGYdGw385fRz0n2M32f1rg==} + '@voidzero-dev/vite-plus-win32-arm64-msvc@0.1.14': + resolution: {integrity: sha512-7iC+Ig+8D/zACy0IJf7w/vQ7duTjux9Ttmm3KOBdVWH4dl3JihydA7+SQVMhz71a4WiqJ6nPidoG8D6hUP4MVQ==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [arm64] os: [win32] - '@voidzero-dev/vite-plus-win32-x64-msvc@0.1.12': - resolution: {integrity: sha512-Ei/UtTTp7UgeEGyV83jhDpSMXhwaZZzfS7Xiaj+zj80GGOwsBre0i+oHGZ7+TuVsZ7Im0sD8IZ9enCpKpV//AQ==} + '@voidzero-dev/vite-plus-win32-x64-msvc@0.1.14': + resolution: {integrity: sha512-yRJ/8yAYFluNHx0Ej6Kevx65MIeM3wFKklnxosVZRlz2ZRL1Ea1Qh3tWATr3Ipk1ciRxBv8KJgp6zXqjxtZSoQ==} engines: {node: ^20.19.0 || >=22.12.0} cpu: [x64] os: [win32] @@ -3766,20 +4634,20 @@ packages: '@volar/typescript@2.4.28': resolution: {integrity: sha512-Ja6yvWrbis2QtN4ClAKreeUZPVYMARDYZl9LMEv1iQ1QdepB6wn0jTRxA9MftYmYa4DQ4k/DaSZpFPUfxl8giw==} - '@vue/compiler-core@3.5.30': - resolution: {integrity: sha512-s3DfdZkcu/qExZ+td75015ljzHc6vE+30cFMGRPROYjqkroYI5NV2X1yAMX9UeyBNWB9MxCfPcsjpLS11nzkkw==} + '@vue/compiler-core@3.5.31': + resolution: {integrity: sha512-k/ueL14aNIEy5Onf0OVzR8kiqF/WThgLdFhxwa4e/KF/0qe38IwIdofoSWBTvvxQOesaz6riAFAUaYjoF9fLLQ==} - '@vue/compiler-dom@3.5.30': - resolution: {integrity: sha512-eCFYESUEVYHhiMuK4SQTldO3RYxyMR/UQL4KdGD1Yrkfdx4m/HYuZ9jSfPdA+nWJY34VWndiYdW/wZXyiPEB9g==} + '@vue/compiler-dom@3.5.31': + resolution: {integrity: sha512-BMY/ozS/xxjYqRFL+tKdRpATJYDTTgWSo0+AJvJNg4ig+Hgb0dOsHPXvloHQ5hmlivUqw1Yt2pPIqp4e0v1GUw==} - '@vue/compiler-sfc@3.5.30': - resolution: {integrity: sha512-LqmFPDn89dtU9vI3wHJnwaV6GfTRD87AjWpTWpyrdVOObVtjIuSeZr181z5C4PmVx/V3j2p+0f7edFKGRMpQ5A==} + '@vue/compiler-sfc@3.5.31': + resolution: {integrity: sha512-M8wpPgR9UJ8MiRGjppvx9uWJfLV7A/T+/rL8s/y3QG3u0c2/YZgff3d6SuimKRIhcYnWg5fTfDMlz2E6seUW8Q==} - '@vue/compiler-ssr@3.5.30': - resolution: {integrity: sha512-NsYK6OMTnx109PSL2IAyf62JP6EUdk4Dmj6AkWcJGBvN0dQoMYtVekAmdqgTtWQgEJo+Okstbf/1p7qZr5H+bA==} + '@vue/compiler-ssr@3.5.31': + resolution: {integrity: sha512-h0xIMxrt/LHOvJKMri+vdYT92BrK3HFLtDqq9Pr/lVVfE4IyKZKvWf0vJFW10Yr6nX02OR4MkJwI0c1HDa1hog==} - '@vue/shared@3.5.30': - resolution: {integrity: sha512-YXgQ7JjaO18NeK2K9VTbDHaFy62WrObMa6XERNfNOkAhD1F1oDSf3ZJ7K6GqabZ0BvSDHajp8qfS5Sa2I9n8uQ==} + '@vue/shared@3.5.31': + resolution: {integrity: sha512-nBxuiuS9Lj5bPkPbWogPUnjxxWpkRniX7e5UBQDWl6Fsf4roq9wwV+cR7ezQ4zXswNvPIlsdj1slcLB7XCsRAw==} '@webassemblyjs/ast@1.14.1': resolution: {integrity: sha512-nuBEDgQfm1ccRp/8bCQrx1frohyufl4JlbMMZ4P1wpeOfDhF6FQkxZJ1b/e+PLwr6X1Nhw6OLme5usuBWYBvuQ==} @@ -3858,8 +4726,8 @@ packages: engines: {node: '>=0.4.0'} hasBin: true - agentation@2.3.3: - resolution: {integrity: sha512-AUZgFCdBQ/nAohlFsHByM9S2Dp7ECMNqVjlOke4hv/90v+wTiwrGladEkgWS60RDQp+CJ5p97meeCthYgTFlKQ==} + agentation@3.0.2: + resolution: {integrity: sha512-iGzBxFVTuZEIKzLY6AExSLAQH6i6SwxV4pAu7v7m3X6bInZ7qlZXAwrEqyc4+EfP4gM7z2RXBF6SF4DeH0f2lA==} peerDependencies: react: '>=18.0.0' react-dom: '>=18.0.0' @@ -3869,8 +4737,8 @@ packages: react-dom: optional: true - ahooks@3.9.6: - resolution: {integrity: sha512-Mr7f05swd5SmKlR9SZo5U6M0LsL4ErweLzpdgXjA1JPmnZ78Vr6wzx0jUtvoxrcqGKYnX0Yjc02iEASVxHFPjQ==} + ahooks@3.9.7: + resolution: {integrity: sha512-S0lvzhbdlhK36RFBkGv+RbOM/dbbweym+BIHM/bwwuWVSVN5TuVErHPMWo4w0t1NDYg5KPp2iEf7Y7E5LASYiw==} peerDependencies: react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 @@ -3898,6 +4766,10 @@ packages: resolution: {integrity: sha512-BvU8nYgGQBxcmMuEeUEmNTvrMVjJNSH7RgW24vXexN4Ven6qCvy4TntnvlnwnMLTVlcRQQdbRY8NKnaIoeWDNg==} engines: {node: '>=18'} + ansi-regex@4.1.1: + resolution: {integrity: sha512-ILlv4k/3f6vfQ4OoP2AGvirOktlQ98ZEL1k9FaQjxa3L1abBgbuTDAdPOpvbGncC0BTVQrl+OM8xZGK6tWXt7g==} + engines: {node: '>=6'} + ansi-regex@5.0.1: resolution: {integrity: sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==} engines: {node: '>=8'} @@ -3950,6 +4822,9 @@ packages: resolution: {integrity: sha512-COROpnaoap1E2F000S62r6A60uHZnmlvomhfyT2DlTcrY1OrBKn2UhH7qn5wTC9zMvD0AY7csdPSNwKP+7WiQw==} engines: {node: '>= 0.4'} + assertion-error-formatter@3.0.0: + resolution: {integrity: sha512-6YyAVLrEze0kQ7CmJfUgrLHb+Y7XghmL2Ie7ijVa2Y9ynP3LV+VDiwFk62Dn0qtqbmY0BT0ss6p1xxpiF2PYbQ==} + assertion-error@2.0.1: resolution: {integrity: sha512-Izi8RQcffqCeNVgFigKli1ssklIbpHnCYc6AknXGYoB6grJqyeby7jv12JUQgmTAnIDnbck1uxksT4dzN3PWBA==} engines: {node: '>=12'} @@ -3968,6 +4843,9 @@ packages: async@3.2.6: resolution: {integrity: sha512-htCUDlxyyCLMgaM3xXg0C0LW2xqfuQ6p05pCEIsXuyQ+a1koYKTuBMzRNwmybfLgvJDMd0r1LTn4+E0Ti6C2AA==} + asynckit@0.4.0: + resolution: {integrity: sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==} + autoprefixer@10.4.27: resolution: {integrity: sha512-NP9APE+tO+LuJGn7/9+cohklunJsXWiaWEfV3si4Gi/XHDwVNgkwr1J3RQYFIvPy76GmJ9/bW8vyoU1LcxwKHA==} engines: {node: ^10 || ^12 || >=14} @@ -3975,6 +4853,9 @@ packages: peerDependencies: postcss: ^8.1.0 + axios@1.14.0: + resolution: {integrity: sha512-3Y8yrqLSwjuzpXuZ0oIYZ/XGgLwUIBU3uLvbcpb0pidD9ctpShJd43KSlEEkVQg6DS0G9NKyzOvBfUtDKEyHvQ==} + bail@2.0.2: resolution: {integrity: sha512-0xO6mYd7JB2YesxDKplafRpsiOzPt9V02ddPCLbY1xYGPOX24NTyN50qnUxgCPcSoYMhKpAuBTjQoRZCAkUDRw==} @@ -3996,20 +4877,11 @@ packages: base64-js@1.5.1: resolution: {integrity: sha512-AKpaYlHn8t4SVbOHCy+b5+KKgvR4vrsD8vbvrbiQJps7fKDTkjkDry6ji0rUJjC0kzbNePLwzxq8iypo41qeWA==} - baseline-browser-mapping@2.10.8: - resolution: {integrity: sha512-PCLz/LXGBsNTErbtB6i5u4eLpHeMfi93aUv5duMmj6caNu6IphS4q6UevDnL36sZQv9lrP11dbPKGMaXPwMKfQ==} + baseline-browser-mapping@2.10.12: + resolution: {integrity: sha512-qyq26DxfY4awP2gIRXhhLWfwzwI+N5Nxk6iQi8EFizIaWIjqicQTE4sLnZZVdeKPRcVNoJOkkpfzoIYuvCKaIQ==} engines: {node: '>=6.0.0'} hasBin: true - before-after-hook@4.0.0: - resolution: {integrity: sha512-q6tR3RPqIB1pMiTRMFcZwuG5T8vwp+vUvEG0vuI6B+Rikh5BfPp2fQ82c925FOs+b0lcFQ8CFrL+KbilfZFhOQ==} - - bezier-easing@2.1.0: - resolution: {integrity: sha512-gbIqZ/eslnUFC1tjEvtz0sgx+xTK20wDnYMIA27VA04R7w6xxXQPZDbibjA9DTWZRA2CXtwHykkVzlCaAJAZig==} - - bidi-js@1.0.3: - resolution: {integrity: sha512-RKshQI1R3YQ+n9YJz2QQ147P66ELpa1FQEg20Dk8oW9t2KgLbpDLLp9aGZ7y8WHSshDknG0bknqGw5/tyCs5tw==} - binary-extensions@2.3.0: resolution: {integrity: sha512-Ceh+7ox5qe7LJuLHoY0feh3pHuUDHAcRUeyL2VYghZwfpkNIy/+8Ocg0a3UuSoYzavmylwuLWQOf3hl0jjMMIw==} engines: {node: '>=8'} @@ -4017,8 +4889,8 @@ packages: birecord@0.1.1: resolution: {integrity: sha512-VUpsf/qykW0heRlC8LooCq28Kxn3mAqKohhDG/49rrsQ1dT1CXyj/pgXS+5BSRzFTR/3DyIBOqQOrGyZOh71Aw==} - birpc@2.9.0: - resolution: {integrity: sha512-KrayHS5pBi69Xi9JmvoqrIgYGDkD6mcSe/i6YKi3w5kekCLzrX4+nawcXqrj2tIp50Kw/mT/s3p+GVK0A0sKxw==} + birpc@4.0.0: + resolution: {integrity: sha512-LShSxJP0KTmd101b6DRyGBj57LZxSDYWKitQNW/mi8GRMvZb078Uf9+pveax1DrVL89vm7mWe+TovdI/UDOuPw==} bl@4.1.0: resolution: {integrity: sha512-1W07cM9gS6DcLperZfFSj+bWLtaPGSOHWhPiGzXmvVJbRLdG82sH/Kn8EtW1VqWVA54AKf2h5k5BbnIbwF3h6w==} @@ -4029,8 +4901,8 @@ packages: brace-expansion@2.0.2: resolution: {integrity: sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==} - brace-expansion@5.0.4: - resolution: {integrity: sha512-h+DEnpVvxmfVefa4jFbCf5HdH5YMDXRsmKflpf1pILZWRFlTbJpxeU55nJl4Smt5HQaGzg1o6RHFPJaOqnmBDg==} + brace-expansion@5.0.5: + resolution: {integrity: sha512-VZznLgtwhn+Mact9tfiwx64fA9erHH/MCXEUfB/0bX/6Fz6ny5EGTXYltMocqg4xFAQZtnO3DHWWXi8RiuN7cQ==} engines: {node: 18 || 20 || >=22} braces@3.0.3: @@ -4063,6 +4935,12 @@ packages: resolution: {integrity: sha512-tjwM5exMg6BGRI+kNmTntNsvdZS1X8BFYS6tnJ2hdH0kVxM6/eVZ2xy+FqStSWvYmtfFMDLIxurorHwDKfDz5Q==} engines: {node: '>=18'} + bundle-require@5.1.0: + resolution: {integrity: sha512-3WrrOuZiyaaZPWiEt4G3+IffISVC9HYlWueJEBWED4ZH4aIAC2PnkdnuRrR94M+w6yGWn4AglWtJtBI8YqvgoA==} + engines: {node: ^12.20.0 || ^14.13.1 || >=16.0.0} + peerDependencies: + esbuild: 0.27.2 + bytes@3.1.2: resolution: {integrity: sha512-/Nf7TyzTx6S3yRJObOAV7956r8cr2+Oj8AC5dt8wSP3BQAoeX58NoHyCU8P8zGkNXStjTSi6fzO6F0pBdcYbEg==} engines: {node: '>= 0.8'} @@ -4075,6 +4953,10 @@ packages: resolution: {integrity: sha512-tixWYgm5ZoOD+3g6UTea91eow5z6AAHaho3g0V9CNSNb45gM8SmflpAc+GRd1InC4AqN/07Unrgp56Y94N9hJQ==} engines: {node: '>=20.19.0'} + call-bind-apply-helpers@1.0.2: + resolution: {integrity: sha512-Sp1ablJ0ivDkSzjcaJdxEunN5/XvksFJ2sMBFfq6x0ryhQV/2b/KwFe21cMpmHtPOSij8K99/wSfoEuTObmuMQ==} + engines: {node: '>= 0.4'} + callsites@3.1.0: resolution: {integrity: sha512-P8BjAsXvZS+VIDUI11hHCQEv74YT67YUi5JJFNWIqL235sBmjX4+qx9Muvls5ivyNENctx46xQLQ3aTuE7ssaQ==} engines: {node: '>=6'} @@ -4086,13 +4968,16 @@ packages: camelize@1.0.1: resolution: {integrity: sha512-dU+Tx2fsypxTgtLoE36npi3UqcjSSMNYfkqgmoEhtZrraP5VWq0K7FkWVTYa8eMPtnU/G2txVsfdCJTn9uzpuQ==} - caniuse-lite@1.0.30001780: - resolution: {integrity: sha512-llngX0E7nQci5BPJDqoZSbuZ5Bcs9F5db7EtgfwBerX9XGtkkiO4NwfDDIRzHTTwcYC8vC7bmeUEPGrKlR/TkQ==} + caniuse-lite@1.0.30001781: + resolution: {integrity: sha512-RdwNCyMsNBftLjW6w01z8bKEvT6e/5tpPVEgtn22TiLGlstHOVecsX2KHFkD5e/vRnIE4EGzpuIODb3mtswtkw==} - canvas@3.2.1: - resolution: {integrity: sha512-ej1sPFR5+0YWtaVp6S1N1FVz69TQCqmrkGeRvQxZeAB1nAIcjNTHVwrZtYtWFFBmQsF40/uDLehsW5KuYC99mg==} + canvas@3.2.2: + resolution: {integrity: sha512-duEt4h1HHu9sJZyVKfLRXR6tsKPY7cEELzxSRJkwddOXYvQT3P/+es98SV384JA0zMOZ5s+9gatnGfM6sL4Drg==} engines: {node: ^18.12.0 || >= 20.9.0} + capital-case@1.0.4: + resolution: {integrity: sha512-ds37W8CytHgwnhGGTi88pcPyR15qoNkOpYwmMMfnWqqWgESapLqvDx6huFjQ5vqWSn2Z06173XNA7LtMOeUh1A==} + ccount@2.0.1: resolution: {integrity: sha512-eyrF0jiFpY+3drT6383f1qhkbGsLSifNAjA61IUjZjmLCWjItY6LB9ft9YhoDgwfmclB2zhu51Lc7+95b8NRAg==} @@ -4186,6 +5071,9 @@ packages: resolution: {integrity: sha512-77PSwercCZU2Fc4sX94eF8k8Pxte6JAwL4/ICZLFjJLqegs7kCuAsqqj/70NQF6TvDpgFjkubQB2FW2ZZddvQg==} engines: {node: '>=8'} + class-transformer@0.5.1: + resolution: {integrity: sha512-SQa1Ws6hUbfC98vKGxZH3KFY0Y1lm5Zm0SY8XX9zbK7FJCyVEac3ATW0RIpwzW+oOfmHE5PMPufDG9hCfoEOMw==} + class-variance-authority@0.7.1: resolution: {integrity: sha512-Ka+9Trutv7G8M6WT6SeiRWz792K5qEqIGEGzXKhAE6xOWAY6pPH8U+9IY3oCMv6kqTmLsv7Xh/2w2RigkePMsg==} @@ -4203,6 +5091,10 @@ packages: resolution: {integrity: sha512-aCj4O5wKyszjMmDT4tZj93kxyydN/K5zPWSCe6/0AV/AA1pqe5ZBIw0a2ZfPQV7lL5/yb5HsUreJ6UFAF1tEQw==} engines: {node: '>=18'} + cli-table3@0.6.5: + resolution: {integrity: sha512-+W/5efTR7y5HRD7gACw9yQjqMVvEMLBHmboM/kPWam+H+Hmyrgjh6YncVKK122YZkXrLudzTuAukUw9FnMf7IQ==} + engines: {node: 10.* || >= 12.*} + cli-truncate@5.2.0: resolution: {integrity: sha512-xRwvIOMGrfOAnM1JYtqQImuaNtDEv9v6oIYAs4LIHwTiKee8uwvIi363igssOC0O5U04i4AlENs79LQLu9tEMw==} engines: {node: '>=20'} @@ -4220,8 +5112,8 @@ packages: react: ^18 || ^19 || ^19.0.0-rc react-dom: ^18 || ^19 || ^19.0.0-rc - code-inspector-plugin@1.4.4: - resolution: {integrity: sha512-fdrSiP5jJ+FFLQmUyaF52xBB1yelJJtGdzr9wwFUJlbq5di4+rfyBHIzSrYgCTU5EAMrsRZ2eSnJb4zFa8Svvw==} + code-inspector-plugin@1.4.5: + resolution: {integrity: sha512-yp3zHd5AZhtVoBNOzKQuJVo1wZe7AIO2vAiVhF8WIAK02IwM9+gY+Pr9deajx+XyJLbzMW+3CgdfLIh+xxW2Hg==} collapse-white-space@2.1.0: resolution: {integrity: sha512-loKTxY1zCOuG4j9f6EPnuyyYkf58RnhhWTvRoZEokgB+WbdXehfjFviyOVYkqzEWz1Q5kRiZdBYS5SwxbQYwzw==} @@ -4236,12 +5128,24 @@ packages: colorette@2.0.20: resolution: {integrity: sha512-IfEDxwoWIjkeXL1eXcDiow4UbKjhLdq6/EuSVR9GMN7KVH3r9gQ83e73hsz1Nd1T3ijd5xv1wcWRYO+D6kCI2w==} + combined-stream@1.0.8: + resolution: {integrity: sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg==} + engines: {node: '>= 0.8'} + comma-separated-tokens@1.0.8: resolution: {integrity: sha512-GHuDRO12Sypu2cV70d1dkA2EUmXHgntrzbpvOB+Qy+49ypNfGgFQIC2fhhXbnyrJRynDCAARsT7Ou0M6hirpfw==} comma-separated-tokens@2.0.3: resolution: {integrity: sha512-Fu4hJdvzeylCfQPp9SGWidpzrMs7tTrlu6Vb8XGaRGck8QSNZJJp538Wrb60Lax4fPwR64ViY468OIUTbRlGZg==} + commander@14.0.0: + resolution: {integrity: sha512-2uM9rYjPvyq39NwLRqaiLtWHyDC1FvryJDa2ATTVims5YAS4PupsEQsDvP14FqhFr0P49CYDugi59xaxJlTXRA==} + engines: {node: '>=20'} + + commander@14.0.2: + resolution: {integrity: sha512-TywoWNNRbhoD0BXs1P3ZEScW8W5iKrnbithIl0YH+uCmBd0QpPOA8yc82DS3BIE5Ma6FnBVUsJ7wVUDz4dvOWQ==} + engines: {node: '>=20'} + commander@14.0.3: resolution: {integrity: sha512-H+y0Jo/T1RZ9qPP4Eh1pkcQcLRglraJaSLoyOtHxu6AapkjWVCy2Sit1QQ4x3Dng8qDlSsZEet7g5Pq06MvTgw==} engines: {node: '>=20'} @@ -4265,6 +5169,10 @@ packages: resolution: {integrity: sha512-aRDkn3uyIlCFfk5NUA+VdwMmMsh8JGhc4hapfV4yxymHGQ3BVskMQfoXGpCo5IoBuQ9tS5iiVKhCpTcB4pW4qw==} engines: {node: '>= 12.0.0'} + comment-parser@1.4.6: + resolution: {integrity: sha512-ObxuY6vnbWTN6Od72xfwN9DbzC7Y2vv8u1Soi9ahRKL37gb6y1qk6/dgjs+3JWuXJHWvsg3BXIwzd/rkmAwavg==} + engines: {node: '>= 12.0.0'} + compare-versions@6.1.1: resolution: {integrity: sha512-4hm4VPpIecmlg59CHXnRDnqGplJFrbLG4aFEl5vl6cK1u76ws3LLvX7ikFnTDl5vo39sjWD6AaDPYodJp/NNHg==} @@ -4274,6 +5182,10 @@ packages: confbox@0.2.4: resolution: {integrity: sha512-ysOGlgTFbN2/Y6Cg3Iye8YKulHw+R2fNXHrgSmXISQdMnomY6eNDprVdW9R5xBguEqI954+S6709UyiO7B+6OQ==} + consola@3.4.2: + resolution: {integrity: sha512-5IKcdX0nnYavi6G7TtOhwkYzyjfJlatbjMjuLSfE2kYT5pMDOilZ4OvMhi637CcDICTmz3wARPoyhqyX1Y+XvA==} + engines: {node: ^14.18.0 || >=16.10.0} + convert-source-map@2.0.0: resolution: {integrity: sha512-Kvp459HrV2FEJ1CAsi1Ku+MY3kasH19TFykTz2xWmMeq6bk2NU3XXvfJ+Q61m0xktWwt+1HSYf3JZsTms3aRJg==} @@ -4311,9 +5223,6 @@ packages: resolution: {integrity: sha512-3O5QdqgFRUbXvK1x5INf1YkBz1UKSWqrd63vWsum8MNHDBYD5urm3QtxZbKU259OrEXNM26lP/MPY3d1IGkBgA==} engines: {node: '>=16'} - css-mediaquery@0.1.2: - resolution: {integrity: sha512-COtn4EROW5dBGlE/4PiKnh6rZpAPxDeFLaEEwt4i10jpDMFt2EhQGS79QmmrO+iKCHv0PU/HrOWEhijFd1x99Q==} - css-select@5.2.2: resolution: {integrity: sha512-TizTzUddG/xYLA3NXodFM0fSbNizXjOKhqiQQwvhlspadZokn1KDy0NZFS0wuEubIYAV5/c1/lAr0TaaFXEXzw==} @@ -4328,10 +5237,6 @@ packages: resolution: {integrity: sha512-6Fv1DV/TYw//QF5IzQdqsNDjx/wc8TrMBZsqjL9eW01tWb7R7k/mq+/VXfJCl7SoD5emsJop9cOByJZfs8hYIw==} engines: {node: ^10 || ^12.20.0 || ^14.13.0 || >=15.0.0} - css-tree@3.2.1: - resolution: {integrity: sha512-X7sjQzceUhu1u7Y/ylrRZFU2FS6LRiFVp6rKLPg23y3x3c3DOKAwuXGDp+PAGjh6CSnCjYeAul8pcT8bAl+lSA==} - engines: {node: ^10 || ^12.20.0 || ^14.13.0 || >=15.0.0} - css-what@6.2.2: resolution: {integrity: sha512-u/O3vwbptzhMs3L1fQE82ZSLHQQfto5gyZzwteVIEyeaY5Fc7R4dapF/BvRoSYFeqfBk4m0V1Vafq5Pjv25wvA==} engines: {node: '>= 6'} @@ -4510,10 +5415,6 @@ packages: dagre-d3-es@7.0.14: resolution: {integrity: sha512-P4rFMVq9ESWqmOgK+dlXvOtLwYg0i7u0HBGJER0LZDJT2VHIPAMZ/riPxqJceWMStH5+E61QxFra9kIS3AqdMg==} - data-urls@7.0.0: - resolution: {integrity: sha512-23XHcCF+coGYevirZceTVD7NdJOqVn+49IHyxgszm+JIiHLoB2TkmPtsYkNWT1pvRSGkc35L6NHs0yHkN2SumA==} - engines: {node: ^20.19.0 || ^22.12.0 || >=24.0.0} - dayjs@1.11.20: resolution: {integrity: sha512-YbwwqR/uYpeoP4pu043q+LTDLFBLApUP6VxRihdfNTqu4ubqMlGDLd6ErXhEgsyvY0K6nCs7nggYumAN+9uEuQ==} @@ -4562,8 +5463,12 @@ packages: defu@6.1.4: resolution: {integrity: sha512-mEQCMmwJu317oSz8CwdIOdwf3xMif1ttiM8LTufzc3g6kR+9Pe236twL8j3IYT1F7GfRgGcW6MWxzZjLIkuHIg==} - delaunator@5.0.1: - resolution: {integrity: sha512-8nvh+XBe96aCESrGOqMp/84b13H9cdKbG5P2ejQCh4d4sK9RL4371qou9drQjMhvnPmhWl5hnmqbEE0fXr9Xnw==} + delaunator@5.1.0: + resolution: {integrity: sha512-AGrQ4QSgssa1NGmWmLPqN5NY2KajF5MqxetNEO+o0n3ZwZZeTmt7bBnvzHWrmkZFxGgr4HdyFgelzgi06otLuQ==} + + delayed-stream@1.0.0: + resolution: {integrity: sha512-ZySD7Nf91aLB0RxL4KGrKHBXl7Eds1DAmEdcoVawXnLD7SDhpNgtuII2aAkg7a7QS41jxPSZ17p4VdGnMHk3MQ==} + engines: {node: '>=0.4.0'} dequal@2.0.3: resolution: {integrity: sha512-0je+qPKHEMohvfRTCEo3CrPG6cAzAYgmzKyxRiYSSDkS6eGJdyVJm7WaYA5ECaAD9wLB2T4EEeymA5aFVcYXCA==} @@ -4589,6 +5494,10 @@ packages: resolution: {integrity: sha512-EjePK1srD3P08o2j4f0ExnylqRs5B9tJjcp9t1krH2qRi8CCdsYfwe9JgSLurFBWwq4uOlipzfk5fHNvwFKr8Q==} engines: {node: ^14.15.0 || ^16.10.0 || >=18.0.0} + diff@4.0.4: + resolution: {integrity: sha512-X07nttJQkwkfKfvTPG/KSnE2OMdcUCao6+eXF3wmnIQRn2aPAHH3VxDbDOdegkd6JbPsXqShpvEOHfAT+nCNwQ==} + engines: {node: '>=0.3.1'} + dlv@1.1.3: resolution: {integrity: sha512-+HlytyjlPKnIG8XuRG8WvmBP8xs8P71y+SKKS6ZXWoEgLuePxtDoUEiH7WkdePWrQ5JBpE6aoVqfZfJUQkjXwA==} @@ -4626,6 +5535,10 @@ packages: resolution: {integrity: sha512-uBq4egWHTcTt33a72vpSG0z3HnPuIl6NqYcTrKEg2azoEyl2hpW0zqlxysq2pK9HlDIHyHyakeYaYnSAwd8bow==} engines: {node: '>=12'} + dunder-proto@1.0.1: + resolution: {integrity: sha512-KIN/nDJBQRcXw0MLVhZE9iQHmG68qAVIBg9CqmUYjmQIhgij9U5MFvrqkUL5FbtyyzZuOeOt0zdeRe4UY7ct+A==} + engines: {node: '>= 0.4'} + echarts-for-react@3.0.6: resolution: {integrity: sha512-4zqLgTGWS3JvkQDXjzkR1k1CHRdpd6by0988TWMJgnvDytegWLbeP/VNZmMa+0VJx2eD7Y632bi2JquXDgiGJg==} peerDependencies: @@ -4635,8 +5548,8 @@ packages: echarts@6.0.0: resolution: {integrity: sha512-Tte/grDQRiETQP4xz3iZWSvoHrkCQtwqd6hs+mifXcjrCuo2iKWbajFObuLJVBlDIJlOzgQPd1hsaKt/3+OMkQ==} - electron-to-chromium@1.5.313: - resolution: {integrity: sha512-QBMrTWEf00GXZmJyx2lbYD45jpI3TUFnNIzJ5BBc8piGUDwMPa1GV6HJWTZVvY/eiN3fSopl7NRbgGp9sZ9LTA==} + electron-to-chromium@1.5.328: + resolution: {integrity: sha512-QNQ5l45DzYytThO21403XN3FvK0hOkWDG8viNf6jqS42msJ8I4tGDSpBCgvDRRPnkffafiwAym2X2eHeGD2V0w==} elkjs@0.11.1: resolution: {integrity: sha512-zxxR9k+rx5ktMwT/FwyLdPCrq7xN6e4VGGHH8hA01vVYKjTFik7nHOxBnAYtrgYUB1RpAiLvA1/U2YraWxyKKg==} @@ -4699,12 +5612,31 @@ packages: error-stack-parser-es@1.0.5: resolution: {integrity: sha512-5qucVt2XcuGMcEGgWI7i+yZpmpByQ8J1lHhcL7PwqCwu9FPP3VUXzT4ltHe5i2z9dePwEHcDVOAfSnHsOlCXRA==} + error-stack-parser@2.1.4: + resolution: {integrity: sha512-Sk5V6wVazPhq5MhpO+AUxJn5x7XSXGl1R93Vn7i+zS15KDVxQijejNCrz8340/2bgLBjR9GtEG8ZVKONDjcqGQ==} + + es-define-property@1.0.1: + resolution: {integrity: sha512-e3nRfgfUZ4rNGL232gUgX06QNyyez04KdjFrF+LTRoOXmrOgFKDg4BCdsjW8EnT69eqdYGmRpJwiPVYNrCaW3g==} + engines: {node: '>= 0.4'} + + es-errors@1.3.0: + resolution: {integrity: sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==} + engines: {node: '>= 0.4'} + es-module-lexer@1.7.0: resolution: {integrity: sha512-jEQoCwk8hyb2AZziIOLhDqpm5+2ww5uIE6lkO/6jcOCusfk6LhMHpXXfBLXTZ7Ydyt0j4VoUQv6uGNYbdW+kBA==} es-module-lexer@2.0.0: resolution: {integrity: sha512-5POEcUuZybH7IdmGsD8wlf0AI55wMecM9rVBTI/qEAy2c1kTOm3DjFYjrBdI2K3BaJjJYfYFeRtM0t9ssnRuxw==} + es-object-atoms@1.1.1: + resolution: {integrity: sha512-FGgH2h8zKNim9ljj7dankFPcICIK9Cp5bm+c2gQSYePhpaG5+esrLODihIorn+Pe6FGJzWhXQotPv73jTaldXA==} + engines: {node: '>= 0.4'} + + es-set-tostringtag@2.1.0: + resolution: {integrity: sha512-j6vWzfrGVfyXxge+O0x5sh6cvxAog0a/4Rdd2K36zCMV5eJ+/+tOAngRO8cODMNWbVRdVlmGZQL2YS3yR8bIUA==} + engines: {node: '>= 0.4'} + es-toolkit@1.45.1: resolution: {integrity: sha512-/jhoOj/Fx+A+IIyDNOvO3TItGmlMKhtX8ISAHKE90c4b/k1tqaqEZ+uUqfpU8DMnW5cgNJv606zS55jGvza0Xw==} @@ -4744,8 +5676,8 @@ packages: peerDependencies: eslint: '>=6.0.0' - eslint-config-flat-gitignore@2.2.1: - resolution: {integrity: sha512-wA5EqN0era7/7Gt5Botlsfin/UNY0etJSEeBgbUlFLFrBi47rAN//+39fI7fpYcl8RENutlFtvp/zRa/M/pZNg==} + eslint-config-flat-gitignore@2.3.0: + resolution: {integrity: sha512-bg4ZLGgoARg1naWfsINUUb/52Ksw/K22K+T16D38Y8v+/sGwwIYrGvH/JBjOin+RQtxxC9tzNNiy4shnGtGyyQ==} peerDependencies: eslint: ^9.5.0 || ^10.0.0 @@ -4763,6 +5695,15 @@ packages: '@eslint/json': optional: true + eslint-markdown@0.6.0: + resolution: {integrity: sha512-NrgfiNto5IJrW1F/Akf2hJYoJTCbXoClOUvtUMDgoqmQNH0VRihNvFh+MFay4E0HV2eozfgxsLSGxnndtRJA8w==} + engines: {node: ^20.19.0 || ^22.13.0 || >=24.0.0} + peerDependencies: + eslint: ^9.31.0 || ^10.0.0-rc.0 + peerDependenciesMeta: + eslint: + optional: true + eslint-merge-processors@2.0.0: resolution: {integrity: sha512-sUuhSf3IrJdGooquEUB5TNpGNpBoQccbnaLHsb1XkBLUPPqCNivCpY05ZcpCOiV9uHwO2yxXEWVczVclzMxYlA==} peerDependencies: @@ -4816,8 +5757,8 @@ packages: peerDependencies: eslint: '>=9.0.0' - eslint-plugin-jsdoc@62.8.0: - resolution: {integrity: sha512-hu3r9/6JBmPG6wTcqtYzgZAnjEG2eqRUATfkFscokESg1VDxZM21ZaMire0KjeMwfj+SXvgB4Rvh5LBuesj92w==} + eslint-plugin-jsdoc@62.8.1: + resolution: {integrity: sha512-e9358PdHgvcMF98foNd3L7hVCw70Lt+YcSL7JzlJebB8eT5oRJtW6bHMQKoAwJtw6q0q0w/fRIr2kwnHdFDI6A==} engines: {node: ^20.19.0 || ^22.13.0 || >=24} peerDependencies: eslint: ^7.0.0 || ^8.0.0 || ^9.0.0 || ^10.0.0 @@ -4828,12 +5769,22 @@ packages: peerDependencies: eslint: '>=9.38.0' + eslint-plugin-markdown-preferences@0.40.3: + resolution: {integrity: sha512-R3CCAEFwnnYXukTdtvdsamGjbTgVs9UZKqMKhNeWNXzFtOP1Frc89bgbd56lJUN7ASaxgvzc5fUpKvDCOTtDpg==} + engines: {node: ^20.19.0 || ^22.12.0 || >=24.0.0} + peerDependencies: + '@eslint/markdown': ^7.4.0 + eslint: '>=9.0.0' + eslint-plugin-n@17.24.0: resolution: {integrity: sha512-/gC7/KAYmfNnPNOb3eu8vw+TdVnV0zhdQwexsw6FLXbhzroVj20vRn2qL8lDWDGnAQ2J8DhdfvXxX9EoxvERvw==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} peerDependencies: eslint: '>=8.23.0' + eslint-plugin-no-barrel-files@1.2.2: + resolution: {integrity: sha512-DF2bnHuEHClmL1+maBO5TD2HnnRsLj8J69FFtVkjObkELyjCXaWBsk+URJkqBpdOWURlL+raGX9AEpWCAiOV0g==} + eslint-plugin-no-only-tests@3.3.0: resolution: {integrity: sha512-brcKcxGnISN2CcVhXJ/kEQlNa0MEfGRtwKtWA16SkqXHKitaKIMrfemJKLKX1YqDU5C/5JY3PvZXd5jEW04e0Q==} engines: {node: '>=5.0.0'} @@ -4849,19 +5800,12 @@ packages: peerDependencies: eslint: ^9.0.0 || ^10.0.0 - eslint-plugin-react-dom@2.13.0: - resolution: {integrity: sha512-+2IZzQ1WEFYOWatW+xvNUqmZn55YBCufzKA7hX3XQ/8eu85Mp4vnlOyNvdVHEOGhUnGuC6+9+zLK+IlEHKdKLQ==} - engines: {node: '>=20.19.0'} + eslint-plugin-react-dom@3.0.0: + resolution: {integrity: sha512-NhxPJSGZzR/bW02wop2whWXYKE8ZLZ9JupC5MWRq1AdM+Z84jnUU8c+eobiRzIhy2OupEjKcB8TaqHuQ+3sVoQ==} + engines: {node: '>=22.0.0'} peerDependencies: - eslint: ^8.57.0 || ^9.0.0 || ^10.0.0 - typescript: '>=4.8.4 <6.0.0' - - eslint-plugin-react-hooks-extra@2.13.0: - resolution: {integrity: sha512-qIbha1nzuyhXM9SbEfrcGVqmyvQu7GAOB2sy9Y4Qo5S8nCqw4fSBxq+8lSce5Tk5Y7XzIkgHOhNyXEvUHRWFMQ==} - engines: {node: '>=20.19.0'} - peerDependencies: - eslint: ^8.57.0 || ^9.0.0 || ^10.0.0 - typescript: '>=4.8.4 <6.0.0' + eslint: ^10.0.0 + typescript: '*' eslint-plugin-react-hooks@7.0.1: resolution: {integrity: sha512-O0d0m04evaNzEPoSW+59Mezf8Qt0InfgGIBJnpC0h3NH/WjUAR7BIKUfysC6todmtiZ/A0oUVS8Gce0WhBrHsA==} @@ -4869,38 +5813,38 @@ packages: peerDependencies: eslint: ^3.0.0 || ^4.0.0 || ^5.0.0 || ^6.0.0 || ^7.0.0 || ^8.0.0-0 || ^9.0.0 - eslint-plugin-react-naming-convention@2.13.0: - resolution: {integrity: sha512-uSd25JzSg2R4p81s3Wqck0AdwRlO9Yc+cZqTEXv7vW8exGGAM3mWnF6hgrgdqVJqBEGJIbS/Vx1r5BdKcY/MHA==} - engines: {node: '>=20.19.0'} + eslint-plugin-react-naming-convention@3.0.0: + resolution: {integrity: sha512-pAtOZST5/NhWIa/I5yz7H1HEZTtCY7LHMhzmN9zvaOdTWyZYtz2g9pxPRDBnkR9uSmHsNt44gj+2JSAD4xwgew==} + engines: {node: '>=22.0.0'} peerDependencies: - eslint: ^8.57.0 || ^9.0.0 || ^10.0.0 - typescript: '>=4.8.4 <6.0.0' + eslint: ^10.0.0 + typescript: '*' eslint-plugin-react-refresh@0.5.2: resolution: {integrity: sha512-hmgTH57GfzoTFjVN0yBwTggnsVUF2tcqi7RJZHqi9lIezSs4eFyAMktA68YD4r5kNw1mxyY4dmkyoFDb3FIqrA==} peerDependencies: eslint: ^9 || ^10 - eslint-plugin-react-rsc@2.13.0: - resolution: {integrity: sha512-RaftgITDLQm1zIgYyvR51sBdy4FlVaXFts5VISBaKbSUB0oqXyzOPxMHasfr9BCSjPLKus9zYe+G/Hr6rjFLXQ==} + eslint-plugin-react-rsc@3.0.0: + resolution: {integrity: sha512-HNP1hVO63WsV4wcXxPJJIcnYrvrN5UZyrXIbDOoCNA0axSXjJ6vA63tI2JHgyGcMTdbKxDJwaVd/dJlMudSZBQ==} engines: {node: '>=20.19.0'} peerDependencies: - eslint: ^8.57.0 || ^9.0.0 || ^10.0.0 - typescript: '>=4.8.4 <6.0.0' + eslint: ^10.0.0 + typescript: '*' - eslint-plugin-react-web-api@2.13.0: - resolution: {integrity: sha512-nmJbzIAte7PeAkp22CwcKEASkKi49MshSdiDGO1XuN3f4N4/8sBfDcWbQuLPde6JiuzDT/0+l7Gi8wwTHtR1kg==} - engines: {node: '>=20.19.0'} + eslint-plugin-react-web-api@3.0.0: + resolution: {integrity: sha512-DZZh9DkZp/BE5ibaDOXaV4p8rEuMNnoPkCvAlyifB/Gz6ZhHonFRTpg+PEK6et8sx6uroUfhy5QGducmZU8Oug==} + engines: {node: '>=22.0.0'} peerDependencies: - eslint: ^8.57.0 || ^9.0.0 || ^10.0.0 - typescript: '>=4.8.4 <6.0.0' + eslint: ^10.0.0 + typescript: '*' - eslint-plugin-react-x@2.13.0: - resolution: {integrity: sha512-cMNX0+ws/fWTgVxn52qAQbaFF2rqvaDAtjrPUzY6XOzPjY0rJQdR2tSlWJttz43r2yBfqu+LGvHlGpWL2wfpTQ==} - engines: {node: '>=20.19.0'} + eslint-plugin-react-x@3.0.0: + resolution: {integrity: sha512-W8QGWk03iqj6EiOhQk2SrrnaiTb2RZFREg1YXgYAh2/zyFztHHnNz4LTeSN+6gFwWDypMFzuFF6uoHO/1KY0Yw==} + engines: {node: '>=22.0.0'} peerDependencies: - eslint: ^8.57.0 || ^9.0.0 || ^10.0.0 - typescript: '>=4.8.4 <6.0.0' + eslint: ^10.0.0 + typescript: '*' eslint-plugin-regexp@3.1.0: resolution: {integrity: sha512-qGXIC3DIKZHcK1H9A9+Byz9gmndY6TTSRkSMTZpNXdyCw2ObSehRgccJv35n9AdUakEjQp5VFNLas6BMXizCZg==} @@ -4913,11 +5857,11 @@ packages: peerDependencies: eslint: ^8.0.0 || ^9.0.0 || ^10.0.0 - eslint-plugin-storybook@10.3.0: - resolution: {integrity: sha512-8R0/RjELXkJ2RxPusX14ZiIj1So90bPnrjbxmQx1BD+4M2VoMHfn3n+6IvzJWQH4FT5tMRRUBqjLBe1fJjRRkg==} + eslint-plugin-storybook@10.3.3: + resolution: {integrity: sha512-jo8wZvKaJlxxrNvf4hCsROJP3CdlpaLiYewAs5Ww+PJxCrLelIi5XVHWOAgBvvr3H9WDKvUw8xuvqPYqAlpkFg==} peerDependencies: eslint: '>=8' - storybook: ^10.3.0 + storybook: ^10.3.3 eslint-plugin-toml@1.3.1: resolution: {integrity: sha512-1l00fBP03HIt9IPV7ZxBi7x0y0NMdEZmakL1jBD6N/FoKBvfKxPw5S8XkmzBecOnFBTn5Z8sNJtL5vdf9cpRMQ==} @@ -4990,8 +5934,8 @@ packages: resolution: {integrity: sha512-tD40eHxA35h0PEIZNeIjkHoDR4YjjJp34biM0mDvplBe//mB+IHCqHDGV7pxF+7MklTvighcCPPZC7ynWyjdTA==} engines: {node: ^20.19.0 || ^22.13.0 || >=24} - eslint@10.0.3: - resolution: {integrity: sha512-COV33RzXZkqhG9P2rZCFl9ZmJ7WL+gQSCRzE7RhkbclbQPtLAWReL7ysA0Sh4c8Im2U9ynybdR56PV0XcKvqaQ==} + eslint@10.1.0: + resolution: {integrity: sha512-S9jlY/ELKEUwwQnqWDO+f+m6sercqOPSqXM5Go94l7DOmxHVDgmSFGWEzeE/gwgTAr0W103BWt0QLe/7mabIvA==} engines: {node: ^20.19.0 || ^22.13.0 || >=24} hasBin: true peerDependencies: @@ -5092,9 +6036,6 @@ packages: engines: {node: '>= 10.17.0'} hasBin: true - fast-content-type-parse@3.0.0: - resolution: {integrity: sha512-ZvLdcY8P+N8mGQJahJV5G4U88CSvT1rP8ApL6uETe88MBXrBHAkZlSEySdUlyztF7ccb+Znos3TFqaepHxdhBg==} - fast-deep-equal@3.1.3: resolution: {integrity: sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==} @@ -5131,7 +6072,7 @@ packages: resolution: {integrity: sha512-tIbYtZbucOs0BRGqPJkshJUYdL+SDH7dVM8gjy+ERp3WAUjLEFJE+02kanyHtwjWOnwrKYBiwAmM0p4kLJAnXg==} engines: {node: '>=12.0.0'} peerDependencies: - picomatch: ^3 || ^4 + picomatch: 4.0.4 peerDependenciesMeta: picomatch: optional: true @@ -5142,6 +6083,10 @@ packages: fflate@0.7.4: resolution: {integrity: sha512-5u2V/CDW15QM1XbbgS+0DfPxVB+jUKhWEKuuFuHncbk3tEEqzmoXL+2KyOFuKGqOnmdIy0/davWF1CkuwtibCw==} + figures@3.2.0: + resolution: {integrity: sha512-yaduQFRKLXYOGgEn6AZau90j3ggSOyiqXU0F9JZfeXYhNa+Jk4X+s45A2zg5jns87GAFa34BBm2kXw4XpNcbdg==} + engines: {node: '>=8'} + file-entry-cache@8.0.0: resolution: {integrity: sha512-XXTUwCvisa5oacNGRP9SfNtYBNAMi+RPwBFmblZEF7N7swHYQS6/Zfk7SRwx4D5j3CH211YNRco1DEMNVfZCnQ==} engines: {node: '>=16.0.0'} @@ -5162,6 +6107,9 @@ packages: resolution: {integrity: sha512-78/PXT1wlLLDgTzDs7sjq9hzz0vXD+zn+7wypEe4fXQxCmdmqfGsEPQxmiCSQI3ajFV91bVSsvNtrJRiW6nGng==} engines: {node: '>=10'} + fix-dts-default-cjs-exports@1.0.1: + resolution: {integrity: sha512-pVIECanWFC61Hzl2+oOCtoJ3F17kglZC/6N94eRWycFgBH35hHx0Li604ZIzhseh97mf2p0cv7vVrOZGoqhlEg==} + flat-cache@4.0.1: resolution: {integrity: sha512-f7ccFPK3SXFHpx15UIGyRJ/FJQctuKZ0zVuN3frBo4HnK3cay9VEW0R6yPYFHC0AgqhukPzKjq22t5DmAyqGyw==} engines: {node: '>=16'} @@ -5169,6 +6117,19 @@ packages: flatted@3.4.2: resolution: {integrity: sha512-PjDse7RzhcPkIJwy5t7KPWQSZ9cAbzQXcafsetQoD7sOJRQlGikNbx7yZp2OotDnJyrDcbyRq3Ttb18iYOqkxA==} + follow-redirects@1.15.11: + resolution: {integrity: sha512-deG2P0JfjrTxl50XGCDyfI97ZGVCxIpfKYmfyrQ54n5FO/0gfIES8C/Psl6kWVDolizcaaxZJnTS0QSMxvnsBQ==} + engines: {node: '>=4.0'} + peerDependencies: + debug: '*' + peerDependenciesMeta: + debug: + optional: true + + form-data@4.0.5: + resolution: {integrity: sha512-8RipRLol37bNs2bhoV67fiTEvdTrbMUYcFTiy3+wuuOnUog2QBHCZWXDRijWQfAkhBj2Uf5UnVaiWwA5vdd82w==} + engines: {node: '>= 6'} + format@0.2.2: resolution: {integrity: sha512-wzsgA6WOq+09wrU1tsJ09udeR/YZRaeArL9e1wPbFg3GG2yDnC2ldKpxs4xunpFF9DgqCqOIra3bc1HWrJ37Ww==} engines: {node: '>=0.4.x'} @@ -5195,11 +6156,19 @@ packages: fs-constants@1.0.0: resolution: {integrity: sha512-y6OAwoSIf7FyjMIv94u+b5rdheZEjzR63GTyZJm5qh4Bi+2YgwLCcI/fPFZkL5PSixOt6ZNKm+w+Hfp/Bciwow==} + fsevents@2.3.2: + resolution: {integrity: sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA==} + engines: {node: ^8.16.0 || ^10.6.0 || >=11.0.0} + os: [darwin] + fsevents@2.3.3: resolution: {integrity: sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==} engines: {node: ^8.16.0 || ^10.6.0 || >=11.0.0} os: [darwin] + function-bind@1.1.2: + resolution: {integrity: sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA==} + functional-red-black-tree@1.0.1: resolution: {integrity: sha512-dsKNQNdj6xA3T+QlADDA7mOSlX0qiMINjn0cgr+eGHGsbSHzTabcIogz2+p/iqP1Xs6EP/sS2SbqH+brGTbq0g==} @@ -5214,16 +6183,24 @@ packages: resolution: {integrity: sha512-CQ+bEO+Tva/qlmw24dCejulK5pMzVnUOFOijVogd3KQs07HnRIgp8TGipvCCRT06xeYEbpbgwaCxglFyiuIcmA==} engines: {node: '>=18'} + get-intrinsic@1.3.0: + resolution: {integrity: sha512-9fSjSaos/fRIVIp+xSJlE6lfwhES7LNtKaCBIamHsjr2na1BiABJPo0mOjjz8GJDURarmCPGqaiVg5mfjb98CQ==} + engines: {node: '>= 0.4'} + get-nonce@1.0.1: resolution: {integrity: sha512-FJhYRoDaiatfEkUK8HKlicmu/3SGFD51q3itKDGoSTysQJBnfOcxU5GxnhE1E6soB76MbT0MBtnKJuXyAx+96Q==} engines: {node: '>=6'} + get-proto@1.0.1: + resolution: {integrity: sha512-sTSfBjoXBp89JvIKIefqw7U2CCebsc74kiY6awiGogKtoSGbgjYE/G/+l9sF3MWFPNc9IcoOC4ODfKHfxFmp0g==} + engines: {node: '>= 0.4'} + get-stream@5.2.0: resolution: {integrity: sha512-nBF+F1rAZVCu/p7rjzgA+Yb4lfYXrpl7a6VmJrU8wF9I1CKvP/QwPNZHnOlwbTkY6dvtFIzFMSyQXbLoTQPRpA==} engines: {node: '>=8'} - get-tsconfig@4.13.6: - resolution: {integrity: sha512-shZT/QMiSHc/YBLxxOkMtgSid5HFoauqCE3/exfsEcwg1WkeqjG+V40yBbBrsD+jW2HDXcs28xOfcbm2jI8Ddw==} + get-tsconfig@4.13.7: + resolution: {integrity: sha512-7tN6rFgBlMgpBML5j8typ92BKFi2sFQvIdpAqLA2beia5avZDrMs0FLZiM5etShWq5irVyGcGMEA1jcDaK7A/Q==} github-from-package@0.0.0: resolution: {integrity: sha512-SyHy3T1v2NUXn29OsWdxmK6RwHD+vkj3v8en8AOBZ1wBQ/hCAQ5bAQTD02kW4W9tUp/3Qh6J8r9EvntiyCmOOw==} @@ -5246,6 +6223,10 @@ packages: resolution: {integrity: sha512-Wjlyrolmm8uDpm/ogGyXZXb1Z+Ca2B8NbJwqBVg0axK9GbBeoS7yGV6vjXnYdGm6X53iehEuxxbyiKp8QmN4Vw==} engines: {node: 18 || 20 || >=22} + global-dirs@3.0.1: + resolution: {integrity: sha512-NBcGGFbBA9s1VzD41QXDG+3++t9Mn5t1FpLdhESY6oKY4gYTFpX4wO3sqGUa0Srjtbfj3szX0RnemmrVRUdULA==} + engines: {node: '>=10'} + globals@14.0.0: resolution: {integrity: sha512-oahGvuMGQlPw/ivIYBjVSrWAfWLBeku5tpPE2fOPLi+WHffIWbuh2tCjhyQhTBPMf5E9jDEH4FOmTYgYwbKwtQ==} engines: {node: '>=18'} @@ -5270,16 +6251,36 @@ packages: peerDependencies: csstype: ^3.0.10 + gopd@1.2.0: + resolution: {integrity: sha512-ZUKRh6/kUFoAiTAtTYPZJ3hw9wNxx+BIBOijnlG9PnrJsCcSjs1wyyD6vJpaYtgnzDrKYRSqf3OO6Rfa93xsRg==} + engines: {node: '>= 0.4'} + graceful-fs@4.2.11: resolution: {integrity: sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ==} hachure-fill@0.5.2: resolution: {integrity: sha512-3GKBOn+m2LX9iq+JC1064cSFprJY4jL1jCXTcpnfER5HYE2l/4EfWSGzkPa/ZDBmYI0ZOEj5VHV/eKnPGkHuOg==} + happy-dom@20.8.9: + resolution: {integrity: sha512-Tz23LR9T9jOGVZm2x1EPdXqwA37G/owYMxRwU0E4miurAtFsPMQ1d2Jc2okUaSjZqAFz2oEn3FLXC5a0a+siyA==} + engines: {node: '>=20.0.0'} + + has-ansi@4.0.1: + resolution: {integrity: sha512-Qr4RtTm30xvEdqUXbSBVWDu+PrTokJOwe/FU+VdfJPk+MXAPoeOzKpRyrDTnZIJwAkQ4oBLTU53nu0HrkF/Z2A==} + engines: {node: '>=8'} + has-flag@4.0.0: resolution: {integrity: sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==} engines: {node: '>=8'} + has-symbols@1.1.0: + resolution: {integrity: sha512-1cDNdwJ2Jaohmb3sg4OmKaMBwuC48sYni5HUw2DvsC8LjGTLK9h+eb1X6RyuOHe4hT0ULCW68iomhjUoKUqlPQ==} + engines: {node: '>= 0.4'} + + has-tostringtag@1.0.2: + resolution: {integrity: sha512-NqADB8VjPFLM2V0VvHUewwwsw0ZWBaIdgo+ieHtK3hasLz4qeCRjYcqfB6AQrBggRKppKF8L52/VqdVsO47Dlw==} + engines: {node: '>= 0.4'} + hast-util-from-dom@5.0.1: resolution: {integrity: sha512-N+LqofjR2zuzTjCPzyDUdSshy4Ma6li7p/c3pA78uTwzFgENbgbUrm2ugwsOdcjI1muO+o6Dgzp9p8WHtn/39Q==} @@ -5344,13 +6345,13 @@ packages: highlightjs-vue@1.0.0: resolution: {integrity: sha512-PDEfEF102G23vHmPhLyPboFCD+BkMGu+GuJe2d9/eH4FsCwvgBpnc9n0pGE+ffKdph38s6foEZiEjdgHdzp+IA==} - hono@4.12.8: - resolution: {integrity: sha512-VJCEvtrezO1IAR+kqEYnxUOoStaQPGrCmX3j4wDTNOcD1uRPFpGlwQUIW8niPuvHXaTUxeOUl5MMDGrl+tmO9A==} + hono@4.12.9: + resolution: {integrity: sha512-wy3T8Zm2bsEvxKZM5w21VdHDDcwVS1yUFFY6i8UobSsKfFceT7TOwhbhfKsDyx7tYQlmRM5FLpIuYvNFyjctiA==} engines: {node: '>=16.9.0'} - html-encoding-sniffer@6.0.0: - resolution: {integrity: sha512-CV9TW3Y3f8/wT0BRFc1/KAVQ3TUHiXmaAb6VW9vtiMFf7SLoMd1PdAc4W3KFOFETBJUb90KatHqlsZMWV+R9Gg==} - engines: {node: ^20.19.0 || ^22.12.0 || >=24.0.0} + hosted-git-info@9.0.2: + resolution: {integrity: sha512-M422h7o/BR3rmCQ8UHi7cyyMqKltdP9Uo+J2fXK+RSAY+wTcKOIRyhTuKv4qn+DJf3g+PL890AzId5KZpX+CBg==} + engines: {node: ^20.17.0 || >=22.9.0} html-entities@2.6.0: resolution: {integrity: sha512-kig+rMn/QOVRvr7c86gQ8lWXq+Hkv6CbAH1hLu+RG338StTpE8Z0b44SDVaqVu7HGKf27frdmUYEs9hTUX/cLQ==} @@ -5381,10 +6382,10 @@ packages: i18next-resources-to-backend@1.2.1: resolution: {integrity: sha512-okHbVA+HZ7n1/76MsfhPqDou0fptl2dAlhRDu2ideXloRRduzHsqDOznJBef+R3DFZnbvWoBW+KxJ7fnFjd6Yw==} - i18next@25.8.18: - resolution: {integrity: sha512-lzY5X83BiL5AP77+9DydbrqkQHFN9hUzWGjqjLpPcp5ZOzuu1aSoKaU3xbBLSjWx9dAzW431y+d+aogxOZaKRA==} + i18next@25.10.10: + resolution: {integrity: sha512-cqUW2Z3EkRx7NqSyywjkgCLK7KLCL6IFVFcONG7nVYIJ3ekZ1/N5jUsihHV6Bq37NfhgtczxJcxduELtjTwkuQ==} peerDependencies: - typescript: ^5 + typescript: ^5 || ^6 peerDependenciesMeta: typescript: optional: true @@ -5440,12 +6441,20 @@ packages: resolution: {integrity: sha512-m6FAo/spmsW2Ab2fU35JTYwtOKa2yAwXSwgjSv1TJzh4Mh7mC3lzAOVLBprb72XsTrgkEIsl7YrFNAiDiRhIGg==} engines: {node: '>=12'} + index-to-position@1.2.0: + resolution: {integrity: sha512-Yg7+ztRkqslMAS2iFaU+Oa4KTSidr63OsFGlOrJoW981kIYO3CGCS3wA95P1mUi/IVSJkn0D479KTJpVpvFNuw==} + engines: {node: '>=18'} + inherits@2.0.4: resolution: {integrity: sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==} ini@1.3.8: resolution: {integrity: sha512-JV/yugV2uzW5iMRSiZAyDtQd+nxtUnjeLt0acNdw98kKLrvuRVyB80tsREOE7yvGVgalhZ6RNXCmEHkUKBKxew==} + ini@2.0.0: + resolution: {integrity: sha512-7PnF4oN3CvZF23ADhA5wRaYEQpJ8qygSkbtTXWBeXWXmEVRXK+1ITciHWwHhsjv1TmW0MgacIv6hEi5pX5NQdA==} + engines: {node: '>=10'} + inline-style-parser@0.2.7: resolution: {integrity: sha512-Nb2ctOyNR8DqQoR0OwRG95uNWIC0C1lCgf5Naz5H6Ji72KZ8OcFZLz2P5sNgwlyoJ8Yif11oMuYs5pBQa86csA==} @@ -5509,34 +6518,38 @@ packages: is-hexadecimal@2.0.1: resolution: {integrity: sha512-DgZQp241c8oO6cA1SbTEWiXeoxV42vlcJxgH+B3hi1AiqqKruZR3ZGF8In3fj4+/y/7rHvlOZLZtgJ/4ttYGZg==} - is-immutable-type@5.0.1: - resolution: {integrity: sha512-LkHEOGVZZXxGl8vDs+10k3DvP++SEoYEAJLRk6buTFi6kD7QekThV7xHS0j6gpnUCQ0zpud/gMDGiV4dQneLTg==} - peerDependencies: - eslint: '*' - typescript: '>=4.7.4' + is-in-ssh@1.0.0: + resolution: {integrity: sha512-jYa6Q9rH90kR1vKB6NM7qqd1mge3Fx4Dhw5TVlK1MUBqhEOuCagrEHMevNuCcbECmXZ0ThXkRm+Ymr51HwEPAw==} + engines: {node: '>=20'} is-inside-container@1.0.0: resolution: {integrity: sha512-KIYLCCJghfHZxqjYBE7rEy0OBuTd5xCHS7tHVgvCLkx7StIoaxwNW3hCALgEUjFfeRk+MG/Qxmp/vtETEF3tRA==} engines: {node: '>=14.16'} hasBin: true - is-node-process@1.2.0: - resolution: {integrity: sha512-Vg4o6/fqPxIjtxgUH5QLJhwZ7gW5diGCVlXpuUfELC62CuxM1iHcRe51f2W1FDy04Ai4KJkagKjx3XaqyfRKXw==} + is-installed-globally@0.4.0: + resolution: {integrity: sha512-iwGqO3J21aaSkC7jWnHP/difazwS7SFeIqxv6wEtLU8Y5KlzFTjyqcSIT0d8s4+dDhKytsk9PJZ2BkS5eZwQRQ==} + engines: {node: '>=10'} is-number@7.0.0: resolution: {integrity: sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng==} engines: {node: '>=0.12.0'} + is-path-inside@3.0.3: + resolution: {integrity: sha512-Fd4gABb+ycGAmKou8eMftCupSir5lRxqf4aD/vd0cD2qc4HL07OjCeuHMr8Ro4CoMaeCKDB0/ECBOVWjTwUvPQ==} + engines: {node: '>=8'} + is-plain-obj@4.1.0: resolution: {integrity: sha512-+Pgi+vMuUNkJyExiMBt5IlFoMyKnr5zhJ4Uspz58WOhBF5QoIZkFyNHIbBAtHwzVAgk5RtndVNsDRN61/mmDqg==} engines: {node: '>=12'} - is-potential-custom-element-name@1.0.1: - resolution: {integrity: sha512-bCYeRA2rVibKZd+s2625gGnGF/t7DSqDs4dP7CrLA1m7jKWz6pps0LpYLJN8Q64HtmPKJ1hrN3nzPNKFEKOUiQ==} - is-reference@3.0.3: resolution: {integrity: sha512-ixkJoqQvAP88E6wLydLGGqCJsrFUnqoH6HnaczB8XmDH1oaWU+xxdptvikTgaEhtZ53Ky6YXiBuUI2WXLMCwjw==} + is-stream@2.0.1: + resolution: {integrity: sha512-hFoiJiTl63nn+kstHGBtewWSKnQLpyb155KHheA1l39uvtO9nWIop1p3udqPcUd/xbF1VLMO4n7OI6p7RbngDg==} + engines: {node: '>=8'} + is-wsl@3.1.1: resolution: {integrity: sha512-e6rvdUCiQCAuumZslxRJWR/Doq4VpPR82kqclvcS0efgt430SlGIk05vdCN58+VrzgtIcfNODjozVielycD4Sw==} engines: {node: '>=16'} @@ -5571,8 +6584,8 @@ packages: resolution: {integrity: sha512-ekilCSN1jwRvIbgeg/57YFh8qQDNbwDb9xT/qu2DAHbFFZUicIl4ygVaAvzveMhMVr3LnpSKTNnwt8PoOfmKhQ==} hasBin: true - jotai@2.18.1: - resolution: {integrity: sha512-e0NOzK+yRFwHo7DOp0DS0Ycq74KMEAObDWFGmfEL28PD9nLqBTt3/Ug7jf9ca72x0gC9LQZG9zH+0ISICmy3iA==} + jotai@2.19.0: + resolution: {integrity: sha512-r2wwxEXP1F2JteDLZEOPoIpAHhV89paKsN5GWVYndPNMMP/uVZDcC+fNj0A8NjKgaPWzdyO8Vp8YcYKe0uCEqQ==} engines: {node: '>=12.20.0'} peerDependencies: '@babel/core': '>=7.0.0' @@ -5589,6 +6602,10 @@ packages: react: optional: true + joycon@3.1.1: + resolution: {integrity: sha512-34wB/Y7MW7bzjKRjUKTa46I2Z7eV62Rkhva+KkopW7Qvv/OSWBqvkSY7vusOPrNuZcUG3tApvdVgNB8POj3SPw==} + engines: {node: '>=10'} + js-audio-recorder@1.0.7: resolution: {integrity: sha512-JiDODCElVHGrFyjGYwYyNi7zCbKk9va9C77w+zCPMmi4C6ix7zsX2h3ddHugmo4dOTOTCym9++b/wVW9nC0IaA==} @@ -5616,19 +6633,6 @@ packages: resolution: {integrity: sha512-/2uqY7x6bsrpi3i9LVU6J89352C0rpMk0as8trXxCtvd4kPk1ke/Eyif6wqfSLvoNJqcDG9Vk4UsXgygzCt2xA==} engines: {node: '>=20.0.0'} - jsdom-testing-mocks@1.16.0: - resolution: {integrity: sha512-wLrulXiLpjmcUYOYGEvz4XARkrmdVpyxzdBl9IAMbQ+ib2/UhUTRCn49McdNfXLff2ysGBUms49ZKX0LR1Q0gg==} - engines: {node: '>=14'} - - jsdom@29.0.0: - resolution: {integrity: sha512-9FshNB6OepopZ08unmmGpsF7/qCjxGPbo3NbgfJAnPeHXnsODE9WWffXZtRFRFe0ntzaAOcSKNJFz8wiyvF1jQ==} - engines: {node: ^20.19.0 || ^22.13.0 || >=24.0.0} - peerDependencies: - canvas: ^3.2.1 - peerDependenciesMeta: - canvas: - optional: true - jsesc@3.1.0: resolution: {integrity: sha512-/sM3dO2FOzXjKQhJuo0Q173wf2KOo8t4I8vHy6lF9poUp7bKT0/NHE8fPX23PwfhnykfqnC2xRxOnVw5XuGIaA==} engines: {node: '>=6'} @@ -5649,12 +6653,6 @@ packages: json-stable-stringify-without-jsonify@1.0.1: resolution: {integrity: sha512-Bdboy+l7tA3OGW6FjyFHWkP5LuByj1Tk33Ljyq0axyzdk9//JSi2u3fP1QSmd1KNwq6VOKYGlAu87CisVir6Pw==} - json-stringify-safe@5.0.1: - resolution: {integrity: sha512-ZClg6AaYvamvYEE82d3Iyd3vSSIjQ+odgjaTzRuO3s7toCdFKczob2i0zCh7JE8kWn17yvAWhUVxvqGwUalsRA==} - - json-with-bigint@3.5.7: - resolution: {integrity: sha512-7ei3MdAI5+fJPVnKlW77TKNKwQ5ppSzWvhPuSuINT/GYW9ZOC1eRKOuhV9yHG5aEsUPj9BBx5JIekkmoLHxZOw==} - json5@2.2.3: resolution: {integrity: sha512-XmOWe7eyHYH14cLdVPoyg+GOH3rYX++KpzrylJwSW98t3Nk+U8XOl8FWKOgwtzdb8lXGf6zYwDUzeHMWfxasyg==} engines: {node: '>=6'} @@ -5674,8 +6672,8 @@ packages: resolution: {integrity: sha512-eQQBjBnsVtGacsG9uJNB8qOr3yA8rga4wAaGG1qRcBzSIvfhERLrWxMAM1hp5fcS6Abo8M4+bUBTekYR0qTPQw==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} - katex@0.16.38: - resolution: {integrity: sha512-cjHooZUmIAUmDsHBN+1n8LaZdpmbj03LtYeYPyuYB7OuloiaeaV6N4LcfjcnHVzGWjVQmKrxxTrpDcmSzEZQwQ==} + katex@0.16.44: + resolution: {integrity: sha512-EkxoDTk8ufHqHlf9QxGwcxeLkWRR3iOuYfRpfORgYfqc8s13bgb+YtRY59NK5ZpRaCwq1kqA6a5lpX8C/eLphQ==} hasBin: true keyv@4.5.4: @@ -5684,13 +6682,13 @@ packages: khroma@2.1.0: resolution: {integrity: sha512-Ls993zuzfayK269Svk9hzpeGUKob/sIgZzyHYdjQoAdQetRKpOLj+k/QQQ/6Qi0Yz65mlROrfd+Ev+1+7dz9Kw==} - knip@5.88.0: - resolution: {integrity: sha512-FZjQYLYwUbVrtC3C1cKyEMMqR4K2ZlkQLZszJgF5cfDo4GUSBZAdAV0P3eyzZrkssRoghLJQA9HTQUW7G+Tc8Q==} - engines: {node: '>=18.18.0'} + knip@6.1.0: + resolution: {integrity: sha512-n5eVbJP7HXmwTsiJcELWJe2O1ESxyCTNxJzRTIECDYDTM465qnqk7fL2dv6ae3NUFvFWorZvGlh9mcwxwJ5Xgw==} + engines: {node: ^20.19.0 || >=22.12.0} hasBin: true - peerDependencies: - '@types/node': '>=18' - typescript: '>=5.0.4 <7' + + knuth-shuffle-seeded@1.0.6: + resolution: {integrity: sha512-9pFH0SplrfyKyojCLxZfMcvkhf5hH0d+UwR9nTVJ/DDQJGuzcXjTwB7TP7sDfehSudlGGaOLblmEWqv04ERVWg==} kolorist@1.8.0: resolution: {integrity: sha512-Y+60/zizpJ3HRH8DCss+q95yr6145JXZo46OTpFvDZWLfRCE4qChOyk1b26nMaNpfHHgxagk9dXT5OP0Tfe+dQ==} @@ -5725,8 +6723,8 @@ packages: '@lexical/utils': '>=0.28.0' lexical: '>=0.28.0' - lexical@0.41.0: - resolution: {integrity: sha512-pNIm5+n+hVnJHB9gYPDYsIO5Y59dNaDU9rJmPPsfqQhP2ojKFnUoPbcRnrI9FJLXB14sSumcY8LUw7Sq70TZqA==} + lexical@0.42.0: + resolution: {integrity: sha512-GY9Lg3YEIU7nSFaiUlLspZ1fm4NfIcfABaxy9nT+fRVDkX7iV005T5Swil83gXUmxFUNKGal3j+hUxHOUDr+Aw==} lib0@0.2.117: resolution: {integrity: sha512-DeXj9X5xDCjgKLU/7RR+/HQEVzuuEUiwldwOGsHK/sfAfELGWEyTcf0x+uOvCvK3O2zPmZePXWL85vtia6GyZw==} @@ -5826,6 +6824,10 @@ packages: resolution: {integrity: sha512-ME4Fb83LgEgwNw96RKNvKV4VTLuXfoKudAmm2lP8Kk87KaMK0/Xrx/aAkMWmT8mDb+3MlFDspfbCs7adjRxA2g==} engines: {node: '>=20.0.0'} + load-tsconfig@0.2.5: + resolution: {integrity: sha512-IXO6OCs9yg8tMKzfPZ1YmheJbZCiEsnBdcB03l0OcfK9prKnJb96siuHCr5Fl37/yo9DnKU+TLpxzTUspw9shg==} + engines: {node: ^12.20.0 || ^14.13.1 || >=16.0.0} + loader-runner@4.3.1: resolution: {integrity: sha512-IWqP2SCPhyVFTBtRcgMHdzlf9ul25NwaFx4wCEH/KjAXuuHY4yNjvPXsBokp8jCB936PyWRaPKUNh8NvylLp2Q==} engines: {node: '>=6.11.5'} @@ -5844,6 +6846,12 @@ packages: lodash.merge@4.6.2: resolution: {integrity: sha512-0KpjqXRVvrYyCsX1swR/XTK0va6VQkQM6MNo7PqW77ByjAhoARA8EfrP1N4+KlKj8YS0ZUCtRT/YUuhyYDujIQ==} + lodash.mergewith@4.6.2: + resolution: {integrity: sha512-GK3g5RPZWTRSeLSpgP8Xhra+pnjBC56q9FZYe1d5RN3TJ35dbkGy3YqBSMbyCrlbi+CM9Z3Jk5yTL7RCsqboyQ==} + + lodash.sortby@4.7.0: + resolution: {integrity: sha512-HDWXG8isMntAyRF5vZ7xKuEvOhT4AhlRt/3czTSjvGUxjYCBVRQY48ViDHyfYz9VIoBkW4TMGQNapx+l3RUwdA==} + lodash@4.17.23: resolution: {integrity: sha512-LgVTMpQtIopCi79SJeDiP0TfWi5CNEc/L/aRdTh3yIvmZXTnheWpKjSZhnvMl8iXbC1tFg9gdHHDMLoV7CnG+w==} @@ -5861,6 +6869,9 @@ packages: loupe@3.2.1: resolution: {integrity: sha512-CdzqowRJCeLU72bHvWqwRBBlLcMEtIvGrlvef74kMnV2AolS9Y8xUv1I0U/MNAWMhBlKIoyuEgoJ0t/bbwHbLQ==} + lower-case@2.0.2: + resolution: {integrity: sha512-7fm3l3NAF9WfN6W3JOmf5drwpVqX78JtoGJ3A6W0a6ZnldM41w2fV5D490psKFTpMds8TJse/eHLFFsNHHjHgg==} + lowlight@1.20.0: resolution: {integrity: sha512-8Ktj+prEb1RoCPkEOrPMYUN/nCggB7qAWe3a7OpMjWQkh3l2RD5wKRQ+o8Q8YuI9RG/xs95waaI/E6ym/7NsTw==} @@ -5906,11 +6917,15 @@ packages: engines: {node: '>= 20'} hasBin: true - marked@17.0.4: - resolution: {integrity: sha512-NOmVMM+KAokHMvjWmC5N/ZOvgmSWuqJB8FoYI019j4ogb/PeRMKoKIjReZ2w3376kkA8dSJIP8uD993Kxc0iRQ==} + marked@17.0.5: + resolution: {integrity: sha512-6hLvc0/JEbRjRgzI6wnT2P1XuM1/RrrDEX0kPt0N7jGm1133g6X7DlxFasUIx+72aKAr904GTxhSLDrd5DIlZg==} engines: {node: '>= 20'} hasBin: true + math-intrinsics@1.1.0: + resolution: {integrity: sha512-/IXtbwEk5HTPyEwyKX6hGkYXxM9nbj64B+ilVJnC/R6B0pH5G4V3b0pVbL7DBj4tkhBAppbQUlf6F6Xl9LHu1g==} + engines: {node: '>= 0.4'} + mdast-util-directive@3.1.0: resolution: {integrity: sha512-I3fNFt+DHmpWCYAT7quoM6lHf9wuqtI+oCOfvILnoicNIqjh5E3dEJWiXuYME2gNe8vl1iMQwyUHa7bgFmak6Q==} @@ -5980,9 +6995,6 @@ packages: mdn-data@2.23.0: resolution: {integrity: sha512-786vq1+4079JSeu2XdcDjrhi/Ry7BWtjDl9WtGPWLiIHb2T66GvIVflZTBoSNZ5JqTtJGYEVMuFA/lbQlMOyDQ==} - mdn-data@2.27.1: - resolution: {integrity: sha512-9Yubnt3e8A0OKwxYSXyhLymGW4sCufcLG6VdiDdUGVkPhpqLxlvP5vl1983gQjJl3tqbrM731mjaZaP68AgosQ==} - memoize-one@5.2.1: resolution: {integrity: sha512-zYiwtZUcYyXKo/np96AGZAckk+FWWsUdJ3cHGGmld7+AhvcWmQyGCYUh1hc4Q/pkOhb65dQR/pqCyK0cOaHz4Q==} @@ -6122,6 +7134,11 @@ packages: resolution: {integrity: sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==} engines: {node: '>= 0.6'} + mime@3.0.0: + resolution: {integrity: sha512-jSCU7/VB1loIWBZe14aEYHU/+1UMEHoaO7qxCOVJOw9GgH72VAWppxNcjU+x9a2k3GSIBXNKxXQFqRvvZ7vr3A==} + engines: {node: '>=10.0.0'} + hasBin: true + mime@4.1.0: resolution: {integrity: sha512-X5ju04+cAzsojXKes0B/S4tcYtFAJ6tTMuSPBEn9CPGlrWr8Fiw7qYeLT0XyH80HSoAoqWCaz+MWKh22P7G1cw==} engines: {node: '>=16'} @@ -6163,8 +7180,13 @@ packages: mkdirp-classic@0.5.3: resolution: {integrity: sha512-gKLcREMhtuZRwRAfqP3RFW+TK4JqApVBtOIftVgjuABpAtpxhPGaDcfvbhNvD0B8iD1oUr/txX35NjcaY6Ns/A==} - mlly@1.8.1: - resolution: {integrity: sha512-SnL6sNutTwRWWR/vcmCYHSADjiEesp5TGQQ0pXyLhW5IoeibRlF/CbSLailbB3CNqJUk9cVJ9dUDnbD7GrcHBQ==} + mkdirp@3.0.1: + resolution: {integrity: sha512-+NsyUUAZDmo6YVHzL/stxSu3t9YS1iljliy3BSDrXJ/dkn1KYdmtZODGGjLcc9XLgVVpH4KshHB8XmZgMhaBXg==} + engines: {node: '>=10'} + hasBin: true + + mlly@1.8.2: + resolution: {integrity: sha512-d+ObxMQFmbt10sretNDytwt85VrbkhhUA/JBGm1MPaWJ65Cl4wOgLaB1NYvJSZ0Ef03MMEU/0xpPMXUIQ29UfA==} module-alias@2.3.4: resolution: {integrity: sha512-bOclZt8hkpuGgSSoG07PKmvzTizROilUTvLNyrMqvlC9snhs7y7GzjNWAVbISIOlhCP1T14rH1PDAV9iNyBq/w==} @@ -6216,8 +7238,8 @@ packages: react: ^16.8 || ^17 || ^18 || ^19 || ^19.0.0-rc react-dom: ^16.8 || ^17 || ^18 || ^19 || ^19.0.0-rc - next@16.2.0: - resolution: {integrity: sha512-NLBVrJy1pbV1Yn00L5sU4vFyAHt5XuSjzrNyFnxo6Com0M0KrL6hHM5B99dbqXb2bE9pm4Ow3Zl1xp6HVY9edQ==} + next@16.2.1: + resolution: {integrity: sha512-VaChzNL7o9rbfdt60HUj8tev4m6d7iC1igAy157526+cJlXOQu5LzsBXNT+xaJnTP/k+utSX5vMv7m0G+zKH+Q==} engines: {node: '>=20.9.0'} hasBin: true peerDependencies: @@ -6237,9 +7259,8 @@ packages: sass: optional: true - nock@14.0.11: - resolution: {integrity: sha512-u5xUnYE+UOOBA6SpELJheMCtj2Laqx15Vl70QxKo43Wz/6nMHXS7PrEioXLjXAwhmawdEMNImwKCcPhBJWbKVw==} - engines: {node: '>=18.20.0 <20 || >=20.12.1'} + no-case@3.0.4: + resolution: {integrity: sha512-fgAN3jGAh+RoxUGZHTSOLJIqUc2wmoBwGR4tbpNAKmmovFoWq0OdRkb0VkldReO2a2iBT/OEulG9XSUc10r3zg==} node-abi@3.89.0: resolution: {integrity: sha512-6u9UwL0HlAl21+agMN3YAMXcKByMqwGx+pq+P76vii5f7hTPtKDp08/H9py6DY+cfDw7kQNTGEj/rly3IgbNQA==} @@ -6254,6 +7275,10 @@ packages: node-releases@2.0.36: resolution: {integrity: sha512-TdC8FSgHz8Mwtw9g5L4gR/Sh9XhSP/0DEkQxfEFXOpiul5IiHgHan2VhYYb6agDSfp4KuvltmGApc8HMgUrIkA==} + normalize-package-data@8.0.0: + resolution: {integrity: sha512-RWk+PI433eESQ7ounYxIp67CYuVsS1uYSonX3kA6ps/3LWfjVQa/ptEg6Y3T6uAMq1mWpX9PQ+qx+QaHpsc7gQ==} + engines: {node: ^20.17.0 || >=22.9.0} + normalize-path@3.0.0: resolution: {integrity: sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA==} engines: {node: '>=0.10.0'} @@ -6316,6 +7341,10 @@ packages: resolution: {integrity: sha512-YgBpdJHPyQ2UE5x+hlSXcnejzAvD0b22U2OuAP+8OnlJT+PjWPxtgmGqKKc+RgTM63U9gN0YzrYc71R2WT/hTA==} engines: {node: '>=18'} + open@11.0.0: + resolution: {integrity: sha512-smsWv2LzFjP03xmvFoJ331ss6h+jixfA4UUV/Bsiyuu4YJPfN+FIQGOIiv4w9/+MoHkfkJ22UIaQWRVFRfH6Vw==} + engines: {node: '>=20'} + openapi-types@12.1.3: resolution: {integrity: sha512-N4YtSYJqghVu4iek2ZUvcN/0aqH1kRDuNqzcycDxhOUpg7GdvLa2F3DgS6yBNhInhv2r/6I0Flkn7CqL8+nIcw==} @@ -6323,23 +7352,24 @@ packages: resolution: {integrity: sha512-6IpQ7mKUxRcZNLIObR0hz7lxsapSSIYNZJwXPGeF0mTVqGKFIXj1DQcMoT22S3ROcLyY/rz0PWaWZ9ayWmad9g==} engines: {node: '>= 0.8.0'} - outvariant@1.4.3: - resolution: {integrity: sha512-+Sl2UErvtsoajRDKCE5/dBz4DIvHXQQnAxtQTF04OJxY0+DyZXSo5P5Bb7XYWOh81syohlYL24hbDwxedPUJCA==} + oxc-parser@0.121.0: + resolution: {integrity: sha512-ek9o58+SCv6AV7nchiAcUJy1DNE2CC5WRdBcO0mF+W4oRjNQfPO7b3pLjTHSFECpHkKGOZSQxx3hk8viIL5YCg==} + engines: {node: ^20.19.0 || >=22.12.0} oxc-resolver@11.19.1: resolution: {integrity: sha512-qE/CIg/spwrTBFt5aKmwe3ifeDdLfA2NESN30E42X/lII5ClF8V7Wt6WIJhcGZjp0/Q+nQ+9vgxGk//xZNX2hg==} - oxfmt@0.40.0: - resolution: {integrity: sha512-g0C3I7xUj4b4DcagevM9kgH6+pUHytikxUcn3/VUkvzTNaaXBeyZqb7IBsHwojeXm4mTBEC/aBjBTMVUkZwWUQ==} + oxfmt@0.42.0: + resolution: {integrity: sha512-QhejGErLSMReNuZ6vxgFHDyGoPbjTRNi6uGHjy0cvIjOQFqD6xmr/T+3L41ixR3NIgzcNiJ6ylQKpvShTgDfqg==} engines: {node: ^20.19.0 || >=22.12.0} hasBin: true - oxlint-tsgolint@0.17.0: - resolution: {integrity: sha512-TdrKhDZCgEYqONFo/j+KvGan7/k3tP5Ouz88wCqpOvJtI2QmcLfGsm1fcMvDnTik48Jj6z83IJBqlkmK9DnY1A==} + oxlint-tsgolint@0.17.3: + resolution: {integrity: sha512-1eh4bcpOMw0e7+YYVxmhFc2mo/V6hJ2+zfukqf+GprvVn3y94b69M/xNrYLmx5A+VdYe0i/bJ2xOs6Hp/jRmRA==} hasBin: true - oxlint@1.55.0: - resolution: {integrity: sha512-T+FjepiyWpaZMhekqRpH8Z3I4vNM610p6w+Vjfqgj5TZUxHXl7N8N5IPvmOU8U4XdTRxqtNNTh9Y4hLtr7yvFg==} + oxlint@1.57.0: + resolution: {integrity: sha512-DGFsuBX5MFZX9yiDdtKjTrYPq45CZ8Fft6qCltJITYZxfwYjVdGf/6wycGYTACloauwIPxUnYhBVeZbHvleGhw==} engines: {node: ^20.19.0 || >=22.12.0} hasBin: true peerDependencies: @@ -6352,6 +7382,10 @@ packages: resolution: {integrity: sha512-TYOanM3wGwNGsZN2cVTYPArw454xnXj5qmWF1bEoAc4+cU/ol7GVh7odevjp1FNHduHc3KZMcFduxU5Xc6uJRQ==} engines: {node: '>=10'} + p-limit@7.3.0: + resolution: {integrity: sha512-7cIXg/Z0M5WZRblrsOla88S4wAK+zOQQWeBYfV3qJuJXMr+LnbYjaadrFaS0JILfEDPVqHyKnZ1Z/1d6J9VVUw==} + engines: {node: '>=20'} + p-locate@5.0.0: resolution: {integrity: sha512-LaNjtRWUBY++zB5nE/NwcaoMylSPk+S+ZHNB1TzdbMJMny6dynpAGt7X/tl/QYq3TIeE6nxHppbo2LGymrG5Pw==} engines: {node: '>=10'} @@ -6359,6 +7393,10 @@ packages: package-manager-detector@1.6.0: resolution: {integrity: sha512-61A5ThoTiDG/C8s8UMZwSorAGwMJ0ERVGj2OjoW5pAalsNOg15+iQiPzrLJ4jhZ1HJzmC2PIHT2oEiH3R5fzNA==} + pad-right@0.2.2: + resolution: {integrity: sha512-4cy8M95ioIGolCoMmm2cMntGR1lPLEbOMzOKu8bzjuJP6JpzEMQcDHmh7hHLYGgob+nKe1YHFMaG4V59HQa89g==} + engines: {node: '>=0.10.0'} + pako@0.2.9: resolution: {integrity: sha512-NUcwaKxUxWrZLpDG+z/xZaCgQITkA/Dv4V/T6bw7VON6l1Xz/VnrBqrYjZQ12TamKHzITTfOEIYUj48y2KXImA==} @@ -6385,6 +7423,10 @@ packages: parse-imports-exports@0.2.4: resolution: {integrity: sha512-4s6vd6dx1AotCx/RCI2m7t7GCh5bDRUtGNvRfHSP2wbBQdMi67pPe7mtzmgwcaQ8VKK/6IB7Glfyu3qdZJPybQ==} + parse-json@8.3.0: + resolution: {integrity: sha512-ybiGyvspI+fAoRQbIPRddCcSTV9/LsJbf0e/S85VLowVGzRmokfneg2kwVW/KU5rOXrPSbF1qAKPMgNTqqROQQ==} + engines: {node: '>=18'} + parse-statements@1.0.11: resolution: {integrity: sha512-HlsyYdMBnbPQ9Jr/VgJ1YF4scnldvJpJxCVx6KgqPL4dxppsWrJHCIIxQXMJrqGnsRkNPATbeMJ8Yxu7JMsYcA==} @@ -6448,12 +7490,12 @@ packages: picocolors@1.1.1: resolution: {integrity: sha512-xceH2snhtb5M9liqDsmEw56le376mTZkEX/jEb/RxNFyegNul7eNslCXP9FDj/Lcu0X8KEyMceP2ntpaHrDEVA==} - picomatch@2.3.1: - resolution: {integrity: sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==} + picomatch@2.3.2: + resolution: {integrity: sha512-V7+vQEJ06Z+c5tSye8S+nHUfI51xoXIXjHQ99cQtKUkQqqO1kO/KCJUfZXuB47h/YBlDhah2H3hdUGXn8ie0oA==} engines: {node: '>=8.6'} - picomatch@4.0.3: - resolution: {integrity: sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==} + picomatch@4.0.4: + resolution: {integrity: sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==} engines: {node: '>=12'} pify@2.3.0: @@ -6477,6 +7519,16 @@ packages: pkg-types@2.3.0: resolution: {integrity: sha512-SIqCzDRg0s9npO5XQ3tNZioRY1uK06lA41ynBC1YmFTmnY6FjUjVt6s4LoADmwoig1qqD0oK8h1p/8mlMx8Oig==} + playwright-core@1.58.2: + resolution: {integrity: sha512-yZkEtftgwS8CsfYo7nm0KE8jsvm6i/PTgVtB8DL726wNf6H2IMsDuxCpJj59KDaxCtSnrWan2AeDqM7JBaultg==} + engines: {node: '>=18'} + hasBin: true + + playwright@1.58.2: + resolution: {integrity: sha512-vA30H8Nvkq/cPBnNw4Q8TWz1EJyqgpuinBcHET0YVJVFldr8JDNiU9LaWAE1KqSkRYazuaBhTpB5ZzShOezQ6A==} + engines: {node: '>=18'} + hasBin: true + pluralize@8.0.0: resolution: {integrity: sha512-Nc3IT5yHzflTfbjgqWcCPpo7DaKy4FnpB0l/zCAW0Tc7jxAiuqSxHasntB3D7887LSrA93kDJ9IXovxJYxyLCA==} engines: {node: '>=4'} @@ -6523,7 +7575,7 @@ packages: jiti: '>=1.21.0' postcss: '>=8.0.9' tsx: ^4.8.1 - yaml: ^2.4.2 + yaml: 2.8.3 peerDependenciesMeta: jiti: optional: true @@ -6563,6 +7615,10 @@ packages: resolution: {integrity: sha512-OW/rX8O/jXnm82Ey1k44pObPtdblfiuWnrd8X7GJ7emImCOstunGbXUpp7HdBrFQX6rJzn3sPT397Wp5aCwCHg==} engines: {node: ^10 || ^12 || >=14} + powershell-utils@0.1.0: + resolution: {integrity: sha512-dM0jVuXJPsDN6DvRpea484tCUaMiXWjuCn++HGTqUWzGDjv5tZkEZldAJ/UMlqRYGFrD/etByo4/xOuC/snX2A==} + engines: {node: '>=20'} + prebuild-install@7.1.3: resolution: {integrity: sha512-8Mf2cbV7x1cXPUILADGI3wuhfqWvtiLA1iclTDbFRZkgRQS0NqsPZphna9V+HyTEadheuPmjaJMsbzKQFOzLug==} engines: {node: '>=10'} @@ -6581,12 +7637,15 @@ packages: resolution: {integrity: sha512-DEvV2ZF2r2/63V+tK8hQvrR2ZGn10srHbXviTlcv7Kpzw8jWiNTqbVgjO3IY8RxrrOUF8VPMQQFysYYYv0YZxw==} engines: {node: '>=6'} + progress@2.0.3: + resolution: {integrity: sha512-7PiHtLll5LdnKIMw100I+8xJXR5gW2QwWYkT6iJva0bXitZKa/XMrSbdmg3r2Xnaidz9Qumd0VPaMrZlF9V9sA==} + engines: {node: '>=0.4.0'} + prop-types@15.8.1: resolution: {integrity: sha512-oj87CgZICdulUohogVAR7AjlC0327U4el4L6eAvOqCeudMDVU0NThNaV+b9Df4dXgSP1gXMTnPdhfe/2qDH5cg==} - propagate@2.0.1: - resolution: {integrity: sha512-vGrhOavPSTz4QVNuBNdcNXePNdNMaO1xj9yBeH1ScQPjk/rhg9sSlCXPhMkFuaNNW/syTvYqsnbIJxMBfRbbag==} - engines: {node: '>= 8'} + property-expr@2.0.6: + resolution: {integrity: sha512-SVtmxhRE/CGkn3eZY1T6pC8Nln6Fr/lu1mKSgRud0eC73whjGfoAogbn78LkD8aFL0zz3bAFerKSnOl7NlErBA==} property-information@5.6.0: resolution: {integrity: sha512-YUHSPk+A30YPv+0Qf8i9Mbfe/C0hdPXk1s1jPVToV8pk8BQtpw10ct89Eo7OWkutrwqvT0eicAxlOg3dOAu8JA==} @@ -6594,6 +7653,10 @@ packages: property-information@7.1.0: resolution: {integrity: sha512-TwEZ+X+yCJmYfL7TPUOcvBZ4QfoT5YenQiJuX//0th53DE6w0xxLEtfK3iyryQFddXuvkIk51EEgrJQ0WJkOmQ==} + proxy-from-env@2.1.0: + resolution: {integrity: sha512-cJ+oHTW1VAEa8cJslgmUZrc+sjRKgAKl3Zyse6+PV38hZe/V6Z14TbCuXcan9F9ghlz4QrFr2c92TNF82UkYHA==} + engines: {node: '>=10'} + pump@3.0.4: resolution: {integrity: sha512-VS7sjc6KR7e1ukRFhQSY5LM2uBWAUPiOPa/A3mkKmiMwSmRFUITt0xuj+/lesgnCv+dPIEYlkzrcyXgquIHMcA==} @@ -6658,8 +7721,8 @@ packages: react: '>= 16.3.0' react-dom: '>= 16.3.0' - react-easy-crop@5.5.6: - resolution: {integrity: sha512-Jw3/ozs8uXj3NpL511Suc4AHY+mLRO23rUgipXvNYKqezcFSYHxe4QXibBymkOoY6oOtLVMPO2HNPRHYvMPyTw==} + react-easy-crop@5.5.7: + resolution: {integrity: sha512-kYo4NtMeXFQB7h1U+h5yhUkE46WQbQdq7if54uDlbMdZHdRgNehfvaFrXnFw5NR1PNoUOJIfTwLnWmEx/MaZnA==} peerDependencies: react: '>=16.4.0' react-dom: '>=16.4.0' @@ -6678,14 +7741,14 @@ packages: react: '>=16.8.0' react-dom: '>=16.8.0' - react-i18next@16.5.8: - resolution: {integrity: sha512-2ABeHHlakxVY+LSirD+OiERxFL6+zip0PaHo979bgwzeHg27Sqc82xxXWIrSFmfWX0ZkrvXMHwhsi/NGUf5VQg==} + react-i18next@16.6.6: + resolution: {integrity: sha512-ZgL2HUoW34UKUkOV7uSQFE1CDnRPD+tCR3ywSuWH7u2iapnz86U8Bi3Vrs620qNDzCf1F47NxglCEkchCTDOHw==} peerDependencies: - i18next: '>= 25.6.2' + i18next: '>= 25.10.9' react: '>= 16.8.0' react-dom: '*' react-native: '*' - typescript: ^5 + typescript: ^5 || ^6 peerDependenciesMeta: react-dom: optional: true @@ -6750,11 +7813,6 @@ packages: react-dom: ^19.2.4 webpack: ^5.59.0 - react-slider@2.0.6: - resolution: {integrity: sha512-gJxG1HwmuMTJ+oWIRCmVWvgwotNCbByTwRkFZC6U4MBsHqJBmxwbYRJUmxy4Tke1ef8r9jfXjgkmY/uHOCEvbA==} - peerDependencies: - react: ^16 || ^17 || ^18 - react-sortablejs@6.1.4: resolution: {integrity: sha512-fc7cBosfhnbh53Mbm6a45W+F735jwZ1UFIYSrIqcO/gRIFoDyZeMtgKlpV4DdyQfbCzdh5LoALLTDRxhMpTyXQ==} peerDependencies: @@ -6804,6 +7862,14 @@ packages: read-cache@1.0.0: resolution: {integrity: sha512-Owdv/Ft7IjOgm/i0xvNDZ1LrRANRfew4b2prF3OWMQLxLfu3bS8FVhCsrSCMK4lR56Y9ya+AThoTpDCTxCmpRA==} + read-package-up@12.0.0: + resolution: {integrity: sha512-Q5hMVBYur/eQNWDdbF4/Wqqr9Bjvtrw2kjGxxBbKLbx8bVCL8gcArjTy8zDUuLGQicftpMuU0riQNcAsbtOVsw==} + engines: {node: '>=20'} + + read-pkg@10.1.0: + resolution: {integrity: sha512-I8g2lArQiP78ll51UeMZojewtYgIRCKCWqZEgOO8c/uefTI+XDXvCSXu3+YNUaTNvZzobrL5+SqHjBrByRRTdg==} + engines: {node: '>=20'} + readable-stream@3.6.2: resolution: {integrity: sha512-9u/sniCrY3D5WdsERHzHE4G2YCXqoG5FTHUiCC4SIbr6XcLZBY05ya9EKjYek9O5xOAwjGq+1JdGBAS7Q9ScoA==} engines: {node: '>= 6'} @@ -6842,6 +7908,9 @@ packages: resolution: {integrity: sha512-J8rn6v4DBb2nnFqkqwy6/NnTYMcgLA+sLr0iIO41qpv0n+ngb7ksag2tMRl0inb1bbO/esUwzW1vbJi7K0sI0g==} engines: {node: ^12.0.0 || ^14.0.0 || >=16.0.0} + reflect-metadata@0.2.2: + resolution: {integrity: sha512-urBwgfrvVP/eAyXx4hluJivBKzuEbSQs9rKWCrCkbSxNv8mxPcUZKeuoF3Uy4mJl3Lwprp6yy5/39VWigZ4K6Q==} + refractor@3.6.0: resolution: {integrity: sha512-MY9W41IOWxxk31o+YvFCNyNzdkc9M20NoZK5vq6jkv4I/uh2zkWcfudj0Q1fovjUQJrNewS9NMzeTtqPf+n5EA==} @@ -6849,6 +7918,9 @@ packages: resolution: {integrity: sha512-sZuz1dYW/ZsfG17WSAG7eS85r5a0dDsvg+7BiiYR5o6lKCAtUrEwdmRmaGF6rwVj3LcmAeYkOWKEPlbPzN3Y3A==} engines: {node: ^12.0.0 || ^14.0.0 || >=16.0.0} + regexp-match-indices@1.0.2: + resolution: {integrity: sha512-DwZuAkt8NF5mKwGGER1EGh2PRqyvhRhhLviH+R8y8dIuaQROlUfXjt4s9ZTXstIsSkptf06BSvwcEmmfheJJWQ==} + regexp-tree@0.1.27: resolution: {integrity: sha512-iETxpjK6YoRWJG5o6hXLwvjYAoW+FEZn9os0PD/b6AP6xQwsa/Y7lCVgIixBbUPMfhu+i2LtdeAqVTgGlQarfA==} hasBin: true @@ -6899,6 +7971,10 @@ packages: remend@1.3.0: resolution: {integrity: sha512-iIhggPkhW3hFImKtB10w0dz4EZbs28mV/dmbcYVonWEJ6UGHHpP+bFZnTh6GNWJONg5m+U56JrL+8IxZRdgWjw==} + repeat-string@1.6.1: + resolution: {integrity: sha512-PV0dzCYDNfRi1jCDbJzpW7jNNDRuCOG/jI5ctQcGKt/clZD+YcPS3yIlWuTJMmESC8aevCFmWJy5wjAFgNqN6w==} + engines: {node: '>=0.10'} + require-from-string@2.0.2: resolution: {integrity: sha512-Xf0nWe6RseziFMu+Ap9biiUbmplq6S9/p+7w7YXP/JBHhrUDDUhwa+vANyubuqfZWTveU//DYVGsDG7RKL/vEw==} engines: {node: '>=0.10.0'} @@ -6917,6 +7993,10 @@ packages: resolution: {integrity: sha512-pb/MYmXstAkysRFx8piNI1tGFNQIFA3vkE3Gq4EuA1dF6gHp/+vgZqsCGJapvy8N3Q+4o7FwvquPJcnZ7RYy4g==} engines: {node: '>=4'} + resolve-from@5.0.0: + resolution: {integrity: sha512-qYg9KP24dD5qka9J47d0aVky0N+b4fTU89LN9iDnjB5waksiC49rvMB0PrUJQGoTmH50XPiqOvAjDfaijGxYZw==} + engines: {node: '>=8'} + resolve-pkg-maps@1.0.0: resolution: {integrity: sha512-seS2Tj26TBVOC2NIc2rOe2y2ZO7efxITtLZcGSOnHHNOQ7CkiUBfw0Iw2ck6xkIhPwLhKNLS8BO+hEpngQlqzw==} @@ -6936,8 +8016,13 @@ packages: rfdc@1.4.1: resolution: {integrity: sha512-q1b3N5QkRUWUl7iyylaaj3kOpIT0N2i9MqIEQXP73GVsN9cw3fdx8X63cEmWhJGi2PPCF23Ijp7ktmd39rawIA==} - robust-predicates@3.0.2: - resolution: {integrity: sha512-IXgzBWvWQwE6PrDI05OvmXUIruQTcoMDzRsOd5CDvHCVLcLHMTSYvOK5Cm46kWqlV3yAbuSpBZdJ5oP5OUoStg==} + robust-predicates@3.0.3: + resolution: {integrity: sha512-NS3levdsRIUOmiJ8FZWCP7LG3QpJyrs/TE0Zpf1yvZu8cAJJ6QMW92H1c7kWpdIHo8RvmLxN/o2JXTKHp74lUA==} + + rolldown@1.0.0-rc.12: + resolution: {integrity: sha512-yP4USLIMYrwpPHEFB5JGH1uxhcslv6/hL0OyvTuY+3qlOSJvZ7ntYnoWpehBxufkgN0cvXxppuTu5hHa/zPh+A==} + engines: {node: ^20.19.0 || >=22.12.0} + hasBin: true rollup@4.59.0: resolution: {integrity: sha512-2oMpl67a3zCH9H79LeMcbDhXW/UmWG/y2zuqnF2jQq5uq9TbM9TVyXvA4+t+ne2IIkBdrLpAaRQAvo7YI/Yyeg==} @@ -6979,10 +8064,6 @@ packages: resolution: {integrity: sha512-6R3J5M4AcbtLUdZmRv2SygeVaM7IhrLXu9BmnOGmmACak8fiUtOsYNWUS4uK7upbmHIBbLBeFeI//477BKLBzA==} engines: {node: '>=11.0.0'} - saxes@6.0.0: - resolution: {integrity: sha512-xAg7SOnEhrm5zI3puOOKyy1OMcMlIJZYNJY7xLBwSze0UjhPLnWfj2GF2EpT0jmzaJKIWKHLsaSSajf35bcYnA==} - engines: {node: '>=v12.22.7'} - scheduler@0.27.0: resolution: {integrity: sha512-eNv+WrVbKu1f3vbYJT/xtiF5syA5HPIMtf9IgY/nKg0sWqzAUEvqY/xm7OcZc/qafLx/iO9FgOmeSAp4v5ti/Q==} @@ -6998,6 +8079,9 @@ packages: resolution: {integrity: sha512-3A6sD0WYP7+QrjbfNA2FN3FsOaGGFoekCVgTyypy53gPxhbkCIjtO6YWgdrfM+n/8sI8JeXZOIxsHjMTNxQ4nQ==} engines: {node: ^14.0.0 || >=16.0.0} + seed-random@2.2.0: + resolution: {integrity: sha512-34EQV6AAHQGhoc0tn/96a9Fsi6v2xdqe/dMUwljGRaFOzR3EgRmECvD0O8vi8X+/uQ50LGHfkNu/Eue5TPKZkQ==} + semver@6.3.1: resolution: {integrity: sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA==} hasBin: true @@ -7060,8 +8144,8 @@ packages: resolution: {integrity: sha512-stxByr12oeeOyY2BlviTNQlYV5xOj47GirPr4yA1hE9JCtxfQN0+tVbkxwCtYDQWhEKWFHsEK48ORg5jrouCAg==} engines: {node: '>=20'} - smol-toml@1.6.0: - resolution: {integrity: sha512-4zemZi0HvTnYwLfrpk/CF9LOd9Lt87kAt50GnqhMpyF9U3poDAP2+iukq2bZsO/ufegbYehBkqINbsWxj4l4cw==} + smol-toml@1.6.1: + resolution: {integrity: sha512-dWUG8F5sIIARXih1DTaQAX4SsiTXhInKf1buxdY9DIg4ZYPZK5nGM1VRIYmEbDbsHt7USo99xSLFu5Q1IqTmsg==} engines: {node: '>= 18'} solid-js@1.9.11: @@ -7091,20 +8175,29 @@ packages: space-separated-tokens@2.0.2: resolution: {integrity: sha512-PEGlAwrG8yXGXRjW32fGbg66JAlOAwbObuqVoJpv/mRgoWDQfgH1wDPvtzWyUSNAXBGSk8h755YDbbcEy3SH2Q==} + spdx-correct@3.2.0: + resolution: {integrity: sha512-kN9dJbvnySHULIluDHy32WHRUu3Og7B9sbY7tsFLctQkIqnMh3hErYgdMjTYuqmcXX+lK5T1lnUt3G7zNswmZA==} + spdx-exceptions@2.5.0: resolution: {integrity: sha512-PiU42r+xO4UbUS1buo3LPJkjlO7430Xn5SVAhdpzzsPHsjbYVflnnFdATgabnLude+Cqu25p6N+g2lw/PFsa4w==} + spdx-expression-parse@3.0.1: + resolution: {integrity: sha512-cbqHunsQWnJNE6KhVSMsMeH5H/L9EpymbzqTQ3uLwNCLZ1Q481oWaofqH7nO6V07xlXwY6PhQdQ2IedWx/ZK4Q==} + spdx-expression-parse@4.0.0: resolution: {integrity: sha512-Clya5JIij/7C6bRR22+tnGXbc4VKlibKSVj2iHvVeX5iMW7s1SIQlqu699JkODJJIhh/pUu8L0/VLh8xflD+LQ==} spdx-license-ids@3.0.23: resolution: {integrity: sha512-CWLcCCH7VLu13TgOH+r8p1O/Znwhqv/dbb6lqWy67G+pT1kHmeD/+V36AVb/vq8QMIQwVShJ6Ssl5FPh0fuSdw==} - srvx@0.11.12: - resolution: {integrity: sha512-AQfrGqntqVPXgP03pvBDN1KyevHC+KmYVqb8vVf4N+aomQqdhaZxjvoVp+AOm4u6x+GgNQY3MVzAUIn+TqwkOA==} + srvx@0.11.13: + resolution: {integrity: sha512-oknN6qduuMPafxKtHucUeG32Q963pjriA5g3/Bl05cwEsUe5VVbIU4qR9LrALHbipSCyBe+VmfDGGydqazDRkw==} engines: {node: '>=20.16.0'} hasBin: true + stackframe@1.3.4: + resolution: {integrity: sha512-oeVtt7eWQS+Na6F//S4kJ2K2VbRlS9D43mAlMyVpVWovy9o+jfgH8O9agzANzaiLjclA0oYzUXEM4PurhSUChw==} + state-local@1.0.7: resolution: {integrity: sha512-HTEHMNieakEnoe33shBYcZ7NX83ACUjCu8c40iOGEZsngj9zRnkqS9j1pqQPXwobB0ZcVTk27REb7COQ0UR59w==} @@ -7115,8 +8208,8 @@ packages: resolution: {integrity: sha512-9SN0XIjBBXCT6ZXXVnScJN4KP2RyFg6B8sEoFlugVHMANysfaEni4LTWlvUQQ/R0wgZl1Ovt9KBQbzn21kHoZA==} engines: {node: '>=20.19.0'} - storybook@10.3.0: - resolution: {integrity: sha512-OpLdng98l7cACuqBoQwewx21Vhgl9XPssgLdXQudW0+N5QPjinWXZpZCquZpXpNCyw5s5BFAcv+jKB3Qkf9jeA==} + storybook@10.3.3: + resolution: {integrity: sha512-tMoRAts9EVqf+mEMPLC6z1DPyHbcPe+CV1MhLN55IKsl0HxNjvVGK44rVPSePbltPE6vIsn4bdRj6CCUt8SJwQ==} hasBin: true peerDependencies: prettier: ^2 || ^3 @@ -7130,8 +8223,9 @@ packages: react: ^18.0.0 || ^19.0.0 react-dom: ^18.0.0 || ^19.0.0 - strict-event-emitter@0.5.1: - resolution: {integrity: sha512-vMgjE/GGEPEFnhFub6pa4FmJBRBVOLpIII2hvCZ8Kzb7K0hlHo7mQv6xYrBvCL2LtAIBwFUK8wvuJgTVSQ5MFQ==} + string-argv@0.3.1: + resolution: {integrity: sha512-a1uQGz7IyVy9YwhqjZIZu1c8JO8dNIe20xBmSS6qu9kv++k3JGzCVmprbNN5Kn+BgzD5E7YYwg1CcjuJMRNsvg==} + engines: {node: '>=0.6.19'} string-argv@0.3.2: resolution: {integrity: sha512-aqD2Q0144Z+/RqG52NeHEkZauTAUWJO8c6yTftGJKO3Tja5tUgIfmIl6kExvhtxSDP7fXB6DvzkfMpCd/F3G+Q==} @@ -7184,6 +8278,9 @@ packages: strip-literal@3.1.0: resolution: {integrity: sha512-8r3mkIM/2+PpjHoOtiAW8Rg3jJLHaV7xPwG+YRGrv6FP0wwk/toTpATxWYOW0BKdWwl82VT2tFYi5DlROa0Mxg==} + structured-clone-es@2.0.0: + resolution: {integrity: sha512-5UuAHmBLXYPCl22xWJrFuGmIhBKQzxISPVz6E7nmTmTcAOpUzlbjKJsRrCE4vADmMQ0dzeCnlWn9XufnAGf76Q==} + style-to-js@1.1.21: resolution: {integrity: sha512-RjQetxJrrUJLQPHbLku6U/ocGtzyjbJMP9lCNK7Ag0CNh690nSH8woqWH9u16nMjYBAok+i7JO1NP2pOy8IsPQ==} @@ -7228,9 +8325,6 @@ packages: engines: {node: '>=14.0.0'} hasBin: true - symbol-tree@3.2.4: - resolution: {integrity: sha512-9QNk5KwDF+Bvz+PyObkmSYjI5ksVUYtjW7AU22r2NKcfLJcXp96hkDWU3+XndOsUb+AQ9QhfzfCT2O+CNWT5Tw==} - synckit@0.11.12: resolution: {integrity: sha512-Bh7QjT8/SuKUIfObSXNHNSK6WHo6J1tHCqJsuaFDP7gP0fkzSfTxI8y85JrppZ0h8l0maIgc2tfuZQ6/t3GtnQ==} engines: {node: ^14.18.0 || >=16.0.0} @@ -7262,8 +8356,8 @@ packages: engines: {node: '>=14.0.0'} hasBin: true - tapable@2.3.0: - resolution: {integrity: sha512-g9ljZiwki/LfxmQADO3dEY1CbpmXT5Hm2fJ+QaGKwSXUylMybePR7/67YW7jOrrvjEgL1Fmz5kzyAjWVWLlucg==} + tapable@2.3.2: + resolution: {integrity: sha512-1MOpMXuhGzGL5TTCZFItxCc0AARf1EZFQkGqMm7ERKj8+Hgr5oLvJOVFcC+lRmR8hCe2S3jC4T5D7Vg/d7/fhA==} engines: {node: '>=6'} tar-fs@2.1.4: @@ -7309,6 +8403,9 @@ packages: thenify@3.3.1: resolution: {integrity: sha512-RVZSIV5IG10Hk3enotrhvz0T9em6cyHBLkH/YAZuKqd8hRkKhSfCGIcP2KUY0EPxndzANBmNllzWPwak+bheSw==} + tiny-case@1.0.3: + resolution: {integrity: sha512-Eet/eeMhkO6TX8mnUteS9zgPbUMQa4I6Kkp5ORiBD5476/m+PIRiumP5tmh5ioJpH7k51Kehawy2UDfsnxxY8Q==} + tiny-inflate@1.0.3: resolution: {integrity: sha512-pkY1fj1cKHb2seWDy0B16HeWyczlJA9/WW3u3c4z/NiWDsO3DOU5D7nhTLE9CF0yXv/QZFY7sEJmj24dK+Rrqw==} @@ -7321,6 +8418,9 @@ packages: tinybench@2.9.0: resolution: {integrity: sha512-0+DUvqWMValLmha6lr4kD8iAMK1HzV0/aKnCtWb9v9641TnP/MFb7Pc2bxoxQjTXAErryXVgUOfv2YqNllqGeg==} + tinyexec@0.3.2: + resolution: {integrity: sha512-KQQR9yN7R5+OSwaK0XQoj22pwHoTlgYqmUscPYoknOoWCWfj/5/ABTMRi69FrKU5ffPVh5QcFikpWJI/P1ocHA==} + tinyexec@1.0.4: resolution: {integrity: sha512-u9r3uZC0bdpGOXtlxUIdwf9pkmvhqJdrVCH9fapQtgy/OeTTMZ1nqH7agtvEfmGui6e1XxjcdrlxvxJvc3sMqw==} engines: {node: '>=18'} @@ -7345,11 +8445,11 @@ packages: resolution: {integrity: sha512-azl+t0z7pw/z958Gy9svOTuzqIk6xq+NSheJzn5MMWtWTFywIacg2wUlzKFGtt3cthx0r2SxMK0yzJOR0IES7Q==} engines: {node: '>=14.0.0'} - tldts-core@7.0.26: - resolution: {integrity: sha512-5WJ2SqFsv4G2Dwi7ZFVRnz6b2H1od39QME1lc2y5Ew3eWiZMAeqOAfWpRP9jHvhUl881406QtZTODvjttJs+ew==} + tldts-core@7.0.27: + resolution: {integrity: sha512-YQ7uPjgWUibIK6DW5lrKujGwUKhLevU4hcGbP5O6TcIUb+oTjJYJVWPS4nZsIHrEEEG6myk/oqAJUEQmpZrHsg==} - tldts@7.0.26: - resolution: {integrity: sha512-WiGwQjr0qYdNNG8KpMKlSvpxz652lqa3Rd+/hSaDcY4Uo6SKWZq2LAF+hsAhUewTtYhXlorBKgNF3Kk8hnjGoQ==} + tldts@7.0.27: + resolution: {integrity: sha512-I4FZcVFcqCRuT0ph6dCDpPuO4Xgzvh+spkcTr1gK7peIvxWauoloVO0vuy1FQnijT63ss6AsHB6+OIM4aXHbPg==} hasBin: true to-regex-range@5.0.1: @@ -7367,17 +8467,16 @@ packages: resolution: {integrity: sha512-A5F0cM6+mDleacLIEUkmfpkBbnHJFV1d2rprHU2MXNk7mlxHq2zGojA+SRvQD1RoMo9gqjZPWEaKG4v1BQ48lw==} engines: {node: ^20.19.0 || ^22.13.0 || >=24} + toposort@2.0.2: + resolution: {integrity: sha512-0a5EOkAUp8D4moMi2W8ZF8jcga7BgZd91O/yabJCFY8az+XSzeGyTKs0Aoo897iV1Nj6guFq8orWDS96z91oGg==} + totalist@3.0.1: resolution: {integrity: sha512-sf4i37nQ2LBx4m3wB74y+ubopq6W/dIzXg0FDGjsYnZHVa1Da8FH853wlL2gtUhg+xJXjfk3kUZS3BRoQeoQBQ==} engines: {node: '>=6'} - tough-cookie@6.0.1: - resolution: {integrity: sha512-LktZQb3IeoUWB9lqR5EWTHgW/VTITCXg4D21M+lvybRVdylLrRMnqaIONLVb5mav8vM19m44HIcGq4qASeu2Qw==} - engines: {node: '>=16'} - - tr46@6.0.0: - resolution: {integrity: sha512-bLVMLPtstlZ4iMQHpFHTR7GAGj2jxi8Dg0s2h2MafAE4uSWF98FC/3MomU51iQAMf8/qDUbKWf5GxuvvVcXEhw==} - engines: {node: '>=20'} + tree-kill@1.2.2: + resolution: {integrity: sha512-L0Orpi8qGpRG//Nd+H90vFB+3iHnue1zSSGmNOOCh1GLJ7rUKVwV2HvijphGQS2UmhUZewS9VgvxYIdgr+fG1A==} + hasBin: true trim-lines@3.0.1: resolution: {integrity: sha512-kRj8B+YHZCc9kQYdWfJB2/oUl9rA99qbowYYBtr4ui4mZyAQ2JpvVBd/6U2YloATfqBhBTSMhTpgBHtU0Mf3Rg==} @@ -7385,8 +8484,8 @@ packages: trough@2.2.0: resolution: {integrity: sha512-tmMpK00BjZiUyVyvrBK7knerNgmgvcV/KLVyuma/SC+TQN167GrMRciANTz09+k3zW8L8t60jWO1GpfkZdjTaw==} - ts-api-utils@2.4.0: - resolution: {integrity: sha512-3TaVTaAv2gTiMB35i3FiGJaRfwb3Pyn/j3m/bfAvGe8FB7CF6u+LMYqYlDh7reQf7UNvoTvdfAqHGmPGOSsPmA==} + ts-api-utils@2.5.0: + resolution: {integrity: sha512-OJ/ibxhPlqrMM0UiNHJ/0CKQkoKF243/AEmplt3qpRgkW8VG7IfOS41h7V8TjITqdByHzrjcS/2si+y4lIh8NA==} engines: {node: '>=18.12'} peerDependencies: typescript: '>=4.8.4' @@ -7436,6 +8535,25 @@ packages: tslib@2.8.1: resolution: {integrity: sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==} + tsup@8.5.1: + resolution: {integrity: sha512-xtgkqwdhpKWr3tKPmCkvYmS9xnQK3m3XgxZHwSUjvfTjp7YfXe5tT3GgWi0F2N+ZSMsOeWeZFh7ZZFg5iPhing==} + engines: {node: '>=18'} + hasBin: true + peerDependencies: + '@microsoft/api-extractor': ^7.36.0 + '@swc/core': ^1 + postcss: ^8.4.12 + typescript: '>=4.5.0' + peerDependenciesMeta: + '@microsoft/api-extractor': + optional: true + '@swc/core': + optional: true + postcss: + optional: true + typescript: + optional: true + tsx@4.21.0: resolution: {integrity: sha512-5C1sg4USs1lfG0GFb2RLXsdpXqBSEhAaA/0kPL01wxzpMqLILNxIxIOKiILz+cdg/pLnOUxFYOR5yhHU666wbw==} engines: {node: '>=18.0.0'} @@ -7451,8 +8569,16 @@ packages: resolution: {integrity: sha512-XleUoc9uwGXqjWwXaUTZAmzMcFZ5858QA2vvx1Ur5xIcixXIP+8LnFDgRplU30us6teqdlskFfu+ae4K79Ooew==} engines: {node: '>= 0.8.0'} - type-fest@5.4.4: - resolution: {integrity: sha512-JnTrzGu+zPV3aXIUhnyWJj4z/wigMsdYajGLIYakqyOW1nPllzXEJee0QQbHj+CTIQtXGlAjuK0UY+2xTyjVAw==} + type-fest@2.19.0: + resolution: {integrity: sha512-RAH822pAdBgcNMAfWnCBU3CFZcfZ/i1eZjwFU/dsLKumyuuP3niueg2UAukXYF0E2AAoc82ZSSf9J0WQBinzHA==} + engines: {node: '>=12.20'} + + type-fest@4.41.0: + resolution: {integrity: sha512-TeTSQ6H5YHvpqVwBRcnLDCBnDOHWYu7IvGbHT6N8AOymcr9PJGjc1GTtiWZTYg0NCgYwvnYWEkVChQAr9bjfwA==} + engines: {node: '>=16'} + + type-fest@5.5.0: + resolution: {integrity: sha512-PlBfpQwiUvGViBNX84Yxwjsdhd1TUlXr6zjX7eoirtCPIr08NAmxwa+fcYBTeRQxHo9YC9wwF3m9i700sHma8g==} engines: {node: '>=20'} typescript@5.9.3: @@ -7485,13 +8611,13 @@ packages: resolution: {integrity: sha512-jxytwMHhsbdpBXxLAcuu0fzlQeXCNnWdDyRHpvWsUl8vd98UwYdl9YTyn8/HcpcJPC3pwUveefsa3zTxyD/ERg==} engines: {node: '>=20.18.1'} - undici@7.24.4: - resolution: {integrity: sha512-BM/JzwwaRXxrLdElV2Uo6cTLEjhSb3WXboncJamZ15NgUURmvlXvxa6xkwIOILIjPNo9i8ku136ZvWV0Uly8+w==} - engines: {node: '>=20.18.1'} - unicode-trie@2.0.0: resolution: {integrity: sha512-x7bc76x0bm4prf1VLg79uhAzKw8DVboClSN5VxJuQ+LKDOVEW9CdH+VY7SP+vX7xCYQqzzgQpFqz15zeLvAtZQ==} + unicorn-magic@0.4.0: + resolution: {integrity: sha512-wH590V9VNgYH9g3lH9wWjTrUoKsjLF6sGLjhR4sH1LWpLmCOH0Zf7PukhDA8BiS7KHe4oPNkcTHqYkj7SOGUOw==} + engines: {node: '>=20'} + unified@11.0.5: resolution: {integrity: sha512-xKvGhPWw3k84Qjh8bI3ZeJjqnyadK+GEFtazSfZv/rKeTkTjOJho6mFqh2SM96iIcZokxiOpg78GazTSg8+KHA==} @@ -7519,9 +8645,6 @@ packages: unist-util-visit@5.1.0: resolution: {integrity: sha512-m+vIdyeCOpdr/QeQCu2EzxX/ohgS8KbnPDgFni4dQsfSCtpz8UqDyY5GjRru8PDKuYn7Fq19j1CQ+nJSsGKOzg==} - universal-user-agent@7.0.3: - resolution: {integrity: sha512-TmnEAEAsBJVZM/AADELsK76llnwcf9vMKuPz8JflO1frO8Lchitr0fNaN9d+Ap0BjKtqWqd/J17qeDnXh8CL2A==} - universalify@2.0.1: resolution: {integrity: sha512-gptHNQghINnc/vTGIk0SOFGFNXw7JVrlRUtConJRlvaw6DuX0wO5Jeko9sWrMBhh+PsYAZ7oXAiOnf/UKogyiw==} engines: {node: '>= 10.0.0'} @@ -7543,6 +8666,9 @@ packages: peerDependencies: browserslist: '>= 4.21.0' + upper-case-first@2.0.2: + resolution: {integrity: sha512-514ppYHBaKwfJRK/pNC6c/OxfGa0obSnAl106u97Ed0I625Nin96KAjttZF6ZL3e1XLtphxnqrOi9iWgm+u+bg==} + uri-js@4.4.1: resolution: {integrity: sha512-7rKUyy33Q1yc98pQ1DAmLtwX109F7TIfWlW1Ydo8Wl1ii1SeHieeh0HHfPeL2fMXK6z0s8ecKs9frCuLJvndBg==} @@ -7607,6 +8733,9 @@ packages: peerDependencies: react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 + util-arity@1.1.0: + resolution: {integrity: sha512-kkyIsXKwemfSy8ZEoaIz06ApApnWsk5hQO0vLjZS6UkBiGiW++Jsyb8vSBoc0WKlffGoGs5yYy/j5pp8zckrFA==} + util-deprecate@1.0.2: resolution: {integrity: sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==} @@ -7618,14 +8747,17 @@ packages: resolution: {integrity: sha512-XQegIaBTVUjSHliKqcnFqYypAd4S+WCYt5NIeRs6w/UAry7z8Y9j5ZwRRL4kzq9U3sD6v+85er9FvkEaBpji2w==} hasBin: true - valibot@1.3.0: - resolution: {integrity: sha512-SItIaOFnWYho/AcRU5gOtyfkTsuDTC3tRv+jy4/py8xERPnvHdM+ybD1iIqWTATVWG1nZetOfwZKq5upBjSqzw==} + valibot@1.3.1: + resolution: {integrity: sha512-sfdRir/QFM0JaF22hqTroPc5xy4DimuGQVKFrzF1YfGwaS1nJot3Y8VqMdLO2Lg27fMzat2yD3pY5PbAYO39Gg==} peerDependencies: typescript: '>=5' peerDependenciesMeta: typescript: optional: true + validate-npm-package-license@3.0.4: + resolution: {integrity: sha512-DpKm2Ui/xN7/HQKCtpZxoRWBhZ9Z0kqtygG8XCgNQ8ZlDnxuQmWhj566j8fN4Cu3/JmbhsDo7fcAJq4s9h27Ew==} + vfile-location@5.0.3: resolution: {integrity: sha512-5yXvWDEgqeiYiBe1lbxYF7UMAIm/IcopxMHrMQDq3nvKcjPKIhZklUKL+AE7J7uApI4kwe2snsK+eI6UTj9EHg==} @@ -7635,14 +8767,14 @@ packages: vfile@6.0.3: resolution: {integrity: sha512-KzIbH/9tXat2u30jf+smMwFCsno4wHVdNmzFyL+T/L3UGqqk6JKfVqOFOZEpZSHADH1k40ab6NUIXZq422ov3Q==} - vinext@0.0.31: - resolution: {integrity: sha512-4se0fL74VjlUOluvozC3ZfMpKSlvkNc5Akqmm8jehXKKp7ncuTLkggLpuMsEnS9vVMywcSdE+B00d1liMfLVow==} + vinext@0.0.38: + resolution: {integrity: sha512-zlQswirXCApDgAFq1eoO/YbRlavGE+Bnowz5vXoQa2EmbFhYg52+T8SZs1QWdOqkbZMhpLIV/iaWvHtkRv2t4Q==} engines: {node: '>=22'} hasBin: true peerDependencies: '@mdx-js/rollup': ^3.0.0 '@vitejs/plugin-react': ^5.1.4 || ^6.0.0 - '@vitejs/plugin-rsc': ^0.5.19 + '@vitejs/plugin-rsc': ^0.5.21 react: '>=19.2.0' react-dom: '>=19.2.0' react-server-dom-webpack: ^19.2.4 @@ -7655,41 +8787,31 @@ packages: react-server-dom-webpack: optional: true - vite-dev-rpc@1.1.0: - resolution: {integrity: sha512-pKXZlgoXGoE8sEKiKJSng4hI1sQ4wi5YT24FCrwrLt6opmkjlqPPVmiPWWJn8M8byMxRGzp1CrFuqQs4M/Z39A==} - peerDependencies: - vite: ^2.9.0 || ^3.0.0-0 || ^4.0.0-0 || ^5.0.0-0 || ^6.0.1 || ^7.0.0-0 - - vite-hot-client@2.1.0: - resolution: {integrity: sha512-7SpgZmU7R+dDnSmvXE1mfDtnHLHQSisdySVR7lO8ceAXvM0otZeuQQ6C8LrS5d/aYyP/QZ0hI0L+dIPrm4YlFQ==} - peerDependencies: - vite: ^2.6.0 || ^3.0.0 || ^4.0.0 || ^5.0.0-0 || ^6.0.0-0 || ^7.0.0-0 - vite-plugin-commonjs@0.10.4: resolution: {integrity: sha512-eWQuvQKCcx0QYB5e5xfxBNjQKyrjEWZIR9UOkOV6JAgxVhtbZvCOF+FNC2ZijBJ3U3Px04ZMMyyMyFBVWIJ5+g==} vite-plugin-dynamic-import@1.6.0: resolution: {integrity: sha512-TM0sz70wfzTIo9YCxVFwS8OA9lNREsh+0vMHGSkWDTZ7bgd1Yjs5RV8EgB634l/91IsXJReg0xtmuQqP0mf+rg==} - vite-plugin-inspect@11.3.3: - resolution: {integrity: sha512-u2eV5La99oHoYPHE6UvbwgEqKKOQGz86wMg40CCosP6q8BkB6e5xPneZfYagK4ojPJSj5anHCrnvC20DpwVdRA==} + vite-plugin-inspect@12.0.0-beta.1: + resolution: {integrity: sha512-ang8DMcQxr2MJRjdvwabkD0uOPFB5/fP4hldZvAqCl82SABXK1zYLyZKGrauCblR61cvDUavxyiHbtD4zTdw0A==} engines: {node: '>=14'} peerDependencies: '@nuxt/kit': '*' - vite: ^6.0.0 || ^7.0.0-0 + vite: ^8.0.0-0 peerDependenciesMeta: '@nuxt/kit': optional: true - vite-plugin-storybook-nextjs@3.2.3: - resolution: {integrity: sha512-NQvkiZKfbGmk0j3mYeTJnGiucV+VOcryCsm/CoE7rBVRrpVntg5lWj+CbosFwHhGPpWQ3I4HJ3nSRzDq0u74Ug==} + vite-plugin-storybook-nextjs@3.2.4: + resolution: {integrity: sha512-shFOJpGQsWDS1FLm8BR8b6FIQC65pFZ5a0IUFGLiBHAX1eRz0N8TOhUJN4p708zfPBLDXqWzj++ocECe8gSoMg==} peerDependencies: next: ^14.1.0 || ^15.0.0 || ^16.0.0 - storybook: ^0.0.0-0 || ^9.0.0 || ^10.0.0 || ^10.0.0-0 || ^10.1.0-0 || ^10.2.0-0 || ^10.3.0-0 + storybook: ^0.0.0-0 || ^9.0.0 || ^10.0.0 || ^10.0.0-0 || ^10.1.0-0 || ^10.2.0-0 || ^10.3.0-0 || ^10.4.0-0 vite: ^5.0.0 || ^6.0.0 || ^7.0.0 || ^8.0.0 - vite-plus@0.1.12: - resolution: {integrity: sha512-8s1RzomZkgrJRiwiYWGq3R0txFPYfBBJGp73XNHQnme0KTTVH5dNm/E2GNyBSMFJbeeF7eh1OSgqWVc2FpR6eA==} + vite-plus@0.1.14: + resolution: {integrity: sha512-p4pWlpZZNiEsHxPWNdeIU9iuPix3ydm3ficb0dXPggoyIkdotfXtvn2NPX9KwfiQImU72EVEs4+VYBZYNcUYrw==} engines: {node: ^20.19.0 || >=22.12.0} hasBin: true @@ -7706,6 +8828,49 @@ packages: peerDependencies: vite: '*' + vite@8.0.3: + resolution: {integrity: sha512-B9ifbFudT1TFhfltfaIPgjo9Z3mDynBTJSUYxTjOQruf/zHH+ezCQKcoqO+h7a9Pw9Nm/OtlXAiGT1axBgwqrQ==} + engines: {node: ^20.19.0 || >=22.12.0} + hasBin: true + peerDependencies: + '@types/node': ^20.19.0 || >=22.12.0 + '@vitejs/devtools': ^0.1.0 + esbuild: 0.27.2 + jiti: '>=1.21.0' + less: ^4.0.0 + sass: ^1.70.0 + sass-embedded: ^1.70.0 + stylus: '>=0.54.8' + sugarss: ^5.0.0 + terser: ^5.16.0 + tsx: ^4.8.1 + yaml: 2.8.3 + peerDependenciesMeta: + '@types/node': + optional: true + '@vitejs/devtools': + optional: true + esbuild: + optional: true + jiti: + optional: true + less: + optional: true + sass: + optional: true + sass-embedded: + optional: true + stylus: + optional: true + sugarss: + optional: true + terser: + optional: true + tsx: + optional: true + yaml: + optional: true + vitefu@1.1.2: resolution: {integrity: sha512-zpKATdUbzbsycPFBN71nS2uzBUQiVnFoOrr2rvqv34S1lcAgMKKkjWleLGeiJlZ8lwCXvtWaRn7R3ZC16SYRuw==} peerDependencies: @@ -7714,8 +8879,8 @@ packages: vite: optional: true - vitest-canvas-mock@1.1.3: - resolution: {integrity: sha512-zlKJR776Qgd+bcACPh0Pq5MG3xWq+CdkACKY/wX4Jyija0BSz8LH3aCCgwFKYFwtm565+050YFEGG9Ki0gE/Hw==} + vitest-canvas-mock@1.1.4: + resolution: {integrity: sha512-4boWHY+STwAxGl1+uwakNNoQky5EjPLC8HuponXNoAscYyT1h/F7RUvTkl4IyF/MiWr3V8Q626je3Iel3eArqA==} peerDependencies: vitest: ^3.0.0 || ^4.0.0 @@ -7749,10 +8914,6 @@ packages: peerDependencies: eslint: ^8.57.0 || ^9.0.0 || ^10.0.0 - w3c-xmlserializer@5.0.0: - resolution: {integrity: sha512-o8qghlI8NZHU1lLPrpi2+Uq7abh4GGPpYANlalzWxyWteJOCsr/P+oPBA49TOLu5FTZO4d3F9MnWJfiMo4BkmA==} - engines: {node: '>=18'} - walk-up-path@4.0.0: resolution: {integrity: sha512-3hu+tD8YzSLGuFYtPRb48vdhKMi0KQV5sn+uWr8+7dMEq/2G/dtLrdDinkLjqq5TIbIBjYJ4Ax/n3YiaW7QM8A==} engines: {node: 20 || >=22} @@ -7767,10 +8928,6 @@ packages: web-vitals@5.1.0: resolution: {integrity: sha512-ArI3kx5jI0atlTtmV0fWU3fjpLmq/nD3Zr1iFFlJLaqa5wLBkUSzINwBPySCX/8jRyjlmy1Volw1kz1g9XE4Jg==} - webidl-conversions@8.0.1: - resolution: {integrity: sha512-BMhLD/Sw+GbJC21C/UgyaZX41nPt8bUTg+jWyDeg7e7YN4xOM05YPSIXceACnXVtqyEw/LMClUQMtMZ+PGGpqQ==} - engines: {node: '>=20'} - webpack-sources@3.3.4: resolution: {integrity: sha512-7tP1PdV4vF+lYPnkMR0jMY5/la2ub5Fc/8VQrrU+lXkiM6C4TjVfGw7iKfyhnTQOsD+6Q/iKw0eFciziRgD58Q==} engines: {node: '>=10.13.0'} @@ -7793,18 +8950,14 @@ packages: engines: {node: '>=18'} deprecated: Use @exodus/bytes instead for a more spec-conformant and faster implementation + whatwg-mimetype@3.0.0: + resolution: {integrity: sha512-nt+N2dzIutVRxARx1nghPKGv1xHikU7HKdfafKkLNLindmPU/ch3U31NOCGGA/dmPcmb1VlofO0vnKAcsm0o/Q==} + engines: {node: '>=12'} + whatwg-mimetype@4.0.0: resolution: {integrity: sha512-QaKxh0eNIi2mE9p2vEdzfagOKHCcj1pJ56EEHGQOVxp8r9/iszLUUV7v89x9O1p/T+NlTM5W7jW6+cz4Fq1YVg==} engines: {node: '>=18'} - whatwg-mimetype@5.0.0: - resolution: {integrity: sha512-sXcNcHOC51uPGF0P/D4NVtrkjSU2fNsm9iog4ZvZJsL3rjoDAzXZhkm2MWt1y+PUdggKAYVoMAIYcs78wJ51Cw==} - engines: {node: '>=20'} - - whatwg-url@16.0.1: - resolution: {integrity: sha512-1to4zXBxmXHV3IiSSEInrreIlu02vUOvrhxJJH5vcxYTBDAx51cqZiKdyTxlecdKNSjj8EcxGBxNf6Vg+945gw==} - engines: {node: ^20.19.0 || ^22.12.0 || >=24.0.0} - which@2.0.2: resolution: {integrity: sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==} engines: {node: '>= 8'} @@ -7821,8 +8974,8 @@ packages: wrappy@1.0.2: resolution: {integrity: sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==} - ws@8.19.0: - resolution: {integrity: sha512-blAT2mjOEIi0ZzruJfIhb3nps74PRWTCz1IjglWEEpQl5XS/UNama6u2/rjFkDDouqr4L67ry+1aGIALViWjDg==} + ws@8.20.0: + resolution: {integrity: sha512-sAt8BhgNbzCtgGbt2OxmpuryO63ZoDk/sqaB/znQm94T4fCEsy/yV+7CdC1kJhOU9lboAEU7R3kquuycDoibVA==} engines: {node: '>=10.0.0'} peerDependencies: bufferutil: ^4.0.1 @@ -7837,16 +8990,17 @@ packages: resolution: {integrity: sha512-h3Fbisa2nKGPxCpm89Hk33lBLsnaGBvctQopaBSOW/uIs6FTe1ATyAnKFJrzVs9vpGdsTe73WF3V4lIsk4Gacw==} engines: {node: '>=18'} + wsl-utils@0.3.1: + resolution: {integrity: sha512-g/eziiSUNBSsdDJtCLB8bdYEUMj4jR7AGeUo96p/3dTafgjHhpF4RiCFPiRILwjQoDXx5MqkBr4fwWtR3Ky4Wg==} + engines: {node: '>=20'} + xml-name-validator@4.0.0: resolution: {integrity: sha512-ICP2e+jsHvAj2E2lIHxa5tjXRlKDJo4IdvPvCXbXQGdzSfmSpNVyIKMvoZHjDY9DP0zV17iI85o90vRFXNccRw==} engines: {node: '>=12'} - xml-name-validator@5.0.0: - resolution: {integrity: sha512-EvGK8EJ3DhaHfbRlETOWAS5pO9MZITeauHKJyb8wyajUfQUenkIg2MvLDTZ4T/TgIcm3HU0TFBgWWboAZ30UHg==} - engines: {node: '>=18'} - - xmlchars@2.2.0: - resolution: {integrity: sha512-JZnDKK8B0RCDw84FNdDAIpZK+JuJw+s7Lz8nksI7SIuU3UXJJslUthsi+uWBUYOwPFwW7W7PRLRfUKpxjtjFCw==} + xmlbuilder@15.1.1: + resolution: {integrity: sha512-yMqGBqtXyeN1e3TGYvgNgDVZ3j84W4cwkOXQswghol6APgZWaff9lnbvN7MHYJOiXsvGPXtjTYJEiC9J2wv9Eg==} + engines: {node: '>=8.0'} xtend@4.0.2: resolution: {integrity: sha512-LKYU1iAXJXUgAXn9URjiu+MWhyUXHsvfp7mcuYm9dSUKK0/CjtrUwFAxD82/mCWbtLsGjFIad0wIsod4zrTAEQ==} @@ -7863,8 +9017,8 @@ packages: resolution: {integrity: sha512-h0uDm97wvT2bokfwwTmY6kJ1hp6YDFL0nRHwNKz8s/VD1FH/vvZjAKoMUE+un0eaYBSG7/c6h+lJTP+31tjgTw==} engines: {node: ^20.19.0 || ^22.13.0 || >=24} - yaml@2.8.2: - resolution: {integrity: sha512-mplynKqc1C2hTVYxd0PU2xQAc22TI1vShAYGksCCfxbn/dFwnHTNi1bvYsBTkhdUNtGIf5xNOg938rrSSYvS9A==} + yaml@2.8.3: + resolution: {integrity: sha512-AvbaCLOO2Otw/lW5bmh9d/WEdcDFdQp2Z2ZUH3pX9U2ihyUY0nvLv7J6TrWowklRGPYbB/IuIMfYgxaCPg5Bpg==} engines: {node: '>= 14.6'} hasBin: true @@ -7880,9 +9034,16 @@ packages: resolution: {integrity: sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q==} engines: {node: '>=10'} + yocto-queue@1.2.2: + resolution: {integrity: sha512-4LCcse/U2MHZ63HAJVE+v71o7yOdIe4cZ70Wpf8D/IyjDKYQLV5GD46B+hSTjJsvV5PztjvHoU580EftxjDZFQ==} + engines: {node: '>=12.20'} + yoga-layout@3.2.1: resolution: {integrity: sha512-0LPOt3AxKqMdFBZA3HBAt/t/8vIKq7VaQYbuA8WxCgung+p9TVyKRYdpvCb80HcdTN2NkbIKbhNwKUfm3tQywQ==} + yup@1.7.1: + resolution: {integrity: sha512-GKHFX2nXul2/4Dtfxhozv701jLQHdf6J34YDh2cEkpqoo8le5Mg6/LrdseVLrFarmFygZTlfIhHx/QKfb/QWXw==} + zen-observable@0.10.0: resolution: {integrity: sha512-iI3lT0iojZhKwT5DaFy2Ce42n3yFcLdFyOh01G7H0flMY60P8MJuVFEoJoNwXlmAyQ45GrjL6AcZmmlv8A5rbw==} @@ -7948,26 +9109,27 @@ snapshots: '@alloc/quick-lru@5.2.0': {} - '@amplitude/analytics-browser@2.36.7': + '@amplitude/analytics-browser@2.38.0': dependencies: - '@amplitude/analytics-core': 2.41.7 - '@amplitude/plugin-autocapture-browser': 1.23.7 - '@amplitude/plugin-network-capture-browser': 1.9.7 - '@amplitude/plugin-page-url-enrichment-browser': 0.6.11 - '@amplitude/plugin-page-view-tracking-browser': 2.8.7 - '@amplitude/plugin-web-vitals-browser': 1.1.22 + '@amplitude/analytics-core': 2.44.0 + '@amplitude/plugin-autocapture-browser': 1.25.0 + '@amplitude/plugin-custom-enrichment-browser': 0.1.2 + '@amplitude/plugin-network-capture-browser': 1.9.11 + '@amplitude/plugin-page-url-enrichment-browser': 0.7.3 + '@amplitude/plugin-page-view-tracking-browser': 2.9.4 + '@amplitude/plugin-web-vitals-browser': 1.1.26 tslib: 2.8.1 - '@amplitude/analytics-client-common@2.4.37': + '@amplitude/analytics-client-common@2.4.41': dependencies: '@amplitude/analytics-connector': 1.6.4 - '@amplitude/analytics-core': 2.41.7 + '@amplitude/analytics-core': 2.44.0 '@amplitude/analytics-types': 2.11.1 tslib: 2.8.1 '@amplitude/analytics-connector@1.6.4': {} - '@amplitude/analytics-core@2.41.7': + '@amplitude/analytics-core@2.44.0': dependencies: '@amplitude/analytics-connector': 1.6.4 '@types/zen-observable': 0.8.3 @@ -7981,94 +9143,103 @@ snapshots: dependencies: js-base64: 3.7.8 - '@amplitude/plugin-autocapture-browser@1.23.7': + '@amplitude/plugin-autocapture-browser@1.25.0': dependencies: - '@amplitude/analytics-core': 2.41.7 + '@amplitude/analytics-core': 2.44.0 tslib: 2.8.1 - '@amplitude/plugin-network-capture-browser@1.9.7': + '@amplitude/plugin-custom-enrichment-browser@0.1.2': dependencies: - '@amplitude/analytics-core': 2.41.7 + '@amplitude/analytics-core': 2.44.0 tslib: 2.8.1 - '@amplitude/plugin-page-url-enrichment-browser@0.6.11': + '@amplitude/plugin-network-capture-browser@1.9.11': dependencies: - '@amplitude/analytics-core': 2.41.7 + '@amplitude/analytics-core': 2.44.0 tslib: 2.8.1 - '@amplitude/plugin-page-view-tracking-browser@2.8.7': + '@amplitude/plugin-page-url-enrichment-browser@0.7.3': dependencies: - '@amplitude/analytics-core': 2.41.7 + '@amplitude/analytics-core': 2.44.0 tslib: 2.8.1 - '@amplitude/plugin-session-replay-browser@1.26.4(@amplitude/rrweb@2.0.0-alpha.35)(rollup@4.59.0)': + '@amplitude/plugin-page-view-tracking-browser@2.9.4': dependencies: - '@amplitude/analytics-client-common': 2.4.37 - '@amplitude/analytics-core': 2.41.7 + '@amplitude/analytics-core': 2.44.0 + tslib: 2.8.1 + + '@amplitude/plugin-session-replay-browser@1.27.5(@amplitude/rrweb@2.0.0-alpha.37)(rollup@4.59.0)': + dependencies: + '@amplitude/analytics-client-common': 2.4.41 + '@amplitude/analytics-core': 2.44.0 '@amplitude/analytics-types': 2.11.1 - '@amplitude/rrweb-plugin-console-record': 2.0.0-alpha.35(@amplitude/rrweb@2.0.0-alpha.35) - '@amplitude/rrweb-record': 2.0.0-alpha.35 - '@amplitude/session-replay-browser': 1.33.1(@amplitude/rrweb@2.0.0-alpha.35)(rollup@4.59.0) + '@amplitude/rrweb-plugin-console-record': 2.0.0-alpha.36(@amplitude/rrweb@2.0.0-alpha.37) + '@amplitude/rrweb-record': 2.0.0-alpha.36 + '@amplitude/session-replay-browser': 1.35.0(@amplitude/rrweb@2.0.0-alpha.37)(rollup@4.59.0) idb-keyval: 6.2.2 tslib: 2.8.1 transitivePeerDependencies: - '@amplitude/rrweb' - rollup - '@amplitude/plugin-web-vitals-browser@1.1.22': + '@amplitude/plugin-web-vitals-browser@1.1.26': dependencies: - '@amplitude/analytics-core': 2.41.7 + '@amplitude/analytics-core': 2.44.0 tslib: 2.8.1 web-vitals: 5.1.0 - '@amplitude/rrdom@2.0.0-alpha.35': + '@amplitude/rrdom@2.0.0-alpha.37': dependencies: - '@amplitude/rrweb-snapshot': 2.0.0-alpha.35 + '@amplitude/rrweb-snapshot': 2.0.0-alpha.37 - '@amplitude/rrweb-packer@2.0.0-alpha.35': + '@amplitude/rrweb-packer@2.0.0-alpha.36': dependencies: - '@amplitude/rrweb-types': 2.0.0-alpha.35 + '@amplitude/rrweb-types': 2.0.0-alpha.37 fflate: 0.4.8 - '@amplitude/rrweb-plugin-console-record@2.0.0-alpha.35(@amplitude/rrweb@2.0.0-alpha.35)': + '@amplitude/rrweb-plugin-console-record@2.0.0-alpha.36(@amplitude/rrweb@2.0.0-alpha.37)': dependencies: - '@amplitude/rrweb': 2.0.0-alpha.35 + '@amplitude/rrweb': 2.0.0-alpha.37 - '@amplitude/rrweb-record@2.0.0-alpha.35': + '@amplitude/rrweb-record@2.0.0-alpha.36': dependencies: - '@amplitude/rrweb': 2.0.0-alpha.35 - '@amplitude/rrweb-types': 2.0.0-alpha.35 + '@amplitude/rrweb': 2.0.0-alpha.37 + '@amplitude/rrweb-types': 2.0.0-alpha.37 - '@amplitude/rrweb-snapshot@2.0.0-alpha.35': + '@amplitude/rrweb-snapshot@2.0.0-alpha.37': dependencies: postcss: 8.5.8 - '@amplitude/rrweb-types@2.0.0-alpha.35': {} + '@amplitude/rrweb-types@2.0.0-alpha.36': {} - '@amplitude/rrweb-utils@2.0.0-alpha.35': {} + '@amplitude/rrweb-types@2.0.0-alpha.37': {} - '@amplitude/rrweb@2.0.0-alpha.35': + '@amplitude/rrweb-utils@2.0.0-alpha.36': {} + + '@amplitude/rrweb-utils@2.0.0-alpha.37': {} + + '@amplitude/rrweb@2.0.0-alpha.37': dependencies: - '@amplitude/rrdom': 2.0.0-alpha.35 - '@amplitude/rrweb-snapshot': 2.0.0-alpha.35 - '@amplitude/rrweb-types': 2.0.0-alpha.35 - '@amplitude/rrweb-utils': 2.0.0-alpha.35 + '@amplitude/rrdom': 2.0.0-alpha.37 + '@amplitude/rrweb-snapshot': 2.0.0-alpha.37 + '@amplitude/rrweb-types': 2.0.0-alpha.37 + '@amplitude/rrweb-utils': 2.0.0-alpha.37 '@types/css-font-loading-module': 0.0.7 '@xstate/fsm': 1.6.5 base64-arraybuffer: 1.0.2 mitt: 3.0.1 - '@amplitude/session-replay-browser@1.33.1(@amplitude/rrweb@2.0.0-alpha.35)(rollup@4.59.0)': + '@amplitude/session-replay-browser@1.35.0(@amplitude/rrweb@2.0.0-alpha.37)(rollup@4.59.0)': dependencies: - '@amplitude/analytics-client-common': 2.4.37 - '@amplitude/analytics-core': 2.41.7 + '@amplitude/analytics-client-common': 2.4.41 + '@amplitude/analytics-core': 2.44.0 '@amplitude/analytics-types': 2.11.1 '@amplitude/experiment-core': 0.7.2 - '@amplitude/rrweb-packer': 2.0.0-alpha.35 - '@amplitude/rrweb-plugin-console-record': 2.0.0-alpha.35(@amplitude/rrweb@2.0.0-alpha.35) - '@amplitude/rrweb-record': 2.0.0-alpha.35 - '@amplitude/rrweb-types': 2.0.0-alpha.35 - '@amplitude/rrweb-utils': 2.0.0-alpha.35 + '@amplitude/rrweb-packer': 2.0.0-alpha.36 + '@amplitude/rrweb-plugin-console-record': 2.0.0-alpha.36(@amplitude/rrweb@2.0.0-alpha.37) + '@amplitude/rrweb-record': 2.0.0-alpha.36 + '@amplitude/rrweb-types': 2.0.0-alpha.36 + '@amplitude/rrweb-utils': 2.0.0-alpha.36 '@amplitude/targeting': 0.2.0 '@rollup/plugin-replace': 6.0.3(rollup@4.59.0) idb: 8.0.0 @@ -8079,57 +9250,57 @@ snapshots: '@amplitude/targeting@0.2.0': dependencies: - '@amplitude/analytics-client-common': 2.4.37 - '@amplitude/analytics-core': 2.41.7 + '@amplitude/analytics-client-common': 2.4.41 + '@amplitude/analytics-core': 2.44.0 '@amplitude/analytics-types': 2.11.1 '@amplitude/experiment-core': 0.7.2 idb: 8.0.0 tslib: 2.8.1 - '@antfu/eslint-config@7.7.3(@eslint-react/eslint-plugin@2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3))(@next/eslint-plugin-next@16.2.0)(@typescript-eslint/rule-tester@8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3))(@typescript-eslint/typescript-estree@8.57.1(typescript@5.9.3))(@typescript-eslint/utils@8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3))(@voidzero-dev/vite-plus-test@0.1.12(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.0(canvas@3.2.1))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(@vue/compiler-sfc@3.5.30)(eslint-plugin-react-hooks@7.0.1(eslint@10.0.3(jiti@1.21.7)))(eslint-plugin-react-refresh@0.5.2(eslint@10.0.3(jiti@1.21.7)))(eslint@10.0.3(jiti@1.21.7))(oxlint@1.55.0(oxlint-tsgolint@0.17.0))(typescript@5.9.3)': + '@antfu/eslint-config@7.7.3(@eslint-react/eslint-plugin@3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(@next/eslint-plugin-next@16.2.1)(@typescript-eslint/rule-tester@8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(@typescript-eslint/typescript-estree@8.57.2(typescript@5.9.3))(@typescript-eslint/utils@8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(@voidzero-dev/vite-plus-test@0.1.14(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(@vue/compiler-sfc@3.5.31)(eslint-plugin-react-hooks@7.0.1(eslint@10.1.0(jiti@1.21.7)))(eslint-plugin-react-refresh@0.5.2(eslint@10.1.0(jiti@1.21.7)))(eslint@10.1.0(jiti@1.21.7))(oxlint@1.57.0(oxlint-tsgolint@0.17.3))(typescript@5.9.3)': dependencies: '@antfu/install-pkg': 1.1.0 '@clack/prompts': 1.1.0 - '@e18e/eslint-plugin': 0.2.0(eslint@10.0.3(jiti@1.21.7))(oxlint@1.55.0(oxlint-tsgolint@0.17.0)) - '@eslint-community/eslint-plugin-eslint-comments': 4.7.1(eslint@10.0.3(jiti@1.21.7)) + '@e18e/eslint-plugin': 0.2.0(eslint@10.1.0(jiti@1.21.7))(oxlint@1.57.0(oxlint-tsgolint@0.17.3)) + '@eslint-community/eslint-plugin-eslint-comments': 4.7.1(eslint@10.1.0(jiti@1.21.7)) '@eslint/markdown': 7.5.1 - '@stylistic/eslint-plugin': 5.10.0(eslint@10.0.3(jiti@1.21.7)) - '@typescript-eslint/eslint-plugin': 8.57.1(@typescript-eslint/parser@8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3))(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@typescript-eslint/parser': 8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@vitest/eslint-plugin': 1.6.12(@voidzero-dev/vite-plus-test@0.1.12(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.0(canvas@3.2.1))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) + '@stylistic/eslint-plugin': 5.10.0(eslint@10.1.0(jiti@1.21.7)) + '@typescript-eslint/eslint-plugin': 8.57.2(@typescript-eslint/parser@8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/parser': 8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@vitest/eslint-plugin': 1.6.13(@typescript-eslint/eslint-plugin@8.57.2(@typescript-eslint/parser@8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(@voidzero-dev/vite-plus-test@0.1.14(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) ansis: 4.2.0 cac: 7.0.0 - eslint: 10.0.3(jiti@1.21.7) - eslint-config-flat-gitignore: 2.2.1(eslint@10.0.3(jiti@1.21.7)) + eslint: 10.1.0(jiti@1.21.7) + eslint-config-flat-gitignore: 2.3.0(eslint@10.1.0(jiti@1.21.7)) eslint-flat-config-utils: 3.0.2 - eslint-merge-processors: 2.0.0(eslint@10.0.3(jiti@1.21.7)) - eslint-plugin-antfu: 3.2.2(eslint@10.0.3(jiti@1.21.7)) - eslint-plugin-command: 3.5.2(@typescript-eslint/rule-tester@8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3))(@typescript-eslint/typescript-estree@8.57.1(typescript@5.9.3))(@typescript-eslint/utils@8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3))(eslint@10.0.3(jiti@1.21.7)) - eslint-plugin-import-lite: 0.5.2(eslint@10.0.3(jiti@1.21.7)) - eslint-plugin-jsdoc: 62.8.0(eslint@10.0.3(jiti@1.21.7)) - eslint-plugin-jsonc: 3.1.2(eslint@10.0.3(jiti@1.21.7)) - eslint-plugin-n: 17.24.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) + eslint-merge-processors: 2.0.0(eslint@10.1.0(jiti@1.21.7)) + eslint-plugin-antfu: 3.2.2(eslint@10.1.0(jiti@1.21.7)) + eslint-plugin-command: 3.5.2(@typescript-eslint/rule-tester@8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(@typescript-eslint/typescript-estree@8.57.2(typescript@5.9.3))(@typescript-eslint/utils@8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(eslint@10.1.0(jiti@1.21.7)) + eslint-plugin-import-lite: 0.5.2(eslint@10.1.0(jiti@1.21.7)) + eslint-plugin-jsdoc: 62.8.1(eslint@10.1.0(jiti@1.21.7)) + eslint-plugin-jsonc: 3.1.2(eslint@10.1.0(jiti@1.21.7)) + eslint-plugin-n: 17.24.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) eslint-plugin-no-only-tests: 3.3.0 - eslint-plugin-perfectionist: 5.7.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - eslint-plugin-pnpm: 1.6.0(eslint@10.0.3(jiti@1.21.7)) - eslint-plugin-regexp: 3.1.0(eslint@10.0.3(jiti@1.21.7)) - eslint-plugin-toml: 1.3.1(eslint@10.0.3(jiti@1.21.7)) - eslint-plugin-unicorn: 63.0.0(eslint@10.0.3(jiti@1.21.7)) - eslint-plugin-unused-imports: 4.4.1(@typescript-eslint/eslint-plugin@8.57.1(@typescript-eslint/parser@8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3))(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3))(eslint@10.0.3(jiti@1.21.7)) - eslint-plugin-vue: 10.8.0(@stylistic/eslint-plugin@5.10.0(eslint@10.0.3(jiti@1.21.7)))(@typescript-eslint/parser@8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3))(eslint@10.0.3(jiti@1.21.7))(vue-eslint-parser@10.4.0(eslint@10.0.3(jiti@1.21.7))) - eslint-plugin-yml: 3.3.1(eslint@10.0.3(jiti@1.21.7)) - eslint-processor-vue-blocks: 2.0.0(@vue/compiler-sfc@3.5.30)(eslint@10.0.3(jiti@1.21.7)) + eslint-plugin-perfectionist: 5.7.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + eslint-plugin-pnpm: 1.6.0(eslint@10.1.0(jiti@1.21.7)) + eslint-plugin-regexp: 3.1.0(eslint@10.1.0(jiti@1.21.7)) + eslint-plugin-toml: 1.3.1(eslint@10.1.0(jiti@1.21.7)) + eslint-plugin-unicorn: 63.0.0(eslint@10.1.0(jiti@1.21.7)) + eslint-plugin-unused-imports: 4.4.1(@typescript-eslint/eslint-plugin@8.57.2(@typescript-eslint/parser@8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(eslint@10.1.0(jiti@1.21.7)) + eslint-plugin-vue: 10.8.0(@stylistic/eslint-plugin@5.10.0(eslint@10.1.0(jiti@1.21.7)))(@typescript-eslint/parser@8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(eslint@10.1.0(jiti@1.21.7))(vue-eslint-parser@10.4.0(eslint@10.1.0(jiti@1.21.7))) + eslint-plugin-yml: 3.3.1(eslint@10.1.0(jiti@1.21.7)) + eslint-processor-vue-blocks: 2.0.0(@vue/compiler-sfc@3.5.31)(eslint@10.1.0(jiti@1.21.7)) globals: 17.4.0 local-pkg: 1.1.2 parse-gitignore: 2.0.0 toml-eslint-parser: 1.0.3 - vue-eslint-parser: 10.4.0(eslint@10.0.3(jiti@1.21.7)) + vue-eslint-parser: 10.4.0(eslint@10.1.0(jiti@1.21.7)) yaml-eslint-parser: 2.0.0 optionalDependencies: - '@eslint-react/eslint-plugin': 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@next/eslint-plugin-next': 16.2.0 - eslint-plugin-react-hooks: 7.0.1(eslint@10.0.3(jiti@1.21.7)) - eslint-plugin-react-refresh: 0.5.2(eslint@10.0.3(jiti@1.21.7)) + '@eslint-react/eslint-plugin': 3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@next/eslint-plugin-next': 16.2.1 + eslint-plugin-react-hooks: 7.0.1(eslint@10.1.0(jiti@1.21.7)) + eslint-plugin-react-refresh: 0.5.2(eslint@10.1.0(jiti@1.21.7)) transitivePeerDependencies: - '@eslint/json' - '@typescript-eslint/rule-tester' @@ -8156,24 +9327,6 @@ snapshots: '@antfu/utils@8.1.1': {} - '@asamuzakjp/css-color@5.0.1': - dependencies: - '@csstools/css-calc': 3.1.1(@csstools/css-parser-algorithms@4.0.0(@csstools/css-tokenizer@4.0.0))(@csstools/css-tokenizer@4.0.0) - '@csstools/css-color-parser': 4.0.2(@csstools/css-parser-algorithms@4.0.0(@csstools/css-tokenizer@4.0.0))(@csstools/css-tokenizer@4.0.0) - '@csstools/css-parser-algorithms': 4.0.0(@csstools/css-tokenizer@4.0.0) - '@csstools/css-tokenizer': 4.0.0 - lru-cache: 11.2.7 - - '@asamuzakjp/dom-selector@7.0.3': - dependencies: - '@asamuzakjp/nwsapi': 2.3.9 - bidi-js: 1.0.3 - css-tree: 3.2.1 - is-potential-custom-element-name: 1.0.1 - lru-cache: 11.2.7 - - '@asamuzakjp/nwsapi@2.3.9': {} - '@babel/code-frame@7.29.0': dependencies: '@babel/helper-validator-identifier': 7.28.5 @@ -8195,7 +9348,7 @@ snapshots: '@babel/types': 7.29.0 '@jridgewell/remapping': 2.3.5 convert-source-map: 2.0.0 - debug: 4.4.3 + debug: 4.4.3(supports-color@8.1.1) gensync: 1.0.0-beta.2 json5: 2.2.3 semver: 6.3.1 @@ -8267,7 +9420,7 @@ snapshots: '@babel/parser': 7.29.2 '@babel/template': 7.28.6 '@babel/types': 7.29.0 - debug: 4.4.3 + debug: 4.4.3(supports-color@8.1.1) transitivePeerDependencies: - supports-color @@ -8304,10 +9457,6 @@ snapshots: '@braintree/sanitize-url@7.1.2': {} - '@bramus/specificity@2.4.2': - dependencies: - css-tree: 3.2.1 - '@chevrotain/cst-dts-gen@11.1.2': dependencies: '@chevrotain/gast': 11.1.2 @@ -8325,13 +9474,13 @@ snapshots: '@chevrotain/utils@11.1.2': {} - '@chromatic-com/storybook@5.0.1(storybook@10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))': + '@chromatic-com/storybook@5.1.1(storybook@10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))': dependencies: '@neoconfetti/react': 1.0.0 chromatic: 13.3.5 filesize: 10.1.6 jsonfile: 6.2.0 - storybook: 10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) + storybook: 10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) strip-ansi: 7.2.0 transitivePeerDependencies: - '@chromatic-com/cypress' @@ -8357,9 +9506,9 @@ snapshots: '@clack/core': 1.1.0 sisteransi: 1.0.5 - '@code-inspector/core@1.4.4': + '@code-inspector/core@1.4.5': dependencies: - '@vue/compiler-dom': 3.5.30 + '@vue/compiler-dom': 3.5.31 chalk: 4.1.2 dotenv: 16.6.1 launch-ide: 1.4.3 @@ -8367,81 +9516,167 @@ snapshots: transitivePeerDependencies: - supports-color - '@code-inspector/esbuild@1.4.4': + '@code-inspector/esbuild@1.4.5': dependencies: - '@code-inspector/core': 1.4.4 + '@code-inspector/core': 1.4.5 transitivePeerDependencies: - supports-color - '@code-inspector/mako@1.4.4': + '@code-inspector/mako@1.4.5': dependencies: - '@code-inspector/core': 1.4.4 + '@code-inspector/core': 1.4.5 transitivePeerDependencies: - supports-color - '@code-inspector/turbopack@1.4.4': + '@code-inspector/turbopack@1.4.5': dependencies: - '@code-inspector/core': 1.4.4 - '@code-inspector/webpack': 1.4.4 + '@code-inspector/core': 1.4.5 + '@code-inspector/webpack': 1.4.5 transitivePeerDependencies: - supports-color - '@code-inspector/vite@1.4.4': + '@code-inspector/vite@1.4.5': dependencies: - '@code-inspector/core': 1.4.4 + '@code-inspector/core': 1.4.5 chalk: 4.1.1 transitivePeerDependencies: - supports-color - '@code-inspector/webpack@1.4.4': + '@code-inspector/webpack@1.4.5': dependencies: - '@code-inspector/core': 1.4.4 + '@code-inspector/core': 1.4.5 transitivePeerDependencies: - supports-color - '@csstools/color-helpers@6.0.2': {} + '@colors/colors@1.5.0': + optional: true - '@csstools/css-calc@3.1.1(@csstools/css-parser-algorithms@4.0.0(@csstools/css-tokenizer@4.0.0))(@csstools/css-tokenizer@4.0.0)': + '@cucumber/ci-environment@13.0.0': {} + + '@cucumber/cucumber-expressions@19.0.0': dependencies: - '@csstools/css-parser-algorithms': 4.0.0(@csstools/css-tokenizer@4.0.0) - '@csstools/css-tokenizer': 4.0.0 + regexp-match-indices: 1.0.2 - '@csstools/css-color-parser@4.0.2(@csstools/css-parser-algorithms@4.0.0(@csstools/css-tokenizer@4.0.0))(@csstools/css-tokenizer@4.0.0)': + '@cucumber/cucumber@12.7.0': dependencies: - '@csstools/color-helpers': 6.0.2 - '@csstools/css-calc': 3.1.1(@csstools/css-parser-algorithms@4.0.0(@csstools/css-tokenizer@4.0.0))(@csstools/css-tokenizer@4.0.0) - '@csstools/css-parser-algorithms': 4.0.0(@csstools/css-tokenizer@4.0.0) - '@csstools/css-tokenizer': 4.0.0 + '@cucumber/ci-environment': 13.0.0 + '@cucumber/cucumber-expressions': 19.0.0 + '@cucumber/gherkin': 38.0.0 + '@cucumber/gherkin-streams': 6.0.0(@cucumber/gherkin@38.0.0)(@cucumber/message-streams@4.0.1(@cucumber/messages@32.0.1))(@cucumber/messages@32.0.1) + '@cucumber/gherkin-utils': 11.0.0 + '@cucumber/html-formatter': 23.0.0(@cucumber/messages@32.0.1) + '@cucumber/junit-xml-formatter': 0.9.0(@cucumber/messages@32.0.1) + '@cucumber/message-streams': 4.0.1(@cucumber/messages@32.0.1) + '@cucumber/messages': 32.0.1 + '@cucumber/pretty-formatter': 1.0.1(@cucumber/cucumber@12.7.0)(@cucumber/messages@32.0.1) + '@cucumber/tag-expressions': 9.1.0 + assertion-error-formatter: 3.0.0 + capital-case: 1.0.4 + chalk: 4.1.2 + cli-table3: 0.6.5 + commander: 14.0.3 + debug: 4.4.3(supports-color@8.1.1) + error-stack-parser: 2.1.4 + figures: 3.2.0 + glob: 13.0.6 + has-ansi: 4.0.1 + indent-string: 4.0.0 + is-installed-globally: 0.4.0 + is-stream: 2.0.1 + knuth-shuffle-seeded: 1.0.6 + lodash.merge: 4.6.2 + lodash.mergewith: 4.6.2 + luxon: 3.7.2 + mime: 3.0.0 + mkdirp: 3.0.1 + mz: 2.7.0 + progress: 2.0.3 + read-package-up: 12.0.0 + semver: 7.7.4 + string-argv: 0.3.1 + supports-color: 8.1.1 + type-fest: 4.41.0 + util-arity: 1.1.0 + yaml: 2.8.3 + yup: 1.7.1 - '@csstools/css-parser-algorithms@4.0.0(@csstools/css-tokenizer@4.0.0)': + '@cucumber/gherkin-streams@6.0.0(@cucumber/gherkin@38.0.0)(@cucumber/message-streams@4.0.1(@cucumber/messages@32.0.1))(@cucumber/messages@32.0.1)': dependencies: - '@csstools/css-tokenizer': 4.0.0 + '@cucumber/gherkin': 38.0.0 + '@cucumber/message-streams': 4.0.1(@cucumber/messages@32.0.1) + '@cucumber/messages': 32.0.1 + commander: 14.0.0 + source-map-support: 0.5.21 - '@csstools/css-syntax-patches-for-csstree@1.1.1(css-tree@3.2.1)': + '@cucumber/gherkin-utils@11.0.0': + dependencies: + '@cucumber/gherkin': 38.0.0 + '@cucumber/messages': 32.0.1 + '@teppeis/multimaps': 3.0.0 + commander: 14.0.2 + source-map-support: 0.5.21 + + '@cucumber/gherkin@38.0.0': + dependencies: + '@cucumber/messages': 32.0.1 + + '@cucumber/html-formatter@23.0.0(@cucumber/messages@32.0.1)': + dependencies: + '@cucumber/messages': 32.0.1 + + '@cucumber/junit-xml-formatter@0.9.0(@cucumber/messages@32.0.1)': + dependencies: + '@cucumber/messages': 32.0.1 + '@cucumber/query': 14.7.0(@cucumber/messages@32.0.1) + '@teppeis/multimaps': 3.0.0 + luxon: 3.7.2 + xmlbuilder: 15.1.1 + + '@cucumber/message-streams@4.0.1(@cucumber/messages@32.0.1)': + dependencies: + '@cucumber/messages': 32.0.1 + + '@cucumber/messages@32.0.1': + dependencies: + class-transformer: 0.5.1 + reflect-metadata: 0.2.2 + + '@cucumber/pretty-formatter@1.0.1(@cucumber/cucumber@12.7.0)(@cucumber/messages@32.0.1)': + dependencies: + '@cucumber/cucumber': 12.7.0 + '@cucumber/messages': 32.0.1 + ansi-styles: 5.2.0 + cli-table3: 0.6.5 + figures: 3.2.0 + ts-dedent: 2.2.0 + + '@cucumber/query@14.7.0(@cucumber/messages@32.0.1)': + dependencies: + '@cucumber/messages': 32.0.1 + '@teppeis/multimaps': 3.0.0 + lodash.sortby: 4.7.0 + + '@cucumber/tag-expressions@9.1.0': {} + + '@e18e/eslint-plugin@0.2.0(eslint@10.1.0(jiti@1.21.7))(oxlint@1.57.0(oxlint-tsgolint@0.17.3))': + dependencies: + eslint-plugin-depend: 1.5.0(eslint@10.1.0(jiti@1.21.7)) optionalDependencies: - css-tree: 3.2.1 + eslint: 10.1.0(jiti@1.21.7) + oxlint: 1.57.0(oxlint-tsgolint@0.17.3) - '@csstools/css-tokenizer@4.0.0': {} - - '@e18e/eslint-plugin@0.2.0(eslint@10.0.3(jiti@1.21.7))(oxlint@1.55.0(oxlint-tsgolint@0.17.0))': - dependencies: - eslint-plugin-depend: 1.5.0(eslint@10.0.3(jiti@1.21.7)) - optionalDependencies: - eslint: 10.0.3(jiti@1.21.7) - oxlint: 1.55.0(oxlint-tsgolint@0.17.0) - - '@egoist/tailwindcss-icons@1.9.2(tailwindcss@3.4.19(tsx@4.21.0)(yaml@2.8.2))': + '@egoist/tailwindcss-icons@1.9.2(tailwindcss@3.4.19(tsx@4.21.0)(yaml@2.8.3))': dependencies: '@iconify/utils': 3.1.0 - tailwindcss: 3.4.19(tsx@4.21.0)(yaml@2.8.2) + tailwindcss: 3.4.19(tsx@4.21.0)(yaml@2.8.3) - '@emnapi/core@1.9.0': + '@emnapi/core@1.9.1': dependencies: '@emnapi/wasi-threads': 1.2.0 tslib: 2.8.1 optional: true - '@emnapi/runtime@1.9.0': + '@emnapi/runtime@1.9.1': dependencies: tslib: 2.8.1 optional: true @@ -8456,7 +9691,7 @@ snapshots: '@es-joy/jsdoccomment@0.84.0': dependencies: '@types/estree': 1.0.8 - '@typescript-eslint/types': 8.57.1 + '@typescript-eslint/types': 8.57.2 comment-parser: 1.4.5 esquery: 1.7.0 jsdoc-type-pratt-parser: 7.1.1 @@ -8541,15 +9776,20 @@ snapshots: '@esbuild/win32-x64@0.27.2': optional: true - '@eslint-community/eslint-plugin-eslint-comments@4.7.1(eslint@10.0.3(jiti@1.21.7))': + '@eslint-community/eslint-plugin-eslint-comments@4.7.1(eslint@10.1.0(jiti@1.21.7))': dependencies: escape-string-regexp: 4.0.0 - eslint: 10.0.3(jiti@1.21.7) + eslint: 10.1.0(jiti@1.21.7) ignore: 7.0.5 - '@eslint-community/eslint-utils@4.9.1(eslint@10.0.3(jiti@1.21.7))': + '@eslint-community/eslint-utils@4.9.1(eslint@10.1.0(jiti@1.21.7))': dependencies: - eslint: 10.0.3(jiti@1.21.7) + eslint: 10.1.0(jiti@1.21.7) + eslint-visitor-keys: 3.4.3 + + '@eslint-community/eslint-utils@4.9.1(eslint@10.1.0(jiti@2.6.1))': + dependencies: + eslint: 10.1.0(jiti@2.6.1) eslint-visitor-keys: 3.4.3 '@eslint-community/eslint-utils@4.9.1(eslint@9.27.0(jiti@1.21.7))': @@ -8559,90 +9799,82 @@ snapshots: '@eslint-community/regexpp@4.12.2': {} - '@eslint-react/ast@2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3)': + '@eslint-react/ast@3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3)': dependencies: - '@eslint-react/eff': 2.13.0 - '@typescript-eslint/types': 8.57.1 - '@typescript-eslint/typescript-estree': 8.57.1(typescript@5.9.3) - '@typescript-eslint/utils': 8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - eslint: 10.0.3(jiti@1.21.7) + '@typescript-eslint/types': 8.57.2 + '@typescript-eslint/typescript-estree': 8.57.2(typescript@5.9.3) + '@typescript-eslint/utils': 8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + eslint: 10.1.0(jiti@1.21.7) string-ts: 2.3.1 typescript: 5.9.3 transitivePeerDependencies: - supports-color - '@eslint-react/core@2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3)': + '@eslint-react/core@3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3)': dependencies: - '@eslint-react/ast': 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/eff': 2.13.0 - '@eslint-react/shared': 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/var': 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@typescript-eslint/scope-manager': 8.57.1 - '@typescript-eslint/types': 8.57.1 - '@typescript-eslint/utils': 8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - eslint: 10.0.3(jiti@1.21.7) + '@eslint-react/ast': 3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/shared': 3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/var': 3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/scope-manager': 8.57.2 + '@typescript-eslint/types': 8.57.2 + '@typescript-eslint/utils': 8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + eslint: 10.1.0(jiti@1.21.7) ts-pattern: 5.9.0 typescript: 5.9.3 transitivePeerDependencies: - supports-color - '@eslint-react/eff@2.13.0': {} - - '@eslint-react/eslint-plugin@2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3)': + '@eslint-react/eslint-plugin@3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3)': dependencies: - '@eslint-react/eff': 2.13.0 - '@eslint-react/shared': 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@typescript-eslint/scope-manager': 8.57.1 - '@typescript-eslint/type-utils': 8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@typescript-eslint/types': 8.57.1 - '@typescript-eslint/utils': 8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - eslint: 10.0.3(jiti@1.21.7) - eslint-plugin-react-dom: 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - eslint-plugin-react-hooks-extra: 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - eslint-plugin-react-naming-convention: 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - eslint-plugin-react-rsc: 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - eslint-plugin-react-web-api: 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - eslint-plugin-react-x: 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - ts-api-utils: 2.4.0(typescript@5.9.3) + '@eslint-react/shared': 3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/scope-manager': 8.57.2 + '@typescript-eslint/type-utils': 8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/types': 8.57.2 + '@typescript-eslint/utils': 8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + eslint: 10.1.0(jiti@1.21.7) + eslint-plugin-react-dom: 3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + eslint-plugin-react-naming-convention: 3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + eslint-plugin-react-rsc: 3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + eslint-plugin-react-web-api: 3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + eslint-plugin-react-x: 3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + ts-api-utils: 2.5.0(typescript@5.9.3) typescript: 5.9.3 transitivePeerDependencies: - supports-color - '@eslint-react/shared@2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3)': + '@eslint-react/shared@3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3)': dependencies: - '@eslint-react/eff': 2.13.0 - '@typescript-eslint/utils': 8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - eslint: 10.0.3(jiti@1.21.7) + '@typescript-eslint/utils': 8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + eslint: 10.1.0(jiti@1.21.7) ts-pattern: 5.9.0 typescript: 5.9.3 zod: 4.3.6 transitivePeerDependencies: - supports-color - '@eslint-react/var@2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3)': + '@eslint-react/var@3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3)': dependencies: - '@eslint-react/ast': 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/eff': 2.13.0 - '@eslint-react/shared': 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@typescript-eslint/scope-manager': 8.57.1 - '@typescript-eslint/types': 8.57.1 - '@typescript-eslint/utils': 8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - eslint: 10.0.3(jiti@1.21.7) + '@eslint-react/ast': 3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/shared': 3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/scope-manager': 8.57.2 + '@typescript-eslint/types': 8.57.2 + '@typescript-eslint/utils': 8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + eslint: 10.1.0(jiti@1.21.7) ts-pattern: 5.9.0 typescript: 5.9.3 transitivePeerDependencies: - supports-color - '@eslint/compat@2.0.3(eslint@10.0.3(jiti@1.21.7))': + '@eslint/compat@2.0.3(eslint@10.1.0(jiti@1.21.7))': dependencies: '@eslint/core': 1.1.1 optionalDependencies: - eslint: 10.0.3(jiti@1.21.7) + eslint: 10.1.0(jiti@1.21.7) '@eslint/config-array@0.20.1': dependencies: '@eslint/object-schema': 2.1.7 - debug: 4.4.3 + debug: 4.4.3(supports-color@8.1.1) minimatch: 3.1.5 transitivePeerDependencies: - supports-color @@ -8650,7 +9882,7 @@ snapshots: '@eslint/config-array@0.23.3': dependencies: '@eslint/object-schema': 3.0.3 - debug: 4.4.3 + debug: 4.4.3(supports-color@8.1.1) minimatch: 10.2.4 transitivePeerDependencies: - supports-color @@ -8685,7 +9917,7 @@ snapshots: '@eslint/eslintrc@3.3.5': dependencies: ajv: 6.14.0 - debug: 4.4.3 + debug: 4.4.3(supports-color@8.1.1) espree: 10.4.0 globals: 14.0.0 ignore: 5.3.2 @@ -8696,6 +9928,10 @@ snapshots: transitivePeerDependencies: - supports-color + '@eslint/js@10.0.1(eslint@10.1.0(jiti@2.6.1))': + optionalDependencies: + eslint: 10.1.0(jiti@2.6.1) + '@eslint/js@9.27.0': {} '@eslint/markdown@7.5.1': @@ -8731,8 +9967,6 @@ snapshots: '@eslint/core': 1.1.1 levn: 0.4.1 - '@exodus/bytes@1.15.0': {} - '@floating-ui/core@1.7.5': dependencies: '@floating-ui/utils': 0.2.11 @@ -8788,9 +10022,9 @@ snapshots: dependencies: react: 19.2.4 - '@hono/node-server@1.19.11(hono@4.12.8)': + '@hono/node-server@1.19.11(hono@4.12.9)': dependencies: - hono: 4.12.8 + hono: 4.12.9 '@humanfs/core@0.19.1': {} @@ -8832,11 +10066,11 @@ snapshots: '@antfu/install-pkg': 1.1.0 '@antfu/utils': 8.1.1 '@iconify/types': 2.0.0 - debug: 4.4.3 + debug: 4.4.3(supports-color@8.1.1) globals: 15.15.0 kolorist: 1.8.0 local-pkg: 1.1.2 - mlly: 1.8.1 + mlly: 1.8.2 transitivePeerDependencies: - supports-color @@ -8844,7 +10078,7 @@ snapshots: dependencies: '@antfu/install-pkg': 1.1.0 '@iconify/types': 2.0.0 - mlly: 1.8.1 + mlly: 1.8.2 '@img/colour@1.1.0': {} @@ -8930,7 +10164,7 @@ snapshots: '@img/sharp-wasm32@0.34.5': dependencies: - '@emnapi/runtime': 1.9.0 + '@emnapi/runtime': 1.9.1 optional: true '@img/sharp-win32-arm64@0.34.5': @@ -8946,11 +10180,11 @@ snapshots: dependencies: minipass: 7.1.3 - '@joshwooding/vite-plugin-react-docgen-typescript@0.6.4(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(typescript@5.9.3)': + '@joshwooding/vite-plugin-react-docgen-typescript@0.6.4(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(typescript@5.9.3)': dependencies: glob: 13.0.6 react-docgen-typescript: 2.4.0(typescript@5.9.3) - vite: '@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' + vite: '@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' optionalDependencies: typescript: 5.9.3 @@ -8978,157 +10212,161 @@ snapshots: '@jridgewell/resolve-uri': 3.1.2 '@jridgewell/sourcemap-codec': 1.5.5 - '@lexical/clipboard@0.41.0': + '@lexical/clipboard@0.42.0': dependencies: - '@lexical/html': 0.41.0 - '@lexical/list': 0.41.0 - '@lexical/selection': 0.41.0 - '@lexical/utils': 0.41.0 - lexical: 0.41.0 + '@lexical/html': 0.42.0 + '@lexical/list': 0.42.0 + '@lexical/selection': 0.42.0 + '@lexical/utils': 0.42.0 + lexical: 0.42.0 - '@lexical/devtools-core@0.41.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4)': + '@lexical/code-core@0.42.0': dependencies: - '@lexical/html': 0.41.0 - '@lexical/link': 0.41.0 - '@lexical/mark': 0.41.0 - '@lexical/table': 0.41.0 - '@lexical/utils': 0.41.0 - lexical: 0.41.0 + lexical: 0.42.0 + + '@lexical/devtools-core@0.42.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4)': + dependencies: + '@lexical/html': 0.42.0 + '@lexical/link': 0.42.0 + '@lexical/mark': 0.42.0 + '@lexical/table': 0.42.0 + '@lexical/utils': 0.42.0 + lexical: 0.42.0 react: 19.2.4 react-dom: 19.2.4(react@19.2.4) - '@lexical/dragon@0.41.0': + '@lexical/dragon@0.42.0': dependencies: - '@lexical/extension': 0.41.0 - lexical: 0.41.0 + '@lexical/extension': 0.42.0 + lexical: 0.42.0 - '@lexical/extension@0.41.0': + '@lexical/extension@0.42.0': dependencies: - '@lexical/utils': 0.41.0 + '@lexical/utils': 0.42.0 '@preact/signals-core': 1.14.0 - lexical: 0.41.0 + lexical: 0.42.0 - '@lexical/hashtag@0.41.0': + '@lexical/hashtag@0.42.0': dependencies: - '@lexical/text': 0.41.0 - '@lexical/utils': 0.41.0 - lexical: 0.41.0 + '@lexical/text': 0.42.0 + '@lexical/utils': 0.42.0 + lexical: 0.42.0 - '@lexical/history@0.41.0': + '@lexical/history@0.42.0': dependencies: - '@lexical/extension': 0.41.0 - '@lexical/utils': 0.41.0 - lexical: 0.41.0 + '@lexical/extension': 0.42.0 + '@lexical/utils': 0.42.0 + lexical: 0.42.0 - '@lexical/html@0.41.0': + '@lexical/html@0.42.0': dependencies: - '@lexical/selection': 0.41.0 - '@lexical/utils': 0.41.0 - lexical: 0.41.0 + '@lexical/selection': 0.42.0 + '@lexical/utils': 0.42.0 + lexical: 0.42.0 - '@lexical/link@0.41.0': + '@lexical/link@0.42.0': dependencies: - '@lexical/extension': 0.41.0 - '@lexical/utils': 0.41.0 - lexical: 0.41.0 + '@lexical/extension': 0.42.0 + '@lexical/utils': 0.42.0 + lexical: 0.42.0 - '@lexical/list@0.41.0': + '@lexical/list@0.42.0': dependencies: - '@lexical/extension': 0.41.0 - '@lexical/selection': 0.41.0 - '@lexical/utils': 0.41.0 - lexical: 0.41.0 + '@lexical/extension': 0.42.0 + '@lexical/selection': 0.42.0 + '@lexical/utils': 0.42.0 + lexical: 0.42.0 - '@lexical/mark@0.41.0': + '@lexical/mark@0.42.0': dependencies: - '@lexical/utils': 0.41.0 - lexical: 0.41.0 + '@lexical/utils': 0.42.0 + lexical: 0.42.0 - '@lexical/markdown@0.41.0': + '@lexical/markdown@0.42.0': dependencies: - '@lexical/code': lexical-code-no-prism@0.41.0(@lexical/utils@0.41.0)(lexical@0.41.0) - '@lexical/link': 0.41.0 - '@lexical/list': 0.41.0 - '@lexical/rich-text': 0.41.0 - '@lexical/text': 0.41.0 - '@lexical/utils': 0.41.0 - lexical: 0.41.0 + '@lexical/code-core': 0.42.0 + '@lexical/link': 0.42.0 + '@lexical/list': 0.42.0 + '@lexical/rich-text': 0.42.0 + '@lexical/text': 0.42.0 + '@lexical/utils': 0.42.0 + lexical: 0.42.0 - '@lexical/offset@0.41.0': + '@lexical/offset@0.42.0': dependencies: - lexical: 0.41.0 + lexical: 0.42.0 - '@lexical/overflow@0.41.0': + '@lexical/overflow@0.42.0': dependencies: - lexical: 0.41.0 + lexical: 0.42.0 - '@lexical/plain-text@0.41.0': + '@lexical/plain-text@0.42.0': dependencies: - '@lexical/clipboard': 0.41.0 - '@lexical/dragon': 0.41.0 - '@lexical/selection': 0.41.0 - '@lexical/utils': 0.41.0 - lexical: 0.41.0 + '@lexical/clipboard': 0.42.0 + '@lexical/dragon': 0.42.0 + '@lexical/selection': 0.42.0 + '@lexical/utils': 0.42.0 + lexical: 0.42.0 - '@lexical/react@0.41.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(yjs@13.6.30)': + '@lexical/react@0.42.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(yjs@13.6.30)': dependencies: '@floating-ui/react': 0.27.19(react-dom@19.2.4(react@19.2.4))(react@19.2.4) - '@lexical/devtools-core': 0.41.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4) - '@lexical/dragon': 0.41.0 - '@lexical/extension': 0.41.0 - '@lexical/hashtag': 0.41.0 - '@lexical/history': 0.41.0 - '@lexical/link': 0.41.0 - '@lexical/list': 0.41.0 - '@lexical/mark': 0.41.0 - '@lexical/markdown': 0.41.0 - '@lexical/overflow': 0.41.0 - '@lexical/plain-text': 0.41.0 - '@lexical/rich-text': 0.41.0 - '@lexical/table': 0.41.0 - '@lexical/text': 0.41.0 - '@lexical/utils': 0.41.0 - '@lexical/yjs': 0.41.0(yjs@13.6.30) - lexical: 0.41.0 + '@lexical/devtools-core': 0.42.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4) + '@lexical/dragon': 0.42.0 + '@lexical/extension': 0.42.0 + '@lexical/hashtag': 0.42.0 + '@lexical/history': 0.42.0 + '@lexical/link': 0.42.0 + '@lexical/list': 0.42.0 + '@lexical/mark': 0.42.0 + '@lexical/markdown': 0.42.0 + '@lexical/overflow': 0.42.0 + '@lexical/plain-text': 0.42.0 + '@lexical/rich-text': 0.42.0 + '@lexical/table': 0.42.0 + '@lexical/text': 0.42.0 + '@lexical/utils': 0.42.0 + '@lexical/yjs': 0.42.0(yjs@13.6.30) + lexical: 0.42.0 react: 19.2.4 react-dom: 19.2.4(react@19.2.4) react-error-boundary: 6.1.1(react@19.2.4) transitivePeerDependencies: - yjs - '@lexical/rich-text@0.41.0': + '@lexical/rich-text@0.42.0': dependencies: - '@lexical/clipboard': 0.41.0 - '@lexical/dragon': 0.41.0 - '@lexical/selection': 0.41.0 - '@lexical/utils': 0.41.0 - lexical: 0.41.0 + '@lexical/clipboard': 0.42.0 + '@lexical/dragon': 0.42.0 + '@lexical/selection': 0.42.0 + '@lexical/utils': 0.42.0 + lexical: 0.42.0 - '@lexical/selection@0.41.0': + '@lexical/selection@0.42.0': dependencies: - lexical: 0.41.0 + lexical: 0.42.0 - '@lexical/table@0.41.0': + '@lexical/table@0.42.0': dependencies: - '@lexical/clipboard': 0.41.0 - '@lexical/extension': 0.41.0 - '@lexical/utils': 0.41.0 - lexical: 0.41.0 + '@lexical/clipboard': 0.42.0 + '@lexical/extension': 0.42.0 + '@lexical/utils': 0.42.0 + lexical: 0.42.0 - '@lexical/text@0.41.0': + '@lexical/text@0.42.0': dependencies: - lexical: 0.41.0 + lexical: 0.42.0 - '@lexical/utils@0.41.0': + '@lexical/utils@0.42.0': dependencies: - '@lexical/selection': 0.41.0 - lexical: 0.41.0 + '@lexical/selection': 0.42.0 + lexical: 0.42.0 - '@lexical/yjs@0.41.0(yjs@13.6.30)': + '@lexical/yjs@0.42.0(yjs@13.6.30)': dependencies: - '@lexical/offset': 0.41.0 - '@lexical/selection': 0.41.0 - lexical: 0.41.0 + '@lexical/offset': 0.42.0 + '@lexical/selection': 0.42.0 + lexical: 0.42.0 yjs: 13.6.30 '@mdx-js/loader@3.1.1(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3))': @@ -9201,19 +10439,10 @@ snapshots: react: 19.2.4 react-dom: 19.2.4(react@19.2.4) - '@mswjs/interceptors@0.41.3': + '@napi-rs/wasm-runtime@1.1.2(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1)': dependencies: - '@open-draft/deferred-promise': 2.2.0 - '@open-draft/logger': 0.3.0 - '@open-draft/until': 2.1.0 - is-node-process: 1.2.0 - outvariant: 1.4.3 - strict-event-emitter: 0.5.1 - - '@napi-rs/wasm-runtime@1.1.1': - dependencies: - '@emnapi/core': 1.9.0 - '@emnapi/runtime': 1.9.0 + '@emnapi/core': 1.9.1 + '@emnapi/runtime': 1.9.1 '@tybys/wasm-util': 0.10.1 optional: true @@ -9221,41 +10450,41 @@ snapshots: '@next/env@16.0.0': {} - '@next/env@16.2.0': {} + '@next/env@16.2.1': {} - '@next/eslint-plugin-next@16.2.0': + '@next/eslint-plugin-next@16.2.1': dependencies: fast-glob: 3.3.1 - '@next/mdx@16.2.0(@mdx-js/loader@3.1.1(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)))(@mdx-js/react@3.1.1(@types/react@19.2.14)(react@19.2.4))': + '@next/mdx@16.2.1(@mdx-js/loader@3.1.1(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)))(@mdx-js/react@3.1.1(@types/react@19.2.14)(react@19.2.4))': dependencies: source-map: 0.7.6 optionalDependencies: '@mdx-js/loader': 3.1.1(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) '@mdx-js/react': 3.1.1(@types/react@19.2.14)(react@19.2.4) - '@next/swc-darwin-arm64@16.2.0': + '@next/swc-darwin-arm64@16.2.1': optional: true - '@next/swc-darwin-x64@16.2.0': + '@next/swc-darwin-x64@16.2.1': optional: true - '@next/swc-linux-arm64-gnu@16.2.0': + '@next/swc-linux-arm64-gnu@16.2.1': optional: true - '@next/swc-linux-arm64-musl@16.2.0': + '@next/swc-linux-arm64-musl@16.2.1': optional: true - '@next/swc-linux-x64-gnu@16.2.0': + '@next/swc-linux-x64-gnu@16.2.1': optional: true - '@next/swc-linux-x64-musl@16.2.0': + '@next/swc-linux-x64-musl@16.2.1': optional: true - '@next/swc-win32-arm64-msvc@16.2.0': + '@next/swc-win32-arm64-msvc@16.2.1': optional: true - '@next/swc-win32-x64-msvc@16.2.0': + '@next/swc-win32-x64-msvc@16.2.1': optional: true '@nodelib/fs.scandir@2.1.5': @@ -9270,128 +10499,146 @@ snapshots: '@nodelib/fs.scandir': 2.1.5 fastq: 1.20.1 + '@nolyfill/hasown@1.0.44': {} + '@nolyfill/is-core-module@1.0.39': {} '@nolyfill/safer-buffer@1.0.44': {} '@nolyfill/side-channel@1.0.44': {} - '@octokit/auth-token@6.0.0': {} - - '@octokit/core@7.0.6': + '@orpc/client@1.13.13': dependencies: - '@octokit/auth-token': 6.0.0 - '@octokit/graphql': 9.0.3 - '@octokit/request': 10.0.8 - '@octokit/request-error': 7.1.0 - '@octokit/types': 16.0.0 - before-after-hook: 4.0.0 - universal-user-agent: 7.0.3 - - '@octokit/endpoint@11.0.3': - dependencies: - '@octokit/types': 16.0.0 - universal-user-agent: 7.0.3 - - '@octokit/graphql@9.0.3': - dependencies: - '@octokit/request': 10.0.8 - '@octokit/types': 16.0.0 - universal-user-agent: 7.0.3 - - '@octokit/openapi-types@27.0.0': {} - - '@octokit/request-error@7.1.0': - dependencies: - '@octokit/types': 16.0.0 - - '@octokit/request@10.0.8': - dependencies: - '@octokit/endpoint': 11.0.3 - '@octokit/request-error': 7.1.0 - '@octokit/types': 16.0.0 - fast-content-type-parse: 3.0.0 - json-with-bigint: 3.5.7 - universal-user-agent: 7.0.3 - - '@octokit/types@16.0.0': - dependencies: - '@octokit/openapi-types': 27.0.0 - - '@open-draft/deferred-promise@2.2.0': {} - - '@open-draft/logger@0.3.0': - dependencies: - is-node-process: 1.2.0 - outvariant: 1.4.3 - - '@open-draft/until@2.1.0': {} - - '@orpc/client@1.13.8': - dependencies: - '@orpc/shared': 1.13.8 - '@orpc/standard-server': 1.13.8 - '@orpc/standard-server-fetch': 1.13.8 - '@orpc/standard-server-peer': 1.13.8 + '@orpc/shared': 1.13.13 + '@orpc/standard-server': 1.13.13 + '@orpc/standard-server-fetch': 1.13.13 + '@orpc/standard-server-peer': 1.13.13 transitivePeerDependencies: - '@opentelemetry/api' - '@orpc/contract@1.13.8': + '@orpc/contract@1.13.13': dependencies: - '@orpc/client': 1.13.8 - '@orpc/shared': 1.13.8 + '@orpc/client': 1.13.13 + '@orpc/shared': 1.13.13 '@standard-schema/spec': 1.1.0 openapi-types: 12.1.3 transitivePeerDependencies: - '@opentelemetry/api' - '@orpc/openapi-client@1.13.8': + '@orpc/openapi-client@1.13.13': dependencies: - '@orpc/client': 1.13.8 - '@orpc/contract': 1.13.8 - '@orpc/shared': 1.13.8 - '@orpc/standard-server': 1.13.8 + '@orpc/client': 1.13.13 + '@orpc/contract': 1.13.13 + '@orpc/shared': 1.13.13 + '@orpc/standard-server': 1.13.13 transitivePeerDependencies: - '@opentelemetry/api' - '@orpc/shared@1.13.8': + '@orpc/shared@1.13.13': dependencies: radash: 12.1.1 - type-fest: 5.4.4 + type-fest: 5.5.0 - '@orpc/standard-server-fetch@1.13.8': + '@orpc/standard-server-fetch@1.13.13': dependencies: - '@orpc/shared': 1.13.8 - '@orpc/standard-server': 1.13.8 + '@orpc/shared': 1.13.13 + '@orpc/standard-server': 1.13.13 transitivePeerDependencies: - '@opentelemetry/api' - '@orpc/standard-server-peer@1.13.8': + '@orpc/standard-server-peer@1.13.13': dependencies: - '@orpc/shared': 1.13.8 - '@orpc/standard-server': 1.13.8 + '@orpc/shared': 1.13.13 + '@orpc/standard-server': 1.13.13 transitivePeerDependencies: - '@opentelemetry/api' - '@orpc/standard-server@1.13.8': + '@orpc/standard-server@1.13.13': dependencies: - '@orpc/shared': 1.13.8 + '@orpc/shared': 1.13.13 transitivePeerDependencies: - '@opentelemetry/api' - '@orpc/tanstack-query@1.13.8(@orpc/client@1.13.8)(@tanstack/query-core@5.91.0)': + '@orpc/tanstack-query@1.13.13(@orpc/client@1.13.13)(@tanstack/query-core@5.95.2)': dependencies: - '@orpc/client': 1.13.8 - '@orpc/shared': 1.13.8 - '@tanstack/query-core': 5.91.0 + '@orpc/client': 1.13.13 + '@orpc/shared': 1.13.13 + '@tanstack/query-core': 5.95.2 transitivePeerDependencies: - '@opentelemetry/api' '@ota-meshi/ast-token-store@0.3.0': {} - '@oxc-project/runtime@0.115.0': {} + '@oxc-parser/binding-android-arm-eabi@0.121.0': + optional: true - '@oxc-project/types@0.115.0': {} + '@oxc-parser/binding-android-arm64@0.121.0': + optional: true + + '@oxc-parser/binding-darwin-arm64@0.121.0': + optional: true + + '@oxc-parser/binding-darwin-x64@0.121.0': + optional: true + + '@oxc-parser/binding-freebsd-x64@0.121.0': + optional: true + + '@oxc-parser/binding-linux-arm-gnueabihf@0.121.0': + optional: true + + '@oxc-parser/binding-linux-arm-musleabihf@0.121.0': + optional: true + + '@oxc-parser/binding-linux-arm64-gnu@0.121.0': + optional: true + + '@oxc-parser/binding-linux-arm64-musl@0.121.0': + optional: true + + '@oxc-parser/binding-linux-ppc64-gnu@0.121.0': + optional: true + + '@oxc-parser/binding-linux-riscv64-gnu@0.121.0': + optional: true + + '@oxc-parser/binding-linux-riscv64-musl@0.121.0': + optional: true + + '@oxc-parser/binding-linux-s390x-gnu@0.121.0': + optional: true + + '@oxc-parser/binding-linux-x64-gnu@0.121.0': + optional: true + + '@oxc-parser/binding-linux-x64-musl@0.121.0': + optional: true + + '@oxc-parser/binding-openharmony-arm64@0.121.0': + optional: true + + '@oxc-parser/binding-wasm32-wasi@0.121.0(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1)': + dependencies: + '@napi-rs/wasm-runtime': 1.1.2(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1) + transitivePeerDependencies: + - '@emnapi/core' + - '@emnapi/runtime' + optional: true + + '@oxc-parser/binding-win32-arm64-msvc@0.121.0': + optional: true + + '@oxc-parser/binding-win32-ia32-msvc@0.121.0': + optional: true + + '@oxc-parser/binding-win32-x64-msvc@0.121.0': + optional: true + + '@oxc-project/runtime@0.121.0': {} + + '@oxc-project/types@0.121.0': {} + + '@oxc-project/types@0.122.0': {} '@oxc-resolver/binding-android-arm-eabi@11.19.1': optional: true @@ -9441,9 +10688,12 @@ snapshots: '@oxc-resolver/binding-openharmony-arm64@11.19.1': optional: true - '@oxc-resolver/binding-wasm32-wasi@11.19.1': + '@oxc-resolver/binding-wasm32-wasi@11.19.1(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1)': dependencies: - '@napi-rs/wasm-runtime': 1.1.1 + '@napi-rs/wasm-runtime': 1.1.2(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1) + transitivePeerDependencies: + - '@emnapi/core' + - '@emnapi/runtime' optional: true '@oxc-resolver/binding-win32-arm64-msvc@11.19.1': @@ -9455,136 +10705,136 @@ snapshots: '@oxc-resolver/binding-win32-x64-msvc@11.19.1': optional: true - '@oxfmt/binding-android-arm-eabi@0.40.0': + '@oxfmt/binding-android-arm-eabi@0.42.0': optional: true - '@oxfmt/binding-android-arm64@0.40.0': + '@oxfmt/binding-android-arm64@0.42.0': optional: true - '@oxfmt/binding-darwin-arm64@0.40.0': + '@oxfmt/binding-darwin-arm64@0.42.0': optional: true - '@oxfmt/binding-darwin-x64@0.40.0': + '@oxfmt/binding-darwin-x64@0.42.0': optional: true - '@oxfmt/binding-freebsd-x64@0.40.0': + '@oxfmt/binding-freebsd-x64@0.42.0': optional: true - '@oxfmt/binding-linux-arm-gnueabihf@0.40.0': + '@oxfmt/binding-linux-arm-gnueabihf@0.42.0': optional: true - '@oxfmt/binding-linux-arm-musleabihf@0.40.0': + '@oxfmt/binding-linux-arm-musleabihf@0.42.0': optional: true - '@oxfmt/binding-linux-arm64-gnu@0.40.0': + '@oxfmt/binding-linux-arm64-gnu@0.42.0': optional: true - '@oxfmt/binding-linux-arm64-musl@0.40.0': + '@oxfmt/binding-linux-arm64-musl@0.42.0': optional: true - '@oxfmt/binding-linux-ppc64-gnu@0.40.0': + '@oxfmt/binding-linux-ppc64-gnu@0.42.0': optional: true - '@oxfmt/binding-linux-riscv64-gnu@0.40.0': + '@oxfmt/binding-linux-riscv64-gnu@0.42.0': optional: true - '@oxfmt/binding-linux-riscv64-musl@0.40.0': + '@oxfmt/binding-linux-riscv64-musl@0.42.0': optional: true - '@oxfmt/binding-linux-s390x-gnu@0.40.0': + '@oxfmt/binding-linux-s390x-gnu@0.42.0': optional: true - '@oxfmt/binding-linux-x64-gnu@0.40.0': + '@oxfmt/binding-linux-x64-gnu@0.42.0': optional: true - '@oxfmt/binding-linux-x64-musl@0.40.0': + '@oxfmt/binding-linux-x64-musl@0.42.0': optional: true - '@oxfmt/binding-openharmony-arm64@0.40.0': + '@oxfmt/binding-openharmony-arm64@0.42.0': optional: true - '@oxfmt/binding-win32-arm64-msvc@0.40.0': + '@oxfmt/binding-win32-arm64-msvc@0.42.0': optional: true - '@oxfmt/binding-win32-ia32-msvc@0.40.0': + '@oxfmt/binding-win32-ia32-msvc@0.42.0': optional: true - '@oxfmt/binding-win32-x64-msvc@0.40.0': + '@oxfmt/binding-win32-x64-msvc@0.42.0': optional: true - '@oxlint-tsgolint/darwin-arm64@0.17.0': + '@oxlint-tsgolint/darwin-arm64@0.17.3': optional: true - '@oxlint-tsgolint/darwin-x64@0.17.0': + '@oxlint-tsgolint/darwin-x64@0.17.3': optional: true - '@oxlint-tsgolint/linux-arm64@0.17.0': + '@oxlint-tsgolint/linux-arm64@0.17.3': optional: true - '@oxlint-tsgolint/linux-x64@0.17.0': + '@oxlint-tsgolint/linux-x64@0.17.3': optional: true - '@oxlint-tsgolint/win32-arm64@0.17.0': + '@oxlint-tsgolint/win32-arm64@0.17.3': optional: true - '@oxlint-tsgolint/win32-x64@0.17.0': + '@oxlint-tsgolint/win32-x64@0.17.3': optional: true - '@oxlint/binding-android-arm-eabi@1.55.0': + '@oxlint/binding-android-arm-eabi@1.57.0': optional: true - '@oxlint/binding-android-arm64@1.55.0': + '@oxlint/binding-android-arm64@1.57.0': optional: true - '@oxlint/binding-darwin-arm64@1.55.0': + '@oxlint/binding-darwin-arm64@1.57.0': optional: true - '@oxlint/binding-darwin-x64@1.55.0': + '@oxlint/binding-darwin-x64@1.57.0': optional: true - '@oxlint/binding-freebsd-x64@1.55.0': + '@oxlint/binding-freebsd-x64@1.57.0': optional: true - '@oxlint/binding-linux-arm-gnueabihf@1.55.0': + '@oxlint/binding-linux-arm-gnueabihf@1.57.0': optional: true - '@oxlint/binding-linux-arm-musleabihf@1.55.0': + '@oxlint/binding-linux-arm-musleabihf@1.57.0': optional: true - '@oxlint/binding-linux-arm64-gnu@1.55.0': + '@oxlint/binding-linux-arm64-gnu@1.57.0': optional: true - '@oxlint/binding-linux-arm64-musl@1.55.0': + '@oxlint/binding-linux-arm64-musl@1.57.0': optional: true - '@oxlint/binding-linux-ppc64-gnu@1.55.0': + '@oxlint/binding-linux-ppc64-gnu@1.57.0': optional: true - '@oxlint/binding-linux-riscv64-gnu@1.55.0': + '@oxlint/binding-linux-riscv64-gnu@1.57.0': optional: true - '@oxlint/binding-linux-riscv64-musl@1.55.0': + '@oxlint/binding-linux-riscv64-musl@1.57.0': optional: true - '@oxlint/binding-linux-s390x-gnu@1.55.0': + '@oxlint/binding-linux-s390x-gnu@1.57.0': optional: true - '@oxlint/binding-linux-x64-gnu@1.55.0': + '@oxlint/binding-linux-x64-gnu@1.57.0': optional: true - '@oxlint/binding-linux-x64-musl@1.55.0': + '@oxlint/binding-linux-x64-musl@1.57.0': optional: true - '@oxlint/binding-openharmony-arm64@1.55.0': + '@oxlint/binding-openharmony-arm64@1.57.0': optional: true - '@oxlint/binding-win32-arm64-msvc@1.55.0': + '@oxlint/binding-win32-arm64-msvc@1.57.0': optional: true - '@oxlint/binding-win32-ia32-msvc@1.55.0': + '@oxlint/binding-win32-ia32-msvc@1.57.0': optional: true - '@oxlint/binding-win32-x64-msvc@1.55.0': + '@oxlint/binding-win32-x64-msvc@1.57.0': optional: true '@parcel/watcher-android-arm64@2.5.6': @@ -9631,7 +10881,7 @@ snapshots: detect-libc: 2.1.2 is-glob: 4.0.3 node-addon-api: 7.1.1 - picomatch: 4.0.3 + picomatch: 4.0.4 optionalDependencies: '@parcel/watcher-android-arm64': 2.5.6 '@parcel/watcher-darwin-arm64': 2.5.6 @@ -9650,6 +10900,10 @@ snapshots: '@pkgr/core@0.2.9': {} + '@playwright/test@1.58.2': + dependencies: + playwright: 1.58.2 + '@polka/url@1.0.0-next.29': {} '@preact/signals-core@1.14.0': {} @@ -9822,7 +11076,7 @@ snapshots: '@react-aria/interactions': 3.27.1(react-dom@19.2.4(react@19.2.4))(react@19.2.4) '@react-aria/utils': 3.33.1(react-dom@19.2.4(react@19.2.4))(react@19.2.4) '@react-types/shared': 3.33.1(react@19.2.4) - '@swc/helpers': 0.5.19 + '@swc/helpers': 0.5.20 clsx: 2.1.1 react: 19.2.4 react-dom: 19.2.4(react@19.2.4) @@ -9833,13 +11087,13 @@ snapshots: '@react-aria/utils': 3.33.1(react-dom@19.2.4(react@19.2.4))(react@19.2.4) '@react-stately/flags': 3.1.2 '@react-types/shared': 3.33.1(react@19.2.4) - '@swc/helpers': 0.5.19 + '@swc/helpers': 0.5.20 react: 19.2.4 react-dom: 19.2.4(react@19.2.4) '@react-aria/ssr@3.9.10(react@19.2.4)': dependencies: - '@swc/helpers': 0.5.19 + '@swc/helpers': 0.5.20 react: 19.2.4 '@react-aria/utils@3.33.1(react-dom@19.2.4(react@19.2.4))(react@19.2.4)': @@ -9848,18 +11102,18 @@ snapshots: '@react-stately/flags': 3.1.2 '@react-stately/utils': 3.11.0(react@19.2.4) '@react-types/shared': 3.33.1(react@19.2.4) - '@swc/helpers': 0.5.19 + '@swc/helpers': 0.5.20 clsx: 2.1.1 react: 19.2.4 react-dom: 19.2.4(react@19.2.4) '@react-stately/flags@3.1.2': dependencies: - '@swc/helpers': 0.5.19 + '@swc/helpers': 0.5.20 '@react-stately/utils@3.11.0(react@19.2.4)': dependencies: - '@swc/helpers': 0.5.19 + '@swc/helpers': 0.5.20 react: 19.2.4 '@react-types/shared@3.33.1(react@19.2.4)': @@ -9952,6 +11206,58 @@ snapshots: '@rgrove/parse-xml@4.2.0': {} + '@rolldown/binding-android-arm64@1.0.0-rc.12': + optional: true + + '@rolldown/binding-darwin-arm64@1.0.0-rc.12': + optional: true + + '@rolldown/binding-darwin-x64@1.0.0-rc.12': + optional: true + + '@rolldown/binding-freebsd-x64@1.0.0-rc.12': + optional: true + + '@rolldown/binding-linux-arm-gnueabihf@1.0.0-rc.12': + optional: true + + '@rolldown/binding-linux-arm64-gnu@1.0.0-rc.12': + optional: true + + '@rolldown/binding-linux-arm64-musl@1.0.0-rc.12': + optional: true + + '@rolldown/binding-linux-ppc64-gnu@1.0.0-rc.12': + optional: true + + '@rolldown/binding-linux-s390x-gnu@1.0.0-rc.12': + optional: true + + '@rolldown/binding-linux-x64-gnu@1.0.0-rc.12': + optional: true + + '@rolldown/binding-linux-x64-musl@1.0.0-rc.12': + optional: true + + '@rolldown/binding-openharmony-arm64@1.0.0-rc.12': + optional: true + + '@rolldown/binding-wasm32-wasi@1.0.0-rc.12(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1)': + dependencies: + '@napi-rs/wasm-runtime': 1.1.2(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1) + transitivePeerDependencies: + - '@emnapi/core' + - '@emnapi/runtime' + optional: true + + '@rolldown/binding-win32-arm64-msvc@1.0.0-rc.12': + optional: true + + '@rolldown/binding-win32-x64-msvc@1.0.0-rc.12': + optional: true + + '@rolldown/pluginutils@1.0.0-rc.12': {} + '@rolldown/pluginutils@1.0.0-rc.5': {} '@rolldown/pluginutils@1.0.0-rc.7': {} @@ -9967,7 +11273,7 @@ snapshots: dependencies: '@types/estree': 1.0.8 estree-walker: 2.0.2 - picomatch: 4.0.3 + picomatch: 4.0.4 optionalDependencies: rollup: 4.59.0 @@ -10046,38 +11352,38 @@ snapshots: '@rollup/rollup-win32-x64-msvc@4.59.0': optional: true - '@sentry-internal/browser-utils@10.44.0': + '@sentry-internal/browser-utils@10.46.0': dependencies: - '@sentry/core': 10.44.0 + '@sentry/core': 10.46.0 - '@sentry-internal/feedback@10.44.0': + '@sentry-internal/feedback@10.46.0': dependencies: - '@sentry/core': 10.44.0 + '@sentry/core': 10.46.0 - '@sentry-internal/replay-canvas@10.44.0': + '@sentry-internal/replay-canvas@10.46.0': dependencies: - '@sentry-internal/replay': 10.44.0 - '@sentry/core': 10.44.0 + '@sentry-internal/replay': 10.46.0 + '@sentry/core': 10.46.0 - '@sentry-internal/replay@10.44.0': + '@sentry-internal/replay@10.46.0': dependencies: - '@sentry-internal/browser-utils': 10.44.0 - '@sentry/core': 10.44.0 + '@sentry-internal/browser-utils': 10.46.0 + '@sentry/core': 10.46.0 - '@sentry/browser@10.44.0': + '@sentry/browser@10.46.0': dependencies: - '@sentry-internal/browser-utils': 10.44.0 - '@sentry-internal/feedback': 10.44.0 - '@sentry-internal/replay': 10.44.0 - '@sentry-internal/replay-canvas': 10.44.0 - '@sentry/core': 10.44.0 + '@sentry-internal/browser-utils': 10.46.0 + '@sentry-internal/feedback': 10.46.0 + '@sentry-internal/replay': 10.46.0 + '@sentry-internal/replay-canvas': 10.46.0 + '@sentry/core': 10.46.0 - '@sentry/core@10.44.0': {} + '@sentry/core@10.46.0': {} - '@sentry/react@10.44.0(react@19.2.4)': + '@sentry/react@10.46.0(react@19.2.4)': dependencies: - '@sentry/browser': 10.44.0 - '@sentry/core': 10.44.0 + '@sentry/browser': 10.46.0 + '@sentry/core': 10.46.0 react: 19.2.4 '@shuding/opentype.js@1.4.0-beta.0': @@ -10125,15 +11431,15 @@ snapshots: '@standard-schema/spec@1.1.0': {} - '@storybook/addon-docs@10.3.0(@types/react@19.2.14)(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(rollup@4.59.0)(storybook@10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3))': + '@storybook/addon-docs@10.3.3(@types/react@19.2.14)(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(rollup@4.59.0)(storybook@10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3))': dependencies: '@mdx-js/react': 3.1.1(@types/react@19.2.14)(react@19.2.4) - '@storybook/csf-plugin': 10.3.0(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(rollup@4.59.0)(storybook@10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) + '@storybook/csf-plugin': 10.3.3(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(rollup@4.59.0)(storybook@10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) '@storybook/icons': 2.0.1(react-dom@19.2.4(react@19.2.4))(react@19.2.4) - '@storybook/react-dom-shim': 10.3.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(storybook@10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)) + '@storybook/react-dom-shim': 10.3.3(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(storybook@10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)) react: 19.2.4 react-dom: 19.2.4(react@19.2.4) - storybook: 10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) + storybook: 10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) ts-dedent: 2.2.0 transitivePeerDependencies: - '@types/react' @@ -10142,41 +11448,41 @@ snapshots: - vite - webpack - '@storybook/addon-links@10.3.0(react@19.2.4)(storybook@10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))': + '@storybook/addon-links@10.3.3(react@19.2.4)(storybook@10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))': dependencies: '@storybook/global': 5.0.0 - storybook: 10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) + storybook: 10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) optionalDependencies: react: 19.2.4 - '@storybook/addon-onboarding@10.3.0(storybook@10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))': + '@storybook/addon-onboarding@10.3.3(storybook@10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))': dependencies: - storybook: 10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) + storybook: 10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) - '@storybook/addon-themes@10.3.0(storybook@10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))': + '@storybook/addon-themes@10.3.3(storybook@10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))': dependencies: - storybook: 10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) + storybook: 10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) ts-dedent: 2.2.0 - '@storybook/builder-vite@10.3.0(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(rollup@4.59.0)(storybook@10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3))': + '@storybook/builder-vite@10.3.3(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(rollup@4.59.0)(storybook@10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3))': dependencies: - '@storybook/csf-plugin': 10.3.0(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(rollup@4.59.0)(storybook@10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) - storybook: 10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) + '@storybook/csf-plugin': 10.3.3(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(rollup@4.59.0)(storybook@10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) + storybook: 10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) ts-dedent: 2.2.0 - vite: '@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' + vite: '@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' transitivePeerDependencies: - esbuild - rollup - webpack - '@storybook/csf-plugin@10.3.0(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(rollup@4.59.0)(storybook@10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3))': + '@storybook/csf-plugin@10.3.3(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(rollup@4.59.0)(storybook@10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3))': dependencies: - storybook: 10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) + storybook: 10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) unplugin: 2.3.11 optionalDependencies: esbuild: 0.27.2 rollup: 4.59.0 - vite: '@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' + vite: '@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' webpack: 5.105.4(esbuild@0.27.2)(uglify-js@3.19.3) '@storybook/global@5.0.0': {} @@ -10186,18 +11492,18 @@ snapshots: react: 19.2.4 react-dom: 19.2.4(react@19.2.4) - '@storybook/nextjs-vite@10.3.0(@babel/core@7.29.0)(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(next@16.2.0(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(rollup@4.59.0)(storybook@10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3))': + '@storybook/nextjs-vite@10.3.3(@babel/core@7.29.0)(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(next@16.2.1(@babel/core@7.29.0)(@playwright/test@1.58.2)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(rollup@4.59.0)(storybook@10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3))': dependencies: - '@storybook/builder-vite': 10.3.0(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(rollup@4.59.0)(storybook@10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) - '@storybook/react': 10.3.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(storybook@10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3) - '@storybook/react-vite': 10.3.0(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(rollup@4.59.0)(storybook@10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) - next: 16.2.0(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0) + '@storybook/builder-vite': 10.3.3(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(rollup@4.59.0)(storybook@10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) + '@storybook/react': 10.3.3(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(storybook@10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3) + '@storybook/react-vite': 10.3.3(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(rollup@4.59.0)(storybook@10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) + next: 16.2.1(@babel/core@7.29.0)(@playwright/test@1.58.2)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0) react: 19.2.4 react-dom: 19.2.4(react@19.2.4) - storybook: 10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) + storybook: 10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) styled-jsx: 5.1.6(@babel/core@7.29.0)(react@19.2.4) - vite: '@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' - vite-plugin-storybook-nextjs: 3.2.3(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(next@16.2.0(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0))(storybook@10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3) + vite: '@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' + vite-plugin-storybook-nextjs: 3.2.4(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(next@16.2.1(@babel/core@7.29.0)(@playwright/test@1.58.2)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0))(storybook@10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3) optionalDependencies: typescript: 5.9.3 transitivePeerDependencies: @@ -10208,27 +11514,27 @@ snapshots: - supports-color - webpack - '@storybook/react-dom-shim@10.3.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(storybook@10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))': + '@storybook/react-dom-shim@10.3.3(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(storybook@10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))': dependencies: react: 19.2.4 react-dom: 19.2.4(react@19.2.4) - storybook: 10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) + storybook: 10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) - '@storybook/react-vite@10.3.0(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(rollup@4.59.0)(storybook@10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3))': + '@storybook/react-vite@10.3.3(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(rollup@4.59.0)(storybook@10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3))': dependencies: - '@joshwooding/vite-plugin-react-docgen-typescript': 0.6.4(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(typescript@5.9.3) + '@joshwooding/vite-plugin-react-docgen-typescript': 0.6.4(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(typescript@5.9.3) '@rollup/pluginutils': 5.3.0(rollup@4.59.0) - '@storybook/builder-vite': 10.3.0(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(rollup@4.59.0)(storybook@10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) - '@storybook/react': 10.3.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(storybook@10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3) + '@storybook/builder-vite': 10.3.3(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(rollup@4.59.0)(storybook@10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) + '@storybook/react': 10.3.3(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(storybook@10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3) empathic: 2.0.0 magic-string: 0.30.21 react: 19.2.4 react-docgen: 8.0.3 react-dom: 19.2.4(react@19.2.4) resolve: 1.22.11 - storybook: 10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) + storybook: 10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) tsconfig-paths: 4.2.0 - vite: '@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' + vite: '@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' transitivePeerDependencies: - esbuild - rollup @@ -10236,15 +11542,15 @@ snapshots: - typescript - webpack - '@storybook/react@10.3.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(storybook@10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3)': + '@storybook/react@10.3.3(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(storybook@10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3)': dependencies: '@storybook/global': 5.0.0 - '@storybook/react-dom-shim': 10.3.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(storybook@10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)) + '@storybook/react-dom-shim': 10.3.3(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(storybook@10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)) react: 19.2.4 react-docgen: 8.0.3 react-docgen-typescript: 2.4.0(typescript@5.9.3) react-dom: 19.2.4(react@19.2.4) - storybook: 10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) + storybook: 10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) optionalDependencies: typescript: 5.9.3 transitivePeerDependencies: @@ -10252,22 +11558,22 @@ snapshots: '@streamdown/math@1.0.2(react@19.2.4)': dependencies: - katex: 0.16.38 + katex: 0.16.44 react: 19.2.4 rehype-katex: 7.0.1 remark-math: 6.0.0 transitivePeerDependencies: - supports-color - '@stylistic/eslint-plugin@5.10.0(eslint@10.0.3(jiti@1.21.7))': + '@stylistic/eslint-plugin@5.10.0(eslint@10.1.0(jiti@1.21.7))': dependencies: - '@eslint-community/eslint-utils': 4.9.1(eslint@10.0.3(jiti@1.21.7)) - '@typescript-eslint/types': 8.57.1 - eslint: 10.0.3(jiti@1.21.7) + '@eslint-community/eslint-utils': 4.9.1(eslint@10.1.0(jiti@1.21.7)) + '@typescript-eslint/types': 8.57.2 + eslint: 10.1.0(jiti@1.21.7) eslint-visitor-keys: 4.2.1 espree: 10.4.0 estraverse: 5.3.0 - picomatch: 4.0.3 + picomatch: 4.0.4 '@svgdotjs/svg.js@3.2.5': {} @@ -10275,28 +11581,28 @@ snapshots: dependencies: tslib: 2.8.1 - '@swc/helpers@0.5.19': + '@swc/helpers@0.5.20': dependencies: tslib: 2.8.1 - '@t3-oss/env-core@0.13.10(typescript@5.9.3)(valibot@1.3.0(typescript@5.9.3))(zod@4.3.6)': + '@t3-oss/env-core@0.13.11(typescript@5.9.3)(valibot@1.3.1(typescript@5.9.3))(zod@4.3.6)': optionalDependencies: typescript: 5.9.3 - valibot: 1.3.0(typescript@5.9.3) + valibot: 1.3.1(typescript@5.9.3) zod: 4.3.6 - '@t3-oss/env-nextjs@0.13.10(typescript@5.9.3)(valibot@1.3.0(typescript@5.9.3))(zod@4.3.6)': + '@t3-oss/env-nextjs@0.13.11(typescript@5.9.3)(valibot@1.3.1(typescript@5.9.3))(zod@4.3.6)': dependencies: - '@t3-oss/env-core': 0.13.10(typescript@5.9.3)(valibot@1.3.0(typescript@5.9.3))(zod@4.3.6) + '@t3-oss/env-core': 0.13.11(typescript@5.9.3)(valibot@1.3.1(typescript@5.9.3))(zod@4.3.6) optionalDependencies: typescript: 5.9.3 - valibot: 1.3.0(typescript@5.9.3) + valibot: 1.3.1(typescript@5.9.3) zod: 4.3.6 - '@tailwindcss/typography@0.5.19(tailwindcss@3.4.19(tsx@4.21.0)(yaml@2.8.2))': + '@tailwindcss/typography@0.5.19(tailwindcss@3.4.19(tsx@4.21.0)(yaml@2.8.3))': dependencies: postcss-selector-parser: 6.0.10 - tailwindcss: 3.4.19(tsx@4.21.0)(yaml@2.8.2) + tailwindcss: 3.4.19(tsx@4.21.0)(yaml@2.8.3) '@tanstack/devtools-client@0.0.6': dependencies: @@ -10304,7 +11610,7 @@ snapshots: '@tanstack/devtools-event-bus@0.4.1': dependencies: - ws: 8.19.0 + ws: 8.20.0 transitivePeerDependencies: - bufferutil - utf-8-validate @@ -10342,10 +11648,10 @@ snapshots: - csstype - utf-8-validate - '@tanstack/eslint-plugin-query@5.91.5(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3)': + '@tanstack/eslint-plugin-query@5.95.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3)': dependencies: - '@typescript-eslint/utils': 8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - eslint: 10.0.3(jiti@1.21.7) + '@typescript-eslint/utils': 8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + eslint: 10.1.0(jiti@1.21.7) optionalDependencies: typescript: 5.9.3 transitivePeerDependencies: @@ -10355,7 +11661,7 @@ snapshots: dependencies: '@tanstack/devtools-event-client': 0.4.3 '@tanstack/pacer-lite': 0.1.1 - '@tanstack/store': 0.9.2 + '@tanstack/store': 0.9.3 '@tanstack/form-devtools@0.2.19(@types/react@19.2.14)(csstype@3.2.3)(react@19.2.4)(solid-js@1.9.11)': dependencies: @@ -10375,9 +11681,9 @@ snapshots: '@tanstack/pacer-lite@0.1.1': {} - '@tanstack/query-core@5.91.0': {} + '@tanstack/query-core@5.95.2': {} - '@tanstack/query-devtools@5.93.0': {} + '@tanstack/query-devtools@5.95.2': {} '@tanstack/react-devtools@0.10.0(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(csstype@3.2.3)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(solid-js@1.9.11)': dependencies: @@ -10407,25 +11713,25 @@ snapshots: '@tanstack/react-form@1.28.5(react-dom@19.2.4(react@19.2.4))(react@19.2.4)': dependencies: '@tanstack/form-core': 1.28.5 - '@tanstack/react-store': 0.9.2(react-dom@19.2.4(react@19.2.4))(react@19.2.4) + '@tanstack/react-store': 0.9.3(react-dom@19.2.4(react@19.2.4))(react@19.2.4) react: 19.2.4 transitivePeerDependencies: - react-dom - '@tanstack/react-query-devtools@5.91.3(@tanstack/react-query@5.91.0(react@19.2.4))(react@19.2.4)': + '@tanstack/react-query-devtools@5.95.2(@tanstack/react-query@5.95.2(react@19.2.4))(react@19.2.4)': dependencies: - '@tanstack/query-devtools': 5.93.0 - '@tanstack/react-query': 5.91.0(react@19.2.4) + '@tanstack/query-devtools': 5.95.2 + '@tanstack/react-query': 5.95.2(react@19.2.4) react: 19.2.4 - '@tanstack/react-query@5.91.0(react@19.2.4)': + '@tanstack/react-query@5.95.2(react@19.2.4)': dependencies: - '@tanstack/query-core': 5.91.0 + '@tanstack/query-core': 5.95.2 react: 19.2.4 - '@tanstack/react-store@0.9.2(react-dom@19.2.4(react@19.2.4))(react@19.2.4)': + '@tanstack/react-store@0.9.3(react-dom@19.2.4(react@19.2.4))(react@19.2.4)': dependencies: - '@tanstack/store': 0.9.2 + '@tanstack/store': 0.9.3 react: 19.2.4 react-dom: 19.2.4(react@19.2.4) use-sync-external-store: 1.6.0(react@19.2.4) @@ -10436,10 +11742,12 @@ snapshots: react: 19.2.4 react-dom: 19.2.4(react@19.2.4) - '@tanstack/store@0.9.2': {} + '@tanstack/store@0.9.3': {} '@tanstack/virtual-core@3.13.23': {} + '@teppeis/multimaps@3.0.0': {} + '@testing-library/dom@10.4.1': dependencies: '@babel/code-frame': 7.29.0 @@ -10491,7 +11799,7 @@ snapshots: '@tsslint/compat-eslint@3.0.2(jiti@1.21.7)(typescript@5.9.3)': dependencies: '@tsslint/types': 3.0.2 - '@typescript-eslint/parser': 8.57.1(eslint@9.27.0(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/parser': 8.57.2(eslint@9.27.0(jiti@1.21.7))(typescript@5.9.3) eslint: 9.27.0(jiti@1.21.7) transitivePeerDependencies: - jiti @@ -10502,7 +11810,7 @@ snapshots: dependencies: '@tsslint/types': 3.0.2 minimatch: 10.2.4 - ts-api-utils: 2.4.0(typescript@5.9.3) + ts-api-utils: 2.5.0(typescript@5.9.3) optionalDependencies: '@tsslint/compat-eslint': 3.0.2(jiti@1.21.7)(typescript@5.9.3) transitivePeerDependencies: @@ -10667,7 +11975,7 @@ snapshots: '@types/d3-transition': 3.0.9 '@types/d3-zoom': 3.0.8 - '@types/debug@4.1.12': + '@types/debug@4.1.13': dependencies: '@types/ms': 2.1.0 @@ -10725,6 +12033,8 @@ snapshots: dependencies: undici-types: 7.18.2 + '@types/normalize-package-data@2.4.4': {} + '@types/papaparse@5.5.2': dependencies: '@types/node': 25.5.0 @@ -10739,10 +12049,6 @@ snapshots: dependencies: '@types/react': 19.2.14 - '@types/react-slider@1.3.6': - dependencies: - '@types/react': 19.2.14 - '@types/react-syntax-highlighter@15.5.13': dependencies: '@types/react': 19.2.14 @@ -10766,6 +12072,12 @@ snapshots: '@types/unist@3.0.3': {} + '@types/whatwg-mimetype@3.0.2': {} + + '@types/ws@8.18.1': + dependencies: + '@types/node': 25.5.0 + '@types/yauzl@2.10.3': dependencies: '@types/node': 25.5.0 @@ -10773,62 +12085,90 @@ snapshots: '@types/zen-observable@0.8.3': {} - '@typescript-eslint/eslint-plugin@8.57.1(@typescript-eslint/parser@8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3))(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3)': + '@typescript-eslint/eslint-plugin@8.57.2(@typescript-eslint/parser@8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3)': dependencies: '@eslint-community/regexpp': 4.12.2 - '@typescript-eslint/parser': 8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@typescript-eslint/scope-manager': 8.57.1 - '@typescript-eslint/type-utils': 8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@typescript-eslint/utils': 8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@typescript-eslint/visitor-keys': 8.57.1 - eslint: 10.0.3(jiti@1.21.7) + '@typescript-eslint/parser': 8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/scope-manager': 8.57.2 + '@typescript-eslint/type-utils': 8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/utils': 8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/visitor-keys': 8.57.2 + eslint: 10.1.0(jiti@1.21.7) ignore: 7.0.5 natural-compare: 1.4.0 - ts-api-utils: 2.4.0(typescript@5.9.3) + ts-api-utils: 2.5.0(typescript@5.9.3) typescript: 5.9.3 transitivePeerDependencies: - supports-color - '@typescript-eslint/parser@8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3)': + '@typescript-eslint/eslint-plugin@8.57.2(@typescript-eslint/parser@8.57.2(eslint@10.1.0(jiti@2.6.1))(typescript@5.9.3))(eslint@10.1.0(jiti@2.6.1))(typescript@5.9.3)': dependencies: - '@typescript-eslint/scope-manager': 8.57.1 - '@typescript-eslint/types': 8.57.1 - '@typescript-eslint/typescript-estree': 8.57.1(typescript@5.9.3) - '@typescript-eslint/visitor-keys': 8.57.1 - debug: 4.4.3 - eslint: 10.0.3(jiti@1.21.7) + '@eslint-community/regexpp': 4.12.2 + '@typescript-eslint/parser': 8.57.2(eslint@10.1.0(jiti@2.6.1))(typescript@5.9.3) + '@typescript-eslint/scope-manager': 8.57.2 + '@typescript-eslint/type-utils': 8.57.2(eslint@10.1.0(jiti@2.6.1))(typescript@5.9.3) + '@typescript-eslint/utils': 8.57.2(eslint@10.1.0(jiti@2.6.1))(typescript@5.9.3) + '@typescript-eslint/visitor-keys': 8.57.2 + eslint: 10.1.0(jiti@2.6.1) + ignore: 7.0.5 + natural-compare: 1.4.0 + ts-api-utils: 2.5.0(typescript@5.9.3) typescript: 5.9.3 transitivePeerDependencies: - supports-color - '@typescript-eslint/parser@8.57.1(eslint@9.27.0(jiti@1.21.7))(typescript@5.9.3)': + '@typescript-eslint/parser@8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3)': dependencies: - '@typescript-eslint/scope-manager': 8.57.1 - '@typescript-eslint/types': 8.57.1 - '@typescript-eslint/typescript-estree': 8.57.1(typescript@5.9.3) - '@typescript-eslint/visitor-keys': 8.57.1 - debug: 4.4.3 + '@typescript-eslint/scope-manager': 8.57.2 + '@typescript-eslint/types': 8.57.2 + '@typescript-eslint/typescript-estree': 8.57.2(typescript@5.9.3) + '@typescript-eslint/visitor-keys': 8.57.2 + debug: 4.4.3(supports-color@8.1.1) + eslint: 10.1.0(jiti@1.21.7) + typescript: 5.9.3 + transitivePeerDependencies: + - supports-color + + '@typescript-eslint/parser@8.57.2(eslint@10.1.0(jiti@2.6.1))(typescript@5.9.3)': + dependencies: + '@typescript-eslint/scope-manager': 8.57.2 + '@typescript-eslint/types': 8.57.2 + '@typescript-eslint/typescript-estree': 8.57.2(typescript@5.9.3) + '@typescript-eslint/visitor-keys': 8.57.2 + debug: 4.4.3(supports-color@8.1.1) + eslint: 10.1.0(jiti@2.6.1) + typescript: 5.9.3 + transitivePeerDependencies: + - supports-color + + '@typescript-eslint/parser@8.57.2(eslint@9.27.0(jiti@1.21.7))(typescript@5.9.3)': + dependencies: + '@typescript-eslint/scope-manager': 8.57.2 + '@typescript-eslint/types': 8.57.2 + '@typescript-eslint/typescript-estree': 8.57.2(typescript@5.9.3) + '@typescript-eslint/visitor-keys': 8.57.2 + debug: 4.4.3(supports-color@8.1.1) eslint: 9.27.0(jiti@1.21.7) typescript: 5.9.3 transitivePeerDependencies: - supports-color - '@typescript-eslint/project-service@8.57.1(typescript@5.9.3)': + '@typescript-eslint/project-service@8.57.2(typescript@5.9.3)': dependencies: - '@typescript-eslint/tsconfig-utils': 8.57.1(typescript@5.9.3) - '@typescript-eslint/types': 8.57.1 - debug: 4.4.3 + '@typescript-eslint/tsconfig-utils': 8.57.2(typescript@5.9.3) + '@typescript-eslint/types': 8.57.2 + debug: 4.4.3(supports-color@8.1.1) typescript: 5.9.3 transitivePeerDependencies: - supports-color - '@typescript-eslint/rule-tester@8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3)': + '@typescript-eslint/rule-tester@8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3)': dependencies: - '@typescript-eslint/parser': 8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@typescript-eslint/typescript-estree': 8.57.1(typescript@5.9.3) - '@typescript-eslint/utils': 8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/parser': 8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/typescript-estree': 8.57.2(typescript@5.9.3) + '@typescript-eslint/utils': 8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) ajv: 6.14.0 - eslint: 10.0.3(jiti@1.21.7) + eslint: 10.1.0(jiti@1.21.7) json-stable-stringify-without-jsonify: 1.0.1 lodash.merge: 4.6.2 semver: 7.7.4 @@ -10836,90 +12176,113 @@ snapshots: - supports-color - typescript - '@typescript-eslint/scope-manager@8.57.1': + '@typescript-eslint/scope-manager@8.57.2': dependencies: - '@typescript-eslint/types': 8.57.1 - '@typescript-eslint/visitor-keys': 8.57.1 + '@typescript-eslint/types': 8.57.2 + '@typescript-eslint/visitor-keys': 8.57.2 - '@typescript-eslint/tsconfig-utils@8.57.1(typescript@5.9.3)': + '@typescript-eslint/tsconfig-utils@8.57.2(typescript@5.9.3)': dependencies: typescript: 5.9.3 - '@typescript-eslint/type-utils@8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3)': + '@typescript-eslint/type-utils@8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3)': dependencies: - '@typescript-eslint/types': 8.57.1 - '@typescript-eslint/typescript-estree': 8.57.1(typescript@5.9.3) - '@typescript-eslint/utils': 8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - debug: 4.4.3 - eslint: 10.0.3(jiti@1.21.7) - ts-api-utils: 2.4.0(typescript@5.9.3) + '@typescript-eslint/types': 8.57.2 + '@typescript-eslint/typescript-estree': 8.57.2(typescript@5.9.3) + '@typescript-eslint/utils': 8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + debug: 4.4.3(supports-color@8.1.1) + eslint: 10.1.0(jiti@1.21.7) + ts-api-utils: 2.5.0(typescript@5.9.3) typescript: 5.9.3 transitivePeerDependencies: - supports-color - '@typescript-eslint/types@8.57.1': {} - - '@typescript-eslint/typescript-estree@8.57.1(typescript@5.9.3)': + '@typescript-eslint/type-utils@8.57.2(eslint@10.1.0(jiti@2.6.1))(typescript@5.9.3)': dependencies: - '@typescript-eslint/project-service': 8.57.1(typescript@5.9.3) - '@typescript-eslint/tsconfig-utils': 8.57.1(typescript@5.9.3) - '@typescript-eslint/types': 8.57.1 - '@typescript-eslint/visitor-keys': 8.57.1 - debug: 4.4.3 + '@typescript-eslint/types': 8.57.2 + '@typescript-eslint/typescript-estree': 8.57.2(typescript@5.9.3) + '@typescript-eslint/utils': 8.57.2(eslint@10.1.0(jiti@2.6.1))(typescript@5.9.3) + debug: 4.4.3(supports-color@8.1.1) + eslint: 10.1.0(jiti@2.6.1) + ts-api-utils: 2.5.0(typescript@5.9.3) + typescript: 5.9.3 + transitivePeerDependencies: + - supports-color + + '@typescript-eslint/types@8.57.2': {} + + '@typescript-eslint/typescript-estree@8.57.2(typescript@5.9.3)': + dependencies: + '@typescript-eslint/project-service': 8.57.2(typescript@5.9.3) + '@typescript-eslint/tsconfig-utils': 8.57.2(typescript@5.9.3) + '@typescript-eslint/types': 8.57.2 + '@typescript-eslint/visitor-keys': 8.57.2 + debug: 4.4.3(supports-color@8.1.1) minimatch: 10.2.4 semver: 7.7.4 tinyglobby: 0.2.15 - ts-api-utils: 2.4.0(typescript@5.9.3) + ts-api-utils: 2.5.0(typescript@5.9.3) typescript: 5.9.3 transitivePeerDependencies: - supports-color - '@typescript-eslint/utils@8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3)': + '@typescript-eslint/utils@8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3)': dependencies: - '@eslint-community/eslint-utils': 4.9.1(eslint@10.0.3(jiti@1.21.7)) - '@typescript-eslint/scope-manager': 8.57.1 - '@typescript-eslint/types': 8.57.1 - '@typescript-eslint/typescript-estree': 8.57.1(typescript@5.9.3) - eslint: 10.0.3(jiti@1.21.7) + '@eslint-community/eslint-utils': 4.9.1(eslint@10.1.0(jiti@1.21.7)) + '@typescript-eslint/scope-manager': 8.57.2 + '@typescript-eslint/types': 8.57.2 + '@typescript-eslint/typescript-estree': 8.57.2(typescript@5.9.3) + eslint: 10.1.0(jiti@1.21.7) typescript: 5.9.3 transitivePeerDependencies: - supports-color - '@typescript-eslint/visitor-keys@8.57.1': + '@typescript-eslint/utils@8.57.2(eslint@10.1.0(jiti@2.6.1))(typescript@5.9.3)': dependencies: - '@typescript-eslint/types': 8.57.1 + '@eslint-community/eslint-utils': 4.9.1(eslint@10.1.0(jiti@2.6.1)) + '@typescript-eslint/scope-manager': 8.57.2 + '@typescript-eslint/types': 8.57.2 + '@typescript-eslint/typescript-estree': 8.57.2(typescript@5.9.3) + eslint: 10.1.0(jiti@2.6.1) + typescript: 5.9.3 + transitivePeerDependencies: + - supports-color + + '@typescript-eslint/visitor-keys@8.57.2': + dependencies: + '@typescript-eslint/types': 8.57.2 eslint-visitor-keys: 5.0.1 - '@typescript/native-preview-darwin-arm64@7.0.0-dev.20260318.1': + '@typescript/native-preview-darwin-arm64@7.0.0-dev.20260329.1': optional: true - '@typescript/native-preview-darwin-x64@7.0.0-dev.20260318.1': + '@typescript/native-preview-darwin-x64@7.0.0-dev.20260329.1': optional: true - '@typescript/native-preview-linux-arm64@7.0.0-dev.20260318.1': + '@typescript/native-preview-linux-arm64@7.0.0-dev.20260329.1': optional: true - '@typescript/native-preview-linux-arm@7.0.0-dev.20260318.1': + '@typescript/native-preview-linux-arm@7.0.0-dev.20260329.1': optional: true - '@typescript/native-preview-linux-x64@7.0.0-dev.20260318.1': + '@typescript/native-preview-linux-x64@7.0.0-dev.20260329.1': optional: true - '@typescript/native-preview-win32-arm64@7.0.0-dev.20260318.1': + '@typescript/native-preview-win32-arm64@7.0.0-dev.20260329.1': optional: true - '@typescript/native-preview-win32-x64@7.0.0-dev.20260318.1': + '@typescript/native-preview-win32-x64@7.0.0-dev.20260329.1': optional: true - '@typescript/native-preview@7.0.0-dev.20260318.1': + '@typescript/native-preview@7.0.0-dev.20260329.1': optionalDependencies: - '@typescript/native-preview-darwin-arm64': 7.0.0-dev.20260318.1 - '@typescript/native-preview-darwin-x64': 7.0.0-dev.20260318.1 - '@typescript/native-preview-linux-arm': 7.0.0-dev.20260318.1 - '@typescript/native-preview-linux-arm64': 7.0.0-dev.20260318.1 - '@typescript/native-preview-linux-x64': 7.0.0-dev.20260318.1 - '@typescript/native-preview-win32-arm64': 7.0.0-dev.20260318.1 - '@typescript/native-preview-win32-x64': 7.0.0-dev.20260318.1 + '@typescript/native-preview-darwin-arm64': 7.0.0-dev.20260329.1 + '@typescript/native-preview-darwin-x64': 7.0.0-dev.20260329.1 + '@typescript/native-preview-linux-arm': 7.0.0-dev.20260329.1 + '@typescript/native-preview-linux-arm64': 7.0.0-dev.20260329.1 + '@typescript/native-preview-linux-x64': 7.0.0-dev.20260329.1 + '@typescript/native-preview-win32-arm64': 7.0.0-dev.20260329.1 + '@typescript/native-preview-win32-x64': 7.0.0-dev.20260329.1 '@ungap/structured-clone@1.3.0': {} @@ -10927,34 +12290,56 @@ snapshots: dependencies: unpic: 4.2.2 - '@unpic/react@1.0.2(next@16.2.0(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)': + '@unpic/react@1.0.2(next@16.2.1(@babel/core@7.29.0)(@playwright/test@1.58.2)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)': dependencies: '@unpic/core': 1.0.3 react: 19.2.4 react-dom: 19.2.4(react@19.2.4) optionalDependencies: - next: 16.2.0(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0) + next: 16.2.1(@babel/core@7.29.0)(@playwright/test@1.58.2)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0) '@upsetjs/venn.js@2.0.0': optionalDependencies: d3-selection: 3.0.0 d3-transition: 3.0.1(d3-selection@3.0.0) - '@valibot/to-json-schema@1.6.0(valibot@1.3.0(typescript@5.9.3))': + '@valibot/to-json-schema@1.6.0(valibot@1.3.1(typescript@5.9.3))': dependencies: - valibot: 1.3.0(typescript@5.9.3) + valibot: 1.3.1(typescript@5.9.3) '@vercel/og@0.8.6': dependencies: '@resvg/resvg-wasm': 2.4.0 satori: 0.16.0 - '@vitejs/plugin-react@6.0.1(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))': + '@vitejs/devtools-kit@0.1.11(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(typescript@5.9.3)(ws@8.20.0)': + dependencies: + '@vitejs/devtools-rpc': 0.1.11(typescript@5.9.3)(ws@8.20.0) + birpc: 4.0.0 + ohash: 2.0.11 + vite: '@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' + transitivePeerDependencies: + - typescript + - ws + + '@vitejs/devtools-rpc@0.1.11(typescript@5.9.3)(ws@8.20.0)': + dependencies: + birpc: 4.0.0 + ohash: 2.0.11 + p-limit: 7.3.0 + structured-clone-es: 2.0.0 + valibot: 1.3.1(typescript@5.9.3) + optionalDependencies: + ws: 8.20.0 + transitivePeerDependencies: + - typescript + + '@vitejs/plugin-react@6.0.1(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))': dependencies: '@rolldown/pluginutils': 1.0.0-rc.7 - vite: '@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' + vite: '@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' - '@vitejs/plugin-rsc@0.5.21(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(react-dom@19.2.4(react@19.2.4))(react-server-dom-webpack@19.2.4(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)))(react@19.2.4)': + '@vitejs/plugin-rsc@0.5.21(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(react-dom@19.2.4(react@19.2.4))(react-server-dom-webpack@19.2.4(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)))(react@19.2.4)': dependencies: '@rolldown/pluginutils': 1.0.0-rc.5 es-module-lexer: 2.0.0 @@ -10963,18 +12348,18 @@ snapshots: periscopic: 4.0.2 react: 19.2.4 react-dom: 19.2.4(react@19.2.4) - srvx: 0.11.12 + srvx: 0.11.13 strip-literal: 3.1.0 turbo-stream: 3.2.0 - vite: '@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' - vitefu: 1.1.2(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)) + vite: '@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' + vitefu: 1.1.2(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)) optionalDependencies: react-server-dom-webpack: 19.2.4(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) - '@vitest/coverage-v8@4.1.0(@voidzero-dev/vite-plus-test@0.1.12(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.0(canvas@3.2.1))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))': + '@vitest/coverage-v8@4.1.2(@voidzero-dev/vite-plus-test@0.1.14(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))': dependencies: '@bcoe/v8-coverage': 1.0.2 - '@vitest/utils': 4.1.0 + '@vitest/utils': 4.1.2 ast-v8-to-istanbul: 1.0.0 istanbul-lib-coverage: 3.2.2 istanbul-lib-report: 3.0.1 @@ -10983,16 +12368,31 @@ snapshots: obug: 2.1.1 std-env: 4.0.0 tinyrainbow: 3.1.0 - vitest: '@voidzero-dev/vite-plus-test@0.1.12(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.0(canvas@3.2.1))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' + vitest: '@voidzero-dev/vite-plus-test@0.1.14(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' - '@vitest/eslint-plugin@1.6.12(@voidzero-dev/vite-plus-test@0.1.12(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.0(canvas@3.2.1))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3)': + '@vitest/coverage-v8@4.1.2(@voidzero-dev/vite-plus-test@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(vite@8.0.3(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1)(@types/node@25.5.0)(esbuild@0.27.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(yaml@2.8.3))(yaml@2.8.3))': dependencies: - '@typescript-eslint/scope-manager': 8.57.1 - '@typescript-eslint/utils': 8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - eslint: 10.0.3(jiti@1.21.7) + '@bcoe/v8-coverage': 1.0.2 + '@vitest/utils': 4.1.2 + ast-v8-to-istanbul: 1.0.0 + istanbul-lib-coverage: 3.2.2 + istanbul-lib-report: 3.0.1 + istanbul-reports: 3.2.0 + magicast: 0.5.2 + obug: 2.1.1 + std-env: 4.0.0 + tinyrainbow: 3.1.0 + vitest: '@voidzero-dev/vite-plus-test@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(vite@8.0.3(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1)(@types/node@25.5.0)(esbuild@0.27.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(yaml@2.8.3))(yaml@2.8.3)' + + '@vitest/eslint-plugin@1.6.13(@typescript-eslint/eslint-plugin@8.57.2(@typescript-eslint/parser@8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(@voidzero-dev/vite-plus-test@0.1.14(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3)': + dependencies: + '@typescript-eslint/scope-manager': 8.57.2 + '@typescript-eslint/utils': 8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + eslint: 10.1.0(jiti@1.21.7) optionalDependencies: + '@typescript-eslint/eslint-plugin': 8.57.2(@typescript-eslint/parser@8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) typescript: 5.9.3 - vitest: '@voidzero-dev/vite-plus-test@0.1.12(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.0(canvas@3.2.1))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' + vitest: '@voidzero-dev/vite-plus-test@0.1.14(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' transitivePeerDependencies: - supports-color @@ -11008,7 +12408,7 @@ snapshots: dependencies: tinyrainbow: 2.0.0 - '@vitest/pretty-format@4.1.0': + '@vitest/pretty-format@4.1.2': dependencies: tinyrainbow: 3.1.0 @@ -11022,16 +12422,16 @@ snapshots: loupe: 3.2.1 tinyrainbow: 2.0.0 - '@vitest/utils@4.1.0': + '@vitest/utils@4.1.2': dependencies: - '@vitest/pretty-format': 4.1.0 + '@vitest/pretty-format': 4.1.2 convert-source-map: 2.0.0 tinyrainbow: 3.1.0 - '@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)': + '@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)': dependencies: - '@oxc-project/runtime': 0.115.0 - '@oxc-project/types': 0.115.0 + '@oxc-project/runtime': 0.121.0 + '@oxc-project/types': 0.122.0 lightningcss: 1.32.0 postcss: 8.5.8 optionalDependencies: @@ -11043,25 +12443,48 @@ snapshots: terser: 5.46.1 tsx: 4.21.0 typescript: 5.9.3 - yaml: 2.8.2 + yaml: 2.8.3 - '@voidzero-dev/vite-plus-darwin-arm64@0.1.12': + '@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)': + dependencies: + '@oxc-project/runtime': 0.121.0 + '@oxc-project/types': 0.122.0 + lightningcss: 1.32.0 + postcss: 8.5.8 + optionalDependencies: + '@types/node': 25.5.0 + esbuild: 0.27.2 + fsevents: 2.3.3 + jiti: 2.6.1 + sass: 1.98.0 + terser: 5.46.1 + tsx: 4.21.0 + typescript: 5.9.3 + yaml: 2.8.3 + + '@voidzero-dev/vite-plus-darwin-arm64@0.1.14': optional: true - '@voidzero-dev/vite-plus-darwin-x64@0.1.12': + '@voidzero-dev/vite-plus-darwin-x64@0.1.14': optional: true - '@voidzero-dev/vite-plus-linux-arm64-gnu@0.1.12': + '@voidzero-dev/vite-plus-linux-arm64-gnu@0.1.14': optional: true - '@voidzero-dev/vite-plus-linux-x64-gnu@0.1.12': + '@voidzero-dev/vite-plus-linux-arm64-musl@0.1.14': optional: true - '@voidzero-dev/vite-plus-test@0.1.12(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.0(canvas@3.2.1))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)': + '@voidzero-dev/vite-plus-linux-x64-gnu@0.1.14': + optional: true + + '@voidzero-dev/vite-plus-linux-x64-musl@0.1.14': + optional: true + + '@voidzero-dev/vite-plus-test@0.1.14(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)': dependencies: '@standard-schema/spec': 1.1.0 '@types/chai': 5.2.3 - '@voidzero-dev/vite-plus-core': 0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2) + '@voidzero-dev/vite-plus-core': 0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3) es-module-lexer: 1.7.0 obug: 2.1.1 pixelmatch: 7.1.0 @@ -11071,11 +12494,11 @@ snapshots: tinybench: 2.9.0 tinyexec: 1.0.4 tinyglobby: 0.2.15 - vite: '@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' - ws: 8.19.0 + vite: '@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' + ws: 8.20.0 optionalDependencies: '@types/node': 25.5.0 - jsdom: 29.0.0(canvas@3.2.1) + happy-dom: 20.8.9 transitivePeerDependencies: - '@arethetypeswrong/core' - '@tsdown/css' @@ -11097,10 +12520,50 @@ snapshots: - utf-8-validate - yaml - '@voidzero-dev/vite-plus-win32-arm64-msvc@0.1.12': + '@voidzero-dev/vite-plus-test@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(vite@8.0.3(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1)(@types/node@25.5.0)(esbuild@0.27.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(yaml@2.8.3))(yaml@2.8.3)': + dependencies: + '@standard-schema/spec': 1.1.0 + '@types/chai': 5.2.3 + '@voidzero-dev/vite-plus-core': 0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3) + es-module-lexer: 1.7.0 + obug: 2.1.1 + pixelmatch: 7.1.0 + pngjs: 7.0.0 + sirv: 3.0.2 + std-env: 4.0.0 + tinybench: 2.9.0 + tinyexec: 1.0.4 + tinyglobby: 0.2.15 + vite: 8.0.3(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1)(@types/node@25.5.0)(esbuild@0.27.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(yaml@2.8.3) + ws: 8.20.0 + optionalDependencies: + '@types/node': 25.5.0 + happy-dom: 20.8.9 + transitivePeerDependencies: + - '@arethetypeswrong/core' + - '@tsdown/css' + - '@tsdown/exe' + - '@vitejs/devtools' + - bufferutil + - esbuild + - jiti + - less + - publint + - sass + - sass-embedded + - stylus + - sugarss + - terser + - tsx + - typescript + - unplugin-unused + - utf-8-validate + - yaml + + '@voidzero-dev/vite-plus-win32-arm64-msvc@0.1.14': optional: true - '@voidzero-dev/vite-plus-win32-x64-msvc@0.1.12': + '@voidzero-dev/vite-plus-win32-x64-msvc@0.1.14': optional: true '@volar/language-core@2.4.28': @@ -11117,37 +12580,37 @@ snapshots: path-browserify: 1.0.1 vscode-uri: 3.1.0 - '@vue/compiler-core@3.5.30': + '@vue/compiler-core@3.5.31': dependencies: '@babel/parser': 7.29.2 - '@vue/shared': 3.5.30 + '@vue/shared': 3.5.31 entities: 7.0.1 estree-walker: 2.0.2 source-map-js: 1.2.1 - '@vue/compiler-dom@3.5.30': + '@vue/compiler-dom@3.5.31': dependencies: - '@vue/compiler-core': 3.5.30 - '@vue/shared': 3.5.30 + '@vue/compiler-core': 3.5.31 + '@vue/shared': 3.5.31 - '@vue/compiler-sfc@3.5.30': + '@vue/compiler-sfc@3.5.31': dependencies: '@babel/parser': 7.29.2 - '@vue/compiler-core': 3.5.30 - '@vue/compiler-dom': 3.5.30 - '@vue/compiler-ssr': 3.5.30 - '@vue/shared': 3.5.30 + '@vue/compiler-core': 3.5.31 + '@vue/compiler-dom': 3.5.31 + '@vue/compiler-ssr': 3.5.31 + '@vue/shared': 3.5.31 estree-walker: 2.0.2 magic-string: 0.30.21 postcss: 8.5.8 source-map-js: 1.2.1 - '@vue/compiler-ssr@3.5.30': + '@vue/compiler-ssr@3.5.31': dependencies: - '@vue/compiler-dom': 3.5.30 - '@vue/shared': 3.5.30 + '@vue/compiler-dom': 3.5.31 + '@vue/shared': 3.5.31 - '@vue/shared@3.5.30': {} + '@vue/shared@3.5.31': {} '@webassemblyjs/ast@1.14.1': dependencies: @@ -11247,12 +12710,12 @@ snapshots: acorn@8.16.0: {} - agentation@2.3.3(react-dom@19.2.4(react@19.2.4))(react@19.2.4): + agentation@3.0.2(react-dom@19.2.4(react@19.2.4))(react@19.2.4): optionalDependencies: react: 19.2.4 react-dom: 19.2.4(react@19.2.4) - ahooks@3.9.6(react-dom@19.2.4(react@19.2.4))(react@19.2.4): + ahooks@3.9.7(react-dom@19.2.4(react@19.2.4))(react@19.2.4): dependencies: '@babel/runtime': 7.29.2 '@types/js-cookie': 3.0.6 @@ -11294,6 +12757,8 @@ snapshots: dependencies: environment: 1.1.0 + ansi-regex@4.1.1: {} + ansi-regex@5.0.1: {} ansi-regex@6.2.2: {} @@ -11313,7 +12778,7 @@ snapshots: anymatch@3.1.3: dependencies: normalize-path: 3.0.0 - picomatch: 2.3.1 + picomatch: 2.3.2 are-docs-informative@0.0.2: {} @@ -11331,6 +12796,12 @@ snapshots: aria-query@5.3.2: {} + assertion-error-formatter@3.0.0: + dependencies: + diff: 4.0.4 + pad-right: 0.2.2 + repeat-string: 1.6.1 + assertion-error@2.0.1: {} ast-types@0.16.1: @@ -11347,15 +12818,25 @@ snapshots: async@3.2.6: {} + asynckit@0.4.0: {} + autoprefixer@10.4.27(postcss@8.5.8): dependencies: browserslist: 4.28.1 - caniuse-lite: 1.0.30001780 + caniuse-lite: 1.0.30001781 fraction.js: 5.3.4 picocolors: 1.1.1 postcss: 8.5.8 postcss-value-parser: 4.2.0 + axios@1.14.0: + dependencies: + follow-redirects: 1.15.11 + form-data: 4.0.5 + proxy-from-env: 2.1.0 + transitivePeerDependencies: + - debug + bail@2.0.2: {} balanced-match@1.0.2: {} @@ -11369,21 +12850,13 @@ snapshots: base64-js@1.5.1: optional: true - baseline-browser-mapping@2.10.8: {} - - before-after-hook@4.0.0: {} - - bezier-easing@2.1.0: {} - - bidi-js@1.0.3: - dependencies: - require-from-string: 2.0.2 + baseline-browser-mapping@2.10.12: {} binary-extensions@2.3.0: {} birecord@0.1.1: {} - birpc@2.9.0: {} + birpc@4.0.0: {} bl@4.1.0: dependencies: @@ -11398,7 +12871,7 @@ snapshots: dependencies: balanced-match: 1.0.2 - brace-expansion@5.0.4: + brace-expansion@5.0.5: dependencies: balanced-match: 4.0.4 @@ -11408,9 +12881,9 @@ snapshots: browserslist@4.28.1: dependencies: - baseline-browser-mapping: 2.10.8 - caniuse-lite: 1.0.30001780 - electron-to-chromium: 1.5.313 + baseline-browser-mapping: 2.10.12 + caniuse-lite: 1.0.30001781 + electron-to-chromium: 1.5.328 node-releases: 2.0.36 update-browserslist-db: 1.2.3(browserslist@4.28.1) @@ -11432,26 +12905,42 @@ snapshots: dependencies: run-applescript: 7.1.0 + bundle-require@5.1.0(esbuild@0.27.2): + dependencies: + esbuild: 0.27.2 + load-tsconfig: 0.2.5 + bytes@3.1.2: {} cac@6.7.14: {} cac@7.0.0: {} + call-bind-apply-helpers@1.0.2: + dependencies: + es-errors: 1.3.0 + function-bind: 1.1.2 + callsites@3.1.0: {} camelcase-css@2.0.1: {} camelize@1.0.1: {} - caniuse-lite@1.0.30001780: {} + caniuse-lite@1.0.30001781: {} - canvas@3.2.1: + canvas@3.2.2: dependencies: node-addon-api: 7.1.1 prebuild-install: 7.1.3 optional: true + capital-case@1.0.4: + dependencies: + no-case: 3.0.4 + tslib: 2.8.1 + upper-case-first: 2.0.2 + ccount@2.0.1: {} chai@5.3.3: @@ -11554,6 +13043,8 @@ snapshots: ci-info@4.4.0: {} + class-transformer@0.5.1: {} + class-variance-authority@0.7.1: dependencies: clsx: 2.1.1 @@ -11570,6 +13061,12 @@ snapshots: dependencies: restore-cursor: 5.1.0 + cli-table3@0.6.5: + dependencies: + string-width: 8.2.0 + optionalDependencies: + '@colors/colors': 1.5.0 + cli-truncate@5.2.0: dependencies: slice-ansi: 8.0.0 @@ -11591,14 +13088,14 @@ snapshots: - '@types/react' - '@types/react-dom' - code-inspector-plugin@1.4.4: + code-inspector-plugin@1.4.5: dependencies: - '@code-inspector/core': 1.4.4 - '@code-inspector/esbuild': 1.4.4 - '@code-inspector/mako': 1.4.4 - '@code-inspector/turbopack': 1.4.4 - '@code-inspector/vite': 1.4.4 - '@code-inspector/webpack': 1.4.4 + '@code-inspector/core': 1.4.5 + '@code-inspector/esbuild': 1.4.5 + '@code-inspector/mako': 1.4.5 + '@code-inspector/turbopack': 1.4.5 + '@code-inspector/vite': 1.4.5 + '@code-inspector/webpack': 1.4.5 chalk: 4.1.1 transitivePeerDependencies: - supports-color @@ -11613,10 +13110,18 @@ snapshots: colorette@2.0.20: {} + combined-stream@1.0.8: + dependencies: + delayed-stream: 1.0.0 + comma-separated-tokens@1.0.8: {} comma-separated-tokens@2.0.3: {} + commander@14.0.0: {} + + commander@14.0.2: {} + commander@14.0.3: {} commander@2.20.3: {} @@ -11629,12 +13134,16 @@ snapshots: comment-parser@1.4.5: {} + comment-parser@1.4.6: {} + compare-versions@6.1.1: {} confbox@0.1.8: {} confbox@0.2.4: {} + consola@3.4.2: {} + convert-source-map@2.0.0: {} copy-to-clipboard@3.3.3: @@ -11671,8 +13180,6 @@ snapshots: css-gradient-parser@0.0.16: {} - css-mediaquery@0.1.2: {} - css-select@5.2.2: dependencies: boolbase: 1.0.0 @@ -11697,11 +13204,6 @@ snapshots: mdn-data: 2.0.30 source-map-js: 1.2.1 - css-tree@3.2.1: - dependencies: - mdn-data: 2.27.1 - source-map-js: 1.2.1 - css-what@6.2.2: {} css.escape@1.5.1: {} @@ -11758,7 +13260,7 @@ snapshots: d3-delaunay@6.0.4: dependencies: - delaunator: 5.0.1 + delaunator: 5.1.0 d3-dispatch@3.0.1: {} @@ -11900,18 +13402,13 @@ snapshots: d3: 7.9.0 lodash-es: 4.17.23 - data-urls@7.0.0: - dependencies: - whatwg-mimetype: 5.0.0 - whatwg-url: 16.0.1 - transitivePeerDependencies: - - '@noble/hashes' - dayjs@1.11.20: {} - debug@4.4.3: + debug@4.4.3(supports-color@8.1.1): dependencies: ms: 2.1.3 + optionalDependencies: + supports-color: 8.1.1 decimal.js@10.6.0: {} @@ -11942,9 +13439,11 @@ snapshots: defu@6.1.4: {} - delaunator@5.0.1: + delaunator@5.1.0: dependencies: - robust-predicates: 3.0.2 + robust-predicates: 3.0.3 + + delayed-stream@1.0.0: {} dequal@2.0.3: {} @@ -11962,6 +13461,8 @@ snapshots: diff-sequences@29.6.3: {} + diff@4.0.4: {} + dlv@1.1.3: {} doctrine@3.0.0: @@ -12000,6 +13501,12 @@ snapshots: dotenv@16.6.1: {} + dunder-proto@1.0.1: + dependencies: + call-bind-apply-helpers: 1.0.2 + es-errors: 1.3.0 + gopd: 1.2.0 + echarts-for-react@3.0.6(echarts@6.0.0)(react@19.2.4): dependencies: echarts: 6.0.0 @@ -12012,7 +13519,7 @@ snapshots: tslib: 2.3.0 zrender: 6.0.0 - electron-to-chromium@1.5.313: {} + electron-to-chromium@1.5.328: {} elkjs@0.11.1: {} @@ -12050,7 +13557,7 @@ snapshots: enhanced-resolve@5.20.1: dependencies: graceful-fs: 4.2.11 - tapable: 2.3.0 + tapable: 2.3.2 entities@4.5.0: {} @@ -12062,10 +13569,29 @@ snapshots: error-stack-parser-es@1.0.5: {} + error-stack-parser@2.1.4: + dependencies: + stackframe: 1.3.4 + + es-define-property@1.0.1: {} + + es-errors@1.3.0: {} + es-module-lexer@1.7.0: {} es-module-lexer@2.0.0: {} + es-object-atoms@1.1.1: + dependencies: + es-errors: 1.3.0 + + es-set-tostringtag@2.1.0: + dependencies: + es-errors: 1.3.0 + get-intrinsic: 1.3.0 + has-tostringtag: 1.0.2 + hasown: '@nolyfill/hasown@1.0.44' + es-toolkit@1.45.1: {} esast-util-from-estree@2.0.0: @@ -12121,92 +13647,102 @@ snapshots: escape-string-regexp@5.0.0: {} - eslint-compat-utils@0.5.1(eslint@10.0.3(jiti@1.21.7)): + eslint-compat-utils@0.5.1(eslint@10.1.0(jiti@1.21.7)): dependencies: - eslint: 10.0.3(jiti@1.21.7) + eslint: 10.1.0(jiti@1.21.7) semver: 7.7.4 - eslint-config-flat-gitignore@2.2.1(eslint@10.0.3(jiti@1.21.7)): + eslint-config-flat-gitignore@2.3.0(eslint@10.1.0(jiti@1.21.7)): dependencies: - '@eslint/compat': 2.0.3(eslint@10.0.3(jiti@1.21.7)) - eslint: 10.0.3(jiti@1.21.7) + '@eslint/compat': 2.0.3(eslint@10.1.0(jiti@1.21.7)) + eslint: 10.1.0(jiti@1.21.7) eslint-flat-config-utils@3.0.2: dependencies: '@eslint/config-helpers': 0.5.3 pathe: 2.0.3 - eslint-json-compat-utils@0.2.3(eslint@10.0.3(jiti@1.21.7))(jsonc-eslint-parser@3.1.0): + eslint-json-compat-utils@0.2.3(eslint@10.1.0(jiti@1.21.7))(jsonc-eslint-parser@3.1.0): dependencies: - eslint: 10.0.3(jiti@1.21.7) + eslint: 10.1.0(jiti@1.21.7) esquery: 1.7.0 jsonc-eslint-parser: 3.1.0 - eslint-merge-processors@2.0.0(eslint@10.0.3(jiti@1.21.7)): + eslint-markdown@0.6.0(eslint@10.1.0(jiti@1.21.7)): dependencies: - eslint: 10.0.3(jiti@1.21.7) + '@eslint/markdown': 7.5.1 + micromark-util-normalize-identifier: 2.0.1 + parse5: 8.0.0 + optionalDependencies: + eslint: 10.1.0(jiti@1.21.7) + transitivePeerDependencies: + - supports-color - eslint-plugin-antfu@3.2.2(eslint@10.0.3(jiti@1.21.7)): + eslint-merge-processors@2.0.0(eslint@10.1.0(jiti@1.21.7)): dependencies: - eslint: 10.0.3(jiti@1.21.7) + eslint: 10.1.0(jiti@1.21.7) - eslint-plugin-better-tailwindcss@4.3.2(eslint@10.0.3(jiti@1.21.7))(oxlint@1.55.0(oxlint-tsgolint@0.17.0))(tailwindcss@3.4.19(tsx@4.21.0)(yaml@2.8.2))(typescript@5.9.3): + eslint-plugin-antfu@3.2.2(eslint@10.1.0(jiti@1.21.7)): + dependencies: + eslint: 10.1.0(jiti@1.21.7) + + eslint-plugin-better-tailwindcss@4.3.2(eslint@10.1.0(jiti@1.21.7))(oxlint@1.57.0(oxlint-tsgolint@0.17.3))(tailwindcss@3.4.19(tsx@4.21.0)(yaml@2.8.3))(typescript@5.9.3): dependencies: '@eslint/css-tree': 3.6.9 - '@valibot/to-json-schema': 1.6.0(valibot@1.3.0(typescript@5.9.3)) + '@valibot/to-json-schema': 1.6.0(valibot@1.3.1(typescript@5.9.3)) enhanced-resolve: 5.20.1 jiti: 2.6.1 synckit: 0.11.12 tailwind-csstree: 0.1.5 - tailwindcss: 3.4.19(tsx@4.21.0)(yaml@2.8.2) + tailwindcss: 3.4.19(tsx@4.21.0)(yaml@2.8.3) tsconfig-paths-webpack-plugin: 4.2.0 - valibot: 1.3.0(typescript@5.9.3) + valibot: 1.3.1(typescript@5.9.3) optionalDependencies: - eslint: 10.0.3(jiti@1.21.7) - oxlint: 1.55.0(oxlint-tsgolint@0.17.0) + eslint: 10.1.0(jiti@1.21.7) + oxlint: 1.57.0(oxlint-tsgolint@0.17.3) transitivePeerDependencies: - '@eslint/css' - typescript - eslint-plugin-command@3.5.2(@typescript-eslint/rule-tester@8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3))(@typescript-eslint/typescript-estree@8.57.1(typescript@5.9.3))(@typescript-eslint/utils@8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3))(eslint@10.0.3(jiti@1.21.7)): + eslint-plugin-command@3.5.2(@typescript-eslint/rule-tester@8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(@typescript-eslint/typescript-estree@8.57.2(typescript@5.9.3))(@typescript-eslint/utils@8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(eslint@10.1.0(jiti@1.21.7)): dependencies: '@es-joy/jsdoccomment': 0.84.0 - '@typescript-eslint/rule-tester': 8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@typescript-eslint/typescript-estree': 8.57.1(typescript@5.9.3) - '@typescript-eslint/utils': 8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - eslint: 10.0.3(jiti@1.21.7) + '@typescript-eslint/rule-tester': 8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/typescript-estree': 8.57.2(typescript@5.9.3) + '@typescript-eslint/utils': 8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + eslint: 10.1.0(jiti@1.21.7) - eslint-plugin-depend@1.5.0(eslint@10.0.3(jiti@1.21.7)): + eslint-plugin-depend@1.5.0(eslint@10.1.0(jiti@1.21.7)): dependencies: empathic: 2.0.0 - eslint: 10.0.3(jiti@1.21.7) + eslint: 10.1.0(jiti@1.21.7) module-replacements: 2.11.0 semver: 7.7.4 - eslint-plugin-es-x@7.8.0(eslint@10.0.3(jiti@1.21.7)): + eslint-plugin-es-x@7.8.0(eslint@10.1.0(jiti@1.21.7)): dependencies: - '@eslint-community/eslint-utils': 4.9.1(eslint@10.0.3(jiti@1.21.7)) + '@eslint-community/eslint-utils': 4.9.1(eslint@10.1.0(jiti@1.21.7)) '@eslint-community/regexpp': 4.12.2 - eslint: 10.0.3(jiti@1.21.7) - eslint-compat-utils: 0.5.1(eslint@10.0.3(jiti@1.21.7)) + eslint: 10.1.0(jiti@1.21.7) + eslint-compat-utils: 0.5.1(eslint@10.1.0(jiti@1.21.7)) - eslint-plugin-hyoban@0.14.1(eslint@10.0.3(jiti@1.21.7)): + eslint-plugin-hyoban@0.14.1(eslint@10.1.0(jiti@1.21.7)): dependencies: - eslint: 10.0.3(jiti@1.21.7) + eslint: 10.1.0(jiti@1.21.7) - eslint-plugin-import-lite@0.5.2(eslint@10.0.3(jiti@1.21.7)): + eslint-plugin-import-lite@0.5.2(eslint@10.1.0(jiti@1.21.7)): dependencies: - eslint: 10.0.3(jiti@1.21.7) + eslint: 10.1.0(jiti@1.21.7) - eslint-plugin-jsdoc@62.8.0(eslint@10.0.3(jiti@1.21.7)): + eslint-plugin-jsdoc@62.8.1(eslint@10.1.0(jiti@1.21.7)): dependencies: '@es-joy/jsdoccomment': 0.84.0 '@es-joy/resolve.exports': 1.2.0 are-docs-informative: 0.0.2 comment-parser: 1.4.5 - debug: 4.4.3 + debug: 4.4.3(supports-color@8.1.1) escape-string-regexp: 4.0.0 - eslint: 10.0.3(jiti@1.21.7) + eslint: 10.1.0(jiti@1.21.7) espree: 11.2.0 esquery: 1.7.0 html-entities: 2.6.0 @@ -12218,28 +13754,48 @@ snapshots: transitivePeerDependencies: - supports-color - eslint-plugin-jsonc@3.1.2(eslint@10.0.3(jiti@1.21.7)): + eslint-plugin-jsonc@3.1.2(eslint@10.1.0(jiti@1.21.7)): dependencies: - '@eslint-community/eslint-utils': 4.9.1(eslint@10.0.3(jiti@1.21.7)) + '@eslint-community/eslint-utils': 4.9.1(eslint@10.1.0(jiti@1.21.7)) '@eslint/core': 1.1.1 '@eslint/plugin-kit': 0.6.1 '@ota-meshi/ast-token-store': 0.3.0 diff-sequences: 29.6.3 - eslint: 10.0.3(jiti@1.21.7) - eslint-json-compat-utils: 0.2.3(eslint@10.0.3(jiti@1.21.7))(jsonc-eslint-parser@3.1.0) + eslint: 10.1.0(jiti@1.21.7) + eslint-json-compat-utils: 0.2.3(eslint@10.1.0(jiti@1.21.7))(jsonc-eslint-parser@3.1.0) jsonc-eslint-parser: 3.1.0 natural-compare: 1.4.0 synckit: 0.11.12 transitivePeerDependencies: - '@eslint/json' - eslint-plugin-n@17.24.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3): + eslint-plugin-markdown-preferences@0.40.3(@eslint/markdown@7.5.1)(eslint@10.1.0(jiti@1.21.7)): dependencies: - '@eslint-community/eslint-utils': 4.9.1(eslint@10.0.3(jiti@1.21.7)) + '@eslint/markdown': 7.5.1 + diff-sequences: 29.6.3 + emoji-regex-xs: 2.0.1 + eslint: 10.1.0(jiti@1.21.7) + mdast-util-from-markdown: 2.0.3 + mdast-util-frontmatter: 2.0.1 + mdast-util-gfm: 3.1.0 + mdast-util-math: 3.0.0 + micromark-extension-frontmatter: 2.0.0 + micromark-extension-gfm: 3.0.0 + micromark-extension-math: 3.1.0 + micromark-factory-space: 2.0.1 + micromark-util-character: 2.1.1 + micromark-util-symbol: 2.0.1 + string-width: 8.2.0 + transitivePeerDependencies: + - supports-color + + eslint-plugin-n@17.24.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3): + dependencies: + '@eslint-community/eslint-utils': 4.9.1(eslint@10.1.0(jiti@1.21.7)) enhanced-resolve: 5.20.1 - eslint: 10.0.3(jiti@1.21.7) - eslint-plugin-es-x: 7.8.0(eslint@10.0.3(jiti@1.21.7)) - get-tsconfig: 4.13.6 + eslint: 10.1.0(jiti@1.21.7) + eslint-plugin-es-x: 7.8.0(eslint@10.1.0(jiti@1.21.7)) + get-tsconfig: 4.13.7 globals: 15.15.0 globrex: 0.1.2 ignore: 5.3.2 @@ -12248,163 +13804,152 @@ snapshots: transitivePeerDependencies: - typescript + eslint-plugin-no-barrel-files@1.2.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3): + dependencies: + '@typescript-eslint/utils': 8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + transitivePeerDependencies: + - eslint + - supports-color + - typescript + eslint-plugin-no-only-tests@3.3.0: {} - eslint-plugin-perfectionist@5.7.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3): + eslint-plugin-perfectionist@5.7.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3): dependencies: - '@typescript-eslint/utils': 8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - eslint: 10.0.3(jiti@1.21.7) + '@typescript-eslint/utils': 8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + eslint: 10.1.0(jiti@1.21.7) natural-orderby: 5.0.0 transitivePeerDependencies: - supports-color - typescript - eslint-plugin-pnpm@1.6.0(eslint@10.0.3(jiti@1.21.7)): + eslint-plugin-pnpm@1.6.0(eslint@10.1.0(jiti@1.21.7)): dependencies: empathic: 2.0.0 - eslint: 10.0.3(jiti@1.21.7) + eslint: 10.1.0(jiti@1.21.7) jsonc-eslint-parser: 3.1.0 pathe: 2.0.3 pnpm-workspace-yaml: 1.6.0 tinyglobby: 0.2.15 - yaml: 2.8.2 + yaml: 2.8.3 yaml-eslint-parser: 2.0.0 - eslint-plugin-react-dom@2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3): + eslint-plugin-react-dom@3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3): dependencies: - '@eslint-react/ast': 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/core': 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/eff': 2.13.0 - '@eslint-react/shared': 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/var': 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@typescript-eslint/scope-manager': 8.57.1 - '@typescript-eslint/types': 8.57.1 - '@typescript-eslint/utils': 8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/ast': 3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/core': 3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/shared': 3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/var': 3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/scope-manager': 8.57.2 + '@typescript-eslint/types': 8.57.2 + '@typescript-eslint/utils': 8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) compare-versions: 6.1.1 - eslint: 10.0.3(jiti@1.21.7) + eslint: 10.1.0(jiti@1.21.7) ts-pattern: 5.9.0 typescript: 5.9.3 transitivePeerDependencies: - supports-color - eslint-plugin-react-hooks-extra@2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3): - dependencies: - '@eslint-react/ast': 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/core': 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/eff': 2.13.0 - '@eslint-react/shared': 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/var': 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@typescript-eslint/scope-manager': 8.57.1 - '@typescript-eslint/type-utils': 8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@typescript-eslint/types': 8.57.1 - '@typescript-eslint/utils': 8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - eslint: 10.0.3(jiti@1.21.7) - ts-pattern: 5.9.0 - typescript: 5.9.3 - transitivePeerDependencies: - - supports-color - - eslint-plugin-react-hooks@7.0.1(eslint@10.0.3(jiti@1.21.7)): + eslint-plugin-react-hooks@7.0.1(eslint@10.1.0(jiti@1.21.7)): dependencies: '@babel/core': 7.29.0 '@babel/parser': 7.29.2 - eslint: 10.0.3(jiti@1.21.7) + eslint: 10.1.0(jiti@1.21.7) hermes-parser: 0.25.1 zod: 4.3.6 zod-validation-error: 4.0.2(zod@4.3.6) transitivePeerDependencies: - supports-color - eslint-plugin-react-naming-convention@2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3): + eslint-plugin-react-naming-convention@3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3): dependencies: - '@eslint-react/ast': 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/core': 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/eff': 2.13.0 - '@eslint-react/shared': 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/var': 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@typescript-eslint/scope-manager': 8.57.1 - '@typescript-eslint/type-utils': 8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@typescript-eslint/types': 8.57.1 - '@typescript-eslint/utils': 8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/ast': 3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/core': 3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/shared': 3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/var': 3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/scope-manager': 8.57.2 + '@typescript-eslint/type-utils': 8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/types': 8.57.2 + '@typescript-eslint/utils': 8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) compare-versions: 6.1.1 - eslint: 10.0.3(jiti@1.21.7) + eslint: 10.1.0(jiti@1.21.7) string-ts: 2.3.1 ts-pattern: 5.9.0 typescript: 5.9.3 transitivePeerDependencies: - supports-color - eslint-plugin-react-refresh@0.5.2(eslint@10.0.3(jiti@1.21.7)): + eslint-plugin-react-refresh@0.5.2(eslint@10.1.0(jiti@1.21.7)): dependencies: - eslint: 10.0.3(jiti@1.21.7) + eslint: 10.1.0(jiti@1.21.7) - eslint-plugin-react-rsc@2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3): + eslint-plugin-react-rsc@3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3): dependencies: - '@eslint-react/ast': 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/shared': 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/var': 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@typescript-eslint/types': 8.57.1 - '@typescript-eslint/utils': 8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - eslint: 10.0.3(jiti@1.21.7) + '@eslint-react/ast': 3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/shared': 3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/var': 3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/scope-manager': 8.57.2 + '@typescript-eslint/type-utils': 8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/types': 8.57.2 + '@typescript-eslint/utils': 8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + eslint: 10.1.0(jiti@1.21.7) ts-pattern: 5.9.0 typescript: 5.9.3 transitivePeerDependencies: - supports-color - eslint-plugin-react-web-api@2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3): + eslint-plugin-react-web-api@3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3): dependencies: - '@eslint-react/ast': 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/core': 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/eff': 2.13.0 - '@eslint-react/shared': 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/var': 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@typescript-eslint/scope-manager': 8.57.1 - '@typescript-eslint/types': 8.57.1 - '@typescript-eslint/utils': 8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/ast': 3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/core': 3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/shared': 3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/var': 3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/scope-manager': 8.57.2 + '@typescript-eslint/types': 8.57.2 + '@typescript-eslint/utils': 8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) birecord: 0.1.1 - eslint: 10.0.3(jiti@1.21.7) + eslint: 10.1.0(jiti@1.21.7) ts-pattern: 5.9.0 typescript: 5.9.3 transitivePeerDependencies: - supports-color - eslint-plugin-react-x@2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3): + eslint-plugin-react-x@3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3): dependencies: - '@eslint-react/ast': 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/core': 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/eff': 2.13.0 - '@eslint-react/shared': 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@eslint-react/var': 2.13.0(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@typescript-eslint/scope-manager': 8.57.1 - '@typescript-eslint/type-utils': 8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - '@typescript-eslint/types': 8.57.1 - '@typescript-eslint/utils': 8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/ast': 3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/core': 3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/shared': 3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@eslint-react/var': 3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/scope-manager': 8.57.2 + '@typescript-eslint/type-utils': 8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/types': 8.57.2 + '@typescript-eslint/utils': 8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) compare-versions: 6.1.1 - eslint: 10.0.3(jiti@1.21.7) - is-immutable-type: 5.0.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - ts-api-utils: 2.4.0(typescript@5.9.3) + eslint: 10.1.0(jiti@1.21.7) + string-ts: 2.3.1 + ts-api-utils: 2.5.0(typescript@5.9.3) ts-pattern: 5.9.0 typescript: 5.9.3 transitivePeerDependencies: - supports-color - eslint-plugin-regexp@3.1.0(eslint@10.0.3(jiti@1.21.7)): + eslint-plugin-regexp@3.1.0(eslint@10.1.0(jiti@1.21.7)): dependencies: - '@eslint-community/eslint-utils': 4.9.1(eslint@10.0.3(jiti@1.21.7)) + '@eslint-community/eslint-utils': 4.9.1(eslint@10.1.0(jiti@1.21.7)) '@eslint-community/regexpp': 4.12.2 - comment-parser: 1.4.5 - eslint: 10.0.3(jiti@1.21.7) + comment-parser: 1.4.6 + eslint: 10.1.0(jiti@1.21.7) jsdoc-type-pratt-parser: 7.1.1 refa: 0.12.1 regexp-ast-analysis: 0.7.1 scslre: 0.3.0 - eslint-plugin-sonarjs@4.0.2(eslint@10.0.3(jiti@1.21.7)): + eslint-plugin-sonarjs@4.0.2(eslint@10.1.0(jiti@1.21.7)): dependencies: '@eslint-community/regexpp': 4.12.2 builtin-modules: 3.3.0 bytes: 3.1.2 - eslint: 10.0.3(jiti@1.21.7) + eslint: 10.1.0(jiti@1.21.7) functional-red-black-tree: 1.0.1 globals: 17.4.0 jsx-ast-utils-x: 0.1.0 @@ -12412,38 +13957,38 @@ snapshots: minimatch: 10.2.4 scslre: 0.3.0 semver: 7.7.4 - ts-api-utils: 2.4.0(typescript@5.9.3) + ts-api-utils: 2.5.0(typescript@5.9.3) typescript: 5.9.3 - eslint-plugin-storybook@10.3.0(eslint@10.0.3(jiti@1.21.7))(storybook@10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3): + eslint-plugin-storybook@10.3.3(eslint@10.1.0(jiti@1.21.7))(storybook@10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3): dependencies: - '@typescript-eslint/utils': 8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - eslint: 10.0.3(jiti@1.21.7) - storybook: 10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) + '@typescript-eslint/utils': 8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + eslint: 10.1.0(jiti@1.21.7) + storybook: 10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) transitivePeerDependencies: - supports-color - typescript - eslint-plugin-toml@1.3.1(eslint@10.0.3(jiti@1.21.7)): + eslint-plugin-toml@1.3.1(eslint@10.1.0(jiti@1.21.7)): dependencies: '@eslint/core': 1.1.1 '@eslint/plugin-kit': 0.6.1 '@ota-meshi/ast-token-store': 0.3.0 - debug: 4.4.3 - eslint: 10.0.3(jiti@1.21.7) + debug: 4.4.3(supports-color@8.1.1) + eslint: 10.1.0(jiti@1.21.7) toml-eslint-parser: 1.0.3 transitivePeerDependencies: - supports-color - eslint-plugin-unicorn@63.0.0(eslint@10.0.3(jiti@1.21.7)): + eslint-plugin-unicorn@63.0.0(eslint@10.1.0(jiti@1.21.7)): dependencies: '@babel/helper-validator-identifier': 7.28.5 - '@eslint-community/eslint-utils': 4.9.1(eslint@10.0.3(jiti@1.21.7)) + '@eslint-community/eslint-utils': 4.9.1(eslint@10.1.0(jiti@1.21.7)) change-case: 5.4.4 ci-info: 4.4.0 clean-regexp: 1.0.0 core-js-compat: 3.49.0 - eslint: 10.0.3(jiti@1.21.7) + eslint: 10.1.0(jiti@1.21.7) find-up-simple: 1.0.1 globals: 16.5.0 indent-string: 5.0.0 @@ -12455,44 +14000,44 @@ snapshots: semver: 7.7.4 strip-indent: 4.1.1 - eslint-plugin-unused-imports@4.4.1(@typescript-eslint/eslint-plugin@8.57.1(@typescript-eslint/parser@8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3))(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3))(eslint@10.0.3(jiti@1.21.7)): + eslint-plugin-unused-imports@4.4.1(@typescript-eslint/eslint-plugin@8.57.2(@typescript-eslint/parser@8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(eslint@10.1.0(jiti@1.21.7)): dependencies: - eslint: 10.0.3(jiti@1.21.7) + eslint: 10.1.0(jiti@1.21.7) optionalDependencies: - '@typescript-eslint/eslint-plugin': 8.57.1(@typescript-eslint/parser@8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3))(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) + '@typescript-eslint/eslint-plugin': 8.57.2(@typescript-eslint/parser@8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) - eslint-plugin-vue@10.8.0(@stylistic/eslint-plugin@5.10.0(eslint@10.0.3(jiti@1.21.7)))(@typescript-eslint/parser@8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3))(eslint@10.0.3(jiti@1.21.7))(vue-eslint-parser@10.4.0(eslint@10.0.3(jiti@1.21.7))): + eslint-plugin-vue@10.8.0(@stylistic/eslint-plugin@5.10.0(eslint@10.1.0(jiti@1.21.7)))(@typescript-eslint/parser@8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(eslint@10.1.0(jiti@1.21.7))(vue-eslint-parser@10.4.0(eslint@10.1.0(jiti@1.21.7))): dependencies: - '@eslint-community/eslint-utils': 4.9.1(eslint@10.0.3(jiti@1.21.7)) - eslint: 10.0.3(jiti@1.21.7) + '@eslint-community/eslint-utils': 4.9.1(eslint@10.1.0(jiti@1.21.7)) + eslint: 10.1.0(jiti@1.21.7) natural-compare: 1.4.0 nth-check: 2.1.1 postcss-selector-parser: 7.1.1 semver: 7.7.4 - vue-eslint-parser: 10.4.0(eslint@10.0.3(jiti@1.21.7)) + vue-eslint-parser: 10.4.0(eslint@10.1.0(jiti@1.21.7)) xml-name-validator: 4.0.0 optionalDependencies: - '@stylistic/eslint-plugin': 5.10.0(eslint@10.0.3(jiti@1.21.7)) - '@typescript-eslint/parser': 8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) + '@stylistic/eslint-plugin': 5.10.0(eslint@10.1.0(jiti@1.21.7)) + '@typescript-eslint/parser': 8.57.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) - eslint-plugin-yml@3.3.1(eslint@10.0.3(jiti@1.21.7)): + eslint-plugin-yml@3.3.1(eslint@10.1.0(jiti@1.21.7)): dependencies: '@eslint/core': 1.1.1 '@eslint/plugin-kit': 0.6.1 '@ota-meshi/ast-token-store': 0.3.0 - debug: 4.4.3 + debug: 4.4.3(supports-color@8.1.1) diff-sequences: 29.6.3 escape-string-regexp: 5.0.0 - eslint: 10.0.3(jiti@1.21.7) + eslint: 10.1.0(jiti@1.21.7) natural-compare: 1.4.0 yaml-eslint-parser: 2.0.0 transitivePeerDependencies: - supports-color - eslint-processor-vue-blocks@2.0.0(@vue/compiler-sfc@3.5.30)(eslint@10.0.3(jiti@1.21.7)): + eslint-processor-vue-blocks@2.0.0(@vue/compiler-sfc@3.5.31)(eslint@10.1.0(jiti@1.21.7)): dependencies: - '@vue/compiler-sfc': 3.5.30 - eslint: 10.0.3(jiti@1.21.7) + '@vue/compiler-sfc': 3.5.31 + eslint: 10.1.0(jiti@1.21.7) eslint-scope@5.1.1: dependencies: @@ -12517,9 +14062,9 @@ snapshots: eslint-visitor-keys@5.0.1: {} - eslint@10.0.3(jiti@1.21.7): + eslint@10.1.0(jiti@1.21.7): dependencies: - '@eslint-community/eslint-utils': 4.9.1(eslint@10.0.3(jiti@1.21.7)) + '@eslint-community/eslint-utils': 4.9.1(eslint@10.1.0(jiti@1.21.7)) '@eslint-community/regexpp': 4.12.2 '@eslint/config-array': 0.23.3 '@eslint/config-helpers': 0.5.3 @@ -12531,7 +14076,7 @@ snapshots: '@types/estree': 1.0.8 ajv: 6.14.0 cross-spawn: 7.0.6 - debug: 4.4.3 + debug: 4.4.3(supports-color@8.1.1) escape-string-regexp: 4.0.0 eslint-scope: 9.1.2 eslint-visitor-keys: 5.0.1 @@ -12554,6 +14099,43 @@ snapshots: transitivePeerDependencies: - supports-color + eslint@10.1.0(jiti@2.6.1): + dependencies: + '@eslint-community/eslint-utils': 4.9.1(eslint@10.1.0(jiti@2.6.1)) + '@eslint-community/regexpp': 4.12.2 + '@eslint/config-array': 0.23.3 + '@eslint/config-helpers': 0.5.3 + '@eslint/core': 1.1.1 + '@eslint/plugin-kit': 0.6.1 + '@humanfs/node': 0.16.7 + '@humanwhocodes/module-importer': 1.0.1 + '@humanwhocodes/retry': 0.4.3 + '@types/estree': 1.0.8 + ajv: 6.14.0 + cross-spawn: 7.0.6 + debug: 4.4.3(supports-color@8.1.1) + escape-string-regexp: 4.0.0 + eslint-scope: 9.1.2 + eslint-visitor-keys: 5.0.1 + espree: 11.2.0 + esquery: 1.7.0 + esutils: 2.0.3 + fast-deep-equal: 3.1.3 + file-entry-cache: 8.0.0 + find-up: 5.0.0 + glob-parent: 6.0.2 + ignore: 5.3.2 + imurmurhash: 0.1.4 + is-glob: 4.0.3 + json-stable-stringify-without-jsonify: 1.0.1 + minimatch: 10.2.4 + natural-compare: 1.4.0 + optionator: 0.9.4 + optionalDependencies: + jiti: 2.6.1 + transitivePeerDependencies: + - supports-color + eslint@9.27.0(jiti@1.21.7): dependencies: '@eslint-community/eslint-utils': 4.9.1(eslint@9.27.0(jiti@1.21.7)) @@ -12572,7 +14154,7 @@ snapshots: ajv: 6.14.0 chalk: 4.1.2 cross-spawn: 7.0.6 - debug: 4.4.3 + debug: 4.4.3(supports-color@8.1.1) escape-string-regexp: 4.0.0 eslint-scope: 8.4.0 eslint-visitor-keys: 4.2.1 @@ -12674,7 +14256,7 @@ snapshots: extract-zip@2.0.1: dependencies: - debug: 4.4.3 + debug: 4.4.3(supports-color@8.1.1) get-stream: 5.2.0 yauzl: 3.2.1 optionalDependencies: @@ -12682,8 +14264,6 @@ snapshots: transitivePeerDependencies: - supports-color - fast-content-type-parse@3.0.0: {} - fast-deep-equal@3.1.3: {} fast-glob@3.3.1: @@ -12724,14 +14304,18 @@ snapshots: dependencies: walk-up-path: 4.0.0 - fdir@6.5.0(picomatch@4.0.3): + fdir@6.5.0(picomatch@4.0.4): optionalDependencies: - picomatch: 4.0.3 + picomatch: 4.0.4 fflate@0.4.8: {} fflate@0.7.4: {} + figures@3.2.0: + dependencies: + escape-string-regexp: 1.0.5 + file-entry-cache@8.0.0: dependencies: flat-cache: 4.0.1 @@ -12749,6 +14333,12 @@ snapshots: locate-path: 6.0.0 path-exists: 4.0.0 + fix-dts-default-cjs-exports@1.0.1: + dependencies: + magic-string: 0.30.21 + mlly: 1.8.2 + rollup: 4.59.0 + flat-cache@4.0.1: dependencies: flatted: 3.4.2 @@ -12756,6 +14346,16 @@ snapshots: flatted@3.4.2: {} + follow-redirects@1.15.11: {} + + form-data@4.0.5: + dependencies: + asynckit: 0.4.0 + combined-stream: 1.0.8 + es-set-tostringtag: 2.1.0 + hasown: '@nolyfill/hasown@1.0.44' + mime-types: 2.1.35 + format@0.2.2: {} formatly@0.3.0: @@ -12776,9 +14376,14 @@ snapshots: fs-constants@1.0.0: optional: true + fsevents@2.3.2: + optional: true + fsevents@2.3.3: optional: true + function-bind@1.1.2: {} + functional-red-black-tree@1.0.1: {} fzf@0.5.2: {} @@ -12787,13 +14392,31 @@ snapshots: get-east-asian-width@1.5.0: {} + get-intrinsic@1.3.0: + dependencies: + call-bind-apply-helpers: 1.0.2 + es-define-property: 1.0.1 + es-errors: 1.3.0 + es-object-atoms: 1.1.1 + function-bind: 1.1.2 + get-proto: 1.0.1 + gopd: 1.2.0 + has-symbols: 1.1.0 + hasown: '@nolyfill/hasown@1.0.44' + math-intrinsics: 1.1.0 + get-nonce@1.0.1: {} + get-proto@1.0.1: + dependencies: + dunder-proto: 1.0.1 + es-object-atoms: 1.1.1 + get-stream@5.2.0: dependencies: pump: 3.0.4 - get-tsconfig@4.13.6: + get-tsconfig@4.13.7: dependencies: resolve-pkg-maps: 1.0.0 @@ -12818,6 +14441,10 @@ snapshots: minipass: 7.1.3 path-scurry: 2.0.2 + global-dirs@3.0.1: + dependencies: + ini: 2.0.0 + globals@14.0.0: {} globals@15.15.0: {} @@ -12832,12 +14459,36 @@ snapshots: dependencies: csstype: 3.2.3 + gopd@1.2.0: {} + graceful-fs@4.2.11: {} hachure-fill@0.5.2: {} + happy-dom@20.8.9: + dependencies: + '@types/node': 25.5.0 + '@types/whatwg-mimetype': 3.0.2 + '@types/ws': 8.18.1 + entities: 7.0.1 + whatwg-mimetype: 3.0.0 + ws: 8.20.0 + transitivePeerDependencies: + - bufferutil + - utf-8-validate + + has-ansi@4.0.1: + dependencies: + ansi-regex: 4.1.1 + has-flag@4.0.0: {} + has-symbols@1.1.0: {} + + has-tostringtag@1.0.2: + dependencies: + has-symbols: 1.1.0 + hast-util-from-dom@5.0.1: dependencies: '@types/hast': 3.0.4 @@ -12993,13 +14644,11 @@ snapshots: highlightjs-vue@1.0.0: {} - hono@4.12.8: {} + hono@4.12.9: {} - html-encoding-sniffer@6.0.0: + hosted-git-info@9.0.2: dependencies: - '@exodus/bytes': 1.15.0 - transitivePeerDependencies: - - '@noble/hashes' + lru-cache: 11.2.7 html-entities@2.6.0: {} @@ -13028,7 +14677,7 @@ snapshots: dependencies: '@babel/runtime': 7.29.2 - i18next@25.8.18(typescript@5.9.3): + i18next@25.10.10(typescript@5.9.3): dependencies: '@babel/runtime': 7.29.2 optionalDependencies: @@ -13074,12 +14723,16 @@ snapshots: indent-string@5.0.0: {} + index-to-position@1.2.0: {} + inherits@2.0.4: optional: true ini@1.3.8: optional: true + ini@2.0.0: {} + inline-style-parser@0.2.7: {} internmap@1.0.1: {} @@ -13130,32 +14783,29 @@ snapshots: is-hexadecimal@2.0.1: {} - is-immutable-type@5.0.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3): - dependencies: - '@typescript-eslint/type-utils': 8.57.1(eslint@10.0.3(jiti@1.21.7))(typescript@5.9.3) - eslint: 10.0.3(jiti@1.21.7) - ts-api-utils: 2.4.0(typescript@5.9.3) - ts-declaration-location: 1.0.7(typescript@5.9.3) - typescript: 5.9.3 - transitivePeerDependencies: - - supports-color + is-in-ssh@1.0.0: {} is-inside-container@1.0.0: dependencies: is-docker: 3.0.0 - is-node-process@1.2.0: {} + is-installed-globally@0.4.0: + dependencies: + global-dirs: 3.0.1 + is-path-inside: 3.0.3 is-number@7.0.0: {} - is-plain-obj@4.1.0: {} + is-path-inside@3.0.3: {} - is-potential-custom-element-name@1.0.1: {} + is-plain-obj@4.1.0: {} is-reference@3.0.3: dependencies: '@types/estree': 1.0.8 + is-stream@2.0.1: {} + is-wsl@3.1.1: dependencies: is-inside-container: 1.0.0 @@ -13187,13 +14837,15 @@ snapshots: jiti@2.6.1: {} - jotai@2.18.1(@babel/core@7.29.0)(@babel/template@7.28.6)(@types/react@19.2.14)(react@19.2.4): + jotai@2.19.0(@babel/core@7.29.0)(@babel/template@7.28.6)(@types/react@19.2.14)(react@19.2.4): optionalDependencies: '@babel/core': 7.29.0 '@babel/template': 7.28.6 '@types/react': 19.2.14 react: 19.2.4 + joycon@3.1.1: {} + js-audio-recorder@1.0.7: {} js-base64@3.7.8: {} @@ -13212,39 +14864,6 @@ snapshots: jsdoc-type-pratt-parser@7.1.1: {} - jsdom-testing-mocks@1.16.0: - dependencies: - bezier-easing: 2.1.0 - css-mediaquery: 0.1.2 - - jsdom@29.0.0(canvas@3.2.1): - dependencies: - '@asamuzakjp/css-color': 5.0.1 - '@asamuzakjp/dom-selector': 7.0.3 - '@bramus/specificity': 2.4.2 - '@csstools/css-syntax-patches-for-csstree': 1.1.1(css-tree@3.2.1) - '@exodus/bytes': 1.15.0 - css-tree: 3.2.1 - data-urls: 7.0.0 - decimal.js: 10.6.0 - html-encoding-sniffer: 6.0.0 - is-potential-custom-element-name: 1.0.1 - lru-cache: 11.2.7 - parse5: 8.0.0 - saxes: 6.0.0 - symbol-tree: 3.2.4 - tough-cookie: 6.0.1 - undici: 7.24.4 - w3c-xmlserializer: 5.0.0 - webidl-conversions: 8.0.1 - whatwg-mimetype: 5.0.0 - whatwg-url: 16.0.1 - xml-name-validator: 5.0.0 - optionalDependencies: - canvas: 3.2.1 - transitivePeerDependencies: - - '@noble/hashes' - jsesc@3.1.0: {} json-buffer@3.0.1: {} @@ -13257,10 +14876,6 @@ snapshots: json-stable-stringify-without-jsonify@1.0.1: {} - json-stringify-safe@5.0.1: {} - - json-with-bigint@3.5.7: {} - json5@2.2.3: {} jsonc-eslint-parser@3.1.0: @@ -13279,7 +14894,7 @@ snapshots: jsx-ast-utils-x@0.1.0: {} - katex@0.16.38: + katex@0.16.44: dependencies: commander: 8.3.0 @@ -13289,23 +14904,30 @@ snapshots: khroma@2.1.0: {} - knip@5.88.0(@types/node@25.5.0)(typescript@5.9.3): + knip@6.1.0(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1): dependencies: '@nodelib/fs.walk': 1.2.8 - '@types/node': 25.5.0 fast-glob: 3.3.3 formatly: 0.3.0 + get-tsconfig: 4.13.7 jiti: 2.6.1 minimist: 1.2.8 - oxc-resolver: 11.19.1 + oxc-parser: 0.121.0(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1) + oxc-resolver: 11.19.1(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1) picocolors: 1.1.1 - picomatch: 4.0.3 - smol-toml: 1.6.0 + picomatch: 4.0.4 + smol-toml: 1.6.1 strip-json-comments: 5.0.3 - typescript: 5.9.3 unbash: 2.2.0 - yaml: 2.8.2 + yaml: 2.8.3 zod: 4.3.6 + transitivePeerDependencies: + - '@emnapi/core' + - '@emnapi/runtime' + + knuth-shuffle-seeded@1.0.6: + dependencies: + seed-random: 2.2.0 kolorist@1.8.0: {} @@ -13337,12 +14959,12 @@ snapshots: prelude-ls: 1.2.1 type-check: 0.4.0 - lexical-code-no-prism@0.41.0(@lexical/utils@0.41.0)(lexical@0.41.0): + lexical-code-no-prism@0.41.0(@lexical/utils@0.42.0)(lexical@0.42.0): dependencies: - '@lexical/utils': 0.41.0 - lexical: 0.41.0 + '@lexical/utils': 0.42.0 + lexical: 0.42.0 - lexical@0.41.0: {} + lexical@0.42.0: {} lib0@0.2.117: dependencies: @@ -13410,10 +15032,10 @@ snapshots: dependencies: commander: 14.0.3 listr2: 9.0.5 - picomatch: 4.0.3 + picomatch: 4.0.4 string-argv: 0.3.2 tinyexec: 1.0.4 - yaml: 2.8.2 + yaml: 2.8.3 listr2@9.0.5: dependencies: @@ -13424,11 +15046,13 @@ snapshots: rfdc: 1.4.1 wrap-ansi: 9.0.2 + load-tsconfig@0.2.5: {} + loader-runner@4.3.1: {} local-pkg@1.1.2: dependencies: - mlly: 1.8.1 + mlly: 1.8.2 pkg-types: 2.3.0 quansync: 0.2.11 @@ -13440,6 +15064,10 @@ snapshots: lodash.merge@4.6.2: {} + lodash.mergewith@4.6.2: {} + + lodash.sortby@4.7.0: {} + lodash@4.17.23: {} log-update@6.1.0: @@ -13458,6 +15086,10 @@ snapshots: loupe@3.2.1: {} + lower-case@2.0.2: + dependencies: + tslib: 2.8.1 + lowlight@1.20.0: dependencies: fault: 1.0.4 @@ -13495,7 +15127,9 @@ snapshots: marked@16.4.2: {} - marked@17.0.4: {} + marked@17.0.5: {} + + math-intrinsics@1.1.0: {} mdast-util-directive@3.1.0: dependencies: @@ -13708,8 +15342,6 @@ snapshots: mdn-data@2.23.0: {} - mdn-data@2.27.1: {} - memoize-one@5.2.1: {} merge-stream@2.0.0: {} @@ -13731,7 +15363,7 @@ snapshots: dagre-d3-es: 7.0.14 dayjs: 1.11.20 dompurify: 3.3.2 - katex: 0.16.38 + katex: 0.16.44 khroma: 2.1.0 lodash-es: 4.17.23 marked: 16.4.2 @@ -13838,7 +15470,7 @@ snapshots: dependencies: '@types/katex': 0.16.8 devlop: 1.1.0 - katex: 0.16.38 + katex: 0.16.44 micromark-factory-space: 2.0.1 micromark-util-character: 2.1.1 micromark-util-symbol: 2.0.1 @@ -14011,8 +15643,8 @@ snapshots: micromark@4.0.2: dependencies: - '@types/debug': 4.1.12 - debug: 4.4.3 + '@types/debug': 4.1.13 + debug: 4.4.3(supports-color@8.1.1) decode-named-character-reference: 1.3.0 devlop: 1.1.0 micromark-core-commonmark: 2.0.3 @@ -14034,7 +15666,7 @@ snapshots: micromatch@4.0.8: dependencies: braces: 3.0.3 - picomatch: 2.3.1 + picomatch: 2.3.2 mime-db@1.52.0: {} @@ -14042,6 +15674,8 @@ snapshots: dependencies: mime-db: 1.52.0 + mime@3.0.0: {} + mime@4.1.0: {} mimic-function@5.0.1: {} @@ -14053,7 +15687,7 @@ snapshots: minimatch@10.2.4: dependencies: - brace-expansion: 5.0.4 + brace-expansion: 5.0.5 minimatch@3.1.5: dependencies: @@ -14072,7 +15706,9 @@ snapshots: mkdirp-classic@0.5.3: optional: true - mlly@1.8.1: + mkdirp@3.0.1: {} + + mlly@1.8.2: dependencies: acorn: 8.16.0 pathe: 2.0.3 @@ -14120,36 +15756,36 @@ snapshots: react: 19.2.4 react-dom: 19.2.4(react@19.2.4) - next@16.2.0(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0): + next@16.2.1(@babel/core@7.29.0)(@playwright/test@1.58.2)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0): dependencies: - '@next/env': 16.2.0 + '@next/env': 16.2.1 '@swc/helpers': 0.5.15 - baseline-browser-mapping: 2.10.8 - caniuse-lite: 1.0.30001780 + baseline-browser-mapping: 2.10.12 + caniuse-lite: 1.0.30001781 postcss: 8.4.31 react: 19.2.4 react-dom: 19.2.4(react@19.2.4) styled-jsx: 5.1.6(@babel/core@7.29.0)(react@19.2.4) optionalDependencies: - '@next/swc-darwin-arm64': 16.2.0 - '@next/swc-darwin-x64': 16.2.0 - '@next/swc-linux-arm64-gnu': 16.2.0 - '@next/swc-linux-arm64-musl': 16.2.0 - '@next/swc-linux-x64-gnu': 16.2.0 - '@next/swc-linux-x64-musl': 16.2.0 - '@next/swc-win32-arm64-msvc': 16.2.0 - '@next/swc-win32-x64-msvc': 16.2.0 + '@next/swc-darwin-arm64': 16.2.1 + '@next/swc-darwin-x64': 16.2.1 + '@next/swc-linux-arm64-gnu': 16.2.1 + '@next/swc-linux-arm64-musl': 16.2.1 + '@next/swc-linux-x64-gnu': 16.2.1 + '@next/swc-linux-x64-musl': 16.2.1 + '@next/swc-win32-arm64-msvc': 16.2.1 + '@next/swc-win32-x64-msvc': 16.2.1 + '@playwright/test': 1.58.2 sass: 1.98.0 sharp: 0.34.5 transitivePeerDependencies: - '@babel/core' - babel-plugin-macros - nock@14.0.11: + no-case@3.0.4: dependencies: - '@mswjs/interceptors': 0.41.3 - json-stringify-safe: 5.0.1 - propagate: 2.0.1 + lower-case: 2.0.2 + tslib: 2.8.1 node-abi@3.89.0: dependencies: @@ -14163,6 +15799,12 @@ snapshots: node-releases@2.0.36: {} + normalize-package-data@8.0.0: + dependencies: + hosted-git-info: 9.0.2 + semver: 7.7.4 + validate-npm-package-license: 3.0.4 + normalize-path@3.0.0: {} normalize-wheel@1.0.1: {} @@ -14171,12 +15813,12 @@ snapshots: dependencies: boolbase: 1.0.0 - nuqs@2.8.9(next@16.2.0(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0))(react@19.2.4): + nuqs@2.8.9(next@16.2.1(@babel/core@7.29.0)(@playwright/test@1.58.2)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0))(react@19.2.4): dependencies: '@standard-schema/spec': 1.0.0 react: 19.2.4 optionalDependencies: - next: 16.2.0(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0) + next: 16.2.1(@babel/core@7.29.0)(@playwright/test@1.58.2)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0) object-assign@4.1.1: {} @@ -14209,6 +15851,15 @@ snapshots: is-inside-container: 1.0.0 wsl-utils: 0.1.0 + open@11.0.0: + dependencies: + default-browser: 5.5.0 + define-lazy-prop: 3.0.0 + is-in-ssh: 1.0.0 + is-inside-container: 1.0.0 + powershell-utils: 0.1.0 + wsl-utils: 0.3.1 + openapi-types@12.1.3: {} optionator@0.9.4: @@ -14220,9 +15871,35 @@ snapshots: type-check: 0.4.0 word-wrap: 1.2.5 - outvariant@1.4.3: {} + oxc-parser@0.121.0(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1): + dependencies: + '@oxc-project/types': 0.121.0 + optionalDependencies: + '@oxc-parser/binding-android-arm-eabi': 0.121.0 + '@oxc-parser/binding-android-arm64': 0.121.0 + '@oxc-parser/binding-darwin-arm64': 0.121.0 + '@oxc-parser/binding-darwin-x64': 0.121.0 + '@oxc-parser/binding-freebsd-x64': 0.121.0 + '@oxc-parser/binding-linux-arm-gnueabihf': 0.121.0 + '@oxc-parser/binding-linux-arm-musleabihf': 0.121.0 + '@oxc-parser/binding-linux-arm64-gnu': 0.121.0 + '@oxc-parser/binding-linux-arm64-musl': 0.121.0 + '@oxc-parser/binding-linux-ppc64-gnu': 0.121.0 + '@oxc-parser/binding-linux-riscv64-gnu': 0.121.0 + '@oxc-parser/binding-linux-riscv64-musl': 0.121.0 + '@oxc-parser/binding-linux-s390x-gnu': 0.121.0 + '@oxc-parser/binding-linux-x64-gnu': 0.121.0 + '@oxc-parser/binding-linux-x64-musl': 0.121.0 + '@oxc-parser/binding-openharmony-arm64': 0.121.0 + '@oxc-parser/binding-wasm32-wasi': 0.121.0(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1) + '@oxc-parser/binding-win32-arm64-msvc': 0.121.0 + '@oxc-parser/binding-win32-ia32-msvc': 0.121.0 + '@oxc-parser/binding-win32-x64-msvc': 0.121.0 + transitivePeerDependencies: + - '@emnapi/core' + - '@emnapi/runtime' - oxc-resolver@11.19.1: + oxc-resolver@11.19.1(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1): optionalDependencies: '@oxc-resolver/binding-android-arm-eabi': 11.19.1 '@oxc-resolver/binding-android-arm64': 11.19.1 @@ -14240,77 +15917,88 @@ snapshots: '@oxc-resolver/binding-linux-x64-gnu': 11.19.1 '@oxc-resolver/binding-linux-x64-musl': 11.19.1 '@oxc-resolver/binding-openharmony-arm64': 11.19.1 - '@oxc-resolver/binding-wasm32-wasi': 11.19.1 + '@oxc-resolver/binding-wasm32-wasi': 11.19.1(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1) '@oxc-resolver/binding-win32-arm64-msvc': 11.19.1 '@oxc-resolver/binding-win32-ia32-msvc': 11.19.1 '@oxc-resolver/binding-win32-x64-msvc': 11.19.1 + transitivePeerDependencies: + - '@emnapi/core' + - '@emnapi/runtime' - oxfmt@0.40.0: + oxfmt@0.42.0: dependencies: tinypool: 2.1.0 optionalDependencies: - '@oxfmt/binding-android-arm-eabi': 0.40.0 - '@oxfmt/binding-android-arm64': 0.40.0 - '@oxfmt/binding-darwin-arm64': 0.40.0 - '@oxfmt/binding-darwin-x64': 0.40.0 - '@oxfmt/binding-freebsd-x64': 0.40.0 - '@oxfmt/binding-linux-arm-gnueabihf': 0.40.0 - '@oxfmt/binding-linux-arm-musleabihf': 0.40.0 - '@oxfmt/binding-linux-arm64-gnu': 0.40.0 - '@oxfmt/binding-linux-arm64-musl': 0.40.0 - '@oxfmt/binding-linux-ppc64-gnu': 0.40.0 - '@oxfmt/binding-linux-riscv64-gnu': 0.40.0 - '@oxfmt/binding-linux-riscv64-musl': 0.40.0 - '@oxfmt/binding-linux-s390x-gnu': 0.40.0 - '@oxfmt/binding-linux-x64-gnu': 0.40.0 - '@oxfmt/binding-linux-x64-musl': 0.40.0 - '@oxfmt/binding-openharmony-arm64': 0.40.0 - '@oxfmt/binding-win32-arm64-msvc': 0.40.0 - '@oxfmt/binding-win32-ia32-msvc': 0.40.0 - '@oxfmt/binding-win32-x64-msvc': 0.40.0 + '@oxfmt/binding-android-arm-eabi': 0.42.0 + '@oxfmt/binding-android-arm64': 0.42.0 + '@oxfmt/binding-darwin-arm64': 0.42.0 + '@oxfmt/binding-darwin-x64': 0.42.0 + '@oxfmt/binding-freebsd-x64': 0.42.0 + '@oxfmt/binding-linux-arm-gnueabihf': 0.42.0 + '@oxfmt/binding-linux-arm-musleabihf': 0.42.0 + '@oxfmt/binding-linux-arm64-gnu': 0.42.0 + '@oxfmt/binding-linux-arm64-musl': 0.42.0 + '@oxfmt/binding-linux-ppc64-gnu': 0.42.0 + '@oxfmt/binding-linux-riscv64-gnu': 0.42.0 + '@oxfmt/binding-linux-riscv64-musl': 0.42.0 + '@oxfmt/binding-linux-s390x-gnu': 0.42.0 + '@oxfmt/binding-linux-x64-gnu': 0.42.0 + '@oxfmt/binding-linux-x64-musl': 0.42.0 + '@oxfmt/binding-openharmony-arm64': 0.42.0 + '@oxfmt/binding-win32-arm64-msvc': 0.42.0 + '@oxfmt/binding-win32-ia32-msvc': 0.42.0 + '@oxfmt/binding-win32-x64-msvc': 0.42.0 - oxlint-tsgolint@0.17.0: + oxlint-tsgolint@0.17.3: optionalDependencies: - '@oxlint-tsgolint/darwin-arm64': 0.17.0 - '@oxlint-tsgolint/darwin-x64': 0.17.0 - '@oxlint-tsgolint/linux-arm64': 0.17.0 - '@oxlint-tsgolint/linux-x64': 0.17.0 - '@oxlint-tsgolint/win32-arm64': 0.17.0 - '@oxlint-tsgolint/win32-x64': 0.17.0 + '@oxlint-tsgolint/darwin-arm64': 0.17.3 + '@oxlint-tsgolint/darwin-x64': 0.17.3 + '@oxlint-tsgolint/linux-arm64': 0.17.3 + '@oxlint-tsgolint/linux-x64': 0.17.3 + '@oxlint-tsgolint/win32-arm64': 0.17.3 + '@oxlint-tsgolint/win32-x64': 0.17.3 - oxlint@1.55.0(oxlint-tsgolint@0.17.0): + oxlint@1.57.0(oxlint-tsgolint@0.17.3): optionalDependencies: - '@oxlint/binding-android-arm-eabi': 1.55.0 - '@oxlint/binding-android-arm64': 1.55.0 - '@oxlint/binding-darwin-arm64': 1.55.0 - '@oxlint/binding-darwin-x64': 1.55.0 - '@oxlint/binding-freebsd-x64': 1.55.0 - '@oxlint/binding-linux-arm-gnueabihf': 1.55.0 - '@oxlint/binding-linux-arm-musleabihf': 1.55.0 - '@oxlint/binding-linux-arm64-gnu': 1.55.0 - '@oxlint/binding-linux-arm64-musl': 1.55.0 - '@oxlint/binding-linux-ppc64-gnu': 1.55.0 - '@oxlint/binding-linux-riscv64-gnu': 1.55.0 - '@oxlint/binding-linux-riscv64-musl': 1.55.0 - '@oxlint/binding-linux-s390x-gnu': 1.55.0 - '@oxlint/binding-linux-x64-gnu': 1.55.0 - '@oxlint/binding-linux-x64-musl': 1.55.0 - '@oxlint/binding-openharmony-arm64': 1.55.0 - '@oxlint/binding-win32-arm64-msvc': 1.55.0 - '@oxlint/binding-win32-ia32-msvc': 1.55.0 - '@oxlint/binding-win32-x64-msvc': 1.55.0 - oxlint-tsgolint: 0.17.0 + '@oxlint/binding-android-arm-eabi': 1.57.0 + '@oxlint/binding-android-arm64': 1.57.0 + '@oxlint/binding-darwin-arm64': 1.57.0 + '@oxlint/binding-darwin-x64': 1.57.0 + '@oxlint/binding-freebsd-x64': 1.57.0 + '@oxlint/binding-linux-arm-gnueabihf': 1.57.0 + '@oxlint/binding-linux-arm-musleabihf': 1.57.0 + '@oxlint/binding-linux-arm64-gnu': 1.57.0 + '@oxlint/binding-linux-arm64-musl': 1.57.0 + '@oxlint/binding-linux-ppc64-gnu': 1.57.0 + '@oxlint/binding-linux-riscv64-gnu': 1.57.0 + '@oxlint/binding-linux-riscv64-musl': 1.57.0 + '@oxlint/binding-linux-s390x-gnu': 1.57.0 + '@oxlint/binding-linux-x64-gnu': 1.57.0 + '@oxlint/binding-linux-x64-musl': 1.57.0 + '@oxlint/binding-openharmony-arm64': 1.57.0 + '@oxlint/binding-win32-arm64-msvc': 1.57.0 + '@oxlint/binding-win32-ia32-msvc': 1.57.0 + '@oxlint/binding-win32-x64-msvc': 1.57.0 + oxlint-tsgolint: 0.17.3 p-limit@3.1.0: dependencies: yocto-queue: 0.1.0 + p-limit@7.3.0: + dependencies: + yocto-queue: 1.2.2 + p-locate@5.0.0: dependencies: p-limit: 3.1.0 package-manager-detector@1.6.0: {} + pad-right@0.2.2: + dependencies: + repeat-string: 1.6.1 + pako@0.2.9: {} papaparse@5.5.3: {} @@ -14349,6 +16037,12 @@ snapshots: dependencies: parse-statements: 1.0.11 + parse-json@8.3.0: + dependencies: + '@babel/code-frame': 7.29.0 + index-to-position: 1.2.0 + type-fest: 4.41.0 + parse-statements@1.0.11: {} parse5-htmlparser2-tree-adapter@7.1.0: @@ -14392,7 +16086,7 @@ snapshots: pdfjs-dist@4.4.168: optionalDependencies: - canvas: 3.2.1 + canvas: 3.2.2 path2d: 0.2.2 pend@1.2.0: {} @@ -14407,9 +16101,9 @@ snapshots: picocolors@1.1.1: {} - picomatch@2.3.1: {} + picomatch@2.3.2: {} - picomatch@4.0.3: {} + picomatch@4.0.4: {} pify@2.3.0: {} @@ -14424,7 +16118,7 @@ snapshots: pkg-types@1.3.1: dependencies: confbox: 0.1.8 - mlly: 1.8.1 + mlly: 1.8.2 pathe: 2.0.3 pkg-types@2.3.0: @@ -14433,13 +16127,21 @@ snapshots: exsolve: 1.0.8 pathe: 2.0.3 + playwright-core@1.58.2: {} + + playwright@1.58.2: + dependencies: + playwright-core: 1.58.2 + optionalDependencies: + fsevents: 2.3.2 + pluralize@8.0.0: {} pngjs@7.0.0: {} pnpm-workspace-yaml@1.6.0: dependencies: - yaml: 2.8.2 + yaml: 2.8.3 points-on-curve@0.2.0: {} @@ -14451,7 +16153,7 @@ snapshots: portfinder@1.0.38: dependencies: async: 3.2.6 - debug: 4.4.3 + debug: 4.4.3(supports-color@8.1.1) transitivePeerDependencies: - supports-color @@ -14471,14 +16173,23 @@ snapshots: dependencies: postcss: 8.5.8 - postcss-load-config@6.0.1(jiti@1.21.7)(postcss@8.5.8)(tsx@4.21.0)(yaml@2.8.2): + postcss-load-config@6.0.1(jiti@1.21.7)(postcss@8.5.8)(tsx@4.21.0)(yaml@2.8.3): dependencies: lilconfig: 3.1.3 optionalDependencies: jiti: 1.21.7 postcss: 8.5.8 tsx: 4.21.0 - yaml: 2.8.2 + yaml: 2.8.3 + + postcss-load-config@6.0.1(jiti@2.6.1)(postcss@8.5.8)(tsx@4.21.0)(yaml@2.8.3): + dependencies: + lilconfig: 3.1.3 + optionalDependencies: + jiti: 2.6.1 + postcss: 8.5.8 + tsx: 4.21.0 + yaml: 2.8.3 postcss-nested@6.2.0(postcss@8.5.8): dependencies: @@ -14514,6 +16225,8 @@ snapshots: picocolors: 1.1.1 source-map-js: 1.2.1 + powershell-utils@0.1.0: {} + prebuild-install@7.1.3: dependencies: detect-libc: 2.1.2 @@ -14540,13 +16253,15 @@ snapshots: prismjs@1.30.0: {} + progress@2.0.3: {} + prop-types@15.8.1: dependencies: loose-envify: 1.4.0 object-assign: 4.1.1 react-is: 16.13.1 - propagate@2.0.1: {} + property-expr@2.0.6: {} property-information@5.6.0: dependencies: @@ -14554,6 +16269,8 @@ snapshots: property-information@7.1.0: {} + proxy-from-env@2.1.0: {} + pump@3.0.4: dependencies: end-of-stream: 1.4.5 @@ -14626,7 +16343,7 @@ snapshots: react: 19.2.4 react-dom: 19.2.4(react@19.2.4) - react-easy-crop@5.5.6(react-dom@19.2.4(react@19.2.4))(react@19.2.4): + react-easy-crop@5.5.7(react-dom@19.2.4(react@19.2.4))(react@19.2.4): dependencies: normalize-wheel: 1.0.1 react: 19.2.4 @@ -14644,11 +16361,11 @@ snapshots: react: 19.2.4 react-dom: 19.2.4(react@19.2.4) - react-i18next@16.5.8(i18next@25.8.18(typescript@5.9.3))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3): + react-i18next@16.6.6(i18next@25.10.10(typescript@5.9.3))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3): dependencies: '@babel/runtime': 7.29.2 html-parse-stringify: 3.0.1 - i18next: 25.8.18(typescript@5.9.3) + i18next: 25.10.10(typescript@5.9.3) react: 19.2.4 use-sync-external-store: 1.6.0(react@19.2.4) optionalDependencies: @@ -14713,11 +16430,6 @@ snapshots: webpack: 5.105.4(esbuild@0.27.2)(uglify-js@3.19.3) webpack-sources: 3.3.4 - react-slider@2.0.6(react@19.2.4): - dependencies: - prop-types: 15.8.1 - react: 19.2.4 - react-sortablejs@6.1.4(@types/sortablejs@1.15.9)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sortablejs@1.15.7): dependencies: '@types/sortablejs': 1.15.9 @@ -14781,6 +16493,20 @@ snapshots: dependencies: pify: 2.3.0 + read-package-up@12.0.0: + dependencies: + find-up-simple: 1.0.1 + read-pkg: 10.1.0 + type-fest: 5.5.0 + + read-pkg@10.1.0: + dependencies: + '@types/normalize-package-data': 2.4.4 + normalize-package-data: 8.0.0 + parse-json: 8.3.0 + type-fest: 5.5.0 + unicorn-magic: 0.4.0 + readable-stream@3.6.2: dependencies: inherits: 2.0.4 @@ -14790,7 +16516,7 @@ snapshots: readdirp@3.6.0: dependencies: - picomatch: 2.3.1 + picomatch: 2.3.2 readdirp@4.1.2: {} @@ -14840,6 +16566,8 @@ snapshots: dependencies: '@eslint-community/regexpp': 4.12.2 + reflect-metadata@0.2.2: {} + refractor@3.6.0: dependencies: hastscript: 6.0.0 @@ -14851,6 +16579,10 @@ snapshots: '@eslint-community/regexpp': 4.12.2 refa: 0.12.1 + regexp-match-indices@1.0.2: + dependencies: + regexp-tree: 0.1.27 + regexp-tree@0.1.27: {} regjsparser@0.13.0: @@ -14867,7 +16599,7 @@ snapshots: '@types/katex': 0.16.8 hast-util-from-html-isomorphic: 2.0.0 hast-util-to-text: 4.0.2 - katex: 0.16.38 + katex: 0.16.44 unist-util-visit-parents: 6.0.2 vfile: 6.0.3 @@ -14957,6 +16689,8 @@ snapshots: remend@1.3.0: {} + repeat-string@1.6.1: {} + require-from-string@2.0.2: {} reselect@5.1.1: {} @@ -14967,6 +16701,8 @@ snapshots: resolve-from@4.0.0: {} + resolve-from@5.0.0: {} + resolve-pkg-maps@1.0.0: {} resolve@1.22.11: @@ -14984,7 +16720,31 @@ snapshots: rfdc@1.4.1: {} - robust-predicates@3.0.2: {} + robust-predicates@3.0.3: {} + + rolldown@1.0.0-rc.12(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1): + dependencies: + '@oxc-project/types': 0.122.0 + '@rolldown/pluginutils': 1.0.0-rc.12 + optionalDependencies: + '@rolldown/binding-android-arm64': 1.0.0-rc.12 + '@rolldown/binding-darwin-arm64': 1.0.0-rc.12 + '@rolldown/binding-darwin-x64': 1.0.0-rc.12 + '@rolldown/binding-freebsd-x64': 1.0.0-rc.12 + '@rolldown/binding-linux-arm-gnueabihf': 1.0.0-rc.12 + '@rolldown/binding-linux-arm64-gnu': 1.0.0-rc.12 + '@rolldown/binding-linux-arm64-musl': 1.0.0-rc.12 + '@rolldown/binding-linux-ppc64-gnu': 1.0.0-rc.12 + '@rolldown/binding-linux-s390x-gnu': 1.0.0-rc.12 + '@rolldown/binding-linux-x64-gnu': 1.0.0-rc.12 + '@rolldown/binding-linux-x64-musl': 1.0.0-rc.12 + '@rolldown/binding-openharmony-arm64': 1.0.0-rc.12 + '@rolldown/binding-wasm32-wasi': 1.0.0-rc.12(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1) + '@rolldown/binding-win32-arm64-msvc': 1.0.0-rc.12 + '@rolldown/binding-win32-x64-msvc': 1.0.0-rc.12 + transitivePeerDependencies: + - '@emnapi/core' + - '@emnapi/runtime' rollup@4.59.0: dependencies: @@ -15063,10 +16823,6 @@ snapshots: sax@1.6.0: {} - saxes@6.0.0: - dependencies: - xmlchars: 2.2.0 - scheduler@0.27.0: {} schema-utils@4.3.3: @@ -15084,6 +16840,8 @@ snapshots: refa: 0.12.1 regexp-ast-analysis: 0.7.1 + seed-random@2.2.0: {} + semver@6.3.1: {} semver@7.7.4: {} @@ -15165,7 +16923,7 @@ snapshots: ansi-styles: 6.2.3 is-fullwidth-code-point: 5.1.0 - smol-toml@1.6.0: {} + smol-toml@1.6.1: {} solid-js@1.9.11: dependencies: @@ -15190,8 +16948,18 @@ snapshots: space-separated-tokens@2.0.2: {} + spdx-correct@3.2.0: + dependencies: + spdx-expression-parse: 3.0.1 + spdx-license-ids: 3.0.23 + spdx-exceptions@2.5.0: {} + spdx-expression-parse@3.0.1: + dependencies: + spdx-exceptions: 2.5.0 + spdx-license-ids: 3.0.23 + spdx-expression-parse@4.0.0: dependencies: spdx-exceptions: 2.5.0 @@ -15199,7 +16967,9 @@ snapshots: spdx-license-ids@3.0.23: {} - srvx@0.11.12: {} + srvx@0.11.13: {} + + stackframe@1.3.4: {} state-local@1.0.7: {} @@ -15207,7 +16977,7 @@ snapshots: std-semver@1.0.8: {} - storybook@10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4): + storybook@10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4): dependencies: '@storybook/global': 5.0.0 '@storybook/icons': 2.0.1(react-dom@19.2.4(react@19.2.4))(react@19.2.4) @@ -15220,7 +16990,7 @@ snapshots: recast: 0.23.11 semver: 7.7.4 use-sync-external-store: 1.6.0(react@19.2.4) - ws: 8.19.0 + ws: 8.20.0 transitivePeerDependencies: - '@testing-library/dom' - bufferutil @@ -15233,7 +17003,7 @@ snapshots: clsx: 2.1.1 hast-util-to-jsx-runtime: 2.3.6 html-url-attributes: 3.0.1 - marked: 17.0.4 + marked: 17.0.5 mermaid: 11.13.0 react: 19.2.4 react-dom: 19.2.4(react@19.2.4) @@ -15251,7 +17021,7 @@ snapshots: transitivePeerDependencies: - supports-color - strict-event-emitter@0.5.1: {} + string-argv@0.3.1: {} string-argv@0.3.2: {} @@ -15297,6 +17067,8 @@ snapshots: dependencies: js-tokens: 9.0.1 + structured-clone-es@2.0.0: {} + style-to-js@1.1.21: dependencies: style-to-object: 1.0.14 @@ -15344,8 +17116,6 @@ snapshots: picocolors: 1.1.1 sax: 1.6.0 - symbol-tree@3.2.4: {} - synckit@0.11.12: dependencies: '@pkgr/core': 0.2.9 @@ -15360,7 +17130,7 @@ snapshots: tailwind-merge@3.5.0: {} - tailwindcss@3.4.19(tsx@4.21.0)(yaml@2.8.2): + tailwindcss@3.4.19(tsx@4.21.0)(yaml@2.8.3): dependencies: '@alloc/quick-lru': 5.2.0 arg: 5.0.2 @@ -15379,7 +17149,7 @@ snapshots: postcss: 8.5.8 postcss-import: 15.1.0(postcss@8.5.8) postcss-js: 4.1.0(postcss@8.5.8) - postcss-load-config: 6.0.1(jiti@1.21.7)(postcss@8.5.8)(tsx@4.21.0)(yaml@2.8.2) + postcss-load-config: 6.0.1(jiti@1.21.7)(postcss@8.5.8)(tsx@4.21.0)(yaml@2.8.3) postcss-nested: 6.2.0(postcss@8.5.8) postcss-selector-parser: 6.1.2 resolve: 1.22.11 @@ -15388,7 +17158,7 @@ snapshots: - tsx - yaml - tapable@2.3.0: {} + tapable@2.3.2: {} tar-fs@2.1.4: dependencies: @@ -15429,7 +17199,7 @@ snapshots: tinyexec: 1.0.4 tinyglobby: 0.2.15 unconfig: 7.5.0 - yaml: 2.8.2 + yaml: 2.8.3 terser-webpack-plugin@5.4.0(esbuild@0.27.2)(uglify-js@3.19.3)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)): dependencies: @@ -15457,6 +17227,8 @@ snapshots: dependencies: any-promise: 1.3.0 + tiny-case@1.0.3: {} + tiny-inflate@1.0.3: {} tiny-invariant@1.2.0: {} @@ -15465,12 +17237,14 @@ snapshots: tinybench@2.9.0: {} + tinyexec@0.3.2: {} + tinyexec@1.0.4: {} tinyglobby@0.2.15: dependencies: - fdir: 6.5.0(picomatch@4.0.3) - picomatch: 4.0.3 + fdir: 6.5.0(picomatch@4.0.4) + picomatch: 4.0.4 tinypool@2.1.0: {} @@ -15480,11 +17254,11 @@ snapshots: tinyspy@4.0.4: {} - tldts-core@7.0.26: {} + tldts-core@7.0.27: {} - tldts@7.0.26: + tldts@7.0.27: dependencies: - tldts-core: 7.0.26 + tldts-core: 7.0.27 to-regex-range@5.0.1: dependencies: @@ -15501,21 +17275,17 @@ snapshots: dependencies: eslint-visitor-keys: 5.0.1 + toposort@2.0.2: {} + totalist@3.0.1: {} - tough-cookie@6.0.1: - dependencies: - tldts: 7.0.26 - - tr46@6.0.0: - dependencies: - punycode: 2.3.1 + tree-kill@1.2.2: {} trim-lines@3.0.1: {} trough@2.2.0: {} - ts-api-utils@2.4.0(typescript@5.9.3): + ts-api-utils@2.5.0(typescript@5.9.3): dependencies: typescript: 5.9.3 @@ -15523,7 +17293,7 @@ snapshots: ts-declaration-location@1.0.7(typescript@5.9.3): dependencies: - picomatch: 4.0.3 + picomatch: 4.0.4 typescript: 5.9.3 ts-dedent@2.2.0: {} @@ -15540,7 +17310,7 @@ snapshots: dependencies: chalk: 4.1.2 enhanced-resolve: 5.20.1 - tapable: 2.3.0 + tapable: 2.3.2 tsconfig-paths: 4.2.0 tsconfig-paths@4.2.0: @@ -15555,10 +17325,38 @@ snapshots: tslib@2.8.1: {} + tsup@8.5.1(jiti@2.6.1)(postcss@8.5.8)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3): + dependencies: + bundle-require: 5.1.0(esbuild@0.27.2) + cac: 6.7.14 + chokidar: 4.0.3 + consola: 3.4.2 + debug: 4.4.3(supports-color@8.1.1) + esbuild: 0.27.2 + fix-dts-default-cjs-exports: 1.0.1 + joycon: 3.1.1 + picocolors: 1.1.1 + postcss-load-config: 6.0.1(jiti@2.6.1)(postcss@8.5.8)(tsx@4.21.0)(yaml@2.8.3) + resolve-from: 5.0.0 + rollup: 4.59.0 + source-map: 0.7.6 + sucrase: 3.35.1 + tinyexec: 0.3.2 + tinyglobby: 0.2.15 + tree-kill: 1.2.2 + optionalDependencies: + postcss: 8.5.8 + typescript: 5.9.3 + transitivePeerDependencies: + - jiti + - supports-color + - tsx + - yaml + tsx@4.21.0: dependencies: esbuild: 0.27.2 - get-tsconfig: 4.13.6 + get-tsconfig: 4.13.7 optionalDependencies: fsevents: 2.3.3 @@ -15573,7 +17371,11 @@ snapshots: dependencies: prelude-ls: 1.2.1 - type-fest@5.4.4: + type-fest@2.19.0: {} + + type-fest@4.41.0: {} + + type-fest@5.5.0: dependencies: tagged-tag: 1.0.0 @@ -15602,13 +17404,13 @@ snapshots: undici@7.24.0: {} - undici@7.24.4: {} - unicode-trie@2.0.0: dependencies: pako: 0.2.9 tiny-inflate: 1.0.3 + unicorn-magic@0.4.0: {} + unified@11.0.5: dependencies: '@types/unist': 3.0.3 @@ -15656,8 +17458,6 @@ snapshots: unist-util-is: 6.0.1 unist-util-visit-parents: 6.0.2 - universal-user-agent@7.0.3: {} - universalify@2.0.1: {} unpic@4.2.2: {} @@ -15665,13 +17465,13 @@ snapshots: unplugin-utils@0.3.1: dependencies: pathe: 2.0.3 - picomatch: 4.0.3 + picomatch: 4.0.4 unplugin@2.3.11: dependencies: '@jridgewell/remapping': 2.3.5 acorn: 8.16.0 - picomatch: 4.0.3 + picomatch: 4.0.4 webpack-virtual-modules: 0.6.2 update-browserslist-db@1.2.3(browserslist@4.28.1): @@ -15680,6 +17480,10 @@ snapshots: escalade: 3.2.0 picocolors: 1.1.1 + upper-case-first@2.0.2: + dependencies: + tslib: 2.8.1 + uri-js@4.4.1: dependencies: punycode: 2.3.1 @@ -15729,16 +17533,23 @@ snapshots: dependencies: react: 19.2.4 + util-arity@1.1.0: {} + util-deprecate@1.0.2: {} uuid@11.1.0: {} uuid@13.0.0: {} - valibot@1.3.0(typescript@5.9.3): + valibot@1.3.1(typescript@5.9.3): optionalDependencies: typescript: 5.9.3 + validate-npm-package-license@3.0.4: + dependencies: + spdx-correct: 3.2.0 + spdx-expression-parse: 3.0.1 + vfile-location@5.0.3: dependencies: '@types/unist': 3.0.3 @@ -15754,37 +17565,27 @@ snapshots: '@types/unist': 3.0.3 vfile-message: 4.0.3 - vinext@0.0.31(d43efe4756ad5ea698dcdb002ea787ea): + vinext@0.0.38(f5786d681f520e26604259e094ebaa46): dependencies: - '@unpic/react': 1.0.2(next@16.2.0(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0))(react-dom@19.2.4(react@19.2.4))(react@19.2.4) + '@unpic/react': 1.0.2(next@16.2.1(@babel/core@7.29.0)(@playwright/test@1.58.2)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0))(react-dom@19.2.4(react@19.2.4))(react@19.2.4) '@vercel/og': 0.8.6 - '@vitejs/plugin-react': 6.0.1(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)) + '@vitejs/plugin-react': 6.0.1(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)) magic-string: 0.30.21 react: 19.2.4 react-dom: 19.2.4(react@19.2.4) rsc-html-stream: 0.0.7 - vite: '@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' + vite: '@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' vite-plugin-commonjs: 0.10.4 - vite-tsconfig-paths: 6.1.1(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(typescript@5.9.3) + vite-tsconfig-paths: 6.1.1(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(typescript@5.9.3) optionalDependencies: '@mdx-js/rollup': 3.1.1(rollup@4.59.0) - '@vitejs/plugin-rsc': 0.5.21(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(react-dom@19.2.4(react@19.2.4))(react-server-dom-webpack@19.2.4(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)))(react@19.2.4) + '@vitejs/plugin-rsc': 0.5.21(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(react-dom@19.2.4(react@19.2.4))(react-server-dom-webpack@19.2.4(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)))(react@19.2.4) react-server-dom-webpack: 19.2.4(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) transitivePeerDependencies: - next - supports-color - typescript - vite-dev-rpc@1.1.0(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)): - dependencies: - birpc: 2.9.0 - vite: '@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' - vite-hot-client: 2.1.0(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)) - - vite-hot-client@2.1.0(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)): - dependencies: - vite: '@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' - vite-plugin-commonjs@0.10.4: dependencies: acorn: 8.16.0 @@ -15798,54 +17599,57 @@ snapshots: fast-glob: 3.3.3 magic-string: 0.30.21 - vite-plugin-inspect@11.3.3(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)): + vite-plugin-inspect@12.0.0-beta.1(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(typescript@5.9.3)(ws@8.20.0): dependencies: + '@vitejs/devtools-kit': 0.1.11(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(typescript@5.9.3)(ws@8.20.0) ansis: 4.2.0 - debug: 4.4.3 error-stack-parser-es: 1.0.5 + obug: 2.1.1 ohash: 2.0.11 - open: 10.2.0 + open: 11.0.0 perfect-debounce: 2.1.0 sirv: 3.0.2 unplugin-utils: 0.3.1 - vite: '@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' - vite-dev-rpc: 1.1.0(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)) + vite: '@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' transitivePeerDependencies: - - supports-color + - typescript + - ws - vite-plugin-storybook-nextjs@3.2.3(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(next@16.2.0(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0))(storybook@10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3): + vite-plugin-storybook-nextjs@3.2.4(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(next@16.2.1(@babel/core@7.29.0)(@playwright/test@1.58.2)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0))(storybook@10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3): dependencies: '@next/env': 16.0.0 image-size: 2.0.2 magic-string: 0.30.21 module-alias: 2.3.4 - next: 16.2.0(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0) - storybook: 10.3.0(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) + next: 16.2.1(@babel/core@7.29.0)(@playwright/test@1.58.2)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0) + storybook: 10.3.3(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) ts-dedent: 2.2.0 - vite: '@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' - vite-tsconfig-paths: 5.1.4(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(typescript@5.9.3) + vite: '@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' + vite-tsconfig-paths: 5.1.4(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(typescript@5.9.3) transitivePeerDependencies: - supports-color - typescript - vite-plus@0.1.12(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.0(canvas@3.2.1))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2): + vite-plus@0.1.14(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3): dependencies: - '@oxc-project/types': 0.115.0 - '@voidzero-dev/vite-plus-core': 0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2) - '@voidzero-dev/vite-plus-test': 0.1.12(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.0(canvas@3.2.1))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2) - cac: 6.7.14 + '@oxc-project/types': 0.122.0 + '@voidzero-dev/vite-plus-core': 0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3) + '@voidzero-dev/vite-plus-test': 0.1.14(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3) + cac: 7.0.0 cross-spawn: 7.0.6 - oxfmt: 0.40.0 - oxlint: 1.55.0(oxlint-tsgolint@0.17.0) - oxlint-tsgolint: 0.17.0 + oxfmt: 0.42.0 + oxlint: 1.57.0(oxlint-tsgolint@0.17.3) + oxlint-tsgolint: 0.17.3 picocolors: 1.1.1 optionalDependencies: - '@voidzero-dev/vite-plus-darwin-arm64': 0.1.12 - '@voidzero-dev/vite-plus-darwin-x64': 0.1.12 - '@voidzero-dev/vite-plus-linux-arm64-gnu': 0.1.12 - '@voidzero-dev/vite-plus-linux-x64-gnu': 0.1.12 - '@voidzero-dev/vite-plus-win32-arm64-msvc': 0.1.12 - '@voidzero-dev/vite-plus-win32-x64-msvc': 0.1.12 + '@voidzero-dev/vite-plus-darwin-arm64': 0.1.14 + '@voidzero-dev/vite-plus-darwin-x64': 0.1.14 + '@voidzero-dev/vite-plus-linux-arm64-gnu': 0.1.14 + '@voidzero-dev/vite-plus-linux-arm64-musl': 0.1.14 + '@voidzero-dev/vite-plus-linux-x64-gnu': 0.1.14 + '@voidzero-dev/vite-plus-linux-x64-musl': 0.1.14 + '@voidzero-dev/vite-plus-win32-arm64-msvc': 0.1.14 + '@voidzero-dev/vite-plus-win32-x64-msvc': 0.1.14 transitivePeerDependencies: - '@arethetypeswrong/core' - '@edge-runtime/vm' @@ -15874,36 +17678,104 @@ snapshots: - vite - yaml - vite-tsconfig-paths@5.1.4(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(typescript@5.9.3): + vite-plus@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(vite@8.0.3(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1)(@types/node@25.5.0)(esbuild@0.27.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(yaml@2.8.3))(yaml@2.8.3): dependencies: - debug: 4.4.3 + '@oxc-project/types': 0.122.0 + '@voidzero-dev/vite-plus-core': 0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3) + '@voidzero-dev/vite-plus-test': 0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(vite@8.0.3(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1)(@types/node@25.5.0)(esbuild@0.27.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(yaml@2.8.3))(yaml@2.8.3) + cac: 7.0.0 + cross-spawn: 7.0.6 + oxfmt: 0.42.0 + oxlint: 1.57.0(oxlint-tsgolint@0.17.3) + oxlint-tsgolint: 0.17.3 + picocolors: 1.1.1 + optionalDependencies: + '@voidzero-dev/vite-plus-darwin-arm64': 0.1.14 + '@voidzero-dev/vite-plus-darwin-x64': 0.1.14 + '@voidzero-dev/vite-plus-linux-arm64-gnu': 0.1.14 + '@voidzero-dev/vite-plus-linux-arm64-musl': 0.1.14 + '@voidzero-dev/vite-plus-linux-x64-gnu': 0.1.14 + '@voidzero-dev/vite-plus-linux-x64-musl': 0.1.14 + '@voidzero-dev/vite-plus-win32-arm64-msvc': 0.1.14 + '@voidzero-dev/vite-plus-win32-x64-msvc': 0.1.14 + transitivePeerDependencies: + - '@arethetypeswrong/core' + - '@edge-runtime/vm' + - '@opentelemetry/api' + - '@tsdown/css' + - '@tsdown/exe' + - '@types/node' + - '@vitejs/devtools' + - '@vitest/ui' + - bufferutil + - esbuild + - happy-dom + - jiti + - jsdom + - less + - publint + - sass + - sass-embedded + - stylus + - sugarss + - terser + - tsx + - typescript + - unplugin-unused + - utf-8-validate + - vite + - yaml + + vite-tsconfig-paths@5.1.4(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(typescript@5.9.3): + dependencies: + debug: 4.4.3(supports-color@8.1.1) globrex: 0.1.2 tsconfck: 3.1.6(typescript@5.9.3) optionalDependencies: - vite: '@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' + vite: '@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' transitivePeerDependencies: - supports-color - typescript - vite-tsconfig-paths@6.1.1(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(typescript@5.9.3): + vite-tsconfig-paths@6.1.1(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(typescript@5.9.3): dependencies: - debug: 4.4.3 + debug: 4.4.3(supports-color@8.1.1) globrex: 0.1.2 tsconfck: 3.1.6(typescript@5.9.3) - vite: '@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' + vite: '@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' transitivePeerDependencies: - supports-color - typescript - vitefu@1.1.2(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)): + vite@8.0.3(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1)(@types/node@25.5.0)(esbuild@0.27.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(yaml@2.8.3): + dependencies: + lightningcss: 1.32.0 + picomatch: 4.0.4 + postcss: 8.5.8 + rolldown: 1.0.0-rc.12(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1) + tinyglobby: 0.2.15 optionalDependencies: - vite: '@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' + '@types/node': 25.5.0 + esbuild: 0.27.2 + fsevents: 2.3.3 + jiti: 2.6.1 + sass: 1.98.0 + terser: 5.46.1 + tsx: 4.21.0 + yaml: 2.8.3 + transitivePeerDependencies: + - '@emnapi/core' + - '@emnapi/runtime' - vitest-canvas-mock@1.1.3(@voidzero-dev/vite-plus-test@0.1.12(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.0(canvas@3.2.1))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)): + vitefu@1.1.2(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)): + optionalDependencies: + vite: '@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' + + vitest-canvas-mock@1.1.4(@voidzero-dev/vite-plus-test@0.1.14(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)): dependencies: cssfontparser: 1.2.1 moo-color: 1.0.3 - vitest: '@voidzero-dev/vite-plus-test@0.1.12(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.12(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.0(canvas@3.2.1))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' + vitest: '@voidzero-dev/vite-plus-test@0.1.14(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' void-elements@3.1.0: {} @@ -15924,10 +17796,10 @@ snapshots: vscode-uri@3.1.0: {} - vue-eslint-parser@10.4.0(eslint@10.0.3(jiti@1.21.7)): + vue-eslint-parser@10.4.0(eslint@10.1.0(jiti@1.21.7)): dependencies: - debug: 4.4.3 - eslint: 10.0.3(jiti@1.21.7) + debug: 4.4.3(supports-color@8.1.1) + eslint: 10.1.0(jiti@1.21.7) eslint-scope: 9.1.2 eslint-visitor-keys: 5.0.1 espree: 11.2.0 @@ -15936,10 +17808,6 @@ snapshots: transitivePeerDependencies: - supports-color - w3c-xmlserializer@5.0.0: - dependencies: - xml-name-validator: 5.0.0 - walk-up-path@4.0.0: {} watchpack@2.5.1: @@ -15951,8 +17819,6 @@ snapshots: web-vitals@5.1.0: {} - webidl-conversions@8.0.1: {} - webpack-sources@3.3.4: {} webpack-virtual-modules@0.6.2: {} @@ -15980,7 +17846,7 @@ snapshots: mime-types: 2.1.35 neo-async: 2.6.2 schema-utils: 4.3.3 - tapable: 2.3.0 + tapable: 2.3.2 terser-webpack-plugin: 5.4.0(esbuild@0.27.2)(uglify-js@3.19.3)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) watchpack: 2.5.1 webpack-sources: 3.3.4 @@ -15993,18 +17859,10 @@ snapshots: dependencies: iconv-lite: 0.6.3 + whatwg-mimetype@3.0.0: {} + whatwg-mimetype@4.0.0: {} - whatwg-mimetype@5.0.0: {} - - whatwg-url@16.0.1: - dependencies: - '@exodus/bytes': 1.15.0 - tr46: 6.0.0 - webidl-conversions: 8.0.1 - transitivePeerDependencies: - - '@noble/hashes' - which@2.0.2: dependencies: isexe: 2.0.0 @@ -16019,17 +17877,20 @@ snapshots: wrappy@1.0.2: {} - ws@8.19.0: {} + ws@8.20.0: {} wsl-utils@0.1.0: dependencies: is-wsl: 3.1.1 + wsl-utils@0.3.1: + dependencies: + is-wsl: 3.1.1 + powershell-utils: 0.1.0 + xml-name-validator@4.0.0: {} - xml-name-validator@5.0.0: {} - - xmlchars@2.2.0: {} + xmlbuilder@15.1.1: {} xtend@4.0.2: {} @@ -16040,9 +17901,9 @@ snapshots: yaml-eslint-parser@2.0.0: dependencies: eslint-visitor-keys: 5.0.1 - yaml: 2.8.2 + yaml: 2.8.3 - yaml@2.8.2: {} + yaml@2.8.3: {} yauzl@3.2.1: dependencies: @@ -16055,8 +17916,17 @@ snapshots: yocto-queue@0.1.0: {} + yocto-queue@1.2.2: {} + yoga-layout@3.2.1: {} + yup@1.7.1: + dependencies: + property-expr: 2.0.6 + tiny-case: 1.0.3 + toposort: 2.0.2 + type-fest: 2.19.0 + zen-observable@0.10.0: {} zimmerframe@1.1.4: {} diff --git a/pnpm-workspace.yaml b/pnpm-workspace.yaml new file mode 100644 index 00000000000..dece6f3f4fe --- /dev/null +++ b/pnpm-workspace.yaml @@ -0,0 +1,257 @@ +packages: + - web + - e2e + - sdks/nodejs-client +overrides: + "@lexical/code": npm:lexical-code-no-prism@0.41.0 + "@monaco-editor/loader": 1.7.0 + "@nolyfill/safe-buffer": npm:safe-buffer@^5.2.1 + array-includes: npm:@nolyfill/array-includes@^1.0.44 + array.prototype.findlast: npm:@nolyfill/array.prototype.findlast@^1.0.44 + array.prototype.findlastindex: npm:@nolyfill/array.prototype.findlastindex@^1.0.44 + array.prototype.flat: npm:@nolyfill/array.prototype.flat@^1.0.44 + array.prototype.flatmap: npm:@nolyfill/array.prototype.flatmap@^1.0.44 + array.prototype.tosorted: npm:@nolyfill/array.prototype.tosorted@^1.0.44 + assert: npm:@nolyfill/assert@^1.0.26 + brace-expansion@<2.0.2: 2.0.2 + canvas: ^3.2.2 + devalue@<5.3.2: 5.3.2 + dompurify@>=3.1.3 <=3.3.1: 3.3.2 + es-iterator-helpers: npm:@nolyfill/es-iterator-helpers@^1.0.21 + esbuild@<0.27.2: 0.27.2 + flatted@<=3.4.1: 3.4.2 + glob@>=10.2.0 <10.5.0: 11.1.0 + hasown: npm:@nolyfill/hasown@^1.0.44 + is-arguments: npm:@nolyfill/is-arguments@^1.0.44 + is-core-module: npm:@nolyfill/is-core-module@^1.0.39 + is-generator-function: npm:@nolyfill/is-generator-function@^1.0.44 + is-typed-array: npm:@nolyfill/is-typed-array@^1.0.44 + isarray: npm:@nolyfill/isarray@^1.0.44 + object.assign: npm:@nolyfill/object.assign@^1.0.44 + object.entries: npm:@nolyfill/object.entries@^1.0.44 + object.fromentries: npm:@nolyfill/object.fromentries@^1.0.44 + object.groupby: npm:@nolyfill/object.groupby@^1.0.44 + object.values: npm:@nolyfill/object.values@^1.0.44 + pbkdf2: ~3.1.5 + pbkdf2@<3.1.3: 3.1.3 + picomatch@<2.3.2: 2.3.2 + picomatch@>=4.0.0 <4.0.4: 4.0.4 + prismjs: ~1.30 + prismjs@<1.30.0: 1.30.0 + rollup@>=4.0.0 <4.59.0: 4.59.0 + safe-buffer: ^5.2.1 + safe-regex-test: npm:@nolyfill/safe-regex-test@^1.0.44 + safer-buffer: npm:@nolyfill/safer-buffer@^1.0.44 + side-channel: npm:@nolyfill/side-channel@^1.0.44 + smol-toml@<1.6.1: 1.6.1 + solid-js: 1.9.11 + string-width: ~8.2.0 + string.prototype.includes: npm:@nolyfill/string.prototype.includes@^1.0.44 + string.prototype.matchall: npm:@nolyfill/string.prototype.matchall@^1.0.44 + string.prototype.repeat: npm:@nolyfill/string.prototype.repeat@^1.0.44 + string.prototype.trimend: npm:@nolyfill/string.prototype.trimend@^1.0.44 + svgo@>=3.0.0 <3.3.3: 3.3.3 + tar@<=7.5.10: 7.5.11 + typed-array-buffer: npm:@nolyfill/typed-array-buffer@^1.0.44 + undici@>=7.0.0 <7.24.0: 7.24.0 + vite: npm:@voidzero-dev/vite-plus-core@0.1.14 + vitest: npm:@voidzero-dev/vite-plus-test@0.1.14 + which-typed-array: npm:@nolyfill/which-typed-array@^1.0.44 + yaml@>=2.0.0 <2.8.3: 2.8.3 + yauzl@<3.2.1: 3.2.1 +ignoredBuiltDependencies: + - canvas + - core-js-pure +onlyBuiltDependencies: + - "@parcel/watcher" + - esbuild + - sharp +catalog: + "@amplitude/analytics-browser": 2.38.0 + "@amplitude/plugin-session-replay-browser": 1.27.5 + "@antfu/eslint-config": 7.7.3 + "@base-ui/react": 1.3.0 + "@chromatic-com/storybook": 5.1.1 + "@cucumber/cucumber": 12.7.0 + "@egoist/tailwindcss-icons": 1.9.2 + "@emoji-mart/data": 1.2.1 + "@eslint-react/eslint-plugin": 3.0.0 + "@eslint/js": ^10.0.1 + "@floating-ui/react": 0.27.19 + "@formatjs/intl-localematcher": 0.8.2 + "@headlessui/react": 2.2.9 + "@heroicons/react": 2.2.0 + "@hono/node-server": 1.19.11 + "@iconify-json/heroicons": 1.2.3 + "@iconify-json/ri": 1.2.10 + "@lexical/code": 0.42.0 + "@lexical/link": 0.42.0 + "@lexical/list": 0.42.0 + "@lexical/react": 0.42.0 + "@lexical/selection": 0.42.0 + "@lexical/text": 0.42.0 + "@lexical/utils": 0.42.0 + "@mdx-js/loader": 3.1.1 + "@mdx-js/react": 3.1.1 + "@mdx-js/rollup": 3.1.1 + "@monaco-editor/react": 4.7.0 + "@next/eslint-plugin-next": 16.2.1 + "@next/mdx": 16.2.1 + "@orpc/client": 1.13.13 + "@orpc/contract": 1.13.13 + "@orpc/openapi-client": 1.13.13 + "@orpc/tanstack-query": 1.13.13 + "@playwright/test": 1.58.2 + "@remixicon/react": 4.9.0 + "@rgrove/parse-xml": 4.2.0 + "@sentry/react": 10.46.0 + "@storybook/addon-docs": 10.3.3 + "@storybook/addon-links": 10.3.3 + "@storybook/addon-onboarding": 10.3.3 + "@storybook/addon-themes": 10.3.3 + "@storybook/nextjs-vite": 10.3.3 + "@storybook/react": 10.3.3 + "@streamdown/math": 1.0.2 + "@svgdotjs/svg.js": 3.2.5 + "@t3-oss/env-nextjs": 0.13.11 + "@tailwindcss/typography": 0.5.19 + "@tanstack/eslint-plugin-query": 5.95.2 + "@tanstack/react-devtools": 0.10.0 + "@tanstack/react-form": 1.28.5 + "@tanstack/react-form-devtools": 0.2.19 + "@tanstack/react-query": 5.95.2 + "@tanstack/react-query-devtools": 5.95.2 + "@testing-library/dom": 10.4.1 + "@testing-library/jest-dom": 6.9.1 + "@testing-library/react": 16.3.2 + "@testing-library/user-event": 14.6.1 + "@tsslint/cli": 3.0.2 + "@tsslint/compat-eslint": 3.0.2 + "@tsslint/config": 3.0.2 + "@types/js-cookie": 3.0.6 + "@types/js-yaml": 4.0.9 + "@types/negotiator": 0.6.4 + "@types/node": 25.5.0 + "@types/postcss-js": 4.1.0 + "@types/qs": 6.15.0 + "@types/react": 19.2.14 + "@types/react-dom": 19.2.3 + "@types/react-syntax-highlighter": 15.5.13 + "@types/react-window": 1.8.8 + "@types/sortablejs": 1.15.9 + "@typescript-eslint/eslint-plugin": ^8.57.2 + "@typescript-eslint/parser": 8.57.2 + "@typescript/native-preview": 7.0.0-dev.20260329.1 + "@vitejs/plugin-react": 6.0.1 + "@vitejs/plugin-rsc": 0.5.21 + "@vitest/coverage-v8": 4.1.2 + abcjs: 6.6.2 + agentation: 3.0.2 + ahooks: 3.9.7 + autoprefixer: 10.4.27 + axios: ^1.14.0 + class-variance-authority: 0.7.1 + clsx: 2.1.1 + cmdk: 1.1.1 + code-inspector-plugin: 1.4.5 + copy-to-clipboard: 3.3.3 + cron-parser: 5.5.0 + dayjs: 1.11.20 + decimal.js: 10.6.0 + dompurify: 3.3.3 + echarts: 6.0.0 + echarts-for-react: 3.0.6 + elkjs: 0.11.1 + embla-carousel-autoplay: 8.6.0 + embla-carousel-react: 8.6.0 + emoji-mart: 5.6.0 + es-toolkit: 1.45.1 + eslint: 10.1.0 + eslint-markdown: 0.6.0 + eslint-plugin-better-tailwindcss: 4.3.2 + eslint-plugin-hyoban: 0.14.1 + eslint-plugin-markdown-preferences: 0.40.3 + eslint-plugin-no-barrel-files: 1.2.2 + eslint-plugin-react-hooks: 7.0.1 + eslint-plugin-react-refresh: 0.5.2 + eslint-plugin-sonarjs: 4.0.2 + eslint-plugin-storybook: 10.3.3 + fast-deep-equal: 3.1.3 + foxact: 0.3.0 + happy-dom: 20.8.9 + hono: 4.12.9 + html-entities: 2.6.0 + html-to-image: 1.11.13 + husky: 9.1.7 + i18next: 25.10.10 + i18next-resources-to-backend: 1.2.1 + iconify-import-svg: 0.1.2 + immer: 11.1.4 + jotai: 2.19.0 + js-audio-recorder: 1.0.7 + js-cookie: 3.0.5 + js-yaml: 4.1.1 + jsonschema: 1.5.0 + katex: 0.16.44 + knip: 6.1.0 + ky: 1.14.3 + lamejs: 1.2.1 + lexical: 0.42.0 + lint-staged: 16.4.0 + mermaid: 11.13.0 + mime: 4.1.0 + mitt: 3.0.1 + negotiator: 1.0.0 + next: 16.2.1 + next-themes: 0.4.6 + nuqs: 2.8.9 + pinyin-pro: 3.28.0 + postcss: 8.5.8 + postcss-js: 5.1.0 + qrcode.react: 4.2.0 + qs: 6.15.0 + react: 19.2.4 + react-18-input-autosize: 3.0.0 + react-dom: 19.2.4 + react-easy-crop: 5.5.7 + react-hotkeys-hook: 5.2.4 + react-i18next: 16.6.6 + react-multi-email: 1.0.25 + react-papaparse: 4.4.0 + react-pdf-highlighter: 8.0.0-rc.0 + react-server-dom-webpack: 19.2.4 + react-sortablejs: 6.1.4 + react-syntax-highlighter: 15.6.6 + react-textarea-autosize: 8.5.9 + react-window: 1.8.11 + reactflow: 11.11.4 + remark-breaks: 4.0.0 + remark-directive: 4.0.0 + sass: 1.98.0 + scheduler: 0.27.0 + sharp: 0.34.5 + sortablejs: 1.15.7 + std-semver: 1.0.8 + storybook: 10.3.3 + streamdown: 2.5.0 + string-ts: 2.3.1 + tailwind-merge: 2.6.1 + tailwindcss: 3.4.19 + taze: 19.10.0 + tldts: 7.0.27 + tsup: ^8.5.1 + tsx: 4.21.0 + typescript: 5.9.3 + uglify-js: 3.19.3 + unist-util-visit: 5.1.0 + use-context-selector: 2.0.0 + uuid: 13.0.0 + vinext: 0.0.38 + vite: npm:@voidzero-dev/vite-plus-core@0.1.14 + vite-plugin-inspect: 12.0.0-beta.1 + vite-plus: 0.1.14 + vitest: npm:@voidzero-dev/vite-plus-test@0.1.14 + vitest-canvas-mock: 1.1.4 + zod: 4.3.6 + zundo: 2.3.0 + zustand: 5.0.12 diff --git a/sdks/nodejs-client/README.md b/sdks/nodejs-client/README.md index f8c2803c08b..7051bbc788a 100644 --- a/sdks/nodejs-client/README.md +++ b/sdks/nodejs-client/README.md @@ -100,6 +100,10 @@ Notes: - Chat/completion require a stable `user` identifier in the request payload. - For streaming responses, iterate the returned AsyncIterable. Use `stream.toText()` to collect text. +## Maintainers + +This package is published from the repository workspace. Install dependencies from the repository root with `pnpm install`, then use `./scripts/publish.sh` for dry runs and publishing so `catalog:` dependencies are resolved before release. + ## License This SDK is released under the MIT License. diff --git a/sdks/nodejs-client/package.json b/sdks/nodejs-client/package.json index 7c8a2934469..63fa6799b10 100644 --- a/sdks/nodejs-client/package.json +++ b/sdks/nodejs-client/package.json @@ -54,22 +54,17 @@ "publish:npm": "./scripts/publish.sh" }, "dependencies": { - "axios": "^1.13.6" + "axios": "catalog:" }, "devDependencies": { - "@eslint/js": "^10.0.1", - "@types/node": "^25.4.0", - "@typescript-eslint/eslint-plugin": "^8.57.0", - "@typescript-eslint/parser": "^8.57.0", - "@vitest/coverage-v8": "4.0.18", - "eslint": "^10.0.3", - "tsup": "^8.5.1", - "typescript": "^5.9.3", - "vitest": "^4.0.18" - }, - "pnpm": { - "overrides": { - "rollup@>=4.0.0,<4.59.0": "4.59.0" - } + "@eslint/js": "catalog:", + "@types/node": "catalog:", + "@typescript-eslint/eslint-plugin": "catalog:", + "@typescript-eslint/parser": "catalog:", + "@vitest/coverage-v8": "catalog:", + "eslint": "catalog:", + "tsup": "catalog:", + "typescript": "catalog:", + "vitest": "catalog:" } } diff --git a/sdks/nodejs-client/pnpm-lock.yaml b/sdks/nodejs-client/pnpm-lock.yaml deleted file mode 100644 index b0aee38cdf6..00000000000 --- a/sdks/nodejs-client/pnpm-lock.yaml +++ /dev/null @@ -1,2266 +0,0 @@ -lockfileVersion: '9.0' - -settings: - autoInstallPeers: true - excludeLinksFromLockfile: false - -overrides: - rollup@>=4.0.0,<4.59.0: 4.59.0 - -importers: - - .: - dependencies: - axios: - specifier: ^1.13.6 - version: 1.13.6 - devDependencies: - '@eslint/js': - specifier: ^10.0.1 - version: 10.0.1(eslint@10.0.3) - '@types/node': - specifier: ^25.4.0 - version: 25.4.0 - '@typescript-eslint/eslint-plugin': - specifier: ^8.57.0 - version: 8.57.0(@typescript-eslint/parser@8.57.0(eslint@10.0.3)(typescript@5.9.3))(eslint@10.0.3)(typescript@5.9.3) - '@typescript-eslint/parser': - specifier: ^8.57.0 - version: 8.57.0(eslint@10.0.3)(typescript@5.9.3) - '@vitest/coverage-v8': - specifier: 4.0.18 - version: 4.0.18(vitest@4.0.18(@types/node@25.4.0)) - eslint: - specifier: ^10.0.3 - version: 10.0.3 - tsup: - specifier: ^8.5.1 - version: 8.5.1(postcss@8.5.8)(typescript@5.9.3) - typescript: - specifier: ^5.9.3 - version: 5.9.3 - vitest: - specifier: ^4.0.18 - version: 4.0.18(@types/node@25.4.0) - -packages: - - '@babel/helper-string-parser@7.27.1': - resolution: {integrity: sha512-qMlSxKbpRlAridDExk92nSobyDdpPijUq2DW6oDnUqd0iOGxmQjyqhMIihI9+zv4LPyZdRje2cavWPbCbWm3eA==} - engines: {node: '>=6.9.0'} - - '@babel/helper-validator-identifier@7.28.5': - resolution: {integrity: sha512-qSs4ifwzKJSV39ucNjsvc6WVHs6b7S03sOh2OcHF9UHfVPqWWALUsNUVzhSBiItjRZoLHx7nIarVjqKVusUZ1Q==} - engines: {node: '>=6.9.0'} - - '@babel/parser@7.29.0': - resolution: {integrity: sha512-IyDgFV5GeDUVX4YdF/3CPULtVGSXXMLh1xVIgdCgxApktqnQV0r7/8Nqthg+8YLGaAtdyIlo2qIdZrbCv4+7ww==} - engines: {node: '>=6.0.0'} - hasBin: true - - '@babel/types@7.29.0': - resolution: {integrity: sha512-LwdZHpScM4Qz8Xw2iKSzS+cfglZzJGvofQICy7W7v4caru4EaAmyUuO6BGrbyQ2mYV11W0U8j5mBhd14dd3B0A==} - engines: {node: '>=6.9.0'} - - '@bcoe/v8-coverage@1.0.2': - resolution: {integrity: sha512-6zABk/ECA/QYSCQ1NGiVwwbQerUCZ+TQbp64Q3AgmfNvurHH0j8TtXa1qbShXA6qqkpAj4V5W8pP6mLe1mcMqA==} - engines: {node: '>=18'} - - '@esbuild/aix-ppc64@0.27.3': - resolution: {integrity: sha512-9fJMTNFTWZMh5qwrBItuziu834eOCUcEqymSH7pY+zoMVEZg3gcPuBNxH1EvfVYe9h0x/Ptw8KBzv7qxb7l8dg==} - engines: {node: '>=18'} - cpu: [ppc64] - os: [aix] - - '@esbuild/android-arm64@0.27.3': - resolution: {integrity: sha512-YdghPYUmj/FX2SYKJ0OZxf+iaKgMsKHVPF1MAq/P8WirnSpCStzKJFjOjzsW0QQ7oIAiccHdcqjbHmJxRb/dmg==} - engines: {node: '>=18'} - cpu: [arm64] - os: [android] - - '@esbuild/android-arm@0.27.3': - resolution: {integrity: sha512-i5D1hPY7GIQmXlXhs2w8AWHhenb00+GxjxRncS2ZM7YNVGNfaMxgzSGuO8o8SJzRc/oZwU2bcScvVERk03QhzA==} - engines: {node: '>=18'} - cpu: [arm] - os: [android] - - '@esbuild/android-x64@0.27.3': - resolution: {integrity: sha512-IN/0BNTkHtk8lkOM8JWAYFg4ORxBkZQf9zXiEOfERX/CzxW3Vg1ewAhU7QSWQpVIzTW+b8Xy+lGzdYXV6UZObQ==} - engines: {node: '>=18'} - cpu: [x64] - os: [android] - - '@esbuild/darwin-arm64@0.27.3': - resolution: {integrity: sha512-Re491k7ByTVRy0t3EKWajdLIr0gz2kKKfzafkth4Q8A5n1xTHrkqZgLLjFEHVD+AXdUGgQMq+Godfq45mGpCKg==} - engines: {node: '>=18'} - cpu: [arm64] - os: [darwin] - - '@esbuild/darwin-x64@0.27.3': - resolution: {integrity: sha512-vHk/hA7/1AckjGzRqi6wbo+jaShzRowYip6rt6q7VYEDX4LEy1pZfDpdxCBnGtl+A5zq8iXDcyuxwtv3hNtHFg==} - engines: {node: '>=18'} - cpu: [x64] - os: [darwin] - - '@esbuild/freebsd-arm64@0.27.3': - resolution: {integrity: sha512-ipTYM2fjt3kQAYOvo6vcxJx3nBYAzPjgTCk7QEgZG8AUO3ydUhvelmhrbOheMnGOlaSFUoHXB6un+A7q4ygY9w==} - engines: {node: '>=18'} - cpu: [arm64] - os: [freebsd] - - '@esbuild/freebsd-x64@0.27.3': - resolution: {integrity: sha512-dDk0X87T7mI6U3K9VjWtHOXqwAMJBNN2r7bejDsc+j03SEjtD9HrOl8gVFByeM0aJksoUuUVU9TBaZa2rgj0oA==} - engines: {node: '>=18'} - cpu: [x64] - os: [freebsd] - - '@esbuild/linux-arm64@0.27.3': - resolution: {integrity: sha512-sZOuFz/xWnZ4KH3YfFrKCf1WyPZHakVzTiqji3WDc0BCl2kBwiJLCXpzLzUBLgmp4veFZdvN5ChW4Eq/8Fc2Fg==} - engines: {node: '>=18'} - cpu: [arm64] - os: [linux] - - '@esbuild/linux-arm@0.27.3': - resolution: {integrity: sha512-s6nPv2QkSupJwLYyfS+gwdirm0ukyTFNl3KTgZEAiJDd+iHZcbTPPcWCcRYH+WlNbwChgH2QkE9NSlNrMT8Gfw==} - engines: {node: '>=18'} - cpu: [arm] - os: [linux] - - '@esbuild/linux-ia32@0.27.3': - resolution: {integrity: sha512-yGlQYjdxtLdh0a3jHjuwOrxQjOZYD/C9PfdbgJJF3TIZWnm/tMd/RcNiLngiu4iwcBAOezdnSLAwQDPqTmtTYg==} - engines: {node: '>=18'} - cpu: [ia32] - os: [linux] - - '@esbuild/linux-loong64@0.27.3': - resolution: {integrity: sha512-WO60Sn8ly3gtzhyjATDgieJNet/KqsDlX5nRC5Y3oTFcS1l0KWba+SEa9Ja1GfDqSF1z6hif/SkpQJbL63cgOA==} - engines: {node: '>=18'} - cpu: [loong64] - os: [linux] - - '@esbuild/linux-mips64el@0.27.3': - resolution: {integrity: sha512-APsymYA6sGcZ4pD6k+UxbDjOFSvPWyZhjaiPyl/f79xKxwTnrn5QUnXR5prvetuaSMsb4jgeHewIDCIWljrSxw==} - engines: {node: '>=18'} - cpu: [mips64el] - os: [linux] - - '@esbuild/linux-ppc64@0.27.3': - resolution: {integrity: sha512-eizBnTeBefojtDb9nSh4vvVQ3V9Qf9Df01PfawPcRzJH4gFSgrObw+LveUyDoKU3kxi5+9RJTCWlj4FjYXVPEA==} - engines: {node: '>=18'} - cpu: [ppc64] - os: [linux] - - '@esbuild/linux-riscv64@0.27.3': - resolution: {integrity: sha512-3Emwh0r5wmfm3ssTWRQSyVhbOHvqegUDRd0WhmXKX2mkHJe1SFCMJhagUleMq+Uci34wLSipf8Lagt4LlpRFWQ==} - engines: {node: '>=18'} - cpu: [riscv64] - os: [linux] - - '@esbuild/linux-s390x@0.27.3': - resolution: {integrity: sha512-pBHUx9LzXWBc7MFIEEL0yD/ZVtNgLytvx60gES28GcWMqil8ElCYR4kvbV2BDqsHOvVDRrOxGySBM9Fcv744hw==} - engines: {node: '>=18'} - cpu: [s390x] - os: [linux] - - '@esbuild/linux-x64@0.27.3': - resolution: {integrity: sha512-Czi8yzXUWIQYAtL/2y6vogER8pvcsOsk5cpwL4Gk5nJqH5UZiVByIY8Eorm5R13gq+DQKYg0+JyQoytLQas4dA==} - engines: {node: '>=18'} - cpu: [x64] - os: [linux] - - '@esbuild/netbsd-arm64@0.27.3': - resolution: {integrity: sha512-sDpk0RgmTCR/5HguIZa9n9u+HVKf40fbEUt+iTzSnCaGvY9kFP0YKBWZtJaraonFnqef5SlJ8/TiPAxzyS+UoA==} - engines: {node: '>=18'} - cpu: [arm64] - os: [netbsd] - - '@esbuild/netbsd-x64@0.27.3': - resolution: {integrity: sha512-P14lFKJl/DdaE00LItAukUdZO5iqNH7+PjoBm+fLQjtxfcfFE20Xf5CrLsmZdq5LFFZzb5JMZ9grUwvtVYzjiA==} - engines: {node: '>=18'} - cpu: [x64] - os: [netbsd] - - '@esbuild/openbsd-arm64@0.27.3': - resolution: {integrity: sha512-AIcMP77AvirGbRl/UZFTq5hjXK+2wC7qFRGoHSDrZ5v5b8DK/GYpXW3CPRL53NkvDqb9D+alBiC/dV0Fb7eJcw==} - engines: {node: '>=18'} - cpu: [arm64] - os: [openbsd] - - '@esbuild/openbsd-x64@0.27.3': - resolution: {integrity: sha512-DnW2sRrBzA+YnE70LKqnM3P+z8vehfJWHXECbwBmH/CU51z6FiqTQTHFenPlHmo3a8UgpLyH3PT+87OViOh1AQ==} - engines: {node: '>=18'} - cpu: [x64] - os: [openbsd] - - '@esbuild/openharmony-arm64@0.27.3': - resolution: {integrity: sha512-NinAEgr/etERPTsZJ7aEZQvvg/A6IsZG/LgZy+81wON2huV7SrK3e63dU0XhyZP4RKGyTm7aOgmQk0bGp0fy2g==} - engines: {node: '>=18'} - cpu: [arm64] - os: [openharmony] - - '@esbuild/sunos-x64@0.27.3': - resolution: {integrity: sha512-PanZ+nEz+eWoBJ8/f8HKxTTD172SKwdXebZ0ndd953gt1HRBbhMsaNqjTyYLGLPdoWHy4zLU7bDVJztF5f3BHA==} - engines: {node: '>=18'} - cpu: [x64] - os: [sunos] - - '@esbuild/win32-arm64@0.27.3': - resolution: {integrity: sha512-B2t59lWWYrbRDw/tjiWOuzSsFh1Y/E95ofKz7rIVYSQkUYBjfSgf6oeYPNWHToFRr2zx52JKApIcAS/D5TUBnA==} - engines: {node: '>=18'} - cpu: [arm64] - os: [win32] - - '@esbuild/win32-ia32@0.27.3': - resolution: {integrity: sha512-QLKSFeXNS8+tHW7tZpMtjlNb7HKau0QDpwm49u0vUp9y1WOF+PEzkU84y9GqYaAVW8aH8f3GcBck26jh54cX4Q==} - engines: {node: '>=18'} - cpu: [ia32] - os: [win32] - - '@esbuild/win32-x64@0.27.3': - resolution: {integrity: sha512-4uJGhsxuptu3OcpVAzli+/gWusVGwZZHTlS63hh++ehExkVT8SgiEf7/uC/PclrPPkLhZqGgCTjd0VWLo6xMqA==} - engines: {node: '>=18'} - cpu: [x64] - os: [win32] - - '@eslint-community/eslint-utils@4.9.1': - resolution: {integrity: sha512-phrYmNiYppR7znFEdqgfWHXR6NCkZEK7hwWDHZUjit/2/U0r6XvkDl0SYnoM51Hq7FhCGdLDT6zxCCOY1hexsQ==} - engines: {node: ^12.22.0 || ^14.17.0 || >=16.0.0} - peerDependencies: - eslint: ^6.0.0 || ^7.0.0 || >=8.0.0 - - '@eslint-community/regexpp@4.12.2': - resolution: {integrity: sha512-EriSTlt5OC9/7SXkRSCAhfSxxoSUgBm33OH+IkwbdpgoqsSsUg7y3uh+IICI/Qg4BBWr3U2i39RpmycbxMq4ew==} - engines: {node: ^12.0.0 || ^14.0.0 || >=16.0.0} - - '@eslint/config-array@0.23.3': - resolution: {integrity: sha512-j+eEWmB6YYLwcNOdlwQ6L2OsptI/LO6lNBuLIqe5R7RetD658HLoF+Mn7LzYmAWWNNzdC6cqP+L6r8ujeYXWLw==} - engines: {node: ^20.19.0 || ^22.13.0 || >=24} - - '@eslint/config-helpers@0.5.3': - resolution: {integrity: sha512-lzGN0onllOZCGroKJmRwY6QcEHxbjBw1gwB8SgRSqK8YbbtEXMvKynsXc3553ckIEBxsbMBU7oOZXKIPGZNeZw==} - engines: {node: ^20.19.0 || ^22.13.0 || >=24} - - '@eslint/core@1.1.1': - resolution: {integrity: sha512-QUPblTtE51/7/Zhfv8BDwO0qkkzQL7P/aWWbqcf4xWLEYn1oKjdO0gglQBB4GAsu7u6wjijbCmzsUTy6mnk6oQ==} - engines: {node: ^20.19.0 || ^22.13.0 || >=24} - - '@eslint/js@10.0.1': - resolution: {integrity: sha512-zeR9k5pd4gxjZ0abRoIaxdc7I3nDktoXZk2qOv9gCNWx3mVwEn32VRhyLaRsDiJjTs0xq/T8mfPtyuXu7GWBcA==} - engines: {node: ^20.19.0 || ^22.13.0 || >=24} - peerDependencies: - eslint: ^10.0.0 - peerDependenciesMeta: - eslint: - optional: true - - '@eslint/object-schema@3.0.3': - resolution: {integrity: sha512-iM869Pugn9Nsxbh/YHRqYiqd23AmIbxJOcpUMOuWCVNdoQJ5ZtwL6h3t0bcZzJUlC3Dq9jCFCESBZnX0GTv7iQ==} - engines: {node: ^20.19.0 || ^22.13.0 || >=24} - - '@eslint/plugin-kit@0.6.1': - resolution: {integrity: sha512-iH1B076HoAshH1mLpHMgwdGeTs0CYwL0SPMkGuSebZrwBp16v415e9NZXg2jtrqPVQjf6IANe2Vtlr5KswtcZQ==} - engines: {node: ^20.19.0 || ^22.13.0 || >=24} - - '@humanfs/core@0.19.1': - resolution: {integrity: sha512-5DyQ4+1JEUzejeK1JGICcideyfUbGixgS9jNgex5nqkW+cY7WZhxBigmieN5Qnw9ZosSNVC9KQKyb+GUaGyKUA==} - engines: {node: '>=18.18.0'} - - '@humanfs/node@0.16.7': - resolution: {integrity: sha512-/zUx+yOsIrG4Y43Eh2peDeKCxlRt/gET6aHfaKpuq267qXdYDFViVHfMaLyygZOnl0kGWxFIgsBy8QFuTLUXEQ==} - engines: {node: '>=18.18.0'} - - '@humanwhocodes/module-importer@1.0.1': - resolution: {integrity: sha512-bxveV4V8v5Yb4ncFTT3rPSgZBOpCkjfK0y4oVVVJwIuDVBRMDXrPyXRL988i5ap9m9bnyEEjWfm5WkBmtffLfA==} - engines: {node: '>=12.22'} - - '@humanwhocodes/retry@0.4.3': - resolution: {integrity: sha512-bV0Tgo9K4hfPCek+aMAn81RppFKv2ySDQeMoSZuvTASywNTnVJCArCZE2FWqpvIatKu7VMRLWlR1EazvVhDyhQ==} - engines: {node: '>=18.18'} - - '@jridgewell/gen-mapping@0.3.13': - resolution: {integrity: sha512-2kkt/7niJ6MgEPxF0bYdQ6etZaA+fQvDcLKckhy1yIQOzaoKjBBjSj63/aLVjYE3qhRt5dvM+uUyfCg6UKCBbA==} - - '@jridgewell/resolve-uri@3.1.2': - resolution: {integrity: sha512-bRISgCIjP20/tbWSPWMEi54QVPRZExkuD9lJL+UIxUKtwVJA8wW1Trb1jMs1RFXo1CBTNZ/5hpC9QvmKWdopKw==} - engines: {node: '>=6.0.0'} - - '@jridgewell/sourcemap-codec@1.5.5': - resolution: {integrity: sha512-cYQ9310grqxueWbl+WuIUIaiUaDcj7WOq5fVhEljNVgRfOUhY9fy2zTvfoqWsnebh8Sl70VScFbICvJnLKB0Og==} - - '@jridgewell/trace-mapping@0.3.31': - resolution: {integrity: sha512-zzNR+SdQSDJzc8joaeP8QQoCQr8NuYx2dIIytl1QeBEZHJ9uW6hebsrYgbz8hJwUQao3TWCMtmfV8Nu1twOLAw==} - - '@rollup/rollup-android-arm-eabi@4.59.0': - resolution: {integrity: sha512-upnNBkA6ZH2VKGcBj9Fyl9IGNPULcjXRlg0LLeaioQWueH30p6IXtJEbKAgvyv+mJaMxSm1l6xwDXYjpEMiLMg==} - cpu: [arm] - os: [android] - - '@rollup/rollup-android-arm64@4.59.0': - resolution: {integrity: sha512-hZ+Zxj3SySm4A/DylsDKZAeVg0mvi++0PYVceVyX7hemkw7OreKdCvW2oQ3T1FMZvCaQXqOTHb8qmBShoqk69Q==} - cpu: [arm64] - os: [android] - - '@rollup/rollup-darwin-arm64@4.59.0': - resolution: {integrity: sha512-W2Psnbh1J8ZJw0xKAd8zdNgF9HRLkdWwwdWqubSVk0pUuQkoHnv7rx4GiF9rT4t5DIZGAsConRE3AxCdJ4m8rg==} - cpu: [arm64] - os: [darwin] - - '@rollup/rollup-darwin-x64@4.59.0': - resolution: {integrity: sha512-ZW2KkwlS4lwTv7ZVsYDiARfFCnSGhzYPdiOU4IM2fDbL+QGlyAbjgSFuqNRbSthybLbIJ915UtZBtmuLrQAT/w==} - cpu: [x64] - os: [darwin] - - '@rollup/rollup-freebsd-arm64@4.59.0': - resolution: {integrity: sha512-EsKaJ5ytAu9jI3lonzn3BgG8iRBjV4LxZexygcQbpiU0wU0ATxhNVEpXKfUa0pS05gTcSDMKpn3Sx+QB9RlTTA==} - cpu: [arm64] - os: [freebsd] - - '@rollup/rollup-freebsd-x64@4.59.0': - resolution: {integrity: sha512-d3DuZi2KzTMjImrxoHIAODUZYoUUMsuUiY4SRRcJy6NJoZ6iIqWnJu9IScV9jXysyGMVuW+KNzZvBLOcpdl3Vg==} - cpu: [x64] - os: [freebsd] - - '@rollup/rollup-linux-arm-gnueabihf@4.59.0': - resolution: {integrity: sha512-t4ONHboXi/3E0rT6OZl1pKbl2Vgxf9vJfWgmUoCEVQVxhW6Cw/c8I6hbbu7DAvgp82RKiH7TpLwxnJeKv2pbsw==} - cpu: [arm] - os: [linux] - libc: [glibc] - - '@rollup/rollup-linux-arm-musleabihf@4.59.0': - resolution: {integrity: sha512-CikFT7aYPA2ufMD086cVORBYGHffBo4K8MQ4uPS/ZnY54GKj36i196u8U+aDVT2LX4eSMbyHtyOh7D7Zvk2VvA==} - cpu: [arm] - os: [linux] - libc: [musl] - - '@rollup/rollup-linux-arm64-gnu@4.59.0': - resolution: {integrity: sha512-jYgUGk5aLd1nUb1CtQ8E+t5JhLc9x5WdBKew9ZgAXg7DBk0ZHErLHdXM24rfX+bKrFe+Xp5YuJo54I5HFjGDAA==} - cpu: [arm64] - os: [linux] - libc: [glibc] - - '@rollup/rollup-linux-arm64-musl@4.59.0': - resolution: {integrity: sha512-peZRVEdnFWZ5Bh2KeumKG9ty7aCXzzEsHShOZEFiCQlDEepP1dpUl/SrUNXNg13UmZl+gzVDPsiCwnV1uI0RUA==} - cpu: [arm64] - os: [linux] - libc: [musl] - - '@rollup/rollup-linux-loong64-gnu@4.59.0': - resolution: {integrity: sha512-gbUSW/97f7+r4gHy3Jlup8zDG190AuodsWnNiXErp9mT90iCy9NKKU0Xwx5k8VlRAIV2uU9CsMnEFg/xXaOfXg==} - cpu: [loong64] - os: [linux] - libc: [glibc] - - '@rollup/rollup-linux-loong64-musl@4.59.0': - resolution: {integrity: sha512-yTRONe79E+o0FWFijasoTjtzG9EBedFXJMl888NBEDCDV9I2wGbFFfJQQe63OijbFCUZqxpHz1GzpbtSFikJ4Q==} - cpu: [loong64] - os: [linux] - libc: [musl] - - '@rollup/rollup-linux-ppc64-gnu@4.59.0': - resolution: {integrity: sha512-sw1o3tfyk12k3OEpRddF68a1unZ5VCN7zoTNtSn2KndUE+ea3m3ROOKRCZxEpmT9nsGnogpFP9x6mnLTCaoLkA==} - cpu: [ppc64] - os: [linux] - libc: [glibc] - - '@rollup/rollup-linux-ppc64-musl@4.59.0': - resolution: {integrity: sha512-+2kLtQ4xT3AiIxkzFVFXfsmlZiG5FXYW7ZyIIvGA7Bdeuh9Z0aN4hVyXS/G1E9bTP/vqszNIN/pUKCk/BTHsKA==} - cpu: [ppc64] - os: [linux] - libc: [musl] - - '@rollup/rollup-linux-riscv64-gnu@4.59.0': - resolution: {integrity: sha512-NDYMpsXYJJaj+I7UdwIuHHNxXZ/b/N2hR15NyH3m2qAtb/hHPA4g4SuuvrdxetTdndfj9b1WOmy73kcPRoERUg==} - cpu: [riscv64] - os: [linux] - libc: [glibc] - - '@rollup/rollup-linux-riscv64-musl@4.59.0': - resolution: {integrity: sha512-nLckB8WOqHIf1bhymk+oHxvM9D3tyPndZH8i8+35p/1YiVoVswPid2yLzgX7ZJP0KQvnkhM4H6QZ5m0LzbyIAg==} - cpu: [riscv64] - os: [linux] - libc: [musl] - - '@rollup/rollup-linux-s390x-gnu@4.59.0': - resolution: {integrity: sha512-oF87Ie3uAIvORFBpwnCvUzdeYUqi2wY6jRFWJAy1qus/udHFYIkplYRW+wo+GRUP4sKzYdmE1Y3+rY5Gc4ZO+w==} - cpu: [s390x] - os: [linux] - libc: [glibc] - - '@rollup/rollup-linux-x64-gnu@4.59.0': - resolution: {integrity: sha512-3AHmtQq/ppNuUspKAlvA8HtLybkDflkMuLK4DPo77DfthRb71V84/c4MlWJXixZz4uruIH4uaa07IqoAkG64fg==} - cpu: [x64] - os: [linux] - libc: [glibc] - - '@rollup/rollup-linux-x64-musl@4.59.0': - resolution: {integrity: sha512-2UdiwS/9cTAx7qIUZB/fWtToJwvt0Vbo0zmnYt7ED35KPg13Q0ym1g442THLC7VyI6JfYTP4PiSOWyoMdV2/xg==} - cpu: [x64] - os: [linux] - libc: [musl] - - '@rollup/rollup-openbsd-x64@4.59.0': - resolution: {integrity: sha512-M3bLRAVk6GOwFlPTIxVBSYKUaqfLrn8l0psKinkCFxl4lQvOSz8ZrKDz2gxcBwHFpci0B6rttydI4IpS4IS/jQ==} - cpu: [x64] - os: [openbsd] - - '@rollup/rollup-openharmony-arm64@4.59.0': - resolution: {integrity: sha512-tt9KBJqaqp5i5HUZzoafHZX8b5Q2Fe7UjYERADll83O4fGqJ49O1FsL6LpdzVFQcpwvnyd0i+K/VSwu/o/nWlA==} - cpu: [arm64] - os: [openharmony] - - '@rollup/rollup-win32-arm64-msvc@4.59.0': - resolution: {integrity: sha512-V5B6mG7OrGTwnxaNUzZTDTjDS7F75PO1ae6MJYdiMu60sq0CqN5CVeVsbhPxalupvTX8gXVSU9gq+Rx1/hvu6A==} - cpu: [arm64] - os: [win32] - - '@rollup/rollup-win32-ia32-msvc@4.59.0': - resolution: {integrity: sha512-UKFMHPuM9R0iBegwzKF4y0C4J9u8C6MEJgFuXTBerMk7EJ92GFVFYBfOZaSGLu6COf7FxpQNqhNS4c4icUPqxA==} - cpu: [ia32] - os: [win32] - - '@rollup/rollup-win32-x64-gnu@4.59.0': - resolution: {integrity: sha512-laBkYlSS1n2L8fSo1thDNGrCTQMmxjYY5G0WFWjFFYZkKPjsMBsgJfGf4TLxXrF6RyhI60L8TMOjBMvXiTcxeA==} - cpu: [x64] - os: [win32] - - '@rollup/rollup-win32-x64-msvc@4.59.0': - resolution: {integrity: sha512-2HRCml6OztYXyJXAvdDXPKcawukWY2GpR5/nxKp4iBgiO3wcoEGkAaqctIbZcNB6KlUQBIqt8VYkNSj2397EfA==} - cpu: [x64] - os: [win32] - - '@standard-schema/spec@1.1.0': - resolution: {integrity: sha512-l2aFy5jALhniG5HgqrD6jXLi/rUWrKvqN/qJx6yoJsgKhblVd+iqqU4RCXavm/jPityDo5TCvKMnpjKnOriy0w==} - - '@types/chai@5.2.3': - resolution: {integrity: sha512-Mw558oeA9fFbv65/y4mHtXDs9bPnFMZAL/jxdPFUpOHHIXX91mcgEHbS5Lahr+pwZFR8A7GQleRWeI6cGFC2UA==} - - '@types/deep-eql@4.0.2': - resolution: {integrity: sha512-c9h9dVVMigMPc4bwTvC5dxqtqJZwQPePsWjPlpSOnojbor6pGqdk541lfA7AqFQr5pB1BRdq0juY9db81BwyFw==} - - '@types/esrecurse@4.3.1': - resolution: {integrity: sha512-xJBAbDifo5hpffDBuHl0Y8ywswbiAp/Wi7Y/GtAgSlZyIABppyurxVueOPE8LUQOxdlgi6Zqce7uoEpqNTeiUw==} - - '@types/estree@1.0.8': - resolution: {integrity: sha512-dWHzHa2WqEXI/O1E9OjrocMTKJl2mSrEolh1Iomrv6U+JuNwaHXsXx9bLu5gG7BUWFIN0skIQJQ/L1rIex4X6w==} - - '@types/json-schema@7.0.15': - resolution: {integrity: sha512-5+fP8P8MFNC+AyZCDxrB2pkZFPGzqQWUzpSeuuVLvm8VMcorNYavBqoFcxK8bQz4Qsbn4oUEEem4wDLfcysGHA==} - - '@types/node@25.4.0': - resolution: {integrity: sha512-9wLpoeWuBlcbBpOY3XmzSTG3oscB6xjBEEtn+pYXTfhyXhIxC5FsBer2KTopBlvKEiW9l13po9fq+SJY/5lkhw==} - - '@typescript-eslint/eslint-plugin@8.57.0': - resolution: {integrity: sha512-qeu4rTHR3/IaFORbD16gmjq9+rEs9fGKdX0kF6BKSfi+gCuG3RCKLlSBYzn/bGsY9Tj7KE/DAQStbp8AHJGHEQ==} - engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} - peerDependencies: - '@typescript-eslint/parser': ^8.57.0 - eslint: ^8.57.0 || ^9.0.0 || ^10.0.0 - typescript: '>=4.8.4 <6.0.0' - - '@typescript-eslint/parser@8.57.0': - resolution: {integrity: sha512-XZzOmihLIr8AD1b9hL9ccNMzEMWt/dE2u7NyTY9jJG6YNiNthaD5XtUHVF2uCXZ15ng+z2hT3MVuxnUYhq6k1g==} - engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} - peerDependencies: - eslint: ^8.57.0 || ^9.0.0 || ^10.0.0 - typescript: '>=4.8.4 <6.0.0' - - '@typescript-eslint/project-service@8.57.0': - resolution: {integrity: sha512-pR+dK0BlxCLxtWfaKQWtYr7MhKmzqZxuii+ZjuFlZlIGRZm22HnXFqa2eY+90MUz8/i80YJmzFGDUsi8dMOV5w==} - engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} - peerDependencies: - typescript: '>=4.8.4 <6.0.0' - - '@typescript-eslint/scope-manager@8.57.0': - resolution: {integrity: sha512-nvExQqAHF01lUM66MskSaZulpPL5pgy5hI5RfrxviLgzZVffB5yYzw27uK/ft8QnKXI2X0LBrHJFr1TaZtAibw==} - engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} - - '@typescript-eslint/tsconfig-utils@8.57.0': - resolution: {integrity: sha512-LtXRihc5ytjJIQEH+xqjB0+YgsV4/tW35XKX3GTZHpWtcC8SPkT/d4tqdf1cKtesryHm2bgp6l555NYcT2NLvA==} - engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} - peerDependencies: - typescript: '>=4.8.4 <6.0.0' - - '@typescript-eslint/type-utils@8.57.0': - resolution: {integrity: sha512-yjgh7gmDcJ1+TcEg8x3uWQmn8ifvSupnPfjP21twPKrDP/pTHlEQgmKcitzF/rzPSmv7QjJ90vRpN4U+zoUjwQ==} - engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} - peerDependencies: - eslint: ^8.57.0 || ^9.0.0 || ^10.0.0 - typescript: '>=4.8.4 <6.0.0' - - '@typescript-eslint/types@8.57.0': - resolution: {integrity: sha512-dTLI8PEXhjUC7B9Kre+u0XznO696BhXcTlOn0/6kf1fHaQW8+VjJAVHJ3eTI14ZapTxdkOmc80HblPQLaEeJdg==} - engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} - - '@typescript-eslint/typescript-estree@8.57.0': - resolution: {integrity: sha512-m7faHcyVg0BT3VdYTlX8GdJEM7COexXxS6KqGopxdtkQRvBanK377QDHr4W/vIPAR+ah9+B/RclSW5ldVniO1Q==} - engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} - peerDependencies: - typescript: '>=4.8.4 <6.0.0' - - '@typescript-eslint/utils@8.57.0': - resolution: {integrity: sha512-5iIHvpD3CZe06riAsbNxxreP+MuYgVUsV0n4bwLH//VJmgtt54sQeY2GszntJ4BjYCpMzrfVh2SBnUQTtys2lQ==} - engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} - peerDependencies: - eslint: ^8.57.0 || ^9.0.0 || ^10.0.0 - typescript: '>=4.8.4 <6.0.0' - - '@typescript-eslint/visitor-keys@8.57.0': - resolution: {integrity: sha512-zm6xx8UT/Xy2oSr2ZXD0pZo7Jx2XsCoID2IUh9YSTFRu7z+WdwYTRk6LhUftm1crwqbuoF6I8zAFeCMw0YjwDg==} - engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} - - '@vitest/coverage-v8@4.0.18': - resolution: {integrity: sha512-7i+N2i0+ME+2JFZhfuz7Tg/FqKtilHjGyGvoHYQ6iLV0zahbsJ9sljC9OcFcPDbhYKCet+sG8SsVqlyGvPflZg==} - peerDependencies: - '@vitest/browser': 4.0.18 - vitest: 4.0.18 - peerDependenciesMeta: - '@vitest/browser': - optional: true - - '@vitest/expect@4.0.18': - resolution: {integrity: sha512-8sCWUyckXXYvx4opfzVY03EOiYVxyNrHS5QxX3DAIi5dpJAAkyJezHCP77VMX4HKA2LDT/Jpfo8i2r5BE3GnQQ==} - - '@vitest/mocker@4.0.18': - resolution: {integrity: sha512-HhVd0MDnzzsgevnOWCBj5Otnzobjy5wLBe4EdeeFGv8luMsGcYqDuFRMcttKWZA5vVO8RFjexVovXvAM4JoJDQ==} - peerDependencies: - msw: ^2.4.9 - vite: ^6.0.0 || ^7.0.0-0 - peerDependenciesMeta: - msw: - optional: true - vite: - optional: true - - '@vitest/pretty-format@4.0.18': - resolution: {integrity: sha512-P24GK3GulZWC5tz87ux0m8OADrQIUVDPIjjj65vBXYG17ZeU3qD7r+MNZ1RNv4l8CGU2vtTRqixrOi9fYk/yKw==} - - '@vitest/runner@4.0.18': - resolution: {integrity: sha512-rpk9y12PGa22Jg6g5M3UVVnTS7+zycIGk9ZNGN+m6tZHKQb7jrP7/77WfZy13Y/EUDd52NDsLRQhYKtv7XfPQw==} - - '@vitest/snapshot@4.0.18': - resolution: {integrity: sha512-PCiV0rcl7jKQjbgYqjtakly6T1uwv/5BQ9SwBLekVg/EaYeQFPiXcgrC2Y7vDMA8dM1SUEAEV82kgSQIlXNMvA==} - - '@vitest/spy@4.0.18': - resolution: {integrity: sha512-cbQt3PTSD7P2OARdVW3qWER5EGq7PHlvE+QfzSC0lbwO+xnt7+XH06ZzFjFRgzUX//JmpxrCu92VdwvEPlWSNw==} - - '@vitest/utils@4.0.18': - resolution: {integrity: sha512-msMRKLMVLWygpK3u2Hybgi4MNjcYJvwTb0Ru09+fOyCXIgT5raYP041DRRdiJiI3k/2U6SEbAETB3YtBrUkCFA==} - - acorn-jsx@5.3.2: - resolution: {integrity: sha512-rq9s+JNhf0IChjtDXxllJ7g41oZk5SlXtp0LHwyA5cejwn7vKmKp4pPri6YEePv2PU65sAsegbXtIinmDFDXgQ==} - peerDependencies: - acorn: ^6.0.0 || ^7.0.0 || ^8.0.0 - - acorn@8.16.0: - resolution: {integrity: sha512-UVJyE9MttOsBQIDKw1skb9nAwQuR5wuGD3+82K6JgJlm/Y+KI92oNsMNGZCYdDsVtRHSak0pcV5Dno5+4jh9sw==} - engines: {node: '>=0.4.0'} - hasBin: true - - ajv@6.14.0: - resolution: {integrity: sha512-IWrosm/yrn43eiKqkfkHis7QioDleaXQHdDVPKg0FSwwd/DuvyX79TZnFOnYpB7dcsFAMmtFztZuXPDvSePkFw==} - - any-promise@1.3.0: - resolution: {integrity: sha512-7UvmKalWRt1wgjL1RrGxoSJW/0QZFIegpeGvZG9kjp8vrRu55XTHbwnqq2GpXm9uLbcuhxm3IqX9OB4MZR1b2A==} - - assertion-error@2.0.1: - resolution: {integrity: sha512-Izi8RQcffqCeNVgFigKli1ssklIbpHnCYc6AknXGYoB6grJqyeby7jv12JUQgmTAnIDnbck1uxksT4dzN3PWBA==} - engines: {node: '>=12'} - - ast-v8-to-istanbul@0.3.12: - resolution: {integrity: sha512-BRRC8VRZY2R4Z4lFIL35MwNXmwVqBityvOIwETtsCSwvjl0IdgFsy9NhdaA6j74nUdtJJlIypeRhpDam19Wq3g==} - - asynckit@0.4.0: - resolution: {integrity: sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==} - - axios@1.13.6: - resolution: {integrity: sha512-ChTCHMouEe2kn713WHbQGcuYrr6fXTBiu460OTwWrWob16g1bXn4vtz07Ope7ewMozJAnEquLk5lWQWtBig9DQ==} - - balanced-match@4.0.4: - resolution: {integrity: sha512-BLrgEcRTwX2o6gGxGOCNyMvGSp35YofuYzw9h1IMTRmKqttAZZVU67bdb9Pr2vUHA8+j3i2tJfjO6C6+4myGTA==} - engines: {node: 18 || 20 || >=22} - - brace-expansion@5.0.4: - resolution: {integrity: sha512-h+DEnpVvxmfVefa4jFbCf5HdH5YMDXRsmKflpf1pILZWRFlTbJpxeU55nJl4Smt5HQaGzg1o6RHFPJaOqnmBDg==} - engines: {node: 18 || 20 || >=22} - - bundle-require@5.1.0: - resolution: {integrity: sha512-3WrrOuZiyaaZPWiEt4G3+IffISVC9HYlWueJEBWED4ZH4aIAC2PnkdnuRrR94M+w6yGWn4AglWtJtBI8YqvgoA==} - engines: {node: ^12.20.0 || ^14.13.1 || >=16.0.0} - peerDependencies: - esbuild: '>=0.18' - - cac@6.7.14: - resolution: {integrity: sha512-b6Ilus+c3RrdDk+JhLKUAQfzzgLEPy6wcXqS7f/xe1EETvsDP6GORG7SFuOs6cID5YkqchW/LXZbX5bc8j7ZcQ==} - engines: {node: '>=8'} - - call-bind-apply-helpers@1.0.2: - resolution: {integrity: sha512-Sp1ablJ0ivDkSzjcaJdxEunN5/XvksFJ2sMBFfq6x0ryhQV/2b/KwFe21cMpmHtPOSij8K99/wSfoEuTObmuMQ==} - engines: {node: '>= 0.4'} - - chai@6.2.2: - resolution: {integrity: sha512-NUPRluOfOiTKBKvWPtSD4PhFvWCqOi0BGStNWs57X9js7XGTprSmFoz5F0tWhR4WPjNeR9jXqdC7/UpSJTnlRg==} - engines: {node: '>=18'} - - chokidar@4.0.3: - resolution: {integrity: sha512-Qgzu8kfBvo+cA4962jnP1KkS6Dop5NS6g7R5LFYJr4b8Ub94PPQXUksCw9PvXoeXPRRddRNC5C1JQUR2SMGtnA==} - engines: {node: '>= 14.16.0'} - - combined-stream@1.0.8: - resolution: {integrity: sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg==} - engines: {node: '>= 0.8'} - - commander@4.1.1: - resolution: {integrity: sha512-NOKm8xhkzAjzFx8B2v5OAHT+u5pRQc2UCa2Vq9jYL/31o2wi9mxBA7LIFs3sV5VSC49z6pEhfbMULvShKj26WA==} - engines: {node: '>= 6'} - - confbox@0.1.8: - resolution: {integrity: sha512-RMtmw0iFkeR4YV+fUOSucriAQNb9g8zFR52MWCtl+cCZOFRNL6zeB395vPzFhEjjn4fMxXudmELnl/KF/WrK6w==} - - consola@3.4.2: - resolution: {integrity: sha512-5IKcdX0nnYavi6G7TtOhwkYzyjfJlatbjMjuLSfE2kYT5pMDOilZ4OvMhi637CcDICTmz3wARPoyhqyX1Y+XvA==} - engines: {node: ^14.18.0 || >=16.10.0} - - cross-spawn@7.0.6: - resolution: {integrity: sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==} - engines: {node: '>= 8'} - - debug@4.4.3: - resolution: {integrity: sha512-RGwwWnwQvkVfavKVt22FGLw+xYSdzARwm0ru6DhTVA3umU5hZc28V3kO4stgYryrTlLpuvgI9GiijltAjNbcqA==} - engines: {node: '>=6.0'} - peerDependencies: - supports-color: '*' - peerDependenciesMeta: - supports-color: - optional: true - - deep-is@0.1.4: - resolution: {integrity: sha512-oIPzksmTg4/MriiaYGO+okXDT7ztn/w3Eptv/+gSIdMdKsJo0u4CfYNFJPy+4SKMuCqGw2wxnA+URMg3t8a/bQ==} - - delayed-stream@1.0.0: - resolution: {integrity: sha512-ZySD7Nf91aLB0RxL4KGrKHBXl7Eds1DAmEdcoVawXnLD7SDhpNgtuII2aAkg7a7QS41jxPSZ17p4VdGnMHk3MQ==} - engines: {node: '>=0.4.0'} - - dunder-proto@1.0.1: - resolution: {integrity: sha512-KIN/nDJBQRcXw0MLVhZE9iQHmG68qAVIBg9CqmUYjmQIhgij9U5MFvrqkUL5FbtyyzZuOeOt0zdeRe4UY7ct+A==} - engines: {node: '>= 0.4'} - - es-define-property@1.0.1: - resolution: {integrity: sha512-e3nRfgfUZ4rNGL232gUgX06QNyyez04KdjFrF+LTRoOXmrOgFKDg4BCdsjW8EnT69eqdYGmRpJwiPVYNrCaW3g==} - engines: {node: '>= 0.4'} - - es-errors@1.3.0: - resolution: {integrity: sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==} - engines: {node: '>= 0.4'} - - es-module-lexer@1.7.0: - resolution: {integrity: sha512-jEQoCwk8hyb2AZziIOLhDqpm5+2ww5uIE6lkO/6jcOCusfk6LhMHpXXfBLXTZ7Ydyt0j4VoUQv6uGNYbdW+kBA==} - - es-object-atoms@1.1.1: - resolution: {integrity: sha512-FGgH2h8zKNim9ljj7dankFPcICIK9Cp5bm+c2gQSYePhpaG5+esrLODihIorn+Pe6FGJzWhXQotPv73jTaldXA==} - engines: {node: '>= 0.4'} - - es-set-tostringtag@2.1.0: - resolution: {integrity: sha512-j6vWzfrGVfyXxge+O0x5sh6cvxAog0a/4Rdd2K36zCMV5eJ+/+tOAngRO8cODMNWbVRdVlmGZQL2YS3yR8bIUA==} - engines: {node: '>= 0.4'} - - esbuild@0.27.3: - resolution: {integrity: sha512-8VwMnyGCONIs6cWue2IdpHxHnAjzxnw2Zr7MkVxB2vjmQ2ivqGFb4LEG3SMnv0Gb2F/G/2yA8zUaiL1gywDCCg==} - engines: {node: '>=18'} - hasBin: true - - escape-string-regexp@4.0.0: - resolution: {integrity: sha512-TtpcNJ3XAzx3Gq8sWRzJaVajRs0uVxA2YAkdb1jm2YkPz4G6egUFAyA3n5vtEIZefPk5Wa4UXbKuS5fKkJWdgA==} - engines: {node: '>=10'} - - eslint-scope@9.1.2: - resolution: {integrity: sha512-xS90H51cKw0jltxmvmHy2Iai1LIqrfbw57b79w/J7MfvDfkIkFZ+kj6zC3BjtUwh150HsSSdxXZcsuv72miDFQ==} - engines: {node: ^20.19.0 || ^22.13.0 || >=24} - - eslint-visitor-keys@3.4.3: - resolution: {integrity: sha512-wpc+LXeiyiisxPlEkUzU6svyS1frIO3Mgxj1fdy7Pm8Ygzguax2N3Fa/D/ag1WqbOprdI+uY6wMUl8/a2G+iag==} - engines: {node: ^12.22.0 || ^14.17.0 || >=16.0.0} - - eslint-visitor-keys@5.0.1: - resolution: {integrity: sha512-tD40eHxA35h0PEIZNeIjkHoDR4YjjJp34biM0mDvplBe//mB+IHCqHDGV7pxF+7MklTvighcCPPZC7ynWyjdTA==} - engines: {node: ^20.19.0 || ^22.13.0 || >=24} - - eslint@10.0.3: - resolution: {integrity: sha512-COV33RzXZkqhG9P2rZCFl9ZmJ7WL+gQSCRzE7RhkbclbQPtLAWReL7ysA0Sh4c8Im2U9ynybdR56PV0XcKvqaQ==} - engines: {node: ^20.19.0 || ^22.13.0 || >=24} - hasBin: true - peerDependencies: - jiti: '*' - peerDependenciesMeta: - jiti: - optional: true - - espree@11.2.0: - resolution: {integrity: sha512-7p3DrVEIopW1B1avAGLuCSh1jubc01H2JHc8B4qqGblmg5gI9yumBgACjWo4JlIc04ufug4xJ3SQI8HkS/Rgzw==} - engines: {node: ^20.19.0 || ^22.13.0 || >=24} - - esquery@1.7.0: - resolution: {integrity: sha512-Ap6G0WQwcU/LHsvLwON1fAQX9Zp0A2Y6Y/cJBl9r/JbW90Zyg4/zbG6zzKa2OTALELarYHmKu0GhpM5EO+7T0g==} - engines: {node: '>=0.10'} - - esrecurse@4.3.0: - resolution: {integrity: sha512-KmfKL3b6G+RXvP8N1vr3Tq1kL/oCFgn2NYXEtqP8/L3pKapUA4G8cFVaoF3SU323CD4XypR/ffioHmkti6/Tag==} - engines: {node: '>=4.0'} - - estraverse@5.3.0: - resolution: {integrity: sha512-MMdARuVEQziNTeJD8DgMqmhwR11BRQ/cBP+pLtYdSTnf3MIO8fFeiINEbX36ZdNlfU/7A9f3gUw49B3oQsvwBA==} - engines: {node: '>=4.0'} - - estree-walker@3.0.3: - resolution: {integrity: sha512-7RUKfXgSMMkzt6ZuXmqapOurLGPPfgj6l9uRZ7lRGolvk0y2yocc35LdcxKC5PQZdn2DMqioAQ2NoWcrTKmm6g==} - - esutils@2.0.3: - resolution: {integrity: sha512-kVscqXk4OCp68SZ0dkgEKVi6/8ij300KBWTJq32P/dYeWTSwK41WyTxalN1eRmA5Z9UU/LX9D7FWSmV9SAYx6g==} - engines: {node: '>=0.10.0'} - - expect-type@1.3.0: - resolution: {integrity: sha512-knvyeauYhqjOYvQ66MznSMs83wmHrCycNEN6Ao+2AeYEfxUIkuiVxdEa1qlGEPK+We3n0THiDciYSsCcgW/DoA==} - engines: {node: '>=12.0.0'} - - fast-deep-equal@3.1.3: - resolution: {integrity: sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==} - - fast-json-stable-stringify@2.1.0: - resolution: {integrity: sha512-lhd/wF+Lk98HZoTCtlVraHtfh5XYijIjalXck7saUtuanSDyLMxnHhSXEDJqHxD7msR8D0uCmqlkwjCV8xvwHw==} - - fast-levenshtein@2.0.6: - resolution: {integrity: sha512-DCXu6Ifhqcks7TZKY3Hxp3y6qphY5SJZmrWMDrKcERSOXWQdMhU9Ig/PYrzyw/ul9jOIyh0N4M0tbC5hodg8dw==} - - fdir@6.5.0: - resolution: {integrity: sha512-tIbYtZbucOs0BRGqPJkshJUYdL+SDH7dVM8gjy+ERp3WAUjLEFJE+02kanyHtwjWOnwrKYBiwAmM0p4kLJAnXg==} - engines: {node: '>=12.0.0'} - peerDependencies: - picomatch: ^3 || ^4 - peerDependenciesMeta: - picomatch: - optional: true - - file-entry-cache@8.0.0: - resolution: {integrity: sha512-XXTUwCvisa5oacNGRP9SfNtYBNAMi+RPwBFmblZEF7N7swHYQS6/Zfk7SRwx4D5j3CH211YNRco1DEMNVfZCnQ==} - engines: {node: '>=16.0.0'} - - find-up@5.0.0: - resolution: {integrity: sha512-78/PXT1wlLLDgTzDs7sjq9hzz0vXD+zn+7wypEe4fXQxCmdmqfGsEPQxmiCSQI3ajFV91bVSsvNtrJRiW6nGng==} - engines: {node: '>=10'} - - fix-dts-default-cjs-exports@1.0.1: - resolution: {integrity: sha512-pVIECanWFC61Hzl2+oOCtoJ3F17kglZC/6N94eRWycFgBH35hHx0Li604ZIzhseh97mf2p0cv7vVrOZGoqhlEg==} - - flat-cache@4.0.1: - resolution: {integrity: sha512-f7ccFPK3SXFHpx15UIGyRJ/FJQctuKZ0zVuN3frBo4HnK3cay9VEW0R6yPYFHC0AgqhukPzKjq22t5DmAyqGyw==} - engines: {node: '>=16'} - - flatted@3.4.1: - resolution: {integrity: sha512-IxfVbRFVlV8V/yRaGzk0UVIcsKKHMSfYw66T/u4nTwlWteQePsxe//LjudR1AMX4tZW3WFCh3Zqa/sjlqpbURQ==} - - follow-redirects@1.15.11: - resolution: {integrity: sha512-deG2P0JfjrTxl50XGCDyfI97ZGVCxIpfKYmfyrQ54n5FO/0gfIES8C/Psl6kWVDolizcaaxZJnTS0QSMxvnsBQ==} - engines: {node: '>=4.0'} - peerDependencies: - debug: '*' - peerDependenciesMeta: - debug: - optional: true - - form-data@4.0.5: - resolution: {integrity: sha512-8RipRLol37bNs2bhoV67fiTEvdTrbMUYcFTiy3+wuuOnUog2QBHCZWXDRijWQfAkhBj2Uf5UnVaiWwA5vdd82w==} - engines: {node: '>= 6'} - - fsevents@2.3.3: - resolution: {integrity: sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==} - engines: {node: ^8.16.0 || ^10.6.0 || >=11.0.0} - os: [darwin] - - function-bind@1.1.2: - resolution: {integrity: sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA==} - - get-intrinsic@1.3.0: - resolution: {integrity: sha512-9fSjSaos/fRIVIp+xSJlE6lfwhES7LNtKaCBIamHsjr2na1BiABJPo0mOjjz8GJDURarmCPGqaiVg5mfjb98CQ==} - engines: {node: '>= 0.4'} - - get-proto@1.0.1: - resolution: {integrity: sha512-sTSfBjoXBp89JvIKIefqw7U2CCebsc74kiY6awiGogKtoSGbgjYE/G/+l9sF3MWFPNc9IcoOC4ODfKHfxFmp0g==} - engines: {node: '>= 0.4'} - - glob-parent@6.0.2: - resolution: {integrity: sha512-XxwI8EOhVQgWp6iDL+3b0r86f4d6AX6zSU55HfB4ydCEuXLXc5FcYeOu+nnGftS4TEju/11rt4KJPTMgbfmv4A==} - engines: {node: '>=10.13.0'} - - gopd@1.2.0: - resolution: {integrity: sha512-ZUKRh6/kUFoAiTAtTYPZJ3hw9wNxx+BIBOijnlG9PnrJsCcSjs1wyyD6vJpaYtgnzDrKYRSqf3OO6Rfa93xsRg==} - engines: {node: '>= 0.4'} - - has-flag@4.0.0: - resolution: {integrity: sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==} - engines: {node: '>=8'} - - has-symbols@1.1.0: - resolution: {integrity: sha512-1cDNdwJ2Jaohmb3sg4OmKaMBwuC48sYni5HUw2DvsC8LjGTLK9h+eb1X6RyuOHe4hT0ULCW68iomhjUoKUqlPQ==} - engines: {node: '>= 0.4'} - - has-tostringtag@1.0.2: - resolution: {integrity: sha512-NqADB8VjPFLM2V0VvHUewwwsw0ZWBaIdgo+ieHtK3hasLz4qeCRjYcqfB6AQrBggRKppKF8L52/VqdVsO47Dlw==} - engines: {node: '>= 0.4'} - - hasown@2.0.2: - resolution: {integrity: sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==} - engines: {node: '>= 0.4'} - - html-escaper@2.0.2: - resolution: {integrity: sha512-H2iMtd0I4Mt5eYiapRdIDjp+XzelXQ0tFE4JS7YFwFevXXMmOp9myNrUvCg0D6ws8iqkRPBfKHgbwig1SmlLfg==} - - ignore@5.3.2: - resolution: {integrity: sha512-hsBTNUqQTDwkWtcdYI2i06Y/nUBEsNEDJKjWdigLvegy8kDuJAS8uRlpkkcQpyEXL0Z/pjDy5HBmMjRCJ2gq+g==} - engines: {node: '>= 4'} - - ignore@7.0.5: - resolution: {integrity: sha512-Hs59xBNfUIunMFgWAbGX5cq6893IbWg4KnrjbYwX3tx0ztorVgTDA6B2sxf8ejHJ4wz8BqGUMYlnzNBer5NvGg==} - engines: {node: '>= 4'} - - imurmurhash@0.1.4: - resolution: {integrity: sha512-JmXMZ6wuvDmLiHEml9ykzqO6lwFbof0GG4IkcGaENdCRDDmMVnny7s5HsIgHCbaq0w2MyPhDqkhTUgS2LU2PHA==} - engines: {node: '>=0.8.19'} - - is-extglob@2.1.1: - resolution: {integrity: sha512-SbKbANkN603Vi4jEZv49LeVJMn4yGwsbzZworEoyEiutsN3nJYdbO36zfhGJ6QEDpOZIFkDtnq5JRxmvl3jsoQ==} - engines: {node: '>=0.10.0'} - - is-glob@4.0.3: - resolution: {integrity: sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg==} - engines: {node: '>=0.10.0'} - - isexe@2.0.0: - resolution: {integrity: sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==} - - istanbul-lib-coverage@3.2.2: - resolution: {integrity: sha512-O8dpsF+r0WV/8MNRKfnmrtCWhuKjxrq2w+jpzBL5UZKTi2LeVWnWOmWRxFlesJONmc+wLAGvKQZEOanko0LFTg==} - engines: {node: '>=8'} - - istanbul-lib-report@3.0.1: - resolution: {integrity: sha512-GCfE1mtsHGOELCU8e/Z7YWzpmybrx/+dSTfLrvY8qRmaY6zXTKWn6WQIjaAFw069icm6GVMNkgu0NzI4iPZUNw==} - engines: {node: '>=10'} - - istanbul-reports@3.2.0: - resolution: {integrity: sha512-HGYWWS/ehqTV3xN10i23tkPkpH46MLCIMFNCaaKNavAXTF1RkqxawEPtnjnGZ6XKSInBKkiOA5BKS+aZiY3AvA==} - engines: {node: '>=8'} - - joycon@3.1.1: - resolution: {integrity: sha512-34wB/Y7MW7bzjKRjUKTa46I2Z7eV62Rkhva+KkopW7Qvv/OSWBqvkSY7vusOPrNuZcUG3tApvdVgNB8POj3SPw==} - engines: {node: '>=10'} - - js-tokens@10.0.0: - resolution: {integrity: sha512-lM/UBzQmfJRo9ABXbPWemivdCW8V2G8FHaHdypQaIy523snUjog0W71ayWXTjiR+ixeMyVHN2XcpnTd/liPg/Q==} - - json-buffer@3.0.1: - resolution: {integrity: sha512-4bV5BfR2mqfQTJm+V5tPPdf+ZpuhiIvTuAB5g8kcrXOZpTT/QwwVRWBywX1ozr6lEuPdbHxwaJlm9G6mI2sfSQ==} - - json-schema-traverse@0.4.1: - resolution: {integrity: sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg==} - - json-stable-stringify-without-jsonify@1.0.1: - resolution: {integrity: sha512-Bdboy+l7tA3OGW6FjyFHWkP5LuByj1Tk33Ljyq0axyzdk9//JSi2u3fP1QSmd1KNwq6VOKYGlAu87CisVir6Pw==} - - keyv@4.5.4: - resolution: {integrity: sha512-oxVHkHR/EJf2CNXnWxRLW6mg7JyCCUcG0DtEGmL2ctUo1PNTin1PUil+r/+4r5MpVgC/fn1kjsx7mjSujKqIpw==} - - levn@0.4.1: - resolution: {integrity: sha512-+bT2uH4E5LGE7h/n3evcS/sQlJXCpIp6ym8OWJ5eV6+67Dsql/LaaT7qJBAt2rzfoa/5QBGBhxDix1dMt2kQKQ==} - engines: {node: '>= 0.8.0'} - - lilconfig@3.1.3: - resolution: {integrity: sha512-/vlFKAoH5Cgt3Ie+JLhRbwOsCQePABiU3tJ1egGvyQ+33R/vcwM2Zl2QR/LzjsBeItPt3oSVXapn+m4nQDvpzw==} - engines: {node: '>=14'} - - lines-and-columns@1.2.4: - resolution: {integrity: sha512-7ylylesZQ/PV29jhEDl3Ufjo6ZX7gCqJr5F7PKrqc93v7fzSymt1BpwEU8nAUXs8qzzvqhbjhK5QZg6Mt/HkBg==} - - load-tsconfig@0.2.5: - resolution: {integrity: sha512-IXO6OCs9yg8tMKzfPZ1YmheJbZCiEsnBdcB03l0OcfK9prKnJb96siuHCr5Fl37/yo9DnKU+TLpxzTUspw9shg==} - engines: {node: ^12.20.0 || ^14.13.1 || >=16.0.0} - - locate-path@6.0.0: - resolution: {integrity: sha512-iPZK6eYjbxRu3uB4/WZ3EsEIMJFMqAoopl3R+zuq0UjcAm/MO6KCweDgPfP3elTztoKP3KtnVHxTn2NHBSDVUw==} - engines: {node: '>=10'} - - magic-string@0.30.21: - resolution: {integrity: sha512-vd2F4YUyEXKGcLHoq+TEyCjxueSeHnFxyyjNp80yg0XV4vUhnDer/lvvlqM/arB5bXQN5K2/3oinyCRyx8T2CQ==} - - magicast@0.5.2: - resolution: {integrity: sha512-E3ZJh4J3S9KfwdjZhe2afj6R9lGIN5Pher1pF39UGrXRqq/VDaGVIGN13BjHd2u8B61hArAGOnso7nBOouW3TQ==} - - make-dir@4.0.0: - resolution: {integrity: sha512-hXdUTZYIVOt1Ex//jAQi+wTZZpUpwBj/0QsOzqegb3rGMMeJiSEu5xLHnYfBrRV4RH2+OCSOO95Is/7x1WJ4bw==} - engines: {node: '>=10'} - - math-intrinsics@1.1.0: - resolution: {integrity: sha512-/IXtbwEk5HTPyEwyKX6hGkYXxM9nbj64B+ilVJnC/R6B0pH5G4V3b0pVbL7DBj4tkhBAppbQUlf6F6Xl9LHu1g==} - engines: {node: '>= 0.4'} - - mime-db@1.52.0: - resolution: {integrity: sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg==} - engines: {node: '>= 0.6'} - - mime-types@2.1.35: - resolution: {integrity: sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==} - engines: {node: '>= 0.6'} - - minimatch@10.2.4: - resolution: {integrity: sha512-oRjTw/97aTBN0RHbYCdtF1MQfvusSIBQM0IZEgzl6426+8jSC0nF1a/GmnVLpfB9yyr6g6FTqWqiZVbxrtaCIg==} - engines: {node: 18 || 20 || >=22} - - mlly@1.8.1: - resolution: {integrity: sha512-SnL6sNutTwRWWR/vcmCYHSADjiEesp5TGQQ0pXyLhW5IoeibRlF/CbSLailbB3CNqJUk9cVJ9dUDnbD7GrcHBQ==} - - ms@2.1.3: - resolution: {integrity: sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==} - - mz@2.7.0: - resolution: {integrity: sha512-z81GNO7nnYMEhrGh9LeymoE4+Yr0Wn5McHIZMK5cfQCl+NDX08sCZgUc9/6MHni9IWuFLm1Z3HTCXu2z9fN62Q==} - - nanoid@3.3.11: - resolution: {integrity: sha512-N8SpfPUnUp1bK+PMYW8qSWdl9U+wwNWI4QKxOYDy9JAro3WMX7p2OeVRF9v+347pnakNevPmiHhNmZ2HbFA76w==} - engines: {node: ^10 || ^12 || ^13.7 || ^14 || >=15.0.1} - hasBin: true - - natural-compare@1.4.0: - resolution: {integrity: sha512-OWND8ei3VtNC9h7V60qff3SVobHr996CTwgxubgyQYEpg290h9J0buyECNNJexkFm5sOajh5G116RYA1c8ZMSw==} - - object-assign@4.1.1: - resolution: {integrity: sha512-rJgTQnkUnH1sFw8yT6VSU3zD3sWmu6sZhIseY8VX+GRu3P6F7Fu+JNDoXfklElbLJSnc3FUQHVe4cU5hj+BcUg==} - engines: {node: '>=0.10.0'} - - obug@2.1.1: - resolution: {integrity: sha512-uTqF9MuPraAQ+IsnPf366RG4cP9RtUi7MLO1N3KEc+wb0a6yKpeL0lmk2IB1jY5KHPAlTc6T/JRdC/YqxHNwkQ==} - - optionator@0.9.4: - resolution: {integrity: sha512-6IpQ7mKUxRcZNLIObR0hz7lxsapSSIYNZJwXPGeF0mTVqGKFIXj1DQcMoT22S3ROcLyY/rz0PWaWZ9ayWmad9g==} - engines: {node: '>= 0.8.0'} - - p-limit@3.1.0: - resolution: {integrity: sha512-TYOanM3wGwNGsZN2cVTYPArw454xnXj5qmWF1bEoAc4+cU/ol7GVh7odevjp1FNHduHc3KZMcFduxU5Xc6uJRQ==} - engines: {node: '>=10'} - - p-locate@5.0.0: - resolution: {integrity: sha512-LaNjtRWUBY++zB5nE/NwcaoMylSPk+S+ZHNB1TzdbMJMny6dynpAGt7X/tl/QYq3TIeE6nxHppbo2LGymrG5Pw==} - engines: {node: '>=10'} - - path-exists@4.0.0: - resolution: {integrity: sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w==} - engines: {node: '>=8'} - - path-key@3.1.1: - resolution: {integrity: sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q==} - engines: {node: '>=8'} - - pathe@2.0.3: - resolution: {integrity: sha512-WUjGcAqP1gQacoQe+OBJsFA7Ld4DyXuUIjZ5cc75cLHvJ7dtNsTugphxIADwspS+AraAUePCKrSVtPLFj/F88w==} - - picocolors@1.1.1: - resolution: {integrity: sha512-xceH2snhtb5M9liqDsmEw56le376mTZkEX/jEb/RxNFyegNul7eNslCXP9FDj/Lcu0X8KEyMceP2ntpaHrDEVA==} - - picomatch@4.0.3: - resolution: {integrity: sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==} - engines: {node: '>=12'} - - pirates@4.0.7: - resolution: {integrity: sha512-TfySrs/5nm8fQJDcBDuUng3VOUKsd7S+zqvbOTiGXHfxX4wK31ard+hoNuvkicM/2YFzlpDgABOevKSsB4G/FA==} - engines: {node: '>= 6'} - - pkg-types@1.3.1: - resolution: {integrity: sha512-/Jm5M4RvtBFVkKWRu2BLUTNP8/M2a+UwuAX+ae4770q1qVGtfjG+WTCupoZixokjmHiry8uI+dlY8KXYV5HVVQ==} - - postcss-load-config@6.0.1: - resolution: {integrity: sha512-oPtTM4oerL+UXmx+93ytZVN82RrlY/wPUV8IeDxFrzIjXOLF1pN+EmKPLbubvKHT2HC20xXsCAH2Z+CKV6Oz/g==} - engines: {node: '>= 18'} - peerDependencies: - jiti: '>=1.21.0' - postcss: '>=8.0.9' - tsx: ^4.8.1 - yaml: ^2.4.2 - peerDependenciesMeta: - jiti: - optional: true - postcss: - optional: true - tsx: - optional: true - yaml: - optional: true - - postcss@8.5.8: - resolution: {integrity: sha512-OW/rX8O/jXnm82Ey1k44pObPtdblfiuWnrd8X7GJ7emImCOstunGbXUpp7HdBrFQX6rJzn3sPT397Wp5aCwCHg==} - engines: {node: ^10 || ^12 || >=14} - - prelude-ls@1.2.1: - resolution: {integrity: sha512-vkcDPrRZo1QZLbn5RLGPpg/WmIQ65qoWWhcGKf/b5eplkkarX0m9z8ppCat4mlOqUsWpyNuYgO3VRyrYHSzX5g==} - engines: {node: '>= 0.8.0'} - - proxy-from-env@1.1.0: - resolution: {integrity: sha512-D+zkORCbA9f1tdWRK0RaCR3GPv50cMxcrz4X8k5LTSUD1Dkw47mKJEZQNunItRTkWwgtaUSo1RVFRIG9ZXiFYg==} - - punycode@2.3.1: - resolution: {integrity: sha512-vYt7UD1U9Wg6138shLtLOvdAu+8DsC/ilFtEVHcH+wydcSpNE20AfSOduf6MkRFahL5FY7X1oU7nKVZFtfq8Fg==} - engines: {node: '>=6'} - - readdirp@4.1.2: - resolution: {integrity: sha512-GDhwkLfywWL2s6vEjyhri+eXmfH6j1L7JE27WhqLeYzoh/A3DBaYGEj2H/HFZCn/kMfim73FXxEJTw06WtxQwg==} - engines: {node: '>= 14.18.0'} - - resolve-from@5.0.0: - resolution: {integrity: sha512-qYg9KP24dD5qka9J47d0aVky0N+b4fTU89LN9iDnjB5waksiC49rvMB0PrUJQGoTmH50XPiqOvAjDfaijGxYZw==} - engines: {node: '>=8'} - - rollup@4.59.0: - resolution: {integrity: sha512-2oMpl67a3zCH9H79LeMcbDhXW/UmWG/y2zuqnF2jQq5uq9TbM9TVyXvA4+t+ne2IIkBdrLpAaRQAvo7YI/Yyeg==} - engines: {node: '>=18.0.0', npm: '>=8.0.0'} - hasBin: true - - semver@7.7.4: - resolution: {integrity: sha512-vFKC2IEtQnVhpT78h1Yp8wzwrf8CM+MzKMHGJZfBtzhZNycRFnXsHk6E5TxIkkMsgNS7mdX3AGB7x2QM2di4lA==} - engines: {node: '>=10'} - hasBin: true - - shebang-command@2.0.0: - resolution: {integrity: sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==} - engines: {node: '>=8'} - - shebang-regex@3.0.0: - resolution: {integrity: sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A==} - engines: {node: '>=8'} - - siginfo@2.0.0: - resolution: {integrity: sha512-ybx0WO1/8bSBLEWXZvEd7gMW3Sn3JFlW3TvX1nREbDLRNQNaeNN8WK0meBwPdAaOI7TtRRRJn/Es1zhrrCHu7g==} - - source-map-js@1.2.1: - resolution: {integrity: sha512-UXWMKhLOwVKb728IUtQPXxfYU+usdybtUrK/8uGE8CQMvrhOpwvzDBwj0QhSL7MQc7vIsISBG8VQ8+IDQxpfQA==} - engines: {node: '>=0.10.0'} - - source-map@0.7.6: - resolution: {integrity: sha512-i5uvt8C3ikiWeNZSVZNWcfZPItFQOsYTUAOkcUPGd8DqDy1uOUikjt5dG+uRlwyvR108Fb9DOd4GvXfT0N2/uQ==} - engines: {node: '>= 12'} - - stackback@0.0.2: - resolution: {integrity: sha512-1XMJE5fQo1jGH6Y/7ebnwPOBEkIEnT4QF32d5R1+VXdXveM0IBMJt8zfaxX1P3QhVwrYe+576+jkANtSS2mBbw==} - - std-env@3.10.0: - resolution: {integrity: sha512-5GS12FdOZNliM5mAOxFRg7Ir0pWz8MdpYm6AY6VPkGpbA7ZzmbzNcBJQ0GPvvyWgcY7QAhCgf9Uy89I03faLkg==} - - sucrase@3.35.1: - resolution: {integrity: sha512-DhuTmvZWux4H1UOnWMB3sk0sbaCVOoQZjv8u1rDoTV0HTdGem9hkAZtl4JZy8P2z4Bg0nT+YMeOFyVr4zcG5Tw==} - engines: {node: '>=16 || 14 >=14.17'} - hasBin: true - - supports-color@7.2.0: - resolution: {integrity: sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==} - engines: {node: '>=8'} - - thenify-all@1.6.0: - resolution: {integrity: sha512-RNxQH/qI8/t3thXJDwcstUO4zeqo64+Uy/+sNVRBx4Xn2OX+OZ9oP+iJnNFqplFra2ZUVeKCSa2oVWi3T4uVmA==} - engines: {node: '>=0.8'} - - thenify@3.3.1: - resolution: {integrity: sha512-RVZSIV5IG10Hk3enotrhvz0T9em6cyHBLkH/YAZuKqd8hRkKhSfCGIcP2KUY0EPxndzANBmNllzWPwak+bheSw==} - - tinybench@2.9.0: - resolution: {integrity: sha512-0+DUvqWMValLmha6lr4kD8iAMK1HzV0/aKnCtWb9v9641TnP/MFb7Pc2bxoxQjTXAErryXVgUOfv2YqNllqGeg==} - - tinyexec@0.3.2: - resolution: {integrity: sha512-KQQR9yN7R5+OSwaK0XQoj22pwHoTlgYqmUscPYoknOoWCWfj/5/ABTMRi69FrKU5ffPVh5QcFikpWJI/P1ocHA==} - - tinyexec@1.0.2: - resolution: {integrity: sha512-W/KYk+NFhkmsYpuHq5JykngiOCnxeVL8v8dFnqxSD8qEEdRfXk1SDM6JzNqcERbcGYj9tMrDQBYV9cjgnunFIg==} - engines: {node: '>=18'} - - tinyglobby@0.2.15: - resolution: {integrity: sha512-j2Zq4NyQYG5XMST4cbs02Ak8iJUdxRM0XI5QyxXuZOzKOINmWurp3smXu3y5wDcJrptwpSjgXHzIQxR0omXljQ==} - engines: {node: '>=12.0.0'} - - tinyrainbow@3.0.3: - resolution: {integrity: sha512-PSkbLUoxOFRzJYjjxHJt9xro7D+iilgMX/C9lawzVuYiIdcihh9DXmVibBe8lmcFrRi/VzlPjBxbN7rH24q8/Q==} - engines: {node: '>=14.0.0'} - - tree-kill@1.2.2: - resolution: {integrity: sha512-L0Orpi8qGpRG//Nd+H90vFB+3iHnue1zSSGmNOOCh1GLJ7rUKVwV2HvijphGQS2UmhUZewS9VgvxYIdgr+fG1A==} - hasBin: true - - ts-api-utils@2.4.0: - resolution: {integrity: sha512-3TaVTaAv2gTiMB35i3FiGJaRfwb3Pyn/j3m/bfAvGe8FB7CF6u+LMYqYlDh7reQf7UNvoTvdfAqHGmPGOSsPmA==} - engines: {node: '>=18.12'} - peerDependencies: - typescript: '>=4.8.4' - - ts-interface-checker@0.1.13: - resolution: {integrity: sha512-Y/arvbn+rrz3JCKl9C4kVNfTfSm2/mEp5FSz5EsZSANGPSlQrpRI5M4PKF+mJnE52jOO90PnPSc3Ur3bTQw0gA==} - - tsup@8.5.1: - resolution: {integrity: sha512-xtgkqwdhpKWr3tKPmCkvYmS9xnQK3m3XgxZHwSUjvfTjp7YfXe5tT3GgWi0F2N+ZSMsOeWeZFh7ZZFg5iPhing==} - engines: {node: '>=18'} - hasBin: true - peerDependencies: - '@microsoft/api-extractor': ^7.36.0 - '@swc/core': ^1 - postcss: ^8.4.12 - typescript: '>=4.5.0' - peerDependenciesMeta: - '@microsoft/api-extractor': - optional: true - '@swc/core': - optional: true - postcss: - optional: true - typescript: - optional: true - - type-check@0.4.0: - resolution: {integrity: sha512-XleUoc9uwGXqjWwXaUTZAmzMcFZ5858QA2vvx1Ur5xIcixXIP+8LnFDgRplU30us6teqdlskFfu+ae4K79Ooew==} - engines: {node: '>= 0.8.0'} - - typescript@5.9.3: - resolution: {integrity: sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==} - engines: {node: '>=14.17'} - hasBin: true - - ufo@1.6.3: - resolution: {integrity: sha512-yDJTmhydvl5lJzBmy/hyOAA0d+aqCBuwl818haVdYCRrWV84o7YyeVm4QlVHStqNrrJSTb6jKuFAVqAFsr+K3Q==} - - undici-types@7.18.2: - resolution: {integrity: sha512-AsuCzffGHJybSaRrmr5eHr81mwJU3kjw6M+uprWvCXiNeN9SOGwQ3Jn8jb8m3Z6izVgknn1R0FTCEAP2QrLY/w==} - - uri-js@4.4.1: - resolution: {integrity: sha512-7rKUyy33Q1yc98pQ1DAmLtwX109F7TIfWlW1Ydo8Wl1ii1SeHieeh0HHfPeL2fMXK6z0s8ecKs9frCuLJvndBg==} - - vite@7.3.1: - resolution: {integrity: sha512-w+N7Hifpc3gRjZ63vYBXA56dvvRlNWRczTdmCBBa+CotUzAPf5b7YMdMR/8CQoeYE5LX3W4wj6RYTgonm1b9DA==} - engines: {node: ^20.19.0 || >=22.12.0} - hasBin: true - peerDependencies: - '@types/node': ^20.19.0 || >=22.12.0 - jiti: '>=1.21.0' - less: ^4.0.0 - lightningcss: ^1.21.0 - sass: ^1.70.0 - sass-embedded: ^1.70.0 - stylus: '>=0.54.8' - sugarss: ^5.0.0 - terser: ^5.16.0 - tsx: ^4.8.1 - yaml: ^2.4.2 - peerDependenciesMeta: - '@types/node': - optional: true - jiti: - optional: true - less: - optional: true - lightningcss: - optional: true - sass: - optional: true - sass-embedded: - optional: true - stylus: - optional: true - sugarss: - optional: true - terser: - optional: true - tsx: - optional: true - yaml: - optional: true - - vitest@4.0.18: - resolution: {integrity: sha512-hOQuK7h0FGKgBAas7v0mSAsnvrIgAvWmRFjmzpJ7SwFHH3g1k2u37JtYwOwmEKhK6ZO3v9ggDBBm0La1LCK4uQ==} - engines: {node: ^20.0.0 || ^22.0.0 || >=24.0.0} - hasBin: true - peerDependencies: - '@edge-runtime/vm': '*' - '@opentelemetry/api': ^1.9.0 - '@types/node': ^20.0.0 || ^22.0.0 || >=24.0.0 - '@vitest/browser-playwright': 4.0.18 - '@vitest/browser-preview': 4.0.18 - '@vitest/browser-webdriverio': 4.0.18 - '@vitest/ui': 4.0.18 - happy-dom: '*' - jsdom: '*' - peerDependenciesMeta: - '@edge-runtime/vm': - optional: true - '@opentelemetry/api': - optional: true - '@types/node': - optional: true - '@vitest/browser-playwright': - optional: true - '@vitest/browser-preview': - optional: true - '@vitest/browser-webdriverio': - optional: true - '@vitest/ui': - optional: true - happy-dom: - optional: true - jsdom: - optional: true - - which@2.0.2: - resolution: {integrity: sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==} - engines: {node: '>= 8'} - hasBin: true - - why-is-node-running@2.3.0: - resolution: {integrity: sha512-hUrmaWBdVDcxvYqnyh09zunKzROWjbZTiNy8dBEjkS7ehEDQibXJ7XvlmtbwuTclUiIyN+CyXQD4Vmko8fNm8w==} - engines: {node: '>=8'} - hasBin: true - - word-wrap@1.2.5: - resolution: {integrity: sha512-BN22B5eaMMI9UMtjrGd5g5eCYPpCPDUy0FJXbYsaT5zYxjFOckS53SQDE3pWkVoWpHXVb3BrYcEN4Twa55B5cA==} - engines: {node: '>=0.10.0'} - - yocto-queue@0.1.0: - resolution: {integrity: sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q==} - engines: {node: '>=10'} - -snapshots: - - '@babel/helper-string-parser@7.27.1': {} - - '@babel/helper-validator-identifier@7.28.5': {} - - '@babel/parser@7.29.0': - dependencies: - '@babel/types': 7.29.0 - - '@babel/types@7.29.0': - dependencies: - '@babel/helper-string-parser': 7.27.1 - '@babel/helper-validator-identifier': 7.28.5 - - '@bcoe/v8-coverage@1.0.2': {} - - '@esbuild/aix-ppc64@0.27.3': - optional: true - - '@esbuild/android-arm64@0.27.3': - optional: true - - '@esbuild/android-arm@0.27.3': - optional: true - - '@esbuild/android-x64@0.27.3': - optional: true - - '@esbuild/darwin-arm64@0.27.3': - optional: true - - '@esbuild/darwin-x64@0.27.3': - optional: true - - '@esbuild/freebsd-arm64@0.27.3': - optional: true - - '@esbuild/freebsd-x64@0.27.3': - optional: true - - '@esbuild/linux-arm64@0.27.3': - optional: true - - '@esbuild/linux-arm@0.27.3': - optional: true - - '@esbuild/linux-ia32@0.27.3': - optional: true - - '@esbuild/linux-loong64@0.27.3': - optional: true - - '@esbuild/linux-mips64el@0.27.3': - optional: true - - '@esbuild/linux-ppc64@0.27.3': - optional: true - - '@esbuild/linux-riscv64@0.27.3': - optional: true - - '@esbuild/linux-s390x@0.27.3': - optional: true - - '@esbuild/linux-x64@0.27.3': - optional: true - - '@esbuild/netbsd-arm64@0.27.3': - optional: true - - '@esbuild/netbsd-x64@0.27.3': - optional: true - - '@esbuild/openbsd-arm64@0.27.3': - optional: true - - '@esbuild/openbsd-x64@0.27.3': - optional: true - - '@esbuild/openharmony-arm64@0.27.3': - optional: true - - '@esbuild/sunos-x64@0.27.3': - optional: true - - '@esbuild/win32-arm64@0.27.3': - optional: true - - '@esbuild/win32-ia32@0.27.3': - optional: true - - '@esbuild/win32-x64@0.27.3': - optional: true - - '@eslint-community/eslint-utils@4.9.1(eslint@10.0.3)': - dependencies: - eslint: 10.0.3 - eslint-visitor-keys: 3.4.3 - - '@eslint-community/regexpp@4.12.2': {} - - '@eslint/config-array@0.23.3': - dependencies: - '@eslint/object-schema': 3.0.3 - debug: 4.4.3 - minimatch: 10.2.4 - transitivePeerDependencies: - - supports-color - - '@eslint/config-helpers@0.5.3': - dependencies: - '@eslint/core': 1.1.1 - - '@eslint/core@1.1.1': - dependencies: - '@types/json-schema': 7.0.15 - - '@eslint/js@10.0.1(eslint@10.0.3)': - optionalDependencies: - eslint: 10.0.3 - - '@eslint/object-schema@3.0.3': {} - - '@eslint/plugin-kit@0.6.1': - dependencies: - '@eslint/core': 1.1.1 - levn: 0.4.1 - - '@humanfs/core@0.19.1': {} - - '@humanfs/node@0.16.7': - dependencies: - '@humanfs/core': 0.19.1 - '@humanwhocodes/retry': 0.4.3 - - '@humanwhocodes/module-importer@1.0.1': {} - - '@humanwhocodes/retry@0.4.3': {} - - '@jridgewell/gen-mapping@0.3.13': - dependencies: - '@jridgewell/sourcemap-codec': 1.5.5 - '@jridgewell/trace-mapping': 0.3.31 - - '@jridgewell/resolve-uri@3.1.2': {} - - '@jridgewell/sourcemap-codec@1.5.5': {} - - '@jridgewell/trace-mapping@0.3.31': - dependencies: - '@jridgewell/resolve-uri': 3.1.2 - '@jridgewell/sourcemap-codec': 1.5.5 - - '@rollup/rollup-android-arm-eabi@4.59.0': - optional: true - - '@rollup/rollup-android-arm64@4.59.0': - optional: true - - '@rollup/rollup-darwin-arm64@4.59.0': - optional: true - - '@rollup/rollup-darwin-x64@4.59.0': - optional: true - - '@rollup/rollup-freebsd-arm64@4.59.0': - optional: true - - '@rollup/rollup-freebsd-x64@4.59.0': - optional: true - - '@rollup/rollup-linux-arm-gnueabihf@4.59.0': - optional: true - - '@rollup/rollup-linux-arm-musleabihf@4.59.0': - optional: true - - '@rollup/rollup-linux-arm64-gnu@4.59.0': - optional: true - - '@rollup/rollup-linux-arm64-musl@4.59.0': - optional: true - - '@rollup/rollup-linux-loong64-gnu@4.59.0': - optional: true - - '@rollup/rollup-linux-loong64-musl@4.59.0': - optional: true - - '@rollup/rollup-linux-ppc64-gnu@4.59.0': - optional: true - - '@rollup/rollup-linux-ppc64-musl@4.59.0': - optional: true - - '@rollup/rollup-linux-riscv64-gnu@4.59.0': - optional: true - - '@rollup/rollup-linux-riscv64-musl@4.59.0': - optional: true - - '@rollup/rollup-linux-s390x-gnu@4.59.0': - optional: true - - '@rollup/rollup-linux-x64-gnu@4.59.0': - optional: true - - '@rollup/rollup-linux-x64-musl@4.59.0': - optional: true - - '@rollup/rollup-openbsd-x64@4.59.0': - optional: true - - '@rollup/rollup-openharmony-arm64@4.59.0': - optional: true - - '@rollup/rollup-win32-arm64-msvc@4.59.0': - optional: true - - '@rollup/rollup-win32-ia32-msvc@4.59.0': - optional: true - - '@rollup/rollup-win32-x64-gnu@4.59.0': - optional: true - - '@rollup/rollup-win32-x64-msvc@4.59.0': - optional: true - - '@standard-schema/spec@1.1.0': {} - - '@types/chai@5.2.3': - dependencies: - '@types/deep-eql': 4.0.2 - assertion-error: 2.0.1 - - '@types/deep-eql@4.0.2': {} - - '@types/esrecurse@4.3.1': {} - - '@types/estree@1.0.8': {} - - '@types/json-schema@7.0.15': {} - - '@types/node@25.4.0': - dependencies: - undici-types: 7.18.2 - - '@typescript-eslint/eslint-plugin@8.57.0(@typescript-eslint/parser@8.57.0(eslint@10.0.3)(typescript@5.9.3))(eslint@10.0.3)(typescript@5.9.3)': - dependencies: - '@eslint-community/regexpp': 4.12.2 - '@typescript-eslint/parser': 8.57.0(eslint@10.0.3)(typescript@5.9.3) - '@typescript-eslint/scope-manager': 8.57.0 - '@typescript-eslint/type-utils': 8.57.0(eslint@10.0.3)(typescript@5.9.3) - '@typescript-eslint/utils': 8.57.0(eslint@10.0.3)(typescript@5.9.3) - '@typescript-eslint/visitor-keys': 8.57.0 - eslint: 10.0.3 - ignore: 7.0.5 - natural-compare: 1.4.0 - ts-api-utils: 2.4.0(typescript@5.9.3) - typescript: 5.9.3 - transitivePeerDependencies: - - supports-color - - '@typescript-eslint/parser@8.57.0(eslint@10.0.3)(typescript@5.9.3)': - dependencies: - '@typescript-eslint/scope-manager': 8.57.0 - '@typescript-eslint/types': 8.57.0 - '@typescript-eslint/typescript-estree': 8.57.0(typescript@5.9.3) - '@typescript-eslint/visitor-keys': 8.57.0 - debug: 4.4.3 - eslint: 10.0.3 - typescript: 5.9.3 - transitivePeerDependencies: - - supports-color - - '@typescript-eslint/project-service@8.57.0(typescript@5.9.3)': - dependencies: - '@typescript-eslint/tsconfig-utils': 8.57.0(typescript@5.9.3) - '@typescript-eslint/types': 8.57.0 - debug: 4.4.3 - typescript: 5.9.3 - transitivePeerDependencies: - - supports-color - - '@typescript-eslint/scope-manager@8.57.0': - dependencies: - '@typescript-eslint/types': 8.57.0 - '@typescript-eslint/visitor-keys': 8.57.0 - - '@typescript-eslint/tsconfig-utils@8.57.0(typescript@5.9.3)': - dependencies: - typescript: 5.9.3 - - '@typescript-eslint/type-utils@8.57.0(eslint@10.0.3)(typescript@5.9.3)': - dependencies: - '@typescript-eslint/types': 8.57.0 - '@typescript-eslint/typescript-estree': 8.57.0(typescript@5.9.3) - '@typescript-eslint/utils': 8.57.0(eslint@10.0.3)(typescript@5.9.3) - debug: 4.4.3 - eslint: 10.0.3 - ts-api-utils: 2.4.0(typescript@5.9.3) - typescript: 5.9.3 - transitivePeerDependencies: - - supports-color - - '@typescript-eslint/types@8.57.0': {} - - '@typescript-eslint/typescript-estree@8.57.0(typescript@5.9.3)': - dependencies: - '@typescript-eslint/project-service': 8.57.0(typescript@5.9.3) - '@typescript-eslint/tsconfig-utils': 8.57.0(typescript@5.9.3) - '@typescript-eslint/types': 8.57.0 - '@typescript-eslint/visitor-keys': 8.57.0 - debug: 4.4.3 - minimatch: 10.2.4 - semver: 7.7.4 - tinyglobby: 0.2.15 - ts-api-utils: 2.4.0(typescript@5.9.3) - typescript: 5.9.3 - transitivePeerDependencies: - - supports-color - - '@typescript-eslint/utils@8.57.0(eslint@10.0.3)(typescript@5.9.3)': - dependencies: - '@eslint-community/eslint-utils': 4.9.1(eslint@10.0.3) - '@typescript-eslint/scope-manager': 8.57.0 - '@typescript-eslint/types': 8.57.0 - '@typescript-eslint/typescript-estree': 8.57.0(typescript@5.9.3) - eslint: 10.0.3 - typescript: 5.9.3 - transitivePeerDependencies: - - supports-color - - '@typescript-eslint/visitor-keys@8.57.0': - dependencies: - '@typescript-eslint/types': 8.57.0 - eslint-visitor-keys: 5.0.1 - - '@vitest/coverage-v8@4.0.18(vitest@4.0.18(@types/node@25.4.0))': - dependencies: - '@bcoe/v8-coverage': 1.0.2 - '@vitest/utils': 4.0.18 - ast-v8-to-istanbul: 0.3.12 - istanbul-lib-coverage: 3.2.2 - istanbul-lib-report: 3.0.1 - istanbul-reports: 3.2.0 - magicast: 0.5.2 - obug: 2.1.1 - std-env: 3.10.0 - tinyrainbow: 3.0.3 - vitest: 4.0.18(@types/node@25.4.0) - - '@vitest/expect@4.0.18': - dependencies: - '@standard-schema/spec': 1.1.0 - '@types/chai': 5.2.3 - '@vitest/spy': 4.0.18 - '@vitest/utils': 4.0.18 - chai: 6.2.2 - tinyrainbow: 3.0.3 - - '@vitest/mocker@4.0.18(vite@7.3.1(@types/node@25.4.0))': - dependencies: - '@vitest/spy': 4.0.18 - estree-walker: 3.0.3 - magic-string: 0.30.21 - optionalDependencies: - vite: 7.3.1(@types/node@25.4.0) - - '@vitest/pretty-format@4.0.18': - dependencies: - tinyrainbow: 3.0.3 - - '@vitest/runner@4.0.18': - dependencies: - '@vitest/utils': 4.0.18 - pathe: 2.0.3 - - '@vitest/snapshot@4.0.18': - dependencies: - '@vitest/pretty-format': 4.0.18 - magic-string: 0.30.21 - pathe: 2.0.3 - - '@vitest/spy@4.0.18': {} - - '@vitest/utils@4.0.18': - dependencies: - '@vitest/pretty-format': 4.0.18 - tinyrainbow: 3.0.3 - - acorn-jsx@5.3.2(acorn@8.16.0): - dependencies: - acorn: 8.16.0 - - acorn@8.16.0: {} - - ajv@6.14.0: - dependencies: - fast-deep-equal: 3.1.3 - fast-json-stable-stringify: 2.1.0 - json-schema-traverse: 0.4.1 - uri-js: 4.4.1 - - any-promise@1.3.0: {} - - assertion-error@2.0.1: {} - - ast-v8-to-istanbul@0.3.12: - dependencies: - '@jridgewell/trace-mapping': 0.3.31 - estree-walker: 3.0.3 - js-tokens: 10.0.0 - - asynckit@0.4.0: {} - - axios@1.13.6: - dependencies: - follow-redirects: 1.15.11 - form-data: 4.0.5 - proxy-from-env: 1.1.0 - transitivePeerDependencies: - - debug - - balanced-match@4.0.4: {} - - brace-expansion@5.0.4: - dependencies: - balanced-match: 4.0.4 - - bundle-require@5.1.0(esbuild@0.27.3): - dependencies: - esbuild: 0.27.3 - load-tsconfig: 0.2.5 - - cac@6.7.14: {} - - call-bind-apply-helpers@1.0.2: - dependencies: - es-errors: 1.3.0 - function-bind: 1.1.2 - - chai@6.2.2: {} - - chokidar@4.0.3: - dependencies: - readdirp: 4.1.2 - - combined-stream@1.0.8: - dependencies: - delayed-stream: 1.0.0 - - commander@4.1.1: {} - - confbox@0.1.8: {} - - consola@3.4.2: {} - - cross-spawn@7.0.6: - dependencies: - path-key: 3.1.1 - shebang-command: 2.0.0 - which: 2.0.2 - - debug@4.4.3: - dependencies: - ms: 2.1.3 - - deep-is@0.1.4: {} - - delayed-stream@1.0.0: {} - - dunder-proto@1.0.1: - dependencies: - call-bind-apply-helpers: 1.0.2 - es-errors: 1.3.0 - gopd: 1.2.0 - - es-define-property@1.0.1: {} - - es-errors@1.3.0: {} - - es-module-lexer@1.7.0: {} - - es-object-atoms@1.1.1: - dependencies: - es-errors: 1.3.0 - - es-set-tostringtag@2.1.0: - dependencies: - es-errors: 1.3.0 - get-intrinsic: 1.3.0 - has-tostringtag: 1.0.2 - hasown: 2.0.2 - - esbuild@0.27.3: - optionalDependencies: - '@esbuild/aix-ppc64': 0.27.3 - '@esbuild/android-arm': 0.27.3 - '@esbuild/android-arm64': 0.27.3 - '@esbuild/android-x64': 0.27.3 - '@esbuild/darwin-arm64': 0.27.3 - '@esbuild/darwin-x64': 0.27.3 - '@esbuild/freebsd-arm64': 0.27.3 - '@esbuild/freebsd-x64': 0.27.3 - '@esbuild/linux-arm': 0.27.3 - '@esbuild/linux-arm64': 0.27.3 - '@esbuild/linux-ia32': 0.27.3 - '@esbuild/linux-loong64': 0.27.3 - '@esbuild/linux-mips64el': 0.27.3 - '@esbuild/linux-ppc64': 0.27.3 - '@esbuild/linux-riscv64': 0.27.3 - '@esbuild/linux-s390x': 0.27.3 - '@esbuild/linux-x64': 0.27.3 - '@esbuild/netbsd-arm64': 0.27.3 - '@esbuild/netbsd-x64': 0.27.3 - '@esbuild/openbsd-arm64': 0.27.3 - '@esbuild/openbsd-x64': 0.27.3 - '@esbuild/openharmony-arm64': 0.27.3 - '@esbuild/sunos-x64': 0.27.3 - '@esbuild/win32-arm64': 0.27.3 - '@esbuild/win32-ia32': 0.27.3 - '@esbuild/win32-x64': 0.27.3 - - escape-string-regexp@4.0.0: {} - - eslint-scope@9.1.2: - dependencies: - '@types/esrecurse': 4.3.1 - '@types/estree': 1.0.8 - esrecurse: 4.3.0 - estraverse: 5.3.0 - - eslint-visitor-keys@3.4.3: {} - - eslint-visitor-keys@5.0.1: {} - - eslint@10.0.3: - dependencies: - '@eslint-community/eslint-utils': 4.9.1(eslint@10.0.3) - '@eslint-community/regexpp': 4.12.2 - '@eslint/config-array': 0.23.3 - '@eslint/config-helpers': 0.5.3 - '@eslint/core': 1.1.1 - '@eslint/plugin-kit': 0.6.1 - '@humanfs/node': 0.16.7 - '@humanwhocodes/module-importer': 1.0.1 - '@humanwhocodes/retry': 0.4.3 - '@types/estree': 1.0.8 - ajv: 6.14.0 - cross-spawn: 7.0.6 - debug: 4.4.3 - escape-string-regexp: 4.0.0 - eslint-scope: 9.1.2 - eslint-visitor-keys: 5.0.1 - espree: 11.2.0 - esquery: 1.7.0 - esutils: 2.0.3 - fast-deep-equal: 3.1.3 - file-entry-cache: 8.0.0 - find-up: 5.0.0 - glob-parent: 6.0.2 - ignore: 5.3.2 - imurmurhash: 0.1.4 - is-glob: 4.0.3 - json-stable-stringify-without-jsonify: 1.0.1 - minimatch: 10.2.4 - natural-compare: 1.4.0 - optionator: 0.9.4 - transitivePeerDependencies: - - supports-color - - espree@11.2.0: - dependencies: - acorn: 8.16.0 - acorn-jsx: 5.3.2(acorn@8.16.0) - eslint-visitor-keys: 5.0.1 - - esquery@1.7.0: - dependencies: - estraverse: 5.3.0 - - esrecurse@4.3.0: - dependencies: - estraverse: 5.3.0 - - estraverse@5.3.0: {} - - estree-walker@3.0.3: - dependencies: - '@types/estree': 1.0.8 - - esutils@2.0.3: {} - - expect-type@1.3.0: {} - - fast-deep-equal@3.1.3: {} - - fast-json-stable-stringify@2.1.0: {} - - fast-levenshtein@2.0.6: {} - - fdir@6.5.0(picomatch@4.0.3): - optionalDependencies: - picomatch: 4.0.3 - - file-entry-cache@8.0.0: - dependencies: - flat-cache: 4.0.1 - - find-up@5.0.0: - dependencies: - locate-path: 6.0.0 - path-exists: 4.0.0 - - fix-dts-default-cjs-exports@1.0.1: - dependencies: - magic-string: 0.30.21 - mlly: 1.8.1 - rollup: 4.59.0 - - flat-cache@4.0.1: - dependencies: - flatted: 3.4.1 - keyv: 4.5.4 - - flatted@3.4.1: {} - - follow-redirects@1.15.11: {} - - form-data@4.0.5: - dependencies: - asynckit: 0.4.0 - combined-stream: 1.0.8 - es-set-tostringtag: 2.1.0 - hasown: 2.0.2 - mime-types: 2.1.35 - - fsevents@2.3.3: - optional: true - - function-bind@1.1.2: {} - - get-intrinsic@1.3.0: - dependencies: - call-bind-apply-helpers: 1.0.2 - es-define-property: 1.0.1 - es-errors: 1.3.0 - es-object-atoms: 1.1.1 - function-bind: 1.1.2 - get-proto: 1.0.1 - gopd: 1.2.0 - has-symbols: 1.1.0 - hasown: 2.0.2 - math-intrinsics: 1.1.0 - - get-proto@1.0.1: - dependencies: - dunder-proto: 1.0.1 - es-object-atoms: 1.1.1 - - glob-parent@6.0.2: - dependencies: - is-glob: 4.0.3 - - gopd@1.2.0: {} - - has-flag@4.0.0: {} - - has-symbols@1.1.0: {} - - has-tostringtag@1.0.2: - dependencies: - has-symbols: 1.1.0 - - hasown@2.0.2: - dependencies: - function-bind: 1.1.2 - - html-escaper@2.0.2: {} - - ignore@5.3.2: {} - - ignore@7.0.5: {} - - imurmurhash@0.1.4: {} - - is-extglob@2.1.1: {} - - is-glob@4.0.3: - dependencies: - is-extglob: 2.1.1 - - isexe@2.0.0: {} - - istanbul-lib-coverage@3.2.2: {} - - istanbul-lib-report@3.0.1: - dependencies: - istanbul-lib-coverage: 3.2.2 - make-dir: 4.0.0 - supports-color: 7.2.0 - - istanbul-reports@3.2.0: - dependencies: - html-escaper: 2.0.2 - istanbul-lib-report: 3.0.1 - - joycon@3.1.1: {} - - js-tokens@10.0.0: {} - - json-buffer@3.0.1: {} - - json-schema-traverse@0.4.1: {} - - json-stable-stringify-without-jsonify@1.0.1: {} - - keyv@4.5.4: - dependencies: - json-buffer: 3.0.1 - - levn@0.4.1: - dependencies: - prelude-ls: 1.2.1 - type-check: 0.4.0 - - lilconfig@3.1.3: {} - - lines-and-columns@1.2.4: {} - - load-tsconfig@0.2.5: {} - - locate-path@6.0.0: - dependencies: - p-locate: 5.0.0 - - magic-string@0.30.21: - dependencies: - '@jridgewell/sourcemap-codec': 1.5.5 - - magicast@0.5.2: - dependencies: - '@babel/parser': 7.29.0 - '@babel/types': 7.29.0 - source-map-js: 1.2.1 - - make-dir@4.0.0: - dependencies: - semver: 7.7.4 - - math-intrinsics@1.1.0: {} - - mime-db@1.52.0: {} - - mime-types@2.1.35: - dependencies: - mime-db: 1.52.0 - - minimatch@10.2.4: - dependencies: - brace-expansion: 5.0.4 - - mlly@1.8.1: - dependencies: - acorn: 8.16.0 - pathe: 2.0.3 - pkg-types: 1.3.1 - ufo: 1.6.3 - - ms@2.1.3: {} - - mz@2.7.0: - dependencies: - any-promise: 1.3.0 - object-assign: 4.1.1 - thenify-all: 1.6.0 - - nanoid@3.3.11: {} - - natural-compare@1.4.0: {} - - object-assign@4.1.1: {} - - obug@2.1.1: {} - - optionator@0.9.4: - dependencies: - deep-is: 0.1.4 - fast-levenshtein: 2.0.6 - levn: 0.4.1 - prelude-ls: 1.2.1 - type-check: 0.4.0 - word-wrap: 1.2.5 - - p-limit@3.1.0: - dependencies: - yocto-queue: 0.1.0 - - p-locate@5.0.0: - dependencies: - p-limit: 3.1.0 - - path-exists@4.0.0: {} - - path-key@3.1.1: {} - - pathe@2.0.3: {} - - picocolors@1.1.1: {} - - picomatch@4.0.3: {} - - pirates@4.0.7: {} - - pkg-types@1.3.1: - dependencies: - confbox: 0.1.8 - mlly: 1.8.1 - pathe: 2.0.3 - - postcss-load-config@6.0.1(postcss@8.5.8): - dependencies: - lilconfig: 3.1.3 - optionalDependencies: - postcss: 8.5.8 - - postcss@8.5.8: - dependencies: - nanoid: 3.3.11 - picocolors: 1.1.1 - source-map-js: 1.2.1 - - prelude-ls@1.2.1: {} - - proxy-from-env@1.1.0: {} - - punycode@2.3.1: {} - - readdirp@4.1.2: {} - - resolve-from@5.0.0: {} - - rollup@4.59.0: - dependencies: - '@types/estree': 1.0.8 - optionalDependencies: - '@rollup/rollup-android-arm-eabi': 4.59.0 - '@rollup/rollup-android-arm64': 4.59.0 - '@rollup/rollup-darwin-arm64': 4.59.0 - '@rollup/rollup-darwin-x64': 4.59.0 - '@rollup/rollup-freebsd-arm64': 4.59.0 - '@rollup/rollup-freebsd-x64': 4.59.0 - '@rollup/rollup-linux-arm-gnueabihf': 4.59.0 - '@rollup/rollup-linux-arm-musleabihf': 4.59.0 - '@rollup/rollup-linux-arm64-gnu': 4.59.0 - '@rollup/rollup-linux-arm64-musl': 4.59.0 - '@rollup/rollup-linux-loong64-gnu': 4.59.0 - '@rollup/rollup-linux-loong64-musl': 4.59.0 - '@rollup/rollup-linux-ppc64-gnu': 4.59.0 - '@rollup/rollup-linux-ppc64-musl': 4.59.0 - '@rollup/rollup-linux-riscv64-gnu': 4.59.0 - '@rollup/rollup-linux-riscv64-musl': 4.59.0 - '@rollup/rollup-linux-s390x-gnu': 4.59.0 - '@rollup/rollup-linux-x64-gnu': 4.59.0 - '@rollup/rollup-linux-x64-musl': 4.59.0 - '@rollup/rollup-openbsd-x64': 4.59.0 - '@rollup/rollup-openharmony-arm64': 4.59.0 - '@rollup/rollup-win32-arm64-msvc': 4.59.0 - '@rollup/rollup-win32-ia32-msvc': 4.59.0 - '@rollup/rollup-win32-x64-gnu': 4.59.0 - '@rollup/rollup-win32-x64-msvc': 4.59.0 - fsevents: 2.3.3 - - semver@7.7.4: {} - - shebang-command@2.0.0: - dependencies: - shebang-regex: 3.0.0 - - shebang-regex@3.0.0: {} - - siginfo@2.0.0: {} - - source-map-js@1.2.1: {} - - source-map@0.7.6: {} - - stackback@0.0.2: {} - - std-env@3.10.0: {} - - sucrase@3.35.1: - dependencies: - '@jridgewell/gen-mapping': 0.3.13 - commander: 4.1.1 - lines-and-columns: 1.2.4 - mz: 2.7.0 - pirates: 4.0.7 - tinyglobby: 0.2.15 - ts-interface-checker: 0.1.13 - - supports-color@7.2.0: - dependencies: - has-flag: 4.0.0 - - thenify-all@1.6.0: - dependencies: - thenify: 3.3.1 - - thenify@3.3.1: - dependencies: - any-promise: 1.3.0 - - tinybench@2.9.0: {} - - tinyexec@0.3.2: {} - - tinyexec@1.0.2: {} - - tinyglobby@0.2.15: - dependencies: - fdir: 6.5.0(picomatch@4.0.3) - picomatch: 4.0.3 - - tinyrainbow@3.0.3: {} - - tree-kill@1.2.2: {} - - ts-api-utils@2.4.0(typescript@5.9.3): - dependencies: - typescript: 5.9.3 - - ts-interface-checker@0.1.13: {} - - tsup@8.5.1(postcss@8.5.8)(typescript@5.9.3): - dependencies: - bundle-require: 5.1.0(esbuild@0.27.3) - cac: 6.7.14 - chokidar: 4.0.3 - consola: 3.4.2 - debug: 4.4.3 - esbuild: 0.27.3 - fix-dts-default-cjs-exports: 1.0.1 - joycon: 3.1.1 - picocolors: 1.1.1 - postcss-load-config: 6.0.1(postcss@8.5.8) - resolve-from: 5.0.0 - rollup: 4.59.0 - source-map: 0.7.6 - sucrase: 3.35.1 - tinyexec: 0.3.2 - tinyglobby: 0.2.15 - tree-kill: 1.2.2 - optionalDependencies: - postcss: 8.5.8 - typescript: 5.9.3 - transitivePeerDependencies: - - jiti - - supports-color - - tsx - - yaml - - type-check@0.4.0: - dependencies: - prelude-ls: 1.2.1 - - typescript@5.9.3: {} - - ufo@1.6.3: {} - - undici-types@7.18.2: {} - - uri-js@4.4.1: - dependencies: - punycode: 2.3.1 - - vite@7.3.1(@types/node@25.4.0): - dependencies: - esbuild: 0.27.3 - fdir: 6.5.0(picomatch@4.0.3) - picomatch: 4.0.3 - postcss: 8.5.8 - rollup: 4.59.0 - tinyglobby: 0.2.15 - optionalDependencies: - '@types/node': 25.4.0 - fsevents: 2.3.3 - - vitest@4.0.18(@types/node@25.4.0): - dependencies: - '@vitest/expect': 4.0.18 - '@vitest/mocker': 4.0.18(vite@7.3.1(@types/node@25.4.0)) - '@vitest/pretty-format': 4.0.18 - '@vitest/runner': 4.0.18 - '@vitest/snapshot': 4.0.18 - '@vitest/spy': 4.0.18 - '@vitest/utils': 4.0.18 - es-module-lexer: 1.7.0 - expect-type: 1.3.0 - magic-string: 0.30.21 - obug: 2.1.1 - pathe: 2.0.3 - picomatch: 4.0.3 - std-env: 3.10.0 - tinybench: 2.9.0 - tinyexec: 1.0.2 - tinyglobby: 0.2.15 - tinyrainbow: 3.0.3 - vite: 7.3.1(@types/node@25.4.0) - why-is-node-running: 2.3.0 - optionalDependencies: - '@types/node': 25.4.0 - transitivePeerDependencies: - - jiti - - less - - lightningcss - - msw - - sass - - sass-embedded - - stylus - - sugarss - - terser - - tsx - - yaml - - which@2.0.2: - dependencies: - isexe: 2.0.0 - - why-is-node-running@2.3.0: - dependencies: - siginfo: 2.0.0 - stackback: 0.0.2 - - word-wrap@1.2.5: {} - - yocto-queue@0.1.0: {} diff --git a/sdks/nodejs-client/pnpm-workspace.yaml b/sdks/nodejs-client/pnpm-workspace.yaml deleted file mode 100644 index efc037aa846..00000000000 --- a/sdks/nodejs-client/pnpm-workspace.yaml +++ /dev/null @@ -1,2 +0,0 @@ -onlyBuiltDependencies: - - esbuild diff --git a/sdks/nodejs-client/scripts/publish.sh b/sdks/nodejs-client/scripts/publish.sh index 043cac046d7..5f8e73f8c01 100755 --- a/sdks/nodejs-client/scripts/publish.sh +++ b/sdks/nodejs-client/scripts/publish.sh @@ -5,10 +5,12 @@ # A beautiful and reliable script to publish the SDK to npm # # Usage: -# ./scripts/publish.sh # Normal publish +# ./scripts/publish.sh # Normal publish # ./scripts/publish.sh --dry-run # Test without publishing # ./scripts/publish.sh --skip-tests # Skip tests (not recommended) # +# This script requires pnpm because the workspace uses catalog: dependencies. +# set -euo pipefail @@ -62,11 +64,27 @@ divider() { echo -e "${DIM}─────────────────────────────────────────────────────────────────${NC}" } +run_npm() { + env \ + -u npm_config_npm_globalconfig \ + -u NPM_CONFIG_NPM_GLOBALCONFIG \ + -u npm_config_verify_deps_before_run \ + -u NPM_CONFIG_VERIFY_DEPS_BEFORE_RUN \ + -u npm_config__jsr_registry \ + -u NPM_CONFIG__JSR_REGISTRY \ + -u npm_config_catalog \ + -u NPM_CONFIG_CATALOG \ + -u npm_config_overrides \ + -u NPM_CONFIG_OVERRIDES \ + npm "$@" +} + # ============================================================================ # Configuration # ============================================================================ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" PROJECT_DIR="$(cd "$SCRIPT_DIR/.." && pwd)" +REPO_ROOT="$(git -C "$SCRIPT_DIR" rev-parse --show-toplevel 2>/dev/null || (cd "$SCRIPT_DIR/../../.." && pwd))" DRY_RUN=false SKIP_TESTS=false @@ -123,23 +141,23 @@ main() { error "npm is not installed" exit 1 fi - NPM_VERSION=$(npm -v) + NPM_VERSION=$(run_npm -v) success "npm: v$NPM_VERSION" - # Check pnpm (optional, for local dev) - if command -v pnpm &> /dev/null; then - PNPM_VERSION=$(pnpm -v) - success "pnpm: v$PNPM_VERSION" - else - info "pnpm not found (optional)" + if ! command -v pnpm &> /dev/null; then + error "pnpm is required because this workspace publishes catalog: dependencies" + info "Install pnpm with Corepack: corepack enable" + exit 1 fi + PNPM_VERSION=$(pnpm -v) + success "pnpm: v$PNPM_VERSION" # Check npm login status - if ! npm whoami &> /dev/null; then + if ! run_npm whoami &> /dev/null; then error "Not logged in to npm. Run 'npm login' first." exit 1 fi - NPM_USER=$(npm whoami) + NPM_USER=$(run_npm whoami) success "Logged in as: ${BOLD}$NPM_USER${NC}" # ======================================================================== @@ -154,11 +172,11 @@ main() { success "Version: ${BOLD}$PACKAGE_VERSION${NC}" # Check if version already exists on npm - if npm view "$PACKAGE_NAME@$PACKAGE_VERSION" version &> /dev/null; then + if run_npm view "$PACKAGE_NAME@$PACKAGE_VERSION" version &> /dev/null; then error "Version $PACKAGE_VERSION already exists on npm!" echo "" info "Current published versions:" - npm view "$PACKAGE_NAME" versions --json 2>/dev/null | tail -5 + run_npm view "$PACKAGE_NAME" versions --json 2>/dev/null | tail -5 echo "" warning "Please update the version in package.json before publishing." exit 1 @@ -170,11 +188,7 @@ main() { # ======================================================================== step "Step 3/6: Installing dependencies..." - if command -v pnpm &> /dev/null; then - pnpm install --frozen-lockfile 2>/dev/null || pnpm install - else - npm ci 2>/dev/null || npm install - fi + pnpm --dir "$REPO_ROOT" install --frozen-lockfile 2>/dev/null || pnpm --dir "$REPO_ROOT" install success "Dependencies installed" # ======================================================================== @@ -185,11 +199,7 @@ main() { if [[ "$SKIP_TESTS" == true ]]; then warning "Skipping tests (--skip-tests flag)" else - if command -v pnpm &> /dev/null; then - pnpm test - else - npm test - fi + pnpm test success "All tests passed" fi @@ -201,11 +211,7 @@ main() { # Clean previous build rm -rf dist - if command -v pnpm &> /dev/null; then - pnpm run build - else - npm run build - fi + pnpm run build success "Build completed" # Verify build output @@ -223,15 +229,32 @@ main() { # Step 6: Publish # ======================================================================== step "Step 6/6: Publishing to npm..." - + + PACK_DIR="$(mktemp -d)" + trap 'rm -rf "$PACK_DIR"' EXIT + + pnpm pack --pack-destination "$PACK_DIR" >/dev/null + PACKAGE_TARBALL="$(find "$PACK_DIR" -maxdepth 1 -name '*.tgz' | head -n 1)" + + if [[ -z "$PACKAGE_TARBALL" ]]; then + error "Pack failed - no tarball generated" + exit 1 + fi + + if tar -xOf "$PACKAGE_TARBALL" package/package.json | grep -q '"catalog:'; then + error "Packed manifest still contains catalog: references" + exit 1 + fi + divider echo -e "${CYAN}Package contents:${NC}" - npm pack --dry-run 2>&1 | head -30 + tar -tzf "$PACKAGE_TARBALL" | head -30 divider if [[ "$DRY_RUN" == true ]]; then warning "DRY-RUN: Skipping actual publish" echo "" + info "Packed artifact: $PACKAGE_TARBALL" info "To publish for real, run without --dry-run flag" else echo "" @@ -239,7 +262,7 @@ main() { echo -e "${DIM}Press Enter to continue, or Ctrl+C to cancel...${NC}" read -r - npm publish --access public + pnpm publish --access public --no-git-checks echo "" success "🎉 Successfully published ${BOLD}$PACKAGE_NAME@$PACKAGE_VERSION${NC} to npm!" diff --git a/web/taze.config.js b/taze.config.js similarity index 91% rename from web/taze.config.js rename to taze.config.js index 4e97a50d2e5..d21756e207d 100644 --- a/web/taze.config.js +++ b/taze.config.js @@ -10,7 +10,7 @@ export default defineConfig({ // We can not upgrade these yet 'tailwind-merge', 'tailwindcss', - '@eslint-react/eslint-plugin', + 'typescript', ], write: true, diff --git a/web/.dockerignore b/web/.dockerignore deleted file mode 100644 index 91437a2259a..00000000000 --- a/web/.dockerignore +++ /dev/null @@ -1,32 +0,0 @@ -.env -.env.* - -# Logs -logs -*.log* - -# node -node_modules -dist -build -coverage -.husky -.next -.pnpm-store - -# vscode -.vscode - -# webstorm -.idea -*.iml -*.iws -*.ipr - - -# Jetbrains -.idea - -# git -.git -.gitignore \ No newline at end of file diff --git a/web/.env.example b/web/.env.example index ed06ebe2c9b..62d4fa6c568 100644 --- a/web/.env.example +++ b/web/.env.example @@ -6,19 +6,23 @@ NEXT_PUBLIC_EDITION=SELF_HOSTED NEXT_PUBLIC_BASE_PATH= # The base URL of console application, refers to the Console base URL of WEB service if console domain is # different from api or web app domain. -# example: http://cloud.dify.ai/console/api +# example: https://cloud.dify.ai/console/api NEXT_PUBLIC_API_PREFIX=http://localhost:5001/console/api # The URL for Web APP, refers to the Web App base URL of WEB service if web app domain is different from # console or api domain. -# example: http://udify.app/api +# example: https://udify.app/api NEXT_PUBLIC_PUBLIC_API_PREFIX=http://localhost:5001/api -# Dev-only Hono proxy targets. The frontend keeps requesting http://localhost:5001 directly. +# When the frontend and backend run on different subdomains, set NEXT_PUBLIC_COOKIE_DOMAIN=1. +NEXT_PUBLIC_COOKIE_DOMAIN= + +# Dev-only Hono proxy targets. +# The frontend keeps requesting http://localhost:5001 directly, +# the proxy server will forward the request to the target server, +# so that you don't need to run a separate backend server and use online API in development. HONO_PROXY_HOST=127.0.0.1 HONO_PROXY_PORT=5001 HONO_CONSOLE_API_PROXY_TARGET= HONO_PUBLIC_API_PROXY_TARGET= -# When the frontend and backend run on different subdomains, set NEXT_PUBLIC_COOKIE_DOMAIN=1. -NEXT_PUBLIC_COOKIE_DOMAIN= # The API PREFIX for MARKETPLACE NEXT_PUBLIC_MARKETPLACE_API_PREFIX=https://marketplace.dify.ai/api/v1 @@ -47,8 +51,6 @@ NEXT_PUBLIC_ALLOW_EMBED= # Allow rendering unsafe URLs which have "data:" scheme. NEXT_PUBLIC_ALLOW_UNSAFE_DATA_SCHEME=false -# Github Access Token, used for invoking Github API -NEXT_PUBLIC_GITHUB_ACCESS_TOKEN= # The maximum number of top-k value for RAG. NEXT_PUBLIC_TOP_K_MAX_VALUE=10 diff --git a/web/.storybook/preview.tsx b/web/.storybook/preview.tsx index 5b384247765..072244c33f9 100644 --- a/web/.storybook/preview.tsx +++ b/web/.storybook/preview.tsx @@ -7,7 +7,7 @@ import { I18nClientProvider as I18N } from '../app/components/provider/i18n' import commonEnUS from '../i18n/en-US/common.json' import '../app/styles/globals.css' -import '../app/styles/markdown.scss' +import '../app/styles/markdown.css' import './storybook.css' const queryClient = new QueryClient({ diff --git a/web/Dockerfile b/web/Dockerfile index b54bae706ca..75024db4f3c 100644 --- a/web/Dockerfile +++ b/web/Dockerfile @@ -19,21 +19,27 @@ ENV NEXT_PUBLIC_BASE_PATH="$NEXT_PUBLIC_BASE_PATH" # install packages FROM base AS packages -WORKDIR /app/web +WORKDIR /app -COPY package.json pnpm-lock.yaml /app/web/ +COPY package.json pnpm-lock.yaml pnpm-workspace.yaml /app/ +COPY web/package.json /app/web/ +COPY e2e/package.json /app/e2e/ +COPY sdks/nodejs-client/package.json /app/sdks/nodejs-client/ # Use packageManager from package.json RUN corepack install -RUN pnpm install --frozen-lockfile +# Install only the web workspace to keep image builds from pulling in +# unrelated workspace dependencies such as e2e tooling. +RUN pnpm install --filter ./web... --frozen-lockfile # build resources FROM base AS builder -WORKDIR /app/web -COPY --from=packages /app/web/ . +WORKDIR /app +COPY --from=packages /app/ . COPY . . +WORKDIR /app/web ENV NODE_OPTIONS="--max-old-space-size=4096" RUN pnpm build @@ -64,13 +70,13 @@ RUN addgroup -S -g ${dify_uid} dify && \ chown -R dify:dify /app -WORKDIR /app/web +WORKDIR /app -COPY --from=builder --chown=dify:dify /app/web/public ./public +COPY --from=builder --chown=dify:dify /app/web/public ./web/public COPY --from=builder --chown=dify:dify /app/web/.next/standalone ./ -COPY --from=builder --chown=dify:dify /app/web/.next/static ./.next/static +COPY --from=builder --chown=dify:dify /app/web/.next/static ./web/.next/static -COPY --chown=dify:dify --chmod=755 docker/entrypoint.sh ./entrypoint.sh +COPY --chown=dify:dify --chmod=755 web/docker/entrypoint.sh ./entrypoint.sh ARG COMMIT_SHA ENV COMMIT_SHA=${COMMIT_SHA} diff --git a/web/Dockerfile.dockerignore b/web/Dockerfile.dockerignore new file mode 100644 index 00000000000..9801003d892 --- /dev/null +++ b/web/Dockerfile.dockerignore @@ -0,0 +1,34 @@ +** +!package.json +!pnpm-lock.yaml +!pnpm-workspace.yaml +!.nvmrc +!web/ +!web/** +!e2e/ +!e2e/package.json +!sdks/ +!sdks/nodejs-client/ +!sdks/nodejs-client/package.json + +.git +node_modules +.pnpm-store +web/.env +web/.env.* +web/logs +web/*.log* +web/node_modules +web/dist +web/build +web/coverage +web/.husky +web/.next +web/.pnpm-store +web/.vscode +web/.idea +web/*.iml +web/*.iws +web/*.ipr +e2e/node_modules +sdks/nodejs-client/node_modules diff --git a/web/README.md b/web/README.md index 1e57e7c6a97..2d69a94dbd9 100644 --- a/web/README.md +++ b/web/README.md @@ -1,6 +1,6 @@ # Dify Frontend -This is a [Next.js](https://nextjs.org/) project bootstrapped with [`create-next-app`](https://github.com/vercel/next.js/tree/canary/packages/create-next-app). +This is a [Next.js] project, but you can dev with [vinext]. ## Getting Started @@ -8,8 +8,11 @@ This is a [Next.js](https://nextjs.org/) project bootstrapped with [`create-next Before starting the web frontend service, please make sure the following environment is ready. -- [Node.js](https://nodejs.org) -- [pnpm](https://pnpm.io) +- [Node.js] +- [pnpm] + +You can also use [Vite+] with the corresponding `vp` commands. +For example, use `vp install` instead of `pnpm install` and `vp test` instead of `pnpm run test`. > [!TIP] > It is recommended to install and enable Corepack to manage package manager versions automatically: @@ -19,7 +22,9 @@ Before starting the web frontend service, please make sure the following environ > corepack enable > ``` > -> Learn more: [Corepack](https://github.com/nodejs/corepack#readme) +> Learn more: [Corepack] + +Run the following commands from the repository root. First, install the dependencies: @@ -27,29 +32,16 @@ First, install the dependencies: pnpm install ``` -Then, configure the environment variables. Create a file named `.env.local` in the current directory and copy the contents from `.env.example`. Modify the values of these environment variables according to your requirements: +> [!NOTE] +> JavaScript dependencies are managed by the workspace files at the repository root: `package.json`, `pnpm-lock.yaml`, `pnpm-workspace.yaml`, and `.nvmrc`. +> Install dependencies from the repository root, then run frontend scripts from `web/`. + +Then, configure the environment variables. +Create `web/.env.local` and copy the contents from `web/.env.example`. +Modify the values of these environment variables according to your requirements: ```bash -cp .env.example .env.local -``` - -```txt -# For production release, change this to PRODUCTION -NEXT_PUBLIC_DEPLOY_ENV=DEVELOPMENT -# The deployment edition, SELF_HOSTED -NEXT_PUBLIC_EDITION=SELF_HOSTED -# The base URL of console application, refers to the Console base URL of WEB service if console domain is -# different from api or web app domain. -# example: http://cloud.dify.ai/console/api -NEXT_PUBLIC_API_PREFIX=http://localhost:5001/console/api -NEXT_PUBLIC_COOKIE_DOMAIN= -# The URL for Web APP, refers to the Web App base URL of WEB service if web app domain is different from -# console or api domain. -# example: http://udify.app/api -NEXT_PUBLIC_PUBLIC_API_PREFIX=http://localhost:5001/api - -# SENTRY -NEXT_PUBLIC_SENTRY_DSN= +cp web/.env.example web/.env.local ``` > [!IMPORTANT] @@ -60,12 +52,17 @@ NEXT_PUBLIC_SENTRY_DSN= Finally, run the development server: ```bash -pnpm run dev +pnpm -C web run dev +# or if you are using vinext which provides a better development experience +pnpm -C web run dev:vinext +# (optional) start the dev proxy server so that you can use online API in development +pnpm -C web run dev:proxy ``` -Open [http://localhost:3000](http://localhost:3000) with your browser to see the result. +Open with your browser to see the result. -You can start editing the file under folder `app`. The page auto-updates as you edit the file. +You can start editing the files under `web/app`. +The page auto-updates as you edit the file. ## Deploy @@ -74,57 +71,73 @@ You can start editing the file under folder `app`. The page auto-updates as you First, build the app for production: ```bash -pnpm run build +pnpm -C web run build ``` Then, start the server: ```bash -pnpm run start +pnpm -C web run start +``` + +If you build the Docker image manually, use the repository root as the build context: + +```bash +docker build -f web/Dockerfile -t dify-web . ``` If you want to customize the host and port: ```bash -pnpm run start --port=3001 --host=0.0.0.0 +pnpm -C web run start --port=3001 --host=0.0.0.0 ``` ## Storybook -This project uses [Storybook](https://storybook.js.org/) for UI component development. +This project uses [Storybook] for UI component development. To start the storybook server, run: ```bash -pnpm storybook +pnpm -C web storybook ``` -Open [http://localhost:6006](http://localhost:6006) with your browser to see the result. +Open with your browser to see the result. ## Lint Code If your IDE is VSCode, rename `web/.vscode/settings.example.json` to `web/.vscode/settings.json` for lint code setting. -Then follow the [Lint Documentation](./docs/lint.md) to lint the code. +Then follow the [Lint Documentation] to lint the code. ## Test -We use [Vitest](https://vitest.dev/) and [React Testing Library](https://testing-library.com/docs/react-testing-library/intro/) for Unit Testing. +We use [Vitest] and [React Testing Library] for Unit Testing. -**📖 Complete Testing Guide**: See [web/testing/testing.md](./testing/testing.md) for detailed testing specifications, best practices, and examples. +**📖 Complete Testing Guide**: See [web/docs/test.md] for detailed testing specifications, best practices, and examples. + +> [!IMPORTANT] +> As we are using Vite+, the `vitest` command is not available. +> Please make sure to run tests with `vp` commands. +> For example, use `npx vp test` instead of `npx vitest`. Run test: ```bash -pnpm test +pnpm -C web test ``` +> [!NOTE] +> Our test is not fully stable yet, and we are actively working on improving it. +> If you encounter test failures only in CI but not locally, please feel free to ignore them and report the issue to us. +> You can try to re-run the test in CI, and it may pass successfully. + ### Example Code If you are not familiar with writing tests, refer to: -- [classnames.spec.ts](./utils/classnames.spec.ts) - Utility function test example -- [index.spec.tsx](./app/components/base/button/index.spec.tsx) - Component test example +- [classnames.spec.ts] - Utility function test example +- [index.spec.tsx] - Component test example ### Analyze Component Complexity @@ -134,7 +147,7 @@ Before writing tests, use the script to analyze component complexity: pnpm analyze-component app/components/your-component/index.tsx ``` -This will help you determine the testing strategy. See [web/testing/testing.md](./testing/testing.md) for details. +This will help you determine the testing strategy. See [web/testing/testing.md] for details. ## Documentation @@ -142,4 +155,19 @@ Visit to view the full documentation. ## Community -The Dify community can be found on [Discord community](https://discord.gg/5AEfbxcd9k), where you can ask questions, voice ideas, and share your projects. +The Dify community can be found on [Discord community], where you can ask questions, voice ideas, and share your projects. + +[Corepack]: https://github.com/nodejs/corepack#readme +[Discord community]: https://discord.gg/5AEfbxcd9k +[Lint Documentation]: ./docs/lint.md +[Next.js]: https://nextjs.org +[Node.js]: https://nodejs.org +[React Testing Library]: https://testing-library.com/docs/react-testing-library/intro +[Storybook]: https://storybook.js.org +[Vite+]: https://viteplus.dev +[Vitest]: https://vitest.dev +[classnames.spec.ts]: ./utils/classnames.spec.ts +[index.spec.tsx]: ./app/components/base/button/index.spec.tsx +[pnpm]: https://pnpm.io +[vinext]: https://github.com/cloudflare/vinext +[web/docs/test.md]: ./docs/test.md diff --git a/web/__tests__/billing/cloud-plan-payment-flow.test.tsx b/web/__tests__/billing/cloud-plan-payment-flow.test.tsx index 84653cd68c0..0c1efbe1afb 100644 --- a/web/__tests__/billing/cloud-plan-payment-flow.test.tsx +++ b/web/__tests__/billing/cloud-plan-payment-flow.test.tsx @@ -95,7 +95,7 @@ describe('Cloud Plan Payment Flow', () => { beforeEach(() => { vi.clearAllMocks() cleanup() - toast.close() + toast.dismiss() setupAppContext() mockFetchSubscriptionUrls.mockResolvedValue({ url: 'https://pay.example.com/checkout' }) mockInvoices.mockResolvedValue({ url: 'https://billing.example.com/invoices' }) diff --git a/web/__tests__/billing/self-hosted-plan-flow.test.tsx b/web/__tests__/billing/self-hosted-plan-flow.test.tsx index 0802b760e12..a3386d0092a 100644 --- a/web/__tests__/billing/self-hosted-plan-flow.test.tsx +++ b/web/__tests__/billing/self-hosted-plan-flow.test.tsx @@ -66,7 +66,7 @@ describe('Self-Hosted Plan Flow', () => { beforeEach(() => { vi.clearAllMocks() cleanup() - toast.close() + toast.dismiss() setupAppContext() // Mock window.location with minimal getter/setter (Location props are non-enumerable) diff --git a/web/__tests__/check-i18n.test.ts b/web/__tests__/check-i18n.test.ts deleted file mode 100644 index de78ae997ec..00000000000 --- a/web/__tests__/check-i18n.test.ts +++ /dev/null @@ -1,860 +0,0 @@ -import fs from 'node:fs' -import path from 'node:path' -import vm from 'node:vm' -import { transpile } from 'typescript' - -describe('i18n:check script functionality', () => { - const testDir = path.join(__dirname, '../i18n-test') - const testEnDir = path.join(testDir, 'en-US') - const testZhDir = path.join(testDir, 'zh-Hans') - - // Helper function that replicates the getKeysFromLanguage logic - async function getKeysFromLanguage(language: string, testPath = testDir): Promise { - return new Promise((resolve, reject) => { - const folderPath = path.resolve(testPath, language) - const allKeys: string[] = [] - - if (!fs.existsSync(folderPath)) { - resolve([]) - return - } - - fs.readdir(folderPath, (err, files) => { - if (err) { - reject(err) - return - } - - const translationFiles = files.filter(file => /\.(ts|js)$/.test(file)) - - translationFiles.forEach((file) => { - const filePath = path.join(folderPath, file) - const fileName = file.replace(/\.[^/.]+$/, '') - const camelCaseFileName = fileName.replace(/[-_](.)/g, (_, c) => - c.toUpperCase()) - - try { - const content = fs.readFileSync(filePath, 'utf8') - const moduleExports = {} - const context = { - exports: moduleExports, - module: { exports: moduleExports }, - require, - console, - __filename: filePath, - __dirname: folderPath, - } - - vm.runInNewContext(transpile(content), context) - const translationObj = (context.module.exports as any).default || context.module.exports - - if (!translationObj || typeof translationObj !== 'object') - throw new Error(`Error parsing file: ${filePath}`) - - const nestedKeys: string[] = [] - const iterateKeys = (obj: any, prefix = '') => { - for (const key in obj) { - const nestedKey = prefix ? `${prefix}.${key}` : key - if (typeof obj[key] === 'object' && obj[key] !== null && !Array.isArray(obj[key])) { - // This is an object (but not array), recurse into it but don't add it as a key - iterateKeys(obj[key], nestedKey) - } - else { - // This is a leaf node (string, number, boolean, array, etc.), add it as a key - nestedKeys.push(nestedKey) - } - } - } - iterateKeys(translationObj) - - const fileKeys = nestedKeys.map(key => `${camelCaseFileName}.${key}`) - allKeys.push(...fileKeys) - } - catch (error) { - reject(error) - } - }) - resolve(allKeys) - }) - }) - } - - beforeEach(() => { - // Clean up and create test directories - if (fs.existsSync(testDir)) - fs.rmSync(testDir, { recursive: true }) - - fs.mkdirSync(testDir, { recursive: true }) - fs.mkdirSync(testEnDir, { recursive: true }) - fs.mkdirSync(testZhDir, { recursive: true }) - }) - - afterEach(() => { - // Clean up test files - if (fs.existsSync(testDir)) - fs.rmSync(testDir, { recursive: true }) - }) - - describe('Key extraction logic', () => { - it('should extract only leaf node keys, not intermediate objects', async () => { - const testContent = `const translation = { - simple: 'Simple Value', - nested: { - level1: 'Level 1 Value', - deep: { - level2: 'Level 2 Value' - } - }, - array: ['not extracted'], - number: 42, - boolean: true -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'test.ts'), testContent) - - const keys = await getKeysFromLanguage('en-US') - - expect(keys).toEqual([ - 'test.simple', - 'test.nested.level1', - 'test.nested.deep.level2', - 'test.array', - 'test.number', - 'test.boolean', - ]) - - // Should not include intermediate object keys - expect(keys).not.toContain('test.nested') - expect(keys).not.toContain('test.nested.deep') - }) - - it('should handle camelCase file name conversion correctly', async () => { - const testContent = `const translation = { - key: 'value' -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'app-debug.ts'), testContent) - fs.writeFileSync(path.join(testEnDir, 'user_profile.ts'), testContent) - - const keys = await getKeysFromLanguage('en-US') - - expect(keys).toContain('appDebug.key') - expect(keys).toContain('userProfile.key') - }) - }) - - describe('Missing keys detection', () => { - it('should detect missing keys in target language', async () => { - const enContent = `const translation = { - common: { - save: 'Save', - cancel: 'Cancel', - delete: 'Delete' - }, - app: { - title: 'My App', - version: '1.0' - } -} - -export default translation -` - - const zhContent = `const translation = { - common: { - save: '保存', - cancel: '取消' - // missing 'delete' - }, - app: { - title: '我的应用' - // missing 'version' - } -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'test.ts'), enContent) - fs.writeFileSync(path.join(testZhDir, 'test.ts'), zhContent) - - const enKeys = await getKeysFromLanguage('en-US') - const zhKeys = await getKeysFromLanguage('zh-Hans') - - const missingKeys = enKeys.filter(key => !zhKeys.includes(key)) - - expect(missingKeys).toContain('test.common.delete') - expect(missingKeys).toContain('test.app.version') - expect(missingKeys).toHaveLength(2) - }) - }) - - describe('Extra keys detection', () => { - it('should detect extra keys in target language', async () => { - const enContent = `const translation = { - common: { - save: 'Save', - cancel: 'Cancel' - } -} - -export default translation -` - - const zhContent = `const translation = { - common: { - save: '保存', - cancel: '取消', - delete: '删除', // extra key - extra: '额外的' // another extra key - }, - newSection: { - someKey: '某个值' // extra section - } -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'test.ts'), enContent) - fs.writeFileSync(path.join(testZhDir, 'test.ts'), zhContent) - - const enKeys = await getKeysFromLanguage('en-US') - const zhKeys = await getKeysFromLanguage('zh-Hans') - - const extraKeys = zhKeys.filter(key => !enKeys.includes(key)) - - expect(extraKeys).toContain('test.common.delete') - expect(extraKeys).toContain('test.common.extra') - expect(extraKeys).toContain('test.newSection.someKey') - expect(extraKeys).toHaveLength(3) - }) - }) - - describe('File filtering logic', () => { - it('should filter keys by specific file correctly', async () => { - // Create multiple files - const file1Content = `const translation = { - button: 'Button', - text: 'Text' -} - -export default translation -` - - const file2Content = `const translation = { - title: 'Title', - description: 'Description' -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'components.ts'), file1Content) - fs.writeFileSync(path.join(testEnDir, 'pages.ts'), file2Content) - fs.writeFileSync(path.join(testZhDir, 'components.ts'), file1Content) - fs.writeFileSync(path.join(testZhDir, 'pages.ts'), file2Content) - - const allEnKeys = await getKeysFromLanguage('en-US') - - // Test file filtering logic - const targetFile = 'components' - const filteredEnKeys = allEnKeys.filter(key => - key.startsWith(targetFile.replace(/[-_](.)/g, (_, c) => c.toUpperCase())), - ) - - expect(allEnKeys).toHaveLength(4) // 2 keys from each file - expect(filteredEnKeys).toHaveLength(2) // only components keys - expect(filteredEnKeys).toContain('components.button') - expect(filteredEnKeys).toContain('components.text') - expect(filteredEnKeys).not.toContain('pages.title') - expect(filteredEnKeys).not.toContain('pages.description') - }) - }) - - describe('Complex nested structure handling', () => { - it('should handle deeply nested objects correctly', async () => { - const complexContent = `const translation = { - level1: { - level2: { - level3: { - level4: { - deepValue: 'Deep Value' - }, - anotherValue: 'Another Value' - }, - simpleValue: 'Simple Value' - }, - directValue: 'Direct Value' - }, - rootValue: 'Root Value' -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'complex.ts'), complexContent) - - const keys = await getKeysFromLanguage('en-US') - - expect(keys).toContain('complex.level1.level2.level3.level4.deepValue') - expect(keys).toContain('complex.level1.level2.level3.anotherValue') - expect(keys).toContain('complex.level1.level2.simpleValue') - expect(keys).toContain('complex.level1.directValue') - expect(keys).toContain('complex.rootValue') - - // Should not include intermediate objects - expect(keys).not.toContain('complex.level1') - expect(keys).not.toContain('complex.level1.level2') - expect(keys).not.toContain('complex.level1.level2.level3') - expect(keys).not.toContain('complex.level1.level2.level3.level4') - }) - }) - - describe('Edge cases', () => { - it('should handle empty objects', async () => { - const emptyContent = `const translation = { - empty: {}, - withValue: 'value' -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'empty.ts'), emptyContent) - - const keys = await getKeysFromLanguage('en-US') - - expect(keys).toContain('empty.withValue') - expect(keys).not.toContain('empty.empty') - }) - - it('should handle special characters in keys', async () => { - const specialContent = `const translation = { - 'key-with-dash': 'value1', - 'key_with_underscore': 'value2', - 'key.with.dots': 'value3', - normalKey: 'value4' -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'special.ts'), specialContent) - - const keys = await getKeysFromLanguage('en-US') - - expect(keys).toContain('special.key-with-dash') - expect(keys).toContain('special.key_with_underscore') - expect(keys).toContain('special.key.with.dots') - expect(keys).toContain('special.normalKey') - }) - - it('should handle different value types', async () => { - const typesContent = `const translation = { - stringValue: 'string', - numberValue: 42, - booleanValue: true, - nullValue: null, - undefinedValue: undefined, - arrayValue: ['array', 'values'], - objectValue: { - nested: 'nested value' - } -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'types.ts'), typesContent) - - const keys = await getKeysFromLanguage('en-US') - - expect(keys).toContain('types.stringValue') - expect(keys).toContain('types.numberValue') - expect(keys).toContain('types.booleanValue') - expect(keys).toContain('types.nullValue') - expect(keys).toContain('types.undefinedValue') - expect(keys).toContain('types.arrayValue') - expect(keys).toContain('types.objectValue.nested') - expect(keys).not.toContain('types.objectValue') - }) - }) - - describe('Real-world scenario tests', () => { - it('should handle app-debug structure like real files', async () => { - const appDebugEn = `const translation = { - pageTitle: { - line1: 'Prompt', - line2: 'Engineering' - }, - operation: { - applyConfig: 'Publish', - resetConfig: 'Reset', - debugConfig: 'Debug' - }, - generate: { - instruction: 'Instructions', - generate: 'Generate', - resTitle: 'Generated Prompt', - noDataLine1: 'Describe your use case on the left,', - noDataLine2: 'the orchestration preview will show here.' - } -} - -export default translation -` - - const appDebugZh = `const translation = { - pageTitle: { - line1: '提示词', - line2: '编排' - }, - operation: { - applyConfig: '发布', - resetConfig: '重置', - debugConfig: '调试' - }, - generate: { - instruction: '指令', - generate: '生成', - resTitle: '生成的提示词', - noData: '在左侧描述您的用例,编排预览将在此处显示。' // This is extra - } -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'app-debug.ts'), appDebugEn) - fs.writeFileSync(path.join(testZhDir, 'app-debug.ts'), appDebugZh) - - const enKeys = await getKeysFromLanguage('en-US') - const zhKeys = await getKeysFromLanguage('zh-Hans') - - const missingKeys = enKeys.filter(key => !zhKeys.includes(key)) - const extraKeys = zhKeys.filter(key => !enKeys.includes(key)) - - expect(missingKeys).toContain('appDebug.generate.noDataLine1') - expect(missingKeys).toContain('appDebug.generate.noDataLine2') - expect(extraKeys).toContain('appDebug.generate.noData') - - expect(missingKeys).toHaveLength(2) - expect(extraKeys).toHaveLength(1) - }) - - it('should handle time structure with operation nested keys', async () => { - const timeEn = `const translation = { - months: { - January: 'January', - February: 'February' - }, - operation: { - now: 'Now', - ok: 'OK', - cancel: 'Cancel', - pickDate: 'Pick Date' - }, - title: { - pickTime: 'Pick Time' - }, - defaultPlaceholder: 'Pick a time...' -} - -export default translation -` - - const timeZh = `const translation = { - months: { - January: '一月', - February: '二月' - }, - operation: { - now: '此刻', - ok: '确定', - cancel: '取消', - pickDate: '选择日期' - }, - title: { - pickTime: '选择时间' - }, - pickDate: '选择日期', // This is extra - duplicates operation.pickDate - defaultPlaceholder: '请选择时间...' -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'time.ts'), timeEn) - fs.writeFileSync(path.join(testZhDir, 'time.ts'), timeZh) - - const enKeys = await getKeysFromLanguage('en-US') - const zhKeys = await getKeysFromLanguage('zh-Hans') - - const missingKeys = enKeys.filter(key => !zhKeys.includes(key)) - const extraKeys = zhKeys.filter(key => !enKeys.includes(key)) - - expect(missingKeys).toHaveLength(0) // No missing keys - expect(extraKeys).toContain('time.pickDate') // Extra root-level pickDate - expect(extraKeys).toHaveLength(1) - - // Should have both keys available - expect(zhKeys).toContain('time.operation.pickDate') // Correct nested key - expect(zhKeys).toContain('time.pickDate') // Extra duplicate key - }) - }) - - describe('Statistics calculation', () => { - it('should calculate correct difference statistics', async () => { - const enContent = `const translation = { - key1: 'value1', - key2: 'value2', - key3: 'value3' -} - -export default translation -` - - const zhContentMissing = `const translation = { - key1: 'value1', - key2: 'value2' - // missing key3 -} - -export default translation -` - - const zhContentExtra = `const translation = { - key1: 'value1', - key2: 'value2', - key3: 'value3', - key4: 'extra', - key5: 'extra2' -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'stats.ts'), enContent) - - // Test missing keys scenario - fs.writeFileSync(path.join(testZhDir, 'stats.ts'), zhContentMissing) - - const enKeys = await getKeysFromLanguage('en-US') - const zhKeysMissing = await getKeysFromLanguage('zh-Hans') - - expect(enKeys.length - zhKeysMissing.length).toBe(1) // +1 means 1 missing key - - // Test extra keys scenario - fs.writeFileSync(path.join(testZhDir, 'stats.ts'), zhContentExtra) - - const zhKeysExtra = await getKeysFromLanguage('zh-Hans') - - expect(enKeys.length - zhKeysExtra.length).toBe(-2) // -2 means 2 extra keys - }) - }) - - describe('Auto-remove multiline key-value pairs', () => { - // Helper function to simulate removeExtraKeysFromFile logic - function removeExtraKeysFromFile(content: string, keysToRemove: string[]): string { - const lines = content.split('\n') - const linesToRemove: number[] = [] - - for (const keyToRemove of keysToRemove) { - let targetLineIndex = -1 - const linesToRemoveForKey: number[] = [] - - // Find the key line (simplified for single-level keys in test) - for (let i = 0; i < lines.length; i++) { - const line = lines[i] - const keyPattern = new RegExp(`^\\s*${keyToRemove}\\s*:`) - if (keyPattern.test(line)) { - targetLineIndex = i - break - } - } - - if (targetLineIndex !== -1) { - linesToRemoveForKey.push(targetLineIndex) - - // Check if this is a multiline key-value pair - const keyLine = lines[targetLineIndex] - const trimmedKeyLine = keyLine.trim() - - // If key line ends with ":" (not complete value), it's likely multiline - if (trimmedKeyLine.endsWith(':') && !trimmedKeyLine.includes('{') && !/:\s*['"`]/.exec(trimmedKeyLine)) { - // Find the value lines that belong to this key - let currentLine = targetLineIndex + 1 - let foundValue = false - - while (currentLine < lines.length) { - const line = lines[currentLine] - const trimmed = line.trim() - - // Skip empty lines - if (trimmed === '') { - currentLine++ - continue - } - - // Check if this line starts a new key (indicates end of current value) - if (/^\w+\s*:/.exec(trimmed)) - break - - // Check if this line is part of the value - if (trimmed.startsWith('\'') || trimmed.startsWith('"') || trimmed.startsWith('`') || foundValue) { - linesToRemoveForKey.push(currentLine) - foundValue = true - - // Check if this line ends the value (ends with quote and comma/no comma) - if ((trimmed.endsWith('\',') || trimmed.endsWith('",') || trimmed.endsWith('`,') - || trimmed.endsWith('\'') || trimmed.endsWith('"') || trimmed.endsWith('`')) - && !trimmed.startsWith('//')) { - break - } - } - else { - break - } - - currentLine++ - } - } - - linesToRemove.push(...linesToRemoveForKey) - } - } - - // Remove duplicates and sort in reverse order - const uniqueLinesToRemove = [...new Set(linesToRemove)].sort((a, b) => b - a) - - for (const lineIndex of uniqueLinesToRemove) - lines.splice(lineIndex, 1) - - return lines.join('\n') - } - - it('should remove single-line key-value pairs correctly', () => { - const content = `const translation = { - keepThis: 'This should stay', - removeThis: 'This should be removed', - alsoKeep: 'This should also stay', -} - -export default translation` - - const result = removeExtraKeysFromFile(content, ['removeThis']) - - expect(result).toContain('keepThis: \'This should stay\'') - expect(result).toContain('alsoKeep: \'This should also stay\'') - expect(result).not.toContain('removeThis: \'This should be removed\'') - }) - - it('should remove multiline key-value pairs completely', () => { - const content = `const translation = { - keepThis: 'This should stay', - removeMultiline: - 'This is a multiline value that should be removed completely', - alsoKeep: 'This should also stay', -} - -export default translation` - - const result = removeExtraKeysFromFile(content, ['removeMultiline']) - - expect(result).toContain('keepThis: \'This should stay\'') - expect(result).toContain('alsoKeep: \'This should also stay\'') - expect(result).not.toContain('removeMultiline:') - expect(result).not.toContain('This is a multiline value that should be removed completely') - }) - - it('should handle mixed single-line and multiline removals', () => { - const content = `const translation = { - keepThis: 'Keep this', - removeSingle: 'Remove this single line', - removeMultiline: - 'Remove this multiline value', - anotherMultiline: - 'Another multiline that spans multiple lines', - keepAnother: 'Keep this too', -} - -export default translation` - - const result = removeExtraKeysFromFile(content, ['removeSingle', 'removeMultiline', 'anotherMultiline']) - - expect(result).toContain('keepThis: \'Keep this\'') - expect(result).toContain('keepAnother: \'Keep this too\'') - expect(result).not.toContain('removeSingle:') - expect(result).not.toContain('removeMultiline:') - expect(result).not.toContain('anotherMultiline:') - expect(result).not.toContain('Remove this single line') - expect(result).not.toContain('Remove this multiline value') - expect(result).not.toContain('Another multiline that spans multiple lines') - }) - - it('should properly detect multiline vs single-line patterns', () => { - const multilineContent = `const translation = { - singleLine: 'This is single line', - multilineKey: - 'This is multiline', - keyWithColon: 'Value with: colon inside', - objectKey: { - nested: 'value' - }, -} - -export default translation` - - // Test that single line with colon in value is not treated as multiline - const result1 = removeExtraKeysFromFile(multilineContent, ['keyWithColon']) - expect(result1).not.toContain('keyWithColon:') - expect(result1).not.toContain('Value with: colon inside') - - // Test that true multiline is handled correctly - const result2 = removeExtraKeysFromFile(multilineContent, ['multilineKey']) - expect(result2).not.toContain('multilineKey:') - expect(result2).not.toContain('This is multiline') - - // Test that object key removal works (note: this is a simplified test) - // In real scenario, object removal would be more complex - const result3 = removeExtraKeysFromFile(multilineContent, ['objectKey']) - expect(result3).not.toContain('objectKey: {') - // Note: Our simplified test function doesn't handle nested object removal perfectly - // This is acceptable as it's testing the main multiline string removal functionality - }) - - it('should handle real-world Polish translation structure', () => { - const polishContent = `const translation = { - createApp: 'UTWÓRZ APLIKACJĘ', - newApp: { - captionAppType: 'Jaki typ aplikacji chcesz stworzyć?', - chatbotDescription: - 'Zbuduj aplikację opartą na czacie. Ta aplikacja używa formatu pytań i odpowiedzi.', - agentDescription: - 'Zbuduj inteligentnego agenta, który może autonomicznie wybierać narzędzia.', - basic: 'Podstawowy', - }, -} - -export default translation` - - const result = removeExtraKeysFromFile(polishContent, ['captionAppType', 'chatbotDescription', 'agentDescription']) - - expect(result).toContain('createApp: \'UTWÓRZ APLIKACJĘ\'') - expect(result).toContain('basic: \'Podstawowy\'') - expect(result).not.toContain('captionAppType:') - expect(result).not.toContain('chatbotDescription:') - expect(result).not.toContain('agentDescription:') - expect(result).not.toContain('Jaki typ aplikacji') - expect(result).not.toContain('Zbuduj aplikację opartą na czacie') - expect(result).not.toContain('Zbuduj inteligentnego agenta') - }) - }) - - describe('Performance and Scalability', () => { - it('should handle large translation files efficiently', async () => { - // Create a large translation file with 1000 keys - const largeContent = `const translation = { -${Array.from({ length: 1000 }, (_, i) => ` key${i}: 'value${i}',`).join('\n')} -} - -export default translation` - - fs.writeFileSync(path.join(testEnDir, 'large.ts'), largeContent) - - const startTime = Date.now() - const keys = await getKeysFromLanguage('en-US') - const endTime = Date.now() - - expect(keys.length).toBe(1000) - expect(endTime - startTime).toBeLessThan(1000) // Should complete in under 1 second - }) - - it('should handle multiple translation files concurrently', async () => { - // Create multiple files - for (let i = 0; i < 10; i++) { - const content = `const translation = { - key${i}: 'value${i}', - nested${i}: { - subkey: 'subvalue' - } -} - -export default translation` - fs.writeFileSync(path.join(testEnDir, `file${i}.ts`), content) - } - - const startTime = Date.now() - const keys = await getKeysFromLanguage('en-US') - const endTime = Date.now() - - expect(keys.length).toBe(20) // 10 files * 2 keys each - expect(endTime - startTime).toBeLessThan(500) - }) - }) - - describe('Unicode and Internationalization', () => { - it('should handle Unicode characters in keys and values', async () => { - const unicodeContent = `const translation = { - '中文键': '中文值', - 'العربية': 'قيمة', - 'emoji_😀': 'value with emoji 🎉', - 'mixed_中文_English': 'mixed value' -} - -export default translation` - - fs.writeFileSync(path.join(testEnDir, 'unicode.ts'), unicodeContent) - - const keys = await getKeysFromLanguage('en-US') - - expect(keys).toContain('unicode.中文键') - expect(keys).toContain('unicode.العربية') - expect(keys).toContain('unicode.emoji_😀') - expect(keys).toContain('unicode.mixed_中文_English') - }) - - it('should handle RTL language files', async () => { - const rtlContent = `const translation = { - مرحبا: 'Hello', - العالم: 'World', - nested: { - مفتاح: 'key' - } -} - -export default translation` - - fs.writeFileSync(path.join(testEnDir, 'rtl.ts'), rtlContent) - - const keys = await getKeysFromLanguage('en-US') - - expect(keys).toContain('rtl.مرحبا') - expect(keys).toContain('rtl.العالم') - expect(keys).toContain('rtl.nested.مفتاح') - }) - }) - - describe('Error Recovery', () => { - it('should handle syntax errors in translation files gracefully', async () => { - const invalidContent = `const translation = { - validKey: 'valid value', - invalidKey: 'missing quote, - anotherKey: 'another value' -} - -export default translation` - - fs.writeFileSync(path.join(testEnDir, 'invalid.ts'), invalidContent) - - await expect(getKeysFromLanguage('en-US')).rejects.toThrow() - }) - }) -}) diff --git a/web/__tests__/datasets/dataset-settings-flow.test.tsx b/web/__tests__/datasets/dataset-settings-flow.test.tsx index 607cd8c2d5a..b4a5e783264 100644 --- a/web/__tests__/datasets/dataset-settings-flow.test.tsx +++ b/web/__tests__/datasets/dataset-settings-flow.test.tsx @@ -19,6 +19,10 @@ import { RETRIEVE_METHOD } from '@/types/app' // --- Mocks --- +const { mockToastError } = vi.hoisted(() => ({ + mockToastError: vi.fn(), +})) + const mockMutateDatasets = vi.fn() const mockInvalidDatasetList = vi.fn() const mockUpdateDatasetSetting = vi.fn().mockResolvedValue({}) @@ -55,8 +59,11 @@ vi.mock('@/app/components/datasets/common/check-rerank-model', () => ({ isReRankModelSelected: () => true, })) -vi.mock('@/app/components/base/toast', () => ({ - default: { notify: vi.fn() }, +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + error: mockToastError, + success: vi.fn(), + }, })) // --- Dataset factory --- @@ -311,7 +318,7 @@ describe('Dataset Settings Flow - Cross-Module Configuration Cascade', () => { describe('Form Submission Validation → All Fields Together', () => { it('should reject empty name on save', async () => { - const Toast = await import('@/app/components/base/toast') + const { toast } = await import('@/app/components/base/ui/toast') const { result } = renderHook(() => useFormState()) act(() => { @@ -322,10 +329,7 @@ describe('Dataset Settings Flow - Cross-Module Configuration Cascade', () => { await result.current.handleSave() }) - expect(Toast.default.notify).toHaveBeenCalledWith({ - type: 'error', - message: expect.any(String), - }) + expect(toast.error).toHaveBeenCalledWith(expect.any(String)) expect(mockUpdateDatasetSetting).not.toHaveBeenCalled() }) diff --git a/web/__tests__/explore/sidebar-lifecycle-flow.test.tsx b/web/__tests__/explore/sidebar-lifecycle-flow.test.tsx index f3d3128ccbf..64dd5321ac5 100644 --- a/web/__tests__/explore/sidebar-lifecycle-flow.test.tsx +++ b/web/__tests__/explore/sidebar-lifecycle-flow.test.tsx @@ -11,8 +11,8 @@ import SideBar from '@/app/components/explore/sidebar' import { MediaType } from '@/hooks/use-breakpoints' import { AppModeEnum } from '@/types/app' -const { mockToastAdd } = vi.hoisted(() => ({ - mockToastAdd: vi.fn(), +const { mockToastSuccess } = vi.hoisted(() => ({ + mockToastSuccess: vi.fn(), })) let mockMediaType: string = MediaType.pc @@ -53,14 +53,16 @@ vi.mock('@/service/use-explore', () => ({ }), })) -vi.mock('@/app/components/base/ui/toast', () => ({ - toast: { - add: mockToastAdd, - close: vi.fn(), - update: vi.fn(), - promise: vi.fn(), - }, -})) +vi.mock('@/app/components/base/ui/toast', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + toast: { + ...actual.toast, + success: mockToastSuccess, + }, + } +}) const createInstalledApp = (overrides: Partial = {}): InstalledApp => ({ id: overrides.id ?? 'app-1', @@ -105,9 +107,7 @@ describe('Sidebar Lifecycle Flow', () => { await waitFor(() => { expect(mockUpdatePinStatus).toHaveBeenCalledWith({ appId: 'app-1', isPinned: true }) - expect(mockToastAdd).toHaveBeenCalledWith(expect.objectContaining({ - type: 'success', - })) + expect(mockToastSuccess).toHaveBeenCalled() }) // Step 2: Simulate refetch returning pinned state, then unpin @@ -124,9 +124,7 @@ describe('Sidebar Lifecycle Flow', () => { await waitFor(() => { expect(mockUpdatePinStatus).toHaveBeenCalledWith({ appId: 'app-1', isPinned: false }) - expect(mockToastAdd).toHaveBeenCalledWith(expect.objectContaining({ - type: 'success', - })) + expect(mockToastSuccess).toHaveBeenCalled() }) }) @@ -150,10 +148,7 @@ describe('Sidebar Lifecycle Flow', () => { // Step 4: Uninstall API called and success toast shown await waitFor(() => { expect(mockUninstall).toHaveBeenCalledWith('app-1') - expect(mockToastAdd).toHaveBeenCalledWith(expect.objectContaining({ - type: 'success', - title: 'common.api.remove', - })) + expect(mockToastSuccess).toHaveBeenCalledWith('common.api.remove') }) }) diff --git a/web/app/components/__tests__/browser-initializer.spec.ts b/web/__tests__/instrumentation-client.spec.ts similarity index 100% rename from web/app/components/__tests__/browser-initializer.spec.ts rename to web/__tests__/instrumentation-client.spec.ts diff --git a/web/__tests__/plugins/plugin-install-flow.test.ts b/web/__tests__/plugins/plugin-install-flow.test.ts index 8edb6705d46..dd5a18b7245 100644 --- a/web/__tests__/plugins/plugin-install-flow.test.ts +++ b/web/__tests__/plugins/plugin-install-flow.test.ts @@ -5,15 +5,21 @@ * upload handling, and task status polling. Verifies the complete plugin * installation pipeline from source discovery to completion. */ -import { beforeEach, describe, expect, it, vi } from 'vitest' -vi.mock('@/config', () => ({ - GITHUB_ACCESS_TOKEN: '', -})) +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { checkForUpdates, fetchReleases, handleUpload } from '@/app/components/plugins/install-plugin/hooks' const mockToastNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ - default: { notify: (...args: unknown[]) => mockToastNotify(...args) }, +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign((message: string, options?: { type?: string }) => mockToastNotify({ type: options?.type, message }), { + success: (message: string) => mockToastNotify({ type: 'success', message }), + error: (message: string) => mockToastNotify({ type: 'error', message }), + warning: (message: string) => mockToastNotify({ type: 'warning', message }), + info: (message: string) => mockToastNotify({ type: 'info', message }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }), })) const mockUploadGitHub = vi.fn() @@ -22,10 +28,6 @@ vi.mock('@/service/plugins', () => ({ checkTaskStatus: vi.fn(), })) -const { useGitHubReleases, useGitHubUpload } = await import( - '@/app/components/plugins/install-plugin/hooks', -) - describe('Plugin Installation Flow Integration', () => { beforeEach(() => { vi.clearAllMocks() @@ -36,22 +38,22 @@ describe('Plugin Installation Flow Integration', () => { it('fetches releases, checks for updates, and uploads the new version', async () => { const mockReleases = [ { - tag_name: 'v2.0.0', - assets: [{ browser_download_url: 'https://github.com/test/v2.difypkg', name: 'plugin-v2.difypkg' }], + tag: 'v2.0.0', + assets: [{ downloadUrl: 'https://github.com/test/v2.difypkg' }], }, { - tag_name: 'v1.5.0', - assets: [{ browser_download_url: 'https://github.com/test/v1.5.difypkg', name: 'plugin-v1.5.difypkg' }], + tag: 'v1.5.0', + assets: [{ downloadUrl: 'https://github.com/test/v1.5.difypkg' }], }, { - tag_name: 'v1.0.0', - assets: [{ browser_download_url: 'https://github.com/test/v1.difypkg', name: 'plugin-v1.difypkg' }], + tag: 'v1.0.0', + assets: [{ downloadUrl: 'https://github.com/test/v1.difypkg' }], }, ] ;(globalThis.fetch as ReturnType).mockResolvedValue({ ok: true, - json: () => Promise.resolve(mockReleases), + json: () => Promise.resolve({ releases: mockReleases }), }) mockUploadGitHub.mockResolvedValue({ @@ -59,8 +61,6 @@ describe('Plugin Installation Flow Integration', () => { unique_identifier: 'test-plugin:2.0.0', }) - const { fetchReleases, checkForUpdates } = useGitHubReleases() - const releases = await fetchReleases('test-org', 'test-repo') expect(releases).toHaveLength(3) expect(releases[0].tag_name).toBe('v2.0.0') @@ -69,7 +69,6 @@ describe('Plugin Installation Flow Integration', () => { expect(needUpdate).toBe(true) expect(toastProps.message).toContain('v2.0.0') - const { handleUpload } = useGitHubUpload() const onSuccess = vi.fn() const result = await handleUpload( 'https://github.com/test-org/test-repo', @@ -96,18 +95,16 @@ describe('Plugin Installation Flow Integration', () => { it('handles no new version available', async () => { const mockReleases = [ { - tag_name: 'v1.0.0', - assets: [{ browser_download_url: 'https://github.com/test/v1.difypkg', name: 'plugin-v1.difypkg' }], + tag: 'v1.0.0', + assets: [{ downloadUrl: 'https://github.com/test/v1.difypkg' }], }, ] ;(globalThis.fetch as ReturnType).mockResolvedValue({ ok: true, - json: () => Promise.resolve(mockReleases), + json: () => Promise.resolve({ releases: mockReleases }), }) - const { fetchReleases, checkForUpdates } = useGitHubReleases() - const releases = await fetchReleases('test-org', 'test-repo') const { needUpdate, toastProps } = checkForUpdates(releases, 'v1.0.0') @@ -119,11 +116,9 @@ describe('Plugin Installation Flow Integration', () => { it('handles empty releases', async () => { ;(globalThis.fetch as ReturnType).mockResolvedValue({ ok: true, - json: () => Promise.resolve([]), + json: () => Promise.resolve({ releases: [] }), }) - const { fetchReleases, checkForUpdates } = useGitHubReleases() - const releases = await fetchReleases('test-org', 'test-repo') expect(releases).toHaveLength(0) @@ -139,7 +134,6 @@ describe('Plugin Installation Flow Integration', () => { status: 404, }) - const { fetchReleases } = useGitHubReleases() const releases = await fetchReleases('nonexistent-org', 'nonexistent-repo') expect(releases).toEqual([]) @@ -151,7 +145,6 @@ describe('Plugin Installation Flow Integration', () => { it('handles upload failure gracefully', async () => { mockUploadGitHub.mockRejectedValue(new Error('Upload failed')) - const { handleUpload } = useGitHubUpload() const onSuccess = vi.fn() await expect( diff --git a/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx b/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx index 6a4e71f574b..1d1c6518fe1 100644 --- a/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx @@ -24,17 +24,11 @@ export default function CheckCode() { const verify = async () => { try { if (!code.trim()) { - toast.add({ - type: 'error', - title: t('checkCode.emptyCode', { ns: 'login' }), - }) + toast.error(t('checkCode.emptyCode', { ns: 'login' })) return } if (!/\d{6}/.test(code)) { - toast.add({ - type: 'error', - title: t('checkCode.invalidCode', { ns: 'login' }), - }) + toast.error(t('checkCode.invalidCode', { ns: 'login' })) return } setIsLoading(true) diff --git a/web/app/(shareLayout)/webapp-reset-password/page.tsx b/web/app/(shareLayout)/webapp-reset-password/page.tsx index 08a42478aa5..0cdfb4ec11c 100644 --- a/web/app/(shareLayout)/webapp-reset-password/page.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/page.tsx @@ -27,15 +27,12 @@ export default function CheckCode() { const handleGetEMailVerificationCode = async () => { try { if (!email) { - toast.add({ type: 'error', title: t('error.emailEmpty', { ns: 'login' }) }) + toast.error(t('error.emailEmpty', { ns: 'login' })) return } if (!emailRegex.test(email)) { - toast.add({ - type: 'error', - title: t('error.emailInValid', { ns: 'login' }), - }) + toast.error(t('error.emailInValid', { ns: 'login' })) return } setIsLoading(true) @@ -48,16 +45,10 @@ export default function CheckCode() { router.push(`/webapp-reset-password/check-code?${params.toString()}`) } else if (res.code === 'account_not_found') { - toast.add({ - type: 'error', - title: t('error.registrationNotAllowed', { ns: 'login' }), - }) + toast.error(t('error.registrationNotAllowed', { ns: 'login' })) } else { - toast.add({ - type: 'error', - title: res.data, - }) + toast.error(res.data) } } catch (error) { diff --git a/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx b/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx index 22d2d22879a..bc8f651d170 100644 --- a/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx @@ -24,10 +24,7 @@ const ChangePasswordForm = () => { const [showConfirmPassword, setShowConfirmPassword] = useState(false) const showErrorMessage = useCallback((message: string) => { - toast.add({ - type: 'error', - title: message, - }) + toast.error(message) }, []) const getSignInUrl = () => { diff --git a/web/app/(shareLayout)/webapp-signin/check-code/page.tsx b/web/app/(shareLayout)/webapp-signin/check-code/page.tsx index 603369a858f..f209ad9e5c8 100644 --- a/web/app/(shareLayout)/webapp-signin/check-code/page.tsx +++ b/web/app/(shareLayout)/webapp-signin/check-code/page.tsx @@ -43,24 +43,15 @@ export default function CheckCode() { try { const appCode = getAppCodeFromRedirectUrl() if (!code.trim()) { - toast.add({ - type: 'error', - title: t('checkCode.emptyCode', { ns: 'login' }), - }) + toast.error(t('checkCode.emptyCode', { ns: 'login' })) return } if (!/\d{6}/.test(code)) { - toast.add({ - type: 'error', - title: t('checkCode.invalidCode', { ns: 'login' }), - }) + toast.error(t('checkCode.invalidCode', { ns: 'login' })) return } if (!redirectUrl || !appCode) { - toast.add({ - type: 'error', - title: t('error.redirectUrlMissing', { ns: 'login' }), - }) + toast.error(t('error.redirectUrlMissing', { ns: 'login' })) return } setIsLoading(true) diff --git a/web/app/(shareLayout)/webapp-signin/components/external-member-sso-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/external-member-sso-auth.tsx index b7fb7036e87..9b4a369908a 100644 --- a/web/app/(shareLayout)/webapp-signin/components/external-member-sso-auth.tsx +++ b/web/app/(shareLayout)/webapp-signin/components/external-member-sso-auth.tsx @@ -17,10 +17,7 @@ const ExternalMemberSSOAuth = () => { const redirectUrl = searchParams.get('redirect_url') const showErrorToast = (message: string) => { - toast.add({ - type: 'error', - title: message, - }) + toast.error(message) } const getAppCodeFromRedirectUrl = useCallback(() => { diff --git a/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx index 7a20713e05d..fbd6b216df0 100644 --- a/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx +++ b/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx @@ -22,15 +22,12 @@ export default function MailAndCodeAuth() { const handleGetEMailVerificationCode = async () => { try { if (!email) { - toast.add({ type: 'error', title: t('error.emailEmpty', { ns: 'login' }) }) + toast.error(t('error.emailEmpty', { ns: 'login' })) return } if (!emailRegex.test(email)) { - toast.add({ - type: 'error', - title: t('error.emailInValid', { ns: 'login' }), - }) + toast.error(t('error.emailInValid', { ns: 'login' })) return } setIsLoading(true) diff --git a/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx index bbc4cc8efd5..1e9355e7baa 100644 --- a/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx +++ b/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx @@ -46,26 +46,20 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut const appCode = getAppCodeFromRedirectUrl() const handleEmailPasswordLogin = async () => { if (!email) { - toast.add({ type: 'error', title: t('error.emailEmpty', { ns: 'login' }) }) + toast.error(t('error.emailEmpty', { ns: 'login' })) return } if (!emailRegex.test(email)) { - toast.add({ - type: 'error', - title: t('error.emailInValid', { ns: 'login' }), - }) + toast.error(t('error.emailInValid', { ns: 'login' })) return } if (!password?.trim()) { - toast.add({ type: 'error', title: t('error.passwordEmpty', { ns: 'login' }) }) + toast.error(t('error.passwordEmpty', { ns: 'login' })) return } if (!redirectUrl || !appCode) { - toast.add({ - type: 'error', - title: t('error.redirectUrlMissing', { ns: 'login' }), - }) + toast.error(t('error.redirectUrlMissing', { ns: 'login' })) return } try { @@ -94,15 +88,12 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut router.replace(decodeURIComponent(redirectUrl)) } else { - toast.add({ - type: 'error', - title: res.data, - }) + toast.error(res.data) } } catch (e: any) { if (e.code === 'authentication_failed') - toast.add({ type: 'error', title: e.message }) + toast.error(e.message) } finally { setIsLoading(false) diff --git a/web/app/(shareLayout)/webapp-signin/components/sso-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/sso-auth.tsx index fd12c2060f3..3178c638cc4 100644 --- a/web/app/(shareLayout)/webapp-signin/components/sso-auth.tsx +++ b/web/app/(shareLayout)/webapp-signin/components/sso-auth.tsx @@ -37,10 +37,7 @@ const SSOAuth: FC = ({ const handleSSOLogin = () => { const appCode = getAppCodeFromRedirectUrl() if (!redirectUrl || !appCode) { - toast.add({ - type: 'error', - title: t('error.invalidRedirectUrlOrAppCode', { ns: 'login' }), - }) + toast.error(t('error.invalidRedirectUrlOrAppCode', { ns: 'login' })) return } setIsLoading(true) @@ -66,10 +63,7 @@ const SSOAuth: FC = ({ }) } else { - toast.add({ - type: 'error', - title: t('error.invalidSSOProtocol', { ns: 'login' }), - }) + toast.error(t('error.invalidSSOProtocol', { ns: 'login' })) setIsLoading(false) } } diff --git a/web/app/account/oauth/authorize/page.tsx b/web/app/account/oauth/authorize/page.tsx index 219a26ddeb8..f4984fb85a3 100644 --- a/web/app/account/oauth/authorize/page.tsx +++ b/web/app/account/oauth/authorize/page.tsx @@ -93,10 +93,7 @@ export default function OAuthAuthorize() { globalThis.location.href = url.toString() } catch (err: any) { - toast.add({ - type: 'error', - title: `${t('error.authorizeFailed', { ns: 'oauth' })}: ${err.message}`, - }) + toast.error(`${t('error.authorizeFailed', { ns: 'oauth' })}: ${err.message}`) } } @@ -104,11 +101,10 @@ export default function OAuthAuthorize() { const invalidParams = !client_id || !redirect_uri if ((invalidParams || isError) && !hasNotifiedRef.current) { hasNotifiedRef.current = true - toast.add({ - type: 'error', - title: invalidParams ? t('error.invalidParams', { ns: 'oauth' }) : t('error.authAppInfoFetchFailed', { ns: 'oauth' }), - timeout: 0, - }) + toast.error( + invalidParams ? t('error.invalidParams', { ns: 'oauth' }) : t('error.authAppInfoFetchFailed', { ns: 'oauth' }), + { timeout: 0 }, + ) } }, [client_id, redirect_uri, isError]) diff --git a/web/app/components/app/app-access-control/access-control.spec.tsx b/web/app/components/app/app-access-control/access-control.spec.tsx index 3950bdf7ee7..3a5f2272edd 100644 --- a/web/app/components/app/app-access-control/access-control.spec.tsx +++ b/web/app/components/app/app-access-control/access-control.spec.tsx @@ -109,7 +109,7 @@ beforeAll(() => { disconnect = vi.fn(() => undefined) unobserve = vi.fn(() => undefined) } - // @ts-expect-error jsdom does not implement IntersectionObserver + // @ts-expect-error test DOM typings do not guarantee IntersectionObserver here globalThis.IntersectionObserver = MockIntersectionObserver }) diff --git a/web/app/components/app/configuration/config/agent/agent-setting/index.spec.tsx b/web/app/components/app/configuration/config/agent/agent-setting/index.spec.tsx index b3a9bd7abc0..1b8d64b911f 100644 --- a/web/app/components/app/configuration/config/agent/agent-setting/index.spec.tsx +++ b/web/app/components/app/configuration/config/agent/agent-setting/index.spec.tsx @@ -12,15 +12,15 @@ vi.mock('ahooks', async (importOriginal) => { } }) -vi.mock('react-slider', () => ({ - default: (props: { className?: string, min?: number, max?: number, value: number, onChange: (value: number) => void }) => ( +vi.mock('@/app/components/base/ui/slider', () => ({ + Slider: (props: { className?: string, min?: number, max?: number, value: number, onValueChange: (value: number) => void }) => ( props.onChange(Number(e.target.value))} + onChange={e => props.onValueChange(Number(e.target.value))} /> ), })) diff --git a/web/app/components/app/configuration/config/agent/agent-setting/index.tsx b/web/app/components/app/configuration/config/agent/agent-setting/index.tsx index ec42e946dd0..bce4e74aabe 100644 --- a/web/app/components/app/configuration/config/agent/agent-setting/index.tsx +++ b/web/app/components/app/configuration/config/agent/agent-setting/index.tsx @@ -9,7 +9,7 @@ import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import { CuteRobot } from '@/app/components/base/icons/src/vender/solid/communication' import { Unblur } from '@/app/components/base/icons/src/vender/solid/education' -import Slider from '@/app/components/base/slider' +import { Slider } from '@/app/components/base/ui/slider' import { DEFAULT_AGENT_PROMPT, MAX_ITERATIONS_NUM } from '@/config' import ItemPanel from './item-panel' @@ -105,12 +105,13 @@ const AgentSetting: FC = ({ min={maxIterationsMin} max={MAX_ITERATIONS_NUM} value={tempPayload.max_iteration} - onChange={(value) => { + onValueChange={(value) => { setTempPayload({ ...tempPayload, max_iteration: value, }) }} + aria-label={t('agent.setting.maximumIterations.name', { ns: 'appDebug' })} /> FeatureStoreState>(() => mockFeatureStoreState), } diff --git a/web/app/components/app/configuration/config/config-document.spec.tsx b/web/app/components/app/configuration/config/config-document.spec.tsx index 2aa87717fce..300acb7ce7d 100644 --- a/web/app/components/app/configuration/config/config-document.spec.tsx +++ b/web/app/components/app/configuration/config/config-document.spec.tsx @@ -1,4 +1,3 @@ -import type { Mock } from 'vitest' import type { FeatureStoreState } from '@/app/components/base/features/store' import { render, screen } from '@testing-library/react' import userEvent from '@testing-library/user-event' @@ -28,7 +27,7 @@ type SetupOptions = { } let mockFeatureStoreState: FeatureStoreState -let mockSetFeatures: Mock +let mockSetFeatures = vi.fn() const mockStore = { getState: vi.fn<() => FeatureStoreState>(() => mockFeatureStoreState), } diff --git a/web/app/components/app/configuration/config/index.spec.tsx b/web/app/components/app/configuration/config/index.spec.tsx index 875e5833975..b24c719b994 100644 --- a/web/app/components/app/configuration/config/index.spec.tsx +++ b/web/app/components/app/configuration/config/index.spec.tsx @@ -1,4 +1,3 @@ -import type { Mock } from 'vitest' import type { ModelConfig, PromptVariable } from '@/models/debug' import type { ToolItem } from '@/types/app' import { render, screen } from '@testing-library/react' @@ -74,10 +73,10 @@ type MockContext = { history: boolean query: boolean } - showHistoryModal: Mock + showHistoryModal: () => void modelConfig: ModelConfig - setModelConfig: Mock - setPrevPromptConfig: Mock + setModelConfig: (modelConfig: ModelConfig) => void + setPrevPromptConfig: (configs: ModelConfig['configs']) => void } const createPromptVariable = (overrides: Partial = {}): PromptVariable => ({ @@ -142,7 +141,7 @@ const createContextValue = (overrides: Partial = {}): MockContext = ...overrides, }) -const mockUseContext = useContextSelector.useContext as Mock +const mockUseContext = vi.mocked(useContextSelector.useContext) const renderConfig = (contextOverrides: Partial = {}) => { const contextValue = createContextValue(contextOverrides) diff --git a/web/app/components/app/configuration/dataset-config/params-config/config-content.spec.tsx b/web/app/components/app/configuration/dataset-config/params-config/config-content.spec.tsx index 4d8b10b22a7..2cd8418c656 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/config-content.spec.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/config-content.spec.tsx @@ -288,10 +288,8 @@ describe('ConfigContent', () => { />, ) - const weightedScoreSlider = screen.getAllByRole('slider') - .find(slider => slider.getAttribute('aria-valuemax') === '1') - expect(weightedScoreSlider).toBeDefined() - await user.click(weightedScoreSlider!) + const weightedScoreSlider = screen.getByLabelText('dataset.weightedScore.semantic') + weightedScoreSlider.focus() const callsBefore = onChange.mock.calls.length await user.keyboard('{ArrowRight}') diff --git a/web/app/components/app/configuration/dataset-config/params-config/index.spec.tsx b/web/app/components/app/configuration/dataset-config/params-config/index.spec.tsx index 93660394140..024432112d3 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/index.spec.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/index.spec.tsx @@ -3,7 +3,7 @@ import type { DatasetConfigs } from '@/models/debug' import { render, screen, waitFor, within } from '@testing-library/react' import userEvent from '@testing-library/user-event' import * as React from 'react' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel, @@ -75,7 +75,7 @@ vi.mock('@/app/components/header/account-setting/model-provider-page/model-param const mockedUseModelListAndDefaultModelAndCurrentProviderAndModel = useModelListAndDefaultModelAndCurrentProviderAndModel as MockedFunction const mockedUseCurrentProviderAndModel = useCurrentProviderAndModel as MockedFunction -let toastNotifySpy: MockInstance +let toastErrorSpy: MockInstance const createDatasetConfigs = (overrides: Partial = {}): DatasetConfigs => { return { @@ -140,7 +140,7 @@ describe('dataset-config/params-config', () => { beforeEach(() => { vi.clearAllMocks() vi.useRealTimers() - toastNotifySpy = vi.spyOn(Toast, 'notify').mockImplementation(() => ({})) + toastErrorSpy = vi.spyOn(toast, 'error').mockImplementation(() => '') mockedUseModelListAndDefaultModelAndCurrentProviderAndModel.mockReturnValue({ modelList: [], defaultModel: undefined, @@ -154,7 +154,7 @@ describe('dataset-config/params-config', () => { }) afterEach(() => { - toastNotifySpy.mockRestore() + toastErrorSpy.mockRestore() }) // Rendering tests (REQUIRED) @@ -254,10 +254,7 @@ describe('dataset-config/params-config', () => { await user.click(dialogScope.getByRole('button', { name: 'common.operation.save' })) // Assert - expect(toastNotifySpy).toHaveBeenCalledWith({ - type: 'error', - message: 'appDebug.datasetConfig.rerankModelRequired', - }) + expect(toastErrorSpy).toHaveBeenCalledWith('appDebug.datasetConfig.rerankModelRequired') expect(screen.getByRole('dialog')).toBeInTheDocument() }) }) diff --git a/web/app/components/app/configuration/dataset-config/params-config/index.tsx b/web/app/components/app/configuration/dataset-config/params-config/index.tsx index 692ae120227..89410203df6 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/index.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/index.tsx @@ -7,7 +7,7 @@ import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import Modal from '@/app/components/base/modal' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { @@ -66,10 +66,7 @@ const ParamsConfig = ({ } } if (errMsg) { - Toast.notify({ - type: 'error', - message: errMsg, - }) + toast.error(errMsg) } return !errMsg } diff --git a/web/app/components/app/configuration/dataset-config/params-config/weighted-score.css b/web/app/components/app/configuration/dataset-config/params-config/weighted-score.css deleted file mode 100644 index ef9350645a9..00000000000 --- a/web/app/components/app/configuration/dataset-config/params-config/weighted-score.css +++ /dev/null @@ -1,7 +0,0 @@ -.weightedScoreSliderTrack { - background: var(--color-util-colors-blue-light-blue-light-500) !important; -} - -.weightedScoreSliderTrack-1 { - background: transparent !important; -} diff --git a/web/app/components/app/configuration/dataset-config/params-config/weighted-score.spec.tsx b/web/app/components/app/configuration/dataset-config/params-config/weighted-score.spec.tsx index 77298303489..8e9348c77ab 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/weighted-score.spec.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/weighted-score.spec.tsx @@ -3,6 +3,8 @@ import userEvent from '@testing-library/user-event' import WeightedScore from './weighted-score' describe('WeightedScore', () => { + const getSliderInput = () => screen.getByLabelText('dataset.weightedScore.semantic') + beforeEach(() => { vi.clearAllMocks() }) @@ -48,8 +50,8 @@ describe('WeightedScore', () => { render() // Act - await user.tab() - const slider = screen.getByRole('slider') + const slider = getSliderInput() + slider.focus() expect(slider).toHaveFocus() const callsBefore = onChange.mock.calls.length await user.keyboard('{ArrowRight}') @@ -69,9 +71,8 @@ describe('WeightedScore', () => { render() // Act - await user.tab() - const slider = screen.getByRole('slider') - expect(slider).toHaveFocus() + const slider = getSliderInput() + expect(slider).toBeDisabled() await user.keyboard('{ArrowRight}') // Assert diff --git a/web/app/components/app/configuration/dataset-config/params-config/weighted-score.tsx b/web/app/components/app/configuration/dataset-config/params-config/weighted-score.tsx index 40beef52e82..d4ce935a4df 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/weighted-score.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/weighted-score.tsx @@ -1,9 +1,13 @@ +import type { CSSProperties } from 'react' import { noop } from 'es-toolkit/function' import { memo } from 'react' import { useTranslation } from 'react-i18next' -import Slider from '@/app/components/base/slider' -import { cn } from '@/utils/classnames' -import './weighted-score.css' +import { Slider } from '@/app/components/base/ui/slider' + +const weightedScoreSliderStyle: CSSProperties & Record<'--slider-track' | '--slider-range', string> = { + '--slider-track': 'var(--color-util-colors-teal-teal-500)', + '--slider-range': 'var(--color-util-colors-blue-light-blue-light-500)', +} const formatNumber = (value: number) => { if (value > 0 && value < 1) @@ -33,24 +37,26 @@ const WeightedScore = ({ return (
- !readonly && onChange({ value: [v, (10 - v * 10) / 10] })} - trackClassName="weightedScoreSliderTrack" - disabled={readonly} - /> +
+ !readonly && onChange({ value: [v, (10 - v * 10) / 10] })} + disabled={readonly} + aria-label={t('weightedScore.semantic', { ns: 'dataset' })} + /> +
-
+
{t('weightedScore.semantic', { ns: 'dataset' })}
{formatNumber(value.value[0])}
-
+
{formatNumber(value.value[1])}
{t('weightedScore.keyword', { ns: 'dataset' })} diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/index.spec.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/index.spec.tsx index 188086246a5..389ab189e9f 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/index.spec.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/index.spec.tsx @@ -556,8 +556,8 @@ describe('DebugWithMultipleModel', () => { ) const twoItems = screen.getAllByTestId('debug-item') - expect(twoItems[0].style.width).toBe('calc(50% - 28px)') - expect(twoItems[1].style.width).toBe('calc(50% - 28px)') + expect(twoItems[0].style.width).toBe('calc(50% - 4px - 24px)') + expect(twoItems[1].style.width).toBe('calc(50% - 4px - 24px)') }) }) @@ -596,13 +596,13 @@ describe('DebugWithMultipleModel', () => { // Assert expect(items).toHaveLength(2) expectItemLayout(items[0], { - width: 'calc(50% - 28px)', + width: 'calc(50% - 4px - 24px)', height: '100%', transform: 'translateX(0) translateY(0)', classes: ['mr-2'], }) expectItemLayout(items[1], { - width: 'calc(50% - 28px)', + width: 'calc(50% - 4px - 24px)', height: '100%', transform: 'translateX(calc(100% + 8px)) translateY(0)', classes: [], @@ -620,19 +620,19 @@ describe('DebugWithMultipleModel', () => { // Assert expect(items).toHaveLength(3) expectItemLayout(items[0], { - width: 'calc(33.3% - 21.33px)', + width: 'calc(33.3% - 5.33px - 16px)', height: '100%', transform: 'translateX(0) translateY(0)', classes: ['mr-2'], }) expectItemLayout(items[1], { - width: 'calc(33.3% - 21.33px)', + width: 'calc(33.3% - 5.33px - 16px)', height: '100%', transform: 'translateX(calc(100% + 8px)) translateY(0)', classes: ['mr-2'], }) expectItemLayout(items[2], { - width: 'calc(33.3% - 21.33px)', + width: 'calc(33.3% - 5.33px - 16px)', height: '100%', transform: 'translateX(calc(200% + 16px)) translateY(0)', classes: [], @@ -655,25 +655,25 @@ describe('DebugWithMultipleModel', () => { // Assert expect(items).toHaveLength(4) expectItemLayout(items[0], { - width: 'calc(50% - 28px)', + width: 'calc(50% - 4px - 24px)', height: 'calc(50% - 4px)', transform: 'translateX(0) translateY(0)', classes: ['mr-2', 'mb-2'], }) expectItemLayout(items[1], { - width: 'calc(50% - 28px)', + width: 'calc(50% - 4px - 24px)', height: 'calc(50% - 4px)', transform: 'translateX(calc(100% + 8px)) translateY(0)', classes: ['mb-2'], }) expectItemLayout(items[2], { - width: 'calc(50% - 28px)', + width: 'calc(50% - 4px - 24px)', height: 'calc(50% - 4px)', transform: 'translateX(0) translateY(calc(100% + 8px))', classes: ['mr-2'], }) expectItemLayout(items[3], { - width: 'calc(50% - 28px)', + width: 'calc(50% - 4px - 24px)', height: 'calc(50% - 4px)', transform: 'translateX(calc(100% + 8px)) translateY(calc(100% + 8px))', classes: [], diff --git a/web/app/components/app/create-app-dialog/app-list/index.tsx b/web/app/components/app/create-app-dialog/app-list/index.tsx index 8b1876be043..1aa40d2014e 100644 --- a/web/app/components/app/create-app-dialog/app-list/index.tsx +++ b/web/app/components/app/create-app-dialog/app-list/index.tsx @@ -137,10 +137,7 @@ const Apps = ({ }) setIsShowCreateModal(false) - toast.add({ - type: 'success', - title: t('newApp.appCreated', { ns: 'app' }), - }) + toast.success(t('newApp.appCreated', { ns: 'app' })) if (onSuccess) onSuccess() if (app.app_id) @@ -149,7 +146,7 @@ const Apps = ({ getRedirection(isCurrentWorkspaceEditor, { id: app.app_id!, mode }, push) } catch { - toast.add({ type: 'error', title: t('newApp.appCreateFailed', { ns: 'app' }) }) + toast.error(t('newApp.appCreateFailed', { ns: 'app' })) } } diff --git a/web/app/components/app/overview/settings/index.spec.tsx b/web/app/components/app/overview/settings/index.spec.tsx index b849b4f0156..e933855ca8d 100644 --- a/web/app/components/app/overview/settings/index.spec.tsx +++ b/web/app/components/app/overview/settings/index.spec.tsx @@ -1,6 +1,3 @@ -/** - * @vitest-environment jsdom - */ import type { ReactNode } from 'react' import type { ModalContextState } from '@/context/modal-context' import type { ProviderContextState } from '@/context/provider-context' diff --git a/web/app/components/app/type-selector/index.spec.tsx b/web/app/components/app/type-selector/index.spec.tsx index e24d9633054..711678f0a8c 100644 --- a/web/app/components/app/type-selector/index.spec.tsx +++ b/web/app/components/app/type-selector/index.spec.tsx @@ -1,4 +1,4 @@ -import { fireEvent, render, screen, within } from '@testing-library/react' +import { fireEvent, render, screen, waitFor, within } from '@testing-library/react' import * as React from 'react' import { AppModeEnum } from '@/types/app' import AppTypeSelector, { AppTypeIcon, AppTypeLabel } from './index' @@ -14,7 +14,7 @@ describe('AppTypeSelector', () => { render() expect(screen.getByText('app.typeSelector.all')).toBeInTheDocument() - expect(screen.queryByRole('tooltip')).not.toBeInTheDocument() + expect(screen.queryByText('app.typeSelector.workflow')).not.toBeInTheDocument() }) }) @@ -39,24 +39,27 @@ describe('AppTypeSelector', () => { // Covers opening/closing the dropdown and selection updates. describe('User interactions', () => { - it('should toggle option list when clicking the trigger', () => { + it('should close option list when clicking outside', () => { render() - expect(screen.queryByRole('tooltip')).not.toBeInTheDocument() + expect(screen.queryByRole('list')).not.toBeInTheDocument() - fireEvent.click(screen.getByText('app.typeSelector.all')) - expect(screen.getByRole('tooltip')).toBeInTheDocument() + fireEvent.click(screen.getByRole('button', { name: 'app.typeSelector.all' })) + expect(screen.getByRole('list')).toBeInTheDocument() - fireEvent.click(screen.getByText('app.typeSelector.all')) - expect(screen.queryByRole('tooltip')).not.toBeInTheDocument() + fireEvent.pointerDown(document.body) + fireEvent.click(document.body) + return waitFor(() => { + expect(screen.queryByRole('list')).not.toBeInTheDocument() + }) }) it('should call onChange with added type when selecting an unselected item', () => { const onChange = vi.fn() render() - fireEvent.click(screen.getByText('app.typeSelector.all')) - fireEvent.click(within(screen.getByRole('tooltip')).getByText('app.typeSelector.workflow')) + fireEvent.click(screen.getByRole('button', { name: 'app.typeSelector.all' })) + fireEvent.click(within(screen.getByRole('list')).getByRole('button', { name: 'app.typeSelector.workflow' })) expect(onChange).toHaveBeenCalledWith([AppModeEnum.WORKFLOW]) }) @@ -65,8 +68,8 @@ describe('AppTypeSelector', () => { const onChange = vi.fn() render() - fireEvent.click(screen.getByText('app.typeSelector.workflow')) - fireEvent.click(within(screen.getByRole('tooltip')).getByText('app.typeSelector.workflow')) + fireEvent.click(screen.getByRole('button', { name: 'app.typeSelector.workflow' })) + fireEvent.click(within(screen.getByRole('list')).getByRole('button', { name: 'app.typeSelector.workflow' })) expect(onChange).toHaveBeenCalledWith([]) }) @@ -75,8 +78,8 @@ describe('AppTypeSelector', () => { const onChange = vi.fn() render() - fireEvent.click(screen.getByText('app.typeSelector.chatbot')) - fireEvent.click(within(screen.getByRole('tooltip')).getByText('app.typeSelector.agent')) + fireEvent.click(screen.getByRole('button', { name: 'app.typeSelector.chatbot' })) + fireEvent.click(within(screen.getByRole('list')).getByRole('button', { name: 'app.typeSelector.agent' })) expect(onChange).toHaveBeenCalledWith([AppModeEnum.CHAT, AppModeEnum.AGENT_CHAT]) }) @@ -88,7 +91,7 @@ describe('AppTypeSelector', () => { fireEvent.click(screen.getByRole('button', { name: 'common.operation.clear' })) expect(onChange).toHaveBeenCalledWith([]) - expect(screen.queryByRole('tooltip')).not.toBeInTheDocument() + expect(screen.queryByText('app.typeSelector.workflow')).not.toBeInTheDocument() }) }) }) diff --git a/web/app/components/app/type-selector/index.tsx b/web/app/components/app/type-selector/index.tsx index e97da4b7f30..a1475f9effa 100644 --- a/web/app/components/app/type-selector/index.tsx +++ b/web/app/components/app/type-selector/index.tsx @@ -4,13 +4,12 @@ import { useState } from 'react' import { useTranslation } from 'react-i18next' import { BubbleTextMod, ChatBot, ListSparkle, Logic } from '@/app/components/base/icons/src/vender/solid/communication' import { - PortalToFollowElem, - PortalToFollowElemContent, - PortalToFollowElemTrigger, -} from '@/app/components/base/portal-to-follow-elem' + Popover, + PopoverContent, + PopoverTrigger, +} from '@/app/components/base/ui/popover' import { AppModeEnum } from '@/types/app' import { cn } from '@/utils/classnames' -import Checkbox from '../../base/checkbox' export type AppSelectorProps = { value: Array @@ -22,43 +21,43 @@ const allTypes: AppModeEnum[] = [AppModeEnum.WORKFLOW, AppModeEnum.ADVANCED_CHAT const AppTypeSelector = ({ value, onChange }: AppSelectorProps) => { const [open, setOpen] = useState(false) const { t } = useTranslation() + const triggerLabel = value.length === 0 + ? t('typeSelector.all', { ns: 'app' }) + : value.map(type => getAppTypeLabel(type, t)).join(', ') return ( -
- setOpen(v => !v)} - className="block" - > -
0 && 'pr-7', )} + > + + + {value.length > 0 && ( + - )} -
-
- -
    + + + )} + +
      {allTypes.map(mode => ( { /> ))}
    - +
-
+ ) } @@ -173,33 +172,54 @@ type AppTypeSelectorItemProps = { } function AppTypeSelectorItem({ checked, type, onClick }: AppTypeSelectorItemProps) { return ( -
  • - - -
    - -
    +
  • +
  • ) } +function getAppTypeLabel(type: AppModeEnum, t: ReturnType['t']) { + if (type === AppModeEnum.CHAT) + return t('typeSelector.chatbot', { ns: 'app' }) + if (type === AppModeEnum.AGENT_CHAT) + return t('typeSelector.agent', { ns: 'app' }) + if (type === AppModeEnum.COMPLETION) + return t('typeSelector.completion', { ns: 'app' }) + if (type === AppModeEnum.ADVANCED_CHAT) + return t('typeSelector.advanced', { ns: 'app' }) + if (type === AppModeEnum.WORKFLOW) + return t('typeSelector.workflow', { ns: 'app' }) + + return '' +} + type AppTypeLabelProps = { type: AppModeEnum className?: string } export function AppTypeLabel({ type, className }: AppTypeLabelProps) { const { t } = useTranslation() - let label = '' - if (type === AppModeEnum.CHAT) - label = t('typeSelector.chatbot', { ns: 'app' }) - if (type === AppModeEnum.AGENT_CHAT) - label = t('typeSelector.agent', { ns: 'app' }) - if (type === AppModeEnum.COMPLETION) - label = t('typeSelector.completion', { ns: 'app' }) - if (type === AppModeEnum.ADVANCED_CHAT) - label = t('typeSelector.advanced', { ns: 'app' }) - if (type === AppModeEnum.WORKFLOW) - label = t('typeSelector.workflow', { ns: 'app' }) - return {label} + return {getAppTypeLabel(type, t)} } diff --git a/web/app/components/apps/hooks/__tests__/use-dsl-drag-drop.spec.ts b/web/app/components/apps/hooks/__tests__/use-dsl-drag-drop.spec.ts index 58fed4caa8a..00e2d69ab24 100644 --- a/web/app/components/apps/hooks/__tests__/use-dsl-drag-drop.spec.ts +++ b/web/app/components/apps/hooks/__tests__/use-dsl-drag-drop.spec.ts @@ -1,16 +1,15 @@ -import type { Mock } from 'vitest' import { act, renderHook } from '@testing-library/react' import { useDSLDragDrop } from '../use-dsl-drag-drop' describe('useDSLDragDrop', () => { let container: HTMLDivElement - let mockOnDSLFileDropped: Mock + let mockOnDSLFileDropped = vi.fn<(file: File) => void>() beforeEach(() => { vi.clearAllMocks() container = document.createElement('div') document.body.appendChild(container) - mockOnDSLFileDropped = vi.fn() + mockOnDSLFileDropped = vi.fn<(file: File) => void>() }) afterEach(() => { diff --git a/web/app/components/apps/index.tsx b/web/app/components/apps/index.tsx index dce9de190d2..b6ca60bd7bd 100644 --- a/web/app/components/apps/index.tsx +++ b/web/app/components/apps/index.tsx @@ -8,12 +8,14 @@ import AppListContext from '@/context/app-list-context' import useDocumentTitle from '@/hooks/use-document-title' import { useImportDSL } from '@/hooks/use-import-dsl' import { DSLImportMode } from '@/models/app' +import dynamic from '@/next/dynamic' import { fetchAppDetail } from '@/service/explore' -import DSLConfirmModal from '../app/create-from-dsl-modal/dsl-confirm-modal' -import CreateAppModal from '../explore/create-app-modal' -import TryApp from '../explore/try-app' import List from './list' +const DSLConfirmModal = dynamic(() => import('../app/create-from-dsl-modal/dsl-confirm-modal'), { ssr: false }) +const CreateAppModal = dynamic(() => import('../explore/create-app-modal'), { ssr: false }) +const TryApp = dynamic(() => import('../explore/try-app'), { ssr: false }) + const Apps = () => { const { t } = useTranslation() diff --git a/web/app/components/apps/list.tsx b/web/app/components/apps/list.tsx index 0d52bd468cf..2ef344f816a 100644 --- a/web/app/components/apps/list.tsx +++ b/web/app/components/apps/list.tsx @@ -5,11 +5,11 @@ import { useDebounceFn } from 'ahooks' import { parseAsStringLiteral, useQueryState } from 'nuqs' import { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' +import Checkbox from '@/app/components/base/checkbox' import Input from '@/app/components/base/input' import TabSliderNew from '@/app/components/base/tab-slider-new' import TagFilter from '@/app/components/base/tag-management/filter' import { useStore as useTagStore } from '@/app/components/base/tag-management/store' -import CheckboxWithLabel from '@/app/components/datasets/create/website/base/checkbox-with-label' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useAppContext } from '@/context/app-context' import { useGlobalPublicStore } from '@/context/global-public-context' @@ -205,12 +205,12 @@ const List: FC = ({ options={options} />
    - + { ) const button = screen.getByRole('button', { name: 'Custom Style' }) expect(button).toHaveStyle({ - color: 'rgb(255, 0, 0)', - backgroundColor: 'rgb(0, 0, 255)', + color: 'red', + backgroundColor: 'blue', }) }) diff --git a/web/app/components/base/amplitude/AmplitudeProvider.tsx b/web/app/components/base/amplitude/AmplitudeProvider.tsx index e1d8e52eacb..00af15e24db 100644 --- a/web/app/components/base/amplitude/AmplitudeProvider.tsx +++ b/web/app/components/base/amplitude/AmplitudeProvider.tsx @@ -5,17 +5,12 @@ import * as amplitude from '@amplitude/analytics-browser' import { sessionReplayPlugin } from '@amplitude/plugin-session-replay-browser' import * as React from 'react' import { useEffect } from 'react' -import { AMPLITUDE_API_KEY, IS_CLOUD_EDITION } from '@/config' +import { AMPLITUDE_API_KEY, isAmplitudeEnabled } from '@/config' export type IAmplitudeProps = { sessionReplaySampleRate?: number } -// Check if Amplitude should be enabled -export const isAmplitudeEnabled = () => { - return IS_CLOUD_EDITION && !!AMPLITUDE_API_KEY -} - // Map URL pathname to English page name for consistent Amplitude tracking const getEnglishPageName = (pathname: string): string => { // Remove leading slash and get the first segment @@ -59,7 +54,7 @@ const AmplitudeProvider: FC = ({ }) => { useEffect(() => { // Only enable in Saas edition with valid API key - if (!isAmplitudeEnabled()) + if (!isAmplitudeEnabled) return // Initialize Amplitude diff --git a/web/app/components/base/amplitude/__tests__/AmplitudeProvider.spec.tsx b/web/app/components/base/amplitude/__tests__/AmplitudeProvider.spec.tsx index b30da72091c..5835634eb72 100644 --- a/web/app/components/base/amplitude/__tests__/AmplitudeProvider.spec.tsx +++ b/web/app/components/base/amplitude/__tests__/AmplitudeProvider.spec.tsx @@ -2,14 +2,24 @@ import * as amplitude from '@amplitude/analytics-browser' import { sessionReplayPlugin } from '@amplitude/plugin-session-replay-browser' import { render } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import AmplitudeProvider, { isAmplitudeEnabled } from '../AmplitudeProvider' +import AmplitudeProvider from '../AmplitudeProvider' const mockConfig = vi.hoisted(() => ({ AMPLITUDE_API_KEY: 'test-api-key', IS_CLOUD_EDITION: true, })) -vi.mock('@/config', () => mockConfig) +vi.mock('@/config', () => ({ + get AMPLITUDE_API_KEY() { + return mockConfig.AMPLITUDE_API_KEY + }, + get IS_CLOUD_EDITION() { + return mockConfig.IS_CLOUD_EDITION + }, + get isAmplitudeEnabled() { + return mockConfig.IS_CLOUD_EDITION && !!mockConfig.AMPLITUDE_API_KEY + }, +})) vi.mock('@amplitude/analytics-browser', () => ({ init: vi.fn(), @@ -27,22 +37,6 @@ describe('AmplitudeProvider', () => { mockConfig.IS_CLOUD_EDITION = true }) - describe('isAmplitudeEnabled', () => { - it('returns true when cloud edition and api key present', () => { - expect(isAmplitudeEnabled()).toBe(true) - }) - - it('returns false when cloud edition but no api key', () => { - mockConfig.AMPLITUDE_API_KEY = '' - expect(isAmplitudeEnabled()).toBe(false) - }) - - it('returns false when not cloud edition', () => { - mockConfig.IS_CLOUD_EDITION = false - expect(isAmplitudeEnabled()).toBe(false) - }) - }) - describe('Component', () => { it('initializes amplitude when enabled', () => { render() diff --git a/web/app/components/base/amplitude/__tests__/index.spec.ts b/web/app/components/base/amplitude/__tests__/index.spec.ts deleted file mode 100644 index 2d7ad6ab84d..00000000000 --- a/web/app/components/base/amplitude/__tests__/index.spec.ts +++ /dev/null @@ -1,32 +0,0 @@ -import { describe, expect, it } from 'vitest' -import AmplitudeProvider, { isAmplitudeEnabled } from '../AmplitudeProvider' -import indexDefault, { - isAmplitudeEnabled as indexIsAmplitudeEnabled, - resetUser, - setUserId, - setUserProperties, - trackEvent, -} from '../index' -import { - resetUser as utilsResetUser, - setUserId as utilsSetUserId, - setUserProperties as utilsSetUserProperties, - trackEvent as utilsTrackEvent, -} from '../utils' - -describe('Amplitude index exports', () => { - it('exports AmplitudeProvider as default', () => { - expect(indexDefault).toBe(AmplitudeProvider) - }) - - it('exports isAmplitudeEnabled', () => { - expect(indexIsAmplitudeEnabled).toBe(isAmplitudeEnabled) - }) - - it('exports utils', () => { - expect(resetUser).toBe(utilsResetUser) - expect(setUserId).toBe(utilsSetUserId) - expect(setUserProperties).toBe(utilsSetUserProperties) - expect(trackEvent).toBe(utilsTrackEvent) - }) -}) diff --git a/web/app/components/base/amplitude/__tests__/utils.spec.ts b/web/app/components/base/amplitude/__tests__/utils.spec.ts index ecbc57e387c..f1ff5db1e35 100644 --- a/web/app/components/base/amplitude/__tests__/utils.spec.ts +++ b/web/app/components/base/amplitude/__tests__/utils.spec.ts @@ -20,8 +20,10 @@ const MockIdentify = vi.hoisted(() => }, ) -vi.mock('../AmplitudeProvider', () => ({ - isAmplitudeEnabled: () => mockState.enabled, +vi.mock('@/config', () => ({ + get isAmplitudeEnabled() { + return mockState.enabled + }, })) vi.mock('@amplitude/analytics-browser', () => ({ diff --git a/web/app/components/base/amplitude/index.ts b/web/app/components/base/amplitude/index.ts index acc792339ef..44cbf728e22 100644 --- a/web/app/components/base/amplitude/index.ts +++ b/web/app/components/base/amplitude/index.ts @@ -1,2 +1,2 @@ -export { default, isAmplitudeEnabled } from './AmplitudeProvider' +export { default } from './lazy-amplitude-provider' export { resetUser, setUserId, setUserProperties, trackEvent } from './utils' diff --git a/web/app/components/base/amplitude/lazy-amplitude-provider.tsx b/web/app/components/base/amplitude/lazy-amplitude-provider.tsx new file mode 100644 index 00000000000..5dfa0e7b539 --- /dev/null +++ b/web/app/components/base/amplitude/lazy-amplitude-provider.tsx @@ -0,0 +1,11 @@ +'use client' + +import type { FC } from 'react' +import type { IAmplitudeProps } from './AmplitudeProvider' +import dynamic from '@/next/dynamic' + +const AmplitudeProvider = dynamic(() => import('./AmplitudeProvider'), { ssr: false }) + +const LazyAmplitudeProvider: FC = props => + +export default LazyAmplitudeProvider diff --git a/web/app/components/base/amplitude/utils.ts b/web/app/components/base/amplitude/utils.ts index 57b96243ec1..8faa8e852e7 100644 --- a/web/app/components/base/amplitude/utils.ts +++ b/web/app/components/base/amplitude/utils.ts @@ -1,5 +1,5 @@ import * as amplitude from '@amplitude/analytics-browser' -import { isAmplitudeEnabled } from './AmplitudeProvider' +import { isAmplitudeEnabled } from '@/config' /** * Track custom event @@ -7,7 +7,7 @@ import { isAmplitudeEnabled } from './AmplitudeProvider' * @param eventProperties Event properties (optional) */ export const trackEvent = (eventName: string, eventProperties?: Record) => { - if (!isAmplitudeEnabled()) + if (!isAmplitudeEnabled) return amplitude.track(eventName, eventProperties) } @@ -17,7 +17,7 @@ export const trackEvent = (eventName: string, eventProperties?: Record { - if (!isAmplitudeEnabled()) + if (!isAmplitudeEnabled) return amplitude.setUserId(userId) } @@ -27,7 +27,7 @@ export const setUserId = (userId: string) => { * @param properties User properties */ export const setUserProperties = (properties: Record) => { - if (!isAmplitudeEnabled()) + if (!isAmplitudeEnabled) return const identifyEvent = new amplitude.Identify() Object.entries(properties).forEach(([key, value]) => { @@ -40,7 +40,7 @@ export const setUserProperties = (properties: Record) => { * Reset user (e.g., when user logs out) */ export const resetUser = () => { - if (!isAmplitudeEnabled()) + if (!isAmplitudeEnabled) return amplitude.reset() } diff --git a/web/app/components/base/avatar/index.tsx b/web/app/components/base/avatar/index.tsx index 2d55ec2720d..f53e1f8985c 100644 --- a/web/app/components/base/avatar/index.tsx +++ b/web/app/components/base/avatar/index.tsx @@ -1,8 +1,9 @@ import type { ImageLoadingStatus } from '@base-ui/react/avatar' +import type * as React from 'react' import { Avatar as BaseAvatar } from '@base-ui/react/avatar' import { cn } from '@/utils/classnames' -const SIZES = { +const avatarSizeClasses = { 'xxs': { root: 'size-4', text: 'text-[7px]' }, 'xs': { root: 'size-5', text: 'text-[8px]' }, 'sm': { root: 'size-6', text: 'text-[10px]' }, @@ -13,7 +14,7 @@ const SIZES = { '3xl': { root: 'size-16', text: 'text-2xl' }, } as const -export type AvatarSize = keyof typeof SIZES +export type AvatarSize = keyof typeof avatarSizeClasses export type AvatarProps = { name: string @@ -23,7 +24,61 @@ export type AvatarProps = { onLoadingStatusChange?: (status: ImageLoadingStatus) => void } -const BASE_CLASS = 'relative inline-flex shrink-0 select-none items-center justify-center overflow-hidden rounded-full bg-primary-600' +export type AvatarRootProps = React.ComponentPropsWithRef & { + size?: AvatarSize +} + +export function AvatarRoot({ + size = 'md', + className, + ...props +}: AvatarRootProps) { + return ( + + ) +} + +export type AvatarImageProps = React.ComponentPropsWithRef + +export function AvatarImage({ + className, + ...props +}: AvatarImageProps) { + return ( + + ) +} + +export type AvatarFallbackProps = React.ComponentPropsWithRef & { + size?: AvatarSize +} + +export function AvatarFallback({ + size = 'md', + className, + ...props +}: AvatarFallbackProps) { + return ( + + ) +} export const Avatar = ({ name, @@ -32,21 +87,18 @@ export const Avatar = ({ className, onLoadingStatusChange, }: AvatarProps) => { - const sizeConfig = SIZES[size] - return ( - + {avatar && ( - )} - + {name?.[0]?.toLocaleUpperCase()} - - + + ) } diff --git a/web/app/components/base/carousel/__tests__/index.spec.tsx b/web/app/components/base/carousel/__tests__/index.spec.tsx index cc452569374..e409b85757f 100644 --- a/web/app/components/base/carousel/__tests__/index.spec.tsx +++ b/web/app/components/base/carousel/__tests__/index.spec.tsx @@ -11,15 +11,15 @@ type EmblaEventName = 'reInit' | 'select' type EmblaListener = (api: MockEmblaApi | undefined) => void type MockEmblaApi = { - scrollPrev: Mock - scrollNext: Mock - scrollTo: Mock - selectedScrollSnap: Mock - canScrollPrev: Mock - canScrollNext: Mock - slideNodes: Mock - on: Mock - off: Mock + scrollPrev: Mock<() => void> + scrollNext: Mock<() => void> + scrollTo: Mock<(index: number) => void> + selectedScrollSnap: Mock<() => number> + canScrollPrev: Mock<() => boolean> + canScrollNext: Mock<() => boolean> + slideNodes: Mock<() => HTMLDivElement[]> + on: Mock<(event: EmblaEventName, callback: EmblaListener) => void> + off: Mock<(event: EmblaEventName, callback: EmblaListener) => void> } let mockCanScrollPrev = false @@ -33,19 +33,19 @@ const mockCarouselRef = vi.fn() const mockedUseEmblaCarousel = vi.mocked(useEmblaCarousel) const createMockEmblaApi = (): MockEmblaApi => ({ - scrollPrev: vi.fn(), - scrollNext: vi.fn(), - scrollTo: vi.fn(), - selectedScrollSnap: vi.fn(() => mockSelectedIndex), - canScrollPrev: vi.fn(() => mockCanScrollPrev), - canScrollNext: vi.fn(() => mockCanScrollNext), - slideNodes: vi.fn(() => - Array.from({ length: mockSlideCount }).fill(document.createElement('div')), + scrollPrev: vi.fn<() => void>(), + scrollNext: vi.fn<() => void>(), + scrollTo: vi.fn<(index: number) => void>(), + selectedScrollSnap: vi.fn<() => number>(() => mockSelectedIndex), + canScrollPrev: vi.fn<() => boolean>(() => mockCanScrollPrev), + canScrollNext: vi.fn<() => boolean>(() => mockCanScrollNext), + slideNodes: vi.fn<() => HTMLDivElement[]>(() => + Array.from({ length: mockSlideCount }, () => document.createElement('div')), ), - on: vi.fn((event: EmblaEventName, callback: EmblaListener) => { + on: vi.fn<(event: EmblaEventName, callback: EmblaListener) => void>((event, callback) => { listeners[event].push(callback) }), - off: vi.fn((event: EmblaEventName, callback: EmblaListener) => { + off: vi.fn<(event: EmblaEventName, callback: EmblaListener) => void>((event, callback) => { listeners[event] = listeners[event].filter(listener => listener !== callback) }), }) diff --git a/web/app/components/base/chat/chat/__tests__/hooks.spec.tsx b/web/app/components/base/chat/chat/__tests__/hooks.spec.tsx index f5b261d5f31..92fa9ea42ee 100644 --- a/web/app/components/base/chat/chat/__tests__/hooks.spec.tsx +++ b/web/app/components/base/chat/chat/__tests__/hooks.spec.tsx @@ -141,6 +141,145 @@ describe('useChat', () => { expect(result.current.chatList[0].suggestedQuestions).toEqual(['Ask Bob']) }) + describe('opening statement referential stability', () => { + it('should keep the same item reference across multiple streaming chatTree mutations', () => { + let callbacks: HookCallbacks + + vi.mocked(ssePost).mockImplementation(async (_url, _params, options) => { + callbacks = options as HookCallbacks + }) + + const config = { + opening_statement: 'Welcome!', + suggested_questions: ['Q1', 'Q2'], + } + const { result } = renderHook(() => useChat(config as ChatConfig)) + + const openerInitial = result.current.chatList[0] + expect(openerInitial.isOpeningStatement).toBe(true) + expect(openerInitial.content).toBe('Welcome!') + + act(() => { + result.current.handleSend('url', { query: 'hello' }, {}) + }) + + act(() => { + callbacks.onWorkflowStarted({ workflow_run_id: 'wr-1', task_id: 't-1' }) + }) + expect(result.current.chatList[0]).toBe(openerInitial) + + act(() => { + callbacks.onData('chunk-1 ', true, { messageId: 'm-1', conversationId: 'c-1', taskId: 't-1' }) + }) + expect(result.current.chatList.length).toBeGreaterThan(1) + expect(result.current.chatList[0]).toBe(openerInitial) + + act(() => { + callbacks.onData('chunk-2 ', false, { messageId: 'm-1' }) + }) + expect(result.current.chatList[0]).toBe(openerInitial) + + act(() => { + callbacks.onData('chunk-3', false, { messageId: 'm-1' }) + callbacks.onMessageEnd({ metadata: { retriever_resources: [] } }) + callbacks.onWorkflowFinished({ data: { status: 'succeeded' } }) + callbacks.onCompleted() + }) + expect(result.current.chatList[0]).toBe(openerInitial) + expect(result.current.chatList.at(-1)!.content).toBe('chunk-1 chunk-2 chunk-3') + }) + + it('should keep stable reference when getIntroduction identity changes but output is identical', () => { + const config = { + opening_statement: 'Hello {{name}}', + suggested_questions: ['Ask about {{name}}'], + } + + const { result, rerender } = renderHook( + ({ fs }) => useChat(config as ChatConfig, fs as UseChatFormSettings), + { initialProps: { fs: { inputs: { name: 'Alice' }, inputsForm: [] } } }, + ) + + const openerBefore = result.current.chatList[0] + expect(openerBefore.content).toBe('Hello Alice') + expect(openerBefore.suggestedQuestions).toEqual(['Ask about Alice']) + + rerender({ fs: { inputs: { name: 'Alice' }, inputsForm: [] } }) + + expect(result.current.chatList[0]).toBe(openerBefore) + }) + + it('should produce a new item when the processed content actually changes', () => { + const config = { + opening_statement: 'Hello {{name}}', + suggested_questions: ['Ask {{name}}'], + } + + const { result, rerender } = renderHook( + ({ fs }) => useChat(config as ChatConfig, fs as UseChatFormSettings), + { initialProps: { fs: { inputs: { name: 'Alice' }, inputsForm: [] } } }, + ) + + const before = result.current.chatList[0] + + rerender({ fs: { inputs: { name: 'Bob' }, inputsForm: [] } }) + + const after = result.current.chatList[0] + expect(after).not.toBe(before) + expect(after.content).toBe('Hello Bob') + expect(after.suggestedQuestions).toEqual(['Ask Bob']) + }) + + it('should keep content and suggestedQuestions stable for opener already in prevChatTree even when sibling metadata changes', () => { + let callbacks: HookCallbacks + vi.mocked(ssePost).mockImplementation(async (_url, _params, options) => { + callbacks = options as HookCallbacks + }) + + const config = { + opening_statement: 'Hello updated', + suggested_questions: ['S1'], + } + const prevChatTree = [{ + id: 'opening-statement', + content: 'old', + isAnswer: true, + isOpeningStatement: true, + suggestedQuestions: [], + }] + + const { result } = renderHook(() => + useChat(config as ChatConfig, undefined, prevChatTree as ChatItemInTree[]), + ) + + const openerBefore = result.current.chatList[0] + expect(openerBefore.content).toBe('Hello updated') + expect(openerBefore.suggestedQuestions).toEqual(['S1']) + + const contentBefore = openerBefore.content + const suggestionsBefore = openerBefore.suggestedQuestions + + act(() => { + result.current.handleSend('url', { query: 'msg' }, {}) + }) + act(() => { + callbacks.onData('resp', true, { messageId: 'm-1', conversationId: 'c-1', taskId: 't-1' }) + }) + + expect(result.current.chatList.length).toBeGreaterThan(1) + const openerAfter = result.current.chatList[0] + expect(openerAfter.content).toBe(contentBefore) + expect(openerAfter.suggestedQuestions).toBe(suggestionsBefore) + }) + + it('should use a stable id of "opening-statement"', () => { + const { result } = renderHook(() => + useChat({ opening_statement: 'Hi' } as ChatConfig), + ) + expect(result.current.chatList[0].id).toBe('opening-statement') + }) + }) + describe('handleSend', () => { it('should block send if already responding', async () => { const { result } = renderHook(() => useChat()) diff --git a/web/app/components/base/chat/chat/__tests__/index.spec.tsx b/web/app/components/base/chat/chat/__tests__/index.spec.tsx index 781b5e86f31..0100b059f01 100644 --- a/web/app/components/base/chat/chat/__tests__/index.spec.tsx +++ b/web/app/components/base/chat/chat/__tests__/index.spec.tsx @@ -8,10 +8,10 @@ import Chat from '../index' // ─── Why each mock exists ───────────────────────────────────────────────────── // // Answer – transitively pulls Markdown (rehype/remark/katex), AgentContent, -// WorkflowProcessItem and Operation; none can resolve in jsdom. +// WorkflowProcessItem and Operation; none can resolve in the test DOM runtime. // Question – pulls Markdown, copy-to-clipboard, react-textarea-autosize. // ChatInputArea – pulls js-audio-recorder (requires Web Audio API unavailable in -// jsdom) and VoiceInput / FileContextProvider chains. +// the test DOM runtime) and VoiceInput / FileContextProvider chains. // PromptLogModal– pulls CopyFeedbackNew and deep modal dep chain. // AgentLogModal – pulls @remixicon/react (causes lint push error), useClickAway // from ahooks, and AgentLogDetail (workflow graph renderer). diff --git a/web/app/components/base/chat/chat/hooks.ts b/web/app/components/base/chat/chat/hooks.ts index 9c06f49b3d7..a0f335f567c 100644 --- a/web/app/components/base/chat/chat/hooks.ts +++ b/web/app/components/base/chat/chat/hooks.ts @@ -88,30 +88,54 @@ export const useChat = ( return processOpeningStatement(str, formSettings?.inputs || {}, formSettings?.inputsForm || []) }, [formSettings?.inputs, formSettings?.inputsForm]) + const processedOpeningContent = config?.opening_statement + ? getIntroduction(config.opening_statement) + : undefined + const processedSuggestionsKey = config?.suggested_questions + ? JSON.stringify(config.suggested_questions.map(q => getIntroduction(q))) + : undefined + + const openingStatementItem = useMemo(() => { + if (!processedOpeningContent) + return null + return { + id: 'opening-statement', + content: processedOpeningContent, + isAnswer: true, + isOpeningStatement: true, + suggestedQuestions: processedSuggestionsKey + ? JSON.parse(processedSuggestionsKey) as string[] + : undefined, + } + }, [processedOpeningContent, processedSuggestionsKey]) + + const threadOpener = useMemo( + () => threadMessages.find(item => item.isOpeningStatement) ?? null, + [threadMessages], + ) + + const mergedOpeningItem = useMemo(() => { + if (!threadOpener || !openingStatementItem) + return null + return { + ...threadOpener, + content: openingStatementItem.content, + suggestedQuestions: openingStatementItem.suggestedQuestions, + } + }, [threadOpener, openingStatementItem]) + /** Final chat list that will be rendered */ const chatList = useMemo(() => { const ret = [...threadMessages] - if (config?.opening_statement) { + if (openingStatementItem) { const index = threadMessages.findIndex(item => item.isOpeningStatement) - if (index > -1) { - ret[index] = { - ...ret[index], - content: getIntroduction(config.opening_statement), - suggestedQuestions: config.suggested_questions?.map(item => getIntroduction(item)), - } - } - else { - ret.unshift({ - id: 'opening-statement', - content: getIntroduction(config.opening_statement), - isAnswer: true, - isOpeningStatement: true, - suggestedQuestions: config.suggested_questions?.map(item => getIntroduction(item)), - }) - } + if (index > -1 && mergedOpeningItem) + ret[index] = mergedOpeningItem + else if (index === -1) + ret.unshift(openingStatementItem) } return ret - }, [threadMessages, config, getIntroduction]) + }, [threadMessages, openingStatementItem, mergedOpeningItem]) useEffect(() => { setAutoFreeze(false) diff --git a/web/app/components/base/chat/utils.ts b/web/app/components/base/chat/utils.ts index b47fec1d0ad..5881f565a49 100644 --- a/web/app/components/base/chat/utils.ts +++ b/web/app/components/base/chat/utils.ts @@ -158,7 +158,7 @@ function buildChatItemTree(allMessages: IChatItem[]): ChatItemInTree[] { rootNodes.push(questionNode) } else { - map[parentMessageId]?.children!.push(questionNode) + map[parentMessageId].children!.push(questionNode) } } } diff --git a/web/app/components/base/date-and-time-picker/calendar/__tests__/index.spec.tsx b/web/app/components/base/date-and-time-picker/calendar/__tests__/index.spec.tsx index d8e00780b1e..8839798c15e 100644 --- a/web/app/components/base/date-and-time-picker/calendar/__tests__/index.spec.tsx +++ b/web/app/components/base/date-and-time-picker/calendar/__tests__/index.spec.tsx @@ -3,7 +3,7 @@ import { fireEvent, render, screen } from '@testing-library/react' import dayjs from '../../utils/dayjs' import Calendar from '../index' -// Mock scrollIntoView since jsdom doesn't implement it +// Mock scrollIntoView since the test DOM runtime doesn't implement it beforeAll(() => { Element.prototype.scrollIntoView = vi.fn() }) diff --git a/web/app/components/base/date-and-time-picker/time-picker/__tests__/index.spec.tsx b/web/app/components/base/date-and-time-picker/time-picker/__tests__/index.spec.tsx index 910faf9cd42..199ed4ee411 100644 --- a/web/app/components/base/date-and-time-picker/time-picker/__tests__/index.spec.tsx +++ b/web/app/components/base/date-and-time-picker/time-picker/__tests__/index.spec.tsx @@ -3,7 +3,7 @@ import { fireEvent, render, screen, within } from '@testing-library/react' import dayjs, { isDayjsObject } from '../../utils/dayjs' import TimePicker from '../index' -// Mock scrollIntoView since jsdom doesn't implement it +// Mock scrollIntoView since the test DOM runtime doesn't implement it beforeAll(() => { Element.prototype.scrollIntoView = vi.fn() }) diff --git a/web/app/components/base/features/new-feature-panel/annotation-reply/config-param-modal.tsx b/web/app/components/base/features/new-feature-panel/annotation-reply/config-param-modal.tsx index 332b87cb300..ac0b6d0f579 100644 --- a/web/app/components/base/features/new-feature-panel/annotation-reply/config-param-modal.tsx +++ b/web/app/components/base/features/new-feature-panel/annotation-reply/config-param-modal.tsx @@ -93,7 +93,6 @@ const ConfigParamModal: FC = ({ className="mt-1" value={(annotationConfig.score_threshold || ANNOTATION_DEFAULT.score_threshold) * 100} onChange={(val) => { - /* v8 ignore next -- callback dispatch depends on react-slider drag mechanics that are flaky in jsdom. @preserve */ setAnnotationConfig({ ...annotationConfig, score_threshold: val / 100, diff --git a/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/__tests__/index.spec.tsx b/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/__tests__/index.spec.tsx index 2bc30e4ead7..ffa9c330439 100644 --- a/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/__tests__/index.spec.tsx +++ b/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/__tests__/index.spec.tsx @@ -1,20 +1,9 @@ import { render, screen } from '@testing-library/react' import ScoreSlider from '../index' -vi.mock('@/app/components/base/features/new-feature-panel/annotation-reply/score-slider/base-slider', () => ({ - default: ({ value, onChange, min, max }: { value: number, onChange: (v: number) => void, min: number, max: number }) => ( - onChange(Number(e.target.value))} - /> - ), -})) - describe('ScoreSlider', () => { + const getSliderInput = () => screen.getByLabelText('appDebug.feature.annotation.scoreThreshold.title') + beforeEach(() => { vi.clearAllMocks() }) @@ -22,7 +11,7 @@ describe('ScoreSlider', () => { it('should render the slider', () => { render() - expect(screen.getByTestId('slider')).toBeInTheDocument() + expect(getSliderInput()).toBeInTheDocument() }) it('should display easy match and accurate match labels', () => { @@ -37,14 +26,14 @@ describe('ScoreSlider', () => { it('should render with custom className', () => { const { container } = render() - // Verifying the component renders successfully with a custom className - expect(screen.getByTestId('slider')).toBeInTheDocument() + expect(getSliderInput()).toBeInTheDocument() expect(container.firstChild).toHaveClass('custom-class') }) it('should pass value to the slider', () => { render() - expect(screen.getByTestId('slider')).toHaveValue('95') + expect(getSliderInput()).toHaveValue('95') + expect(screen.getByText('0.95')).toBeInTheDocument() }) }) diff --git a/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/base-slider/__tests__/index.spec.tsx b/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/base-slider/__tests__/index.spec.tsx deleted file mode 100644 index 815e8ffe491..00000000000 --- a/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/base-slider/__tests__/index.spec.tsx +++ /dev/null @@ -1,50 +0,0 @@ -import { render, screen } from '@testing-library/react' -import Slider from '../index' - -describe('BaseSlider', () => { - beforeEach(() => { - vi.clearAllMocks() - }) - - it('should render the slider component', () => { - render() - - expect(screen.getByRole('slider')).toBeInTheDocument() - }) - - it('should display the formatted value in the thumb', () => { - render() - - expect(screen.getByText('0.85')).toBeInTheDocument() - }) - - it('should use default min/max/step when not provided', () => { - render() - - const slider = screen.getByRole('slider') - expect(slider).toHaveAttribute('aria-valuemin', '0') - expect(slider).toHaveAttribute('aria-valuemax', '100') - expect(slider).toHaveAttribute('aria-valuenow', '50') - }) - - it('should use custom min/max/step when provided', () => { - render() - - const slider = screen.getByRole('slider') - expect(slider).toHaveAttribute('aria-valuemin', '80') - expect(slider).toHaveAttribute('aria-valuemax', '100') - expect(slider).toHaveAttribute('aria-valuenow', '90') - }) - - it('should handle NaN value as 0', () => { - render() - - expect(screen.getByRole('slider')).toHaveAttribute('aria-valuenow', '0') - }) - - it('should pass disabled prop', () => { - render() - - expect(screen.getByRole('slider')).toHaveAttribute('aria-disabled', 'true') - }) -}) diff --git a/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/base-slider/index.tsx b/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/base-slider/index.tsx deleted file mode 100644 index 509426c08eb..00000000000 --- a/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/base-slider/index.tsx +++ /dev/null @@ -1,40 +0,0 @@ -import ReactSlider from 'react-slider' -import { cn } from '@/utils/classnames' -import s from './style.module.css' - -type ISliderProps = { - className?: string - value: number - max?: number - min?: number - step?: number - disabled?: boolean - onChange: (value: number) => void -} - -const Slider: React.FC = ({ className, max, min, step, value, disabled, onChange }) => { - return ( - ( -
    -
    -
    - {(state.valueNow / 100).toFixed(2)} -
    -
    -
    - )} - /> - ) -} - -export default Slider diff --git a/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/base-slider/style.module.css b/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/base-slider/style.module.css deleted file mode 100644 index 8ef23b54b5b..00000000000 --- a/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/base-slider/style.module.css +++ /dev/null @@ -1,20 +0,0 @@ -.slider { - position: relative; -} - -.slider.disabled { - opacity: 0.6; -} - -.slider-thumb:focus { - outline: none; -} - -.slider-track { - background-color: #528BFF; - height: 2px; -} - -.slider-track-1 { - background-color: #E5E7EB; -} diff --git a/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/index.tsx b/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/index.tsx index c6fb1a0b4e8..0363eb28208 100644 --- a/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/index.tsx +++ b/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/index.tsx @@ -2,7 +2,7 @@ import type { FC } from 'react' import * as React from 'react' import { useTranslation } from 'react-i18next' -import Slider from '@/app/components/base/features/new-feature-panel/annotation-reply/score-slider/base-slider' +import { Slider } from '@/app/components/base/ui/slider' type Props = { className?: string @@ -10,23 +10,42 @@ type Props = { onChange: (value: number) => void } +const clamp = (value: number, min: number, max: number) => { + if (!Number.isFinite(value)) + return min + + return Math.min(Math.max(value, min), max) +} + const ScoreSlider: FC = ({ className, value, onChange, }) => { const { t } = useTranslation() + const safeValue = clamp(value, 80, 100) return (
    -
    +
    +
    + {(safeValue / 100).toFixed(2)} +
    diff --git a/web/app/components/base/file-uploader/file-from-link-or-local/index.tsx b/web/app/components/base/file-uploader/file-from-link-or-local/index.tsx index 69496903a67..de9cc7ecd00 100644 --- a/web/app/components/base/file-uploader/file-from-link-or-local/index.tsx +++ b/web/app/components/base/file-uploader/file-from-link-or-local/index.tsx @@ -37,11 +37,11 @@ const FileFromLinkOrLocal = ({ const { handleLoadFileFromLink } = useFile(fileConfig) const disabled = !!fileConfig.number_limits && files.length >= fileConfig.number_limits const fileLinkPlaceholder = t('fileUploader.pasteFileLinkInputPlaceholder', { ns: 'common' }) - /* v8 ignore next -- fallback for missing i18n key is not reliably testable under current global translation mocks in jsdom @preserve */ + /* v8 ignore next -- fallback for a missing i18n key is not reliably testable under the current global translation mocks in the test DOM runtime. @preserve */ const fileLinkPlaceholderText = fileLinkPlaceholder || '' const handleSaveUrl = () => { - /* v8 ignore next -- guarded by UI-level disabled state (`disabled={!url || disabled}`), not reachable in jsdom click flow @preserve */ + /* v8 ignore next -- guarded by UI-level disabled state (`disabled={!url || disabled}`), not reachable in the current test click flow. @preserve */ if (!url) return diff --git a/web/app/components/base/icons/__tests__/utils.spec.ts b/web/app/components/base/icons/__tests__/utils.spec.ts index a25f39111d9..f8534038bfa 100644 --- a/web/app/components/base/icons/__tests__/utils.spec.ts +++ b/web/app/components/base/icons/__tests__/utils.spec.ts @@ -62,7 +62,7 @@ describe('generate icon base utils', () => { const { container } = render(generate(node, 'key')) // to svg element expect(container.firstChild).toHaveClass('container') - expect(container.querySelector('span')).toHaveStyle({ color: 'rgb(0, 0, 255)' }) + expect(container.querySelector('span')).toHaveStyle({ color: 'blue' }) }) // add not has children diff --git a/web/app/components/base/input/__tests__/index.spec.tsx b/web/app/components/base/input/__tests__/index.spec.tsx index 2c5b563a127..dfab8617c28 100644 --- a/web/app/components/base/input/__tests__/index.spec.tsx +++ b/web/app/components/base/input/__tests__/index.spec.tsx @@ -99,7 +99,7 @@ describe('Input component', () => { render() const input = screen.getByPlaceholderText(/input/i) expect(input).toHaveClass(customClass) - expect(input).toHaveStyle({ color: 'rgb(255, 0, 0)' }) + expect(input).toHaveStyle({ color: 'red' }) }) it('applies large size variant correctly', () => { diff --git a/web/app/components/base/markdown-blocks/__tests__/code-block.spec.tsx b/web/app/components/base/markdown-blocks/__tests__/code-block.spec.tsx index 308232fd0f4..a16686801c6 100644 --- a/web/app/components/base/markdown-blocks/__tests__/code-block.spec.tsx +++ b/web/app/components/base/markdown-blocks/__tests__/code-block.spec.tsx @@ -1,6 +1,6 @@ -import { createRequire } from 'node:module' import { act, render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' +import * as echarts from 'echarts' import { Theme } from '@/types/app' import CodeBlock from '../code-block' @@ -10,17 +10,28 @@ type UseThemeReturn = { } const mockUseTheme = vi.fn<() => UseThemeReturn>(() => ({ theme: Theme.light })) -const require = createRequire(import.meta.url) -const echartsCjs = require('echarts') as { - getInstanceByDom: (dom: HTMLDivElement | null) => { - resize: (opts?: { width?: string, height?: string }) => void - } | null -} +const mockEcharts = vi.hoisted(() => { + const state = { + finishedHandler: undefined as undefined | ((event?: unknown) => void), + echartsInstance: { + resize: vi.fn<(opts?: { width?: string, height?: string }) => void>(), + trigger: vi.fn((eventName: string, event?: unknown) => { + if (eventName === 'finished') + state.finishedHandler?.(event) + }), + }, + getInstanceByDom: vi.fn(() => state.echartsInstance), + } + + return state +}) let clientWidthSpy: { mockRestore: () => void } | null = null let clientHeightSpy: { mockRestore: () => void } | null = null let offsetWidthSpy: { mockRestore: () => void } | null = null let offsetHeightSpy: { mockRestore: () => void } | null = null +let consoleErrorSpy: ReturnType | null = null +let consoleWarnSpy: ReturnType | null = null type AudioContextCtor = new () => unknown type WindowWithLegacyAudio = Window & { @@ -59,6 +70,42 @@ vi.mock('@/hooks/use-theme', () => ({ default: () => mockUseTheme(), })) +vi.mock('echarts', () => ({ + getInstanceByDom: mockEcharts.getInstanceByDom, +})) + +vi.mock('echarts-for-react', async () => { + const React = await vi.importActual('react') + + const MockReactEcharts = React.forwardRef(({ + onChartReady, + onEvents, + }: { + onChartReady?: (instance: typeof mockEcharts.echartsInstance) => void + onEvents?: { finished?: (event?: unknown) => void } + }, ref: React.ForwardedRef<{ getEchartsInstance: () => typeof mockEcharts.echartsInstance }>) => { + React.useImperativeHandle(ref, () => ({ + getEchartsInstance: () => mockEcharts.echartsInstance, + })) + + React.useEffect(() => { + mockEcharts.finishedHandler = onEvents?.finished + onChartReady?.(mockEcharts.echartsInstance) + onEvents?.finished?.({}) + return () => { + mockEcharts.finishedHandler = undefined + } + }, [onChartReady, onEvents]) + + return
    + }) + + return { + __esModule: true, + default: MockReactEcharts, + } +}) + vi.mock('@/app/components/base/mermaid', () => ({ __esModule: true, default: ({ PrimitiveCode }: { PrimitiveCode: string }) =>
    {PrimitiveCode}
    , @@ -74,15 +121,17 @@ const findEchartsHost = async () => { const findEchartsInstance = async () => { const host = await findEchartsHost() await waitFor(() => { - expect(echartsCjs.getInstanceByDom(host)).toBeTruthy() + expect(echarts.getInstanceByDom(host)).toBeTruthy() }) - return echartsCjs.getInstanceByDom(host)! + return echarts.getInstanceByDom(host)! } describe('CodeBlock', () => { beforeEach(() => { vi.clearAllMocks() mockUseTheme.mockReturnValue({ theme: Theme.light }) + consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) + consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) clientWidthSpy = vi.spyOn(HTMLElement.prototype, 'clientWidth', 'get').mockReturnValue(900) clientHeightSpy = vi.spyOn(HTMLElement.prototype, 'clientHeight', 'get').mockReturnValue(400) offsetWidthSpy = vi.spyOn(HTMLElement.prototype, 'offsetWidth', 'get').mockReturnValue(900) @@ -98,6 +147,10 @@ describe('CodeBlock', () => { afterEach(() => { vi.useRealTimers() + consoleErrorSpy?.mockRestore() + consoleWarnSpy?.mockRestore() + consoleErrorSpy = null + consoleWarnSpy = null clientWidthSpy?.mockRestore() clientHeightSpy?.mockRestore() offsetWidthSpy?.mockRestore() diff --git a/web/app/components/base/markdown-blocks/__tests__/think-block.spec.tsx b/web/app/components/base/markdown-blocks/__tests__/think-block.spec.tsx index e8b956cbbfd..4f224681575 100644 --- a/web/app/components/base/markdown-blocks/__tests__/think-block.spec.tsx +++ b/web/app/components/base/markdown-blocks/__tests__/think-block.spec.tsx @@ -163,25 +163,16 @@ describe('ThinkBlock', () => { expect(screen.getByText(/Thought/)).toBeInTheDocument() }) - it('should NOT stop timer when isResponding is undefined (outside ChatContextProvider)', () => { - // Render without ChatContextProvider + it('should stop timer when isResponding is undefined (historical conversation outside active response)', () => { + // Render without ChatContextProvider — simulates historical conversation render(

    Content without ENDTHINKFLAG

    , ) - // Initial state should show "Thinking..." - expect(screen.getByText(/Thinking\.\.\./)).toBeInTheDocument() - - // Advance timer - act(() => { - vi.advanceTimersByTime(2000) - }) - - // Timer should still be running (showing "Thinking..." not "Thought") - expect(screen.getByText(/Thinking\.\.\./)).toBeInTheDocument() - expect(screen.getByText(/\(2\.0s\)/)).toBeInTheDocument() + // Timer should be stopped immediately — isResponding undefined means not in active response + expect(screen.getByText(/Thought/)).toBeInTheDocument() }) }) diff --git a/web/app/components/base/markdown-blocks/code-block.tsx b/web/app/components/base/markdown-blocks/code-block.tsx index b36d8d77885..412c61d52d3 100644 --- a/web/app/components/base/markdown-blocks/code-block.tsx +++ b/web/app/components/base/markdown-blocks/code-block.tsx @@ -85,13 +85,30 @@ const CodeBlock: any = memo(({ inline, className, children = '', ...props }: any const processedRef = useRef(false) // Track if content was successfully processed const isInitialRenderRef = useRef(true) // Track if this is initial render const chartInstanceRef = useRef(null) // Direct reference to ECharts instance - const resizeTimerRef = useRef(null) // For debounce handling + const resizeTimerRef = useRef | null>(null) // For debounce handling + const chartReadyTimerRef = useRef | null>(null) const finishedEventCountRef = useRef(0) // Track finished event trigger count const match = /language-(\w+)/.exec(className || '') const language = match?.[1] const languageShowName = getCorrectCapitalizationLanguageName(language || '') const isDarkMode = theme === Theme.dark + const clearResizeTimer = useCallback(() => { + if (!resizeTimerRef.current) + return + + clearTimeout(resizeTimerRef.current) + resizeTimerRef.current = null + }, []) + + const clearChartReadyTimer = useCallback(() => { + if (!chartReadyTimerRef.current) + return + + clearTimeout(chartReadyTimerRef.current) + chartReadyTimerRef.current = null + }, []) + const echartsStyle = useMemo(() => ({ height: '350px', width: '100%', @@ -104,26 +121,27 @@ const CodeBlock: any = memo(({ inline, className, children = '', ...props }: any // Debounce resize operations const debouncedResize = useCallback(() => { - if (resizeTimerRef.current) - clearTimeout(resizeTimerRef.current) + clearResizeTimer() resizeTimerRef.current = setTimeout(() => { if (chartInstanceRef.current) chartInstanceRef.current.resize() resizeTimerRef.current = null }, 200) - }, []) + }, [clearResizeTimer]) // Handle ECharts instance initialization const handleChartReady = useCallback((instance: any) => { chartInstanceRef.current = instance // Force resize to ensure timeline displays correctly - setTimeout(() => { + clearChartReadyTimer() + chartReadyTimerRef.current = setTimeout(() => { if (chartInstanceRef.current) chartInstanceRef.current.resize() + chartReadyTimerRef.current = null }, 200) - }, []) + }, [clearChartReadyTimer]) // Store event handlers in useMemo to avoid recreating them const echartsEvents = useMemo(() => ({ @@ -157,10 +175,20 @@ const CodeBlock: any = memo(({ inline, className, children = '', ...props }: any return () => { window.removeEventListener('resize', handleResize) - if (resizeTimerRef.current) - clearTimeout(resizeTimerRef.current) + clearResizeTimer() + clearChartReadyTimer() + chartInstanceRef.current = null } - }, [language, debouncedResize]) + }, [language, debouncedResize, clearResizeTimer, clearChartReadyTimer]) + + useEffect(() => { + return () => { + clearResizeTimer() + clearChartReadyTimer() + chartInstanceRef.current = null + echartsRef.current = null + } + }, [clearResizeTimer, clearChartReadyTimer]) // Process chart data when content changes useEffect(() => { // Only process echarts content diff --git a/web/app/components/base/markdown-blocks/think-block.tsx b/web/app/components/base/markdown-blocks/think-block.tsx index f920218152e..184ed892745 100644 --- a/web/app/components/base/markdown-blocks/think-block.tsx +++ b/web/app/components/base/markdown-blocks/think-block.tsx @@ -39,9 +39,10 @@ const removeEndThink = (children: any): any => { const useThinkTimer = (children: any) => { const { isResponding } = useChatContext() + const endThinkDetected = hasEndThink(children) const [startTime] = useState(() => Date.now()) const [elapsedTime, setElapsedTime] = useState(0) - const [isComplete, setIsComplete] = useState(false) + const [isComplete, setIsComplete] = useState(() => endThinkDetected) const timerRef = useRef(null) useEffect(() => { @@ -61,11 +62,10 @@ const useThinkTimer = (children: any) => { useEffect(() => { // Stop timer when: // 1. Content has [ENDTHINKFLAG] marker (normal completion) - // 2. isResponding is explicitly false (user clicked stop button) - // Note: Don't stop when isResponding is undefined (component used outside ChatContextProvider) - if (hasEndThink(children) || isResponding === false) + // 2. isResponding is not true (false = user clicked stop, undefined = historical conversation) + if (endThinkDetected || !isResponding) setIsComplete(true) - }, [children, isResponding]) + }, [endThinkDetected, isResponding]) return { elapsedTime, isComplete } } diff --git a/web/app/components/base/node-status/__tests__/index.spec.tsx b/web/app/components/base/node-status/__tests__/index.spec.tsx index f74af4965e7..37b12946c8e 100644 --- a/web/app/components/base/node-status/__tests__/index.spec.tsx +++ b/web/app/components/base/node-status/__tests__/index.spec.tsx @@ -41,7 +41,7 @@ describe('NodeStatus', () => { it('applies styleCss correctly', () => { const { container } = render() - expect(container.firstChild).toHaveStyle({ color: 'rgb(255, 0, 0)' }) + expect(container.firstChild).toHaveStyle({ color: 'red' }) }) it('applies iconClassName to the icon', () => { diff --git a/web/app/components/base/pagination/__tests__/pagination.spec.tsx b/web/app/components/base/pagination/__tests__/pagination.spec.tsx index 776802ff19a..06eac9bfbd8 100644 --- a/web/app/components/base/pagination/__tests__/pagination.spec.tsx +++ b/web/app/components/base/pagination/__tests__/pagination.spec.tsx @@ -131,7 +131,7 @@ describe('Pagination', () => { setCurrentPage, children: Prev, }) - fireEvent.keyPress(screen.getByText(/prev/i), { key: 'Enter', charCode: 13 }) + fireEvent.keyDown(screen.getByText(/prev/i).closest('button')!, { key: 'Enter', code: 'Enter', keyCode: 13, which: 13 }) expect(setCurrentPage).toHaveBeenCalledWith(2) }) @@ -142,7 +142,7 @@ describe('Pagination', () => { setCurrentPage, children: Prev, }) - fireEvent.keyPress(screen.getByText(/prev/i), { key: 'Enter', charCode: 13 }) + fireEvent.keyDown(screen.getByText(/prev/i).closest('button')!, { key: 'Enter', code: 'Enter', keyCode: 13, which: 13 }) expect(setCurrentPage).not.toHaveBeenCalled() }) @@ -213,7 +213,7 @@ describe('Pagination', () => { setCurrentPage, children: Next, }) - fireEvent.keyPress(screen.getByText(/next/i), { key: 'Enter', charCode: 13 }) + fireEvent.keyDown(screen.getByText(/next/i).closest('button')!, { key: 'Enter', code: 'Enter', keyCode: 13, which: 13 }) expect(setCurrentPage).toHaveBeenCalledWith(1) }) @@ -225,7 +225,7 @@ describe('Pagination', () => { setCurrentPage, children: Next, }) - fireEvent.keyPress(screen.getByText(/next/i), { key: 'Enter', charCode: 13 }) + fireEvent.keyDown(screen.getByText(/next/i).closest('button')!, { key: 'Enter', code: 'Enter', keyCode: 13, which: 13 }) expect(setCurrentPage).not.toHaveBeenCalled() }) @@ -318,7 +318,7 @@ describe('Pagination', () => { /> ), }) - fireEvent.keyPress(screen.getByText('4'), { key: 'Enter', charCode: 13 }) + fireEvent.keyDown(screen.getByText('4').closest('a')!, { key: 'Enter', code: 'Enter', keyCode: 13, which: 13 }) expect(setCurrentPage).toHaveBeenCalledWith(3) // 0-indexed }) diff --git a/web/app/components/base/pagination/pagination.tsx b/web/app/components/base/pagination/pagination.tsx index 0eb06b594c7..b258090d80d 100644 --- a/web/app/components/base/pagination/pagination.tsx +++ b/web/app/components/base/pagination/pagination.tsx @@ -50,7 +50,7 @@ export const PrevButton = ({ tabIndex={disabled ? '-1' : 0} disabled={disabled} data-testid={dataTestId} - onKeyPress={(event: React.KeyboardEvent) => { + onKeyDown={(event: React.KeyboardEvent) => { event.preventDefault() if (event.key === 'Enter' && !disabled) previous() @@ -85,7 +85,7 @@ export const NextButton = ({ tabIndex={disabled ? '-1' : 0} disabled={disabled} data-testid={dataTestId} - onKeyPress={(event: React.KeyboardEvent) => { + onKeyDown={(event: React.KeyboardEvent) => { event.preventDefault() if (event.key === 'Enter' && !disabled) next() @@ -140,7 +140,7 @@ export const PageButton = ({ }) || undefined } tabIndex={0} - onKeyPress={(event: React.KeyboardEvent) => { + onKeyDown={(event: React.KeyboardEvent) => { if (event.key === 'Enter') pagination.setCurrentPage(page - 1) }} diff --git a/web/app/components/base/param-item/__tests__/index-slider.spec.tsx b/web/app/components/base/param-item/__tests__/index-slider.spec.tsx index 0048b896442..64488358449 100644 --- a/web/app/components/base/param-item/__tests__/index-slider.spec.tsx +++ b/web/app/components/base/param-item/__tests__/index-slider.spec.tsx @@ -14,12 +14,14 @@ describe('ParamItem Slider onChange', () => { vi.clearAllMocks() }) + const getSlider = () => screen.getByLabelText('Test Param') + it('should divide slider value by 100 when max < 5', async () => { const user = userEvent.setup() render() - const slider = screen.getByRole('slider') + const slider = getSlider() - await user.click(slider) + slider.focus() await user.keyboard('{ArrowRight}') // max=1 < 5, so slider value change (50->51) becomes 0.51 @@ -29,9 +31,9 @@ describe('ParamItem Slider onChange', () => { it('should not divide slider value when max >= 5', async () => { const user = userEvent.setup() render() - const slider = screen.getByRole('slider') + const slider = getSlider() - await user.click(slider) + slider.focus() await user.keyboard('{ArrowRight}') // max=10 >= 5, so value remains raw (5->6) diff --git a/web/app/components/base/param-item/__tests__/index.spec.tsx b/web/app/components/base/param-item/__tests__/index.spec.tsx index 96591446c80..889662c87db 100644 --- a/web/app/components/base/param-item/__tests__/index.spec.tsx +++ b/web/app/components/base/param-item/__tests__/index.spec.tsx @@ -17,6 +17,8 @@ describe('ParamItem', () => { vi.clearAllMocks() }) + const getSlider = () => screen.getByLabelText('Test Param') + describe('Rendering', () => { it('should render the parameter name', () => { render() @@ -54,7 +56,7 @@ describe('ParamItem', () => { render() expect(screen.getByRole('textbox')).toBeInTheDocument() - expect(screen.getByRole('slider')).toBeInTheDocument() + expect(getSlider()).toBeInTheDocument() }) }) @@ -74,7 +76,7 @@ describe('ParamItem', () => { it('should disable Slider when enable is false', () => { render() - expect(screen.getByRole('slider')).toHaveAttribute('aria-disabled', 'true') + expect(getSlider()).toBeDisabled() }) it('should set switch value based on enable prop', () => { @@ -135,7 +137,7 @@ describe('ParamItem', () => { await user.clear(input) expect(defaultProps.onChange).toHaveBeenLastCalledWith('test_param', 0) - expect(screen.getByRole('slider')).toHaveAttribute('aria-valuenow', '0') + expect(getSlider()).toHaveAttribute('aria-valuenow', '0') await user.tab() @@ -166,12 +168,12 @@ describe('ParamItem', () => { await user.type(input, '1.5') expect(defaultProps.onChange).toHaveBeenLastCalledWith('test_param', 1) - expect(screen.getByRole('slider')).toHaveAttribute('aria-valuenow', '100') + expect(getSlider()).toHaveAttribute('aria-valuenow', '100') }) it('should pass scaled value to slider when max < 5', () => { render() - const slider = screen.getByRole('slider') + const slider = getSlider() // When max < 5, slider value = value * 100 = 50 expect(slider).toHaveAttribute('aria-valuenow', '50') @@ -179,7 +181,7 @@ describe('ParamItem', () => { it('should pass raw value to slider when max >= 5', () => { render() - const slider = screen.getByRole('slider') + const slider = getSlider() // When max >= 5, slider value = value = 5 expect(slider).toHaveAttribute('aria-valuenow', '5') @@ -212,15 +214,15 @@ describe('ParamItem', () => { render() // Slider should get value * 100 = 50, min * 100 = 0, max * 100 = 100 - const slider = screen.getByRole('slider') - expect(slider).toHaveAttribute('aria-valuemax', '100') + const slider = getSlider() + expect(slider).toHaveAttribute('max', '100') }) it('should not scale slider value when max >= 5', () => { render() - const slider = screen.getByRole('slider') - expect(slider).toHaveAttribute('aria-valuemax', '10') + const slider = getSlider() + expect(slider).toHaveAttribute('max', '10') }) it('should expose default minimum of 0 when min is not provided', () => { diff --git a/web/app/components/base/param-item/__tests__/score-threshold-item.spec.tsx b/web/app/components/base/param-item/__tests__/score-threshold-item.spec.tsx index 54a13e1b749..ddc286942b4 100644 --- a/web/app/components/base/param-item/__tests__/score-threshold-item.spec.tsx +++ b/web/app/components/base/param-item/__tests__/score-threshold-item.spec.tsx @@ -14,6 +14,8 @@ describe('ScoreThresholdItem', () => { vi.clearAllMocks() }) + const getSlider = () => screen.getByLabelText('appDebug.datasetConfig.score_threshold') + describe('Rendering', () => { it('should render the translated parameter name', () => { render() @@ -32,7 +34,7 @@ describe('ScoreThresholdItem', () => { render() expect(screen.getByRole('textbox')).toBeInTheDocument() - expect(screen.getByRole('slider')).toBeInTheDocument() + expect(getSlider()).toBeInTheDocument() }) }) @@ -63,7 +65,7 @@ describe('ScoreThresholdItem', () => { render() expect(screen.getByRole('textbox')).toBeDisabled() - expect(screen.getByRole('slider')).toHaveAttribute('aria-disabled', 'true') + expect(getSlider()).toBeDisabled() }) }) diff --git a/web/app/components/base/param-item/__tests__/top-k-item.spec.tsx b/web/app/components/base/param-item/__tests__/top-k-item.spec.tsx index 1b8555213b9..c84fd505186 100644 --- a/web/app/components/base/param-item/__tests__/top-k-item.spec.tsx +++ b/web/app/components/base/param-item/__tests__/top-k-item.spec.tsx @@ -19,6 +19,8 @@ describe('TopKItem', () => { vi.clearAllMocks() }) + const getSlider = () => screen.getByLabelText('appDebug.datasetConfig.top_k') + describe('Rendering', () => { it('should render the translated parameter name', () => { render() @@ -37,7 +39,7 @@ describe('TopKItem', () => { render() expect(screen.getByRole('textbox')).toBeInTheDocument() - expect(screen.getByRole('slider')).toBeInTheDocument() + expect(getSlider()).toBeInTheDocument() }) }) @@ -52,7 +54,7 @@ describe('TopKItem', () => { render() expect(screen.getByRole('textbox')).toBeDisabled() - expect(screen.getByRole('slider')).toHaveAttribute('aria-disabled', 'true') + expect(getSlider()).toBeDisabled() }) }) @@ -77,10 +79,10 @@ describe('TopKItem', () => { it('should render slider with max >= 5 so no scaling is applied', () => { render() - const slider = screen.getByRole('slider') + const slider = getSlider() // max=10 >= 5 so slider shows raw values - expect(slider).toHaveAttribute('aria-valuemax', '10') + expect(slider).toHaveAttribute('max', '10') }) it('should not render a switch (no hasSwitch prop)', () => { @@ -116,9 +118,9 @@ describe('TopKItem', () => { it('should call onChange with integer value when slider changes', async () => { const user = userEvent.setup() render() - const slider = screen.getByRole('slider') + const slider = getSlider() - await user.click(slider) + slider.focus() await user.keyboard('{ArrowRight}') expect(defaultProps.onChange).toHaveBeenLastCalledWith('top_k', 3) diff --git a/web/app/components/base/param-item/index.tsx b/web/app/components/base/param-item/index.tsx index 63af4bca842..56999fc6ea6 100644 --- a/web/app/components/base/param-item/index.tsx +++ b/web/app/components/base/param-item/index.tsx @@ -1,8 +1,8 @@ 'use client' import type { FC } from 'react' -import Slider from '@/app/components/base/slider' import Switch from '@/app/components/base/switch' import Tooltip from '@/app/components/base/tooltip' +import { Slider } from '@/app/components/base/ui/slider' import { NumberField, NumberFieldControls, @@ -78,7 +78,8 @@ const ParamItem: FC = ({ className, id, name, noTooltip, tip, step = 0.1, value={max < 5 ? value * 100 : value} min={min < 1 ? min * 100 : min} max={max < 5 ? max * 100 : max} - onChange={value => onChange(id, value / (max < 5 ? 100 : 1))} + onValueChange={value => onChange(id, value / (max < 5 ? 100 : 1))} + aria-label={name} />
    diff --git a/web/app/components/base/premium-badge/__tests__/index.spec.tsx b/web/app/components/base/premium-badge/__tests__/index.spec.tsx index af8ace22f03..d107c07e52e 100644 --- a/web/app/components/base/premium-badge/__tests__/index.spec.tsx +++ b/web/app/components/base/premium-badge/__tests__/index.spec.tsx @@ -41,6 +41,6 @@ describe('PremiumBadge', () => { ) const badge = screen.getByText('Premium') expect(badge).toBeInTheDocument() - expect(badge).toHaveStyle('background-color: rgb(255, 0, 0)') // Note: React converts 'red' to 'rgb(255, 0, 0)' + expect(badge).toHaveStyle('background-color: red') }) }) diff --git a/web/app/components/base/prompt-editor/plugins/__tests__/on-blur-or-focus-block.spec.tsx b/web/app/components/base/prompt-editor/plugins/__tests__/on-blur-or-focus-block.spec.tsx index dd2f74f7e50..a16ae9d8233 100644 --- a/web/app/components/base/prompt-editor/plugins/__tests__/on-blur-or-focus-block.spec.tsx +++ b/web/app/components/base/prompt-editor/plugins/__tests__/on-blur-or-focus-block.spec.tsx @@ -3,13 +3,10 @@ import { LexicalComposer } from '@lexical/react/LexicalComposer' import { act, render, waitFor } from '@testing-library/react' import { BLUR_COMMAND, - COMMAND_PRIORITY_EDITOR, FOCUS_COMMAND, - KEY_ESCAPE_COMMAND, } from 'lexical' import OnBlurBlock from '../on-blur-or-focus-block' import { CaptureEditorPlugin } from '../test-utils' -import { CLEAR_HIDE_MENU_TIMEOUT } from '../workflow-variable-block' const renderOnBlurBlock = (props?: { onBlur?: () => void @@ -75,7 +72,7 @@ describe('OnBlurBlock', () => { expect(onFocus).toHaveBeenCalledTimes(1) }) - it('should call onBlur and dispatch escape after delay when blur target is not var-search-input', async () => { + it('should call onBlur when blur target is not var-search-input', async () => { const onBlur = vi.fn() const { getEditor } = renderOnBlurBlock({ onBlur }) @@ -85,14 +82,6 @@ describe('OnBlurBlock', () => { const editor = getEditor() expect(editor).not.toBeNull() - vi.useFakeTimers() - - const onEscape = vi.fn(() => true) - const unregister = editor!.registerCommand( - KEY_ESCAPE_COMMAND, - onEscape, - COMMAND_PRIORITY_EDITOR, - ) let handled = false act(() => { @@ -101,18 +90,9 @@ describe('OnBlurBlock', () => { expect(handled).toBe(true) expect(onBlur).toHaveBeenCalledTimes(1) - expect(onEscape).not.toHaveBeenCalled() - - act(() => { - vi.advanceTimersByTime(200) - }) - - expect(onEscape).toHaveBeenCalledTimes(1) - unregister() - vi.useRealTimers() }) - it('should dispatch delayed escape when onBlur callback is not provided', async () => { + it('should handle blur when onBlur callback is not provided', async () => { const { getEditor } = renderOnBlurBlock() await waitFor(() => { @@ -121,28 +101,16 @@ describe('OnBlurBlock', () => { const editor = getEditor() expect(editor).not.toBeNull() - vi.useFakeTimers() - - const onEscape = vi.fn(() => true) - const unregister = editor!.registerCommand( - KEY_ESCAPE_COMMAND, - onEscape, - COMMAND_PRIORITY_EDITOR, - ) + let handled = false act(() => { - editor!.dispatchCommand(BLUR_COMMAND, createBlurEvent(document.createElement('div'))) - }) - act(() => { - vi.advanceTimersByTime(200) + handled = editor!.dispatchCommand(BLUR_COMMAND, createBlurEvent(document.createElement('div'))) }) - expect(onEscape).toHaveBeenCalledTimes(1) - unregister() - vi.useRealTimers() + expect(handled).toBe(true) }) - it('should skip onBlur and delayed escape when blur target is var-search-input', async () => { + it('should skip onBlur when blur target is var-search-input', async () => { const onBlur = vi.fn() const { getEditor } = renderOnBlurBlock({ onBlur }) @@ -152,31 +120,17 @@ describe('OnBlurBlock', () => { const editor = getEditor() expect(editor).not.toBeNull() - vi.useFakeTimers() const target = document.createElement('input') target.classList.add('var-search-input') - const onEscape = vi.fn(() => true) - const unregister = editor!.registerCommand( - KEY_ESCAPE_COMMAND, - onEscape, - COMMAND_PRIORITY_EDITOR, - ) - let handled = false act(() => { handled = editor!.dispatchCommand(BLUR_COMMAND, createBlurEvent(target)) }) - act(() => { - vi.advanceTimersByTime(200) - }) expect(handled).toBe(true) expect(onBlur).not.toHaveBeenCalled() - expect(onEscape).not.toHaveBeenCalled() - unregister() - vi.useRealTimers() }) it('should handle focus command when onFocus callback is not provided', async () => { @@ -198,59 +152,6 @@ describe('OnBlurBlock', () => { }) }) - describe('Clear timeout command', () => { - it('should clear scheduled escape timeout when clear command is dispatched', async () => { - const { getEditor } = renderOnBlurBlock({ onBlur: vi.fn() }) - - await waitFor(() => { - expect(getEditor()).not.toBeNull() - }) - - const editor = getEditor() - expect(editor).not.toBeNull() - vi.useFakeTimers() - - const onEscape = vi.fn(() => true) - const unregister = editor!.registerCommand( - KEY_ESCAPE_COMMAND, - onEscape, - COMMAND_PRIORITY_EDITOR, - ) - - act(() => { - editor!.dispatchCommand(BLUR_COMMAND, createBlurEvent(document.createElement('div'))) - }) - act(() => { - editor!.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined) - }) - act(() => { - vi.advanceTimersByTime(200) - }) - - expect(onEscape).not.toHaveBeenCalled() - unregister() - vi.useRealTimers() - }) - - it('should handle clear command when no timeout is scheduled', async () => { - const { getEditor } = renderOnBlurBlock() - - await waitFor(() => { - expect(getEditor()).not.toBeNull() - }) - - const editor = getEditor() - expect(editor).not.toBeNull() - - let handled = false - act(() => { - handled = editor!.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined) - }) - - expect(handled).toBe(true) - }) - }) - describe('Lifecycle cleanup', () => { it('should unregister commands when component unmounts', async () => { const { getEditor, unmount } = renderOnBlurBlock() @@ -266,16 +167,13 @@ describe('OnBlurBlock', () => { let blurHandled = true let focusHandled = true - let clearHandled = true act(() => { blurHandled = editor!.dispatchCommand(BLUR_COMMAND, createBlurEvent(document.createElement('div'))) focusHandled = editor!.dispatchCommand(FOCUS_COMMAND, createFocusEvent()) - clearHandled = editor!.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined) }) expect(blurHandled).toBe(false) expect(focusHandled).toBe(false) - expect(clearHandled).toBe(false) }) }) }) diff --git a/web/app/components/base/prompt-editor/plugins/__tests__/update-block.spec.tsx b/web/app/components/base/prompt-editor/plugins/__tests__/update-block.spec.tsx index 8f6a72a7def..4283910c318 100644 --- a/web/app/components/base/prompt-editor/plugins/__tests__/update-block.spec.tsx +++ b/web/app/components/base/prompt-editor/plugins/__tests__/update-block.spec.tsx @@ -1,14 +1,13 @@ import type { LexicalEditor } from 'lexical' import { LexicalComposer } from '@lexical/react/LexicalComposer' import { act, render, waitFor } from '@testing-library/react' -import { $getRoot, COMMAND_PRIORITY_EDITOR } from 'lexical' +import { $getRoot } from 'lexical' import { CustomTextNode } from '../custom-text/node' import { CaptureEditorPlugin } from '../test-utils' import UpdateBlock, { PROMPT_EDITOR_INSERT_QUICKLY, PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER, } from '../update-block' -import { CLEAR_HIDE_MENU_TIMEOUT } from '../workflow-variable-block' const { mockUseEventEmitterContextContext } = vi.hoisted(() => ({ mockUseEventEmitterContextContext: vi.fn(), @@ -157,7 +156,7 @@ describe('UpdateBlock', () => { }) describe('Quick insert event', () => { - it('should insert slash and dispatch clear command when quick insert event matches instance id', async () => { + it('should insert slash when quick insert event matches instance id', async () => { const { emit, getEditor } = setup({ instanceId: 'instance-1' }) await waitFor(() => { @@ -168,13 +167,6 @@ describe('UpdateBlock', () => { selectRootEnd(editor!) - const clearCommandHandler = vi.fn(() => true) - const unregister = editor!.registerCommand( - CLEAR_HIDE_MENU_TIMEOUT, - clearCommandHandler, - COMMAND_PRIORITY_EDITOR, - ) - emit({ type: PROMPT_EDITOR_INSERT_QUICKLY, instanceId: 'instance-1', @@ -183,9 +175,6 @@ describe('UpdateBlock', () => { await waitFor(() => { expect(readEditorText(editor!)).toBe('/') }) - expect(clearCommandHandler).toHaveBeenCalledTimes(1) - - unregister() }) it('should ignore quick insert event when instance id does not match', async () => { diff --git a/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx b/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx index 6cc6c3a67fe..51b14b76c82 100644 --- a/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx +++ b/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx @@ -23,6 +23,8 @@ import { $createTextNode, $getRoot, $setSelection, + BLUR_COMMAND, + FOCUS_COMMAND, KEY_ESCAPE_COMMAND, } from 'lexical' import * as React from 'react' @@ -631,4 +633,180 @@ describe('ComponentPicker (component-picker-block/index.tsx)', () => { // With a single option group, the only divider should be the workflow-var/options separator. expect(document.querySelectorAll('.bg-divider-subtle')).toHaveLength(1) }) + + describe('blur/focus menu visibility', () => { + it('hides the menu after a 200ms delay when blur command is dispatched', async () => { + const captures: Captures = { editor: null, eventEmitter: null } + + render(( + + )) + + const editor = await waitForEditor(captures) + await setEditorText(editor, '{', true) + expect(await screen.findByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + vi.useFakeTimers() + + act(() => { + editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: document.createElement('button') })) + }) + + expect(screen.queryByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + act(() => { + vi.advanceTimersByTime(200) + }) + + expect(screen.queryByText('common.promptEditor.context.item.title')).not.toBeInTheDocument() + + vi.useRealTimers() + }) + + it('restores menu visibility when focus command is dispatched after blur hides it', async () => { + const captures: Captures = { editor: null, eventEmitter: null } + + render(( + + )) + + const editor = await waitForEditor(captures) + await setEditorText(editor, '{', true) + expect(await screen.findByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + vi.useFakeTimers() + + act(() => { + editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: document.createElement('button') })) + }) + act(() => { + vi.advanceTimersByTime(200) + }) + + expect(screen.queryByText('common.promptEditor.context.item.title')).not.toBeInTheDocument() + + act(() => { + editor.dispatchCommand(FOCUS_COMMAND, new FocusEvent('focus')) + }) + + vi.useRealTimers() + + await setEditorText(editor, '{', true) + await waitFor(() => { + expect(screen.queryByText('common.promptEditor.context.item.title')).toBeInTheDocument() + }) + }) + + it('cancels the blur timer when focus arrives before the 200ms timeout', async () => { + const captures: Captures = { editor: null, eventEmitter: null } + + render(( + + )) + + const editor = await waitForEditor(captures) + await setEditorText(editor, '{', true) + expect(await screen.findByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + vi.useFakeTimers() + + act(() => { + editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: document.createElement('button') })) + }) + + act(() => { + editor.dispatchCommand(FOCUS_COMMAND, new FocusEvent('focus')) + }) + + act(() => { + vi.advanceTimersByTime(200) + }) + + expect(screen.queryByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + vi.useRealTimers() + }) + + it('cancels a pending blur timer when a subsequent blur targets var-search-input', async () => { + const captures: Captures = { editor: null, eventEmitter: null } + + render(( + + )) + + const editor = await waitForEditor(captures) + await setEditorText(editor, '{', true) + expect(await screen.findByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + vi.useFakeTimers() + + act(() => { + editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: document.createElement('button') })) + }) + + const varInput = document.createElement('input') + varInput.classList.add('var-search-input') + + act(() => { + editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: varInput })) + }) + + act(() => { + vi.advanceTimersByTime(200) + }) + + expect(screen.queryByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + vi.useRealTimers() + }) + + it('does not hide the menu when blur target is var-search-input', async () => { + const captures: Captures = { editor: null, eventEmitter: null } + + render(( + + )) + + const editor = await waitForEditor(captures) + await setEditorText(editor, '{', true) + expect(await screen.findByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + vi.useFakeTimers() + + const target = document.createElement('input') + target.classList.add('var-search-input') + + act(() => { + editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: target })) + }) + + act(() => { + vi.advanceTimersByTime(200) + }) + + expect(screen.queryByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + vi.useRealTimers() + }) + }) }) diff --git a/web/app/components/base/prompt-editor/plugins/component-picker-block/index.tsx b/web/app/components/base/prompt-editor/plugins/component-picker-block/index.tsx index 8001a2755b6..bebc1b59af7 100644 --- a/web/app/components/base/prompt-editor/plugins/component-picker-block/index.tsx +++ b/web/app/components/base/prompt-editor/plugins/component-picker-block/index.tsx @@ -21,11 +21,19 @@ import { } from '@floating-ui/react' import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext' import { LexicalTypeaheadMenuPlugin } from '@lexical/react/LexicalTypeaheadMenuPlugin' -import { KEY_ESCAPE_COMMAND } from 'lexical' +import { mergeRegister } from '@lexical/utils' +import { + BLUR_COMMAND, + COMMAND_PRIORITY_EDITOR, + FOCUS_COMMAND, + KEY_ESCAPE_COMMAND, +} from 'lexical' import { Fragment, memo, useCallback, + useEffect, + useRef, useState, } from 'react' import ReactDOM from 'react-dom' @@ -87,6 +95,46 @@ const ComponentPicker = ({ }) const [queryString, setQueryString] = useState(null) + const [blurHidden, setBlurHidden] = useState(false) + const blurTimerRef = useRef | null>(null) + + const clearBlurTimer = useCallback(() => { + if (blurTimerRef.current) { + clearTimeout(blurTimerRef.current) + blurTimerRef.current = null + } + }, []) + + useEffect(() => { + const unregister = mergeRegister( + editor.registerCommand( + BLUR_COMMAND, + (event) => { + clearBlurTimer() + const target = event?.relatedTarget as HTMLElement + if (!target?.classList?.contains('var-search-input')) + blurTimerRef.current = setTimeout(() => setBlurHidden(true), 200) + return false + }, + COMMAND_PRIORITY_EDITOR, + ), + editor.registerCommand( + FOCUS_COMMAND, + () => { + clearBlurTimer() + setBlurHidden(false) + return false + }, + COMMAND_PRIORITY_EDITOR, + ), + ) + + return () => { + if (blurTimerRef.current) + clearTimeout(blurTimerRef.current) + unregister() + } + }, [editor, clearBlurTimer]) eventEmitter?.useSubscription((v: any) => { if (v.type === INSERT_VARIABLE_VALUE_BLOCK_COMMAND) @@ -159,6 +207,8 @@ const ComponentPicker = ({ anchorElementRef, { options, selectedIndex, selectOptionAndCleanUp, setHighlightedIndex }, ) => { + if (blurHidden) + return null if (!(anchorElementRef.current && (allFlattenOptions.length || workflowVariableBlock?.show))) return null @@ -240,7 +290,7 @@ const ComponentPicker = ({ } ) - }, [allFlattenOptions.length, workflowVariableBlock?.show, floatingStyles, isPositioned, refs, workflowVariableOptions, isSupportFileVar, handleClose, currentBlock?.generatorType, handleSelectWorkflowVariable, queryString, workflowVariableBlock?.showManageInputField, workflowVariableBlock?.onManageInputField]) + }, [blurHidden, allFlattenOptions.length, workflowVariableBlock?.show, floatingStyles, isPositioned, refs, workflowVariableOptions, isSupportFileVar, handleClose, currentBlock?.generatorType, handleSelectWorkflowVariable, queryString, workflowVariableBlock?.showManageInputField, workflowVariableBlock?.onManageInputField]) return ( void @@ -20,35 +18,13 @@ const OnBlurBlock: FC = ({ }) => { const [editor] = useLexicalComposerContext() - const ref = useRef | null>(null) - useEffect(() => { - const clearHideMenuTimeout = () => { - if (ref.current) { - clearTimeout(ref.current) - ref.current = null - } - } - - const unregister = mergeRegister( - editor.registerCommand( - CLEAR_HIDE_MENU_TIMEOUT, - () => { - clearHideMenuTimeout() - return true - }, - COMMAND_PRIORITY_EDITOR, - ), + return mergeRegister( editor.registerCommand( BLUR_COMMAND, (event) => { - // Check if the clicked target element is var-search-input const target = event?.relatedTarget as HTMLElement if (!target?.classList?.contains('var-search-input')) { - clearHideMenuTimeout() - ref.current = setTimeout(() => { - editor.dispatchCommand(KEY_ESCAPE_COMMAND, new KeyboardEvent('keydown', { key: 'Escape' })) - }, 200) if (onBlur) onBlur() } @@ -66,11 +42,6 @@ const OnBlurBlock: FC = ({ COMMAND_PRIORITY_EDITOR, ), ) - - return () => { - clearHideMenuTimeout() - unregister() - } }, [editor, onBlur, onFocus]) return null diff --git a/web/app/components/base/prompt-editor/plugins/shortcuts-popup-plugin/index.tsx b/web/app/components/base/prompt-editor/plugins/shortcuts-popup-plugin/index.tsx index abe6ea9a452..7dcda803f23 100644 --- a/web/app/components/base/prompt-editor/plugins/shortcuts-popup-plugin/index.tsx +++ b/web/app/components/base/prompt-editor/plugins/shortcuts-popup-plugin/index.tsx @@ -141,7 +141,7 @@ export default function ShortcutsPopupPlugin({ const portalRef = useRef(null) const lastSelectionRef = useRef(null) - /* v8 ignore next -- defensive non-browser fallback; this client-only plugin runs where document exists (browser/jsdom). @preserve */ + /* v8 ignore next -- defensive non-browser fallback; this client-only plugin runs where document exists (browser/test DOM runtime). @preserve */ const containerEl = useMemo(() => container ?? (typeof document !== 'undefined' ? document.body : null), [container]) const useContainer = !!containerEl && containerEl !== document.body @@ -210,7 +210,7 @@ export default function ShortcutsPopupPlugin({ if (rect.width === 0 && rect.height === 0) { const root = editor.getRootElement() - /* v8 ignore next 10 -- zero-size rect recovery depends on browser layout/selection geometry; deterministic reproduction in jsdom is unreliable. @preserve */ + /* v8 ignore next 10 -- zero-size rect recovery depends on browser layout/selection geometry; deterministic reproduction in the test DOM runtime is unreliable. @preserve */ if (root) { const sc = range.startContainer const node = sc.nodeType === Node.ELEMENT_NODE diff --git a/web/app/components/base/prompt-editor/plugins/update-block.tsx b/web/app/components/base/prompt-editor/plugins/update-block.tsx index bf89a259af5..2d83573b1fe 100644 --- a/web/app/components/base/prompt-editor/plugins/update-block.tsx +++ b/web/app/components/base/prompt-editor/plugins/update-block.tsx @@ -3,7 +3,6 @@ import { $insertNodes } from 'lexical' import { useEventEmitterContextContext } from '@/context/event-emitter' import { textToEditorState } from '../utils' import { CustomTextNode } from './custom-text/node' -import { CLEAR_HIDE_MENU_TIMEOUT } from './workflow-variable-block' export const PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER = 'PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER' export const PROMPT_EDITOR_INSERT_QUICKLY = 'PROMPT_EDITOR_INSERT_QUICKLY' @@ -30,8 +29,6 @@ const UpdateBlock = ({ editor.update(() => { const textNode = new CustomTextNode('/') $insertNodes([textNode]) - - editor.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined) }) } }) diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/index.spec.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/index.spec.tsx index ca4973b8300..1591dc44f9f 100644 --- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/index.spec.tsx +++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/index.spec.tsx @@ -9,7 +9,6 @@ import { $insertNodes, COMMAND_PRIORITY_EDITOR } from 'lexical' import { Type } from '@/app/components/workflow/nodes/llm/types' import { BlockEnum } from '@/app/components/workflow/types' import { - CLEAR_HIDE_MENU_TIMEOUT, DELETE_WORKFLOW_VARIABLE_BLOCK_COMMAND, INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND, UPDATE_WORKFLOW_NODES_MAP, @@ -134,7 +133,6 @@ describe('WorkflowVariableBlock', () => { const insertHandler = mockRegisterCommand.mock.calls[0][1] as (variables: string[]) => boolean const result = insertHandler(['node-1', 'answer']) - expect(mockDispatchCommand).toHaveBeenCalledWith(CLEAR_HIDE_MENU_TIMEOUT, undefined) expect($createWorkflowVariableBlockNode).toHaveBeenCalledWith( ['node-1', 'answer'], workflowNodesMap, diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx index 76b27958031..c8cac64d19d 100644 --- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx +++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx @@ -18,7 +18,6 @@ import { export const INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND = createCommand('INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND') export const DELETE_WORKFLOW_VARIABLE_BLOCK_COMMAND = createCommand('DELETE_WORKFLOW_VARIABLE_BLOCK_COMMAND') -export const CLEAR_HIDE_MENU_TIMEOUT = createCommand('CLEAR_HIDE_MENU_TIMEOUT') export const UPDATE_WORKFLOW_NODES_MAP = createCommand('UPDATE_WORKFLOW_NODES_MAP') export type WorkflowVariableBlockProps = { @@ -49,7 +48,6 @@ const WorkflowVariableBlock = memo(({ editor.registerCommand( INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND, (variables: string[]) => { - editor.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined) const workflowVariableBlockNode = $createWorkflowVariableBlockNode(variables, workflowNodesMap, getVarType) $insertNodes([workflowVariableBlockNode]) diff --git a/web/app/components/base/slider/__tests__/index.spec.tsx b/web/app/components/base/slider/__tests__/index.spec.tsx deleted file mode 100644 index bb1f030689a..00000000000 --- a/web/app/components/base/slider/__tests__/index.spec.tsx +++ /dev/null @@ -1,77 +0,0 @@ -import { act, render, screen } from '@testing-library/react' -import userEvent from '@testing-library/user-event' -import { describe, expect, it, vi } from 'vitest' -import Slider from '../index' - -describe('Slider Component', () => { - it('should render with correct default ARIA limits and current value', () => { - render() - - const slider = screen.getByRole('slider') - expect(slider).toHaveAttribute('aria-valuemin', '0') - expect(slider).toHaveAttribute('aria-valuemax', '100') - expect(slider).toHaveAttribute('aria-valuenow', '50') - }) - - it('should apply custom min, max, and step values', () => { - render() - - const slider = screen.getByRole('slider') - expect(slider).toHaveAttribute('aria-valuemin', '5') - expect(slider).toHaveAttribute('aria-valuemax', '20') - expect(slider).toHaveAttribute('aria-valuenow', '10') - }) - - it('should default to 0 if the value prop is NaN', () => { - render() - - const slider = screen.getByRole('slider') - expect(slider).toHaveAttribute('aria-valuenow', '0') - }) - - it('should call onChange when arrow keys are pressed', async () => { - const user = userEvent.setup() - const onChange = vi.fn() - - render() - - const slider = screen.getByRole('slider') - - await act(async () => { - slider.focus() - await user.keyboard('{ArrowRight}') - }) - - expect(onChange).toHaveBeenCalledTimes(1) - expect(onChange).toHaveBeenCalledWith(21, 0) - }) - - it('should not trigger onChange when disabled', async () => { - const user = userEvent.setup() - const onChange = vi.fn() - render() - - const slider = screen.getByRole('slider') - - expect(slider).toHaveAttribute('aria-disabled', 'true') - - await act(async () => { - slider.focus() - await user.keyboard('{ArrowRight}') - }) - - expect(onChange).not.toHaveBeenCalled() - }) - - it('should apply custom class names', () => { - render( - , - ) - - const sliderWrapper = screen.getByRole('slider').closest('.outer-test') - expect(sliderWrapper).toBeInTheDocument() - - const thumb = screen.getByRole('slider') - expect(thumb).toHaveClass('thumb-test') - }) -}) diff --git a/web/app/components/base/slider/index.stories.tsx b/web/app/components/base/slider/index.stories.tsx deleted file mode 100644 index bde937ffadb..00000000000 --- a/web/app/components/base/slider/index.stories.tsx +++ /dev/null @@ -1,635 +0,0 @@ -import type { Meta, StoryObj } from '@storybook/nextjs-vite' -import { useState } from 'react' -import Slider from '.' - -const meta = { - title: 'Base/Data Entry/Slider', - component: Slider, - parameters: { - layout: 'centered', - docs: { - description: { - component: 'Slider component for selecting a numeric value within a range. Built on react-slider with customizable min/max/step values.', - }, - }, - }, - tags: ['autodocs'], - argTypes: { - value: { - control: 'number', - description: 'Current slider value', - }, - min: { - control: 'number', - description: 'Minimum value (default: 0)', - }, - max: { - control: 'number', - description: 'Maximum value (default: 100)', - }, - step: { - control: 'number', - description: 'Step increment (default: 1)', - }, - disabled: { - control: 'boolean', - description: 'Disabled state', - }, - }, - args: { - onChange: (value) => { - console.log('Slider value:', value) - }, - }, -} satisfies Meta - -export default meta -type Story = StoryObj - -// Interactive demo wrapper -const SliderDemo = (args: any) => { - const [value, setValue] = useState(args.value || 50) - - return ( -
    - { - setValue(v) - console.log('Slider value:', v) - }} - /> -
    - Value: - {' '} - {value} -
    -
    - ) -} - -// Default state -export const Default: Story = { - render: args => , - args: { - value: 50, - min: 0, - max: 100, - step: 1, - disabled: false, - }, -} - -// With custom range -export const CustomRange: Story = { - render: args => , - args: { - value: 25, - min: 0, - max: 50, - step: 1, - disabled: false, - }, -} - -// With step increment -export const WithStepIncrement: Story = { - render: args => , - args: { - value: 50, - min: 0, - max: 100, - step: 10, - disabled: false, - }, -} - -// Decimal values -export const DecimalValues: Story = { - render: args => , - args: { - value: 2.5, - min: 0, - max: 5, - step: 0.5, - disabled: false, - }, -} - -// Disabled state -export const Disabled: Story = { - render: args => , - args: { - value: 75, - min: 0, - max: 100, - step: 1, - disabled: true, - }, -} - -// Real-world example - Volume control -const VolumeControlDemo = () => { - const [volume, setVolume] = useState(70) - - const getVolumeIcon = (vol: number) => { - if (vol === 0) - return '🔇' - if (vol < 33) - return '🔈' - if (vol < 66) - return '🔉' - return '🔊' - } - - return ( -
    -
    -

    Volume Control

    - {getVolumeIcon(volume)} -
    - -
    - Mute - - {volume} - % - - Max -
    -
    - ) -} - -export const VolumeControl: Story = { - render: () => , - parameters: { controls: { disable: true } }, -} as unknown as Story - -// Real-world example - Brightness control -const BrightnessControlDemo = () => { - const [brightness, setBrightness] = useState(80) - - return ( -
    -
    -

    Screen Brightness

    - ☀️ -
    - -
    -
    - Preview at - {' '} - {brightness} - % brightness -
    -
    -
    - ) -} - -export const BrightnessControl: Story = { - render: () => , - parameters: { controls: { disable: true } }, -} as unknown as Story - -// Real-world example - Price range filter -const PriceRangeFilterDemo = () => { - const [maxPrice, setMaxPrice] = useState(500) - const minPrice = 0 - - const products = [ - { name: 'Product A', price: 150 }, - { name: 'Product B', price: 350 }, - { name: 'Product C', price: 600 }, - { name: 'Product D', price: 250 }, - { name: 'Product E', price: 450 }, - ] - - const filteredProducts = products.filter(p => p.price >= minPrice && p.price <= maxPrice) - - return ( -
    -

    Filter by Price

    -
    -
    - Maximum Price - - $ - {maxPrice} - -
    - -
    -
    -
    - Showing - {' '} - {filteredProducts.length} - {' '} - of - {' '} - {products.length} - {' '} - products -
    -
    - {filteredProducts.map(product => ( -
    - {product.name} - - $ - {product.price} - -
    - ))} -
    -
    -
    - ) -} - -export const PriceRangeFilter: Story = { - render: () => , - parameters: { controls: { disable: true } }, -} as unknown as Story - -// Real-world example - Temperature selector -const TemperatureSelectorDemo = () => { - const [temperature, setTemperature] = useState(22) - const fahrenheit = ((temperature * 9) / 5 + 32).toFixed(1) - - return ( -
    -

    Thermostat Control

    -
    - -
    -
    -
    -
    Celsius
    -
    - {temperature} - °C -
    -
    -
    -
    Fahrenheit
    -
    - {fahrenheit} - °F -
    -
    -
    -
    - {temperature < 18 && '🥶 Too cold'} - {temperature >= 18 && temperature <= 24 && '😊 Comfortable'} - {temperature > 24 && '🥵 Too warm'} -
    -
    - ) -} - -export const TemperatureSelector: Story = { - render: () => , - parameters: { controls: { disable: true } }, -} as unknown as Story - -// Real-world example - Progress/completion slider -const ProgressSliderDemo = () => { - const [progress, setProgress] = useState(65) - - return ( -
    -

    Project Completion

    - -
    -
    - Progress - - {progress} - % - -
    -
    -
    - = 25 ? '✅' : '⏳'}>Planning - 25% -
    -
    - = 50 ? '✅' : '⏳'}>Development - 50% -
    -
    - = 75 ? '✅' : '⏳'}>Testing - 75% -
    -
    - = 100 ? '✅' : '⏳'}>Deployment - 100% -
    -
    -
    -
    - ) -} - -export const ProgressSlider: Story = { - render: () => , - parameters: { controls: { disable: true } }, -} as unknown as Story - -// Real-world example - Zoom control -const ZoomControlDemo = () => { - const [zoom, setZoom] = useState(100) - - return ( -
    -

    Zoom Level

    -
    - -
    - -
    - -
    -
    - 50% - - {zoom} - % - - 200% -
    -
    -
    Preview content
    -
    -
    - ) -} - -export const ZoomControl: Story = { - render: () => , - parameters: { controls: { disable: true } }, -} as unknown as Story - -// Real-world example - AI model parameters -const AIModelParametersDemo = () => { - const [temperature, setTemperature] = useState(0.7) - const [maxTokens, setMaxTokens] = useState(2000) - const [topP, setTopP] = useState(0.9) - - return ( -
    -

    Model Configuration

    -
    -
    -
    - - {temperature} -
    - -

    - Controls randomness. Lower is more focused, higher is more creative. -

    -
    - -
    -
    - - {maxTokens} -
    - -

    - Maximum length of generated response. -

    -
    - -
    -
    - - {topP} -
    - -

    - Nucleus sampling threshold. -

    -
    -
    -
    -
    - Temperature: - {' '} - {temperature} -
    -
    - Max Tokens: - {' '} - {maxTokens} -
    -
    - Top P: - {' '} - {topP} -
    -
    -
    - ) -} - -export const AIModelParameters: Story = { - render: () => , - parameters: { controls: { disable: true } }, -} as unknown as Story - -// Real-world example - Image quality selector -const ImageQualitySelectorDemo = () => { - const [quality, setQuality] = useState(80) - - const getQualityLabel = (q: number) => { - if (q < 50) - return 'Low' - if (q < 70) - return 'Medium' - if (q < 90) - return 'High' - return 'Maximum' - } - - const estimatedSize = Math.round((quality / 100) * 5) - - return ( -
    -

    Image Export Quality

    - -
    -
    -
    Quality
    -
    {getQualityLabel(quality)}
    -
    - {quality} - % -
    -
    -
    -
    File Size
    -
    - ~ - {estimatedSize} - {' '} - MB -
    -
    Estimated
    -
    -
    -
    - ) -} - -export const ImageQualitySelector: Story = { - render: () => , - parameters: { controls: { disable: true } }, -} as unknown as Story - -// Multiple sliders -const MultipleSlidersDemo = () => { - const [red, setRed] = useState(128) - const [green, setGreen] = useState(128) - const [blue, setBlue] = useState(128) - - const rgbColor = `rgb(${red}, ${green}, ${blue})` - - return ( -
    -

    RGB Color Picker

    -
    -
    -
    - - {red} -
    - -
    -
    -
    - - {green} -
    - -
    -
    -
    - - {blue} -
    - -
    -
    -
    -
    -
    -
    Color Value
    -
    {rgbColor}
    -
    - # - {red.toString(16).padStart(2, '0')} - {green.toString(16).padStart(2, '0')} - {blue.toString(16).padStart(2, '0')} -
    -
    -
    -
    - ) -} - -export const MultipleSliders: Story = { - render: () => , - parameters: { controls: { disable: true } }, -} as unknown as Story - -// Interactive playground -export const Playground: Story = { - render: args => , - args: { - value: 50, - min: 0, - max: 100, - step: 1, - disabled: false, - }, -} diff --git a/web/app/components/base/slider/index.tsx b/web/app/components/base/slider/index.tsx deleted file mode 100644 index 4e4656f590c..00000000000 --- a/web/app/components/base/slider/index.tsx +++ /dev/null @@ -1,43 +0,0 @@ -import ReactSlider from 'react-slider' -import { cn } from '@/utils/classnames' -import './style.css' - -type ISliderProps = { - className?: string - thumbClassName?: string - trackClassName?: string - value: number - max?: number - min?: number - step?: number - disabled?: boolean - onChange: (value: number) => void -} - -const Slider: React.FC = ({ - className, - thumbClassName, - trackClassName, - max, - min, - step, - value, - disabled, - onChange, -}) => { - return ( - - ) -} - -export default Slider diff --git a/web/app/components/base/slider/style.css b/web/app/components/base/slider/style.css deleted file mode 100644 index 5d87fb0897a..00000000000 --- a/web/app/components/base/slider/style.css +++ /dev/null @@ -1,11 +0,0 @@ -.slider.disabled { - opacity: 0.6; -} - -.slider-track { - background-color: var(--color-components-slider-range); -} - -.slider-track-1 { - background-color: var(--color-components-slider-track); -} diff --git a/web/app/components/base/tag-management/__tests__/filter.spec.tsx b/web/app/components/base/tag-management/__tests__/filter.spec.tsx index 3cffac29b26..a455d1a7912 100644 --- a/web/app/components/base/tag-management/__tests__/filter.spec.tsx +++ b/web/app/components/base/tag-management/__tests__/filter.spec.tsx @@ -14,23 +14,11 @@ vi.mock('@/service/tag', () => ({ fetchTagList, })) -// Mock ahooks to avoid timer-related issues in tests vi.mock('ahooks', () => { return { - useDebounceFn: (fn: (...args: unknown[]) => void) => { - const ref = React.useRef(fn) - ref.current = fn - const stableRun = React.useRef((...args: unknown[]) => { - // Schedule to run after current event handler finishes, - // allowing React to process pending state updates first - Promise.resolve().then(() => ref.current(...args)) - }) - return { run: stableRun.current } - }, useMount: (fn: () => void) => { React.useEffect(() => { fn() - // eslint-disable-next-line react-hooks/exhaustive-deps }, []) }, } @@ -228,7 +216,6 @@ describe('TagFilter', () => { const searchInput = screen.getByRole('textbox') await user.type(searchInput, 'Front') - // With debounce mocked to be synchronous, results should be immediate expect(screen.getByText('Frontend')).toBeInTheDocument() expect(screen.queryByText('Backend')).not.toBeInTheDocument() expect(screen.queryByText('API Design')).not.toBeInTheDocument() @@ -257,22 +244,14 @@ describe('TagFilter', () => { const searchInput = screen.getByRole('textbox') await user.type(searchInput, 'Front') - // Wait for the debounced search to filter - await waitFor(() => { - expect(screen.queryByText('Backend')).not.toBeInTheDocument() - }) + expect(screen.queryByText('Backend')).not.toBeInTheDocument() - // Clear the search using the Input's clear button const clearButton = screen.getByTestId('input-clear') await user.click(clearButton) - // The input value should be cleared expect(searchInput).toHaveValue('') - // After the clear + microtask re-render, all app tags should be visible again - await waitFor(() => { - expect(screen.getByText('Backend')).toBeInTheDocument() - }) + expect(screen.getByText('Backend')).toBeInTheDocument() expect(screen.getByText('Frontend')).toBeInTheDocument() expect(screen.getByText('API Design')).toBeInTheDocument() }) diff --git a/web/app/components/base/tag-management/filter.tsx b/web/app/components/base/tag-management/filter.tsx index ad71334ddbf..fcd59bcf7d6 100644 --- a/web/app/components/base/tag-management/filter.tsx +++ b/web/app/components/base/tag-management/filter.tsx @@ -1,15 +1,15 @@ import type { FC } from 'react' import type { Tag } from '@/app/components/base/tag-management/constant' -import { useDebounceFn, useMount } from 'ahooks' +import { useMount } from 'ahooks' import { useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' import { Tag01, Tag03 } from '@/app/components/base/icons/src/vender/line/financeAndECommerce' import Input from '@/app/components/base/input' import { - PortalToFollowElem, - PortalToFollowElemContent, - PortalToFollowElemTrigger, -} from '@/app/components/base/portal-to-follow-elem' + Popover, + PopoverContent, + PopoverTrigger, +} from '@/app/components/base/ui/popover' import { fetchTagList } from '@/service/tag' import { cn } from '@/utils/classnames' @@ -33,18 +33,10 @@ const TagFilter: FC = ({ const setShowTagManagementModal = useTagStore(s => s.setShowTagManagementModal) const [keywords, setKeywords] = useState('') - const [searchKeywords, setSearchKeywords] = useState('') - const { run: handleSearch } = useDebounceFn(() => { - setSearchKeywords(keywords) - }, { wait: 500 }) - const handleKeywordsChange = (value: string) => { - setKeywords(value) - handleSearch() - } const filteredTagList = useMemo(() => { - return tagList.filter(tag => tag.type === type && tag.name.includes(searchKeywords)) - }, [type, tagList, searchKeywords]) + return tagList.filter(tag => tag.type === type && tag.name.includes(keywords)) + }, [type, tagList, keywords]) const currentTag = useMemo(() => { return tagList.find(tag => tag.id === value[0]) @@ -64,61 +56,61 @@ const TagFilter: FC = ({ }) return ( -
    - setOpen(v => !v)} - className="block" - > -
    -
    - -
    -
    - {!value.length && t('tag.placeholder', { ns: 'common' })} - {!!value.length && currentTag?.name} -
    - {value.length > 1 && ( -
    {`+${value.length - 1}`}
    - )} - {!value.length && ( +
    - +
    - )} - {!!value.length && ( -
    { - e.stopPropagation() - onChange([]) - }} - data-testid="tag-filter-clear-button" - > - +
    + {!value.length && t('tag.placeholder', { ns: 'common' })} + {!!value.length && currentTag?.name}
    - )} -
    - - -
    + {value.length > 1 && ( +
    {`+${value.length - 1}`}
    + )} + {!value.length && ( +
    + +
    + )} + + )} + /> + {!!value.length && ( + + )} + +
    handleKeywordsChange(e.target.value)} - onClear={() => handleKeywordsChange('')} + onChange={e => setKeywords(e.target.value)} + onClear={() => setKeywords('')} />
    @@ -155,9 +147,9 @@ const TagFilter: FC = ({
    -
    +
    - + ) } diff --git a/web/app/components/base/ui/slider/__tests__/index.spec.tsx b/web/app/components/base/ui/slider/__tests__/index.spec.tsx new file mode 100644 index 00000000000..f34de5010d3 --- /dev/null +++ b/web/app/components/base/ui/slider/__tests__/index.spec.tsx @@ -0,0 +1,73 @@ +import { act, render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { describe, expect, it, vi } from 'vitest' +import { Slider } from '../index' + +describe('Slider', () => { + const getSliderInput = () => screen.getByLabelText('Value') + + it('should render with correct default ARIA limits and current value', () => { + render() + + const slider = getSliderInput() + expect(slider).toHaveAttribute('min', '0') + expect(slider).toHaveAttribute('max', '100') + expect(slider).toHaveAttribute('aria-valuenow', '50') + }) + + it('should apply custom min, max, and step values', () => { + render() + + const slider = getSliderInput() + expect(slider).toHaveAttribute('min', '5') + expect(slider).toHaveAttribute('max', '20') + expect(slider).toHaveAttribute('aria-valuenow', '10') + }) + + it('should clamp non-finite values to min', () => { + render() + + expect(getSliderInput()).toHaveAttribute('aria-valuenow', '5') + }) + + it('should call onValueChange when arrow keys are pressed', async () => { + const user = userEvent.setup() + const onValueChange = vi.fn() + + render() + + const slider = getSliderInput() + + await act(async () => { + slider.focus() + await user.keyboard('{ArrowRight}') + }) + + expect(onValueChange).toHaveBeenCalledTimes(1) + expect(onValueChange).toHaveBeenLastCalledWith(21, expect.anything()) + }) + + it('should not trigger onValueChange when disabled', async () => { + const user = userEvent.setup() + const onValueChange = vi.fn() + render() + + const slider = getSliderInput() + + expect(slider).toBeDisabled() + + await act(async () => { + slider.focus() + await user.keyboard('{ArrowRight}') + }) + + expect(onValueChange).not.toHaveBeenCalled() + }) + + it('should apply custom class names on root', () => { + const { container } = render() + + const sliderWrapper = container.querySelector('.outer-test') + expect(sliderWrapper).toBeInTheDocument() + }) +}) diff --git a/web/app/components/base/ui/slider/index.stories.tsx b/web/app/components/base/ui/slider/index.stories.tsx new file mode 100644 index 00000000000..b61a6cb2888 --- /dev/null +++ b/web/app/components/base/ui/slider/index.stories.tsx @@ -0,0 +1,92 @@ +import type { Meta, StoryObj } from '@storybook/nextjs-vite' +import type * as React from 'react' +import { useState } from 'react' +import { Slider } from '.' + +const meta = { + title: 'Base UI/Data Entry/Slider', + component: Slider, + parameters: { + layout: 'centered', + docs: { + description: { + component: 'Single-value horizontal slider built on Base UI.', + }, + }, + }, + tags: ['autodocs'], + argTypes: { + value: { + control: 'number', + }, + min: { + control: 'number', + }, + max: { + control: 'number', + }, + step: { + control: 'number', + }, + disabled: { + control: 'boolean', + }, + }, +} satisfies Meta + +export default meta + +type Story = StoryObj + +function SliderDemo({ + value: initialValue = 50, + defaultValue: _defaultValue, + ...args +}: React.ComponentProps) { + const [value, setValue] = useState(initialValue) + + return ( +
    + +
    + {value} +
    +
    + ) +} + +export const Default: Story = { + render: args => , + args: { + value: 50, + min: 0, + max: 100, + step: 1, + }, +} + +export const Decimal: Story = { + render: args => , + args: { + value: 0.5, + min: 0, + max: 1, + step: 0.1, + }, +} + +export const Disabled: Story = { + render: args => , + args: { + value: 75, + min: 0, + max: 100, + step: 1, + disabled: true, + }, +} diff --git a/web/app/components/base/ui/slider/index.tsx b/web/app/components/base/ui/slider/index.tsx new file mode 100644 index 00000000000..8e1dc969bc8 --- /dev/null +++ b/web/app/components/base/ui/slider/index.tsx @@ -0,0 +1,100 @@ +'use client' + +import { Slider as BaseSlider } from '@base-ui/react/slider' +import * as React from 'react' +import { cn } from '@/utils/classnames' + +type SliderRootProps = BaseSlider.Root.Props +type SliderThumbProps = BaseSlider.Thumb.Props + +type SliderBaseProps = Pick< + SliderRootProps, + 'onValueChange' | 'min' | 'max' | 'step' | 'disabled' | 'name' +> & Pick & { + className?: string +} + +type ControlledSliderProps = SliderBaseProps & { + value: number + defaultValue?: never +} + +type UncontrolledSliderProps = SliderBaseProps & { + value?: never + defaultValue?: number +} + +export type SliderProps = ControlledSliderProps | UncontrolledSliderProps + +const sliderRootClassName = 'group/slider relative inline-flex w-full data-[disabled]:opacity-30' +const sliderControlClassName = cn( + 'relative flex h-5 w-full touch-none select-none items-center', + 'data-[disabled]:cursor-not-allowed', +) +const sliderTrackClassName = cn( + 'relative h-1 w-full overflow-hidden rounded-full', + 'bg-[var(--slider-track,var(--color-components-slider-track))]', +) +const sliderIndicatorClassName = cn( + 'h-full rounded-full', + 'bg-[var(--slider-range,var(--color-components-slider-range))]', +) +const sliderThumbClassName = cn( + 'block h-5 w-2 shrink-0 rounded-[3px] border-[0.5px]', + 'border-[var(--slider-knob-border,var(--color-components-slider-knob-border))]', + 'bg-[var(--slider-knob,var(--color-components-slider-knob))] shadow-sm', + 'transition-[background-color,border-color,box-shadow,opacity] motion-reduce:transition-none', + 'hover:bg-[var(--slider-knob-hover,var(--color-components-slider-knob-hover))]', + 'focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-components-slider-knob-border-hover focus-visible:ring-offset-0', + 'active:shadow-md', + 'group-data-[disabled]/slider:bg-[var(--slider-knob-disabled,var(--color-components-slider-knob-disabled))]', + 'group-data-[disabled]/slider:border-[var(--slider-knob-border,var(--color-components-slider-knob-border))]', + 'group-data-[disabled]/slider:shadow-none', +) + +const getSafeValue = (value: number | undefined, min: number) => { + if (value === undefined) + return undefined + + return Number.isFinite(value) ? value : min +} + +export function Slider({ + value, + defaultValue, + onValueChange, + min = 0, + max = 100, + step = 1, + disabled = false, + name, + className, + 'aria-label': ariaLabel, + 'aria-labelledby': ariaLabelledby, +}: SliderProps) { + return ( + + + + + + + + + ) +} diff --git a/web/app/components/base/ui/toast/__tests__/index.spec.tsx b/web/app/components/base/ui/toast/__tests__/index.spec.tsx index 75364117c3f..db6d86719a3 100644 --- a/web/app/components/base/ui/toast/__tests__/index.spec.tsx +++ b/web/app/components/base/ui/toast/__tests__/index.spec.tsx @@ -7,27 +7,25 @@ describe('base/ui/toast', () => { vi.clearAllMocks() vi.useFakeTimers({ shouldAdvanceTime: true }) act(() => { - toast.close() + toast.dismiss() }) }) afterEach(() => { act(() => { - toast.close() + toast.dismiss() vi.runOnlyPendingTimers() }) vi.useRealTimers() }) // Core host and manager integration. - it('should render a toast when add is called', async () => { + it('should render a success toast when called through the typed shortcut', async () => { render() act(() => { - toast.add({ - title: 'Saved', + toast.success('Saved', { description: 'Your changes are available now.', - type: 'success', }) }) @@ -47,20 +45,14 @@ describe('base/ui/toast', () => { render() act(() => { - toast.add({ - title: 'First toast', - }) + toast('First toast') }) expect(await screen.findByText('First toast')).toBeInTheDocument() act(() => { - toast.add({ - title: 'Second toast', - }) - toast.add({ - title: 'Third toast', - }) + toast('Second toast') + toast('Third toast') }) expect(await screen.findByText('Third toast')).toBeInTheDocument() @@ -74,13 +66,25 @@ describe('base/ui/toast', () => { }) }) + // Neutral calls should map directly to a toast with only a title. + it('should render a neutral toast when called directly', async () => { + render() + + act(() => { + toast('Neutral toast') + }) + + expect(await screen.findByText('Neutral toast')).toBeInTheDocument() + expect(document.body.querySelector('[aria-hidden="true"].i-ri-information-2-fill')).not.toBeInTheDocument() + }) + // Base UI limit should cap the visible stack and mark overflow toasts as limited. it('should mark overflow toasts as limited when the stack exceeds the configured limit', async () => { render() act(() => { - toast.add({ title: 'First toast' }) - toast.add({ title: 'Second toast' }) + toast('First toast') + toast('Second toast') }) expect(await screen.findByText('Second toast')).toBeInTheDocument() @@ -88,13 +92,12 @@ describe('base/ui/toast', () => { }) // Closing should work through the public manager API. - it('should close a toast when close(id) is called', async () => { + it('should dismiss a toast when dismiss(id) is called', async () => { render() let toastId = '' act(() => { - toastId = toast.add({ - title: 'Closable', + toastId = toast('Closable', { description: 'This toast can be removed.', }) }) @@ -102,7 +105,7 @@ describe('base/ui/toast', () => { expect(await screen.findByText('Closable')).toBeInTheDocument() act(() => { - toast.close(toastId) + toast.dismiss(toastId) }) await waitFor(() => { @@ -117,8 +120,7 @@ describe('base/ui/toast', () => { render() act(() => { - toast.add({ - title: 'Dismiss me', + toast('Dismiss me', { description: 'Manual dismissal path.', onClose, }) @@ -143,9 +145,7 @@ describe('base/ui/toast', () => { render() act(() => { - toast.add({ - title: 'Default timeout', - }) + toast('Default timeout') }) expect(await screen.findByText('Default timeout')).toBeInTheDocument() @@ -170,9 +170,7 @@ describe('base/ui/toast', () => { render() act(() => { - toast.add({ - title: 'Configured timeout', - }) + toast('Configured timeout') }) expect(await screen.findByText('Configured timeout')).toBeInTheDocument() @@ -197,8 +195,7 @@ describe('base/ui/toast', () => { render() act(() => { - toast.add({ - title: 'Custom timeout', + toast('Custom timeout', { timeout: 1000, }) }) @@ -214,8 +211,7 @@ describe('base/ui/toast', () => { }) act(() => { - toast.add({ - title: 'Persistent', + toast('Persistent', { timeout: 0, }) }) @@ -235,10 +231,8 @@ describe('base/ui/toast', () => { let toastId = '' act(() => { - toastId = toast.add({ - title: 'Loading', + toastId = toast.info('Loading', { description: 'Preparing your data…', - type: 'info', }) }) @@ -264,8 +258,7 @@ describe('base/ui/toast', () => { render() act(() => { - toast.add({ - title: 'Action toast', + toast('Action toast', { actionProps: { children: 'Undo', onClick: onAction, diff --git a/web/app/components/base/ui/toast/index.stories.tsx b/web/app/components/base/ui/toast/index.stories.tsx index 045ca96823d..a0dd806d19c 100644 --- a/web/app/components/base/ui/toast/index.stories.tsx +++ b/web/app/components/base/ui/toast/index.stories.tsx @@ -57,9 +57,8 @@ const VariantExamples = () => { }, } as const - toast.add({ - type, - ...copy[type], + toast[type](copy[type].title, { + description: copy[type].description, }) } @@ -103,14 +102,16 @@ const StackExamples = () => { title: 'Ready to publish', description: 'The newest toast stays frontmost while older items tuck behind it.', }, - ].forEach(item => toast.add(item)) + ].forEach((item) => { + toast[item.type](item.title, { + description: item.description, + }) + }) } const createBurst = () => { Array.from({ length: 5 }).forEach((_, index) => { - toast.add({ - type: index % 2 === 0 ? 'info' : 'success', - title: `Background task ${index + 1}`, + toast[index % 2 === 0 ? 'info' : 'success'](`Background task ${index + 1}`, { description: 'Use this to inspect how the stack behaves near the host limit.', }) }) @@ -191,16 +192,12 @@ const PromiseExamples = () => { const ActionExamples = () => { const createActionToast = () => { - toast.add({ - type: 'warning', - title: 'Project archived', + toast.warning('Project archived', { description: 'You can restore it from workspace settings for the next 30 days.', actionProps: { children: 'Undo', onClick: () => { - toast.add({ - type: 'success', - title: 'Project restored', + toast.success('Project restored', { description: 'The workspace is active again.', }) }, @@ -209,17 +206,12 @@ const ActionExamples = () => { } const createLongCopyToast = () => { - toast.add({ - type: 'info', - title: 'Knowledge ingestion in progress', + toast.info('Knowledge ingestion in progress', { description: 'This longer example helps validate line wrapping, close button alignment, and action button placement when the content spans multiple rows.', actionProps: { children: 'View details', onClick: () => { - toast.add({ - type: 'info', - title: 'Job details opened', - }) + toast.info('Job details opened') }, }, }) @@ -243,9 +235,7 @@ const ActionExamples = () => { const UpdateExamples = () => { const createUpdatableToast = () => { - const toastId = toast.add({ - type: 'info', - title: 'Import started', + const toastId = toast.info('Import started', { description: 'Preparing assets and metadata for processing.', timeout: 0, }) @@ -261,7 +251,7 @@ const UpdateExamples = () => { } const clearAll = () => { - toast.close() + toast.dismiss() } return ( diff --git a/web/app/components/base/ui/toast/index.tsx b/web/app/components/base/ui/toast/index.tsx index d91648e44a9..a3f4e13727b 100644 --- a/web/app/components/base/ui/toast/index.tsx +++ b/web/app/components/base/ui/toast/index.tsx @@ -5,6 +5,7 @@ import type { ToastManagerUpdateOptions, ToastObject, } from '@base-ui/react/toast' +import type { ReactNode } from 'react' import { Toast as BaseToast } from '@base-ui/react/toast' import { useTranslation } from 'react-i18next' import { cn } from '@/utils/classnames' @@ -44,6 +45,9 @@ export type ToastUpdateOptions = Omit, 'dat type?: ToastType } +export type ToastOptions = Omit +export type TypedToastOptions = Omit + type ToastPromiseResultOption = string | ToastUpdateOptions | ((value: Value) => string | ToastUpdateOptions) export type ToastPromiseOptions = { @@ -57,6 +61,21 @@ export type ToastHostProps = { limit?: number } +type ToastDismiss = (toastId?: string) => void +type ToastCall = (title: ReactNode, options?: ToastOptions) => string +type TypedToastCall = (title: ReactNode, options?: TypedToastOptions) => string + +export type ToastApi = { + (title: ReactNode, options?: ToastOptions): string + success: TypedToastCall + error: TypedToastCall + warning: TypedToastCall + info: TypedToastCall + dismiss: ToastDismiss + update: (toastId: string, options: ToastUpdateOptions) => void + promise: (promiseValue: Promise, options: ToastPromiseOptions) => Promise +} + const toastManager = BaseToast.createToastManager() function isToastType(type: string): type is ToastType { @@ -67,21 +86,48 @@ function getToastType(type?: string): ToastType | undefined { return type && isToastType(type) ? type : undefined } -export const toast = { - add(options: ToastAddOptions) { - return toastManager.add(options) - }, - close(toastId?: string) { - toastManager.close(toastId) - }, - update(toastId: string, options: ToastUpdateOptions) { - toastManager.update(toastId, options) - }, - promise(promiseValue: Promise, options: ToastPromiseOptions) { - return toastManager.promise(promiseValue, options) - }, +function addToast(options: ToastAddOptions) { + return toastManager.add(options) } +const showToast: ToastCall = (title, options) => addToast({ + ...options, + title, +}) + +const dismissToast: ToastDismiss = (toastId) => { + toastManager.close(toastId) +} + +function createTypedToast(type: ToastType): TypedToastCall { + return (title, options) => addToast({ + ...options, + title, + type, + }) +} + +function updateToast(toastId: string, options: ToastUpdateOptions) { + toastManager.update(toastId, options) +} + +function promiseToast(promiseValue: Promise, options: ToastPromiseOptions) { + return toastManager.promise(promiseValue, options) +} + +export const toast: ToastApi = Object.assign( + showToast, + { + success: createTypedToast('success'), + error: createTypedToast('error'), + warning: createTypedToast('warning'), + info: createTypedToast('info'), + dismiss: dismissToast, + update: updateToast, + promise: promiseToast, + }, +) + function ToastIcon({ type }: { type?: ToastType }) { return type ?