Merge branch 'refactor-enhance-billing-info-guard' into deploy/dev

This commit is contained in:
hj24
2026-04-02 11:02:00 +08:00
159 changed files with 8254 additions and 8597 deletions

View File

@@ -6,7 +6,6 @@ runs:
- name: Setup Vite+
uses: voidzero-dev/setup-vp@20553a7a7429c429a74894104a2835d7fed28a72 # v1.3.0
with:
working-directory: web
node-version-file: .nvmrc
cache: true
run-install: true

View File

@@ -35,7 +35,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@@ -84,7 +84,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@@ -156,7 +156,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
with:
enable-cache: true
python-version: "3.12"
@@ -203,7 +203,7 @@ jobs:
- name: Report coverage
if: ${{ env.CODECOV_TOKEN != '' }}
uses: codecov/codecov-action@1af58845a975a7985b0beb0cbe6fbbb71a41dbad # v5.5.3
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
with:
files: ./coverage.xml
disable_search: true

View File

@@ -39,6 +39,10 @@ jobs:
with:
files: |
web/**
package.json
pnpm-lock.yaml
pnpm-workspace.yaml
.nvmrc
- name: Check api inputs
if: github.event_name != 'merge_group'
id: api-changes
@@ -52,7 +56,7 @@ jobs:
python-version: "3.11"
- if: github.event_name != 'merge_group'
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
- name: Generate Docker Compose
if: github.event_name != 'merge_group' && steps.docker-compose-changes.outputs.any_changed == 'true'

View File

@@ -24,27 +24,39 @@ env:
jobs:
build:
runs-on: ${{ matrix.platform == 'linux/arm64' && 'arm64_runner' || 'ubuntu-latest' }}
runs-on: ${{ matrix.runs_on }}
if: github.repository == 'langgenius/dify'
strategy:
matrix:
include:
- service_name: "build-api-amd64"
image_name_env: "DIFY_API_IMAGE_NAME"
context: "api"
artifact_context: "api"
build_context: "{{defaultContext}}:api"
file: "Dockerfile"
platform: linux/amd64
runs_on: ubuntu-latest
- service_name: "build-api-arm64"
image_name_env: "DIFY_API_IMAGE_NAME"
context: "api"
artifact_context: "api"
build_context: "{{defaultContext}}:api"
file: "Dockerfile"
platform: linux/arm64
runs_on: ubuntu-24.04-arm
- service_name: "build-web-amd64"
image_name_env: "DIFY_WEB_IMAGE_NAME"
context: "web"
artifact_context: "web"
build_context: "{{defaultContext}}"
file: "web/Dockerfile"
platform: linux/amd64
runs_on: ubuntu-latest
- service_name: "build-web-arm64"
image_name_env: "DIFY_WEB_IMAGE_NAME"
context: "web"
artifact_context: "web"
build_context: "{{defaultContext}}"
file: "web/Dockerfile"
platform: linux/arm64
runs_on: ubuntu-24.04-arm
steps:
- name: Prepare
@@ -58,9 +70,6 @@ jobs:
username: ${{ env.DOCKERHUB_USER }}
password: ${{ env.DOCKERHUB_TOKEN }}
- name: Set up QEMU
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4.0.0
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
@@ -74,7 +83,8 @@ jobs:
id: build
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
with:
context: "{{defaultContext}}:${{ matrix.context }}"
context: ${{ matrix.build_context }}
file: ${{ matrix.file }}
platforms: ${{ matrix.platform }}
build-args: COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }}
labels: ${{ steps.meta.outputs.labels }}
@@ -93,7 +103,7 @@ jobs:
- name: Upload digest
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: digests-${{ matrix.context }}-${{ env.PLATFORM_PAIR }}
name: digests-${{ matrix.artifact_context }}-${{ env.PLATFORM_PAIR }}
path: /tmp/digests/*
if-no-files-found: error
retention-days: 1

View File

@@ -19,7 +19,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
with:
enable-cache: true
python-version: "3.12"
@@ -69,7 +69,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
with:
enable-cache: true
python-version: "3.12"

View File

@@ -6,7 +6,12 @@ on:
- "main"
paths:
- api/Dockerfile
- web/docker/**
- web/Dockerfile
- package.json
- pnpm-lock.yaml
- pnpm-workspace.yaml
- .nvmrc
concurrency:
group: docker-build-${{ github.head_ref || github.run_id }}
@@ -14,26 +19,31 @@ concurrency:
jobs:
build-docker:
runs-on: ubuntu-latest
runs-on: ${{ matrix.runs_on }}
strategy:
matrix:
include:
- service_name: "api-amd64"
platform: linux/amd64
context: "api"
runs_on: ubuntu-latest
context: "{{defaultContext}}:api"
file: "Dockerfile"
- service_name: "api-arm64"
platform: linux/arm64
context: "api"
runs_on: ubuntu-24.04-arm
context: "{{defaultContext}}:api"
file: "Dockerfile"
- service_name: "web-amd64"
platform: linux/amd64
context: "web"
runs_on: ubuntu-latest
context: "{{defaultContext}}"
file: "web/Dockerfile"
- service_name: "web-arm64"
platform: linux/arm64
context: "web"
runs_on: ubuntu-24.04-arm
context: "{{defaultContext}}"
file: "web/Dockerfile"
steps:
- name: Set up QEMU
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4.0.0
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
@@ -41,8 +51,8 @@ jobs:
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
with:
push: false
context: "{{defaultContext}}:${{ matrix.context }}"
file: "${{ matrix.file }}"
context: ${{ matrix.context }}
file: ${{ matrix.file }}
platforms: ${{ matrix.platform }}
cache-from: type=gha
cache-to: type=gha,mode=max

View File

@@ -65,6 +65,10 @@ jobs:
- 'docker/volumes/sandbox/conf/**'
web:
- 'web/**'
- 'package.json'
- 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml'
- '.nvmrc'
- '.github/workflows/web-tests.yml'
- '.github/actions/setup-web/**'
e2e:
@@ -73,6 +77,10 @@ jobs:
- 'api/uv.lock'
- 'e2e/**'
- 'web/**'
- 'package.json'
- 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml'
- '.nvmrc'
- 'docker/docker-compose.middleware.yaml'
- 'docker/middleware.env.example'
- '.github/workflows/web-e2e.yml'

View File

@@ -22,7 +22,7 @@ jobs:
fetch-depth: 0
- name: Setup Python & UV
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
with:
enable-cache: true
@@ -50,6 +50,17 @@ jobs:
run: |
diff -u /tmp/pyrefly_base.txt /tmp/pyrefly_pr.txt > pyrefly_diff.txt || true
- name: Check if line counts match
id: line_count_check
run: |
base_lines=$(wc -l < /tmp/pyrefly_base.txt)
pr_lines=$(wc -l < /tmp/pyrefly_pr.txt)
if [ "$base_lines" -eq "$pr_lines" ]; then
echo "same=true" >> $GITHUB_OUTPUT
else
echo "same=false" >> $GITHUB_OUTPUT
fi
- name: Save PR number
run: |
echo ${{ github.event.pull_request.number }} > pr_number.txt
@@ -63,7 +74,7 @@ jobs:
pr_number.txt
- name: Comment PR with pyrefly diff
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
if: ${{ github.event.pull_request.head.repo.full_name == github.repository && steps.line_count_check.outputs.same == 'false' }}
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -33,7 +33,7 @@ jobs:
- name: Setup UV and Python
if: steps.changed-files.outputs.any_changed == 'true'
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
with:
enable-cache: false
python-version: "3.12"
@@ -77,6 +77,10 @@ jobs:
with:
files: |
web/**
package.json
pnpm-lock.yaml
pnpm-workspace.yaml
.nvmrc
.github/workflows/style.yml
.github/actions/setup-web/**
@@ -90,9 +94,9 @@ jobs:
uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
with:
path: web/.eslintcache
key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
restore-keys: |
${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
- name: Web style check
if: steps.changed-files.outputs.any_changed == 'true'

View File

@@ -6,6 +6,9 @@ on:
- main
paths:
- sdks/**
- package.json
- pnpm-lock.yaml
- pnpm-workspace.yaml
concurrency:
group: sdk-tests-${{ github.head_ref || github.run_id }}

95
.github/workflows/vdb-tests-full.yml vendored Normal file
View File

@@ -0,0 +1,95 @@
name: Run Full VDB Tests
on:
schedule:
- cron: '0 3 * * 1'
workflow_dispatch:
permissions:
contents: read
concurrency:
group: vdb-tests-full-${{ github.ref || github.run_id }}
cancel-in-progress: true
jobs:
test:
name: Full VDB Tests
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
strategy:
matrix:
python-version:
- "3.12"
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Free Disk Space
uses: endersonmenezes/free-disk-space@7901478139cff6e9d44df5972fd8ab8fcade4db1 # v3.2.2
with:
remove_dotnet: true
remove_haskell: true
remove_tool_cache: true
- name: Setup UV and Python
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
cache-dependency-glob: api/uv.lock
- name: Check UV lockfile
run: uv lock --project api --check
- name: Install dependencies
run: uv sync --project api --dev
- name: Set up dotenvs
run: |
cp docker/.env.example docker/.env
cp docker/middleware.env.example docker/middleware.env
- name: Expose Service Ports
run: sh .github/workflows/expose_service_ports.sh
# - name: Set up Vector Store (TiDB)
# uses: hoverkraft-tech/compose-action@v2.0.2
# with:
# compose-file: docker/tidb/docker-compose.yaml
# services: |
# tidb
# tiflash
- name: Set up Full Vector Store Matrix
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
with:
compose-file: |
docker/docker-compose.yaml
services: |
weaviate
qdrant
couchbase-server
etcd
minio
milvus-standalone
pgvecto-rs
pgvector
chroma
elasticsearch
oceanbase
- name: setup test config
run: |
echo $(pwd)
ls -lah .
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
# - name: Check VDB Ready (TiDB)
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
- name: Test Vector Stores
run: uv run --project api bash dev/pytest/pytest_vdb.sh

View File

@@ -1,15 +1,18 @@
name: Run VDB Tests
name: Run VDB Smoke Tests
on:
workflow_call:
permissions:
contents: read
concurrency:
group: vdb-tests-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
test:
name: VDB Tests
name: VDB Smoke Tests
runs-on: ubuntu-latest
strategy:
matrix:
@@ -30,7 +33,7 @@ jobs:
remove_tool_cache: true
- name: Setup UV and Python
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@@ -58,23 +61,18 @@ jobs:
# tidb
# tiflash
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase, OceanBase)
- name: Set up Vector Stores for Smoke Coverage
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
with:
compose-file: |
docker/docker-compose.yaml
services: |
db_postgres
redis
weaviate
qdrant
couchbase-server
etcd
minio
milvus-standalone
pgvecto-rs
pgvector
chroma
elasticsearch
oceanbase
- name: setup test config
run: |
@@ -86,4 +84,9 @@ jobs:
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
- name: Test Vector Stores
run: uv run --project api bash dev/pytest/pytest_vdb.sh
run: |
uv run --project api pytest --timeout "${PYTEST_TIMEOUT:-180}" \
api/tests/integration_tests/vdb/chroma \
api/tests/integration_tests/vdb/pgvector \
api/tests/integration_tests/vdb/qdrant \
api/tests/integration_tests/vdb/weaviate

View File

@@ -27,12 +27,8 @@ jobs:
- name: Setup web dependencies
uses: ./.github/actions/setup-web
- name: Install E2E package dependencies
working-directory: ./e2e
run: vp install --frozen-lockfile
- name: Setup UV and Python
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
with:
enable-cache: true
python-version: "3.12"

View File

@@ -83,40 +83,9 @@ jobs:
- name: Report coverage
if: ${{ env.CODECOV_TOKEN != '' }}
uses: codecov/codecov-action@1af58845a975a7985b0beb0cbe6fbbb71a41dbad # v5.5.3
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
with:
directory: web/coverage
flags: web
env:
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
web-build:
name: Web Build
runs-on: ubuntu-latest
defaults:
run:
working-directory: ./web
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
with:
files: |
web/**
.github/workflows/web-tests.yml
.github/actions/setup-web/**
- name: Setup web environment
if: steps.changed-files.outputs.any_changed == 'true'
uses: ./.github/actions/setup-web
- name: Web build check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: vp run build

3
.gitignore vendored
View File

@@ -212,6 +212,7 @@ api/.vscode
# pnpm
/.pnpm-store
/node_modules
# plugin migrate
plugins.jsonl
@@ -239,4 +240,4 @@ scripts/stress-test/reports/
*.local.md
# Code Agent Folder
.qoder/*
.qoder/*

View File

View File

@@ -24,8 +24,8 @@ prepare-docker:
# Step 2: Prepare web environment
prepare-web:
@echo "🌐 Setting up web environment..."
@cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists"
@cd web && pnpm install
@cp -n web/.env.example web/.env.local 2>/dev/null || echo "Web .env.local already exists"
@pnpm install
@echo "✅ Web environment prepared (not started)"
# Step 3: Prepare API environment
@@ -93,7 +93,7 @@ test:
# Build Docker images
build-web:
@echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..."
docker build -t $(WEB_IMAGE):$(VERSION) ./web
docker build -f web/Dockerfile -t $(WEB_IMAGE):$(VERSION) .
@echo "Web Docker image built successfully: $(WEB_IMAGE):$(VERSION)"
build-api:

View File

@@ -40,6 +40,8 @@ The scripts resolve paths relative to their location, so you can run them from a
./dev/start-web
```
`./dev/setup` and `./dev/start-web` install JavaScript dependencies through the repository root workspace, so you do not need a separate `cd web && pnpm install` step.
1. Set up your application by visiting `http://localhost:3000`.
1. Start the worker service (async and scheduler tasks, runs from `api`).

View File

@@ -120,7 +120,8 @@ class DatasourceOAuthCallback(Resource):
if context is None:
raise Forbidden("Invalid context_id")
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
user_id: str = context["user_id"]
tenant_id: str = context["tenant_id"]
datasource_provider_id = DatasourceProviderID(provider_id)
plugin_id = datasource_provider_id.plugin_id
datasource_provider_service = DatasourceProviderService()
@@ -141,7 +142,7 @@ class DatasourceOAuthCallback(Resource):
system_credentials=oauth_client_params,
request=request,
)
credential_id = context.get("credential_id")
credential_id: str | None = context.get("credential_id")
if credential_id:
datasource_provider_service.reauthorize_datasource_oauth_provider(
tenant_id=tenant_id,
@@ -150,7 +151,7 @@ class DatasourceOAuthCallback(Resource):
name=oauth_response.metadata.get("name") or None,
expire_at=oauth_response.expires_at,
credentials=dict(oauth_response.credentials),
credential_id=context.get("credential_id"),
credential_id=credential_id,
)
else:
datasource_provider_service.add_datasource_oauth_provider(

View File

@@ -287,12 +287,10 @@ class ModelProviderModelCredentialApi(Resource):
provider=provider,
)
else:
# Normalize model_type to the origin value stored in DB (e.g., "text-generation" for LLM)
normalized_model_type = args.model_type.to_origin_model_type()
available_credentials = model_provider_service.get_provider_model_available_credentials(
tenant_id=tenant_id,
provider=provider,
model_type=normalized_model_type,
model_type=args.model_type,
model=args.model,
)

View File

@@ -832,7 +832,8 @@ class ToolOAuthCallback(Resource):
tool_provider = ToolProviderID(provider)
plugin_id = tool_provider.plugin_id
provider_name = tool_provider.provider_name
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
user_id: str = context["user_id"]
tenant_id: str = context["tenant_id"]
oauth_handler = OAuthHandler()
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id, provider)

View File

@@ -499,9 +499,9 @@ class TriggerOAuthCallbackApi(Resource):
provider_id = TriggerProviderID(provider)
plugin_id = provider_id.plugin_id
provider_name = provider_id.provider_name
user_id = context.get("user_id")
tenant_id = context.get("tenant_id")
subscription_builder_id = context.get("subscription_builder_id")
user_id: str = context["user_id"]
tenant_id: str = context["tenant_id"]
subscription_builder_id: str = context["subscription_builder_id"]
# Get OAuth client configuration
oauth_client_params = TriggerProviderService.get_oauth_client(

View File

@@ -174,6 +174,7 @@ class MCPAppApi(Resource):
required=variable.get("required", False),
max_length=variable.get("max_length"),
options=variable.get("options") or [],
json_schema=variable.get("json_schema"),
)
def _parse_mcp_request(self, args: dict) -> mcp_types.ClientRequest | mcp_types.ClientNotification:

View File

@@ -29,6 +29,31 @@ from services.entities.knowledge_entities.knowledge_entities import SegmentUpdat
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
from services.summary_index_service import SummaryIndexService
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict:
"""Marshal a single segment and enrich it with summary content."""
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
segment_dict["summary"] = summary.summary_content if summary else None
return segment_dict
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict]:
"""Marshal multiple segments and enrich them with summary content (batch query)."""
segment_ids = [segment.id for segment in segments]
summaries: dict = {}
if segment_ids:
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()}
result = []
for segment in segments:
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
segment_dict["summary"] = summaries.get(segment.id)
result.append(segment_dict)
return result
class SegmentCreatePayload(BaseModel):
@@ -132,7 +157,7 @@ class SegmentApi(DatasetApiResource):
for args_item in payload.segments:
SegmentService.segment_create_args_validate(args_item, document)
segments = SegmentService.multi_create_segment(payload.segments, document, dataset)
return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200
return {"data": _marshal_segments_with_summary(segments, dataset_id), "doc_form": document.doc_form}, 200
else:
return {"error": "Segments is required"}, 400
@@ -196,7 +221,7 @@ class SegmentApi(DatasetApiResource):
)
response = {
"data": marshal(segments, segment_fields),
"data": _marshal_segments_with_summary(segments, dataset_id),
"doc_form": document.doc_form,
"total": total,
"has_more": len(segments) == limit,
@@ -296,7 +321,7 @@ class DatasetSegmentApi(DatasetApiResource):
payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {})
updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset)
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200
return {"data": _marshal_segment_with_summary(updated_segment, dataset_id), "doc_form": document.doc_form}, 200
@service_api_ns.doc("get_segment")
@service_api_ns.doc(description="Get a specific segment by ID")
@@ -326,7 +351,7 @@ class DatasetSegmentApi(DatasetApiResource):
if not segment:
raise NotFound("Segment not found.")
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
return {"data": _marshal_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
@service_api_ns.route(

View File

@@ -81,7 +81,7 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_type == system_configuration.current_quota_type,
Provider.quota_limit > Provider.quota_used,
)
.values(

View File

@@ -8,6 +8,7 @@ associates with the node span.
"""
import logging
from contextvars import Token
from dataclasses import dataclass
from typing import cast, final
@@ -35,7 +36,7 @@ logger = logging.getLogger(__name__)
@dataclass(slots=True)
class _NodeSpanContext:
span: "Span"
token: object
token: Token[context_api.Context]
@final

View File

@@ -403,7 +403,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
ProviderModelCredential.model_type == model_type,
),
)
@@ -753,7 +753,7 @@ class ProviderConfiguration(BaseModel):
ProviderModel.tenant_id == self.tenant_id,
ProviderModel.provider_name.in_(provider_names),
ProviderModel.model_name == model,
ProviderModel.model_type == model_type.to_origin_model_type(),
ProviderModel.model_type == model_type,
)
return session.execute(stmt).scalar_one_or_none()
@@ -778,7 +778,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
ProviderModelCredential.model_type == model_type,
)
credential_record = session.execute(stmt).scalar_one_or_none()
@@ -825,7 +825,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
ProviderModelCredential.model_type == model_type,
ProviderModelCredential.credential_name == credential_name,
)
if exclude_id:
@@ -901,7 +901,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
ProviderModelCredential.model_type == model_type,
)
credential_record = s.execute(stmt).scalar_one_or_none()
original_credentials = (
@@ -970,7 +970,7 @@ class ProviderConfiguration(BaseModel):
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_name=model,
model_type=model_type.to_origin_model_type(),
model_type=model_type,
encrypted_config=json.dumps(credentials),
credential_name=credential_name,
)
@@ -983,7 +983,7 @@ class ProviderConfiguration(BaseModel):
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_name=model,
model_type=model_type.to_origin_model_type(),
model_type=model_type,
credential_id=credential.id,
is_valid=True,
)
@@ -1038,7 +1038,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
ProviderModelCredential.model_type == model_type,
)
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record:
@@ -1083,7 +1083,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
ProviderModelCredential.model_type == model_type,
)
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record:
@@ -1116,7 +1116,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
ProviderModelCredential.model_type == model_type,
)
available_credentials_count = session.execute(count_stmt).scalar() or 0
session.delete(credential_record)
@@ -1156,7 +1156,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
ProviderModelCredential.model_type == model_type,
)
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record:
@@ -1171,7 +1171,7 @@ class ProviderConfiguration(BaseModel):
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_name=model,
model_type=model_type.to_origin_model_type(),
model_type=model_type,
is_valid=True,
credential_id=credential_id,
)
@@ -1207,7 +1207,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
ProviderModelCredential.model_type == model_type,
)
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record:
@@ -1263,7 +1263,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelSetting).where(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name.in_(self._get_provider_names()),
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_type == model_type,
ProviderModelSetting.model_name == model,
)
return session.execute(stmt).scalars().first()
@@ -1286,7 +1286,7 @@ class ProviderConfiguration(BaseModel):
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_type=model_type,
model_name=model,
enabled=True,
)
@@ -1312,7 +1312,7 @@ class ProviderConfiguration(BaseModel):
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_type=model_type,
model_name=model,
enabled=False,
)
@@ -1348,7 +1348,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(func.count(LoadBalancingModelConfig.id)).where(
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name.in_(provider_names),
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_type == model_type,
LoadBalancingModelConfig.model_name == model,
)
load_balancing_config_count = session.execute(stmt).scalar() or 0
@@ -1364,7 +1364,7 @@ class ProviderConfiguration(BaseModel):
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_type=model_type,
model_name=model,
load_balancing_enabled=True,
)
@@ -1391,7 +1391,7 @@ class ProviderConfiguration(BaseModel):
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_type=model_type,
model_name=model,
load_balancing_enabled=False,
)

View File

@@ -260,4 +260,12 @@ def convert_input_form_to_parameters(
parameters[item.variable]["enum"] = item.options
elif item.type == VariableEntityType.NUMBER:
parameters[item.variable]["type"] = "number"
elif item.type == VariableEntityType.CHECKBOX:
parameters[item.variable]["type"] = "boolean"
elif item.type == VariableEntityType.JSON_OBJECT:
parameters[item.variable]["type"] = "object"
if item.json_schema:
for key in ("properties", "required", "additionalProperties"):
if key in item.json_schema:
parameters[item.variable][key] = item.json_schema[key]
return parameters, required

View File

@@ -16,7 +16,13 @@ from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExport
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.util.instrumentation import InstrumentationScope
from opentelemetry.semconv.resource import ResourceAttributes
from opentelemetry.semconv._incubating.attributes.deployment_attributes import ( # type: ignore[import-untyped]
DEPLOYMENT_ENVIRONMENT,
)
from opentelemetry.semconv._incubating.attributes.host_attributes import ( # type: ignore[import-untyped]
HOST_NAME,
)
from opentelemetry.semconv.attributes import service_attributes
from opentelemetry.trace import Link, SpanContext, TraceFlags
from configs import dify_config
@@ -45,10 +51,10 @@ class TraceClient:
self.endpoint = endpoint
self.resource = Resource(
attributes={
ResourceAttributes.SERVICE_NAME: service_name,
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
ResourceAttributes.HOST_NAME: socket.gethostname(),
service_attributes.SERVICE_NAME: service_name,
service_attributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
HOST_NAME: socket.gethostname(),
ACS_ARMS_SERVICE_FEATURE: "genai_app",
}
)

View File

@@ -19,7 +19,7 @@ from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExport
from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes
from opentelemetry.semconv.attributes import exception_attributes
from opentelemetry.trace import Span, Status, StatusCode, set_span_in_context, use_span
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from opentelemetry.util.types import AttributeValue
@@ -134,10 +134,10 @@ def set_span_status(current_span: Span, error: Exception | str | None = None):
if not exception_message:
exception_message = repr(error)
attributes: dict[str, AttributeValue] = {
OTELSpanAttributes.EXCEPTION_TYPE: exception_type,
OTELSpanAttributes.EXCEPTION_MESSAGE: exception_message,
OTELSpanAttributes.EXCEPTION_ESCAPED: False,
OTELSpanAttributes.EXCEPTION_STACKTRACE: error_string,
exception_attributes.EXCEPTION_TYPE: exception_type,
exception_attributes.EXCEPTION_MESSAGE: exception_message,
exception_attributes.EXCEPTION_ESCAPED: False,
exception_attributes.EXCEPTION_STACKTRACE: error_string,
}
current_span.add_event(name="exception", attributes=attributes)
else:

View File

@@ -1,9 +1,19 @@
import logging
import os
from datetime import datetime, timedelta
import uuid
from datetime import UTC, datetime, timedelta
from graphon.enums import BuiltinNodeTypes
from langfuse import Langfuse
from langfuse.api import (
CreateGenerationBody,
CreateSpanBody,
IngestionEvent_GenerationCreate,
IngestionEvent_SpanCreate,
IngestionEvent_TraceCreate,
TraceBody,
)
from langfuse.api.commons.types.usage import Usage
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
@@ -396,18 +406,61 @@ class LangFuseDataTrace(BaseTraceInstance):
)
self.add_span(langfuse_span_data=name_generation_span_data)
def _make_event_id(self) -> str:
return str(uuid.uuid4())
def _now_iso(self) -> str:
return datetime.now(UTC).isoformat()
def add_trace(self, langfuse_trace_data: LangfuseTrace | None = None):
format_trace_data = filter_none_values(langfuse_trace_data.model_dump()) if langfuse_trace_data else {}
data = filter_none_values(langfuse_trace_data.model_dump()) if langfuse_trace_data else {}
try:
self.langfuse_client.trace(**format_trace_data)
body = TraceBody(
id=data.get("id"),
name=data.get("name"),
user_id=data.get("user_id"),
input=data.get("input"),
output=data.get("output"),
metadata=data.get("metadata"),
session_id=data.get("session_id"),
version=data.get("version"),
release=data.get("release"),
tags=data.get("tags"),
public=data.get("public"),
)
event = IngestionEvent_TraceCreate(
body=body,
id=self._make_event_id(),
timestamp=self._now_iso(),
)
self.langfuse_client.api.ingestion.batch(batch=[event])
logger.debug("LangFuse Trace created successfully")
except Exception as e:
raise ValueError(f"LangFuse Failed to create trace: {str(e)}")
def add_span(self, langfuse_span_data: LangfuseSpan | None = None):
format_span_data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {}
data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {}
try:
self.langfuse_client.span(**format_span_data)
body = CreateSpanBody(
id=data.get("id"),
trace_id=data.get("trace_id"),
name=data.get("name"),
start_time=data.get("start_time"),
end_time=data.get("end_time"),
input=data.get("input"),
output=data.get("output"),
metadata=data.get("metadata"),
level=data.get("level"),
status_message=data.get("status_message"),
parent_observation_id=data.get("parent_observation_id"),
version=data.get("version"),
)
event = IngestionEvent_SpanCreate(
body=body,
id=self._make_event_id(),
timestamp=self._now_iso(),
)
self.langfuse_client.api.ingestion.batch(batch=[event])
logger.debug("LangFuse Span created successfully")
except Exception as e:
raise ValueError(f"LangFuse Failed to create span: {str(e)}")
@@ -418,11 +471,45 @@ class LangFuseDataTrace(BaseTraceInstance):
span.end(**format_span_data)
def add_generation(self, langfuse_generation_data: LangfuseGeneration | None = None):
format_generation_data = (
filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {}
)
data = filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {}
try:
self.langfuse_client.generation(**format_generation_data)
usage_data = data.pop("usage", None)
usage = None
if usage_data:
usage = Usage(
input=usage_data.get("input", 0) or 0,
output=usage_data.get("output", 0) or 0,
total=usage_data.get("total", 0) or 0,
unit=usage_data.get("unit"),
input_cost=usage_data.get("inputCost"),
output_cost=usage_data.get("outputCost"),
total_cost=usage_data.get("totalCost"),
)
body = CreateGenerationBody(
id=data.get("id"),
trace_id=data.get("trace_id"),
name=data.get("name"),
start_time=data.get("start_time"),
end_time=data.get("end_time"),
model=data.get("model"),
model_parameters=data.get("model_parameters"),
input=data.get("input"),
output=data.get("output"),
usage=usage,
metadata=data.get("metadata"),
level=data.get("level"),
status_message=data.get("status_message"),
parent_observation_id=data.get("parent_observation_id"),
version=data.get("version"),
completion_start_time=data.get("completion_start_time"),
)
event = IngestionEvent_GenerationCreate(
body=body,
id=self._make_event_id(),
timestamp=self._now_iso(),
)
self.langfuse_client.api.ingestion.batch(batch=[event])
logger.debug("LangFuse Generation created successfully")
except Exception as e:
raise ValueError(f"LangFuse Failed to create generation: {str(e)}")
@@ -443,7 +530,7 @@ class LangFuseDataTrace(BaseTraceInstance):
def get_project_key(self):
try:
projects = self.langfuse_client.client.projects.get()
projects = self.langfuse_client.api.projects.get()
return projects.data[0].id
except Exception as e:
logger.debug("LangFuse get project key failed: %s", str(e))

View File

@@ -26,7 +26,13 @@ from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExport
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.semconv.resource import ResourceAttributes
from opentelemetry.semconv._incubating.attributes.deployment_attributes import ( # type: ignore[import-untyped]
DEPLOYMENT_ENVIRONMENT,
)
from opentelemetry.semconv._incubating.attributes.host_attributes import ( # type: ignore[import-untyped]
HOST_NAME,
)
from opentelemetry.semconv.attributes import service_attributes
from opentelemetry.trace import SpanKind
from opentelemetry.util.types import AttributeValue
@@ -73,13 +79,13 @@ class TencentTraceClient:
self.resource = Resource(
attributes={
ResourceAttributes.SERVICE_NAME: service_name,
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
ResourceAttributes.HOST_NAME: socket.gethostname(),
ResourceAttributes.TELEMETRY_SDK_LANGUAGE: "python",
ResourceAttributes.TELEMETRY_SDK_NAME: "opentelemetry",
ResourceAttributes.TELEMETRY_SDK_VERSION: _get_opentelemetry_sdk_version(),
service_attributes.SERVICE_NAME: service_name,
service_attributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
HOST_NAME: socket.gethostname(),
"telemetry.sdk.language": "python",
"telemetry.sdk.name": "opentelemetry",
"telemetry.sdk.version": _get_opentelemetry_sdk_version(),
}
)
# Prepare gRPC endpoint/metadata

View File

@@ -306,7 +306,7 @@ class ProviderManager:
"""
stmt = select(TenantDefaultModel).where(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
TenantDefaultModel.model_type == model_type,
)
default_model = db.session.scalar(stmt)
@@ -324,7 +324,7 @@ class ProviderManager:
default_model = TenantDefaultModel(
tenant_id=tenant_id,
model_type=model_type.to_origin_model_type(),
model_type=model_type,
provider_name=available_model.provider.provider,
model_name=available_model.model,
)
@@ -391,7 +391,7 @@ class ProviderManager:
raise ValueError(f"Model {model} does not exist.")
stmt = select(TenantDefaultModel).where(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
TenantDefaultModel.model_type == model_type,
)
default_model = db.session.scalar(stmt)
@@ -405,7 +405,7 @@ class ProviderManager:
# create default model
default_model = TenantDefaultModel(
tenant_id=tenant_id,
model_type=model_type.to_origin_model_type(),
model_type=model_type,
provider_name=provider,
model_name=model,
)
@@ -626,9 +626,8 @@ class ProviderManager:
if provider_record.provider_type != ProviderType.SYSTEM:
continue
provider_quota_to_provider_record_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = (
provider_record
)
if provider_record.quota_type is not None:
provider_quota_to_provider_record_dict[provider_record.quota_type] = provider_record
for quota in configuration.quotas:
if quota.quota_type in (ProviderQuotaType.TRIAL, ProviderQuotaType.PAID):
@@ -641,7 +640,7 @@ class ProviderManager:
# TODO: Use provider name with prefix after the data migration.
provider_name=ModelProviderID(provider_name).provider_name,
provider_type=ProviderType.SYSTEM,
quota_type=quota.quota_type,
quota_type=quota.quota_type, # type: ignore[arg-type]
quota_limit=0, # type: ignore
quota_used=0,
is_valid=True,
@@ -823,7 +822,7 @@ class ProviderManager:
custom_model_configurations.append(
CustomModelConfiguration(
model=provider_model_record.model_name,
model_type=ModelType.value_of(provider_model_record.model_type),
model_type=provider_model_record.model_type,
credentials=provider_model_credentials,
current_credential_id=provider_model_record.credential_id,
current_credential_name=provider_model_record.credential_name,
@@ -921,9 +920,8 @@ class ProviderManager:
if provider_record.provider_type != ProviderType.SYSTEM:
continue
quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = (
provider_record
)
if provider_record.quota_type is not None:
quota_type_to_provider_records_dict[provider_record.quota_type] = provider_record # type: ignore[index]
quota_configurations = []
if dify_config.EDITION == "CLOUD":
@@ -1203,7 +1201,7 @@ class ProviderManager:
model_settings.append(
ModelSettings(
model=provider_model_setting.model_name,
model_type=ModelType.value_of(provider_model_setting.model_type),
model_type=provider_model_setting.model_type,
enabled=provider_model_setting.enabled,
load_balancing_enabled=provider_model_setting.load_balancing_enabled,
load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [],

View File

@@ -97,13 +97,13 @@ class Jieba(BaseKeyword):
documents = []
segment_query_stmt = db.session.query(DocumentSegment).where(
segment_query_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id.in_(sorted_chunk_indices)
)
if document_ids_filter:
segment_query_stmt = segment_query_stmt.where(DocumentSegment.document_id.in_(document_ids_filter))
segments = db.session.execute(segment_query_stmt).scalars().all()
segments = db.session.scalars(segment_query_stmt).all()
segment_map = {segment.index_node_id: segment for segment in segments}
for chunk_index in sorted_chunk_indices:
segment = segment_map.get(chunk_index)

View File

@@ -432,10 +432,11 @@ class RetrievalService:
# Batch query dataset documents
dataset_documents = {
doc.id: doc
for doc in db.session.query(DatasetDocument)
.where(DatasetDocument.id.in_(document_ids))
.options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id))
.all()
for doc in db.session.scalars(
select(DatasetDocument)
.where(DatasetDocument.id.in_(document_ids))
.options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id))
).all()
}
valid_dataset_documents = {}

View File

@@ -426,11 +426,10 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
else:
idle_tidb_auth_binding = (
db.session.query(TidbAuthBinding)
idle_tidb_auth_binding = db.session.scalar(
select(TidbAuthBinding)
.where(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
.limit(1)
.one_or_none()
)
if idle_tidb_auth_binding:
idle_tidb_auth_binding.active = True

View File

@@ -277,7 +277,7 @@ class Vector:
return self._vector_processor.search_by_vector(query_vector, **kwargs)
def search_by_file(self, file_id: str, **kwargs: Any) -> list[Document]:
upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
upload_file: UploadFile | None = db.session.get(UploadFile, file_id)
if not upload_file:
return []

View File

@@ -4,7 +4,7 @@ from collections.abc import Sequence
from typing import Any
from graphon.model_runtime.entities.model_entities import ModelType
from sqlalchemy import func, select
from sqlalchemy import delete, func, select
from core.model_manager import ModelManager
from core.rag.index_processor.constant.index_type import IndexTechniqueType
@@ -63,10 +63,8 @@ class DatasetDocumentStore:
return output
def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False):
max_position = (
db.session.query(func.max(DocumentSegment.position))
.where(DocumentSegment.document_id == self._document_id)
.scalar()
max_position = db.session.scalar(
select(func.max(DocumentSegment.position)).where(DocumentSegment.document_id == self._document_id)
)
if max_position is None:
@@ -155,12 +153,14 @@ class DatasetDocumentStore:
)
if save_child and doc.children:
# delete the existing child chunks
db.session.query(ChildChunk).where(
ChildChunk.tenant_id == self._dataset.tenant_id,
ChildChunk.dataset_id == self._dataset.id,
ChildChunk.document_id == self._document_id,
ChildChunk.segment_id == segment_document.id,
).delete()
db.session.execute(
delete(ChildChunk).where(
ChildChunk.tenant_id == self._dataset.tenant_id,
ChildChunk.dataset_id == self._dataset.id,
ChildChunk.document_id == self._document_id,
ChildChunk.segment_id == segment_document.id,
)
)
# add new child chunks
for position, child in enumerate(doc.children, start=1):
child_segment = ChildChunk(

View File

@@ -6,6 +6,7 @@ from typing import Any, cast
import numpy as np
from graphon.model_runtime.entities.model_entities import ModelPropertyKey
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from configs import dify_config
@@ -31,14 +32,14 @@ class CacheEmbedding(Embeddings):
embedding_queue_indices = []
for i, text in enumerate(texts):
hash = helper.generate_text_hash(text)
embedding = (
db.session.query(Embedding)
.filter_by(
model_name=self._model_instance.model_name,
hash=hash,
provider_name=self._model_instance.provider,
embedding = db.session.scalar(
select(Embedding)
.where(
Embedding.model_name == self._model_instance.model_name,
Embedding.hash == hash,
Embedding.provider_name == self._model_instance.provider,
)
.first()
.limit(1)
)
if embedding:
text_embeddings[i] = embedding.get_embedding()
@@ -112,14 +113,14 @@ class CacheEmbedding(Embeddings):
embedding_queue_indices = []
for i, multimodel_document in enumerate(multimodel_documents):
file_id = multimodel_document["file_id"]
embedding = (
db.session.query(Embedding)
.filter_by(
model_name=self._model_instance.model_name,
hash=file_id,
provider_name=self._model_instance.provider,
embedding = db.session.scalar(
select(Embedding)
.where(
Embedding.model_name == self._model_instance.model_name,
Embedding.hash == file_id,
Embedding.provider_name == self._model_instance.provider,
)
.first()
.limit(1)
)
if embedding:
multimodel_embeddings[i] = embedding.get_embedding()

View File

@@ -4,6 +4,7 @@ import operator
from typing import Any, cast
import httpx
from sqlalchemy import update
from configs import dify_config
from core.rag.extractor.extractor_base import BaseExtractor
@@ -346,9 +347,11 @@ class NotionExtractor(BaseExtractor):
if data_source_info:
data_source_info["last_edited_time"] = last_edited_time
db.session.query(DocumentModel).filter_by(id=document_model.id).update(
{DocumentModel.data_source_info: json.dumps(data_source_info)}
) # type: ignore
db.session.execute(
update(DocumentModel)
.where(DocumentModel.id == document_model.id)
.values(data_source_info=json.dumps(data_source_info))
)
db.session.commit()
def get_notion_last_edited_time(self) -> str:

View File

@@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Any, NotRequired, Optional
from urllib.parse import unquote, urlparse
import httpx
from sqlalchemy import select
from typing_extensions import TypedDict
from configs import dify_config
@@ -200,7 +201,7 @@ class BaseIndexProcessor(ABC):
# Get unique IDs for database query
unique_upload_file_ids = list(set(upload_file_id_list))
upload_files = db.session.query(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids)).all()
upload_files = db.session.scalars(select(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids))).all()
# Create a mapping from ID to UploadFile for quick lookup
upload_file_map = {upload_file.id: upload_file for upload_file in upload_files}
@@ -312,7 +313,7 @@ class BaseIndexProcessor(ABC):
"""
from services.file_service import FileService
tool_file = db.session.query(ToolFile).where(ToolFile.id == tool_file_id).first()
tool_file = db.session.get(ToolFile, tool_file_id)
if not tool_file:
return None
blob = storage.load_once(tool_file.file_key)

View File

@@ -18,6 +18,7 @@ from graphon.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
from sqlalchemy import select
from core.app.file_access import DatabaseFileAccessController
from core.app.llm import deduct_llm_quota
@@ -145,14 +146,12 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
if delete_summaries:
if node_ids:
# Find segments by index_node_id
segments = (
db.session.query(DocumentSegment)
.filter(
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
)
.all()
)
).all()
segment_ids = [segment.id for segment in segments]
if segment_ids:
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
@@ -537,11 +536,9 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
# Get unique IDs for database query
unique_upload_file_ids = list(set(upload_file_id_list))
upload_files = (
db.session.query(UploadFile)
.where(UploadFile.id.in_(unique_upload_file_ids), UploadFile.tenant_id == tenant_id)
.all()
)
upload_files = db.session.scalars(
select(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids), UploadFile.tenant_id == tenant_id)
).all()
# Create File objects from UploadFile records
file_objects = []

View File

@@ -6,6 +6,8 @@ import uuid
from collections.abc import Mapping
from typing import Any
from sqlalchemy import delete, select
from configs import dify_config
from core.db.session_factory import session_factory
from core.entities.knowledge_entities import PreviewDetail
@@ -177,17 +179,16 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
child_node_ids = precomputed_child_node_ids
else:
# Fallback to original query (may fail if segments are already deleted)
child_node_ids = (
db.session.query(ChildChunk.index_node_id)
rows = db.session.execute(
select(ChildChunk.index_node_id)
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
.where(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
ChildChunk.dataset_id == dataset.id,
)
.all()
)
child_node_ids = [child_node_id[0] for child_node_id in child_node_ids if child_node_id[0]]
).all()
child_node_ids = [row[0] for row in rows if row[0]]
# Delete from vector index
if child_node_ids:
@@ -195,18 +196,22 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
# Delete from database
if delete_child_chunks and child_node_ids:
db.session.query(ChildChunk).where(
ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids)
).delete(synchronize_session=False)
db.session.execute(
delete(ChildChunk).where(
ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids)
)
)
db.session.commit()
else:
vector.delete()
if delete_child_chunks:
# Use existing compound index: (tenant_id, dataset_id, ...)
db.session.query(ChildChunk).where(
ChildChunk.tenant_id == dataset.tenant_id, ChildChunk.dataset_id == dataset.id
).delete(synchronize_session=False)
db.session.execute(
delete(ChildChunk).where(
ChildChunk.tenant_id == dataset.tenant_id, ChildChunk.dataset_id == dataset.id
)
)
db.session.commit()
def retrieve(

View File

@@ -134,9 +134,7 @@ class RerankModelRunner(BaseRerankRunner):
):
if document.metadata.get("doc_type") == DocType.IMAGE:
# Query file info within db.session context to ensure thread-safe access
upload_file = (
db.session.query(UploadFile).where(UploadFile.id == document.metadata["doc_id"]).first()
)
upload_file = db.session.get(UploadFile, document.metadata["doc_id"])
if upload_file:
blob = storage.load_once(upload_file.key)
document_file_base64 = base64.b64encode(blob).decode()
@@ -169,7 +167,7 @@ class RerankModelRunner(BaseRerankRunner):
return rerank_result, unique_documents
elif query_type == QueryType.IMAGE_QUERY:
# Query file info within db.session context to ensure thread-safe access
upload_file = db.session.query(UploadFile).where(UploadFile.id == query).first()
upload_file = db.session.get(UploadFile, query)
if upload_file:
blob = storage.load_once(upload_file.key)
file_query = base64.b64encode(blob).decode()

View File

@@ -1340,7 +1340,7 @@ class DatasetRetrieval:
metadata_filtering_conditions: MetadataFilteringCondition | None,
inputs: dict,
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]:
document_query = db.session.query(DatasetDocument).where(
document_query = select(DatasetDocument).where(
DatasetDocument.dataset_id.in_(dataset_ids),
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
@@ -1411,7 +1411,7 @@ class DatasetRetrieval:
document_query = document_query.where(and_(*filters))
else:
document_query = document_query.where(or_(*filters))
documents = document_query.all()
documents = db.session.scalars(document_query).all()
# group by dataset_id
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
for document in documents:

View File

@@ -27,7 +27,10 @@ from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.sdk.trace.sampling import ParentBasedTraceIdRatio
from opentelemetry.semconv.resource import ResourceAttributes
from opentelemetry.semconv._incubating.attributes.host_attributes import ( # type: ignore[import-untyped]
HOST_NAME,
)
from opentelemetry.semconv.attributes import service_attributes
from opentelemetry.trace import SpanContext, TraceFlags
from opentelemetry.util.types import Attributes, AttributeValue
@@ -114,8 +117,8 @@ class EnterpriseExporter:
resource = Resource(
attributes={
ResourceAttributes.SERVICE_NAME: service_name,
ResourceAttributes.HOST_NAME: socket.gethostname(),
service_attributes.SERVICE_NAME: service_name,
HOST_NAME: socket.gethostname(),
}
)
sampler = ParentBasedTraceIdRatio(sampling_rate)

View File

@@ -157,7 +157,7 @@ def handle(sender: Message, **kwargs):
tenant_id=tenant_id,
provider_name=ModelProviderID(model_config.provider).provider_name,
provider_type=ProviderType.SYSTEM.value,
quota_type=provider_configuration.system_configuration.current_quota_type.value,
quota_type=provider_configuration.system_configuration.current_quota_type,
),
values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
additional_filters=_ProviderUpdateAdditionalFilters(

View File

@@ -6,15 +6,24 @@ def init_app(app: DifyApp):
if dify_config.SENTRY_DSN:
import sentry_sdk
from graphon.model_runtime.errors.invoke import InvokeRateLimitError
from langfuse import parse_error
from sentry_sdk.integrations.celery import CeleryIntegration
from sentry_sdk.integrations.flask import FlaskIntegration
from werkzeug.exceptions import HTTPException
try:
from langfuse._utils import parse_error
_langfuse_error_response = parse_error.defaultErrorResponse
except (ImportError, AttributeError):
_langfuse_error_response = (
"Unexpected error occurred. Please check your request"
" and contact support: https://langfuse.com/support."
)
def before_send(event, hint):
if "exc_info" in hint:
_, exc_value, _ = hint["exc_info"]
if parse_error.defaultErrorResponse in str(exc_value):
if _langfuse_error_response in str(exc_value):
return None
return event
@@ -27,7 +36,7 @@ def init_app(app: DifyApp):
ValueError,
FileNotFoundError,
InvokeRateLimitError,
parse_error.defaultErrorResponse,
_langfuse_error_response,
],
traces_sample_rate=dify_config.SENTRY_TRACES_SAMPLE_RATE,
profiles_sample_rate=dify_config.SENTRY_PROFILES_SAMPLE_RATE,

View File

@@ -1,5 +1,7 @@
import contextlib
import logging
from collections.abc import Callable
from typing import Protocol, cast
import flask
from opentelemetry.instrumentation.celery import CeleryInstrumentor
@@ -21,6 +23,38 @@ from extensions.otel.runtime import is_celery_worker
logger = logging.getLogger(__name__)
class SupportsInstrument(Protocol):
def instrument(self, **kwargs: object) -> None: ...
class SupportsFlaskInstrumentor(Protocol):
def instrument_app(
self, app: DifyApp, response_hook: Callable[[Span, str, list], None] | None = None, **kwargs: object
) -> None: ...
# Some OpenTelemetry instrumentor constructors are typed loosely enough that
# pyrefly infers `NoneType`. Narrow the instances to just the methods we use
# while leaving runtime behavior unchanged.
def _new_celery_instrumentor() -> SupportsInstrument:
return cast(
SupportsInstrument,
CeleryInstrumentor(tracer_provider=get_tracer_provider(), meter_provider=get_meter_provider()),
)
def _new_httpx_instrumentor() -> SupportsInstrument:
return cast(SupportsInstrument, HTTPXClientInstrumentor())
def _new_redis_instrumentor() -> SupportsInstrument:
return cast(SupportsInstrument, RedisInstrumentor())
def _new_sqlalchemy_instrumentor() -> SupportsInstrument:
return cast(SupportsInstrument, SQLAlchemyInstrumentor())
class ExceptionLoggingHandler(logging.Handler):
"""
Handler that records exceptions to the current OpenTelemetry span.
@@ -97,7 +131,7 @@ def init_flask_instrumentor(app: DifyApp) -> None:
from opentelemetry.instrumentation.flask import FlaskInstrumentor
instrumentor = FlaskInstrumentor()
instrumentor = cast(SupportsFlaskInstrumentor, FlaskInstrumentor())
if dify_config.DEBUG:
logger.info("Initializing Flask instrumentor")
instrumentor.instrument_app(app, response_hook=response_hook)
@@ -106,21 +140,21 @@ def init_flask_instrumentor(app: DifyApp) -> None:
def init_sqlalchemy_instrumentor(app: DifyApp) -> None:
with app.app_context():
engines = list(app.extensions["sqlalchemy"].engines.values())
SQLAlchemyInstrumentor().instrument(enable_commenter=True, engines=engines)
_new_sqlalchemy_instrumentor().instrument(enable_commenter=True, engines=engines)
def init_redis_instrumentor() -> None:
RedisInstrumentor().instrument()
_new_redis_instrumentor().instrument()
def init_httpx_instrumentor() -> None:
HTTPXClientInstrumentor().instrument()
_new_httpx_instrumentor().instrument()
def init_instruments(app: DifyApp) -> None:
if not is_celery_worker():
init_flask_instrumentor(app)
CeleryInstrumentor(tracer_provider=get_tracer_provider(), meter_provider=get_meter_provider()).instrument()
_new_celery_instrumentor().instrument()
instrument_exception_logging()
init_sqlalchemy_instrumentor(app)

View File

@@ -6,6 +6,7 @@ from functools import cached_property
from uuid import uuid4
import sqlalchemy as sa
from graphon.model_runtime.entities.model_entities import ModelType
from sqlalchemy import DateTime, String, func, select, text
from sqlalchemy.orm import Mapped, mapped_column
@@ -13,7 +14,7 @@ from libs.uuid_utils import uuidv7
from .base import TypeBase
from .engine import db
from .enums import CredentialSourceType, PaymentStatus
from .enums import CredentialSourceType, PaymentStatus, ProviderQuotaType
from .types import EnumText, LongText, StringUUID
@@ -29,24 +30,6 @@ class ProviderType(StrEnum):
raise ValueError(f"No matching enum found for value '{value}'")
class ProviderQuotaType(StrEnum):
PAID = auto()
"""hosted paid quota"""
FREE = auto()
"""third-party free quota"""
TRIAL = auto()
"""hosted trial quota"""
@staticmethod
def value_of(value: str) -> ProviderQuotaType:
for member in ProviderQuotaType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class Provider(TypeBase):
"""
Provider model representing the API providers and their configurations.
@@ -77,7 +60,9 @@ class Provider(TypeBase):
last_used: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, init=False)
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
quota_type: Mapped[str | None] = mapped_column(String(40), nullable=True, server_default=text("''"), default="")
quota_type: Mapped[ProviderQuotaType | None] = mapped_column(
EnumText(ProviderQuotaType, length=40), nullable=True, server_default=text("''"), default=None
)
quota_limit: Mapped[int | None] = mapped_column(sa.BigInteger, nullable=True, default=None)
quota_used: Mapped[int | None] = mapped_column(sa.BigInteger, nullable=True, default=0)
@@ -147,7 +132,7 @@ class ProviderModel(TypeBase):
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
model_type: Mapped[ModelType] = mapped_column(EnumText(ModelType, length=40), nullable=False)
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"), default=False)
created_at: Mapped[datetime] = mapped_column(
@@ -189,7 +174,7 @@ class TenantDefaultModel(TypeBase):
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
model_type: Mapped[ModelType] = mapped_column(EnumText(ModelType, length=40), nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
@@ -269,7 +254,7 @@ class ProviderModelSetting(TypeBase):
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
model_type: Mapped[ModelType] = mapped_column(EnumText(ModelType, length=40), nullable=False)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"), default=True)
load_balancing_enabled: Mapped[bool] = mapped_column(
sa.Boolean, nullable=False, server_default=text("false"), default=False
@@ -299,7 +284,7 @@ class LoadBalancingModelConfig(TypeBase):
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
model_type: Mapped[ModelType] = mapped_column(EnumText(ModelType, length=40), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
encrypted_config: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
@@ -364,7 +349,7 @@ class ProviderModelCredential(TypeBase):
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
model_type: Mapped[ModelType] = mapped_column(EnumText(ModelType, length=40), nullable=False)
credential_name: Mapped[str] = mapped_column(String(255), nullable=False)
encrypted_config: Mapped[str] = mapped_column(LongText, nullable=False)
created_at: Mapped[datetime] = mapped_column(

View File

@@ -144,8 +144,8 @@ class EnumText(TypeDecorator[_E | None], Generic[_E]):
return dialect.type_descriptor(VARCHAR(self._length))
def process_result_value(self, value: str | None, dialect: Dialect) -> _E | None:
if value is None:
return value
if value is None or value == "":
return None
# Type annotation guarantees value is str at this point
return self._enum_class(value)

View File

@@ -33,7 +33,7 @@ dependencies = [
"httpx[socks]~=0.28.0",
"jieba==0.42.1",
"json-repair>=0.55.1",
"langfuse~=2.51.3",
"langfuse>=3.0.0,<5.0.0",
"langsmith~=0.7.16",
"markdown~=3.10.2",
"mlflow-skinny>=3.0.0",
@@ -41,23 +41,23 @@ dependencies = [
"openpyxl~=3.1.5",
"opik~=1.10.37",
"litellm==1.82.6", # Pinned to avoid madoka dependency issue
"opentelemetry-api==1.28.0",
"opentelemetry-distro==0.49b0",
"opentelemetry-exporter-otlp==1.28.0",
"opentelemetry-exporter-otlp-proto-common==1.28.0",
"opentelemetry-exporter-otlp-proto-grpc==1.28.0",
"opentelemetry-exporter-otlp-proto-http==1.28.0",
"opentelemetry-instrumentation==0.49b0",
"opentelemetry-instrumentation-celery==0.49b0",
"opentelemetry-instrumentation-flask==0.49b0",
"opentelemetry-instrumentation-httpx==0.49b0",
"opentelemetry-instrumentation-redis==0.49b0",
"opentelemetry-instrumentation-sqlalchemy==0.49b0",
"opentelemetry-api==1.40.0",
"opentelemetry-distro==0.61b0",
"opentelemetry-exporter-otlp==1.40.0",
"opentelemetry-exporter-otlp-proto-common==1.40.0",
"opentelemetry-exporter-otlp-proto-grpc==1.40.0",
"opentelemetry-exporter-otlp-proto-http==1.40.0",
"opentelemetry-instrumentation==0.61b0",
"opentelemetry-instrumentation-celery==0.61b0",
"opentelemetry-instrumentation-flask==0.61b0",
"opentelemetry-instrumentation-httpx==0.61b0",
"opentelemetry-instrumentation-redis==0.61b0",
"opentelemetry-instrumentation-sqlalchemy==0.61b0",
"opentelemetry-propagator-b3==1.40.0",
"opentelemetry-proto==1.28.0",
"opentelemetry-sdk==1.28.0",
"opentelemetry-semantic-conventions==0.49b0",
"opentelemetry-util-http==0.49b0",
"opentelemetry-proto==1.40.0",
"opentelemetry-sdk==1.40.0",
"opentelemetry-semantic-conventions==0.61b0",
"opentelemetry-util-http==0.61b0",
"pandas[excel,output-formatting,performance]~=3.0.1",
"psycogreen~=1.0.2",
"psycopg2-binary~=2.9.6",

View File

@@ -7,9 +7,19 @@ from datetime import UTC, datetime, timedelta
from hashlib import sha256
from typing import Any, cast
from pydantic import BaseModel
from pydantic import BaseModel, TypeAdapter
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from typing_extensions import TypedDict
class InvitationData(TypedDict):
account_id: str
email: str
workspace_id: str
_invitation_adapter: TypeAdapter[InvitationData] = TypeAdapter(InvitationData)
from werkzeug.exceptions import Unauthorized
from configs import dify_config
@@ -1571,7 +1581,7 @@ class RegisterService:
@classmethod
def get_invitation_by_token(
cls, token: str, workspace_id: str | None = None, email: str | None = None
) -> dict[str, str] | None:
) -> InvitationData | None:
if workspace_id is not None and email is not None:
email_hash = sha256(email.encode()).hexdigest()
cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}"
@@ -1590,7 +1600,7 @@ class RegisterService:
if not data:
return None
invitation: dict = json.loads(data)
invitation = _invitation_adapter.validate_json(data)
return invitation
@classmethod

View File

@@ -2,7 +2,7 @@ import json
import logging
import os
from collections.abc import Sequence
from typing import Literal
from typing import Literal, NotRequired
import httpx
from pydantic import TypeAdapter
@@ -47,6 +47,58 @@ class QuotaReleaseResult(TypedDict):
_quota_reserve_adapter = TypeAdapter(QuotaReserveResult)
_quota_commit_adapter = TypeAdapter(QuotaCommitResult)
_quota_release_adapter = TypeAdapter(QuotaReleaseResult)
class _BillingQuota(TypedDict):
size: int
limit: int
class _VectorSpaceQuota(TypedDict):
size: float
limit: int
class _KnowledgeRateLimit(TypedDict):
# NOTE (hj24):
# 1. Return for sandbox users but is null for other plans, it's defined but never used.
# 2. Keep it for compatibility for now, can be deprecated in future versions.
size: NotRequired[int]
# NOTE END
limit: int
class _BillingSubscription(TypedDict):
plan: str
interval: str
education: bool
class BillingInfo(TypedDict):
"""Response of /subscription/info.
NOTE (hj24):
- Fields not listed here (e.g. trigger_event, api_rate_limit) are stripped by TypeAdapter.validate_python()
- To ensure the precision, billing may convert fields like int as str, be careful when use TypeAdapter:
1. validate_python in non-strict mode will coerce it to the expected type
2. In strict mode, it will raise ValidationError
3. To preserve compatibility, always keep non-strict mode here and avoid strict mode
"""
enabled: bool
subscription: _BillingSubscription
members: _BillingQuota
apps: _BillingQuota
vector_space: _VectorSpaceQuota
knowledge_rate_limit: _KnowledgeRateLimit
documents_upload_quota: _BillingQuota
annotation_quota_limit: _BillingQuota
docs_processing: str
can_replace_logo: bool
model_load_balancing_enabled: bool
knowledge_pipeline_publish_enabled: bool
next_credit_reset_date: NotRequired[int]
_billing_info_adapter = TypeAdapter(BillingInfo)
class BillingService:
@@ -61,11 +113,11 @@ class BillingService:
_PLAN_CACHE_TTL = 600
@classmethod
def get_info(cls, tenant_id: str):
def get_info(cls, tenant_id: str) -> BillingInfo:
params = {"tenant_id": tenant_id}
billing_info = cls._send_request("GET", "/subscription/info", params=params)
return billing_info
return _billing_info_adapter.validate_python(billing_info)
@classmethod
def get_tenant_feature_plan_usage_info(cls, tenant_id: str):

View File

@@ -312,7 +312,10 @@ class FeatureService:
features.apps.limit = billing_info["apps"]["limit"]
if "vector_space" in billing_info:
features.vector_space.size = billing_info["vector_space"]["size"]
# NOTE (hj24): billing API returns vector_space.size as float (e.g. 0.0)
# but LimitationModel.size is int; truncate here for compatibility
features.vector_space.size = int(billing_info["vector_space"]["size"])
# NOTE END
features.vector_space.limit = billing_info["vector_space"]["limit"]
if "documents_upload_quota" in billing_info:
@@ -333,7 +336,11 @@ class FeatureService:
features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"]
if "knowledge_rate_limit" in billing_info:
# NOTE (hj24):
# 1. knowledge_rate_limit size is nullable, currently it's defined but never used, only limit is used.
# 2. So be careful if later we decide to use [size], we cannot assume it is always present.
features.knowledge_rate_limit = billing_info["knowledge_rate_limit"]["limit"]
# NOTE END
if "knowledge_pipeline_publish_enabled" in billing_info:
features.knowledge_pipeline.publish_enabled = billing_info["knowledge_pipeline_publish_enabled"]

View File

@@ -1,8 +1,8 @@
import json
from collections.abc import Sequence
from typing import Union
from graphon.model_runtime.entities.model_entities import ModelType
from pydantic import TypeAdapter
from sqlalchemy.orm import sessionmaker
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
@@ -17,7 +17,7 @@ from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import Account
from models.enums import FeedbackFromSource, FeedbackRating
from models.model import App, AppMode, AppModelConfig, EndUser, Message, MessageFeedback
from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, EndUser, Message, MessageFeedback
from repositories.execution_extra_content_repository import ExecutionExtraContentRepository
from repositories.sqlalchemy_execution_extra_content_repository import (
SQLAlchemyExecutionExtraContentRepository,
@@ -31,6 +31,8 @@ from services.errors.message import (
)
from services.workflow_service import WorkflowService
_app_model_config_adapter: TypeAdapter[AppModelConfigDict] = TypeAdapter(AppModelConfigDict)
def _create_execution_extra_content_repository() -> ExecutionExtraContentRepository:
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
@@ -286,7 +288,9 @@ class MessageService:
.first()
)
else:
conversation_override_model_configs = json.loads(conversation.override_model_configs)
conversation_override_model_configs = _app_model_config_adapter.validate_json(
conversation.override_model_configs
)
app_model_config = AppModelConfig(
app_id=app_model.id,
)

View File

@@ -1,7 +1,6 @@
import json
import logging
from json import JSONDecodeError
from typing import Union
from typing import Any, Union
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.entities.provider_entities import (
@@ -116,7 +115,7 @@ class ModelLoadBalancingService:
.where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_type == model_type_enum,
LoadBalancingModelConfig.model_name == model,
or_(
LoadBalancingModelConfig.credential_source_type == credential_source_type,
@@ -168,10 +167,10 @@ class ModelLoadBalancingService:
try:
if load_balancing_config.encrypted_config:
credentials: dict[str, object] = json.loads(load_balancing_config.encrypted_config)
credentials: dict[str, Any] = json.loads(load_balancing_config.encrypted_config)
else:
credentials = {}
except JSONDecodeError:
except (json.JSONDecodeError, ValueError):
credentials = {}
# Get provider credential secret variables
@@ -241,7 +240,7 @@ class ModelLoadBalancingService:
.where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_type == model_type_enum,
LoadBalancingModelConfig.model_name == model,
LoadBalancingModelConfig.id == config_id,
)
@@ -256,7 +255,7 @@ class ModelLoadBalancingService:
credentials = json.loads(load_balancing_model_config.encrypted_config)
else:
credentials = {}
except JSONDecodeError:
except (json.JSONDecodeError, ValueError):
credentials = {}
# Get credential form schemas from model credential schema or provider credential schema
@@ -289,7 +288,7 @@ class ModelLoadBalancingService:
inherit_config = LoadBalancingModelConfig(
tenant_id=tenant_id,
provider_name=provider,
model_type=model_type.to_origin_model_type(),
model_type=model_type,
model_name=model,
name="__inherit__",
)
@@ -329,7 +328,7 @@ class ModelLoadBalancingService:
select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_type == model_type_enum,
LoadBalancingModelConfig.model_name == model,
)
).all()
@@ -369,7 +368,7 @@ class ModelLoadBalancingService:
tenant_id=tenant_id,
provider_name=provider_configuration.provider.provider,
model_name=model,
model_type=model_type_enum.to_origin_model_type(),
model_type=model_type_enum,
)
.first()
)
@@ -433,7 +432,7 @@ class ModelLoadBalancingService:
load_balancing_model_config = LoadBalancingModelConfig(
tenant_id=tenant_id,
provider_name=provider_configuration.provider.provider,
model_type=model_type_enum.to_origin_model_type(),
model_type=model_type_enum,
model_name=model,
name=credential_record.credential_name,
encrypted_config=credential_record.encrypted_config,
@@ -461,7 +460,7 @@ class ModelLoadBalancingService:
load_balancing_model_config = LoadBalancingModelConfig(
tenant_id=tenant_id,
provider_name=provider_configuration.provider.provider,
model_type=model_type_enum.to_origin_model_type(),
model_type=model_type_enum,
model_name=model,
name=name,
encrypted_config=json.dumps(credentials),
@@ -516,7 +515,7 @@ class ModelLoadBalancingService:
.where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_type == model_type_enum,
LoadBalancingModelConfig.model_name == model,
LoadBalancingModelConfig.id == config_id,
)
@@ -575,7 +574,7 @@ class ModelLoadBalancingService:
original_credentials = json.loads(load_balancing_model_config.encrypted_config)
else:
original_credentials = {}
except JSONDecodeError:
except (json.JSONDecodeError, ValueError):
original_credentials = {}
# encrypt credentials

View File

@@ -12,7 +12,9 @@ import click
import sqlalchemy as sa
import tqdm
from flask import Flask, current_app
from pydantic import TypeAdapter
from sqlalchemy.orm import Session
from typing_extensions import TypedDict
from core.agent.entities import AgentToolEntity
from core.helper import marketplace
@@ -33,6 +35,14 @@ logger = logging.getLogger(__name__)
excluded_providers = ["time", "audio", "code", "webscraper"]
class _TenantPluginRecord(TypedDict):
tenant_id: str
plugins: list[str]
_tenant_plugin_adapter: TypeAdapter[_TenantPluginRecord] = TypeAdapter(_TenantPluginRecord)
class PluginMigration:
@classmethod
def extract_plugins(cls, filepath: str, workers: int):
@@ -308,9 +318,8 @@ class PluginMigration:
logger.info("Extracting unique plugins from %s", extracted_plugins)
with open(extracted_plugins) as f:
for line in f:
data = json.loads(line)
new_plugin_ids = data.get("plugins", [])
for plugin_id in new_plugin_ids:
data = _tenant_plugin_adapter.validate_json(line)
for plugin_id in data["plugins"]:
if plugin_id not in plugin_ids:
plugin_ids.append(plugin_id)
@@ -381,21 +390,23 @@ class PluginMigration:
Read line by line, and install plugins for each tenant.
"""
for line in f:
data = json.loads(line)
tenant_id = data.get("tenant_id")
plugin_ids = data.get("plugins", [])
current_not_installed = {
"tenant_id": tenant_id,
"plugin_not_exist": [],
}
data = _tenant_plugin_adapter.validate_json(line)
tenant_id = data["tenant_id"]
plugin_ids = data["plugins"]
plugin_not_exist: list[str] = []
# get plugin unique identifier
for plugin_id in plugin_ids:
unique_identifier = plugins.get(plugin_id)
if unique_identifier:
current_not_installed["plugin_not_exist"].append(plugin_id)
plugin_not_exist.append(plugin_id)
if current_not_installed["plugin_not_exist"]:
not_installed.append(current_not_installed)
if plugin_not_exist:
not_installed.append(
{
"tenant_id": tenant_id,
"plugin_not_exist": plugin_not_exist,
}
)
thread_pool.submit(install, tenant_id, plugin_ids)

View File

@@ -6,7 +6,6 @@ back to the database.
"""
import io
import json
import logging
import time
import zipfile
@@ -17,8 +16,23 @@ from datetime import datetime
from typing import Any, cast
import click
from pydantic import TypeAdapter
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.engine import CursorResult
from typing_extensions import TypedDict
class _TableInfo(TypedDict, total=False):
row_count: int
class ArchiveManifest(TypedDict, total=False):
tables: dict[str, _TableInfo]
schema_version: str
_manifest_adapter: TypeAdapter[ArchiveManifest] = TypeAdapter(ArchiveManifest)
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
from extensions.ext_database import db
@@ -239,12 +253,12 @@ class WorkflowRunRestore:
return self.workflow_run_repo
@staticmethod
def _load_manifest_from_zip(archive: zipfile.ZipFile) -> dict[str, Any]:
def _load_manifest_from_zip(archive: zipfile.ZipFile) -> ArchiveManifest:
try:
data = archive.read("manifest.json")
except KeyError as e:
raise ValueError("manifest.json missing from archive bundle") from e
return json.loads(data.decode("utf-8"))
return _manifest_adapter.validate_json(data)
def _restore_table_records(
self,
@@ -332,7 +346,7 @@ class WorkflowRunRestore:
return result
def _get_schema_version(self, manifest: dict[str, Any]) -> str:
def _get_schema_version(self, manifest: ArchiveManifest) -> str:
schema_version = manifest.get("schema_version")
if not schema_version:
logger.warning("Manifest missing schema_version; defaulting to 1.0")

View File

@@ -3,7 +3,7 @@ import logging
from collections.abc import Mapping
from typing import Any, Union
from pydantic import ValidationError
from pydantic import TypeAdapter, ValidationError
from yarl import URL
from configs import dify_config
@@ -31,6 +31,8 @@ from services.plugin.plugin_service import PluginService
logger = logging.getLogger(__name__)
_mcp_tools_adapter: TypeAdapter[list[MCPTool]] = TypeAdapter(list[MCPTool])
class ToolTransformService:
_MCP_SCHEMA_TYPE_RESOLUTION_MAX_DEPTH = 10
@@ -53,7 +55,7 @@ class ToolTransformService:
if isinstance(icon, str):
return json.loads(icon)
return icon
except Exception:
except (json.JSONDecodeError, ValueError):
return {"background": "#252525", "content": "\ud83d\ude01"}
elif provider_type == ToolProviderType.MCP:
return icon
@@ -247,8 +249,8 @@ class ToolTransformService:
response = provider_entity.to_api_response(user_name=user_name, include_sensitive=include_sensitive)
try:
mcp_tools = [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
except (ValidationError, json.JSONDecodeError):
mcp_tools = _mcp_tools_adapter.validate_json(db_provider.tools)
except (ValidationError, ValueError):
mcp_tools = []
# Add additional fields specific to the transform
response["id"] = db_provider.server_identifier if not for_list else db_provider.id

View File

@@ -1,4 +1,3 @@
import json
import logging
import uuid
from collections.abc import Mapping
@@ -7,6 +6,7 @@ from datetime import datetime
from typing import Any
from flask import Request, Response
from pydantic import TypeAdapter
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.entities.request import TriggerDispatchResponse
@@ -29,6 +29,8 @@ from services.trigger.trigger_provider_service import TriggerProviderService
logger = logging.getLogger(__name__)
_request_logs_adapter: TypeAdapter[list[RequestLog]] = TypeAdapter(list[RequestLog])
class TriggerSubscriptionBuilderService:
"""Service for managing trigger providers and credentials"""
@@ -398,7 +400,7 @@ class TriggerSubscriptionBuilderService:
cache_key = cls.encode_cache_key(endpoint_id)
subscription_cache = redis_client.get(cache_key)
if subscription_cache:
return SubscriptionBuilder.model_validate(json.loads(subscription_cache))
return SubscriptionBuilder.model_validate_json(subscription_cache)
return None
@@ -423,12 +425,16 @@ class TriggerSubscriptionBuilderService:
)
key = f"trigger:subscription:builder:logs:{endpoint_id}"
logs = json.loads(redis_client.get(key) or "[]")
logs.append(log.model_dump(mode="json"))
logs = _request_logs_adapter.validate_json(redis_client.get(key) or b"[]")
logs.append(log)
# Keep last N logs
logs = logs[-cls.__VALIDATION_REQUEST_CACHE_COUNT__ :]
redis_client.setex(key, cls.__VALIDATION_REQUEST_CACHE_EXPIRE_SECONDS__, json.dumps(logs, default=str))
redis_client.setex(
key,
cls.__VALIDATION_REQUEST_CACHE_EXPIRE_SECONDS__,
_request_logs_adapter.dump_json(logs),
)
@classmethod
def list_logs(cls, endpoint_id: str) -> list[RequestLog]:
@@ -437,7 +443,7 @@ class TriggerSubscriptionBuilderService:
logs_json = redis_client.get(key)
if not logs_json:
return []
return [RequestLog.model_validate(log) for log in json.loads(logs_json)]
return _request_logs_adapter.validate_json(logs_json)
@classmethod
def process_builder_validation_endpoint(cls, endpoint_id: str, request: Request) -> Response | None:

View File

@@ -1118,7 +1118,7 @@ class WorkflowService:
continue
try:
payload = json.loads(recipient.recipient_payload)
except Exception:
except (json.JSONDecodeError, ValueError):
logger.exception("Failed to parse human input recipient payload for delivery test.")
continue
email = payload.get("email")

View File

@@ -0,0 +1,182 @@
import pytest
from sqlalchemy import delete
from core.db.session_factory import session_factory
from models import Tenant
from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
from services.plugin.plugin_permission_service import PluginPermissionService
@pytest.fixture
def tenant(flask_req_ctx):
with session_factory.create_session() as session:
t = Tenant(name="plugin_it_tenant")
session.add(t)
session.commit()
tenant_id = t.id
yield tenant_id
with session_factory.create_session() as session:
session.execute(delete(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id))
session.execute(
delete(TenantPluginAutoUpgradeStrategy).where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
)
session.execute(delete(Tenant).where(Tenant.id == tenant_id))
session.commit()
class TestPluginPermissionLifecycle:
def test_get_returns_none_for_new_tenant(self, tenant):
assert PluginPermissionService.get_permission(tenant) is None
def test_change_creates_row(self, tenant):
result = PluginPermissionService.change_permission(
tenant,
TenantPluginPermission.InstallPermission.ADMINS,
TenantPluginPermission.DebugPermission.EVERYONE,
)
assert result is True
perm = PluginPermissionService.get_permission(tenant)
assert perm is not None
assert perm.install_permission == TenantPluginPermission.InstallPermission.ADMINS
assert perm.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE
def test_change_updates_existing_row(self, tenant):
PluginPermissionService.change_permission(
tenant,
TenantPluginPermission.InstallPermission.ADMINS,
TenantPluginPermission.DebugPermission.NOBODY,
)
PluginPermissionService.change_permission(
tenant,
TenantPluginPermission.InstallPermission.EVERYONE,
TenantPluginPermission.DebugPermission.ADMINS,
)
perm = PluginPermissionService.get_permission(tenant)
assert perm is not None
assert perm.install_permission == TenantPluginPermission.InstallPermission.EVERYONE
assert perm.debug_permission == TenantPluginPermission.DebugPermission.ADMINS
with session_factory.create_session() as session:
count = session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant).count()
assert count == 1
class TestPluginAutoUpgradeLifecycle:
def test_get_returns_none_for_new_tenant(self, tenant):
assert PluginAutoUpgradeService.get_strategy(tenant) is None
def test_change_creates_row(self, tenant):
result = PluginAutoUpgradeService.change_strategy(
tenant,
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
upgrade_time_of_day=3,
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
exclude_plugins=[],
include_plugins=[],
)
assert result is True
strategy = PluginAutoUpgradeService.get_strategy(tenant)
assert strategy is not None
assert strategy.strategy_setting == TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST
assert strategy.upgrade_time_of_day == 3
def test_change_updates_existing_row(self, tenant):
PluginAutoUpgradeService.change_strategy(
tenant,
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
upgrade_time_of_day=0,
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
exclude_plugins=[],
include_plugins=[],
)
PluginAutoUpgradeService.change_strategy(
tenant,
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
upgrade_time_of_day=12,
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL,
exclude_plugins=[],
include_plugins=["plugin-a"],
)
strategy = PluginAutoUpgradeService.get_strategy(tenant)
assert strategy is not None
assert strategy.strategy_setting == TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST
assert strategy.upgrade_time_of_day == 12
assert strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL
assert strategy.include_plugins == ["plugin-a"]
def test_exclude_plugin_creates_strategy_when_none_exists(self, tenant):
PluginAutoUpgradeService.exclude_plugin(tenant, "my-plugin")
strategy = PluginAutoUpgradeService.get_strategy(tenant)
assert strategy is not None
assert strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE
assert "my-plugin" in strategy.exclude_plugins
def test_exclude_plugin_appends_in_exclude_mode(self, tenant):
PluginAutoUpgradeService.change_strategy(
tenant,
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
upgrade_time_of_day=0,
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
exclude_plugins=["existing"],
include_plugins=[],
)
PluginAutoUpgradeService.exclude_plugin(tenant, "new-plugin")
strategy = PluginAutoUpgradeService.get_strategy(tenant)
assert strategy is not None
assert "existing" in strategy.exclude_plugins
assert "new-plugin" in strategy.exclude_plugins
def test_exclude_plugin_dedup_in_exclude_mode(self, tenant):
PluginAutoUpgradeService.change_strategy(
tenant,
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
upgrade_time_of_day=0,
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
exclude_plugins=["same-plugin"],
include_plugins=[],
)
PluginAutoUpgradeService.exclude_plugin(tenant, "same-plugin")
strategy = PluginAutoUpgradeService.get_strategy(tenant)
assert strategy is not None
assert strategy.exclude_plugins.count("same-plugin") == 1
def test_exclude_from_partial_mode_removes_from_include(self, tenant):
PluginAutoUpgradeService.change_strategy(
tenant,
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
upgrade_time_of_day=0,
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL,
exclude_plugins=[],
include_plugins=["p1", "p2"],
)
PluginAutoUpgradeService.exclude_plugin(tenant, "p1")
strategy = PluginAutoUpgradeService.get_strategy(tenant)
assert strategy is not None
assert "p1" not in strategy.include_plugins
assert "p2" in strategy.include_plugins
def test_exclude_from_all_mode_switches_to_exclude(self, tenant):
PluginAutoUpgradeService.change_strategy(
tenant,
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
upgrade_time_of_day=0,
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
exclude_plugins=[],
include_plugins=[],
)
PluginAutoUpgradeService.exclude_plugin(tenant, "excluded-plugin")
strategy = PluginAutoUpgradeService.get_strategy(tenant)
assert strategy is not None
assert strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE
assert "excluded-plugin" in strategy.exclude_plugins

View File

@@ -0,0 +1,348 @@
import datetime
import math
import uuid
import pytest
from sqlalchemy import delete
from core.db.session_factory import session_factory
from models import Tenant
from models.enums import FeedbackFromSource, FeedbackRating
from models.model import (
App,
Conversation,
Message,
MessageAnnotation,
MessageFeedback,
)
from services.retention.conversation.messages_clean_policy import BillingDisabledPolicy
from services.retention.conversation.messages_clean_service import MessagesCleanService
_NOW = datetime.datetime(2026, 1, 15, 12, 0, 0, tzinfo=datetime.UTC)
_OLD = _NOW - datetime.timedelta(days=60)
_VERY_OLD = _NOW - datetime.timedelta(days=90)
_RECENT = _NOW - datetime.timedelta(days=5)
_WINDOW_START = _VERY_OLD - datetime.timedelta(hours=1)
_WINDOW_END = _RECENT + datetime.timedelta(hours=1)
_DEFAULT_BATCH_SIZE = 100
_PAGINATION_MESSAGE_COUNT = 25
_PAGINATION_BATCH_SIZE = 8
@pytest.fixture
def tenant_and_app(flask_req_ctx):
"""Creates a Tenant, App and Conversation for the test and cleans up after."""
with session_factory.create_session() as session:
tenant = Tenant(name="retention_it_tenant")
session.add(tenant)
session.flush()
app = App(
tenant_id=tenant.id,
name="Retention IT App",
mode="chat",
enable_site=True,
enable_api=True,
)
session.add(app)
session.flush()
conv = Conversation(
app_id=app.id,
mode="chat",
name="test_conv",
status="normal",
from_source="console",
_inputs={},
)
session.add(conv)
session.commit()
tenant_id = tenant.id
app_id = app.id
conv_id = conv.id
yield {"tenant_id": tenant_id, "app_id": app_id, "conversation_id": conv_id}
with session_factory.create_session() as session:
session.execute(delete(Conversation).where(Conversation.id == conv_id))
session.execute(delete(App).where(App.id == app_id))
session.execute(delete(Tenant).where(Tenant.id == tenant_id))
session.commit()
def _make_message(app_id: str, conversation_id: str, created_at: datetime.datetime) -> Message:
return Message(
app_id=app_id,
conversation_id=conversation_id,
query="test",
message=[{"text": "hello"}],
answer="world",
message_tokens=1,
message_unit_price=0,
answer_tokens=1,
answer_unit_price=0,
from_source="console",
currency="USD",
_inputs={},
created_at=created_at,
)
class TestMessagesCleanServiceIntegration:
@pytest.fixture
def seed_messages(self, tenant_and_app):
"""Seeds one message at each of _VERY_OLD, _OLD, and _RECENT.
Yields a semantic mapping keyed by age label.
"""
data = tenant_and_app
app_id = data["app_id"]
conv_id = data["conversation_id"]
# Ordered tuple of (label, timestamp) for deterministic seeding
timestamps = [
("very_old", _VERY_OLD),
("old", _OLD),
("recent", _RECENT),
]
msg_ids: dict[str, str] = {}
with session_factory.create_session() as session:
for label, ts in timestamps:
msg = _make_message(app_id, conv_id, ts)
session.add(msg)
session.flush()
msg_ids[label] = msg.id
session.commit()
yield {"msg_ids": msg_ids, **data}
with session_factory.create_session() as session:
session.execute(
delete(Message)
.where(Message.id.in_(list(msg_ids.values())))
.execution_options(synchronize_session=False)
)
session.commit()
@pytest.fixture
def paginated_seed_messages(self, tenant_and_app):
"""Seeds multiple messages separated by 1-second increments starting at _OLD."""
data = tenant_and_app
app_id = data["app_id"]
conv_id = data["conversation_id"]
msg_ids: list[str] = []
with session_factory.create_session() as session:
for i in range(_PAGINATION_MESSAGE_COUNT):
ts = _OLD + datetime.timedelta(seconds=i)
msg = _make_message(app_id, conv_id, ts)
session.add(msg)
session.flush()
msg_ids.append(msg.id)
session.commit()
yield {"msg_ids": msg_ids, **data}
with session_factory.create_session() as session:
session.execute(delete(Message).where(Message.id.in_(msg_ids)).execution_options(synchronize_session=False))
session.commit()
@pytest.fixture
def cascade_test_data(self, tenant_and_app):
"""Seeds one Message with an associated Feedback and Annotation."""
data = tenant_and_app
app_id = data["app_id"]
conv_id = data["conversation_id"]
with session_factory.create_session() as session:
msg = _make_message(app_id, conv_id, _OLD)
session.add(msg)
session.flush()
feedback = MessageFeedback(
app_id=app_id,
conversation_id=conv_id,
message_id=msg.id,
rating=FeedbackRating.LIKE,
from_source=FeedbackFromSource.USER,
)
annotation = MessageAnnotation(
app_id=app_id,
conversation_id=conv_id,
message_id=msg.id,
question="q",
content="a",
account_id=str(uuid.uuid4()),
)
session.add_all([feedback, annotation])
session.commit()
msg_id = msg.id
fb_id = feedback.id
ann_id = annotation.id
yield {"msg_id": msg_id, "fb_id": fb_id, "ann_id": ann_id, **data}
with session_factory.create_session() as session:
session.execute(delete(MessageAnnotation).where(MessageAnnotation.id == ann_id))
session.execute(delete(MessageFeedback).where(MessageFeedback.id == fb_id))
session.execute(delete(Message).where(Message.id == msg_id))
session.commit()
def test_dry_run_does_not_delete(self, seed_messages):
"""Dry-run must count eligible rows without deleting any of them."""
data = seed_messages
msg_ids = data["msg_ids"]
all_ids = list(msg_ids.values())
svc = MessagesCleanService.from_time_range(
policy=BillingDisabledPolicy(),
start_from=_WINDOW_START,
end_before=_WINDOW_END,
batch_size=_DEFAULT_BATCH_SIZE,
dry_run=True,
)
stats = svc.run()
assert stats["filtered_messages"] == len(all_ids)
assert stats["total_deleted"] == 0
with session_factory.create_session() as session:
remaining = session.query(Message).where(Message.id.in_(all_ids)).count()
assert remaining == len(all_ids)
def test_billing_disabled_deletes_all_in_range(self, seed_messages):
"""All 3 seeded messages fall within the window and must be deleted."""
data = seed_messages
msg_ids = data["msg_ids"]
all_ids = list(msg_ids.values())
svc = MessagesCleanService.from_time_range(
policy=BillingDisabledPolicy(),
start_from=_WINDOW_START,
end_before=_WINDOW_END,
batch_size=_DEFAULT_BATCH_SIZE,
dry_run=False,
)
stats = svc.run()
assert stats["total_deleted"] == len(all_ids)
with session_factory.create_session() as session:
remaining = session.query(Message).where(Message.id.in_(all_ids)).count()
assert remaining == 0
def test_start_from_filters_correctly(self, seed_messages):
"""Only the message at _OLD falls within the narrow ±1 h window."""
data = seed_messages
msg_ids = data["msg_ids"]
start = _OLD - datetime.timedelta(hours=1)
end = _OLD + datetime.timedelta(hours=1)
svc = MessagesCleanService.from_time_range(
policy=BillingDisabledPolicy(),
start_from=start,
end_before=end,
batch_size=_DEFAULT_BATCH_SIZE,
)
stats = svc.run()
assert stats["total_deleted"] == 1
with session_factory.create_session() as session:
all_ids = list(msg_ids.values())
remaining_ids = {r[0] for r in session.query(Message.id).where(Message.id.in_(all_ids)).all()}
assert msg_ids["old"] not in remaining_ids
assert msg_ids["very_old"] in remaining_ids
assert msg_ids["recent"] in remaining_ids
def test_cursor_pagination_across_batches(self, paginated_seed_messages):
"""Messages must be deleted across multiple batches."""
data = paginated_seed_messages
msg_ids = data["msg_ids"]
# _OLD is the earliest; the last one is _OLD + (_PAGINATION_MESSAGE_COUNT - 1) s.
pagination_window_start = _OLD - datetime.timedelta(seconds=1)
pagination_window_end = _OLD + datetime.timedelta(seconds=_PAGINATION_MESSAGE_COUNT)
svc = MessagesCleanService.from_time_range(
policy=BillingDisabledPolicy(),
start_from=pagination_window_start,
end_before=pagination_window_end,
batch_size=_PAGINATION_BATCH_SIZE,
)
stats = svc.run()
assert stats["total_deleted"] == _PAGINATION_MESSAGE_COUNT
expected_batches = math.ceil(_PAGINATION_MESSAGE_COUNT / _PAGINATION_BATCH_SIZE)
assert stats["batches"] >= expected_batches
with session_factory.create_session() as session:
remaining = session.query(Message).where(Message.id.in_(msg_ids)).count()
assert remaining == 0
def test_no_messages_in_range_returns_empty_stats(self, seed_messages):
"""A window entirely in the future must yield zero matches."""
far_future = _NOW + datetime.timedelta(days=365)
even_further = far_future + datetime.timedelta(days=1)
svc = MessagesCleanService.from_time_range(
policy=BillingDisabledPolicy(),
start_from=far_future,
end_before=even_further,
batch_size=_DEFAULT_BATCH_SIZE,
)
stats = svc.run()
assert stats["total_messages"] == 0
assert stats["total_deleted"] == 0
def test_relation_cascade_deletes(self, cascade_test_data):
"""Deleting a Message must cascade to its Feedback and Annotation rows."""
data = cascade_test_data
msg_id = data["msg_id"]
fb_id = data["fb_id"]
ann_id = data["ann_id"]
svc = MessagesCleanService.from_time_range(
policy=BillingDisabledPolicy(),
start_from=_OLD - datetime.timedelta(hours=1),
end_before=_OLD + datetime.timedelta(hours=1),
batch_size=_DEFAULT_BATCH_SIZE,
)
stats = svc.run()
assert stats["total_deleted"] == 1
with session_factory.create_session() as session:
assert session.query(Message).where(Message.id == msg_id).count() == 0
assert session.query(MessageFeedback).where(MessageFeedback.id == fb_id).count() == 0
assert session.query(MessageAnnotation).where(MessageAnnotation.id == ann_id).count() == 0
def test_factory_from_time_range_validation(self):
with pytest.raises(ValueError, match="start_from"):
MessagesCleanService.from_time_range(
policy=BillingDisabledPolicy(),
start_from=_NOW,
end_before=_OLD,
)
def test_factory_from_days_validation(self):
with pytest.raises(ValueError, match="days"):
MessagesCleanService.from_days(
policy=BillingDisabledPolicy(),
days=-1,
)
def test_factory_batch_size_validation(self):
with pytest.raises(ValueError, match="batch_size"):
MessagesCleanService.from_time_range(
policy=BillingDisabledPolicy(),
start_from=_OLD,
end_before=_NOW,
batch_size=0,
)

View File

@@ -0,0 +1,177 @@
import datetime
import io
import json
import uuid
import zipfile
from unittest.mock import MagicMock, patch
import pytest
from services.retention.workflow_run.archive_paid_plan_workflow_run import (
ArchiveSummary,
WorkflowRunArchiver,
)
from services.retention.workflow_run.constants import ARCHIVE_SCHEMA_VERSION
class TestWorkflowRunArchiverInit:
def test_start_from_without_end_before_raises(self):
with pytest.raises(ValueError, match="start_from and end_before must be provided together"):
WorkflowRunArchiver(start_from=datetime.datetime(2025, 1, 1))
def test_end_before_without_start_from_raises(self):
with pytest.raises(ValueError, match="start_from and end_before must be provided together"):
WorkflowRunArchiver(end_before=datetime.datetime(2025, 1, 1))
def test_start_equals_end_raises(self):
ts = datetime.datetime(2025, 1, 1)
with pytest.raises(ValueError, match="start_from must be earlier than end_before"):
WorkflowRunArchiver(start_from=ts, end_before=ts)
def test_start_after_end_raises(self):
with pytest.raises(ValueError, match="start_from must be earlier than end_before"):
WorkflowRunArchiver(
start_from=datetime.datetime(2025, 6, 1),
end_before=datetime.datetime(2025, 1, 1),
)
def test_workers_zero_raises(self):
with pytest.raises(ValueError, match="workers must be at least 1"):
WorkflowRunArchiver(workers=0)
def test_valid_init_defaults(self):
archiver = WorkflowRunArchiver(days=30, batch_size=50)
assert archiver.days == 30
assert archiver.batch_size == 50
assert archiver.dry_run is False
assert archiver.delete_after_archive is False
assert archiver.start_from is None
def test_valid_init_with_time_range(self):
start = datetime.datetime(2025, 1, 1)
end = datetime.datetime(2025, 6, 1)
archiver = WorkflowRunArchiver(start_from=start, end_before=end, workers=2)
assert archiver.start_from is not None
assert archiver.end_before is not None
assert archiver.workers == 2
class TestBuildArchiveBundle:
def test_bundle_contains_manifest_and_all_tables(self):
archiver = WorkflowRunArchiver(days=90)
manifest_data = json.dumps({"schema_version": ARCHIVE_SCHEMA_VERSION}).encode("utf-8")
table_payloads = dict.fromkeys(archiver.ARCHIVED_TABLES, b"")
bundle_bytes = archiver._build_archive_bundle(manifest_data, table_payloads)
with zipfile.ZipFile(io.BytesIO(bundle_bytes), "r") as zf:
names = set(zf.namelist())
assert "manifest.json" in names
for table in archiver.ARCHIVED_TABLES:
assert f"{table}.jsonl" in names, f"Missing {table}.jsonl in bundle"
def test_bundle_missing_table_payload_raises(self):
archiver = WorkflowRunArchiver(days=90)
manifest_data = b"{}"
incomplete_payloads = {archiver.ARCHIVED_TABLES[0]: b"data"}
with pytest.raises(ValueError, match="Missing archive payload"):
archiver._build_archive_bundle(manifest_data, incomplete_payloads)
class TestGenerateManifest:
def test_manifest_structure(self):
archiver = WorkflowRunArchiver(days=90)
from services.retention.workflow_run.archive_paid_plan_workflow_run import TableStats
run = MagicMock()
run.id = str(uuid.uuid4())
run.tenant_id = str(uuid.uuid4())
run.app_id = str(uuid.uuid4())
run.workflow_id = str(uuid.uuid4())
run.created_at = datetime.datetime(2025, 3, 15, 10, 0, 0)
stats = [
TableStats(table_name="workflow_runs", row_count=1, checksum="abc123", size_bytes=512),
TableStats(table_name="workflow_app_logs", row_count=2, checksum="def456", size_bytes=1024),
]
manifest = archiver._generate_manifest(run, stats)
assert manifest["schema_version"] == ARCHIVE_SCHEMA_VERSION
assert manifest["workflow_run_id"] == run.id
assert manifest["tenant_id"] == run.tenant_id
assert manifest["app_id"] == run.app_id
assert "tables" in manifest
assert manifest["tables"]["workflow_runs"]["row_count"] == 1
assert manifest["tables"]["workflow_runs"]["checksum"] == "abc123"
assert manifest["tables"]["workflow_app_logs"]["row_count"] == 2
class TestFilterPaidTenants:
def test_all_tenants_paid_when_billing_disabled(self):
archiver = WorkflowRunArchiver(days=90)
tenant_ids = {"t1", "t2", "t3"}
with patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") as cfg:
cfg.BILLING_ENABLED = False
result = archiver._filter_paid_tenants(tenant_ids)
assert result == tenant_ids
def test_empty_tenants_returns_empty(self):
archiver = WorkflowRunArchiver(days=90)
with patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") as cfg:
cfg.BILLING_ENABLED = True
result = archiver._filter_paid_tenants(set())
assert result == set()
def test_only_paid_plans_returned(self):
archiver = WorkflowRunArchiver(days=90)
mock_bulk = {
"t1": {"plan": "professional"},
"t2": {"plan": "sandbox"},
"t3": {"plan": "team"},
}
with (
patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") as cfg,
patch("services.retention.workflow_run.archive_paid_plan_workflow_run.BillingService") as billing,
):
cfg.BILLING_ENABLED = True
billing.get_plan_bulk_with_cache.return_value = mock_bulk
result = archiver._filter_paid_tenants({"t1", "t2", "t3"})
assert "t1" in result
assert "t3" in result
assert "t2" not in result
def test_billing_api_failure_returns_empty(self):
archiver = WorkflowRunArchiver(days=90)
with (
patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") as cfg,
patch("services.retention.workflow_run.archive_paid_plan_workflow_run.BillingService") as billing,
):
cfg.BILLING_ENABLED = True
billing.get_plan_bulk_with_cache.side_effect = RuntimeError("API down")
result = archiver._filter_paid_tenants({"t1"})
assert result == set()
class TestDryRunArchive:
@patch("services.retention.workflow_run.archive_paid_plan_workflow_run.get_archive_storage")
def test_dry_run_does_not_call_storage(self, mock_get_storage, flask_req_ctx):
archiver = WorkflowRunArchiver(days=90, dry_run=True)
with patch.object(archiver, "_get_runs_batch", return_value=[]):
summary = archiver.run()
mock_get_storage.assert_not_called()
assert isinstance(summary, ArchiveSummary)
assert summary.runs_failed == 0

View File

@@ -1,7 +1,4 @@
"""
Additional tests to improve coverage for low-coverage modules in controllers/console/app.
Target: increase coverage for files with <75% coverage.
"""
"""Testcontainers integration tests for controllers/console/app endpoints."""
from __future__ import annotations
@@ -70,26 +67,12 @@ def _unwrap(func):
return func
class _ConnContext:
def __init__(self, rows):
self._rows = rows
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def execute(self, _query, _args):
return self._rows
# ========== Completion Tests ==========
class TestCompletionEndpoints:
"""Tests for completion API endpoints."""
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_completion_create_payload(self):
"""Test completion creation payload."""
payload = CompletionMessagePayload(inputs={"prompt": "test"}, model_config={})
assert payload.inputs == {"prompt": "test"}
@@ -209,7 +192,9 @@ class TestCompletionEndpoints:
class TestAppEndpoints:
"""Tests for app endpoints."""
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_app_put_should_preserve_icon_type_when_payload_omits_it(self, app, monkeypatch):
api = app_module.AppApi()
@@ -250,12 +235,12 @@ class TestAppEndpoints:
)
# ========== OpsTrace Tests ==========
class TestOpsTraceEndpoints:
"""Tests for ops_trace endpoint."""
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_ops_trace_query_basic(self):
"""Test ops_trace query."""
query = TraceProviderQuery(tracing_provider="langfuse")
assert query.tracing_provider == "langfuse"
@@ -310,12 +295,12 @@ class TestOpsTraceEndpoints:
method(app_id="app-1")
# ========== Site Tests ==========
class TestSiteEndpoints:
"""Tests for site endpoint."""
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_site_response_structure(self):
"""Test site response structure."""
payload = AppSiteUpdatePayload(title="My Site", description="Test site")
assert payload.title == "My Site"
@@ -369,27 +354,22 @@ class TestSiteEndpoints:
assert result is site
# ========== Workflow Tests ==========
class TestWorkflowEndpoints:
"""Tests for workflow endpoints."""
def test_workflow_copy_payload(self):
"""Test workflow copy payload."""
payload = SyncDraftWorkflowPayload(graph={}, features={})
assert payload.graph == {}
def test_workflow_mode_query(self):
"""Test workflow mode query."""
payload = AdvancedChatWorkflowRunPayload(inputs={}, query="hi")
assert payload.query == "hi"
# ========== Workflow App Log Tests ==========
class TestWorkflowAppLogEndpoints:
"""Tests for workflow app log endpoints."""
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_workflow_app_log_query(self):
"""Test workflow app log query."""
query = WorkflowAppLogQuery(keyword="test", page=1, limit=20)
assert query.keyword == "test"
@@ -427,12 +407,12 @@ class TestWorkflowAppLogEndpoints:
assert result == {"items": [], "total": 0}
# ========== Workflow Draft Variable Tests ==========
class TestWorkflowDraftVariableEndpoints:
"""Tests for workflow draft variable endpoints."""
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_workflow_variable_creation(self):
"""Test workflow variable creation."""
payload = WorkflowDraftVariableUpdatePayload(name="var1", value="test")
assert payload.name == "var1"
@@ -472,12 +452,12 @@ class TestWorkflowDraftVariableEndpoints:
assert result == {"items": [], "total": 0}
# ========== Workflow Statistic Tests ==========
class TestWorkflowStatisticEndpoints:
"""Tests for workflow statistic endpoints."""
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_workflow_statistic_time_range(self):
"""Test workflow statistic time range query."""
query = WorkflowStatisticQuery(start="2024-01-01", end="2024-12-31")
assert query.start == "2024-01-01"
@@ -541,12 +521,12 @@ class TestWorkflowStatisticEndpoints:
assert response.get_json() == {"data": [{"date": "2024-01-02"}]}
# ========== Workflow Trigger Tests ==========
class TestWorkflowTriggerEndpoints:
"""Tests for workflow trigger endpoints."""
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_webhook_trigger_payload(self):
"""Test webhook trigger payload."""
payload = Parser(node_id="node-1")
assert payload.node_id == "node-1"
@@ -578,22 +558,13 @@ class TestWorkflowTriggerEndpoints:
assert result is trigger
# ========== Wraps Tests ==========
class TestWrapsEndpoints:
"""Tests for wraps utility functions."""
def test_get_app_model_context(self):
"""Test get_app_model wrapper context."""
# These are decorator functions, so we test their availability
assert hasattr(wraps_module, "get_app_model")
# ========== MCP Server Tests ==========
class TestMCPServerEndpoints:
"""Tests for MCP server endpoints."""
def test_mcp_server_connection(self):
"""Test MCP server connection."""
payload = MCPServerCreatePayload(parameters={"url": "http://localhost:3000"})
assert payload.parameters["url"] == "http://localhost:3000"
@@ -602,22 +573,14 @@ class TestMCPServerEndpoints:
assert payload.status == "active"
# ========== Error Handling Tests ==========
class TestErrorHandling:
"""Tests for error handling in various endpoints."""
def test_annotation_list_query_validation(self):
"""Test annotation list query validation."""
with pytest.raises(ValueError):
annotation_module.AnnotationListQuery(page=0)
# ========== Integration-like Tests ==========
class TestPayloadIntegration:
"""Integration tests for payload handling."""
def test_multiple_payload_types(self):
"""Test handling of multiple payload types."""
payloads = [
annotation_module.AnnotationReplyPayload(
score_threshold=0.5, embedding_provider_name="openai", embedding_model_name="text-embedding-3-small"

View File

@@ -0,0 +1,142 @@
"""Testcontainers integration tests for controllers.console.app.app_import endpoints."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from controllers.console.app import app_import as app_import_module
from services.app_dsl_service import ImportStatus
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
class _Result:
def __init__(self, status: ImportStatus, app_id: str | None = "app-1"):
self.status = status
self.app_id = app_id
def model_dump(self, mode: str = "json"):
return {"status": self.status, "app_id": self.app_id}
def _install_features(monkeypatch: pytest.MonkeyPatch, enabled: bool) -> None:
features = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=enabled))
monkeypatch.setattr(app_import_module.FeatureService, "get_system_features", lambda: features)
class TestAppImportApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_import_post_returns_failed_status(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportApi()
method = _unwrap(api.post)
_install_features(monkeypatch, enabled=False)
monkeypatch.setattr(
app_import_module.AppDslService,
"import_app",
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None),
)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method()
assert status == 400
assert response["status"] == ImportStatus.FAILED
def test_import_post_returns_pending_status(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportApi()
method = _unwrap(api.post)
_install_features(monkeypatch, enabled=False)
monkeypatch.setattr(
app_import_module.AppDslService,
"import_app",
lambda *_args, **_kwargs: _Result(ImportStatus.PENDING),
)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method()
assert status == 202
assert response["status"] == ImportStatus.PENDING
def test_import_post_updates_webapp_auth_when_enabled(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportApi()
method = _unwrap(api.post)
_install_features(monkeypatch, enabled=True)
monkeypatch.setattr(
app_import_module.AppDslService,
"import_app",
lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"),
)
update_access = MagicMock()
monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method()
update_access.assert_called_once_with("app-123", "private")
assert status == 200
assert response["status"] == ImportStatus.COMPLETED
class TestAppImportConfirmApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_import_confirm_returns_failed_status(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportConfirmApi()
method = _unwrap(api.post)
monkeypatch.setattr(
app_import_module.AppDslService,
"confirm_import",
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED),
)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"):
response, status = method(import_id="import-1")
assert status == 400
assert response["status"] == ImportStatus.FAILED
class TestAppImportCheckDependenciesApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_import_check_dependencies_returns_result(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportCheckDependenciesApi()
method = _unwrap(api.get)
monkeypatch.setattr(
app_import_module.AppDslService,
"check_dependencies",
lambda *_args, **_kwargs: SimpleNamespace(model_dump=lambda mode="json": {"leaked_dependencies": []}),
)
with app.test_request_context("/console/api/apps/imports/app-1/check-dependencies", method="GET"):
response, status = method(app_model=SimpleNamespace(id="app-1"))
assert status == 200
assert response["leaked_dependencies"] == []

View File

@@ -1,6 +1,12 @@
"""Testcontainers integration tests for rag_pipeline controller endpoints."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from controllers.console import console_ns
from controllers.console.datasets.rag_pipeline.rag_pipeline import (
@@ -9,6 +15,7 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline import (
PipelineTemplateListApi,
PublishCustomizedPipelineTemplateApi,
)
from models.dataset import PipelineCustomizedTemplate
def unwrap(func):
@@ -18,6 +25,10 @@ def unwrap(func):
class TestPipelineTemplateListApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_get_success(self, app):
api = PipelineTemplateListApi()
method = unwrap(api.get)
@@ -38,6 +49,10 @@ class TestPipelineTemplateListApi:
class TestPipelineTemplateDetailApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_get_success(self, app):
api = PipelineTemplateDetailApi()
method = unwrap(api.get)
@@ -99,6 +114,10 @@ class TestPipelineTemplateDetailApi:
class TestCustomizedPipelineTemplateApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_patch_success(self, app):
api = CustomizedPipelineTemplateApi()
method = unwrap(api.patch)
@@ -136,35 +155,29 @@ class TestCustomizedPipelineTemplateApi:
delete_mock.assert_called_once_with("tpl-1")
assert response == 200
def test_post_success(self, app):
def test_post_success(self, app, db_session_with_containers: Session):
api = CustomizedPipelineTemplateApi()
method = unwrap(api.post)
template = MagicMock()
template.yaml_content = "yaml-data"
tenant_id = str(uuid4())
template = PipelineCustomizedTemplate(
tenant_id=tenant_id,
name="Test Template",
description="Test",
chunk_structure="hierarchical",
icon={"icon": "📘"},
position=0,
yaml_content="yaml-data",
install_count=0,
language="en-US",
created_by=str(uuid4()),
)
db_session_with_containers.add(template)
db_session_with_containers.commit()
db_session_with_containers.expire_all()
fake_db = MagicMock()
fake_db.engine = MagicMock()
session = MagicMock()
session.query.return_value.where.return_value.first.return_value = template
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline.Session",
return_value=session_ctx,
),
):
response, status = method(api, "tpl-1")
with app.test_request_context("/"):
response, status = method(api, template.id)
assert status == 200
assert response == {"data": "yaml-data"}
@@ -173,32 +186,16 @@ class TestCustomizedPipelineTemplateApi:
api = CustomizedPipelineTemplateApi()
method = unwrap(api.post)
fake_db = MagicMock()
fake_db.engine = MagicMock()
session = MagicMock()
session.query.return_value.where.return_value.first.return_value = None
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline.Session",
return_value=session_ctx,
),
):
with app.test_request_context("/"):
with pytest.raises(ValueError):
method(api, "tpl-1")
method(api, str(uuid4()))
class TestPublishCustomizedPipelineTemplateApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_post_success(self, app):
api = PublishCustomizedPipelineTemplateApi()
method = unwrap(api.post)

View File

@@ -1,3 +1,7 @@
"""Testcontainers integration tests for rag_pipeline_datasets controller endpoints."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
@@ -19,6 +23,10 @@ def unwrap(func):
class TestCreateRagPipelineDatasetApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def _valid_payload(self):
return {"yaml_content": "name: test"}
@@ -33,13 +41,6 @@ class TestCreateRagPipelineDatasetApi:
mock_service = MagicMock()
mock_service.create_rag_pipeline_dataset.return_value = import_info
mock_session_ctx = MagicMock()
mock_session_ctx.__enter__.return_value = MagicMock()
mock_session_ctx.__exit__.return_value = None
fake_db = MagicMock()
fake_db.engine = MagicMock()
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
@@ -47,14 +48,6 @@ class TestCreateRagPipelineDatasetApi:
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session",
return_value=mock_session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService",
return_value=mock_service,
@@ -93,13 +86,6 @@ class TestCreateRagPipelineDatasetApi:
mock_service = MagicMock()
mock_service.create_rag_pipeline_dataset.side_effect = services.errors.dataset.DatasetNameDuplicateError()
mock_session_ctx = MagicMock()
mock_session_ctx.__enter__.return_value = MagicMock()
mock_session_ctx.__exit__.return_value = None
fake_db = MagicMock()
fake_db.engine = MagicMock()
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
@@ -107,14 +93,6 @@ class TestCreateRagPipelineDatasetApi:
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session",
return_value=mock_session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService",
return_value=mock_service,
@@ -143,6 +121,10 @@ class TestCreateRagPipelineDatasetApi:
class TestCreateEmptyRagPipelineDatasetApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_post_success(self, app):
api = CreateEmptyRagPipelineDatasetApi()
method = unwrap(api.post)

View File

@@ -1,5 +1,11 @@
"""Testcontainers integration tests for rag_pipeline_import controller endpoints."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from controllers.console import console_ns
from controllers.console.datasets.rag_pipeline.rag_pipeline_import import (
RagPipelineExportApi,
@@ -18,6 +24,10 @@ def unwrap(func):
class TestRagPipelineImportApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def _payload(self, mode="create"):
return {
"mode": mode,
@@ -30,7 +40,6 @@ class TestRagPipelineImportApi:
method = unwrap(api.post)
payload = self._payload()
user = MagicMock()
result = MagicMock()
result.status = "completed"
@@ -39,13 +48,6 @@ class TestRagPipelineImportApi:
service = MagicMock()
service.import_rag_pipeline.return_value = result
fake_db = MagicMock()
fake_db.engine = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = MagicMock()
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
@@ -53,14 +55,6 @@ class TestRagPipelineImportApi:
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
@@ -76,7 +70,6 @@ class TestRagPipelineImportApi:
method = unwrap(api.post)
payload = self._payload()
user = MagicMock()
result = MagicMock()
result.status = ImportStatus.FAILED
@@ -85,13 +78,6 @@ class TestRagPipelineImportApi:
service = MagicMock()
service.import_rag_pipeline.return_value = result
fake_db = MagicMock()
fake_db.engine = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = MagicMock()
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
@@ -99,14 +85,6 @@ class TestRagPipelineImportApi:
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
@@ -122,7 +100,6 @@ class TestRagPipelineImportApi:
method = unwrap(api.post)
payload = self._payload()
user = MagicMock()
result = MagicMock()
result.status = ImportStatus.PENDING
@@ -131,13 +108,6 @@ class TestRagPipelineImportApi:
service = MagicMock()
service.import_rag_pipeline.return_value = result
fake_db = MagicMock()
fake_db.engine = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = MagicMock()
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
@@ -145,14 +115,6 @@ class TestRagPipelineImportApi:
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
@@ -165,6 +127,10 @@ class TestRagPipelineImportApi:
class TestRagPipelineImportConfirmApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_confirm_success(self, app):
api = RagPipelineImportConfirmApi()
method = unwrap(api.post)
@@ -177,27 +143,12 @@ class TestRagPipelineImportConfirmApi:
service = MagicMock()
service.confirm_import.return_value = result
fake_db = MagicMock()
fake_db.engine = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = MagicMock()
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
@@ -220,27 +171,12 @@ class TestRagPipelineImportConfirmApi:
service = MagicMock()
service.confirm_import.return_value = result
fake_db = MagicMock()
fake_db.engine = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = MagicMock()
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
@@ -253,6 +189,10 @@ class TestRagPipelineImportConfirmApi:
class TestRagPipelineImportCheckDependenciesApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_get_success(self, app):
api = RagPipelineImportCheckDependenciesApi()
method = unwrap(api.get)
@@ -264,23 +204,8 @@ class TestRagPipelineImportCheckDependenciesApi:
service = MagicMock()
service.check_dependencies.return_value = result
fake_db = MagicMock()
fake_db.engine = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = MagicMock()
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
@@ -293,6 +218,10 @@ class TestRagPipelineImportCheckDependenciesApi:
class TestRagPipelineExportApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_get_with_include_secret(self, app):
api = RagPipelineExportApi()
method = unwrap(api.get)
@@ -301,23 +230,8 @@ class TestRagPipelineExportApi:
service = MagicMock()
service.export_rag_pipeline_dsl.return_value = {"yaml": "data"}
fake_db = MagicMock()
fake_db.engine = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = MagicMock()
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/?include_secret=true"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,

View File

@@ -1,7 +1,13 @@
"""Testcontainers integration tests for rag_pipeline_workflow controller endpoints."""
from __future__ import annotations
from datetime import datetime
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden, HTTPException, NotFound
import services
@@ -38,6 +44,10 @@ def unwrap(func):
class TestDraftWorkflowApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_get_draft_success(self, app):
api = DraftRagPipelineApi()
method = unwrap(api.get)
@@ -200,6 +210,10 @@ class TestDraftWorkflowApi:
class TestDraftRunNodes:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_iteration_node_success(self, app):
api = RagPipelineDraftRunIterationNodeApi()
method = unwrap(api.post)
@@ -275,6 +289,10 @@ class TestDraftRunNodes:
class TestPipelineRunApis:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_draft_run_success(self, app):
api = DraftRagPipelineRunApi()
method = unwrap(api.post)
@@ -337,6 +355,10 @@ class TestPipelineRunApis:
class TestDraftNodeRun:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_execution_not_found(self, app):
api = RagPipelineDraftNodeRunApi()
method = unwrap(api.post)
@@ -364,45 +386,43 @@ class TestDraftNodeRun:
class TestPublishedPipelineApis:
def test_publish_success(self, app):
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_publish_success(self, app, db_session_with_containers: Session):
from models.dataset import Pipeline
api = PublishedRagPipelineApi()
method = unwrap(api.post)
pipeline = MagicMock()
tenant_id = str(uuid4())
pipeline = Pipeline(
tenant_id=tenant_id,
name="test-pipeline",
description="test",
created_by=str(uuid4()),
)
db_session_with_containers.add(pipeline)
db_session_with_containers.commit()
db_session_with_containers.expire_all()
user = MagicMock(id="u1")
workflow = MagicMock(
id="w1",
id=str(uuid4()),
created_at=naive_utc_now(),
)
session = MagicMock()
session.merge.return_value = pipeline
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
service = MagicMock()
service.publish_workflow.return_value = workflow
fake_db = MagicMock()
fake_db.engine = MagicMock()
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
@@ -415,6 +435,10 @@ class TestPublishedPipelineApis:
class TestMiscApis:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_task_stop(self, app):
api = RagPipelineTaskStopApi()
method = unwrap(api.post)
@@ -471,6 +495,10 @@ class TestMiscApis:
class TestPublishedRagPipelineRunApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_published_run_success(self, app):
api = PublishedRagPipelineRunApi()
method = unwrap(api.post)
@@ -536,6 +564,10 @@ class TestPublishedRagPipelineRunApi:
class TestDefaultBlockConfigApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_get_block_config_success(self, app):
api = DefaultRagPipelineBlockConfigApi()
method = unwrap(api.get)
@@ -567,6 +599,10 @@ class TestDefaultBlockConfigApi:
class TestPublishedAllRagPipelineApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_get_published_workflows_success(self, app):
api = PublishedAllRagPipelineApi()
method = unwrap(api.get)
@@ -577,28 +613,12 @@ class TestPublishedAllRagPipelineApi:
service = MagicMock()
service.get_all_published_workflow.return_value = ([{"id": "w1"}], False)
session = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
fake_db = MagicMock()
fake_db.engine = MagicMock()
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
@@ -628,6 +648,10 @@ class TestPublishedAllRagPipelineApi:
class TestRagPipelineByIdApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_patch_success(self, app):
api = RagPipelineByIdApi()
method = unwrap(api.patch)
@@ -640,14 +664,6 @@ class TestRagPipelineByIdApi:
service = MagicMock()
service.update_workflow.return_value = workflow
session = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
fake_db = MagicMock()
fake_db.engine = MagicMock()
payload = {"marked_name": "test"}
with (
@@ -657,14 +673,6 @@ class TestRagPipelineByIdApi:
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
@@ -700,24 +708,8 @@ class TestRagPipelineByIdApi:
workflow_service = MagicMock()
session = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
fake_db = MagicMock()
fake_db.engine = MagicMock()
with (
app.test_request_context("/", method="DELETE"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.WorkflowService",
return_value=workflow_service,
@@ -725,12 +717,7 @@ class TestRagPipelineByIdApi:
):
result = method(api, pipeline, "old-workflow")
workflow_service.delete_workflow.assert_called_once_with(
session=session,
workflow_id="old-workflow",
tenant_id="t1",
)
session.commit.assert_called_once()
workflow_service.delete_workflow.assert_called_once()
assert result == (None, 204)
def test_delete_active_workflow_rejected(self, app):
@@ -745,6 +732,10 @@ class TestRagPipelineByIdApi:
class TestRagPipelineWorkflowLastRunApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_last_run_success(self, app):
api = RagPipelineWorkflowLastRunApi()
method = unwrap(api.get)
@@ -788,6 +779,10 @@ class TestRagPipelineWorkflowLastRunApi:
class TestRagPipelineDatasourceVariableApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_set_datasource_variables_success(self, app):
api = RagPipelineDatasourceVariableApi()
method = unwrap(api.post)

View File

@@ -1,3 +1,7 @@
"""Testcontainers integration tests for controllers.console.datasets.data_source endpoints."""
from __future__ import annotations
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
@@ -46,6 +50,10 @@ def mock_engine():
class TestDataSourceApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_get_success(self, app, patch_tenant):
api = DataSourceApi()
method = unwrap(api.get)
@@ -179,6 +187,10 @@ class TestDataSourceApi:
class TestDataSourceNotionListApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_get_credential_not_found(self, app, patch_tenant):
api = DataSourceNotionListApi()
method = unwrap(api.get)
@@ -310,6 +322,10 @@ class TestDataSourceNotionListApi:
class TestDataSourceNotionApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_get_preview_success(self, app, patch_tenant):
api = DataSourceNotionApi()
method = unwrap(api.get)
@@ -364,6 +380,10 @@ class TestDataSourceNotionApi:
class TestDataSourceNotionDatasetSyncApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_get_success(self, app, patch_tenant):
api = DataSourceNotionDatasetSyncApi()
method = unwrap(api.get)
@@ -403,6 +423,10 @@ class TestDataSourceNotionDatasetSyncApi:
class TestDataSourceNotionDocumentSyncApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_get_success(self, app, patch_tenant):
api = DataSourceNotionDocumentSyncApi()
method = unwrap(api.get)

View File

@@ -1,7 +1,10 @@
"""Testcontainers integration tests for controllers.console.explore.conversation endpoints."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
import controllers.console.explore.conversation as conversation_module
@@ -48,24 +51,12 @@ def user():
return user
@pytest.fixture(autouse=True)
def mock_db_and_session():
with (
patch.object(
conversation_module,
"db",
MagicMock(session=MagicMock(), engine=MagicMock()),
),
patch(
"controllers.console.explore.conversation.Session",
MagicMock(),
),
):
yield
class TestConversationListApi:
def test_get_success(self, app: Flask, chat_app, user):
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_get_success(self, app, chat_app, user):
api = conversation_module.ConversationListApi()
method = unwrap(api.get)
@@ -90,7 +81,7 @@ class TestConversationListApi:
assert result["has_more"] is False
assert len(result["data"]) == 2
def test_last_conversation_not_exists(self, app: Flask, chat_app, user):
def test_last_conversation_not_exists(self, app, chat_app, user):
api = conversation_module.ConversationListApi()
method = unwrap(api.get)
@@ -106,7 +97,7 @@ class TestConversationListApi:
with pytest.raises(NotFound):
method(chat_app)
def test_wrong_app_mode(self, app: Flask, non_chat_app):
def test_wrong_app_mode(self, app, non_chat_app):
api = conversation_module.ConversationListApi()
method = unwrap(api.get)
@@ -116,7 +107,11 @@ class TestConversationListApi:
class TestConversationApi:
def test_delete_success(self, app: Flask, chat_app, user):
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_delete_success(self, app, chat_app, user):
api = conversation_module.ConversationApi()
method = unwrap(api.delete)
@@ -134,7 +129,7 @@ class TestConversationApi:
assert status == 204
assert body["result"] == "success"
def test_delete_not_found(self, app: Flask, chat_app, user):
def test_delete_not_found(self, app, chat_app, user):
api = conversation_module.ConversationApi()
method = unwrap(api.delete)
@@ -150,7 +145,7 @@ class TestConversationApi:
with pytest.raises(NotFound):
method(chat_app, "cid")
def test_delete_wrong_app_mode(self, app: Flask, non_chat_app):
def test_delete_wrong_app_mode(self, app, non_chat_app):
api = conversation_module.ConversationApi()
method = unwrap(api.delete)
@@ -160,7 +155,11 @@ class TestConversationApi:
class TestConversationRenameApi:
def test_rename_success(self, app: Flask, chat_app, user):
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_rename_success(self, app, chat_app, user):
api = conversation_module.ConversationRenameApi()
method = unwrap(api.post)
@@ -179,7 +178,7 @@ class TestConversationRenameApi:
assert result["id"] == "cid"
def test_rename_not_found(self, app: Flask, chat_app, user):
def test_rename_not_found(self, app, chat_app, user):
api = conversation_module.ConversationRenameApi()
method = unwrap(api.post)
@@ -197,7 +196,11 @@ class TestConversationRenameApi:
class TestConversationPinApi:
def test_pin_success(self, app: Flask, chat_app, user):
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_pin_success(self, app, chat_app, user):
api = conversation_module.ConversationPinApi()
method = unwrap(api.patch)
@@ -215,7 +218,11 @@ class TestConversationPinApi:
class TestConversationUnPinApi:
def test_unpin_success(self, app: Flask, chat_app, user):
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_unpin_success(self, app, chat_app, user):
api = conversation_module.ConversationUnPinApi()
method = unwrap(api.patch)

View File

@@ -1,9 +1,11 @@
"""Testcontainers integration tests for controllers.console.workspace.tool_providers endpoints."""
from __future__ import annotations
import json
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from flask_restx import Api
from werkzeug.exceptions import Forbidden
from controllers.console.workspace.tool_providers import (
@@ -31,7 +33,6 @@ from controllers.console.workspace.tool_providers import (
ToolOAuthCustomClient,
ToolPluginOAuthApi,
ToolProviderListApi,
ToolProviderMCPApi,
ToolWorkflowListApi,
ToolWorkflowProviderCreateApi,
ToolWorkflowProviderDeleteApi,
@@ -39,8 +40,6 @@ from controllers.console.workspace.tool_providers import (
ToolWorkflowProviderUpdateApi,
is_valid_url,
)
from core.db.session_factory import configure_session_factory
from extensions.ext_database import db
from services.tools.mcp_tools_manage_service import ReconnectResult
@@ -61,17 +60,8 @@ def _mock_user_tenant():
@pytest.fixture
def client():
app = Flask(__name__)
app.config["TESTING"] = True
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:"
api = Api(app)
api.add_resource(ToolProviderMCPApi, "/console/api/workspaces/current/tool-provider/mcp")
db.init_app(app)
# Configure session factory used by controller code
with app.app_context():
configure_session_factory(db.engine)
return app.test_client()
def client(flask_app_with_containers):
return flask_app_with_containers.test_client()
@patch(
@@ -152,10 +142,14 @@ class TestUtils:
assert not is_valid_url("")
assert not is_valid_url("ftp://example.com")
assert not is_valid_url("not-a-url")
assert not is_valid_url(None)
assert not is_valid_url(None) # type: ignore[arg-type]
class TestToolProviderListApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_get_success(self, app):
api = ToolProviderListApi()
method = unwrap(api.get)
@@ -175,6 +169,10 @@ class TestToolProviderListApi:
class TestBuiltinProviderApis:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_list_tools(self, app):
api = ToolBuiltinProviderListToolsApi()
method = unwrap(api.get)
@@ -379,6 +377,10 @@ class TestBuiltinProviderApis:
class TestApiProviderApis:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_add(self, app):
api = ToolApiProviderAddApi()
method = unwrap(api.post)
@@ -502,6 +504,10 @@ class TestApiProviderApis:
class TestWorkflowApis:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_create(self, app):
api = ToolWorkflowProviderCreateApi()
method = unwrap(api.post)
@@ -587,6 +593,10 @@ class TestWorkflowApis:
class TestLists:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_builtin_list(self, app):
api = ToolBuiltinListApi()
method = unwrap(api.get)
@@ -649,6 +659,10 @@ class TestLists:
class TestLabels:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_labels(self, app):
api = ToolLabelsApi()
method = unwrap(api.get)
@@ -664,6 +678,10 @@ class TestLabels:
class TestOAuth:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_oauth_no_client(self, app):
api = ToolPluginOAuthApi()
method = unwrap(api.get)
@@ -692,6 +710,10 @@ class TestOAuth:
class TestOAuthCustomClient:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_save_custom_client(self, app):
api = ToolOAuthCustomClient()
method = unwrap(api.post)

View File

@@ -1,3 +1,7 @@
"""Testcontainers integration tests for controllers.console.workspace.trigger_providers endpoints."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
@@ -40,6 +44,10 @@ def mock_user():
class TestTriggerProviderApis:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_icon_success(self, app):
api = TriggerProviderIconApi()
method = unwrap(api.get)
@@ -84,6 +92,10 @@ class TestTriggerProviderApis:
class TestTriggerSubscriptionListApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_list_success(self, app):
api = TriggerSubscriptionListApi()
method = unwrap(api.get)
@@ -115,6 +127,10 @@ class TestTriggerSubscriptionListApi:
class TestTriggerSubscriptionBuilderApis:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_create_builder(self, app):
api = TriggerSubscriptionBuilderCreateApi()
method = unwrap(api.post)
@@ -219,6 +235,10 @@ class TestTriggerSubscriptionBuilderApis:
class TestTriggerSubscriptionCrud:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_update_rename_only(self, app):
api = TriggerSubscriptionUpdateApi()
method = unwrap(api.post)
@@ -321,6 +341,10 @@ class TestTriggerSubscriptionCrud:
class TestTriggerOAuthApis:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_oauth_authorize_success(self, app):
api = TriggerOAuthAuthorizeApi()
method = unwrap(api.get)
@@ -455,6 +479,10 @@ class TestTriggerOAuthApis:
class TestTriggerOAuthClientManageApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_get_client(self, app):
api = TriggerOAuthClientManageApi()
method = unwrap(api.get)
@@ -527,6 +555,10 @@ class TestTriggerOAuthClientManageApi:
class TestTriggerSubscriptionVerifyApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_verify_success(self, app):
api = TriggerSubscriptionVerifyApi()
method = unwrap(api.post)

View File

@@ -0,0 +1,185 @@
"""Testcontainers integration tests for plugin_permission_required decorator."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import patch
import pytest
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console.workspace import plugin_permission_required
from models.account import Tenant, TenantPluginPermission, TenantStatus
def _create_tenant(db_session: Session) -> Tenant:
tenant = Tenant(name="test-tenant", status=TenantStatus.NORMAL, plan="basic")
db_session.add(tenant)
db_session.commit()
db_session.expire_all()
return tenant
def _create_permission(
db_session: Session,
tenant_id: str,
install: TenantPluginPermission.InstallPermission = TenantPluginPermission.InstallPermission.EVERYONE,
debug: TenantPluginPermission.DebugPermission = TenantPluginPermission.DebugPermission.EVERYONE,
) -> TenantPluginPermission:
perm = TenantPluginPermission(
tenant_id=tenant_id,
install_permission=install,
debug_permission=debug,
)
db_session.add(perm)
db_session.commit()
db_session.expire_all()
return perm
class TestPluginPermissionRequired:
def test_allows_without_permission(self, db_session_with_containers: Session):
tenant = _create_tenant(db_session_with_containers)
user = SimpleNamespace(is_admin_or_owner=False)
with patch(
"controllers.console.workspace.current_account_with_tenant",
return_value=(user, tenant.id),
):
@plugin_permission_required()
def handler():
return "ok"
assert handler() == "ok"
def test_install_nobody_forbidden(self, db_session_with_containers: Session):
tenant = _create_tenant(db_session_with_containers)
_create_permission(
db_session_with_containers,
tenant.id,
install=TenantPluginPermission.InstallPermission.NOBODY,
debug=TenantPluginPermission.DebugPermission.EVERYONE,
)
user = SimpleNamespace(is_admin_or_owner=True)
with patch(
"controllers.console.workspace.current_account_with_tenant",
return_value=(user, tenant.id),
):
@plugin_permission_required(install_required=True)
def handler():
return "ok"
with pytest.raises(Forbidden):
handler()
def test_install_admin_requires_admin(self, db_session_with_containers: Session):
tenant = _create_tenant(db_session_with_containers)
_create_permission(
db_session_with_containers,
tenant.id,
install=TenantPluginPermission.InstallPermission.ADMINS,
debug=TenantPluginPermission.DebugPermission.EVERYONE,
)
user = SimpleNamespace(is_admin_or_owner=False)
with patch(
"controllers.console.workspace.current_account_with_tenant",
return_value=(user, tenant.id),
):
@plugin_permission_required(install_required=True)
def handler():
return "ok"
with pytest.raises(Forbidden):
handler()
def test_install_admin_allows_admin(self, db_session_with_containers: Session):
tenant = _create_tenant(db_session_with_containers)
_create_permission(
db_session_with_containers,
tenant.id,
install=TenantPluginPermission.InstallPermission.ADMINS,
debug=TenantPluginPermission.DebugPermission.EVERYONE,
)
user = SimpleNamespace(is_admin_or_owner=True)
with patch(
"controllers.console.workspace.current_account_with_tenant",
return_value=(user, tenant.id),
):
@plugin_permission_required(install_required=True)
def handler():
return "ok"
assert handler() == "ok"
def test_debug_nobody_forbidden(self, db_session_with_containers: Session):
tenant = _create_tenant(db_session_with_containers)
_create_permission(
db_session_with_containers,
tenant.id,
install=TenantPluginPermission.InstallPermission.EVERYONE,
debug=TenantPluginPermission.DebugPermission.NOBODY,
)
user = SimpleNamespace(is_admin_or_owner=True)
with patch(
"controllers.console.workspace.current_account_with_tenant",
return_value=(user, tenant.id),
):
@plugin_permission_required(debug_required=True)
def handler():
return "ok"
with pytest.raises(Forbidden):
handler()
def test_debug_admin_requires_admin(self, db_session_with_containers: Session):
tenant = _create_tenant(db_session_with_containers)
_create_permission(
db_session_with_containers,
tenant.id,
install=TenantPluginPermission.InstallPermission.EVERYONE,
debug=TenantPluginPermission.DebugPermission.ADMINS,
)
user = SimpleNamespace(is_admin_or_owner=False)
with patch(
"controllers.console.workspace.current_account_with_tenant",
return_value=(user, tenant.id),
):
@plugin_permission_required(debug_required=True)
def handler():
return "ok"
with pytest.raises(Forbidden):
handler()
def test_debug_admin_allows_admin(self, db_session_with_containers: Session):
tenant = _create_tenant(db_session_with_containers)
_create_permission(
db_session_with_containers,
tenant.id,
install=TenantPluginPermission.InstallPermission.EVERYONE,
debug=TenantPluginPermission.DebugPermission.ADMINS,
)
user = SimpleNamespace(is_admin_or_owner=True)
with patch(
"controllers.console.workspace.current_account_with_tenant",
return_value=(user, tenant.id),
):
@plugin_permission_required(debug_required=True)
def handler():
return "ok"
assert handler() == "ok"

View File

@@ -1,5 +1,10 @@
"""Testcontainers integration tests for controllers.mcp.mcp endpoints."""
from __future__ import annotations
import types
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from flask import Response
@@ -14,24 +19,6 @@ def unwrap(func):
return func
@pytest.fixture(autouse=True)
def mock_db():
module.db = types.SimpleNamespace(engine=object())
@pytest.fixture
def fake_session():
session = MagicMock()
session.__enter__.return_value = session
session.__exit__.return_value = False
return session
@pytest.fixture(autouse=True)
def mock_session(fake_session):
module.Session = MagicMock(return_value=fake_session)
@pytest.fixture(autouse=True)
def mock_mcp_ns():
fake_ns = types.SimpleNamespace()
@@ -44,8 +31,13 @@ def fake_payload(data):
module.mcp_ns.payload = data
_TENANT_ID = str(uuid4())
_APP_ID = str(uuid4())
_SERVER_ID = str(uuid4())
class DummyServer:
def __init__(self, status, app_id="app-1", tenant_id="tenant-1", server_id="srv-1"):
def __init__(self, status, app_id=_APP_ID, tenant_id=_TENANT_ID, server_id=_SERVER_ID):
self.status = status
self.app_id = app_id
self.tenant_id = tenant_id
@@ -54,8 +46,8 @@ class DummyServer:
class DummyApp:
def __init__(self, mode, workflow=None, app_model_config=None):
self.id = "app-1"
self.tenant_id = "tenant-1"
self.id = _APP_ID
self.tenant_id = _TENANT_ID
self.mode = mode
self.workflow = workflow
self.app_model_config = app_model_config
@@ -76,6 +68,7 @@ class DummyResult:
return {"jsonrpc": "2.0", "result": "ok", "id": 1}
@pytest.mark.usefixtures("flask_req_ctx_with_containers")
class TestMCPAppApi:
@patch.object(module, "handle_mcp_request", return_value=DummyResult(), autospec=True)
def test_success_request(self, mock_handle):

View File

@@ -1,4 +1,4 @@
"""Unit tests for controllers.web.conversation endpoints."""
"""Testcontainers integration tests for controllers.web.conversation endpoints."""
from __future__ import annotations
@@ -7,7 +7,6 @@ from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
from controllers.web.conversation import (
@@ -33,18 +32,18 @@ def _end_user() -> SimpleNamespace:
return SimpleNamespace(id="eu-1")
# ---------------------------------------------------------------------------
# ConversationListApi
# ---------------------------------------------------------------------------
class TestConversationListApi:
def test_non_chat_mode_raises(self, app: Flask) -> None:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_non_chat_mode_raises(self, app) -> None:
with app.test_request_context("/conversations"):
with pytest.raises(NotChatAppError):
ConversationListApi().get(_completion_app(), _end_user())
@patch("controllers.web.conversation.WebConversationService.pagination_by_last_id")
@patch("controllers.web.conversation.db")
def test_happy_path(self, mock_db: MagicMock, mock_paginate: MagicMock, app: Flask) -> None:
def test_happy_path(self, mock_paginate: MagicMock, app) -> None:
conv_id = str(uuid4())
conv = SimpleNamespace(
id=conv_id,
@@ -56,34 +55,26 @@ class TestConversationListApi:
updated_at=1700000000,
)
mock_paginate.return_value = SimpleNamespace(limit=20, has_more=False, data=[conv])
mock_db.engine = "engine"
session_mock = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__ = MagicMock(return_value=session_mock)
session_ctx.__exit__ = MagicMock(return_value=False)
with (
app.test_request_context("/conversations?limit=20"),
patch("controllers.web.conversation.Session", return_value=session_ctx),
):
with app.test_request_context("/conversations?limit=20"):
result = ConversationListApi().get(_chat_app(), _end_user())
assert result["limit"] == 20
assert result["has_more"] is False
# ---------------------------------------------------------------------------
# ConversationApi (delete)
# ---------------------------------------------------------------------------
class TestConversationApi:
def test_non_chat_mode_raises(self, app: Flask) -> None:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_non_chat_mode_raises(self, app) -> None:
with app.test_request_context(f"/conversations/{uuid4()}"):
with pytest.raises(NotChatAppError):
ConversationApi().delete(_completion_app(), _end_user(), uuid4())
@patch("controllers.web.conversation.ConversationService.delete")
def test_delete_success(self, mock_delete: MagicMock, app: Flask) -> None:
def test_delete_success(self, mock_delete: MagicMock, app) -> None:
c_id = uuid4()
with app.test_request_context(f"/conversations/{c_id}"):
result, status = ConversationApi().delete(_chat_app(), _end_user(), c_id)
@@ -92,25 +83,26 @@ class TestConversationApi:
assert result["result"] == "success"
@patch("controllers.web.conversation.ConversationService.delete", side_effect=ConversationNotExistsError())
def test_delete_not_found(self, mock_delete: MagicMock, app: Flask) -> None:
def test_delete_not_found(self, mock_delete: MagicMock, app) -> None:
c_id = uuid4()
with app.test_request_context(f"/conversations/{c_id}"):
with pytest.raises(NotFound, match="Conversation Not Exists"):
ConversationApi().delete(_chat_app(), _end_user(), c_id)
# ---------------------------------------------------------------------------
# ConversationRenameApi
# ---------------------------------------------------------------------------
class TestConversationRenameApi:
def test_non_chat_mode_raises(self, app: Flask) -> None:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_non_chat_mode_raises(self, app) -> None:
with app.test_request_context(f"/conversations/{uuid4()}/name", method="POST", json={"name": "x"}):
with pytest.raises(NotChatAppError):
ConversationRenameApi().post(_completion_app(), _end_user(), uuid4())
@patch("controllers.web.conversation.ConversationService.rename")
@patch("controllers.web.conversation.web_ns")
def test_rename_success(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None:
def test_rename_success(self, mock_ns: MagicMock, mock_rename: MagicMock, app) -> None:
c_id = uuid4()
mock_ns.payload = {"name": "New Name", "auto_generate": False}
conv = SimpleNamespace(
@@ -134,7 +126,7 @@ class TestConversationRenameApi:
side_effect=ConversationNotExistsError(),
)
@patch("controllers.web.conversation.web_ns")
def test_rename_not_found(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None:
def test_rename_not_found(self, mock_ns: MagicMock, mock_rename: MagicMock, app) -> None:
c_id = uuid4()
mock_ns.payload = {"name": "X", "auto_generate": False}
@@ -143,17 +135,18 @@ class TestConversationRenameApi:
ConversationRenameApi().post(_chat_app(), _end_user(), c_id)
# ---------------------------------------------------------------------------
# ConversationPinApi / ConversationUnPinApi
# ---------------------------------------------------------------------------
class TestConversationPinApi:
def test_non_chat_mode_raises(self, app: Flask) -> None:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_non_chat_mode_raises(self, app) -> None:
with app.test_request_context(f"/conversations/{uuid4()}/pin", method="PATCH"):
with pytest.raises(NotChatAppError):
ConversationPinApi().patch(_completion_app(), _end_user(), uuid4())
@patch("controllers.web.conversation.WebConversationService.pin")
def test_pin_success(self, mock_pin: MagicMock, app: Flask) -> None:
def test_pin_success(self, mock_pin: MagicMock, app) -> None:
c_id = uuid4()
with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"):
result = ConversationPinApi().patch(_chat_app(), _end_user(), c_id)
@@ -161,7 +154,7 @@ class TestConversationPinApi:
assert result["result"] == "success"
@patch("controllers.web.conversation.WebConversationService.pin", side_effect=ConversationNotExistsError())
def test_pin_not_found(self, mock_pin: MagicMock, app: Flask) -> None:
def test_pin_not_found(self, mock_pin: MagicMock, app) -> None:
c_id = uuid4()
with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"):
with pytest.raises(NotFound):
@@ -169,13 +162,17 @@ class TestConversationPinApi:
class TestConversationUnPinApi:
def test_non_chat_mode_raises(self, app: Flask) -> None:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def test_non_chat_mode_raises(self, app) -> None:
with app.test_request_context(f"/conversations/{uuid4()}/unpin", method="PATCH"):
with pytest.raises(NotChatAppError):
ConversationUnPinApi().patch(_completion_app(), _end_user(), uuid4())
@patch("controllers.web.conversation.WebConversationService.unpin")
def test_unpin_success(self, mock_unpin: MagicMock, app: Flask) -> None:
def test_unpin_success(self, mock_unpin: MagicMock, app) -> None:
c_id = uuid4()
with app.test_request_context(f"/conversations/{c_id}/unpin", method="PATCH"):
result = ConversationUnPinApi().patch(_chat_app(), _end_user(), c_id)

View File

@@ -1,9 +1,12 @@
"""Testcontainers integration tests for controllers.web.forgot_password endpoints."""
from __future__ import annotations
import base64
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.web.forgot_password import (
ForgotPasswordCheckApi,
@@ -12,13 +15,6 @@ from controllers.web.forgot_password import (
)
@pytest.fixture
def app():
flask_app = Flask(__name__)
flask_app.config["TESTING"] = True
return flask_app
@pytest.fixture(autouse=True)
def _patch_wraps():
wraps_features = SimpleNamespace(enable_email_password_login=True)
@@ -33,6 +29,10 @@ def _patch_wraps():
class TestForgotPasswordSendEmailApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
@patch("controllers.web.forgot_password.AccountService.send_reset_password_email")
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
@@ -69,6 +69,10 @@ class TestForgotPasswordSendEmailApi:
class TestForgotPasswordCheckApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
@patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
@patch("controllers.web.forgot_password.AccountService.generate_reset_password_token")
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
@@ -143,6 +147,10 @@ class TestForgotPasswordCheckApi:
class TestForgotPasswordResetApi:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
@patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account")
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.web.forgot_password.Session")

View File

@@ -1,13 +1,14 @@
"""Unit tests for controllers.web.wraps — JWT auth decorator and validation helpers."""
"""Testcontainers integration tests for controllers.web.wraps — JWT auth decorator and validation helpers."""
from __future__ import annotations
from datetime import UTC, datetime, timedelta
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from flask import Flask
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError
@@ -18,12 +19,8 @@ from controllers.web.wraps import (
)
# ---------------------------------------------------------------------------
# _validate_webapp_token
# ---------------------------------------------------------------------------
class TestValidateWebappToken:
def test_enterprise_enabled_and_app_auth_requires_webapp_source(self) -> None:
"""When both flags are true, a non-webapp source must raise."""
decoded = {"token_source": "other"}
with pytest.raises(WebAppAuthRequiredError):
_validate_webapp_token(decoded, app_web_auth_enabled=True, system_webapp_auth_enabled=True)
@@ -38,7 +35,6 @@ class TestValidateWebappToken:
_validate_webapp_token(decoded, app_web_auth_enabled=True, system_webapp_auth_enabled=True)
def test_public_app_rejects_webapp_source(self) -> None:
"""When auth is not required, a webapp-sourced token must be rejected."""
decoded = {"token_source": "webapp"}
with pytest.raises(Unauthorized):
_validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=False)
@@ -52,18 +48,13 @@ class TestValidateWebappToken:
_validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=False)
def test_system_enabled_but_app_public(self) -> None:
"""system_webapp_auth_enabled=True but app is public — webapp source rejected."""
decoded = {"token_source": "webapp"}
with pytest.raises(Unauthorized):
_validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=True)
# ---------------------------------------------------------------------------
# _validate_user_accessibility
# ---------------------------------------------------------------------------
class TestValidateUserAccessibility:
def test_skips_when_auth_disabled(self) -> None:
"""No checks when system or app auth is disabled."""
_validate_user_accessibility(
decoded={},
app_code="code",
@@ -123,7 +114,6 @@ class TestValidateUserAccessibility:
def test_external_auth_type_checks_sso_update_time(
self, mock_perm_check: MagicMock, mock_sso_time: MagicMock
) -> None:
# granted_at is before SSO update time → denied
mock_sso_time.return_value = datetime.now(UTC)
old_granted = int((datetime.now(UTC) - timedelta(hours=1)).timestamp())
decoded = {"user_id": "u1", "auth_type": "external", "granted_at": old_granted}
@@ -164,7 +154,6 @@ class TestValidateUserAccessibility:
recent_granted = int(datetime.now(UTC).timestamp())
decoded = {"user_id": "u1", "auth_type": "external", "granted_at": recent_granted}
settings = SimpleNamespace(access_mode="public")
# Should not raise
_validate_user_accessibility(
decoded=decoded,
app_code="code",
@@ -191,10 +180,49 @@ class TestValidateUserAccessibility:
)
# ---------------------------------------------------------------------------
# decode_jwt_token
# ---------------------------------------------------------------------------
class TestDecodeJwtToken:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
def _create_app_site_enduser(self, db_session: Session, *, enable_site: bool = True):
from models.model import App, AppMode, CustomizeTokenStrategy, EndUser, Site
tenant_id = str(uuid4())
app_model = App(
tenant_id=tenant_id,
mode=AppMode.CHAT.value,
name="test-app",
enable_site=enable_site,
enable_api=True,
)
db_session.add(app_model)
db_session.commit()
db_session.expire_all()
site = Site(
app_id=app_model.id,
title="test-site",
default_language="en-US",
customize_token_strategy=CustomizeTokenStrategy.NOT_ALLOW,
code="code1",
)
db_session.add(site)
db_session.commit()
db_session.expire_all()
end_user = EndUser(
tenant_id=tenant_id,
app_id=app_model.id,
type="browser",
session_id="sess-1",
)
db_session.add(end_user)
db_session.commit()
db_session.expire_all()
return app_model, site, end_user
@patch("controllers.web.wraps._validate_user_accessibility")
@patch("controllers.web.wraps._validate_webapp_token")
@patch("controllers.web.wraps.EnterpriseService.WebAppAuth.get_app_access_mode_by_id")
@@ -202,10 +230,8 @@ class TestDecodeJwtToken:
@patch("controllers.web.wraps.FeatureService.get_system_features")
@patch("controllers.web.wraps.PassportService")
@patch("controllers.web.wraps.extract_webapp_passport")
@patch("controllers.web.wraps.db")
def test_happy_path(
self,
mock_db: MagicMock,
mock_extract: MagicMock,
mock_passport_cls: MagicMock,
mock_features: MagicMock,
@@ -213,40 +239,28 @@ class TestDecodeJwtToken:
mock_access_mode: MagicMock,
mock_validate_token: MagicMock,
mock_validate_user: MagicMock,
app: Flask,
app,
db_session_with_containers: Session,
) -> None:
app_model, site, end_user = self._create_app_site_enduser(db_session_with_containers)
mock_extract.return_value = "jwt-token"
mock_passport_cls.return_value.verify.return_value = {
"app_code": "code1",
"app_id": "app-1",
"end_user_id": "eu-1",
"app_code": site.code,
"app_id": app_model.id,
"end_user_id": end_user.id,
}
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
app_model = SimpleNamespace(id="app-1", enable_site=True)
site = SimpleNamespace(code="code1")
end_user = SimpleNamespace(id="eu-1", session_id="sess-1")
with app.test_request_context("/", headers={"X-App-Code": site.code}):
result_app, result_user = decode_jwt_token()
# Configure session mock to return correct objects via scalar()
session_mock = MagicMock()
session_mock.scalar.side_effect = [app_model, site, end_user]
session_ctx = MagicMock()
session_ctx.__enter__ = MagicMock(return_value=session_mock)
session_ctx.__exit__ = MagicMock(return_value=False)
mock_db.engine = "engine"
with patch("controllers.web.wraps.Session", return_value=session_ctx):
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
result_app, result_user = decode_jwt_token()
assert result_app.id == "app-1"
assert result_user.id == "eu-1"
assert result_app.id == app_model.id
assert result_user.id == end_user.id
@patch("controllers.web.wraps.FeatureService.get_system_features")
@patch("controllers.web.wraps.extract_webapp_passport")
def test_missing_token_raises_unauthorized(
self, mock_extract: MagicMock, mock_features: MagicMock, app: Flask
) -> None:
def test_missing_token_raises_unauthorized(self, mock_extract: MagicMock, mock_features: MagicMock, app) -> None:
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
mock_extract.return_value = None
@@ -257,137 +271,98 @@ class TestDecodeJwtToken:
@patch("controllers.web.wraps.FeatureService.get_system_features")
@patch("controllers.web.wraps.PassportService")
@patch("controllers.web.wraps.extract_webapp_passport")
@patch("controllers.web.wraps.db")
def test_missing_app_raises_not_found(
self,
mock_db: MagicMock,
mock_extract: MagicMock,
mock_passport_cls: MagicMock,
mock_features: MagicMock,
app: Flask,
app,
) -> None:
non_existent_id = str(uuid4())
mock_extract.return_value = "jwt-token"
mock_passport_cls.return_value.verify.return_value = {
"app_code": "code1",
"app_id": "app-1",
"end_user_id": "eu-1",
"app_id": non_existent_id,
"end_user_id": str(uuid4()),
}
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
session_mock = MagicMock()
session_mock.scalar.return_value = None # No app found
session_ctx = MagicMock()
session_ctx.__enter__ = MagicMock(return_value=session_mock)
session_ctx.__exit__ = MagicMock(return_value=False)
mock_db.engine = "engine"
with patch("controllers.web.wraps.Session", return_value=session_ctx):
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
with pytest.raises(NotFound):
decode_jwt_token()
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
with pytest.raises(NotFound):
decode_jwt_token()
@patch("controllers.web.wraps.FeatureService.get_system_features")
@patch("controllers.web.wraps.PassportService")
@patch("controllers.web.wraps.extract_webapp_passport")
@patch("controllers.web.wraps.db")
def test_disabled_site_raises_bad_request(
self,
mock_db: MagicMock,
mock_extract: MagicMock,
mock_passport_cls: MagicMock,
mock_features: MagicMock,
app: Flask,
app,
db_session_with_containers: Session,
) -> None:
app_model, site, end_user = self._create_app_site_enduser(db_session_with_containers, enable_site=False)
mock_extract.return_value = "jwt-token"
mock_passport_cls.return_value.verify.return_value = {
"app_code": "code1",
"app_id": "app-1",
"end_user_id": "eu-1",
"app_code": site.code,
"app_id": app_model.id,
"end_user_id": end_user.id,
}
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
app_model = SimpleNamespace(id="app-1", enable_site=False)
session_mock = MagicMock()
# scalar calls: app_model, site (code found), then end_user
session_mock.scalar.side_effect = [app_model, SimpleNamespace(code="code1"), None]
session_ctx = MagicMock()
session_ctx.__enter__ = MagicMock(return_value=session_mock)
session_ctx.__exit__ = MagicMock(return_value=False)
mock_db.engine = "engine"
with patch("controllers.web.wraps.Session", return_value=session_ctx):
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
with pytest.raises(BadRequest, match="Site is disabled"):
decode_jwt_token()
with app.test_request_context("/", headers={"X-App-Code": site.code}):
with pytest.raises(BadRequest, match="Site is disabled"):
decode_jwt_token()
@patch("controllers.web.wraps.FeatureService.get_system_features")
@patch("controllers.web.wraps.PassportService")
@patch("controllers.web.wraps.extract_webapp_passport")
@patch("controllers.web.wraps.db")
def test_missing_end_user_raises_not_found(
self,
mock_db: MagicMock,
mock_extract: MagicMock,
mock_passport_cls: MagicMock,
mock_features: MagicMock,
app: Flask,
app,
db_session_with_containers: Session,
) -> None:
app_model, site, _ = self._create_app_site_enduser(db_session_with_containers)
non_existent_eu = str(uuid4())
mock_extract.return_value = "jwt-token"
mock_passport_cls.return_value.verify.return_value = {
"app_code": "code1",
"app_id": "app-1",
"end_user_id": "eu-1",
"app_code": site.code,
"app_id": app_model.id,
"end_user_id": non_existent_eu,
}
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
app_model = SimpleNamespace(id="app-1", enable_site=True)
site = SimpleNamespace(code="code1")
session_mock = MagicMock()
session_mock.scalar.side_effect = [app_model, site, None] # end_user is None
session_ctx = MagicMock()
session_ctx.__enter__ = MagicMock(return_value=session_mock)
session_ctx.__exit__ = MagicMock(return_value=False)
mock_db.engine = "engine"
with patch("controllers.web.wraps.Session", return_value=session_ctx):
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
with pytest.raises(NotFound):
decode_jwt_token()
with app.test_request_context("/", headers={"X-App-Code": site.code}):
with pytest.raises(NotFound):
decode_jwt_token()
@patch("controllers.web.wraps.FeatureService.get_system_features")
@patch("controllers.web.wraps.PassportService")
@patch("controllers.web.wraps.extract_webapp_passport")
@patch("controllers.web.wraps.db")
def test_user_id_mismatch_raises_unauthorized(
self,
mock_db: MagicMock,
mock_extract: MagicMock,
mock_passport_cls: MagicMock,
mock_features: MagicMock,
app: Flask,
app,
db_session_with_containers: Session,
) -> None:
app_model, site, end_user = self._create_app_site_enduser(db_session_with_containers)
mock_extract.return_value = "jwt-token"
mock_passport_cls.return_value.verify.return_value = {
"app_code": "code1",
"app_id": "app-1",
"end_user_id": "eu-1",
"app_code": site.code,
"app_id": app_model.id,
"end_user_id": end_user.id,
}
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
app_model = SimpleNamespace(id="app-1", enable_site=True)
site = SimpleNamespace(code="code1")
end_user = SimpleNamespace(id="eu-1", session_id="sess-1")
session_mock = MagicMock()
session_mock.scalar.side_effect = [app_model, site, end_user]
session_ctx = MagicMock()
session_ctx.__enter__ = MagicMock(return_value=session_mock)
session_ctx.__exit__ = MagicMock(return_value=False)
mock_db.engine = "engine"
with patch("controllers.web.wraps.Session", return_value=session_ctx):
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
with pytest.raises(Unauthorized, match="expired"):
decode_jwt_token(user_id="different-user")
with app.test_request_context("/", headers={"X-App-Code": site.code}):
with pytest.raises(Unauthorized, match="expired"):
decode_jwt_token(user_id="different-user")

View File

@@ -141,7 +141,7 @@ class TestModelLoadBalancingService:
tenant_id=tenant_id,
provider_name="openai",
model_name="gpt-3.5-turbo",
model_type="text-generation", # Use the origin model type that matches the query
model_type="llm",
enabled=True,
load_balancing_enabled=False,
)
@@ -298,7 +298,7 @@ class TestModelLoadBalancingService:
tenant_id=tenant.id,
provider_name="openai",
model_name="gpt-3.5-turbo",
model_type="text-generation", # Use the origin model type that matches the query
model_type="llm",
name="config1",
encrypted_config='{"api_key": "test_key"}',
enabled=True,
@@ -417,7 +417,7 @@ class TestModelLoadBalancingService:
tenant_id=tenant.id,
provider_name="openai",
model_name="gpt-3.5-turbo",
model_type="text-generation", # Use the origin model type that matches the query
model_type="llm",
name="config1",
encrypted_config='{"api_key": "test_key"}',
enabled=True,

View File

@@ -1,157 +0,0 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from controllers.console.app import app_import as app_import_module
from services.app_dsl_service import ImportStatus
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
class _Result:
def __init__(self, status: ImportStatus, app_id: str | None = "app-1"):
self.status = status
self.app_id = app_id
def model_dump(self, mode: str = "json"):
return {"status": self.status, "app_id": self.app_id}
class _SessionContext:
def __init__(self, session):
self._session = session
def __enter__(self):
return self._session
def __exit__(self, exc_type, exc, tb):
return False
def _install_session(monkeypatch: pytest.MonkeyPatch, session: MagicMock) -> None:
monkeypatch.setattr(app_import_module, "Session", lambda *_: _SessionContext(session))
monkeypatch.setattr(app_import_module, "db", SimpleNamespace(engine=object()))
def _install_features(monkeypatch: pytest.MonkeyPatch, enabled: bool) -> None:
features = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=enabled))
monkeypatch.setattr(app_import_module.FeatureService, "get_system_features", lambda: features)
def test_import_post_returns_failed_status(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportApi()
method = _unwrap(api.post)
session = MagicMock()
_install_session(monkeypatch, session)
_install_features(monkeypatch, enabled=False)
monkeypatch.setattr(
app_import_module.AppDslService,
"import_app",
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None),
)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method()
session.commit.assert_called_once()
assert status == 400
assert response["status"] == ImportStatus.FAILED
def test_import_post_returns_pending_status(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportApi()
method = _unwrap(api.post)
session = MagicMock()
_install_session(monkeypatch, session)
_install_features(monkeypatch, enabled=False)
monkeypatch.setattr(
app_import_module.AppDslService,
"import_app",
lambda *_args, **_kwargs: _Result(ImportStatus.PENDING),
)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method()
session.commit.assert_called_once()
assert status == 202
assert response["status"] == ImportStatus.PENDING
def test_import_post_updates_webapp_auth_when_enabled(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportApi()
method = _unwrap(api.post)
session = MagicMock()
_install_session(monkeypatch, session)
_install_features(monkeypatch, enabled=True)
monkeypatch.setattr(
app_import_module.AppDslService,
"import_app",
lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"),
)
update_access = MagicMock()
monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method()
session.commit.assert_called_once()
update_access.assert_called_once_with("app-123", "private")
assert status == 200
assert response["status"] == ImportStatus.COMPLETED
def test_import_confirm_returns_failed_status(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportConfirmApi()
method = _unwrap(api.post)
session = MagicMock()
_install_session(monkeypatch, session)
monkeypatch.setattr(
app_import_module.AppDslService,
"confirm_import",
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED),
)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"):
response, status = method(import_id="import-1")
session.commit.assert_called_once()
assert status == 400
assert response["status"] == ImportStatus.FAILED
def test_import_check_dependencies_returns_result(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportCheckDependenciesApi()
method = _unwrap(api.get)
session = MagicMock()
_install_session(monkeypatch, session)
monkeypatch.setattr(
app_import_module.AppDslService,
"check_dependencies",
lambda *_args, **_kwargs: SimpleNamespace(model_dump=lambda mode="json": {"leaked_dependencies": []}),
)
with app.test_request_context("/console/api/apps/imports/app-1/check-dependencies", method="GET"):
response, status = method(app_model=SimpleNamespace(id="app-1"))
assert status == 200
assert response["leaked_dependencies"] == []

View File

@@ -1,142 +0,0 @@
from __future__ import annotations
import importlib
from types import SimpleNamespace
import pytest
from werkzeug.exceptions import Forbidden
from controllers.console.workspace import plugin_permission_required
from models.account import TenantPluginPermission
class _SessionStub:
def __init__(self, permission):
self._permission = permission
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def query(self, *_args, **_kwargs):
return self
def where(self, *_args, **_kwargs):
return self
def first(self):
return self._permission
def _workspace_module():
return importlib.import_module(plugin_permission_required.__module__)
def _patch_session(monkeypatch: pytest.MonkeyPatch, permission):
module = _workspace_module()
monkeypatch.setattr(module, "Session", lambda *_args, **_kwargs: _SessionStub(permission))
monkeypatch.setattr(module, "db", SimpleNamespace(engine=object()))
def test_plugin_permission_allows_without_permission(monkeypatch: pytest.MonkeyPatch) -> None:
user = SimpleNamespace(is_admin_or_owner=False)
module = _workspace_module()
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
_patch_session(monkeypatch, None)
@plugin_permission_required()
def handler():
return "ok"
assert handler() == "ok"
def test_plugin_permission_install_nobody_forbidden(monkeypatch: pytest.MonkeyPatch) -> None:
user = SimpleNamespace(is_admin_or_owner=True)
permission = SimpleNamespace(
install_permission=TenantPluginPermission.InstallPermission.NOBODY,
debug_permission=TenantPluginPermission.DebugPermission.EVERYONE,
)
module = _workspace_module()
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
_patch_session(monkeypatch, permission)
@plugin_permission_required(install_required=True)
def handler():
return "ok"
with pytest.raises(Forbidden):
handler()
def test_plugin_permission_install_admin_requires_admin(monkeypatch: pytest.MonkeyPatch) -> None:
user = SimpleNamespace(is_admin_or_owner=False)
permission = SimpleNamespace(
install_permission=TenantPluginPermission.InstallPermission.ADMINS,
debug_permission=TenantPluginPermission.DebugPermission.EVERYONE,
)
module = _workspace_module()
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
_patch_session(monkeypatch, permission)
@plugin_permission_required(install_required=True)
def handler():
return "ok"
with pytest.raises(Forbidden):
handler()
def test_plugin_permission_install_admin_allows_admin(monkeypatch: pytest.MonkeyPatch) -> None:
user = SimpleNamespace(is_admin_or_owner=True)
permission = SimpleNamespace(
install_permission=TenantPluginPermission.InstallPermission.ADMINS,
debug_permission=TenantPluginPermission.DebugPermission.EVERYONE,
)
module = _workspace_module()
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
_patch_session(monkeypatch, permission)
@plugin_permission_required(install_required=True)
def handler():
return "ok"
assert handler() == "ok"
def test_plugin_permission_debug_nobody_forbidden(monkeypatch: pytest.MonkeyPatch) -> None:
user = SimpleNamespace(is_admin_or_owner=True)
permission = SimpleNamespace(
install_permission=TenantPluginPermission.InstallPermission.EVERYONE,
debug_permission=TenantPluginPermission.DebugPermission.NOBODY,
)
module = _workspace_module()
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
_patch_session(monkeypatch, permission)
@plugin_permission_required(debug_required=True)
def handler():
return "ok"
with pytest.raises(Forbidden):
handler()
def test_plugin_permission_debug_admin_requires_admin(monkeypatch: pytest.MonkeyPatch) -> None:
user = SimpleNamespace(is_admin_or_owner=False)
permission = SimpleNamespace(
install_permission=TenantPluginPermission.InstallPermission.EVERYONE,
debug_permission=TenantPluginPermission.DebugPermission.ADMINS,
)
module = _workspace_module()
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
_patch_session(monkeypatch, permission)
@plugin_permission_required(debug_required=True)
def handler():
return "ok"
with pytest.raises(Forbidden):
handler()

View File

@@ -768,6 +768,7 @@ class TestSegmentApiGet:
``current_account_with_tenant()`` and ``marshal``.
"""
@patch("controllers.service_api.dataset.segment.SummaryIndexService")
@patch("controllers.service_api.dataset.segment.marshal")
@patch("controllers.service_api.dataset.segment.SegmentService")
@patch("controllers.service_api.dataset.segment.DocumentService")
@@ -780,6 +781,7 @@ class TestSegmentApiGet:
mock_doc_svc,
mock_seg_svc,
mock_marshal,
mock_summary_svc,
app,
mock_tenant,
mock_dataset,
@@ -791,7 +793,8 @@ class TestSegmentApiGet:
mock_db.session.scalar.return_value = mock_dataset
mock_doc_svc.get_document.return_value = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX)
mock_seg_svc.get_segments.return_value = ([mock_segment], 1)
mock_marshal.return_value = [{"id": mock_segment.id}]
mock_marshal.return_value = {"id": mock_segment.id}
mock_summary_svc.get_segments_summaries.return_value = {}
# Act
with app.test_request_context(
@@ -872,6 +875,7 @@ class TestSegmentApiPost:
mock_rate_limit.enabled = False
mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit
@patch("controllers.service_api.dataset.segment.SummaryIndexService")
@patch("controllers.service_api.dataset.segment.marshal")
@patch("controllers.service_api.dataset.segment.SegmentService")
@patch("controllers.service_api.dataset.segment.DocumentService")
@@ -888,6 +892,7 @@ class TestSegmentApiPost:
mock_doc_svc,
mock_seg_svc,
mock_marshal,
mock_summary_svc,
app,
mock_tenant,
mock_dataset,
@@ -909,7 +914,8 @@ class TestSegmentApiPost:
mock_seg_svc.segment_create_args_validate.return_value = None
mock_seg_svc.multi_create_segment.return_value = [mock_segment]
mock_marshal.return_value = [{"id": mock_segment.id}]
mock_marshal.return_value = {"id": mock_segment.id}
mock_summary_svc.get_segments_summaries.return_value = {}
segments_data = [{"content": "Test segment content", "answer": "Test answer"}]
@@ -1206,6 +1212,7 @@ class TestDatasetSegmentApiUpdate:
mock_rate_limit.enabled = False
mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit
@patch("controllers.service_api.dataset.segment.SummaryIndexService")
@patch("controllers.service_api.dataset.segment.marshal")
@patch("controllers.service_api.dataset.segment.SegmentService")
@patch("controllers.service_api.dataset.segment.DocumentService")
@@ -1224,6 +1231,7 @@ class TestDatasetSegmentApiUpdate:
mock_doc_svc,
mock_seg_svc,
mock_marshal,
mock_summary_svc,
app,
mock_tenant,
mock_dataset,
@@ -1240,6 +1248,7 @@ class TestDatasetSegmentApiUpdate:
updated = Mock()
mock_seg_svc.update_segment.return_value = updated
mock_marshal.return_value = {"id": mock_segment.id}
mock_summary_svc.get_segment_summary.return_value = None
with app.test_request_context(
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{mock_segment.id}",
@@ -1349,6 +1358,7 @@ class TestDatasetSegmentApiGetSingle:
``current_account_with_tenant()`` and ``marshal``.
"""
@patch("controllers.service_api.dataset.segment.SummaryIndexService")
@patch("controllers.service_api.dataset.segment.marshal")
@patch("controllers.service_api.dataset.segment.SegmentService")
@patch("controllers.service_api.dataset.segment.DocumentService")
@@ -1363,6 +1373,7 @@ class TestDatasetSegmentApiGetSingle:
mock_doc_svc,
mock_seg_svc,
mock_marshal,
mock_summary_svc,
app,
mock_tenant,
mock_dataset,
@@ -1376,6 +1387,7 @@ class TestDatasetSegmentApiGetSingle:
mock_doc_svc.get_document.return_value = mock_doc
mock_seg_svc.get_segment_by_id.return_value = mock_segment
mock_marshal.return_value = {"id": mock_segment.id}
mock_summary_svc.get_segment_summary.return_value = None
with app.test_request_context(
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{mock_segment.id}",
@@ -1393,6 +1405,55 @@ class TestDatasetSegmentApiGetSingle:
assert "data" in response
assert response["doc_form"] == IndexStructureType.PARAGRAPH_INDEX
@patch("controllers.service_api.dataset.segment.SummaryIndexService")
@patch("controllers.service_api.dataset.segment.marshal")
@patch("controllers.service_api.dataset.segment.SegmentService")
@patch("controllers.service_api.dataset.segment.DocumentService")
@patch("controllers.service_api.dataset.segment.DatasetService")
@patch("controllers.service_api.dataset.segment.current_account_with_tenant")
@patch("controllers.service_api.dataset.segment.db")
def test_get_single_segment_includes_summary(
self,
mock_db,
mock_account_fn,
mock_dataset_svc,
mock_doc_svc,
mock_seg_svc,
mock_marshal,
mock_summary_svc,
app,
mock_tenant,
mock_dataset,
mock_segment,
):
"""Test that single segment response includes summary content from SummaryIndexService."""
mock_account_fn.return_value = (Mock(), mock_tenant.id)
mock_db.session.scalar.return_value = mock_dataset
mock_dataset_svc.check_dataset_model_setting.return_value = None
mock_doc = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX)
mock_doc_svc.get_document.return_value = mock_doc
mock_seg_svc.get_segment_by_id.return_value = mock_segment
mock_marshal.return_value = {"id": mock_segment.id, "summary": None}
mock_summary_record = Mock()
mock_summary_record.summary_content = "This is the segment summary"
mock_summary_svc.get_segment_summary.return_value = mock_summary_record
with app.test_request_context(
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{mock_segment.id}",
method="GET",
):
api = DatasetSegmentApi()
response, status = api.get(
tenant_id=mock_tenant.id,
dataset_id=mock_dataset.id,
document_id="doc-id",
segment_id=mock_segment.id,
)
assert status == 200
assert response["data"]["summary"] == "This is the segment summary"
@patch("controllers.service_api.dataset.segment.current_account_with_tenant")
@patch("controllers.service_api.dataset.segment.db")
def test_get_single_segment_dataset_not_found(

View File

@@ -415,12 +415,44 @@ class TestUtilityFunctions:
label="Upload",
required=False,
),
VariableEntity(
type=VariableEntityType.CHECKBOX,
variable="enabled",
description="Enable flag",
label="Enabled",
required=False,
),
VariableEntity(
type=VariableEntityType.JSON_OBJECT,
variable="config",
description="Config object",
label="Config",
required=True,
),
VariableEntity(
type=VariableEntityType.JSON_OBJECT,
variable="schema_config",
description="Config with schema",
label="Schema Config",
required=False,
json_schema={
"properties": {
"host": {"type": "string"},
"port": {"type": "number"},
},
"required": ["host"],
"additionalProperties": False,
},
),
]
parameters_dict: dict[str, str] = {
"name": "Enter your name",
"category": "Select category",
"count": "Enter count",
"enabled": "Enable flag",
"config": "Config object",
"schema_config": "Config with schema",
}
parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict)
@@ -437,20 +469,35 @@ class TestUtilityFunctions:
assert "count" in parameters
assert parameters["count"]["type"] == "number"
# FILE type should be skipped - it creates empty dict but gets filtered later
# Check that it doesn't have any meaningful content
if "upload" in parameters:
assert parameters["upload"] == {}
# FILE type is skipped entirely via `continue` — key should not exist
assert "upload" not in parameters
# CHECKBOX maps to boolean
assert parameters["enabled"]["type"] == "boolean"
# JSON_OBJECT without json_schema maps to object
assert parameters["config"]["type"] == "object"
assert "properties" not in parameters["config"]
# JSON_OBJECT with json_schema forwards schema keys
assert parameters["schema_config"]["type"] == "object"
assert parameters["schema_config"]["properties"] == {
"host": {"type": "string"},
"port": {"type": "number"},
}
assert parameters["schema_config"]["required"] == ["host"]
assert parameters["schema_config"]["additionalProperties"] is False
# Check required fields
assert "name" in required
assert "count" in required
assert "config" in required
assert "category" not in required
# Note: _get_request_id function has been removed as request_id is now passed as parameter
def test_convert_input_form_to_parameters_jsonschema_validation_ok(self):
"""Current schema uses 'number' for numeric fields; it should be a valid JSON Schema."""
"""Generated schema with all supported types should be valid JSON Schema."""
user_input_form = [
VariableEntity(
type=VariableEntityType.NUMBER,
@@ -466,11 +513,27 @@ class TestUtilityFunctions:
label="Name",
required=False,
),
VariableEntity(
type=VariableEntityType.CHECKBOX,
variable="enabled",
description="Toggle",
label="Enabled",
required=False,
),
VariableEntity(
type=VariableEntityType.JSON_OBJECT,
variable="metadata",
description="Metadata",
label="Metadata",
required=False,
),
]
parameters_dict = {
"count": "Enter count",
"name": "Enter your name",
"enabled": "Toggle flag",
"metadata": "Metadata object",
}
parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict)
@@ -485,9 +548,12 @@ class TestUtilityFunctions:
# 1) The schema itself must be valid
jsonschema.Draft202012Validator.check_schema(schema)
# 2) Both float and integer instances should pass validation
# 2) Validate instances with all types
jsonschema.validate(instance={"count": 3.14, "name": "alice"}, schema=schema)
jsonschema.validate(instance={"count": 2, "name": "bob"}, schema=schema)
jsonschema.validate(
instance={"count": 2, "enabled": True, "metadata": {"key": "val"}},
schema=schema,
)
def test_legacy_float_type_schema_is_invalid(self):
"""Legacy/buggy behavior: using 'float' should produce an invalid JSON Schema."""

View File

@@ -521,11 +521,11 @@ def test_generate_name_trace(trace_instance):
def test_add_trace_success(trace_instance):
data = LangfuseTrace(id="t1", name="trace")
trace_instance.add_trace(data)
trace_instance.langfuse_client.trace.assert_called_once()
trace_instance.langfuse_client.api.ingestion.batch.assert_called_once()
def test_add_trace_error(trace_instance):
trace_instance.langfuse_client.trace.side_effect = Exception("error")
trace_instance.langfuse_client.api.ingestion.batch.side_effect = Exception("error")
data = LangfuseTrace(id="t1", name="trace")
with pytest.raises(ValueError, match="LangFuse Failed to create trace: error"):
trace_instance.add_trace(data)
@@ -534,11 +534,11 @@ def test_add_trace_error(trace_instance):
def test_add_span_success(trace_instance):
data = LangfuseSpan(id="s1", name="span", trace_id="t1")
trace_instance.add_span(data)
trace_instance.langfuse_client.span.assert_called_once()
trace_instance.langfuse_client.api.ingestion.batch.assert_called_once()
def test_add_span_error(trace_instance):
trace_instance.langfuse_client.span.side_effect = Exception("error")
trace_instance.langfuse_client.api.ingestion.batch.side_effect = Exception("error")
data = LangfuseSpan(id="s1", name="span", trace_id="t1")
with pytest.raises(ValueError, match="LangFuse Failed to create span: error"):
trace_instance.add_span(data)
@@ -554,11 +554,11 @@ def test_update_span(trace_instance):
def test_add_generation_success(trace_instance):
data = LangfuseGeneration(id="g1", name="gen", trace_id="t1")
trace_instance.add_generation(data)
trace_instance.langfuse_client.generation.assert_called_once()
trace_instance.langfuse_client.api.ingestion.batch.assert_called_once()
def test_add_generation_error(trace_instance):
trace_instance.langfuse_client.generation.side_effect = Exception("error")
trace_instance.langfuse_client.api.ingestion.batch.side_effect = Exception("error")
data = LangfuseGeneration(id="g1", name="gen", trace_id="t1")
with pytest.raises(ValueError, match="LangFuse Failed to create generation: error"):
trace_instance.add_generation(data)
@@ -585,12 +585,12 @@ def test_api_check_error(trace_instance):
def test_get_project_key_success(trace_instance):
mock_data = MagicMock()
mock_data.id = "proj-1"
trace_instance.langfuse_client.client.projects.get.return_value = MagicMock(data=[mock_data])
trace_instance.langfuse_client.api.projects.get.return_value = MagicMock(data=[mock_data])
assert trace_instance.get_project_key() == "proj-1"
def test_get_project_key_error(trace_instance):
trace_instance.langfuse_client.client.projects.get.side_effect = Exception("fail")
trace_instance.langfuse_client.api.projects.get.side_effect = Exception("fail")
with pytest.raises(ValueError, match="LangFuse get project key failed: fail"):
trace_instance.get_project_key()

View File

@@ -201,27 +201,23 @@ def test_search_returns_documents_in_rank_order_and_applies_filter(monkeypatch,
document_id = _Field("document_id")
keyword = Jieba(_dataset(_dataset_keyword_table()))
query_stmt = _FakeQuery()
patched_runtime.session.query.return_value = query_stmt
patched_runtime.session.execute.return_value = _FakeExecuteResult(
[
SimpleNamespace(
index_node_id="node-2",
content="segment-content",
index_node_hash="hash-2",
document_id="doc-2",
dataset_id="dataset-1",
)
]
)
patched_runtime.session.scalars.return_value.all.return_value = [
SimpleNamespace(
index_node_id="node-2",
content="segment-content",
index_node_hash="hash-2",
document_id="doc-2",
dataset_id="dataset-1",
)
]
monkeypatch.setattr(jieba_module, "DocumentSegment", _FakeDocumentSegment)
monkeypatch.setattr(jieba_module, "select", lambda *_: _FakeSelect())
monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={"k": {"node-1", "node-2"}}))
monkeypatch.setattr(keyword, "_retrieve_ids_by_query", MagicMock(return_value=["node-1", "node-2"]))
documents = keyword.search("query", top_k=2, document_ids_filter=["doc-2"])
assert len(query_stmt.where_calls) == 2
assert len(documents) == 1
assert documents[0].page_content == "segment-content"
assert documents[0].metadata["doc_id"] == "node-2"

View File

@@ -714,13 +714,13 @@ class TestRetrievalServiceInternals:
dataset_id="dataset-id",
)
dataset_query = Mock()
dataset_query.where.return_value.options.return_value.all.return_value = [
scalars_result = Mock()
scalars_result.all.return_value = [
dataset_doc_parent,
dataset_doc_text,
dataset_doc_parent_summary,
]
monkeypatch.setattr(retrieval_service_module.db.session, "query", Mock(return_value=dataset_query))
monkeypatch.setattr(retrieval_service_module.db.session, "scalars", Mock(return_value=scalars_result))
monkeypatch.setattr(retrieval_service_module, "RetrievalChildChunk", _SimpleRetrievalChildChunk)
monkeypatch.setattr(retrieval_service_module, "RetrievalSegments", _SimpleRetrievalSegment)
@@ -882,7 +882,7 @@ class TestRetrievalServiceInternals:
def test_format_retrieval_documents_rolls_back_and_raises_when_db_fails(self, monkeypatch):
rollback = Mock()
monkeypatch.setattr(retrieval_service_module.db.session, "rollback", rollback)
monkeypatch.setattr(retrieval_service_module.db.session, "query", Mock(side_effect=RuntimeError("db error")))
monkeypatch.setattr(retrieval_service_module.db.session, "scalars", Mock(side_effect=RuntimeError("db error")))
documents = [Document(page_content="content", metadata={"document_id": "doc-1"}, provider="dify")]

View File

@@ -340,15 +340,13 @@ def test_search_by_file_handles_missing_and_existing_upload(vector_factory_modul
vector._embeddings = MagicMock()
vector._vector_processor = MagicMock()
mock_session = SimpleNamespace(get=lambda _model, _id: None)
monkeypatch.setattr(vector_factory_module, "UploadFile", SimpleNamespace(id=_Field()))
monkeypatch.setattr(
vector_factory_module, "db", SimpleNamespace(session=SimpleNamespace(query=lambda _model: upload_query))
)
monkeypatch.setattr(vector_factory_module, "db", SimpleNamespace(session=mock_session))
upload_query.first.return_value = None
assert vector.search_by_file("file-1") == []
upload_query.first.return_value = SimpleNamespace(key="blob-key")
mock_session.get = lambda _model, _id: SimpleNamespace(key="blob-key")
monkeypatch.setattr(vector_factory_module.storage, "load_once", MagicMock(return_value=b"file-bytes"))
vector._embeddings.embed_multimodal_query.return_value = [0.3, 0.4]
vector._vector_processor.search_by_vector.return_value = ["hit"]

View File

@@ -167,7 +167,7 @@ class TestDatasetDocumentStoreAddDocuments:
):
mock_session = MagicMock()
mock_db.session = mock_session
mock_db.session.query.return_value.where.return_value.scalar.return_value = None
mock_db.session.scalar.return_value = None
mock_manager = MagicMock()
mock_manager.get_model_instance.return_value = mock_model_instance
@@ -211,7 +211,7 @@ class TestDatasetDocumentStoreAddDocuments:
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
mock_session = MagicMock()
mock_db.session = mock_session
mock_db.session.query.return_value.where.return_value.scalar.return_value = 5
mock_db.session.scalar.return_value = 5
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment):
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):
@@ -276,7 +276,7 @@ class TestDatasetDocumentStoreAddDocuments:
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
mock_session = MagicMock()
mock_db.session = mock_session
mock_db.session.query.return_value.where.return_value.scalar.return_value = None
mock_db.session.scalar.return_value = None
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None):
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):
@@ -353,7 +353,7 @@ class TestDatasetDocumentStoreAddDocuments:
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
mock_session = MagicMock()
mock_db.session = mock_session
mock_db.session.query.return_value.where.return_value.scalar.return_value = None
mock_db.session.scalar.return_value = None
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None):
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):
@@ -755,7 +755,7 @@ class TestDatasetDocumentStoreAddDocumentsUpdateChild:
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
mock_session = MagicMock()
mock_db.session = mock_session
mock_db.session.query.return_value.where.return_value.scalar.return_value = 5
mock_db.session.scalar.return_value = 5
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment):
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):
@@ -767,7 +767,7 @@ class TestDatasetDocumentStoreAddDocumentsUpdateChild:
store.add_documents([mock_doc], save_child=True)
mock_db.session.query.return_value.where.return_value.delete.assert_called()
mock_db.session.execute.assert_called()
mock_db.session.commit.assert_called()
@@ -798,7 +798,7 @@ class TestDatasetDocumentStoreAddDocumentsUpdateAnswer:
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
mock_session = MagicMock()
mock_db.session = mock_session
mock_db.session.query.return_value.where.return_value.scalar.return_value = 5
mock_db.session.scalar.return_value = 5
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment):
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):

View File

@@ -69,7 +69,7 @@ class TestCacheEmbeddingMultimodalDocuments:
documents = [{"file_id": "file123", "content": "test content"}]
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
mock_model_instance.invoke_multimodal_embedding.return_value = sample_multimodal_result
result = cache_embedding.embed_multimodal_documents(documents)
@@ -114,7 +114,7 @@ class TestCacheEmbeddingMultimodalDocuments:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result
result = cache_embedding.embed_multimodal_documents(documents)
@@ -134,7 +134,7 @@ class TestCacheEmbeddingMultimodalDocuments:
mock_cached_embedding.get_embedding.return_value = normalized_cached
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_cached_embedding
mock_session.scalar.return_value = mock_cached_embedding
result = cache_embedding.embed_multimodal_documents(documents)
@@ -180,18 +180,7 @@ class TestCacheEmbeddingMultimodalDocuments:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
call_count = [0]
def mock_filter_by(**kwargs):
call_count[0] += 1
mock_query = Mock()
if call_count[0] == 1:
mock_query.first.return_value = mock_cached_embedding
else:
mock_query.first.return_value = None
return mock_query
mock_session.query.return_value.filter_by = mock_filter_by
mock_session.scalar.side_effect = [mock_cached_embedding, None, None]
mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result
result = cache_embedding.embed_multimodal_documents(documents)
@@ -224,7 +213,7 @@ class TestCacheEmbeddingMultimodalDocuments:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result
with patch("core.rag.embedding.cached_embedding.logger") as mock_logger:
@@ -265,7 +254,7 @@ class TestCacheEmbeddingMultimodalDocuments:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
batch_results = [create_batch_result(10), create_batch_result(10), create_batch_result(5)]
mock_model_instance.invoke_multimodal_embedding.side_effect = batch_results
@@ -281,7 +270,7 @@ class TestCacheEmbeddingMultimodalDocuments:
documents = [{"file_id": "file123"}]
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
mock_model_instance.invoke_multimodal_embedding.side_effect = Exception("API Error")
with pytest.raises(Exception) as exc_info:
@@ -298,7 +287,7 @@ class TestCacheEmbeddingMultimodalDocuments:
documents = [{"file_id": "file123"}]
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
mock_model_instance.invoke_multimodal_embedding.return_value = sample_multimodal_result
mock_session.commit.side_effect = IntegrityError("Duplicate key", None, None)

View File

@@ -139,7 +139,7 @@ class TestCacheEmbeddingDocuments:
# Mock database query to return no cached embedding (cache miss)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
# Mock model invocation
mock_model_instance.invoke_text_embedding.return_value = sample_embedding_result
@@ -203,7 +203,7 @@ class TestCacheEmbeddingDocuments:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = embedding_result
# Act
@@ -240,7 +240,7 @@ class TestCacheEmbeddingDocuments:
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
# Mock database to return cached embedding (cache hit)
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_cached_embedding
mock_session.scalar.return_value = mock_cached_embedding
# Act
result = cache_embedding.embed_documents(texts)
@@ -313,19 +313,7 @@ class TestCacheEmbeddingDocuments:
mock_hash.side_effect = generate_hash
# Mock database to return cached embedding only for first text (hash_1)
call_count = [0]
def mock_filter_by(**kwargs):
call_count[0] += 1
mock_query = Mock()
# First call (hash_1) returns cached, others return None
if call_count[0] == 1:
mock_query.first.return_value = mock_cached_embedding
else:
mock_query.first.return_value = None
return mock_query
mock_session.query.return_value.filter_by = mock_filter_by
mock_session.scalar.side_effect = [mock_cached_embedding, None, None]
mock_model_instance.invoke_text_embedding.return_value = embedding_result
# Act
@@ -392,7 +380,7 @@ class TestCacheEmbeddingDocuments:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
# Mock model to return appropriate batch results
batch_results = [
@@ -455,7 +443,7 @@ class TestCacheEmbeddingDocuments:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = embedding_result
with patch("core.rag.embedding.cached_embedding.logger") as mock_logger:
@@ -489,7 +477,7 @@ class TestCacheEmbeddingDocuments:
texts = ["Test text"]
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
# Mock model to raise connection error
mock_model_instance.invoke_text_embedding.side_effect = InvokeConnectionError("Failed to connect to API")
@@ -515,7 +503,7 @@ class TestCacheEmbeddingDocuments:
texts = ["Test text"]
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
# Mock model to raise rate limit error
mock_model_instance.invoke_text_embedding.side_effect = InvokeRateLimitError("Rate limit exceeded")
@@ -539,7 +527,7 @@ class TestCacheEmbeddingDocuments:
texts = ["Test text"]
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
# Mock model to raise authorization error
mock_model_instance.invoke_text_embedding.side_effect = InvokeAuthorizationError("Invalid API key")
@@ -564,7 +552,7 @@ class TestCacheEmbeddingDocuments:
texts = ["Test text"]
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = sample_embedding_result
# Mock database commit to raise IntegrityError
@@ -884,7 +872,7 @@ class TestEmbeddingModelSwitching:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
model_instance_ada.invoke_text_embedding.return_value = result_ada
model_instance_3_small.invoke_text_embedding.return_value = result_3_small
@@ -1047,7 +1035,7 @@ class TestEmbeddingDimensionValidation:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = embedding_result
# Act
@@ -1100,7 +1088,7 @@ class TestEmbeddingDimensionValidation:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = embedding_result
# Act
@@ -1186,7 +1174,7 @@ class TestEmbeddingDimensionValidation:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
model_instance_ada.invoke_text_embedding.return_value = result_ada
model_instance_cohere.invoke_text_embedding.return_value = result_cohere
@@ -1284,7 +1272,7 @@ class TestEmbeddingEdgeCases:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = embedding_result
# Act
@@ -1327,7 +1315,7 @@ class TestEmbeddingEdgeCases:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = embedding_result
# Act
@@ -1375,7 +1363,7 @@ class TestEmbeddingEdgeCases:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = embedding_result
# Act
@@ -1427,7 +1415,7 @@ class TestEmbeddingEdgeCases:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = embedding_result
# Act
@@ -1483,7 +1471,7 @@ class TestEmbeddingEdgeCases:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = embedding_result
# Act
@@ -1551,7 +1539,7 @@ class TestEmbeddingEdgeCases:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = embedding_result
# Act
@@ -1649,7 +1637,7 @@ class TestEmbeddingEdgeCases:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = embedding_result
# Act
@@ -1728,7 +1716,7 @@ class TestEmbeddingCachePerformance:
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
# First call: cache miss
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
usage = EmbeddingUsage(
tokens=5,
@@ -1756,7 +1744,7 @@ class TestEmbeddingCachePerformance:
assert len(result1) == 1
# Arrange - Second call: cache hit
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_cached_embedding
mock_session.scalar.return_value = mock_cached_embedding
# Act - Second call (cache hit)
result2 = cache_embedding.embed_documents([text])
@@ -1816,7 +1804,7 @@ class TestEmbeddingCachePerformance:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
# Mock model to return appropriate batch results
batch_results = [

View File

@@ -405,35 +405,36 @@ class TestNotionMetadataAndCredentialMethods:
class FakeDocumentModel:
data_source_info = "data_source_info"
id = "id"
update_calls = []
execute_calls = []
class FakeQuery:
def filter_by(self, **kwargs):
class FakeUpdateStmt:
def where(self, *args):
return self
def update(self, payload):
update_calls.append(payload)
def values(self, **kwargs):
return self
class FakeSession:
committed = False
def query(self, model):
assert model is FakeDocumentModel
return FakeQuery()
def execute(self, stmt):
execute_calls.append(stmt)
def commit(self):
self.committed = True
fake_db = SimpleNamespace(session=FakeSession())
monkeypatch.setattr(notion_extractor, "DocumentModel", FakeDocumentModel)
monkeypatch.setattr(notion_extractor, "update", lambda model: FakeUpdateStmt())
monkeypatch.setattr(notion_extractor, "db", fake_db)
monkeypatch.setattr(extractor, "get_notion_last_edited_time", lambda: "2026-01-01T00:00:00.000Z")
doc_model = SimpleNamespace(id="doc-1", data_source_info_dict={"source": "notion"})
extractor.update_last_edited_time(doc_model)
assert update_calls
assert execute_calls
assert fake_db.session.committed is True
def test_get_notion_last_edited_time_uses_page_and_database_urls(self, mocker: MockerFixture):

View File

@@ -188,10 +188,10 @@ class TestParagraphIndexProcessor:
mock_keyword_cls.return_value.add_texts.assert_called_once_with(docs)
def test_clean_deletes_summaries_and_vector(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None:
segment_query = Mock()
segment_query.filter.return_value.all.return_value = [SimpleNamespace(id="seg-1")]
scalars_result = Mock()
scalars_result.all.return_value = [SimpleNamespace(id="seg-1")]
session = Mock()
session.query.return_value = segment_query
session.scalars.return_value = scalars_result
with (
patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session),
@@ -531,10 +531,10 @@ class TestParagraphIndexProcessor:
size=1,
key="key",
)
query = Mock()
query.where.return_value.all.return_value = [image_upload, non_image_upload]
scalars_result = Mock()
scalars_result.all.return_value = [image_upload, non_image_upload]
session = Mock()
session.query.return_value = query
session.scalars.return_value = scalars_result
with (
patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session),
@@ -565,10 +565,10 @@ class TestParagraphIndexProcessor:
size=1,
key="key",
)
query = Mock()
query.where.return_value.all.return_value = [image_upload]
scalars_result = Mock()
scalars_result.all.return_value = [image_upload]
session = Mock()
session.query.return_value = query
session.scalars.return_value = scalars_result
with (
patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session),

View File

@@ -208,11 +208,7 @@ class TestParentChildIndexProcessor:
vector.create_multimodal.assert_called_once_with(multimodal_docs)
def test_clean_with_precomputed_child_ids(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None:
delete_query = Mock()
where_query = Mock()
where_query.delete.return_value = 2
session = Mock()
session.query.return_value.where.return_value = where_query
with (
patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls,
@@ -227,16 +223,16 @@ class TestParentChildIndexProcessor:
)
vector.delete_by_ids.assert_called_once_with(["child-1", "child-2"])
where_query.delete.assert_called_once_with(synchronize_session=False)
session.execute.assert_called()
session.commit.assert_called_once()
def test_clean_queries_child_ids_when_not_precomputed(
self, processor: ParentChildIndexProcessor, dataset: Mock
) -> None:
child_query = Mock()
child_query.join.return_value.where.return_value.all.return_value = [("child-1",), (None,), ("child-2",)]
execute_result = Mock()
execute_result.all.return_value = [("child-1",), (None,), ("child-2",)]
session = Mock()
session.query.return_value = child_query
session.execute.return_value = execute_result
with (
patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls,
@@ -248,10 +244,7 @@ class TestParentChildIndexProcessor:
vector.delete_by_ids.assert_called_once_with(["child-1", "child-2"])
def test_clean_dataset_wide_cleanup(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None:
where_query = Mock()
where_query.delete.return_value = 3
session = Mock()
session.query.return_value.where.return_value = where_query
with (
patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls,
@@ -261,7 +254,7 @@ class TestParentChildIndexProcessor:
processor.clean(dataset, None, delete_child_chunks=True)
vector.delete.assert_called_once()
where_query.delete.assert_called_once_with(synchronize_session=False)
session.execute.assert_called()
session.commit.assert_called_once()
def test_clean_deletes_summaries_when_requested(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None:

View File

@@ -133,10 +133,10 @@ class TestBaseIndexProcessor:
upload_b = SimpleNamespace(id="bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb", name="b.png")
upload_tool = SimpleNamespace(id="tool-upload-id", name="tool.png")
upload_remote = SimpleNamespace(id="remote-upload-id", name="remote.png")
db_query = Mock()
db_query.where.return_value.all.return_value = [upload_a, upload_b, upload_tool, upload_remote]
scalars_result = Mock()
scalars_result.all.return_value = [upload_a, upload_b, upload_tool, upload_remote]
db_session = Mock()
db_session.query.return_value = db_query
db_session.scalars.return_value = scalars_result
with (
patch.object(processor, "_extract_markdown_images", return_value=images),
@@ -170,10 +170,10 @@ class TestBaseIndexProcessor:
def test_get_content_files_ignores_missing_upload_records(self, processor: _ForwardingBaseIndexProcessor) -> None:
document = Document(page_content="ignored", metadata={"document_id": "doc-1", "dataset_id": "ds-1"})
images = ["/files/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/image-preview"]
db_query = Mock()
db_query.where.return_value.all.return_value = []
scalars_result = Mock()
scalars_result.all.return_value = []
db_session = Mock()
db_session.query.return_value = db_query
db_session.scalars.return_value = scalars_result
with (
patch.object(processor, "_extract_markdown_images", return_value=images),
@@ -259,20 +259,16 @@ class TestBaseIndexProcessor:
assert processor._download_image("https://example.com/image.png", current_user=Mock()) is None
def test_download_tool_file_returns_none_when_not_found(self, processor: _ForwardingBaseIndexProcessor) -> None:
db_query = Mock()
db_query.where.return_value.first.return_value = None
db_session = Mock()
db_session.query.return_value = db_query
db_session.get.return_value = None
with patch("core.rag.index_processor.index_processor_base.db.session", db_session):
assert processor._download_tool_file("tool-id", current_user=Mock()) is None
def test_download_tool_file_uploads_file_when_found(self, processor: _ForwardingBaseIndexProcessor) -> None:
tool_file = SimpleNamespace(file_key="k1", name="tool.png", mimetype="image/png")
db_query = Mock()
db_query.where.return_value.first.return_value = tool_file
db_session = Mock()
db_session.query.return_value = db_query
db_session.get.return_value = tool_file
mock_db = Mock()
mock_db.session = db_session
mock_db.engine = Mock()

View File

@@ -473,12 +473,10 @@ class TestRerankModelRunnerMultimodal:
metadata={},
provider="external",
)
query = Mock()
query.where.return_value.first.return_value = SimpleNamespace(key="image-key")
rerank_result = RerankResult(model="rerank-model", docs=[])
with (
patch("core.rag.rerank.rerank_model.db.session.query", return_value=query),
patch("core.rag.rerank.rerank_model.db.session.get", return_value=SimpleNamespace(key="image-key")),
patch("core.rag.rerank.rerank_model.storage.load_once", return_value=b"image-bytes") as mock_load_once,
patch.object(
rerank_runner,
@@ -504,12 +502,10 @@ class TestRerankModelRunnerMultimodal:
metadata={"doc_id": "img-missing", "doc_type": DocType.IMAGE},
provider="dify",
)
query = Mock()
query.where.return_value.first.return_value = None
rerank_result = RerankResult(model="rerank-model", docs=[])
with (
patch("core.rag.rerank.rerank_model.db.session.query", return_value=query),
patch("core.rag.rerank.rerank_model.db.session.get", return_value=None),
patch.object(
rerank_runner,
"fetch_text_rerank",
@@ -533,8 +529,6 @@ class TestRerankModelRunnerMultimodal:
metadata={"doc_id": "txt-1", "doc_type": DocType.TEXT},
provider="dify",
)
query_chain = Mock()
query_chain.where.return_value.first.return_value = SimpleNamespace(key="query-image-key")
rerank_result = RerankResult(
model="rerank-model",
docs=[RerankDocument(index=0, text="text-content", score=0.77)],
@@ -542,7 +536,7 @@ class TestRerankModelRunnerMultimodal:
mock_model_instance.invoke_multimodal_rerank.return_value = rerank_result
session = MagicMock()
session.query.return_value = query_chain
session.get.return_value = SimpleNamespace(key="query-image-key")
with (
patch("core.rag.rerank.rerank_model.db.session", session),
patch("core.rag.rerank.rerank_model.storage.load_once", return_value=b"query-image-bytes"),
@@ -563,10 +557,7 @@ class TestRerankModelRunnerMultimodal:
assert "user" not in invoke_kwargs
def test_fetch_multimodal_rerank_raises_when_query_image_not_found(self, rerank_runner):
query_chain = Mock()
query_chain.where.return_value.first.return_value = None
with patch("core.rag.rerank.rerank_model.db.session.query", return_value=query_chain):
with patch("core.rag.rerank.rerank_model.db.session.get", return_value=None):
with pytest.raises(ValueError, match="Upload file not found for query"):
rerank_runner.fetch_multimodal_rerank(
query="missing-upload-id",

View File

@@ -3971,11 +3971,10 @@ class TestDatasetRetrievalAdditionalHelpers:
)
def test_get_metadata_filter_condition(self, retrieval: DatasetRetrieval) -> None:
db_query = Mock()
db_query.where.return_value = db_query
db_query.all.return_value = [SimpleNamespace(dataset_id="d1", id="doc-1")]
scalars_result = Mock()
scalars_result.all.return_value = [SimpleNamespace(dataset_id="d1", id="doc-1")]
with patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query):
with patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=scalars_result):
mapping, condition = retrieval.get_metadata_filter_condition(
dataset_ids=["d1"],
query="python",
@@ -3991,7 +3990,7 @@ class TestDatasetRetrievalAdditionalHelpers:
automatic_filters = [{"condition": "contains", "metadata_name": "author", "value": "Alice"}]
with (
patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query),
patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=scalars_result),
patch.object(retrieval, "_automatic_metadata_filter_func", return_value=automatic_filters),
):
mapping, condition = retrieval.get_metadata_filter_condition(
@@ -4012,7 +4011,7 @@ class TestDatasetRetrievalAdditionalHelpers:
logical_operator="and",
conditions=[AppCondition(name="author", comparison_operator="contains", value="{{name}}")],
)
with patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query):
with patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=scalars_result):
mapping, condition = retrieval.get_metadata_filter_condition(
dataset_ids=["d1"],
query="python",
@@ -4027,7 +4026,7 @@ class TestDatasetRetrievalAdditionalHelpers:
assert condition is not None
assert condition.conditions[0].value == "Alice"
with patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query):
with patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=scalars_result):
with pytest.raises(ValueError, match="Invalid metadata filtering mode"):
retrieval.get_metadata_filter_condition(
dataset_ids=["d1"],

Some files were not shown because too many files have changed in this diff Show More