Compare commits

..

3 Commits

Author SHA1 Message Date
yyh
9311150bd2 Merge branch 'main' into 4-2-no-global-loading 2026-04-02 19:16:40 +08:00
autofix-ci[bot]
c49201ee28 [autofix.ci] apply automated fixes 2026-04-02 09:40:01 +00:00
Stephen Zhou
d13e6901cf refactor: no global loading 2026-04-02 17:36:06 +08:00
1128 changed files with 16598 additions and 34187 deletions

9
.github/labeler.yml vendored
View File

@@ -1,10 +1,3 @@
web:
- changed-files:
- any-glob-to-any-file:
- 'web/**'
- 'packages/**'
- 'package.json'
- 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml'
- '.npmrc'
- '.nvmrc'
- any-glob-to-any-file: 'web/**'

View File

@@ -20,4 +20,4 @@
- [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
- [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
- [x] I've updated the documentation accordingly.
- [x] I ran `make lint` and `make type-check` (backend) and `cd web && pnpm exec vp staged` (frontend) to appease the lint gods
- [x] I ran `make lint` and `make type-check` (backend) and `cd web && npx lint-staged` (frontend) to appease the lint gods

View File

@@ -1,82 +0,0 @@
import { execFileSync } from 'node:child_process'
import fs from 'node:fs'
import path from 'node:path'
const repoRoot = process.cwd()
const baseSha = process.env.BASE_SHA || ''
const headSha = process.env.HEAD_SHA || ''
const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean)
const outputPath = process.env.I18N_CHANGES_OUTPUT_PATH || '/tmp/i18n-changes.json'
const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`)
const readCurrentJson = (fileStem) => {
const filePath = englishPath(fileStem)
if (!fs.existsSync(filePath))
return null
return JSON.parse(fs.readFileSync(filePath, 'utf8'))
}
const readBaseJson = (fileStem) => {
if (!baseSha)
return null
try {
const relativePath = `web/i18n/en-US/${fileStem}.json`
const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' })
return JSON.parse(content)
}
catch {
return null
}
}
const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue)
const changes = {}
for (const fileStem of files) {
const currentJson = readCurrentJson(fileStem)
const beforeJson = readBaseJson(fileStem) || {}
const afterJson = currentJson || {}
const added = {}
const updated = {}
const deleted = []
for (const [key, value] of Object.entries(afterJson)) {
if (!(key in beforeJson)) {
added[key] = value
continue
}
if (!compareJson(beforeJson[key], value)) {
updated[key] = {
before: beforeJson[key],
after: value,
}
}
}
for (const key of Object.keys(beforeJson)) {
if (!(key in afterJson))
deleted.push(key)
}
changes[fileStem] = {
fileDeleted: currentJson === null,
added,
updated,
deleted,
}
}
fs.writeFileSync(
outputPath,
JSON.stringify({
baseSha,
headSha,
files,
changes,
})
)

View File

@@ -39,11 +39,9 @@ jobs:
with:
files: |
web/**
packages/**
package.json
pnpm-lock.yaml
pnpm-workspace.yaml
.npmrc
.nvmrc
- name: Check api inputs
if: github.event_name != 'merge_group'

View File

@@ -65,7 +65,7 @@ jobs:
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Login to Docker Hub
uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4.1.0
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
with:
username: ${{ env.DOCKERHUB_USER }}
password: ${{ env.DOCKERHUB_TOKEN }}
@@ -130,7 +130,7 @@ jobs:
merge-multiple: true
- name: Login to Docker Hub
uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4.1.0
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
with:
username: ${{ env.DOCKERHUB_USER }}
password: ${{ env.DOCKERHUB_TOKEN }}

View File

@@ -8,11 +8,9 @@ on:
- api/Dockerfile
- web/docker/**
- web/Dockerfile
- packages/**
- package.json
- pnpm-lock.yaml
- pnpm-workspace.yaml
- .npmrc
- .nvmrc
concurrency:

View File

@@ -65,11 +65,9 @@ jobs:
- 'docker/volumes/sandbox/conf/**'
web:
- 'web/**'
- 'packages/**'
- 'package.json'
- 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml'
- '.npmrc'
- '.nvmrc'
- '.github/workflows/web-tests.yml'
- '.github/actions/setup-web/**'
@@ -79,11 +77,9 @@ jobs:
- 'api/uv.lock'
- 'e2e/**'
- 'web/**'
- 'packages/**'
- 'package.json'
- 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml'
- '.npmrc'
- '.nvmrc'
- 'docker/docker-compose.middleware.yaml'
- 'docker/middleware.env.example'

View File

@@ -77,11 +77,9 @@ jobs:
with:
files: |
web/**
packages/**
package.json
pnpm-lock.yaml
pnpm-workspace.yaml
.npmrc
.nvmrc
.github/workflows/style.yml
.github/actions/setup-web/**
@@ -151,7 +149,7 @@ jobs:
.editorconfig
- name: Super-linter
uses: super-linter/super-linter/slim@9e863354e3ff62e0727d37183162c4a88873df41 # v8.6.0
uses: super-linter/super-linter/slim@61abc07d755095a68f4987d1c2c3d1d64408f1f9 # v8.5.0
if: steps.changed-files.outputs.any_changed == 'true'
env:
BASH_SEVERITY: warning

View File

@@ -9,7 +9,6 @@ on:
- package.json
- pnpm-lock.yaml
- pnpm-workspace.yaml
- .npmrc
concurrency:
group: sdk-tests-${{ github.head_ref || github.run_id }}

View File

@@ -68,7 +68,89 @@ jobs:
" web/i18n-config/languages.ts | sed 's/[[:space:]]*$//')
generate_changes_json() {
node .github/scripts/generate-i18n-changes.mjs
node <<'NODE'
const { execFileSync } = require('node:child_process')
const fs = require('node:fs')
const path = require('node:path')
const repoRoot = process.cwd()
const baseSha = process.env.BASE_SHA || ''
const headSha = process.env.HEAD_SHA || ''
const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean)
const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`)
const readCurrentJson = (fileStem) => {
const filePath = englishPath(fileStem)
if (!fs.existsSync(filePath))
return null
return JSON.parse(fs.readFileSync(filePath, 'utf8'))
}
const readBaseJson = (fileStem) => {
if (!baseSha)
return null
try {
const relativePath = `web/i18n/en-US/${fileStem}.json`
const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' })
return JSON.parse(content)
}
catch (error) {
return null
}
}
const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue)
const changes = {}
for (const fileStem of files) {
const currentJson = readCurrentJson(fileStem)
const beforeJson = readBaseJson(fileStem) || {}
const afterJson = currentJson || {}
const added = {}
const updated = {}
const deleted = []
for (const [key, value] of Object.entries(afterJson)) {
if (!(key in beforeJson)) {
added[key] = value
continue
}
if (!compareJson(beforeJson[key], value)) {
updated[key] = {
before: beforeJson[key],
after: value,
}
}
}
for (const key of Object.keys(beforeJson)) {
if (!(key in afterJson))
deleted.push(key)
}
changes[fileStem] = {
fileDeleted: currentJson === null,
added,
updated,
deleted,
}
}
fs.writeFileSync(
'/tmp/i18n-changes.json',
JSON.stringify({
baseSha,
headSha,
files,
changes,
})
)
NODE
}
if [ "${{ github.event_name }}" = "repository_dispatch" ]; then
@@ -158,7 +240,7 @@ jobs:
- name: Run Claude Code for Translation Sync
if: steps.context.outputs.CHANGED_FILES != ''
uses: anthropics/claude-code-action@6e2bd52842c65e914eba5c8badd17560bd26b5de # v1.0.89
uses: anthropics/claude-code-action@88c168b39e7e64da0286d812b6e9fbebb6708185 # v1.0.82
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}
@@ -188,7 +270,7 @@ jobs:
Tool rules:
- Use Read for repository files.
- Use Edit for JSON updates.
- Use Bash only for `vp`.
- Use Bash only for `pnpm`.
- Do not use Bash for `git`, `gh`, or branch management.
Required execution plan:
@@ -210,7 +292,7 @@ jobs:
- Read the current English JSON file for any file that still exists so wording, placeholders, and surrounding terminology stay accurate.
- If `Structured change set available` is `false`, treat this as a scoped full sync and use the current English files plus scoped checks as the source of truth.
4. Run a scoped pre-check before editing:
- `vp run dify-web#i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
- `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
- Use this command as the source of truth for missing and extra keys inside the current scope.
5. Apply translations.
- For every target language and scoped file:
@@ -218,19 +300,19 @@ jobs:
- If the locale file does not exist yet, create it with `Write` and then continue with `Edit` as needed.
- ADD missing keys.
- UPDATE stale translations when the English value changed.
- DELETE removed keys. Prefer `vp run dify-web#i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }} --auto-remove` for extra keys so deletions stay in scope.
- DELETE removed keys. Prefer `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }} --auto-remove` for extra keys so deletions stay in scope.
- Preserve placeholders exactly: `{{variable}}`, `${variable}`, HTML tags, component tags, and variable names.
- Match the existing terminology and register used by each locale.
- Prefer one Edit per file when stable, but prioritize correctness over batching.
6. Verify only the edited files.
- Run `vp run dify-web#lint:fix --quiet -- <relative edited i18n file paths under web/>`
- Run `vp run dify-web#i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
- Run `pnpm --dir ${{ github.workspace }}/web lint:fix --quiet -- <relative edited i18n file paths>`
- Run `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
- If verification fails, fix the remaining problems before continuing.
7. Stop after the scoped locale files are updated and verification passes.
- Do not create branches, commits, or pull requests.
claude_args: |
--max-turns 120
--allowedTools "Read,Write,Edit,Bash(vp *),Bash(vp:*),Glob,Grep"
--allowedTools "Read,Write,Edit,Bash(pnpm *),Bash(pnpm:*),Glob,Grep"
- name: Prepare branch metadata
id: pr_meta
@@ -272,7 +354,6 @@ jobs:
- name: Create or update translation PR
if: steps.pr_meta.outputs.has_changes == 'true'
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
BRANCH_NAME: ${{ steps.pr_meta.outputs.branch_name }}
FILES_IN_SCOPE: ${{ steps.context.outputs.CHANGED_FILES }}
TARGET_LANGS: ${{ steps.context.outputs.TARGET_LANGS }}
@@ -321,8 +402,8 @@ jobs:
'',
'## Verification',
'',
`- \`vp run dify-web#i18n:check --file ${process.env.FILES_IN_SCOPE} --lang ${process.env.TARGET_LANGS}\``,
`- \`vp run dify-web#lint:fix --quiet -- <edited i18n files under web/>\``,
`- \`pnpm --dir web run i18n:check --file ${process.env.FILES_IN_SCOPE} --lang ${process.env.TARGET_LANGS}\``,
`- \`pnpm --dir web lint:fix --quiet -- <edited i18n files>\``,
'',
'## Notes',
'',

