mirror of
https://github.com/langgenius/dify.git
synced 2026-04-10 05:40:33 +08:00
Compare commits
1 Commits
feat/colla
...
codex/forc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a33791c3de |
@@ -64,7 +64,7 @@ export const useUpdateAccessMode = () => {
|
||||
|
||||
// Component only adds UI behavior.
|
||||
updateAccessMode({ appId, mode }, {
|
||||
onSuccess: () => toast.success('...'),
|
||||
onSuccess: () => Toast.notify({ type: 'success', message: '...' }),
|
||||
})
|
||||
|
||||
// Avoid putting invalidation knowledge in the component.
|
||||
@@ -114,7 +114,10 @@ try {
|
||||
router.push(`/orders/${order.id}`)
|
||||
}
|
||||
catch (error) {
|
||||
toast.error(error instanceof Error ? error.message : 'Unknown error')
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: error instanceof Error ? error.message : 'Unknown error',
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
1
.github/actions/setup-web/action.yml
vendored
1
.github/actions/setup-web/action.yml
vendored
@@ -6,6 +6,7 @@ runs:
|
||||
- name: Setup Vite+
|
||||
uses: voidzero-dev/setup-vp@20553a7a7429c429a74894104a2835d7fed28a72 # v1.3.0
|
||||
with:
|
||||
working-directory: web
|
||||
node-version-file: .nvmrc
|
||||
cache: true
|
||||
run-install: true
|
||||
|
||||
9
.github/labeler.yml
vendored
9
.github/labeler.yml
vendored
@@ -1,10 +1,3 @@
|
||||
web:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- 'web/**'
|
||||
- 'packages/**'
|
||||
- 'package.json'
|
||||
- 'pnpm-lock.yaml'
|
||||
- 'pnpm-workspace.yaml'
|
||||
- '.npmrc'
|
||||
- '.nvmrc'
|
||||
- any-glob-to-any-file: 'web/**'
|
||||
|
||||
2
.github/pull_request_template.md
vendored
2
.github/pull_request_template.md
vendored
@@ -20,4 +20,4 @@
|
||||
- [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
|
||||
- [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
|
||||
- [x] I've updated the documentation accordingly.
|
||||
- [x] I ran `make lint` and `make type-check` (backend) and `cd web && pnpm exec vp staged` (frontend) to appease the lint gods
|
||||
- [x] I ran `make lint` and `make type-check` (backend) and `cd web && npx lint-staged` (frontend) to appease the lint gods
|
||||
|
||||
82
.github/scripts/generate-i18n-changes.mjs
vendored
82
.github/scripts/generate-i18n-changes.mjs
vendored
@@ -1,82 +0,0 @@
|
||||
import { execFileSync } from 'node:child_process'
|
||||
import fs from 'node:fs'
|
||||
import path from 'node:path'
|
||||
|
||||
const repoRoot = process.cwd()
|
||||
const baseSha = process.env.BASE_SHA || ''
|
||||
const headSha = process.env.HEAD_SHA || ''
|
||||
const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean)
|
||||
const outputPath = process.env.I18N_CHANGES_OUTPUT_PATH || '/tmp/i18n-changes.json'
|
||||
|
||||
const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`)
|
||||
|
||||
const readCurrentJson = (fileStem) => {
|
||||
const filePath = englishPath(fileStem)
|
||||
if (!fs.existsSync(filePath))
|
||||
return null
|
||||
|
||||
return JSON.parse(fs.readFileSync(filePath, 'utf8'))
|
||||
}
|
||||
|
||||
const readBaseJson = (fileStem) => {
|
||||
if (!baseSha)
|
||||
return null
|
||||
|
||||
try {
|
||||
const relativePath = `web/i18n/en-US/${fileStem}.json`
|
||||
const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' })
|
||||
return JSON.parse(content)
|
||||
}
|
||||
catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue)
|
||||
|
||||
const changes = {}
|
||||
|
||||
for (const fileStem of files) {
|
||||
const currentJson = readCurrentJson(fileStem)
|
||||
const beforeJson = readBaseJson(fileStem) || {}
|
||||
const afterJson = currentJson || {}
|
||||
const added = {}
|
||||
const updated = {}
|
||||
const deleted = []
|
||||
|
||||
for (const [key, value] of Object.entries(afterJson)) {
|
||||
if (!(key in beforeJson)) {
|
||||
added[key] = value
|
||||
continue
|
||||
}
|
||||
|
||||
if (!compareJson(beforeJson[key], value)) {
|
||||
updated[key] = {
|
||||
before: beforeJson[key],
|
||||
after: value,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const key of Object.keys(beforeJson)) {
|
||||
if (!(key in afterJson))
|
||||
deleted.push(key)
|
||||
}
|
||||
|
||||
changes[fileStem] = {
|
||||
fileDeleted: currentJson === null,
|
||||
added,
|
||||
updated,
|
||||
deleted,
|
||||
}
|
||||
}
|
||||
|
||||
fs.writeFileSync(
|
||||
outputPath,
|
||||
JSON.stringify({
|
||||
baseSha,
|
||||
headSha,
|
||||
files,
|
||||
changes,
|
||||
})
|
||||
)
|
||||
11
.github/workflows/api-tests.yml
vendored
11
.github/workflows/api-tests.yml
vendored
@@ -9,6 +9,9 @@ on:
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true
|
||||
|
||||
concurrency:
|
||||
group: api-tests-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
@@ -35,7 +38,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@@ -84,7 +87,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@@ -156,7 +159,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
@@ -203,7 +206,7 @@ jobs:
|
||||
|
||||
- name: Report coverage
|
||||
if: ${{ env.CODECOV_TOKEN != '' }}
|
||||
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
|
||||
uses: codecov/codecov-action@1af58845a975a7985b0beb0cbe6fbbb71a41dbad # v5.5.3
|
||||
with:
|
||||
files: ./coverage.xml
|
||||
disable_search: true
|
||||
|
||||
11
.github/workflows/autofix.yml
vendored
11
.github/workflows/autofix.yml
vendored
@@ -10,6 +10,9 @@ on:
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true
|
||||
|
||||
jobs:
|
||||
autofix:
|
||||
if: github.repository == 'langgenius/dify'
|
||||
@@ -39,12 +42,6 @@ jobs:
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
packages/**
|
||||
package.json
|
||||
pnpm-lock.yaml
|
||||
pnpm-workspace.yaml
|
||||
.npmrc
|
||||
.nvmrc
|
||||
- name: Check api inputs
|
||||
if: github.event_name != 'merge_group'
|
||||
id: api-changes
|
||||
@@ -58,7 +55,7 @@ jobs:
|
||||
python-version: "3.11"
|
||||
|
||||
- if: github.event_name != 'merge_group'
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
|
||||
- name: Generate Docker Compose
|
||||
if: github.event_name != 'merge_group' && steps.docker-compose-changes.outputs.any_changed == 'true'
|
||||
|
||||
34
.github/workflows/build-push.yml
vendored
34
.github/workflows/build-push.yml
vendored
@@ -24,39 +24,27 @@ env:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ${{ matrix.runs_on }}
|
||||
runs-on: ${{ matrix.platform == 'linux/arm64' && 'arm64_runner' || 'ubuntu-latest' }}
|
||||
if: github.repository == 'langgenius/dify'
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- service_name: "build-api-amd64"
|
||||
image_name_env: "DIFY_API_IMAGE_NAME"
|
||||
artifact_context: "api"
|
||||
build_context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
context: "api"
|
||||
platform: linux/amd64
|
||||
runs_on: ubuntu-latest
|
||||
- service_name: "build-api-arm64"
|
||||
image_name_env: "DIFY_API_IMAGE_NAME"
|
||||
artifact_context: "api"
|
||||
build_context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
context: "api"
|
||||
platform: linux/arm64
|
||||
runs_on: ubuntu-24.04-arm
|
||||
- service_name: "build-web-amd64"
|
||||
image_name_env: "DIFY_WEB_IMAGE_NAME"
|
||||
artifact_context: "web"
|
||||
build_context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
context: "web"
|
||||
platform: linux/amd64
|
||||
runs_on: ubuntu-latest
|
||||
- service_name: "build-web-arm64"
|
||||
image_name_env: "DIFY_WEB_IMAGE_NAME"
|
||||
artifact_context: "web"
|
||||
build_context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
context: "web"
|
||||
platform: linux/arm64
|
||||
runs_on: ubuntu-24.04-arm
|
||||
|
||||
steps:
|
||||
- name: Prepare
|
||||
@@ -65,11 +53,14 @@ jobs:
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4.1.0
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
|
||||
with:
|
||||
username: ${{ env.DOCKERHUB_USER }}
|
||||
password: ${{ env.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4.0.0
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||
|
||||
@@ -83,8 +74,7 @@ jobs:
|
||||
id: build
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
|
||||
with:
|
||||
context: ${{ matrix.build_context }}
|
||||
file: ${{ matrix.file }}
|
||||
context: "{{defaultContext}}:${{ matrix.context }}"
|
||||
platforms: ${{ matrix.platform }}
|
||||
build-args: COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
@@ -103,7 +93,7 @@ jobs:
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
with:
|
||||
name: digests-${{ matrix.artifact_context }}-${{ env.PLATFORM_PAIR }}
|
||||
name: digests-${{ matrix.context }}-${{ env.PLATFORM_PAIR }}
|
||||
path: /tmp/digests/*
|
||||
if-no-files-found: error
|
||||
retention-days: 1
|
||||
@@ -130,7 +120,7 @@ jobs:
|
||||
merge-multiple: true
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4.1.0
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
|
||||
with:
|
||||
username: ${{ env.DOCKERHUB_USER }}
|
||||
password: ${{ env.DOCKERHUB_TOKEN }}
|
||||
|
||||
7
.github/workflows/db-migration-test.yml
vendored
7
.github/workflows/db-migration-test.yml
vendored
@@ -3,6 +3,9 @@ name: DB Migration Test
|
||||
on:
|
||||
workflow_call:
|
||||
|
||||
env:
|
||||
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true
|
||||
|
||||
concurrency:
|
||||
group: db-migration-test-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
@@ -19,7 +22,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
@@ -69,7 +72,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
|
||||
32
.github/workflows/docker-build.yml
vendored
32
.github/workflows/docker-build.yml
vendored
@@ -6,14 +6,7 @@ on:
|
||||
- "main"
|
||||
paths:
|
||||
- api/Dockerfile
|
||||
- web/docker/**
|
||||
- web/Dockerfile
|
||||
- packages/**
|
||||
- package.json
|
||||
- pnpm-lock.yaml
|
||||
- pnpm-workspace.yaml
|
||||
- .npmrc
|
||||
- .nvmrc
|
||||
|
||||
concurrency:
|
||||
group: docker-build-${{ github.head_ref || github.run_id }}
|
||||
@@ -21,31 +14,26 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
build-docker:
|
||||
runs-on: ${{ matrix.runs_on }}
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- service_name: "api-amd64"
|
||||
platform: linux/amd64
|
||||
runs_on: ubuntu-latest
|
||||
context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
context: "api"
|
||||
- service_name: "api-arm64"
|
||||
platform: linux/arm64
|
||||
runs_on: ubuntu-24.04-arm
|
||||
context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
context: "api"
|
||||
- service_name: "web-amd64"
|
||||
platform: linux/amd64
|
||||
runs_on: ubuntu-latest
|
||||
context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
context: "web"
|
||||
- service_name: "web-arm64"
|
||||
platform: linux/arm64
|
||||
runs_on: ubuntu-24.04-arm
|
||||
context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
context: "web"
|
||||
steps:
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4.0.0
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||
|
||||
@@ -53,8 +41,8 @@ jobs:
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
|
||||
with:
|
||||
push: false
|
||||
context: ${{ matrix.context }}
|
||||
file: ${{ matrix.file }}
|
||||
context: "{{defaultContext}}:${{ matrix.context }}"
|
||||
file: "${{ matrix.file }}"
|
||||
platforms: ${{ matrix.platform }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
15
.github/workflows/main-ci.yml
vendored
15
.github/workflows/main-ci.yml
vendored
@@ -16,6 +16,9 @@ permissions:
|
||||
checks: write
|
||||
statuses: write
|
||||
|
||||
env:
|
||||
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true
|
||||
|
||||
concurrency:
|
||||
group: main-ci-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
@@ -65,12 +68,6 @@ jobs:
|
||||
- 'docker/volumes/sandbox/conf/**'
|
||||
web:
|
||||
- 'web/**'
|
||||
- 'packages/**'
|
||||
- 'package.json'
|
||||
- 'pnpm-lock.yaml'
|
||||
- 'pnpm-workspace.yaml'
|
||||
- '.npmrc'
|
||||
- '.nvmrc'
|
||||
- '.github/workflows/web-tests.yml'
|
||||
- '.github/actions/setup-web/**'
|
||||
e2e:
|
||||
@@ -79,12 +76,6 @@ jobs:
|
||||
- 'api/uv.lock'
|
||||
- 'e2e/**'
|
||||
- 'web/**'
|
||||
- 'packages/**'
|
||||
- 'package.json'
|
||||
- 'pnpm-lock.yaml'
|
||||
- 'pnpm-workspace.yaml'
|
||||
- '.npmrc'
|
||||
- '.nvmrc'
|
||||
- 'docker/docker-compose.middleware.yaml'
|
||||
- 'docker/middleware.env.example'
|
||||
- '.github/workflows/web-e2e.yml'
|
||||
|
||||
15
.github/workflows/pyrefly-diff.yml
vendored
15
.github/workflows/pyrefly-diff.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python & UV
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
@@ -50,17 +50,6 @@ jobs:
|
||||
run: |
|
||||
diff -u /tmp/pyrefly_base.txt /tmp/pyrefly_pr.txt > pyrefly_diff.txt || true
|
||||
|
||||
- name: Check if line counts match
|
||||
id: line_count_check
|
||||
run: |
|
||||
base_lines=$(wc -l < /tmp/pyrefly_base.txt)
|
||||
pr_lines=$(wc -l < /tmp/pyrefly_pr.txt)
|
||||
if [ "$base_lines" -eq "$pr_lines" ]; then
|
||||
echo "same=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "same=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Save PR number
|
||||
run: |
|
||||
echo ${{ github.event.pull_request.number }} > pr_number.txt
|
||||
@@ -74,7 +63,7 @@ jobs:
|
||||
pr_number.txt
|
||||
|
||||
- name: Comment PR with pyrefly diff
|
||||
if: ${{ github.event.pull_request.head.repo.full_name == github.repository && steps.line_count_check.outputs.same == 'false' }}
|
||||
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
14
.github/workflows/style.yml
vendored
14
.github/workflows/style.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
||||
|
||||
- name: Setup UV and Python
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
with:
|
||||
enable-cache: false
|
||||
python-version: "3.12"
|
||||
@@ -77,12 +77,6 @@ jobs:
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
packages/**
|
||||
package.json
|
||||
pnpm-lock.yaml
|
||||
pnpm-workspace.yaml
|
||||
.npmrc
|
||||
.nvmrc
|
||||
.github/workflows/style.yml
|
||||
.github/actions/setup-web/**
|
||||
|
||||
@@ -96,9 +90,9 @@ jobs:
|
||||
uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
|
||||
with:
|
||||
path: web/.eslintcache
|
||||
key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
|
||||
key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
|
||||
${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
|
||||
|
||||
- name: Web style check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
@@ -151,7 +145,7 @@ jobs:
|
||||
.editorconfig
|
||||
|
||||
- name: Super-linter
|
||||
uses: super-linter/super-linter/slim@9e863354e3ff62e0727d37183162c4a88873df41 # v8.6.0
|
||||
uses: super-linter/super-linter/slim@61abc07d755095a68f4987d1c2c3d1d64408f1f9 # v8.5.0
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
env:
|
||||
BASH_SEVERITY: warning
|
||||
|
||||
4
.github/workflows/tool-test-sdks.yaml
vendored
4
.github/workflows/tool-test-sdks.yaml
vendored
@@ -6,10 +6,6 @@ on:
|
||||
- main
|
||||
paths:
|
||||
- sdks/**
|
||||
- package.json
|
||||
- pnpm-lock.yaml
|
||||
- pnpm-workspace.yaml
|
||||
- .npmrc
|
||||
|
||||
concurrency:
|
||||
group: sdk-tests-${{ github.head_ref || github.run_id }}
|
||||
|
||||
226
.github/workflows/translate-i18n-claude.yml
vendored
226
.github/workflows/translate-i18n-claude.yml
vendored
@@ -1,10 +1,10 @@
|
||||
name: Translate i18n Files with Claude Code
|
||||
|
||||
# Note: claude-code-action doesn't support push events directly.
|
||||
# Push events are bridged by trigger-i18n-sync.yml via repository_dispatch.
|
||||
on:
|
||||
repository_dispatch:
|
||||
types: [i18n-sync]
|
||||
push:
|
||||
branches: [main]
|
||||
paths:
|
||||
- 'web/i18n/en-US/*.json'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
files:
|
||||
@@ -30,7 +30,7 @@ permissions:
|
||||
|
||||
concurrency:
|
||||
group: translate-i18n-${{ github.event_name }}-${{ github.ref }}
|
||||
cancel-in-progress: false
|
||||
cancel-in-progress: ${{ github.event_name == 'push' }}
|
||||
|
||||
jobs:
|
||||
translate:
|
||||
@@ -67,31 +67,19 @@ jobs:
|
||||
}
|
||||
" web/i18n-config/languages.ts | sed 's/[[:space:]]*$//')
|
||||
|
||||
generate_changes_json() {
|
||||
node .github/scripts/generate-i18n-changes.mjs
|
||||
}
|
||||
|
||||
if [ "${{ github.event_name }}" = "repository_dispatch" ]; then
|
||||
BASE_SHA="${{ github.event.client_payload.base_sha }}"
|
||||
HEAD_SHA="${{ github.event.client_payload.head_sha }}"
|
||||
CHANGED_FILES="${{ github.event.client_payload.changed_files }}"
|
||||
TARGET_LANGS="$DEFAULT_TARGET_LANGS"
|
||||
SYNC_MODE="${{ github.event.client_payload.sync_mode || 'incremental' }}"
|
||||
|
||||
if [ -n "${{ github.event.client_payload.changes_base64 }}" ]; then
|
||||
printf '%s' '${{ github.event.client_payload.changes_base64 }}' | base64 -d > /tmp/i18n-changes.json
|
||||
CHANGES_AVAILABLE="true"
|
||||
CHANGES_SOURCE="embedded"
|
||||
elif [ -n "$BASE_SHA" ] && [ -n "$CHANGED_FILES" ]; then
|
||||
export BASE_SHA HEAD_SHA CHANGED_FILES
|
||||
generate_changes_json
|
||||
CHANGES_AVAILABLE="true"
|
||||
CHANGES_SOURCE="recomputed"
|
||||
else
|
||||
printf '%s' '{"baseSha":"","headSha":"","files":[],"changes":{}}' > /tmp/i18n-changes.json
|
||||
CHANGES_AVAILABLE="false"
|
||||
CHANGES_SOURCE="unavailable"
|
||||
if [ "${{ github.event_name }}" = "push" ]; then
|
||||
BASE_SHA="${{ github.event.before }}"
|
||||
if [ -z "$BASE_SHA" ] || [ "$BASE_SHA" = "0000000000000000000000000000000000000000" ]; then
|
||||
BASE_SHA=$(git rev-parse HEAD~1 2>/dev/null || true)
|
||||
fi
|
||||
HEAD_SHA="${{ github.sha }}"
|
||||
if [ -n "$BASE_SHA" ]; then
|
||||
CHANGED_FILES=$(git diff --name-only "$BASE_SHA" "$HEAD_SHA" -- 'web/i18n/en-US/*.json' 2>/dev/null | sed -n 's@^.*/@@p' | sed 's/\.json$//' | tr '\n' ' ' | sed 's/[[:space:]]*$//')
|
||||
else
|
||||
CHANGED_FILES=$(find web/i18n/en-US -maxdepth 1 -type f -name '*.json' -print | sed -n 's@^.*/@@p' | sed 's/\.json$//' | sort | tr '\n' ' ' | sed 's/[[:space:]]*$//')
|
||||
fi
|
||||
TARGET_LANGS="$DEFAULT_TARGET_LANGS"
|
||||
SYNC_MODE="incremental"
|
||||
else
|
||||
BASE_SHA=""
|
||||
HEAD_SHA=$(git rev-parse HEAD)
|
||||
@@ -116,17 +104,6 @@ jobs:
|
||||
else
|
||||
CHANGED_FILES=""
|
||||
fi
|
||||
|
||||
if [ "$SYNC_MODE" = "incremental" ] && [ -n "$CHANGED_FILES" ]; then
|
||||
export BASE_SHA HEAD_SHA CHANGED_FILES
|
||||
generate_changes_json
|
||||
CHANGES_AVAILABLE="true"
|
||||
CHANGES_SOURCE="local"
|
||||
else
|
||||
printf '%s' '{"baseSha":"","headSha":"","files":[],"changes":{}}' > /tmp/i18n-changes.json
|
||||
CHANGES_AVAILABLE="false"
|
||||
CHANGES_SOURCE="unavailable"
|
||||
fi
|
||||
fi
|
||||
|
||||
FILE_ARGS=""
|
||||
@@ -146,8 +123,6 @@ jobs:
|
||||
echo "CHANGED_FILES=$CHANGED_FILES"
|
||||
echo "TARGET_LANGS=$TARGET_LANGS"
|
||||
echo "SYNC_MODE=$SYNC_MODE"
|
||||
echo "CHANGES_AVAILABLE=$CHANGES_AVAILABLE"
|
||||
echo "CHANGES_SOURCE=$CHANGES_SOURCE"
|
||||
echo "FILE_ARGS=$FILE_ARGS"
|
||||
echo "LANG_ARGS=$LANG_ARGS"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
@@ -158,7 +133,7 @@ jobs:
|
||||
|
||||
- name: Run Claude Code for Translation Sync
|
||||
if: steps.context.outputs.CHANGED_FILES != ''
|
||||
uses: anthropics/claude-code-action@6e2bd52842c65e914eba5c8badd17560bd26b5de # v1.0.89
|
||||
uses: anthropics/claude-code-action@88c168b39e7e64da0286d812b6e9fbebb6708185 # v1.0.82
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
@@ -166,7 +141,7 @@ jobs:
|
||||
show_full_output: ${{ github.event_name == 'workflow_dispatch' }}
|
||||
prompt: |
|
||||
You are the i18n sync agent for the Dify repository.
|
||||
Your job is to keep translations synchronized with the English source files under `${{ github.workspace }}/web/i18n/en-US/`.
|
||||
Your job is to keep translations synchronized with the English source files under `${{ github.workspace }}/web/i18n/en-US/`, then open a PR with the result.
|
||||
|
||||
Use absolute paths at all times:
|
||||
- Repo root: `${{ github.workspace }}`
|
||||
@@ -181,15 +156,12 @@ jobs:
|
||||
- Head SHA: `${{ steps.context.outputs.HEAD_SHA }}`
|
||||
- Scoped file args: `${{ steps.context.outputs.FILE_ARGS }}`
|
||||
- Scoped language args: `${{ steps.context.outputs.LANG_ARGS }}`
|
||||
- Structured change set available: `${{ steps.context.outputs.CHANGES_AVAILABLE }}`
|
||||
- Structured change set source: `${{ steps.context.outputs.CHANGES_SOURCE }}`
|
||||
- Structured change set file: `/tmp/i18n-changes.json`
|
||||
|
||||
Tool rules:
|
||||
- Use Read for repository files.
|
||||
- Use Edit for JSON updates.
|
||||
- Use Bash only for `vp`.
|
||||
- Do not use Bash for `git`, `gh`, or branch management.
|
||||
- Use Bash only for `git`, `gh`, `pnpm`, and `date`.
|
||||
- Run Bash commands one by one. Do not combine commands with `&&`, `||`, pipes, or command substitution.
|
||||
|
||||
Required execution plan:
|
||||
1. Resolve target languages.
|
||||
@@ -200,146 +172,42 @@ jobs:
|
||||
- Only process the resolved target languages, never `en-US`.
|
||||
- Do not touch unrelated i18n files.
|
||||
- Do not modify `${{ github.workspace }}/web/i18n/en-US/`.
|
||||
3. Resolve source changes.
|
||||
- If `Structured change set available` is `true`, read `/tmp/i18n-changes.json` and use it as the source of truth for file-level and key-level changes.
|
||||
- For each file entry:
|
||||
- `added` contains new English keys that need translations.
|
||||
- `updated` contains stale keys whose English source changed; re-translate using the `after` value.
|
||||
- `deleted` contains keys that should be removed from locale files.
|
||||
- `fileDeleted: true` means the English file no longer exists; remove the matching locale file if present.
|
||||
- Read the current English JSON file for any file that still exists so wording, placeholders, and surrounding terminology stay accurate.
|
||||
- If `Structured change set available` is `false`, treat this as a scoped full sync and use the current English files plus scoped checks as the source of truth.
|
||||
3. Detect English changes per file.
|
||||
- Read the current English JSON file for each file in scope.
|
||||
- If sync mode is `incremental` and `Base SHA` is not empty, run:
|
||||
`git -C ${{ github.workspace }} show <Base SHA>:web/i18n/en-US/<file>.json`
|
||||
- If sync mode is `full` or `Base SHA` is empty, skip historical comparison and treat the current English file as the only source of truth for structural sync.
|
||||
- If the file did not exist at Base SHA, treat all current keys as ADD.
|
||||
- Compare previous and current English JSON to identify:
|
||||
- ADD: key only in current
|
||||
- UPDATE: key exists in both and the English value changed
|
||||
- DELETE: key only in previous
|
||||
- Do not rely on a truncated diff file.
|
||||
4. Run a scoped pre-check before editing:
|
||||
- `vp run dify-web#i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
|
||||
- `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
|
||||
- Use this command as the source of truth for missing and extra keys inside the current scope.
|
||||
5. Apply translations.
|
||||
- For every target language and scoped file:
|
||||
- If `fileDeleted` is `true`, remove the locale file if it exists and skip the rest of that file.
|
||||
- If the locale file does not exist yet, create it with `Write` and then continue with `Edit` as needed.
|
||||
- ADD missing keys.
|
||||
- UPDATE stale translations when the English value changed.
|
||||
- DELETE removed keys. Prefer `vp run dify-web#i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }} --auto-remove` for extra keys so deletions stay in scope.
|
||||
- DELETE removed keys. Prefer `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }} --auto-remove` for extra keys so deletions stay in scope.
|
||||
- For `zh-Hans` and `ja-JP`, if the locale file also changed between Base SHA and Head SHA, preserve manual translations unless they are clearly wrong for the new English value. If in doubt, keep the manual translation.
|
||||
- Preserve placeholders exactly: `{{variable}}`, `${variable}`, HTML tags, component tags, and variable names.
|
||||
- Match the existing terminology and register used by each locale.
|
||||
- Prefer one Edit per file when stable, but prioritize correctness over batching.
|
||||
6. Verify only the edited files.
|
||||
- Run `vp run dify-web#lint:fix --quiet -- <relative edited i18n file paths under web/>`
|
||||
- Run `vp run dify-web#i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
|
||||
- Run `pnpm --dir ${{ github.workspace }}/web lint:fix --quiet -- <relative edited i18n file paths>`
|
||||
- Run `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
|
||||
- If verification fails, fix the remaining problems before continuing.
|
||||
7. Stop after the scoped locale files are updated and verification passes.
|
||||
- Do not create branches, commits, or pull requests.
|
||||
7. Create a PR only when there are changes in `web/i18n/`.
|
||||
- Check `git -C ${{ github.workspace }} status --porcelain -- web/i18n/`
|
||||
- Create branch `chore/i18n-sync-<timestamp>`
|
||||
- Commit message: `chore(i18n): sync translations with en-US`
|
||||
- Push the branch and open a PR against `main`
|
||||
- PR title: `chore(i18n): sync translations with en-US`
|
||||
- PR body: summarize files, languages, sync mode, and verification commands
|
||||
8. If there are no translation changes after verification, do not create a branch, commit, or PR.
|
||||
claude_args: |
|
||||
--max-turns 120
|
||||
--allowedTools "Read,Write,Edit,Bash(vp *),Bash(vp:*),Glob,Grep"
|
||||
|
||||
- name: Prepare branch metadata
|
||||
id: pr_meta
|
||||
if: steps.context.outputs.CHANGED_FILES != ''
|
||||
shell: bash
|
||||
run: |
|
||||
if [ -z "$(git -C "${{ github.workspace }}" status --porcelain -- web/i18n/)" ]; then
|
||||
echo "has_changes=false" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
SCOPE_HASH=$(printf '%s|%s|%s' "${{ steps.context.outputs.CHANGED_FILES }}" "${{ steps.context.outputs.TARGET_LANGS }}" "${{ steps.context.outputs.SYNC_MODE }}" | sha256sum | cut -c1-8)
|
||||
HEAD_SHORT=$(printf '%s' "${{ steps.context.outputs.HEAD_SHA }}" | cut -c1-12)
|
||||
BRANCH_NAME="chore/i18n-sync-${HEAD_SHORT}-${SCOPE_HASH}"
|
||||
|
||||
{
|
||||
echo "has_changes=true"
|
||||
echo "branch_name=$BRANCH_NAME"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Commit translation changes
|
||||
if: steps.pr_meta.outputs.has_changes == 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
git -C "${{ github.workspace }}" checkout -B "${{ steps.pr_meta.outputs.branch_name }}"
|
||||
git -C "${{ github.workspace }}" add web/i18n/
|
||||
git -C "${{ github.workspace }}" commit -m "chore(i18n): sync translations with en-US"
|
||||
|
||||
- name: Push translation branch
|
||||
if: steps.pr_meta.outputs.has_changes == 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
if git -C "${{ github.workspace }}" ls-remote --exit-code --heads origin "${{ steps.pr_meta.outputs.branch_name }}" >/dev/null 2>&1; then
|
||||
git -C "${{ github.workspace }}" push --force-with-lease origin "${{ steps.pr_meta.outputs.branch_name }}"
|
||||
else
|
||||
git -C "${{ github.workspace }}" push --set-upstream origin "${{ steps.pr_meta.outputs.branch_name }}"
|
||||
fi
|
||||
|
||||
- name: Create or update translation PR
|
||||
if: steps.pr_meta.outputs.has_changes == 'true'
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
BRANCH_NAME: ${{ steps.pr_meta.outputs.branch_name }}
|
||||
FILES_IN_SCOPE: ${{ steps.context.outputs.CHANGED_FILES }}
|
||||
TARGET_LANGS: ${{ steps.context.outputs.TARGET_LANGS }}
|
||||
SYNC_MODE: ${{ steps.context.outputs.SYNC_MODE }}
|
||||
CHANGES_SOURCE: ${{ steps.context.outputs.CHANGES_SOURCE }}
|
||||
BASE_SHA: ${{ steps.context.outputs.BASE_SHA }}
|
||||
HEAD_SHA: ${{ steps.context.outputs.HEAD_SHA }}
|
||||
REPO_NAME: ${{ github.repository }}
|
||||
shell: bash
|
||||
run: |
|
||||
PR_BODY_FILE=/tmp/i18n-pr-body.md
|
||||
LANG_COUNT=$(printf '%s\n' "$TARGET_LANGS" | wc -w | tr -d ' ')
|
||||
if [ "$LANG_COUNT" = "0" ]; then
|
||||
LANG_COUNT="0"
|
||||
fi
|
||||
export LANG_COUNT
|
||||
|
||||
node <<'NODE' > "$PR_BODY_FILE"
|
||||
const fs = require('node:fs')
|
||||
|
||||
const changesPath = '/tmp/i18n-changes.json'
|
||||
const changes = fs.existsSync(changesPath)
|
||||
? JSON.parse(fs.readFileSync(changesPath, 'utf8'))
|
||||
: { changes: {} }
|
||||
|
||||
const filesInScope = (process.env.FILES_IN_SCOPE || '').split(/\s+/).filter(Boolean)
|
||||
const lines = [
|
||||
'## Summary',
|
||||
'',
|
||||
`- **Files synced**: \`${process.env.FILES_IN_SCOPE || '<none>'}\``,
|
||||
`- **Languages updated**: ${process.env.TARGET_LANGS || '<none>'} (${process.env.LANG_COUNT} languages)`,
|
||||
`- **Sync mode**: ${process.env.SYNC_MODE}${process.env.BASE_SHA ? ` (base: \`${process.env.BASE_SHA.slice(0, 10)}\`, head: \`${process.env.HEAD_SHA.slice(0, 10)}\`)` : ` (head: \`${process.env.HEAD_SHA.slice(0, 10)}\`)`}`,
|
||||
'',
|
||||
'### Key changes',
|
||||
]
|
||||
|
||||
for (const fileName of filesInScope) {
|
||||
const fileChange = changes.changes?.[fileName] || { added: {}, updated: {}, deleted: [], fileDeleted: false }
|
||||
const addedKeys = Object.keys(fileChange.added || {})
|
||||
const updatedKeys = Object.keys(fileChange.updated || {})
|
||||
const deletedKeys = fileChange.deleted || []
|
||||
lines.push(`- \`${fileName}\`: +${addedKeys.length} / ~${updatedKeys.length} / -${deletedKeys.length}${fileChange.fileDeleted ? ' (file deleted in en-US)' : ''}`)
|
||||
}
|
||||
|
||||
lines.push(
|
||||
'',
|
||||
'## Verification',
|
||||
'',
|
||||
`- \`vp run dify-web#i18n:check --file ${process.env.FILES_IN_SCOPE} --lang ${process.env.TARGET_LANGS}\``,
|
||||
`- \`vp run dify-web#lint:fix --quiet -- <edited i18n files under web/>\``,
|
||||
'',
|
||||
'## Notes',
|
||||
'',
|
||||
'- This PR was generated from structured en-US key changes produced by `trigger-i18n-sync.yml`.',
|
||||
`- Structured change source: ${process.env.CHANGES_SOURCE || 'unknown'}.`,
|
||||
'- Branch name is deterministic for the head SHA and scope, so reruns update the same PR instead of opening duplicates.',
|
||||
'',
|
||||
'🤖 Generated with [Claude Code](https://claude.com/claude-code)'
|
||||
)
|
||||
|
||||
process.stdout.write(lines.join('\n'))
|
||||
NODE
|
||||
|
||||
EXISTING_PR_NUMBER=$(gh pr list --repo "$REPO_NAME" --head "$BRANCH_NAME" --state open --json number --jq '.[0].number')
|
||||
|
||||
if [ -n "$EXISTING_PR_NUMBER" ] && [ "$EXISTING_PR_NUMBER" != "null" ]; then
|
||||
gh pr edit "$EXISTING_PR_NUMBER" --repo "$REPO_NAME" --title "chore(i18n): sync translations with en-US" --body-file "$PR_BODY_FILE"
|
||||
else
|
||||
gh pr create --repo "$REPO_NAME" --head "$BRANCH_NAME" --base main --title "chore(i18n): sync translations with en-US" --body-file "$PR_BODY_FILE"
|
||||
fi
|
||||
--max-turns 80
|
||||
--allowedTools "Read,Write,Edit,Bash(git *),Bash(git:*),Bash(gh *),Bash(gh:*),Bash(pnpm *),Bash(pnpm:*),Bash(date *),Bash(date:*),Glob,Grep"
|
||||
|
||||
90
.github/workflows/trigger-i18n-sync.yml
vendored
90
.github/workflows/trigger-i18n-sync.yml
vendored
@@ -1,90 +0,0 @@
|
||||
name: Trigger i18n Sync on Push
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
paths:
|
||||
- 'web/i18n/en-US/*.json'
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
concurrency:
|
||||
group: trigger-i18n-sync-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
trigger:
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Detect changed files and build structured change set
|
||||
id: detect
|
||||
shell: bash
|
||||
run: |
|
||||
BASE_SHA="${{ github.event.before }}"
|
||||
if [ -z "$BASE_SHA" ] || [ "$BASE_SHA" = "0000000000000000000000000000000000000000" ]; then
|
||||
BASE_SHA=$(git rev-parse HEAD~1 2>/dev/null || true)
|
||||
fi
|
||||
HEAD_SHA="${{ github.sha }}"
|
||||
|
||||
if [ -n "$BASE_SHA" ]; then
|
||||
CHANGED_FILES=$(git diff --name-only "$BASE_SHA" "$HEAD_SHA" -- 'web/i18n/en-US/*.json' 2>/dev/null | sed -n 's@^.*/@@p' | sed 's/\.json$//' | tr '\n' ' ' | sed 's/[[:space:]]*$//')
|
||||
else
|
||||
CHANGED_FILES=$(find web/i18n/en-US -maxdepth 1 -type f -name '*.json' -print | sed -n 's@^.*/@@p' | sed 's/\.json$//' | sort | tr '\n' ' ' | sed 's/[[:space:]]*$//')
|
||||
fi
|
||||
|
||||
export BASE_SHA HEAD_SHA CHANGED_FILES
|
||||
node .github/scripts/generate-i18n-changes.mjs
|
||||
|
||||
if [ -n "$CHANGED_FILES" ]; then
|
||||
echo "has_changes=true" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "has_changes=false" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
echo "base_sha=$BASE_SHA" >> "$GITHUB_OUTPUT"
|
||||
echo "head_sha=$HEAD_SHA" >> "$GITHUB_OUTPUT"
|
||||
echo "changed_files=$CHANGED_FILES" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Trigger i18n sync workflow
|
||||
if: steps.detect.outputs.has_changes == 'true'
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
env:
|
||||
BASE_SHA: ${{ steps.detect.outputs.base_sha }}
|
||||
HEAD_SHA: ${{ steps.detect.outputs.head_sha }}
|
||||
CHANGED_FILES: ${{ steps.detect.outputs.changed_files }}
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
const fs = require('fs')
|
||||
|
||||
const changesJson = fs.readFileSync('/tmp/i18n-changes.json', 'utf8')
|
||||
const changesBase64 = Buffer.from(changesJson).toString('base64')
|
||||
const maxEmbeddedChangesChars = 48000
|
||||
const changesEmbedded = changesBase64.length <= maxEmbeddedChangesChars
|
||||
|
||||
if (!changesEmbedded) {
|
||||
console.log(`Structured change set too large to embed safely (${changesBase64.length} chars). Downstream workflow will regenerate it from git history.`)
|
||||
}
|
||||
|
||||
await github.rest.repos.createDispatchEvent({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
event_type: 'i18n-sync',
|
||||
client_payload: {
|
||||
changed_files: process.env.CHANGED_FILES,
|
||||
changes_base64: changesEmbedded ? changesBase64 : '',
|
||||
changes_embedded: changesEmbedded,
|
||||
sync_mode: 'incremental',
|
||||
base_sha: process.env.BASE_SHA,
|
||||
head_sha: process.env.HEAD_SHA,
|
||||
},
|
||||
})
|
||||
95
.github/workflows/vdb-tests-full.yml
vendored
95
.github/workflows/vdb-tests-full.yml
vendored
@@ -1,95 +0,0 @@
|
||||
name: Run Full VDB Tests
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 3 * * 1'
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: vdb-tests-full-${{ github.ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: Full VDB Tests
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
- "3.12"
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Free Disk Space
|
||||
uses: endersonmenezes/free-disk-space@7901478139cff6e9d44df5972fd8ab8fcade4db1 # v3.2.2
|
||||
with:
|
||||
remove_dotnet: true
|
||||
remove_haskell: true
|
||||
remove_tool_cache: true
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache-dependency-glob: api/uv.lock
|
||||
|
||||
- name: Check UV lockfile
|
||||
run: uv lock --project api --check
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Set up dotenvs
|
||||
run: |
|
||||
cp docker/.env.example docker/.env
|
||||
cp docker/middleware.env.example docker/middleware.env
|
||||
|
||||
- name: Expose Service Ports
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
# - name: Set up Vector Store (TiDB)
|
||||
# uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
# with:
|
||||
# compose-file: docker/tidb/docker-compose.yaml
|
||||
# services: |
|
||||
# tidb
|
||||
# tiflash
|
||||
|
||||
- name: Set up Full Vector Store Matrix
|
||||
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.yaml
|
||||
services: |
|
||||
weaviate
|
||||
qdrant
|
||||
couchbase-server
|
||||
etcd
|
||||
minio
|
||||
milvus-standalone
|
||||
pgvecto-rs
|
||||
pgvector
|
||||
chroma
|
||||
elasticsearch
|
||||
oceanbase
|
||||
|
||||
- name: setup test config
|
||||
run: |
|
||||
echo $(pwd)
|
||||
ls -lah .
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
# - name: Check VDB Ready (TiDB)
|
||||
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: uv run --project api bash dev/pytest/pytest_vdb.sh
|
||||
28
.github/workflows/vdb-tests.yml
vendored
28
.github/workflows/vdb-tests.yml
vendored
@@ -1,10 +1,10 @@
|
||||
name: Run VDB Smoke Tests
|
||||
name: Run VDB Tests
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
env:
|
||||
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true
|
||||
|
||||
concurrency:
|
||||
group: vdb-tests-${{ github.head_ref || github.run_id }}
|
||||
@@ -12,7 +12,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: VDB Smoke Tests
|
||||
name: VDB Tests
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
@@ -33,7 +33,7 @@ jobs:
|
||||
remove_tool_cache: true
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@@ -61,18 +61,23 @@ jobs:
|
||||
# tidb
|
||||
# tiflash
|
||||
|
||||
- name: Set up Vector Stores for Smoke Coverage
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase, OceanBase)
|
||||
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.yaml
|
||||
services: |
|
||||
db_postgres
|
||||
redis
|
||||
weaviate
|
||||
qdrant
|
||||
couchbase-server
|
||||
etcd
|
||||
minio
|
||||
milvus-standalone
|
||||
pgvecto-rs
|
||||
pgvector
|
||||
chroma
|
||||
elasticsearch
|
||||
oceanbase
|
||||
|
||||
- name: setup test config
|
||||
run: |
|
||||
@@ -84,9 +89,4 @@ jobs:
|
||||
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: |
|
||||
uv run --project api pytest --timeout "${PYTEST_TIMEOUT:-180}" \
|
||||
api/tests/integration_tests/vdb/chroma \
|
||||
api/tests/integration_tests/vdb/pgvector \
|
||||
api/tests/integration_tests/vdb/qdrant \
|
||||
api/tests/integration_tests/vdb/weaviate
|
||||
run: uv run --project api bash dev/pytest/pytest_vdb.sh
|
||||
|
||||
6
.github/workflows/web-e2e.yml
vendored
6
.github/workflows/web-e2e.yml
vendored
@@ -27,8 +27,12 @@ jobs:
|
||||
- name: Setup web dependencies
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Install E2E package dependencies
|
||||
working-directory: ./e2e
|
||||
run: vp install --frozen-lockfile
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
|
||||
33
.github/workflows/web-tests.yml
vendored
33
.github/workflows/web-tests.yml
vendored
@@ -83,9 +83,40 @@ jobs:
|
||||
|
||||
- name: Report coverage
|
||||
if: ${{ env.CODECOV_TOKEN != '' }}
|
||||
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
|
||||
uses: codecov/codecov-action@1af58845a975a7985b0beb0cbe6fbbb71a41dbad # v5.5.3
|
||||
with:
|
||||
directory: web/coverage
|
||||
flags: web
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
|
||||
|
||||
web-build:
|
||||
name: Web Build
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: ./web
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
.github/workflows/web-tests.yml
|
||||
.github/actions/setup-web/**
|
||||
|
||||
- name: Setup web environment
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Web build check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: vp run build
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -212,8 +212,6 @@ api/.vscode
|
||||
|
||||
# pnpm
|
||||
/.pnpm-store
|
||||
node_modules
|
||||
.vite-hooks/_
|
||||
|
||||
# plugin migrate
|
||||
plugins.jsonl
|
||||
@@ -241,4 +239,4 @@ scripts/stress-test/reports/
|
||||
*.local.md
|
||||
|
||||
# Code Agent Folder
|
||||
.qoder/*
|
||||
.qoder/*
|
||||
6
Makefile
6
Makefile
@@ -24,8 +24,8 @@ prepare-docker:
|
||||
# Step 2: Prepare web environment
|
||||
prepare-web:
|
||||
@echo "🌐 Setting up web environment..."
|
||||
@cp -n web/.env.example web/.env.local 2>/dev/null || echo "Web .env.local already exists"
|
||||
@pnpm install
|
||||
@cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists"
|
||||
@cd web && pnpm install
|
||||
@echo "✅ Web environment prepared (not started)"
|
||||
|
||||
# Step 3: Prepare API environment
|
||||
@@ -93,7 +93,7 @@ test:
|
||||
# Build Docker images
|
||||
build-web:
|
||||
@echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..."
|
||||
docker build -f web/Dockerfile -t $(WEB_IMAGE):$(VERSION) .
|
||||
docker build -t $(WEB_IMAGE):$(VERSION) ./web
|
||||
@echo "Web Docker image built successfully: $(WEB_IMAGE):$(VERSION)"
|
||||
|
||||
build-api:
|
||||
|
||||
@@ -33,9 +33,6 @@ TRIGGER_URL=http://localhost:5001
|
||||
# The time in seconds after the signature is rejected
|
||||
FILES_ACCESS_TIMEOUT=300
|
||||
|
||||
# Collaboration mode toggle
|
||||
ENABLE_COLLABORATION_MODE=false
|
||||
|
||||
# Access token expiration time in minutes
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=60
|
||||
|
||||
@@ -74,13 +71,6 @@ REDIS_USE_CLUSTERS=false
|
||||
REDIS_CLUSTERS=
|
||||
REDIS_CLUSTERS_PASSWORD=
|
||||
|
||||
REDIS_RETRY_RETRIES=3
|
||||
REDIS_RETRY_BACKOFF_BASE=1.0
|
||||
REDIS_RETRY_BACKOFF_CAP=10.0
|
||||
REDIS_SOCKET_TIMEOUT=5.0
|
||||
REDIS_SOCKET_CONNECT_TIMEOUT=5.0
|
||||
REDIS_HEALTH_CHECK_INTERVAL=30
|
||||
|
||||
# celery configuration
|
||||
CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1
|
||||
CELERY_BACKEND=redis
|
||||
@@ -112,7 +102,6 @@ S3_BUCKET_NAME=your-bucket-name
|
||||
S3_ACCESS_KEY=your-access-key
|
||||
S3_SECRET_KEY=your-secret-key
|
||||
S3_REGION=your-region
|
||||
S3_ADDRESS_STYLE=auto
|
||||
|
||||
# Workflow run and Conversation archive storage (S3-compatible)
|
||||
ARCHIVE_STORAGE_ENABLED=false
|
||||
|
||||
@@ -115,6 +115,12 @@ ignore = [
|
||||
"controllers/console/human_input_form.py" = ["TID251"]
|
||||
"controllers/web/human_input_form.py" = ["TID251"]
|
||||
|
||||
[lint.pyflakes]
|
||||
allowed-unused-imports = [
|
||||
"tests.integration_tests",
|
||||
"tests.unit_tests",
|
||||
]
|
||||
|
||||
[lint.flake8-tidy-imports]
|
||||
|
||||
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse"]
|
||||
|
||||
@@ -40,8 +40,6 @@ The scripts resolve paths relative to their location, so you can run them from a
|
||||
./dev/start-web
|
||||
```
|
||||
|
||||
`./dev/setup` and `./dev/start-web` install JavaScript dependencies through the repository root workspace, so you do not need a separate `cd web && pnpm install` step.
|
||||
|
||||
1. Set up your application by visiting `http://localhost:3000`.
|
||||
|
||||
1. Start the worker service (async and scheduler tasks, runs from `api`).
|
||||
|
||||
17
api/app.py
17
api/app.py
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
@@ -17,15 +16,10 @@ def is_db_command() -> bool:
|
||||
|
||||
|
||||
# create app
|
||||
flask_app = None
|
||||
socketio_app = None
|
||||
|
||||
if is_db_command():
|
||||
from app_factory import create_migrations_app
|
||||
|
||||
app = create_migrations_app()
|
||||
socketio_app = app
|
||||
flask_app = app
|
||||
else:
|
||||
# Gunicorn and Celery handle monkey patching automatically in production by
|
||||
# specifying the `gevent` worker class. Manual monkey patching is not required here.
|
||||
@@ -36,15 +30,8 @@ else:
|
||||
|
||||
from app_factory import create_app
|
||||
|
||||
socketio_app, flask_app = create_app()
|
||||
app = flask_app
|
||||
app = create_app()
|
||||
celery = cast("Celery", app.extensions["celery"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
from gevent import pywsgi
|
||||
from geventwebsocket.handler import WebSocketHandler # type: ignore[reportMissingTypeStubs]
|
||||
|
||||
host = os.environ.get("HOST", "0.0.0.0")
|
||||
port = int(os.environ.get("PORT", 5001))
|
||||
server = pywsgi.WSGIServer((host, port), socketio_app, handler_class=WebSocketHandler)
|
||||
server.serve_forever()
|
||||
app.run(host="0.0.0.0", port=5001)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import socketio # type: ignore[reportMissingTypeStubs]
|
||||
from flask import request
|
||||
from opentelemetry.trace import get_current_span
|
||||
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
|
||||
@@ -11,7 +10,6 @@ from contexts.wrapper import RecyclableContextVar
|
||||
from controllers.console.error import UnauthorizedAndForceLogout
|
||||
from core.logging.context import init_request_context
|
||||
from dify_app import DifyApp
|
||||
from extensions.ext_socketio import sio
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
@@ -124,18 +122,14 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||
return dify_app
|
||||
|
||||
|
||||
def create_app() -> tuple[socketio.WSGIApp, DifyApp]:
|
||||
def create_app() -> DifyApp:
|
||||
start_time = time.perf_counter()
|
||||
app = create_flask_app_with_configs()
|
||||
initialize_extensions(app)
|
||||
|
||||
sio.app = app
|
||||
socketio_app = socketio.WSGIApp(sio, app)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
if dify_config.DEBUG:
|
||||
logger.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2))
|
||||
return socketio_app, app
|
||||
return app
|
||||
|
||||
|
||||
def initialize_extensions(app: DifyApp):
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
# This module provides a lightweight Celery instance for use in Docker health checks.
|
||||
# Unlike celery_entrypoint.py, this does NOT import app.py and therefore avoids
|
||||
# initializing all Flask extensions (DB, Redis, storage, blueprints, etc.).
|
||||
# Using this module keeps the health check fast and low-cost.
|
||||
from celery import Celery
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_celery import get_celery_broker_transport_options, get_celery_ssl_options
|
||||
|
||||
celery = Celery(broker=dify_config.CELERY_BROKER_URL)
|
||||
|
||||
broker_transport_options = get_celery_broker_transport_options()
|
||||
if broker_transport_options:
|
||||
celery.conf.update(broker_transport_options=broker_transport_options)
|
||||
|
||||
ssl_options = get_celery_ssl_options()
|
||||
if ssl_options:
|
||||
celery.conf.update(broker_use_ssl=ssl_options)
|
||||
@@ -1,7 +1,7 @@
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
from typing import TypedDict
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
@@ -503,19 +503,7 @@ def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
|
||||
return [row[0] for row in result]
|
||||
|
||||
|
||||
class _AppOrphanCounts(TypedDict):
|
||||
variables: int
|
||||
files: int
|
||||
|
||||
|
||||
class OrphanedDraftVariableStatsDict(TypedDict):
|
||||
total_orphaned_variables: int
|
||||
total_orphaned_files: int
|
||||
orphaned_app_count: int
|
||||
orphaned_by_app: dict[str, _AppOrphanCounts]
|
||||
|
||||
|
||||
def _count_orphaned_draft_variables() -> OrphanedDraftVariableStatsDict:
|
||||
def _count_orphaned_draft_variables() -> dict[str, Any]:
|
||||
"""
|
||||
Count orphaned draft variables by app, including associated file counts.
|
||||
|
||||
@@ -538,7 +526,7 @@ def _count_orphaned_draft_variables() -> OrphanedDraftVariableStatsDict:
|
||||
|
||||
with db.engine.connect() as conn:
|
||||
result = conn.execute(sa.text(variables_query))
|
||||
orphaned_by_app: dict[str, _AppOrphanCounts] = {}
|
||||
orphaned_by_app = {}
|
||||
total_files = 0
|
||||
|
||||
for row in result:
|
||||
|
||||
@@ -1274,13 +1274,6 @@ class PositionConfig(BaseSettings):
|
||||
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""}
|
||||
|
||||
|
||||
class CollaborationConfig(BaseSettings):
|
||||
ENABLE_COLLABORATION_MODE: bool = Field(
|
||||
description="Whether to enable collaboration mode features across the workspace",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class LoginConfig(BaseSettings):
|
||||
ENABLE_EMAIL_CODE_LOGIN: bool = Field(
|
||||
description="whether to enable email code login",
|
||||
@@ -1406,7 +1399,6 @@ class FeatureConfig(
|
||||
WorkflowConfig,
|
||||
WorkflowNodeExecutionConfig,
|
||||
WorkspaceConfig,
|
||||
CollaborationConfig,
|
||||
LoginConfig,
|
||||
AccountConfig,
|
||||
SwaggerUIConfig,
|
||||
|
||||
31
api/configs/middleware/cache/redis_config.py
vendored
31
api/configs/middleware/cache/redis_config.py
vendored
@@ -117,37 +117,6 @@ class RedisConfig(BaseSettings):
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_RETRY_RETRIES: NonNegativeInt = Field(
|
||||
description="Maximum number of retries per Redis command on "
|
||||
"transient failures (ConnectionError, TimeoutError, socket.timeout)",
|
||||
default=3,
|
||||
)
|
||||
|
||||
REDIS_RETRY_BACKOFF_BASE: PositiveFloat = Field(
|
||||
description="Base delay in seconds for exponential backoff between retries",
|
||||
default=1.0,
|
||||
)
|
||||
|
||||
REDIS_RETRY_BACKOFF_CAP: PositiveFloat = Field(
|
||||
description="Maximum backoff delay in seconds between retries",
|
||||
default=10.0,
|
||||
)
|
||||
|
||||
REDIS_SOCKET_TIMEOUT: PositiveFloat | None = Field(
|
||||
description="Socket timeout in seconds for Redis read/write operations",
|
||||
default=5.0,
|
||||
)
|
||||
|
||||
REDIS_SOCKET_CONNECT_TIMEOUT: PositiveFloat | None = Field(
|
||||
description="Socket timeout in seconds for Redis connection establishment",
|
||||
default=5.0,
|
||||
)
|
||||
|
||||
REDIS_HEALTH_CHECK_INTERVAL: NonNegativeInt = Field(
|
||||
description="Interval in seconds between Redis connection health checks (0 to disable)",
|
||||
default=30,
|
||||
)
|
||||
|
||||
@field_validator("REDIS_MAX_CONNECTIONS", mode="before")
|
||||
@classmethod
|
||||
def _empty_string_to_none_for_max_conns(cls, v):
|
||||
|
||||
@@ -7,16 +7,15 @@ UUID_NIL = "00000000-0000-0000-0000-000000000000"
|
||||
|
||||
DEFAULT_FILE_NUMBER_LIMITS = 3
|
||||
|
||||
_IMAGE_EXTENSION_BASE: frozenset[str] = frozenset(("jpg", "jpeg", "png", "webp", "gif", "svg"))
|
||||
_VIDEO_EXTENSION_BASE: frozenset[str] = frozenset(("mp4", "mov", "mpeg", "webm"))
|
||||
_AUDIO_EXTENSION_BASE: frozenset[str] = frozenset(("mp3", "m4a", "wav", "amr", "mpga"))
|
||||
IMAGE_EXTENSIONS = convert_to_lower_and_upper_set({"jpg", "jpeg", "png", "webp", "gif", "svg"})
|
||||
|
||||
IMAGE_EXTENSIONS: frozenset[str] = frozenset(convert_to_lower_and_upper_set(_IMAGE_EXTENSION_BASE))
|
||||
VIDEO_EXTENSIONS: frozenset[str] = frozenset(convert_to_lower_and_upper_set(_VIDEO_EXTENSION_BASE))
|
||||
AUDIO_EXTENSIONS: frozenset[str] = frozenset(convert_to_lower_and_upper_set(_AUDIO_EXTENSION_BASE))
|
||||
VIDEO_EXTENSIONS = convert_to_lower_and_upper_set({"mp4", "mov", "mpeg", "webm"})
|
||||
|
||||
_UNSTRUCTURED_DOCUMENT_EXTENSION_BASE: frozenset[str] = frozenset(
|
||||
(
|
||||
AUDIO_EXTENSIONS = convert_to_lower_and_upper_set({"mp3", "m4a", "wav", "amr", "mpga"})
|
||||
|
||||
_doc_extensions: set[str]
|
||||
if dify_config.ETL_TYPE == "Unstructured":
|
||||
_doc_extensions = {
|
||||
"txt",
|
||||
"markdown",
|
||||
"md",
|
||||
@@ -36,10 +35,11 @@ _UNSTRUCTURED_DOCUMENT_EXTENSION_BASE: frozenset[str] = frozenset(
|
||||
"pptx",
|
||||
"xml",
|
||||
"epub",
|
||||
)
|
||||
)
|
||||
_DEFAULT_DOCUMENT_EXTENSION_BASE: frozenset[str] = frozenset(
|
||||
(
|
||||
}
|
||||
if dify_config.UNSTRUCTURED_API_URL:
|
||||
_doc_extensions.add("ppt")
|
||||
else:
|
||||
_doc_extensions = {
|
||||
"txt",
|
||||
"markdown",
|
||||
"md",
|
||||
@@ -53,17 +53,8 @@ _DEFAULT_DOCUMENT_EXTENSION_BASE: frozenset[str] = frozenset(
|
||||
"csv",
|
||||
"vtt",
|
||||
"properties",
|
||||
)
|
||||
)
|
||||
|
||||
_doc_extensions: set[str]
|
||||
if dify_config.ETL_TYPE == "Unstructured":
|
||||
_doc_extensions = set(_UNSTRUCTURED_DOCUMENT_EXTENSION_BASE)
|
||||
if dify_config.UNSTRUCTURED_API_URL:
|
||||
_doc_extensions.add("ppt")
|
||||
else:
|
||||
_doc_extensions = set(_DEFAULT_DOCUMENT_EXTENSION_BASE)
|
||||
DOCUMENT_EXTENSIONS: frozenset[str] = frozenset(convert_to_lower_and_upper_set(_doc_extensions))
|
||||
}
|
||||
DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions)
|
||||
|
||||
# console
|
||||
COOKIE_NAME_ACCESS_TOKEN = "access_token"
|
||||
|
||||
@@ -10,7 +10,7 @@ import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from typing import Any, Protocol, final, runtime_checkable
|
||||
from typing import Any, Protocol, TypeVar, final, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -188,6 +188,8 @@ class ExecutionContextBuilder:
|
||||
_capturer: Callable[[], IExecutionContext] | None = None
|
||||
_tenant_context_providers: dict[tuple[str, str], Callable[[], BaseModel]] = {}
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class ContextProviderNotFoundError(KeyError):
|
||||
"""Raised when a tenant-scoped context provider is missing."""
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from contextvars import ContextVar
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class HiddenValue:
|
||||
@@ -8,7 +11,7 @@ class HiddenValue:
|
||||
_default = HiddenValue()
|
||||
|
||||
|
||||
class RecyclableContextVar[T]:
|
||||
class RecyclableContextVar(Generic[T]):
|
||||
"""
|
||||
RecyclableContextVar is a wrapper around ContextVar
|
||||
It's safe to use in gunicorn with thread recycling, but features like `reset` are not available for now
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
|
||||
# --- Conversation schemas ---
|
||||
|
||||
|
||||
class ConversationRenamePayload(BaseModel):
|
||||
name: str | None = None
|
||||
auto_generate: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_name_requirement(self):
|
||||
if not self.auto_generate:
|
||||
if self.name is None or not self.name.strip():
|
||||
raise ValueError("name is required when auto_generate is false")
|
||||
return self
|
||||
|
||||
|
||||
# --- Message schemas ---
|
||||
|
||||
|
||||
class MessageListQuery(BaseModel):
|
||||
conversation_id: UUIDStrOrEmpty
|
||||
first_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class MessageFeedbackPayload(BaseModel):
|
||||
rating: Literal["like", "dislike"] | None = None
|
||||
content: str | None = None
|
||||
|
||||
|
||||
# --- Saved message schemas ---
|
||||
|
||||
|
||||
class SavedMessageListQuery(BaseModel):
|
||||
last_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class SavedMessageCreatePayload(BaseModel):
|
||||
message_id: UUIDStrOrEmpty
|
||||
|
||||
|
||||
# --- Workflow schemas ---
|
||||
|
||||
|
||||
class DefaultBlockConfigQuery(BaseModel):
|
||||
q: str | None = None
|
||||
|
||||
|
||||
class WorkflowListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=10, ge=1, le=100)
|
||||
user_id: str | None = None
|
||||
named_only: bool = False
|
||||
|
||||
|
||||
class WorkflowRunPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class WorkflowUpdatePayload(BaseModel):
|
||||
marked_name: str | None = Field(default=None, max_length=20)
|
||||
marked_comment: str | None = Field(default=None, max_length=100)
|
||||
|
||||
|
||||
# --- Audio schemas ---
|
||||
|
||||
|
||||
class TextToAudioPayload(BaseModel):
|
||||
message_id: str | None = None
|
||||
voice: str | None = None
|
||||
text: str | None = None
|
||||
streaming: bool | None = None
|
||||
@@ -1,14 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
from graphon.file import helpers as file_helpers
|
||||
from pydantic import BaseModel, ConfigDict, computed_field
|
||||
|
||||
from models.model import IconType
|
||||
|
||||
type JSONValue = str | int | float | bool | None | dict[str, Any] | list[Any]
|
||||
type JSONObject = dict[str, Any]
|
||||
JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]
|
||||
JSONObject: TypeAlias = dict[str, Any]
|
||||
|
||||
|
||||
class SystemParameters(BaseModel):
|
||||
|
||||
@@ -4,8 +4,8 @@ from urllib.parse import quote
|
||||
|
||||
from flask import Response
|
||||
|
||||
HTML_MIME_TYPES: frozenset[str] = frozenset(("text/html", "application/xhtml+xml"))
|
||||
HTML_EXTENSIONS: frozenset[str] = frozenset(("html", "htm"))
|
||||
HTML_MIME_TYPES = frozenset({"text/html", "application/xhtml+xml"})
|
||||
HTML_EXTENSIONS = frozenset({"html", "htm"})
|
||||
|
||||
|
||||
def _normalize_mime_type(mime_type: str | None) -> str:
|
||||
|
||||
@@ -65,7 +65,6 @@ from .app import (
|
||||
statistic,
|
||||
workflow,
|
||||
workflow_app_log,
|
||||
workflow_comment,
|
||||
workflow_draft_variable,
|
||||
workflow_run,
|
||||
workflow_statistic,
|
||||
@@ -117,7 +116,6 @@ from .explore import (
|
||||
saved_message,
|
||||
trial,
|
||||
)
|
||||
from .socketio import workflow as socketio_workflow # pyright: ignore[reportUnusedImport]
|
||||
|
||||
# Import tag controllers
|
||||
from .tag import tags
|
||||
@@ -213,7 +211,6 @@ __all__ = [
|
||||
"website",
|
||||
"workflow",
|
||||
"workflow_app_log",
|
||||
"workflow_comment",
|
||||
"workflow_draft_variable",
|
||||
"workflow_run",
|
||||
"workflow_statistic",
|
||||
|
||||
@@ -2,7 +2,7 @@ import csv
|
||||
import io
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import cast
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@@ -18,7 +18,10 @@ from core.db.session_factory import session_factory
|
||||
from extensions.ext_database import db
|
||||
from libs.token import extract_access_token
|
||||
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
|
||||
from services.billing_service import BillingService, LangContentDict
|
||||
from services.billing_service import BillingService
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
@@ -69,9 +72,9 @@ console_ns.schema_model(
|
||||
)
|
||||
|
||||
|
||||
def admin_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
def admin_required(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if not dify_config.ADMIN_API_KEY:
|
||||
raise Unauthorized("API key is invalid.")
|
||||
|
||||
@@ -329,7 +332,7 @@ class UpsertNotificationApi(Resource):
|
||||
def post(self):
|
||||
payload = UpsertNotificationPayload.model_validate(console_ns.payload)
|
||||
result = BillingService.upsert_notification(
|
||||
contents=[cast(LangContentDict, c.model_dump()) for c in payload.contents],
|
||||
contents=[c.model_dump() for c in payload.contents],
|
||||
frequency=payload.frequency,
|
||||
status=payload.status,
|
||||
notification_id=payload.notification_id,
|
||||
|
||||
@@ -2,7 +2,7 @@ import flask_restx
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from flask_restx._http import HTTPStatus
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from extensions.ext_database import db
|
||||
@@ -34,7 +34,7 @@ api_key_list_model = console_ns.model(
|
||||
|
||||
|
||||
def _get_resource(resource_id, tenant_id, resource_model):
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
resource = session.execute(
|
||||
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
|
||||
).scalar_one_or_none()
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, TypeAlias
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.file import helpers as file_helpers
|
||||
from pydantic import AliasChoices, BaseModel, Field, computed_field, field_validator
|
||||
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.common.helpers import FileInfo
|
||||
@@ -26,25 +26,25 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from core.rag.entities import PreProcessingRule, Rule, Segmentation
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.trigger.constants import TRIGGER_NODE_TYPES
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, DatasetPermissionEnum, Workflow
|
||||
from models.model import IconType
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.app_dsl_service import AppDslService, ImportMode
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.entities.dsl_entities import ImportMode
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
DataSource,
|
||||
InfoList,
|
||||
NotionIcon,
|
||||
NotionInfo,
|
||||
NotionPage,
|
||||
PreProcessingRule,
|
||||
RerankingModel,
|
||||
Rule,
|
||||
Segmentation,
|
||||
WebsiteInfo,
|
||||
WeightKeywordSetting,
|
||||
WeightModel,
|
||||
@@ -152,7 +152,17 @@ class AppTracePayload(BaseModel):
|
||||
return value
|
||||
|
||||
|
||||
type JSONValue = Any
|
||||
JSONValue: TypeAlias = Any
|
||||
|
||||
|
||||
class ResponseModel(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
extra="ignore",
|
||||
populate_by_name=True,
|
||||
serialize_by_alias=True,
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
@@ -632,7 +642,7 @@ class AppCopyApi(Resource):
|
||||
|
||||
args = CopyAppPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
import_service = AppDslService(session)
|
||||
yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True)
|
||||
result = import_service.import_app(
|
||||
@@ -645,6 +655,7 @@ class AppCopyApi(Resource):
|
||||
icon=args.icon,
|
||||
icon_background=args.icon_background,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Inherit web app permission from original app
|
||||
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
@@ -17,9 +17,8 @@ from fields.app_fields import (
|
||||
)
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import App
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.app_dsl_service import AppDslService, ImportStatus
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.entities.dsl_entities import ImportStatus
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from .. import console_ns
|
||||
@@ -72,7 +71,7 @@ class AppImportApi(Resource):
|
||||
args = AppImportPayload.model_validate(console_ns.payload)
|
||||
|
||||
# Create service with session
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
import_service = AppDslService(session)
|
||||
# Import app
|
||||
account = current_user
|
||||
@@ -88,18 +87,17 @@ class AppImportApi(Resource):
|
||||
icon_background=args.icon_background,
|
||||
app_id=args.app_id,
|
||||
)
|
||||
session.commit()
|
||||
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
|
||||
# update web app setting as private
|
||||
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
|
||||
# Return appropriate status code based on result
|
||||
status = result.status
|
||||
match status:
|
||||
case ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
case ImportStatus.PENDING:
|
||||
return result.model_dump(mode="json"), 202
|
||||
case ImportStatus.COMPLETED | ImportStatus.COMPLETED_WITH_WARNINGS:
|
||||
return result.model_dump(mode="json"), 200
|
||||
if status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
elif status == ImportStatus.PENDING:
|
||||
return result.model_dump(mode="json"), 202
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/apps/imports/<string:import_id>/confirm")
|
||||
@@ -114,11 +112,12 @@ class AppImportConfirmApi(Resource):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
# Create service with session
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
import_service = AppDslService(session)
|
||||
# Confirm import
|
||||
account = current_user
|
||||
result = import_service.confirm_import(import_id=import_id, account=account)
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
if result.status == ImportStatus.FAILED:
|
||||
@@ -135,7 +134,7 @@ class AppImportCheckDependenciesApi(Resource):
|
||||
@marshal_with(app_import_check_dependencies_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
import_service = AppDslService(session)
|
||||
result = import_service.check_dependencies(app_model=app_model)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
@@ -69,7 +69,7 @@ class ConversationVariablesApi(Resource):
|
||||
page_size = 100
|
||||
stmt = stmt.limit(page_size).offset((page - 1) * page_size)
|
||||
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
rows = session.scalars(stmt).all()
|
||||
|
||||
return {
|
||||
|
||||
@@ -8,7 +8,6 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import exists, func, select
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.controller_schemas import MessageFeedbackPayload as _MessageFeedbackPayloadBase
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
@@ -60,8 +59,10 @@ class ChatMessagesQuery(BaseModel):
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class MessageFeedbackPayload(_MessageFeedbackPayloadBase):
|
||||
class MessageFeedbackPayload(BaseModel):
|
||||
message_id: str = Field(..., description="Message ID")
|
||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
|
||||
content: str | None = Field(default=None, description="Feedback content")
|
||||
|
||||
@field_validator("message_id")
|
||||
@classmethod
|
||||
|
||||
@@ -9,12 +9,11 @@ from graphon.enums import NodeType
|
||||
from graphon.file import File
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel, Field, ValidationError, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import DefaultBlockConfigQuery, WorkflowListQuery, WorkflowUpdatePayload
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.app.workflow_run import workflow_run_node_execution_model
|
||||
@@ -39,7 +38,6 @@ from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from factories import file_factory, variable_factory
|
||||
from fields.member_fields import simple_account_fields
|
||||
from fields.online_user_fields import online_user_list_fields
|
||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
from libs import helper
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@@ -48,7 +46,6 @@ from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from models.workflow import Workflow
|
||||
from repositories.workflow_collaboration_repository import WORKFLOW_ONLINE_USERS_PREFIX
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
@@ -145,6 +142,10 @@ class PublishWorkflowPayload(BaseModel):
|
||||
marked_comment: str | None = Field(default=None, max_length=100)
|
||||
|
||||
|
||||
class DefaultBlockConfigQuery(BaseModel):
|
||||
q: str | None = None
|
||||
|
||||
|
||||
class ConvertToWorkflowPayload(BaseModel):
|
||||
name: str | None = None
|
||||
icon_type: str | None = None
|
||||
@@ -152,12 +153,16 @@ class ConvertToWorkflowPayload(BaseModel):
|
||||
icon_background: str | None = None
|
||||
|
||||
|
||||
class WorkflowFeaturesPayload(BaseModel):
|
||||
features: dict[str, Any] = Field(..., description="Workflow feature configuration")
|
||||
class WorkflowListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=10, ge=1, le=100)
|
||||
user_id: str | None = None
|
||||
named_only: bool = False
|
||||
|
||||
|
||||
class WorkflowOnlineUsersQuery(BaseModel):
|
||||
workflow_ids: str = Field(..., description="Comma-separated workflow IDs")
|
||||
class WorkflowUpdatePayload(BaseModel):
|
||||
marked_name: str | None = Field(default=None, max_length=20)
|
||||
marked_comment: str | None = Field(default=None, max_length=100)
|
||||
|
||||
|
||||
class DraftWorkflowTriggerRunPayload(BaseModel):
|
||||
@@ -183,8 +188,6 @@ reg(DefaultBlockConfigQuery)
|
||||
reg(ConvertToWorkflowPayload)
|
||||
reg(WorkflowListQuery)
|
||||
reg(WorkflowUpdatePayload)
|
||||
reg(WorkflowFeaturesPayload)
|
||||
reg(WorkflowOnlineUsersQuery)
|
||||
reg(DraftWorkflowTriggerRunPayload)
|
||||
reg(DraftWorkflowTriggerRunAllPayload)
|
||||
|
||||
@@ -265,18 +268,22 @@ class DraftWorkflowApi(Resource):
|
||||
|
||||
content_type = request.headers.get("Content-Type", "")
|
||||
|
||||
payload_data: dict[str, Any] | None = None
|
||||
if "application/json" in content_type:
|
||||
payload_data = request.get_json(silent=True)
|
||||
if not isinstance(payload_data, dict):
|
||||
return {"message": "Invalid JSON data"}, 400
|
||||
args_model = SyncDraftWorkflowPayload.model_validate(payload_data)
|
||||
elif "text/plain" in content_type:
|
||||
try:
|
||||
args_model = SyncDraftWorkflowPayload.model_validate_json(request.data)
|
||||
except (ValueError, ValidationError):
|
||||
payload_data = json.loads(request.data.decode("utf-8"))
|
||||
except json.JSONDecodeError:
|
||||
return {"message": "Invalid JSON data"}, 400
|
||||
if not isinstance(payload_data, dict):
|
||||
return {"message": "Invalid JSON data"}, 400
|
||||
else:
|
||||
abort(415)
|
||||
|
||||
args_model = SyncDraftWorkflowPayload.model_validate(payload_data)
|
||||
args = args_model.model_dump()
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
@@ -833,7 +840,7 @@ class PublishedWorkflowApi(Resource):
|
||||
args = PublishWorkflowPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
workflow = workflow_service.publish_workflow(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
@@ -851,6 +858,8 @@ class PublishedWorkflowApi(Resource):
|
||||
|
||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
|
||||
session.commit()
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"created_at": workflow_created_at,
|
||||
@@ -943,31 +952,6 @@ class ConvertToWorkflowApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/features")
|
||||
class WorkflowFeaturesApi(Resource):
|
||||
"""Update draft workflow features."""
|
||||
|
||||
@console_ns.expect(console_ns.models[WorkflowFeaturesPayload.__name__])
|
||||
@console_ns.doc("update_workflow_features")
|
||||
@console_ns.doc(description="Update draft workflow features")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Workflow features updated successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = WorkflowFeaturesPayload.model_validate(console_ns.payload or {})
|
||||
features = args.features
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
workflow_service.update_draft_workflow_features(app_model=app_model, features=features, account=current_user)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows")
|
||||
class PublishedAllWorkflowApi(Resource):
|
||||
@console_ns.expect(console_ns.models[WorkflowListQuery.__name__])
|
||||
@@ -998,7 +982,7 @@ class PublishedAllWorkflowApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
workflows, has_more = workflow_service.get_all_published_workflow(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
@@ -1088,7 +1072,7 @@ class WorkflowByIdApi(Resource):
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
# Create a session and manage the transaction
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow = workflow_service.update_workflow(
|
||||
session=session,
|
||||
workflow_id=workflow_id,
|
||||
@@ -1100,6 +1084,9 @@ class WorkflowByIdApi(Resource):
|
||||
if not workflow:
|
||||
raise NotFound("Workflow not found")
|
||||
|
||||
# Commit the transaction in the controller
|
||||
session.commit()
|
||||
|
||||
return workflow
|
||||
|
||||
@setup_required
|
||||
@@ -1114,11 +1101,13 @@ class WorkflowByIdApi(Resource):
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
# Create a session and manage the transaction
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
try:
|
||||
workflow_service.delete_workflow(
|
||||
session=session, workflow_id=workflow_id, tenant_id=app_model.tenant_id
|
||||
)
|
||||
# Commit the transaction in the controller
|
||||
session.commit()
|
||||
except WorkflowInUseError as e:
|
||||
abort(400, description=str(e))
|
||||
except DraftWorkflowDeletionError as e:
|
||||
@@ -1377,32 +1366,3 @@ class DraftWorkflowTriggerRunAllApi(Resource):
|
||||
"status": "error",
|
||||
}
|
||||
), 400
|
||||
|
||||
|
||||
@console_ns.route("/apps/workflows/online-users")
|
||||
class WorkflowOnlineUsersApi(Resource):
|
||||
@console_ns.expect(console_ns.models[WorkflowOnlineUsersQuery.__name__])
|
||||
@console_ns.doc("get_workflow_online_users")
|
||||
@console_ns.doc(description="Get workflow online users")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(online_user_list_fields)
|
||||
def get(self):
|
||||
args = WorkflowOnlineUsersQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
workflow_ids = [workflow_id.strip() for workflow_id in args.workflow_ids.split(",") if workflow_id.strip()]
|
||||
|
||||
results = []
|
||||
for workflow_id in workflow_ids:
|
||||
users_json = redis_client.hgetall(f"{WORKFLOW_ONLINE_USERS_PREFIX}{workflow_id}")
|
||||
|
||||
users = []
|
||||
for _, user_info_json in users_json.items():
|
||||
try:
|
||||
users.append(json.loads(user_info_json))
|
||||
except Exception:
|
||||
continue
|
||||
results.append({"workflow_id": workflow_id, "users": users})
|
||||
|
||||
return {"data": results}
|
||||
|
||||
@@ -5,7 +5,7 @@ from flask import request
|
||||
from flask_restx import Resource, marshal_with
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
@@ -87,7 +87,7 @@ class WorkflowAppLogApi(Resource):
|
||||
|
||||
# get paginate workflow app logs
|
||||
workflow_app_service = WorkflowAppService()
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
@@ -124,7 +124,7 @@ class WorkflowArchivedLogApi(Resource):
|
||||
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
workflow_app_service = WorkflowAppService()
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_archive_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
|
||||
@@ -1,322 +0,0 @@
|
||||
import logging
|
||||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.member_fields import AccountWithRole
|
||||
from fields.workflow_comment_fields import (
|
||||
workflow_comment_basic_fields,
|
||||
workflow_comment_create_fields,
|
||||
workflow_comment_detail_fields,
|
||||
workflow_comment_reply_create_fields,
|
||||
workflow_comment_reply_update_fields,
|
||||
workflow_comment_resolve_fields,
|
||||
workflow_comment_update_fields,
|
||||
)
|
||||
from libs.login import current_user, login_required
|
||||
from models import App
|
||||
from services.account_service import TenantService
|
||||
from services.workflow_comment_service import WorkflowCommentService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowCommentCreatePayload(BaseModel):
|
||||
position_x: float = Field(..., description="Comment X position")
|
||||
position_y: float = Field(..., description="Comment Y position")
|
||||
content: str = Field(..., description="Comment content")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentUpdatePayload(BaseModel):
|
||||
content: str = Field(..., description="Comment content")
|
||||
position_x: float | None = Field(default=None, description="Comment X position")
|
||||
position_y: float | None = Field(default=None, description="Comment Y position")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentReplyCreatePayload(BaseModel):
|
||||
content: str = Field(..., description="Reply content")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentReplyUpdatePayload(BaseModel):
|
||||
content: str = Field(..., description="Reply content")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentMentionUsersPayload(BaseModel):
|
||||
users: list[AccountWithRole]
|
||||
|
||||
|
||||
for model in (
|
||||
WorkflowCommentCreatePayload,
|
||||
WorkflowCommentUpdatePayload,
|
||||
WorkflowCommentReplyCreatePayload,
|
||||
WorkflowCommentReplyUpdatePayload,
|
||||
):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
register_schema_models(console_ns, AccountWithRole, WorkflowCommentMentionUsersPayload)
|
||||
|
||||
workflow_comment_basic_model = console_ns.model("WorkflowCommentBasic", workflow_comment_basic_fields)
|
||||
workflow_comment_detail_model = console_ns.model("WorkflowCommentDetail", workflow_comment_detail_fields)
|
||||
workflow_comment_create_model = console_ns.model("WorkflowCommentCreate", workflow_comment_create_fields)
|
||||
workflow_comment_update_model = console_ns.model("WorkflowCommentUpdate", workflow_comment_update_fields)
|
||||
workflow_comment_resolve_model = console_ns.model("WorkflowCommentResolve", workflow_comment_resolve_fields)
|
||||
workflow_comment_reply_create_model = console_ns.model(
|
||||
"WorkflowCommentReplyCreate", workflow_comment_reply_create_fields
|
||||
)
|
||||
workflow_comment_reply_update_model = console_ns.model(
|
||||
"WorkflowCommentReplyUpdate", workflow_comment_reply_update_fields
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments")
|
||||
class WorkflowCommentListApi(Resource):
|
||||
"""API for listing and creating workflow comments."""
|
||||
|
||||
@console_ns.doc("list_workflow_comments")
|
||||
@console_ns.doc(description="Get all comments for a workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Comments retrieved successfully", workflow_comment_basic_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_basic_model, envelope="data")
|
||||
def get(self, app_model: App):
|
||||
"""Get all comments for a workflow."""
|
||||
comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id)
|
||||
|
||||
return comments
|
||||
|
||||
@console_ns.doc("create_workflow_comment")
|
||||
@console_ns.doc(description="Create a new workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentCreatePayload.__name__])
|
||||
@console_ns.response(201, "Comment created successfully", workflow_comment_create_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_create_model)
|
||||
def post(self, app_model: App):
|
||||
"""Create a new workflow comment."""
|
||||
payload = WorkflowCommentCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.create_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
created_by=current_user.id,
|
||||
content=payload.content,
|
||||
position_x=payload.position_x,
|
||||
position_y=payload.position_y,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result, 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>")
|
||||
class WorkflowCommentDetailApi(Resource):
|
||||
"""API for managing individual workflow comments."""
|
||||
|
||||
@console_ns.doc("get_workflow_comment")
|
||||
@console_ns.doc(description="Get a specific workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(200, "Comment retrieved successfully", workflow_comment_detail_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_detail_model)
|
||||
def get(self, app_model: App, comment_id: str):
|
||||
"""Get a specific workflow comment."""
|
||||
comment = WorkflowCommentService.get_comment(
|
||||
tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id
|
||||
)
|
||||
|
||||
return comment
|
||||
|
||||
@console_ns.doc("update_workflow_comment")
|
||||
@console_ns.doc(description="Update a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Comment updated successfully", workflow_comment_update_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_update_model)
|
||||
def put(self, app_model: App, comment_id: str):
|
||||
"""Update a workflow comment."""
|
||||
payload = WorkflowCommentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.update_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
content=payload.content,
|
||||
position_x=payload.position_x,
|
||||
position_y=payload.position_y,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@console_ns.doc("delete_workflow_comment")
|
||||
@console_ns.doc(description="Delete a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(204, "Comment deleted successfully")
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def delete(self, app_model: App, comment_id: str):
|
||||
"""Delete a workflow comment."""
|
||||
WorkflowCommentService.delete_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/resolve")
|
||||
class WorkflowCommentResolveApi(Resource):
|
||||
"""API for resolving and reopening workflow comments."""
|
||||
|
||||
@console_ns.doc("resolve_workflow_comment")
|
||||
@console_ns.doc(description="Resolve a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(200, "Comment resolved successfully", workflow_comment_resolve_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_resolve_model)
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
"""Resolve a workflow comment."""
|
||||
comment = WorkflowCommentService.resolve_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return comment
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies")
|
||||
class WorkflowCommentReplyApi(Resource):
|
||||
"""API for managing comment replies."""
|
||||
|
||||
@console_ns.doc("create_workflow_comment_reply")
|
||||
@console_ns.doc(description="Add a reply to a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentReplyCreatePayload.__name__])
|
||||
@console_ns.response(201, "Reply created successfully", workflow_comment_reply_create_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_reply_create_model)
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
"""Add a reply to a workflow comment."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
payload = WorkflowCommentReplyCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.create_reply(
|
||||
comment_id=comment_id,
|
||||
content=payload.content,
|
||||
created_by=current_user.id,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result, 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies/<string:reply_id>")
|
||||
class WorkflowCommentReplyDetailApi(Resource):
|
||||
"""API for managing individual comment replies."""
|
||||
|
||||
@console_ns.doc("update_workflow_comment_reply")
|
||||
@console_ns.doc(description="Update a comment reply")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentReplyUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Reply updated successfully", workflow_comment_reply_update_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_reply_update_model)
|
||||
def put(self, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Update a comment reply."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
payload = WorkflowCommentReplyUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
reply = WorkflowCommentService.update_reply(
|
||||
reply_id=reply_id,
|
||||
user_id=current_user.id,
|
||||
content=payload.content,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return reply
|
||||
|
||||
@console_ns.doc("delete_workflow_comment_reply")
|
||||
@console_ns.doc(description="Delete a comment reply")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"})
|
||||
@console_ns.response(204, "Reply deleted successfully")
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def delete(self, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Delete a comment reply."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
WorkflowCommentService.delete_reply(reply_id=reply_id, user_id=current_user.id)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/mention-users")
|
||||
class WorkflowCommentMentionUsersApi(Resource):
|
||||
"""API for getting mentionable users for workflow comments."""
|
||||
|
||||
@console_ns.doc("workflow_comment_mention_users")
|
||||
@console_ns.doc(description="Get all users in current tenant for mentions")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(
|
||||
200, "Mentionable users retrieved successfully", console_ns.models[WorkflowCommentMentionUsersPayload.__name__]
|
||||
)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model: App):
|
||||
"""Get all users in current tenant for mentions."""
|
||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||
users = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
|
||||
response = WorkflowCommentMentionUsersPayload(users=users)
|
||||
return response.model_dump(mode="json"), 200
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
from typing import Any, NoReturn, ParamSpec, TypeVar
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
@@ -10,7 +10,7 @@ from graphon.variables.segment_group import SegmentGroup
|
||||
from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from graphon.variables.types import SegmentType
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
@@ -22,7 +22,6 @@ from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from libs.login import current_user, login_required
|
||||
@@ -46,16 +45,6 @@ class WorkflowDraftVariableUpdatePayload(BaseModel):
|
||||
value: Any | None = Field(default=None, description="Variable value")
|
||||
|
||||
|
||||
class ConversationVariableUpdatePayload(BaseModel):
|
||||
conversation_variables: list[dict[str, Any]] = Field(
|
||||
..., description="Conversation variables for the draft workflow"
|
||||
)
|
||||
|
||||
|
||||
class EnvironmentVariableUpdatePayload(BaseModel):
|
||||
environment_variables: list[dict[str, Any]] = Field(..., description="Environment variables for the draft workflow")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkflowDraftVariableListQuery.__name__,
|
||||
WorkflowDraftVariableListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
@@ -64,14 +53,6 @@ console_ns.schema_model(
|
||||
WorkflowDraftVariableUpdatePayload.__name__,
|
||||
WorkflowDraftVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ConversationVariableUpdatePayload.__name__,
|
||||
ConversationVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
EnvironmentVariableUpdatePayload.__name__,
|
||||
EnvironmentVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
def _convert_values_to_json_serializable_object(value: Segment):
|
||||
@@ -211,8 +192,11 @@ workflow_draft_variable_list_model = console_ns.model(
|
||||
"WorkflowDraftVariableList", workflow_draft_variable_list_fields_copy
|
||||
)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||
|
||||
def _api_prerequisite(f: Callable[P, R]):
|
||||
"""Common prerequisites for all draft workflow variable APIs.
|
||||
|
||||
It ensures the following conditions are satisfied:
|
||||
@@ -229,7 +213,7 @@ def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@wraps(f)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
@@ -260,7 +244,7 @@ class WorkflowVariableCollectionApi(Resource):
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
@@ -286,7 +270,7 @@ class WorkflowVariableCollectionApi(Resource):
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
def validate_node_id(node_id: str) -> None:
|
||||
def validate_node_id(node_id: str) -> NoReturn | None:
|
||||
if node_id in [
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
@@ -301,6 +285,7 @@ def validate_node_id(node_id: str) -> None:
|
||||
raise InvalidArgumentError(
|
||||
f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}",
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
|
||||
@@ -313,7 +298,7 @@ class NodeVariableCollectionApi(Resource):
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, app_model: App, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
@@ -403,27 +388,24 @@ class VariableApi(Resource):
|
||||
|
||||
new_value = None
|
||||
if raw_value is not None:
|
||||
match variable.value_type:
|
||||
case SegmentType.FILE:
|
||||
if not isinstance(raw_value, dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
|
||||
raw_value = build_from_mapping(
|
||||
mapping=raw_value,
|
||||
tenant_id=app_model.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
case SegmentType.ARRAY_FILE:
|
||||
if not isinstance(raw_value, list):
|
||||
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
|
||||
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
raw_value = build_from_mappings(
|
||||
mappings=raw_value,
|
||||
tenant_id=app_model.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
case _:
|
||||
pass
|
||||
if variable.value_type == SegmentType.FILE:
|
||||
if not isinstance(raw_value, dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
|
||||
raw_value = build_from_mapping(
|
||||
mapping=raw_value,
|
||||
tenant_id=app_model.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
elif variable.value_type == SegmentType.ARRAY_FILE:
|
||||
if not isinstance(raw_value, list):
|
||||
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
|
||||
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
raw_value = build_from_mappings(
|
||||
mappings=raw_value,
|
||||
tenant_id=app_model.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
new_value = build_segment_with_type(variable.value_type, raw_value)
|
||||
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||
db.session.commit()
|
||||
@@ -483,7 +465,7 @@ class VariableResetApi(Resource):
|
||||
|
||||
|
||||
def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
@@ -521,34 +503,6 @@ class ConversationVariableCollectionApi(Resource):
|
||||
db.session.commit()
|
||||
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
|
||||
|
||||
@console_ns.expect(console_ns.models[ConversationVariableUpdatePayload.__name__])
|
||||
@console_ns.doc("update_conversation_variables")
|
||||
@console_ns.doc(description="Update conversation variables for workflow draft")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Conversation variables updated successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
||||
def post(self, app_model: App):
|
||||
payload = ConversationVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
conversation_variables_list = payload.conversation_variables
|
||||
conversation_variables = [
|
||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
]
|
||||
|
||||
workflow_service.update_draft_workflow_conversation_variables(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/system-variables")
|
||||
class SystemVariableCollectionApi(Resource):
|
||||
@@ -600,31 +554,3 @@ class EnvironmentVariableCollectionApi(Resource):
|
||||
)
|
||||
|
||||
return {"items": env_vars_list}
|
||||
|
||||
@console_ns.expect(console_ns.models[EnvironmentVariableUpdatePayload.__name__])
|
||||
@console_ns.doc("update_environment_variables")
|
||||
@console_ns.doc(description="Update environment variables for workflow draft")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Environment variables updated successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
payload = EnvironmentVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
environment_variables_list = payload.environment_variables
|
||||
environment_variables = [
|
||||
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
|
||||
]
|
||||
|
||||
workflow_service.update_draft_workflow_environment_variables(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
environment_variables=environment_variables,
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@@ -4,7 +4,7 @@ from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
@@ -64,15 +64,15 @@ class WebhookTriggerApi(Resource):
|
||||
|
||||
node_id = args.node_id
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
# Get webhook trigger for this app and node
|
||||
webhook_trigger = session.scalar(
|
||||
select(WorkflowWebhookTrigger)
|
||||
webhook_trigger = (
|
||||
session.query(WorkflowWebhookTrigger)
|
||||
.where(
|
||||
WorkflowWebhookTrigger.app_id == app_model.id,
|
||||
WorkflowWebhookTrigger.node_id == node_id,
|
||||
)
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not webhook_trigger:
|
||||
@@ -95,7 +95,7 @@ class AppTriggersApi(Resource):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
# Get all triggers for this app using select API
|
||||
triggers = (
|
||||
session.execute(
|
||||
@@ -137,7 +137,7 @@ class AppTriggerEnableApi(Resource):
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
trigger_id = args.trigger_id
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
# Find the trigger using select
|
||||
trigger = session.execute(
|
||||
select(AppTrigger).where(
|
||||
@@ -153,6 +153,9 @@ class AppTriggerEnableApi(Resource):
|
||||
# Update status based on enable_trigger boolean
|
||||
trigger.status = AppTriggerStatus.ENABLED if args.enable_trigger else AppTriggerStatus.DISABLED
|
||||
|
||||
session.commit()
|
||||
session.refresh(trigger)
|
||||
|
||||
# Add computed icon field
|
||||
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
|
||||
if trigger.trigger_type == "trigger-plugin":
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import overload
|
||||
from typing import ParamSpec, TypeVar, Union
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
@@ -9,6 +9,11 @@ from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant
|
||||
from models import App, AppMode
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
P1 = ParamSpec("P1")
|
||||
R1 = TypeVar("R1")
|
||||
|
||||
|
||||
def _load_app_model(app_id: str) -> App | None:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@@ -23,30 +28,10 @@ def _load_app_model_with_trial(app_id: str) -> App | None:
|
||||
return app_model
|
||||
|
||||
|
||||
@overload
|
||||
def get_app_model[**P, R](
|
||||
view: Callable[P, R],
|
||||
*,
|
||||
mode: AppMode | list[AppMode] | None = None,
|
||||
) -> Callable[P, R]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def get_app_model[**P, R](
|
||||
view: None = None,
|
||||
*,
|
||||
mode: AppMode | list[AppMode] | None = None,
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
|
||||
|
||||
|
||||
def get_app_model[**P, R](
|
||||
view: Callable[P, R] | None = None,
|
||||
*,
|
||||
mode: AppMode | list[AppMode] | None = None,
|
||||
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
|
||||
def decorator(view_func: Callable[P1, R1]):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
def decorated_view(*args: P1.args, **kwargs: P1.kwargs):
|
||||
if not kwargs.get("app_id"):
|
||||
raise ValueError("missing app_id in path parameters")
|
||||
|
||||
@@ -84,30 +69,10 @@ def get_app_model[**P, R](
|
||||
return decorator(view)
|
||||
|
||||
|
||||
@overload
|
||||
def get_app_model_with_trial[**P, R](
|
||||
view: Callable[P, R],
|
||||
*,
|
||||
mode: AppMode | list[AppMode] | None = None,
|
||||
) -> Callable[P, R]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def get_app_model_with_trial[**P, R](
|
||||
view: None = None,
|
||||
*,
|
||||
mode: AppMode | list[AppMode] | None = None,
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
|
||||
|
||||
|
||||
def get_app_model_with_trial[**P, R](
|
||||
view: Callable[P, R] | None = None,
|
||||
*,
|
||||
mode: AppMode | list[AppMode] | None = None,
|
||||
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
def get_app_model_with_trial(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
|
||||
def decorator(view_func: Callable[P, R]):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
if not kwargs.get("app_id"):
|
||||
raise ValueError("missing app_id in path parameters")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import secrets
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
@@ -20,18 +20,35 @@ from controllers.console.wraps import email_password_login_enabled, setup_requir
|
||||
from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.password import hash_password
|
||||
from libs.password import hash_password, valid_password
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.entities.auth_entities import (
|
||||
ForgotPasswordCheckPayload,
|
||||
ForgotPasswordResetPayload,
|
||||
ForgotPasswordSendPayload,
|
||||
)
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ForgotPasswordSendPayload(BaseModel):
|
||||
email: EmailStr = Field(...)
|
||||
language: str | None = Field(default=None)
|
||||
|
||||
|
||||
class ForgotPasswordCheckPayload(BaseModel):
|
||||
email: EmailStr = Field(...)
|
||||
code: str = Field(...)
|
||||
token: str = Field(...)
|
||||
|
||||
|
||||
class ForgotPasswordResetPayload(BaseModel):
|
||||
token: str = Field(...)
|
||||
new_password: str = Field(...)
|
||||
password_confirm: str = Field(...)
|
||||
|
||||
@field_validator("new_password", "password_confirm")
|
||||
@classmethod
|
||||
def validate_password(cls, value: str) -> str:
|
||||
return valid_password(value)
|
||||
|
||||
|
||||
class ForgotPasswordEmailResponse(BaseModel):
|
||||
result: str = Field(description="Operation result")
|
||||
data: str | None = Field(default=None, description="Reset token")
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
import flask_login
|
||||
from flask import make_response, request
|
||||
from flask_restx import Resource
|
||||
@@ -40,9 +42,8 @@ from libs.token import (
|
||||
set_csrf_token_to_cookie,
|
||||
set_refresh_token_to_cookie,
|
||||
)
|
||||
from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.billing_service import BillingService
|
||||
from services.entities.auth_entities import LoginPayloadBase
|
||||
from services.errors.account import AccountRegisterError
|
||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
|
||||
from services.feature_service import FeatureService
|
||||
@@ -50,7 +51,9 @@ from services.feature_service import FeatureService
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class LoginPayload(LoginPayloadBase):
|
||||
class LoginPayload(BaseModel):
|
||||
email: EmailStr = Field(..., description="Email address")
|
||||
password: str = Field(..., description="Password")
|
||||
remember_me: bool = Field(default=False, description="Remember me flag")
|
||||
invite_token: str | None = Field(default=None, description="Invitation token")
|
||||
|
||||
@@ -98,7 +101,7 @@ class LoginApi(Resource):
|
||||
raise EmailPasswordLoginLimitError()
|
||||
|
||||
invite_token = args.invite_token
|
||||
invitation_data: InvitationDetailDict | None = None
|
||||
invitation_data: dict[str, Any] | None = None
|
||||
if invite_token:
|
||||
invitation_data = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token)
|
||||
if invitation_data is None:
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Concatenate
|
||||
from typing import Concatenate, ParamSpec, TypeVar
|
||||
|
||||
from flask import jsonify, request
|
||||
from flask.typing import ResponseReturnValue
|
||||
from flask_restx import Resource
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel
|
||||
@@ -17,6 +16,10 @@ from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType,
|
||||
|
||||
from .. import console_ns
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class OAuthClientPayload(BaseModel):
|
||||
client_id: str
|
||||
@@ -36,11 +39,9 @@ class OAuthTokenRequest(BaseModel):
|
||||
refresh_token: str | None = None
|
||||
|
||||
|
||||
def oauth_server_client_id_required[T, **P, R](
|
||||
view: Callable[Concatenate[T, OAuthProviderApp, P], R],
|
||||
) -> Callable[Concatenate[T, P], R]:
|
||||
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
|
||||
json_data = request.get_json()
|
||||
if json_data is None:
|
||||
raise BadRequest("client_id is required")
|
||||
@@ -57,13 +58,9 @@ def oauth_server_client_id_required[T, **P, R](
|
||||
return decorated
|
||||
|
||||
|
||||
def oauth_server_access_token_required[T, **P, R](
|
||||
view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R],
|
||||
) -> Callable[Concatenate[T, OAuthProviderApp, P], R | ResponseReturnValue]:
|
||||
def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(
|
||||
self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs
|
||||
) -> R | ResponseReturnValue:
|
||||
def decorated(self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs):
|
||||
if not isinstance(oauth_provider_app, OAuthProviderApp):
|
||||
raise BadRequest("Invalid oauth_provider_app")
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ class Subscription(Resource):
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
args = SubscriptionQuery.model_validate(request.args.to_dict(flat=True))
|
||||
args = SubscriptionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
return BillingService.get_subscription(args.plan, args.interval, current_user.email, current_tenant_id)
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ class ComplianceApi(Resource):
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
args = ComplianceDownloadQuery.model_validate(request.args.to_dict(flat=True))
|
||||
args = ComplianceDownloadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
device_info = request.headers.get("User-Agent", "Unknown device")
|
||||
|
||||
@@ -6,7 +6,7 @@ from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import get_or_create_model, register_schema_model
|
||||
@@ -158,11 +158,10 @@ class DataSourceApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, binding_id, action: Literal["enable", "disable"]):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
binding_id = str(binding_id)
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
data_source_binding = session.execute(
|
||||
select(DataSourceOauthBinding).filter_by(id=binding_id, tenant_id=current_tenant_id)
|
||||
select(DataSourceOauthBinding).filter_by(id=binding_id)
|
||||
).scalar_one_or_none()
|
||||
if data_source_binding is None:
|
||||
raise NotFound("Data source binding not found.")
|
||||
@@ -212,7 +211,7 @@ class DataSourceNotionListApi(Resource):
|
||||
if not credential:
|
||||
raise NotFound("Credential not found.")
|
||||
exist_page_ids = []
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
# import notion in the exist dataset
|
||||
if query.dataset_id:
|
||||
dataset = DatasetService.get_dataset(query.dataset_id)
|
||||
|
||||
@@ -173,11 +173,8 @@ class ExternalApiTemplateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, external_knowledge_api_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||
external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(
|
||||
external_knowledge_api_id, current_tenant_id
|
||||
)
|
||||
external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(external_knowledge_api_id)
|
||||
if external_knowledge_api is None:
|
||||
raise NotFound("API template not found.")
|
||||
|
||||
|
||||
@@ -120,8 +120,7 @@ class DatasourceOAuthCallback(Resource):
|
||||
if context is None:
|
||||
raise Forbidden("Invalid context_id")
|
||||
|
||||
user_id: str = context["user_id"]
|
||||
tenant_id: str = context["tenant_id"]
|
||||
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
plugin_id = datasource_provider_id.plugin_id
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
@@ -142,7 +141,7 @@ class DatasourceOAuthCallback(Resource):
|
||||
system_credentials=oauth_client_params,
|
||||
request=request,
|
||||
)
|
||||
credential_id: str | None = context.get("credential_id")
|
||||
credential_id = context.get("credential_id")
|
||||
if credential_id:
|
||||
datasource_provider_service.reauthorize_datasource_oauth_provider(
|
||||
tenant_id=tenant_id,
|
||||
@@ -151,7 +150,7 @@ class DatasourceOAuthCallback(Resource):
|
||||
name=oauth_response.metadata.get("name") or None,
|
||||
expire_at=oauth_response.expires_at,
|
||||
credentials=dict(oauth_response.credentials),
|
||||
credential_id=credential_id,
|
||||
credential_id=context.get("credential_id"),
|
||||
)
|
||||
else:
|
||||
datasource_provider_service.add_datasource_oauth_provider(
|
||||
|
||||
@@ -3,8 +3,7 @@ import logging
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
@@ -86,9 +85,9 @@ class CustomizedPipelineTemplateApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def post(self, template_id: str):
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
template = session.scalar(
|
||||
select(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).limit(1)
|
||||
with Session(db.engine) as session:
|
||||
template = (
|
||||
session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first()
|
||||
)
|
||||
if not template:
|
||||
raise ValueError("Customized pipeline template not found.")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from flask_restx import Resource, marshal
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
@@ -54,7 +54,7 @@ class CreateRagPipelineDatasetApi(Resource):
|
||||
yaml_content=payload.yaml_content,
|
||||
)
|
||||
try:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||
import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset(
|
||||
tenant_id=current_tenant_id,
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any, NoReturn
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, marshal, marshal_with
|
||||
from graphon.variables.types import SegmentType
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
@@ -56,7 +55,7 @@ class WorkflowDraftVariablePatchPayload(BaseModel):
|
||||
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
|
||||
|
||||
|
||||
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||
def _api_prerequisite(f):
|
||||
"""Common prerequisites for all draft workflow variable APIs.
|
||||
|
||||
It ensures the following conditions are satisfied:
|
||||
@@ -71,7 +70,7 @@ def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
def wrapper(*args, **kwargs):
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
return f(*args, **kwargs)
|
||||
@@ -97,7 +96,7 @@ class RagPipelineVariableCollectionApi(Resource):
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
@@ -144,7 +143,7 @@ class RagPipelineNodeVariableCollectionApi(Resource):
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, pipeline: Pipeline, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
@@ -223,27 +222,24 @@ class RagPipelineVariableApi(Resource):
|
||||
|
||||
new_value = None
|
||||
if raw_value is not None:
|
||||
match variable.value_type:
|
||||
case SegmentType.FILE:
|
||||
if not isinstance(raw_value, dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
|
||||
raw_value = build_from_mapping(
|
||||
mapping=raw_value,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
case SegmentType.ARRAY_FILE:
|
||||
if not isinstance(raw_value, list):
|
||||
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
|
||||
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
raw_value = build_from_mappings(
|
||||
mappings=raw_value,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
case _:
|
||||
pass
|
||||
if variable.value_type == SegmentType.FILE:
|
||||
if not isinstance(raw_value, dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
|
||||
raw_value = build_from_mapping(
|
||||
mapping=raw_value,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
elif variable.value_type == SegmentType.ARRAY_FILE:
|
||||
if not isinstance(raw_value, list):
|
||||
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
|
||||
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
raw_value = build_from_mappings(
|
||||
mappings=raw_value,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
new_value = build_segment_with_type(variable.value_type, raw_value)
|
||||
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||
db.session.commit()
|
||||
@@ -293,7 +289,7 @@ class RagPipelineVariableResetApi(Resource):
|
||||
|
||||
|
||||
def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList:
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with # type: ignore
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import get_or_create_model, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
@@ -19,7 +19,7 @@ from fields.rag_pipeline_fields import (
|
||||
)
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import Pipeline
|
||||
from services.entities.dsl_entities import ImportStatus
|
||||
from services.app_dsl_service import ImportStatus
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ class RagPipelineImportApi(Resource):
|
||||
payload = RagPipelineImportPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
# Create service with session
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
# Import app
|
||||
account = current_user
|
||||
@@ -80,16 +80,15 @@ class RagPipelineImportApi(Resource):
|
||||
pipeline_id=payload.pipeline_id,
|
||||
dataset_name=payload.name,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
status = result.status
|
||||
match status:
|
||||
case ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
case ImportStatus.PENDING:
|
||||
return result.model_dump(mode="json"), 202
|
||||
case ImportStatus.COMPLETED | ImportStatus.COMPLETED_WITH_WARNINGS:
|
||||
return result.model_dump(mode="json"), 200
|
||||
if status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
elif status == ImportStatus.PENDING:
|
||||
return result.model_dump(mode="json"), 202
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/imports/<string:import_id>/confirm")
|
||||
@@ -103,11 +102,12 @@ class RagPipelineImportConfirmApi(Resource):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
# Create service with session
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
# Confirm import
|
||||
account = current_user
|
||||
result = import_service.confirm_import(import_id=import_id, account=account)
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
if result.status == ImportStatus.FAILED:
|
||||
@@ -124,7 +124,7 @@ class RagPipelineImportCheckDependenciesApi(Resource):
|
||||
@edit_permission_required
|
||||
@marshal_with(pipeline_import_check_dependencies_model)
|
||||
def get(self, pipeline: Pipeline):
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
result = import_service.check_dependencies(pipeline=pipeline)
|
||||
|
||||
@@ -142,7 +142,7 @@ class RagPipelineExportApi(Resource):
|
||||
# Add include_secret params
|
||||
query = IncludeSecretQuery.model_validate(request.args.to_dict())
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
export_service = RagPipelineDslService(session)
|
||||
result = export_service.export_rag_pipeline_dsl(
|
||||
pipeline=pipeline, include_secret=query.include_secret == "true"
|
||||
|
||||
@@ -5,12 +5,11 @@ from typing import Any, Literal, cast
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, marshal_with # type: ignore
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import DefaultBlockConfigQuery, WorkflowListQuery, WorkflowUpdatePayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
@@ -95,6 +94,22 @@ class PublishedWorkflowRunPayload(DraftWorkflowRunPayload):
|
||||
original_document_id: str | None = None
|
||||
|
||||
|
||||
class DefaultBlockConfigQuery(BaseModel):
|
||||
q: str | None = None
|
||||
|
||||
|
||||
class WorkflowListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=10, ge=1, le=100)
|
||||
user_id: str | None = None
|
||||
named_only: bool = False
|
||||
|
||||
|
||||
class WorkflowUpdatePayload(BaseModel):
|
||||
marked_name: str | None = Field(default=None, max_length=20)
|
||||
marked_comment: str | None = Field(default=None, max_length=100)
|
||||
|
||||
|
||||
class NodeIdQuery(BaseModel):
|
||||
node_id: str
|
||||
|
||||
@@ -171,14 +186,29 @@ class DraftRagPipelineApi(Resource):
|
||||
|
||||
if "application/json" in content_type:
|
||||
payload_dict = console_ns.payload or {}
|
||||
payload = DraftWorkflowSyncPayload.model_validate(payload_dict)
|
||||
elif "text/plain" in content_type:
|
||||
try:
|
||||
payload = DraftWorkflowSyncPayload.model_validate_json(request.data)
|
||||
except (ValueError, ValidationError):
|
||||
data = json.loads(request.data.decode("utf-8"))
|
||||
if "graph" not in data or "features" not in data:
|
||||
raise ValueError("graph or features not found in data")
|
||||
|
||||
if not isinstance(data.get("graph"), dict):
|
||||
raise ValueError("graph is not a dict")
|
||||
|
||||
payload_dict = {
|
||||
"graph": data.get("graph"),
|
||||
"features": data.get("features"),
|
||||
"hash": data.get("hash"),
|
||||
"environment_variables": data.get("environment_variables"),
|
||||
"conversation_variables": data.get("conversation_variables"),
|
||||
"rag_pipeline_variables": data.get("rag_pipeline_variables"),
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
return {"message": "Invalid JSON data"}, 400
|
||||
else:
|
||||
abort(415)
|
||||
|
||||
payload = DraftWorkflowSyncPayload.model_validate(payload_dict)
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
|
||||
try:
|
||||
@@ -578,15 +608,19 @@ class PublishedRagPipelineApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow = rag_pipeline_service.publish_workflow(
|
||||
session=db.session, # type: ignore[reportArgumentType,arg-type]
|
||||
pipeline=pipeline,
|
||||
account=current_user,
|
||||
)
|
||||
pipeline.is_published = True
|
||||
pipeline.workflow_id = workflow.id
|
||||
db.session.commit()
|
||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
with Session(db.engine) as session:
|
||||
pipeline = session.merge(pipeline)
|
||||
workflow = rag_pipeline_service.publish_workflow(
|
||||
session=session,
|
||||
pipeline=pipeline,
|
||||
account=current_user,
|
||||
)
|
||||
pipeline.is_published = True
|
||||
pipeline.workflow_id = workflow.id
|
||||
session.add(pipeline)
|
||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
|
||||
session.commit()
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
@@ -661,7 +695,7 @@ class PublishedAllRagPipelineApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
workflows, has_more = rag_pipeline_service.get_all_published_workflow(
|
||||
session=session,
|
||||
pipeline=pipeline,
|
||||
@@ -733,7 +767,7 @@ class RagPipelineByIdApi(Resource):
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
|
||||
# Create a session and manage the transaction
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow = rag_pipeline_service.update_workflow(
|
||||
session=session,
|
||||
workflow_id=workflow_id,
|
||||
@@ -745,6 +779,9 @@ class RagPipelineByIdApi(Resource):
|
||||
if not workflow:
|
||||
raise NotFound("Workflow not found")
|
||||
|
||||
# Commit the transaction in the controller
|
||||
session.commit()
|
||||
|
||||
return workflow
|
||||
|
||||
@setup_required
|
||||
@@ -761,13 +798,14 @@ class RagPipelineByIdApi(Resource):
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
try:
|
||||
workflow_service.delete_workflow(
|
||||
session=session,
|
||||
workflow_id=workflow_id,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
)
|
||||
session.commit()
|
||||
except WorkflowInUseError as e:
|
||||
abort(400, description=str(e))
|
||||
except DraftWorkflowDeletionError as e:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
@@ -8,10 +9,13 @@ from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.dataset import Pipeline
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
def get_rag_pipeline[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
|
||||
def get_rag_pipeline(view_func: Callable[P, R]):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
if not kwargs.get("pipeline_id"):
|
||||
raise ValueError("missing pipeline_id in path parameters")
|
||||
|
||||
|
||||
@@ -2,10 +2,10 @@ import logging
|
||||
|
||||
from flask import request
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import TextToAudioPayload
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
@@ -32,6 +32,14 @@ from .. import console_ns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TextToAudioPayload(BaseModel):
|
||||
message_id: str | None = None
|
||||
voice: str | None = None
|
||||
text: str | None = None
|
||||
streaming: bool | None = Field(default=None, description="Enable streaming response")
|
||||
|
||||
|
||||
register_schema_model(console_ns, TextToAudioPayload)
|
||||
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from pydantic import BaseModel, Field, TypeAdapter, model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.controller_schemas import ConversationRenamePayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.explore.error import NotChatAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
@@ -33,6 +32,18 @@ class ConversationListQuery(BaseModel):
|
||||
pinned: bool | None = None
|
||||
|
||||
|
||||
class ConversationRenamePayload(BaseModel):
|
||||
name: str | None = None
|
||||
auto_generate: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_name_requirement(self):
|
||||
if not self.auto_generate:
|
||||
if self.name is None or not self.name.strip():
|
||||
raise ValueError("name is required when auto_generate is false")
|
||||
return self
|
||||
|
||||
|
||||
register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload)
|
||||
|
||||
|
||||
@@ -63,7 +74,7 @@ class ConversationListApi(InstalledAppResource):
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
pagination = WebConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
|
||||
@@ -3,10 +3,9 @@ from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.controller_schemas import MessageFeedbackPayload, MessageListQuery
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.app.error import (
|
||||
AppMoreLikeThisDisabledError,
|
||||
@@ -26,6 +25,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse
|
||||
from libs import helper
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.enums import FeedbackRating
|
||||
from models.model import AppMode
|
||||
@@ -44,6 +44,17 @@ from .. import console_ns
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageListQuery(BaseModel):
|
||||
conversation_id: UUIDStrOrEmpty
|
||||
first_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class MessageFeedbackPayload(BaseModel):
|
||||
rating: Literal["like", "dislike"] | None = None
|
||||
content: str | None = None
|
||||
|
||||
|
||||
class MoreLikeThisQuery(BaseModel):
|
||||
response_mode: Literal["blocking", "streaming"]
|
||||
|
||||
|
||||
@@ -1,18 +1,28 @@
|
||||
from flask import request
|
||||
from pydantic import TypeAdapter
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.controller_schemas import SavedMessageCreatePayload, SavedMessageListQuery
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.explore.error import NotCompletionAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.login import current_account_with_tenant
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
||||
|
||||
class SavedMessageListQuery(BaseModel):
|
||||
last_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class SavedMessageCreatePayload(BaseModel):
|
||||
message_id: UUIDStrOrEmpty
|
||||
|
||||
|
||||
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.common.controller_schemas import WorkflowRunPayload
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
@@ -33,6 +34,12 @@ from .. import console_ns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowRunPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
register_schema_model(console_ns, WorkflowRunPayload)
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Concatenate
|
||||
from typing import Concatenate, ParamSpec, TypeVar
|
||||
|
||||
from flask import abort
|
||||
from flask_restx import Resource
|
||||
@@ -15,8 +15,12 @@ from models import AccountTrialAppRecord, App, InstalledApp, TrialApp
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
|
||||
def installed_app_required[**P, R](view: Callable[Concatenate[InstalledApp, P], R] | None = None):
|
||||
|
||||
def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | None = None):
|
||||
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||
@@ -45,7 +49,7 @@ def installed_app_required[**P, R](view: Callable[Concatenate[InstalledApp, P],
|
||||
return decorator
|
||||
|
||||
|
||||
def user_allowed_to_access_app[**P, R](view: Callable[Concatenate[InstalledApp, P], R] | None = None):
|
||||
def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] | None = None):
|
||||
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
|
||||
@@ -69,7 +73,7 @@ def user_allowed_to_access_app[**P, R](view: Callable[Concatenate[InstalledApp,
|
||||
return decorator
|
||||
|
||||
|
||||
def trial_app_required[**P, R](view: Callable[Concatenate[App, P], R] | None = None):
|
||||
def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
|
||||
def decorator(view: Callable[Concatenate[App, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||
@@ -102,7 +106,7 @@ def trial_app_required[**P, R](view: Callable[Concatenate[App, P], R] | None = N
|
||||
return decorator
|
||||
|
||||
|
||||
def trial_feature_enable[**P, R](view: Callable[P, R]):
|
||||
def trial_feature_enable(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
@@ -113,7 +117,7 @@ def trial_feature_enable[**P, R](view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def explore_banner_enabled[**P, R](view: Callable[P, R]):
|
||||
def explore_banner_enabled(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
|
||||
@@ -168,13 +168,12 @@ class ConsoleWorkflowEventsApi(Resource):
|
||||
else:
|
||||
msg_generator = MessageGenerator()
|
||||
generator: BaseAppGenerator
|
||||
match app.mode:
|
||||
case AppMode.ADVANCED_CHAT:
|
||||
generator = AdvancedChatAppGenerator()
|
||||
case AppMode.WORKFLOW:
|
||||
generator = WorkflowAppGenerator()
|
||||
case _:
|
||||
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
|
||||
if app.mode == AppMode.ADVANCED_CHAT:
|
||||
generator = AdvancedChatAppGenerator()
|
||||
elif app.mode == AppMode.WORKFLOW:
|
||||
generator = WorkflowAppGenerator()
|
||||
else:
|
||||
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
|
||||
|
||||
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -13,21 +11,6 @@ from services.billing_service import BillingService
|
||||
_FALLBACK_LANG = "en-US"
|
||||
|
||||
|
||||
class NotificationItemDict(TypedDict):
|
||||
notification_id: str | None
|
||||
frequency: str | None
|
||||
lang: str
|
||||
title: str
|
||||
subtitle: str
|
||||
body: str
|
||||
title_pic_url: str
|
||||
|
||||
|
||||
class NotificationResponseDict(TypedDict):
|
||||
should_show: bool
|
||||
notifications: list[NotificationItemDict]
|
||||
|
||||
|
||||
def _pick_lang_content(contents: dict, lang: str) -> dict:
|
||||
"""Return the single LangContent for *lang*, falling back to English."""
|
||||
return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {})
|
||||
@@ -62,30 +45,28 @@ class NotificationApi(Resource):
|
||||
result = BillingService.get_account_notification(str(current_user.id))
|
||||
|
||||
# Proto JSON uses camelCase field names (Kratos default marshaling).
|
||||
response: NotificationResponseDict
|
||||
if not result.get("shouldShow"):
|
||||
response = {"should_show": False, "notifications": []}
|
||||
return response, 200
|
||||
return {"should_show": False, "notifications": []}, 200
|
||||
|
||||
lang = current_user.interface_language or _FALLBACK_LANG
|
||||
|
||||
notifications: list[NotificationItemDict] = []
|
||||
notifications = []
|
||||
for notification in result.get("notifications") or []:
|
||||
contents: dict = notification.get("contents") or {}
|
||||
lang_content = _pick_lang_content(contents, lang)
|
||||
item: NotificationItemDict = {
|
||||
"notification_id": notification.get("notificationId"),
|
||||
"frequency": notification.get("frequency"),
|
||||
"lang": lang_content.get("lang", lang),
|
||||
"title": lang_content.get("title", ""),
|
||||
"subtitle": lang_content.get("subtitle", ""),
|
||||
"body": lang_content.get("body", ""),
|
||||
"title_pic_url": lang_content.get("titlePicUrl", ""),
|
||||
}
|
||||
notifications.append(item)
|
||||
notifications.append(
|
||||
{
|
||||
"notification_id": notification.get("notificationId"),
|
||||
"frequency": notification.get("frequency"),
|
||||
"lang": lang_content.get("lang", lang),
|
||||
"title": lang_content.get("title", ""),
|
||||
"subtitle": lang_content.get("subtitle", ""),
|
||||
"body": lang_content.get("body", ""),
|
||||
"title_pic_url": lang_content.get("titlePicUrl", ""),
|
||||
}
|
||||
)
|
||||
|
||||
response = {"should_show": bool(notifications), "notifications": notifications}
|
||||
return response, 200
|
||||
return {"should_show": bool(notifications), "notifications": notifications}, 200
|
||||
|
||||
|
||||
@console_ns.route("/notification/dismiss")
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
@@ -1,108 +0,0 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from flask import Request as FlaskRequest
|
||||
|
||||
from extensions.ext_socketio import sio
|
||||
from libs.passport import PassportService
|
||||
from libs.token import extract_access_token
|
||||
from repositories.workflow_collaboration_repository import WorkflowCollaborationRepository
|
||||
from services.account_service import AccountService
|
||||
from services.workflow_collaboration_service import WorkflowCollaborationService
|
||||
|
||||
repository = WorkflowCollaborationRepository()
|
||||
collaboration_service = WorkflowCollaborationService(repository, sio)
|
||||
|
||||
|
||||
def _sio_on(event: str) -> Callable[[Callable[..., object]], Callable[..., object]]:
|
||||
return cast(Callable[[Callable[..., object]], Callable[..., object]], sio.on(event))
|
||||
|
||||
|
||||
@_sio_on("connect")
|
||||
def socket_connect(sid, environ, auth):
|
||||
"""
|
||||
WebSocket connect event, do authentication here.
|
||||
"""
|
||||
try:
|
||||
request_environ = FlaskRequest(environ)
|
||||
token = extract_access_token(request_environ)
|
||||
except Exception:
|
||||
logging.exception("Failed to extract token")
|
||||
token = None
|
||||
|
||||
if not token:
|
||||
logging.warning("Socket connect rejected: missing token (sid=%s)", sid)
|
||||
return False
|
||||
|
||||
try:
|
||||
decoded = PassportService().verify(token)
|
||||
user_id = decoded.get("user_id")
|
||||
if not user_id:
|
||||
logging.warning("Socket connect rejected: missing user_id (sid=%s)", sid)
|
||||
return False
|
||||
|
||||
with sio.app.app_context():
|
||||
user = AccountService.load_logged_in_account(account_id=user_id)
|
||||
if not user:
|
||||
logging.warning("Socket connect rejected: user not found (user_id=%s, sid=%s)", user_id, sid)
|
||||
return False
|
||||
if not user.has_edit_permission:
|
||||
logging.warning("Socket connect rejected: no edit permission (user_id=%s, sid=%s)", user_id, sid)
|
||||
return False
|
||||
|
||||
collaboration_service.save_session(sid, user)
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
logging.exception("Socket authentication failed")
|
||||
return False
|
||||
|
||||
|
||||
@_sio_on("user_connect")
|
||||
def handle_user_connect(sid, data):
|
||||
"""
|
||||
Handle user connect event. Each session (tab) is treated as an independent collaborator.
|
||||
"""
|
||||
workflow_id = data.get("workflow_id")
|
||||
if not workflow_id:
|
||||
return {"msg": "workflow_id is required"}, 400
|
||||
|
||||
result = collaboration_service.register_session(workflow_id, sid)
|
||||
if not result:
|
||||
return {"msg": "unauthorized"}, 401
|
||||
|
||||
user_id, is_leader = result
|
||||
return {"msg": "connected", "user_id": user_id, "sid": sid, "isLeader": is_leader}
|
||||
|
||||
|
||||
@_sio_on("disconnect")
|
||||
def handle_disconnect(sid):
|
||||
"""
|
||||
Handle session disconnect event. Remove the specific session from online users.
|
||||
"""
|
||||
collaboration_service.disconnect_session(sid)
|
||||
|
||||
|
||||
@_sio_on("collaboration_event")
|
||||
def handle_collaboration_event(sid, data):
|
||||
"""
|
||||
Handle general collaboration events, include:
|
||||
1. mouse_move
|
||||
2. vars_and_features_update
|
||||
3. sync_request (ask leader to update graph)
|
||||
4. app_state_update
|
||||
5. mcp_server_update
|
||||
6. workflow_update
|
||||
7. comments_update
|
||||
8. node_panel_presence
|
||||
"""
|
||||
return collaboration_service.relay_collaboration_event(sid, data)
|
||||
|
||||
|
||||
@_sio_on("graph_event")
|
||||
def handle_graph_event(sid, data):
|
||||
"""
|
||||
Handle graph events - simple broadcast relay.
|
||||
"""
|
||||
return collaboration_service.relay_graph_event(sid, data)
|
||||
@@ -9,14 +9,7 @@ from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import TagType
|
||||
from services.tag_service import (
|
||||
SaveTagPayload,
|
||||
TagBindingCreatePayload,
|
||||
TagBindingDeletePayload,
|
||||
TagService,
|
||||
UpdateTagPayload,
|
||||
)
|
||||
from services.tag_service import TagService
|
||||
|
||||
dataset_tag_fields = {
|
||||
"id": fields.String,
|
||||
@@ -32,19 +25,19 @@ def build_dataset_tag_fields(api_or_ns: Namespace):
|
||||
|
||||
class TagBasePayload(BaseModel):
|
||||
name: str = Field(description="Tag name", min_length=1, max_length=50)
|
||||
type: TagType = Field(description="Tag type")
|
||||
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
|
||||
|
||||
|
||||
class TagBindingPayload(BaseModel):
|
||||
tag_ids: list[str] = Field(description="Tag IDs to bind")
|
||||
target_id: str = Field(description="Target ID to bind tags to")
|
||||
type: TagType = Field(description="Tag type")
|
||||
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
|
||||
|
||||
|
||||
class TagBindingRemovePayload(BaseModel):
|
||||
tag_id: str = Field(description="Tag ID to remove")
|
||||
target_id: str = Field(description="Target ID to unbind tag from")
|
||||
type: TagType = Field(description="Tag type")
|
||||
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
|
||||
|
||||
|
||||
class TagListQueryParam(BaseModel):
|
||||
@@ -89,7 +82,7 @@ class TagListApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagBasePayload.model_validate(console_ns.payload or {})
|
||||
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=payload.type))
|
||||
tag = TagService.save_tags(payload.model_dump())
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
|
||||
@@ -110,7 +103,7 @@ class TagUpdateDeleteApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagBasePayload.model_validate(console_ns.payload or {})
|
||||
tag = TagService.update_tags(UpdateTagPayload(name=payload.name, type=payload.type), tag_id)
|
||||
tag = TagService.update_tags(payload.model_dump(), tag_id)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
@@ -143,9 +136,7 @@ class TagBindingCreateApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagBindingPayload.model_validate(console_ns.payload or {})
|
||||
TagService.save_tag_binding(
|
||||
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=payload.type)
|
||||
)
|
||||
TagService.save_tag_binding(payload.model_dump())
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@@ -163,8 +154,6 @@ class TagBindingDeleteApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
|
||||
TagService.delete_tag_binding(
|
||||
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=payload.type)
|
||||
)
|
||||
TagService.delete_tag_binding(payload.model_dump())
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@@ -1,33 +1,36 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.account import TenantPluginPermission
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def plugin_permission_required(
|
||||
install_required: bool = False,
|
||||
debug_required: bool = False,
|
||||
):
|
||||
def interceptor[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
user = current_user
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
permission = session.scalar(
|
||||
select(TenantPluginPermission)
|
||||
with Session(db.engine) as session:
|
||||
permission = (
|
||||
session.query(TenantPluginPermission)
|
||||
.where(
|
||||
TenantPluginPermission.tenant_id == tenant_id,
|
||||
)
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not permission:
|
||||
|
||||
@@ -6,10 +6,9 @@ from typing import Literal
|
||||
import pytz
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from graphon.file import helpers as file_helpers
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
@@ -76,10 +75,6 @@ class AccountAvatarPayload(BaseModel):
|
||||
avatar: str
|
||||
|
||||
|
||||
class AccountAvatarQuery(BaseModel):
|
||||
avatar: str = Field(..., description="Avatar file ID")
|
||||
|
||||
|
||||
class AccountInterfaceLanguagePayload(BaseModel):
|
||||
interface_language: str
|
||||
|
||||
@@ -165,7 +160,6 @@ def reg(cls: type[BaseModel]):
|
||||
reg(AccountInitPayload)
|
||||
reg(AccountNamePayload)
|
||||
reg(AccountAvatarPayload)
|
||||
reg(AccountAvatarQuery)
|
||||
reg(AccountInterfaceLanguagePayload)
|
||||
reg(AccountInterfaceThemePayload)
|
||||
reg(AccountTimezonePayload)
|
||||
@@ -275,18 +269,6 @@ class AccountNameApi(Resource):
|
||||
|
||||
@console_ns.route("/account/avatar")
|
||||
class AccountAvatarApi(Resource):
|
||||
@console_ns.expect(console_ns.models[AccountAvatarQuery.__name__])
|
||||
@console_ns.doc("get_account_avatar")
|
||||
@console_ns.doc(description="Get account avatar url")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
args = AccountAvatarQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
avatar_url = file_helpers.get_signed_file_url(args.avatar)
|
||||
return {"avatar_url": avatar_url}
|
||||
|
||||
@console_ns.expect(console_ns.models[AccountAvatarPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -537,7 +519,7 @@ class EducationAutoCompleteApi(Resource):
|
||||
@cloud_edition_billing_enabled
|
||||
@marshal_with(data_fields)
|
||||
def get(self):
|
||||
payload = request.args.to_dict(flat=True)
|
||||
payload = request.args.to_dict(flat=True) # type: ignore
|
||||
args = EducationAutocompleteQuery.model_validate(payload)
|
||||
|
||||
return BillingService.EducationIdentity.autocomplete(args.keywords, args.page, args.limit)
|
||||
@@ -580,7 +562,7 @@ class ChangeEmailSendEmailApi(Resource):
|
||||
|
||||
user_email = current_user.email
|
||||
else:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
|
||||
if account is None:
|
||||
raise AccountNotFound()
|
||||
|
||||
@@ -99,7 +99,7 @@ class ModelProviderListApi(Resource):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
payload = request.args.to_dict(flat=True)
|
||||
payload = request.args.to_dict(flat=True) # type: ignore
|
||||
args = ParserModelList.model_validate(payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@@ -118,7 +118,7 @@ class ModelProviderCredentialApi(Resource):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
tenant_id = current_tenant_id
|
||||
# if credential_id is not provided, return current used credential
|
||||
payload = request.args.to_dict(flat=True)
|
||||
payload = request.args.to_dict(flat=True) # type: ignore
|
||||
args = ParserCredentialId.model_validate(payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
@@ -287,10 +287,12 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
provider=provider,
|
||||
)
|
||||
else:
|
||||
# Normalize model_type to the origin value stored in DB (e.g., "text-generation" for LLM)
|
||||
normalized_model_type = args.model_type.to_origin_model_type()
|
||||
available_credentials = model_provider_service.get_provider_model_available_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model_type=args.model_type,
|
||||
model_type=normalized_model_type,
|
||||
model=args.model,
|
||||
)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from flask import make_response, redirect, request, send_file
|
||||
from flask_restx import Resource
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel, Field, HttpUrl, field_validator, model_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
@@ -832,8 +832,7 @@ class ToolOAuthCallback(Resource):
|
||||
tool_provider = ToolProviderID(provider)
|
||||
plugin_id = tool_provider.plugin_id
|
||||
provider_name = tool_provider.provider_name
|
||||
user_id: str = context["user_id"]
|
||||
tenant_id: str = context["tenant_id"]
|
||||
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
|
||||
|
||||
oauth_handler = OAuthHandler()
|
||||
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id, provider)
|
||||
@@ -1019,7 +1018,7 @@ class ToolProviderMCPApi(Resource):
|
||||
|
||||
# Step 1: Get provider data for URL validation (short-lived session, no network I/O)
|
||||
validation_data = None
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
validation_data = service.get_provider_for_url_validation(
|
||||
tenant_id=current_tenant_id, provider_id=payload.provider_id
|
||||
@@ -1034,7 +1033,7 @@ class ToolProviderMCPApi(Resource):
|
||||
)
|
||||
|
||||
# Step 3: Perform database update in a transaction
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.update_provider(
|
||||
tenant_id=current_tenant_id,
|
||||
@@ -1061,7 +1060,7 @@ class ToolProviderMCPApi(Resource):
|
||||
payload = MCPProviderDeletePayload.model_validate(console_ns.payload or {})
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.delete_provider(tenant_id=current_tenant_id, provider_id=payload.provider_id)
|
||||
|
||||
@@ -1079,7 +1078,7 @@ class ToolMCPAuthApi(Resource):
|
||||
provider_id = payload.provider_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
if not db_provider:
|
||||
@@ -1100,7 +1099,7 @@ class ToolMCPAuthApi(Resource):
|
||||
sse_read_timeout=provider_entity.sse_read_timeout,
|
||||
):
|
||||
# Update credentials in new transaction
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.update_provider_credentials(
|
||||
provider_id=provider_id,
|
||||
@@ -1118,17 +1117,17 @@ class ToolMCPAuthApi(Resource):
|
||||
resource_metadata_url=e.resource_metadata_url,
|
||||
scope_hint=e.scope_hint,
|
||||
)
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
response = service.execute_auth_actions(auth_result)
|
||||
return response
|
||||
except MCPRefreshTokenError as e:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
|
||||
except (MCPError, ValueError) as e:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
||||
@@ -1141,7 +1140,7 @@ class ToolMCPDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, provider_id):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
|
||||
@@ -1155,7 +1154,7 @@ class ToolMCPListAllApi(Resource):
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
# Skip sensitive data decryption for list view to improve performance
|
||||
tools = service.list_providers(tenant_id=tenant_id, include_sensitive=False)
|
||||
@@ -1170,7 +1169,7 @@ class ToolMCPUpdateApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, provider_id):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
tools = service.list_provider_tools(
|
||||
tenant_id=tenant_id,
|
||||
@@ -1188,7 +1187,7 @@ class ToolMCPCallbackApi(Resource):
|
||||
authorization_code = query.code
|
||||
|
||||
# Create service instance for handle_callback
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
# handle_callback now returns state data and tokens
|
||||
state_data, tokens = handle_callback(state_key, authorization_code)
|
||||
|
||||
@@ -5,7 +5,7 @@ from flask import make_response, redirect, request
|
||||
from flask_restx import Resource
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel, model_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
@@ -375,7 +375,7 @@ class TriggerSubscriptionDeleteApi(Resource):
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
try:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
# Delete trigger provider subscription
|
||||
TriggerProviderService.delete_trigger_provider(
|
||||
session=session,
|
||||
@@ -388,6 +388,7 @@ class TriggerSubscriptionDeleteApi(Resource):
|
||||
tenant_id=user.current_tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
)
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
@@ -498,9 +499,9 @@ class TriggerOAuthCallbackApi(Resource):
|
||||
provider_id = TriggerProviderID(provider)
|
||||
plugin_id = provider_id.plugin_id
|
||||
provider_name = provider_id.provider_name
|
||||
user_id: str = context["user_id"]
|
||||
tenant_id: str = context["tenant_id"]
|
||||
subscription_builder_id: str = context["subscription_builder_id"]
|
||||
user_id = context.get("user_id")
|
||||
tenant_id = context.get("tenant_id")
|
||||
subscription_builder_id = context.get("subscription_builder_id")
|
||||
|
||||
# Get OAuth client configuration
|
||||
oauth_client_params = TriggerProviderService.get_oauth_client(
|
||||
|
||||
@@ -28,7 +28,7 @@ from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import Tenant, TenantCustomConfigDict, TenantStatus
|
||||
from models.account import Tenant, TenantStatus
|
||||
from services.account_service import TenantService
|
||||
from services.billing_service import BillingService, SubscriptionPlan
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
@@ -155,7 +155,7 @@ class WorkspaceListApi(Resource):
|
||||
@setup_required
|
||||
@admin_required
|
||||
def get(self):
|
||||
payload = request.args.to_dict(flat=True)
|
||||
payload = request.args.to_dict(flat=True) # type: ignore
|
||||
args = WorkspaceListQuery.model_validate(payload)
|
||||
|
||||
stmt = select(Tenant).order_by(Tenant.created_at.desc())
|
||||
@@ -240,10 +240,8 @@ class CustomConfigWorkspaceApi(Resource):
|
||||
args = WorkspaceCustomConfigPayload.model_validate(payload)
|
||||
tenant = db.get_or_404(Tenant, current_tenant_id)
|
||||
|
||||
custom_config_dict: TenantCustomConfigDict = {
|
||||
"remove_webapp_brand": args.remove_webapp_brand
|
||||
if args.remove_webapp_brand is not None
|
||||
else tenant.custom_config_dict.get("remove_webapp_brand", False),
|
||||
custom_config_dict = {
|
||||
"remove_webapp_brand": args.remove_webapp_brand,
|
||||
"replace_webapp_logo": args.replace_webapp_logo
|
||||
if args.replace_webapp_logo is not None
|
||||
else tenant.custom_config_dict.get("replace_webapp_logo"),
|
||||
|
||||
@@ -4,6 +4,7 @@ import os
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import abort, request
|
||||
from sqlalchemy import select
|
||||
@@ -24,6 +25,9 @@ from services.operation_service import OperationService
|
||||
|
||||
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
# Field names for decryption
|
||||
FIELD_NAME_PASSWORD = "password"
|
||||
FIELD_NAME_CODE = "code"
|
||||
@@ -33,7 +37,7 @@ ERROR_MSG_INVALID_ENCRYPTED_DATA = "Invalid encrypted data"
|
||||
ERROR_MSG_INVALID_ENCRYPTED_CODE = "Invalid encrypted code"
|
||||
|
||||
|
||||
def account_initialization_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
def account_initialization_required(view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
# check account initialization
|
||||
@@ -46,7 +50,7 @@ def account_initialization_required[**P, R](view: Callable[P, R]) -> Callable[P,
|
||||
return decorated
|
||||
|
||||
|
||||
def only_edition_cloud[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
def only_edition_cloud(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if dify_config.EDITION != "CLOUD":
|
||||
@@ -57,7 +61,7 @@ def only_edition_cloud[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
return decorated
|
||||
|
||||
|
||||
def only_edition_enterprise[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
def only_edition_enterprise(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if not dify_config.ENTERPRISE_ENABLED:
|
||||
@@ -68,7 +72,7 @@ def only_edition_enterprise[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
return decorated
|
||||
|
||||
|
||||
def only_edition_self_hosted[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
def only_edition_self_hosted(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if dify_config.EDITION != "SELF_HOSTED":
|
||||
@@ -79,7 +83,7 @@ def only_edition_self_hosted[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
return decorated
|
||||
|
||||
|
||||
def cloud_edition_billing_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
def cloud_edition_billing_enabled(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@@ -91,7 +95,7 @@ def cloud_edition_billing_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R
|
||||
return decorated
|
||||
|
||||
|
||||
def cloud_edition_billing_resource_check[**P, R](resource: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def cloud_edition_billing_resource_check(resource: str):
|
||||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
@@ -133,9 +137,7 @@ def cloud_edition_billing_resource_check[**P, R](resource: str) -> Callable[[Cal
|
||||
return interceptor
|
||||
|
||||
|
||||
def cloud_edition_billing_knowledge_limit_check[**P, R](
|
||||
resource: str,
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def cloud_edition_billing_knowledge_limit_check(resource: str):
|
||||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
@@ -158,7 +160,7 @@ def cloud_edition_billing_knowledge_limit_check[**P, R](
|
||||
return interceptor
|
||||
|
||||
|
||||
def cloud_edition_billing_rate_limit_check[**P, R](resource: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def cloud_edition_billing_rate_limit_check(resource: str):
|
||||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
@@ -194,7 +196,7 @@ def cloud_edition_billing_rate_limit_check[**P, R](resource: str) -> Callable[[C
|
||||
return interceptor
|
||||
|
||||
|
||||
def cloud_utm_record[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
def cloud_utm_record(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
with contextlib.suppress(Exception):
|
||||
@@ -213,7 +215,7 @@ def cloud_utm_record[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
return decorated
|
||||
|
||||
|
||||
def setup_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
def setup_required(view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
# check setup
|
||||
@@ -227,7 +229,7 @@ def setup_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
return decorated
|
||||
|
||||
|
||||
def enterprise_license_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
def enterprise_license_required(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
settings = FeatureService.get_system_features()
|
||||
@@ -239,7 +241,7 @@ def enterprise_license_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
return decorated
|
||||
|
||||
|
||||
def email_password_login_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
def email_password_login_enabled(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
@@ -252,7 +254,7 @@ def email_password_login_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]
|
||||
return decorated
|
||||
|
||||
|
||||
def email_register_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
def email_register_enabled(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
@@ -265,7 +267,7 @@ def email_register_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
return decorated
|
||||
|
||||
|
||||
def enable_change_email[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
def enable_change_email(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
@@ -278,7 +280,7 @@ def enable_change_email[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
return decorated
|
||||
|
||||
|
||||
def is_allow_transfer_owner[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
def is_allow_transfer_owner(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
from libs.workspace_permission import check_workspace_owner_transfer_permission
|
||||
@@ -291,7 +293,7 @@ def is_allow_transfer_owner[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
return decorated
|
||||
|
||||
|
||||
def knowledge_pipeline_publish_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
def knowledge_pipeline_publish_enabled(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@@ -303,7 +305,7 @@ def knowledge_pipeline_publish_enabled[**P, R](view: Callable[P, R]) -> Callable
|
||||
return decorated
|
||||
|
||||
|
||||
def edit_permission_required[**P, R](f: Callable[P, R]) -> Callable[P, R]:
|
||||
def edit_permission_required(f: Callable[P, R]):
|
||||
@wraps(f)
|
||||
def decorated_function(*args: P.args, **kwargs: P.kwargs):
|
||||
from werkzeug.exceptions import Forbidden
|
||||
@@ -321,7 +323,7 @@ def edit_permission_required[**P, R](f: Callable[P, R]) -> Callable[P, R]:
|
||||
return decorated_function
|
||||
|
||||
|
||||
def is_admin_or_owner_required[**P, R](f: Callable[P, R]) -> Callable[P, R]:
|
||||
def is_admin_or_owner_required(f: Callable[P, R]):
|
||||
@wraps(f)
|
||||
def decorated_function(*args: P.args, **kwargs: P.kwargs):
|
||||
from werkzeug.exceptions import Forbidden
|
||||
@@ -337,7 +339,7 @@ def is_admin_or_owner_required[**P, R](f: Callable[P, R]) -> Callable[P, R]:
|
||||
return decorated_function
|
||||
|
||||
|
||||
def annotation_import_rate_limit[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
def annotation_import_rate_limit(view: Callable[P, R]):
|
||||
"""
|
||||
Rate limiting decorator for annotation import operations.
|
||||
|
||||
@@ -386,7 +388,7 @@ def annotation_import_rate_limit[**P, R](view: Callable[P, R]) -> Callable[P, R]
|
||||
return decorated
|
||||
|
||||
|
||||
def annotation_import_concurrency_limit[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
def annotation_import_concurrency_limit(view: Callable[P, R]):
|
||||
"""
|
||||
Concurrency control decorator for annotation import operations.
|
||||
|
||||
@@ -453,7 +455,7 @@ def _decrypt_field(field_name: str, error_class: type[Exception], error_message:
|
||||
payload[field_name] = decoded_value
|
||||
|
||||
|
||||
def decrypt_password_field[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
def decrypt_password_field(view: Callable[P, R]):
|
||||
"""
|
||||
Decorator to decrypt password field in request payload.
|
||||
|
||||
@@ -475,7 +477,7 @@ def decrypt_password_field[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
return decorated
|
||||
|
||||
|
||||
def decrypt_code_field[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
def decrypt_code_field(view: Callable[P, R]):
|
||||
"""
|
||||
Decorator to decrypt verification code field in request payload.
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console.wraps import setup_required
|
||||
@@ -18,8 +18,7 @@ from controllers.inner_api.wraps import enterprise_inner_api_only
|
||||
from extensions.ext_database import db
|
||||
from models import Account, App
|
||||
from models.account import AccountStatus
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.entities.dsl_entities import ImportMode, ImportStatus
|
||||
from services.app_dsl_service import AppDslService, ImportMode, ImportStatus
|
||||
|
||||
|
||||
class InnerAppDSLImportPayload(BaseModel):
|
||||
@@ -56,7 +55,7 @@ class EnterpriseAppDSLImport(Resource):
|
||||
|
||||
account.set_tenant_id(workspace_id)
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
dsl_service = AppDslService(session)
|
||||
result = dsl_service.import_app(
|
||||
account=account,
|
||||
@@ -65,6 +64,7 @@ class EnterpriseAppDSLImport(Resource):
|
||||
name=args.name,
|
||||
description=args.description,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
if result.status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
|
||||
@@ -1,17 +1,21 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from models.account import Tenant
|
||||
from models.model import DefaultEndUserSessionID, EndUser
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class TenantUserPayload(BaseModel):
|
||||
tenant_id: str
|
||||
@@ -29,7 +33,7 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
try:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
user_model = None
|
||||
|
||||
if is_anonymous:
|
||||
@@ -52,7 +56,7 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
||||
session_id=user_id,
|
||||
)
|
||||
session.add(user_model)
|
||||
session.flush()
|
||||
session.commit()
|
||||
session.refresh(user_model)
|
||||
|
||||
except Exception:
|
||||
@@ -61,9 +65,9 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
||||
return user_model
|
||||
|
||||
|
||||
def get_user_tenant[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
def get_user_tenant(view_func: Callable[P, R]):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
payload = TenantUserPayload.model_validate(request.get_json(silent=True) or {})
|
||||
|
||||
user_id = payload.user_id
|
||||
@@ -93,14 +97,10 @@ def get_user_tenant[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
return decorated_view
|
||||
|
||||
|
||||
def plugin_data[**P, R](
|
||||
view: Callable[P, R] | None = None,
|
||||
*,
|
||||
payload_type: type[BaseModel],
|
||||
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]):
|
||||
def decorator(view_func: Callable[P, R]):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
try:
|
||||
data = request.get_json()
|
||||
except Exception:
|
||||
|
||||
@@ -3,7 +3,10 @@ from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from hashlib import sha1
|
||||
from hmac import new as hmac_new
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
from flask import abort, request
|
||||
|
||||
from configs import dify_config
|
||||
@@ -11,9 +14,9 @@ from extensions.ext_database import db
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
def billing_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
def billing_inner_api_only(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if not dify_config.INNER_API:
|
||||
abort(404)
|
||||
|
||||
@@ -27,9 +30,9 @@ def billing_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
return decorated
|
||||
|
||||
|
||||
def enterprise_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
def enterprise_inner_api_only(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if not dify_config.INNER_API:
|
||||
abort(404)
|
||||
|
||||
@@ -43,9 +46,9 @@ def enterprise_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
return decorated
|
||||
|
||||
|
||||
def enterprise_inner_api_user_auth[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
def enterprise_inner_api_user_auth(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if not dify_config.INNER_API:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@@ -79,9 +82,9 @@ def enterprise_inner_api_user_auth[**P, R](view: Callable[P, R]) -> Callable[P,
|
||||
return decorated
|
||||
|
||||
|
||||
def plugin_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
def plugin_inner_api_only(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if not dify_config.PLUGIN_DAEMON_KEY:
|
||||
abort(404)
|
||||
|
||||
|
||||
@@ -4,8 +4,7 @@ from flask import Response
|
||||
from flask_restx import Resource
|
||||
from graphon.variables.input_entities import VariableEntity
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.mcp import mcp_ns
|
||||
@@ -68,7 +67,7 @@ class MCPAppApi(Resource):
|
||||
request_id: Union[int, str] | None = args.id
|
||||
mcp_request = self._parse_mcp_request(args.model_dump(exclude_none=True))
|
||||
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get MCP server and app
|
||||
mcp_server, app = self._get_mcp_server_and_app(server_code, session)
|
||||
self._validate_server_status(mcp_server)
|
||||
@@ -81,11 +80,11 @@ class MCPAppApi(Resource):
|
||||
|
||||
def _get_mcp_server_and_app(self, server_code: str, session: Session) -> tuple[AppMCPServer, App]:
|
||||
"""Get and validate MCP server and app in one query session"""
|
||||
mcp_server = session.scalar(select(AppMCPServer).where(AppMCPServer.server_code == server_code).limit(1))
|
||||
mcp_server = session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
|
||||
if not mcp_server:
|
||||
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server Not Found")
|
||||
|
||||
app = session.scalar(select(App).where(App.id == mcp_server.app_id).limit(1))
|
||||
app = session.query(App).where(App.id == mcp_server.app_id).first()
|
||||
if not app:
|
||||
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App Not Found")
|
||||
|
||||
@@ -175,7 +174,6 @@ class MCPAppApi(Resource):
|
||||
required=variable.get("required", False),
|
||||
max_length=variable.get("max_length"),
|
||||
options=variable.get("options") or [],
|
||||
json_schema=variable.get("json_schema"),
|
||||
)
|
||||
|
||||
def _parse_mcp_request(self, args: dict) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
|
||||
@@ -190,13 +188,13 @@ class MCPAppApi(Resource):
|
||||
|
||||
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str) -> EndUser | None:
|
||||
"""Get end user - manages its own database session"""
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
return session.scalar(
|
||||
select(EndUser)
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
return (
|
||||
session.query(EndUser)
|
||||
.where(EndUser.tenant_id == tenant_id)
|
||||
.where(EndUser.session_id == mcp_server_id)
|
||||
.where(EndUser.type == "mcp")
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
|
||||
def _create_end_user(
|
||||
@@ -230,7 +228,9 @@ class MCPAppApi(Resource):
|
||||
if not end_user and isinstance(mcp_request.root, mcp_types.InitializeRequest):
|
||||
client_info = mcp_request.root.params.clientInfo
|
||||
client_name = f"{client_info.name}@{client_info.version}"
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as create_session:
|
||||
# Commit the session before creating end user to avoid transaction conflicts
|
||||
session.commit()
|
||||
with Session(db.engine, expire_on_commit=False) as create_session, create_session.begin():
|
||||
end_user = self._create_end_user(client_name, app.tenant_id, app.id, mcp_server.id, create_session)
|
||||
|
||||
return handle_mcp_request(app, mcp_request, user_input_form, mcp_server, end_user, request_id)
|
||||
|
||||
@@ -2,12 +2,11 @@ from typing import Any, Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import ConversationRenamePayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
@@ -35,6 +34,18 @@ class ConversationListQuery(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class ConversationRenamePayload(BaseModel):
|
||||
name: str | None = Field(default=None, description="New conversation name (required if auto_generate is false)")
|
||||
auto_generate: bool = Field(default=False, description="Auto-generate conversation name")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_name_requirement(self):
|
||||
if not self.auto_generate:
|
||||
if self.name is None or not self.name.strip():
|
||||
raise ValueError("name is required when auto_generate is false")
|
||||
return self
|
||||
|
||||
|
||||
class ConversationVariablesQuery(BaseModel):
|
||||
last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last variable ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
|
||||
@@ -105,7 +116,7 @@ class ConversationApi(Resource):
|
||||
last_id = str(query_args.last_id) if query_args.last_id else None
|
||||
|
||||
try:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
pagination = ConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@@ -6,7 +7,6 @@ from pydantic import BaseModel, Field, TypeAdapter
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import MessageFeedbackPayload, MessageListQuery
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
@@ -14,6 +14,7 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from models.enums import FeedbackRating
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.errors.message import (
|
||||
@@ -26,6 +27,17 @@ from services.message_service import MessageService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageListQuery(BaseModel):
|
||||
conversation_id: UUIDStrOrEmpty
|
||||
first_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return")
|
||||
|
||||
|
||||
class MessageFeedbackPayload(BaseModel):
|
||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
|
||||
content: str | None = Field(default=None, description="Feedback content")
|
||||
|
||||
|
||||
class FeedbackListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, description="Page number")
|
||||
limit: int = Field(default=20, ge=1, le=101, description="Number of feedbacks per page")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from flask import request
|
||||
@@ -8,10 +8,9 @@ from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
from controllers.common.controller_schemas import WorkflowRunPayload as WorkflowRunPayloadBase
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import (
|
||||
@@ -47,7 +46,9 @@ from services.workflow_app_service import WorkflowAppService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowRunPayload(WorkflowRunPayloadBase):
|
||||
class WorkflowRunPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
files: list[dict[str, Any]] | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
|
||||
|
||||
@@ -313,7 +314,7 @@ class WorkflowAppLogApi(Resource):
|
||||
|
||||
# get paginate workflow app logs
|
||||
workflow_app_service = WorkflowAppService()
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
|
||||
@@ -22,17 +22,10 @@ from fields.tag_fields import DataSetTag
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from models.enums import TagType
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
||||
from services.tag_service import (
|
||||
SaveTagPayload,
|
||||
TagBindingCreatePayload,
|
||||
TagBindingDeletePayload,
|
||||
TagService,
|
||||
UpdateTagPayload,
|
||||
)
|
||||
from services.tag_service import TagService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
@@ -520,7 +513,7 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=TagType.KNOWLEDGE))
|
||||
tag = TagService.save_tags({"name": payload.name, "type": "knowledge"})
|
||||
|
||||
response = DataSetTag.model_validate(
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
@@ -543,8 +536,9 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||
params = {"name": payload.name, "type": "knowledge"}
|
||||
tag_id = payload.tag_id
|
||||
tag = TagService.update_tags(UpdateTagPayload(name=payload.name, type=TagType.KNOWLEDGE), tag_id)
|
||||
tag = TagService.update_tags(params, tag_id)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
@@ -591,9 +585,7 @@ class DatasetTagBindingApi(DatasetApiResource):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagBindingPayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.save_tag_binding(
|
||||
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE)
|
||||
)
|
||||
TagService.save_tag_binding({"tag_ids": payload.tag_ids, "target_id": payload.target_id, "type": "knowledge"})
|
||||
|
||||
return "", 204
|
||||
|
||||
@@ -617,9 +609,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.delete_tag_binding(
|
||||
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=TagType.KNOWLEDGE)
|
||||
)
|
||||
TagService.delete_tag_binding({"tag_id": payload.tag_id, "target_id": payload.target_id, "type": "knowledge"})
|
||||
|
||||
return "", 204
|
||||
|
||||
|
||||
@@ -31,7 +31,6 @@ from controllers.service_api.wraps import (
|
||||
cloud_edition_billing_resource_check,
|
||||
)
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
from core.rag.entities import PreProcessingRule, Rule, Segmentation
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from fields.document_fields import document_fields, document_status_fields
|
||||
@@ -41,8 +40,11 @@ from models.enums import SegmentStatus
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
KnowledgeConfig,
|
||||
PreProcessingRule,
|
||||
ProcessRule,
|
||||
RetrievalModel,
|
||||
Rule,
|
||||
Segmentation,
|
||||
)
|
||||
from services.file_service import FileService
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
@@ -4,23 +4,13 @@ Serialization helpers for Service API knowledge pipeline endpoints.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.model import UploadFile
|
||||
|
||||
|
||||
class UploadFileDict(TypedDict):
|
||||
id: str
|
||||
name: str
|
||||
size: int
|
||||
extension: str
|
||||
mime_type: str | None
|
||||
created_by: str
|
||||
created_at: str | None
|
||||
|
||||
|
||||
def serialize_upload_file(upload_file: UploadFile) -> UploadFileDict:
|
||||
def serialize_upload_file(upload_file: UploadFile) -> dict[str, Any]:
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
|
||||
@@ -29,31 +29,6 @@ from services.entities.knowledge_entities.knowledge_entities import SegmentUpdat
|
||||
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
|
||||
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
|
||||
from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
|
||||
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict:
|
||||
"""Marshal a single segment and enrich it with summary content."""
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
|
||||
segment_dict["summary"] = summary.summary_content if summary else None
|
||||
return segment_dict
|
||||
|
||||
|
||||
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict]:
|
||||
"""Marshal multiple segments and enrich them with summary content (batch query)."""
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
summaries: dict = {}
|
||||
if segment_ids:
|
||||
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
|
||||
summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()}
|
||||
|
||||
result = []
|
||||
for segment in segments:
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
segment_dict["summary"] = summaries.get(segment.id)
|
||||
result.append(segment_dict)
|
||||
return result
|
||||
|
||||
|
||||
class SegmentCreatePayload(BaseModel):
|
||||
@@ -157,7 +132,7 @@ class SegmentApi(DatasetApiResource):
|
||||
for args_item in payload.segments:
|
||||
SegmentService.segment_create_args_validate(args_item, document)
|
||||
segments = SegmentService.multi_create_segment(payload.segments, document, dataset)
|
||||
return {"data": _marshal_segments_with_summary(segments, dataset_id), "doc_form": document.doc_form}, 200
|
||||
return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200
|
||||
else:
|
||||
return {"error": "Segments is required"}, 400
|
||||
|
||||
@@ -221,7 +196,7 @@ class SegmentApi(DatasetApiResource):
|
||||
)
|
||||
|
||||
response = {
|
||||
"data": _marshal_segments_with_summary(segments, dataset_id),
|
||||
"data": marshal(segments, segment_fields),
|
||||
"doc_form": document.doc_form,
|
||||
"total": total,
|
||||
"has_more": len(segments) == limit,
|
||||
@@ -321,7 +296,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset)
|
||||
return {"data": _marshal_segment_with_summary(updated_segment, dataset_id), "doc_form": document.doc_form}, 200
|
||||
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
@service_api_ns.doc("get_segment")
|
||||
@service_api_ns.doc(description="Get a specific segment by ID")
|
||||
@@ -351,7 +326,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
return {"data": _marshal_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
|
||||
@service_api_ns.route(
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import inspect
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from enum import StrEnum, auto
|
||||
from functools import wraps
|
||||
from typing import cast, overload
|
||||
from typing import Concatenate, ParamSpec, TypeVar, cast, overload
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
@@ -24,6 +23,10 @@ from services.api_token_service import ApiTokenCache, fetch_token_with_single_fl
|
||||
from services.end_user_service import EndUserService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -43,16 +46,16 @@ class FetchUserArg(BaseModel):
|
||||
|
||||
|
||||
@overload
|
||||
def validate_app_token[**P, R](view: Callable[P, R]) -> Callable[P, R]: ...
|
||||
def validate_app_token(view: Callable[P, R]) -> Callable[P, R]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def validate_app_token[**P, R](
|
||||
def validate_app_token(
|
||||
view: None = None, *, fetch_user_arg: FetchUserArg | None = None
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
|
||||
|
||||
|
||||
def validate_app_token[**P, R](
|
||||
def validate_app_token(
|
||||
view: Callable[P, R] | None = None, *, fetch_user_arg: FetchUserArg | None = None
|
||||
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
@@ -133,10 +136,7 @@ def validate_app_token[**P, R](
|
||||
return decorator(view)
|
||||
|
||||
|
||||
def cloud_edition_billing_resource_check[**P, R](
|
||||
resource: str,
|
||||
api_token_type: str,
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
|
||||
def interceptor(view: Callable[P, R]):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
api_token = validate_and_get_api_token(api_token_type)
|
||||
@@ -166,10 +166,7 @@ def cloud_edition_billing_resource_check[**P, R](
|
||||
return interceptor
|
||||
|
||||
|
||||
def cloud_edition_billing_knowledge_limit_check[**P, R](
|
||||
resource: str,
|
||||
api_token_type: str,
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str):
|
||||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
@@ -191,10 +188,7 @@ def cloud_edition_billing_knowledge_limit_check[**P, R](
|
||||
return interceptor
|
||||
|
||||
|
||||
def cloud_edition_billing_rate_limit_check[**P, R](
|
||||
resource: str,
|
||||
api_token_type: str,
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
|
||||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
@@ -231,73 +225,99 @@ def cloud_edition_billing_rate_limit_check[**P, R](
|
||||
return interceptor
|
||||
|
||||
|
||||
def validate_dataset_token[R](view: Callable[..., R]) -> Callable[..., R]:
|
||||
positional_parameters = [
|
||||
parameter
|
||||
for parameter in inspect.signature(view).parameters.values()
|
||||
if parameter.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
|
||||
]
|
||||
expects_bound_instance = bool(positional_parameters and positional_parameters[0].name in {"self", "cls"})
|
||||
@overload
|
||||
def validate_dataset_token(view: Callable[Concatenate[T, P], R]) -> Callable[P, R]: ...
|
||||
|
||||
@wraps(view)
|
||||
def decorated(*args: object, **kwargs: object) -> R:
|
||||
api_token = validate_and_get_api_token("dataset")
|
||||
|
||||
# Flask may pass URL path parameters positionally, so inspect both kwargs and args.
|
||||
dataset_id = kwargs.get("dataset_id")
|
||||
@overload
|
||||
def validate_dataset_token(view: None = None) -> Callable[[Callable[Concatenate[T, P], R]], Callable[P, R]]: ...
|
||||
|
||||
if not dataset_id and args:
|
||||
potential_id = args[0]
|
||||
try:
|
||||
str_id = str(potential_id)
|
||||
if len(str_id) == 36 and str_id.count("-") == 4:
|
||||
dataset_id = str_id
|
||||
except Exception:
|
||||
logger.exception("Failed to parse dataset_id from positional args")
|
||||
|
||||
if dataset_id:
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset)
|
||||
.where(
|
||||
Dataset.id == dataset_id,
|
||||
Dataset.tenant_id == api_token.tenant_id,
|
||||
def validate_dataset_token(
|
||||
view: Callable[Concatenate[T, P], R] | None = None,
|
||||
) -> Callable[P, R] | Callable[[Callable[Concatenate[T, P], R]], Callable[P, R]]:
|
||||
def decorator(view_func: Callable[Concatenate[T, P], R]) -> Callable[P, R]:
|
||||
@wraps(view_func)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
api_token = validate_and_get_api_token("dataset")
|
||||
|
||||
# get url path dataset_id from positional args or kwargs
|
||||
# Flask passes URL path parameters as positional arguments
|
||||
dataset_id = None
|
||||
|
||||
# First try to get from kwargs (explicit parameter)
|
||||
dataset_id = kwargs.get("dataset_id")
|
||||
|
||||
# If not in kwargs, try to extract from positional args
|
||||
if not dataset_id and args:
|
||||
# For class methods: args[0] is self, args[1] is dataset_id (if exists)
|
||||
# Check if first arg is likely a class instance (has __dict__ or __class__)
|
||||
if len(args) > 1 and hasattr(args[0], "__dict__"):
|
||||
# This is a class method, dataset_id should be in args[1]
|
||||
potential_id = args[1]
|
||||
# Validate it's a string-like UUID, not another object
|
||||
try:
|
||||
# Try to convert to string and check if it's a valid UUID format
|
||||
str_id = str(potential_id)
|
||||
# Basic check: UUIDs are 36 chars with hyphens
|
||||
if len(str_id) == 36 and str_id.count("-") == 4:
|
||||
dataset_id = str_id
|
||||
except Exception:
|
||||
logger.exception("Failed to parse dataset_id from class method args")
|
||||
elif len(args) > 0:
|
||||
# Not a class method, check if args[0] looks like a UUID
|
||||
potential_id = args[0]
|
||||
try:
|
||||
str_id = str(potential_id)
|
||||
if len(str_id) == 36 and str_id.count("-") == 4:
|
||||
dataset_id = str_id
|
||||
except Exception:
|
||||
logger.exception("Failed to parse dataset_id from positional args")
|
||||
|
||||
# Validate dataset if dataset_id is provided
|
||||
if dataset_id:
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset)
|
||||
.where(
|
||||
Dataset.id == dataset_id,
|
||||
Dataset.tenant_id == api_token.tenant_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
if not dataset.enable_api:
|
||||
raise Forbidden("Dataset api access is not enabled.")
|
||||
|
||||
tenant_account_join = db.session.execute(
|
||||
select(Tenant, TenantAccountJoin)
|
||||
.where(Tenant.id == api_token.tenant_id)
|
||||
.where(TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.where(TenantAccountJoin.role.in_(["owner"]))
|
||||
.where(Tenant.status == TenantStatus.NORMAL)
|
||||
).one_or_none() # TODO: only owner information is required, so only one is returned.
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
account = db.session.get(Account, ta.account_id)
|
||||
# Login admin
|
||||
if account:
|
||||
account.current_tenant = tenant
|
||||
current_app.login_manager._update_request_context_with_user(account) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
if not dataset.enable_api:
|
||||
raise Forbidden("Dataset api access is not enabled.")
|
||||
tenant_account_join = db.session.execute(
|
||||
select(Tenant, TenantAccountJoin)
|
||||
.where(Tenant.id == api_token.tenant_id)
|
||||
.where(TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.where(TenantAccountJoin.role.in_(["owner"]))
|
||||
.where(Tenant.status == TenantStatus.NORMAL)
|
||||
).one_or_none() # TODO: only owner information is required, so only one is returned.
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
account = db.session.get(Account, ta.account_id)
|
||||
# Login admin
|
||||
if account:
|
||||
account.current_tenant = tenant
|
||||
current_app.login_manager._update_request_context_with_user(account) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
|
||||
else:
|
||||
raise Unauthorized("Tenant owner account does not exist.")
|
||||
else:
|
||||
raise Unauthorized("Tenant owner account does not exist.")
|
||||
else:
|
||||
raise Unauthorized("Tenant does not exist.")
|
||||
raise Unauthorized("Tenant does not exist.")
|
||||
return view_func(api_token.tenant_id, *args, **kwargs) # type: ignore[arg-type]
|
||||
|
||||
if expects_bound_instance:
|
||||
if not args:
|
||||
raise TypeError("validate_dataset_token expected a bound resource instance.")
|
||||
return view(args[0], api_token.tenant_id, *args[1:], **kwargs)
|
||||
return decorated
|
||||
|
||||
return view(api_token.tenant_id, *args, **kwargs)
|
||||
if view:
|
||||
return decorator(view)
|
||||
|
||||
return decorated
|
||||
# if view is None, it means that the decorator is used without parentheses
|
||||
# use the decorator as a function for method_decorators
|
||||
return decorator
|
||||
|
||||
|
||||
def validate_and_get_api_token(scope: str | None = None):
|
||||
|
||||
@@ -7,7 +7,7 @@ from werkzeug.exceptions import NotFound, RequestEntityTooLarge
|
||||
from controllers.trigger import bp
|
||||
from core.trigger.debug.event_bus import TriggerDebugEventBus
|
||||
from core.trigger.debug.events import WebhookDebugEvent, build_webhook_pool_key
|
||||
from services.trigger.webhook_service import RawWebhookDataDict, WebhookService
|
||||
from services.trigger.webhook_service import WebhookService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -23,7 +23,6 @@ def _prepare_webhook_execution(webhook_id: str, is_debug: bool = False):
|
||||
webhook_id, is_debug=is_debug
|
||||
)
|
||||
|
||||
webhook_data: RawWebhookDataDict
|
||||
try:
|
||||
# Use new unified extraction and validation
|
||||
webhook_data = WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)
|
||||
|
||||
@@ -3,11 +3,10 @@ import logging
|
||||
from flask import request
|
||||
from flask_restx import fields, marshal_with
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import field_validator
|
||||
from pydantic import BaseModel, field_validator
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import TextToAudioPayload as TextToAudioPayloadBase
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import (
|
||||
AppUnavailableError,
|
||||
@@ -35,7 +34,12 @@ from services.errors.audio import (
|
||||
from ..common.schema import register_schema_models
|
||||
|
||||
|
||||
class TextToAudioPayload(TextToAudioPayloadBase):
|
||||
class TextToAudioPayload(BaseModel):
|
||||
message_id: str | None = None
|
||||
voice: str | None = None
|
||||
text: str | None = None
|
||||
streaming: bool | None = None
|
||||
|
||||
@field_validator("message_id")
|
||||
@classmethod
|
||||
def validate_message_id(cls, value: str | None) -> str | None:
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.controller_schemas import ConversationRenamePayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotChatAppError
|
||||
@@ -38,6 +37,18 @@ class ConversationListQuery(BaseModel):
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class ConversationRenamePayload(BaseModel):
|
||||
name: str | None = None
|
||||
auto_generate: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_name_requirement(self):
|
||||
if not self.auto_generate:
|
||||
if self.name is None or not self.name.strip():
|
||||
raise ValueError("name is required when auto_generate is false")
|
||||
return self
|
||||
|
||||
|
||||
register_schema_models(web_ns, ConversationListQuery, ConversationRenamePayload)
|
||||
|
||||
|
||||
@@ -88,7 +99,7 @@ class ConversationListApi(WebApiResource):
|
||||
query = ConversationListQuery.model_validate(raw_args)
|
||||
|
||||
try:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
pagination = WebConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
|
||||
@@ -3,7 +3,8 @@ import secrets
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.auth.error import (
|
||||
@@ -18,15 +19,33 @@ from controllers.console.error import EmailSendIpLimitError
|
||||
from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required
|
||||
from controllers.web import web_ns
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.password import hash_password
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.password import hash_password, valid_password
|
||||
from models.account import Account
|
||||
from services.account_service import AccountService
|
||||
from services.entities.auth_entities import (
|
||||
ForgotPasswordCheckPayload,
|
||||
ForgotPasswordResetPayload,
|
||||
ForgotPasswordSendPayload,
|
||||
)
|
||||
|
||||
|
||||
class ForgotPasswordSendPayload(BaseModel):
|
||||
email: EmailStr
|
||||
language: str | None = None
|
||||
|
||||
|
||||
class ForgotPasswordCheckPayload(BaseModel):
|
||||
email: EmailStr
|
||||
code: str
|
||||
token: str = Field(min_length=1)
|
||||
|
||||
|
||||
class ForgotPasswordResetPayload(BaseModel):
|
||||
token: str = Field(min_length=1)
|
||||
new_password: str
|
||||
password_confirm: str
|
||||
|
||||
@field_validator("new_password", "password_confirm")
|
||||
@classmethod
|
||||
def validate_password(cls, value: str) -> str:
|
||||
return valid_password(value)
|
||||
|
||||
|
||||
register_schema_models(web_ns, ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload)
|
||||
|
||||
@@ -62,7 +81,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session)
|
||||
token = None
|
||||
if account is None:
|
||||
@@ -161,17 +180,18 @@ class ForgotPasswordResetApi(Resource):
|
||||
|
||||
email = reset_data.get("email", "")
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
|
||||
if account:
|
||||
self._update_existing_account(account, password_hashed, salt)
|
||||
self._update_existing_account(account, password_hashed, salt, session)
|
||||
else:
|
||||
raise AuthenticationFailedError()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
def _update_existing_account(self, account: Account, password_hashed, salt):
|
||||
def _update_existing_account(self, account: Account, password_hashed, salt, session):
|
||||
# Update existing account credentials
|
||||
account.password = base64.b64encode(password_hashed).decode()
|
||||
account.password_salt = base64.b64encode(salt).decode()
|
||||
session.commit()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user