Compare commits

..

1 Commits

Author SHA1 Message Date
autofix-ci[bot]
096661b874 [autofix.ci] apply automated fixes 2026-01-07 09:41:56 +00:00
510 changed files with 8108 additions and 26756 deletions

View File

@@ -5,18 +5,5 @@
"typescript-lsp@claude-plugins-official": true,
"pyright-lsp@claude-plugins-official": true,
"ralph-loop@claude-plugins-official": true
},
"hooks": {
"PreToolUse": [
{
"matcher": "Bash",
"hooks": [
{
"type": "command",
"command": "npx -y block-no-verify@1.1.1"
}
]
}
]
}
}

View File

@@ -39,6 +39,12 @@ jobs:
- name: Install dependencies
run: uv sync --project api --dev
- name: Run pyrefly check
run: |
cd api
uv add --dev pyrefly
uv run pyrefly check || true
- name: Run dify config tests
run: uv run --project api dev/pytest/pytest_config_tests.py

View File

@@ -1,29 +0,0 @@
name: Deploy HITL
on:
workflow_run:
workflows: ["Build and Push API & Web"]
branches:
- "feat/hitl-frontend"
- "feat/hitl-backend"
types:
- completed
jobs:
deploy:
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success' &&
(
github.event.workflow_run.head_branch == 'feat/hitl-frontend' ||
github.event.workflow_run.head_branch == 'feat/hitl-backend'
)
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v0.1.8
with:
host: ${{ secrets.HITL_SSH_HOST }}
username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }}
script: |
${{ vars.SSH_SCRIPT || secrets.SSH_SCRIPT }}

View File

@@ -1,4 +1,4 @@
name: Deploy Agent Dev
name: Deploy Trigger Dev
permissions:
contents: read
@@ -7,7 +7,7 @@ on:
workflow_run:
workflows: ["Build and Push API & Web"]
branches:
- "deploy/agent-dev"
- "deploy/trigger-dev"
types:
- completed
@@ -16,12 +16,12 @@ jobs:
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/agent-dev'
github.event.workflow_run.head_branch == 'deploy/trigger-dev'
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v0.1.8
with:
host: ${{ secrets.AGENT_DEV_SSH_HOST }}
host: ${{ secrets.TRIGGER_SSH_HOST }}
username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }}
script: |

View File

@@ -0,0 +1,94 @@
name: Translate i18n Files Based on English
on:
push:
branches: [main]
paths:
- 'web/i18n/en-US/*.json'
workflow_dispatch:
permissions:
contents: write
pull-requests: write
jobs:
check-and-update:
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
defaults:
run:
working-directory: web
steps:
# Keep use old checkout action version for https://github.com/peter-evans/create-pull-request/issues/4272
- uses: actions/checkout@v4
with:
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Check for file changes in i18n/en-US
id: check_files
run: |
# Skip check for manual trigger, translate all files
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
echo "FILES_CHANGED=true" >> $GITHUB_ENV
echo "FILE_ARGS=" >> $GITHUB_ENV
echo "Manual trigger: translating all files"
else
git fetch origin "${{ github.event.before }}" || true
git fetch origin "${{ github.sha }}" || true
changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.json')
echo "Changed files: $changed_files"
if [ -n "$changed_files" ]; then
echo "FILES_CHANGED=true" >> $GITHUB_ENV
file_args=""
for file in $changed_files; do
filename=$(basename "$file" .json)
file_args="$file_args --file $filename"
done
echo "FILE_ARGS=$file_args" >> $GITHUB_ENV
echo "File arguments: $file_args"
else
echo "FILES_CHANGED=false" >> $GITHUB_ENV
fi
fi
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
package_json_file: web/package.json
run_install: false
- name: Set up Node.js
if: env.FILES_CHANGED == 'true'
uses: actions/setup-node@v6
with:
node-version: 'lts/*'
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Install dependencies
if: env.FILES_CHANGED == 'true'
working-directory: ./web
run: pnpm install --frozen-lockfile
- name: Generate i18n translations
if: env.FILES_CHANGED == 'true'
working-directory: ./web
run: pnpm run i18n:gen ${{ env.FILE_ARGS }}
- name: Create Pull Request
if: env.FILES_CHANGED == 'true'
uses: peter-evans/create-pull-request@v6
with:
token: ${{ secrets.GITHUB_TOKEN }}
commit-message: 'chore(i18n): update translations based on en-US changes'
title: 'chore(i18n): translate i18n files based on en-US changes'
body: |
This PR was automatically created to update i18n translation files based on changes in en-US locale.
**Triggered by:** ${{ github.sha }}
**Changes included:**
- Updated translation files for all locales
branch: chore/automated-i18n-updates-${{ github.sha }}
delete-branch: true

View File