View File

@@ -42,7 +42,88 @@ jobs:
fi
export BASE_SHA HEAD_SHA CHANGED_FILES
node .github/scripts/generate-i18n-changes.mjs
node <<'NODE'
const { execFileSync } = require('node:child_process')
const fs = require('node:fs')
const path = require('node:path')
const repoRoot = process.cwd()
const baseSha = process.env.BASE_SHA || ''
const headSha = process.env.HEAD_SHA || ''
const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean)
const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`)
const readCurrentJson = (fileStem) => {
const filePath = englishPath(fileStem)
if (!fs.existsSync(filePath))
return null
return JSON.parse(fs.readFileSync(filePath, 'utf8'))
}
const readBaseJson = (fileStem) => {
if (!baseSha)
return null
try {
const relativePath = `web/i18n/en-US/${fileStem}.json`
const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' })
return JSON.parse(content)
}
catch (error) {
return null
}
}
const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue)
const changes = {}
for (const fileStem of files) {
const beforeJson = readBaseJson(fileStem) || {}
const afterJson = readCurrentJson(fileStem) || {}
const added = {}
const updated = {}
const deleted = []
for (const [key, value] of Object.entries(afterJson)) {
if (!(key in beforeJson)) {
added[key] = value
continue
}
if (!compareJson(beforeJson[key], value)) {
updated[key] = {
before: beforeJson[key],
after: value,
}
}
}
for (const key of Object.keys(beforeJson)) {
if (!(key in afterJson))
deleted.push(key)
}
changes[fileStem] = {
fileDeleted: readCurrentJson(fileStem) === null,
added,
updated,
deleted,
}
}
fs.writeFileSync(
'/tmp/i18n-changes.json',
JSON.stringify({
baseSha,
headSha,
files,
changes,
})
)
NODE
if [ -n "$CHANGED_FILES" ]; then
echo "has_changes=true" >> "$GITHUB_OUTPUT"

View File

@@ -36,7 +36,7 @@ jobs:
remove_tool_cache: true
- name: Setup UV and Python
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}

View File

@@ -81,18 +81,38 @@ if $web_modified; then
if $web_ts_modified; then
echo "Running TypeScript type-check:tsgo"
if ! npm run type-check:tsgo; then
echo "Type check failed. Please run 'npm run type-check:tsgo' to fix the errors."
if ! pnpm run type-check:tsgo; then
echo "Type check failed. Please run 'pnpm run type-check:tsgo' to fix the errors."
exit 1
fi
else
echo "No staged TypeScript changes detected, skipping type-check:tsgo"
fi
echo "Running knip"
if ! npm run knip; then
echo "Knip check failed. Please run 'npm run knip' to fix the errors."
exit 1
echo "Running unit tests check"
modified_files=$(git diff --cached --name-only -- utils | grep -v '\.spec\.ts$' || true)
if [ -n "$modified_files" ]; then
for file in $modified_files; do
test_file="${file%.*}.spec.ts"
echo "Checking for test file: $test_file"
# check if the test file exists
if [ -f "../$test_file" ]; then
echo "Detected changes in $file, running corresponding unit tests..."
pnpm run test "../$test_file"
if [ $? -ne 0 ]; then
echo "Unit tests failed. Please fix the errors before committing."
exit 1
fi
echo "Unit tests for $file passed."
else
echo "Warning: $file does not have a corresponding test file."
fi
done
echo "All unit tests for modified web/utils files have passed."
fi
cd ../

View File

@@ -1,18 +0,0 @@
# This module provides a lightweight Celery instance for use in Docker health checks.
# Unlike celery_entrypoint.py, this does NOT import app.py and therefore avoids
# initializing all Flask extensions (DB, Redis, storage, blueprints, etc.).
# Using this module keeps the health check fast and low-cost.
from celery import Celery
from configs import dify_config
from extensions.ext_celery import get_celery_broker_transport_options, get_celery_ssl_options
celery = Celery(broker=dify_config.CELERY_BROKER_URL)
broker_transport_options = get_celery_broker_transport_options()
if broker_transport_options:
celery.conf.update(broker_transport_options=broker_transport_options)
ssl_options = get_celery_ssl_options()
if ssl_options:
celery.conf.update(broker_use_ssl=ssl_options)

View File

@@ -1,7 +1,7 @@
import datetime
import logging
import time
from typing import TypedDict
from typing import Any
import click
import sqlalchemy as sa
@@ -503,19 +503,7 @@ def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
return [row[0] for row in result]
class _AppOrphanCounts(TypedDict):
variables: int
files: int
class OrphanedDraftVariableStatsDict(TypedDict):
total_orphaned_variables: int
total_orphaned_files: int
orphaned_app_count: int
orphaned_by_app: dict[str, _AppOrphanCounts]
def _count_orphaned_draft_variables() -> OrphanedDraftVariableStatsDict:
def _count_orphaned_draft_variables() -> dict[str, Any]:
"""
Count orphaned draft variables by app, including associated file counts.
@@ -538,7 +526,7 @@ def _count_orphaned_draft_variables() -> OrphanedDraftVariableStatsDict:
with db.engine.connect() as conn:
result = conn.execute(sa.text(variables_query))
orphaned_by_app: dict[str, _AppOrphanCounts] = {}
orphaned_by_app = {}
total_files = 0
for row in result:

View File

@@ -1,63 +0,0 @@
from typing import Any, Literal
from pydantic import BaseModel, Field, model_validator
from libs.helper import UUIDStrOrEmpty
# --- Conversation schemas ---
class ConversationRenamePayload(BaseModel):
name: str | None = None
auto_generate: bool = False
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
# --- Message schemas ---
class MessageListQuery(BaseModel):
conversation_id: UUIDStrOrEmpty
first_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class MessageFeedbackPayload(BaseModel):
rating: Literal["like", "dislike"] | None = None
content: str | None = None
# --- Saved message schemas ---
class SavedMessageListQuery(BaseModel):
last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class SavedMessageCreatePayload(BaseModel):
message_id: UUIDStrOrEmpty
# --- Workflow schemas ---
class WorkflowRunPayload(BaseModel):
inputs: dict[str, Any]
files: list[dict[str, Any]] | None = None
# --- Audio schemas ---
class TextToAudioPayload(BaseModel):
message_id: str | None = None
voice: str | None = None
text: str | None = None
streaming: bool | None = None

View File

@@ -2,7 +2,6 @@ import csv
import io
from collections.abc import Callable
from functools import wraps
from typing import cast
from flask import request
from flask_restx import Resource
@@ -18,7 +17,7 @@ from core.db.session_factory import session_factory
from extensions.ext_database import db
from libs.token import extract_access_token
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
from services.billing_service import BillingService, LangContentDict
from services.billing_service import BillingService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@@ -329,7 +328,7 @@ class UpsertNotificationApi(Resource):
def post(self):
payload = UpsertNotificationPayload.model_validate(console_ns.payload)
result = BillingService.upsert_notification(
contents=[cast(LangContentDict, c.model_dump()) for c in payload.contents],
contents=[c.model_dump() for c in payload.contents],
frequency=payload.frequency,
status=payload.status,
notification_id=payload.notification_id,

View File

@@ -7,7 +7,7 @@ from flask import request
from flask_restx import Resource
from graphon.enums import WorkflowExecutionStatus
from graphon.file import helpers as file_helpers
from pydantic import AliasChoices, BaseModel, Field, computed_field, field_validator
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest
@@ -26,11 +26,9 @@ from controllers.console.wraps import (
setup_required,
)
from core.ops.ops_trace_manager import OpsTraceManager
from core.rag.entities import PreProcessingRule, Rule, Segmentation
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.trigger.constants import TRIGGER_NODE_TYPES
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.login import current_account_with_tenant, login_required
from models import App, DatasetPermissionEnum, Workflow
from models.model import IconType
@@ -43,7 +41,10 @@ from services.entities.knowledge_entities.knowledge_entities import (
NotionIcon,
NotionInfo,
NotionPage,
PreProcessingRule,
RerankingModel,
Rule,
Segmentation,
WebsiteInfo,
WeightKeywordSetting,
WeightModel,
@@ -154,6 +155,16 @@ class AppTracePayload(BaseModel):
type JSONValue = Any
class ResponseModel(BaseModel):
model_config = ConfigDict(
from_attributes=True,
extra="ignore",
populate_by_name=True,
serialize_by_alias=True,
protected_namespaces=(),
)
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())

View File

@@ -8,7 +8,6 @@ from pydantic import BaseModel, Field, field_validator
from sqlalchemy import exists, func, select
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.controller_schemas import MessageFeedbackPayload as _MessageFeedbackPayloadBase
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import (
@@ -60,8 +59,10 @@ class ChatMessagesQuery(BaseModel):
return uuid_value(value)
class MessageFeedbackPayload(_MessageFeedbackPayloadBase):
class MessageFeedbackPayload(BaseModel):
message_id: str = Field(..., description="Message ID")
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
content: str | None = Field(default=None, description="Feedback content")
@field_validator("message_id")
@classmethod

View File

@@ -66,13 +66,13 @@ class WebhookTriggerApi(Resource):
with sessionmaker(db.engine).begin() as session:
# Get webhook trigger for this app and node
webhook_trigger = session.scalar(
select(WorkflowWebhookTrigger)
webhook_trigger = (
session.query(WorkflowWebhookTrigger)
.where(
WorkflowWebhookTrigger.app_id == app_model.id,
WorkflowWebhookTrigger.node_id == node_id,
)
.limit(1)
.first()
)
if not webhook_trigger:

View File

@@ -3,7 +3,7 @@ import secrets
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
@@ -20,18 +20,35 @@ from controllers.console.wraps import email_password_login_enabled, setup_requir
from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password
from libs.password import hash_password, valid_password
from services.account_service import AccountService, TenantService
from services.entities.auth_entities import (
ForgotPasswordCheckPayload,
ForgotPasswordResetPayload,
ForgotPasswordSendPayload,
)
from services.feature_service import FeatureService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ForgotPasswordSendPayload(BaseModel):
email: EmailStr = Field(...)
language: str | None = Field(default=None)
class ForgotPasswordCheckPayload(BaseModel):
email: EmailStr = Field(...)
code: str = Field(...)
token: str = Field(...)
class ForgotPasswordResetPayload(BaseModel):
token: str = Field(...)
new_password: str = Field(...)
password_confirm: str = Field(...)
@field_validator("new_password", "password_confirm")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)
class ForgotPasswordEmailResponse(BaseModel):
result: str = Field(description="Operation result")
data: str | None = Field(default=None, description="Reset token")

View File

@@ -1,3 +1,5 @@
from typing import Any
import flask_login
from flask import make_response, request
from flask_restx import Resource
@@ -40,9 +42,8 @@ from libs.token import (
set_csrf_token_to_cookie,
set_refresh_token_to_cookie,
)
from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService
from services.account_service import AccountService, RegisterService, TenantService
from services.billing_service import BillingService
from services.entities.auth_entities import LoginPayloadBase
from services.errors.account import AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
from services.feature_service import FeatureService
@@ -50,7 +51,9 @@ from services.feature_service import FeatureService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class LoginPayload(LoginPayloadBase):
class LoginPayload(BaseModel):
email: EmailStr = Field(..., description="Email address")
password: str = Field(..., description="Password")
remember_me: bool = Field(default=False, description="Remember me flag")
invite_token: str | None = Field(default=None, description="Invitation token")
@@ -98,7 +101,7 @@ class LoginApi(Resource):
raise EmailPasswordLoginLimitError()
invite_token = args.invite_token
invitation_data: InvitationDetailDict | None = None
invitation_data: dict[str, Any] | None = None
if invite_token:
invitation_data = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token)
if invitation_data is None:

View File

@@ -1,6 +1,4 @@
import base64
import json
from datetime import UTC, datetime, timedelta
from typing import Literal
from flask import request
@@ -11,7 +9,6 @@ from werkzeug.exceptions import BadRequest
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from enums.cloud_plan import CloudPlan
from extensions.ext_redis import redis_client
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
@@ -87,39 +84,3 @@ class PartnerTenants(Resource):
raise BadRequest("Invalid partner information")
return BillingService.sync_partner_tenants_bindings(current_user.id, decoded_partner_key, click_id)
_DEBUG_KEY = "billing:debug"
_DEBUG_TTL = timedelta(days=7)
class DebugDataPayload(BaseModel):
type: str = Field(..., min_length=1, description="Data type key")
data: str = Field(..., min_length=1, description="Data value to append")
@console_ns.route("/billing/debug/data")
class DebugData(Resource):
def post(self):
body = DebugDataPayload.model_validate(request.get_json(force=True))
item = json.dumps({
"type": body.type,
"data": body.data,
"createTime": datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ"),
})
redis_client.lpush(_DEBUG_KEY, item)
redis_client.expire(_DEBUG_KEY, _DEBUG_TTL)
return {"result": "ok"}, 201
def get(self):
recent = request.args.get("recent", 10, type=int)
items = redis_client.lrange(_DEBUG_KEY, 0, recent - 1)
return {
"data": [
json.loads(item.decode("utf-8") if isinstance(item, bytes) else item) for item in items
]
}
def delete(self):
redis_client.delete(_DEBUG_KEY)
return {"result": "ok"}

View File

@@ -158,11 +158,10 @@ class DataSourceApi(Resource):
@login_required
@account_initialization_required
def patch(self, binding_id, action: Literal["enable", "disable"]):
_, current_tenant_id = current_account_with_tenant()
binding_id = str(binding_id)
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
data_source_binding = session.execute(
select(DataSourceOauthBinding).filter_by(id=binding_id, tenant_id=current_tenant_id)
select(DataSourceOauthBinding).filter_by(id=binding_id)
).scalar_one_or_none()
if data_source_binding is None:
raise NotFound("Data source binding not found.")

View File

@@ -3,7 +3,6 @@ import logging
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
@@ -87,8 +86,8 @@ class CustomizedPipelineTemplateApi(Resource):
@enterprise_license_required
def post(self, template_id: str):
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
template = session.scalar(
select(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).limit(1)
template = (
session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first()
)
if not template:
raise ValueError("Customized pipeline template not found.")

View File

@@ -2,10 +2,10 @@ import logging
from flask import request
from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError
import services
from controllers.common.controller_schemas import TextToAudioPayload
from controllers.common.schema import register_schema_model
from controllers.console.app.error import (
AppUnavailableError,
@@ -32,6 +32,14 @@ from .. import console_ns
logger = logging.getLogger(__name__)
class TextToAudioPayload(BaseModel):
message_id: str | None = None
voice: str | None = None
text: str | None = None
streaming: bool | None = Field(default=None, description="Enable streaming response")
register_schema_model(console_ns, TextToAudioPayload)

View File

@@ -1,11 +1,10 @@
from typing import Any
from flask import request
from pydantic import BaseModel, Field, TypeAdapter
from pydantic import BaseModel, Field, TypeAdapter, model_validator
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import ConversationRenamePayload
from controllers.common.schema import register_schema_models
from controllers.console.explore.error import NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource
@@ -33,6 +32,18 @@ class ConversationListQuery(BaseModel):
pinned: bool | None = None
class ConversationRenamePayload(BaseModel):
name: str | None = None
auto_generate: bool = False
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload)

View File

@@ -3,10 +3,9 @@ from typing import Literal
from flask import request
from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, TypeAdapter
from pydantic import BaseModel, Field, TypeAdapter
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.controller_schemas import MessageFeedbackPayload, MessageListQuery
from controllers.common.schema import register_schema_models
from controllers.console.app.error import (
AppMoreLikeThisDisabledError,
@@ -26,6 +25,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from fields.conversation_fields import ResultResponse
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse
from libs import helper
from libs.helper import UUIDStrOrEmpty
from libs.login import current_account_with_tenant
from models.enums import FeedbackRating
from models.model import AppMode
@@ -44,6 +44,17 @@ from .. import console_ns
logger = logging.getLogger(__name__)
class MessageListQuery(BaseModel):
conversation_id: UUIDStrOrEmpty
first_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class MessageFeedbackPayload(BaseModel):
rating: Literal["like", "dislike"] | None = None
content: str | None = None
class MoreLikeThisQuery(BaseModel):
response_mode: Literal["blocking", "streaming"]

View File

@@ -1,18 +1,28 @@
from flask import request
from pydantic import TypeAdapter
from pydantic import BaseModel, Field, TypeAdapter
from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import SavedMessageCreatePayload, SavedMessageListQuery
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import ResultResponse
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
from libs.helper import UUIDStrOrEmpty
from libs.login import current_account_with_tenant
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
class SavedMessageListQuery(BaseModel):
last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class SavedMessageCreatePayload(BaseModel):
message_id: UUIDStrOrEmpty
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)

View File

@@ -1,10 +1,11 @@
import logging
from typing import Any
from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel
from werkzeug.exceptions import InternalServerError
from controllers.common.controller_schemas import WorkflowRunPayload
from controllers.common.schema import register_schema_model
from controllers.console.app.error import (
CompletionRequestError,
@@ -33,6 +34,12 @@ from .. import console_ns
logger = logging.getLogger(__name__)
class WorkflowRunPayload(BaseModel):
inputs: dict[str, Any]
files: list[dict[str, Any]] | None = None
register_schema_model(console_ns, WorkflowRunPayload)

View File

@@ -1,5 +1,3 @@
from typing import TypedDict
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
@@ -13,21 +11,6 @@ from services.billing_service import BillingService
_FALLBACK_LANG = "en-US"
class NotificationItemDict(TypedDict):
notification_id: str | None
frequency: str | None
lang: str
title: str
subtitle: str
body: str
title_pic_url: str
class NotificationResponseDict(TypedDict):
should_show: bool
notifications: list[NotificationItemDict]
def _pick_lang_content(contents: dict, lang: str) -> dict:
"""Return the single LangContent for *lang*, falling back to English."""
return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {})
@@ -62,30 +45,28 @@ class NotificationApi(Resource):
result = BillingService.get_account_notification(str(current_user.id))
# Proto JSON uses camelCase field names (Kratos default marshaling).
response: NotificationResponseDict
if not result.get("shouldShow"):
response = {"should_show": False, "notifications": []}
return response, 200
return {"should_show": False, "notifications": []}, 200
lang = current_user.interface_language or _FALLBACK_LANG
notifications: list[NotificationItemDict] = []
notifications = []
for notification in result.get("notifications") or []:
contents: dict = notification.get("contents") or {}
lang_content = _pick_lang_content(contents, lang)
item: NotificationItemDict = {
"notification_id": notification.get("notificationId"),
"frequency": notification.get("frequency"),
"lang": lang_content.get("lang", lang),
"title": lang_content.get("title", ""),
"subtitle": lang_content.get("subtitle", ""),
"body": lang_content.get("body", ""),
"title_pic_url": lang_content.get("titlePicUrl", ""),
}
notifications.append(item)
notifications.append(
{
"notification_id": notification.get("notificationId"),
"frequency": notification.get("frequency"),
"lang": lang_content.get("lang", lang),
"title": lang_content.get("title", ""),
"subtitle": lang_content.get("subtitle", ""),
"body": lang_content.get("body", ""),
"title_pic_url": lang_content.get("titlePicUrl", ""),
}
)
response = {"should_show": bool(notifications), "notifications": notifications}
return response, 200
return {"should_show": bool(notifications), "notifications": notifications}, 200
@console_ns.route("/notification/dismiss")

View File

@@ -9,14 +9,7 @@ from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from libs.login import current_account_with_tenant, login_required
from models.enums import TagType
from services.tag_service import (
SaveTagPayload,
TagBindingCreatePayload,
TagBindingDeletePayload,
TagService,
UpdateTagPayload,
)
from services.tag_service import TagService
dataset_tag_fields = {
"id": fields.String,
@@ -32,19 +25,19 @@ def build_dataset_tag_fields(api_or_ns: Namespace):
class TagBasePayload(BaseModel):
name: str = Field(description="Tag name", min_length=1, max_length=50)
type: TagType = Field(description="Tag type")
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
class TagBindingPayload(BaseModel):
tag_ids: list[str] = Field(description="Tag IDs to bind")
target_id: str = Field(description="Target ID to bind tags to")
type: TagType = Field(description="Tag type")
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
class TagBindingRemovePayload(BaseModel):
tag_id: str = Field(description="Tag ID to remove")
target_id: str = Field(description="Target ID to unbind tag from")
type: TagType = Field(description="Tag type")
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
class TagListQueryParam(BaseModel):
@@ -89,7 +82,7 @@ class TagListApi(Resource):
raise Forbidden()
payload = TagBasePayload.model_validate(console_ns.payload or {})
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=payload.type))
tag = TagService.save_tags(payload.model_dump())
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
@@ -110,7 +103,7 @@ class TagUpdateDeleteApi(Resource):
raise Forbidden()
payload = TagBasePayload.model_validate(console_ns.payload or {})
tag = TagService.update_tags(UpdateTagPayload(name=payload.name, type=payload.type), tag_id)
tag = TagService.update_tags(payload.model_dump(), tag_id)
binding_count = TagService.get_tag_binding_count(tag_id)
@@ -143,9 +136,7 @@ class TagBindingCreateApi(Resource):
raise Forbidden()
payload = TagBindingPayload.model_validate(console_ns.payload or {})
TagService.save_tag_binding(
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=payload.type)
)
TagService.save_tag_binding(payload.model_dump())
return {"result": "success"}, 200
@@ -163,8 +154,6 @@ class TagBindingDeleteApi(Resource):
raise Forbidden()
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=payload.type)
)
TagService.delete_tag_binding(payload.model_dump())
return {"result": "success"}, 200

View File

@@ -1,7 +1,6 @@
from collections.abc import Callable
from functools import wraps
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Forbidden
@@ -22,12 +21,12 @@ def plugin_permission_required(
tenant_id = current_tenant_id
with sessionmaker(db.engine).begin() as session:
permission = session.scalar(
select(TenantPluginPermission)
permission = (
session.query(TenantPluginPermission)
.where(
TenantPluginPermission.tenant_id == tenant_id,
)
.limit(1)
.first()
)
if not permission:

View File

@@ -28,7 +28,7 @@ from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.account import Tenant, TenantCustomConfigDict, TenantStatus
from models.account import Tenant, TenantStatus
from services.account_service import TenantService
from services.billing_service import BillingService, SubscriptionPlan
from services.enterprise.enterprise_service import EnterpriseService
@@ -240,10 +240,8 @@ class CustomConfigWorkspaceApi(Resource):
args = WorkspaceCustomConfigPayload.model_validate(payload)
tenant = db.get_or_404(Tenant, current_tenant_id)
custom_config_dict: TenantCustomConfigDict = {
"remove_webapp_brand": args.remove_webapp_brand
if args.remove_webapp_brand is not None
else tenant.custom_config_dict.get("remove_webapp_brand", False),
custom_config_dict = {
"remove_webapp_brand": args.remove_webapp_brand,
"replace_webapp_logo": args.replace_webapp_logo
if args.replace_webapp_logo is not None
else tenant.custom_config_dict.get("replace_webapp_logo"),

View File

@@ -9,7 +9,7 @@ from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_model
from controllers.console.wraps import setup_required
@@ -55,7 +55,7 @@ class EnterpriseAppDSLImport(Resource):
account.set_tenant_id(workspace_id)
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
dsl_service = AppDslService(session)
result = dsl_service.import_app(
account=account,
@@ -64,6 +64,7 @@ class EnterpriseAppDSLImport(Resource):
name=args.name,
description=args.description,
)
session.commit()
if result.status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400

View File

@@ -4,7 +4,6 @@ from flask import Response
from flask_restx import Resource
from graphon.variables.input_entities import VariableEntity
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from controllers.common.schema import register_schema_model
@@ -81,11 +80,11 @@ class MCPAppApi(Resource):
def _get_mcp_server_and_app(self, server_code: str, session: Session) -> tuple[AppMCPServer, App]:
"""Get and validate MCP server and app in one query session"""
mcp_server = session.scalar(select(AppMCPServer).where(AppMCPServer.server_code == server_code).limit(1))
mcp_server = session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
if not mcp_server:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server Not Found")
app = session.scalar(select(App).where(App.id == mcp_server.app_id).limit(1))
app = session.query(App).where(App.id == mcp_server.app_id).first()
if not app:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App Not Found")
@@ -191,12 +190,12 @@ class MCPAppApi(Resource):
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str) -> EndUser | None:
"""Get end user - manages its own database session"""
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
return session.scalar(
select(EndUser)
return (
session.query(EndUser)
.where(EndUser.tenant_id == tenant_id)
.where(EndUser.session_id == mcp_server_id)
.where(EndUser.type == "mcp")
.limit(1)
.first()
)
def _create_end_user(

View File

@@ -2,12 +2,11 @@ from typing import Any, Literal
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, TypeAdapter, field_validator
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, NotFound
import services
from controllers.common.controller_schemas import ConversationRenamePayload
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import NotChatAppError
@@ -35,6 +34,18 @@ class ConversationListQuery(BaseModel):
)
class ConversationRenamePayload(BaseModel):
name: str | None = Field(default=None, description="New conversation name (required if auto_generate is false)")
auto_generate: bool = Field(default=False, description="Auto-generate conversation name")
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
class ConversationVariablesQuery(BaseModel):
last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last variable ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")

View File

@@ -1,4 +1,5 @@
import logging
from typing import Literal
from flask import request
from flask_restx import Resource
@@ -6,7 +7,6 @@ from pydantic import BaseModel, Field, TypeAdapter
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services
from controllers.common.controller_schemas import MessageFeedbackPayload, MessageListQuery
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import NotChatAppError
@@ -14,6 +14,7 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
from core.app.entities.app_invoke_entities import InvokeFrom
from fields.conversation_fields import ResultResponse
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
from libs.helper import UUIDStrOrEmpty
from models.enums import FeedbackRating
from models.model import App, AppMode, EndUser
from services.errors.message import (
@@ -26,6 +27,17 @@ from services.message_service import MessageService
logger = logging.getLogger(__name__)
class MessageListQuery(BaseModel):
conversation_id: UUIDStrOrEmpty
first_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return")
class MessageFeedbackPayload(BaseModel):
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
content: str | None = Field(default=None, description="Feedback content")
class FeedbackListQuery(BaseModel):
page: int = Field(default=1, ge=1, description="Page number")
limit: int = Field(default=20, ge=1, le=101, description="Number of feedbacks per page")

View File

@@ -1,5 +1,5 @@
import logging
from typing import Literal
from typing import Any, Literal
from dateutil.parser import isoparse
from flask import request
@@ -11,7 +11,6 @@ from pydantic import BaseModel, Field
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
from controllers.common.controller_schemas import WorkflowRunPayload as WorkflowRunPayloadBase
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import (
@@ -47,7 +46,9 @@ from services.workflow_app_service import WorkflowAppService
logger = logging.getLogger(__name__)
class WorkflowRunPayload(WorkflowRunPayloadBase):
class WorkflowRunPayload(BaseModel):
inputs: dict[str, Any]
files: list[dict[str, Any]] | None = None
response_mode: Literal["blocking", "streaming"] | None = None

View File

@@ -22,17 +22,10 @@ from fields.tag_fields import DataSetTag
from libs.login import current_user
from models.account import Account
from models.dataset import DatasetPermissionEnum
from models.enums import TagType
from models.provider_ids import ModelProviderID
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
from services.tag_service import (
SaveTagPayload,
TagBindingCreatePayload,
TagBindingDeletePayload,
TagService,
UpdateTagPayload,
)
from services.tag_service import TagService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@@ -520,7 +513,7 @@ class DatasetTagsApi(DatasetApiResource):
raise Forbidden()
payload = TagCreatePayload.model_validate(service_api_ns.payload or {})
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=TagType.KNOWLEDGE))
tag = TagService.save_tags({"name": payload.name, "type": "knowledge"})
response = DataSetTag.model_validate(
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
@@ -543,8 +536,9 @@ class DatasetTagsApi(DatasetApiResource):
raise Forbidden()
payload = TagUpdatePayload.model_validate(service_api_ns.payload or {})
params = {"name": payload.name, "type": "knowledge"}
tag_id = payload.tag_id
tag = TagService.update_tags(UpdateTagPayload(name=payload.name, type=TagType.KNOWLEDGE), tag_id)
tag = TagService.update_tags(params, tag_id)
binding_count = TagService.get_tag_binding_count(tag_id)
@@ -591,9 +585,7 @@ class DatasetTagBindingApi(DatasetApiResource):
raise Forbidden()
payload = TagBindingPayload.model_validate(service_api_ns.payload or {})
TagService.save_tag_binding(
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE)
)
TagService.save_tag_binding({"tag_ids": payload.tag_ids, "target_id": payload.target_id, "type": "knowledge"})
return "", 204
@@ -617,9 +609,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
raise Forbidden()
payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=TagType.KNOWLEDGE)
)
TagService.delete_tag_binding({"tag_id": payload.tag_id, "target_id": payload.target_id, "type": "knowledge"})
return "", 204

View File

@@ -31,7 +31,6 @@ from controllers.service_api.wraps import (
cloud_edition_billing_resource_check,
)
from core.errors.error import ProviderTokenNotInitError
from core.rag.entities import PreProcessingRule, Rule, Segmentation
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from fields.document_fields import document_fields, document_status_fields
@@ -41,8 +40,11 @@ from models.enums import SegmentStatus
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import (
KnowledgeConfig,
PreProcessingRule,
ProcessRule,
RetrievalModel,
Rule,
Segmentation,
)
from services.file_service import FileService
from services.summary_index_service import SummaryIndexService

View File

@@ -4,23 +4,13 @@ Serialization helpers for Service API knowledge pipeline endpoints.
from __future__ import annotations
from typing import TYPE_CHECKING, TypedDict
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from models.model import UploadFile
class UploadFileDict(TypedDict):
id: str
name: str
size: int
extension: str
mime_type: str | None
created_by: str
created_at: str | None
def serialize_upload_file(upload_file: UploadFile) -> UploadFileDict:
def serialize_upload_file(upload_file: UploadFile) -> dict[str, Any]:
return {
"id": upload_file.id,
"name": upload_file.name,

View File

@@ -7,7 +7,7 @@ from werkzeug.exceptions import NotFound, RequestEntityTooLarge
from controllers.trigger import bp
from core.trigger.debug.event_bus import TriggerDebugEventBus
from core.trigger.debug.events import WebhookDebugEvent, build_webhook_pool_key
from services.trigger.webhook_service import RawWebhookDataDict, WebhookService
from services.trigger.webhook_service import WebhookService
logger = logging.getLogger(__name__)
@@ -23,7 +23,6 @@ def _prepare_webhook_execution(webhook_id: str, is_debug: bool = False):
webhook_id, is_debug=is_debug
)
webhook_data: RawWebhookDataDict
try:
# Use new unified extraction and validation
webhook_data = WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)

View File

@@ -3,11 +3,10 @@ import logging
from flask import request
from flask_restx import fields, marshal_with
from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import field_validator
from pydantic import BaseModel, field_validator
from werkzeug.exceptions import InternalServerError
import services
from controllers.common.controller_schemas import TextToAudioPayload as TextToAudioPayloadBase
from controllers.web import web_ns
from controllers.web.error import (
AppUnavailableError,
@@ -35,7 +34,12 @@ from services.errors.audio import (
from ..common.schema import register_schema_models
class TextToAudioPayload(TextToAudioPayloadBase):
class TextToAudioPayload(BaseModel):
message_id: str | None = None
voice: str | None = None
text: str | None = None
streaming: bool | None = None
@field_validator("message_id")
@classmethod
def validate_message_id(cls, value: str | None) -> str | None:

View File

@@ -1,11 +1,10 @@
from typing import Literal
from flask import request
from pydantic import BaseModel, Field, TypeAdapter, field_validator
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import ConversationRenamePayload
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import NotChatAppError
@@ -38,6 +37,18 @@ class ConversationListQuery(BaseModel):
return uuid_value(value)
class ConversationRenamePayload(BaseModel):
name: str | None = None
auto_generate: bool = False
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
register_schema_models(web_ns, ConversationListQuery, ConversationRenamePayload)

View File

@@ -3,6 +3,7 @@ import secrets
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
@@ -18,15 +19,33 @@ from controllers.console.error import EmailSendIpLimitError
from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required
from controllers.web import web_ns
from extensions.ext_database import db
from libs.helper import extract_remote_ip
from libs.password import hash_password
from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password
from models.account import Account
from services.account_service import AccountService
from services.entities.auth_entities import (
ForgotPasswordCheckPayload,
ForgotPasswordResetPayload,
ForgotPasswordSendPayload,
)
class ForgotPasswordSendPayload(BaseModel):
email: EmailStr
language: str | None = None
class ForgotPasswordCheckPayload(BaseModel):
email: EmailStr
code: str
token: str = Field(min_length=1)
class ForgotPasswordResetPayload(BaseModel):
token: str = Field(min_length=1)
new_password: str
password_confirm: str
@field_validator("new_password", "password_confirm")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)
register_schema_models(web_ns, ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload)

View File

@@ -29,11 +29,13 @@ from libs.token import (
)
from services.account_service import AccountService
from services.app_service import AppService
from services.entities.auth_entities import LoginPayloadBase
from services.webapp_auth_service import WebAppAuthService
class LoginPayload(LoginPayloadBase):
class LoginPayload(BaseModel):
email: EmailStr
password: str
@field_validator("password")
@classmethod
def validate_password(cls, value: str) -> str:

View File

@@ -6,7 +6,6 @@ from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field, TypeAdapter, field_validator
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.controller_schemas import MessageFeedbackPayload
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import (
@@ -54,6 +53,11 @@ class MessageListQuery(BaseModel):
return uuid_value(value)
class MessageFeedbackPayload(BaseModel):
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
content: str | None = Field(default=None, description="Feedback content")
class MessageMoreLikeThisQuery(BaseModel):
response_mode: Literal["blocking", "streaming"] = Field(
description="Response mode",

View File

@@ -1,17 +1,27 @@
from flask import request
from pydantic import TypeAdapter
from pydantic import BaseModel, Field, TypeAdapter
from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import SavedMessageCreatePayload, SavedMessageListQuery
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import NotCompletionAppError
from controllers.web.wraps import WebApiResource
from fields.conversation_fields import ResultResponse
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
from libs.helper import UUIDStrOrEmpty
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
class SavedMessageListQuery(BaseModel):
last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class SavedMessageCreatePayload(BaseModel):
message_id: UUIDStrOrEmpty
register_schema_models(web_ns, SavedMessageListQuery, SavedMessageCreatePayload)

View File

@@ -1,10 +1,11 @@
import logging
from typing import Any
from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError
from controllers.common.controller_schemas import WorkflowRunPayload
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import (
@@ -29,6 +30,12 @@ from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError
class WorkflowRunPayload(BaseModel):
inputs: dict[str, Any] = Field(description="Input variables for the workflow")
files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed by the workflow")
logger = logging.getLogger(__name__)
register_schema_models(web_ns, WorkflowRunPayload)

View File

@@ -79,18 +79,21 @@ class CotChatAgentRunner(CotAgentRunner):
if not agent_scratchpad:
assistant_messages = []
else:
content = ""
assistant_message = AssistantPromptMessage(content="")
assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str
for unit in agent_scratchpad:
if unit.is_final():
content += f"Final Answer: {unit.agent_response}"
assert isinstance(assistant_message.content, str)
assistant_message.content += f"Final Answer: {unit.agent_response}"
else:
content += f"Thought: {unit.thought}\n\n"
assert isinstance(assistant_message.content, str)
assistant_message.content += f"Thought: {unit.thought}\n\n"
if unit.action_str:
content += f"Action: {unit.action_str}\n\n"
assistant_message.content += f"Action: {unit.action_str}\n\n"
if unit.observation:
content += f"Observation: {unit.observation}\n\n"
assistant_message.content += f"Observation: {unit.observation}\n\n"
assistant_messages = [AssistantPromptMessage(content=content)]
assistant_messages = [assistant_message]
# query messages
query_messages = self._organize_user_query(self._query, [])

View File

@@ -5,10 +5,6 @@ from configs import dify_config
from constants import DEFAULT_FILE_NUMBER_LIMITS
class FeatureToggleDict(TypedDict):
enabled: bool
class SystemParametersDict(TypedDict):
image_file_size_limit: int
video_file_size_limit: int
@@ -20,12 +16,12 @@ class SystemParametersDict(TypedDict):
class AppParametersDict(TypedDict):
opening_statement: str | None
suggested_questions: list[str]
suggested_questions_after_answer: FeatureToggleDict
speech_to_text: FeatureToggleDict
text_to_speech: FeatureToggleDict
retriever_resource: FeatureToggleDict
annotation_reply: FeatureToggleDict
more_like_this: FeatureToggleDict
suggested_questions_after_answer: dict[str, Any]
speech_to_text: dict[str, Any]
text_to_speech: dict[str, Any]
retriever_resource: dict[str, Any]
annotation_reply: dict[str, Any]
more_like_this: dict[str, Any]
user_input_form: list[dict[str, Any]]
sensitive_word_avoidance: dict[str, Any]
file_upload: dict[str, Any]

View File

@@ -1,3 +1,4 @@
from collections.abc import Sequence
from enum import StrEnum, auto
from typing import Any, Literal
@@ -8,7 +9,6 @@ from graphon.variables.input_entities import VariableEntity as WorkflowVariableE
from pydantic import BaseModel, Field
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
from core.rag.entities import MetadataFilteringCondition
from models.model import AppMode
@@ -111,6 +111,31 @@ class ExternalDataVariableEntity(BaseModel):
config: dict[str, Any] = Field(default_factory=dict)
SupportedComparisonOperator = Literal[
# for string or array
"contains",
"not contains",
"start with",
"end with",
"is",
"is not",
"empty",
"not empty",
"in",
"not in",
# for number
"=",
"",
">",
"<",
"",
"",
# for time
"before",
"after",
]
class ModelConfig(BaseModel):
provider: str
name: str
@@ -118,6 +143,25 @@ class ModelConfig(BaseModel):
completion_params: dict[str, Any] = Field(default_factory=dict)
class Condition(BaseModel):
"""
Condition detail
"""
name: str
comparison_operator: SupportedComparisonOperator
value: str | Sequence[str] | None | int | float = None
class MetadataFilteringCondition(BaseModel):
"""
Metadata Filtering Condition.
"""
logical_operator: Literal["and", "or"] | None = "and"
conditions: list[Condition] | None = Field(default=None, deprecated=True)
class DatasetRetrieveConfigEntity(BaseModel):
"""
Dataset Retrieve Config Entity.

View File

@@ -107,13 +107,13 @@ class AppGenerateResponseConverter(ABC):
return metadata
@classmethod
def _error_to_stream_response(cls, e: Exception) -> dict[str, Any]:
def _error_to_stream_response(cls, e: Exception):
"""
Error to stream response.
:param e: exception
:return:
"""
error_responses: dict[type[Exception], dict[str, Any]] = {
error_responses = {
ValueError: {"code": "invalid_param", "status": 400},
ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400},
QuotaExceededError: {
@@ -127,7 +127,7 @@ class AppGenerateResponseConverter(ABC):
}
# Determine the response based on the type of exception
data: dict[str, Any] | None = None
data = None
for k, v in error_responses.items():
if isinstance(e, k):
data = v

View File

@@ -66,7 +66,7 @@ from core.app.entities.queue_entities import (
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.rag.entities import RetrievalSourceMetadata
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class
from core.workflow.system_variables import (
build_bootstrap_variables,

View File

@@ -10,7 +10,7 @@ from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChun
from pydantic import BaseModel, ConfigDict, Field
from core.app.entities.agent_strategy import AgentStrategyInfo
from core.rag.entities import RetrievalSourceMetadata
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
class QueueEvent(StrEnum):

View File

@@ -9,7 +9,7 @@ from graphon.nodes.human_input.entities import FormInput, UserAction
from pydantic import BaseModel, ConfigDict, Field
from core.app.entities.agent_strategy import AgentStrategyInfo
from core.rag.entities import RetrievalSourceMetadata
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
class AnnotationReplyAccount(BaseModel):

View File

@@ -509,8 +509,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
:return:
"""
with Session(db.engine, expire_on_commit=False) as session:
agent_thought: MessageAgentThought | None = session.scalar(
select(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).limit(1)
agent_thought: MessageAgentThought | None = (
session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
)
if agent_thought:

View File

@@ -6,7 +6,7 @@ from sqlalchemy import select, update
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
from core.rag.entities import RetrievalSourceMetadata
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.models.document import Document
from extensions.ext_database import db

View File

@@ -345,8 +345,8 @@ class DatasourceManager:
@classmethod
def get_upload_file_by_id(cls, file_id: str, tenant_id: str) -> File:
with session_factory.create_session() as session:
upload_file = session.scalar(
select(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).limit(1)
upload_file = (
session.query(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).first()
)
if not upload_file:
raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}")

View File

@@ -1,3 +1,22 @@
from core.tools.entities.common_entities import I18nObject, I18nObjectDict
from pydantic import BaseModel, Field, model_validator
__all__ = ["I18nObject", "I18nObjectDict"]
class I18nObject(BaseModel):
"""
Model class for i18n object.
"""
en_US: str
zh_Hans: str | None = Field(default=None)
pt_BR: str | None = Field(default=None)
ja_JP: str | None = Field(default=None)
@model_validator(mode="after")
def _(self):
self.zh_Hans = self.zh_Hans or self.en_US
self.pt_BR = self.pt_BR or self.en_US
self.ja_JP = self.ja_JP or self.en_US
return self
def to_dict(self) -> dict:
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}

View File

@@ -9,7 +9,7 @@ from yarl import URL
from configs import dify_config
from core.entities.provider_entities import ProviderConfig
from core.plugin.entities import OAuthSchema
from core.plugin.entities.oauth import OAuthSchema
from core.plugin.entities.parameters import (
PluginParameter,
PluginParameterOption,

View File

@@ -1,8 +1 @@
from core.entities.plugin_credential_type import PluginCredentialType
DEFAULT_PLUGIN_ID = "langgenius"
__all__ = [
"DEFAULT_PLUGIN_ID",
"PluginCredentialType",
]

View File

@@ -1,9 +0,0 @@
import enum
class PluginCredentialType(enum.Enum):
MODEL = 0 # must be 0 for API contract compatibility
TOOL = 1 # must be 1 for API contract compatibility
def to_number(self):
return self.value

View File

@@ -22,7 +22,6 @@ from sqlalchemy import func, select
from sqlalchemy.orm import Session
from constants import HIDDEN_VALUE
from core.entities import PluginCredentialType
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
from core.entities.provider_entities import (
CustomConfiguration,
@@ -47,6 +46,7 @@ from models.provider import (
TenantPreferredModelProvider,
)
from models.provider_ids import ModelProviderID
from services.enterprise.plugin_manager_service import PluginCredentialType
logger = logging.getLogger(__name__)

View File

@@ -2,7 +2,7 @@
Credential utility functions for checking credential existence and policy compliance.
"""
from core.entities import PluginCredentialType
from services.enterprise.plugin_manager_service import PluginCredentialType
def is_credential_exists(credential_id: str, credential_type: "PluginCredentialType") -> bool:

View File

@@ -2,7 +2,7 @@ import json
import logging
import re
from collections.abc import Sequence
from typing import Protocol, TypedDict, cast
from typing import Protocol, cast
import json_repair
from graphon.enums import WorkflowNodeExecutionMetadataKey
@@ -49,17 +49,6 @@ class WorkflowServiceInterface(Protocol):
pass
class CodeGenerateResultDict(TypedDict):
code: str
language: str
error: str
class StructuredOutputResultDict(TypedDict):
output: str
error: str
class LLMGenerator:
@classmethod
def generate_conversation_name(
@@ -304,7 +293,7 @@ class LLMGenerator:
cls,
tenant_id: str,
args: RuleCodeGeneratePayload,
) -> CodeGenerateResultDict:
):
if args.code_language == "python":
prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE)
else:
@@ -373,9 +362,7 @@ class LLMGenerator:
return answer.strip()
@classmethod
def generate_structured_output(
cls, tenant_id: str, args: RuleStructuredOutputPayload
) -> StructuredOutputResultDict:
def generate_structured_output(cls, tenant_id: str, args: RuleStructuredOutputPayload):
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
@@ -467,7 +454,7 @@ class LLMGenerator:
):
session = db.session()
app: App | None = session.scalar(select(App).where(App.id == flow_id).limit(1))
app: App | None = session.query(App).where(App.id == flow_id).first()
if not app:
raise ValueError("App not found.")
workflow = workflow_service.get_draft_workflow(app_model=app)

View File

@@ -6,7 +6,6 @@ import logging
import flask
from core.logging.context import get_request_id, get_trace_id
from core.logging.structured_formatter import IdentityDict
class TraceContextFilter(logging.Filter):
@@ -61,7 +60,7 @@ class IdentityContextFilter(logging.Filter):
record.user_type = identity.get("user_type", "")
return True
def _extract_identity(self) -> IdentityDict:
def _extract_identity(self) -> dict[str, str]:
"""Extract identity from current_user if in request context."""
try:
if not flask.has_request_context():
@@ -78,7 +77,7 @@ class IdentityContextFilter(logging.Filter):
from models import Account
from models.model import EndUser
identity: IdentityDict = {}
identity: dict[str, str] = {}
if isinstance(user, Account):
if user.current_tenant_id:

View File

@@ -3,19 +3,13 @@
import logging
import traceback
from datetime import UTC, datetime
from typing import Any, TypedDict
from typing import Any
import orjson
from configs import dify_config
class IdentityDict(TypedDict, total=False):
tenant_id: str
user_id: str
user_type: str
class StructuredJSONFormatter(logging.Formatter):
"""
JSON log formatter following the specified schema:
@@ -90,7 +84,7 @@ class StructuredJSONFormatter(logging.Formatter):
return log_dict
def _extract_identity(self, record: logging.LogRecord) -> IdentityDict | None:
def _extract_identity(self, record: logging.LogRecord) -> dict[str, str] | None:
tenant_id = getattr(record, "tenant_id", None)
user_id = getattr(record, "user_id", None)
user_type = getattr(record, "user_type", None)
@@ -98,7 +92,7 @@ class StructuredJSONFormatter(logging.Formatter):
if not any([tenant_id, user_id, user_type]):
return None
identity: IdentityDict = {}
identity: dict[str, str] = {}
if tenant_id:
identity["tenant_id"] = tenant_id
if user_id:

View File

@@ -1,7 +1,7 @@
import json
import logging
from collections.abc import Mapping
from typing import Any, NotRequired, TypedDict, cast
from typing import Any, cast
from graphon.variables.input_entities import VariableEntity, VariableEntityType
@@ -15,17 +15,6 @@ from services.app_generate_service import AppGenerateService
logger = logging.getLogger(__name__)
class ToolParameterSchemaDict(TypedDict):
type: str
properties: dict[str, Any]
required: list[str]
class ToolArgumentsDict(TypedDict):
query: NotRequired[str]
inputs: dict[str, Any]
def handle_mcp_request(
app: App,
request: mcp_types.ClientRequest,
@@ -130,7 +119,7 @@ def handle_list_tools(
mcp_types.Tool(
name=app_name,
description=description,
inputSchema=cast(dict[str, Any], parameter_schema),
inputSchema=parameter_schema,
)
],
)
@@ -165,7 +154,7 @@ def build_parameter_schema(
app_mode: str,
user_input_form: list[VariableEntity],
parameters_dict: dict[str, str],
) -> ToolParameterSchemaDict:
) -> dict[str, Any]:
"""Build parameter schema for the tool"""
parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict)
@@ -185,7 +174,7 @@ def build_parameter_schema(
}
def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> ToolArgumentsDict:
def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> dict[str, Any]:
"""Prepare arguments based on app mode"""
if app.mode == AppMode.WORKFLOW:
return {"inputs": arguments}

View File

@@ -4,7 +4,7 @@ from collections.abc import Callable
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
from datetime import timedelta
from types import TracebackType
from typing import Any, Self
from typing import Any, Self, cast
from httpx import HTTPStatusError
from pydantic import BaseModel
@@ -338,11 +338,12 @@ class BaseSession[
validated_request = self._receive_request_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
validated_request = cast(ReceiveRequestT, validated_request)
responder = RequestResponder[ReceiveRequestT, SendResultT](
request_id=message.message.root.id,
request_meta=validated_request.root.params.meta if validated_request.root.params else None,
request=validated_request, # type: ignore[arg-type] # mypy can't narrow constrained TypeVar from model_validate
request=validated_request,
session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
)
@@ -358,14 +359,15 @@ class BaseSession[
notification = self._receive_notification_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
notification = cast(ReceiveNotificationT, notification)
# Handle cancellation notifications
if isinstance(notification.root, CancelledNotification):
cancelled_id = notification.root.params.requestId
if cancelled_id in self._in_flight:
self._in_flight[cancelled_id].cancel()
else:
self._received_notification(notification) # type: ignore[arg-type]
self._handle_incoming(notification) # type: ignore[arg-type]
self._received_notification(notification)
self._handle_incoming(notification)
except Exception as e:
# For other validation errors, log and continue
logger.warning("Failed to validate notification: %s. Message was: %s", e, message.message.root)

View File

@@ -17,7 +17,6 @@ from graphon.model_runtime.model_providers.__base.text_embedding_model import Te
from graphon.model_runtime.model_providers.__base.tts_model import TTSModel
from configs import dify_config
from core.entities import PluginCredentialType
from core.entities.embedding_type import EmbeddingInputType
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import ModelLoadBalancingConfiguration
@@ -26,6 +25,7 @@ from core.plugin.impl.model_runtime_factory import create_plugin_provider_manage
from core.provider_manager import ProviderManager
from extensions.ext_redis import redis_client
from models.provider import ProviderType
from services.enterprise.plugin_manager_service import PluginCredentialType
logger = logging.getLogger(__name__)

View File

@@ -1,6 +1,6 @@
import json
from collections.abc import Mapping
from typing import Any, TypedDict
from typing import Any
from graphon.entities import WorkflowNodeExecution
from graphon.enums import WorkflowNodeExecutionStatus
@@ -56,22 +56,10 @@ def create_links_from_trace_id(trace_id: str | None) -> list[Link]:
return links
class RetrievalDocumentMetadataDict(TypedDict):
dataset_id: Any
doc_id: Any
document_id: Any
class RetrievalDocumentDict(TypedDict):
content: str
metadata: RetrievalDocumentMetadataDict
score: Any
def extract_retrieval_documents(documents: list[Document]) -> list[RetrievalDocumentDict]:
documents_data: list[RetrievalDocumentDict] = []
def extract_retrieval_documents(documents: list[Document]) -> list[dict[str, Any]]:
documents_data = []
for document in documents:
document_data: RetrievalDocumentDict = {
document_data = {
"content": document.page_content,
"metadata": {
"dataset_id": document.metadata.get("dataset_id"),
@@ -95,7 +83,7 @@ def create_common_span_attributes(
framework: str = DEFAULT_FRAMEWORK_NAME,
inputs: str = "",
outputs: str = "",
) -> dict[str, str]:
) -> dict[str, Any]:
return {
GEN_AI_SESSION_ID: session_id,
GEN_AI_USER_ID: user_id,

View File

@@ -56,10 +56,8 @@ class BaseTraceInstance(ABC):
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
current_tenant = session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.account_id == service_account.id, TenantAccountJoin.current.is_(True))
.limit(1)
current_tenant = (
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first()
)
if not current_tenant:
raise ValueError(f"Current tenant not found for account {service_account.id}")

View File

@@ -241,10 +241,8 @@ class TencentDataTrace(BaseTraceInstance):
if not service_account:
raise ValueError(f"Creator account not found for app {app_id}")
current_tenant = session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.account_id == service_account.id, TenantAccountJoin.current.is_(True))
.limit(1)
current_tenant = (
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first()
)
if not current_tenant:
raise ValueError(f"Current tenant not found for account {service_account.id}")

View File

@@ -1,5 +0,0 @@
from core.plugin.entities.oauth import OAuthSchema
__all__ = [
"OAuthSchema",
]

View File

@@ -1,3 +1,5 @@
from collections.abc import Sequence
from pydantic import BaseModel, Field
from core.entities.provider_entities import ProviderConfig
@@ -8,12 +10,12 @@ class OAuthSchema(BaseModel):
OAuth schema
"""
client_schema: list[ProviderConfig] = Field(
client_schema: Sequence[ProviderConfig] = Field(
default_factory=list,
description="client schema like client_id, client_secret, etc.",
)
credentials_schema: list[ProviderConfig] = Field(
credentials_schema: Sequence[ProviderConfig] = Field(
default_factory=list,
description="credentials schema like access_token, refresh_token, etc.",
)

View File

@@ -209,10 +209,7 @@ class PluginInstaller(BasePluginClient):
"GET",
f"plugin/{tenant_id}/management/decode/from_identifier",
PluginDecodeResponse,
params={
"plugin_unique_identifier": plugin_unique_identifier,
"PluginUniqueIdentifier": plugin_unique_identifier, # compat with daemon <= 0.5.4
},
params={"plugin_unique_identifier": plugin_unique_identifier},
)
def fetch_plugin_installation_by_ids(

View File

@@ -1,10 +1,11 @@
from __future__ import annotations
import contextlib
import json
from collections import defaultdict
from collections.abc import Sequence
from json import JSONDecodeError
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.entities.provider_entities import (
@@ -14,7 +15,6 @@ from graphon.model_runtime.entities.provider_entities import (
ProviderEntity,
)
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from pydantic import TypeAdapter
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
@@ -58,8 +58,6 @@ from services.feature_service import FeatureService
if TYPE_CHECKING:
from graphon.model_runtime.runtime import ModelRuntime
_credentials_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any])
class ProviderManager:
"""
@@ -877,8 +875,8 @@ class ProviderManager:
return {"openai_api_key": encrypted_config}
try:
credentials = _credentials_adapter.validate_json(encrypted_config)
except (ValueError, JSONDecodeError):
credentials = cast(dict, json.loads(encrypted_config))
except JSONDecodeError:
return {}
# Decrypt secret variables
@@ -1017,7 +1015,7 @@ class ProviderManager:
if not cached_provider_credentials:
provider_credentials: dict[str, Any] = {}
if provider_records and provider_records[0].encrypted_config:
provider_credentials = _credentials_adapter.validate_json(provider_records[0].encrypted_config)
provider_credentials = json.loads(provider_records[0].encrypted_config)
# Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables(
@@ -1164,10 +1162,8 @@ class ProviderManager:
if not cached_provider_model_credentials:
try:
provider_model_credentials = _credentials_adapter.validate_json(
load_balancing_model_config.encrypted_config
)
except (ValueError, JSONDecodeError):
provider_model_credentials = json.loads(load_balancing_model_config.encrypted_config)
except JSONDecodeError:
continue
# Get decoding rsa key and cipher for decrypting credentials
@@ -1180,7 +1176,7 @@ class ProviderManager:
if variable in provider_model_credentials:
try:
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_model_credentials.get(variable) or "",
provider_model_credentials.get(variable),
self.decoding_rsa_key,
self.decoding_cipher_rsa,
)

View File

@@ -15,7 +15,7 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor,
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.embedding.retrieval import AttachmentInfoDict, RetrievalChildChunk, RetrievalSegments
from core.rag.entities import MetadataFilteringCondition
from core.rag.entities.metadata_entities import MetadataCondition
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.query_type import QueryType
@@ -182,9 +182,7 @@ class RetrievalService:
if not dataset:
return []
metadata_condition = (
MetadataFilteringCondition.model_validate(metadata_filtering_conditions)
if metadata_filtering_conditions
else None
MetadataCondition.model_validate(metadata_filtering_conditions) if metadata_filtering_conditions else None
)
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
dataset.tenant_id,
@@ -242,7 +240,7 @@ class RetrievalService:
@classmethod
def _get_dataset(cls, dataset_id: str) -> Dataset | None:
with Session(db.engine) as session:
return session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
return session.query(Dataset).where(Dataset.id == dataset_id).first()
@classmethod
def keyword_search(
@@ -575,13 +573,15 @@ class RetrievalService:
# Batch query summaries for segments retrieved via summary (only enabled summaries)
if summary_segment_ids:
summaries = session.scalars(
select(DocumentSegmentSummary).where(
summaries = (
session.query(DocumentSegmentSummary)
.filter(
DocumentSegmentSummary.chunk_id.in_(list(summary_segment_ids)),
DocumentSegmentSummary.status == "completed",
DocumentSegmentSummary.enabled.is_(True), # Only retrieve enabled summaries
DocumentSegmentSummary.enabled == True, # Only retrieve enabled summaries
)
).all()
.all()
)
for summary in summaries:
if summary.summary_content:
segment_summary_map[summary.chunk_id] = summary.summary_content
@@ -851,12 +851,12 @@ class RetrievalService:
def get_segment_attachment_info(
cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session
) -> SegmentAttachmentResult | None:
upload_file = session.scalar(select(UploadFile).where(UploadFile.id == attachment_id).limit(1))
upload_file = session.query(UploadFile).where(UploadFile.id == attachment_id).first()
if upload_file:
attachment_binding = session.scalar(
select(SegmentAttachmentBinding)
attachment_binding = (
session.query(SegmentAttachmentBinding)
.where(SegmentAttachmentBinding.attachment_id == upload_file.id)
.limit(1)
.first()
)
if attachment_binding:
attachment_info: AttachmentInfoDict = {
@@ -875,12 +875,14 @@ class RetrievalService:
cls, attachment_ids: list[str], session: Session
) -> list[SegmentAttachmentInfoResult]:
attachment_infos: list[SegmentAttachmentInfoResult] = []
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(attachment_ids))).all()
upload_files = session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
if upload_files:
upload_file_ids = [upload_file.id for upload_file in upload_files]
attachment_bindings = session.scalars(
select(SegmentAttachmentBinding).where(SegmentAttachmentBinding.attachment_id.in_(upload_file_ids))
).all()
attachment_bindings = (
session.query(SegmentAttachmentBinding)
.where(SegmentAttachmentBinding.attachment_id.in_(upload_file_ids))
.all()
)
attachment_binding_map = {binding.attachment_id: binding for binding in attachment_bindings}
if attachment_bindings:

View File

@@ -1,5 +1,5 @@
import json
from typing import Any, TypedDict
from typing import Any
from pydantic import BaseModel, model_validator
@@ -13,13 +13,6 @@ from core.rag.models.document import Document
from extensions.ext_redis import redis_client
class AnalyticdbClientParamsDict(TypedDict):
access_key_id: str
access_key_secret: str
region_id: str
read_timeout: int
class AnalyticdbVectorOpenAPIConfig(BaseModel):
access_key_id: str
access_key_secret: str
@@ -51,14 +44,13 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel):
raise ValueError("config ANALYTICDB_NAMESPACE_PASSWORD is required")
return values
def to_analyticdb_client_params(self) -> AnalyticdbClientParamsDict:
result: AnalyticdbClientParamsDict = {
def to_analyticdb_client_params(self):
return {
"access_key_id": self.access_key_id,
"access_key_secret": self.access_key_secret,
"region_id": self.region_id,
"read_timeout": self.read_timeout,
}
return result
class AnalyticdbVectorOpenAPI:

View File

@@ -30,7 +30,7 @@ from pymochow.model.table import AnnSearch, BM25SearchRequest, HNSWSearchParams,
from configs import dify_config
from core.rag.datasource.vdb.field import Field as VDBField
from core.rag.datasource.vdb.field import parse_metadata_json
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
@@ -85,12 +85,8 @@ class BaiduVector(BaseVector):
def get_type(self) -> str:
return VectorType.BAIDU
def to_index_struct(self) -> VectorIndexStructDict:
result: VectorIndexStructDict = {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name},
}
return result
def to_index_struct(self):
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self._create_table(len(embeddings[0]))

View File

@@ -1,12 +1,12 @@
import json
from typing import Any, TypedDict
from typing import Any
import chromadb
from chromadb import QueryResult, Settings
from pydantic import BaseModel
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
@@ -15,15 +15,6 @@ from extensions.ext_redis import redis_client
from models.dataset import Dataset
class ChromaParamsDict(TypedDict):
host: str
port: int
ssl: bool
tenant: str
database: str
settings: Settings
class ChromaConfig(BaseModel):
host: str
port: int
@@ -32,13 +23,14 @@ class ChromaConfig(BaseModel):
auth_provider: str | None = None
auth_credentials: str | None = None
def to_chroma_params(self) -> ChromaParamsDict:
def to_chroma_params(self):
settings = Settings(
# auth
chroma_client_auth_provider=self.auth_provider,
chroma_client_auth_credentials=self.auth_credentials,
)
result: ChromaParamsDict = {
return {
"host": self.host,
"port": self.port,
"ssl": False,
@@ -46,7 +38,6 @@ class ChromaConfig(BaseModel):
"database": self.database,
"settings": settings,
}
return result
class ChromaVector(BaseVector):
@@ -154,10 +145,7 @@ class ChromaVectorFactory(AbstractVectorFactory):
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
index_struct_dict: VectorIndexStructDict = {
"type": VectorType.CHROMA,
"vector_store": {"class_prefix": collection_name},
}
index_struct_dict = {"type": VectorType.CHROMA, "vector_store": {"class_prefix": collection_name}}
dataset.index_struct = json.dumps(index_struct_dict)
return ChromaVector(

View File

@@ -1,6 +1,6 @@
import json
import logging
from typing import Any, TypedDict
from typing import Any
from packaging import version
from pydantic import BaseModel, model_validator
@@ -20,15 +20,6 @@ from models.dataset import Dataset
logger = logging.getLogger(__name__)
class MilvusParamsDict(TypedDict):
uri: str
token: str | None
user: str | None
password: str | None
db_name: str
analyzer_params: str | None
class MilvusConfig(BaseModel):
"""
Configuration class for Milvus connection.
@@ -59,11 +50,11 @@ class MilvusConfig(BaseModel):
raise ValueError("config MILVUS_PASSWORD is required")
return values
def to_milvus_params(self) -> MilvusParamsDict:
def to_milvus_params(self):
"""
Convert the configuration to a dictionary of Milvus connection parameters.
"""
result: MilvusParamsDict = {
return {
"uri": self.uri,
"token": self.token,
"user": self.user,
@@ -71,7 +62,6 @@ class MilvusConfig(BaseModel):
"db_name": self.database,
"analyzer_params": self.analyzer_params,
}
return result
class MilvusVector(BaseVector):
@@ -362,7 +352,6 @@ class MilvusVector(BaseVector):
# Create Index params for the collection
index_params_obj = IndexParams()
assert index_params is not None
index_params_obj.add_index(field_name=Field.VECTOR, **index_params)
# Create Sparse Vector Index for the collection

View File

@@ -22,7 +22,7 @@ from sqlalchemy import select
from configs import dify_config
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
@@ -94,12 +94,8 @@ class QdrantVector(BaseVector):
def get_type(self) -> str:
return VectorType.QDRANT
def to_index_struct(self) -> VectorIndexStructDict:
result: VectorIndexStructDict = {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name},
}
return result
def to_index_struct(self):
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
if texts:

View File

@@ -1,7 +1,7 @@
import json
import logging
import math
from typing import Any, TypedDict
from typing import Any
from pydantic import BaseModel
from tcvdb_text.encoder import BM25Encoder # type: ignore
@@ -12,7 +12,7 @@ from tcvectordb.model.document import AnnSearch, Filter, KeywordSearch, Weighted
from configs import dify_config
from core.rag.datasource.vdb.field import parse_metadata_json
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
@@ -23,13 +23,6 @@ from models.dataset import Dataset
logger = logging.getLogger(__name__)
class TencentParamsDict(TypedDict):
url: str
username: str | None
key: str | None
timeout: float
class TencentConfig(BaseModel):
url: str
api_key: str | None = None
@@ -43,14 +36,8 @@ class TencentConfig(BaseModel):
max_upsert_batch_size: int = 128
enable_hybrid_search: bool = False # Flag to enable hybrid search
def to_tencent_params(self) -> TencentParamsDict:
result: TencentParamsDict = {
"url": self.url,
"username": self.username,
"key": self.api_key,
"timeout": self.timeout,
}
return result
def to_tencent_params(self):
return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout}
bm25 = BM25Encoder.default("zh")
@@ -96,12 +83,8 @@ class TencentVector(BaseVector):
def get_type(self) -> str:
return VectorType.TENCENT
def to_index_struct(self) -> VectorIndexStructDict:
result: VectorIndexStructDict = {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name},
}
return result
def to_index_struct(self):
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def _has_collection(self) -> bool:
return bool(

View File

@@ -25,7 +25,7 @@ from sqlalchemy import select
from configs import dify_config
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
@@ -91,12 +91,8 @@ class TidbOnQdrantVector(BaseVector):
def get_type(self) -> str:
return VectorType.TIDB_ON_QDRANT
def to_index_struct(self) -> VectorIndexStructDict:
result: VectorIndexStructDict = {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name},
}
return result
def to_index_struct(self):
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
if texts:

View File

@@ -1,20 +1,11 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, TypedDict
from typing import Any
from core.rag.models.document import Document
class VectorStoreDict(TypedDict):
class_prefix: str
class VectorIndexStructDict(TypedDict):
type: str
vector_store: VectorStoreDict
class BaseVector(ABC):
def __init__(self, collection_name: str):
self._collection_name = collection_name

View File

@@ -9,7 +9,7 @@ from sqlalchemy import select
from configs import dify_config
from core.model_manager import ModelManager
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.cached_embedding import CacheEmbedding
from core.rag.embedding.embedding_base import Embeddings
@@ -30,11 +30,8 @@ class AbstractVectorFactory(ABC):
raise NotImplementedError
@staticmethod
def gen_index_struct_dict(vector_type: VectorType, collection_name: str) -> VectorIndexStructDict:
index_struct_dict: VectorIndexStructDict = {
"type": vector_type,
"vector_store": {"class_prefix": collection_name},
}
def gen_index_struct_dict(vector_type: VectorType, collection_name: str):
index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}}
return index_struct_dict

View File

@@ -24,7 +24,7 @@ from weaviate.exceptions import UnexpectedStatusCodeError
from configs import dify_config
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
@@ -184,13 +184,9 @@ class WeaviateVector(BaseVector):
dataset_id = dataset.id
return Dataset.gen_collection_name_by_id(dataset_id)
def to_index_struct(self) -> VectorIndexStructDict:
def to_index_struct(self) -> dict:
"""Returns the index structure dictionary for persistence."""
result: VectorIndexStructDict = {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name},
}
return result
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
"""

View File

@@ -1,28 +0,0 @@
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext
from core.rag.entities.event import DatasourceCompletedEvent, DatasourceErrorEvent, DatasourceProcessingEvent
from core.rag.entities.index_entities import EconomySetting, EmbeddingSetting, IndexMethod
from core.rag.entities.metadata_entities import Condition, MetadataFilteringCondition, SupportedComparisonOperator
from core.rag.entities.processing_entities import ParentMode, PreProcessingRule, Rule, Segmentation
from core.rag.entities.retrieval_settings import KeywordSetting, VectorSetting, WeightedScoreConfig
__all__ = [
"Condition",
"DatasourceCompletedEvent",
"DatasourceErrorEvent",
"DatasourceProcessingEvent",
"DocumentContext",
"EconomySetting",
"EmbeddingSetting",
"IndexMethod",
"KeywordSetting",
"MetadataFilteringCondition",
"ParentMode",
"PreProcessingRule",
"RetrievalSourceMetadata",
"Rule",
"Segmentation",
"SupportedComparisonOperator",
"VectorSetting",
"WeightedScoreConfig",
]

View File

@@ -1,30 +0,0 @@
from typing import Literal
from pydantic import BaseModel
class EmbeddingSetting(BaseModel):
"""
Embedding Setting.
"""
embedding_provider_name: str
embedding_model_name: str
class EconomySetting(BaseModel):
"""
Economy Setting.
"""
keyword_number: int
class IndexMethod(BaseModel):
"""
Knowledge Index Setting.
"""
indexing_technique: Literal["high_quality", "economy"]
embedding_setting: EmbeddingSetting
economy_setting: EconomySetting

View File

@@ -38,9 +38,9 @@ class Condition(BaseModel):
value: str | Sequence[str] | None | int | float = None
class MetadataFilteringCondition(BaseModel):
class MetadataCondition(BaseModel):
"""
Metadata Filtering Condition.
Metadata Condition.
"""
logical_operator: Literal["and", "or"] | None = "and"

View File

@@ -1,27 +0,0 @@
from enum import StrEnum
from typing import Literal
from pydantic import BaseModel
class ParentMode(StrEnum):
FULL_DOC = "full-doc"
PARAGRAPH = "paragraph"
class PreProcessingRule(BaseModel):
id: str
enabled: bool
class Segmentation(BaseModel):
separator: str = "\n"
max_tokens: int
chunk_overlap: int = 0
class Rule(BaseModel):
pre_processing_rules: list[PreProcessingRule] | None = None
segmentation: Segmentation | None = None
parent_mode: Literal["full-doc", "paragraph"] | None = None
subchunk_segmentation: Segmentation | None = None

View File

@@ -1,28 +0,0 @@
from pydantic import BaseModel
class VectorSetting(BaseModel):
"""
Vector Setting.
"""
vector_weight: float
embedding_provider_name: str
embedding_model_name: str
class KeywordSetting(BaseModel):
"""
Keyword Setting.
"""
keyword_weight: float
class WeightedScoreConfig(BaseModel):
"""
Weighted score Config.
"""
vector_setting: VectorSetting
keyword_setting: KeywordSetting

View File

@@ -12,7 +12,7 @@ from core.db.session_factory import session_factory
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError
from core.workflow.nodes.knowledge_index.protocols import IndexingResultDict, Preview, PreviewItem, QaPreview
from core.workflow.nodes.knowledge_index.protocols import Preview, PreviewItem, QaPreview
from models.dataset import Dataset, Document, DocumentSegment
from .index_processor_factory import IndexProcessorFactory
@@ -61,7 +61,7 @@ class IndexProcessor:
chunks: Mapping[str, Any],
batch: Any,
summary_index_setting: SummaryIndexSettingDict | None = None,
) -> IndexingResultDict:
):
with session_factory.create_session() as session:
document = session.query(Document).filter_by(id=document_id).first()
if not document:
@@ -129,7 +129,7 @@ class IndexProcessor:
}
)
result: IndexingResultDict = {
return {
"dataset_id": dataset_id,
"dataset_name": dataset_name_value,
"batch": batch,
@@ -138,7 +138,6 @@ class IndexProcessor:
"created_at": created_at_value.timestamp(),
"display_status": "completed",
}
return result
def get_preview_output(
self,

View File

@@ -32,7 +32,6 @@ from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.entities import Rule
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.doc_type import DocType
@@ -50,6 +49,7 @@ from models.account import Account
from models.dataset import Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument
from services.account_service import AccountService
from services.entities.knowledge_entities.knowledge_entities import Rule
from services.summary_index_service import SummaryIndexService
_file_access_controller = DatabaseFileAccessController()

View File

@@ -17,7 +17,6 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.entities import ParentMode, Rule
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.doc_type import DocType
@@ -31,6 +30,7 @@ from models import Account
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.account_service import AccountService
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
from services.summary_index_service import SummaryIndexService
logger = logging.getLogger(__name__)

View File

@@ -19,7 +19,6 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.entities import Rule
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
@@ -31,6 +30,7 @@ from libs import helper
from models.account import Account
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import Rule
from services.summary_index_service import SummaryIndexService
logger = logging.getLogger(__name__)

View File

@@ -1,6 +1,16 @@
from pydantic import BaseModel
from core.rag.entities import KeywordSetting, VectorSetting
class VectorSetting(BaseModel):
vector_weight: float
embedding_provider_name: str
embedding_model_name: str
class KeywordSetting(BaseModel):
keyword_weight: float
class Weights(BaseModel):

Some files were not shown because too many files have changed in this diff Show More