mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 09:39:25 +08:00
Merge branch 'refactor-enhance-billing-info-guard' into deploy/dev
This commit is contained in:
1
.github/actions/setup-web/action.yml
vendored
1
.github/actions/setup-web/action.yml
vendored
@@ -6,7 +6,6 @@ runs:
|
||||
- name: Setup Vite+
|
||||
uses: voidzero-dev/setup-vp@20553a7a7429c429a74894104a2835d7fed28a72 # v1.3.0
|
||||
with:
|
||||
working-directory: web
|
||||
node-version-file: .nvmrc
|
||||
cache: true
|
||||
run-install: true
|
||||
|
||||
8
.github/workflows/api-tests.yml
vendored
8
.github/workflows/api-tests.yml
vendored
@@ -35,7 +35,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@@ -84,7 +84,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@@ -156,7 +156,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
@@ -203,7 +203,7 @@ jobs:
|
||||
|
||||
- name: Report coverage
|
||||
if: ${{ env.CODECOV_TOKEN != '' }}
|
||||
uses: codecov/codecov-action@1af58845a975a7985b0beb0cbe6fbbb71a41dbad # v5.5.3
|
||||
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
|
||||
with:
|
||||
files: ./coverage.xml
|
||||
disable_search: true
|
||||
|
||||
6
.github/workflows/autofix.yml
vendored
6
.github/workflows/autofix.yml
vendored
@@ -39,6 +39,10 @@ jobs:
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
package.json
|
||||
pnpm-lock.yaml
|
||||
pnpm-workspace.yaml
|
||||
.nvmrc
|
||||
- name: Check api inputs
|
||||
if: github.event_name != 'merge_group'
|
||||
id: api-changes
|
||||
@@ -52,7 +56,7 @@ jobs:
|
||||
python-version: "3.11"
|
||||
|
||||
- if: github.event_name != 'merge_group'
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
|
||||
- name: Generate Docker Compose
|
||||
if: github.event_name != 'merge_group' && steps.docker-compose-changes.outputs.any_changed == 'true'
|
||||
|
||||
30
.github/workflows/build-push.yml
vendored
30
.github/workflows/build-push.yml
vendored
@@ -24,27 +24,39 @@ env:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ${{ matrix.platform == 'linux/arm64' && 'arm64_runner' || 'ubuntu-latest' }}
|
||||
runs-on: ${{ matrix.runs_on }}
|
||||
if: github.repository == 'langgenius/dify'
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- service_name: "build-api-amd64"
|
||||
image_name_env: "DIFY_API_IMAGE_NAME"
|
||||
context: "api"
|
||||
artifact_context: "api"
|
||||
build_context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
platform: linux/amd64
|
||||
runs_on: ubuntu-latest
|
||||
- service_name: "build-api-arm64"
|
||||
image_name_env: "DIFY_API_IMAGE_NAME"
|
||||
context: "api"
|
||||
artifact_context: "api"
|
||||
build_context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
platform: linux/arm64
|
||||
runs_on: ubuntu-24.04-arm
|
||||
- service_name: "build-web-amd64"
|
||||
image_name_env: "DIFY_WEB_IMAGE_NAME"
|
||||
context: "web"
|
||||
artifact_context: "web"
|
||||
build_context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
platform: linux/amd64
|
||||
runs_on: ubuntu-latest
|
||||
- service_name: "build-web-arm64"
|
||||
image_name_env: "DIFY_WEB_IMAGE_NAME"
|
||||
context: "web"
|
||||
artifact_context: "web"
|
||||
build_context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
platform: linux/arm64
|
||||
runs_on: ubuntu-24.04-arm
|
||||
|
||||
steps:
|
||||
- name: Prepare
|
||||
@@ -58,9 +70,6 @@ jobs:
|
||||
username: ${{ env.DOCKERHUB_USER }}
|
||||
password: ${{ env.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4.0.0
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||
|
||||
@@ -74,7 +83,8 @@ jobs:
|
||||
id: build
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
|
||||
with:
|
||||
context: "{{defaultContext}}:${{ matrix.context }}"
|
||||
context: ${{ matrix.build_context }}
|
||||
file: ${{ matrix.file }}
|
||||
platforms: ${{ matrix.platform }}
|
||||
build-args: COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
@@ -93,7 +103,7 @@ jobs:
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
with:
|
||||
name: digests-${{ matrix.context }}-${{ env.PLATFORM_PAIR }}
|
||||
name: digests-${{ matrix.artifact_context }}-${{ env.PLATFORM_PAIR }}
|
||||
path: /tmp/digests/*
|
||||
if-no-files-found: error
|
||||
retention-days: 1
|
||||
|
||||
4
.github/workflows/db-migration-test.yml
vendored
4
.github/workflows/db-migration-test.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
@@ -69,7 +69,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
|
||||
30
.github/workflows/docker-build.yml
vendored
30
.github/workflows/docker-build.yml
vendored
@@ -6,7 +6,12 @@ on:
|
||||
- "main"
|
||||
paths:
|
||||
- api/Dockerfile
|
||||
- web/docker/**
|
||||
- web/Dockerfile
|
||||
- package.json
|
||||
- pnpm-lock.yaml
|
||||
- pnpm-workspace.yaml
|
||||
- .nvmrc
|
||||
|
||||
concurrency:
|
||||
group: docker-build-${{ github.head_ref || github.run_id }}
|
||||
@@ -14,26 +19,31 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
build-docker:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ${{ matrix.runs_on }}
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- service_name: "api-amd64"
|
||||
platform: linux/amd64
|
||||
context: "api"
|
||||
runs_on: ubuntu-latest
|
||||
context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
- service_name: "api-arm64"
|
||||
platform: linux/arm64
|
||||
context: "api"
|
||||
runs_on: ubuntu-24.04-arm
|
||||
context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
- service_name: "web-amd64"
|
||||
platform: linux/amd64
|
||||
context: "web"
|
||||
runs_on: ubuntu-latest
|
||||
context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
- service_name: "web-arm64"
|
||||
platform: linux/arm64
|
||||
context: "web"
|
||||
runs_on: ubuntu-24.04-arm
|
||||
context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
steps:
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4.0.0
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||
|
||||
@@ -41,8 +51,8 @@ jobs:
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
|
||||
with:
|
||||
push: false
|
||||
context: "{{defaultContext}}:${{ matrix.context }}"
|
||||
file: "${{ matrix.file }}"
|
||||
context: ${{ matrix.context }}
|
||||
file: ${{ matrix.file }}
|
||||
platforms: ${{ matrix.platform }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
8
.github/workflows/main-ci.yml
vendored
8
.github/workflows/main-ci.yml
vendored
@@ -65,6 +65,10 @@ jobs:
|
||||
- 'docker/volumes/sandbox/conf/**'
|
||||
web:
|
||||
- 'web/**'
|
||||
- 'package.json'
|
||||
- 'pnpm-lock.yaml'
|
||||
- 'pnpm-workspace.yaml'
|
||||
- '.nvmrc'
|
||||
- '.github/workflows/web-tests.yml'
|
||||
- '.github/actions/setup-web/**'
|
||||
e2e:
|
||||
@@ -73,6 +77,10 @@ jobs:
|
||||
- 'api/uv.lock'
|
||||
- 'e2e/**'
|
||||
- 'web/**'
|
||||
- 'package.json'
|
||||
- 'pnpm-lock.yaml'
|
||||
- 'pnpm-workspace.yaml'
|
||||
- '.nvmrc'
|
||||
- 'docker/docker-compose.middleware.yaml'
|
||||
- 'docker/middleware.env.example'
|
||||
- '.github/workflows/web-e2e.yml'
|
||||
|
||||
15
.github/workflows/pyrefly-diff.yml
vendored
15
.github/workflows/pyrefly-diff.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python & UV
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
@@ -50,6 +50,17 @@ jobs:
|
||||
run: |
|
||||
diff -u /tmp/pyrefly_base.txt /tmp/pyrefly_pr.txt > pyrefly_diff.txt || true
|
||||
|
||||
- name: Check if line counts match
|
||||
id: line_count_check
|
||||
run: |
|
||||
base_lines=$(wc -l < /tmp/pyrefly_base.txt)
|
||||
pr_lines=$(wc -l < /tmp/pyrefly_pr.txt)
|
||||
if [ "$base_lines" -eq "$pr_lines" ]; then
|
||||
echo "same=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "same=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Save PR number
|
||||
run: |
|
||||
echo ${{ github.event.pull_request.number }} > pr_number.txt
|
||||
@@ -63,7 +74,7 @@ jobs:
|
||||
pr_number.txt
|
||||
|
||||
- name: Comment PR with pyrefly diff
|
||||
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
|
||||
if: ${{ github.event.pull_request.head.repo.full_name == github.repository && steps.line_count_check.outputs.same == 'false' }}
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
10
.github/workflows/style.yml
vendored
10
.github/workflows/style.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
||||
|
||||
- name: Setup UV and Python
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: false
|
||||
python-version: "3.12"
|
||||
@@ -77,6 +77,10 @@ jobs:
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
package.json
|
||||
pnpm-lock.yaml
|
||||
pnpm-workspace.yaml
|
||||
.nvmrc
|
||||
.github/workflows/style.yml
|
||||
.github/actions/setup-web/**
|
||||
|
||||
@@ -90,9 +94,9 @@ jobs:
|
||||
uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
|
||||
with:
|
||||
path: web/.eslintcache
|
||||
key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
|
||||
key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
|
||||
${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
|
||||
|
||||
- name: Web style check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
|
||||
3
.github/workflows/tool-test-sdks.yaml
vendored
3
.github/workflows/tool-test-sdks.yaml
vendored
@@ -6,6 +6,9 @@ on:
|
||||
- main
|
||||
paths:
|
||||
- sdks/**
|
||||
- package.json
|
||||
- pnpm-lock.yaml
|
||||
- pnpm-workspace.yaml
|
||||
|
||||
concurrency:
|
||||
group: sdk-tests-${{ github.head_ref || github.run_id }}
|
||||
|
||||
95
.github/workflows/vdb-tests-full.yml
vendored
Normal file
95
.github/workflows/vdb-tests-full.yml
vendored
Normal file
@@ -0,0 +1,95 @@
|
||||
name: Run Full VDB Tests
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 3 * * 1'
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: vdb-tests-full-${{ github.ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: Full VDB Tests
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
- "3.12"
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Free Disk Space
|
||||
uses: endersonmenezes/free-disk-space@7901478139cff6e9d44df5972fd8ab8fcade4db1 # v3.2.2
|
||||
with:
|
||||
remove_dotnet: true
|
||||
remove_haskell: true
|
||||
remove_tool_cache: true
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache-dependency-glob: api/uv.lock
|
||||
|
||||
- name: Check UV lockfile
|
||||
run: uv lock --project api --check
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Set up dotenvs
|
||||
run: |
|
||||
cp docker/.env.example docker/.env
|
||||
cp docker/middleware.env.example docker/middleware.env
|
||||
|
||||
- name: Expose Service Ports
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
# - name: Set up Vector Store (TiDB)
|
||||
# uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
# with:
|
||||
# compose-file: docker/tidb/docker-compose.yaml
|
||||
# services: |
|
||||
# tidb
|
||||
# tiflash
|
||||
|
||||
- name: Set up Full Vector Store Matrix
|
||||
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.yaml
|
||||
services: |
|
||||
weaviate
|
||||
qdrant
|
||||
couchbase-server
|
||||
etcd
|
||||
minio
|
||||
milvus-standalone
|
||||
pgvecto-rs
|
||||
pgvector
|
||||
chroma
|
||||
elasticsearch
|
||||
oceanbase
|
||||
|
||||
- name: setup test config
|
||||
run: |
|
||||
echo $(pwd)
|
||||
ls -lah .
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
# - name: Check VDB Ready (TiDB)
|
||||
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: uv run --project api bash dev/pytest/pytest_vdb.sh
|
||||
27
.github/workflows/vdb-tests.yml
vendored
27
.github/workflows/vdb-tests.yml
vendored
@@ -1,15 +1,18 @@
|
||||
name: Run VDB Tests
|
||||
name: Run VDB Smoke Tests
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: vdb-tests-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: VDB Tests
|
||||
name: VDB Smoke Tests
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
@@ -30,7 +33,7 @@ jobs:
|
||||
remove_tool_cache: true
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@@ -58,23 +61,18 @@ jobs:
|
||||
# tidb
|
||||
# tiflash
|
||||
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase, OceanBase)
|
||||
- name: Set up Vector Stores for Smoke Coverage
|
||||
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.yaml
|
||||
services: |
|
||||
db_postgres
|
||||
redis
|
||||
weaviate
|
||||
qdrant
|
||||
couchbase-server
|
||||
etcd
|
||||
minio
|
||||
milvus-standalone
|
||||
pgvecto-rs
|
||||
pgvector
|
||||
chroma
|
||||
elasticsearch
|
||||
oceanbase
|
||||
|
||||
- name: setup test config
|
||||
run: |
|
||||
@@ -86,4 +84,9 @@ jobs:
|
||||
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: uv run --project api bash dev/pytest/pytest_vdb.sh
|
||||
run: |
|
||||
uv run --project api pytest --timeout "${PYTEST_TIMEOUT:-180}" \
|
||||
api/tests/integration_tests/vdb/chroma \
|
||||
api/tests/integration_tests/vdb/pgvector \
|
||||
api/tests/integration_tests/vdb/qdrant \
|
||||
api/tests/integration_tests/vdb/weaviate
|
||||
|
||||
6
.github/workflows/web-e2e.yml
vendored
6
.github/workflows/web-e2e.yml
vendored
@@ -27,12 +27,8 @@ jobs:
|
||||
- name: Setup web dependencies
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Install E2E package dependencies
|
||||
working-directory: ./e2e
|
||||
run: vp install --frozen-lockfile
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
|
||||
33
.github/workflows/web-tests.yml
vendored
33
.github/workflows/web-tests.yml
vendored
@@ -83,40 +83,9 @@ jobs:
|
||||
|
||||
- name: Report coverage
|
||||
if: ${{ env.CODECOV_TOKEN != '' }}
|
||||
uses: codecov/codecov-action@1af58845a975a7985b0beb0cbe6fbbb71a41dbad # v5.5.3
|
||||
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
|
||||
with:
|
||||
directory: web/coverage
|
||||
flags: web
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
|
||||
|
||||
web-build:
|
||||
name: Web Build
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: ./web
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
.github/workflows/web-tests.yml
|
||||
.github/actions/setup-web/**
|
||||
|
||||
- name: Setup web environment
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Web build check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: vp run build
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -212,6 +212,7 @@ api/.vscode
|
||||
|
||||
# pnpm
|
||||
/.pnpm-store
|
||||
/node_modules
|
||||
|
||||
# plugin migrate
|
||||
plugins.jsonl
|
||||
@@ -239,4 +240,4 @@ scripts/stress-test/reports/
|
||||
*.local.md
|
||||
|
||||
# Code Agent Folder
|
||||
.qoder/*
|
||||
.qoder/*
|
||||
|
||||
6
Makefile
6
Makefile
@@ -24,8 +24,8 @@ prepare-docker:
|
||||
# Step 2: Prepare web environment
|
||||
prepare-web:
|
||||
@echo "🌐 Setting up web environment..."
|
||||
@cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists"
|
||||
@cd web && pnpm install
|
||||
@cp -n web/.env.example web/.env.local 2>/dev/null || echo "Web .env.local already exists"
|
||||
@pnpm install
|
||||
@echo "✅ Web environment prepared (not started)"
|
||||
|
||||
# Step 3: Prepare API environment
|
||||
@@ -93,7 +93,7 @@ test:
|
||||
# Build Docker images
|
||||
build-web:
|
||||
@echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..."
|
||||
docker build -t $(WEB_IMAGE):$(VERSION) ./web
|
||||
docker build -f web/Dockerfile -t $(WEB_IMAGE):$(VERSION) .
|
||||
@echo "Web Docker image built successfully: $(WEB_IMAGE):$(VERSION)"
|
||||
|
||||
build-api:
|
||||
|
||||
@@ -40,6 +40,8 @@ The scripts resolve paths relative to their location, so you can run them from a
|
||||
./dev/start-web
|
||||
```
|
||||
|
||||
`./dev/setup` and `./dev/start-web` install JavaScript dependencies through the repository root workspace, so you do not need a separate `cd web && pnpm install` step.
|
||||
|
||||
1. Set up your application by visiting `http://localhost:3000`.
|
||||
|
||||
1. Start the worker service (async and scheduler tasks, runs from `api`).
|
||||
|
||||
@@ -120,7 +120,8 @@ class DatasourceOAuthCallback(Resource):
|
||||
if context is None:
|
||||
raise Forbidden("Invalid context_id")
|
||||
|
||||
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
|
||||
user_id: str = context["user_id"]
|
||||
tenant_id: str = context["tenant_id"]
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
plugin_id = datasource_provider_id.plugin_id
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
@@ -141,7 +142,7 @@ class DatasourceOAuthCallback(Resource):
|
||||
system_credentials=oauth_client_params,
|
||||
request=request,
|
||||
)
|
||||
credential_id = context.get("credential_id")
|
||||
credential_id: str | None = context.get("credential_id")
|
||||
if credential_id:
|
||||
datasource_provider_service.reauthorize_datasource_oauth_provider(
|
||||
tenant_id=tenant_id,
|
||||
@@ -150,7 +151,7 @@ class DatasourceOAuthCallback(Resource):
|
||||
name=oauth_response.metadata.get("name") or None,
|
||||
expire_at=oauth_response.expires_at,
|
||||
credentials=dict(oauth_response.credentials),
|
||||
credential_id=context.get("credential_id"),
|
||||
credential_id=credential_id,
|
||||
)
|
||||
else:
|
||||
datasource_provider_service.add_datasource_oauth_provider(
|
||||
|
||||
@@ -287,12 +287,10 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
provider=provider,
|
||||
)
|
||||
else:
|
||||
# Normalize model_type to the origin value stored in DB (e.g., "text-generation" for LLM)
|
||||
normalized_model_type = args.model_type.to_origin_model_type()
|
||||
available_credentials = model_provider_service.get_provider_model_available_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model_type=normalized_model_type,
|
||||
model_type=args.model_type,
|
||||
model=args.model,
|
||||
)
|
||||
|
||||
|
||||
@@ -832,7 +832,8 @@ class ToolOAuthCallback(Resource):
|
||||
tool_provider = ToolProviderID(provider)
|
||||
plugin_id = tool_provider.plugin_id
|
||||
provider_name = tool_provider.provider_name
|
||||
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
|
||||
user_id: str = context["user_id"]
|
||||
tenant_id: str = context["tenant_id"]
|
||||
|
||||
oauth_handler = OAuthHandler()
|
||||
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id, provider)
|
||||
|
||||
@@ -499,9 +499,9 @@ class TriggerOAuthCallbackApi(Resource):
|
||||
provider_id = TriggerProviderID(provider)
|
||||
plugin_id = provider_id.plugin_id
|
||||
provider_name = provider_id.provider_name
|
||||
user_id = context.get("user_id")
|
||||
tenant_id = context.get("tenant_id")
|
||||
subscription_builder_id = context.get("subscription_builder_id")
|
||||
user_id: str = context["user_id"]
|
||||
tenant_id: str = context["tenant_id"]
|
||||
subscription_builder_id: str = context["subscription_builder_id"]
|
||||
|
||||
# Get OAuth client configuration
|
||||
oauth_client_params = TriggerProviderService.get_oauth_client(
|
||||
|
||||
@@ -174,6 +174,7 @@ class MCPAppApi(Resource):
|
||||
required=variable.get("required", False),
|
||||
max_length=variable.get("max_length"),
|
||||
options=variable.get("options") or [],
|
||||
json_schema=variable.get("json_schema"),
|
||||
)
|
||||
|
||||
def _parse_mcp_request(self, args: dict) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
|
||||
|
||||
@@ -29,6 +29,31 @@ from services.entities.knowledge_entities.knowledge_entities import SegmentUpdat
|
||||
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
|
||||
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
|
||||
from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
|
||||
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict:
|
||||
"""Marshal a single segment and enrich it with summary content."""
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
|
||||
segment_dict["summary"] = summary.summary_content if summary else None
|
||||
return segment_dict
|
||||
|
||||
|
||||
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict]:
|
||||
"""Marshal multiple segments and enrich them with summary content (batch query)."""
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
summaries: dict = {}
|
||||
if segment_ids:
|
||||
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
|
||||
summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()}
|
||||
|
||||
result = []
|
||||
for segment in segments:
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
segment_dict["summary"] = summaries.get(segment.id)
|
||||
result.append(segment_dict)
|
||||
return result
|
||||
|
||||
|
||||
class SegmentCreatePayload(BaseModel):
|
||||
@@ -132,7 +157,7 @@ class SegmentApi(DatasetApiResource):
|
||||
for args_item in payload.segments:
|
||||
SegmentService.segment_create_args_validate(args_item, document)
|
||||
segments = SegmentService.multi_create_segment(payload.segments, document, dataset)
|
||||
return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200
|
||||
return {"data": _marshal_segments_with_summary(segments, dataset_id), "doc_form": document.doc_form}, 200
|
||||
else:
|
||||
return {"error": "Segments is required"}, 400
|
||||
|
||||
@@ -196,7 +221,7 @@ class SegmentApi(DatasetApiResource):
|
||||
)
|
||||
|
||||
response = {
|
||||
"data": marshal(segments, segment_fields),
|
||||
"data": _marshal_segments_with_summary(segments, dataset_id),
|
||||
"doc_form": document.doc_form,
|
||||
"total": total,
|
||||
"has_more": len(segments) == limit,
|
||||
@@ -296,7 +321,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset)
|
||||
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
return {"data": _marshal_segment_with_summary(updated_segment, dataset_id), "doc_form": document.doc_form}, 200
|
||||
|
||||
@service_api_ns.doc("get_segment")
|
||||
@service_api_ns.doc(description="Get a specific segment by ID")
|
||||
@@ -326,7 +351,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
return {"data": _marshal_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
|
||||
|
||||
|
||||
@service_api_ns.route(
|
||||
|
||||
@@ -81,7 +81,7 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
)
|
||||
.values(
|
||||
|
||||
@@ -8,6 +8,7 @@ associates with the node span.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from contextvars import Token
|
||||
from dataclasses import dataclass
|
||||
from typing import cast, final
|
||||
|
||||
@@ -35,7 +36,7 @@ logger = logging.getLogger(__name__)
|
||||
@dataclass(slots=True)
|
||||
class _NodeSpanContext:
|
||||
span: "Span"
|
||||
token: object
|
||||
token: Token[context_api.Context]
|
||||
|
||||
|
||||
@final
|
||||
|
||||
@@ -403,7 +403,7 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -753,7 +753,7 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderModel.tenant_id == self.tenant_id,
|
||||
ProviderModel.provider_name.in_(provider_names),
|
||||
ProviderModel.model_name == model,
|
||||
ProviderModel.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModel.model_type == model_type,
|
||||
)
|
||||
|
||||
return session.execute(stmt).scalar_one_or_none()
|
||||
@@ -778,7 +778,7 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
)
|
||||
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
@@ -825,7 +825,7 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
ProviderModelCredential.credential_name == credential_name,
|
||||
)
|
||||
if exclude_id:
|
||||
@@ -901,7 +901,7 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
)
|
||||
credential_record = s.execute(stmt).scalar_one_or_none()
|
||||
original_credentials = (
|
||||
@@ -970,7 +970,7 @@ class ProviderConfiguration(BaseModel):
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_name=model,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_type=model_type,
|
||||
encrypted_config=json.dumps(credentials),
|
||||
credential_name=credential_name,
|
||||
)
|
||||
@@ -983,7 +983,7 @@ class ProviderConfiguration(BaseModel):
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_name=model,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_type=model_type,
|
||||
credential_id=credential.id,
|
||||
is_valid=True,
|
||||
)
|
||||
@@ -1038,7 +1038,7 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
)
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
if not credential_record:
|
||||
@@ -1083,7 +1083,7 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
)
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
if not credential_record:
|
||||
@@ -1116,7 +1116,7 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
)
|
||||
available_credentials_count = session.execute(count_stmt).scalar() or 0
|
||||
session.delete(credential_record)
|
||||
@@ -1156,7 +1156,7 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
)
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
if not credential_record:
|
||||
@@ -1171,7 +1171,7 @@ class ProviderConfiguration(BaseModel):
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_name=model,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_type=model_type,
|
||||
is_valid=True,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
@@ -1207,7 +1207,7 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
)
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
if not credential_record:
|
||||
@@ -1263,7 +1263,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderModelSetting).where(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_type == model_type,
|
||||
ProviderModelSetting.model_name == model,
|
||||
)
|
||||
return session.execute(stmt).scalars().first()
|
||||
@@ -1286,7 +1286,7 @@ class ProviderConfiguration(BaseModel):
|
||||
model_setting = ProviderModelSetting(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_type=model_type,
|
||||
model_name=model,
|
||||
enabled=True,
|
||||
)
|
||||
@@ -1312,7 +1312,7 @@ class ProviderConfiguration(BaseModel):
|
||||
model_setting = ProviderModelSetting(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_type=model_type,
|
||||
model_name=model,
|
||||
enabled=False,
|
||||
)
|
||||
@@ -1348,7 +1348,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(func.count(LoadBalancingModelConfig.id)).where(
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name.in_(provider_names),
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_type == model_type,
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
)
|
||||
load_balancing_config_count = session.execute(stmt).scalar() or 0
|
||||
@@ -1364,7 +1364,7 @@ class ProviderConfiguration(BaseModel):
|
||||
model_setting = ProviderModelSetting(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_type=model_type,
|
||||
model_name=model,
|
||||
load_balancing_enabled=True,
|
||||
)
|
||||
@@ -1391,7 +1391,7 @@ class ProviderConfiguration(BaseModel):
|
||||
model_setting = ProviderModelSetting(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_type=model_type,
|
||||
model_name=model,
|
||||
load_balancing_enabled=False,
|
||||
)
|
||||
|
||||
@@ -260,4 +260,12 @@ def convert_input_form_to_parameters(
|
||||
parameters[item.variable]["enum"] = item.options
|
||||
elif item.type == VariableEntityType.NUMBER:
|
||||
parameters[item.variable]["type"] = "number"
|
||||
elif item.type == VariableEntityType.CHECKBOX:
|
||||
parameters[item.variable]["type"] = "boolean"
|
||||
elif item.type == VariableEntityType.JSON_OBJECT:
|
||||
parameters[item.variable]["type"] = "object"
|
||||
if item.json_schema:
|
||||
for key in ("properties", "required", "additionalProperties"):
|
||||
if key in item.json_schema:
|
||||
parameters[item.variable][key] = item.json_schema[key]
|
||||
return parameters, required
|
||||
|
||||
@@ -16,7 +16,13 @@ from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExport
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import ReadableSpan
|
||||
from opentelemetry.sdk.util.instrumentation import InstrumentationScope
|
||||
from opentelemetry.semconv.resource import ResourceAttributes
|
||||
from opentelemetry.semconv._incubating.attributes.deployment_attributes import ( # type: ignore[import-untyped]
|
||||
DEPLOYMENT_ENVIRONMENT,
|
||||
)
|
||||
from opentelemetry.semconv._incubating.attributes.host_attributes import ( # type: ignore[import-untyped]
|
||||
HOST_NAME,
|
||||
)
|
||||
from opentelemetry.semconv.attributes import service_attributes
|
||||
from opentelemetry.trace import Link, SpanContext, TraceFlags
|
||||
|
||||
from configs import dify_config
|
||||
@@ -45,10 +51,10 @@ class TraceClient:
|
||||
self.endpoint = endpoint
|
||||
self.resource = Resource(
|
||||
attributes={
|
||||
ResourceAttributes.SERVICE_NAME: service_name,
|
||||
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
|
||||
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
|
||||
ResourceAttributes.HOST_NAME: socket.gethostname(),
|
||||
service_attributes.SERVICE_NAME: service_name,
|
||||
service_attributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
|
||||
DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
|
||||
HOST_NAME: socket.gethostname(),
|
||||
ACS_ARMS_SERVICE_FEATURE: "genai_app",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -19,7 +19,7 @@ from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExport
|
||||
from opentelemetry.sdk import trace as trace_sdk
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
||||
from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes
|
||||
from opentelemetry.semconv.attributes import exception_attributes
|
||||
from opentelemetry.trace import Span, Status, StatusCode, set_span_in_context, use_span
|
||||
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
||||
from opentelemetry.util.types import AttributeValue
|
||||
@@ -134,10 +134,10 @@ def set_span_status(current_span: Span, error: Exception | str | None = None):
|
||||
if not exception_message:
|
||||
exception_message = repr(error)
|
||||
attributes: dict[str, AttributeValue] = {
|
||||
OTELSpanAttributes.EXCEPTION_TYPE: exception_type,
|
||||
OTELSpanAttributes.EXCEPTION_MESSAGE: exception_message,
|
||||
OTELSpanAttributes.EXCEPTION_ESCAPED: False,
|
||||
OTELSpanAttributes.EXCEPTION_STACKTRACE: error_string,
|
||||
exception_attributes.EXCEPTION_TYPE: exception_type,
|
||||
exception_attributes.EXCEPTION_MESSAGE: exception_message,
|
||||
exception_attributes.EXCEPTION_ESCAPED: False,
|
||||
exception_attributes.EXCEPTION_STACKTRACE: error_string,
|
||||
}
|
||||
current_span.add_event(name="exception", attributes=attributes)
|
||||
else:
|
||||
|
||||
@@ -1,9 +1,19 @@
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from langfuse import Langfuse
|
||||
from langfuse.api import (
|
||||
CreateGenerationBody,
|
||||
CreateSpanBody,
|
||||
IngestionEvent_GenerationCreate,
|
||||
IngestionEvent_SpanCreate,
|
||||
IngestionEvent_TraceCreate,
|
||||
TraceBody,
|
||||
)
|
||||
from langfuse.api.commons.types.usage import Usage
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
@@ -396,18 +406,61 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
)
|
||||
self.add_span(langfuse_span_data=name_generation_span_data)
|
||||
|
||||
def _make_event_id(self) -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def _now_iso(self) -> str:
|
||||
return datetime.now(UTC).isoformat()
|
||||
|
||||
def add_trace(self, langfuse_trace_data: LangfuseTrace | None = None):
|
||||
format_trace_data = filter_none_values(langfuse_trace_data.model_dump()) if langfuse_trace_data else {}
|
||||
data = filter_none_values(langfuse_trace_data.model_dump()) if langfuse_trace_data else {}
|
||||
try:
|
||||
self.langfuse_client.trace(**format_trace_data)
|
||||
body = TraceBody(
|
||||
id=data.get("id"),
|
||||
name=data.get("name"),
|
||||
user_id=data.get("user_id"),
|
||||
input=data.get("input"),
|
||||
output=data.get("output"),
|
||||
metadata=data.get("metadata"),
|
||||
session_id=data.get("session_id"),
|
||||
version=data.get("version"),
|
||||
release=data.get("release"),
|
||||
tags=data.get("tags"),
|
||||
public=data.get("public"),
|
||||
)
|
||||
event = IngestionEvent_TraceCreate(
|
||||
body=body,
|
||||
id=self._make_event_id(),
|
||||
timestamp=self._now_iso(),
|
||||
)
|
||||
self.langfuse_client.api.ingestion.batch(batch=[event])
|
||||
logger.debug("LangFuse Trace created successfully")
|
||||
except Exception as e:
|
||||
raise ValueError(f"LangFuse Failed to create trace: {str(e)}")
|
||||
|
||||
def add_span(self, langfuse_span_data: LangfuseSpan | None = None):
|
||||
format_span_data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {}
|
||||
data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {}
|
||||
try:
|
||||
self.langfuse_client.span(**format_span_data)
|
||||
body = CreateSpanBody(
|
||||
id=data.get("id"),
|
||||
trace_id=data.get("trace_id"),
|
||||
name=data.get("name"),
|
||||
start_time=data.get("start_time"),
|
||||
end_time=data.get("end_time"),
|
||||
input=data.get("input"),
|
||||
output=data.get("output"),
|
||||
metadata=data.get("metadata"),
|
||||
level=data.get("level"),
|
||||
status_message=data.get("status_message"),
|
||||
parent_observation_id=data.get("parent_observation_id"),
|
||||
version=data.get("version"),
|
||||
)
|
||||
event = IngestionEvent_SpanCreate(
|
||||
body=body,
|
||||
id=self._make_event_id(),
|
||||
timestamp=self._now_iso(),
|
||||
)
|
||||
self.langfuse_client.api.ingestion.batch(batch=[event])
|
||||
logger.debug("LangFuse Span created successfully")
|
||||
except Exception as e:
|
||||
raise ValueError(f"LangFuse Failed to create span: {str(e)}")
|
||||
@@ -418,11 +471,45 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
span.end(**format_span_data)
|
||||
|
||||
def add_generation(self, langfuse_generation_data: LangfuseGeneration | None = None):
|
||||
format_generation_data = (
|
||||
filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {}
|
||||
)
|
||||
data = filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {}
|
||||
try:
|
||||
self.langfuse_client.generation(**format_generation_data)
|
||||
usage_data = data.pop("usage", None)
|
||||
usage = None
|
||||
if usage_data:
|
||||
usage = Usage(
|
||||
input=usage_data.get("input", 0) or 0,
|
||||
output=usage_data.get("output", 0) or 0,
|
||||
total=usage_data.get("total", 0) or 0,
|
||||
unit=usage_data.get("unit"),
|
||||
input_cost=usage_data.get("inputCost"),
|
||||
output_cost=usage_data.get("outputCost"),
|
||||
total_cost=usage_data.get("totalCost"),
|
||||
)
|
||||
|
||||
body = CreateGenerationBody(
|
||||
id=data.get("id"),
|
||||
trace_id=data.get("trace_id"),
|
||||
name=data.get("name"),
|
||||
start_time=data.get("start_time"),
|
||||
end_time=data.get("end_time"),
|
||||
model=data.get("model"),
|
||||
model_parameters=data.get("model_parameters"),
|
||||
input=data.get("input"),
|
||||
output=data.get("output"),
|
||||
usage=usage,
|
||||
metadata=data.get("metadata"),
|
||||
level=data.get("level"),
|
||||
status_message=data.get("status_message"),
|
||||
parent_observation_id=data.get("parent_observation_id"),
|
||||
version=data.get("version"),
|
||||
completion_start_time=data.get("completion_start_time"),
|
||||
)
|
||||
event = IngestionEvent_GenerationCreate(
|
||||
body=body,
|
||||
id=self._make_event_id(),
|
||||
timestamp=self._now_iso(),
|
||||
)
|
||||
self.langfuse_client.api.ingestion.batch(batch=[event])
|
||||
logger.debug("LangFuse Generation created successfully")
|
||||
except Exception as e:
|
||||
raise ValueError(f"LangFuse Failed to create generation: {str(e)}")
|
||||
@@ -443,7 +530,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
|
||||
def get_project_key(self):
|
||||
try:
|
||||
projects = self.langfuse_client.client.projects.get()
|
||||
projects = self.langfuse_client.api.projects.get()
|
||||
return projects.data[0].id
|
||||
except Exception as e:
|
||||
logger.debug("LangFuse get project key failed: %s", str(e))
|
||||
|
||||
@@ -26,7 +26,13 @@ from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExport
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.semconv.resource import ResourceAttributes
|
||||
from opentelemetry.semconv._incubating.attributes.deployment_attributes import ( # type: ignore[import-untyped]
|
||||
DEPLOYMENT_ENVIRONMENT,
|
||||
)
|
||||
from opentelemetry.semconv._incubating.attributes.host_attributes import ( # type: ignore[import-untyped]
|
||||
HOST_NAME,
|
||||
)
|
||||
from opentelemetry.semconv.attributes import service_attributes
|
||||
from opentelemetry.trace import SpanKind
|
||||
from opentelemetry.util.types import AttributeValue
|
||||
|
||||
@@ -73,13 +79,13 @@ class TencentTraceClient:
|
||||
|
||||
self.resource = Resource(
|
||||
attributes={
|
||||
ResourceAttributes.SERVICE_NAME: service_name,
|
||||
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
|
||||
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
|
||||
ResourceAttributes.HOST_NAME: socket.gethostname(),
|
||||
ResourceAttributes.TELEMETRY_SDK_LANGUAGE: "python",
|
||||
ResourceAttributes.TELEMETRY_SDK_NAME: "opentelemetry",
|
||||
ResourceAttributes.TELEMETRY_SDK_VERSION: _get_opentelemetry_sdk_version(),
|
||||
service_attributes.SERVICE_NAME: service_name,
|
||||
service_attributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
|
||||
DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
|
||||
HOST_NAME: socket.gethostname(),
|
||||
"telemetry.sdk.language": "python",
|
||||
"telemetry.sdk.name": "opentelemetry",
|
||||
"telemetry.sdk.version": _get_opentelemetry_sdk_version(),
|
||||
}
|
||||
)
|
||||
# Prepare gRPC endpoint/metadata
|
||||
|
||||
@@ -306,7 +306,7 @@ class ProviderManager:
|
||||
"""
|
||||
stmt = select(TenantDefaultModel).where(
|
||||
TenantDefaultModel.tenant_id == tenant_id,
|
||||
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
|
||||
TenantDefaultModel.model_type == model_type,
|
||||
)
|
||||
default_model = db.session.scalar(stmt)
|
||||
|
||||
@@ -324,7 +324,7 @@ class ProviderManager:
|
||||
|
||||
default_model = TenantDefaultModel(
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_type=model_type,
|
||||
provider_name=available_model.provider.provider,
|
||||
model_name=available_model.model,
|
||||
)
|
||||
@@ -391,7 +391,7 @@ class ProviderManager:
|
||||
raise ValueError(f"Model {model} does not exist.")
|
||||
stmt = select(TenantDefaultModel).where(
|
||||
TenantDefaultModel.tenant_id == tenant_id,
|
||||
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
|
||||
TenantDefaultModel.model_type == model_type,
|
||||
)
|
||||
default_model = db.session.scalar(stmt)
|
||||
|
||||
@@ -405,7 +405,7 @@ class ProviderManager:
|
||||
# create default model
|
||||
default_model = TenantDefaultModel(
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_type=model_type,
|
||||
provider_name=provider,
|
||||
model_name=model,
|
||||
)
|
||||
@@ -626,9 +626,8 @@ class ProviderManager:
|
||||
if provider_record.provider_type != ProviderType.SYSTEM:
|
||||
continue
|
||||
|
||||
provider_quota_to_provider_record_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = (
|
||||
provider_record
|
||||
)
|
||||
if provider_record.quota_type is not None:
|
||||
provider_quota_to_provider_record_dict[provider_record.quota_type] = provider_record
|
||||
|
||||
for quota in configuration.quotas:
|
||||
if quota.quota_type in (ProviderQuotaType.TRIAL, ProviderQuotaType.PAID):
|
||||
@@ -641,7 +640,7 @@ class ProviderManager:
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
provider_name=ModelProviderID(provider_name).provider_name,
|
||||
provider_type=ProviderType.SYSTEM,
|
||||
quota_type=quota.quota_type,
|
||||
quota_type=quota.quota_type, # type: ignore[arg-type]
|
||||
quota_limit=0, # type: ignore
|
||||
quota_used=0,
|
||||
is_valid=True,
|
||||
@@ -823,7 +822,7 @@ class ProviderManager:
|
||||
custom_model_configurations.append(
|
||||
CustomModelConfiguration(
|
||||
model=provider_model_record.model_name,
|
||||
model_type=ModelType.value_of(provider_model_record.model_type),
|
||||
model_type=provider_model_record.model_type,
|
||||
credentials=provider_model_credentials,
|
||||
current_credential_id=provider_model_record.credential_id,
|
||||
current_credential_name=provider_model_record.credential_name,
|
||||
@@ -921,9 +920,8 @@ class ProviderManager:
|
||||
if provider_record.provider_type != ProviderType.SYSTEM:
|
||||
continue
|
||||
|
||||
quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = (
|
||||
provider_record
|
||||
)
|
||||
if provider_record.quota_type is not None:
|
||||
quota_type_to_provider_records_dict[provider_record.quota_type] = provider_record # type: ignore[index]
|
||||
quota_configurations = []
|
||||
|
||||
if dify_config.EDITION == "CLOUD":
|
||||
@@ -1203,7 +1201,7 @@ class ProviderManager:
|
||||
model_settings.append(
|
||||
ModelSettings(
|
||||
model=provider_model_setting.model_name,
|
||||
model_type=ModelType.value_of(provider_model_setting.model_type),
|
||||
model_type=provider_model_setting.model_type,
|
||||
enabled=provider_model_setting.enabled,
|
||||
load_balancing_enabled=provider_model_setting.load_balancing_enabled,
|
||||
load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [],
|
||||
|
||||
@@ -97,13 +97,13 @@ class Jieba(BaseKeyword):
|
||||
|
||||
documents = []
|
||||
|
||||
segment_query_stmt = db.session.query(DocumentSegment).where(
|
||||
segment_query_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id.in_(sorted_chunk_indices)
|
||||
)
|
||||
if document_ids_filter:
|
||||
segment_query_stmt = segment_query_stmt.where(DocumentSegment.document_id.in_(document_ids_filter))
|
||||
|
||||
segments = db.session.execute(segment_query_stmt).scalars().all()
|
||||
segments = db.session.scalars(segment_query_stmt).all()
|
||||
segment_map = {segment.index_node_id: segment for segment in segments}
|
||||
for chunk_index in sorted_chunk_indices:
|
||||
segment = segment_map.get(chunk_index)
|
||||
|
||||
@@ -432,10 +432,11 @@ class RetrievalService:
|
||||
# Batch query dataset documents
|
||||
dataset_documents = {
|
||||
doc.id: doc
|
||||
for doc in db.session.query(DatasetDocument)
|
||||
.where(DatasetDocument.id.in_(document_ids))
|
||||
.options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id))
|
||||
.all()
|
||||
for doc in db.session.scalars(
|
||||
select(DatasetDocument)
|
||||
.where(DatasetDocument.id.in_(document_ids))
|
||||
.options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id))
|
||||
).all()
|
||||
}
|
||||
|
||||
valid_dataset_documents = {}
|
||||
|
||||
@@ -426,11 +426,10 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
||||
|
||||
else:
|
||||
idle_tidb_auth_binding = (
|
||||
db.session.query(TidbAuthBinding)
|
||||
idle_tidb_auth_binding = db.session.scalar(
|
||||
select(TidbAuthBinding)
|
||||
.where(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
|
||||
.limit(1)
|
||||
.one_or_none()
|
||||
)
|
||||
if idle_tidb_auth_binding:
|
||||
idle_tidb_auth_binding.active = True
|
||||
|
||||
@@ -277,7 +277,7 @@ class Vector:
|
||||
return self._vector_processor.search_by_vector(query_vector, **kwargs)
|
||||
|
||||
def search_by_file(self, file_id: str, **kwargs: Any) -> list[Document]:
|
||||
upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
upload_file: UploadFile | None = db.session.get(UploadFile, file_id)
|
||||
|
||||
if not upload_file:
|
||||
return []
|
||||
|
||||
@@ -4,7 +4,7 @@ from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy import delete, func, select
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
@@ -63,10 +63,8 @@ class DatasetDocumentStore:
|
||||
return output
|
||||
|
||||
def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False):
|
||||
max_position = (
|
||||
db.session.query(func.max(DocumentSegment.position))
|
||||
.where(DocumentSegment.document_id == self._document_id)
|
||||
.scalar()
|
||||
max_position = db.session.scalar(
|
||||
select(func.max(DocumentSegment.position)).where(DocumentSegment.document_id == self._document_id)
|
||||
)
|
||||
|
||||
if max_position is None:
|
||||
@@ -155,12 +153,14 @@ class DatasetDocumentStore:
|
||||
)
|
||||
if save_child and doc.children:
|
||||
# delete the existing child chunks
|
||||
db.session.query(ChildChunk).where(
|
||||
ChildChunk.tenant_id == self._dataset.tenant_id,
|
||||
ChildChunk.dataset_id == self._dataset.id,
|
||||
ChildChunk.document_id == self._document_id,
|
||||
ChildChunk.segment_id == segment_document.id,
|
||||
).delete()
|
||||
db.session.execute(
|
||||
delete(ChildChunk).where(
|
||||
ChildChunk.tenant_id == self._dataset.tenant_id,
|
||||
ChildChunk.dataset_id == self._dataset.id,
|
||||
ChildChunk.document_id == self._document_id,
|
||||
ChildChunk.segment_id == segment_document.id,
|
||||
)
|
||||
)
|
||||
# add new child chunks
|
||||
for position, child in enumerate(doc.children, start=1):
|
||||
child_segment = ChildChunk(
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any, cast
|
||||
import numpy as np
|
||||
from graphon.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from configs import dify_config
|
||||
@@ -31,14 +32,14 @@ class CacheEmbedding(Embeddings):
|
||||
embedding_queue_indices = []
|
||||
for i, text in enumerate(texts):
|
||||
hash = helper.generate_text_hash(text)
|
||||
embedding = (
|
||||
db.session.query(Embedding)
|
||||
.filter_by(
|
||||
model_name=self._model_instance.model_name,
|
||||
hash=hash,
|
||||
provider_name=self._model_instance.provider,
|
||||
embedding = db.session.scalar(
|
||||
select(Embedding)
|
||||
.where(
|
||||
Embedding.model_name == self._model_instance.model_name,
|
||||
Embedding.hash == hash,
|
||||
Embedding.provider_name == self._model_instance.provider,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if embedding:
|
||||
text_embeddings[i] = embedding.get_embedding()
|
||||
@@ -112,14 +113,14 @@ class CacheEmbedding(Embeddings):
|
||||
embedding_queue_indices = []
|
||||
for i, multimodel_document in enumerate(multimodel_documents):
|
||||
file_id = multimodel_document["file_id"]
|
||||
embedding = (
|
||||
db.session.query(Embedding)
|
||||
.filter_by(
|
||||
model_name=self._model_instance.model_name,
|
||||
hash=file_id,
|
||||
provider_name=self._model_instance.provider,
|
||||
embedding = db.session.scalar(
|
||||
select(Embedding)
|
||||
.where(
|
||||
Embedding.model_name == self._model_instance.model_name,
|
||||
Embedding.hash == file_id,
|
||||
Embedding.provider_name == self._model_instance.provider,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if embedding:
|
||||
multimodel_embeddings[i] = embedding.get_embedding()
|
||||
|
||||
@@ -4,6 +4,7 @@ import operator
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import update
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
@@ -346,9 +347,11 @@ class NotionExtractor(BaseExtractor):
|
||||
if data_source_info:
|
||||
data_source_info["last_edited_time"] = last_edited_time
|
||||
|
||||
db.session.query(DocumentModel).filter_by(id=document_model.id).update(
|
||||
{DocumentModel.data_source_info: json.dumps(data_source_info)}
|
||||
) # type: ignore
|
||||
db.session.execute(
|
||||
update(DocumentModel)
|
||||
.where(DocumentModel.id == document_model.id)
|
||||
.values(data_source_info=json.dumps(data_source_info))
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
def get_notion_last_edited_time(self) -> str:
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Any, NotRequired, Optional
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import select
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
@@ -200,7 +201,7 @@ class BaseIndexProcessor(ABC):
|
||||
|
||||
# Get unique IDs for database query
|
||||
unique_upload_file_ids = list(set(upload_file_id_list))
|
||||
upload_files = db.session.query(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids)).all()
|
||||
upload_files = db.session.scalars(select(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids))).all()
|
||||
|
||||
# Create a mapping from ID to UploadFile for quick lookup
|
||||
upload_file_map = {upload_file.id: upload_file for upload_file in upload_files}
|
||||
@@ -312,7 +313,7 @@ class BaseIndexProcessor(ABC):
|
||||
"""
|
||||
from services.file_service import FileService
|
||||
|
||||
tool_file = db.session.query(ToolFile).where(ToolFile.id == tool_file_id).first()
|
||||
tool_file = db.session.get(ToolFile, tool_file_id)
|
||||
if not tool_file:
|
||||
return None
|
||||
blob = storage.load_once(tool_file.file_key)
|
||||
|
||||
@@ -18,6 +18,7 @@ from graphon.model_runtime.entities.message_entities import (
|
||||
UserPromptMessage,
|
||||
)
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.app.llm import deduct_llm_quota
|
||||
@@ -145,14 +146,12 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
if delete_summaries:
|
||||
if node_ids:
|
||||
# Find segments by index_node_id
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(node_ids),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
if segment_ids:
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
|
||||
@@ -537,11 +536,9 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
|
||||
# Get unique IDs for database query
|
||||
unique_upload_file_ids = list(set(upload_file_id_list))
|
||||
upload_files = (
|
||||
db.session.query(UploadFile)
|
||||
.where(UploadFile.id.in_(unique_upload_file_ids), UploadFile.tenant_id == tenant_id)
|
||||
.all()
|
||||
)
|
||||
upload_files = db.session.scalars(
|
||||
select(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids), UploadFile.tenant_id == tenant_id)
|
||||
).all()
|
||||
|
||||
# Create File objects from UploadFile records
|
||||
file_objects = []
|
||||
|
||||
@@ -6,6 +6,8 @@ import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
@@ -177,17 +179,16 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
child_node_ids = precomputed_child_node_ids
|
||||
else:
|
||||
# Fallback to original query (may fail if segments are already deleted)
|
||||
child_node_ids = (
|
||||
db.session.query(ChildChunk.index_node_id)
|
||||
rows = db.session.execute(
|
||||
select(ChildChunk.index_node_id)
|
||||
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
|
||||
.where(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(node_ids),
|
||||
ChildChunk.dataset_id == dataset.id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
child_node_ids = [child_node_id[0] for child_node_id in child_node_ids if child_node_id[0]]
|
||||
).all()
|
||||
child_node_ids = [row[0] for row in rows if row[0]]
|
||||
|
||||
# Delete from vector index
|
||||
if child_node_ids:
|
||||
@@ -195,18 +196,22 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
|
||||
# Delete from database
|
||||
if delete_child_chunks and child_node_ids:
|
||||
db.session.query(ChildChunk).where(
|
||||
ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids)
|
||||
).delete(synchronize_session=False)
|
||||
db.session.execute(
|
||||
delete(ChildChunk).where(
|
||||
ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids)
|
||||
)
|
||||
)
|
||||
db.session.commit()
|
||||
else:
|
||||
vector.delete()
|
||||
|
||||
if delete_child_chunks:
|
||||
# Use existing compound index: (tenant_id, dataset_id, ...)
|
||||
db.session.query(ChildChunk).where(
|
||||
ChildChunk.tenant_id == dataset.tenant_id, ChildChunk.dataset_id == dataset.id
|
||||
).delete(synchronize_session=False)
|
||||
db.session.execute(
|
||||
delete(ChildChunk).where(
|
||||
ChildChunk.tenant_id == dataset.tenant_id, ChildChunk.dataset_id == dataset.id
|
||||
)
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
def retrieve(
|
||||
|
||||
@@ -134,9 +134,7 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
):
|
||||
if document.metadata.get("doc_type") == DocType.IMAGE:
|
||||
# Query file info within db.session context to ensure thread-safe access
|
||||
upload_file = (
|
||||
db.session.query(UploadFile).where(UploadFile.id == document.metadata["doc_id"]).first()
|
||||
)
|
||||
upload_file = db.session.get(UploadFile, document.metadata["doc_id"])
|
||||
if upload_file:
|
||||
blob = storage.load_once(upload_file.key)
|
||||
document_file_base64 = base64.b64encode(blob).decode()
|
||||
@@ -169,7 +167,7 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
return rerank_result, unique_documents
|
||||
elif query_type == QueryType.IMAGE_QUERY:
|
||||
# Query file info within db.session context to ensure thread-safe access
|
||||
upload_file = db.session.query(UploadFile).where(UploadFile.id == query).first()
|
||||
upload_file = db.session.get(UploadFile, query)
|
||||
if upload_file:
|
||||
blob = storage.load_once(upload_file.key)
|
||||
file_query = base64.b64encode(blob).decode()
|
||||
|
||||
@@ -1340,7 +1340,7 @@ class DatasetRetrieval:
|
||||
metadata_filtering_conditions: MetadataFilteringCondition | None,
|
||||
inputs: dict,
|
||||
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]:
|
||||
document_query = db.session.query(DatasetDocument).where(
|
||||
document_query = select(DatasetDocument).where(
|
||||
DatasetDocument.dataset_id.in_(dataset_ids),
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
@@ -1411,7 +1411,7 @@ class DatasetRetrieval:
|
||||
document_query = document_query.where(and_(*filters))
|
||||
else:
|
||||
document_query = document_query.where(or_(*filters))
|
||||
documents = document_query.all()
|
||||
documents = db.session.scalars(document_query).all()
|
||||
# group by dataset_id
|
||||
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
|
||||
for document in documents:
|
||||
|
||||
@@ -27,7 +27,10 @@ from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.sdk.trace.sampling import ParentBasedTraceIdRatio
|
||||
from opentelemetry.semconv.resource import ResourceAttributes
|
||||
from opentelemetry.semconv._incubating.attributes.host_attributes import ( # type: ignore[import-untyped]
|
||||
HOST_NAME,
|
||||
)
|
||||
from opentelemetry.semconv.attributes import service_attributes
|
||||
from opentelemetry.trace import SpanContext, TraceFlags
|
||||
from opentelemetry.util.types import Attributes, AttributeValue
|
||||
|
||||
@@ -114,8 +117,8 @@ class EnterpriseExporter:
|
||||
|
||||
resource = Resource(
|
||||
attributes={
|
||||
ResourceAttributes.SERVICE_NAME: service_name,
|
||||
ResourceAttributes.HOST_NAME: socket.gethostname(),
|
||||
service_attributes.SERVICE_NAME: service_name,
|
||||
HOST_NAME: socket.gethostname(),
|
||||
}
|
||||
)
|
||||
sampler = ParentBasedTraceIdRatio(sampling_rate)
|
||||
|
||||
@@ -157,7 +157,7 @@ def handle(sender: Message, **kwargs):
|
||||
tenant_id=tenant_id,
|
||||
provider_name=ModelProviderID(model_config.provider).provider_name,
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
quota_type=provider_configuration.system_configuration.current_quota_type.value,
|
||||
quota_type=provider_configuration.system_configuration.current_quota_type,
|
||||
),
|
||||
values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
|
||||
additional_filters=_ProviderUpdateAdditionalFilters(
|
||||
|
||||
@@ -6,15 +6,24 @@ def init_app(app: DifyApp):
|
||||
if dify_config.SENTRY_DSN:
|
||||
import sentry_sdk
|
||||
from graphon.model_runtime.errors.invoke import InvokeRateLimitError
|
||||
from langfuse import parse_error
|
||||
from sentry_sdk.integrations.celery import CeleryIntegration
|
||||
from sentry_sdk.integrations.flask import FlaskIntegration
|
||||
from werkzeug.exceptions import HTTPException
|
||||
|
||||
try:
|
||||
from langfuse._utils import parse_error
|
||||
|
||||
_langfuse_error_response = parse_error.defaultErrorResponse
|
||||
except (ImportError, AttributeError):
|
||||
_langfuse_error_response = (
|
||||
"Unexpected error occurred. Please check your request"
|
||||
" and contact support: https://langfuse.com/support."
|
||||
)
|
||||
|
||||
def before_send(event, hint):
|
||||
if "exc_info" in hint:
|
||||
_, exc_value, _ = hint["exc_info"]
|
||||
if parse_error.defaultErrorResponse in str(exc_value):
|
||||
if _langfuse_error_response in str(exc_value):
|
||||
return None
|
||||
|
||||
return event
|
||||
@@ -27,7 +36,7 @@ def init_app(app: DifyApp):
|
||||
ValueError,
|
||||
FileNotFoundError,
|
||||
InvokeRateLimitError,
|
||||
parse_error.defaultErrorResponse,
|
||||
_langfuse_error_response,
|
||||
],
|
||||
traces_sample_rate=dify_config.SENTRY_TRACES_SAMPLE_RATE,
|
||||
profiles_sample_rate=dify_config.SENTRY_PROFILES_SAMPLE_RATE,
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Protocol, cast
|
||||
|
||||
import flask
|
||||
from opentelemetry.instrumentation.celery import CeleryInstrumentor
|
||||
@@ -21,6 +23,38 @@ from extensions.otel.runtime import is_celery_worker
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SupportsInstrument(Protocol):
|
||||
def instrument(self, **kwargs: object) -> None: ...
|
||||
|
||||
|
||||
class SupportsFlaskInstrumentor(Protocol):
|
||||
def instrument_app(
|
||||
self, app: DifyApp, response_hook: Callable[[Span, str, list], None] | None = None, **kwargs: object
|
||||
) -> None: ...
|
||||
|
||||
|
||||
# Some OpenTelemetry instrumentor constructors are typed loosely enough that
|
||||
# pyrefly infers `NoneType`. Narrow the instances to just the methods we use
|
||||
# while leaving runtime behavior unchanged.
|
||||
def _new_celery_instrumentor() -> SupportsInstrument:
|
||||
return cast(
|
||||
SupportsInstrument,
|
||||
CeleryInstrumentor(tracer_provider=get_tracer_provider(), meter_provider=get_meter_provider()),
|
||||
)
|
||||
|
||||
|
||||
def _new_httpx_instrumentor() -> SupportsInstrument:
|
||||
return cast(SupportsInstrument, HTTPXClientInstrumentor())
|
||||
|
||||
|
||||
def _new_redis_instrumentor() -> SupportsInstrument:
|
||||
return cast(SupportsInstrument, RedisInstrumentor())
|
||||
|
||||
|
||||
def _new_sqlalchemy_instrumentor() -> SupportsInstrument:
|
||||
return cast(SupportsInstrument, SQLAlchemyInstrumentor())
|
||||
|
||||
|
||||
class ExceptionLoggingHandler(logging.Handler):
|
||||
"""
|
||||
Handler that records exceptions to the current OpenTelemetry span.
|
||||
@@ -97,7 +131,7 @@ def init_flask_instrumentor(app: DifyApp) -> None:
|
||||
|
||||
from opentelemetry.instrumentation.flask import FlaskInstrumentor
|
||||
|
||||
instrumentor = FlaskInstrumentor()
|
||||
instrumentor = cast(SupportsFlaskInstrumentor, FlaskInstrumentor())
|
||||
if dify_config.DEBUG:
|
||||
logger.info("Initializing Flask instrumentor")
|
||||
instrumentor.instrument_app(app, response_hook=response_hook)
|
||||
@@ -106,21 +140,21 @@ def init_flask_instrumentor(app: DifyApp) -> None:
|
||||
def init_sqlalchemy_instrumentor(app: DifyApp) -> None:
|
||||
with app.app_context():
|
||||
engines = list(app.extensions["sqlalchemy"].engines.values())
|
||||
SQLAlchemyInstrumentor().instrument(enable_commenter=True, engines=engines)
|
||||
_new_sqlalchemy_instrumentor().instrument(enable_commenter=True, engines=engines)
|
||||
|
||||
|
||||
def init_redis_instrumentor() -> None:
|
||||
RedisInstrumentor().instrument()
|
||||
_new_redis_instrumentor().instrument()
|
||||
|
||||
|
||||
def init_httpx_instrumentor() -> None:
|
||||
HTTPXClientInstrumentor().instrument()
|
||||
_new_httpx_instrumentor().instrument()
|
||||
|
||||
|
||||
def init_instruments(app: DifyApp) -> None:
|
||||
if not is_celery_worker():
|
||||
init_flask_instrumentor(app)
|
||||
CeleryInstrumentor(tracer_provider=get_tracer_provider(), meter_provider=get_meter_provider()).instrument()
|
||||
_new_celery_instrumentor().instrument()
|
||||
|
||||
instrument_exception_logging()
|
||||
init_sqlalchemy_instrumentor(app)
|
||||
|
||||
@@ -6,6 +6,7 @@ from functools import cached_property
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from sqlalchemy import DateTime, String, func, select, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
@@ -13,7 +14,7 @@ from libs.uuid_utils import uuidv7
|
||||
|
||||
from .base import TypeBase
|
||||
from .engine import db
|
||||
from .enums import CredentialSourceType, PaymentStatus
|
||||
from .enums import CredentialSourceType, PaymentStatus, ProviderQuotaType
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
|
||||
|
||||
@@ -29,24 +30,6 @@ class ProviderType(StrEnum):
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class ProviderQuotaType(StrEnum):
|
||||
PAID = auto()
|
||||
"""hosted paid quota"""
|
||||
|
||||
FREE = auto()
|
||||
"""third-party free quota"""
|
||||
|
||||
TRIAL = auto()
|
||||
"""hosted trial quota"""
|
||||
|
||||
@staticmethod
|
||||
def value_of(value: str) -> ProviderQuotaType:
|
||||
for member in ProviderQuotaType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class Provider(TypeBase):
|
||||
"""
|
||||
Provider model representing the API providers and their configurations.
|
||||
@@ -77,7 +60,9 @@ class Provider(TypeBase):
|
||||
last_used: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, init=False)
|
||||
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
|
||||
quota_type: Mapped[str | None] = mapped_column(String(40), nullable=True, server_default=text("''"), default="")
|
||||
quota_type: Mapped[ProviderQuotaType | None] = mapped_column(
|
||||
EnumText(ProviderQuotaType, length=40), nullable=True, server_default=text("''"), default=None
|
||||
)
|
||||
quota_limit: Mapped[int | None] = mapped_column(sa.BigInteger, nullable=True, default=None)
|
||||
quota_used: Mapped[int | None] = mapped_column(sa.BigInteger, nullable=True, default=0)
|
||||
|
||||
@@ -147,7 +132,7 @@ class ProviderModel(TypeBase):
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
model_type: Mapped[ModelType] = mapped_column(EnumText(ModelType, length=40), nullable=False)
|
||||
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"), default=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
@@ -189,7 +174,7 @@ class TenantDefaultModel(TypeBase):
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
model_type: Mapped[ModelType] = mapped_column(EnumText(ModelType, length=40), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
@@ -269,7 +254,7 @@ class ProviderModelSetting(TypeBase):
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
model_type: Mapped[ModelType] = mapped_column(EnumText(ModelType, length=40), nullable=False)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"), default=True)
|
||||
load_balancing_enabled: Mapped[bool] = mapped_column(
|
||||
sa.Boolean, nullable=False, server_default=text("false"), default=False
|
||||
@@ -299,7 +284,7 @@ class LoadBalancingModelConfig(TypeBase):
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
model_type: Mapped[ModelType] = mapped_column(EnumText(ModelType, length=40), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
encrypted_config: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
@@ -364,7 +349,7 @@ class ProviderModelCredential(TypeBase):
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
model_type: Mapped[ModelType] = mapped_column(EnumText(ModelType, length=40), nullable=False)
|
||||
credential_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
encrypted_config: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
|
||||
@@ -144,8 +144,8 @@ class EnumText(TypeDecorator[_E | None], Generic[_E]):
|
||||
return dialect.type_descriptor(VARCHAR(self._length))
|
||||
|
||||
def process_result_value(self, value: str | None, dialect: Dialect) -> _E | None:
|
||||
if value is None:
|
||||
return value
|
||||
if value is None or value == "":
|
||||
return None
|
||||
# Type annotation guarantees value is str at this point
|
||||
return self._enum_class(value)
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ dependencies = [
|
||||
"httpx[socks]~=0.28.0",
|
||||
"jieba==0.42.1",
|
||||
"json-repair>=0.55.1",
|
||||
"langfuse~=2.51.3",
|
||||
"langfuse>=3.0.0,<5.0.0",
|
||||
"langsmith~=0.7.16",
|
||||
"markdown~=3.10.2",
|
||||
"mlflow-skinny>=3.0.0",
|
||||
@@ -41,23 +41,23 @@ dependencies = [
|
||||
"openpyxl~=3.1.5",
|
||||
"opik~=1.10.37",
|
||||
"litellm==1.82.6", # Pinned to avoid madoka dependency issue
|
||||
"opentelemetry-api==1.28.0",
|
||||
"opentelemetry-distro==0.49b0",
|
||||
"opentelemetry-exporter-otlp==1.28.0",
|
||||
"opentelemetry-exporter-otlp-proto-common==1.28.0",
|
||||
"opentelemetry-exporter-otlp-proto-grpc==1.28.0",
|
||||
"opentelemetry-exporter-otlp-proto-http==1.28.0",
|
||||
"opentelemetry-instrumentation==0.49b0",
|
||||
"opentelemetry-instrumentation-celery==0.49b0",
|
||||
"opentelemetry-instrumentation-flask==0.49b0",
|
||||
"opentelemetry-instrumentation-httpx==0.49b0",
|
||||
"opentelemetry-instrumentation-redis==0.49b0",
|
||||
"opentelemetry-instrumentation-sqlalchemy==0.49b0",
|
||||
"opentelemetry-api==1.40.0",
|
||||
"opentelemetry-distro==0.61b0",
|
||||
"opentelemetry-exporter-otlp==1.40.0",
|
||||
"opentelemetry-exporter-otlp-proto-common==1.40.0",
|
||||
"opentelemetry-exporter-otlp-proto-grpc==1.40.0",
|
||||
"opentelemetry-exporter-otlp-proto-http==1.40.0",
|
||||
"opentelemetry-instrumentation==0.61b0",
|
||||
"opentelemetry-instrumentation-celery==0.61b0",
|
||||
"opentelemetry-instrumentation-flask==0.61b0",
|
||||
"opentelemetry-instrumentation-httpx==0.61b0",
|
||||
"opentelemetry-instrumentation-redis==0.61b0",
|
||||
"opentelemetry-instrumentation-sqlalchemy==0.61b0",
|
||||
"opentelemetry-propagator-b3==1.40.0",
|
||||
"opentelemetry-proto==1.28.0",
|
||||
"opentelemetry-sdk==1.28.0",
|
||||
"opentelemetry-semantic-conventions==0.49b0",
|
||||
"opentelemetry-util-http==0.49b0",
|
||||
"opentelemetry-proto==1.40.0",
|
||||
"opentelemetry-sdk==1.40.0",
|
||||
"opentelemetry-semantic-conventions==0.61b0",
|
||||
"opentelemetry-util-http==0.61b0",
|
||||
"pandas[excel,output-formatting,performance]~=3.0.1",
|
||||
"psycogreen~=1.0.2",
|
||||
"psycopg2-binary~=2.9.6",
|
||||
|
||||
@@ -7,9 +7,19 @@ from datetime import UTC, datetime, timedelta
|
||||
from hashlib import sha256
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class InvitationData(TypedDict):
|
||||
account_id: str
|
||||
email: str
|
||||
workspace_id: str
|
||||
|
||||
|
||||
_invitation_adapter: TypeAdapter[InvitationData] = TypeAdapter(InvitationData)
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
@@ -1571,7 +1581,7 @@ class RegisterService:
|
||||
@classmethod
|
||||
def get_invitation_by_token(
|
||||
cls, token: str, workspace_id: str | None = None, email: str | None = None
|
||||
) -> dict[str, str] | None:
|
||||
) -> InvitationData | None:
|
||||
if workspace_id is not None and email is not None:
|
||||
email_hash = sha256(email.encode()).hexdigest()
|
||||
cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}"
|
||||
@@ -1590,7 +1600,7 @@ class RegisterService:
|
||||
if not data:
|
||||
return None
|
||||
|
||||
invitation: dict = json.loads(data)
|
||||
invitation = _invitation_adapter.validate_json(data)
|
||||
return invitation
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal
|
||||
from typing import Literal, NotRequired
|
||||
|
||||
import httpx
|
||||
from pydantic import TypeAdapter
|
||||
@@ -47,6 +47,58 @@ class QuotaReleaseResult(TypedDict):
|
||||
_quota_reserve_adapter = TypeAdapter(QuotaReserveResult)
|
||||
_quota_commit_adapter = TypeAdapter(QuotaCommitResult)
|
||||
_quota_release_adapter = TypeAdapter(QuotaReleaseResult)
|
||||
class _BillingQuota(TypedDict):
|
||||
size: int
|
||||
limit: int
|
||||
|
||||
|
||||
class _VectorSpaceQuota(TypedDict):
|
||||
size: float
|
||||
limit: int
|
||||
|
||||
|
||||
class _KnowledgeRateLimit(TypedDict):
|
||||
# NOTE (hj24):
|
||||
# 1. Return for sandbox users but is null for other plans, it's defined but never used.
|
||||
# 2. Keep it for compatibility for now, can be deprecated in future versions.
|
||||
size: NotRequired[int]
|
||||
# NOTE END
|
||||
limit: int
|
||||
|
||||
|
||||
class _BillingSubscription(TypedDict):
|
||||
plan: str
|
||||
interval: str
|
||||
education: bool
|
||||
|
||||
|
||||
class BillingInfo(TypedDict):
|
||||
"""Response of /subscription/info.
|
||||
|
||||
NOTE (hj24):
|
||||
- Fields not listed here (e.g. trigger_event, api_rate_limit) are stripped by TypeAdapter.validate_python()
|
||||
- To ensure the precision, billing may convert fields like int as str, be careful when use TypeAdapter:
|
||||
1. validate_python in non-strict mode will coerce it to the expected type
|
||||
2. In strict mode, it will raise ValidationError
|
||||
3. To preserve compatibility, always keep non-strict mode here and avoid strict mode
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
subscription: _BillingSubscription
|
||||
members: _BillingQuota
|
||||
apps: _BillingQuota
|
||||
vector_space: _VectorSpaceQuota
|
||||
knowledge_rate_limit: _KnowledgeRateLimit
|
||||
documents_upload_quota: _BillingQuota
|
||||
annotation_quota_limit: _BillingQuota
|
||||
docs_processing: str
|
||||
can_replace_logo: bool
|
||||
model_load_balancing_enabled: bool
|
||||
knowledge_pipeline_publish_enabled: bool
|
||||
next_credit_reset_date: NotRequired[int]
|
||||
|
||||
|
||||
_billing_info_adapter = TypeAdapter(BillingInfo)
|
||||
|
||||
|
||||
class BillingService:
|
||||
@@ -61,11 +113,11 @@ class BillingService:
|
||||
_PLAN_CACHE_TTL = 600
|
||||
|
||||
@classmethod
|
||||
def get_info(cls, tenant_id: str):
|
||||
def get_info(cls, tenant_id: str) -> BillingInfo:
|
||||
params = {"tenant_id": tenant_id}
|
||||
|
||||
billing_info = cls._send_request("GET", "/subscription/info", params=params)
|
||||
return billing_info
|
||||
return _billing_info_adapter.validate_python(billing_info)
|
||||
|
||||
@classmethod
|
||||
def get_tenant_feature_plan_usage_info(cls, tenant_id: str):
|
||||
|
||||
@@ -312,7 +312,10 @@ class FeatureService:
|
||||
features.apps.limit = billing_info["apps"]["limit"]
|
||||
|
||||
if "vector_space" in billing_info:
|
||||
features.vector_space.size = billing_info["vector_space"]["size"]
|
||||
# NOTE (hj24): billing API returns vector_space.size as float (e.g. 0.0)
|
||||
# but LimitationModel.size is int; truncate here for compatibility
|
||||
features.vector_space.size = int(billing_info["vector_space"]["size"])
|
||||
# NOTE END
|
||||
features.vector_space.limit = billing_info["vector_space"]["limit"]
|
||||
|
||||
if "documents_upload_quota" in billing_info:
|
||||
@@ -333,7 +336,11 @@ class FeatureService:
|
||||
features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"]
|
||||
|
||||
if "knowledge_rate_limit" in billing_info:
|
||||
# NOTE (hj24):
|
||||
# 1. knowledge_rate_limit size is nullable, currently it's defined but never used, only limit is used.
|
||||
# 2. So be careful if later we decide to use [size], we cannot assume it is always present.
|
||||
features.knowledge_rate_limit = billing_info["knowledge_rate_limit"]["limit"]
|
||||
# NOTE END
|
||||
|
||||
if "knowledge_pipeline_publish_enabled" in billing_info:
|
||||
features.knowledge_pipeline.publish_enabled = billing_info["knowledge_pipeline_publish_enabled"]
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Union
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
@@ -17,7 +17,7 @@ from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models import Account
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
from models.model import App, AppMode, AppModelConfig, EndUser, Message, MessageFeedback
|
||||
from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, EndUser, Message, MessageFeedback
|
||||
from repositories.execution_extra_content_repository import ExecutionExtraContentRepository
|
||||
from repositories.sqlalchemy_execution_extra_content_repository import (
|
||||
SQLAlchemyExecutionExtraContentRepository,
|
||||
@@ -31,6 +31,8 @@ from services.errors.message import (
|
||||
)
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
_app_model_config_adapter: TypeAdapter[AppModelConfigDict] = TypeAdapter(AppModelConfigDict)
|
||||
|
||||
|
||||
def _create_execution_extra_content_repository() -> ExecutionExtraContentRepository:
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
@@ -286,7 +288,9 @@ class MessageService:
|
||||
.first()
|
||||
)
|
||||
else:
|
||||
conversation_override_model_configs = json.loads(conversation.override_model_configs)
|
||||
conversation_override_model_configs = _app_model_config_adapter.validate_json(
|
||||
conversation.override_model_configs
|
||||
)
|
||||
app_model_config = AppModelConfig(
|
||||
app_id=app_model.id,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from json import JSONDecodeError
|
||||
from typing import Union
|
||||
from typing import Any, Union
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import (
|
||||
@@ -116,7 +115,7 @@ class ModelLoadBalancingService:
|
||||
.where(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_type == model_type_enum,
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
or_(
|
||||
LoadBalancingModelConfig.credential_source_type == credential_source_type,
|
||||
@@ -168,10 +167,10 @@ class ModelLoadBalancingService:
|
||||
|
||||
try:
|
||||
if load_balancing_config.encrypted_config:
|
||||
credentials: dict[str, object] = json.loads(load_balancing_config.encrypted_config)
|
||||
credentials: dict[str, Any] = json.loads(load_balancing_config.encrypted_config)
|
||||
else:
|
||||
credentials = {}
|
||||
except JSONDecodeError:
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
credentials = {}
|
||||
|
||||
# Get provider credential secret variables
|
||||
@@ -241,7 +240,7 @@ class ModelLoadBalancingService:
|
||||
.where(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_type == model_type_enum,
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
LoadBalancingModelConfig.id == config_id,
|
||||
)
|
||||
@@ -256,7 +255,7 @@ class ModelLoadBalancingService:
|
||||
credentials = json.loads(load_balancing_model_config.encrypted_config)
|
||||
else:
|
||||
credentials = {}
|
||||
except JSONDecodeError:
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
credentials = {}
|
||||
|
||||
# Get credential form schemas from model credential schema or provider credential schema
|
||||
@@ -289,7 +288,7 @@ class ModelLoadBalancingService:
|
||||
inherit_config = LoadBalancingModelConfig(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_type=model_type,
|
||||
model_name=model,
|
||||
name="__inherit__",
|
||||
)
|
||||
@@ -329,7 +328,7 @@ class ModelLoadBalancingService:
|
||||
select(LoadBalancingModelConfig).where(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_type == model_type_enum,
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
)
|
||||
).all()
|
||||
@@ -369,7 +368,7 @@ class ModelLoadBalancingService:
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_configuration.provider.provider,
|
||||
model_name=model,
|
||||
model_type=model_type_enum.to_origin_model_type(),
|
||||
model_type=model_type_enum,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
@@ -433,7 +432,7 @@ class ModelLoadBalancingService:
|
||||
load_balancing_model_config = LoadBalancingModelConfig(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_configuration.provider.provider,
|
||||
model_type=model_type_enum.to_origin_model_type(),
|
||||
model_type=model_type_enum,
|
||||
model_name=model,
|
||||
name=credential_record.credential_name,
|
||||
encrypted_config=credential_record.encrypted_config,
|
||||
@@ -461,7 +460,7 @@ class ModelLoadBalancingService:
|
||||
load_balancing_model_config = LoadBalancingModelConfig(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_configuration.provider.provider,
|
||||
model_type=model_type_enum.to_origin_model_type(),
|
||||
model_type=model_type_enum,
|
||||
model_name=model,
|
||||
name=name,
|
||||
encrypted_config=json.dumps(credentials),
|
||||
@@ -516,7 +515,7 @@ class ModelLoadBalancingService:
|
||||
.where(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider,
|
||||
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_type == model_type_enum,
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
LoadBalancingModelConfig.id == config_id,
|
||||
)
|
||||
@@ -575,7 +574,7 @@ class ModelLoadBalancingService:
|
||||
original_credentials = json.loads(load_balancing_model_config.encrypted_config)
|
||||
else:
|
||||
original_credentials = {}
|
||||
except JSONDecodeError:
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
original_credentials = {}
|
||||
|
||||
# encrypt credentials
|
||||
|
||||
@@ -12,7 +12,9 @@ import click
|
||||
import sqlalchemy as sa
|
||||
import tqdm
|
||||
from flask import Flask, current_app
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy.orm import Session
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.helper import marketplace
|
||||
@@ -33,6 +35,14 @@ logger = logging.getLogger(__name__)
|
||||
excluded_providers = ["time", "audio", "code", "webscraper"]
|
||||
|
||||
|
||||
class _TenantPluginRecord(TypedDict):
|
||||
tenant_id: str
|
||||
plugins: list[str]
|
||||
|
||||
|
||||
_tenant_plugin_adapter: TypeAdapter[_TenantPluginRecord] = TypeAdapter(_TenantPluginRecord)
|
||||
|
||||
|
||||
class PluginMigration:
|
||||
@classmethod
|
||||
def extract_plugins(cls, filepath: str, workers: int):
|
||||
@@ -308,9 +318,8 @@ class PluginMigration:
|
||||
logger.info("Extracting unique plugins from %s", extracted_plugins)
|
||||
with open(extracted_plugins) as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
new_plugin_ids = data.get("plugins", [])
|
||||
for plugin_id in new_plugin_ids:
|
||||
data = _tenant_plugin_adapter.validate_json(line)
|
||||
for plugin_id in data["plugins"]:
|
||||
if plugin_id not in plugin_ids:
|
||||
plugin_ids.append(plugin_id)
|
||||
|
||||
@@ -381,21 +390,23 @@ class PluginMigration:
|
||||
Read line by line, and install plugins for each tenant.
|
||||
"""
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
tenant_id = data.get("tenant_id")
|
||||
plugin_ids = data.get("plugins", [])
|
||||
current_not_installed = {
|
||||
"tenant_id": tenant_id,
|
||||
"plugin_not_exist": [],
|
||||
}
|
||||
data = _tenant_plugin_adapter.validate_json(line)
|
||||
tenant_id = data["tenant_id"]
|
||||
plugin_ids = data["plugins"]
|
||||
plugin_not_exist: list[str] = []
|
||||
# get plugin unique identifier
|
||||
for plugin_id in plugin_ids:
|
||||
unique_identifier = plugins.get(plugin_id)
|
||||
if unique_identifier:
|
||||
current_not_installed["plugin_not_exist"].append(plugin_id)
|
||||
plugin_not_exist.append(plugin_id)
|
||||
|
||||
if current_not_installed["plugin_not_exist"]:
|
||||
not_installed.append(current_not_installed)
|
||||
if plugin_not_exist:
|
||||
not_installed.append(
|
||||
{
|
||||
"tenant_id": tenant_id,
|
||||
"plugin_not_exist": plugin_not_exist,
|
||||
}
|
||||
)
|
||||
|
||||
thread_pool.submit(install, tenant_id, plugin_ids)
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ back to the database.
|
||||
"""
|
||||
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import zipfile
|
||||
@@ -17,8 +16,23 @@ from datetime import datetime
|
||||
from typing import Any, cast
|
||||
|
||||
import click
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class _TableInfo(TypedDict, total=False):
|
||||
row_count: int
|
||||
|
||||
|
||||
class ArchiveManifest(TypedDict, total=False):
|
||||
tables: dict[str, _TableInfo]
|
||||
schema_version: str
|
||||
|
||||
|
||||
_manifest_adapter: TypeAdapter[ArchiveManifest] = TypeAdapter(ArchiveManifest)
|
||||
|
||||
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
|
||||
|
||||
from extensions.ext_database import db
|
||||
@@ -239,12 +253,12 @@ class WorkflowRunRestore:
|
||||
return self.workflow_run_repo
|
||||
|
||||
@staticmethod
|
||||
def _load_manifest_from_zip(archive: zipfile.ZipFile) -> dict[str, Any]:
|
||||
def _load_manifest_from_zip(archive: zipfile.ZipFile) -> ArchiveManifest:
|
||||
try:
|
||||
data = archive.read("manifest.json")
|
||||
except KeyError as e:
|
||||
raise ValueError("manifest.json missing from archive bundle") from e
|
||||
return json.loads(data.decode("utf-8"))
|
||||
return _manifest_adapter.validate_json(data)
|
||||
|
||||
def _restore_table_records(
|
||||
self,
|
||||
@@ -332,7 +346,7 @@ class WorkflowRunRestore:
|
||||
|
||||
return result
|
||||
|
||||
def _get_schema_version(self, manifest: dict[str, Any]) -> str:
|
||||
def _get_schema_version(self, manifest: ArchiveManifest) -> str:
|
||||
schema_version = manifest.get("schema_version")
|
||||
if not schema_version:
|
||||
logger.warning("Manifest missing schema_version; defaulting to 1.0")
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import ValidationError
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
@@ -31,6 +31,8 @@ from services.plugin.plugin_service import PluginService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_mcp_tools_adapter: TypeAdapter[list[MCPTool]] = TypeAdapter(list[MCPTool])
|
||||
|
||||
|
||||
class ToolTransformService:
|
||||
_MCP_SCHEMA_TYPE_RESOLUTION_MAX_DEPTH = 10
|
||||
@@ -53,7 +55,7 @@ class ToolTransformService:
|
||||
if isinstance(icon, str):
|
||||
return json.loads(icon)
|
||||
return icon
|
||||
except Exception:
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
elif provider_type == ToolProviderType.MCP:
|
||||
return icon
|
||||
@@ -247,8 +249,8 @@ class ToolTransformService:
|
||||
|
||||
response = provider_entity.to_api_response(user_name=user_name, include_sensitive=include_sensitive)
|
||||
try:
|
||||
mcp_tools = [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
|
||||
except (ValidationError, json.JSONDecodeError):
|
||||
mcp_tools = _mcp_tools_adapter.validate_json(db_provider.tools)
|
||||
except (ValidationError, ValueError):
|
||||
mcp_tools = []
|
||||
# Add additional fields specific to the transform
|
||||
response["id"] = db_provider.server_identifier if not for_list else db_provider.id
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
@@ -7,6 +6,7 @@ from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from flask import Request, Response
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.entities.request import TriggerDispatchResponse
|
||||
@@ -29,6 +29,8 @@ from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_request_logs_adapter: TypeAdapter[list[RequestLog]] = TypeAdapter(list[RequestLog])
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderService:
|
||||
"""Service for managing trigger providers and credentials"""
|
||||
@@ -398,7 +400,7 @@ class TriggerSubscriptionBuilderService:
|
||||
cache_key = cls.encode_cache_key(endpoint_id)
|
||||
subscription_cache = redis_client.get(cache_key)
|
||||
if subscription_cache:
|
||||
return SubscriptionBuilder.model_validate(json.loads(subscription_cache))
|
||||
return SubscriptionBuilder.model_validate_json(subscription_cache)
|
||||
|
||||
return None
|
||||
|
||||
@@ -423,12 +425,16 @@ class TriggerSubscriptionBuilderService:
|
||||
)
|
||||
|
||||
key = f"trigger:subscription:builder:logs:{endpoint_id}"
|
||||
logs = json.loads(redis_client.get(key) or "[]")
|
||||
logs.append(log.model_dump(mode="json"))
|
||||
logs = _request_logs_adapter.validate_json(redis_client.get(key) or b"[]")
|
||||
logs.append(log)
|
||||
|
||||
# Keep last N logs
|
||||
logs = logs[-cls.__VALIDATION_REQUEST_CACHE_COUNT__ :]
|
||||
redis_client.setex(key, cls.__VALIDATION_REQUEST_CACHE_EXPIRE_SECONDS__, json.dumps(logs, default=str))
|
||||
redis_client.setex(
|
||||
key,
|
||||
cls.__VALIDATION_REQUEST_CACHE_EXPIRE_SECONDS__,
|
||||
_request_logs_adapter.dump_json(logs),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def list_logs(cls, endpoint_id: str) -> list[RequestLog]:
|
||||
@@ -437,7 +443,7 @@ class TriggerSubscriptionBuilderService:
|
||||
logs_json = redis_client.get(key)
|
||||
if not logs_json:
|
||||
return []
|
||||
return [RequestLog.model_validate(log) for log in json.loads(logs_json)]
|
||||
return _request_logs_adapter.validate_json(logs_json)
|
||||
|
||||
@classmethod
|
||||
def process_builder_validation_endpoint(cls, endpoint_id: str, request: Request) -> Response | None:
|
||||
|
||||
@@ -1118,7 +1118,7 @@ class WorkflowService:
|
||||
continue
|
||||
try:
|
||||
payload = json.loads(recipient.recipient_payload)
|
||||
except Exception:
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
logger.exception("Failed to parse human input recipient payload for delivery test.")
|
||||
continue
|
||||
email = payload.get("email")
|
||||
|
||||
@@ -0,0 +1,182 @@
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from models import Tenant
|
||||
from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant(flask_req_ctx):
|
||||
with session_factory.create_session() as session:
|
||||
t = Tenant(name="plugin_it_tenant")
|
||||
session.add(t)
|
||||
session.commit()
|
||||
tenant_id = t.id
|
||||
|
||||
yield tenant_id
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
session.execute(delete(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id))
|
||||
session.execute(
|
||||
delete(TenantPluginAutoUpgradeStrategy).where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
)
|
||||
session.execute(delete(Tenant).where(Tenant.id == tenant_id))
|
||||
session.commit()
|
||||
|
||||
|
||||
class TestPluginPermissionLifecycle:
|
||||
def test_get_returns_none_for_new_tenant(self, tenant):
|
||||
assert PluginPermissionService.get_permission(tenant) is None
|
||||
|
||||
def test_change_creates_row(self, tenant):
|
||||
result = PluginPermissionService.change_permission(
|
||||
tenant,
|
||||
TenantPluginPermission.InstallPermission.ADMINS,
|
||||
TenantPluginPermission.DebugPermission.EVERYONE,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
perm = PluginPermissionService.get_permission(tenant)
|
||||
assert perm is not None
|
||||
assert perm.install_permission == TenantPluginPermission.InstallPermission.ADMINS
|
||||
assert perm.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE
|
||||
|
||||
def test_change_updates_existing_row(self, tenant):
|
||||
PluginPermissionService.change_permission(
|
||||
tenant,
|
||||
TenantPluginPermission.InstallPermission.ADMINS,
|
||||
TenantPluginPermission.DebugPermission.NOBODY,
|
||||
)
|
||||
PluginPermissionService.change_permission(
|
||||
tenant,
|
||||
TenantPluginPermission.InstallPermission.EVERYONE,
|
||||
TenantPluginPermission.DebugPermission.ADMINS,
|
||||
)
|
||||
perm = PluginPermissionService.get_permission(tenant)
|
||||
assert perm is not None
|
||||
assert perm.install_permission == TenantPluginPermission.InstallPermission.EVERYONE
|
||||
assert perm.debug_permission == TenantPluginPermission.DebugPermission.ADMINS
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
count = session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant).count()
|
||||
assert count == 1
|
||||
|
||||
|
||||
class TestPluginAutoUpgradeLifecycle:
|
||||
def test_get_returns_none_for_new_tenant(self, tenant):
|
||||
assert PluginAutoUpgradeService.get_strategy(tenant) is None
|
||||
|
||||
def test_change_creates_row(self, tenant):
|
||||
result = PluginAutoUpgradeService.change_strategy(
|
||||
tenant,
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
|
||||
upgrade_time_of_day=3,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
|
||||
exclude_plugins=[],
|
||||
include_plugins=[],
|
||||
)
|
||||
assert result is True
|
||||
|
||||
strategy = PluginAutoUpgradeService.get_strategy(tenant)
|
||||
assert strategy is not None
|
||||
assert strategy.strategy_setting == TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST
|
||||
assert strategy.upgrade_time_of_day == 3
|
||||
|
||||
def test_change_updates_existing_row(self, tenant):
|
||||
PluginAutoUpgradeService.change_strategy(
|
||||
tenant,
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
|
||||
upgrade_time_of_day=0,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
|
||||
exclude_plugins=[],
|
||||
include_plugins=[],
|
||||
)
|
||||
PluginAutoUpgradeService.change_strategy(
|
||||
tenant,
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
|
||||
upgrade_time_of_day=12,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL,
|
||||
exclude_plugins=[],
|
||||
include_plugins=["plugin-a"],
|
||||
)
|
||||
|
||||
strategy = PluginAutoUpgradeService.get_strategy(tenant)
|
||||
assert strategy is not None
|
||||
assert strategy.strategy_setting == TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST
|
||||
assert strategy.upgrade_time_of_day == 12
|
||||
assert strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL
|
||||
assert strategy.include_plugins == ["plugin-a"]
|
||||
|
||||
def test_exclude_plugin_creates_strategy_when_none_exists(self, tenant):
|
||||
PluginAutoUpgradeService.exclude_plugin(tenant, "my-plugin")
|
||||
|
||||
strategy = PluginAutoUpgradeService.get_strategy(tenant)
|
||||
assert strategy is not None
|
||||
assert strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE
|
||||
assert "my-plugin" in strategy.exclude_plugins
|
||||
|
||||
def test_exclude_plugin_appends_in_exclude_mode(self, tenant):
|
||||
PluginAutoUpgradeService.change_strategy(
|
||||
tenant,
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
|
||||
upgrade_time_of_day=0,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
|
||||
exclude_plugins=["existing"],
|
||||
include_plugins=[],
|
||||
)
|
||||
PluginAutoUpgradeService.exclude_plugin(tenant, "new-plugin")
|
||||
|
||||
strategy = PluginAutoUpgradeService.get_strategy(tenant)
|
||||
assert strategy is not None
|
||||
assert "existing" in strategy.exclude_plugins
|
||||
assert "new-plugin" in strategy.exclude_plugins
|
||||
|
||||
def test_exclude_plugin_dedup_in_exclude_mode(self, tenant):
|
||||
PluginAutoUpgradeService.change_strategy(
|
||||
tenant,
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
|
||||
upgrade_time_of_day=0,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
|
||||
exclude_plugins=["same-plugin"],
|
||||
include_plugins=[],
|
||||
)
|
||||
PluginAutoUpgradeService.exclude_plugin(tenant, "same-plugin")
|
||||
|
||||
strategy = PluginAutoUpgradeService.get_strategy(tenant)
|
||||
assert strategy is not None
|
||||
assert strategy.exclude_plugins.count("same-plugin") == 1
|
||||
|
||||
def test_exclude_from_partial_mode_removes_from_include(self, tenant):
|
||||
PluginAutoUpgradeService.change_strategy(
|
||||
tenant,
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
|
||||
upgrade_time_of_day=0,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL,
|
||||
exclude_plugins=[],
|
||||
include_plugins=["p1", "p2"],
|
||||
)
|
||||
PluginAutoUpgradeService.exclude_plugin(tenant, "p1")
|
||||
|
||||
strategy = PluginAutoUpgradeService.get_strategy(tenant)
|
||||
assert strategy is not None
|
||||
assert "p1" not in strategy.include_plugins
|
||||
assert "p2" in strategy.include_plugins
|
||||
|
||||
def test_exclude_from_all_mode_switches_to_exclude(self, tenant):
|
||||
PluginAutoUpgradeService.change_strategy(
|
||||
tenant,
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
|
||||
upgrade_time_of_day=0,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
|
||||
exclude_plugins=[],
|
||||
include_plugins=[],
|
||||
)
|
||||
PluginAutoUpgradeService.exclude_plugin(tenant, "excluded-plugin")
|
||||
|
||||
strategy = PluginAutoUpgradeService.get_strategy(tenant)
|
||||
assert strategy is not None
|
||||
assert strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE
|
||||
assert "excluded-plugin" in strategy.exclude_plugins
|
||||
@@ -0,0 +1,348 @@
|
||||
import datetime
|
||||
import math
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from models import Tenant
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
from models.model import (
|
||||
App,
|
||||
Conversation,
|
||||
Message,
|
||||
MessageAnnotation,
|
||||
MessageFeedback,
|
||||
)
|
||||
from services.retention.conversation.messages_clean_policy import BillingDisabledPolicy
|
||||
from services.retention.conversation.messages_clean_service import MessagesCleanService
|
||||
|
||||
_NOW = datetime.datetime(2026, 1, 15, 12, 0, 0, tzinfo=datetime.UTC)
|
||||
_OLD = _NOW - datetime.timedelta(days=60)
|
||||
_VERY_OLD = _NOW - datetime.timedelta(days=90)
|
||||
_RECENT = _NOW - datetime.timedelta(days=5)
|
||||
|
||||
_WINDOW_START = _VERY_OLD - datetime.timedelta(hours=1)
|
||||
_WINDOW_END = _RECENT + datetime.timedelta(hours=1)
|
||||
|
||||
_DEFAULT_BATCH_SIZE = 100
|
||||
_PAGINATION_MESSAGE_COUNT = 25
|
||||
_PAGINATION_BATCH_SIZE = 8
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_and_app(flask_req_ctx):
|
||||
"""Creates a Tenant, App and Conversation for the test and cleans up after."""
|
||||
with session_factory.create_session() as session:
|
||||
tenant = Tenant(name="retention_it_tenant")
|
||||
session.add(tenant)
|
||||
session.flush()
|
||||
|
||||
app = App(
|
||||
tenant_id=tenant.id,
|
||||
name="Retention IT App",
|
||||
mode="chat",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
)
|
||||
session.add(app)
|
||||
session.flush()
|
||||
|
||||
conv = Conversation(
|
||||
app_id=app.id,
|
||||
mode="chat",
|
||||
name="test_conv",
|
||||
status="normal",
|
||||
from_source="console",
|
||||
_inputs={},
|
||||
)
|
||||
session.add(conv)
|
||||
session.commit()
|
||||
|
||||
tenant_id = tenant.id
|
||||
app_id = app.id
|
||||
conv_id = conv.id
|
||||
|
||||
yield {"tenant_id": tenant_id, "app_id": app_id, "conversation_id": conv_id}
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
session.execute(delete(Conversation).where(Conversation.id == conv_id))
|
||||
session.execute(delete(App).where(App.id == app_id))
|
||||
session.execute(delete(Tenant).where(Tenant.id == tenant_id))
|
||||
session.commit()
|
||||
|
||||
|
||||
def _make_message(app_id: str, conversation_id: str, created_at: datetime.datetime) -> Message:
|
||||
return Message(
|
||||
app_id=app_id,
|
||||
conversation_id=conversation_id,
|
||||
query="test",
|
||||
message=[{"text": "hello"}],
|
||||
answer="world",
|
||||
message_tokens=1,
|
||||
message_unit_price=0,
|
||||
answer_tokens=1,
|
||||
answer_unit_price=0,
|
||||
from_source="console",
|
||||
currency="USD",
|
||||
_inputs={},
|
||||
created_at=created_at,
|
||||
)
|
||||
|
||||
|
||||
class TestMessagesCleanServiceIntegration:
|
||||
@pytest.fixture
|
||||
def seed_messages(self, tenant_and_app):
|
||||
"""Seeds one message at each of _VERY_OLD, _OLD, and _RECENT.
|
||||
Yields a semantic mapping keyed by age label.
|
||||
"""
|
||||
data = tenant_and_app
|
||||
app_id = data["app_id"]
|
||||
conv_id = data["conversation_id"]
|
||||
# Ordered tuple of (label, timestamp) for deterministic seeding
|
||||
timestamps = [
|
||||
("very_old", _VERY_OLD),
|
||||
("old", _OLD),
|
||||
("recent", _RECENT),
|
||||
]
|
||||
msg_ids: dict[str, str] = {}
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
for label, ts in timestamps:
|
||||
msg = _make_message(app_id, conv_id, ts)
|
||||
session.add(msg)
|
||||
session.flush()
|
||||
msg_ids[label] = msg.id
|
||||
session.commit()
|
||||
|
||||
yield {"msg_ids": msg_ids, **data}
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
session.execute(
|
||||
delete(Message)
|
||||
.where(Message.id.in_(list(msg_ids.values())))
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
@pytest.fixture
|
||||
def paginated_seed_messages(self, tenant_and_app):
|
||||
"""Seeds multiple messages separated by 1-second increments starting at _OLD."""
|
||||
data = tenant_and_app
|
||||
app_id = data["app_id"]
|
||||
conv_id = data["conversation_id"]
|
||||
msg_ids: list[str] = []
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
for i in range(_PAGINATION_MESSAGE_COUNT):
|
||||
ts = _OLD + datetime.timedelta(seconds=i)
|
||||
msg = _make_message(app_id, conv_id, ts)
|
||||
session.add(msg)
|
||||
session.flush()
|
||||
msg_ids.append(msg.id)
|
||||
session.commit()
|
||||
|
||||
yield {"msg_ids": msg_ids, **data}
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
session.execute(delete(Message).where(Message.id.in_(msg_ids)).execution_options(synchronize_session=False))
|
||||
session.commit()
|
||||
|
||||
@pytest.fixture
|
||||
def cascade_test_data(self, tenant_and_app):
|
||||
"""Seeds one Message with an associated Feedback and Annotation."""
|
||||
data = tenant_and_app
|
||||
app_id = data["app_id"]
|
||||
conv_id = data["conversation_id"]
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
msg = _make_message(app_id, conv_id, _OLD)
|
||||
session.add(msg)
|
||||
session.flush()
|
||||
|
||||
feedback = MessageFeedback(
|
||||
app_id=app_id,
|
||||
conversation_id=conv_id,
|
||||
message_id=msg.id,
|
||||
rating=FeedbackRating.LIKE,
|
||||
from_source=FeedbackFromSource.USER,
|
||||
)
|
||||
annotation = MessageAnnotation(
|
||||
app_id=app_id,
|
||||
conversation_id=conv_id,
|
||||
message_id=msg.id,
|
||||
question="q",
|
||||
content="a",
|
||||
account_id=str(uuid.uuid4()),
|
||||
)
|
||||
session.add_all([feedback, annotation])
|
||||
session.commit()
|
||||
|
||||
msg_id = msg.id
|
||||
fb_id = feedback.id
|
||||
ann_id = annotation.id
|
||||
|
||||
yield {"msg_id": msg_id, "fb_id": fb_id, "ann_id": ann_id, **data}
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
session.execute(delete(MessageAnnotation).where(MessageAnnotation.id == ann_id))
|
||||
session.execute(delete(MessageFeedback).where(MessageFeedback.id == fb_id))
|
||||
session.execute(delete(Message).where(Message.id == msg_id))
|
||||
session.commit()
|
||||
|
||||
def test_dry_run_does_not_delete(self, seed_messages):
|
||||
"""Dry-run must count eligible rows without deleting any of them."""
|
||||
data = seed_messages
|
||||
msg_ids = data["msg_ids"]
|
||||
all_ids = list(msg_ids.values())
|
||||
|
||||
svc = MessagesCleanService.from_time_range(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=_WINDOW_START,
|
||||
end_before=_WINDOW_END,
|
||||
batch_size=_DEFAULT_BATCH_SIZE,
|
||||
dry_run=True,
|
||||
)
|
||||
stats = svc.run()
|
||||
|
||||
assert stats["filtered_messages"] == len(all_ids)
|
||||
assert stats["total_deleted"] == 0
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
remaining = session.query(Message).where(Message.id.in_(all_ids)).count()
|
||||
assert remaining == len(all_ids)
|
||||
|
||||
def test_billing_disabled_deletes_all_in_range(self, seed_messages):
|
||||
"""All 3 seeded messages fall within the window and must be deleted."""
|
||||
data = seed_messages
|
||||
msg_ids = data["msg_ids"]
|
||||
all_ids = list(msg_ids.values())
|
||||
|
||||
svc = MessagesCleanService.from_time_range(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=_WINDOW_START,
|
||||
end_before=_WINDOW_END,
|
||||
batch_size=_DEFAULT_BATCH_SIZE,
|
||||
dry_run=False,
|
||||
)
|
||||
stats = svc.run()
|
||||
|
||||
assert stats["total_deleted"] == len(all_ids)
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
remaining = session.query(Message).where(Message.id.in_(all_ids)).count()
|
||||
assert remaining == 0
|
||||
|
||||
def test_start_from_filters_correctly(self, seed_messages):
|
||||
"""Only the message at _OLD falls within the narrow ±1 h window."""
|
||||
data = seed_messages
|
||||
msg_ids = data["msg_ids"]
|
||||
|
||||
start = _OLD - datetime.timedelta(hours=1)
|
||||
end = _OLD + datetime.timedelta(hours=1)
|
||||
|
||||
svc = MessagesCleanService.from_time_range(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=start,
|
||||
end_before=end,
|
||||
batch_size=_DEFAULT_BATCH_SIZE,
|
||||
)
|
||||
stats = svc.run()
|
||||
|
||||
assert stats["total_deleted"] == 1
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
all_ids = list(msg_ids.values())
|
||||
remaining_ids = {r[0] for r in session.query(Message.id).where(Message.id.in_(all_ids)).all()}
|
||||
|
||||
assert msg_ids["old"] not in remaining_ids
|
||||
assert msg_ids["very_old"] in remaining_ids
|
||||
assert msg_ids["recent"] in remaining_ids
|
||||
|
||||
def test_cursor_pagination_across_batches(self, paginated_seed_messages):
|
||||
"""Messages must be deleted across multiple batches."""
|
||||
data = paginated_seed_messages
|
||||
msg_ids = data["msg_ids"]
|
||||
|
||||
# _OLD is the earliest; the last one is _OLD + (_PAGINATION_MESSAGE_COUNT - 1) s.
|
||||
pagination_window_start = _OLD - datetime.timedelta(seconds=1)
|
||||
pagination_window_end = _OLD + datetime.timedelta(seconds=_PAGINATION_MESSAGE_COUNT)
|
||||
|
||||
svc = MessagesCleanService.from_time_range(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=pagination_window_start,
|
||||
end_before=pagination_window_end,
|
||||
batch_size=_PAGINATION_BATCH_SIZE,
|
||||
)
|
||||
stats = svc.run()
|
||||
|
||||
assert stats["total_deleted"] == _PAGINATION_MESSAGE_COUNT
|
||||
expected_batches = math.ceil(_PAGINATION_MESSAGE_COUNT / _PAGINATION_BATCH_SIZE)
|
||||
assert stats["batches"] >= expected_batches
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
remaining = session.query(Message).where(Message.id.in_(msg_ids)).count()
|
||||
assert remaining == 0
|
||||
|
||||
def test_no_messages_in_range_returns_empty_stats(self, seed_messages):
|
||||
"""A window entirely in the future must yield zero matches."""
|
||||
far_future = _NOW + datetime.timedelta(days=365)
|
||||
even_further = far_future + datetime.timedelta(days=1)
|
||||
|
||||
svc = MessagesCleanService.from_time_range(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=far_future,
|
||||
end_before=even_further,
|
||||
batch_size=_DEFAULT_BATCH_SIZE,
|
||||
)
|
||||
stats = svc.run()
|
||||
|
||||
assert stats["total_messages"] == 0
|
||||
assert stats["total_deleted"] == 0
|
||||
|
||||
def test_relation_cascade_deletes(self, cascade_test_data):
|
||||
"""Deleting a Message must cascade to its Feedback and Annotation rows."""
|
||||
data = cascade_test_data
|
||||
msg_id = data["msg_id"]
|
||||
fb_id = data["fb_id"]
|
||||
ann_id = data["ann_id"]
|
||||
|
||||
svc = MessagesCleanService.from_time_range(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=_OLD - datetime.timedelta(hours=1),
|
||||
end_before=_OLD + datetime.timedelta(hours=1),
|
||||
batch_size=_DEFAULT_BATCH_SIZE,
|
||||
)
|
||||
stats = svc.run()
|
||||
|
||||
assert stats["total_deleted"] == 1
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
assert session.query(Message).where(Message.id == msg_id).count() == 0
|
||||
assert session.query(MessageFeedback).where(MessageFeedback.id == fb_id).count() == 0
|
||||
assert session.query(MessageAnnotation).where(MessageAnnotation.id == ann_id).count() == 0
|
||||
|
||||
def test_factory_from_time_range_validation(self):
|
||||
with pytest.raises(ValueError, match="start_from"):
|
||||
MessagesCleanService.from_time_range(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=_NOW,
|
||||
end_before=_OLD,
|
||||
)
|
||||
|
||||
def test_factory_from_days_validation(self):
|
||||
with pytest.raises(ValueError, match="days"):
|
||||
MessagesCleanService.from_days(
|
||||
policy=BillingDisabledPolicy(),
|
||||
days=-1,
|
||||
)
|
||||
|
||||
def test_factory_batch_size_validation(self):
|
||||
with pytest.raises(ValueError, match="batch_size"):
|
||||
MessagesCleanService.from_time_range(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=_OLD,
|
||||
end_before=_NOW,
|
||||
batch_size=0,
|
||||
)
|
||||
@@ -0,0 +1,177 @@
|
||||
import datetime
|
||||
import io
|
||||
import json
|
||||
import uuid
|
||||
import zipfile
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from services.retention.workflow_run.archive_paid_plan_workflow_run import (
|
||||
ArchiveSummary,
|
||||
WorkflowRunArchiver,
|
||||
)
|
||||
from services.retention.workflow_run.constants import ARCHIVE_SCHEMA_VERSION
|
||||
|
||||
|
||||
class TestWorkflowRunArchiverInit:
|
||||
def test_start_from_without_end_before_raises(self):
|
||||
with pytest.raises(ValueError, match="start_from and end_before must be provided together"):
|
||||
WorkflowRunArchiver(start_from=datetime.datetime(2025, 1, 1))
|
||||
|
||||
def test_end_before_without_start_from_raises(self):
|
||||
with pytest.raises(ValueError, match="start_from and end_before must be provided together"):
|
||||
WorkflowRunArchiver(end_before=datetime.datetime(2025, 1, 1))
|
||||
|
||||
def test_start_equals_end_raises(self):
|
||||
ts = datetime.datetime(2025, 1, 1)
|
||||
with pytest.raises(ValueError, match="start_from must be earlier than end_before"):
|
||||
WorkflowRunArchiver(start_from=ts, end_before=ts)
|
||||
|
||||
def test_start_after_end_raises(self):
|
||||
with pytest.raises(ValueError, match="start_from must be earlier than end_before"):
|
||||
WorkflowRunArchiver(
|
||||
start_from=datetime.datetime(2025, 6, 1),
|
||||
end_before=datetime.datetime(2025, 1, 1),
|
||||
)
|
||||
|
||||
def test_workers_zero_raises(self):
|
||||
with pytest.raises(ValueError, match="workers must be at least 1"):
|
||||
WorkflowRunArchiver(workers=0)
|
||||
|
||||
def test_valid_init_defaults(self):
|
||||
archiver = WorkflowRunArchiver(days=30, batch_size=50)
|
||||
assert archiver.days == 30
|
||||
assert archiver.batch_size == 50
|
||||
assert archiver.dry_run is False
|
||||
assert archiver.delete_after_archive is False
|
||||
assert archiver.start_from is None
|
||||
|
||||
def test_valid_init_with_time_range(self):
|
||||
start = datetime.datetime(2025, 1, 1)
|
||||
end = datetime.datetime(2025, 6, 1)
|
||||
archiver = WorkflowRunArchiver(start_from=start, end_before=end, workers=2)
|
||||
assert archiver.start_from is not None
|
||||
assert archiver.end_before is not None
|
||||
assert archiver.workers == 2
|
||||
|
||||
|
||||
class TestBuildArchiveBundle:
|
||||
def test_bundle_contains_manifest_and_all_tables(self):
|
||||
archiver = WorkflowRunArchiver(days=90)
|
||||
|
||||
manifest_data = json.dumps({"schema_version": ARCHIVE_SCHEMA_VERSION}).encode("utf-8")
|
||||
table_payloads = dict.fromkeys(archiver.ARCHIVED_TABLES, b"")
|
||||
|
||||
bundle_bytes = archiver._build_archive_bundle(manifest_data, table_payloads)
|
||||
|
||||
with zipfile.ZipFile(io.BytesIO(bundle_bytes), "r") as zf:
|
||||
names = set(zf.namelist())
|
||||
assert "manifest.json" in names
|
||||
for table in archiver.ARCHIVED_TABLES:
|
||||
assert f"{table}.jsonl" in names, f"Missing {table}.jsonl in bundle"
|
||||
|
||||
def test_bundle_missing_table_payload_raises(self):
|
||||
archiver = WorkflowRunArchiver(days=90)
|
||||
manifest_data = b"{}"
|
||||
incomplete_payloads = {archiver.ARCHIVED_TABLES[0]: b"data"}
|
||||
|
||||
with pytest.raises(ValueError, match="Missing archive payload"):
|
||||
archiver._build_archive_bundle(manifest_data, incomplete_payloads)
|
||||
|
||||
|
||||
class TestGenerateManifest:
|
||||
def test_manifest_structure(self):
|
||||
archiver = WorkflowRunArchiver(days=90)
|
||||
from services.retention.workflow_run.archive_paid_plan_workflow_run import TableStats
|
||||
|
||||
run = MagicMock()
|
||||
run.id = str(uuid.uuid4())
|
||||
run.tenant_id = str(uuid.uuid4())
|
||||
run.app_id = str(uuid.uuid4())
|
||||
run.workflow_id = str(uuid.uuid4())
|
||||
run.created_at = datetime.datetime(2025, 3, 15, 10, 0, 0)
|
||||
|
||||
stats = [
|
||||
TableStats(table_name="workflow_runs", row_count=1, checksum="abc123", size_bytes=512),
|
||||
TableStats(table_name="workflow_app_logs", row_count=2, checksum="def456", size_bytes=1024),
|
||||
]
|
||||
|
||||
manifest = archiver._generate_manifest(run, stats)
|
||||
|
||||
assert manifest["schema_version"] == ARCHIVE_SCHEMA_VERSION
|
||||
assert manifest["workflow_run_id"] == run.id
|
||||
assert manifest["tenant_id"] == run.tenant_id
|
||||
assert manifest["app_id"] == run.app_id
|
||||
assert "tables" in manifest
|
||||
assert manifest["tables"]["workflow_runs"]["row_count"] == 1
|
||||
assert manifest["tables"]["workflow_runs"]["checksum"] == "abc123"
|
||||
assert manifest["tables"]["workflow_app_logs"]["row_count"] == 2
|
||||
|
||||
|
||||
class TestFilterPaidTenants:
|
||||
def test_all_tenants_paid_when_billing_disabled(self):
|
||||
archiver = WorkflowRunArchiver(days=90)
|
||||
tenant_ids = {"t1", "t2", "t3"}
|
||||
|
||||
with patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") as cfg:
|
||||
cfg.BILLING_ENABLED = False
|
||||
result = archiver._filter_paid_tenants(tenant_ids)
|
||||
|
||||
assert result == tenant_ids
|
||||
|
||||
def test_empty_tenants_returns_empty(self):
|
||||
archiver = WorkflowRunArchiver(days=90)
|
||||
|
||||
with patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") as cfg:
|
||||
cfg.BILLING_ENABLED = True
|
||||
result = archiver._filter_paid_tenants(set())
|
||||
|
||||
assert result == set()
|
||||
|
||||
def test_only_paid_plans_returned(self):
|
||||
archiver = WorkflowRunArchiver(days=90)
|
||||
|
||||
mock_bulk = {
|
||||
"t1": {"plan": "professional"},
|
||||
"t2": {"plan": "sandbox"},
|
||||
"t3": {"plan": "team"},
|
||||
}
|
||||
|
||||
with (
|
||||
patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") as cfg,
|
||||
patch("services.retention.workflow_run.archive_paid_plan_workflow_run.BillingService") as billing,
|
||||
):
|
||||
cfg.BILLING_ENABLED = True
|
||||
billing.get_plan_bulk_with_cache.return_value = mock_bulk
|
||||
result = archiver._filter_paid_tenants({"t1", "t2", "t3"})
|
||||
|
||||
assert "t1" in result
|
||||
assert "t3" in result
|
||||
assert "t2" not in result
|
||||
|
||||
def test_billing_api_failure_returns_empty(self):
|
||||
archiver = WorkflowRunArchiver(days=90)
|
||||
|
||||
with (
|
||||
patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") as cfg,
|
||||
patch("services.retention.workflow_run.archive_paid_plan_workflow_run.BillingService") as billing,
|
||||
):
|
||||
cfg.BILLING_ENABLED = True
|
||||
billing.get_plan_bulk_with_cache.side_effect = RuntimeError("API down")
|
||||
result = archiver._filter_paid_tenants({"t1"})
|
||||
|
||||
assert result == set()
|
||||
|
||||
|
||||
class TestDryRunArchive:
|
||||
@patch("services.retention.workflow_run.archive_paid_plan_workflow_run.get_archive_storage")
|
||||
def test_dry_run_does_not_call_storage(self, mock_get_storage, flask_req_ctx):
|
||||
archiver = WorkflowRunArchiver(days=90, dry_run=True)
|
||||
|
||||
with patch.object(archiver, "_get_runs_batch", return_value=[]):
|
||||
summary = archiver.run()
|
||||
|
||||
mock_get_storage.assert_not_called()
|
||||
assert isinstance(summary, ArchiveSummary)
|
||||
assert summary.runs_failed == 0
|
||||
@@ -1,7 +1,4 @@
|
||||
"""
|
||||
Additional tests to improve coverage for low-coverage modules in controllers/console/app.
|
||||
Target: increase coverage for files with <75% coverage.
|
||||
"""
|
||||
"""Testcontainers integration tests for controllers/console/app endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -70,26 +67,12 @@ def _unwrap(func):
|
||||
return func
|
||||
|
||||
|
||||
class _ConnContext:
|
||||
def __init__(self, rows):
|
||||
self._rows = rows
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def execute(self, _query, _args):
|
||||
return self._rows
|
||||
|
||||
|
||||
# ========== Completion Tests ==========
|
||||
class TestCompletionEndpoints:
|
||||
"""Tests for completion API endpoints."""
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_completion_create_payload(self):
|
||||
"""Test completion creation payload."""
|
||||
payload = CompletionMessagePayload(inputs={"prompt": "test"}, model_config={})
|
||||
assert payload.inputs == {"prompt": "test"}
|
||||
|
||||
@@ -209,7 +192,9 @@ class TestCompletionEndpoints:
|
||||
|
||||
|
||||
class TestAppEndpoints:
|
||||
"""Tests for app endpoints."""
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_app_put_should_preserve_icon_type_when_payload_omits_it(self, app, monkeypatch):
|
||||
api = app_module.AppApi()
|
||||
@@ -250,12 +235,12 @@ class TestAppEndpoints:
|
||||
)
|
||||
|
||||
|
||||
# ========== OpsTrace Tests ==========
|
||||
class TestOpsTraceEndpoints:
|
||||
"""Tests for ops_trace endpoint."""
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_ops_trace_query_basic(self):
|
||||
"""Test ops_trace query."""
|
||||
query = TraceProviderQuery(tracing_provider="langfuse")
|
||||
assert query.tracing_provider == "langfuse"
|
||||
|
||||
@@ -310,12 +295,12 @@ class TestOpsTraceEndpoints:
|
||||
method(app_id="app-1")
|
||||
|
||||
|
||||
# ========== Site Tests ==========
|
||||
class TestSiteEndpoints:
|
||||
"""Tests for site endpoint."""
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_site_response_structure(self):
|
||||
"""Test site response structure."""
|
||||
payload = AppSiteUpdatePayload(title="My Site", description="Test site")
|
||||
assert payload.title == "My Site"
|
||||
|
||||
@@ -369,27 +354,22 @@ class TestSiteEndpoints:
|
||||
assert result is site
|
||||
|
||||
|
||||
# ========== Workflow Tests ==========
|
||||
class TestWorkflowEndpoints:
|
||||
"""Tests for workflow endpoints."""
|
||||
|
||||
def test_workflow_copy_payload(self):
|
||||
"""Test workflow copy payload."""
|
||||
payload = SyncDraftWorkflowPayload(graph={}, features={})
|
||||
assert payload.graph == {}
|
||||
|
||||
def test_workflow_mode_query(self):
|
||||
"""Test workflow mode query."""
|
||||
payload = AdvancedChatWorkflowRunPayload(inputs={}, query="hi")
|
||||
assert payload.query == "hi"
|
||||
|
||||
|
||||
# ========== Workflow App Log Tests ==========
|
||||
class TestWorkflowAppLogEndpoints:
|
||||
"""Tests for workflow app log endpoints."""
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_workflow_app_log_query(self):
|
||||
"""Test workflow app log query."""
|
||||
query = WorkflowAppLogQuery(keyword="test", page=1, limit=20)
|
||||
assert query.keyword == "test"
|
||||
|
||||
@@ -427,12 +407,12 @@ class TestWorkflowAppLogEndpoints:
|
||||
assert result == {"items": [], "total": 0}
|
||||
|
||||
|
||||
# ========== Workflow Draft Variable Tests ==========
|
||||
class TestWorkflowDraftVariableEndpoints:
|
||||
"""Tests for workflow draft variable endpoints."""
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_workflow_variable_creation(self):
|
||||
"""Test workflow variable creation."""
|
||||
payload = WorkflowDraftVariableUpdatePayload(name="var1", value="test")
|
||||
assert payload.name == "var1"
|
||||
|
||||
@@ -472,12 +452,12 @@ class TestWorkflowDraftVariableEndpoints:
|
||||
assert result == {"items": [], "total": 0}
|
||||
|
||||
|
||||
# ========== Workflow Statistic Tests ==========
|
||||
class TestWorkflowStatisticEndpoints:
|
||||
"""Tests for workflow statistic endpoints."""
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_workflow_statistic_time_range(self):
|
||||
"""Test workflow statistic time range query."""
|
||||
query = WorkflowStatisticQuery(start="2024-01-01", end="2024-12-31")
|
||||
assert query.start == "2024-01-01"
|
||||
|
||||
@@ -541,12 +521,12 @@ class TestWorkflowStatisticEndpoints:
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-02"}]}
|
||||
|
||||
|
||||
# ========== Workflow Trigger Tests ==========
|
||||
class TestWorkflowTriggerEndpoints:
|
||||
"""Tests for workflow trigger endpoints."""
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_webhook_trigger_payload(self):
|
||||
"""Test webhook trigger payload."""
|
||||
payload = Parser(node_id="node-1")
|
||||
assert payload.node_id == "node-1"
|
||||
|
||||
@@ -578,22 +558,13 @@ class TestWorkflowTriggerEndpoints:
|
||||
assert result is trigger
|
||||
|
||||
|
||||
# ========== Wraps Tests ==========
|
||||
class TestWrapsEndpoints:
|
||||
"""Tests for wraps utility functions."""
|
||||
|
||||
def test_get_app_model_context(self):
|
||||
"""Test get_app_model wrapper context."""
|
||||
# These are decorator functions, so we test their availability
|
||||
assert hasattr(wraps_module, "get_app_model")
|
||||
|
||||
|
||||
# ========== MCP Server Tests ==========
|
||||
class TestMCPServerEndpoints:
|
||||
"""Tests for MCP server endpoints."""
|
||||
|
||||
def test_mcp_server_connection(self):
|
||||
"""Test MCP server connection."""
|
||||
payload = MCPServerCreatePayload(parameters={"url": "http://localhost:3000"})
|
||||
assert payload.parameters["url"] == "http://localhost:3000"
|
||||
|
||||
@@ -602,22 +573,14 @@ class TestMCPServerEndpoints:
|
||||
assert payload.status == "active"
|
||||
|
||||
|
||||
# ========== Error Handling Tests ==========
|
||||
class TestErrorHandling:
|
||||
"""Tests for error handling in various endpoints."""
|
||||
|
||||
def test_annotation_list_query_validation(self):
|
||||
"""Test annotation list query validation."""
|
||||
with pytest.raises(ValueError):
|
||||
annotation_module.AnnotationListQuery(page=0)
|
||||
|
||||
|
||||
# ========== Integration-like Tests ==========
|
||||
class TestPayloadIntegration:
|
||||
"""Integration tests for payload handling."""
|
||||
|
||||
def test_multiple_payload_types(self):
|
||||
"""Test handling of multiple payload types."""
|
||||
payloads = [
|
||||
annotation_module.AnnotationReplyPayload(
|
||||
score_threshold=0.5, embedding_provider_name="openai", embedding_model_name="text-embedding-3-small"
|
||||
@@ -0,0 +1,142 @@
|
||||
"""Testcontainers integration tests for controllers.console.app.app_import endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.app import app_import as app_import_module
|
||||
from services.app_dsl_service import ImportStatus
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
class _Result:
|
||||
def __init__(self, status: ImportStatus, app_id: str | None = "app-1"):
|
||||
self.status = status
|
||||
self.app_id = app_id
|
||||
|
||||
def model_dump(self, mode: str = "json"):
|
||||
return {"status": self.status, "app_id": self.app_id}
|
||||
|
||||
|
||||
def _install_features(monkeypatch: pytest.MonkeyPatch, enabled: bool) -> None:
|
||||
features = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=enabled))
|
||||
monkeypatch.setattr(app_import_module.FeatureService, "get_system_features", lambda: features)
|
||||
|
||||
|
||||
class TestAppImportApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_import_post_returns_failed_status(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
_install_features(monkeypatch, enabled=False)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
|
||||
assert status == 400
|
||||
assert response["status"] == ImportStatus.FAILED
|
||||
|
||||
def test_import_post_returns_pending_status(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
_install_features(monkeypatch, enabled=False)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.PENDING),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
|
||||
assert status == 202
|
||||
assert response["status"] == ImportStatus.PENDING
|
||||
|
||||
def test_import_post_updates_webapp_auth_when_enabled(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
_install_features(monkeypatch, enabled=True)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"),
|
||||
)
|
||||
update_access = MagicMock()
|
||||
monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
|
||||
update_access.assert_called_once_with("app-123", "private")
|
||||
assert status == 200
|
||||
assert response["status"] == ImportStatus.COMPLETED
|
||||
|
||||
|
||||
class TestAppImportConfirmApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_import_confirm_returns_failed_status(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportConfirmApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"confirm_import",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"):
|
||||
response, status = method(import_id="import-1")
|
||||
|
||||
assert status == 400
|
||||
assert response["status"] == ImportStatus.FAILED
|
||||
|
||||
|
||||
class TestAppImportCheckDependenciesApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_import_check_dependencies_returns_result(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportCheckDependenciesApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"check_dependencies",
|
||||
lambda *_args, **_kwargs: SimpleNamespace(model_dump=lambda mode="json": {"leaked_dependencies": []}),
|
||||
)
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports/app-1/check-dependencies", method="GET"):
|
||||
response, status = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert status == 200
|
||||
assert response["leaked_dependencies"] == []
|
||||
@@ -1,6 +1,12 @@
|
||||
"""Testcontainers integration tests for rag_pipeline controller endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline import (
|
||||
@@ -9,6 +15,7 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline import (
|
||||
PipelineTemplateListApi,
|
||||
PublishCustomizedPipelineTemplateApi,
|
||||
)
|
||||
from models.dataset import PipelineCustomizedTemplate
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
@@ -18,6 +25,10 @@ def unwrap(func):
|
||||
|
||||
|
||||
class TestPipelineTemplateListApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_success(self, app):
|
||||
api = PipelineTemplateListApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -38,6 +49,10 @@ class TestPipelineTemplateListApi:
|
||||
|
||||
|
||||
class TestPipelineTemplateDetailApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_success(self, app):
|
||||
api = PipelineTemplateDetailApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -99,6 +114,10 @@ class TestPipelineTemplateDetailApi:
|
||||
|
||||
|
||||
class TestCustomizedPipelineTemplateApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_patch_success(self, app):
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.patch)
|
||||
@@ -136,35 +155,29 @@ class TestCustomizedPipelineTemplateApi:
|
||||
delete_mock.assert_called_once_with("tpl-1")
|
||||
assert response == 200
|
||||
|
||||
def test_post_success(self, app):
|
||||
def test_post_success(self, app, db_session_with_containers: Session):
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
template = MagicMock()
|
||||
template.yaml_content = "yaml-data"
|
||||
tenant_id = str(uuid4())
|
||||
template = PipelineCustomizedTemplate(
|
||||
tenant_id=tenant_id,
|
||||
name="Test Template",
|
||||
description="Test",
|
||||
chunk_structure="hierarchical",
|
||||
icon={"icon": "📘"},
|
||||
position=0,
|
||||
yaml_content="yaml-data",
|
||||
install_count=0,
|
||||
language="en-US",
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
db_session_with_containers.add(template)
|
||||
db_session_with_containers.commit()
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = template
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "tpl-1")
|
||||
with app.test_request_context("/"):
|
||||
response, status = method(api, template.id)
|
||||
|
||||
assert status == 200
|
||||
assert response == {"data": "yaml-data"}
|
||||
@@ -173,32 +186,16 @@ class TestCustomizedPipelineTemplateApi:
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
):
|
||||
with app.test_request_context("/"):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "tpl-1")
|
||||
method(api, str(uuid4()))
|
||||
|
||||
|
||||
class TestPublishCustomizedPipelineTemplateApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_post_success(self, app):
|
||||
api = PublishCustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -1,3 +1,7 @@
|
||||
"""Testcontainers integration tests for rag_pipeline_datasets controller endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -19,6 +23,10 @@ def unwrap(func):
|
||||
|
||||
|
||||
class TestCreateRagPipelineDatasetApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def _valid_payload(self):
|
||||
return {"yaml_content": "name: test"}
|
||||
|
||||
@@ -33,13 +41,6 @@ class TestCreateRagPipelineDatasetApi:
|
||||
mock_service = MagicMock()
|
||||
mock_service.create_rag_pipeline_dataset.return_value = import_info
|
||||
|
||||
mock_session_ctx = MagicMock()
|
||||
mock_session_ctx.__enter__.return_value = MagicMock()
|
||||
mock_session_ctx.__exit__.return_value = None
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
@@ -47,14 +48,6 @@ class TestCreateRagPipelineDatasetApi:
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session",
|
||||
return_value=mock_session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService",
|
||||
return_value=mock_service,
|
||||
@@ -93,13 +86,6 @@ class TestCreateRagPipelineDatasetApi:
|
||||
mock_service = MagicMock()
|
||||
mock_service.create_rag_pipeline_dataset.side_effect = services.errors.dataset.DatasetNameDuplicateError()
|
||||
|
||||
mock_session_ctx = MagicMock()
|
||||
mock_session_ctx.__enter__.return_value = MagicMock()
|
||||
mock_session_ctx.__exit__.return_value = None
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
@@ -107,14 +93,6 @@ class TestCreateRagPipelineDatasetApi:
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session",
|
||||
return_value=mock_session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService",
|
||||
return_value=mock_service,
|
||||
@@ -143,6 +121,10 @@ class TestCreateRagPipelineDatasetApi:
|
||||
|
||||
|
||||
class TestCreateEmptyRagPipelineDatasetApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_post_success(self, app):
|
||||
api = CreateEmptyRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -1,5 +1,11 @@
|
||||
"""Testcontainers integration tests for rag_pipeline_import controller endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline_import import (
|
||||
RagPipelineExportApi,
|
||||
@@ -18,6 +24,10 @@ def unwrap(func):
|
||||
|
||||
|
||||
class TestRagPipelineImportApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def _payload(self, mode="create"):
|
||||
return {
|
||||
"mode": mode,
|
||||
@@ -30,7 +40,6 @@ class TestRagPipelineImportApi:
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._payload()
|
||||
|
||||
user = MagicMock()
|
||||
result = MagicMock()
|
||||
result.status = "completed"
|
||||
@@ -39,13 +48,6 @@ class TestRagPipelineImportApi:
|
||||
service = MagicMock()
|
||||
service.import_rag_pipeline.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
@@ -53,14 +55,6 @@ class TestRagPipelineImportApi:
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
@@ -76,7 +70,6 @@ class TestRagPipelineImportApi:
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._payload()
|
||||
|
||||
user = MagicMock()
|
||||
result = MagicMock()
|
||||
result.status = ImportStatus.FAILED
|
||||
@@ -85,13 +78,6 @@ class TestRagPipelineImportApi:
|
||||
service = MagicMock()
|
||||
service.import_rag_pipeline.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
@@ -99,14 +85,6 @@ class TestRagPipelineImportApi:
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
@@ -122,7 +100,6 @@ class TestRagPipelineImportApi:
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._payload()
|
||||
|
||||
user = MagicMock()
|
||||
result = MagicMock()
|
||||
result.status = ImportStatus.PENDING
|
||||
@@ -131,13 +108,6 @@ class TestRagPipelineImportApi:
|
||||
service = MagicMock()
|
||||
service.import_rag_pipeline.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
@@ -145,14 +115,6 @@ class TestRagPipelineImportApi:
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
@@ -165,6 +127,10 @@ class TestRagPipelineImportApi:
|
||||
|
||||
|
||||
class TestRagPipelineImportConfirmApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_confirm_success(self, app):
|
||||
api = RagPipelineImportConfirmApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -177,27 +143,12 @@ class TestRagPipelineImportConfirmApi:
|
||||
service = MagicMock()
|
||||
service.confirm_import.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
@@ -220,27 +171,12 @@ class TestRagPipelineImportConfirmApi:
|
||||
service = MagicMock()
|
||||
service.confirm_import.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
@@ -253,6 +189,10 @@ class TestRagPipelineImportConfirmApi:
|
||||
|
||||
|
||||
class TestRagPipelineImportCheckDependenciesApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_success(self, app):
|
||||
api = RagPipelineImportCheckDependenciesApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -264,23 +204,8 @@ class TestRagPipelineImportCheckDependenciesApi:
|
||||
service = MagicMock()
|
||||
service.check_dependencies.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
@@ -293,6 +218,10 @@ class TestRagPipelineImportCheckDependenciesApi:
|
||||
|
||||
|
||||
class TestRagPipelineExportApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_with_include_secret(self, app):
|
||||
api = RagPipelineExportApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -301,23 +230,8 @@ class TestRagPipelineExportApi:
|
||||
service = MagicMock()
|
||||
service.export_rag_pipeline_dsl.return_value = {"yaml": "data"}
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/?include_secret=true"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
@@ -1,7 +1,13 @@
|
||||
"""Testcontainers integration tests for rag_pipeline_workflow controller endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, HTTPException, NotFound
|
||||
|
||||
import services
|
||||
@@ -38,6 +44,10 @@ def unwrap(func):
|
||||
|
||||
|
||||
class TestDraftWorkflowApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_draft_success(self, app):
|
||||
api = DraftRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -200,6 +210,10 @@ class TestDraftWorkflowApi:
|
||||
|
||||
|
||||
class TestDraftRunNodes:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_iteration_node_success(self, app):
|
||||
api = RagPipelineDraftRunIterationNodeApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -275,6 +289,10 @@ class TestDraftRunNodes:
|
||||
|
||||
|
||||
class TestPipelineRunApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_draft_run_success(self, app):
|
||||
api = DraftRagPipelineRunApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -337,6 +355,10 @@ class TestPipelineRunApis:
|
||||
|
||||
|
||||
class TestDraftNodeRun:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_execution_not_found(self, app):
|
||||
api = RagPipelineDraftNodeRunApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -364,45 +386,43 @@ class TestDraftNodeRun:
|
||||
|
||||
|
||||
class TestPublishedPipelineApis:
|
||||
def test_publish_success(self, app):
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_publish_success(self, app, db_session_with_containers: Session):
|
||||
from models.dataset import Pipeline
|
||||
|
||||
api = PublishedRagPipelineApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
tenant_id = str(uuid4())
|
||||
pipeline = Pipeline(
|
||||
tenant_id=tenant_id,
|
||||
name="test-pipeline",
|
||||
description="test",
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
db_session_with_containers.add(pipeline)
|
||||
db_session_with_containers.commit()
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
workflow = MagicMock(
|
||||
id="w1",
|
||||
id=str(uuid4()),
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
session = MagicMock()
|
||||
session.merge.return_value = pipeline
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
service = MagicMock()
|
||||
service.publish_workflow.return_value = workflow
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
@@ -415,6 +435,10 @@ class TestPublishedPipelineApis:
|
||||
|
||||
|
||||
class TestMiscApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_task_stop(self, app):
|
||||
api = RagPipelineTaskStopApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -471,6 +495,10 @@ class TestMiscApis:
|
||||
|
||||
|
||||
class TestPublishedRagPipelineRunApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_published_run_success(self, app):
|
||||
api = PublishedRagPipelineRunApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -536,6 +564,10 @@ class TestPublishedRagPipelineRunApi:
|
||||
|
||||
|
||||
class TestDefaultBlockConfigApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_block_config_success(self, app):
|
||||
api = DefaultRagPipelineBlockConfigApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -567,6 +599,10 @@ class TestDefaultBlockConfigApi:
|
||||
|
||||
|
||||
class TestPublishedAllRagPipelineApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_published_workflows_success(self, app):
|
||||
api = PublishedAllRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -577,28 +613,12 @@ class TestPublishedAllRagPipelineApi:
|
||||
service = MagicMock()
|
||||
service.get_all_published_workflow.return_value = ([{"id": "w1"}], False)
|
||||
|
||||
session = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
@@ -628,6 +648,10 @@ class TestPublishedAllRagPipelineApi:
|
||||
|
||||
|
||||
class TestRagPipelineByIdApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_patch_success(self, app):
|
||||
api = RagPipelineByIdApi()
|
||||
method = unwrap(api.patch)
|
||||
@@ -640,14 +664,6 @@ class TestRagPipelineByIdApi:
|
||||
service = MagicMock()
|
||||
service.update_workflow.return_value = workflow
|
||||
|
||||
session = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
payload = {"marked_name": "test"}
|
||||
|
||||
with (
|
||||
@@ -657,14 +673,6 @@ class TestRagPipelineByIdApi:
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
@@ -700,24 +708,8 @@ class TestRagPipelineByIdApi:
|
||||
|
||||
workflow_service = MagicMock()
|
||||
|
||||
session = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", method="DELETE"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.WorkflowService",
|
||||
return_value=workflow_service,
|
||||
@@ -725,12 +717,7 @@ class TestRagPipelineByIdApi:
|
||||
):
|
||||
result = method(api, pipeline, "old-workflow")
|
||||
|
||||
workflow_service.delete_workflow.assert_called_once_with(
|
||||
session=session,
|
||||
workflow_id="old-workflow",
|
||||
tenant_id="t1",
|
||||
)
|
||||
session.commit.assert_called_once()
|
||||
workflow_service.delete_workflow.assert_called_once()
|
||||
assert result == (None, 204)
|
||||
|
||||
def test_delete_active_workflow_rejected(self, app):
|
||||
@@ -745,6 +732,10 @@ class TestRagPipelineByIdApi:
|
||||
|
||||
|
||||
class TestRagPipelineWorkflowLastRunApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_last_run_success(self, app):
|
||||
api = RagPipelineWorkflowLastRunApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -788,6 +779,10 @@ class TestRagPipelineWorkflowLastRunApi:
|
||||
|
||||
|
||||
class TestRagPipelineDatasourceVariableApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_set_datasource_variables_success(self, app):
|
||||
api = RagPipelineDatasourceVariableApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -1,3 +1,7 @@
|
||||
"""Testcontainers integration tests for controllers.console.datasets.data_source endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -46,6 +50,10 @@ def mock_engine():
|
||||
|
||||
|
||||
class TestDataSourceApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_success(self, app, patch_tenant):
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -179,6 +187,10 @@ class TestDataSourceApi:
|
||||
|
||||
|
||||
class TestDataSourceNotionListApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_credential_not_found(self, app, patch_tenant):
|
||||
api = DataSourceNotionListApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -310,6 +322,10 @@ class TestDataSourceNotionListApi:
|
||||
|
||||
|
||||
class TestDataSourceNotionApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_preview_success(self, app, patch_tenant):
|
||||
api = DataSourceNotionApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -364,6 +380,10 @@ class TestDataSourceNotionApi:
|
||||
|
||||
|
||||
class TestDataSourceNotionDatasetSyncApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_success(self, app, patch_tenant):
|
||||
api = DataSourceNotionDatasetSyncApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -403,6 +423,10 @@ class TestDataSourceNotionDatasetSyncApi:
|
||||
|
||||
|
||||
class TestDataSourceNotionDocumentSyncApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_success(self, app, patch_tenant):
|
||||
api = DataSourceNotionDocumentSyncApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -1,7 +1,10 @@
|
||||
"""Testcontainers integration tests for controllers.console.explore.conversation endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import controllers.console.explore.conversation as conversation_module
|
||||
@@ -48,24 +51,12 @@ def user():
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_db_and_session():
|
||||
with (
|
||||
patch.object(
|
||||
conversation_module,
|
||||
"db",
|
||||
MagicMock(session=MagicMock(), engine=MagicMock()),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.explore.conversation.Session",
|
||||
MagicMock(),
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
class TestConversationListApi:
|
||||
def test_get_success(self, app: Flask, chat_app, user):
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_success(self, app, chat_app, user):
|
||||
api = conversation_module.ConversationListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@@ -90,7 +81,7 @@ class TestConversationListApi:
|
||||
assert result["has_more"] is False
|
||||
assert len(result["data"]) == 2
|
||||
|
||||
def test_last_conversation_not_exists(self, app: Flask, chat_app, user):
|
||||
def test_last_conversation_not_exists(self, app, chat_app, user):
|
||||
api = conversation_module.ConversationListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@@ -106,7 +97,7 @@ class TestConversationListApi:
|
||||
with pytest.raises(NotFound):
|
||||
method(chat_app)
|
||||
|
||||
def test_wrong_app_mode(self, app: Flask, non_chat_app):
|
||||
def test_wrong_app_mode(self, app, non_chat_app):
|
||||
api = conversation_module.ConversationListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@@ -116,7 +107,11 @@ class TestConversationListApi:
|
||||
|
||||
|
||||
class TestConversationApi:
|
||||
def test_delete_success(self, app: Flask, chat_app, user):
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_delete_success(self, app, chat_app, user):
|
||||
api = conversation_module.ConversationApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
@@ -134,7 +129,7 @@ class TestConversationApi:
|
||||
assert status == 204
|
||||
assert body["result"] == "success"
|
||||
|
||||
def test_delete_not_found(self, app: Flask, chat_app, user):
|
||||
def test_delete_not_found(self, app, chat_app, user):
|
||||
api = conversation_module.ConversationApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
@@ -150,7 +145,7 @@ class TestConversationApi:
|
||||
with pytest.raises(NotFound):
|
||||
method(chat_app, "cid")
|
||||
|
||||
def test_delete_wrong_app_mode(self, app: Flask, non_chat_app):
|
||||
def test_delete_wrong_app_mode(self, app, non_chat_app):
|
||||
api = conversation_module.ConversationApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
@@ -160,7 +155,11 @@ class TestConversationApi:
|
||||
|
||||
|
||||
class TestConversationRenameApi:
|
||||
def test_rename_success(self, app: Flask, chat_app, user):
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_rename_success(self, app, chat_app, user):
|
||||
api = conversation_module.ConversationRenameApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@@ -179,7 +178,7 @@ class TestConversationRenameApi:
|
||||
|
||||
assert result["id"] == "cid"
|
||||
|
||||
def test_rename_not_found(self, app: Flask, chat_app, user):
|
||||
def test_rename_not_found(self, app, chat_app, user):
|
||||
api = conversation_module.ConversationRenameApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@@ -197,7 +196,11 @@ class TestConversationRenameApi:
|
||||
|
||||
|
||||
class TestConversationPinApi:
|
||||
def test_pin_success(self, app: Flask, chat_app, user):
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_pin_success(self, app, chat_app, user):
|
||||
api = conversation_module.ConversationPinApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
@@ -215,7 +218,11 @@ class TestConversationPinApi:
|
||||
|
||||
|
||||
class TestConversationUnPinApi:
|
||||
def test_unpin_success(self, app: Flask, chat_app, user):
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_unpin_success(self, app, chat_app, user):
|
||||
api = conversation_module.ConversationUnPinApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
"""Testcontainers integration tests for controllers.console.workspace.tool_providers endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask_restx import Api
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console.workspace.tool_providers import (
|
||||
@@ -31,7 +33,6 @@ from controllers.console.workspace.tool_providers import (
|
||||
ToolOAuthCustomClient,
|
||||
ToolPluginOAuthApi,
|
||||
ToolProviderListApi,
|
||||
ToolProviderMCPApi,
|
||||
ToolWorkflowListApi,
|
||||
ToolWorkflowProviderCreateApi,
|
||||
ToolWorkflowProviderDeleteApi,
|
||||
@@ -39,8 +40,6 @@ from controllers.console.workspace.tool_providers import (
|
||||
ToolWorkflowProviderUpdateApi,
|
||||
is_valid_url,
|
||||
)
|
||||
from core.db.session_factory import configure_session_factory
|
||||
from extensions.ext_database import db
|
||||
from services.tools.mcp_tools_manage_service import ReconnectResult
|
||||
|
||||
|
||||
@@ -61,17 +60,8 @@ def _mock_user_tenant():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:"
|
||||
api = Api(app)
|
||||
api.add_resource(ToolProviderMCPApi, "/console/api/workspaces/current/tool-provider/mcp")
|
||||
db.init_app(app)
|
||||
# Configure session factory used by controller code
|
||||
with app.app_context():
|
||||
configure_session_factory(db.engine)
|
||||
return app.test_client()
|
||||
def client(flask_app_with_containers):
|
||||
return flask_app_with_containers.test_client()
|
||||
|
||||
|
||||
@patch(
|
||||
@@ -152,10 +142,14 @@ class TestUtils:
|
||||
assert not is_valid_url("")
|
||||
assert not is_valid_url("ftp://example.com")
|
||||
assert not is_valid_url("not-a-url")
|
||||
assert not is_valid_url(None)
|
||||
assert not is_valid_url(None) # type: ignore[arg-type]
|
||||
|
||||
|
||||
class TestToolProviderListApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_success(self, app):
|
||||
api = ToolProviderListApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -175,6 +169,10 @@ class TestToolProviderListApi:
|
||||
|
||||
|
||||
class TestBuiltinProviderApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_list_tools(self, app):
|
||||
api = ToolBuiltinProviderListToolsApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -379,6 +377,10 @@ class TestBuiltinProviderApis:
|
||||
|
||||
|
||||
class TestApiProviderApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_add(self, app):
|
||||
api = ToolApiProviderAddApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -502,6 +504,10 @@ class TestApiProviderApis:
|
||||
|
||||
|
||||
class TestWorkflowApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_create(self, app):
|
||||
api = ToolWorkflowProviderCreateApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -587,6 +593,10 @@ class TestWorkflowApis:
|
||||
|
||||
|
||||
class TestLists:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_builtin_list(self, app):
|
||||
api = ToolBuiltinListApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -649,6 +659,10 @@ class TestLists:
|
||||
|
||||
|
||||
class TestLabels:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_labels(self, app):
|
||||
api = ToolLabelsApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -664,6 +678,10 @@ class TestLabels:
|
||||
|
||||
|
||||
class TestOAuth:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_oauth_no_client(self, app):
|
||||
api = ToolPluginOAuthApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -692,6 +710,10 @@ class TestOAuth:
|
||||
|
||||
|
||||
class TestOAuthCustomClient:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_save_custom_client(self, app):
|
||||
api = ToolOAuthCustomClient()
|
||||
method = unwrap(api.post)
|
||||
@@ -1,3 +1,7 @@
|
||||
"""Testcontainers integration tests for controllers.console.workspace.trigger_providers endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -40,6 +44,10 @@ def mock_user():
|
||||
|
||||
|
||||
class TestTriggerProviderApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_icon_success(self, app):
|
||||
api = TriggerProviderIconApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -84,6 +92,10 @@ class TestTriggerProviderApis:
|
||||
|
||||
|
||||
class TestTriggerSubscriptionListApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_list_success(self, app):
|
||||
api = TriggerSubscriptionListApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -115,6 +127,10 @@ class TestTriggerSubscriptionListApi:
|
||||
|
||||
|
||||
class TestTriggerSubscriptionBuilderApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_create_builder(self, app):
|
||||
api = TriggerSubscriptionBuilderCreateApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -219,6 +235,10 @@ class TestTriggerSubscriptionBuilderApis:
|
||||
|
||||
|
||||
class TestTriggerSubscriptionCrud:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_update_rename_only(self, app):
|
||||
api = TriggerSubscriptionUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -321,6 +341,10 @@ class TestTriggerSubscriptionCrud:
|
||||
|
||||
|
||||
class TestTriggerOAuthApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_oauth_authorize_success(self, app):
|
||||
api = TriggerOAuthAuthorizeApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -455,6 +479,10 @@ class TestTriggerOAuthApis:
|
||||
|
||||
|
||||
class TestTriggerOAuthClientManageApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_client(self, app):
|
||||
api = TriggerOAuthClientManageApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -527,6 +555,10 @@ class TestTriggerOAuthClientManageApi:
|
||||
|
||||
|
||||
class TestTriggerSubscriptionVerifyApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_verify_success(self, app):
|
||||
api = TriggerSubscriptionVerifyApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -0,0 +1,185 @@
|
||||
"""Testcontainers integration tests for plugin_permission_required decorator."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console.workspace import plugin_permission_required
|
||||
from models.account import Tenant, TenantPluginPermission, TenantStatus
|
||||
|
||||
|
||||
def _create_tenant(db_session: Session) -> Tenant:
|
||||
tenant = Tenant(name="test-tenant", status=TenantStatus.NORMAL, plan="basic")
|
||||
db_session.add(tenant)
|
||||
db_session.commit()
|
||||
db_session.expire_all()
|
||||
return tenant
|
||||
|
||||
|
||||
def _create_permission(
|
||||
db_session: Session,
|
||||
tenant_id: str,
|
||||
install: TenantPluginPermission.InstallPermission = TenantPluginPermission.InstallPermission.EVERYONE,
|
||||
debug: TenantPluginPermission.DebugPermission = TenantPluginPermission.DebugPermission.EVERYONE,
|
||||
) -> TenantPluginPermission:
|
||||
perm = TenantPluginPermission(
|
||||
tenant_id=tenant_id,
|
||||
install_permission=install,
|
||||
debug_permission=debug,
|
||||
)
|
||||
db_session.add(perm)
|
||||
db_session.commit()
|
||||
db_session.expire_all()
|
||||
return perm
|
||||
|
||||
|
||||
class TestPluginPermissionRequired:
|
||||
def test_allows_without_permission(self, db_session_with_containers: Session):
|
||||
tenant = _create_tenant(db_session_with_containers)
|
||||
user = SimpleNamespace(is_admin_or_owner=False)
|
||||
|
||||
with patch(
|
||||
"controllers.console.workspace.current_account_with_tenant",
|
||||
return_value=(user, tenant.id),
|
||||
):
|
||||
|
||||
@plugin_permission_required()
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
assert handler() == "ok"
|
||||
|
||||
def test_install_nobody_forbidden(self, db_session_with_containers: Session):
|
||||
tenant = _create_tenant(db_session_with_containers)
|
||||
_create_permission(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
install=TenantPluginPermission.InstallPermission.NOBODY,
|
||||
debug=TenantPluginPermission.DebugPermission.EVERYONE,
|
||||
)
|
||||
user = SimpleNamespace(is_admin_or_owner=True)
|
||||
|
||||
with patch(
|
||||
"controllers.console.workspace.current_account_with_tenant",
|
||||
return_value=(user, tenant.id),
|
||||
):
|
||||
|
||||
@plugin_permission_required(install_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
handler()
|
||||
|
||||
def test_install_admin_requires_admin(self, db_session_with_containers: Session):
|
||||
tenant = _create_tenant(db_session_with_containers)
|
||||
_create_permission(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
install=TenantPluginPermission.InstallPermission.ADMINS,
|
||||
debug=TenantPluginPermission.DebugPermission.EVERYONE,
|
||||
)
|
||||
user = SimpleNamespace(is_admin_or_owner=False)
|
||||
|
||||
with patch(
|
||||
"controllers.console.workspace.current_account_with_tenant",
|
||||
return_value=(user, tenant.id),
|
||||
):
|
||||
|
||||
@plugin_permission_required(install_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
handler()
|
||||
|
||||
def test_install_admin_allows_admin(self, db_session_with_containers: Session):
|
||||
tenant = _create_tenant(db_session_with_containers)
|
||||
_create_permission(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
install=TenantPluginPermission.InstallPermission.ADMINS,
|
||||
debug=TenantPluginPermission.DebugPermission.EVERYONE,
|
||||
)
|
||||
user = SimpleNamespace(is_admin_or_owner=True)
|
||||
|
||||
with patch(
|
||||
"controllers.console.workspace.current_account_with_tenant",
|
||||
return_value=(user, tenant.id),
|
||||
):
|
||||
|
||||
@plugin_permission_required(install_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
assert handler() == "ok"
|
||||
|
||||
def test_debug_nobody_forbidden(self, db_session_with_containers: Session):
|
||||
tenant = _create_tenant(db_session_with_containers)
|
||||
_create_permission(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
install=TenantPluginPermission.InstallPermission.EVERYONE,
|
||||
debug=TenantPluginPermission.DebugPermission.NOBODY,
|
||||
)
|
||||
user = SimpleNamespace(is_admin_or_owner=True)
|
||||
|
||||
with patch(
|
||||
"controllers.console.workspace.current_account_with_tenant",
|
||||
return_value=(user, tenant.id),
|
||||
):
|
||||
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
handler()
|
||||
|
||||
def test_debug_admin_requires_admin(self, db_session_with_containers: Session):
|
||||
tenant = _create_tenant(db_session_with_containers)
|
||||
_create_permission(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
install=TenantPluginPermission.InstallPermission.EVERYONE,
|
||||
debug=TenantPluginPermission.DebugPermission.ADMINS,
|
||||
)
|
||||
user = SimpleNamespace(is_admin_or_owner=False)
|
||||
|
||||
with patch(
|
||||
"controllers.console.workspace.current_account_with_tenant",
|
||||
return_value=(user, tenant.id),
|
||||
):
|
||||
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
handler()
|
||||
|
||||
def test_debug_admin_allows_admin(self, db_session_with_containers: Session):
|
||||
tenant = _create_tenant(db_session_with_containers)
|
||||
_create_permission(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
install=TenantPluginPermission.InstallPermission.EVERYONE,
|
||||
debug=TenantPluginPermission.DebugPermission.ADMINS,
|
||||
)
|
||||
user = SimpleNamespace(is_admin_or_owner=True)
|
||||
|
||||
with patch(
|
||||
"controllers.console.workspace.current_account_with_tenant",
|
||||
return_value=(user, tenant.id),
|
||||
):
|
||||
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
assert handler() == "ok"
|
||||
@@ -1,5 +1,10 @@
|
||||
"""Testcontainers integration tests for controllers.mcp.mcp endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import types
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from flask import Response
|
||||
@@ -14,24 +19,6 @@ def unwrap(func):
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_db():
|
||||
module.db = types.SimpleNamespace(engine=object())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_session():
|
||||
session = MagicMock()
|
||||
session.__enter__.return_value = session
|
||||
session.__exit__.return_value = False
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_session(fake_session):
|
||||
module.Session = MagicMock(return_value=fake_session)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_mcp_ns():
|
||||
fake_ns = types.SimpleNamespace()
|
||||
@@ -44,8 +31,13 @@ def fake_payload(data):
|
||||
module.mcp_ns.payload = data
|
||||
|
||||
|
||||
_TENANT_ID = str(uuid4())
|
||||
_APP_ID = str(uuid4())
|
||||
_SERVER_ID = str(uuid4())
|
||||
|
||||
|
||||
class DummyServer:
|
||||
def __init__(self, status, app_id="app-1", tenant_id="tenant-1", server_id="srv-1"):
|
||||
def __init__(self, status, app_id=_APP_ID, tenant_id=_TENANT_ID, server_id=_SERVER_ID):
|
||||
self.status = status
|
||||
self.app_id = app_id
|
||||
self.tenant_id = tenant_id
|
||||
@@ -54,8 +46,8 @@ class DummyServer:
|
||||
|
||||
class DummyApp:
|
||||
def __init__(self, mode, workflow=None, app_model_config=None):
|
||||
self.id = "app-1"
|
||||
self.tenant_id = "tenant-1"
|
||||
self.id = _APP_ID
|
||||
self.tenant_id = _TENANT_ID
|
||||
self.mode = mode
|
||||
self.workflow = workflow
|
||||
self.app_model_config = app_model_config
|
||||
@@ -76,6 +68,7 @@ class DummyResult:
|
||||
return {"jsonrpc": "2.0", "result": "ok", "id": 1}
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("flask_req_ctx_with_containers")
|
||||
class TestMCPAppApi:
|
||||
@patch.object(module, "handle_mcp_request", return_value=DummyResult(), autospec=True)
|
||||
def test_success_request(self, mock_handle):
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Unit tests for controllers.web.conversation endpoints."""
|
||||
"""Testcontainers integration tests for controllers.web.conversation endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -7,7 +7,6 @@ from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.web.conversation import (
|
||||
@@ -33,18 +32,18 @@ def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ConversationListApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestConversationListApi:
|
||||
def test_non_chat_mode_raises(self, app: Flask) -> None:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_non_chat_mode_raises(self, app) -> None:
|
||||
with app.test_request_context("/conversations"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ConversationListApi().get(_completion_app(), _end_user())
|
||||
|
||||
@patch("controllers.web.conversation.WebConversationService.pagination_by_last_id")
|
||||
@patch("controllers.web.conversation.db")
|
||||
def test_happy_path(self, mock_db: MagicMock, mock_paginate: MagicMock, app: Flask) -> None:
|
||||
def test_happy_path(self, mock_paginate: MagicMock, app) -> None:
|
||||
conv_id = str(uuid4())
|
||||
conv = SimpleNamespace(
|
||||
id=conv_id,
|
||||
@@ -56,34 +55,26 @@ class TestConversationListApi:
|
||||
updated_at=1700000000,
|
||||
)
|
||||
mock_paginate.return_value = SimpleNamespace(limit=20, has_more=False, data=[conv])
|
||||
mock_db.engine = "engine"
|
||||
|
||||
session_mock = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
app.test_request_context("/conversations?limit=20"),
|
||||
patch("controllers.web.conversation.Session", return_value=session_ctx),
|
||||
):
|
||||
with app.test_request_context("/conversations?limit=20"):
|
||||
result = ConversationListApi().get(_chat_app(), _end_user())
|
||||
|
||||
assert result["limit"] == 20
|
||||
assert result["has_more"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ConversationApi (delete)
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestConversationApi:
|
||||
def test_non_chat_mode_raises(self, app: Flask) -> None:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_non_chat_mode_raises(self, app) -> None:
|
||||
with app.test_request_context(f"/conversations/{uuid4()}"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ConversationApi().delete(_completion_app(), _end_user(), uuid4())
|
||||
|
||||
@patch("controllers.web.conversation.ConversationService.delete")
|
||||
def test_delete_success(self, mock_delete: MagicMock, app: Flask) -> None:
|
||||
def test_delete_success(self, mock_delete: MagicMock, app) -> None:
|
||||
c_id = uuid4()
|
||||
with app.test_request_context(f"/conversations/{c_id}"):
|
||||
result, status = ConversationApi().delete(_chat_app(), _end_user(), c_id)
|
||||
@@ -92,25 +83,26 @@ class TestConversationApi:
|
||||
assert result["result"] == "success"
|
||||
|
||||
@patch("controllers.web.conversation.ConversationService.delete", side_effect=ConversationNotExistsError())
|
||||
def test_delete_not_found(self, mock_delete: MagicMock, app: Flask) -> None:
|
||||
def test_delete_not_found(self, mock_delete: MagicMock, app) -> None:
|
||||
c_id = uuid4()
|
||||
with app.test_request_context(f"/conversations/{c_id}"):
|
||||
with pytest.raises(NotFound, match="Conversation Not Exists"):
|
||||
ConversationApi().delete(_chat_app(), _end_user(), c_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ConversationRenameApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestConversationRenameApi:
|
||||
def test_non_chat_mode_raises(self, app: Flask) -> None:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_non_chat_mode_raises(self, app) -> None:
|
||||
with app.test_request_context(f"/conversations/{uuid4()}/name", method="POST", json={"name": "x"}):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ConversationRenameApi().post(_completion_app(), _end_user(), uuid4())
|
||||
|
||||
@patch("controllers.web.conversation.ConversationService.rename")
|
||||
@patch("controllers.web.conversation.web_ns")
|
||||
def test_rename_success(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None:
|
||||
def test_rename_success(self, mock_ns: MagicMock, mock_rename: MagicMock, app) -> None:
|
||||
c_id = uuid4()
|
||||
mock_ns.payload = {"name": "New Name", "auto_generate": False}
|
||||
conv = SimpleNamespace(
|
||||
@@ -134,7 +126,7 @@ class TestConversationRenameApi:
|
||||
side_effect=ConversationNotExistsError(),
|
||||
)
|
||||
@patch("controllers.web.conversation.web_ns")
|
||||
def test_rename_not_found(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None:
|
||||
def test_rename_not_found(self, mock_ns: MagicMock, mock_rename: MagicMock, app) -> None:
|
||||
c_id = uuid4()
|
||||
mock_ns.payload = {"name": "X", "auto_generate": False}
|
||||
|
||||
@@ -143,17 +135,18 @@ class TestConversationRenameApi:
|
||||
ConversationRenameApi().post(_chat_app(), _end_user(), c_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ConversationPinApi / ConversationUnPinApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestConversationPinApi:
|
||||
def test_non_chat_mode_raises(self, app: Flask) -> None:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_non_chat_mode_raises(self, app) -> None:
|
||||
with app.test_request_context(f"/conversations/{uuid4()}/pin", method="PATCH"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ConversationPinApi().patch(_completion_app(), _end_user(), uuid4())
|
||||
|
||||
@patch("controllers.web.conversation.WebConversationService.pin")
|
||||
def test_pin_success(self, mock_pin: MagicMock, app: Flask) -> None:
|
||||
def test_pin_success(self, mock_pin: MagicMock, app) -> None:
|
||||
c_id = uuid4()
|
||||
with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"):
|
||||
result = ConversationPinApi().patch(_chat_app(), _end_user(), c_id)
|
||||
@@ -161,7 +154,7 @@ class TestConversationPinApi:
|
||||
assert result["result"] == "success"
|
||||
|
||||
@patch("controllers.web.conversation.WebConversationService.pin", side_effect=ConversationNotExistsError())
|
||||
def test_pin_not_found(self, mock_pin: MagicMock, app: Flask) -> None:
|
||||
def test_pin_not_found(self, mock_pin: MagicMock, app) -> None:
|
||||
c_id = uuid4()
|
||||
with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"):
|
||||
with pytest.raises(NotFound):
|
||||
@@ -169,13 +162,17 @@ class TestConversationPinApi:
|
||||
|
||||
|
||||
class TestConversationUnPinApi:
|
||||
def test_non_chat_mode_raises(self, app: Flask) -> None:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_non_chat_mode_raises(self, app) -> None:
|
||||
with app.test_request_context(f"/conversations/{uuid4()}/unpin", method="PATCH"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ConversationUnPinApi().patch(_completion_app(), _end_user(), uuid4())
|
||||
|
||||
@patch("controllers.web.conversation.WebConversationService.unpin")
|
||||
def test_unpin_success(self, mock_unpin: MagicMock, app: Flask) -> None:
|
||||
def test_unpin_success(self, mock_unpin: MagicMock, app) -> None:
|
||||
c_id = uuid4()
|
||||
with app.test_request_context(f"/conversations/{c_id}/unpin", method="PATCH"):
|
||||
result = ConversationUnPinApi().patch(_chat_app(), _end_user(), c_id)
|
||||
@@ -1,9 +1,12 @@
|
||||
"""Testcontainers integration tests for controllers.web.forgot_password endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.forgot_password import (
|
||||
ForgotPasswordCheckApi,
|
||||
@@ -12,13 +15,6 @@ from controllers.web.forgot_password import (
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_wraps():
|
||||
wraps_features = SimpleNamespace(enable_email_password_login=True)
|
||||
@@ -33,6 +29,10 @@ def _patch_wraps():
|
||||
|
||||
|
||||
class TestForgotPasswordSendEmailApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
@patch("controllers.web.forgot_password.AccountService.send_reset_password_email")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@@ -69,6 +69,10 @@ class TestForgotPasswordSendEmailApi:
|
||||
|
||||
|
||||
class TestForgotPasswordCheckApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
@patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
|
||||
@patch("controllers.web.forgot_password.AccountService.generate_reset_password_token")
|
||||
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@@ -143,6 +147,10 @@ class TestForgotPasswordCheckApi:
|
||||
|
||||
|
||||
class TestForgotPasswordResetApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
@patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.web.forgot_password.Session")
|
||||
@@ -1,13 +1,14 @@
|
||||
"""Unit tests for controllers.web.wraps — JWT auth decorator and validation helpers."""
|
||||
"""Testcontainers integration tests for controllers.web.wraps — JWT auth decorator and validation helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
|
||||
|
||||
from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError
|
||||
@@ -18,12 +19,8 @@ from controllers.web.wraps import (
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _validate_webapp_token
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestValidateWebappToken:
|
||||
def test_enterprise_enabled_and_app_auth_requires_webapp_source(self) -> None:
|
||||
"""When both flags are true, a non-webapp source must raise."""
|
||||
decoded = {"token_source": "other"}
|
||||
with pytest.raises(WebAppAuthRequiredError):
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=True, system_webapp_auth_enabled=True)
|
||||
@@ -38,7 +35,6 @@ class TestValidateWebappToken:
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=True, system_webapp_auth_enabled=True)
|
||||
|
||||
def test_public_app_rejects_webapp_source(self) -> None:
|
||||
"""When auth is not required, a webapp-sourced token must be rejected."""
|
||||
decoded = {"token_source": "webapp"}
|
||||
with pytest.raises(Unauthorized):
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=False)
|
||||
@@ -52,18 +48,13 @@ class TestValidateWebappToken:
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=False)
|
||||
|
||||
def test_system_enabled_but_app_public(self) -> None:
|
||||
"""system_webapp_auth_enabled=True but app is public — webapp source rejected."""
|
||||
decoded = {"token_source": "webapp"}
|
||||
with pytest.raises(Unauthorized):
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=True)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _validate_user_accessibility
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestValidateUserAccessibility:
|
||||
def test_skips_when_auth_disabled(self) -> None:
|
||||
"""No checks when system or app auth is disabled."""
|
||||
_validate_user_accessibility(
|
||||
decoded={},
|
||||
app_code="code",
|
||||
@@ -123,7 +114,6 @@ class TestValidateUserAccessibility:
|
||||
def test_external_auth_type_checks_sso_update_time(
|
||||
self, mock_perm_check: MagicMock, mock_sso_time: MagicMock
|
||||
) -> None:
|
||||
# granted_at is before SSO update time → denied
|
||||
mock_sso_time.return_value = datetime.now(UTC)
|
||||
old_granted = int((datetime.now(UTC) - timedelta(hours=1)).timestamp())
|
||||
decoded = {"user_id": "u1", "auth_type": "external", "granted_at": old_granted}
|
||||
@@ -164,7 +154,6 @@ class TestValidateUserAccessibility:
|
||||
recent_granted = int(datetime.now(UTC).timestamp())
|
||||
decoded = {"user_id": "u1", "auth_type": "external", "granted_at": recent_granted}
|
||||
settings = SimpleNamespace(access_mode="public")
|
||||
# Should not raise
|
||||
_validate_user_accessibility(
|
||||
decoded=decoded,
|
||||
app_code="code",
|
||||
@@ -191,10 +180,49 @@ class TestValidateUserAccessibility:
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# decode_jwt_token
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestDecodeJwtToken:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def _create_app_site_enduser(self, db_session: Session, *, enable_site: bool = True):
|
||||
from models.model import App, AppMode, CustomizeTokenStrategy, EndUser, Site
|
||||
|
||||
tenant_id = str(uuid4())
|
||||
app_model = App(
|
||||
tenant_id=tenant_id,
|
||||
mode=AppMode.CHAT.value,
|
||||
name="test-app",
|
||||
enable_site=enable_site,
|
||||
enable_api=True,
|
||||
)
|
||||
db_session.add(app_model)
|
||||
db_session.commit()
|
||||
db_session.expire_all()
|
||||
|
||||
site = Site(
|
||||
app_id=app_model.id,
|
||||
title="test-site",
|
||||
default_language="en-US",
|
||||
customize_token_strategy=CustomizeTokenStrategy.NOT_ALLOW,
|
||||
code="code1",
|
||||
)
|
||||
db_session.add(site)
|
||||
db_session.commit()
|
||||
db_session.expire_all()
|
||||
|
||||
end_user = EndUser(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
type="browser",
|
||||
session_id="sess-1",
|
||||
)
|
||||
db_session.add(end_user)
|
||||
db_session.commit()
|
||||
db_session.expire_all()
|
||||
|
||||
return app_model, site, end_user
|
||||
|
||||
@patch("controllers.web.wraps._validate_user_accessibility")
|
||||
@patch("controllers.web.wraps._validate_webapp_token")
|
||||
@patch("controllers.web.wraps.EnterpriseService.WebAppAuth.get_app_access_mode_by_id")
|
||||
@@ -202,10 +230,8 @@ class TestDecodeJwtToken:
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.PassportService")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
@patch("controllers.web.wraps.db")
|
||||
def test_happy_path(
|
||||
self,
|
||||
mock_db: MagicMock,
|
||||
mock_extract: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_features: MagicMock,
|
||||
@@ -213,40 +239,28 @@ class TestDecodeJwtToken:
|
||||
mock_access_mode: MagicMock,
|
||||
mock_validate_token: MagicMock,
|
||||
mock_validate_user: MagicMock,
|
||||
app: Flask,
|
||||
app,
|
||||
db_session_with_containers: Session,
|
||||
) -> None:
|
||||
app_model, site, end_user = self._create_app_site_enduser(db_session_with_containers)
|
||||
|
||||
mock_extract.return_value = "jwt-token"
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"app_code": "code1",
|
||||
"app_id": "app-1",
|
||||
"end_user_id": "eu-1",
|
||||
"app_code": site.code,
|
||||
"app_id": app_model.id,
|
||||
"end_user_id": end_user.id,
|
||||
}
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", enable_site=True)
|
||||
site = SimpleNamespace(code="code1")
|
||||
end_user = SimpleNamespace(id="eu-1", session_id="sess-1")
|
||||
with app.test_request_context("/", headers={"X-App-Code": site.code}):
|
||||
result_app, result_user = decode_jwt_token()
|
||||
|
||||
# Configure session mock to return correct objects via scalar()
|
||||
session_mock = MagicMock()
|
||||
session_mock.scalar.side_effect = [app_model, site, end_user]
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
mock_db.engine = "engine"
|
||||
|
||||
with patch("controllers.web.wraps.Session", return_value=session_ctx):
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
result_app, result_user = decode_jwt_token()
|
||||
|
||||
assert result_app.id == "app-1"
|
||||
assert result_user.id == "eu-1"
|
||||
assert result_app.id == app_model.id
|
||||
assert result_user.id == end_user.id
|
||||
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
def test_missing_token_raises_unauthorized(
|
||||
self, mock_extract: MagicMock, mock_features: MagicMock, app: Flask
|
||||
) -> None:
|
||||
def test_missing_token_raises_unauthorized(self, mock_extract: MagicMock, mock_features: MagicMock, app) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
mock_extract.return_value = None
|
||||
|
||||
@@ -257,137 +271,98 @@ class TestDecodeJwtToken:
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.PassportService")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
@patch("controllers.web.wraps.db")
|
||||
def test_missing_app_raises_not_found(
|
||||
self,
|
||||
mock_db: MagicMock,
|
||||
mock_extract: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_features: MagicMock,
|
||||
app: Flask,
|
||||
app,
|
||||
) -> None:
|
||||
non_existent_id = str(uuid4())
|
||||
mock_extract.return_value = "jwt-token"
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"app_code": "code1",
|
||||
"app_id": "app-1",
|
||||
"end_user_id": "eu-1",
|
||||
"app_id": non_existent_id,
|
||||
"end_user_id": str(uuid4()),
|
||||
}
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
|
||||
session_mock = MagicMock()
|
||||
session_mock.scalar.return_value = None # No app found
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
mock_db.engine = "engine"
|
||||
|
||||
with patch("controllers.web.wraps.Session", return_value=session_ctx):
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(NotFound):
|
||||
decode_jwt_token()
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(NotFound):
|
||||
decode_jwt_token()
|
||||
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.PassportService")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
@patch("controllers.web.wraps.db")
|
||||
def test_disabled_site_raises_bad_request(
|
||||
self,
|
||||
mock_db: MagicMock,
|
||||
mock_extract: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_features: MagicMock,
|
||||
app: Flask,
|
||||
app,
|
||||
db_session_with_containers: Session,
|
||||
) -> None:
|
||||
app_model, site, end_user = self._create_app_site_enduser(db_session_with_containers, enable_site=False)
|
||||
|
||||
mock_extract.return_value = "jwt-token"
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"app_code": "code1",
|
||||
"app_id": "app-1",
|
||||
"end_user_id": "eu-1",
|
||||
"app_code": site.code,
|
||||
"app_id": app_model.id,
|
||||
"end_user_id": end_user.id,
|
||||
}
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", enable_site=False)
|
||||
|
||||
session_mock = MagicMock()
|
||||
# scalar calls: app_model, site (code found), then end_user
|
||||
session_mock.scalar.side_effect = [app_model, SimpleNamespace(code="code1"), None]
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
mock_db.engine = "engine"
|
||||
|
||||
with patch("controllers.web.wraps.Session", return_value=session_ctx):
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(BadRequest, match="Site is disabled"):
|
||||
decode_jwt_token()
|
||||
with app.test_request_context("/", headers={"X-App-Code": site.code}):
|
||||
with pytest.raises(BadRequest, match="Site is disabled"):
|
||||
decode_jwt_token()
|
||||
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.PassportService")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
@patch("controllers.web.wraps.db")
|
||||
def test_missing_end_user_raises_not_found(
|
||||
self,
|
||||
mock_db: MagicMock,
|
||||
mock_extract: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_features: MagicMock,
|
||||
app: Flask,
|
||||
app,
|
||||
db_session_with_containers: Session,
|
||||
) -> None:
|
||||
app_model, site, _ = self._create_app_site_enduser(db_session_with_containers)
|
||||
non_existent_eu = str(uuid4())
|
||||
|
||||
mock_extract.return_value = "jwt-token"
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"app_code": "code1",
|
||||
"app_id": "app-1",
|
||||
"end_user_id": "eu-1",
|
||||
"app_code": site.code,
|
||||
"app_id": app_model.id,
|
||||
"end_user_id": non_existent_eu,
|
||||
}
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", enable_site=True)
|
||||
site = SimpleNamespace(code="code1")
|
||||
|
||||
session_mock = MagicMock()
|
||||
session_mock.scalar.side_effect = [app_model, site, None] # end_user is None
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
mock_db.engine = "engine"
|
||||
|
||||
with patch("controllers.web.wraps.Session", return_value=session_ctx):
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(NotFound):
|
||||
decode_jwt_token()
|
||||
with app.test_request_context("/", headers={"X-App-Code": site.code}):
|
||||
with pytest.raises(NotFound):
|
||||
decode_jwt_token()
|
||||
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.PassportService")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
@patch("controllers.web.wraps.db")
|
||||
def test_user_id_mismatch_raises_unauthorized(
|
||||
self,
|
||||
mock_db: MagicMock,
|
||||
mock_extract: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_features: MagicMock,
|
||||
app: Flask,
|
||||
app,
|
||||
db_session_with_containers: Session,
|
||||
) -> None:
|
||||
app_model, site, end_user = self._create_app_site_enduser(db_session_with_containers)
|
||||
|
||||
mock_extract.return_value = "jwt-token"
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"app_code": "code1",
|
||||
"app_id": "app-1",
|
||||
"end_user_id": "eu-1",
|
||||
"app_code": site.code,
|
||||
"app_id": app_model.id,
|
||||
"end_user_id": end_user.id,
|
||||
}
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", enable_site=True)
|
||||
site = SimpleNamespace(code="code1")
|
||||
end_user = SimpleNamespace(id="eu-1", session_id="sess-1")
|
||||
|
||||
session_mock = MagicMock()
|
||||
session_mock.scalar.side_effect = [app_model, site, end_user]
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
mock_db.engine = "engine"
|
||||
|
||||
with patch("controllers.web.wraps.Session", return_value=session_ctx):
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(Unauthorized, match="expired"):
|
||||
decode_jwt_token(user_id="different-user")
|
||||
with app.test_request_context("/", headers={"X-App-Code": site.code}):
|
||||
with pytest.raises(Unauthorized, match="expired"):
|
||||
decode_jwt_token(user_id="different-user")
|
||||
@@ -141,7 +141,7 @@ class TestModelLoadBalancingService:
|
||||
tenant_id=tenant_id,
|
||||
provider_name="openai",
|
||||
model_name="gpt-3.5-turbo",
|
||||
model_type="text-generation", # Use the origin model type that matches the query
|
||||
model_type="llm",
|
||||
enabled=True,
|
||||
load_balancing_enabled=False,
|
||||
)
|
||||
@@ -298,7 +298,7 @@ class TestModelLoadBalancingService:
|
||||
tenant_id=tenant.id,
|
||||
provider_name="openai",
|
||||
model_name="gpt-3.5-turbo",
|
||||
model_type="text-generation", # Use the origin model type that matches the query
|
||||
model_type="llm",
|
||||
name="config1",
|
||||
encrypted_config='{"api_key": "test_key"}',
|
||||
enabled=True,
|
||||
@@ -417,7 +417,7 @@ class TestModelLoadBalancingService:
|
||||
tenant_id=tenant.id,
|
||||
provider_name="openai",
|
||||
model_name="gpt-3.5-turbo",
|
||||
model_type="text-generation", # Use the origin model type that matches the query
|
||||
model_type="llm",
|
||||
name="config1",
|
||||
encrypted_config='{"api_key": "test_key"}',
|
||||
enabled=True,
|
||||
|
||||
@@ -1,157 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.app import app_import as app_import_module
|
||||
from services.app_dsl_service import ImportStatus
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
class _Result:
|
||||
def __init__(self, status: ImportStatus, app_id: str | None = "app-1"):
|
||||
self.status = status
|
||||
self.app_id = app_id
|
||||
|
||||
def model_dump(self, mode: str = "json"):
|
||||
return {"status": self.status, "app_id": self.app_id}
|
||||
|
||||
|
||||
class _SessionContext:
|
||||
def __init__(self, session):
|
||||
self._session = session
|
||||
|
||||
def __enter__(self):
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
|
||||
def _install_session(monkeypatch: pytest.MonkeyPatch, session: MagicMock) -> None:
|
||||
monkeypatch.setattr(app_import_module, "Session", lambda *_: _SessionContext(session))
|
||||
monkeypatch.setattr(app_import_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
|
||||
def _install_features(monkeypatch: pytest.MonkeyPatch, enabled: bool) -> None:
|
||||
features = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=enabled))
|
||||
monkeypatch.setattr(app_import_module.FeatureService, "get_system_features", lambda: features)
|
||||
|
||||
|
||||
def test_import_post_returns_failed_status(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
session = MagicMock()
|
||||
_install_session(monkeypatch, session)
|
||||
_install_features(monkeypatch, enabled=False)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
|
||||
session.commit.assert_called_once()
|
||||
assert status == 400
|
||||
assert response["status"] == ImportStatus.FAILED
|
||||
|
||||
|
||||
def test_import_post_returns_pending_status(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
session = MagicMock()
|
||||
_install_session(monkeypatch, session)
|
||||
_install_features(monkeypatch, enabled=False)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.PENDING),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
|
||||
session.commit.assert_called_once()
|
||||
assert status == 202
|
||||
assert response["status"] == ImportStatus.PENDING
|
||||
|
||||
|
||||
def test_import_post_updates_webapp_auth_when_enabled(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
session = MagicMock()
|
||||
_install_session(monkeypatch, session)
|
||||
_install_features(monkeypatch, enabled=True)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"),
|
||||
)
|
||||
update_access = MagicMock()
|
||||
monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
|
||||
session.commit.assert_called_once()
|
||||
update_access.assert_called_once_with("app-123", "private")
|
||||
assert status == 200
|
||||
assert response["status"] == ImportStatus.COMPLETED
|
||||
|
||||
|
||||
def test_import_confirm_returns_failed_status(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportConfirmApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
session = MagicMock()
|
||||
_install_session(monkeypatch, session)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"confirm_import",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"):
|
||||
response, status = method(import_id="import-1")
|
||||
|
||||
session.commit.assert_called_once()
|
||||
assert status == 400
|
||||
assert response["status"] == ImportStatus.FAILED
|
||||
|
||||
|
||||
def test_import_check_dependencies_returns_result(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportCheckDependenciesApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
session = MagicMock()
|
||||
_install_session(monkeypatch, session)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"check_dependencies",
|
||||
lambda *_args, **_kwargs: SimpleNamespace(model_dump=lambda mode="json": {"leaked_dependencies": []}),
|
||||
)
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports/app-1/check-dependencies", method="GET"):
|
||||
response, status = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert status == 200
|
||||
assert response["leaked_dependencies"] == []
|
||||
@@ -1,142 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console.workspace import plugin_permission_required
|
||||
from models.account import TenantPluginPermission
|
||||
|
||||
|
||||
class _SessionStub:
|
||||
def __init__(self, permission):
|
||||
self._permission = permission
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def query(self, *_args, **_kwargs):
|
||||
return self
|
||||
|
||||
def where(self, *_args, **_kwargs):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return self._permission
|
||||
|
||||
|
||||
def _workspace_module():
|
||||
return importlib.import_module(plugin_permission_required.__module__)
|
||||
|
||||
|
||||
def _patch_session(monkeypatch: pytest.MonkeyPatch, permission):
|
||||
module = _workspace_module()
|
||||
monkeypatch.setattr(module, "Session", lambda *_args, **_kwargs: _SessionStub(permission))
|
||||
monkeypatch.setattr(module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
|
||||
def test_plugin_permission_allows_without_permission(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(is_admin_or_owner=False)
|
||||
module = _workspace_module()
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
_patch_session(monkeypatch, None)
|
||||
|
||||
@plugin_permission_required()
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
assert handler() == "ok"
|
||||
|
||||
|
||||
def test_plugin_permission_install_nobody_forbidden(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(is_admin_or_owner=True)
|
||||
permission = SimpleNamespace(
|
||||
install_permission=TenantPluginPermission.InstallPermission.NOBODY,
|
||||
debug_permission=TenantPluginPermission.DebugPermission.EVERYONE,
|
||||
)
|
||||
module = _workspace_module()
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
_patch_session(monkeypatch, permission)
|
||||
|
||||
@plugin_permission_required(install_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
handler()
|
||||
|
||||
|
||||
def test_plugin_permission_install_admin_requires_admin(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(is_admin_or_owner=False)
|
||||
permission = SimpleNamespace(
|
||||
install_permission=TenantPluginPermission.InstallPermission.ADMINS,
|
||||
debug_permission=TenantPluginPermission.DebugPermission.EVERYONE,
|
||||
)
|
||||
module = _workspace_module()
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
_patch_session(monkeypatch, permission)
|
||||
|
||||
@plugin_permission_required(install_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
handler()
|
||||
|
||||
|
||||
def test_plugin_permission_install_admin_allows_admin(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(is_admin_or_owner=True)
|
||||
permission = SimpleNamespace(
|
||||
install_permission=TenantPluginPermission.InstallPermission.ADMINS,
|
||||
debug_permission=TenantPluginPermission.DebugPermission.EVERYONE,
|
||||
)
|
||||
module = _workspace_module()
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
_patch_session(monkeypatch, permission)
|
||||
|
||||
@plugin_permission_required(install_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
assert handler() == "ok"
|
||||
|
||||
|
||||
def test_plugin_permission_debug_nobody_forbidden(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(is_admin_or_owner=True)
|
||||
permission = SimpleNamespace(
|
||||
install_permission=TenantPluginPermission.InstallPermission.EVERYONE,
|
||||
debug_permission=TenantPluginPermission.DebugPermission.NOBODY,
|
||||
)
|
||||
module = _workspace_module()
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
_patch_session(monkeypatch, permission)
|
||||
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
handler()
|
||||
|
||||
|
||||
def test_plugin_permission_debug_admin_requires_admin(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(is_admin_or_owner=False)
|
||||
permission = SimpleNamespace(
|
||||
install_permission=TenantPluginPermission.InstallPermission.EVERYONE,
|
||||
debug_permission=TenantPluginPermission.DebugPermission.ADMINS,
|
||||
)
|
||||
module = _workspace_module()
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
_patch_session(monkeypatch, permission)
|
||||
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
handler()
|
||||
@@ -768,6 +768,7 @@ class TestSegmentApiGet:
|
||||
``current_account_with_tenant()`` and ``marshal``.
|
||||
"""
|
||||
|
||||
@patch("controllers.service_api.dataset.segment.SummaryIndexService")
|
||||
@patch("controllers.service_api.dataset.segment.marshal")
|
||||
@patch("controllers.service_api.dataset.segment.SegmentService")
|
||||
@patch("controllers.service_api.dataset.segment.DocumentService")
|
||||
@@ -780,6 +781,7 @@ class TestSegmentApiGet:
|
||||
mock_doc_svc,
|
||||
mock_seg_svc,
|
||||
mock_marshal,
|
||||
mock_summary_svc,
|
||||
app,
|
||||
mock_tenant,
|
||||
mock_dataset,
|
||||
@@ -791,7 +793,8 @@ class TestSegmentApiGet:
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_doc_svc.get_document.return_value = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX)
|
||||
mock_seg_svc.get_segments.return_value = ([mock_segment], 1)
|
||||
mock_marshal.return_value = [{"id": mock_segment.id}]
|
||||
mock_marshal.return_value = {"id": mock_segment.id}
|
||||
mock_summary_svc.get_segments_summaries.return_value = {}
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
@@ -872,6 +875,7 @@ class TestSegmentApiPost:
|
||||
mock_rate_limit.enabled = False
|
||||
mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit
|
||||
|
||||
@patch("controllers.service_api.dataset.segment.SummaryIndexService")
|
||||
@patch("controllers.service_api.dataset.segment.marshal")
|
||||
@patch("controllers.service_api.dataset.segment.SegmentService")
|
||||
@patch("controllers.service_api.dataset.segment.DocumentService")
|
||||
@@ -888,6 +892,7 @@ class TestSegmentApiPost:
|
||||
mock_doc_svc,
|
||||
mock_seg_svc,
|
||||
mock_marshal,
|
||||
mock_summary_svc,
|
||||
app,
|
||||
mock_tenant,
|
||||
mock_dataset,
|
||||
@@ -909,7 +914,8 @@ class TestSegmentApiPost:
|
||||
|
||||
mock_seg_svc.segment_create_args_validate.return_value = None
|
||||
mock_seg_svc.multi_create_segment.return_value = [mock_segment]
|
||||
mock_marshal.return_value = [{"id": mock_segment.id}]
|
||||
mock_marshal.return_value = {"id": mock_segment.id}
|
||||
mock_summary_svc.get_segments_summaries.return_value = {}
|
||||
|
||||
segments_data = [{"content": "Test segment content", "answer": "Test answer"}]
|
||||
|
||||
@@ -1206,6 +1212,7 @@ class TestDatasetSegmentApiUpdate:
|
||||
mock_rate_limit.enabled = False
|
||||
mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit
|
||||
|
||||
@patch("controllers.service_api.dataset.segment.SummaryIndexService")
|
||||
@patch("controllers.service_api.dataset.segment.marshal")
|
||||
@patch("controllers.service_api.dataset.segment.SegmentService")
|
||||
@patch("controllers.service_api.dataset.segment.DocumentService")
|
||||
@@ -1224,6 +1231,7 @@ class TestDatasetSegmentApiUpdate:
|
||||
mock_doc_svc,
|
||||
mock_seg_svc,
|
||||
mock_marshal,
|
||||
mock_summary_svc,
|
||||
app,
|
||||
mock_tenant,
|
||||
mock_dataset,
|
||||
@@ -1240,6 +1248,7 @@ class TestDatasetSegmentApiUpdate:
|
||||
updated = Mock()
|
||||
mock_seg_svc.update_segment.return_value = updated
|
||||
mock_marshal.return_value = {"id": mock_segment.id}
|
||||
mock_summary_svc.get_segment_summary.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{mock_segment.id}",
|
||||
@@ -1349,6 +1358,7 @@ class TestDatasetSegmentApiGetSingle:
|
||||
``current_account_with_tenant()`` and ``marshal``.
|
||||
"""
|
||||
|
||||
@patch("controllers.service_api.dataset.segment.SummaryIndexService")
|
||||
@patch("controllers.service_api.dataset.segment.marshal")
|
||||
@patch("controllers.service_api.dataset.segment.SegmentService")
|
||||
@patch("controllers.service_api.dataset.segment.DocumentService")
|
||||
@@ -1363,6 +1373,7 @@ class TestDatasetSegmentApiGetSingle:
|
||||
mock_doc_svc,
|
||||
mock_seg_svc,
|
||||
mock_marshal,
|
||||
mock_summary_svc,
|
||||
app,
|
||||
mock_tenant,
|
||||
mock_dataset,
|
||||
@@ -1376,6 +1387,7 @@ class TestDatasetSegmentApiGetSingle:
|
||||
mock_doc_svc.get_document.return_value = mock_doc
|
||||
mock_seg_svc.get_segment_by_id.return_value = mock_segment
|
||||
mock_marshal.return_value = {"id": mock_segment.id}
|
||||
mock_summary_svc.get_segment_summary.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{mock_segment.id}",
|
||||
@@ -1393,6 +1405,55 @@ class TestDatasetSegmentApiGetSingle:
|
||||
assert "data" in response
|
||||
assert response["doc_form"] == IndexStructureType.PARAGRAPH_INDEX
|
||||
|
||||
@patch("controllers.service_api.dataset.segment.SummaryIndexService")
|
||||
@patch("controllers.service_api.dataset.segment.marshal")
|
||||
@patch("controllers.service_api.dataset.segment.SegmentService")
|
||||
@patch("controllers.service_api.dataset.segment.DocumentService")
|
||||
@patch("controllers.service_api.dataset.segment.DatasetService")
|
||||
@patch("controllers.service_api.dataset.segment.current_account_with_tenant")
|
||||
@patch("controllers.service_api.dataset.segment.db")
|
||||
def test_get_single_segment_includes_summary(
|
||||
self,
|
||||
mock_db,
|
||||
mock_account_fn,
|
||||
mock_dataset_svc,
|
||||
mock_doc_svc,
|
||||
mock_seg_svc,
|
||||
mock_marshal,
|
||||
mock_summary_svc,
|
||||
app,
|
||||
mock_tenant,
|
||||
mock_dataset,
|
||||
mock_segment,
|
||||
):
|
||||
"""Test that single segment response includes summary content from SummaryIndexService."""
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_model_setting.return_value = None
|
||||
mock_doc = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX)
|
||||
mock_doc_svc.get_document.return_value = mock_doc
|
||||
mock_seg_svc.get_segment_by_id.return_value = mock_segment
|
||||
mock_marshal.return_value = {"id": mock_segment.id, "summary": None}
|
||||
|
||||
mock_summary_record = Mock()
|
||||
mock_summary_record.summary_content = "This is the segment summary"
|
||||
mock_summary_svc.get_segment_summary.return_value = mock_summary_record
|
||||
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{mock_segment.id}",
|
||||
method="GET",
|
||||
):
|
||||
api = DatasetSegmentApi()
|
||||
response, status = api.get(
|
||||
tenant_id=mock_tenant.id,
|
||||
dataset_id=mock_dataset.id,
|
||||
document_id="doc-id",
|
||||
segment_id=mock_segment.id,
|
||||
)
|
||||
|
||||
assert status == 200
|
||||
assert response["data"]["summary"] == "This is the segment summary"
|
||||
|
||||
@patch("controllers.service_api.dataset.segment.current_account_with_tenant")
|
||||
@patch("controllers.service_api.dataset.segment.db")
|
||||
def test_get_single_segment_dataset_not_found(
|
||||
|
||||
@@ -415,12 +415,44 @@ class TestUtilityFunctions:
|
||||
label="Upload",
|
||||
required=False,
|
||||
),
|
||||
VariableEntity(
|
||||
type=VariableEntityType.CHECKBOX,
|
||||
variable="enabled",
|
||||
description="Enable flag",
|
||||
label="Enabled",
|
||||
required=False,
|
||||
),
|
||||
VariableEntity(
|
||||
type=VariableEntityType.JSON_OBJECT,
|
||||
variable="config",
|
||||
description="Config object",
|
||||
label="Config",
|
||||
required=True,
|
||||
),
|
||||
VariableEntity(
|
||||
type=VariableEntityType.JSON_OBJECT,
|
||||
variable="schema_config",
|
||||
description="Config with schema",
|
||||
label="Schema Config",
|
||||
required=False,
|
||||
json_schema={
|
||||
"properties": {
|
||||
"host": {"type": "string"},
|
||||
"port": {"type": "number"},
|
||||
},
|
||||
"required": ["host"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
parameters_dict: dict[str, str] = {
|
||||
"name": "Enter your name",
|
||||
"category": "Select category",
|
||||
"count": "Enter count",
|
||||
"enabled": "Enable flag",
|
||||
"config": "Config object",
|
||||
"schema_config": "Config with schema",
|
||||
}
|
||||
|
||||
parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict)
|
||||
@@ -437,20 +469,35 @@ class TestUtilityFunctions:
|
||||
assert "count" in parameters
|
||||
assert parameters["count"]["type"] == "number"
|
||||
|
||||
# FILE type should be skipped - it creates empty dict but gets filtered later
|
||||
# Check that it doesn't have any meaningful content
|
||||
if "upload" in parameters:
|
||||
assert parameters["upload"] == {}
|
||||
# FILE type is skipped entirely via `continue` — key should not exist
|
||||
assert "upload" not in parameters
|
||||
|
||||
# CHECKBOX maps to boolean
|
||||
assert parameters["enabled"]["type"] == "boolean"
|
||||
|
||||
# JSON_OBJECT without json_schema maps to object
|
||||
assert parameters["config"]["type"] == "object"
|
||||
assert "properties" not in parameters["config"]
|
||||
|
||||
# JSON_OBJECT with json_schema forwards schema keys
|
||||
assert parameters["schema_config"]["type"] == "object"
|
||||
assert parameters["schema_config"]["properties"] == {
|
||||
"host": {"type": "string"},
|
||||
"port": {"type": "number"},
|
||||
}
|
||||
assert parameters["schema_config"]["required"] == ["host"]
|
||||
assert parameters["schema_config"]["additionalProperties"] is False
|
||||
|
||||
# Check required fields
|
||||
assert "name" in required
|
||||
assert "count" in required
|
||||
assert "config" in required
|
||||
assert "category" not in required
|
||||
|
||||
# Note: _get_request_id function has been removed as request_id is now passed as parameter
|
||||
|
||||
def test_convert_input_form_to_parameters_jsonschema_validation_ok(self):
|
||||
"""Current schema uses 'number' for numeric fields; it should be a valid JSON Schema."""
|
||||
"""Generated schema with all supported types should be valid JSON Schema."""
|
||||
user_input_form = [
|
||||
VariableEntity(
|
||||
type=VariableEntityType.NUMBER,
|
||||
@@ -466,11 +513,27 @@ class TestUtilityFunctions:
|
||||
label="Name",
|
||||
required=False,
|
||||
),
|
||||
VariableEntity(
|
||||
type=VariableEntityType.CHECKBOX,
|
||||
variable="enabled",
|
||||
description="Toggle",
|
||||
label="Enabled",
|
||||
required=False,
|
||||
),
|
||||
VariableEntity(
|
||||
type=VariableEntityType.JSON_OBJECT,
|
||||
variable="metadata",
|
||||
description="Metadata",
|
||||
label="Metadata",
|
||||
required=False,
|
||||
),
|
||||
]
|
||||
|
||||
parameters_dict = {
|
||||
"count": "Enter count",
|
||||
"name": "Enter your name",
|
||||
"enabled": "Toggle flag",
|
||||
"metadata": "Metadata object",
|
||||
}
|
||||
|
||||
parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict)
|
||||
@@ -485,9 +548,12 @@ class TestUtilityFunctions:
|
||||
# 1) The schema itself must be valid
|
||||
jsonschema.Draft202012Validator.check_schema(schema)
|
||||
|
||||
# 2) Both float and integer instances should pass validation
|
||||
# 2) Validate instances with all types
|
||||
jsonschema.validate(instance={"count": 3.14, "name": "alice"}, schema=schema)
|
||||
jsonschema.validate(instance={"count": 2, "name": "bob"}, schema=schema)
|
||||
jsonschema.validate(
|
||||
instance={"count": 2, "enabled": True, "metadata": {"key": "val"}},
|
||||
schema=schema,
|
||||
)
|
||||
|
||||
def test_legacy_float_type_schema_is_invalid(self):
|
||||
"""Legacy/buggy behavior: using 'float' should produce an invalid JSON Schema."""
|
||||
|
||||
@@ -521,11 +521,11 @@ def test_generate_name_trace(trace_instance):
|
||||
def test_add_trace_success(trace_instance):
|
||||
data = LangfuseTrace(id="t1", name="trace")
|
||||
trace_instance.add_trace(data)
|
||||
trace_instance.langfuse_client.trace.assert_called_once()
|
||||
trace_instance.langfuse_client.api.ingestion.batch.assert_called_once()
|
||||
|
||||
|
||||
def test_add_trace_error(trace_instance):
|
||||
trace_instance.langfuse_client.trace.side_effect = Exception("error")
|
||||
trace_instance.langfuse_client.api.ingestion.batch.side_effect = Exception("error")
|
||||
data = LangfuseTrace(id="t1", name="trace")
|
||||
with pytest.raises(ValueError, match="LangFuse Failed to create trace: error"):
|
||||
trace_instance.add_trace(data)
|
||||
@@ -534,11 +534,11 @@ def test_add_trace_error(trace_instance):
|
||||
def test_add_span_success(trace_instance):
|
||||
data = LangfuseSpan(id="s1", name="span", trace_id="t1")
|
||||
trace_instance.add_span(data)
|
||||
trace_instance.langfuse_client.span.assert_called_once()
|
||||
trace_instance.langfuse_client.api.ingestion.batch.assert_called_once()
|
||||
|
||||
|
||||
def test_add_span_error(trace_instance):
|
||||
trace_instance.langfuse_client.span.side_effect = Exception("error")
|
||||
trace_instance.langfuse_client.api.ingestion.batch.side_effect = Exception("error")
|
||||
data = LangfuseSpan(id="s1", name="span", trace_id="t1")
|
||||
with pytest.raises(ValueError, match="LangFuse Failed to create span: error"):
|
||||
trace_instance.add_span(data)
|
||||
@@ -554,11 +554,11 @@ def test_update_span(trace_instance):
|
||||
def test_add_generation_success(trace_instance):
|
||||
data = LangfuseGeneration(id="g1", name="gen", trace_id="t1")
|
||||
trace_instance.add_generation(data)
|
||||
trace_instance.langfuse_client.generation.assert_called_once()
|
||||
trace_instance.langfuse_client.api.ingestion.batch.assert_called_once()
|
||||
|
||||
|
||||
def test_add_generation_error(trace_instance):
|
||||
trace_instance.langfuse_client.generation.side_effect = Exception("error")
|
||||
trace_instance.langfuse_client.api.ingestion.batch.side_effect = Exception("error")
|
||||
data = LangfuseGeneration(id="g1", name="gen", trace_id="t1")
|
||||
with pytest.raises(ValueError, match="LangFuse Failed to create generation: error"):
|
||||
trace_instance.add_generation(data)
|
||||
@@ -585,12 +585,12 @@ def test_api_check_error(trace_instance):
|
||||
def test_get_project_key_success(trace_instance):
|
||||
mock_data = MagicMock()
|
||||
mock_data.id = "proj-1"
|
||||
trace_instance.langfuse_client.client.projects.get.return_value = MagicMock(data=[mock_data])
|
||||
trace_instance.langfuse_client.api.projects.get.return_value = MagicMock(data=[mock_data])
|
||||
assert trace_instance.get_project_key() == "proj-1"
|
||||
|
||||
|
||||
def test_get_project_key_error(trace_instance):
|
||||
trace_instance.langfuse_client.client.projects.get.side_effect = Exception("fail")
|
||||
trace_instance.langfuse_client.api.projects.get.side_effect = Exception("fail")
|
||||
with pytest.raises(ValueError, match="LangFuse get project key failed: fail"):
|
||||
trace_instance.get_project_key()
|
||||
|
||||
|
||||
@@ -201,27 +201,23 @@ def test_search_returns_documents_in_rank_order_and_applies_filter(monkeypatch,
|
||||
document_id = _Field("document_id")
|
||||
|
||||
keyword = Jieba(_dataset(_dataset_keyword_table()))
|
||||
query_stmt = _FakeQuery()
|
||||
patched_runtime.session.query.return_value = query_stmt
|
||||
patched_runtime.session.execute.return_value = _FakeExecuteResult(
|
||||
[
|
||||
SimpleNamespace(
|
||||
index_node_id="node-2",
|
||||
content="segment-content",
|
||||
index_node_hash="hash-2",
|
||||
document_id="doc-2",
|
||||
dataset_id="dataset-1",
|
||||
)
|
||||
]
|
||||
)
|
||||
patched_runtime.session.scalars.return_value.all.return_value = [
|
||||
SimpleNamespace(
|
||||
index_node_id="node-2",
|
||||
content="segment-content",
|
||||
index_node_hash="hash-2",
|
||||
document_id="doc-2",
|
||||
dataset_id="dataset-1",
|
||||
)
|
||||
]
|
||||
|
||||
monkeypatch.setattr(jieba_module, "DocumentSegment", _FakeDocumentSegment)
|
||||
monkeypatch.setattr(jieba_module, "select", lambda *_: _FakeSelect())
|
||||
monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={"k": {"node-1", "node-2"}}))
|
||||
monkeypatch.setattr(keyword, "_retrieve_ids_by_query", MagicMock(return_value=["node-1", "node-2"]))
|
||||
|
||||
documents = keyword.search("query", top_k=2, document_ids_filter=["doc-2"])
|
||||
|
||||
assert len(query_stmt.where_calls) == 2
|
||||
assert len(documents) == 1
|
||||
assert documents[0].page_content == "segment-content"
|
||||
assert documents[0].metadata["doc_id"] == "node-2"
|
||||
|
||||
@@ -714,13 +714,13 @@ class TestRetrievalServiceInternals:
|
||||
dataset_id="dataset-id",
|
||||
)
|
||||
|
||||
dataset_query = Mock()
|
||||
dataset_query.where.return_value.options.return_value.all.return_value = [
|
||||
scalars_result = Mock()
|
||||
scalars_result.all.return_value = [
|
||||
dataset_doc_parent,
|
||||
dataset_doc_text,
|
||||
dataset_doc_parent_summary,
|
||||
]
|
||||
monkeypatch.setattr(retrieval_service_module.db.session, "query", Mock(return_value=dataset_query))
|
||||
monkeypatch.setattr(retrieval_service_module.db.session, "scalars", Mock(return_value=scalars_result))
|
||||
monkeypatch.setattr(retrieval_service_module, "RetrievalChildChunk", _SimpleRetrievalChildChunk)
|
||||
monkeypatch.setattr(retrieval_service_module, "RetrievalSegments", _SimpleRetrievalSegment)
|
||||
|
||||
@@ -882,7 +882,7 @@ class TestRetrievalServiceInternals:
|
||||
def test_format_retrieval_documents_rolls_back_and_raises_when_db_fails(self, monkeypatch):
|
||||
rollback = Mock()
|
||||
monkeypatch.setattr(retrieval_service_module.db.session, "rollback", rollback)
|
||||
monkeypatch.setattr(retrieval_service_module.db.session, "query", Mock(side_effect=RuntimeError("db error")))
|
||||
monkeypatch.setattr(retrieval_service_module.db.session, "scalars", Mock(side_effect=RuntimeError("db error")))
|
||||
|
||||
documents = [Document(page_content="content", metadata={"document_id": "doc-1"}, provider="dify")]
|
||||
|
||||
|
||||
@@ -340,15 +340,13 @@ def test_search_by_file_handles_missing_and_existing_upload(vector_factory_modul
|
||||
vector._embeddings = MagicMock()
|
||||
vector._vector_processor = MagicMock()
|
||||
|
||||
mock_session = SimpleNamespace(get=lambda _model, _id: None)
|
||||
monkeypatch.setattr(vector_factory_module, "UploadFile", SimpleNamespace(id=_Field()))
|
||||
monkeypatch.setattr(
|
||||
vector_factory_module, "db", SimpleNamespace(session=SimpleNamespace(query=lambda _model: upload_query))
|
||||
)
|
||||
monkeypatch.setattr(vector_factory_module, "db", SimpleNamespace(session=mock_session))
|
||||
|
||||
upload_query.first.return_value = None
|
||||
assert vector.search_by_file("file-1") == []
|
||||
|
||||
upload_query.first.return_value = SimpleNamespace(key="blob-key")
|
||||
mock_session.get = lambda _model, _id: SimpleNamespace(key="blob-key")
|
||||
monkeypatch.setattr(vector_factory_module.storage, "load_once", MagicMock(return_value=b"file-bytes"))
|
||||
vector._embeddings.embed_multimodal_query.return_value = [0.3, 0.4]
|
||||
vector._vector_processor.search_by_vector.return_value = ["hit"]
|
||||
|
||||
@@ -167,7 +167,7 @@ class TestDatasetDocumentStoreAddDocuments:
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_db.session = mock_session
|
||||
mock_db.session.query.return_value.where.return_value.scalar.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.get_model_instance.return_value = mock_model_instance
|
||||
@@ -211,7 +211,7 @@ class TestDatasetDocumentStoreAddDocuments:
|
||||
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
|
||||
mock_session = MagicMock()
|
||||
mock_db.session = mock_session
|
||||
mock_db.session.query.return_value.where.return_value.scalar.return_value = 5
|
||||
mock_db.session.scalar.return_value = 5
|
||||
|
||||
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment):
|
||||
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):
|
||||
@@ -276,7 +276,7 @@ class TestDatasetDocumentStoreAddDocuments:
|
||||
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
|
||||
mock_session = MagicMock()
|
||||
mock_db.session = mock_session
|
||||
mock_db.session.query.return_value.where.return_value.scalar.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None):
|
||||
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):
|
||||
@@ -353,7 +353,7 @@ class TestDatasetDocumentStoreAddDocuments:
|
||||
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
|
||||
mock_session = MagicMock()
|
||||
mock_db.session = mock_session
|
||||
mock_db.session.query.return_value.where.return_value.scalar.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None):
|
||||
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):
|
||||
@@ -755,7 +755,7 @@ class TestDatasetDocumentStoreAddDocumentsUpdateChild:
|
||||
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
|
||||
mock_session = MagicMock()
|
||||
mock_db.session = mock_session
|
||||
mock_db.session.query.return_value.where.return_value.scalar.return_value = 5
|
||||
mock_db.session.scalar.return_value = 5
|
||||
|
||||
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment):
|
||||
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):
|
||||
@@ -767,7 +767,7 @@ class TestDatasetDocumentStoreAddDocumentsUpdateChild:
|
||||
|
||||
store.add_documents([mock_doc], save_child=True)
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.delete.assert_called()
|
||||
mock_db.session.execute.assert_called()
|
||||
mock_db.session.commit.assert_called()
|
||||
|
||||
|
||||
@@ -798,7 +798,7 @@ class TestDatasetDocumentStoreAddDocumentsUpdateAnswer:
|
||||
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
|
||||
mock_session = MagicMock()
|
||||
mock_db.session = mock_session
|
||||
mock_db.session.query.return_value.where.return_value.scalar.return_value = 5
|
||||
mock_db.session.scalar.return_value = 5
|
||||
|
||||
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment):
|
||||
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):
|
||||
|
||||
@@ -69,7 +69,7 @@ class TestCacheEmbeddingMultimodalDocuments:
|
||||
documents = [{"file_id": "file123", "content": "test content"}]
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
mock_model_instance.invoke_multimodal_embedding.return_value = sample_multimodal_result
|
||||
|
||||
result = cache_embedding.embed_multimodal_documents(documents)
|
||||
@@ -114,7 +114,7 @@ class TestCacheEmbeddingMultimodalDocuments:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result
|
||||
|
||||
result = cache_embedding.embed_multimodal_documents(documents)
|
||||
@@ -134,7 +134,7 @@ class TestCacheEmbeddingMultimodalDocuments:
|
||||
mock_cached_embedding.get_embedding.return_value = normalized_cached
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_cached_embedding
|
||||
mock_session.scalar.return_value = mock_cached_embedding
|
||||
|
||||
result = cache_embedding.embed_multimodal_documents(documents)
|
||||
|
||||
@@ -180,18 +180,7 @@ class TestCacheEmbeddingMultimodalDocuments:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
call_count = [0]
|
||||
|
||||
def mock_filter_by(**kwargs):
|
||||
call_count[0] += 1
|
||||
mock_query = Mock()
|
||||
if call_count[0] == 1:
|
||||
mock_query.first.return_value = mock_cached_embedding
|
||||
else:
|
||||
mock_query.first.return_value = None
|
||||
return mock_query
|
||||
|
||||
mock_session.query.return_value.filter_by = mock_filter_by
|
||||
mock_session.scalar.side_effect = [mock_cached_embedding, None, None]
|
||||
mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result
|
||||
|
||||
result = cache_embedding.embed_multimodal_documents(documents)
|
||||
@@ -224,7 +213,7 @@ class TestCacheEmbeddingMultimodalDocuments:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.logger") as mock_logger:
|
||||
@@ -265,7 +254,7 @@ class TestCacheEmbeddingMultimodalDocuments:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
batch_results = [create_batch_result(10), create_batch_result(10), create_batch_result(5)]
|
||||
mock_model_instance.invoke_multimodal_embedding.side_effect = batch_results
|
||||
@@ -281,7 +270,7 @@ class TestCacheEmbeddingMultimodalDocuments:
|
||||
documents = [{"file_id": "file123"}]
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
mock_model_instance.invoke_multimodal_embedding.side_effect = Exception("API Error")
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
@@ -298,7 +287,7 @@ class TestCacheEmbeddingMultimodalDocuments:
|
||||
documents = [{"file_id": "file123"}]
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
mock_model_instance.invoke_multimodal_embedding.return_value = sample_multimodal_result
|
||||
|
||||
mock_session.commit.side_effect = IntegrityError("Duplicate key", None, None)
|
||||
|
||||
@@ -139,7 +139,7 @@ class TestCacheEmbeddingDocuments:
|
||||
|
||||
# Mock database query to return no cached embedding (cache miss)
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Mock model invocation
|
||||
mock_model_instance.invoke_text_embedding.return_value = sample_embedding_result
|
||||
@@ -203,7 +203,7 @@ class TestCacheEmbeddingDocuments:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
mock_model_instance.invoke_text_embedding.return_value = embedding_result
|
||||
|
||||
# Act
|
||||
@@ -240,7 +240,7 @@ class TestCacheEmbeddingDocuments:
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
# Mock database to return cached embedding (cache hit)
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_cached_embedding
|
||||
mock_session.scalar.return_value = mock_cached_embedding
|
||||
|
||||
# Act
|
||||
result = cache_embedding.embed_documents(texts)
|
||||
@@ -313,19 +313,7 @@ class TestCacheEmbeddingDocuments:
|
||||
mock_hash.side_effect = generate_hash
|
||||
|
||||
# Mock database to return cached embedding only for first text (hash_1)
|
||||
call_count = [0]
|
||||
|
||||
def mock_filter_by(**kwargs):
|
||||
call_count[0] += 1
|
||||
mock_query = Mock()
|
||||
# First call (hash_1) returns cached, others return None
|
||||
if call_count[0] == 1:
|
||||
mock_query.first.return_value = mock_cached_embedding
|
||||
else:
|
||||
mock_query.first.return_value = None
|
||||
return mock_query
|
||||
|
||||
mock_session.query.return_value.filter_by = mock_filter_by
|
||||
mock_session.scalar.side_effect = [mock_cached_embedding, None, None]
|
||||
mock_model_instance.invoke_text_embedding.return_value = embedding_result
|
||||
|
||||
# Act
|
||||
@@ -392,7 +380,7 @@ class TestCacheEmbeddingDocuments:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Mock model to return appropriate batch results
|
||||
batch_results = [
|
||||
@@ -455,7 +443,7 @@ class TestCacheEmbeddingDocuments:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
mock_model_instance.invoke_text_embedding.return_value = embedding_result
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.logger") as mock_logger:
|
||||
@@ -489,7 +477,7 @@ class TestCacheEmbeddingDocuments:
|
||||
texts = ["Test text"]
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Mock model to raise connection error
|
||||
mock_model_instance.invoke_text_embedding.side_effect = InvokeConnectionError("Failed to connect to API")
|
||||
@@ -515,7 +503,7 @@ class TestCacheEmbeddingDocuments:
|
||||
texts = ["Test text"]
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Mock model to raise rate limit error
|
||||
mock_model_instance.invoke_text_embedding.side_effect = InvokeRateLimitError("Rate limit exceeded")
|
||||
@@ -539,7 +527,7 @@ class TestCacheEmbeddingDocuments:
|
||||
texts = ["Test text"]
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Mock model to raise authorization error
|
||||
mock_model_instance.invoke_text_embedding.side_effect = InvokeAuthorizationError("Invalid API key")
|
||||
@@ -564,7 +552,7 @@ class TestCacheEmbeddingDocuments:
|
||||
texts = ["Test text"]
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
mock_model_instance.invoke_text_embedding.return_value = sample_embedding_result
|
||||
|
||||
# Mock database commit to raise IntegrityError
|
||||
@@ -884,7 +872,7 @@ class TestEmbeddingModelSwitching:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
model_instance_ada.invoke_text_embedding.return_value = result_ada
|
||||
model_instance_3_small.invoke_text_embedding.return_value = result_3_small
|
||||
@@ -1047,7 +1035,7 @@ class TestEmbeddingDimensionValidation:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
mock_model_instance.invoke_text_embedding.return_value = embedding_result
|
||||
|
||||
# Act
|
||||
@@ -1100,7 +1088,7 @@ class TestEmbeddingDimensionValidation:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
mock_model_instance.invoke_text_embedding.return_value = embedding_result
|
||||
|
||||
# Act
|
||||
@@ -1186,7 +1174,7 @@ class TestEmbeddingDimensionValidation:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
model_instance_ada.invoke_text_embedding.return_value = result_ada
|
||||
model_instance_cohere.invoke_text_embedding.return_value = result_cohere
|
||||
@@ -1284,7 +1272,7 @@ class TestEmbeddingEdgeCases:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
mock_model_instance.invoke_text_embedding.return_value = embedding_result
|
||||
|
||||
# Act
|
||||
@@ -1327,7 +1315,7 @@ class TestEmbeddingEdgeCases:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
mock_model_instance.invoke_text_embedding.return_value = embedding_result
|
||||
|
||||
# Act
|
||||
@@ -1375,7 +1363,7 @@ class TestEmbeddingEdgeCases:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
mock_model_instance.invoke_text_embedding.return_value = embedding_result
|
||||
|
||||
# Act
|
||||
@@ -1427,7 +1415,7 @@ class TestEmbeddingEdgeCases:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
mock_model_instance.invoke_text_embedding.return_value = embedding_result
|
||||
|
||||
# Act
|
||||
@@ -1483,7 +1471,7 @@ class TestEmbeddingEdgeCases:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
mock_model_instance.invoke_text_embedding.return_value = embedding_result
|
||||
|
||||
# Act
|
||||
@@ -1551,7 +1539,7 @@ class TestEmbeddingEdgeCases:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
mock_model_instance.invoke_text_embedding.return_value = embedding_result
|
||||
|
||||
# Act
|
||||
@@ -1649,7 +1637,7 @@ class TestEmbeddingEdgeCases:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
mock_model_instance.invoke_text_embedding.return_value = embedding_result
|
||||
|
||||
# Act
|
||||
@@ -1728,7 +1716,7 @@ class TestEmbeddingCachePerformance:
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
# First call: cache miss
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
usage = EmbeddingUsage(
|
||||
tokens=5,
|
||||
@@ -1756,7 +1744,7 @@ class TestEmbeddingCachePerformance:
|
||||
assert len(result1) == 1
|
||||
|
||||
# Arrange - Second call: cache hit
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_cached_embedding
|
||||
mock_session.scalar.return_value = mock_cached_embedding
|
||||
|
||||
# Act - Second call (cache hit)
|
||||
result2 = cache_embedding.embed_documents([text])
|
||||
@@ -1816,7 +1804,7 @@ class TestEmbeddingCachePerformance:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Mock model to return appropriate batch results
|
||||
batch_results = [
|
||||
|
||||
@@ -405,35 +405,36 @@ class TestNotionMetadataAndCredentialMethods:
|
||||
|
||||
class FakeDocumentModel:
|
||||
data_source_info = "data_source_info"
|
||||
id = "id"
|
||||
|
||||
update_calls = []
|
||||
execute_calls = []
|
||||
|
||||
class FakeQuery:
|
||||
def filter_by(self, **kwargs):
|
||||
class FakeUpdateStmt:
|
||||
def where(self, *args):
|
||||
return self
|
||||
|
||||
def update(self, payload):
|
||||
update_calls.append(payload)
|
||||
def values(self, **kwargs):
|
||||
return self
|
||||
|
||||
class FakeSession:
|
||||
committed = False
|
||||
|
||||
def query(self, model):
|
||||
assert model is FakeDocumentModel
|
||||
return FakeQuery()
|
||||
def execute(self, stmt):
|
||||
execute_calls.append(stmt)
|
||||
|
||||
def commit(self):
|
||||
self.committed = True
|
||||
|
||||
fake_db = SimpleNamespace(session=FakeSession())
|
||||
monkeypatch.setattr(notion_extractor, "DocumentModel", FakeDocumentModel)
|
||||
monkeypatch.setattr(notion_extractor, "update", lambda model: FakeUpdateStmt())
|
||||
monkeypatch.setattr(notion_extractor, "db", fake_db)
|
||||
monkeypatch.setattr(extractor, "get_notion_last_edited_time", lambda: "2026-01-01T00:00:00.000Z")
|
||||
|
||||
doc_model = SimpleNamespace(id="doc-1", data_source_info_dict={"source": "notion"})
|
||||
extractor.update_last_edited_time(doc_model)
|
||||
|
||||
assert update_calls
|
||||
assert execute_calls
|
||||
assert fake_db.session.committed is True
|
||||
|
||||
def test_get_notion_last_edited_time_uses_page_and_database_urls(self, mocker: MockerFixture):
|
||||
|
||||
@@ -188,10 +188,10 @@ class TestParagraphIndexProcessor:
|
||||
mock_keyword_cls.return_value.add_texts.assert_called_once_with(docs)
|
||||
|
||||
def test_clean_deletes_summaries_and_vector(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None:
|
||||
segment_query = Mock()
|
||||
segment_query.filter.return_value.all.return_value = [SimpleNamespace(id="seg-1")]
|
||||
scalars_result = Mock()
|
||||
scalars_result.all.return_value = [SimpleNamespace(id="seg-1")]
|
||||
session = Mock()
|
||||
session.query.return_value = segment_query
|
||||
session.scalars.return_value = scalars_result
|
||||
|
||||
with (
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session),
|
||||
@@ -531,10 +531,10 @@ class TestParagraphIndexProcessor:
|
||||
size=1,
|
||||
key="key",
|
||||
)
|
||||
query = Mock()
|
||||
query.where.return_value.all.return_value = [image_upload, non_image_upload]
|
||||
scalars_result = Mock()
|
||||
scalars_result.all.return_value = [image_upload, non_image_upload]
|
||||
session = Mock()
|
||||
session.query.return_value = query
|
||||
session.scalars.return_value = scalars_result
|
||||
|
||||
with (
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session),
|
||||
@@ -565,10 +565,10 @@ class TestParagraphIndexProcessor:
|
||||
size=1,
|
||||
key="key",
|
||||
)
|
||||
query = Mock()
|
||||
query.where.return_value.all.return_value = [image_upload]
|
||||
scalars_result = Mock()
|
||||
scalars_result.all.return_value = [image_upload]
|
||||
session = Mock()
|
||||
session.query.return_value = query
|
||||
session.scalars.return_value = scalars_result
|
||||
|
||||
with (
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session),
|
||||
|
||||
@@ -208,11 +208,7 @@ class TestParentChildIndexProcessor:
|
||||
vector.create_multimodal.assert_called_once_with(multimodal_docs)
|
||||
|
||||
def test_clean_with_precomputed_child_ids(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None:
|
||||
delete_query = Mock()
|
||||
where_query = Mock()
|
||||
where_query.delete.return_value = 2
|
||||
session = Mock()
|
||||
session.query.return_value.where.return_value = where_query
|
||||
|
||||
with (
|
||||
patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls,
|
||||
@@ -227,16 +223,16 @@ class TestParentChildIndexProcessor:
|
||||
)
|
||||
|
||||
vector.delete_by_ids.assert_called_once_with(["child-1", "child-2"])
|
||||
where_query.delete.assert_called_once_with(synchronize_session=False)
|
||||
session.execute.assert_called()
|
||||
session.commit.assert_called_once()
|
||||
|
||||
def test_clean_queries_child_ids_when_not_precomputed(
|
||||
self, processor: ParentChildIndexProcessor, dataset: Mock
|
||||
) -> None:
|
||||
child_query = Mock()
|
||||
child_query.join.return_value.where.return_value.all.return_value = [("child-1",), (None,), ("child-2",)]
|
||||
execute_result = Mock()
|
||||
execute_result.all.return_value = [("child-1",), (None,), ("child-2",)]
|
||||
session = Mock()
|
||||
session.query.return_value = child_query
|
||||
session.execute.return_value = execute_result
|
||||
|
||||
with (
|
||||
patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls,
|
||||
@@ -248,10 +244,7 @@ class TestParentChildIndexProcessor:
|
||||
vector.delete_by_ids.assert_called_once_with(["child-1", "child-2"])
|
||||
|
||||
def test_clean_dataset_wide_cleanup(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None:
|
||||
where_query = Mock()
|
||||
where_query.delete.return_value = 3
|
||||
session = Mock()
|
||||
session.query.return_value.where.return_value = where_query
|
||||
|
||||
with (
|
||||
patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls,
|
||||
@@ -261,7 +254,7 @@ class TestParentChildIndexProcessor:
|
||||
processor.clean(dataset, None, delete_child_chunks=True)
|
||||
|
||||
vector.delete.assert_called_once()
|
||||
where_query.delete.assert_called_once_with(synchronize_session=False)
|
||||
session.execute.assert_called()
|
||||
session.commit.assert_called_once()
|
||||
|
||||
def test_clean_deletes_summaries_when_requested(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None:
|
||||
|
||||
@@ -133,10 +133,10 @@ class TestBaseIndexProcessor:
|
||||
upload_b = SimpleNamespace(id="bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb", name="b.png")
|
||||
upload_tool = SimpleNamespace(id="tool-upload-id", name="tool.png")
|
||||
upload_remote = SimpleNamespace(id="remote-upload-id", name="remote.png")
|
||||
db_query = Mock()
|
||||
db_query.where.return_value.all.return_value = [upload_a, upload_b, upload_tool, upload_remote]
|
||||
scalars_result = Mock()
|
||||
scalars_result.all.return_value = [upload_a, upload_b, upload_tool, upload_remote]
|
||||
db_session = Mock()
|
||||
db_session.query.return_value = db_query
|
||||
db_session.scalars.return_value = scalars_result
|
||||
|
||||
with (
|
||||
patch.object(processor, "_extract_markdown_images", return_value=images),
|
||||
@@ -170,10 +170,10 @@ class TestBaseIndexProcessor:
|
||||
def test_get_content_files_ignores_missing_upload_records(self, processor: _ForwardingBaseIndexProcessor) -> None:
|
||||
document = Document(page_content="ignored", metadata={"document_id": "doc-1", "dataset_id": "ds-1"})
|
||||
images = ["/files/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/image-preview"]
|
||||
db_query = Mock()
|
||||
db_query.where.return_value.all.return_value = []
|
||||
scalars_result = Mock()
|
||||
scalars_result.all.return_value = []
|
||||
db_session = Mock()
|
||||
db_session.query.return_value = db_query
|
||||
db_session.scalars.return_value = scalars_result
|
||||
|
||||
with (
|
||||
patch.object(processor, "_extract_markdown_images", return_value=images),
|
||||
@@ -259,20 +259,16 @@ class TestBaseIndexProcessor:
|
||||
assert processor._download_image("https://example.com/image.png", current_user=Mock()) is None
|
||||
|
||||
def test_download_tool_file_returns_none_when_not_found(self, processor: _ForwardingBaseIndexProcessor) -> None:
|
||||
db_query = Mock()
|
||||
db_query.where.return_value.first.return_value = None
|
||||
db_session = Mock()
|
||||
db_session.query.return_value = db_query
|
||||
db_session.get.return_value = None
|
||||
|
||||
with patch("core.rag.index_processor.index_processor_base.db.session", db_session):
|
||||
assert processor._download_tool_file("tool-id", current_user=Mock()) is None
|
||||
|
||||
def test_download_tool_file_uploads_file_when_found(self, processor: _ForwardingBaseIndexProcessor) -> None:
|
||||
tool_file = SimpleNamespace(file_key="k1", name="tool.png", mimetype="image/png")
|
||||
db_query = Mock()
|
||||
db_query.where.return_value.first.return_value = tool_file
|
||||
db_session = Mock()
|
||||
db_session.query.return_value = db_query
|
||||
db_session.get.return_value = tool_file
|
||||
mock_db = Mock()
|
||||
mock_db.session = db_session
|
||||
mock_db.engine = Mock()
|
||||
|
||||
@@ -473,12 +473,10 @@ class TestRerankModelRunnerMultimodal:
|
||||
metadata={},
|
||||
provider="external",
|
||||
)
|
||||
query = Mock()
|
||||
query.where.return_value.first.return_value = SimpleNamespace(key="image-key")
|
||||
rerank_result = RerankResult(model="rerank-model", docs=[])
|
||||
|
||||
with (
|
||||
patch("core.rag.rerank.rerank_model.db.session.query", return_value=query),
|
||||
patch("core.rag.rerank.rerank_model.db.session.get", return_value=SimpleNamespace(key="image-key")),
|
||||
patch("core.rag.rerank.rerank_model.storage.load_once", return_value=b"image-bytes") as mock_load_once,
|
||||
patch.object(
|
||||
rerank_runner,
|
||||
@@ -504,12 +502,10 @@ class TestRerankModelRunnerMultimodal:
|
||||
metadata={"doc_id": "img-missing", "doc_type": DocType.IMAGE},
|
||||
provider="dify",
|
||||
)
|
||||
query = Mock()
|
||||
query.where.return_value.first.return_value = None
|
||||
rerank_result = RerankResult(model="rerank-model", docs=[])
|
||||
|
||||
with (
|
||||
patch("core.rag.rerank.rerank_model.db.session.query", return_value=query),
|
||||
patch("core.rag.rerank.rerank_model.db.session.get", return_value=None),
|
||||
patch.object(
|
||||
rerank_runner,
|
||||
"fetch_text_rerank",
|
||||
@@ -533,8 +529,6 @@ class TestRerankModelRunnerMultimodal:
|
||||
metadata={"doc_id": "txt-1", "doc_type": DocType.TEXT},
|
||||
provider="dify",
|
||||
)
|
||||
query_chain = Mock()
|
||||
query_chain.where.return_value.first.return_value = SimpleNamespace(key="query-image-key")
|
||||
rerank_result = RerankResult(
|
||||
model="rerank-model",
|
||||
docs=[RerankDocument(index=0, text="text-content", score=0.77)],
|
||||
@@ -542,7 +536,7 @@ class TestRerankModelRunnerMultimodal:
|
||||
mock_model_instance.invoke_multimodal_rerank.return_value = rerank_result
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query_chain
|
||||
session.get.return_value = SimpleNamespace(key="query-image-key")
|
||||
with (
|
||||
patch("core.rag.rerank.rerank_model.db.session", session),
|
||||
patch("core.rag.rerank.rerank_model.storage.load_once", return_value=b"query-image-bytes"),
|
||||
@@ -563,10 +557,7 @@ class TestRerankModelRunnerMultimodal:
|
||||
assert "user" not in invoke_kwargs
|
||||
|
||||
def test_fetch_multimodal_rerank_raises_when_query_image_not_found(self, rerank_runner):
|
||||
query_chain = Mock()
|
||||
query_chain.where.return_value.first.return_value = None
|
||||
|
||||
with patch("core.rag.rerank.rerank_model.db.session.query", return_value=query_chain):
|
||||
with patch("core.rag.rerank.rerank_model.db.session.get", return_value=None):
|
||||
with pytest.raises(ValueError, match="Upload file not found for query"):
|
||||
rerank_runner.fetch_multimodal_rerank(
|
||||
query="missing-upload-id",
|
||||
|
||||
@@ -3971,11 +3971,10 @@ class TestDatasetRetrievalAdditionalHelpers:
|
||||
)
|
||||
|
||||
def test_get_metadata_filter_condition(self, retrieval: DatasetRetrieval) -> None:
|
||||
db_query = Mock()
|
||||
db_query.where.return_value = db_query
|
||||
db_query.all.return_value = [SimpleNamespace(dataset_id="d1", id="doc-1")]
|
||||
scalars_result = Mock()
|
||||
scalars_result.all.return_value = [SimpleNamespace(dataset_id="d1", id="doc-1")]
|
||||
|
||||
with patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query):
|
||||
with patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=scalars_result):
|
||||
mapping, condition = retrieval.get_metadata_filter_condition(
|
||||
dataset_ids=["d1"],
|
||||
query="python",
|
||||
@@ -3991,7 +3990,7 @@ class TestDatasetRetrievalAdditionalHelpers:
|
||||
|
||||
automatic_filters = [{"condition": "contains", "metadata_name": "author", "value": "Alice"}]
|
||||
with (
|
||||
patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query),
|
||||
patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=scalars_result),
|
||||
patch.object(retrieval, "_automatic_metadata_filter_func", return_value=automatic_filters),
|
||||
):
|
||||
mapping, condition = retrieval.get_metadata_filter_condition(
|
||||
@@ -4012,7 +4011,7 @@ class TestDatasetRetrievalAdditionalHelpers:
|
||||
logical_operator="and",
|
||||
conditions=[AppCondition(name="author", comparison_operator="contains", value="{{name}}")],
|
||||
)
|
||||
with patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query):
|
||||
with patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=scalars_result):
|
||||
mapping, condition = retrieval.get_metadata_filter_condition(
|
||||
dataset_ids=["d1"],
|
||||
query="python",
|
||||
@@ -4027,7 +4026,7 @@ class TestDatasetRetrievalAdditionalHelpers:
|
||||
assert condition is not None
|
||||
assert condition.conditions[0].value == "Alice"
|
||||
|
||||
with patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query):
|
||||
with patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=scalars_result):
|
||||
with pytest.raises(ValueError, match="Invalid metadata filtering mode"):
|
||||
retrieval.get_metadata_filter_condition(
|
||||
dataset_ids=["d1"],
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user