@@ -1,421 +0,0 @@
name: Translate i18n Files with Claude Code
# Note: claude-code-action doesn't support push events directly.
# Push events are handled by trigger-i18n-sync.yml which sends repository_dispatch.
# See: https://github.com/langgenius/dify/issues/30743
on:
repository_dispatch:
types: [i18n-sync]
workflow_dispatch:
inputs:
files:
description: 'Specific files to translate (space-separated, e.g., "app common"). Leave empty for all files.'
required: false
type: string
languages:
description: 'Specific languages to translate (space-separated, e.g., "zh-Hans ja-JP"). Leave empty for all supported languages.'
required: false
type: string
mode:
description: 'Sync mode: incremental (only changes) or full (re-check all keys)'
required: false
default: 'incremental'
type: choice
options:
- incremental
- full
permissions:
contents: write
pull-requests: write
jobs:
translate:
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
timeout-minutes: 60
steps:
- name: Checkout repository
uses: actions/checkout@v6
with:
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Configure Git
run: |
git config --global user.name "github-actions[bot]"
git config --global user.email "github-actions[bot]@users.noreply.github.com"
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
package_json_file: web/package.json
run_install: false
- name: Set up Node.js
uses: actions/setup-node@v6
with:
node-version: 'lts/*'
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Detect changed files and generate diff
id: detect_changes
run: |
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
# Manual trigger
if [ -n "${{ github.event.inputs.files }}" ]; then
echo "CHANGED_FILES=${{ github.event.inputs.files }}" >> $GITHUB_OUTPUT
else
# Get all JSON files in en-US directory
files=$(ls web/i18n/en-US/*.json 2>/dev/null | xargs -n1 basename | sed 's/.json$//' | tr '\n' ' ')
echo "CHANGED_FILES=$files" >> $GITHUB_OUTPUT
fi
echo "TARGET_LANGS=${{ github.event.inputs.languages }}" >> $GITHUB_OUTPUT
echo "SYNC_MODE=${{ github.event.inputs.mode || 'incremental' }}" >> $GITHUB_OUTPUT
# For manual trigger with incremental mode, get diff from last commit
# For full mode, we'll do a complete check anyway
if [ "${{ github.event.inputs.mode }}" == "full" ]; then
echo "Full mode: will check all keys" > /tmp/i18n-diff.txt
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
else
git diff HEAD~1..HEAD -- 'web/i18n/en-US/*.json' > /tmp/i18n-diff.txt 2>/dev/null || echo "" > /tmp/i18n-diff.txt
if [ -s /tmp/i18n-diff.txt ]; then
echo "DIFF_AVAILABLE=true" >> $GITHUB_OUTPUT
else
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
fi
fi
elif [ "${{ github.event_name }}" == "repository_dispatch" ]; then
# Triggered by push via trigger-i18n-sync.yml workflow
# Validate required payload fields
if [ -z "${{ github.event.client_payload.changed_files }}" ]; then
echo "Error: repository_dispatch payload missing required 'changed_files' field" >&2
exit 1
fi
echo "CHANGED_FILES=${{ github.event.client_payload.changed_files }}" >> $GITHUB_OUTPUT
echo "TARGET_LANGS=" >> $GITHUB_OUTPUT
echo "SYNC_MODE=${{ github.event.client_payload.sync_mode || 'incremental' }}" >> $GITHUB_OUTPUT
# Decode the base64-encoded diff from the trigger workflow
if [ -n "${{ github.event.client_payload.diff_base64 }}" ]; then
if ! echo "${{ github.event.client_payload.diff_base64 }}" | base64 -d > /tmp/i18n-diff.txt 2>&1; then
echo "Warning: Failed to decode base64 diff payload" >&2
echo "" > /tmp/i18n-diff.txt
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
elif [ -s /tmp/i18n-diff.txt ]; then
echo "DIFF_AVAILABLE=true" >> $GITHUB_OUTPUT
else
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
fi
else
echo "" > /tmp/i18n-diff.txt
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
fi
else
echo "Unsupported event type: ${{ github.event_name }}"
exit 1
fi
# Truncate diff if too large (keep first 50KB)
if [ -f /tmp/i18n-diff.txt ]; then
head -c 50000 /tmp/i18n-diff.txt > /tmp/i18n-diff-truncated.txt
mv /tmp/i18n-diff-truncated.txt /tmp/i18n-diff.txt
fi
echo "Detected files: $(cat $GITHUB_OUTPUT | grep CHANGED_FILES || echo 'none')"
- name: Run Claude Code for Translation Sync
if: steps.detect_changes.outputs.CHANGED_FILES != ''
uses: anthropics/claude-code-action@v1
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}
prompt: |
You are a professional i18n synchronization engineer for the Dify project.
Your task is to keep all language translations in sync with the English source (en-US).
## CRITICAL TOOL RESTRICTIONS
- Use **Read** tool to read files (NOT cat or bash)
- Use **Edit** tool to modify JSON files (NOT node, jq, or bash scripts)
- Use **Bash** ONLY for: git commands, gh commands, pnpm commands
- Run bash commands ONE BY ONE, never combine with && or ||
- NEVER use `$()` command substitution - it's not supported. Split into separate commands instead.
## WORKING DIRECTORY & ABSOLUTE PATHS
Claude Code sandbox working directory may vary. Always use absolute paths:
- For pnpm: `pnpm --dir ${{ github.workspace }}/web <command>`
- For git: `git -C ${{ github.workspace }} <command>`
- For gh: `gh --repo ${{ github.repository }} <command>`
- For file paths: `${{ github.workspace }}/web/i18n/`
## EFFICIENCY RULES
- **ONE Edit per language file** - batch all key additions into a single Edit
- Insert new keys at the beginning of JSON (after `{`), lint:fix will sort them
- Translate ALL keys for a language mentally first, then do ONE Edit
## Context
- Changed/target files: ${{ steps.detect_changes.outputs.CHANGED_FILES }}
- Target languages (empty means all supported): ${{ steps.detect_changes.outputs.TARGET_LANGS }}
- Sync mode: ${{ steps.detect_changes.outputs.SYNC_MODE }}
- Translation files are located in: ${{ github.workspace }}/web/i18n/{locale}/{filename}.json
- Language configuration is in: ${{ github.workspace }}/web/i18n-config/languages.ts
- Git diff is available: ${{ steps.detect_changes.outputs.DIFF_AVAILABLE }}
## CRITICAL DESIGN: Verify First, Then Sync
You MUST follow this three-phase approach:
═══════════════════════════════════════════════════════════════
║ PHASE 1: VERIFY - Analyze and Generate Change Report ║
═══════════════════════════════════════════════════════════════
### Step 1.1: Analyze Git Diff (for incremental mode)
Use the Read tool to read `/tmp/i18n-diff.txt` to see the git diff.
Parse the diff to categorize changes:
- Lines with `+` (not `+++`): Added or modified values
- Lines with `-` (not `---`): Removed or old values
- Identify specific keys for each category:
* ADD: Keys that appear only in `+` lines (new keys)
* UPDATE: Keys that appear in both `-` and `+` lines (value changed)
* DELETE: Keys that appear only in `-` lines (removed keys)
### Step 1.2: Read Language Configuration
Use the Read tool to read `${{ github.workspace }}/web/i18n-config/languages.ts`.
Extract all languages with `supported: true`.
### Step 1.3: Run i18n:check for Each Language
```bash
pnpm --dir ${{ github.workspace }}/web install --frozen-lockfile
```
```bash
pnpm --dir ${{ github.workspace }}/web run i18n:check
```
This will report:
- Missing keys (need to ADD)
- Extra keys (need to DELETE)
### Step 1.4: Generate Change Report
Create a structured report identifying:
```
╔══════════════════════════════════════════════════════════════╗
║ I18N SYNC CHANGE REPORT ║
╠══════════════════════════════════════════════════════════════╣
║ Files to process: [list] ║
║ Languages to sync: [list] ║
╠══════════════════════════════════════════════════════════════╣
║ ADD (New Keys): ║
║ - [filename].[key]: "English value" ║
║ ... ║
╠══════════════════════════════════════════════════════════════╣
║ UPDATE (Modified Keys - MUST re-translate): ║
║ - [filename].[key]: "Old value" → "New value" ║
║ ... ║
╠══════════════════════════════════════════════════════════════╣
║ DELETE (Extra Keys): ║
║ - [language]/[filename].[key] ║
║ ... ║
╚══════════════════════════════════════════════════════════════╝
```
**IMPORTANT**: For UPDATE detection, compare git diff to find keys where
the English value changed. These MUST be re-translated even if target
language already has a translation (it's now stale!).
═══════════════════════════════════════════════════════════════
║ PHASE 2: SYNC - Execute Changes Based on Report ║
═══════════════════════════════════════════════════════════════
### Step 2.1: Process ADD Operations (BATCH per language file)
**CRITICAL WORKFLOW for efficiency:**
1. First, translate ALL new keys for ALL languages mentally
2. Then, for EACH language file, do ONE Edit operation:
- Read the file once
- Insert ALL new keys at the beginning (right after the opening `{`)
- Don't worry about alphabetical order - lint:fix will sort them later
Example Edit (adding 3 keys to zh-Hans/app.json):
```
old_string: '{\n "accessControl"'
new_string: '{\n "newKey1": "translation1",\n "newKey2": "translation2",\n "newKey3": "translation3",\n "accessControl"'
```
**IMPORTANT**:
- ONE Edit per language file (not one Edit per key!)
- Always use the Edit tool. NEVER use bash scripts, node, or jq.
### Step 2.2: Process UPDATE Operations
**IMPORTANT: Special handling for zh-Hans and ja-JP**
If zh-Hans or ja-JP files were ALSO modified in the same push:
- Run: `git -C ${{ github.workspace }} diff HEAD~1 --name-only` and check for zh-Hans or ja-JP files
- If found, it means someone manually translated them. Apply these rules:
1. **Missing keys**: Still ADD them (completeness required)
2. **Existing translations**: Compare with the NEW English value:
- If translation is **completely wrong** or **unrelated** → Update it
- If translation is **roughly correct** (captures the meaning) → Keep it, respect manual work
- When in doubt, **keep the manual translation**
Example:
- English changed: "Save" → "Save Changes"
- Manual translation: "保存更改" → Keep it (correct meaning)
- Manual translation: "删除" → Update it (completely wrong)
For other languages:
Use Edit tool to replace the old value with the new translation.
You can batch multiple updates in one Edit if they are adjacent.
### Step 2.3: Process DELETE Operations
For extra keys reported by i18n:check:
- Run: `pnpm --dir ${{ github.workspace }}/web run i18n:check --auto-remove`
- Or manually remove from target language JSON files
## Translation Guidelines
- PRESERVE all placeholders exactly as-is:
- `{{variable}}` - Mustache interpolation
- `${variable}` - Template literal
- `<tag>content</tag>` - HTML tags
- `_one`, `_other` - Pluralization suffixes (these are KEY suffixes, not values)
- Use appropriate language register (formal/informal) based on existing translations
- Match existing translation style in each language
- Technical terms: check existing conventions per language
- For CJK languages: no spaces between characters unless necessary
- For RTL languages (ar-TN, fa-IR): ensure proper text handling
## Output Format Requirements
- Alphabetical key ordering (if original file uses it)
- 2-space indentation
- Trailing newline at end of file
- Valid JSON (use proper escaping for special characters)
═══════════════════════════════════════════════════════════════
║ PHASE 3: RE-VERIFY - Confirm All Issues Resolved ║
═══════════════════════════════════════════════════════════════
### Step 3.1: Run Lint Fix (IMPORTANT!)
```bash
pnpm --dir ${{ github.workspace }}/web lint:fix --quiet -- 'i18n/**/*.json'
```
This ensures:
- JSON keys are sorted alphabetically (jsonc/sort-keys rule)
- Valid i18n keys (dify-i18n/valid-i18n-keys rule)
- No extra keys (dify-i18n/no-extra-keys rule)
### Step 3.2: Run Final i18n Check
```bash
pnpm --dir ${{ github.workspace }}/web run i18n:check
```
### Step 3.3: Fix Any Remaining Issues
If check reports issues:
- Go back to PHASE 2 for unresolved items
- Repeat until check passes
### Step 3.4: Generate Final Summary
```
╔══════════════════════════════════════════════════════════════╗
║ SYNC COMPLETED SUMMARY ║
╠══════════════════════════════════════════════════════════════╣
║ Language │ Added │ Updated │ Deleted │ Status ║
╠══════════════════════════════════════════════════════════════╣
║ zh-Hans │ 5 │ 2 │ 1 │ ✓ Complete ║
║ ja-JP │ 5 │ 2 │ 1 │ ✓ Complete ║
║ ... │ ... │ ... │ ... │ ... ║
╠══════════════════════════════════════════════════════════════╣
║ i18n:check │ PASSED - All keys in sync ║
╚══════════════════════════════════════════════════════════════╝
```
## Mode-Specific Behavior
**SYNC_MODE = "incremental"** (default):
- Focus on keys identified from git diff
- Also check i18n:check output for any missing/extra keys
- Efficient for small changes
**SYNC_MODE = "full"**:
- Compare ALL keys between en-US and each language
- Run i18n:check to identify all discrepancies
- Use for first-time sync or fixing historical issues
## Important Notes
1. Always run i18n:check BEFORE and AFTER making changes
2. The check script is the source of truth for missing/extra keys
3. For UPDATE scenario: git diff is the source of truth for changed values
4. Create a single commit with all translation changes
5. If any translation fails, continue with others and report failures
═══════════════════════════════════════════════════════════════
║ PHASE 4: COMMIT AND CREATE PR ║
═══════════════════════════════════════════════════════════════
After all translations are complete and verified:
### Step 4.1: Check for changes
```bash
git -C ${{ github.workspace }} status --porcelain
```
If there are changes:
### Step 4.2: Create a new branch and commit
Run these git commands ONE BY ONE (not combined with &&).
**IMPORTANT**: Do NOT use `$()` command substitution. Use two separate commands:
1. First, get the timestamp:
```bash
date +%Y%m%d-%H%M%S
```
(Note the output, e.g., "20260115-143052")
2. Then create branch using the timestamp value:
```bash
git -C ${{ github.workspace }} checkout -b chore/i18n-sync-20260115-143052
```
(Replace "20260115-143052" with the actual timestamp from step 1)
3. Stage changes:
```bash
git -C ${{ github.workspace }} add web/i18n/
```
4. Commit:
```bash
git -C ${{ github.workspace }} commit -m "chore(i18n): sync translations with en-US - Mode: ${{ steps.detect_changes.outputs.SYNC_MODE }}"
```
5. Push:
```bash
git -C ${{ github.workspace }} push origin HEAD
```
### Step 4.3: Create Pull Request
```bash
gh pr create --repo ${{ github.repository }} --title "chore(i18n): sync translations with en-US" --body "## Summary
This PR was automatically generated to sync i18n translation files.
### Changes
- Mode: ${{ steps.detect_changes.outputs.SYNC_MODE }}
- Files processed: ${{ steps.detect_changes.outputs.CHANGED_FILES }}
### Verification
- [x] \`i18n:check\` passed
- [x] \`lint:fix\` applied
🤖 Generated with Claude Code GitHub Action" --base main
```
claude_args: |
--max-turns 150
--allowedTools "Read,Write,Edit,Bash(git *),Bash(git:*),Bash(gh *),Bash(gh:*),Bash(pnpm *),Bash(pnpm:*),Bash(date *),Bash(date:*),Glob,Grep"

View File

@@ -1,66 +0,0 @@
name: Trigger i18n Sync on Push
# This workflow bridges the push event to repository_dispatch
# because claude-code-action doesn't support push events directly.
# See: https://github.com/langgenius/dify/issues/30743
on:
push:
branches: [main]
paths:
- 'web/i18n/en-US/*.json'
permissions:
contents: write
jobs:
trigger:
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
timeout-minutes: 5
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Detect changed files and generate diff
id: detect
run: |
BEFORE_SHA="${{ github.event.before }}"
# Handle edge case: force push may have null/zero SHA
if [ -z "$BEFORE_SHA" ] || [ "$BEFORE_SHA" = "0000000000000000000000000000000000000000" ]; then
BEFORE_SHA="HEAD~1"
fi
# Detect changed i18n files
changed=$(git diff --name-only "$BEFORE_SHA" "${{ github.sha }}" -- 'web/i18n/en-US/*.json' 2>/dev/null | xargs -n1 basename 2>/dev/null | sed 's/.json$//' | tr '\n' ' ' || echo "")
echo "changed_files=$changed" >> $GITHUB_OUTPUT
# Generate diff for context
git diff "$BEFORE_SHA" "${{ github.sha }}" -- 'web/i18n/en-US/*.json' > /tmp/i18n-diff.txt 2>/dev/null || echo "" > /tmp/i18n-diff.txt
# Truncate if too large (keep first 50KB to match receiving workflow)
head -c 50000 /tmp/i18n-diff.txt > /tmp/i18n-diff-truncated.txt
mv /tmp/i18n-diff-truncated.txt /tmp/i18n-diff.txt
# Base64 encode the diff for safe JSON transport (portable, single-line)
diff_base64=$(base64 < /tmp/i18n-diff.txt | tr -d '\n')
echo "diff_base64=$diff_base64" >> $GITHUB_OUTPUT
if [ -n "$changed" ]; then
echo "has_changes=true" >> $GITHUB_OUTPUT
echo "Detected changed files: $changed"
else
echo "has_changes=false" >> $GITHUB_OUTPUT
echo "No i18n changes detected"
fi
- name: Trigger i18n sync workflow
if: steps.detect.outputs.has_changes == 'true'
uses: peter-evans/repository-dispatch@v3
with:
token: ${{ secrets.GITHUB_TOKEN }}
event-type: i18n-sync
client-payload: '{"changed_files": "${{ steps.detect.outputs.changed_files }}", "diff_base64": "${{ steps.detect.outputs.diff_base64 }}", "sync_mode": "incremental", "trigger_sha": "${{ github.sha }}"}'

1
.nvmrc Normal file
View File

@@ -0,0 +1 @@
22.11.0

View File

@@ -589,7 +589,6 @@ ENABLE_CLEAN_UNUSED_DATASETS_TASK=false
ENABLE_CREATE_TIDB_SERVERLESS_TASK=false
ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK=false
ENABLE_CLEAN_MESSAGES=false
ENABLE_WORKFLOW_RUN_CLEANUP_TASK=false
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
ENABLE_DATASETS_QUEUE_MONITOR=false
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true

View File

@@ -1,5 +1,4 @@
import base64
import datetime
import json
import logging
import secrets
@@ -35,7 +34,7 @@ from libs.rsa import generate_key_pair
from models import Tenant
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.model import App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
from models.provider import Provider, ProviderModel
from models.provider_ids import DatasourceProviderID, ToolProviderID
@@ -46,7 +45,6 @@ from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpi
from services.plugin.data_migration import PluginDataMigration
from services.plugin.plugin_migration import PluginMigration
from services.plugin.plugin_service import PluginService
from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup
from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
logger = logging.getLogger(__name__)
@@ -64,10 +62,8 @@ def reset_password(email, new_password, password_confirm):
if str(new_password).strip() != str(password_confirm).strip():
click.echo(click.style("Passwords do not match.", fg="red"))
return
normalized_email = email.strip().lower()
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
account = session.query(Account).where(Account.email == email).one_or_none()
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
@@ -88,7 +84,7 @@ def reset_password(email, new_password, password_confirm):
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
AccountService.reset_login_error_rate_limit(normalized_email)
AccountService.reset_login_error_rate_limit(email)
click.echo(click.style("Password reset successfully.", fg="green"))
@@ -104,22 +100,20 @@ def reset_email(email, new_email, email_confirm):
if str(new_email).strip() != str(email_confirm).strip():
click.echo(click.style("New emails do not match.", fg="red"))
return
normalized_new_email = new_email.strip().lower()
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
account = session.query(Account).where(Account.email == email).one_or_none()
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
try:
email_validate(normalized_new_email)
email_validate(new_email)
except:
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return
account.email = normalized_new_email
account.email = new_email
click.echo(click.style("Email updated successfully.", fg="green"))
@@ -664,7 +658,7 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
return
# Create account
email = email.strip().lower()
email = email.strip()
if "@" not in email:
click.echo(click.style("Invalid email address.", fg="red"))
@@ -858,61 +852,6 @@ def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[
click.echo(click.style("Clear free plan tenant expired logs completed.", fg="green"))
@click.command("clean-workflow-runs", help="Clean expired workflow runs and related data for free tenants.")
@click.option("--days", default=30, show_default=True, help="Delete workflow runs created before N days ago.")
@click.option("--batch-size", default=200, show_default=True, help="Batch size for selecting workflow runs.")
@click.option(
"--start-from",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
default=None,
help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.",
)
@click.option(
"--end-before",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
default=None,
help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.",
)
@click.option(
"--dry-run",
is_flag=True,
help="Preview cleanup results without deleting any workflow run data.",
)
def clean_workflow_runs(
days: int,
batch_size: int,
start_from: datetime.datetime | None,
end_before: datetime.datetime | None,
dry_run: bool,
):
"""
Clean workflow runs and related workflow data for free tenants.
"""
if (start_from is None) ^ (end_before is None):
raise click.UsageError("--start-from and --end-before must be provided together.")
start_time = datetime.datetime.now(datetime.UTC)
click.echo(click.style(f"Starting workflow run cleanup at {start_time.isoformat()}.", fg="white"))
WorkflowRunCleanup(
days=days,
batch_size=batch_size,
start_from=start_from,
end_before=end_before,
dry_run=dry_run,
).run()
end_time = datetime.datetime.now(datetime.UTC)
elapsed = end_time - start_time
click.echo(
click.style(
f"Workflow run cleanup completed. start={start_time.isoformat()} "
f"end={end_time.isoformat()} duration={elapsed}",
fg="green",
)
)
@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.")
@click.command("clear-orphaned-file-records", help="Clear orphaned file records.")
def clear_orphaned_file_records(force: bool):

View File

@@ -959,16 +959,6 @@ class MailConfig(BaseSettings):
default=None,
)
ENABLE_TRIAL_APP: bool = Field(
description="Enable trial app",
default=False,
)
ENABLE_EXPLORE_BANNER: bool = Field(
description="Enable explore banner",
default=False,
)
class RagEtlConfig(BaseSettings):
"""
@@ -1111,10 +1101,6 @@ class CeleryScheduleTasksConfig(BaseSettings):
description="Enable clean messages task",
default=False,
)
ENABLE_WORKFLOW_RUN_CLEANUP_TASK: bool = Field(
description="Enable scheduled workflow run cleanup task",
default=False,
)
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field(
description="Enable mail clean document notify task",
default=False,

View File

@@ -8,11 +8,6 @@ class HostedCreditConfig(BaseSettings):
default="",
)
HOSTED_POOL_CREDITS: int = Field(
description="Pool credits for hosted service",
default=200,
)
def get_model_credits(self, model_name: str) -> int:
"""
Get credit value for a specific model name.
@@ -65,46 +60,19 @@ class HostedOpenAiConfig(BaseSettings):
HOSTED_OPENAI_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for trial access",
default="gpt-4,"
"gpt-4-turbo-preview,"
"gpt-4-turbo-2024-04-09,"
"gpt-4-1106-preview,"
"gpt-4-0125-preview,"
"gpt-4-turbo,"
"gpt-4.1,"
"gpt-4.1-2025-04-14,"
"gpt-4.1-mini,"
"gpt-4.1-mini-2025-04-14,"
"gpt-4.1-nano,"
"gpt-4.1-nano-2025-04-14,"
"gpt-3.5-turbo,"
default="gpt-3.5-turbo,"
"gpt-3.5-turbo-1106,"
"gpt-3.5-turbo-instruct,"
"gpt-3.5-turbo-16k,"
"gpt-3.5-turbo-16k-0613,"
"gpt-3.5-turbo-1106,"
"gpt-3.5-turbo-0613,"
"gpt-3.5-turbo-0125,"
"gpt-3.5-turbo-instruct,"
"text-davinci-003,"
"chatgpt-4o-latest,"
"gpt-4o,"
"gpt-4o-2024-05-13,"
"gpt-4o-2024-08-06,"
"gpt-4o-2024-11-20,"
"gpt-4o-audio-preview,"
"gpt-4o-audio-preview-2025-06-03,"
"gpt-4o-mini,"
"gpt-4o-mini-2024-07-18,"
"o3-mini,"
"o3-mini-2025-01-31,"
"gpt-5-mini-2025-08-07,"
"gpt-5-mini,"
"o4-mini,"
"o4-mini-2025-04-16,"
"gpt-5-chat-latest,"
"gpt-5,"
"gpt-5-2025-08-07,"
"gpt-5-nano,"
"gpt-5-nano-2025-08-07",
"text-davinci-003",
)
HOSTED_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field(
description="Quota limit for hosted OpenAI service usage",
default=200,
)
HOSTED_OPENAI_PAID_ENABLED: bool = Field(
@@ -119,13 +87,6 @@ class HostedOpenAiConfig(BaseSettings):
"gpt-4-turbo-2024-04-09,"
"gpt-4-1106-preview,"
"gpt-4-0125-preview,"
"gpt-4-turbo,"
"gpt-4.1,"
"gpt-4.1-2025-04-14,"
"gpt-4.1-mini,"
"gpt-4.1-mini-2025-04-14,"
"gpt-4.1-nano,"
"gpt-4.1-nano-2025-04-14,"
"gpt-3.5-turbo,"
"gpt-3.5-turbo-16k,"
"gpt-3.5-turbo-16k-0613,"
@@ -133,150 +94,7 @@ class HostedOpenAiConfig(BaseSettings):
"gpt-3.5-turbo-0613,"
"gpt-3.5-turbo-0125,"
"gpt-3.5-turbo-instruct,"
"text-davinci-003,"
"chatgpt-4o-latest,"
"gpt-4o,"
"gpt-4o-2024-05-13,"
"gpt-4o-2024-08-06,"
"gpt-4o-2024-11-20,"
"gpt-4o-audio-preview,"
"gpt-4o-audio-preview-2025-06-03,"
"gpt-4o-mini,"
"gpt-4o-mini-2024-07-18,"
"o3-mini,"
"o3-mini-2025-01-31,"
"gpt-5-mini-2025-08-07,"
"gpt-5-mini,"
"o4-mini,"
"o4-mini-2025-04-16,"
"gpt-5-chat-latest,"
"gpt-5,"
"gpt-5-2025-08-07,"
"gpt-5-nano,"
"gpt-5-nano-2025-08-07",
)
class HostedGeminiConfig(BaseSettings):
"""
Configuration for fetching Gemini service
"""
HOSTED_GEMINI_API_KEY: str | None = Field(
description="API key for hosted Gemini service",
default=None,
)
HOSTED_GEMINI_API_BASE: str | None = Field(
description="Base URL for hosted Gemini API",
default=None,
)
HOSTED_GEMINI_API_ORGANIZATION: str | None = Field(
description="Organization ID for hosted Gemini service",
default=None,
)
HOSTED_GEMINI_TRIAL_ENABLED: bool = Field(
description="Enable trial access to hosted Gemini service",
default=False,
)
HOSTED_GEMINI_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for trial access",
default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,",
)
HOSTED_GEMINI_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted gemini service",
default=False,
)
HOSTED_GEMINI_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,",
)
class HostedXAIConfig(BaseSettings):
"""
Configuration for fetching XAI service
"""
HOSTED_XAI_API_KEY: str | None = Field(
description="API key for hosted XAI service",
default=None,
)
HOSTED_XAI_API_BASE: str | None = Field(
description="Base URL for hosted XAI API",
default=None,
)
HOSTED_XAI_API_ORGANIZATION: str | None = Field(
description="Organization ID for hosted XAI service",
default=None,
)
HOSTED_XAI_TRIAL_ENABLED: bool = Field(
description="Enable trial access to hosted XAI service",
default=False,
)
HOSTED_XAI_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for trial access",
default="grok-3,grok-3-mini,grok-3-mini-fast",
)
HOSTED_XAI_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted XAI service",
default=False,
)
HOSTED_XAI_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="grok-3,grok-3-mini,grok-3-mini-fast",
)
class HostedDeepseekConfig(BaseSettings):
"""
Configuration for fetching Deepseek service
"""
HOSTED_DEEPSEEK_API_KEY: str | None = Field(
description="API key for hosted Deepseek service",
default=None,
)
HOSTED_DEEPSEEK_API_BASE: str | None = Field(
description="Base URL for hosted Deepseek API",
default=None,
)
HOSTED_DEEPSEEK_API_ORGANIZATION: str | None = Field(
description="Organization ID for hosted Deepseek service",
default=None,
)
HOSTED_DEEPSEEK_TRIAL_ENABLED: bool = Field(
description="Enable trial access to hosted Deepseek service",
default=False,
)
HOSTED_DEEPSEEK_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for trial access",
default="deepseek-chat,deepseek-reasoner",
)
HOSTED_DEEPSEEK_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted Deepseek service",
default=False,
)
HOSTED_DEEPSEEK_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="deepseek-chat,deepseek-reasoner",
"text-davinci-003",
)
@@ -326,66 +144,16 @@ class HostedAnthropicConfig(BaseSettings):
default=False,
)
HOSTED_ANTHROPIC_QUOTA_LIMIT: NonNegativeInt = Field(
description="Quota limit for hosted Anthropic service usage",
default=600000,
)
HOSTED_ANTHROPIC_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted Anthropic service",
default=False,
)
HOSTED_ANTHROPIC_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="claude-opus-4-20250514,"
"claude-sonnet-4-20250514,"
"claude-3-5-haiku-20241022,"
"claude-3-opus-20240229,"
"claude-3-7-sonnet-20250219,"
"claude-3-haiku-20240307",
)
HOSTED_ANTHROPIC_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="claude-opus-4-20250514,"
"claude-sonnet-4-20250514,"
"claude-3-5-haiku-20241022,"
"claude-3-opus-20240229,"
"claude-3-7-sonnet-20250219,"
"claude-3-haiku-20240307",
)
class HostedTongyiConfig(BaseSettings):
"""
Configuration for hosted Tongyi service
"""
HOSTED_TONGYI_API_KEY: str | None = Field(
description="API key for hosted Tongyi service",
default=None,
)
HOSTED_TONGYI_USE_INTERNATIONAL_ENDPOINT: bool = Field(
description="Use international endpoint for hosted Tongyi service",
default=False,
)
HOSTED_TONGYI_TRIAL_ENABLED: bool = Field(
description="Enable trial access to hosted Tongyi service",
default=False,
)
HOSTED_TONGYI_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted Anthropic service",
default=False,
)
HOSTED_TONGYI_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for trial access",
default="",
)
HOSTED_TONGYI_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="",
)
class HostedMinmaxConfig(BaseSettings):
"""
@@ -478,13 +246,9 @@ class HostedServiceConfig(
HostedOpenAiConfig,
HostedSparkConfig,
HostedZhipuAIConfig,
HostedTongyiConfig,
# moderation
HostedModerationConfig,
# credit config
HostedCreditConfig,
HostedGeminiConfig,
HostedXAIConfig,
HostedDeepseekConfig,
):
pass

View File

@@ -107,12 +107,10 @@ from .datasets.rag_pipeline import (
# Import explore controllers
from .explore import (
banner,
installed_app,
parameter,
recommended_app,
saved_message,
trial,
)
# Import tag controllers
@@ -147,7 +145,6 @@ __all__ = [
"apikey",
"app",
"audio",
"banner",
"billing",
"bp",
"completion",
@@ -201,7 +198,6 @@ __all__ = [
"statistic",
"tags",
"tool_providers",
"trial",
"trigger_providers",
"version",
"website",

View File

@@ -15,7 +15,7 @@ from controllers.console.wraps import only_edition_cloud
from core.db.session_factory import session_factory
from extensions.ext_database import db
from libs.token import extract_access_token
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
from models.model import App, InstalledApp, RecommendedApp
P = ParamSpec("P")
R = TypeVar("R")
@@ -32,8 +32,6 @@ class InsertExploreAppPayload(BaseModel):
language: str = Field(...)
category: str = Field(...)
position: int = Field(...)
can_trial: bool = Field(default=False)
trial_limit: int = Field(default=0)
@field_validator("language")
@classmethod
@@ -41,33 +39,11 @@ class InsertExploreAppPayload(BaseModel):
return supported_language(value)
class InsertExploreBannerPayload(BaseModel):
category: str = Field(...)
title: str = Field(...)
description: str = Field(...)
img_src: str = Field(..., alias="img-src")
language: str = Field(default="en-US")
link: str = Field(...)
sort: int = Field(...)
@field_validator("language")
@classmethod
def validate_language(cls, value: str) -> str:
return supported_language(value)
model_config = {"populate_by_name": True}
console_ns.schema_model(
InsertExploreAppPayload.__name__,
InsertExploreAppPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
InsertExploreBannerPayload.__name__,
InsertExploreBannerPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
def admin_required(view: Callable[P, R]):
@wraps(view)
@@ -133,20 +109,6 @@ class InsertExploreAppListApi(Resource):
)
db.session.add(recommended_app)
if payload.can_trial:
trial_app = db.session.execute(
select(TrialApp).where(TrialApp.app_id == payload.app_id)
).scalar_one_or_none()
if not trial_app:
db.session.add(
TrialApp(
app_id=payload.app_id,
tenant_id=app.tenant_id,
trial_limit=payload.trial_limit,
)
)
else:
trial_app.trial_limit = payload.trial_limit
app.is_public = True
db.session.commit()
@@ -161,20 +123,6 @@ class InsertExploreAppListApi(Resource):
recommended_app.category = payload.category
recommended_app.position = payload.position
if payload.can_trial:
trial_app = db.session.execute(
select(TrialApp).where(TrialApp.app_id == payload.app_id)
).scalar_one_or_none()
if not trial_app:
db.session.add(
TrialApp(
app_id=payload.app_id,
tenant_id=app.tenant_id,
trial_limit=payload.trial_limit,
)
)
else:
trial_app.trial_limit = payload.trial_limit
app.is_public = True
db.session.commit()
@@ -220,62 +168,7 @@ class InsertExploreAppApi(Resource):
for installed_app in installed_apps:
session.delete(installed_app)
trial_app = session.execute(
select(TrialApp).where(TrialApp.app_id == recommended_app.app_id)
).scalar_one_or_none()
if trial_app:
session.delete(trial_app)
db.session.delete(recommended_app)
db.session.commit()
return {"result": "success"}, 204
@console_ns.route("/admin/insert-explore-banner")
class InsertExploreBannerApi(Resource):
@console_ns.doc("insert_explore_banner")
@console_ns.doc(description="Insert an explore banner")
@console_ns.expect(console_ns.models[InsertExploreBannerPayload.__name__])
@console_ns.response(201, "Banner inserted successfully")
@only_edition_cloud
@admin_required
def post(self):
payload = InsertExploreBannerPayload.model_validate(console_ns.payload)
content = {
"category": payload.category,
"title": payload.title,
"description": payload.description,
"img-src": payload.img_src,
}
banner = ExporleBanner(
content=content,
link=payload.link,
sort=payload.sort,
language=payload.language,
)
db.session.add(banner)
db.session.commit()
return {"result": "success"}, 201
@console_ns.route("/admin/insert-explore-banner/<uuid:banner_id>")
class DeleteExploreBannerApi(Resource):
@console_ns.doc("delete_explore_banner")
@console_ns.doc(description="Delete an explore banner")
@console_ns.doc(params={"banner_id": "Banner ID to delete"})
@console_ns.response(204, "Banner deleted successfully")
@only_edition_cloud
@admin_required
def delete(self, banner_id):
banner = db.session.execute(select(ExporleBanner).where(ExporleBanner.id == banner_id)).scalar_one_or_none()
if not banner:
raise NotFound(f"Banner '{banner_id}' is not found")
db.session.delete(banner)
db.session.commit()
return {"result": "success"}, 204

View File

@@ -272,7 +272,6 @@ class AnnotationExportApi(Resource):
@account_initialization_required
@edit_permission_required
def get(self, app_id):
app_id = str(app_id)
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
response_data = {"data": marshal(annotation_list, annotation_fields)}
@@ -360,6 +359,7 @@ class AnnotationBatchImportApi(Resource):
file.seek(0, 2) # Seek to end of file
file_size = file.tell()
file.seek(0) # Reset to beginning
max_size_bytes = dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT * 1024 * 1024
if file_size > max_size_bytes:
abort(

View File

@@ -115,9 +115,3 @@ class InvokeRateLimitError(BaseHTTPException):
error_code = "rate_limit_error"
description = "Rate Limit Error"
code = 429
class NeedAddIdsError(BaseHTTPException):
error_code = "need_add_ids"
description = "Need to add ids."
code = 400

View File

@@ -23,11 +23,6 @@ def _load_app_model(app_id: str) -> App | None:
return app_model
def _load_app_model_with_trial(app_id: str) -> App | None:
app_model = db.session.query(App).where(App.id == app_id, App.status == "normal").first()
return app_model
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func: Callable[P1, R1]):
@wraps(view_func)
@@ -67,44 +62,3 @@ def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, li
return decorator
else:
return decorator(view)
def get_app_model_with_trial(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func: Callable[P, R]):
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters")
app_id = kwargs.get("app_id")
app_id = str(app_id)
del kwargs["app_id"]
app_model = _load_app_model_with_trial(app_id)
if not app_model:
raise AppNotFoundError()
app_mode = AppMode.value_of(app_model.mode)
if mode is not None:
if isinstance(mode, list):
modes = mode
else:
modes = [mode]
if app_mode not in modes:
mode_values = {m.value for m in modes}
raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}")
kwargs["app_model"] = app_model
return view_func(*args, **kwargs)
return decorated_view
if view is None:
return decorator
else:
return decorator(view)

View File

@@ -63,9 +63,10 @@ class ActivateCheckApi(Resource):
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
workspaceId = args.workspace_id
reg_email = args.email
token = args.token
invitation = RegisterService.get_invitation_with_case_fallback(workspaceId, args.email, token)
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
if invitation:
data = invitation.get("data", {})
tenant = invitation.get("tenant", None)
@@ -99,12 +100,11 @@ class ActivateApi(Resource):
def post(self):
args = ActivatePayload.model_validate(console_ns.payload)
normalized_request_email = args.email.lower() if args.email else None
invitation = RegisterService.get_invitation_with_case_fallback(args.workspace_id, args.email, args.token)
invitation = RegisterService.get_invitation_if_token_valid(args.workspace_id, args.email, args.token)
if invitation is None:
raise AlreadyActivateError()
RegisterService.revoke_token(args.workspace_id, normalized_request_email, args.token)
RegisterService.revoke_token(args.workspace_id, args.email, args.token)
account = invitation["account"]
account.name = args.name

View File

@@ -1,6 +1,7 @@
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
@@ -61,7 +62,6 @@ class EmailRegisterSendEmailApi(Resource):
@email_register_enabled
def post(self):
args = EmailRegisterSendPayload.model_validate(console_ns.payload)
normalized_email = args.email.lower()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
@@ -70,12 +70,13 @@ class EmailRegisterSendEmailApi(Resource):
if args.language in languages:
language = args.language
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
raise AccountInFreezeError()
with Session(db.engine) as session:
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language)
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
token = None
token = AccountService.send_email_register_email(email=args.email, account=account, language=language)
return {"result": "success", "data": token}
@@ -87,9 +88,9 @@ class EmailRegisterCheckApi(Resource):
def post(self):
args = EmailRegisterValidityPayload.model_validate(console_ns.payload)
user_email = args.email.lower()
user_email = args.email
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(user_email)
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args.email)
if is_email_register_error_rate_limit:
raise EmailRegisterLimitError()
@@ -97,14 +98,11 @@ class EmailRegisterCheckApi(Resource):
if token_data is None:
raise InvalidTokenError()
token_email = token_data.get("email")
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
if user_email != normalized_token_email:
if user_email != token_data.get("email"):
raise InvalidEmailError()
if args.code != token_data.get("code"):
AccountService.add_email_register_error_rate_limit(user_email)
AccountService.add_email_register_error_rate_limit(args.email)
raise EmailCodeError()
# Verified, revoke the first token
@@ -115,8 +113,8 @@ class EmailRegisterCheckApi(Resource):
user_email, code=args.code, additional_data={"phase": "register"}
)
AccountService.reset_email_register_error_rate_limit(user_email)
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
AccountService.reset_email_register_error_rate_limit(args.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@console_ns.route("/email-register")
@@ -143,23 +141,22 @@ class EmailRegisterResetApi(Resource):
AccountService.revoke_email_register_token(args.token)
email = register_data.get("email", "")
normalized_email = email.lower()
with Session(db.engine) as session:
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
if account:
raise EmailAlreadyInUseError()
else:
account = self._create_new_account(normalized_email, args.password_confirm)
account = self._create_new_account(email, args.password_confirm)
if not account:
raise AccountNotFoundError()
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(normalized_email)
AccountService.reset_login_error_rate_limit(email)
return {"result": "success", "data": token_pair.model_dump()}
def _create_new_account(self, email: str, password: str) -> Account | None:
def _create_new_account(self, email, password) -> Account | None:
# Create new account if allowed
account = None
try:

View File

@@ -4,6 +4,7 @@ import secrets
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.console import console_ns
@@ -20,6 +21,7 @@ from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password
from models import Account
from services.account_service import AccountService, TenantService
from services.feature_service import FeatureService
@@ -74,7 +76,6 @@ class ForgotPasswordSendEmailApi(Resource):
@email_password_login_enabled
def post(self):
args = ForgotPasswordSendPayload.model_validate(console_ns.payload)
normalized_email = args.email.lower()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
@@ -86,11 +87,11 @@ class ForgotPasswordSendEmailApi(Resource):
language = "en-US"
with Session(db.engine) as session:
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
token = AccountService.send_reset_password_email(
account=account,
email=normalized_email,
email=args.email,
language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register,
)
@@ -121,9 +122,9 @@ class ForgotPasswordCheckApi(Resource):
def post(self):
args = ForgotPasswordCheckPayload.model_validate(console_ns.payload)
user_email = args.email.lower()
user_email = args.email
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email)
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args.email)
if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError()
@@ -131,16 +132,11 @@ class ForgotPasswordCheckApi(Resource):
if token_data is None:
raise InvalidTokenError()
token_email = token_data.get("email")
if not isinstance(token_email, str):
raise InvalidEmailError()
normalized_token_email = token_email.lower()
if user_email != normalized_token_email:
if user_email != token_data.get("email"):
raise InvalidEmailError()
if args.code != token_data.get("code"):
AccountService.add_forgot_password_error_rate_limit(user_email)
AccountService.add_forgot_password_error_rate_limit(args.email)
raise EmailCodeError()
# Verified, revoke the first token
@@ -148,11 +144,11 @@ class ForgotPasswordCheckApi(Resource):
# Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token(
token_email, code=args.code, additional_data={"phase": "reset"}
user_email, code=args.code, additional_data={"phase": "reset"}
)
AccountService.reset_forgot_password_error_rate_limit(user_email)
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
AccountService.reset_forgot_password_error_rate_limit(args.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@console_ns.route("/forgot-password/resets")
@@ -191,8 +187,9 @@ class ForgotPasswordResetApi(Resource):
password_hashed = hash_password(args.new_password, salt)
email = reset_data.get("email", "")
with Session(db.engine) as session:
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
if account:
self._update_existing_account(account, password_hashed, salt, session)

View File

@@ -90,38 +90,32 @@ class LoginApi(Resource):
def post(self):
"""Authenticate user and login."""
args = LoginPayload.model_validate(console_ns.payload)
request_email = args.email
normalized_email = request_email.lower()
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
raise AccountInFreezeError()
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(normalized_email)
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args.email)
if is_login_error_rate_limit:
raise EmailPasswordLoginLimitError()
invite_token = args.invite_token
invitation_data: dict[str, Any] | None = None
if invite_token:
invitation_data = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token)
if invitation_data is None:
invite_token = None
if args.invite_token:
invitation_data = RegisterService.get_invitation_if_token_valid(None, args.email, args.invite_token)
try:
if invitation_data:
data = invitation_data.get("data", {})
invitee_email = data.get("email") if data else None
invitee_email_normalized = invitee_email.lower() if isinstance(invitee_email, str) else invitee_email
if invitee_email_normalized != normalized_email:
if invitee_email != args.email:
raise InvalidEmailError()
account = _authenticate_account_with_case_fallback(
request_email, normalized_email, args.password, invite_token
)
account = AccountService.authenticate(args.email, args.password, args.invite_token)
else:
account = AccountService.authenticate(args.email, args.password)
except services.errors.account.AccountLoginError:
raise AccountBannedError()
except services.errors.account.AccountPasswordError as exc:
AccountService.add_login_error_rate_limit(normalized_email)
raise AuthenticationFailedError() from exc
except services.errors.account.AccountPasswordError:
AccountService.add_login_error_rate_limit(args.email)
raise AuthenticationFailedError()
# SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account)
if len(tenants) == 0:
@@ -136,7 +130,7 @@ class LoginApi(Resource):
}
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(normalized_email)
AccountService.reset_login_error_rate_limit(args.email)
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})
@@ -176,19 +170,18 @@ class ResetPasswordSendEmailApi(Resource):
@console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self):
args = EmailPayload.model_validate(console_ns.payload)
normalized_email = args.email.lower()
if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
try:
account = _get_account_with_case_fallback(args.email)
account = AccountService.get_user_through_email(args.email)
except AccountRegisterError:
raise AccountInFreezeError()
token = AccountService.send_reset_password_email(
email=normalized_email,
email=args.email,
account=account,
language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register,
@@ -203,7 +196,6 @@ class EmailCodeLoginSendEmailApi(Resource):
@console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self):
args = EmailPayload.model_validate(console_ns.payload)
normalized_email = args.email.lower()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
@@ -214,13 +206,13 @@ class EmailCodeLoginSendEmailApi(Resource):
else:
language = "en-US"
try:
account = _get_account_with_case_fallback(args.email)
account = AccountService.get_user_through_email(args.email)
except AccountRegisterError:
raise AccountInFreezeError()
if account is None:
if FeatureService.get_system_features().is_allow_register:
token = AccountService.send_email_code_login_email(email=normalized_email, language=language)
token = AccountService.send_email_code_login_email(email=args.email, language=language)
else:
raise AccountNotFound()
else:
@@ -237,17 +229,14 @@ class EmailCodeLoginApi(Resource):
def post(self):
args = EmailCodeLoginPayload.model_validate(console_ns.payload)
original_email = args.email
user_email = original_email.lower()
user_email = args.email
language = args.language
token_data = AccountService.get_email_code_login_data(args.token)
if token_data is None:
raise InvalidTokenError()
token_email = token_data.get("email")
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
if normalized_token_email != user_email:
if token_data["email"] != args.email:
raise InvalidEmailError()
if token_data["code"] != args.code:
@@ -255,7 +244,7 @@ class EmailCodeLoginApi(Resource):
AccountService.revoke_email_code_login_token(args.token)
try:
account = _get_account_with_case_fallback(original_email)
account = AccountService.get_user_through_email(user_email)
except AccountRegisterError:
raise AccountInFreezeError()
if account:
@@ -286,7 +275,7 @@ class EmailCodeLoginApi(Resource):
except WorkspacesLimitExceededError:
raise WorkspacesLimitExceeded()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(user_email)
AccountService.reset_login_error_rate_limit(args.email)
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})
@@ -320,22 +309,3 @@ class RefreshTokenApi(Resource):
return response
except Exception as e:
return {"result": "fail", "message": str(e)}, 401
def _get_account_with_case_fallback(email: str):
account = AccountService.get_user_through_email(email)
if account or email == email.lower():
return account
return AccountService.get_user_through_email(email.lower())
def _authenticate_account_with_case_fallback(
original_email: str, normalized_email: str, password: str, invite_token: str | None
):
try:
return AccountService.authenticate(original_email, password, invite_token)
except services.errors.account.AccountPasswordError:
if original_email == normalized_email:
raise
return AccountService.authenticate(normalized_email, password, invite_token)

View File

@@ -3,6 +3,7 @@ import logging
import httpx
from flask import current_app, redirect, request
from flask_restx import Resource
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized
@@ -117,10 +118,7 @@ class OAuthCallback(Resource):
invitation = RegisterService.get_invitation_by_token(token=invite_token)
if invitation:
invitation_email = invitation.get("email", None)
invitation_email_normalized = (
invitation_email.lower() if isinstance(invitation_email, str) else invitation_email
)
if invitation_email_normalized != user_info.email.lower():
if invitation_email != user_info.email:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.")
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
@@ -161,7 +159,10 @@ class OAuthCallback(Resource):
ip_address=extract_remote_ip(request),
)
response = redirect(f"{dify_config.CONSOLE_WEB_URL}?oauth_new_user={str(oauth_new_user).lower()}")
base_url = dify_config.CONSOLE_WEB_URL
query_char = "&" if "?" in base_url else "?"
target_url = f"{base_url}{query_char}oauth_new_user={str(oauth_new_user).lower()}"
response = redirect(target_url)
set_access_token_to_cookie(request, response, token_pair.access_token)
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
@@ -174,7 +175,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
if not account:
with Session(db.engine) as session:
account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session)
account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none()
return account
@@ -196,10 +197,9 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account,
tenant_was_created.send(new_tenant)
if not account:
normalized_email = user_info.email.lower()
oauth_new_user = True
if not FeatureService.get_system_features().is_allow_register:
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email):
raise AccountRegisterError(
description=(
"This email account has been deleted within the past "
@@ -210,11 +210,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account,
raise AccountRegisterError(description=("Invalid email or password"))
account_name = user_info.name or "Dify"
account = RegisterService.register(
email=normalized_email,
name=account_name,
password=None,
open_id=user_info.id,
provider=provider,
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
)
# Set interface language

View File

@@ -146,7 +146,6 @@ class DatasetUpdatePayload(BaseModel):
embedding_model: str | None = None
embedding_model_provider: str | None = None
retrieval_model: dict[str, Any] | None = None
summary_index_setting: dict[str, Any] | None = None
partial_member_list: list[dict[str, str]] | None = None
external_retrieval_model: dict[str, Any] | None = None
external_knowledge_id: str | None = None

View File

@@ -39,10 +39,9 @@ from fields.document_fields import (
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.dataset import DocumentPipelineExecutionLog, DocumentSegmentSummary
from models.dataset import DocumentPipelineExecutionLog
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
from tasks.generate_summary_index_task import generate_summary_index_task
from ..app.error import (
ProviderModelCurrentlyNotSupportError,
@@ -105,10 +104,6 @@ class DocumentRenamePayload(BaseModel):
name: str
class GenerateSummaryPayload(BaseModel):
document_list: list[str]
register_schema_models(
console_ns,
KnowledgeConfig,
@@ -116,7 +111,6 @@ register_schema_models(
RetrievalModel,
DocumentRetryPayload,
DocumentRenamePayload,
GenerateSummaryPayload,
)
@@ -301,97 +295,6 @@ class DatasetDocumentListApi(Resource):
paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
documents = paginated_documents.items
# Check if dataset has summary index enabled
has_summary_index = (
dataset.summary_index_setting
and dataset.summary_index_setting.get("enable") is True
)
# Filter documents that need summary calculation
documents_need_summary = [doc for doc in documents if doc.need_summary is True]
document_ids_need_summary = [str(doc.id) for doc in documents_need_summary]
# Calculate summary_index_status for documents that need summary (only if dataset summary index is enabled)
summary_status_map = {}
if has_summary_index and document_ids_need_summary:
# Get all segments for these documents (excluding qa_model and re_segment)
segments = (
db.session.query(DocumentSegment.id, DocumentSegment.document_id)
.where(
DocumentSegment.document_id.in_(document_ids_need_summary),
DocumentSegment.status != "re_segment",
DocumentSegment.tenant_id == current_tenant_id,
)
.all()
)
# Group segments by document_id
document_segments_map = {}
for segment in segments:
doc_id = str(segment.document_id)
if doc_id not in document_segments_map:
document_segments_map[doc_id] = []
document_segments_map[doc_id].append(segment.id)
# Get all summary records for these segments
all_segment_ids = [seg.id for seg in segments]
summaries = {}
if all_segment_ids:
summary_records = (
db.session.query(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id.in_(all_segment_ids),
DocumentSegmentSummary.dataset_id == dataset_id,
DocumentSegmentSummary.enabled == True, # Only count enabled summaries
)
.all()
)
summaries = {summary.chunk_id: summary.status for summary in summary_records}
# Calculate summary_index_status for each document
for doc_id in document_ids_need_summary:
segment_ids = document_segments_map.get(doc_id, [])
if not segment_ids:
# No segments, status is "GENERATING" (waiting to generate)
summary_status_map[doc_id] = "GENERATING"
continue
# Count summary statuses for this document's segments
status_counts = {"completed": 0, "generating": 0, "error": 0, "not_started": 0}
for segment_id in segment_ids:
status = summaries.get(segment_id, "not_started")
if status in status_counts:
status_counts[status] += 1
else:
status_counts["not_started"] += 1
total_segments = len(segment_ids)
completed_count = status_counts["completed"]
generating_count = status_counts["generating"]
error_count = status_counts["error"]
# Determine overall status (only three states: GENERATING, COMPLETED, ERROR)
if completed_count == total_segments:
summary_status_map[doc_id] = "COMPLETED"
elif error_count > 0:
# Has errors (even if some are completed or generating)
summary_status_map[doc_id] = "ERROR"
elif generating_count > 0 or status_counts["not_started"] > 0:
# Still generating or not started
summary_status_map[doc_id] = "GENERATING"
else:
# Default to generating
summary_status_map[doc_id] = "GENERATING"
# Add summary_index_status to each document
for document in documents:
if has_summary_index and document.need_summary is True:
document.summary_index_status = summary_status_map.get(str(document.id), "GENERATING")
else:
# Return null if summary index is not enabled or document doesn't need summary
document.summary_index_status = None
if fetch:
for document in documents:
completed_segments = (
@@ -490,7 +393,6 @@ class DatasetDocumentListApi(Resource):
return {"result": "success"}, 204
@console_ns.route("/datasets/init")
class DatasetInitApi(Resource):
@console_ns.doc("init_dataset")
@@ -878,7 +780,6 @@ class DocumentApi(DocumentResource):
"display_status": document.display_status,
"doc_form": document.doc_form,
"doc_language": document.doc_language,
"need_summary": document.need_summary if document.need_summary is not None else False,
}
else:
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
@@ -914,7 +815,6 @@ class DocumentApi(DocumentResource):
"display_status": document.display_status,
"doc_form": document.doc_form,
"doc_language": document.doc_language,
"need_summary": document.need_summary if document.need_summary is not None else False,
}
return response, 200
@@ -1282,211 +1182,3 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
"input_data": log.input_data,
"datasource_node_id": log.datasource_node_id,
}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/generate-summary")
class DocumentGenerateSummaryApi(Resource):
@console_ns.doc("generate_summary_for_documents")
@console_ns.doc(description="Generate summary index for documents")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.expect(console_ns.models[GenerateSummaryPayload.__name__])
@console_ns.response(200, "Summary generation started successfully")
@console_ns.response(400, "Invalid request or dataset configuration")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id):
"""
Generate summary index for specified documents.
This endpoint checks if the dataset configuration supports summary generation
(indexing_technique must be 'high_quality' and summary_index_setting.enable must be true),
then asynchronously generates summary indexes for the provided documents.
"""
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
# Get dataset
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# Check permissions
if not current_user.is_dataset_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# Validate request payload
payload = GenerateSummaryPayload.model_validate(console_ns.payload or {})
document_list = payload.document_list
if not document_list:
raise ValueError("document_list cannot be empty.")
# Check if dataset configuration supports summary generation
if dataset.indexing_technique != "high_quality":
raise ValueError(
f"Summary generation is only available for 'high_quality' indexing technique. "
f"Current indexing technique: {dataset.indexing_technique}"
)
summary_index_setting = dataset.summary_index_setting
if not summary_index_setting or not summary_index_setting.get("enable"):
raise ValueError(
"Summary index is not enabled for this dataset. "
"Please enable it in the dataset settings."
)
# Verify all documents exist and belong to the dataset
documents = (
db.session.query(Document)
.filter(
Document.id.in_(document_list),
Document.dataset_id == dataset_id,
)
.all()
)
if len(documents) != len(document_list):
found_ids = {doc.id for doc in documents}
missing_ids = set(document_list) - found_ids
raise NotFound(f"Some documents not found: {list(missing_ids)}")
# Dispatch async tasks for each document
for document in documents:
# Skip qa_model documents as they don't generate summaries
if document.doc_form == "qa_model":
logger.info(
f"Skipping summary generation for qa_model document {document.id}"
)
continue
# Dispatch async task
generate_summary_index_task(dataset_id, document.id)
logger.info(
f"Dispatched summary generation task for document {document.id} in dataset {dataset_id}"
)
return {"result": "success"}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/summary-status")
class DocumentSummaryStatusApi(DocumentResource):
@console_ns.doc("get_document_summary_status")
@console_ns.doc(description="Get summary index generation status for a document")
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@console_ns.response(200, "Summary status retrieved successfully")
@console_ns.response(404, "Document not found")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, document_id):
"""
Get summary index generation status for a document.
Returns:
- total_segments: Total number of segments in the document
- summary_status: Dictionary with status counts
- completed: Number of summaries completed
- generating: Number of summaries being generated
- error: Number of summaries with errors
- not_started: Number of segments without summary records
- summaries: List of summary records with status and content preview
"""
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
document_id = str(document_id)
# Get document
document = self.get_document(dataset_id, document_id)
# Get dataset
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# Check permissions
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# Get all segments for this document
segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
)
.all()
)
total_segments = len(segments)
# Get all summary records for these segments
segment_ids = [segment.id for segment in segments]
summaries = []
if segment_ids:
summaries = (
db.session.query(DocumentSegmentSummary)
.filter(
DocumentSegmentSummary.document_id == document_id,
DocumentSegmentSummary.dataset_id == dataset_id,
DocumentSegmentSummary.chunk_id.in_(segment_ids),
DocumentSegmentSummary.enabled == True, # Only return enabled summaries
)
.all()
)
# Create a mapping of chunk_id to summary
summary_map = {summary.chunk_id: summary for summary in summaries}
# Count statuses
status_counts = {
"completed": 0,
"generating": 0,
"error": 0,
"not_started": 0,
}
summary_list = []
for segment in segments:
summary = summary_map.get(segment.id)
if summary:
status = summary.status
status_counts[status] = status_counts.get(status, 0) + 1
summary_list.append({
"segment_id": segment.id,
"segment_position": segment.position,
"status": summary.status,
"summary_preview": summary.summary_content[:100] + "..." if summary.summary_content and len(summary.summary_content) > 100 else summary.summary_content,
"error": summary.error,
"created_at": int(summary.created_at.timestamp()) if summary.created_at else None,
"updated_at": int(summary.updated_at.timestamp()) if summary.updated_at else None,
})
else:
status_counts["not_started"] += 1
summary_list.append({
"segment_id": segment.id,
"segment_position": segment.position,
"status": "not_started",
"summary_preview": None,
"error": None,
"created_at": None,
"updated_at": None,
})
return {
"total_segments": total_segments,
"summary_status": status_counts,
"summaries": summary_list,
}, 200

View File

@@ -32,7 +32,7 @@ from extensions.ext_redis import redis_client
from fields.segment_fields import child_chunk_fields, segment_fields
from libs.helper import escape_like_pattern
from libs.login import current_account_with_tenant, login_required
from models.dataset import ChildChunk, DocumentSegment, DocumentSegmentSummary
from models.dataset import ChildChunk, DocumentSegment
from models.model import UploadFile
from services.dataset_service import DatasetService, DocumentService, SegmentService
from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs
@@ -41,23 +41,6 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
def _get_segment_with_summary(segment, dataset_id):
"""Helper function to marshal segment and add summary information."""
segment_dict = marshal(segment, segment_fields)
# Query summary for this segment (only enabled summaries)
summary = (
db.session.query(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset_id,
DocumentSegmentSummary.enabled == True, # Only return enabled summaries
)
.first()
)
segment_dict["summary"] = summary.summary_content if summary else None
return segment_dict
class SegmentListQuery(BaseModel):
limit: int = Field(default=20, ge=1, le=100)
status: list[str] = Field(default_factory=list)
@@ -80,7 +63,6 @@ class SegmentUpdatePayload(BaseModel):
keywords: list[str] | None = None
regenerate_child_chunks: bool = False
attachment_ids: list[str] | None = None
summary: str | None = None # Summary content for summary index
class BatchImportPayload(BaseModel):
@@ -198,34 +180,8 @@ class DatasetDocumentSegmentListApi(Resource):
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
# Query summaries for all segments in this page (batch query for efficiency)
segment_ids = [segment.id for segment in segments.items]
summaries = {}
if segment_ids:
summary_records = (
db.session.query(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id.in_(segment_ids),
DocumentSegmentSummary.dataset_id == dataset_id,
)
.all()
)
# Only include enabled summaries
summaries = {
summary.chunk_id: summary.summary_content
for summary in summary_records
if summary.enabled is True
}
# Add summary to each segment
segments_with_summary = []
for segment in segments.items:
segment_dict = marshal(segment, segment_fields)
segment_dict["summary"] = summaries.get(segment.id)
segments_with_summary.append(segment_dict)
response = {
"data": segments_with_summary,
"data": marshal(segments.items, segment_fields),
"limit": limit,
"total": segments.total,
"total_pages": segments.pages,
@@ -371,7 +327,7 @@ class DatasetDocumentSegmentAddApi(Resource):
payload_dict = payload.model_dump(exclude_none=True)
SegmentService.segment_create_args_validate(payload_dict, document)
segment = SegmentService.create_segment(payload_dict, document, dataset)
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
@@ -433,12 +389,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
payload = SegmentUpdatePayload.model_validate(console_ns.payload or {})
payload_dict = payload.model_dump(exclude_none=True)
SegmentService.segment_create_args_validate(payload_dict, document)
# Update segment (summary update with change detection is handled in SegmentService.update_segment)
segment = SegmentService.update_segment(
SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
)
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@setup_required
@login_required

View File

@@ -1,4 +1,4 @@
from flask_restx import Resource, fields
from flask_restx import Resource
from controllers.common.schema import register_schema_model
from libs.login import login_required
@@ -10,56 +10,17 @@ from ..wraps import (
cloud_edition_billing_rate_limit_check,
setup_required,
)
from fields.hit_testing_fields import (
child_chunk_fields,
document_fields,
files_fields,
hit_testing_record_fields,
segment_fields,
)
register_schema_model(console_ns, HitTestingPayload)
def _get_or_create_model(model_name: str, field_def):
"""Get or create a flask_restx model to avoid dict type issues in Swagger."""
existing = console_ns.models.get(model_name)
if existing is None:
existing = console_ns.model(model_name, field_def)
return existing
# Register models for flask_restx to avoid dict type issues in Swagger
document_model = _get_or_create_model("HitTestingDocument", document_fields)
segment_fields_copy = segment_fields.copy()
segment_fields_copy["document"] = fields.Nested(document_model)
segment_model = _get_or_create_model("HitTestingSegment", segment_fields_copy)
child_chunk_model = _get_or_create_model("HitTestingChildChunk", child_chunk_fields)
files_model = _get_or_create_model("HitTestingFile", files_fields)
hit_testing_record_fields_copy = hit_testing_record_fields.copy()
hit_testing_record_fields_copy["segment"] = fields.Nested(segment_model)
hit_testing_record_fields_copy["child_chunks"] = fields.List(fields.Nested(child_chunk_model))
hit_testing_record_fields_copy["files"] = fields.List(fields.Nested(files_model))
hit_testing_record_model = _get_or_create_model("HitTestingRecord", hit_testing_record_fields_copy)
# Response model for hit testing API
hit_testing_response_fields = {
"query": fields.String,
"records": fields.List(fields.Nested(hit_testing_record_model)),
}
hit_testing_response_model = _get_or_create_model("HitTestingResponse", hit_testing_response_fields)
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
class HitTestingApi(Resource, DatasetsHitTestingBase):
@console_ns.doc("test_dataset_retrieval")
@console_ns.doc(description="Test dataset knowledge retrieval")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.expect(console_ns.models[HitTestingPayload.__name__])
@console_ns.response(200, "Hit testing completed successfully", model=hit_testing_response_model)
@console_ns.response(200, "Hit testing completed successfully")
@console_ns.response(404, "Dataset not found")
@console_ns.response(400, "Invalid parameters")
@setup_required

View File

@@ -1,43 +0,0 @@
from flask import request
from flask_restx import Resource
from controllers.console import api
from controllers.console.explore.wraps import explore_banner_enabled
from extensions.ext_database import db
from models.model import ExporleBanner
class BannerApi(Resource):
"""Resource for banner list."""
@explore_banner_enabled
def get(self):
"""Get banner list."""
language = request.args.get("language", "en-US")
# Build base query for enabled banners
base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == "enabled")
# Try to get banners in the requested language
banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all()
# Fallback to en-US if no banners found and language is not en-US
if not banners and language != "en-US":
banners = base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort).all()
# Convert banners to serializable format
result = []
for banner in banners:
banner_data = {
"id": banner.id,
"content": banner.content, # Already parsed as JSON by SQLAlchemy
"link": banner.link,
"sort": banner.sort,
"status": banner.status,
"created_at": banner.created_at.isoformat() if banner.created_at else None,
}
result.append(banner_data)
return result
api.add_resource(BannerApi, "/explore/banners")

View File

@@ -29,25 +29,3 @@ class AppAccessDeniedError(BaseHTTPException):
error_code = "access_denied"
description = "App access denied."
code = 403
class TrialAppNotAllowed(BaseHTTPException):
"""*403* `Trial App Not Allowed`
Raise if the user has reached the trial app limit.
"""
error_code = "trial_app_not_allowed"
code = 403
description = "the app is not allowed to be trial."
class TrialAppLimitExceeded(BaseHTTPException):
"""*403* `Trial App Limit Exceeded`
Raise if the user has exceeded the trial app limit.
"""
error_code = "trial_app_limit_exceeded"
code = 403
description = "The user has exceeded the trial app limit."

View File

@@ -29,7 +29,6 @@ recommended_app_fields = {
"category": fields.String,
"position": fields.Integer,
"is_listed": fields.Boolean,
"can_trial": fields.Boolean,
}
recommended_app_list_fields = {

View File

@@ -1,512 +0,0 @@
import logging
from typing import Any, cast
from flask import request
from flask_restx import Resource, marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.common.fields import Parameters as ParametersResponse
from controllers.common.fields import Site as SiteResponse
from controllers.console import api
from controllers.console.app.error import (
AppUnavailableError,
AudioTooLargeError,
CompletionRequestError,
ConversationCompletedError,
NeedAddIdsError,
NoAudioUploadedError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderNotSupportSpeechToTextError,
ProviderQuotaExceededError,
UnsupportedAudioTypeError,
)
from controllers.console.app.wraps import get_app_model_with_trial
from controllers.console.explore.error import (
AppSuggestedQuestionsAfterAnswerDisabledError,
NotChatAppError,
NotCompletionAppError,
NotWorkflowAppError,
)
from controllers.console.explore.wraps import TrialAppResource, trial_feature_enable
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db
from fields.app_fields import app_detail_fields_with_site
from fields.dataset_fields import dataset_fields
from fields.workflow_fields import workflow_fields
from libs import helper
from libs.helper import uuid_value
from libs.login import current_user
from models import Account
from models.account import TenantStatus
from models.model import AppMode, Site
from models.workflow import Workflow
from services.app_generate_service import AppGenerateService
from services.app_service import AppService
from services.audio_service import AudioService
from services.dataset_service import DatasetService
from services.errors.audio import (
AudioTooLargeServiceError,
NoAudioUploadedServiceError,
ProviderNotSupportSpeechToTextServiceError,
UnsupportedAudioTypeServiceError,
)
from services.errors.conversation import ConversationNotExistsError
from services.errors.llm import InvokeRateLimitError
from services.errors.message import (
MessageNotExistsError,
SuggestedQuestionsAfterAnswerDisabledError,
)
from services.message_service import MessageService
from services.recommended_app_service import RecommendedAppService
logger = logging.getLogger(__name__)
class TrialAppWorkflowRunApi(TrialAppResource):
def post(self, trial_app):
"""
Run workflow
"""
app_model = trial_app
if not app_model:
raise NotWorkflowAppError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("files", type=list, required=False, location="json")
args = parser.parse_args()
assert current_user is not None
try:
app_id = app_model.id
user_id = current_user.id
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
)
RecommendedAppService.add_trial_app_record(app_id, user_id)
return helper.compact_generate_response(response)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
class TrialAppWorkflowTaskStopApi(TrialAppResource):
def post(self, trial_app, task_id: str):
"""
Stop workflow task
"""
app_model = trial_app
if not app_model:
raise NotWorkflowAppError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
assert current_user is not None
# Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check)
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager.send_stop_command(task_id)
return {"result": "success"}
class TrialChatApi(TrialAppResource):
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, required=True, location="json")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
args = parser.parse_args()
args["auto_generate_name"] = False
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
# Get IDs before they might be detached from session
app_id = app_model.id
user_id = current_user.id
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
)
RecommendedAppService.add_trial_app_record(app_id, user_id)
return helper.compact_generate_response(response)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
class TrialMessageSuggestedQuestionApi(TrialAppResource):
@trial_feature_enable
def get(self, trial_app, message_id):
app_model = trial_app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
message_id = str(message_id)
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
)
except MessageNotExistsError:
raise NotFound("Message not found")
except ConversationNotExistsError:
raise NotFound("Conversation not found")
except SuggestedQuestionsAfterAnswerDisabledError:
raise AppSuggestedQuestionsAfterAnswerDisabledError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
return {"data": questions}
class TrialChatAudioApi(TrialAppResource):
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
file = request.files["file"]
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
# Get IDs before they might be detached from session
app_id = app_model.id
user_id = current_user.id
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None)
RecommendedAppService.add_trial_app_record(app_id, user_id)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
except AudioTooLargeServiceError as e:
raise AudioTooLargeError(str(e))
except UnsupportedAudioTypeServiceError:
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:
logger.exception("internal server error.")
raise InternalServerError()
class TrialChatTextApi(TrialAppResource):
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
try:
parser = reqparse.RequestParser()
parser.add_argument("message_id", type=str, required=False, location="json")
parser.add_argument("voice", type=str, location="json")
parser.add_argument("text", type=str, location="json")
parser.add_argument("streaming", type=bool, location="json")
args = parser.parse_args()
message_id = args.get("message_id", None)
text = args.get("text", None)
voice = args.get("voice", None)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
# Get IDs before they might be detached from session
app_id = app_model.id
user_id = current_user.id
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
RecommendedAppService.add_trial_app_record(app_id, user_id)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
except AudioTooLargeServiceError as e:
raise AudioTooLargeError(str(e))
except UnsupportedAudioTypeServiceError:
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:
logger.exception("internal server error.")
raise InternalServerError()
class TrialCompletionApi(TrialAppResource):
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
if app_model.mode != "completion":
raise NotCompletionAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, location="json", default="")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
args = parser.parse_args()
streaming = args["response_mode"] == "streaming"
args["auto_generate_name"] = False
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
# Get IDs before they might be detached from session
app_id = app_model.id
user_id = current_user.id
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
)
RecommendedAppService.add_trial_app_record(app_id, user_id)
return helper.compact_generate_response(response)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
class TrialSitApi(Resource):
"""Resource for trial app sites."""
@trial_feature_enable
@get_app_model_with_trial
def get(self, app_model):
"""Retrieve app site info.
Returns the site configuration for the application including theme, icons, and text.
"""
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site:
raise Forbidden()
assert app_model.tenant
if app_model.tenant.status == TenantStatus.ARCHIVE:
raise Forbidden()
return SiteResponse.model_validate(site).model_dump(mode="json")
class TrialAppParameterApi(Resource):
"""Resource for app variables."""
@trial_feature_enable
@get_app_model_with_trial
def get(self, app_model):
"""Retrieve app parameters."""
if app_model is None:
raise AppUnavailableError()
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app_model.workflow
if workflow is None:
raise AppUnavailableError()
features_dict = workflow.features_dict
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app_model.app_model_config
if app_model_config is None:
raise AppUnavailableError()
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get("user_input_form", [])
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
return ParametersResponse.model_validate(parameters).model_dump(mode="json")
class AppApi(Resource):
@trial_feature_enable
@get_app_model_with_trial
@marshal_with(app_detail_fields_with_site)
def get(self, app_model):
"""Get app detail"""
app_service = AppService()
app_model = app_service.get_app(app_model)
return app_model
class AppWorkflowApi(Resource):
@trial_feature_enable
@get_app_model_with_trial
@marshal_with(workflow_fields)
def get(self, app_model):
"""Get workflow detail"""
if not app_model.workflow_id:
raise AppUnavailableError()
workflow = (
db.session.query(Workflow)
.where(
Workflow.id == app_model.workflow_id,
)
.first()
)
return workflow
class DatasetListApi(Resource):
@trial_feature_enable
@get_app_model_with_trial
def get(self, app_model):
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
ids = request.args.getlist("ids")
tenant_id = app_model.tenant_id
if ids:
datasets, total = DatasetService.get_datasets_by_ids(ids, tenant_id)
else:
raise NeedAddIdsError()
data = cast(list[dict[str, Any]], marshal(datasets, dataset_fields))
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
return response
api.add_resource(TrialChatApi, "/trial-apps/<uuid:app_id>/chat-messages", endpoint="trial_app_chat_completion")
api.add_resource(
TrialMessageSuggestedQuestionApi,
"/trial-apps/<uuid:app_id>/messages/<uuid:message_id>/suggested-questions",
endpoint="trial_app_suggested_question",
)
api.add_resource(TrialChatAudioApi, "/trial-apps/<uuid:app_id>/audio-to-text", endpoint="trial_app_audio")
api.add_resource(TrialChatTextApi, "/trial-apps/<uuid:app_id>/text-to-audio", endpoint="trial_app_text")
api.add_resource(TrialCompletionApi, "/trial-apps/<uuid:app_id>/completion-messages", endpoint="trial_app_completion")
api.add_resource(TrialSitApi, "/trial-apps/<uuid:app_id>/site")
api.add_resource(TrialAppParameterApi, "/trial-apps/<uuid:app_id>/parameters", endpoint="trial_app_parameters")
api.add_resource(AppApi, "/trial-apps/<uuid:app_id>", endpoint="trial_app")
api.add_resource(TrialAppWorkflowRunApi, "/trial-apps/<uuid:app_id>/workflows/run", endpoint="trial_app_workflow_run")
api.add_resource(TrialAppWorkflowTaskStopApi, "/trial-apps/<uuid:app_id>/workflows/tasks/<string:task_id>/stop")
api.add_resource(AppWorkflowApi, "/trial-apps/<uuid:app_id>/workflows", endpoint="trial_app_workflow")
api.add_resource(DatasetListApi, "/trial-apps/<uuid:app_id>/datasets", endpoint="trial_app_datasets")

View File

@@ -2,15 +2,14 @@ from collections.abc import Callable
from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar
from flask import abort
from flask_restx import Resource
from werkzeug.exceptions import NotFound
from controllers.console.explore.error import AppAccessDeniedError, TrialAppLimitExceeded, TrialAppNotAllowed
from controllers.console.explore.error import AppAccessDeniedError
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
from models import AccountTrialAppRecord, App, InstalledApp, TrialApp
from models import InstalledApp
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
@@ -72,61 +71,6 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
return decorator
def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
def decorator(view: Callable[Concatenate[App, P], R]):
@wraps(view)
def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
current_user, _ = current_account_with_tenant()
trial_app = db.session.query(TrialApp).where(TrialApp.app_id == str(app_id)).first()
if trial_app is None:
raise TrialAppNotAllowed()
app = trial_app.app
if app is None:
raise TrialAppNotAllowed()
account_trial_app_record = (
db.session.query(AccountTrialAppRecord)
.where(AccountTrialAppRecord.account_id == current_user.id, AccountTrialAppRecord.app_id == app_id)
.first()
)
if account_trial_app_record:
if account_trial_app_record.count >= trial_app.trial_limit:
raise TrialAppLimitExceeded()
return view(app, *args, **kwargs)
return decorated
if view:
return decorator(view)
return decorator
def trial_feature_enable(view: Callable[..., R]) -> Callable[..., R]:
@wraps(view)
def decorated(*args, **kwargs):
features = FeatureService.get_system_features()
if not features.enable_trial_app:
abort(403, "Trial app feature is not enabled.")
return view(*args, **kwargs)
return decorated
def explore_banner_enabled(view: Callable[..., R]) -> Callable[..., R]:
@wraps(view)
def decorated(*args, **kwargs):
features = FeatureService.get_system_features()
if not features.enable_explore_banner:
abort(403, "Explore banner feature is not enabled.")
return view(*args, **kwargs)
return decorated
class InstalledAppResource(Resource):
# must be reversed if there are multiple decorators
@@ -136,13 +80,3 @@ class InstalledAppResource(Resource):
account_initialization_required,
login_required,
]
class TrialAppResource(Resource):
# must be reversed if there are multiple decorators
method_decorators = [
trial_app_required,
account_initialization_required,
login_required,
]

View File

@@ -84,11 +84,10 @@ class SetupApi(Resource):
raise NotInitValidateError()
args = SetupRequestPayload.model_validate(console_ns.payload)
normalized_email = args.email.lower()
# setup
RegisterService.setup(
email=normalized_email,
email=args.email,
name=args.name,
password=args.password,
ip_address=extract_remote_ip(request),

View File

@@ -41,7 +41,7 @@ from fields.member_fields import account_fields
from libs.datetime_utils import naive_utc_now
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
from libs.login import current_account_with_tenant, login_required
from models import AccountIntegrate, InvitationCode
from models import Account, AccountIntegrate, InvitationCode
from services.account_service import AccountService
from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
@@ -536,8 +536,7 @@ class ChangeEmailSendEmailApi(Resource):
else:
language = "en-US"
account = None
user_email = None
email_for_sending = args.email.lower()
user_email = args.email
if args.phase is not None and args.phase == "new_email":
if args.token is None:
raise InvalidTokenError()
@@ -547,24 +546,16 @@ class ChangeEmailSendEmailApi(Resource):
raise InvalidTokenError()
user_email = reset_data.get("email", "")
if user_email.lower() != current_user.email.lower():
if user_email != current_user.email:
raise InvalidEmailError()
user_email = current_user.email
else:
with Session(db.engine) as session:
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
if account is None:
raise AccountNotFound()
email_for_sending = account.email
user_email = account.email
token = AccountService.send_change_email_email(
account=account,
email=email_for_sending,
old_email=user_email,
language=language,
phase=args.phase,
account=account, email=args.email, old_email=user_email, language=language, phase=args.phase
)
return {"result": "success", "data": token}
@@ -580,9 +571,9 @@ class ChangeEmailCheckApi(Resource):
payload = console_ns.payload or {}
args = ChangeEmailValidityPayload.model_validate(payload)
user_email = args.email.lower()
user_email = args.email
is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(user_email)
is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args.email)
if is_change_email_error_rate_limit:
raise EmailChangeLimitError()
@@ -590,13 +581,11 @@ class ChangeEmailCheckApi(Resource):
if token_data is None:
raise InvalidTokenError()
token_email = token_data.get("email")
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
if user_email != normalized_token_email:
if user_email != token_data.get("email"):
raise InvalidEmailError()
if args.code != token_data.get("code"):
AccountService.add_change_email_error_rate_limit(user_email)
AccountService.add_change_email_error_rate_limit(args.email)
raise EmailCodeError()
# Verified, revoke the first token
@@ -607,8 +596,8 @@ class ChangeEmailCheckApi(Resource):
user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
)
AccountService.reset_change_email_error_rate_limit(user_email)
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
AccountService.reset_change_email_error_rate_limit(args.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@console_ns.route("/account/change-email/reset")
@@ -622,12 +611,11 @@ class ChangeEmailResetApi(Resource):
def post(self):
payload = console_ns.payload or {}
args = ChangeEmailResetPayload.model_validate(payload)
normalized_new_email = args.new_email.lower()
if AccountService.is_account_in_freeze(normalized_new_email):
if AccountService.is_account_in_freeze(args.new_email):
raise AccountInFreezeError()
if not AccountService.check_email_unique(normalized_new_email):
if not AccountService.check_email_unique(args.new_email):
raise EmailAlreadyInUseError()
reset_data = AccountService.get_change_email_data(args.token)
@@ -638,13 +626,13 @@ class ChangeEmailResetApi(Resource):
old_email = reset_data.get("old_email", "")
current_user, _ = current_account_with_tenant()
if current_user.email.lower() != old_email.lower():
if current_user.email != old_email:
raise AccountNotFound()
updated_account = AccountService.update_account_email(current_user, email=normalized_new_email)
updated_account = AccountService.update_account_email(current_user, email=args.new_email)
AccountService.send_change_email_completed_notify_email(
email=normalized_new_email,
email=args.new_email,
)
return updated_account
@@ -657,9 +645,8 @@ class CheckEmailUnique(Resource):
def post(self):
payload = console_ns.payload or {}
args = CheckEmailUniquePayload.model_validate(payload)
normalized_email = args.email.lower()
if AccountService.is_account_in_freeze(normalized_email):
if AccountService.is_account_in_freeze(args.email):
raise AccountInFreezeError()
if not AccountService.check_email_unique(normalized_email):
if not AccountService.check_email_unique(args.email):
raise EmailAlreadyInUseError()
return {"result": "success"}

View File

@@ -116,31 +116,26 @@ class MemberInviteEmailApi(Resource):
raise WorkspaceMembersLimitExceeded()
for invitee_email in invitee_emails:
normalized_invitee_email = invitee_email.lower()
try:
if not inviter.current_tenant:
raise ValueError("No current tenant")
token = RegisterService.invite_new_member(
tenant=inviter.current_tenant,
email=invitee_email,
language=interface_language,
role=invitee_role,
inviter=inviter,
inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter
)
encoded_invitee_email = parse.quote(normalized_invitee_email)
encoded_invitee_email = parse.quote(invitee_email)
invitation_results.append(
{
"status": "success",
"email": normalized_invitee_email,
"email": invitee_email,
"url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}",
}
)
except AccountAlreadyInTenantError:
invitation_results.append(
{"status": "success", "email": normalized_invitee_email, "url": f"{console_web_url}/signin"}
{"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"}
)
except Exception as e:
invitation_results.append({"status": "failed", "email": normalized_invitee_email, "message": str(e)})
invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)})
return {
"result": "success",

View File

@@ -1,14 +1,14 @@
import logging
from collections.abc import Mapping
from typing import Any
from flask import make_response, redirect, request
from flask_restx import Resource
from pydantic import BaseModel, model_validator
from flask_restx import Resource, reqparse
from pydantic import BaseModel, Field, model_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden
from configs import dify_config
from controllers.common.schema import register_schema_models
from controllers.web.error import NotFoundError
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import CredentialType
@@ -35,38 +35,35 @@ from ..wraps import (
logger = logging.getLogger(__name__)
class TriggerSubscriptionBuilderCreatePayload(BaseModel):
credential_type: str = CredentialType.UNAUTHORIZED
class TriggerSubscriptionUpdateRequest(BaseModel):
"""Request payload for updating a trigger subscription"""
class TriggerSubscriptionBuilderVerifyPayload(BaseModel):
credentials: dict[str, Any]
class TriggerSubscriptionBuilderUpdatePayload(BaseModel):
name: str | None = None
parameters: dict[str, Any] | None = None
properties: dict[str, Any] | None = None
credentials: dict[str, Any] | None = None
name: str | None = Field(default=None, description="The name for the subscription")
credentials: Mapping[str, Any] | None = Field(default=None, description="The credentials for the subscription")
parameters: Mapping[str, Any] | None = Field(default=None, description="The parameters for the subscription")
properties: Mapping[str, Any] | None = Field(default=None, description="The properties for the subscription")
@model_validator(mode="after")
def check_at_least_one_field(self):
if all(v is None for v in self.model_dump().values()):
if all(v is None for v in (self.name, self.credentials, self.parameters, self.properties)):
raise ValueError("At least one of name, credentials, parameters, or properties must be provided")
return self
class TriggerOAuthClientPayload(BaseModel):
client_params: dict[str, Any] | None = None
enabled: bool | None = None
class TriggerSubscriptionVerifyRequest(BaseModel):
"""Request payload for verifying subscription credentials."""
credentials: Mapping[str, Any] = Field(description="The credentials to verify")
register_schema_models(
console_ns,
TriggerSubscriptionBuilderCreatePayload,
TriggerSubscriptionBuilderVerifyPayload,
TriggerSubscriptionBuilderUpdatePayload,
TriggerOAuthClientPayload,
console_ns.schema_model(
TriggerSubscriptionUpdateRequest.__name__,
TriggerSubscriptionUpdateRequest.model_json_schema(ref_template="#/definitions/{model}"),
)
console_ns.schema_model(
TriggerSubscriptionVerifyRequest.__name__,
TriggerSubscriptionVerifyRequest.model_json_schema(ref_template="#/definitions/{model}"),
)
@@ -135,11 +132,16 @@ class TriggerSubscriptionListApi(Resource):
raise
parser = reqparse.RequestParser().add_argument(
"credential_type", type=str, required=False, nullable=True, location="json"
)
@console_ns.route(
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/create",
)
class TriggerSubscriptionBuilderCreateApi(Resource):
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderCreatePayload.__name__])
@console_ns.expect(parser)
@setup_required
@login_required
@edit_permission_required
@@ -149,10 +151,10 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
user = current_user
assert user.current_tenant_id is not None
payload = TriggerSubscriptionBuilderCreatePayload.model_validate(console_ns.payload or {})
args = parser.parse_args()
try:
credential_type = CredentialType.of(payload.credential_type)
credential_type = CredentialType.of(args.get("credential_type") or CredentialType.UNAUTHORIZED.value)
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
tenant_id=user.current_tenant_id,
user_id=user.id,
@@ -180,11 +182,18 @@ class TriggerSubscriptionBuilderGetApi(Resource):
)
parser_api = (
reqparse.RequestParser()
# The credentials of the subscription builder
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
)
@console_ns.route(
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify-and-update/<path:subscription_builder_id>",
)
class TriggerSubscriptionBuilderVerifyApi(Resource):
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderVerifyPayload.__name__])
class TriggerSubscriptionBuilderVerifyAndUpdateApi(Resource):
@console_ns.expect(parser_api)
@setup_required
@login_required
@edit_permission_required
@@ -194,7 +203,7 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
user = current_user
assert user.current_tenant_id is not None
payload = TriggerSubscriptionBuilderVerifyPayload.model_validate(console_ns.payload or {})
args = parser_api.parse_args()
try:
# Use atomic update_and_verify to prevent race conditions
@@ -204,7 +213,7 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
credentials=payload.credentials,
credentials=args.get("credentials", None),
),
)
except Exception as e:
@@ -212,11 +221,24 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
raise ValueError(str(e)) from e
parser_update_api = (
reqparse.RequestParser()
# The name of the subscription builder
.add_argument("name", type=str, required=False, nullable=True, location="json")
# The parameters of the subscription builder
.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
# The properties of the subscription builder
.add_argument("properties", type=dict, required=False, nullable=True, location="json")
# The credentials of the subscription builder
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
)
@console_ns.route(
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/update/<path:subscription_builder_id>",
)
class TriggerSubscriptionBuilderUpdateApi(Resource):
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderUpdatePayload.__name__])
@console_ns.expect(parser_update_api)
@setup_required
@login_required
@edit_permission_required
@@ -227,7 +249,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
assert isinstance(user, Account)
assert user.current_tenant_id is not None
payload = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {})
args = parser_update_api.parse_args()
try:
return jsonable_encoder(
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
@@ -235,10 +257,10 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
name=payload.name,
parameters=payload.parameters,
properties=payload.properties,
credentials=payload.credentials,
name=args.get("name", None),
parameters=args.get("parameters", None),
properties=args.get("properties", None),
credentials=args.get("credentials", None),
),
)
)
@@ -273,7 +295,7 @@ class TriggerSubscriptionBuilderLogsApi(Resource):
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/build/<path:subscription_builder_id>",
)
class TriggerSubscriptionBuilderBuildApi(Resource):
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderUpdatePayload.__name__])
@console_ns.expect(parser_update_api)
@setup_required
@login_required
@edit_permission_required
@@ -282,7 +304,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
"""Build a subscription instance for a trigger provider"""
user = current_user
assert user.current_tenant_id is not None
payload = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {})
args = parser_update_api.parse_args()
try:
# Use atomic update_and_build to prevent race conditions
TriggerSubscriptionBuilderService.update_and_build_builder(
@@ -291,9 +313,9 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
name=payload.name,
parameters=payload.parameters,
properties=payload.properties,
name=args.get("name", None),
parameters=args.get("parameters", None),
properties=args.get("properties", None),
),
)
return 200
@@ -306,7 +328,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/update",
)
class TriggerSubscriptionUpdateApi(Resource):
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderUpdatePayload.__name__])
@console_ns.expect(console_ns.models[TriggerSubscriptionUpdateRequest.__name__])
@setup_required
@login_required
@edit_permission_required
@@ -316,7 +338,7 @@ class TriggerSubscriptionUpdateApi(Resource):
user = current_user
assert user.current_tenant_id is not None
request = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {})
request = TriggerSubscriptionUpdateRequest.model_validate(console_ns.payload)
subscription = TriggerProviderService.get_subscription_by_id(
tenant_id=user.current_tenant_id,
@@ -546,6 +568,13 @@ class TriggerOAuthCallbackApi(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
parser_oauth_client = (
reqparse.RequestParser()
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
)
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/oauth/client")
class TriggerOAuthClientManageApi(Resource):
@setup_required
@@ -593,7 +622,7 @@ class TriggerOAuthClientManageApi(Resource):
logger.exception("Error getting OAuth client", exc_info=e)
raise
@console_ns.expect(console_ns.models[TriggerOAuthClientPayload.__name__])
@console_ns.expect(parser_oauth_client)
@setup_required
@login_required
@is_admin_or_owner_required
@@ -603,15 +632,15 @@ class TriggerOAuthClientManageApi(Resource):
user = current_user
assert user.current_tenant_id is not None
payload = TriggerOAuthClientPayload.model_validate(console_ns.payload or {})
args = parser_oauth_client.parse_args()
try:
provider_id = TriggerProviderID(provider)
return TriggerProviderService.save_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
provider_id=provider_id,
client_params=payload.client_params,
enabled=payload.enabled,
client_params=args.get("client_params"),
enabled=args.get("enabled"),
)
except ValueError as e:
@@ -647,7 +676,7 @@ class TriggerOAuthClientManageApi(Resource):
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/verify/<path:subscription_id>",
)
class TriggerSubscriptionVerifyApi(Resource):
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderVerifyPayload.__name__])
@console_ns.expect(console_ns.models[TriggerSubscriptionVerifyRequest.__name__])
@setup_required
@login_required
@edit_permission_required
@@ -657,7 +686,9 @@ class TriggerSubscriptionVerifyApi(Resource):
user = current_user
assert user.current_tenant_id is not None
verify_request = TriggerSubscriptionBuilderVerifyPayload.model_validate(console_ns.payload or {})
verify_request: TriggerSubscriptionVerifyRequest = TriggerSubscriptionVerifyRequest.model_validate(
console_ns.payload
)
try:
result = TriggerProviderService.verify_subscription_credentials(

View File

@@ -80,9 +80,6 @@ tenant_fields = {
"in_trial": fields.Boolean,
"trial_end_reason": fields.String,
"custom_config": fields.Raw(attribute="custom_config"),
"trial_credits": fields.Integer,
"trial_credits_used": fields.Integer,
"next_credit_reset_date": fields.Integer,
}
tenants_fields = {

View File

@@ -358,12 +358,14 @@ def annotation_import_rate_limit(view: Callable[P, R]):
def decorated(*args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant()
current_time = int(time.time() * 1000)
# Check per-minute rate limit
minute_key = f"annotation_import_rate_limit:{current_tenant_id}:1min"
redis_client.zadd(minute_key, {current_time: current_time})
redis_client.zremrangebyscore(minute_key, 0, current_time - 60000)
minute_count = redis_client.zcard(minute_key)
redis_client.expire(minute_key, 120) # 2 minutes TTL
if minute_count > dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE:
abort(
429,
@@ -377,6 +379,7 @@ def annotation_import_rate_limit(view: Callable[P, R]):
redis_client.zremrangebyscore(hour_key, 0, current_time - 3600000)
hour_count = redis_client.zcard(hour_key)
redis_client.expire(hour_key, 7200) # 2 hours TTL
if hour_count > dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR:
abort(
429,

View File

@@ -4,6 +4,7 @@ import secrets
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
@@ -21,7 +22,7 @@ from controllers.web import web_ns
from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password
from models.account import Account
from models import Account
from services.account_service import AccountService
@@ -69,9 +70,6 @@ class ForgotPasswordSendEmailApi(Resource):
def post(self):
payload = ForgotPasswordSendPayload.model_validate(web_ns.payload or {})
request_email = payload.email
normalized_email = request_email.lower()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
@@ -82,12 +80,12 @@ class ForgotPasswordSendEmailApi(Resource):
language = "en-US"
with Session(db.engine) as session:
account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session)
account = session.execute(select(Account).filter_by(email=payload.email)).scalar_one_or_none()
token = None
if account is None:
raise AuthenticationFailedError()
else:
token = AccountService.send_reset_password_email(account=account, email=normalized_email, language=language)
token = AccountService.send_reset_password_email(account=account, email=payload.email, language=language)
return {"result": "success", "data": token}
@@ -106,9 +104,9 @@ class ForgotPasswordCheckApi(Resource):
def post(self):
payload = ForgotPasswordCheckPayload.model_validate(web_ns.payload or {})
user_email = payload.email.lower()
user_email = payload.email
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email)
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(payload.email)
if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError()
@@ -116,16 +114,11 @@ class ForgotPasswordCheckApi(Resource):
if token_data is None:
raise InvalidTokenError()
token_email = token_data.get("email")
if not isinstance(token_email, str):
raise InvalidEmailError()
normalized_token_email = token_email.lower()
if user_email != normalized_token_email:
if user_email != token_data.get("email"):
raise InvalidEmailError()
if payload.code != token_data.get("code"):
AccountService.add_forgot_password_error_rate_limit(user_email)
AccountService.add_forgot_password_error_rate_limit(payload.email)
raise EmailCodeError()
# Verified, revoke the first token
@@ -133,11 +126,11 @@ class ForgotPasswordCheckApi(Resource):
# Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token(
token_email, code=payload.code, additional_data={"phase": "reset"}
user_email, code=payload.code, additional_data={"phase": "reset"}
)
AccountService.reset_forgot_password_error_rate_limit(user_email)
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
AccountService.reset_forgot_password_error_rate_limit(payload.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@web_ns.route("/forgot-password/resets")
@@ -181,7 +174,7 @@ class ForgotPasswordResetApi(Resource):
email = reset_data.get("email", "")
with Session(db.engine) as session:
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
if account:
self._update_existing_account(account, password_hashed, salt, session)

View File

@@ -10,12 +10,7 @@ from controllers.console.auth.error import (
InvalidEmailError,
)
from controllers.console.error import AccountBannedError
from controllers.console.wraps import (
decrypt_code_field,
decrypt_password_field,
only_edition_enterprise,
setup_required,
)
from controllers.console.wraps import only_edition_enterprise, setup_required
from controllers.web import web_ns
from controllers.web.wraps import decode_jwt_token
from libs.helper import email
@@ -47,7 +42,6 @@ class LoginApi(Resource):
404: "Account not found",
}
)
@decrypt_password_field
def post(self):
"""Authenticate user and login."""
parser = (
@@ -187,7 +181,6 @@ class EmailCodeLoginApi(Resource):
404: "Account not found",
}
)
@decrypt_code_field
def post(self):
parser = (
reqparse.RequestParser()
@@ -197,29 +190,25 @@ class EmailCodeLoginApi(Resource):
)
args = parser.parse_args()
user_email = args["email"].lower()
user_email = args["email"]
token_data = WebAppAuthService.get_email_code_login_data(args["token"])
if token_data is None:
raise InvalidTokenError()
token_email = token_data.get("email")
if not isinstance(token_email, str):
raise InvalidEmailError()
normalized_token_email = token_email.lower()
if normalized_token_email != user_email:
if token_data["email"] != args["email"]:
raise InvalidEmailError()
if token_data["code"] != args["code"]:
raise EmailCodeError()
WebAppAuthService.revoke_email_code_login_token(args["token"])
account = WebAppAuthService.get_user_through_email(token_email)
account = WebAppAuthService.get_user_through_email(user_email)
if not account:
raise AuthenticationFailedError()
token = WebAppAuthService.login(account=account)
AccountService.reset_login_error_rate_limit(user_email)
AccountService.reset_login_error_rate_limit(args["email"])
response = make_response({"result": "success", "data": {"access_token": token}})
# set_access_token_to_cookie(request, response, token, samesite="None", httponly=False)
return response

View File

@@ -1,7 +1,6 @@
import json
import logging
import uuid
from decimal import Decimal
from typing import Union, cast
from sqlalchemy import select
@@ -42,7 +41,6 @@ from core.tools.tool_manager import ToolManager
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db
from factories import file_factory
from models.enums import CreatorUserRole
from models.model import Conversation, Message, MessageAgentThought, MessageFile
logger = logging.getLogger(__name__)
@@ -302,7 +300,6 @@ class BaseAgentRunner(AppRunner):
thought = MessageAgentThought(
message_id=message_id,
message_chain_id=None,
tool_process_data=None,
thought="",
tool=tool_name,
tool_labels_str="{}",
@@ -310,20 +307,20 @@ class BaseAgentRunner(AppRunner):
tool_input=tool_input,
message=message,
message_token=0,
message_unit_price=Decimal(0),
message_price_unit=Decimal("0.001"),
message_unit_price=0,
message_price_unit=0,
message_files=json.dumps(messages_ids) if messages_ids else "",
answer="",
observation="",
answer_token=0,
answer_unit_price=Decimal(0),
answer_price_unit=Decimal("0.001"),
answer_unit_price=0,
answer_price_unit=0,
tokens=0,
total_price=Decimal(0),
total_price=0,
position=self.agent_thought_count + 1,
currency="USD",
latency=0,
created_by_role=CreatorUserRole.ACCOUNT,
created_by_role="account",
created_by=self.user_id,
)
@@ -356,8 +353,7 @@ class BaseAgentRunner(AppRunner):
raise ValueError("agent thought not found")
if thought:
existing_thought = agent_thought.thought or ""
agent_thought.thought = f"{existing_thought}{thought}"
agent_thought.thought += thought
if tool_name:
agent_thought.tool = tool_name
@@ -455,30 +451,21 @@ class BaseAgentRunner(AppRunner):
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
if agent_thoughts:
for agent_thought in agent_thoughts:
tool_names_raw = agent_thought.tool
if tool_names_raw:
tool_names = tool_names_raw.split(";")
tools = agent_thought.tool
if tools:
tools = tools.split(";")
tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_call_response: list[ToolPromptMessage] = []
tool_input_payload = agent_thought.tool_input
if tool_input_payload:
try:
tool_inputs = json.loads(tool_input_payload)
except Exception:
tool_inputs = {tool: {} for tool in tool_names}
else:
tool_inputs = {tool: {} for tool in tool_names}
try:
tool_inputs = json.loads(agent_thought.tool_input)
except Exception:
tool_inputs = {tool: {} for tool in tools}
try:
tool_responses = json.loads(agent_thought.observation)
except Exception:
tool_responses = dict.fromkeys(tools, agent_thought.observation)
observation_payload = agent_thought.observation
if observation_payload:
try:
tool_responses = json.loads(observation_payload)
except Exception:
tool_responses = dict.fromkeys(tool_names, observation_payload)
else:
tool_responses = dict.fromkeys(tool_names, observation_payload)
for tool in tool_names:
for tool in tools:
# generate a uuid for tool call
tool_call_id = str(uuid.uuid4())
tool_calls.append(
@@ -508,7 +495,7 @@ class BaseAgentRunner(AppRunner):
*tool_call_response,
]
)
if not tool_names_raw:
if not tools:
result.append(AssistantPromptMessage(content=agent_thought.thought))
else:
if message.answer:

View File

@@ -1,468 +0,0 @@
import json
import logging
from collections.abc import Generator
from copy import deepcopy
from typing import Any, Union
from core.agent.base_agent_runner import BaseAgentRunner
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
from core.file import file_manager
from core.model_runtime.entities import (
AssistantPromptMessage,
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
LLMUsage,
PromptMessage,
PromptMessageContentType,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
from core.workflow.nodes.agent.exc import AgentMaxIterationError
from models.model import Message
logger = logging.getLogger(__name__)
class FunctionCallAgentRunner(BaseAgentRunner):
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
"""
Run FunctionCall agent application
"""
self.query = query
app_generate_entity = self.application_generate_entity
app_config = self.app_config
assert app_config is not None, "app_config is required"
assert app_config.agent is not None, "app_config.agent is required"
# convert tools into ModelRuntime Tool format
tool_instances, prompt_messages_tools = self._init_prompt_tools()
assert app_config.agent
iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration, 99) + 1
# continue to run until there is not any tool call
function_call_state = True
llm_usage: dict[str, LLMUsage | None] = {"usage": None}
final_answer = ""
prompt_messages: list = [] # Initialize prompt_messages
# get tracing instance
trace_manager = app_generate_entity.trace_manager
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMUsage):
if not final_llm_usage_dict["usage"]:
final_llm_usage_dict["usage"] = usage
else:
llm_usage = final_llm_usage_dict["usage"]
llm_usage.prompt_tokens += usage.prompt_tokens
llm_usage.completion_tokens += usage.completion_tokens
llm_usage.total_tokens += usage.total_tokens
llm_usage.prompt_price += usage.prompt_price
llm_usage.completion_price += usage.completion_price
llm_usage.total_price += usage.total_price
model_instance = self.model_instance
while function_call_state and iteration_step <= max_iteration_steps:
function_call_state = False
if iteration_step == max_iteration_steps:
# the last iteration, remove all tools
prompt_messages_tools = []
message_file_ids: list[str] = []
agent_thought_id = self.create_agent_thought(
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
)
# recalc llm max tokens
prompt_messages = self._organize_prompt_messages()
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=app_generate_entity.model_conf.parameters,
tools=prompt_messages_tools,
stop=app_generate_entity.model_conf.stop,
stream=self.stream_tool_call,
user=self.user_id,
callbacks=[],
)
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
# save full response
response = ""
# save tool call names and inputs
tool_call_names = ""
tool_call_inputs = ""
current_llm_usage = None
if isinstance(chunks, Generator):
is_first_chunk = True
for chunk in chunks:
if is_first_chunk:
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
)
is_first_chunk = False
# check if there is any tool call
if self.check_tool_calls(chunk):
function_call_state = True
tool_calls.extend(self.extract_tool_calls(chunk) or [])
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
try:
tool_call_inputs = json.dumps(
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
)
except TypeError:
# fallback: force ASCII to handle non-serializable objects
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
if chunk.delta.message and chunk.delta.message.content:
if isinstance(chunk.delta.message.content, list):
for content in chunk.delta.message.content:
response += content.data
else:
response += str(chunk.delta.message.content)
if chunk.delta.usage:
increase_usage(llm_usage, chunk.delta.usage)
current_llm_usage = chunk.delta.usage
yield chunk
else:
result = chunks
# check if there is any tool call
if self.check_blocking_tool_calls(result):
function_call_state = True
tool_calls.extend(self.extract_blocking_tool_calls(result) or [])
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
try:
tool_call_inputs = json.dumps(
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
)
except TypeError:
# fallback: force ASCII to handle non-serializable objects
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
if result.usage:
increase_usage(llm_usage, result.usage)
current_llm_usage = result.usage
if result.message and result.message.content:
if isinstance(result.message.content, list):
for content in result.message.content:
response += content.data
else:
response += str(result.message.content)
if not result.message.content:
result.message.content = ""
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
)
yield LLMResultChunk(
model=model_instance.model,
prompt_messages=result.prompt_messages,
system_fingerprint=result.system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=result.message,
usage=result.usage,
),
)
assistant_message = AssistantPromptMessage(content=response, tool_calls=[])
if tool_calls:
assistant_message.tool_calls = [
AssistantPromptMessage.ToolCall(
id=tool_call[0],
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False)
),
)
for tool_call in tool_calls
]
self._current_thoughts.append(assistant_message)
# save thought
self.save_agent_thought(
agent_thought_id=agent_thought_id,
tool_name=tool_call_names,
tool_input=tool_call_inputs,
thought=response,
tool_invoke_meta=None,
observation=None,
answer=response,
messages_ids=[],
llm_usage=current_llm_usage,
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
)
final_answer += response + "\n"
# Check if max iteration is reached and model still wants to call tools
if iteration_step == max_iteration_steps and tool_calls:
raise AgentMaxIterationError(app_config.agent.max_iteration)
# call tools
tool_responses = []
for tool_call_id, tool_call_name, tool_call_args in tool_calls:
tool_instance = tool_instances.get(tool_call_name)
if not tool_instance:
tool_response = {
"tool_call_id": tool_call_id,
"tool_call_name": tool_call_name,
"tool_response": f"there is not a tool named {tool_call_name}",
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(),
}
else:
# invoke tool
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
tool=tool_instance,
tool_parameters=tool_call_args,
user_id=self.user_id,
tenant_id=self.tenant_id,
message=self.message,
invoke_from=self.application_generate_entity.invoke_from,
agent_tool_callback=self.agent_callback,
trace_manager=trace_manager,
app_id=self.application_generate_entity.app_config.app_id,
message_id=self.message.id,
conversation_id=self.conversation.id,
)
# publish files
for message_file_id in message_files:
# publish message file
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
)
# add message file ids
message_file_ids.append(message_file_id)
tool_response = {
"tool_call_id": tool_call_id,
"tool_call_name": tool_call_name,
"tool_response": tool_invoke_response,
"meta": tool_invoke_meta.to_dict(),
}
tool_responses.append(tool_response)
if tool_response["tool_response"] is not None:
self._current_thoughts.append(
ToolPromptMessage(
content=str(tool_response["tool_response"]),
tool_call_id=tool_call_id,
name=tool_call_name,
)
)
if len(tool_responses) > 0:
# save agent thought
self.save_agent_thought(
agent_thought_id=agent_thought_id,
tool_name="",
tool_input="",
thought="",
tool_invoke_meta={
tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses
},
observation={
tool_response["tool_call_name"]: tool_response["tool_response"]
for tool_response in tool_responses
},
answer="",
messages_ids=message_file_ids,
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
)
# update prompt tool
for prompt_tool in prompt_messages_tools:
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
iteration_step += 1
# publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
system_fingerprint="",
)
),
PublishFrom.APPLICATION_MANAGER,
)
def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
"""
Check if there is any tool call in llm result chunk
"""
if llm_result_chunk.delta.message.tool_calls:
return True
return False
def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
"""
Check if there is any blocking tool call in llm result
"""
if llm_result.message.tool_calls:
return True
return False
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
"""
Extract tool calls from llm result chunk
Returns:
List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
"""
tool_calls = []
for prompt_message in llm_result_chunk.delta.message.tool_calls:
args = {}
if prompt_message.function.arguments != "":
args = json.loads(prompt_message.function.arguments)
tool_calls.append(
(
prompt_message.id,
prompt_message.function.name,
args,
)
)
return tool_calls
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
"""
Extract blocking tool calls from llm result
Returns:
List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
"""
tool_calls = []
for prompt_message in llm_result.message.tool_calls:
args = {}
if prompt_message.function.arguments != "":
args = json.loads(prompt_message.function.arguments)
tool_calls.append(
(
prompt_message.id,
prompt_message.function.name,
args,
)
)
return tool_calls
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Initialize system message
"""
if not prompt_messages and prompt_template:
return [
SystemPromptMessage(content=prompt_template),
]
if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
return prompt_messages or []
def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Organize user query
"""
if self.files:
# get image detail config
image_detail_config = (
self.application_generate_entity.file_upload_config.image_config.detail
if (
self.application_generate_entity.file_upload_config
and self.application_generate_entity.file_upload_config.image_config
)
else None
)
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
for file in self.files:
prompt_message_contents.append(
file_manager.to_prompt_message_content(
file,
image_detail_config=image_detail_config,
)
)
prompt_message_contents.append(TextPromptMessageContent(data=query))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
prompt_messages.append(UserPromptMessage(content=query))
return prompt_messages
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
As for now, gpt supports both fc and vision at the first iteration.
We need to remove the image messages from the prompt messages at the first iteration.
"""
prompt_messages = deepcopy(prompt_messages)
for prompt_message in prompt_messages:
if isinstance(prompt_message, UserPromptMessage):
if isinstance(prompt_message.content, list):
prompt_message.content = "\n".join(
[
content.data
if content.type == PromptMessageContentType.TEXT
else "[image]"
if content.type == PromptMessageContentType.IMAGE
else "[file]"
for content in prompt_message.content
]
)
return prompt_messages
def _organize_prompt_messages(self):
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
query_prompt_messages = self._organize_user_query(self.query or "", [])
self.history_prompt_messages = AgentHistoryPromptTransform(
model_config=self.model_config,
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
history_messages=self.history_prompt_messages,
memory=self.memory,
).get_prompt()
prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
if len(self._current_thoughts) != 0:
# clear messages after the first iteration
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
return prompt_messages

View File

@@ -1,3 +1,4 @@
import json
from collections.abc import Sequence
from enum import StrEnum, auto
from typing import Any, Literal
@@ -120,7 +121,7 @@ class VariableEntity(BaseModel):
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
json_schema: dict | None = Field(default=None)
json_schema: str | None = Field(default=None)
@field_validator("description", mode="before")
@classmethod
@@ -134,11 +135,17 @@ class VariableEntity(BaseModel):
@field_validator("json_schema")
@classmethod
def validate_json_schema(cls, schema: dict | None) -> dict | None:
def validate_json_schema(cls, schema: str | None) -> str | None:
if schema is None:
return None
try:
Draft7Validator.check_schema(schema)
json_schema = json.loads(schema)
except json.JSONDecodeError:
raise ValueError(f"invalid json_schema value {schema}")
try:
Draft7Validator.check_schema(json_schema)
except SchemaError as e:
raise ValueError(f"Invalid JSON schema: {e.message}")
return schema

View File

@@ -26,6 +26,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig:
features_dict = workflow.features_dict
app_mode = AppMode.value_of(app_model.mode)
app_config = AdvancedChatAppConfig(
tenant_id=app_model.tenant_id,

View File

@@ -39,6 +39,7 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.otel import WorkflowAppRunnerHandler, trace_span
from models import Workflow
from models.enums import UserFrom
from models.model import App, Conversation, Message, MessageAnnotation
from models.workflow import ConversationVariable
from services.conversation_variable_updater import ConversationVariableUpdater
@@ -105,11 +106,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
if not app_record:
raise ValueError("App not found")
invoke_from = self.application_generate_entity.invoke_from
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
invoke_from = InvokeFrom.DEBUGGER
user_from = self._resolve_user_from(invoke_from)
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
# Handle single iteration or single loop run
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
@@ -162,8 +158,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
workflow_id=self._workflow.id,
tenant_id=self._workflow.tenant_id,
user_id=self.application_generate_entity.user_id,
user_from=user_from,
invoke_from=invoke_from,
)
db.session.close()
@@ -181,8 +175,12 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
graph=graph,
graph_config=self._workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=user_from,
invoke_from=invoke_from,
user_from=(
UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else UserFrom.END_USER
),
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,

View File

@@ -471,6 +471,25 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if node_finish_resp:
yield node_finish_resp
# For ANSWER nodes, check if we need to send a message_replace event
# Only send if the final output differs from the accumulated task_state.answer
# This happens when variables were updated by variable_assigner during workflow execution
if event.node_type == NodeType.ANSWER and event.outputs:
final_answer = event.outputs.get("answer")
if final_answer is not None and final_answer != self._task_state.answer:
logger.info(
"ANSWER node final output '%s' differs from accumulated answer '%s', sending message_replace event",
final_answer,
self._task_state.answer,
)
# Update the task state answer
self._task_state.answer = str(final_answer)
# Send message_replace event to update the UI
yield self._message_cycle_manager.message_replace_to_stream_response(
answer=str(final_answer),
reason="variable_update",
)
def _handle_node_failed_events(
self,
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],

View File

@@ -1,3 +1,4 @@
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Union, final
@@ -75,24 +76,12 @@ class BaseAppGenerator:
user_inputs = {**user_inputs, **files_inputs, **file_list_inputs}
# Check if all files are converted to File
invalid_dict_keys = [
k
for k, v in user_inputs.items()
if isinstance(v, dict)
and entity_dictionary[k].type not in {VariableEntityType.FILE, VariableEntityType.JSON_OBJECT}
]
if invalid_dict_keys:
raise ValueError(f"Invalid input type for {invalid_dict_keys}")
invalid_list_dict_keys = [
k
for k, v in user_inputs.items()
if isinstance(v, list)
and any(isinstance(item, dict) for item in v)
and entity_dictionary[k].type != VariableEntityType.FILE_LIST
]
if invalid_list_dict_keys:
raise ValueError(f"Invalid input type for {invalid_list_dict_keys}")
if any(filter(lambda v: isinstance(v, dict), user_inputs.values())):
raise ValueError("Invalid input type")
if any(
filter(lambda v: isinstance(v, dict), filter(lambda item: isinstance(item, list), user_inputs.values()))
):
raise ValueError("Invalid input type")
return user_inputs
@@ -189,8 +178,12 @@ class BaseAppGenerator:
elif value == 0:
value = False
case VariableEntityType.JSON_OBJECT:
if value and not isinstance(value, dict):
raise ValueError(f"{variable_entity.variable} in input form must be a dict")
if not isinstance(value, str):
raise ValueError(f"{variable_entity.variable} in input form must be a string")
try:
json.loads(value)
except json.JSONDecodeError:
raise ValueError(f"{variable_entity.variable} in input form must be a valid JSON object")
case _:
raise AssertionError("this statement should be unreachable.")

View File

@@ -73,15 +73,9 @@ class PipelineRunner(WorkflowBasedAppRunner):
"""
app_config = self.application_generate_entity.app_config
app_config = cast(PipelineConfig, app_config)
invoke_from = self.application_generate_entity.invoke_from
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
invoke_from = InvokeFrom.DEBUGGER
user_from = self._resolve_user_from(invoke_from)
user_id = None
if invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).where(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
@@ -123,7 +117,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
dataset_id=self.application_generate_entity.dataset_id,
datasource_type=self.application_generate_entity.datasource_type,
datasource_info=self.application_generate_entity.datasource_info,
invoke_from=invoke_from.value,
invoke_from=self.application_generate_entity.invoke_from.value,
)
rag_pipeline_variables = []
@@ -155,8 +149,6 @@ class PipelineRunner(WorkflowBasedAppRunner):
graph_runtime_state=graph_runtime_state,
start_node_id=self.application_generate_entity.start_node_id,
workflow=workflow,
user_from=user_from,
invoke_from=invoke_from,
)
# RUN WORKFLOW
@@ -167,8 +159,12 @@ class PipelineRunner(WorkflowBasedAppRunner):
graph=graph,
graph_config=workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=user_from,
invoke_from=invoke_from,
user_from=(
UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else UserFrom.END_USER
),
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
graph_runtime_state=graph_runtime_state,
variable_pool=variable_pool,
@@ -214,12 +210,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
return workflow
def _init_rag_pipeline_graph(
self,
workflow: Workflow,
graph_runtime_state: GraphRuntimeState,
start_node_id: str | None = None,
user_from: UserFrom = UserFrom.ACCOUNT,
invoke_from: InvokeFrom = InvokeFrom.SERVICE_API,
self, workflow: Workflow, graph_runtime_state: GraphRuntimeState, start_node_id: str | None = None
) -> Graph:
"""
Init pipeline graph
@@ -262,8 +253,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
workflow_id=workflow.id,
graph_config=graph_config,
user_id=self.application_generate_entity.user_id,
user_from=user_from,
invoke_from=invoke_from,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)

View File

@@ -20,6 +20,7 @@ from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_redis import redis_client
from extensions.otel import WorkflowAppRunnerHandler, trace_span
from libs.datetime_utils import naive_utc_now
from models.enums import UserFrom
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@@ -73,12 +74,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
)
invoke_from = self.application_generate_entity.invoke_from
# if only single iteration or single loop run is requested
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
invoke_from = InvokeFrom.DEBUGGER
user_from = self._resolve_user_from(invoke_from)
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
workflow=self._workflow,
@@ -106,8 +102,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
workflow_id=self._workflow.id,
tenant_id=self._workflow.tenant_id,
user_id=self.application_generate_entity.user_id,
user_from=user_from,
invoke_from=invoke_from,
root_node_id=self._root_node_id,
)
@@ -126,8 +120,12 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
graph=graph,
graph_config=self._workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=user_from,
invoke_from=invoke_from,
user_from=(
UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else UserFrom.END_USER
),
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,

View File

@@ -77,18 +77,10 @@ class WorkflowBasedAppRunner:
self._app_id = app_id
self._graph_engine_layers = graph_engine_layers
@staticmethod
def _resolve_user_from(invoke_from: InvokeFrom) -> UserFrom:
if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}:
return UserFrom.ACCOUNT
return UserFrom.END_USER
def _init_graph(
self,
graph_config: Mapping[str, Any],
graph_runtime_state: GraphRuntimeState,
user_from: UserFrom,
invoke_from: InvokeFrom,
workflow_id: str = "",
tenant_id: str = "",
user_id: str = "",
@@ -113,8 +105,8 @@ class WorkflowBasedAppRunner:
workflow_id=workflow_id,
graph_config=graph_config,
user_id=user_id,
user_from=user_from,
invoke_from=invoke_from,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)
@@ -258,7 +250,7 @@ class WorkflowBasedAppRunner:
graph_config=graph_config,
user_id="",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)

View File

@@ -3,7 +3,6 @@ from pydantic import BaseModel, Field, field_validator
class PreviewDetail(BaseModel):
content: str
summary: str | None = None
child_chunks: list[str] | None = None

View File

@@ -56,10 +56,6 @@ class HostingConfiguration:
self.provider_map[f"{DEFAULT_PLUGIN_ID}/minimax/minimax"] = self.init_minimax()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/spark/spark"] = self.init_spark()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/zhipuai/zhipuai"] = self.init_zhipuai()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/gemini/google"] = self.init_gemini()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/x/x"] = self.init_xai()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/deepseek/deepseek"] = self.init_deepseek()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/tongyi/tongyi"] = self.init_tongyi()
self.moderation_config = self.init_moderation_config()
@@ -132,7 +128,7 @@ class HostingConfiguration:
quotas: list[HostingQuota] = []
if dify_config.HOSTED_OPENAI_TRIAL_ENABLED:
hosted_quota_limit = 0
hosted_quota_limit = dify_config.HOSTED_OPENAI_QUOTA_LIMIT
trial_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models)
quotas.append(trial_quota)
@@ -160,49 +156,18 @@ class HostingConfiguration:
quota_unit=quota_unit,
)
def init_gemini(self) -> HostingProvider:
quota_unit = QuotaUnit.CREDITS
quotas: list[HostingQuota] = []
if dify_config.HOSTED_GEMINI_TRIAL_ENABLED:
hosted_quota_limit = 0
trial_models = self.parse_restrict_models_from_env("HOSTED_GEMINI_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models)
quotas.append(trial_quota)
if dify_config.HOSTED_GEMINI_PAID_ENABLED:
paid_models = self.parse_restrict_models_from_env("HOSTED_GEMINI_PAID_MODELS")
paid_quota = PaidHostingQuota(restrict_models=paid_models)
quotas.append(paid_quota)
if len(quotas) > 0:
credentials = {
"google_api_key": dify_config.HOSTED_GEMINI_API_KEY,
}
if dify_config.HOSTED_GEMINI_API_BASE:
credentials["google_base_url"] = dify_config.HOSTED_GEMINI_API_BASE
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
def init_anthropic(self) -> HostingProvider:
quota_unit = QuotaUnit.CREDITS
@staticmethod
def init_anthropic() -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
quotas: list[HostingQuota] = []
if dify_config.HOSTED_ANTHROPIC_TRIAL_ENABLED:
hosted_quota_limit = 0
trail_models = self.parse_restrict_models_from_env("HOSTED_ANTHROPIC_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
hosted_quota_limit = dify_config.HOSTED_ANTHROPIC_QUOTA_LIMIT
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit)
quotas.append(trial_quota)
if dify_config.HOSTED_ANTHROPIC_PAID_ENABLED:
paid_models = self.parse_restrict_models_from_env("HOSTED_ANTHROPIC_PAID_MODELS")
paid_quota = PaidHostingQuota(restrict_models=paid_models)
paid_quota = PaidHostingQuota()
quotas.append(paid_quota)
if len(quotas) > 0:
@@ -220,94 +185,6 @@ class HostingConfiguration:
quota_unit=quota_unit,
)
def init_tongyi(self) -> HostingProvider:
quota_unit = QuotaUnit.CREDITS
quotas: list[HostingQuota] = []
if dify_config.HOSTED_TONGYI_TRIAL_ENABLED:
hosted_quota_limit = 0
trail_models = self.parse_restrict_models_from_env("HOSTED_TONGYI_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
quotas.append(trial_quota)
if dify_config.HOSTED_TONGYI_PAID_ENABLED:
paid_models = self.parse_restrict_models_from_env("HOSTED_TONGYI_PAID_MODELS")
paid_quota = PaidHostingQuota(restrict_models=paid_models)
quotas.append(paid_quota)
if len(quotas) > 0:
credentials = {
"dashscope_api_key": dify_config.HOSTED_TONGYI_API_KEY,
"use_international_endpoint": dify_config.HOSTED_TONGYI_USE_INTERNATIONAL_ENDPOINT,
}
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
def init_xai(self) -> HostingProvider:
quota_unit = QuotaUnit.CREDITS
quotas: list[HostingQuota] = []
if dify_config.HOSTED_XAI_TRIAL_ENABLED:
hosted_quota_limit = 0
trail_models = self.parse_restrict_models_from_env("HOSTED_XAI_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
quotas.append(trial_quota)
if dify_config.HOSTED_XAI_PAID_ENABLED:
paid_models = self.parse_restrict_models_from_env("HOSTED_XAI_PAID_MODELS")
paid_quota = PaidHostingQuota(restrict_models=paid_models)
quotas.append(paid_quota)
if len(quotas) > 0:
credentials = {
"api_key": dify_config.HOSTED_XAI_API_KEY,
}
if dify_config.HOSTED_XAI_API_BASE:
credentials["endpoint_url"] = dify_config.HOSTED_XAI_API_BASE
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
def init_deepseek(self) -> HostingProvider:
quota_unit = QuotaUnit.CREDITS
quotas: list[HostingQuota] = []
if dify_config.HOSTED_DEEPSEEK_TRIAL_ENABLED:
hosted_quota_limit = 0
trail_models = self.parse_restrict_models_from_env("HOSTED_DEEPSEEK_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
quotas.append(trial_quota)
if dify_config.HOSTED_DEEPSEEK_PAID_ENABLED:
paid_models = self.parse_restrict_models_from_env("HOSTED_DEEPSEEK_PAID_MODELS")
paid_quota = PaidHostingQuota(restrict_models=paid_models)
quotas.append(paid_quota)
if len(quotas) > 0:
credentials = {
"api_key": dify_config.HOSTED_DEEPSEEK_API_KEY,
}
if dify_config.HOSTED_DEEPSEEK_API_BASE:
credentials["endpoint_url"] = dify_config.HOSTED_DEEPSEEK_API_BASE
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
@staticmethod
def init_minimax() -> HostingProvider:
quota_unit = QuotaUnit.TOKENS

View File

@@ -311,18 +311,14 @@ class IndexingRunner:
qa_preview_texts: list[QAPreviewDetail] = []
total_segments = 0
# doc_form represents the segmentation method (general, parent-child, QA)
index_type = doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
# one extract_setting is one source document
for extract_setting in extract_settings:
# extract
processing_rule = DatasetProcessRule(
mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"])
)
# Extract document content
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
# Cleaning and segmentation
documents = index_processor.transform(
text_docs,
current_user=None,
@@ -365,12 +361,6 @@ class IndexingRunner:
if doc_form and doc_form == "qa_model":
return IndexingEstimate(total_segments=total_segments * 20, qa_preview=qa_preview_texts, preview=[])
# Generate summary preview
summary_index_setting = tmp_processing_rule["summary_index_setting"] if "summary_index_setting" in tmp_processing_rule else None
if summary_index_setting and summary_index_setting.get('enable') and preview_texts:
preview_texts = index_processor.generate_summary_preview(tenant_id, preview_texts, summary_index_setting)
return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
def _extract(

View File

@@ -434,6 +434,3 @@ INSTRUCTION_GENERATE_TEMPLATE_PROMPT = """The output of this prompt is not as ex
You should edit the prompt according to the IDEAL OUTPUT."""
INSTRUCTION_GENERATE_TEMPLATE_CODE = """Please fix the errors in the {{#error_message#}}."""
DEFAULT_GENERATOR_SUMMARY_PROMPT = """
You are a helpful assistant that summarizes long pieces of text into concise summaries. Given the following text, generate a brief summary that captures the main points and key information. The summary should be clear, concise, and written in complete sentences. """

View File

@@ -251,7 +251,10 @@ class AssistantPromptMessage(PromptMessage):
:return: True if prompt message is empty, False otherwise
"""
return super().is_empty() and not self.tool_calls
if not super().is_empty() and not self.tool_calls:
return False
return True
class SystemPromptMessage(PromptMessage):

View File

@@ -1,7 +1,6 @@
import logging
from collections.abc import Sequence
from opentelemetry.trace import SpanKind
from sqlalchemy.orm import sessionmaker
from core.ops.aliyun_trace.data_exporter.traceclient import (
@@ -152,7 +151,6 @@ class AliyunDataTrace(BaseTraceInstance):
),
status=status,
links=trace_metadata.links,
span_kind=SpanKind.SERVER,
)
self.trace_client.add_span(message_span)
@@ -458,7 +456,6 @@ class AliyunDataTrace(BaseTraceInstance):
),
status=status,
links=trace_metadata.links,
span_kind=SpanKind.SERVER,
)
self.trace_client.add_span(message_span)
@@ -478,7 +475,6 @@ class AliyunDataTrace(BaseTraceInstance):
),
status=status,
links=trace_metadata.links,
span_kind=SpanKind.SERVER if message_span_id is None else SpanKind.INTERNAL,
)
self.trace_client.add_span(workflow_span)

View File

@@ -166,7 +166,7 @@ class SpanBuilder:
attributes=span_data.attributes,
events=span_data.events,
links=span_data.links,
kind=span_data.span_kind,
kind=trace_api.SpanKind.INTERNAL,
status=span_data.status,
start_time=span_data.start_time,
end_time=span_data.end_time,

View File

@@ -4,7 +4,7 @@ from typing import Any
from opentelemetry import trace as trace_api
from opentelemetry.sdk.trace import Event
from opentelemetry.trace import SpanKind, Status, StatusCode
from opentelemetry.trace import Status, StatusCode
from pydantic import BaseModel, Field
@@ -34,4 +34,3 @@ class SpanData(BaseModel):
status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.")
start_time: int | None = Field(..., description="The start time of the span in nanoseconds.")
end_time: int | None = Field(..., description="The end time of the span in nanoseconds.")
span_kind: SpanKind = Field(default=SpanKind.INTERNAL, description="The OpenTelemetry SpanKind for this span.")

View File

@@ -618,18 +618,18 @@ class ProviderManager:
)
for quota in configuration.quotas:
if quota.quota_type in (ProviderQuotaType.TRIAL, ProviderQuotaType.PAID):
if quota.quota_type == ProviderQuotaType.TRIAL:
# Init trial provider records if not exists
if quota.quota_type not in provider_quota_to_provider_record_dict:
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
try:
# FIXME ignore the type error, only TrialHostingQuota has limit need to change the logic
new_provider_record = Provider(
tenant_id=tenant_id,
# TODO: Use provider name with prefix after the data migration.
provider_name=ModelProviderID(provider_name).provider_name,
provider_type=ProviderType.SYSTEM.value,
quota_type=quota.quota_type,
quota_limit=0, # type: ignore
provider_type=ProviderType.SYSTEM,
quota_type=ProviderQuotaType.TRIAL,
quota_limit=quota.quota_limit, # type: ignore
quota_used=0,
is_valid=True,
)
@@ -641,8 +641,8 @@ class ProviderManager:
stmt = select(Provider).where(
Provider.tenant_id == tenant_id,
Provider.provider_name == ModelProviderID(provider_name).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == quota.quota_type,
Provider.provider_type == ProviderType.SYSTEM,
Provider.quota_type == ProviderQuotaType.TRIAL,
)
existed_provider_record = db.session.scalar(stmt)
if not existed_provider_record:
@@ -912,22 +912,6 @@ class ProviderManager:
provider_record
)
quota_configurations = []
if dify_config.EDITION == "CLOUD":
from services.credit_pool_service import CreditPoolService
trail_pool = CreditPoolService.get_pool(
tenant_id=tenant_id,
pool_type=ProviderQuotaType.TRIAL.value,
)
paid_pool = CreditPoolService.get_pool(
tenant_id=tenant_id,
pool_type=ProviderQuotaType.PAID.value,
)
else:
trail_pool = None
paid_pool = None
for provider_quota in provider_hosting_configuration.quotas:
if provider_quota.quota_type not in quota_type_to_provider_records_dict:
if provider_quota.quota_type == ProviderQuotaType.FREE:
@@ -948,36 +932,16 @@ class ProviderManager:
raise ValueError("quota_used is None")
if provider_record.quota_limit is None:
raise ValueError("quota_limit is None")
if provider_quota.quota_type == ProviderQuotaType.TRIAL and trail_pool is not None:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=trail_pool.quota_used,
quota_limit=trail_pool.quota_limit,
is_valid=trail_pool.quota_limit > trail_pool.quota_used or trail_pool.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
elif provider_quota.quota_type == ProviderQuotaType.PAID and paid_pool is not None:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=paid_pool.quota_used,
quota_limit=paid_pool.quota_limit,
is_valid=paid_pool.quota_limit > paid_pool.quota_used or paid_pool.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
else:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=provider_record.quota_used,
quota_limit=provider_record.quota_limit,
is_valid=provider_record.quota_limit > provider_record.quota_used
or provider_record.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=provider_record.quota_used,
quota_limit=provider_record.quota_limit,
is_valid=provider_record.quota_limit > provider_record.quota_used
or provider_record.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
quota_configurations.append(quota_configuration)

View File

@@ -392,69 +392,6 @@ class RetrievalService:
records = []
include_segment_ids = set()
segment_child_map = {}
segment_file_map = {}
segment_summary_map = {} # Map segment_id to summary content
summary_segment_ids = set() # Track segments retrieved via summary
with Session(bind=db.engine, expire_on_commit=False) as session:
# Process documents
for document in documents:
segment_id = None
attachment_info = None
child_chunk = None
document_id = document.metadata.get("document_id")
if document_id not in dataset_documents:
continue
dataset_document = dataset_documents[document_id]
if not dataset_document:
continue
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
# Handle parent-child documents
if document.metadata.get("doc_type") == DocType.IMAGE:
attachment_info_dict = cls.get_segment_attachment_info(
dataset_document.dataset_id,
dataset_document.tenant_id,
document.metadata.get("doc_id") or "",
session,
)
if attachment_info_dict:
attachment_info = attachment_info_dict["attachment_info"]
segment_id = attachment_info_dict["segment_id"]
else:
# Check if this is a summary document
is_summary = document.metadata.get("is_summary", False)
if is_summary:
# For summary documents, find the original chunk via original_chunk_id
original_chunk_id = document.metadata.get("original_chunk_id")
if not original_chunk_id:
continue
segment_id = original_chunk_id
# Track that this segment was retrieved via summary
summary_segment_ids.add(segment_id)
else:
# For normal documents, find by child chunk index_node_id
child_index_node_id = document.metadata.get("doc_id")
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
child_chunk = session.scalar(child_chunk_stmt)
if not child_chunk:
continue
segment_id = child_chunk.segment_id
if not segment_id:
continue
segment = (
session.query(DocumentSegment)
.where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id == segment_id,
)
.first()
)
valid_dataset_documents = {}
image_doc_ids: list[Any] = []
@@ -570,47 +507,7 @@ class RetrievalService:
max_score = max(
max_score, file_document.metadata.get("score", 0.0) if file_document else 0.0
)
segment = session.scalar(document_segment_stmt)
if segment:
segment_file_map[segment.id] = [attachment_info]
else:
# Check if this is a summary document
is_summary = document.metadata.get("is_summary", False)
if is_summary:
# For summary documents, find the original chunk via original_chunk_id
original_chunk_id = document.metadata.get("original_chunk_id")
if not original_chunk_id:
continue
# Track that this segment was retrieved via summary
summary_segment_ids.add(original_chunk_id)
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id == original_chunk_id,
)
segment = session.scalar(document_segment_stmt)
else:
# For normal documents, find by index_node_id
index_node_id = document.metadata.get("doc_id")
if not index_node_id:
continue
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
segment = session.scalar(document_segment_stmt)
if not segment:
continue
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
record = {
"segment": segment,
"score": document.metadata.get("score"), # type: ignore
}
map_detail = {
"max_score": max_score,
"child_chunks": child_chunk_details,
@@ -645,23 +542,6 @@ class RetrievalService:
if record["segment"].id in attachment_map:
record["files"] = attachment_map[record["segment"].id] # type: ignore[assignment]
# Batch query summaries for segments retrieved via summary (only enabled summaries)
if summary_segment_ids:
from models.dataset import DocumentSegmentSummary
summaries = (
session.query(DocumentSegmentSummary)
.filter(
DocumentSegmentSummary.chunk_id.in_(summary_segment_ids),
DocumentSegmentSummary.status == "completed",
DocumentSegmentSummary.enabled == True, # Only retrieve enabled summaries
)
.all()
)
for summary in summaries:
if summary.summary_content:
segment_summary_map[summary.chunk_id] = summary.summary_content
result: list[RetrievalSegments] = []
for record in records:
# Extract segment
@@ -696,16 +576,9 @@ class RetrievalService:
else None
)
# Extract summary if this segment was retrieved via summary
summary_content = segment_summary_map.get(segment.id)
# Create RetrievalSegments object
retrieval_segment = RetrievalSegments(
segment=segment,
child_chunks=child_chunks_list,
score=score,
files=files,
summary=summary_content
segment=segment, child_chunks=child_chunks_list, score=score, files=files
)
result.append(retrieval_segment)

View File

@@ -20,4 +20,3 @@ class RetrievalSegments(BaseModel):
child_chunks: list[RetrievalChildChunk] | None = None
score: float | None = None
files: list[dict[str, str | int]] | None = None
summary: str | None = None # Summary content if retrieved via summary index

View File

@@ -13,7 +13,6 @@ from urllib.parse import unquote, urlparse
import httpx
from configs import dify_config
from core.entities.knowledge_entities import PreviewDetail
from core.helper import ssrf_proxy
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.index_processor.constant.doc_type import DocType
@@ -46,15 +45,6 @@ class BaseIndexProcessor(ABC):
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
raise NotImplementedError
@abstractmethod
def generate_summary_preview(self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict) -> list[PreviewDetail]:
"""
For each segment in preview_texts, generate a summary using LLM and attach it to the segment.
The summary can be stored in a new attribute, e.g., summary.
This method should be implemented by subclasses.
"""
raise NotImplementedError
@abstractmethod
def load(
self,

View File

@@ -1,13 +1,9 @@
"""Paragraph index processor."""
import logging
import uuid
from collections.abc import Mapping
from typing import Any
logger = logging.getLogger(__name__)
from core.entities.knowledge_entities import PreviewDetail
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.retrieval_service import RetrievalService
@@ -21,19 +17,12 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
from extensions.ext_database import db
from libs import helper
from models.account import Account
from models.dataset import Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import Dataset, DatasetProcessRule
from models.dataset import Document as DatasetDocument
from services.account_service import AccountService
from services.entities.knowledge_entities.knowledge_entities import Rule
from services.summary_index_service import SummaryIndexService
from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
from core.model_runtime.entities.message_entities import UserPromptMessage
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from core.model_manager import ModelInstance
class ParagraphIndexProcessor(BaseIndexProcessor):
@@ -119,29 +108,6 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
keyword.add_texts(documents)
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
# For disable operations, disable_summaries_for_segments is called directly in the task.
# Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
delete_summaries = kwargs.get("delete_summaries", False)
if delete_summaries:
if node_ids:
# Find segments by index_node_id
segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
)
.all()
)
segment_ids = [segment.id for segment in segments]
if segment_ids:
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
else:
# Delete all summaries for the dataset
SummaryIndexService.delete_summaries_for_segments(dataset, None)
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
if node_ids:
@@ -261,70 +227,3 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
}
else:
raise ValueError("Chunks is not a list")
def generate_summary_preview(self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict) -> list[PreviewDetail]:
"""
For each segment, concurrently call generate_summary to generate a summary
and write it to the summary attribute of PreviewDetail.
"""
import concurrent.futures
from flask import current_app
# Capture Flask app context for worker threads
flask_app = None
try:
flask_app = current_app._get_current_object() # type: ignore
except RuntimeError:
logger.warning("No Flask application context available, summary generation may fail")
def process(preview: PreviewDetail) -> None:
"""Generate summary for a single preview item."""
try:
if flask_app:
# Ensure Flask app context in worker thread
with flask_app.app_context():
summary = self.generate_summary(tenant_id, preview.content, summary_index_setting)
preview.summary = summary
else:
# Fallback: try without app context (may fail)
summary = self.generate_summary(tenant_id, preview.content, summary_index_setting)
preview.summary = summary
except Exception as e:
logger.error(f"Failed to generate summary for preview: {str(e)}")
# Don't fail the entire preview if summary generation fails
preview.summary = None
with concurrent.futures.ThreadPoolExecutor() as executor:
list(executor.map(process, preview_texts))
return preview_texts
@staticmethod
def generate_summary(tenant_id: str, text: str, summary_index_setting: dict = None) -> str:
"""
Generate summary for the given text using ModelInstance.invoke_llm and the default or custom summary prompt.
"""
if not summary_index_setting or not summary_index_setting.get("enable"):
raise ValueError("summary_index_setting is required and must be enabled to generate summary.")
model_name = summary_index_setting.get("model_name")
model_provider_name = summary_index_setting.get("model_provider_name")
summary_prompt = summary_index_setting.get("summary_prompt")
# Import default summary prompt
if not summary_prompt:
summary_prompt = DEFAULT_GENERATOR_SUMMARY_PROMPT
prompt = f"{summary_prompt}\n{text}"
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(tenant_id, model_provider_name, ModelType.LLM)
model_instance = ModelInstance(provider_model_bundle, model_name)
prompt_messages = [UserPromptMessage(content=prompt)]
result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters={},
stream=False
)
return getattr(result.message, "content", "")

View File

@@ -25,7 +25,6 @@ from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegm
from models.dataset import Document as DatasetDocument
from services.account_service import AccountService
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
from services.summary_index_service import SummaryIndexService
class ParentChildIndexProcessor(BaseIndexProcessor):
@@ -136,29 +135,6 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
# node_ids is segment's node_ids
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
# For disable operations, disable_summaries_for_segments is called directly in the task.
# Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
delete_summaries = kwargs.get("delete_summaries", False)
if delete_summaries:
if node_ids:
# Find segments by index_node_id
segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
)
.all()
)
segment_ids = [segment.id for segment in segments]
if segment_ids:
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
else:
# Delete all summaries for the dataset
SummaryIndexService.delete_summaries_for_segments(dataset, None)
if dataset.indexing_technique == "high_quality":
delete_child_chunks = kwargs.get("delete_child_chunks") or False
precomputed_child_node_ids = kwargs.get("precomputed_child_node_ids")

View File

@@ -25,10 +25,9 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.account import Account
from models.dataset import Dataset, DocumentSegment
from models.dataset import Dataset
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import Rule
from services.summary_index_service import SummaryIndexService
logger = logging.getLogger(__name__)
@@ -145,30 +144,6 @@ class QAIndexProcessor(BaseIndexProcessor):
vector.create_multimodal(multimodal_documents)
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
# For disable operations, disable_summaries_for_segments is called directly in the task.
# Note: qa_model doesn't generate summaries, but we clean them for completeness
# Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
delete_summaries = kwargs.get("delete_summaries", False)
if delete_summaries:
if node_ids:
# Find segments by index_node_id
segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
)
.all()
)
segment_ids = [segment.id for segment in segments]
if segment_ids:
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
else:
# Delete all summaries for the dataset
SummaryIndexService.delete_summaries_for_segments(dataset, None)
vector = Vector(dataset)
if node_ids:
vector.delete_by_ids(node_ids)

View File

@@ -211,10 +211,6 @@ class WorkflowExecutionStatus(StrEnum):
def is_ended(self) -> bool:
return self in _END_STATE
@classmethod
def ended_values(cls) -> list[str]:
return [status.value for status in _END_STATE]
_END_STATE = frozenset(
[

View File

@@ -62,21 +62,6 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
inputs = {"variable_selector": variable_selector}
process_data = {"documents": value if isinstance(value, list) else [value]}
# Ensure storage_key is loaded for File objects
files_to_check = value if isinstance(value, list) else [value]
files_needing_storage_key = [
f for f in files_to_check
if isinstance(f, File) and not f.storage_key and f.related_id
]
if files_needing_storage_key:
from factories.file_factory import StorageKeyLoader
from extensions.ext_database import db
from sqlalchemy.orm import Session
with Session(bind=db.engine) as session:
storage_key_loader = StorageKeyLoader(session, tenant_id=self.tenant_id)
storage_key_loader.load_storage_keys(files_needing_storage_key)
try:
if isinstance(value, list):
extracted_text_list = list(map(_extract_text_from_file, value))
@@ -430,15 +415,6 @@ def _download_file_content(file: File) -> bytes:
response.raise_for_status()
return response.content
else:
# Check if storage_key is set
if not file.storage_key:
raise FileDownloadError(f"File storage_key is missing for file: {file.filename}")
# Check if file exists before downloading
from extensions.ext_storage import storage
if not storage.exists(file.storage_key):
raise FileDownloadError(f"File not found in storage: {file.storage_key}")
return file_manager.download(file)
except Exception as e:
raise FileDownloadError(f"Error downloading file: {str(e)}") from e

View File

@@ -158,5 +158,3 @@ class KnowledgeIndexNodeData(BaseNodeData):
type: str = "knowledge-index"
chunk_structure: str
index_chunk_variable_selector: list[str]
indexing_technique: str | None = None
summary_index_setting: dict | None = None

View File

@@ -1,11 +1,9 @@
import concurrent.futures
import datetime
import logging
import time
from collections.abc import Mapping
from typing import Any
from flask import current_app
from sqlalchemy import func, select
from core.app.entities.app_invoke_entities import InvokeFrom
@@ -18,9 +16,7 @@ from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary
from services.summary_index_service import SummaryIndexService
from tasks.generate_summary_index_task import generate_summary_index_task
from models.dataset import Dataset, Document, DocumentSegment
from .entities import KnowledgeIndexNodeData
from .exc import (
@@ -71,18 +67,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
# index knowledge
try:
if is_preview:
# Preview mode: generate summaries for chunks directly without saving to database
# Format preview and generate summaries on-the-fly
# Get indexing_technique and summary_index_setting from node_data (workflow graph config)
# or fallback to dataset if not available in node_data
indexing_technique = node_data.indexing_technique or dataset.indexing_technique
summary_index_setting = node_data.summary_index_setting or dataset.summary_index_setting
outputs = self._get_preview_output_with_summaries(
node_data.chunk_structure, chunks, dataset=dataset,
indexing_technique=indexing_technique,
summary_index_setting=summary_index_setting
)
outputs = self._get_preview_output(node_data.chunk_structure, chunks)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
@@ -178,9 +163,6 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
db.session.commit()
# Generate summary index if enabled
self._handle_summary_index_generation(dataset, document, variable_pool)
return {
"dataset_id": ds_id_value,
"dataset_name": dataset_name_value,
@@ -191,269 +173,9 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
"display_status": "completed",
}
def _handle_summary_index_generation(
self,
dataset: Dataset,
document: Document,
variable_pool: VariablePool,
) -> None:
"""
Handle summary index generation based on mode (debug/preview or production).
Args:
dataset: Dataset containing the document
document: Document to generate summaries for
variable_pool: Variable pool to check invoke_from
"""
# Only generate summary index for high_quality indexing technique
if dataset.indexing_technique != "high_quality":
return
# Check if summary index is enabled
summary_index_setting = dataset.summary_index_setting
if not summary_index_setting or not summary_index_setting.get("enable"):
return
# Skip qa_model documents
if document.doc_form == "qa_model":
return
# Determine if in preview/debug mode
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
is_preview = invoke_from and invoke_from.value == InvokeFrom.DEBUGGER
# Determine if only parent chunks should be processed
only_parent_chunks = dataset.chunk_structure == "parent_child_index"
if is_preview:
try:
# Query segments that need summary generation
query = db.session.query(DocumentSegment).filter_by(
dataset_id=dataset.id,
document_id=document.id,
status="completed",
enabled=True,
)
segments = query.all()
if not segments:
logger.info(f"No segments found for document {document.id}")
return
# Filter segments based on mode
segments_to_process = []
for segment in segments:
# Skip if summary already exists
existing_summary = (
db.session.query(DocumentSegmentSummary)
.filter_by(chunk_id=segment.id, dataset_id=dataset.id, status="completed")
.first()
)
if existing_summary:
continue
# For parent-child mode, all segments are parent chunks, so process all
segments_to_process.append(segment)
if not segments_to_process:
logger.info(f"No segments need summary generation for document {document.id}")
return
# Use ThreadPoolExecutor for concurrent generation
flask_app = current_app._get_current_object() # type: ignore
max_workers = min(10, len(segments_to_process)) # Limit to 10 workers
def process_segment(segment: DocumentSegment) -> None:
"""Process a single segment in a thread with Flask app context."""
with flask_app.app_context():
try:
SummaryIndexService.generate_and_vectorize_summary(
segment, dataset, summary_index_setting
)
except Exception as e:
logger.error(f"Failed to generate summary for segment {segment.id}: {str(e)}")
# Continue processing other segments
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [
executor.submit(process_segment, segment) for segment in segments_to_process
]
# Wait for all tasks to complete
concurrent.futures.wait(futures, timeout=300)
logger.info(
f"Successfully generated summary index for {len(segments_to_process)} segments "
f"in document {document.id}"
)
except Exception as e:
logger.exception(f"Failed to generate summary index for document {document.id}: {str(e)}")
# Don't fail the entire indexing process if summary generation fails
else:
# Production mode: asynchronous generation
logger.info(f"Queuing summary index generation task for document {document.id} (production mode)")
try:
generate_summary_index_task.delay(dataset.id, document.id, None)
logger.info(f"Summary index generation task queued for document {document.id}")
except Exception as e:
logger.exception(f"Failed to queue summary index generation task for document {document.id}: {str(e)}")
# Don't fail the entire indexing process if task queuing fails
def _get_preview_output_with_summaries(
self, chunk_structure: str, chunks: Any, dataset: Dataset,
indexing_technique: str | None = None,
summary_index_setting: dict | None = None
) -> Mapping[str, Any]:
"""
Generate preview output with summaries for chunks in preview mode.
This method generates summaries on-the-fly without saving to database.
Args:
chunk_structure: Chunk structure type
chunks: Chunks to generate preview for
dataset: Dataset object (for tenant_id)
indexing_technique: Indexing technique from node config or dataset
summary_index_setting: Summary index setting from node config or dataset
"""
def _get_preview_output(self, chunk_structure: str, chunks: Any) -> Mapping[str, Any]:
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
preview_output = index_processor.format_preview(chunks)
# Check if summary index is enabled
if indexing_technique != "high_quality":
return preview_output
if not summary_index_setting or not summary_index_setting.get("enable"):
return preview_output
# Generate summaries for chunks
if "preview" in preview_output and isinstance(preview_output["preview"], list):
chunk_count = len(preview_output["preview"])
logger.info(
f"Generating summaries for {chunk_count} chunks in preview mode "
f"(dataset: {dataset.id})"
)
# Use ParagraphIndexProcessor's generate_summary method
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
# Get Flask app for application context in worker threads
flask_app = None
try:
flask_app = current_app._get_current_object() # type: ignore
except RuntimeError:
logger.warning("No Flask application context available, summary generation may fail")
def generate_summary_for_chunk(preview_item: dict) -> None:
"""Generate summary for a single chunk."""
if "content" in preview_item:
try:
# Set Flask application context in worker thread
if flask_app:
with flask_app.app_context():
summary = ParagraphIndexProcessor.generate_summary(
tenant_id=dataset.tenant_id,
text=preview_item["content"],
summary_index_setting=summary_index_setting,
)
if summary:
preview_item["summary"] = summary
else:
# Fallback: try without app context (may fail)
summary = ParagraphIndexProcessor.generate_summary(
tenant_id=dataset.tenant_id,
text=preview_item["content"],
summary_index_setting=summary_index_setting,
)
if summary:
preview_item["summary"] = summary
except Exception as e:
logger.error(f"Failed to generate summary for chunk: {str(e)}")
# Don't fail the entire preview if summary generation fails
# Generate summaries concurrently using ThreadPoolExecutor
# Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total)
timeout_seconds = min(300, 60 * len(preview_output["preview"]))
with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_output["preview"]))) as executor:
futures = [
executor.submit(generate_summary_for_chunk, preview_item)
for preview_item in preview_output["preview"]
]
# Wait for all tasks to complete with timeout
done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds)
# Cancel tasks that didn't complete in time
if not_done:
logger.warning(
f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s. "
"Cancelling remaining tasks..."
)
for future in not_done:
future.cancel()
# Wait a bit for cancellation to take effect
concurrent.futures.wait(not_done, timeout=5)
completed_count = sum(1 for item in preview_output["preview"] if item.get("summary") is not None)
logger.info(
f"Completed summary generation for preview chunks: {completed_count}/{len(preview_output['preview'])} succeeded"
)
return preview_output
def _get_preview_output(
self, chunk_structure: str, chunks: Any, dataset: Dataset | None = None, variable_pool: VariablePool | None = None
) -> Mapping[str, Any]:
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
preview_output = index_processor.format_preview(chunks)
# If dataset is provided, try to enrich preview with summaries
if dataset and variable_pool:
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
if document_id:
document = db.session.query(Document).filter_by(id=document_id.value).first()
if document:
# Query summaries for this document
summaries = (
db.session.query(DocumentSegmentSummary)
.filter_by(
dataset_id=dataset.id,
document_id=document.id,
status="completed",
enabled=True,
)
.all()
)
if summaries:
# Create a map of segment content to summary for matching
# Use content matching as chunks in preview might not be indexed yet
summary_by_content = {}
for summary in summaries:
segment = (
db.session.query(DocumentSegment)
.filter_by(id=summary.chunk_id, dataset_id=dataset.id)
.first()
)
if segment:
# Normalize content for matching (strip whitespace)
normalized_content = segment.content.strip()
summary_by_content[normalized_content] = summary.summary_content
# Enrich preview with summaries by content matching
if "preview" in preview_output and isinstance(preview_output["preview"], list):
matched_count = 0
for preview_item in preview_output["preview"]:
if "content" in preview_item:
# Normalize content for matching
normalized_chunk_content = preview_item["content"].strip()
if normalized_chunk_content in summary_by_content:
preview_item["summary"] = summary_by_content[normalized_chunk_content]
matched_count += 1
if matched_count > 0:
logger.info(
f"Enriched preview with {matched_count} existing summaries "
f"(dataset: {dataset.id}, document: {document.id})"
)
return preview_output
return index_processor.format_preview(chunks)
@classmethod
def version(cls) -> str:

View File

@@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.entities.provider_entities import QuotaUnit
from core.file.models import File
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
@@ -136,37 +136,21 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
used_quota = 1
if used_quota is not None and system_configuration.current_quota_type is not None:
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
)
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
else:
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
session.execute(stmt)
session.commit()
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
)
session.execute(stmt)
session.commit()

View File

@@ -1,3 +1,4 @@
import json
from typing import Any
from jsonschema import Draft7Validator, ValidationError
@@ -42,22 +43,25 @@ class StartNode(Node[StartNodeData]):
if value is None and variable.required:
raise ValueError(f"{key} is required in input form")
# If no value provided, skip further processing for this key
if not value:
continue
if not isinstance(value, dict):
raise ValueError(f"JSON object for '{key}' must be an object")
# Overwrite with normalized dict to ensure downstream consistency
node_inputs[key] = value
# If schema exists, then validate against it
schema = variable.json_schema
if not schema:
continue
if not value:
continue
try:
Draft7Validator(schema).validate(value)
json_schema = json.loads(schema)
except json.JSONDecodeError as e:
raise ValueError(f"{schema} must be a valid JSON object")
try:
json_value = json.loads(value)
except json.JSONDecodeError as e:
raise ValueError(f"{value} must be a valid JSON object")
try:
Draft7Validator(json_schema).validate(json_value)
except ValidationError as e:
raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}")
node_inputs[key] = json_value

View File

@@ -33,15 +33,6 @@ class VariableAssignerNode(Node[VariableAssignerData]):
graph_runtime_state=graph_runtime_state,
)
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
"""
Check if this Variable Assigner node blocks the output of specific variables.
Returns True if this node updates any of the requested conversation variables.
"""
assigned_selector = tuple(self.node_data.assigned_variable_selector)
return assigned_selector in variable_selectors
@classmethod
def version(cls) -> str:
return "1"

View File

@@ -19,7 +19,6 @@ from core.workflow.graph_engine.protocols.command_channel import CommandChannel
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
from core.workflow.nodes import NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
@@ -137,11 +136,13 @@ class WorkflowEntry:
:param user_inputs: user inputs
:return:
"""
node_config = dict(workflow.get_node_config_by_id(node_id))
node_config = workflow.get_node_config_by_id(node_id)
node_config_data = node_config.get("data", {})
# Get node type
# Get node class
node_type = NodeType(node_config_data.get("type"))
node_version = node_config_data.get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
# init graph init params and runtime state
graph_init_params = GraphInitParams(
@@ -157,12 +158,12 @@ class WorkflowEntry:
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init workflow run state
node_factory = DifyNodeFactory(
node = node_cls(
id=str(uuid.uuid4()),
config=node_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
node = node_factory.create_node(node_config)
node_cls = type(node)
try:
# variable selector to variable mapping

View File

@@ -10,7 +10,7 @@ from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit, SystemConfiguration
from core.entities.provider_entities import QuotaUnit, SystemConfiguration
from events.message_event import message_was_created
from extensions.ext_database import db
from extensions.ext_redis import redis_client, redis_fallback
@@ -134,38 +134,22 @@ def handle(sender: Message, **kwargs):
system_configuration=system_configuration,
model_name=model_config.model,
)
if used_quota is not None:
if provider_configuration.system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
quota_update = _ProviderUpdateOperation(
filters=_ProviderUpdateFilters(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="trial",
)
elif provider_configuration.system_configuration.current_quota_type == ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
else:
quota_update = _ProviderUpdateOperation(
filters=_ProviderUpdateFilters(
tenant_id=tenant_id,
provider_name=ModelProviderID(model_config.provider).provider_name,
provider_type=ProviderType.SYSTEM.value,
quota_type=provider_configuration.system_configuration.current_quota_type.value,
),
values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
additional_filters=_ProviderUpdateAdditionalFilters(
quota_limit_check=True # Provider.quota_limit > Provider.quota_used
),
description="quota_deduction_update",
)
updates_to_perform.append(quota_update)
provider_name=ModelProviderID(model_config.provider).provider_name,
provider_type=ProviderType.SYSTEM,
quota_type=provider_configuration.system_configuration.current_quota_type.value,
),
values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
additional_filters=_ProviderUpdateAdditionalFilters(
quota_limit_check=True # Provider.quota_limit > Provider.quota_used
),
description="quota_deduction_update",
)
updates_to_perform.append(quota_update)
# Execute all updates
start_time = time_module.perf_counter()

View File

@@ -102,8 +102,6 @@ def init_app(app: DifyApp) -> Celery:
imports = [
"tasks.async_workflow_tasks", # trigger workers
"tasks.trigger_processing_tasks", # async trigger processing
"tasks.generate_summary_index_task", # summary index generation
"tasks.regenerate_summary_index_task", # summary index regeneration
]
day = dify_config.CELERY_BEAT_SCHEDULER_TIME
@@ -165,13 +163,6 @@ def init_app(app: DifyApp) -> Celery:
"task": "schedule.clean_workflow_runlogs_precise.clean_workflow_runlogs_precise",
"schedule": crontab(minute="0", hour="2"),
}
if dify_config.ENABLE_WORKFLOW_RUN_CLEANUP_TASK:
# for saas only
imports.append("schedule.clean_workflow_runs_task")
beat_schedule["clean_workflow_runs_task"] = {
"task": "schedule.clean_workflow_runs_task.clean_workflow_runs_task",
"schedule": crontab(minute="0", hour="0"),
}
if dify_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK:
imports.append("schedule.workflow_schedule_task")
beat_schedule["workflow_schedule_task"] = {

View File

@@ -4,7 +4,6 @@ from dify_app import DifyApp
def init_app(app: DifyApp):
from commands import (
add_qdrant_index,
clean_workflow_runs,
cleanup_orphaned_draft_variables,
clear_free_plan_tenant_expired_logs,
clear_orphaned_file_records,
@@ -57,7 +56,6 @@ def init_app(app: DifyApp):
setup_datasource_oauth_client,
transform_datasource_credentials,
install_rag_pipeline_plugins,
clean_workflow_runs,
]
for cmd in cmds_to_register:
app.cli.add_command(cmd)

View File

@@ -39,14 +39,6 @@ dataset_retrieval_model_fields = {
"score_threshold_enabled": fields.Boolean,
"score_threshold": fields.Float,
}
dataset_summary_index_fields = {
"enable": fields.Boolean,
"model_name": fields.String,
"model_provider_name": fields.String,
"summary_prompt": fields.String,
}
external_retrieval_model_fields = {
"top_k": fields.Integer,
"score_threshold": fields.Float,
@@ -91,7 +83,6 @@ dataset_detail_fields = {
"embedding_model_provider": fields.String,
"embedding_available": fields.Boolean,
"retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields),
"summary_index_setting": fields.Nested(dataset_summary_index_fields),
"tags": fields.List(fields.Nested(tag_fields)),
"doc_form": fields.String,
"external_knowledge_info": fields.Nested(external_knowledge_info_fields),

View File

@@ -33,8 +33,6 @@ document_fields = {
"hit_count": fields.Integer,
"doc_form": fields.String,
"doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"),
"summary_index_status": fields.String, # Summary index generation status: "waiting", "generating", "completed", "partial_error", or null if not enabled
"need_summary": fields.Boolean, # Whether this document needs summary index generation
}
document_with_segments_fields = {
@@ -62,8 +60,6 @@ document_with_segments_fields = {
"completed_segments": fields.Integer,
"total_segments": fields.Integer,
"doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"),
"summary_index_status": fields.String, # Summary index generation status: "waiting", "generating", "completed", "partial_error", or null if not enabled
"need_summary": fields.Boolean, # Whether this document needs summary index generation
}
dataset_and_document_fields = {

View File

@@ -58,5 +58,4 @@ hit_testing_record_fields = {
"score": fields.Float,
"tsne_position": fields.Raw,
"files": fields.List(fields.Nested(files_fields)),
"summary": fields.String, # Summary content if retrieved via summary index
}

View File

@@ -49,5 +49,4 @@ segment_fields = {
"stopped_at": TimestampField,
"child_chunks": fields.List(fields.Nested(child_chunk_fields)),
"attachments": fields.List(fields.Nested(attachment_fields)),
"summary": fields.String, # Summary content for the segment
}

View File

@@ -1,29 +0,0 @@
"""Add index on workflow_runs.created_at
Revision ID: 8a7f2ad7c23e
Revises: d57accd375ae
Create Date: 2025-12-10 15:04:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "8a7f2ad7c23e"
down_revision = "d57accd375ae"
branch_labels = None
depends_on = None
def upgrade():
with op.batch_alter_table("workflow_runs", schema=None) as batch_op:
batch_op.create_index(
batch_op.f("workflow_runs_created_at_idx"),
["created_at"],
unique=False,
)
def downgrade():
with op.batch_alter_table("workflow_runs", schema=None) as batch_op:
batch_op.drop_index(batch_op.f("workflow_runs_created_at_idx"))

View File

@@ -1,64 +0,0 @@
"""Alter table pipeline_recommended_plugins add column type
Revision ID: 6bb0832495f0
Revises: 7bb281b7a422
Create Date: 2025-12-15 16:14:38.482072
"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '6bb0832495f0'
down_revision = '7bb281b7a422'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_triggers', schema=None) as batch_op:
batch_op.alter_column('provider_name',
existing_type=sa.VARCHAR(length=255),
nullable=False,
existing_server_default=sa.text("''::character varying"))
with op.batch_alter_table('operation_logs', schema=None) as batch_op:
batch_op.alter_column('content',
existing_type=postgresql.JSON(astext_type=sa.Text()),
nullable=False)
with op.batch_alter_table('pipeline_recommended_plugins', schema=None) as batch_op:
batch_op.add_column(sa.Column('type', sa.String(length=50), nullable=True))
with op.batch_alter_table('providers', schema=None) as batch_op:
batch_op.alter_column('quota_used',
existing_type=sa.BIGINT(),
nullable=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('providers', schema=None) as batch_op:
batch_op.alter_column('quota_used',
existing_type=sa.BIGINT(),
nullable=True)
with op.batch_alter_table('pipeline_recommended_plugins', schema=None) as batch_op:
batch_op.drop_column('type')
with op.batch_alter_table('operation_logs', schema=None) as batch_op:
batch_op.alter_column('content',
existing_type=postgresql.JSON(astext_type=sa.Text()),
nullable=True)
with op.batch_alter_table('app_triggers', schema=None) as batch_op:
batch_op.alter_column('provider_name',
existing_type=sa.VARCHAR(length=255),
nullable=True,
existing_server_default=sa.text("''::character varying"))
# ### end Alembic commands ###

View File

@@ -1,33 +0,0 @@
"""add type column not null default tool
Revision ID: 2536f83803a8
Revises: 6bb0832495f0
Create Date: 2025-12-16 14:24:40.740253
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '2536f83803a8'
down_revision = '6bb0832495f0'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('pipeline_recommended_plugins', schema=None) as batch_op:
batch_op.add_column(sa.Column('type', sa.String(length=50), nullable=False, server_default='tool'))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('pipeline_recommended_plugins', schema=None) as batch_op:
batch_op.drop_column('type')
# ### end Alembic commands ###

View File

@@ -12,7 +12,7 @@ from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '85c8b4a64f53'
down_revision = '7df29de0f6be'
down_revision = '03ea244985ce'
branch_labels = None
depends_on = None

View File

@@ -1,46 +0,0 @@
"""add credit pool
Revision ID: 7df29de0f6be
Revises: 03ea244985ce
Create Date: 2025-12-25 10:39:15.139304
"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '7df29de0f6be'
down_revision = '03ea244985ce'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('tenant_credit_pools',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('pool_type', sa.String(length=40), server_default='trial', nullable=False),
sa.Column('quota_limit', sa.BigInteger(), nullable=False),
sa.Column('quota_used', sa.BigInteger(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='tenant_credit_pool_pkey')
)
with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op:
batch_op.create_index('tenant_credit_pool_pool_type_idx', ['pool_type'], unique=False)
batch_op.create_index('tenant_credit_pool_tenant_id_idx', ['tenant_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op:
batch_op.drop_index('tenant_credit_pool_tenant_id_idx')
batch_op.drop_index('tenant_credit_pool_pool_type_idx')
op.drop_table('tenant_credit_pools')
# ### end Alembic commands ###

View File

@@ -1,73 +0,0 @@
"""add table explore banner and trial
Revision ID: f9f6d18a37f9
Revises: 7df29de0f6be
Create Date: 2026-01-09 11:10:18.079355
"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = 'f9f6d18a37f9'
down_revision = '7df29de0f6be'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('account_trial_app_records',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('account_id', models.types.StringUUID(), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=False),
sa.Column('count', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='user_trial_app_pkey'),
sa.UniqueConstraint('account_id', 'app_id', name='unique_account_trial_app_record')
)
with op.batch_alter_table('account_trial_app_records', schema=None) as batch_op:
batch_op.create_index('account_trial_app_record_account_id_idx', ['account_id'], unique=False)
batch_op.create_index('account_trial_app_record_app_id_idx', ['app_id'], unique=False)
op.create_table('exporle_banners',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('content', sa.JSON(), nullable=False),
sa.Column('link', sa.String(length=255), nullable=False),
sa.Column('sort', sa.Integer(), nullable=False),
sa.Column('status', sa.String(length=255), server_default=sa.text("'enabled'::character varying"), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('language', sa.String(length=255), server_default=sa.text("'en-US'::character varying"), nullable=False),
sa.PrimaryKeyConstraint('id', name='exporler_banner_pkey')
)
op.create_table('trial_apps',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('trial_limit', sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint('id', name='trial_app_pkey'),
sa.UniqueConstraint('app_id', name='unique_trail_app_id')
)
with op.batch_alter_table('trial_apps', schema=None) as batch_op:
batch_op.create_index('trial_app_app_id_idx', ['app_id'], unique=False)
batch_op.create_index('trial_app_tenant_id_idx', ['tenant_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('trial_apps', schema=None) as batch_op:
batch_op.drop_index('trial_app_tenant_id_idx')
batch_op.drop_index('trial_app_app_id_idx')
op.drop_table('trial_apps')
op.drop_table('exporle_banners')
with op.batch_alter_table('account_trial_app_records', schema=None) as batch_op:
batch_op.drop_index('account_trial_app_record_app_id_idx')
batch_op.drop_index('account_trial_app_record_account_id_idx')
op.drop_table('account_trial_app_records')
# ### end Alembic commands ###

View File

@@ -1,30 +0,0 @@
"""add workflow_run_created_at_id_idx
Revision ID: 905527cc8fd3
Revises: 7df29de0f6be
Create Date: 2025-01-09 16:30:02.462084
"""
from alembic import op
import models as models
# revision identifiers, used by Alembic.
revision = '905527cc8fd3'
down_revision = '7df29de0f6be'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
batch_op.create_index('workflow_run_created_at_id_idx', ['created_at', 'id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
batch_op.drop_index('workflow_run_created_at_id_idx')
# ### end Alembic commands ###

View File

@@ -1,69 +0,0 @@
"""add SummaryIndex feature
Revision ID: 562dcce7d77c
Revises: 03ea244985ce
Create Date: 2026-01-12 13:58:40.584802
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '562dcce7d77c'
down_revision = '03ea244985ce'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('document_segment_summary',
sa.Column('id', models.types.StringUUID(), nullable=False),
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
sa.Column('document_id', models.types.StringUUID(), nullable=False),
sa.Column('chunk_id', models.types.StringUUID(), nullable=False),
sa.Column('summary_content', models.types.LongText(), nullable=True),
sa.Column('summary_index_node_id', sa.String(length=255), nullable=True),
sa.Column('summary_index_node_hash', sa.String(length=255), nullable=True),
sa.Column('status', sa.String(length=32), server_default=sa.text("'generating'"), nullable=False),
sa.Column('error', models.types.LongText(), nullable=True),
sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False),
sa.Column('disabled_at', sa.DateTime(), nullable=True),
sa.Column('disabled_by', models.types.StringUUID(), nullable=True),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='document_segment_summary_pkey')
)
with op.batch_alter_table('document_segment_summary', schema=None) as batch_op:
batch_op.create_index('document_segment_summary_chunk_id_idx', ['chunk_id'], unique=False)
batch_op.create_index('document_segment_summary_dataset_id_idx', ['dataset_id'], unique=False)
batch_op.create_index('document_segment_summary_document_id_idx', ['document_id'], unique=False)
batch_op.create_index('document_segment_summary_status_idx', ['status'], unique=False)
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.add_column(sa.Column('summary_index_setting', models.types.AdjustedJSON(), nullable=True))
with op.batch_alter_table('documents', schema=None) as batch_op:
batch_op.add_column(sa.Column('need_summary', sa.Boolean(), server_default=sa.text('false'), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('documents', schema=None) as batch_op:
batch_op.drop_column('need_summary')
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.drop_column('summary_index_setting')
with op.batch_alter_table('document_segment_summary', schema=None) as batch_op:
batch_op.drop_index('document_segment_summary_status_idx')
batch_op.drop_index('document_segment_summary_document_id_idx')
batch_op.drop_index('document_segment_summary_dataset_id_idx')
batch_op.drop_index('document_segment_summary_chunk_id_idx')
op.drop_table('document_segment_summary')
# ### end Alembic commands ###

View File

@@ -35,7 +35,6 @@ from .enums import (
WorkflowTriggerStatus,
)
from .model import (
AccountTrialAppRecord,
ApiRequest,
ApiToken,
App,
@@ -48,7 +47,6 @@ from .model import (
DatasetRetrieverResource,
DifySetup,
EndUser,
ExporleBanner,
IconType,
InstalledApp,
LLMGenerationDetail,
@@ -63,9 +61,7 @@ from .model import (
Site,
Tag,
TagBinding,
TenantCreditPool,
TraceAppConfig,
TrialApp,
UploadFile,
)
from .oauth import DatasourceOauthParamConfig, DatasourceProvider
@@ -118,7 +114,6 @@ __all__ = [
"Account",
"AccountIntegrate",
"AccountStatus",
"AccountTrialAppRecord",
"ApiRequest",
"ApiToken",
"ApiToolProvider",
@@ -155,7 +150,6 @@ __all__ = [
"DocumentSegment",
"Embedding",
"EndUser",
"ExporleBanner",
"ExternalKnowledgeApis",
"ExternalKnowledgeBindings",
"IconType",
@@ -185,7 +179,6 @@ __all__ = [
"Tenant",
"TenantAccountJoin",
"TenantAccountRole",
"TenantCreditPool",
"TenantDefaultModel",
"TenantPreferredModelProvider",
"TenantStatus",
@@ -195,7 +188,6 @@ __all__ = [
"ToolLabelBinding",
"ToolModelInvoke",
"TraceAppConfig",
"TrialApp",
"TriggerOAuthSystemClient",
"TriggerOAuthTenantClient",
"TriggerSubscription",

View File

@@ -72,7 +72,6 @@ class Dataset(Base):
keyword_number = mapped_column(sa.Integer, nullable=True, server_default=sa.text("10"))
collection_binding_id = mapped_column(StringUUID, nullable=True)
retrieval_model = mapped_column(AdjustedJSON, nullable=True)
summary_index_setting = mapped_column(AdjustedJSON, nullable=True)
built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
icon_info = mapped_column(AdjustedJSON, nullable=True)
runtime_mode = mapped_column(sa.String(255), nullable=True, server_default=sa.text("'general'"))
@@ -420,7 +419,6 @@ class Document(Base):
doc_metadata = mapped_column(AdjustedJSON, nullable=True)
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'"))
doc_language = mapped_column(String(255), nullable=True)
need_summary: Mapped[bool | None] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"))
DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"]
@@ -1569,34 +1567,3 @@ class SegmentAttachmentBinding(Base):
segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
attachment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
class DocumentSegmentSummary(Base):
__tablename__ = "document_segment_summary"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="document_segment_summary_pkey"),
sa.Index("document_segment_summary_dataset_id_idx", "dataset_id"),
sa.Index("document_segment_summary_document_id_idx", "document_id"),
sa.Index("document_segment_summary_chunk_id_idx", "chunk_id"),
sa.Index("document_segment_summary_status_idx", "status"),
)
id: Mapped[str] = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()))
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# corresponds to DocumentSegment.id or parent chunk id
chunk_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
summary_content: Mapped[str] = mapped_column(LongText, nullable=True)
summary_index_node_id: Mapped[str] = mapped_column(String(255), nullable=True)
summary_index_node_hash: Mapped[str] = mapped_column(String(255), nullable=True)
status: Mapped[str] = mapped_column(String(32), nullable=False, server_default=sa.text("'generating'"))
error: Mapped[str] = mapped_column(LongText, nullable=True)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
disabled_by = mapped_column(StringUUID, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp())
def __repr__(self):
return f"<DocumentSegmentSummary id={self.id} chunk_id={self.chunk_id} status={self.status}>"

View File

@@ -12,8 +12,8 @@ from uuid import uuid4
import sqlalchemy as sa
from flask import request
from flask_login import UserMixin # type: ignore[import-untyped]
from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
from flask_login import UserMixin
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
from sqlalchemy.orm import Mapped, Session, mapped_column
from configs import dify_config
@@ -605,64 +605,6 @@ class InstalledApp(TypeBase):
return tenant
class TrialApp(Base):
__tablename__ = "trial_apps"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="trial_app_pkey"),
sa.Index("trial_app_app_id_idx", "app_id"),
sa.Index("trial_app_tenant_id_idx", "tenant_id"),
sa.UniqueConstraint("app_id", name="unique_trail_app_id"),
)
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
tenant_id = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
trial_limit = mapped_column(sa.Integer, nullable=False, default=3)
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
class AccountTrialAppRecord(Base):
__tablename__ = "account_trial_app_records"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="user_trial_app_pkey"),
sa.Index("account_trial_app_record_account_id_idx", "account_id"),
sa.Index("account_trial_app_record_app_id_idx", "app_id"),
sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"),
)
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
account_id = mapped_column(StringUUID, nullable=False)
app_id = mapped_column(StringUUID, nullable=False)
count = mapped_column(sa.Integer, nullable=False, default=0)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
@property
def user(self) -> Account | None:
user = db.session.query(Account).where(Account.id == self.account_id).first()
return user
class ExporleBanner(Base):
__tablename__ = "exporle_banners"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),)
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
content = mapped_column(sa.JSON, nullable=False)
link = mapped_column(String(255), nullable=False)
sort = mapped_column(sa.Integer, nullable=False)
status = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"))
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
language = mapped_column(String(255), nullable=False, server_default=sa.text("'en-US'::character varying"))
class OAuthProviderApp(TypeBase):
"""
Globally shared OAuth provider app information.
@@ -1915,7 +1857,7 @@ class MessageChain(TypeBase):
)
class MessageAgentThought(TypeBase):
class MessageAgentThought(Base):
__tablename__ = "message_agent_thoughts"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"),
@@ -1923,42 +1865,34 @@ class MessageAgentThought(TypeBase):
sa.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"),
)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
id = mapped_column(StringUUID, default=lambda: str(uuid4()))
message_id = mapped_column(StringUUID, nullable=False)
message_chain_id = mapped_column(StringUUID, nullable=True)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
message_chain_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
thought: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
tool: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
tool_labels_str: Mapped[str] = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
tool_meta_str: Mapped[str] = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
tool_input: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
observation: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
thought = mapped_column(LongText, nullable=True)
tool = mapped_column(LongText, nullable=True)
tool_labels_str = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
tool_meta_str = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
tool_input = mapped_column(LongText, nullable=True)
observation = mapped_column(LongText, nullable=True)
# plugin_id = mapped_column(StringUUID, nullable=True) ## for future design
tool_process_data: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
message: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
message_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
message_unit_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None)
message_price_unit: Mapped[Decimal] = mapped_column(
sa.Numeric(10, 7), nullable=False, default=Decimal("0.001"), server_default=sa.text("0.001")
)
message_files: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
answer: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
answer_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
answer_unit_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None)
answer_price_unit: Mapped[Decimal] = mapped_column(
sa.Numeric(10, 7), nullable=False, default=Decimal("0.001"), server_default=sa.text("0.001")
)
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
total_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None)
currency: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, init=False, server_default=sa.func.current_timestamp()
)
tool_process_data = mapped_column(LongText, nullable=True)
message = mapped_column(LongText, nullable=True)
message_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
message_unit_price = mapped_column(sa.Numeric, nullable=True)
message_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
message_files = mapped_column(LongText, nullable=True)
answer = mapped_column(LongText, nullable=True)
answer_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
answer_unit_price = mapped_column(sa.Numeric, nullable=True)
answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
total_price = mapped_column(sa.Numeric, nullable=True)
currency = mapped_column(String(255), nullable=True)
latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True)
created_by_role = mapped_column(String(255), nullable=False)
created_by = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp())
@property
def files(self) -> list[Any]:
@@ -2155,32 +2089,6 @@ class TraceAppConfig(TypeBase):
}
class TenantCreditPool(Base):
__tablename__ = "tenant_credit_pools"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="tenant_credit_pool_pkey"),
sa.Index("tenant_credit_pool_tenant_id_idx", "tenant_id"),
sa.Index("tenant_credit_pool_pool_type_idx", "pool_type"),
)
id = mapped_column(StringUUID, primary_key=True, server_default=text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
pool_type = mapped_column(String(40), nullable=False, default="trial", server_default="trial")
quota_limit = mapped_column(BigInteger, nullable=False, default=0)
quota_used = mapped_column(BigInteger, nullable=False, default=0)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP"))
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
@property
def remaining_credits(self) -> int:
return max(0, self.quota_limit - self.quota_used)
def has_sufficient_credits(self, required_credits: int) -> bool:
return self.remaining_credits >= required_credits
class LLMGenerationDetail(Base):
"""
Store LLM generation details including reasoning process and tool calls.

View File

@@ -628,7 +628,6 @@ class WorkflowRun(Base):
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="workflow_run_pkey"),
sa.Index("workflow_run_triggerd_from_idx", "tenant_id", "app_id", "triggered_from"),
sa.Index("workflow_run_created_at_id_idx", "created_at", "id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))

View File

@@ -1,6 +1,6 @@
[project]
name = "dify-api"
version = "1.11.3"
version = "1.11.2"
requires-python = ">=3.11,<3.13"
dependencies = [
@@ -189,7 +189,7 @@ storage = [
"opendal~=0.46.0",
"oss2==2.18.5",
"supabase~=2.18.1",
"tos~=2.9.0",
"tos~=2.7.1",
]
############################################################

View File

@@ -34,14 +34,11 @@ Example:
```
"""
from collections.abc import Callable, Sequence
from collections.abc import Sequence
from datetime import datetime
from typing import Protocol
from sqlalchemy.orm import Session
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.enums import WorkflowType
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.enums import WorkflowRunTriggeredFrom
@@ -256,44 +253,6 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
"""
...
def get_runs_batch_by_time_range(
self,
start_from: datetime | None,
end_before: datetime,
last_seen: tuple[datetime, str] | None,
batch_size: int,
run_types: Sequence[WorkflowType] | None = None,
tenant_ids: Sequence[str] | None = None,
) -> Sequence[WorkflowRun]:
"""
Fetch ended workflow runs in a time window for archival and clean batching.
"""
...
def delete_runs_with_related(
self,
runs: Sequence[WorkflowRun],
delete_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None,
delete_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None,
) -> dict[str, int]:
"""
Delete workflow runs and their related records (node executions, offloads, app logs,
trigger logs, pauses, pause reasons).
"""
...
def count_runs_with_related(
self,
runs: Sequence[WorkflowRun],
count_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None,
count_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None,
) -> dict[str, int]:
"""
Count workflow runs and their related records (node executions, offloads, app logs,
trigger logs, pauses, pause reasons) without deleting data.
"""
...
def create_workflow_pause(
self,
workflow_run_id: str,

View File

@@ -7,18 +7,13 @@ using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations.
from collections.abc import Sequence
from datetime import datetime
from typing import TypedDict, cast
from typing import cast
from sqlalchemy import asc, delete, desc, func, select, tuple_
from sqlalchemy import asc, delete, desc, select
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session, sessionmaker
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import (
WorkflowNodeExecutionModel,
WorkflowNodeExecutionOffload,
WorkflowNodeExecutionTriggeredFrom,
)
from models.workflow import WorkflowNodeExecutionModel
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
@@ -49,26 +44,6 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
"""
self._session_maker = session_maker
@staticmethod
def _map_run_triggered_from_to_node_triggered_from(triggered_from: str) -> str:
"""
Map workflow run triggered_from values to workflow node execution triggered_from values.
"""
if triggered_from in {
WorkflowRunTriggeredFrom.APP_RUN.value,
WorkflowRunTriggeredFrom.DEBUGGING.value,
WorkflowRunTriggeredFrom.SCHEDULE.value,
WorkflowRunTriggeredFrom.PLUGIN.value,
WorkflowRunTriggeredFrom.WEBHOOK.value,
}:
return WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
if triggered_from in {
WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value,
WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value,
}:
return WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN.value
return ""
def get_node_last_execution(
self,
tenant_id: str,
@@ -315,119 +290,3 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
result = cast(CursorResult, session.execute(stmt))
session.commit()
return result.rowcount
class RunContext(TypedDict):
run_id: str
tenant_id: str
app_id: str
workflow_id: str
triggered_from: str
@staticmethod
def delete_by_runs(session: Session, runs: Sequence[RunContext]) -> tuple[int, int]:
"""
Delete node executions (and offloads) for the given workflow runs using indexed columns.
Uses the composite index on (tenant_id, app_id, workflow_id, triggered_from, workflow_run_id)
by filtering on those columns with tuple IN.
"""
if not runs:
return 0, 0
tuple_values = [
(
run["tenant_id"],
run["app_id"],
run["workflow_id"],
DifyAPISQLAlchemyWorkflowNodeExecutionRepository._map_run_triggered_from_to_node_triggered_from(
run["triggered_from"]
),
run["run_id"],
)
for run in runs
]
node_execution_ids = session.scalars(
select(WorkflowNodeExecutionModel.id).where(
tuple_(
WorkflowNodeExecutionModel.tenant_id,
WorkflowNodeExecutionModel.app_id,
WorkflowNodeExecutionModel.workflow_id,
WorkflowNodeExecutionModel.triggered_from,
WorkflowNodeExecutionModel.workflow_run_id,
).in_(tuple_values)
)
).all()
if not node_execution_ids:
return 0, 0
offloads_deleted = (
cast(
CursorResult,
session.execute(
delete(WorkflowNodeExecutionOffload).where(
WorkflowNodeExecutionOffload.node_execution_id.in_(node_execution_ids)
)
),
).rowcount
or 0
)
node_executions_deleted = (
cast(
CursorResult,
session.execute(
delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(node_execution_ids))
),
).rowcount
or 0
)
return node_executions_deleted, offloads_deleted
@staticmethod
def count_by_runs(session: Session, runs: Sequence[RunContext]) -> tuple[int, int]:
"""
Count node executions (and offloads) for the given workflow runs using indexed columns.
"""
if not runs:
return 0, 0
tuple_values = [
(
run["tenant_id"],
run["app_id"],
run["workflow_id"],
DifyAPISQLAlchemyWorkflowNodeExecutionRepository._map_run_triggered_from_to_node_triggered_from(
run["triggered_from"]
),
run["run_id"],
)
for run in runs
]
tuple_filter = tuple_(
WorkflowNodeExecutionModel.tenant_id,
WorkflowNodeExecutionModel.app_id,
WorkflowNodeExecutionModel.workflow_id,
WorkflowNodeExecutionModel.triggered_from,
WorkflowNodeExecutionModel.workflow_run_id,
).in_(tuple_values)
node_executions_count = (
session.scalar(select(func.count()).select_from(WorkflowNodeExecutionModel).where(tuple_filter)) or 0
)
offloads_count = (
session.scalar(
select(func.count())
.select_from(WorkflowNodeExecutionOffload)
.join(
WorkflowNodeExecutionModel,
WorkflowNodeExecutionOffload.node_execution_id == WorkflowNodeExecutionModel.id,
)
.where(tuple_filter)
)
or 0
)
return int(node_executions_count), int(offloads_count)

View File

@@ -21,7 +21,7 @@ Implementation Notes:
import logging
import uuid
from collections.abc import Callable, Sequence
from collections.abc import Sequence
from datetime import datetime
from decimal import Decimal
from typing import Any, cast
@@ -32,7 +32,7 @@ from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session, selectinload, sessionmaker
from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, SchedulingPause
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
from core.workflow.enums import WorkflowExecutionStatus
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from libs.helper import convert_datetime_to_date
@@ -40,14 +40,8 @@ from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.time_parser import get_time_threshold
from libs.uuid_utils import uuidv7
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import (
WorkflowAppLog,
WorkflowPauseReason,
WorkflowRun,
)
from models.workflow import (
WorkflowPause as WorkflowPauseModel,
)
from models.workflow import WorkflowPause as WorkflowPauseModel
from models.workflow import WorkflowPauseReason, WorkflowRun
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.entities.workflow_pause import WorkflowPauseEntity
from repositories.types import (
@@ -320,171 +314,6 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
logger.info("Total deleted %s workflow runs for app %s", total_deleted, app_id)
return total_deleted
def get_runs_batch_by_time_range(
self,
start_from: datetime | None,
end_before: datetime,
last_seen: tuple[datetime, str] | None,
batch_size: int,
run_types: Sequence[WorkflowType] | None = None,
tenant_ids: Sequence[str] | None = None,
) -> Sequence[WorkflowRun]:
"""
Fetch ended workflow runs in a time window for archival and clean batching.
Query scope:
- created_at in [start_from, end_before)
- type in run_types (when provided)
- status is an ended state
- optional tenant_id filter and cursor (last_seen) for pagination
"""
with self._session_maker() as session:
stmt = (
select(WorkflowRun)
.where(
WorkflowRun.created_at < end_before,
WorkflowRun.status.in_(WorkflowExecutionStatus.ended_values()),
)
.order_by(WorkflowRun.created_at.asc(), WorkflowRun.id.asc())
.limit(batch_size)
)
if run_types is not None:
if not run_types:
return []
stmt = stmt.where(WorkflowRun.type.in_(run_types))
if start_from:
stmt = stmt.where(WorkflowRun.created_at >= start_from)
if tenant_ids:
stmt = stmt.where(WorkflowRun.tenant_id.in_(tenant_ids))
if last_seen:
stmt = stmt.where(
or_(
WorkflowRun.created_at > last_seen[0],
and_(WorkflowRun.created_at == last_seen[0], WorkflowRun.id > last_seen[1]),
)
)
return session.scalars(stmt).all()
def delete_runs_with_related(
self,
runs: Sequence[WorkflowRun],
delete_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None,
delete_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None,
) -> dict[str, int]:
if not runs:
return {
"runs": 0,
"node_executions": 0,
"offloads": 0,
"app_logs": 0,
"trigger_logs": 0,
"pauses": 0,
"pause_reasons": 0,
}
with self._session_maker() as session:
run_ids = [run.id for run in runs]
if delete_node_executions:
node_executions_deleted, offloads_deleted = delete_node_executions(session, runs)
else:
node_executions_deleted, offloads_deleted = 0, 0
app_logs_result = session.execute(delete(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(run_ids)))
app_logs_deleted = cast(CursorResult, app_logs_result).rowcount or 0
pause_ids = session.scalars(
select(WorkflowPauseModel.id).where(WorkflowPauseModel.workflow_run_id.in_(run_ids))
).all()
pause_reasons_deleted = 0
pauses_deleted = 0
if pause_ids:
pause_reasons_result = session.execute(
delete(WorkflowPauseReason).where(WorkflowPauseReason.pause_id.in_(pause_ids))
)
pause_reasons_deleted = cast(CursorResult, pause_reasons_result).rowcount or 0
pauses_result = session.execute(delete(WorkflowPauseModel).where(WorkflowPauseModel.id.in_(pause_ids)))
pauses_deleted = cast(CursorResult, pauses_result).rowcount or 0
trigger_logs_deleted = delete_trigger_logs(session, run_ids) if delete_trigger_logs else 0
runs_result = session.execute(delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids)))
runs_deleted = cast(CursorResult, runs_result).rowcount or 0
session.commit()
return {
"runs": runs_deleted,
"node_executions": node_executions_deleted,
"offloads": offloads_deleted,
"app_logs": app_logs_deleted,
"trigger_logs": trigger_logs_deleted,
"pauses": pauses_deleted,
"pause_reasons": pause_reasons_deleted,
}
def count_runs_with_related(
self,
runs: Sequence[WorkflowRun],
count_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None,
count_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None,
) -> dict[str, int]:
if not runs:
return {
"runs": 0,
"node_executions": 0,
"offloads": 0,
"app_logs": 0,
"trigger_logs": 0,
"pauses": 0,
"pause_reasons": 0,
}
with self._session_maker() as session:
run_ids = [run.id for run in runs]
if count_node_executions:
node_executions_count, offloads_count = count_node_executions(session, runs)
else:
node_executions_count, offloads_count = 0, 0
app_logs_count = (
session.scalar(
select(func.count()).select_from(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(run_ids))
)
or 0
)
pause_ids = session.scalars(
select(WorkflowPauseModel.id).where(WorkflowPauseModel.workflow_run_id.in_(run_ids))
).all()
pauses_count = len(pause_ids)
pause_reasons_count = 0
if pause_ids:
pause_reasons_count = (
session.scalar(
select(func.count())
.select_from(WorkflowPauseReason)
.where(WorkflowPauseReason.pause_id.in_(pause_ids))
)
or 0
)
trigger_logs_count = count_trigger_logs(session, run_ids) if count_trigger_logs else 0
return {
"runs": len(runs),
"node_executions": node_executions_count,
"offloads": offloads_count,
"app_logs": int(app_logs_count),
"trigger_logs": trigger_logs_count,
"pauses": pauses_count,
"pause_reasons": int(pause_reasons_count),
}
def create_workflow_pause(
self,
workflow_run_id: str,

View File

@@ -4,10 +4,8 @@ SQLAlchemy implementation of WorkflowTriggerLogRepository.
from collections.abc import Sequence
from datetime import UTC, datetime, timedelta
from typing import cast
from sqlalchemy import and_, delete, func, select
from sqlalchemy.engine import CursorResult
from sqlalchemy import and_, select
from sqlalchemy.orm import Session
from models.enums import WorkflowTriggerStatus
@@ -86,37 +84,3 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository):
)
return list(self.session.scalars(query).all())
def delete_by_run_ids(self, run_ids: Sequence[str]) -> int:
"""
Delete trigger logs associated with the given workflow run ids.
Args:
run_ids: Collection of workflow run identifiers.
Returns:
Number of rows deleted.
"""
if not run_ids:
return 0
result = self.session.execute(delete(WorkflowTriggerLog).where(WorkflowTriggerLog.workflow_run_id.in_(run_ids)))
return cast(CursorResult, result).rowcount or 0
def count_by_run_ids(self, run_ids: Sequence[str]) -> int:
"""
Count trigger logs associated with the given workflow run ids.
Args:
run_ids: Collection of workflow run identifiers.
Returns:
Number of rows matched.
"""
if not run_ids:
return 0
count = self.session.scalar(
select(func.count()).select_from(WorkflowTriggerLog).where(WorkflowTriggerLog.workflow_run_id.in_(run_ids))
)
return int(count or 0)

View File

@@ -109,15 +109,3 @@ class WorkflowTriggerLogRepository(Protocol):
A sequence of recent WorkflowTriggerLog instances
"""
...
def delete_by_run_ids(self, run_ids: Sequence[str]) -> int:
"""
Delete trigger logs for workflow run IDs.
Args:
run_ids: Workflow run IDs to delete
Returns:
Number of rows deleted
"""
...

View File

@@ -1,43 +0,0 @@
from datetime import UTC, datetime
import click
import app
from configs import dify_config
from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup
@app.celery.task(queue="retention")
def clean_workflow_runs_task() -> None:
"""
Scheduled cleanup for workflow runs and related records (sandbox tenants only).
"""
click.echo(
click.style(
(
"Scheduled workflow run cleanup starting: "
f"cutoff={dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS} days, "
f"batch={dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE}"
),
fg="green",
)
)
start_time = datetime.now(UTC)
WorkflowRunCleanup(
days=dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS,
batch_size=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE,
start_from=None,
end_before=None,
).run()
end_time = datetime.now(UTC)
elapsed = end_time - start_time
click.echo(
click.style(
f"Scheduled workflow run cleanup finished. start={start_time.isoformat()} "
f"end={end_time.isoformat()} duration={elapsed}",
fg="green",
)
)

View File

@@ -8,7 +8,7 @@ from hashlib import sha256
from typing import Any, cast
from pydantic import BaseModel
from sqlalchemy import func, select
from sqlalchemy import func
from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized
@@ -748,21 +748,6 @@ class AccountService:
cls.email_code_login_rate_limiter.increment_rate_limit(email)
return token
@staticmethod
def get_account_by_email_with_case_fallback(email: str, session: Session | None = None) -> Account | None:
"""
Retrieve an account by email and fall back to the lowercase email if the original lookup fails.
This keeps backward compatibility for older records that stored uppercase emails while the
rest of the system gradually normalizes new inputs.
"""
query_session = session or db.session
account = query_session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
if account or email == email.lower():
return account
return query_session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()
@classmethod
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:
return TokenManager.get_token_data(token, "email_code_login")
@@ -1014,11 +999,6 @@ class TenantService:
tenant.encrypt_public_key = generate_key_pair(tenant.id)
db.session.commit()
from services.credit_pool_service import CreditPoolService
CreditPoolService.create_default_pool(tenant.id)
return tenant
@staticmethod
@@ -1378,22 +1358,16 @@ class RegisterService:
if not inviter:
raise ValueError("Inviter is required")
normalized_email = email.lower()
"""Invite new member"""
with Session(db.engine) as session:
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
account = session.query(Account).filter_by(email=email).first()
if not account:
TenantService.check_member_permission(tenant, inviter, None, "add")
name = normalized_email.split("@")[0]
name = email.split("@")[0]
account = cls.register(
email=normalized_email,
name=name,
language=language,
status=AccountStatus.PENDING,
is_setup=True,
email=email, name=name, language=language, status=AccountStatus.PENDING, is_setup=True
)
# Create new tenant member for invited tenant
TenantService.create_tenant_member(tenant, account, role)
@@ -1415,7 +1389,7 @@ class RegisterService:
# send email
send_invite_member_mail_task.delay(
language=language,
to=account.email,
to=email,
token=token,
inviter_name=inviter.name if inviter else "Dify",
workspace_name=tenant.name,
@@ -1514,16 +1488,6 @@ class RegisterService:
invitation: dict = json.loads(data)
return invitation
@classmethod
def get_invitation_with_case_fallback(
cls, workspace_id: str | None, email: str | None, token: str
) -> dict[str, Any] | None:
invitation = cls.get_invitation_if_token_valid(workspace_id, email, token)
if invitation or not email or email == email.lower():
return invitation
normalized_email = email.lower()
return cls.get_invitation_if_token_valid(workspace_id, normalized_email, token)
def _generate_refresh_token(length: int = 64):
token = secrets.token_hex(length)

View File

@@ -355,6 +355,7 @@ class AppAnnotationService:
def batch_import_app_annotations(cls, app_id, file: FileStorage):
"""
Batch import annotations from CSV file with enhanced security checks.
Security features:
- File size validation
- Row count limits (min/max)
@@ -363,6 +364,7 @@ class AppAnnotationService:
- Concurrency tracking
"""
from configs import dify_config
# get app info
current_user, current_tenant_id = current_account_with_tenant()
app = (
@@ -446,31 +448,27 @@ class AppAnnotationService:
f"The CSV file must contain at least {min_records} valid annotation record(s). "
f"Found {len(result)} valid record(s)."
)
# Check annotation quota limit
features = FeatureService.get_features(current_tenant_id)
if features.billing.enabled:
annotation_quota_limit = features.annotation_quota_limit
if annotation_quota_limit.limit < len(result) + annotation_quota_limit.size:
raise ValueError(
f"The number of annotations ({len(result)}) would exceed your subscription limit. "
f"Current usage: {annotation_quota_limit.size}/{annotation_quota_limit.limit}. "
f"Available: {annotation_quota_limit.limit - annotation_quota_limit.size}."
)
# Create async job
raise ValueError("The number of annotations exceeds the limit of your subscription.")
# async job
job_id = str(uuid.uuid4())
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
# Register job in active tasks list for concurrency tracking
current_time = int(naive_utc_now().timestamp() * 1000)
active_jobs_key = f"annotation_import_active:{current_tenant_id}"
redis_client.zadd(active_jobs_key, {job_id: current_time})
redis_client.expire(active_jobs_key, 7200) # 2 hours TTL
# Set job status
redis_client.setnx(indexing_cache_key, "waiting")
redis_client.expire(indexing_cache_key, 3600) # 1 hour TTL
# Send batch import task
batch_import_annotations_task.delay(str(job_id), result, app_id, current_tenant_id, current_user.id)
except ValueError as e:
return {"error_msg": str(e)}
except Exception as e:

Some files were not shown because too many files have changed in this diff Show More