mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 20:22:39 +08:00
Compare commits
5 Commits
refactor/r
...
deploy/tri
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7dfe615613 | ||
|
|
a1a3fa0283 | ||
|
|
ff7344f3d3 | ||
|
|
bcd33be22a | ||
|
|
991f31f195 |
@@ -1,9 +0,0 @@
|
||||
{
|
||||
"enabledPlugins": {
|
||||
"feature-dev@claude-plugins-official": true,
|
||||
"context7@claude-plugins-official": true,
|
||||
"typescript-lsp@claude-plugins-official": true,
|
||||
"pyright-lsp@claude-plugins-official": true,
|
||||
"ralph-wiggum@claude-plugins-official": true
|
||||
}
|
||||
}
|
||||
19
.claude/settings.json.example
Normal file
19
.claude/settings.json.example
Normal file
@@ -0,0 +1,19 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [],
|
||||
"deny": []
|
||||
},
|
||||
"env": {
|
||||
"__comment": "Environment variables for MCP servers. Override in .claude/settings.local.json with actual values.",
|
||||
"GITHUB_PERSONAL_ACCESS_TOKEN": "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||||
},
|
||||
"enabledMcpjsonServers": [
|
||||
"context7",
|
||||
"sequential-thinking",
|
||||
"github",
|
||||
"fetch",
|
||||
"playwright",
|
||||
"ide"
|
||||
],
|
||||
"enableAllProjectMcpServers": true
|
||||
}
|
||||
@@ -1,483 +0,0 @@
|
||||
---
|
||||
name: component-refactoring
|
||||
description: Refactor high-complexity React components in Dify frontend. Use when `pnpm analyze-component --json` shows complexity > 50 or lineCount > 300, when the user asks for code splitting, hook extraction, or complexity reduction, or when `pnpm analyze-component` warns to refactor before testing; avoid for simple/well-structured components, third-party wrappers, or when the user explicitly wants testing without refactoring.
|
||||
---
|
||||
|
||||
# Dify Component Refactoring Skill
|
||||
|
||||
Refactor high-complexity React components in the Dify frontend codebase with the patterns and workflow below.
|
||||
|
||||
> **Complexity Threshold**: Components with complexity > 50 (measured by `pnpm analyze-component`) should be refactored before testing.
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### Commands (run from `web/`)
|
||||
|
||||
Use paths relative to `web/` (e.g., `app/components/...`).
|
||||
Use `refactor-component` for refactoring prompts and `analyze-component` for testing prompts and metrics.
|
||||
|
||||
```bash
|
||||
cd web
|
||||
|
||||
# Generate refactoring prompt
|
||||
pnpm refactor-component <path>
|
||||
|
||||
# Output refactoring analysis as JSON
|
||||
pnpm refactor-component <path> --json
|
||||
|
||||
# Generate testing prompt (after refactoring)
|
||||
pnpm analyze-component <path>
|
||||
|
||||
# Output testing analysis as JSON
|
||||
pnpm analyze-component <path> --json
|
||||
```
|
||||
|
||||
### Complexity Analysis
|
||||
|
||||
```bash
|
||||
# Analyze component complexity
|
||||
pnpm analyze-component <path> --json
|
||||
|
||||
# Key metrics to check:
|
||||
# - complexity: normalized score 0-100 (target < 50)
|
||||
# - maxComplexity: highest single function complexity
|
||||
# - lineCount: total lines (target < 300)
|
||||
```
|
||||
|
||||
### Complexity Score Interpretation
|
||||
|
||||
| Score | Level | Action |
|
||||
|-------|-------|--------|
|
||||
| 0-25 | 🟢 Simple | Ready for testing |
|
||||
| 26-50 | 🟡 Medium | Consider minor refactoring |
|
||||
| 51-75 | 🟠 Complex | **Refactor before testing** |
|
||||
| 76-100 | 🔴 Very Complex | **Must refactor** |
|
||||
|
||||
## Core Refactoring Patterns
|
||||
|
||||
### Pattern 1: Extract Custom Hooks
|
||||
|
||||
**When**: Component has complex state management, multiple `useState`/`useEffect`, or business logic mixed with UI.
|
||||
|
||||
**Dify Convention**: Place hooks in a `hooks/` subdirectory or alongside the component as `use-<feature>.ts`.
|
||||
|
||||
```typescript
|
||||
// ❌ Before: Complex state logic in component
|
||||
const Configuration: FC = () => {
|
||||
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
|
||||
const [datasetConfigs, setDatasetConfigs] = useState<DatasetConfigs>(...)
|
||||
const [completionParams, setCompletionParams] = useState<FormValue>({})
|
||||
|
||||
// 50+ lines of state management logic...
|
||||
|
||||
return <div>...</div>
|
||||
}
|
||||
|
||||
// ✅ After: Extract to custom hook
|
||||
// hooks/use-model-config.ts
|
||||
export const useModelConfig = (appId: string) => {
|
||||
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
|
||||
const [completionParams, setCompletionParams] = useState<FormValue>({})
|
||||
|
||||
// Related state management logic here
|
||||
|
||||
return { modelConfig, setModelConfig, completionParams, setCompletionParams }
|
||||
}
|
||||
|
||||
// Component becomes cleaner
|
||||
const Configuration: FC = () => {
|
||||
const { modelConfig, setModelConfig } = useModelConfig(appId)
|
||||
return <div>...</div>
|
||||
}
|
||||
```
|
||||
|
||||
**Dify Examples**:
|
||||
- `web/app/components/app/configuration/hooks/use-advanced-prompt-config.ts`
|
||||
- `web/app/components/app/configuration/debug/hooks.tsx`
|
||||
- `web/app/components/workflow/hooks/use-workflow.ts`
|
||||
|
||||
### Pattern 2: Extract Sub-Components
|
||||
|
||||
**When**: Single component has multiple UI sections, conditional rendering blocks, or repeated patterns.
|
||||
|
||||
**Dify Convention**: Place sub-components in subdirectories or as separate files in the same directory.
|
||||
|
||||
```typescript
|
||||
// ❌ Before: Monolithic JSX with multiple sections
|
||||
const AppInfo = () => {
|
||||
return (
|
||||
<div>
|
||||
{/* 100 lines of header UI */}
|
||||
{/* 100 lines of operations UI */}
|
||||
{/* 100 lines of modals */}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// ✅ After: Split into focused components
|
||||
// app-info/
|
||||
// ├── index.tsx (orchestration only)
|
||||
// ├── app-header.tsx (header UI)
|
||||
// ├── app-operations.tsx (operations UI)
|
||||
// └── app-modals.tsx (modal management)
|
||||
|
||||
const AppInfo = () => {
|
||||
const { showModal, setShowModal } = useAppInfoModals()
|
||||
|
||||
return (
|
||||
<div>
|
||||
<AppHeader appDetail={appDetail} />
|
||||
<AppOperations onAction={handleAction} />
|
||||
<AppModals show={showModal} onClose={() => setShowModal(null)} />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**Dify Examples**:
|
||||
- `web/app/components/app/configuration/` directory structure
|
||||
- `web/app/components/workflow/nodes/` per-node organization
|
||||
|
||||
### Pattern 3: Simplify Conditional Logic
|
||||
|
||||
**When**: Deep nesting (> 3 levels), complex ternaries, or multiple `if/else` chains.
|
||||
|
||||
```typescript
|
||||
// ❌ Before: Deeply nested conditionals
|
||||
const Template = useMemo(() => {
|
||||
if (appDetail?.mode === AppModeEnum.CHAT) {
|
||||
switch (locale) {
|
||||
case LanguagesSupported[1]:
|
||||
return <TemplateChatZh />
|
||||
case LanguagesSupported[7]:
|
||||
return <TemplateChatJa />
|
||||
default:
|
||||
return <TemplateChatEn />
|
||||
}
|
||||
}
|
||||
if (appDetail?.mode === AppModeEnum.ADVANCED_CHAT) {
|
||||
// Another 15 lines...
|
||||
}
|
||||
// More conditions...
|
||||
}, [appDetail, locale])
|
||||
|
||||
// ✅ After: Use lookup tables + early returns
|
||||
const TEMPLATE_MAP = {
|
||||
[AppModeEnum.CHAT]: {
|
||||
[LanguagesSupported[1]]: TemplateChatZh,
|
||||
[LanguagesSupported[7]]: TemplateChatJa,
|
||||
default: TemplateChatEn,
|
||||
},
|
||||
[AppModeEnum.ADVANCED_CHAT]: {
|
||||
[LanguagesSupported[1]]: TemplateAdvancedChatZh,
|
||||
// ...
|
||||
},
|
||||
}
|
||||
|
||||
const Template = useMemo(() => {
|
||||
const modeTemplates = TEMPLATE_MAP[appDetail?.mode]
|
||||
if (!modeTemplates) return null
|
||||
|
||||
const TemplateComponent = modeTemplates[locale] || modeTemplates.default
|
||||
return <TemplateComponent appDetail={appDetail} />
|
||||
}, [appDetail, locale])
|
||||
```
|
||||
|
||||
### Pattern 4: Extract API/Data Logic
|
||||
|
||||
**When**: Component directly handles API calls, data transformation, or complex async operations.
|
||||
|
||||
**Dify Convention**: Use `@tanstack/react-query` hooks from `web/service/use-*.ts` or create custom data hooks.
|
||||
|
||||
```typescript
|
||||
// ❌ Before: API logic in component
|
||||
const MCPServiceCard = () => {
|
||||
const [basicAppConfig, setBasicAppConfig] = useState({})
|
||||
|
||||
useEffect(() => {
|
||||
if (isBasicApp && appId) {
|
||||
(async () => {
|
||||
const res = await fetchAppDetail({ url: '/apps', id: appId })
|
||||
setBasicAppConfig(res?.model_config || {})
|
||||
})()
|
||||
}
|
||||
}, [appId, isBasicApp])
|
||||
|
||||
// More API-related logic...
|
||||
}
|
||||
|
||||
// ✅ After: Extract to data hook using React Query
|
||||
// use-app-config.ts
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { get } from '@/service/base'
|
||||
|
||||
const NAME_SPACE = 'appConfig'
|
||||
|
||||
export const useAppConfig = (appId: string, isBasicApp: boolean) => {
|
||||
return useQuery({
|
||||
enabled: isBasicApp && !!appId,
|
||||
queryKey: [NAME_SPACE, 'detail', appId],
|
||||
queryFn: () => get<AppDetailResponse>(`/apps/${appId}`),
|
||||
select: data => data?.model_config || {},
|
||||
})
|
||||
}
|
||||
|
||||
// Component becomes cleaner
|
||||
const MCPServiceCard = () => {
|
||||
const { data: config, isLoading } = useAppConfig(appId, isBasicApp)
|
||||
// UI only
|
||||
}
|
||||
```
|
||||
|
||||
**React Query Best Practices in Dify**:
|
||||
- Define `NAME_SPACE` for query key organization
|
||||
- Use `enabled` option for conditional fetching
|
||||
- Use `select` for data transformation
|
||||
- Export invalidation hooks: `useInvalidXxx`
|
||||
|
||||
**Dify Examples**:
|
||||
- `web/service/use-workflow.ts`
|
||||
- `web/service/use-common.ts`
|
||||
- `web/service/knowledge/use-dataset.ts`
|
||||
- `web/service/knowledge/use-document.ts`
|
||||
|
||||
### Pattern 5: Extract Modal/Dialog Management
|
||||
|
||||
**When**: Component manages multiple modals with complex open/close states.
|
||||
|
||||
**Dify Convention**: Modals should be extracted with their state management.
|
||||
|
||||
```typescript
|
||||
// ❌ Before: Multiple modal states in component
|
||||
const AppInfo = () => {
|
||||
const [showEditModal, setShowEditModal] = useState(false)
|
||||
const [showDuplicateModal, setShowDuplicateModal] = useState(false)
|
||||
const [showConfirmDelete, setShowConfirmDelete] = useState(false)
|
||||
const [showSwitchModal, setShowSwitchModal] = useState(false)
|
||||
const [showImportDSLModal, setShowImportDSLModal] = useState(false)
|
||||
// 5+ more modal states...
|
||||
}
|
||||
|
||||
// ✅ After: Extract to modal management hook
|
||||
type ModalType = 'edit' | 'duplicate' | 'delete' | 'switch' | 'import' | null
|
||||
|
||||
const useAppInfoModals = () => {
|
||||
const [activeModal, setActiveModal] = useState<ModalType>(null)
|
||||
|
||||
const openModal = useCallback((type: ModalType) => setActiveModal(type), [])
|
||||
const closeModal = useCallback(() => setActiveModal(null), [])
|
||||
|
||||
return {
|
||||
activeModal,
|
||||
openModal,
|
||||
closeModal,
|
||||
isOpen: (type: ModalType) => activeModal === type,
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Pattern 6: Extract Form Logic
|
||||
|
||||
**When**: Complex form validation, submission handling, or field transformation.
|
||||
|
||||
**Dify Convention**: Use `@tanstack/react-form` patterns from `web/app/components/base/form/`.
|
||||
|
||||
```typescript
|
||||
// ✅ Use existing form infrastructure
|
||||
import { useAppForm } from '@/app/components/base/form'
|
||||
|
||||
const ConfigForm = () => {
|
||||
const form = useAppForm({
|
||||
defaultValues: { name: '', description: '' },
|
||||
onSubmit: handleSubmit,
|
||||
})
|
||||
|
||||
return <form.Provider>...</form.Provider>
|
||||
}
|
||||
```
|
||||
|
||||
## Dify-Specific Refactoring Guidelines
|
||||
|
||||
### 1. Context Provider Extraction
|
||||
|
||||
**When**: Component provides complex context values with multiple states.
|
||||
|
||||
```typescript
|
||||
// ❌ Before: Large context value object
|
||||
const value = {
|
||||
appId, isAPIKeySet, isTrailFinished, mode, modelModeType,
|
||||
promptMode, isAdvancedMode, isAgent, isOpenAI, isFunctionCall,
|
||||
// 50+ more properties...
|
||||
}
|
||||
return <ConfigContext.Provider value={value}>...</ConfigContext.Provider>
|
||||
|
||||
// ✅ After: Split into domain-specific contexts
|
||||
<ModelConfigProvider value={modelConfigValue}>
|
||||
<DatasetConfigProvider value={datasetConfigValue}>
|
||||
<UIConfigProvider value={uiConfigValue}>
|
||||
{children}
|
||||
</UIConfigProvider>
|
||||
</DatasetConfigProvider>
|
||||
</ModelConfigProvider>
|
||||
```
|
||||
|
||||
**Dify Reference**: `web/context/` directory structure
|
||||
|
||||
### 2. Workflow Node Components
|
||||
|
||||
**When**: Refactoring workflow node components (`web/app/components/workflow/nodes/`).
|
||||
|
||||
**Conventions**:
|
||||
- Keep node logic in `use-interactions.ts`
|
||||
- Extract panel UI to separate files
|
||||
- Use `_base` components for common patterns
|
||||
|
||||
```
|
||||
nodes/<node-type>/
|
||||
├── index.tsx # Node registration
|
||||
├── node.tsx # Node visual component
|
||||
├── panel.tsx # Configuration panel
|
||||
├── use-interactions.ts # Node-specific hooks
|
||||
└── types.ts # Type definitions
|
||||
```
|
||||
|
||||
### 3. Configuration Components
|
||||
|
||||
**When**: Refactoring app configuration components.
|
||||
|
||||
**Conventions**:
|
||||
- Separate config sections into subdirectories
|
||||
- Use existing patterns from `web/app/components/app/configuration/`
|
||||
- Keep feature toggles in dedicated components
|
||||
|
||||
### 4. Tool/Plugin Components
|
||||
|
||||
**When**: Refactoring tool-related components (`web/app/components/tools/`).
|
||||
|
||||
**Conventions**:
|
||||
- Follow existing modal patterns
|
||||
- Use service hooks from `web/service/use-tools.ts`
|
||||
- Keep provider-specific logic isolated
|
||||
|
||||
## Refactoring Workflow
|
||||
|
||||
### Step 1: Generate Refactoring Prompt
|
||||
|
||||
```bash
|
||||
pnpm refactor-component <path>
|
||||
```
|
||||
|
||||
This command will:
|
||||
- Analyze component complexity and features
|
||||
- Identify specific refactoring actions needed
|
||||
- Generate a prompt for AI assistant (auto-copied to clipboard on macOS)
|
||||
- Provide detailed requirements based on detected patterns
|
||||
|
||||
### Step 2: Analyze Details
|
||||
|
||||
```bash
|
||||
pnpm analyze-component <path> --json
|
||||
```
|
||||
|
||||
Identify:
|
||||
- Total complexity score
|
||||
- Max function complexity
|
||||
- Line count
|
||||
- Features detected (state, effects, API, etc.)
|
||||
|
||||
### Step 3: Plan
|
||||
|
||||
Create a refactoring plan based on detected features:
|
||||
|
||||
| Detected Feature | Refactoring Action |
|
||||
|------------------|-------------------|
|
||||
| `hasState: true` + `hasEffects: true` | Extract custom hook |
|
||||
| `hasAPI: true` | Extract data/service hook |
|
||||
| `hasEvents: true` (many) | Extract event handlers |
|
||||
| `lineCount > 300` | Split into sub-components |
|
||||
| `maxComplexity > 50` | Simplify conditional logic |
|
||||
|
||||
### Step 4: Execute Incrementally
|
||||
|
||||
1. **Extract one piece at a time**
|
||||
2. **Run lint, type-check, and tests after each extraction**
|
||||
3. **Verify functionality before next step**
|
||||
|
||||
```
|
||||
For each extraction:
|
||||
┌────────────────────────────────────────┐
|
||||
│ 1. Extract code │
|
||||
│ 2. Run: pnpm lint:fix │
|
||||
│ 3. Run: pnpm type-check:tsgo │
|
||||
│ 4. Run: pnpm test │
|
||||
│ 5. Test functionality manually │
|
||||
│ 6. PASS? → Next extraction │
|
||||
│ FAIL? → Fix before continuing │
|
||||
└────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Step 5: Verify
|
||||
|
||||
After refactoring:
|
||||
|
||||
```bash
|
||||
# Re-run refactor command to verify improvements
|
||||
pnpm refactor-component <path>
|
||||
|
||||
# If complexity < 25 and lines < 200, you'll see:
|
||||
# ✅ COMPONENT IS WELL-STRUCTURED
|
||||
|
||||
# For detailed metrics:
|
||||
pnpm analyze-component <path> --json
|
||||
|
||||
# Target metrics:
|
||||
# - complexity < 50
|
||||
# - lineCount < 300
|
||||
# - maxComplexity < 30
|
||||
```
|
||||
|
||||
## Common Mistakes to Avoid
|
||||
|
||||
### ❌ Over-Engineering
|
||||
|
||||
```typescript
|
||||
// ❌ Too many tiny hooks
|
||||
const useButtonText = () => useState('Click')
|
||||
const useButtonDisabled = () => useState(false)
|
||||
const useButtonLoading = () => useState(false)
|
||||
|
||||
// ✅ Cohesive hook with related state
|
||||
const useButtonState = () => {
|
||||
const [text, setText] = useState('Click')
|
||||
const [disabled, setDisabled] = useState(false)
|
||||
const [loading, setLoading] = useState(false)
|
||||
return { text, setText, disabled, setDisabled, loading, setLoading }
|
||||
}
|
||||
```
|
||||
|
||||
### ❌ Breaking Existing Patterns
|
||||
|
||||
- Follow existing directory structures
|
||||
- Maintain naming conventions
|
||||
- Preserve export patterns for compatibility
|
||||
|
||||
### ❌ Premature Abstraction
|
||||
|
||||
- Only extract when there's clear complexity benefit
|
||||
- Don't create abstractions for single-use code
|
||||
- Keep refactored code in the same domain area
|
||||
|
||||
## References
|
||||
|
||||
### Dify Codebase Examples
|
||||
|
||||
- **Hook extraction**: `web/app/components/app/configuration/hooks/`
|
||||
- **Component splitting**: `web/app/components/app/configuration/`
|
||||
- **Service hooks**: `web/service/use-*.ts`
|
||||
- **Workflow patterns**: `web/app/components/workflow/hooks/`
|
||||
- **Form patterns**: `web/app/components/base/form/`
|
||||
|
||||
### Related Skills
|
||||
|
||||
- `frontend-testing` - For testing refactored components
|
||||
- `web/testing/testing.md` - Testing specification
|
||||
@@ -1,493 +0,0 @@
|
||||
# Complexity Reduction Patterns
|
||||
|
||||
This document provides patterns for reducing cognitive complexity in Dify React components.
|
||||
|
||||
## Understanding Complexity
|
||||
|
||||
### SonarJS Cognitive Complexity
|
||||
|
||||
The `pnpm analyze-component` tool uses SonarJS cognitive complexity metrics:
|
||||
|
||||
- **Total Complexity**: Sum of all functions' complexity in the file
|
||||
- **Max Complexity**: Highest single function complexity
|
||||
|
||||
### What Increases Complexity
|
||||
|
||||
| Pattern | Complexity Impact |
|
||||
|---------|-------------------|
|
||||
| `if/else` | +1 per branch |
|
||||
| Nested conditions | +1 per nesting level |
|
||||
| `switch/case` | +1 per case |
|
||||
| `for/while/do` | +1 per loop |
|
||||
| `&&`/`||` chains | +1 per operator |
|
||||
| Nested callbacks | +1 per nesting level |
|
||||
| `try/catch` | +1 per catch |
|
||||
| Ternary expressions | +1 per nesting |
|
||||
|
||||
## Pattern 1: Replace Conditionals with Lookup Tables
|
||||
|
||||
**Before** (complexity: ~15):
|
||||
|
||||
```typescript
|
||||
const Template = useMemo(() => {
|
||||
if (appDetail?.mode === AppModeEnum.CHAT) {
|
||||
switch (locale) {
|
||||
case LanguagesSupported[1]:
|
||||
return <TemplateChatZh appDetail={appDetail} />
|
||||
case LanguagesSupported[7]:
|
||||
return <TemplateChatJa appDetail={appDetail} />
|
||||
default:
|
||||
return <TemplateChatEn appDetail={appDetail} />
|
||||
}
|
||||
}
|
||||
if (appDetail?.mode === AppModeEnum.ADVANCED_CHAT) {
|
||||
switch (locale) {
|
||||
case LanguagesSupported[1]:
|
||||
return <TemplateAdvancedChatZh appDetail={appDetail} />
|
||||
case LanguagesSupported[7]:
|
||||
return <TemplateAdvancedChatJa appDetail={appDetail} />
|
||||
default:
|
||||
return <TemplateAdvancedChatEn appDetail={appDetail} />
|
||||
}
|
||||
}
|
||||
if (appDetail?.mode === AppModeEnum.WORKFLOW) {
|
||||
// Similar pattern...
|
||||
}
|
||||
return null
|
||||
}, [appDetail, locale])
|
||||
```
|
||||
|
||||
**After** (complexity: ~3):
|
||||
|
||||
```typescript
|
||||
// Define lookup table outside component
|
||||
const TEMPLATE_MAP: Record<AppModeEnum, Record<string, FC<TemplateProps>>> = {
|
||||
[AppModeEnum.CHAT]: {
|
||||
[LanguagesSupported[1]]: TemplateChatZh,
|
||||
[LanguagesSupported[7]]: TemplateChatJa,
|
||||
default: TemplateChatEn,
|
||||
},
|
||||
[AppModeEnum.ADVANCED_CHAT]: {
|
||||
[LanguagesSupported[1]]: TemplateAdvancedChatZh,
|
||||
[LanguagesSupported[7]]: TemplateAdvancedChatJa,
|
||||
default: TemplateAdvancedChatEn,
|
||||
},
|
||||
[AppModeEnum.WORKFLOW]: {
|
||||
[LanguagesSupported[1]]: TemplateWorkflowZh,
|
||||
[LanguagesSupported[7]]: TemplateWorkflowJa,
|
||||
default: TemplateWorkflowEn,
|
||||
},
|
||||
// ...
|
||||
}
|
||||
|
||||
// Clean component logic
|
||||
const Template = useMemo(() => {
|
||||
if (!appDetail?.mode) return null
|
||||
|
||||
const templates = TEMPLATE_MAP[appDetail.mode]
|
||||
if (!templates) return null
|
||||
|
||||
const TemplateComponent = templates[locale] ?? templates.default
|
||||
return <TemplateComponent appDetail={appDetail} />
|
||||
}, [appDetail, locale])
|
||||
```
|
||||
|
||||
## Pattern 2: Use Early Returns
|
||||
|
||||
**Before** (complexity: ~10):
|
||||
|
||||
```typescript
|
||||
const handleSubmit = () => {
|
||||
if (isValid) {
|
||||
if (hasChanges) {
|
||||
if (isConnected) {
|
||||
submitData()
|
||||
} else {
|
||||
showConnectionError()
|
||||
}
|
||||
} else {
|
||||
showNoChangesMessage()
|
||||
}
|
||||
} else {
|
||||
showValidationError()
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**After** (complexity: ~4):
|
||||
|
||||
```typescript
|
||||
const handleSubmit = () => {
|
||||
if (!isValid) {
|
||||
showValidationError()
|
||||
return
|
||||
}
|
||||
|
||||
if (!hasChanges) {
|
||||
showNoChangesMessage()
|
||||
return
|
||||
}
|
||||
|
||||
if (!isConnected) {
|
||||
showConnectionError()
|
||||
return
|
||||
}
|
||||
|
||||
submitData()
|
||||
}
|
||||
```
|
||||
|
||||
## Pattern 3: Extract Complex Conditions
|
||||
|
||||
**Before** (complexity: high):
|
||||
|
||||
```typescript
|
||||
const canPublish = (() => {
|
||||
if (mode !== AppModeEnum.COMPLETION) {
|
||||
if (!isAdvancedMode)
|
||||
return true
|
||||
|
||||
if (modelModeType === ModelModeType.completion) {
|
||||
if (!hasSetBlockStatus.history || !hasSetBlockStatus.query)
|
||||
return false
|
||||
return true
|
||||
}
|
||||
return true
|
||||
}
|
||||
return !promptEmpty
|
||||
})()
|
||||
```
|
||||
|
||||
**After** (complexity: lower):
|
||||
|
||||
```typescript
|
||||
// Extract to named functions
|
||||
const canPublishInCompletionMode = () => !promptEmpty
|
||||
|
||||
const canPublishInChatMode = () => {
|
||||
if (!isAdvancedMode) return true
|
||||
if (modelModeType !== ModelModeType.completion) return true
|
||||
return hasSetBlockStatus.history && hasSetBlockStatus.query
|
||||
}
|
||||
|
||||
// Clean main logic
|
||||
const canPublish = mode === AppModeEnum.COMPLETION
|
||||
? canPublishInCompletionMode()
|
||||
: canPublishInChatMode()
|
||||
```
|
||||
|
||||
## Pattern 4: Replace Chained Ternaries
|
||||
|
||||
**Before** (complexity: ~5):
|
||||
|
||||
```typescript
|
||||
const statusText = serverActivated
|
||||
? t('status.running')
|
||||
: serverPublished
|
||||
? t('status.inactive')
|
||||
: appUnpublished
|
||||
? t('status.unpublished')
|
||||
: t('status.notConfigured')
|
||||
```
|
||||
|
||||
**After** (complexity: ~2):
|
||||
|
||||
```typescript
|
||||
const getStatusText = () => {
|
||||
if (serverActivated) return t('status.running')
|
||||
if (serverPublished) return t('status.inactive')
|
||||
if (appUnpublished) return t('status.unpublished')
|
||||
return t('status.notConfigured')
|
||||
}
|
||||
|
||||
const statusText = getStatusText()
|
||||
```
|
||||
|
||||
Or use lookup:
|
||||
|
||||
```typescript
|
||||
const STATUS_TEXT_MAP = {
|
||||
running: 'status.running',
|
||||
inactive: 'status.inactive',
|
||||
unpublished: 'status.unpublished',
|
||||
notConfigured: 'status.notConfigured',
|
||||
} as const
|
||||
|
||||
const getStatusKey = (): keyof typeof STATUS_TEXT_MAP => {
|
||||
if (serverActivated) return 'running'
|
||||
if (serverPublished) return 'inactive'
|
||||
if (appUnpublished) return 'unpublished'
|
||||
return 'notConfigured'
|
||||
}
|
||||
|
||||
const statusText = t(STATUS_TEXT_MAP[getStatusKey()])
|
||||
```
|
||||
|
||||
## Pattern 5: Flatten Nested Loops
|
||||
|
||||
**Before** (complexity: high):
|
||||
|
||||
```typescript
|
||||
const processData = (items: Item[]) => {
|
||||
const results: ProcessedItem[] = []
|
||||
|
||||
for (const item of items) {
|
||||
if (item.isValid) {
|
||||
for (const child of item.children) {
|
||||
if (child.isActive) {
|
||||
for (const prop of child.properties) {
|
||||
if (prop.value !== null) {
|
||||
results.push({
|
||||
itemId: item.id,
|
||||
childId: child.id,
|
||||
propValue: prop.value,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
```
|
||||
|
||||
**After** (complexity: lower):
|
||||
|
||||
```typescript
|
||||
// Use functional approach
|
||||
const processData = (items: Item[]) => {
|
||||
return items
|
||||
.filter(item => item.isValid)
|
||||
.flatMap(item =>
|
||||
item.children
|
||||
.filter(child => child.isActive)
|
||||
.flatMap(child =>
|
||||
child.properties
|
||||
.filter(prop => prop.value !== null)
|
||||
.map(prop => ({
|
||||
itemId: item.id,
|
||||
childId: child.id,
|
||||
propValue: prop.value,
|
||||
}))
|
||||
)
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
## Pattern 6: Extract Event Handler Logic
|
||||
|
||||
**Before** (complexity: high in component):
|
||||
|
||||
```typescript
|
||||
const Component = () => {
|
||||
const handleSelect = (data: DataSet[]) => {
|
||||
if (isEqual(data.map(item => item.id), dataSets.map(item => item.id))) {
|
||||
hideSelectDataSet()
|
||||
return
|
||||
}
|
||||
|
||||
formattingChangedDispatcher()
|
||||
let newDatasets = data
|
||||
if (data.find(item => !item.name)) {
|
||||
const newSelected = produce(data, (draft) => {
|
||||
data.forEach((item, index) => {
|
||||
if (!item.name) {
|
||||
const newItem = dataSets.find(i => i.id === item.id)
|
||||
if (newItem)
|
||||
draft[index] = newItem
|
||||
}
|
||||
})
|
||||
})
|
||||
setDataSets(newSelected)
|
||||
newDatasets = newSelected
|
||||
}
|
||||
else {
|
||||
setDataSets(data)
|
||||
}
|
||||
hideSelectDataSet()
|
||||
|
||||
// 40 more lines of logic...
|
||||
}
|
||||
|
||||
return <div>...</div>
|
||||
}
|
||||
```
|
||||
|
||||
**After** (complexity: lower):
|
||||
|
||||
```typescript
|
||||
// Extract to hook or utility
|
||||
const useDatasetSelection = (dataSets: DataSet[], setDataSets: SetState<DataSet[]>) => {
|
||||
const normalizeSelection = (data: DataSet[]) => {
|
||||
const hasUnloadedItem = data.some(item => !item.name)
|
||||
if (!hasUnloadedItem) return data
|
||||
|
||||
return produce(data, (draft) => {
|
||||
data.forEach((item, index) => {
|
||||
if (!item.name) {
|
||||
const existing = dataSets.find(i => i.id === item.id)
|
||||
if (existing) draft[index] = existing
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
const hasSelectionChanged = (newData: DataSet[]) => {
|
||||
return !isEqual(
|
||||
newData.map(item => item.id),
|
||||
dataSets.map(item => item.id)
|
||||
)
|
||||
}
|
||||
|
||||
return { normalizeSelection, hasSelectionChanged }
|
||||
}
|
||||
|
||||
// Component becomes cleaner
|
||||
const Component = () => {
|
||||
const { normalizeSelection, hasSelectionChanged } = useDatasetSelection(dataSets, setDataSets)
|
||||
|
||||
const handleSelect = (data: DataSet[]) => {
|
||||
if (!hasSelectionChanged(data)) {
|
||||
hideSelectDataSet()
|
||||
return
|
||||
}
|
||||
|
||||
formattingChangedDispatcher()
|
||||
const normalized = normalizeSelection(data)
|
||||
setDataSets(normalized)
|
||||
hideSelectDataSet()
|
||||
}
|
||||
|
||||
return <div>...</div>
|
||||
}
|
||||
```
|
||||
|
||||
## Pattern 7: Reduce Boolean Logic Complexity
|
||||
|
||||
**Before** (complexity: ~8):
|
||||
|
||||
```typescript
|
||||
const toggleDisabled = hasInsufficientPermissions
|
||||
|| appUnpublished
|
||||
|| missingStartNode
|
||||
|| triggerModeDisabled
|
||||
|| (isAdvancedApp && !currentWorkflow?.graph)
|
||||
|| (isBasicApp && !basicAppConfig.updated_at)
|
||||
```
|
||||
|
||||
**After** (complexity: ~3):
|
||||
|
||||
```typescript
|
||||
// Extract meaningful boolean functions
|
||||
const isAppReady = () => {
|
||||
if (isAdvancedApp) return !!currentWorkflow?.graph
|
||||
return !!basicAppConfig.updated_at
|
||||
}
|
||||
|
||||
const hasRequiredPermissions = () => {
|
||||
return isCurrentWorkspaceEditor && !hasInsufficientPermissions
|
||||
}
|
||||
|
||||
const canToggle = () => {
|
||||
if (!hasRequiredPermissions()) return false
|
||||
if (!isAppReady()) return false
|
||||
if (missingStartNode) return false
|
||||
if (triggerModeDisabled) return false
|
||||
return true
|
||||
}
|
||||
|
||||
const toggleDisabled = !canToggle()
|
||||
```
|
||||
|
||||
## Pattern 8: Simplify useMemo/useCallback Dependencies
|
||||
|
||||
**Before** (complexity: multiple recalculations):
|
||||
|
||||
```typescript
|
||||
const payload = useMemo(() => {
|
||||
let parameters: Parameter[] = []
|
||||
let outputParameters: OutputParameter[] = []
|
||||
|
||||
if (!published) {
|
||||
parameters = (inputs || []).map((item) => ({
|
||||
name: item.variable,
|
||||
description: '',
|
||||
form: 'llm',
|
||||
required: item.required,
|
||||
type: item.type,
|
||||
}))
|
||||
outputParameters = (outputs || []).map((item) => ({
|
||||
name: item.variable,
|
||||
description: '',
|
||||
type: item.value_type,
|
||||
}))
|
||||
}
|
||||
else if (detail && detail.tool) {
|
||||
parameters = (inputs || []).map((item) => ({
|
||||
// Complex transformation...
|
||||
}))
|
||||
outputParameters = (outputs || []).map((item) => ({
|
||||
// Complex transformation...
|
||||
}))
|
||||
}
|
||||
|
||||
return {
|
||||
icon: detail?.icon || icon,
|
||||
label: detail?.label || name,
|
||||
// ...more fields
|
||||
}
|
||||
}, [detail, published, workflowAppId, icon, name, description, inputs, outputs])
|
||||
```
|
||||
|
||||
**After** (complexity: separated concerns):
|
||||
|
||||
```typescript
|
||||
// Separate transformations
|
||||
const useParameterTransform = (inputs: InputVar[], detail?: ToolDetail, published?: boolean) => {
|
||||
return useMemo(() => {
|
||||
if (!published) {
|
||||
return inputs.map(item => ({
|
||||
name: item.variable,
|
||||
description: '',
|
||||
form: 'llm',
|
||||
required: item.required,
|
||||
type: item.type,
|
||||
}))
|
||||
}
|
||||
|
||||
if (!detail?.tool) return []
|
||||
|
||||
return inputs.map(item => ({
|
||||
name: item.variable,
|
||||
required: item.required,
|
||||
type: item.type === 'paragraph' ? 'string' : item.type,
|
||||
description: detail.tool.parameters.find(p => p.name === item.variable)?.llm_description || '',
|
||||
form: detail.tool.parameters.find(p => p.name === item.variable)?.form || 'llm',
|
||||
}))
|
||||
}, [inputs, detail, published])
|
||||
}
|
||||
|
||||
// Component uses hook
|
||||
const parameters = useParameterTransform(inputs, detail, published)
|
||||
const outputParameters = useOutputTransform(outputs, detail, published)
|
||||
|
||||
const payload = useMemo(() => ({
|
||||
icon: detail?.icon || icon,
|
||||
label: detail?.label || name,
|
||||
parameters,
|
||||
outputParameters,
|
||||
// ...
|
||||
}), [detail, icon, name, parameters, outputParameters])
|
||||
```
|
||||
|
||||
## Target Metrics After Refactoring
|
||||
|
||||
| Metric | Target |
|
||||
|--------|--------|
|
||||
| Total Complexity | < 50 |
|
||||
| Max Function Complexity | < 30 |
|
||||
| Function Length | < 30 lines |
|
||||
| Nesting Depth | ≤ 3 levels |
|
||||
| Conditional Chains | ≤ 3 conditions |
|
||||
@@ -1,477 +0,0 @@
|
||||
# Component Splitting Patterns
|
||||
|
||||
This document provides detailed guidance on splitting large components into smaller, focused components in Dify.
|
||||
|
||||
## When to Split Components
|
||||
|
||||
Split a component when you identify:
|
||||
|
||||
1. **Multiple UI sections** - Distinct visual areas with minimal coupling that can be composed independently
|
||||
1. **Conditional rendering blocks** - Large `{condition && <JSX />}` blocks
|
||||
1. **Repeated patterns** - Similar UI structures used multiple times
|
||||
1. **300+ lines** - Component exceeds manageable size
|
||||
1. **Modal clusters** - Multiple modals rendered in one component
|
||||
|
||||
## Splitting Strategies
|
||||
|
||||
### Strategy 1: Section-Based Splitting
|
||||
|
||||
Identify visual sections and extract each as a component.
|
||||
|
||||
```typescript
|
||||
// ❌ Before: Monolithic component (500+ lines)
|
||||
const ConfigurationPage = () => {
|
||||
return (
|
||||
<div>
|
||||
{/* Header Section - 50 lines */}
|
||||
<div className="header">
|
||||
<h1>{t('configuration.title')}</h1>
|
||||
<div className="actions">
|
||||
{isAdvancedMode && <Badge>Advanced</Badge>}
|
||||
<ModelParameterModal ... />
|
||||
<AppPublisher ... />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Config Section - 200 lines */}
|
||||
<div className="config">
|
||||
<Config />
|
||||
</div>
|
||||
|
||||
{/* Debug Section - 150 lines */}
|
||||
<div className="debug">
|
||||
<Debug ... />
|
||||
</div>
|
||||
|
||||
{/* Modals Section - 100 lines */}
|
||||
{showSelectDataSet && <SelectDataSet ... />}
|
||||
{showHistoryModal && <EditHistoryModal ... />}
|
||||
{showUseGPT4Confirm && <Confirm ... />}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// ✅ After: Split into focused components
|
||||
// configuration/
|
||||
// ├── index.tsx (orchestration)
|
||||
// ├── configuration-header.tsx
|
||||
// ├── configuration-content.tsx
|
||||
// ├── configuration-debug.tsx
|
||||
// └── configuration-modals.tsx
|
||||
|
||||
// configuration-header.tsx
|
||||
interface ConfigurationHeaderProps {
|
||||
isAdvancedMode: boolean
|
||||
onPublish: () => void
|
||||
}
|
||||
|
||||
const ConfigurationHeader: FC<ConfigurationHeaderProps> = ({
|
||||
isAdvancedMode,
|
||||
onPublish,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
return (
|
||||
<div className="header">
|
||||
<h1>{t('configuration.title')}</h1>
|
||||
<div className="actions">
|
||||
{isAdvancedMode && <Badge>Advanced</Badge>}
|
||||
<ModelParameterModal ... />
|
||||
<AppPublisher onPublish={onPublish} />
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// index.tsx (orchestration only)
|
||||
const ConfigurationPage = () => {
|
||||
const { modelConfig, setModelConfig } = useModelConfig()
|
||||
const { activeModal, openModal, closeModal } = useModalState()
|
||||
|
||||
return (
|
||||
<div>
|
||||
<ConfigurationHeader
|
||||
isAdvancedMode={isAdvancedMode}
|
||||
onPublish={handlePublish}
|
||||
/>
|
||||
<ConfigurationContent
|
||||
modelConfig={modelConfig}
|
||||
onConfigChange={setModelConfig}
|
||||
/>
|
||||
{!isMobile && (
|
||||
<ConfigurationDebug
|
||||
inputs={inputs}
|
||||
onSetting={handleSetting}
|
||||
/>
|
||||
)}
|
||||
<ConfigurationModals
|
||||
activeModal={activeModal}
|
||||
onClose={closeModal}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### Strategy 2: Conditional Block Extraction
|
||||
|
||||
Extract large conditional rendering blocks.
|
||||
|
||||
```typescript
|
||||
// ❌ Before: Large conditional blocks
|
||||
const AppInfo = () => {
|
||||
return (
|
||||
<div>
|
||||
{expand ? (
|
||||
<div className="expanded">
|
||||
{/* 100 lines of expanded view */}
|
||||
</div>
|
||||
) : (
|
||||
<div className="collapsed">
|
||||
{/* 50 lines of collapsed view */}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// ✅ After: Separate view components
|
||||
const AppInfoExpanded: FC<AppInfoViewProps> = ({ appDetail, onAction }) => {
|
||||
return (
|
||||
<div className="expanded">
|
||||
{/* Clean, focused expanded view */}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const AppInfoCollapsed: FC<AppInfoViewProps> = ({ appDetail, onAction }) => {
|
||||
return (
|
||||
<div className="collapsed">
|
||||
{/* Clean, focused collapsed view */}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const AppInfo = () => {
|
||||
return (
|
||||
<div>
|
||||
{expand
|
||||
? <AppInfoExpanded appDetail={appDetail} onAction={handleAction} />
|
||||
: <AppInfoCollapsed appDetail={appDetail} onAction={handleAction} />
|
||||
}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### Strategy 3: Modal Extraction
|
||||
|
||||
Extract modals with their trigger logic.
|
||||
|
||||
```typescript
|
||||
// ❌ Before: Multiple modals in one component
|
||||
const AppInfo = () => {
|
||||
const [showEdit, setShowEdit] = useState(false)
|
||||
const [showDuplicate, setShowDuplicate] = useState(false)
|
||||
const [showDelete, setShowDelete] = useState(false)
|
||||
const [showSwitch, setShowSwitch] = useState(false)
|
||||
|
||||
const onEdit = async (data) => { /* 20 lines */ }
|
||||
const onDuplicate = async (data) => { /* 20 lines */ }
|
||||
const onDelete = async () => { /* 15 lines */ }
|
||||
|
||||
return (
|
||||
<div>
|
||||
{/* Main content */}
|
||||
|
||||
{showEdit && <EditModal onConfirm={onEdit} onClose={() => setShowEdit(false)} />}
|
||||
{showDuplicate && <DuplicateModal onConfirm={onDuplicate} onClose={() => setShowDuplicate(false)} />}
|
||||
{showDelete && <DeleteConfirm onConfirm={onDelete} onClose={() => setShowDelete(false)} />}
|
||||
{showSwitch && <SwitchModal ... />}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// ✅ After: Modal manager component
|
||||
// app-info-modals.tsx
|
||||
type ModalType = 'edit' | 'duplicate' | 'delete' | 'switch' | null
|
||||
|
||||
interface AppInfoModalsProps {
|
||||
appDetail: AppDetail
|
||||
activeModal: ModalType
|
||||
onClose: () => void
|
||||
onSuccess: () => void
|
||||
}
|
||||
|
||||
const AppInfoModals: FC<AppInfoModalsProps> = ({
|
||||
appDetail,
|
||||
activeModal,
|
||||
onClose,
|
||||
onSuccess,
|
||||
}) => {
|
||||
const handleEdit = async (data) => { /* logic */ }
|
||||
const handleDuplicate = async (data) => { /* logic */ }
|
||||
const handleDelete = async () => { /* logic */ }
|
||||
|
||||
return (
|
||||
<>
|
||||
{activeModal === 'edit' && (
|
||||
<EditModal
|
||||
appDetail={appDetail}
|
||||
onConfirm={handleEdit}
|
||||
onClose={onClose}
|
||||
/>
|
||||
)}
|
||||
{activeModal === 'duplicate' && (
|
||||
<DuplicateModal
|
||||
appDetail={appDetail}
|
||||
onConfirm={handleDuplicate}
|
||||
onClose={onClose}
|
||||
/>
|
||||
)}
|
||||
{activeModal === 'delete' && (
|
||||
<DeleteConfirm
|
||||
onConfirm={handleDelete}
|
||||
onClose={onClose}
|
||||
/>
|
||||
)}
|
||||
{activeModal === 'switch' && (
|
||||
<SwitchModal
|
||||
appDetail={appDetail}
|
||||
onClose={onClose}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
// Parent component
|
||||
const AppInfo = () => {
|
||||
const { activeModal, openModal, closeModal } = useModalState()
|
||||
|
||||
return (
|
||||
<div>
|
||||
{/* Main content with openModal triggers */}
|
||||
<Button onClick={() => openModal('edit')}>Edit</Button>
|
||||
|
||||
<AppInfoModals
|
||||
appDetail={appDetail}
|
||||
activeModal={activeModal}
|
||||
onClose={closeModal}
|
||||
onSuccess={handleSuccess}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### Strategy 4: List Item Extraction
|
||||
|
||||
Extract repeated item rendering.
|
||||
|
||||
```typescript
|
||||
// ❌ Before: Inline item rendering
|
||||
const OperationsList = () => {
|
||||
return (
|
||||
<div>
|
||||
{operations.map(op => (
|
||||
<div key={op.id} className="operation-item">
|
||||
<span className="icon">{op.icon}</span>
|
||||
<span className="title">{op.title}</span>
|
||||
<span className="description">{op.description}</span>
|
||||
<button onClick={() => op.onClick()}>
|
||||
{op.actionLabel}
|
||||
</button>
|
||||
{op.badge && <Badge>{op.badge}</Badge>}
|
||||
{/* More complex rendering... */}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// ✅ After: Extracted item component
|
||||
interface OperationItemProps {
|
||||
operation: Operation
|
||||
onAction: (id: string) => void
|
||||
}
|
||||
|
||||
const OperationItem: FC<OperationItemProps> = ({ operation, onAction }) => {
|
||||
return (
|
||||
<div className="operation-item">
|
||||
<span className="icon">{operation.icon}</span>
|
||||
<span className="title">{operation.title}</span>
|
||||
<span className="description">{operation.description}</span>
|
||||
<button onClick={() => onAction(operation.id)}>
|
||||
{operation.actionLabel}
|
||||
</button>
|
||||
{operation.badge && <Badge>{operation.badge}</Badge>}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const OperationsList = () => {
|
||||
const handleAction = useCallback((id: string) => {
|
||||
const op = operations.find(o => o.id === id)
|
||||
op?.onClick()
|
||||
}, [operations])
|
||||
|
||||
return (
|
||||
<div>
|
||||
{operations.map(op => (
|
||||
<OperationItem
|
||||
key={op.id}
|
||||
operation={op}
|
||||
onAction={handleAction}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
## Directory Structure Patterns
|
||||
|
||||
### Pattern A: Flat Structure (Simple Components)
|
||||
|
||||
For components with 2-3 sub-components:
|
||||
|
||||
```
|
||||
component-name/
|
||||
├── index.tsx # Main component
|
||||
├── sub-component-a.tsx
|
||||
├── sub-component-b.tsx
|
||||
└── types.ts # Shared types
|
||||
```
|
||||
|
||||
### Pattern B: Nested Structure (Complex Components)
|
||||
|
||||
For components with many sub-components:
|
||||
|
||||
```
|
||||
component-name/
|
||||
├── index.tsx # Main orchestration
|
||||
├── types.ts # Shared types
|
||||
├── hooks/
|
||||
│ ├── use-feature-a.ts
|
||||
│ └── use-feature-b.ts
|
||||
├── components/
|
||||
│ ├── header/
|
||||
│ │ └── index.tsx
|
||||
│ ├── content/
|
||||
│ │ └── index.tsx
|
||||
│ └── modals/
|
||||
│ └── index.tsx
|
||||
└── utils/
|
||||
└── helpers.ts
|
||||
```
|
||||
|
||||
### Pattern C: Feature-Based Structure (Dify Standard)
|
||||
|
||||
Following Dify's existing patterns:
|
||||
|
||||
```
|
||||
configuration/
|
||||
├── index.tsx # Main page component
|
||||
├── base/ # Base/shared components
|
||||
│ ├── feature-panel/
|
||||
│ ├── group-name/
|
||||
│ └── operation-btn/
|
||||
├── config/ # Config section
|
||||
│ ├── index.tsx
|
||||
│ ├── agent/
|
||||
│ └── automatic/
|
||||
├── dataset-config/ # Dataset section
|
||||
│ ├── index.tsx
|
||||
│ ├── card-item/
|
||||
│ └── params-config/
|
||||
├── debug/ # Debug section
|
||||
│ ├── index.tsx
|
||||
│ └── hooks.tsx
|
||||
└── hooks/ # Shared hooks
|
||||
└── use-advanced-prompt-config.ts
|
||||
```
|
||||
|
||||
## Props Design
|
||||
|
||||
### Minimal Props Principle
|
||||
|
||||
Pass only what's needed:
|
||||
|
||||
```typescript
|
||||
// ❌ Bad: Passing entire objects when only some fields needed
|
||||
<ConfigHeader appDetail={appDetail} modelConfig={modelConfig} />
|
||||
|
||||
// ✅ Good: Destructure to minimum required
|
||||
<ConfigHeader
|
||||
appName={appDetail.name}
|
||||
isAdvancedMode={modelConfig.isAdvanced}
|
||||
onPublish={handlePublish}
|
||||
/>
|
||||
```
|
||||
|
||||
### Callback Props Pattern
|
||||
|
||||
Use callbacks for child-to-parent communication:
|
||||
|
||||
```typescript
|
||||
// Parent
|
||||
const Parent = () => {
|
||||
const [value, setValue] = useState('')
|
||||
|
||||
return (
|
||||
<Child
|
||||
value={value}
|
||||
onChange={setValue}
|
||||
onSubmit={handleSubmit}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
// Child
|
||||
interface ChildProps {
|
||||
value: string
|
||||
onChange: (value: string) => void
|
||||
onSubmit: () => void
|
||||
}
|
||||
|
||||
const Child: FC<ChildProps> = ({ value, onChange, onSubmit }) => {
|
||||
return (
|
||||
<div>
|
||||
<input value={value} onChange={e => onChange(e.target.value)} />
|
||||
<button onClick={onSubmit}>Submit</button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### Render Props for Flexibility
|
||||
|
||||
When sub-components need parent context:
|
||||
|
||||
```typescript
|
||||
interface ListProps<T> {
|
||||
items: T[]
|
||||
renderItem: (item: T, index: number) => React.ReactNode
|
||||
renderEmpty?: () => React.ReactNode
|
||||
}
|
||||
|
||||
function List<T>({ items, renderItem, renderEmpty }: ListProps<T>) {
|
||||
if (items.length === 0 && renderEmpty) {
|
||||
return <>{renderEmpty()}</>
|
||||
}
|
||||
|
||||
return (
|
||||
<div>
|
||||
{items.map((item, index) => renderItem(item, index))}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// Usage
|
||||
<List
|
||||
items={operations}
|
||||
renderItem={(op, i) => <OperationItem key={i} operation={op} />}
|
||||
renderEmpty={() => <EmptyState message="No operations" />}
|
||||
/>
|
||||
```
|
||||
@@ -1,317 +0,0 @@
|
||||
# Hook Extraction Patterns
|
||||
|
||||
This document provides detailed guidance on extracting custom hooks from complex components in Dify.
|
||||
|
||||
## When to Extract Hooks
|
||||
|
||||
Extract a custom hook when you identify:
|
||||
|
||||
1. **Coupled state groups** - Multiple `useState` hooks that are always used together
|
||||
1. **Complex effects** - `useEffect` with multiple dependencies or cleanup logic
|
||||
1. **Business logic** - Data transformations, validations, or calculations
|
||||
1. **Reusable patterns** - Logic that appears in multiple components
|
||||
|
||||
## Extraction Process
|
||||
|
||||
### Step 1: Identify State Groups
|
||||
|
||||
Look for state variables that are logically related:
|
||||
|
||||
```typescript
|
||||
// ❌ These belong together - extract to hook
|
||||
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
|
||||
const [completionParams, setCompletionParams] = useState<FormValue>({})
|
||||
const [modelModeType, setModelModeType] = useState<ModelModeType>(...)
|
||||
|
||||
// These are model-related state that should be in useModelConfig()
|
||||
```
|
||||
|
||||
### Step 2: Identify Related Effects
|
||||
|
||||
Find effects that modify the grouped state:
|
||||
|
||||
```typescript
|
||||
// ❌ These effects belong with the state above
|
||||
useEffect(() => {
|
||||
if (hasFetchedDetail && !modelModeType) {
|
||||
const mode = currModel?.model_properties.mode
|
||||
if (mode) {
|
||||
const newModelConfig = produce(modelConfig, (draft) => {
|
||||
draft.mode = mode
|
||||
})
|
||||
setModelConfig(newModelConfig)
|
||||
}
|
||||
}
|
||||
}, [textGenerationModelList, hasFetchedDetail, modelModeType, currModel])
|
||||
```
|
||||
|
||||
### Step 3: Create the Hook
|
||||
|
||||
```typescript
|
||||
// hooks/use-model-config.ts
|
||||
import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import type { ModelConfig } from '@/models/debug'
|
||||
import { produce } from 'immer'
|
||||
import { useEffect, useState } from 'react'
|
||||
import { ModelModeType } from '@/types/app'
|
||||
|
||||
interface UseModelConfigParams {
|
||||
initialConfig?: Partial<ModelConfig>
|
||||
currModel?: { model_properties?: { mode?: ModelModeType } }
|
||||
hasFetchedDetail: boolean
|
||||
}
|
||||
|
||||
interface UseModelConfigReturn {
|
||||
modelConfig: ModelConfig
|
||||
setModelConfig: (config: ModelConfig) => void
|
||||
completionParams: FormValue
|
||||
setCompletionParams: (params: FormValue) => void
|
||||
modelModeType: ModelModeType
|
||||
}
|
||||
|
||||
export const useModelConfig = ({
|
||||
initialConfig,
|
||||
currModel,
|
||||
hasFetchedDetail,
|
||||
}: UseModelConfigParams): UseModelConfigReturn => {
|
||||
const [modelConfig, setModelConfig] = useState<ModelConfig>({
|
||||
provider: 'langgenius/openai/openai',
|
||||
model_id: 'gpt-3.5-turbo',
|
||||
mode: ModelModeType.unset,
|
||||
// ... default values
|
||||
...initialConfig,
|
||||
})
|
||||
|
||||
const [completionParams, setCompletionParams] = useState<FormValue>({})
|
||||
|
||||
const modelModeType = modelConfig.mode
|
||||
|
||||
// Fill old app data missing model mode
|
||||
useEffect(() => {
|
||||
if (hasFetchedDetail && !modelModeType) {
|
||||
const mode = currModel?.model_properties?.mode
|
||||
if (mode) {
|
||||
setModelConfig(produce(modelConfig, (draft) => {
|
||||
draft.mode = mode
|
||||
}))
|
||||
}
|
||||
}
|
||||
}, [hasFetchedDetail, modelModeType, currModel])
|
||||
|
||||
return {
|
||||
modelConfig,
|
||||
setModelConfig,
|
||||
completionParams,
|
||||
setCompletionParams,
|
||||
modelModeType,
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Step 4: Update Component
|
||||
|
||||
```typescript
|
||||
// Before: 50+ lines of state management
|
||||
const Configuration: FC = () => {
|
||||
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
|
||||
// ... lots of related state and effects
|
||||
}
|
||||
|
||||
// After: Clean component
|
||||
const Configuration: FC = () => {
|
||||
const {
|
||||
modelConfig,
|
||||
setModelConfig,
|
||||
completionParams,
|
||||
setCompletionParams,
|
||||
modelModeType,
|
||||
} = useModelConfig({
|
||||
currModel,
|
||||
hasFetchedDetail,
|
||||
})
|
||||
|
||||
// Component now focuses on UI
|
||||
}
|
||||
```
|
||||
|
||||
## Naming Conventions
|
||||
|
||||
### Hook Names
|
||||
|
||||
- Use `use` prefix: `useModelConfig`, `useDatasetConfig`
|
||||
- Be specific: `useAdvancedPromptConfig` not `usePrompt`
|
||||
- Include domain: `useWorkflowVariables`, `useMCPServer`
|
||||
|
||||
### File Names
|
||||
|
||||
- Kebab-case: `use-model-config.ts`
|
||||
- Place in `hooks/` subdirectory when multiple hooks exist
|
||||
- Place alongside component for single-use hooks
|
||||
|
||||
### Return Type Names
|
||||
|
||||
- Suffix with `Return`: `UseModelConfigReturn`
|
||||
- Suffix params with `Params`: `UseModelConfigParams`
|
||||
|
||||
## Common Hook Patterns in Dify
|
||||
|
||||
### 1. Data Fetching Hook (React Query)
|
||||
|
||||
```typescript
|
||||
// Pattern: Use @tanstack/react-query for data fetching
|
||||
import { useQuery, useQueryClient } from '@tanstack/react-query'
|
||||
import { get } from '@/service/base'
|
||||
import { useInvalid } from '@/service/use-base'
|
||||
|
||||
const NAME_SPACE = 'appConfig'
|
||||
|
||||
// Query keys for cache management
|
||||
export const appConfigQueryKeys = {
|
||||
detail: (appId: string) => [NAME_SPACE, 'detail', appId] as const,
|
||||
}
|
||||
|
||||
// Main data hook
|
||||
export const useAppConfig = (appId: string) => {
|
||||
return useQuery({
|
||||
enabled: !!appId,
|
||||
queryKey: appConfigQueryKeys.detail(appId),
|
||||
queryFn: () => get<AppDetailResponse>(`/apps/${appId}`),
|
||||
select: data => data?.model_config || null,
|
||||
})
|
||||
}
|
||||
|
||||
// Invalidation hook for refreshing data
|
||||
export const useInvalidAppConfig = () => {
|
||||
return useInvalid([NAME_SPACE])
|
||||
}
|
||||
|
||||
// Usage in component
|
||||
const Component = () => {
|
||||
const { data: config, isLoading, error, refetch } = useAppConfig(appId)
|
||||
const invalidAppConfig = useInvalidAppConfig()
|
||||
|
||||
const handleRefresh = () => {
|
||||
invalidAppConfig() // Invalidates cache and triggers refetch
|
||||
}
|
||||
|
||||
return <div>...</div>
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Form State Hook
|
||||
|
||||
```typescript
|
||||
// Pattern: Form state + validation + submission
|
||||
export const useConfigForm = (initialValues: ConfigFormValues) => {
|
||||
const [values, setValues] = useState(initialValues)
|
||||
const [errors, setErrors] = useState<Record<string, string>>({})
|
||||
const [isSubmitting, setIsSubmitting] = useState(false)
|
||||
|
||||
const validate = useCallback(() => {
|
||||
const newErrors: Record<string, string> = {}
|
||||
if (!values.name) newErrors.name = 'Name is required'
|
||||
setErrors(newErrors)
|
||||
return Object.keys(newErrors).length === 0
|
||||
}, [values])
|
||||
|
||||
const handleChange = useCallback((field: string, value: any) => {
|
||||
setValues(prev => ({ ...prev, [field]: value }))
|
||||
}, [])
|
||||
|
||||
const handleSubmit = useCallback(async (onSubmit: (values: ConfigFormValues) => Promise<void>) => {
|
||||
if (!validate()) return
|
||||
setIsSubmitting(true)
|
||||
try {
|
||||
await onSubmit(values)
|
||||
} finally {
|
||||
setIsSubmitting(false)
|
||||
}
|
||||
}, [values, validate])
|
||||
|
||||
return { values, errors, isSubmitting, handleChange, handleSubmit }
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Modal State Hook
|
||||
|
||||
```typescript
|
||||
// Pattern: Multiple modal management
|
||||
type ModalType = 'edit' | 'delete' | 'duplicate' | null
|
||||
|
||||
export const useModalState = () => {
|
||||
const [activeModal, setActiveModal] = useState<ModalType>(null)
|
||||
const [modalData, setModalData] = useState<any>(null)
|
||||
|
||||
const openModal = useCallback((type: ModalType, data?: any) => {
|
||||
setActiveModal(type)
|
||||
setModalData(data)
|
||||
}, [])
|
||||
|
||||
const closeModal = useCallback(() => {
|
||||
setActiveModal(null)
|
||||
setModalData(null)
|
||||
}, [])
|
||||
|
||||
return {
|
||||
activeModal,
|
||||
modalData,
|
||||
openModal,
|
||||
closeModal,
|
||||
isOpen: useCallback((type: ModalType) => activeModal === type, [activeModal]),
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Toggle/Boolean Hook
|
||||
|
||||
```typescript
|
||||
// Pattern: Boolean state with convenience methods
|
||||
export const useToggle = (initialValue = false) => {
|
||||
const [value, setValue] = useState(initialValue)
|
||||
|
||||
const toggle = useCallback(() => setValue(v => !v), [])
|
||||
const setTrue = useCallback(() => setValue(true), [])
|
||||
const setFalse = useCallback(() => setValue(false), [])
|
||||
|
||||
return [value, { toggle, setTrue, setFalse, set: setValue }] as const
|
||||
}
|
||||
|
||||
// Usage
|
||||
const [isExpanded, { toggle, setTrue: expand, setFalse: collapse }] = useToggle()
|
||||
```
|
||||
|
||||
## Testing Extracted Hooks
|
||||
|
||||
After extraction, test hooks in isolation:
|
||||
|
||||
```typescript
|
||||
// use-model-config.spec.ts
|
||||
import { renderHook, act } from '@testing-library/react'
|
||||
import { useModelConfig } from './use-model-config'
|
||||
|
||||
describe('useModelConfig', () => {
|
||||
it('should initialize with default values', () => {
|
||||
const { result } = renderHook(() => useModelConfig({
|
||||
hasFetchedDetail: false,
|
||||
}))
|
||||
|
||||
expect(result.current.modelConfig.provider).toBe('langgenius/openai/openai')
|
||||
expect(result.current.modelModeType).toBe(ModelModeType.unset)
|
||||
})
|
||||
|
||||
it('should update model config', () => {
|
||||
const { result } = renderHook(() => useModelConfig({
|
||||
hasFetchedDetail: true,
|
||||
}))
|
||||
|
||||
act(() => {
|
||||
result.current.setModelConfig({
|
||||
...result.current.modelConfig,
|
||||
model_id: 'gpt-4',
|
||||
})
|
||||
})
|
||||
|
||||
expect(result.current.modelConfig.model_id).toBe('gpt-4')
|
||||
})
|
||||
})
|
||||
```
|
||||
@@ -1,73 +0,0 @@
|
||||
---
|
||||
name: frontend-code-review
|
||||
description: "Trigger when the user requests a review of frontend files (e.g., `.tsx`, `.ts`, `.js`). Support both pending-change reviews and focused file reviews while applying the checklist rules."
|
||||
---
|
||||
|
||||
# Frontend Code Review
|
||||
|
||||
## Intent
|
||||
Use this skill whenever the user asks to review frontend code (especially `.tsx`, `.ts`, or `.js` files). Support two review modes:
|
||||
|
||||
1. **Pending-change review** – inspect staged/working-tree files slated for commit and flag checklist violations before submission.
|
||||
2. **File-targeted review** – review the specific file(s) the user names and report the relevant checklist findings.
|
||||
|
||||
Stick to the checklist below for every applicable file and mode.
|
||||
|
||||
## Checklist
|
||||
See [references/code-quality.md](references/code-quality.md), [references/performance.md](references/performance.md), [references/business-logic.md](references/business-logic.md) for the living checklist split by category—treat it as the canonical set of rules to follow.
|
||||
|
||||
Flag each rule violation with urgency metadata so future reviewers can prioritize fixes.
|
||||
|
||||
## Review Process
|
||||
1. Open the relevant component/module. Gather lines that relate to class names, React Flow hooks, prop memoization, and styling.
|
||||
2. For each rule in the review point, note where the code deviates and capture a representative snippet.
|
||||
3. Compose the review section per the template below. Group violations first by **Urgent** flag, then by category order (Code Quality, Performance, Business Logic).
|
||||
|
||||
## Required output
|
||||
When invoked, the response must exactly follow one of the two templates:
|
||||
|
||||
### Template A (any findings)
|
||||
```
|
||||
# Code review
|
||||
Found <N> urgent issues need to be fixed:
|
||||
|
||||
## 1 <brief description of bug>
|
||||
FilePath: <path> line <line>
|
||||
<relevant code snippet or pointer>
|
||||
|
||||
|
||||
### Suggested fix
|
||||
<brief description of suggested fix>
|
||||
|
||||
---
|
||||
... (repeat for each urgent issue) ...
|
||||
|
||||
Found <M> suggestions for improvement:
|
||||
|
||||
## 1 <brief description of suggestion>
|
||||
FilePath: <path> line <line>
|
||||
<relevant code snippet or pointer>
|
||||
|
||||
|
||||
### Suggested fix
|
||||
<brief description of suggested fix>
|
||||
|
||||
---
|
||||
|
||||
... (repeat for each suggestion) ...
|
||||
```
|
||||
|
||||
If there are no urgent issues, omit that section. If there are no suggestions, omit that section.
|
||||
|
||||
If the issue number is more than 10, summarize as "10+ urgent issues" or "10+ suggestions" and just output the first 10 issues.
|
||||
|
||||
Don't compress the blank lines between sections; keep them as-is for readability.
|
||||
|
||||
If you use Template A (i.e., there are issues to fix) and at least one issue requires code changes, append a brief follow-up question after the structured output asking whether the user wants you to apply the suggested fix(es). For example: "Would you like me to use the Suggested fix section to address these issues?"
|
||||
|
||||
### Template B (no issues)
|
||||
```
|
||||
## Code review
|
||||
No issues found.
|
||||
```
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
# Rule Catalog — Business Logic
|
||||
|
||||
## Can't use workflowStore in Node components
|
||||
|
||||
IsUrgent: True
|
||||
|
||||
### Description
|
||||
|
||||
File path pattern of node components: `web/app/components/workflow/nodes/[nodeName]/node.tsx`
|
||||
|
||||
Node components are also used when creating a RAG Pipe from a template, but in that context there is no workflowStore Provider, which results in a blank screen. [This Issue](https://github.com/langgenius/dify/issues/29168) was caused by exactly this reason.
|
||||
|
||||
### Suggested Fix
|
||||
|
||||
Use `import { useNodes } from 'reactflow'` instead of `import useNodes from '@/app/components/workflow/store/workflow/use-nodes'`.
|
||||
@@ -1,44 +0,0 @@
|
||||
# Rule Catalog — Code Quality
|
||||
|
||||
## Conditional class names use utility function
|
||||
|
||||
IsUrgent: True
|
||||
Category: Code Quality
|
||||
|
||||
### Description
|
||||
|
||||
Ensure conditional CSS is handled via the shared `classNames` instead of custom ternaries, string concatenation, or template strings. Centralizing class logic keeps components consistent and easier to maintain.
|
||||
|
||||
### Suggested Fix
|
||||
|
||||
```ts
|
||||
import { cn } from '@/utils/classnames'
|
||||
const classNames = cn(isActive ? 'text-primary-600' : 'text-gray-500')
|
||||
```
|
||||
|
||||
## Tailwind-first styling
|
||||
|
||||
IsUrgent: True
|
||||
Category: Code Quality
|
||||
|
||||
### Description
|
||||
|
||||
Favor Tailwind CSS utility classes instead of adding new `.module.css` files unless a Tailwind combination cannot achieve the required styling. Keeping styles in Tailwind improves consistency and reduces maintenance overhead.
|
||||
|
||||
Update this file when adding, editing, or removing Code Quality rules so the catalog remains accurate.
|
||||
|
||||
## Classname ordering for easy overrides
|
||||
|
||||
### Description
|
||||
|
||||
When writing components, always place the incoming `className` prop after the component’s own class values so that downstream consumers can override or extend the styling. This keeps your component’s defaults but still lets external callers change or remove specific styles.
|
||||
|
||||
Example:
|
||||
|
||||
```tsx
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
const Button = ({ className }) => {
|
||||
return <div className={cn('bg-primary-600', className)}></div>
|
||||
}
|
||||
```
|
||||
@@ -1,45 +0,0 @@
|
||||
# Rule Catalog — Performance
|
||||
|
||||
## React Flow data usage
|
||||
|
||||
IsUrgent: True
|
||||
Category: Performance
|
||||
|
||||
### Description
|
||||
|
||||
When rendering React Flow, prefer `useNodes`/`useEdges` for UI consumption and rely on `useStoreApi` inside callbacks that mutate or read node/edge state. Avoid manually pulling Flow data outside of these hooks.
|
||||
|
||||
## Complex prop memoization
|
||||
|
||||
IsUrgent: True
|
||||
Category: Performance
|
||||
|
||||
### Description
|
||||
|
||||
Wrap complex prop values (objects, arrays, maps) in `useMemo` prior to passing them into child components to guarantee stable references and prevent unnecessary renders.
|
||||
|
||||
Update this file when adding, editing, or removing Performance rules so the catalog remains accurate.
|
||||
|
||||
Wrong:
|
||||
|
||||
```tsx
|
||||
<HeavyComp
|
||||
config={{
|
||||
provider: ...,
|
||||
detail: ...
|
||||
}}
|
||||
/>
|
||||
```
|
||||
|
||||
Right:
|
||||
|
||||
```tsx
|
||||
const config = useMemo(() => ({
|
||||
provider: ...,
|
||||
detail: ...
|
||||
}), [provider, detail]);
|
||||
|
||||
<HeavyComp
|
||||
config={config}
|
||||
/>
|
||||
```
|
||||
@@ -1,322 +0,0 @@
|
||||
---
|
||||
name: frontend-testing
|
||||
description: Generate Vitest + React Testing Library tests for Dify frontend components, hooks, and utilities. Triggers on testing, spec files, coverage, Vitest, RTL, unit tests, integration tests, or write/review test requests.
|
||||
---
|
||||
|
||||
# Dify Frontend Testing Skill
|
||||
|
||||
This skill enables Claude to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices.
|
||||
|
||||
> **⚠️ Authoritative Source**: This skill is derived from `web/testing/testing.md`. Use Vitest mock/timer APIs (`vi.*`).
|
||||
|
||||
## When to Apply This Skill
|
||||
|
||||
Apply this skill when the user:
|
||||
|
||||
- Asks to **write tests** for a component, hook, or utility
|
||||
- Asks to **review existing tests** for completeness
|
||||
- Mentions **Vitest**, **React Testing Library**, **RTL**, or **spec files**
|
||||
- Requests **test coverage** improvement
|
||||
- Uses `pnpm analyze-component` output as context
|
||||
- Mentions **testing**, **unit tests**, or **integration tests** for frontend code
|
||||
- Wants to understand **testing patterns** in the Dify codebase
|
||||
|
||||
**Do NOT apply** when:
|
||||
|
||||
- User is asking about backend/API tests (Python/pytest)
|
||||
- User is asking about E2E tests (Playwright/Cypress)
|
||||
- User is only asking conceptual questions without code context
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### Tech Stack
|
||||
|
||||
| Tool | Version | Purpose |
|
||||
|------|---------|---------|
|
||||
| Vitest | 4.0.16 | Test runner |
|
||||
| React Testing Library | 16.0 | Component testing |
|
||||
| jsdom | - | Test environment |
|
||||
| nock | 14.0 | HTTP mocking |
|
||||
| TypeScript | 5.x | Type safety |
|
||||
|
||||
### Key Commands
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
pnpm test
|
||||
|
||||
# Watch mode
|
||||
pnpm test:watch
|
||||
|
||||
# Run specific file
|
||||
pnpm test path/to/file.spec.tsx
|
||||
|
||||
# Generate coverage report
|
||||
pnpm test:coverage
|
||||
|
||||
# Analyze component complexity
|
||||
pnpm analyze-component <path>
|
||||
|
||||
# Review existing test
|
||||
pnpm analyze-component <path> --review
|
||||
```
|
||||
|
||||
### File Naming
|
||||
|
||||
- Test files: `ComponentName.spec.tsx` (same directory as component)
|
||||
- Integration tests: `web/__tests__/` directory
|
||||
|
||||
## Test Structure Template
|
||||
|
||||
```typescript
|
||||
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
|
||||
import Component from './index'
|
||||
|
||||
// ✅ Import real project components (DO NOT mock these)
|
||||
// import Loading from '@/app/components/base/loading'
|
||||
// import { ChildComponent } from './child-component'
|
||||
|
||||
// ✅ Mock external dependencies only
|
||||
vi.mock('@/service/api')
|
||||
vi.mock('next/navigation', () => ({
|
||||
useRouter: () => ({ push: vi.fn() }),
|
||||
usePathname: () => '/test',
|
||||
}))
|
||||
|
||||
// Shared state for mocks (if needed)
|
||||
let mockSharedState = false
|
||||
|
||||
describe('ComponentName', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks() // ✅ Reset mocks BEFORE each test
|
||||
mockSharedState = false // ✅ Reset shared state
|
||||
})
|
||||
|
||||
// Rendering tests (REQUIRED)
|
||||
describe('Rendering', () => {
|
||||
it('should render without crashing', () => {
|
||||
// Arrange
|
||||
const props = { title: 'Test' }
|
||||
|
||||
// Act
|
||||
render(<Component {...props} />)
|
||||
|
||||
// Assert
|
||||
expect(screen.getByText('Test')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// Props tests (REQUIRED)
|
||||
describe('Props', () => {
|
||||
it('should apply custom className', () => {
|
||||
render(<Component className="custom" />)
|
||||
expect(screen.getByRole('button')).toHaveClass('custom')
|
||||
})
|
||||
})
|
||||
|
||||
// User Interactions
|
||||
describe('User Interactions', () => {
|
||||
it('should handle click events', () => {
|
||||
const handleClick = vi.fn()
|
||||
render(<Component onClick={handleClick} />)
|
||||
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
|
||||
expect(handleClick).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
|
||||
// Edge Cases (REQUIRED)
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle null data', () => {
|
||||
render(<Component data={null} />)
|
||||
expect(screen.getByText(/no data/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle empty array', () => {
|
||||
render(<Component items={[]} />)
|
||||
expect(screen.getByText(/empty/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
## Testing Workflow (CRITICAL)
|
||||
|
||||
### ⚠️ Incremental Approach Required
|
||||
|
||||
**NEVER generate all test files at once.** For complex components or multi-file directories:
|
||||
|
||||
1. **Analyze & Plan**: List all files, order by complexity (simple → complex)
|
||||
1. **Process ONE at a time**: Write test → Run test → Fix if needed → Next
|
||||
1. **Verify before proceeding**: Do NOT continue to next file until current passes
|
||||
|
||||
```
|
||||
For each file:
|
||||
┌────────────────────────────────────────┐
|
||||
│ 1. Write test │
|
||||
│ 2. Run: pnpm test <file>.spec.tsx │
|
||||
│ 3. PASS? → Mark complete, next file │
|
||||
│ FAIL? → Fix first, then continue │
|
||||
└────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Complexity-Based Order
|
||||
|
||||
Process in this order for multi-file testing:
|
||||
|
||||
1. 🟢 Utility functions (simplest)
|
||||
1. 🟢 Custom hooks
|
||||
1. 🟡 Simple components (presentational)
|
||||
1. 🟡 Medium components (state, effects)
|
||||
1. 🔴 Complex components (API, routing)
|
||||
1. 🔴 Integration tests (index files - last)
|
||||
|
||||
### When to Refactor First
|
||||
|
||||
- **Complexity > 50**: Break into smaller pieces before testing
|
||||
- **500+ lines**: Consider splitting before testing
|
||||
- **Many dependencies**: Extract logic into hooks first
|
||||
|
||||
> 📖 See `references/workflow.md` for complete workflow details and todo list format.
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
### Path-Level Testing (Directory Testing)
|
||||
|
||||
When assigned to test a directory/path, test **ALL content** within that path:
|
||||
|
||||
- Test all components, hooks, utilities in the directory (not just `index` file)
|
||||
- Use incremental approach: one file at a time, verify each before proceeding
|
||||
- Goal: 100% coverage of ALL files in the directory
|
||||
|
||||
### Integration Testing First
|
||||
|
||||
**Prefer integration testing** when writing tests for a directory:
|
||||
|
||||
- ✅ **Import real project components** directly (including base components and siblings)
|
||||
- ✅ **Only mock**: API services (`@/service/*`), `next/navigation`, complex context providers
|
||||
- ❌ **DO NOT mock** base components (`@/app/components/base/*`)
|
||||
- ❌ **DO NOT mock** sibling/child components in the same directory
|
||||
|
||||
> See [Test Structure Template](#test-structure-template) for correct import/mock patterns.
|
||||
|
||||
## Core Principles
|
||||
|
||||
### 1. AAA Pattern (Arrange-Act-Assert)
|
||||
|
||||
Every test should clearly separate:
|
||||
|
||||
- **Arrange**: Setup test data and render component
|
||||
- **Act**: Perform user actions
|
||||
- **Assert**: Verify expected outcomes
|
||||
|
||||
### 2. Black-Box Testing
|
||||
|
||||
- Test observable behavior, not implementation details
|
||||
- Use semantic queries (getByRole, getByLabelText)
|
||||
- Avoid testing internal state directly
|
||||
- **Prefer pattern matching over hardcoded strings** in assertions:
|
||||
|
||||
```typescript
|
||||
// ❌ Avoid: hardcoded text assertions
|
||||
expect(screen.getByText('Loading...')).toBeInTheDocument()
|
||||
|
||||
// ✅ Better: role-based queries
|
||||
expect(screen.getByRole('status')).toBeInTheDocument()
|
||||
|
||||
// ✅ Better: pattern matching
|
||||
expect(screen.getByText(/loading/i)).toBeInTheDocument()
|
||||
```
|
||||
|
||||
### 3. Single Behavior Per Test
|
||||
|
||||
Each test verifies ONE user-observable behavior:
|
||||
|
||||
```typescript
|
||||
// ✅ Good: One behavior
|
||||
it('should disable button when loading', () => {
|
||||
render(<Button loading />)
|
||||
expect(screen.getByRole('button')).toBeDisabled()
|
||||
})
|
||||
|
||||
// ❌ Bad: Multiple behaviors
|
||||
it('should handle loading state', () => {
|
||||
render(<Button loading />)
|
||||
expect(screen.getByRole('button')).toBeDisabled()
|
||||
expect(screen.getByText('Loading...')).toBeInTheDocument()
|
||||
expect(screen.getByRole('button')).toHaveClass('loading')
|
||||
})
|
||||
```
|
||||
|
||||
### 4. Semantic Naming
|
||||
|
||||
Use `should <behavior> when <condition>`:
|
||||
|
||||
```typescript
|
||||
it('should show error message when validation fails')
|
||||
it('should call onSubmit when form is valid')
|
||||
it('should disable input when isReadOnly is true')
|
||||
```
|
||||
|
||||
## Required Test Scenarios
|
||||
|
||||
### Always Required (All Components)
|
||||
|
||||
1. **Rendering**: Component renders without crashing
|
||||
1. **Props**: Required props, optional props, default values
|
||||
1. **Edge Cases**: null, undefined, empty values, boundary conditions
|
||||
|
||||
### Conditional (When Present)
|
||||
|
||||
| Feature | Test Focus |
|
||||
|---------|-----------|
|
||||
| `useState` | Initial state, transitions, cleanup |
|
||||
| `useEffect` | Execution, dependencies, cleanup |
|
||||
| Event handlers | All onClick, onChange, onSubmit, keyboard |
|
||||
| API calls | Loading, success, error states |
|
||||
| Routing | Navigation, params, query strings |
|
||||
| `useCallback`/`useMemo` | Referential equality |
|
||||
| Context | Provider values, consumer behavior |
|
||||
| Forms | Validation, submission, error display |
|
||||
|
||||
## Coverage Goals (Per File)
|
||||
|
||||
For each test file generated, aim for:
|
||||
|
||||
- ✅ **100%** function coverage
|
||||
- ✅ **100%** statement coverage
|
||||
- ✅ **>95%** branch coverage
|
||||
- ✅ **>95%** line coverage
|
||||
|
||||
> **Note**: For multi-file directories, process one file at a time with full coverage each. See `references/workflow.md`.
|
||||
|
||||
## Detailed Guides
|
||||
|
||||
For more detailed information, refer to:
|
||||
|
||||
- `references/workflow.md` - **Incremental testing workflow** (MUST READ for multi-file testing)
|
||||
- `references/mocking.md` - Mock patterns and best practices
|
||||
- `references/async-testing.md` - Async operations and API calls
|
||||
- `references/domain-components.md` - Workflow, Dataset, Configuration testing
|
||||
- `references/common-patterns.md` - Frequently used testing patterns
|
||||
- `references/checklist.md` - Test generation checklist and validation steps
|
||||
|
||||
## Authoritative References
|
||||
|
||||
### Primary Specification (MUST follow)
|
||||
|
||||
- **`web/testing/testing.md`** - The canonical testing specification. This skill is derived from this document.
|
||||
|
||||
### Reference Examples in Codebase
|
||||
|
||||
- `web/utils/classnames.spec.ts` - Utility function tests
|
||||
- `web/app/components/base/button/index.spec.tsx` - Component tests
|
||||
- `web/__mocks__/provider-context.ts` - Mock factory example
|
||||
|
||||
### Project Configuration
|
||||
|
||||
- `web/vitest.config.ts` - Vitest configuration
|
||||
- `web/vitest.setup.ts` - Test environment setup
|
||||
- `web/scripts/analyze-component.js` - Component analysis tool
|
||||
- Modules are not mocked automatically. Global mocks live in `web/vitest.setup.ts` (for example `react-i18next`, `next/image`); mock other modules like `ky` or `mime` locally in test files.
|
||||
@@ -1,293 +0,0 @@
|
||||
/**
|
||||
* Test Template for React Components
|
||||
*
|
||||
* WHY THIS STRUCTURE?
|
||||
* - Organized sections make tests easy to navigate and maintain
|
||||
* - Mocks at top ensure consistent test isolation
|
||||
* - Factory functions reduce duplication and improve readability
|
||||
* - describe blocks group related scenarios for better debugging
|
||||
*
|
||||
* INSTRUCTIONS:
|
||||
* 1. Replace `ComponentName` with your component name
|
||||
* 2. Update import path
|
||||
* 3. Add/remove test sections based on component features (use analyze-component)
|
||||
* 4. Follow AAA pattern: Arrange → Act → Assert
|
||||
*
|
||||
* RUN FIRST: pnpm analyze-component <path> to identify required test scenarios
|
||||
*/
|
||||
|
||||
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
// import ComponentName from './index'
|
||||
|
||||
// ============================================================================
|
||||
// Mocks
|
||||
// ============================================================================
|
||||
// WHY: Mocks must be hoisted to top of file (Vitest requirement).
|
||||
// They run BEFORE imports, so keep them before component imports.
|
||||
|
||||
// i18n (automatically mocked)
|
||||
// WHY: Global mock in web/vitest.setup.ts is auto-loaded by Vitest setup
|
||||
// The global mock provides: useTranslation, Trans, useMixedTranslation, useGetLanguage
|
||||
// No explicit mock needed for most tests
|
||||
//
|
||||
// Override only if custom translations are required:
|
||||
// import { createReactI18nextMock } from '@/test/i18n-mock'
|
||||
// vi.mock('react-i18next', () => createReactI18nextMock({
|
||||
// 'my.custom.key': 'Custom Translation',
|
||||
// 'button.save': 'Save',
|
||||
// }))
|
||||
|
||||
// Router (if component uses useRouter, usePathname, useSearchParams)
|
||||
// WHY: Isolates tests from Next.js routing, enables testing navigation behavior
|
||||
// const mockPush = vi.fn()
|
||||
// vi.mock('next/navigation', () => ({
|
||||
// useRouter: () => ({ push: mockPush }),
|
||||
// usePathname: () => '/test-path',
|
||||
// }))
|
||||
|
||||
// API services (if component fetches data)
|
||||
// WHY: Prevents real network calls, enables testing all states (loading/success/error)
|
||||
// vi.mock('@/service/api')
|
||||
// import * as api from '@/service/api'
|
||||
// const mockedApi = vi.mocked(api)
|
||||
|
||||
// Shared mock state (for portal/dropdown components)
|
||||
// WHY: Portal components like PortalToFollowElem need shared state between
|
||||
// parent and child mocks to correctly simulate open/close behavior
|
||||
// let mockOpenState = false
|
||||
|
||||
// ============================================================================
|
||||
// Test Data Factories
|
||||
// ============================================================================
|
||||
// WHY FACTORIES?
|
||||
// - Avoid hard-coded test data scattered across tests
|
||||
// - Easy to create variations with overrides
|
||||
// - Type-safe when using actual types from source
|
||||
// - Single source of truth for default test values
|
||||
|
||||
// const createMockProps = (overrides = {}) => ({
|
||||
// // Default props that make component render successfully
|
||||
// ...overrides,
|
||||
// })
|
||||
|
||||
// const createMockItem = (overrides = {}) => ({
|
||||
// id: 'item-1',
|
||||
// name: 'Test Item',
|
||||
// ...overrides,
|
||||
// })
|
||||
|
||||
// ============================================================================
|
||||
// Test Helpers
|
||||
// ============================================================================
|
||||
|
||||
// const renderComponent = (props = {}) => {
|
||||
// return render(<ComponentName {...createMockProps(props)} />)
|
||||
// }
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
describe('ComponentName', () => {
|
||||
// WHY beforeEach with clearAllMocks?
|
||||
// - Ensures each test starts with clean slate
|
||||
// - Prevents mock call history from leaking between tests
|
||||
// - MUST be beforeEach (not afterEach) to reset BEFORE assertions like toHaveBeenCalledTimes
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
// Reset shared mock state if used (CRITICAL for portal/dropdown tests)
|
||||
// mockOpenState = false
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Rendering Tests (REQUIRED - Every component MUST have these)
|
||||
// --------------------------------------------------------------------------
|
||||
// WHY: Catches import errors, missing providers, and basic render issues
|
||||
describe('Rendering', () => {
|
||||
it('should render without crashing', () => {
|
||||
// Arrange - Setup data and mocks
|
||||
// const props = createMockProps()
|
||||
|
||||
// Act - Render the component
|
||||
// render(<ComponentName {...props} />)
|
||||
|
||||
// Assert - Verify expected output
|
||||
// Prefer getByRole for accessibility; it's what users "see"
|
||||
// expect(screen.getByRole('...')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render with default props', () => {
|
||||
// WHY: Verifies component works without optional props
|
||||
// render(<ComponentName />)
|
||||
// expect(screen.getByText('...')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Props Tests (REQUIRED - Every component MUST test prop behavior)
|
||||
// --------------------------------------------------------------------------
|
||||
// WHY: Props are the component's API contract. Test them thoroughly.
|
||||
describe('Props', () => {
|
||||
it('should apply custom className', () => {
|
||||
// WHY: Common pattern in Dify - components should merge custom classes
|
||||
// render(<ComponentName className="custom-class" />)
|
||||
// expect(screen.getByTestId('component')).toHaveClass('custom-class')
|
||||
})
|
||||
|
||||
it('should use default values for optional props', () => {
|
||||
// WHY: Verifies TypeScript defaults work at runtime
|
||||
// render(<ComponentName />)
|
||||
// expect(screen.getByRole('...')).toHaveAttribute('...', 'default-value')
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// User Interactions (if component has event handlers - on*, handle*)
|
||||
// --------------------------------------------------------------------------
|
||||
// WHY: Event handlers are core functionality. Test from user's perspective.
|
||||
describe('User Interactions', () => {
|
||||
it('should call onClick when clicked', async () => {
|
||||
// WHY userEvent over fireEvent?
|
||||
// - userEvent simulates real user behavior (focus, hover, then click)
|
||||
// - fireEvent is lower-level, doesn't trigger all browser events
|
||||
// const user = userEvent.setup()
|
||||
// const handleClick = vi.fn()
|
||||
// render(<ComponentName onClick={handleClick} />)
|
||||
//
|
||||
// await user.click(screen.getByRole('button'))
|
||||
//
|
||||
// expect(handleClick).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should call onChange when value changes', async () => {
|
||||
// const user = userEvent.setup()
|
||||
// const handleChange = vi.fn()
|
||||
// render(<ComponentName onChange={handleChange} />)
|
||||
//
|
||||
// await user.type(screen.getByRole('textbox'), 'new value')
|
||||
//
|
||||
// expect(handleChange).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// State Management (if component uses useState/useReducer)
|
||||
// --------------------------------------------------------------------------
|
||||
// WHY: Test state through observable UI changes, not internal state values
|
||||
describe('State Management', () => {
|
||||
it('should update state on interaction', async () => {
|
||||
// WHY test via UI, not state?
|
||||
// - State is implementation detail; UI is what users see
|
||||
// - If UI works correctly, state must be correct
|
||||
// const user = userEvent.setup()
|
||||
// render(<ComponentName />)
|
||||
//
|
||||
// // Initial state - verify what user sees
|
||||
// expect(screen.getByText('Initial')).toBeInTheDocument()
|
||||
//
|
||||
// // Trigger state change via user action
|
||||
// await user.click(screen.getByRole('button'))
|
||||
//
|
||||
// // New state - verify UI updated
|
||||
// expect(screen.getByText('Updated')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Async Operations (if component fetches data - useQuery, fetch)
|
||||
// --------------------------------------------------------------------------
|
||||
// WHY: Async operations have 3 states users experience: loading, success, error
|
||||
describe('Async Operations', () => {
|
||||
it('should show loading state', () => {
|
||||
// WHY never-resolving promise?
|
||||
// - Keeps component in loading state for assertion
|
||||
// - Alternative: use fake timers
|
||||
// mockedApi.fetchData.mockImplementation(() => new Promise(() => {}))
|
||||
// render(<ComponentName />)
|
||||
//
|
||||
// expect(screen.getByText(/loading/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show data on success', async () => {
|
||||
// WHY waitFor?
|
||||
// - Component updates asynchronously after fetch resolves
|
||||
// - waitFor retries assertion until it passes or times out
|
||||
// mockedApi.fetchData.mockResolvedValue({ items: ['Item 1'] })
|
||||
// render(<ComponentName />)
|
||||
//
|
||||
// await waitFor(() => {
|
||||
// expect(screen.getByText('Item 1')).toBeInTheDocument()
|
||||
// })
|
||||
})
|
||||
|
||||
it('should show error on failure', async () => {
|
||||
// mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
|
||||
// render(<ComponentName />)
|
||||
//
|
||||
// await waitFor(() => {
|
||||
// expect(screen.getByText(/error/i)).toBeInTheDocument()
|
||||
// })
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Edge Cases (REQUIRED - Every component MUST handle edge cases)
|
||||
// --------------------------------------------------------------------------
|
||||
// WHY: Real-world data is messy. Components must handle:
|
||||
// - Null/undefined from API failures or optional fields
|
||||
// - Empty arrays/strings from user clearing data
|
||||
// - Boundary values (0, MAX_INT, special characters)
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle null value', () => {
|
||||
// WHY test null specifically?
|
||||
// - API might return null for missing data
|
||||
// - Prevents "Cannot read property of null" in production
|
||||
// render(<ComponentName value={null} />)
|
||||
// expect(screen.getByText(/no data/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle undefined value', () => {
|
||||
// WHY test undefined separately from null?
|
||||
// - TypeScript treats them differently
|
||||
// - Optional props are undefined, not null
|
||||
// render(<ComponentName value={undefined} />)
|
||||
// expect(screen.getByText(/no data/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle empty array', () => {
|
||||
// WHY: Empty state often needs special UI (e.g., "No items yet")
|
||||
// render(<ComponentName items={[]} />)
|
||||
// expect(screen.getByText(/empty/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle empty string', () => {
|
||||
// WHY: Empty strings are truthy in JS but visually empty
|
||||
// render(<ComponentName text="" />)
|
||||
// expect(screen.getByText(/placeholder/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Accessibility (optional but recommended for Dify's enterprise users)
|
||||
// --------------------------------------------------------------------------
|
||||
// WHY: Dify has enterprise customers who may require accessibility compliance
|
||||
describe('Accessibility', () => {
|
||||
it('should have accessible name', () => {
|
||||
// WHY getByRole with name?
|
||||
// - Tests that screen readers can identify the element
|
||||
// - Enforces proper labeling practices
|
||||
// render(<ComponentName label="Test Label" />)
|
||||
// expect(screen.getByRole('button', { name: /test label/i })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should support keyboard navigation', async () => {
|
||||
// WHY: Some users can't use a mouse
|
||||
// const user = userEvent.setup()
|
||||
// render(<ComponentName />)
|
||||
//
|
||||
// await user.tab()
|
||||
// expect(screen.getByRole('button')).toHaveFocus()
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,207 +0,0 @@
|
||||
/**
|
||||
* Test Template for Custom Hooks
|
||||
*
|
||||
* Instructions:
|
||||
* 1. Replace `useHookName` with your hook name
|
||||
* 2. Update import path
|
||||
* 3. Add/remove test sections based on hook features
|
||||
*/
|
||||
|
||||
import { renderHook, act, waitFor } from '@testing-library/react'
|
||||
// import { useHookName } from './use-hook-name'
|
||||
|
||||
// ============================================================================
|
||||
// Mocks
|
||||
// ============================================================================
|
||||
|
||||
// API services (if hook fetches data)
|
||||
// vi.mock('@/service/api')
|
||||
// import * as api from '@/service/api'
|
||||
// const mockedApi = vi.mocked(api)
|
||||
|
||||
// ============================================================================
|
||||
// Test Helpers
|
||||
// ============================================================================
|
||||
|
||||
// Wrapper for hooks that need context
|
||||
// const createWrapper = (contextValue = {}) => {
|
||||
// return ({ children }: { children: React.ReactNode }) => (
|
||||
// <SomeContext.Provider value={contextValue}>
|
||||
// {children}
|
||||
// </SomeContext.Provider>
|
||||
// )
|
||||
// }
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
describe('useHookName', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Initial State
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Initial State', () => {
|
||||
it('should return initial state', () => {
|
||||
// const { result } = renderHook(() => useHookName())
|
||||
//
|
||||
// expect(result.current.value).toBe(initialValue)
|
||||
// expect(result.current.isLoading).toBe(false)
|
||||
})
|
||||
|
||||
it('should accept initial value from props', () => {
|
||||
// const { result } = renderHook(() => useHookName({ initialValue: 'custom' }))
|
||||
//
|
||||
// expect(result.current.value).toBe('custom')
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// State Updates
|
||||
// --------------------------------------------------------------------------
|
||||
describe('State Updates', () => {
|
||||
it('should update value when setValue is called', () => {
|
||||
// const { result } = renderHook(() => useHookName())
|
||||
//
|
||||
// act(() => {
|
||||
// result.current.setValue('new value')
|
||||
// })
|
||||
//
|
||||
// expect(result.current.value).toBe('new value')
|
||||
})
|
||||
|
||||
it('should reset to initial value', () => {
|
||||
// const { result } = renderHook(() => useHookName({ initialValue: 'initial' }))
|
||||
//
|
||||
// act(() => {
|
||||
// result.current.setValue('changed')
|
||||
// })
|
||||
// expect(result.current.value).toBe('changed')
|
||||
//
|
||||
// act(() => {
|
||||
// result.current.reset()
|
||||
// })
|
||||
// expect(result.current.value).toBe('initial')
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Async Operations
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Async Operations', () => {
|
||||
it('should fetch data on mount', async () => {
|
||||
// mockedApi.fetchData.mockResolvedValue({ data: 'test' })
|
||||
//
|
||||
// const { result } = renderHook(() => useHookName())
|
||||
//
|
||||
// // Initially loading
|
||||
// expect(result.current.isLoading).toBe(true)
|
||||
//
|
||||
// // Wait for data
|
||||
// await waitFor(() => {
|
||||
// expect(result.current.isLoading).toBe(false)
|
||||
// })
|
||||
//
|
||||
// expect(result.current.data).toEqual({ data: 'test' })
|
||||
})
|
||||
|
||||
it('should handle fetch error', async () => {
|
||||
// mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
|
||||
//
|
||||
// const { result } = renderHook(() => useHookName())
|
||||
//
|
||||
// await waitFor(() => {
|
||||
// expect(result.current.error).toBeTruthy()
|
||||
// })
|
||||
//
|
||||
// expect(result.current.error?.message).toBe('Network error')
|
||||
})
|
||||
|
||||
it('should refetch when dependency changes', async () => {
|
||||
// mockedApi.fetchData.mockResolvedValue({ data: 'test' })
|
||||
//
|
||||
// const { result, rerender } = renderHook(
|
||||
// ({ id }) => useHookName(id),
|
||||
// { initialProps: { id: '1' } }
|
||||
// )
|
||||
//
|
||||
// await waitFor(() => {
|
||||
// expect(mockedApi.fetchData).toHaveBeenCalledWith('1')
|
||||
// })
|
||||
//
|
||||
// rerender({ id: '2' })
|
||||
//
|
||||
// await waitFor(() => {
|
||||
// expect(mockedApi.fetchData).toHaveBeenCalledWith('2')
|
||||
// })
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Side Effects
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Side Effects', () => {
|
||||
it('should call callback when value changes', () => {
|
||||
// const callback = vi.fn()
|
||||
// const { result } = renderHook(() => useHookName({ onChange: callback }))
|
||||
//
|
||||
// act(() => {
|
||||
// result.current.setValue('new value')
|
||||
// })
|
||||
//
|
||||
// expect(callback).toHaveBeenCalledWith('new value')
|
||||
})
|
||||
|
||||
it('should cleanup on unmount', () => {
|
||||
// const cleanup = vi.fn()
|
||||
// vi.spyOn(window, 'addEventListener')
|
||||
// vi.spyOn(window, 'removeEventListener')
|
||||
//
|
||||
// const { unmount } = renderHook(() => useHookName())
|
||||
//
|
||||
// expect(window.addEventListener).toHaveBeenCalled()
|
||||
//
|
||||
// unmount()
|
||||
//
|
||||
// expect(window.removeEventListener).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Edge Cases
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle null input', () => {
|
||||
// const { result } = renderHook(() => useHookName(null))
|
||||
//
|
||||
// expect(result.current.value).toBeNull()
|
||||
})
|
||||
|
||||
it('should handle rapid updates', () => {
|
||||
// const { result } = renderHook(() => useHookName())
|
||||
//
|
||||
// act(() => {
|
||||
// result.current.setValue('1')
|
||||
// result.current.setValue('2')
|
||||
// result.current.setValue('3')
|
||||
// })
|
||||
//
|
||||
// expect(result.current.value).toBe('3')
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// With Context (if hook uses context)
|
||||
// --------------------------------------------------------------------------
|
||||
describe('With Context', () => {
|
||||
it('should use context value', () => {
|
||||
// const wrapper = createWrapper({ someValue: 'context-value' })
|
||||
// const { result } = renderHook(() => useHookName(), { wrapper })
|
||||
//
|
||||
// expect(result.current.contextValue).toBe('context-value')
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,154 +0,0 @@
|
||||
/**
|
||||
* Test Template for Utility Functions
|
||||
*
|
||||
* Instructions:
|
||||
* 1. Replace `utilityFunction` with your function name
|
||||
* 2. Update import path
|
||||
* 3. Use test.each for data-driven tests
|
||||
*/
|
||||
|
||||
// import { utilityFunction } from './utility'
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
describe('utilityFunction', () => {
|
||||
// --------------------------------------------------------------------------
|
||||
// Basic Functionality
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Basic Functionality', () => {
|
||||
it('should return expected result for valid input', () => {
|
||||
// expect(utilityFunction('input')).toBe('expected-output')
|
||||
})
|
||||
|
||||
it('should handle multiple arguments', () => {
|
||||
// expect(utilityFunction('a', 'b', 'c')).toBe('abc')
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Data-Driven Tests
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Input/Output Mapping', () => {
|
||||
test.each([
|
||||
// [input, expected]
|
||||
['input1', 'output1'],
|
||||
['input2', 'output2'],
|
||||
['input3', 'output3'],
|
||||
])('should return %s for input %s', (input, expected) => {
|
||||
// expect(utilityFunction(input)).toBe(expected)
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Edge Cases
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle empty string', () => {
|
||||
// expect(utilityFunction('')).toBe('')
|
||||
})
|
||||
|
||||
it('should handle null', () => {
|
||||
// expect(utilityFunction(null)).toBe(null)
|
||||
// or
|
||||
// expect(() => utilityFunction(null)).toThrow()
|
||||
})
|
||||
|
||||
it('should handle undefined', () => {
|
||||
// expect(utilityFunction(undefined)).toBe(undefined)
|
||||
// or
|
||||
// expect(() => utilityFunction(undefined)).toThrow()
|
||||
})
|
||||
|
||||
it('should handle empty array', () => {
|
||||
// expect(utilityFunction([])).toEqual([])
|
||||
})
|
||||
|
||||
it('should handle empty object', () => {
|
||||
// expect(utilityFunction({})).toEqual({})
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Boundary Conditions
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Boundary Conditions', () => {
|
||||
it('should handle minimum value', () => {
|
||||
// expect(utilityFunction(0)).toBe(0)
|
||||
})
|
||||
|
||||
it('should handle maximum value', () => {
|
||||
// expect(utilityFunction(Number.MAX_SAFE_INTEGER)).toBe(...)
|
||||
})
|
||||
|
||||
it('should handle negative numbers', () => {
|
||||
// expect(utilityFunction(-1)).toBe(...)
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Type Coercion (if applicable)
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Type Handling', () => {
|
||||
it('should handle numeric string', () => {
|
||||
// expect(utilityFunction('123')).toBe(123)
|
||||
})
|
||||
|
||||
it('should handle boolean', () => {
|
||||
// expect(utilityFunction(true)).toBe(...)
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Error Cases
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Error Handling', () => {
|
||||
it('should throw for invalid input', () => {
|
||||
// expect(() => utilityFunction('invalid')).toThrow('Error message')
|
||||
})
|
||||
|
||||
it('should throw with specific error type', () => {
|
||||
// expect(() => utilityFunction('invalid')).toThrow(ValidationError)
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Complex Objects (if applicable)
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Object Handling', () => {
|
||||
it('should preserve object structure', () => {
|
||||
// const input = { a: 1, b: 2 }
|
||||
// expect(utilityFunction(input)).toEqual({ a: 1, b: 2 })
|
||||
})
|
||||
|
||||
it('should handle nested objects', () => {
|
||||
// const input = { nested: { deep: 'value' } }
|
||||
// expect(utilityFunction(input)).toEqual({ nested: { deep: 'transformed' } })
|
||||
})
|
||||
|
||||
it('should not mutate input', () => {
|
||||
// const input = { a: 1 }
|
||||
// const inputCopy = { ...input }
|
||||
// utilityFunction(input)
|
||||
// expect(input).toEqual(inputCopy)
|
||||
})
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Array Handling (if applicable)
|
||||
// --------------------------------------------------------------------------
|
||||
describe('Array Handling', () => {
|
||||
it('should process all elements', () => {
|
||||
// expect(utilityFunction([1, 2, 3])).toEqual([2, 4, 6])
|
||||
})
|
||||
|
||||
it('should handle single element array', () => {
|
||||
// expect(utilityFunction([1])).toEqual([2])
|
||||
})
|
||||
|
||||
it('should preserve order', () => {
|
||||
// expect(utilityFunction(['c', 'a', 'b'])).toEqual(['c', 'a', 'b'])
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,345 +0,0 @@
|
||||
# Async Testing Guide
|
||||
|
||||
## Core Async Patterns
|
||||
|
||||
### 1. waitFor - Wait for Condition
|
||||
|
||||
```typescript
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
|
||||
it('should load and display data', async () => {
|
||||
render(<DataComponent />)
|
||||
|
||||
// Wait for element to appear
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('Loaded Data')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should hide loading spinner after load', async () => {
|
||||
render(<DataComponent />)
|
||||
|
||||
// Wait for element to disappear
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByText('Loading...')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### 2. findBy\* - Async Queries
|
||||
|
||||
```typescript
|
||||
it('should show user name after fetch', async () => {
|
||||
render(<UserProfile />)
|
||||
|
||||
// findBy returns a promise, auto-waits up to 1000ms
|
||||
const userName = await screen.findByText('John Doe')
|
||||
expect(userName).toBeInTheDocument()
|
||||
|
||||
// findByRole with options
|
||||
const button = await screen.findByRole('button', { name: /submit/i })
|
||||
expect(button).toBeEnabled()
|
||||
})
|
||||
```
|
||||
|
||||
### 3. userEvent for Async Interactions
|
||||
|
||||
```typescript
|
||||
import userEvent from '@testing-library/user-event'
|
||||
|
||||
it('should submit form', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onSubmit = vi.fn()
|
||||
|
||||
render(<Form onSubmit={onSubmit} />)
|
||||
|
||||
// userEvent methods are async
|
||||
await user.type(screen.getByLabelText('Email'), 'test@example.com')
|
||||
await user.click(screen.getByRole('button', { name: /submit/i }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(onSubmit).toHaveBeenCalledWith({ email: 'test@example.com' })
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
## Fake Timers
|
||||
|
||||
### When to Use Fake Timers
|
||||
|
||||
- Testing components with `setTimeout`/`setInterval`
|
||||
- Testing debounce/throttle behavior
|
||||
- Testing animations or delayed transitions
|
||||
- Testing polling or retry logic
|
||||
|
||||
### Basic Fake Timer Setup
|
||||
|
||||
```typescript
|
||||
describe('Debounced Search', () => {
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
it('should debounce search input', async () => {
|
||||
const onSearch = vi.fn()
|
||||
render(<SearchInput onSearch={onSearch} debounceMs={300} />)
|
||||
|
||||
// Type in the input
|
||||
fireEvent.change(screen.getByRole('textbox'), { target: { value: 'query' } })
|
||||
|
||||
// Search not called immediately
|
||||
expect(onSearch).not.toHaveBeenCalled()
|
||||
|
||||
// Advance timers
|
||||
vi.advanceTimersByTime(300)
|
||||
|
||||
// Now search is called
|
||||
expect(onSearch).toHaveBeenCalledWith('query')
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### Fake Timers with Async Code
|
||||
|
||||
```typescript
|
||||
it('should retry on failure', async () => {
|
||||
vi.useFakeTimers()
|
||||
const fetchData = vi.fn()
|
||||
.mockRejectedValueOnce(new Error('Network error'))
|
||||
.mockResolvedValueOnce({ data: 'success' })
|
||||
|
||||
render(<RetryComponent fetchData={fetchData} retryDelayMs={1000} />)
|
||||
|
||||
// First call fails
|
||||
await waitFor(() => {
|
||||
expect(fetchData).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
// Advance timer for retry
|
||||
vi.advanceTimersByTime(1000)
|
||||
|
||||
// Second call succeeds
|
||||
await waitFor(() => {
|
||||
expect(fetchData).toHaveBeenCalledTimes(2)
|
||||
expect(screen.getByText('success')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
vi.useRealTimers()
|
||||
})
|
||||
```
|
||||
|
||||
### Common Fake Timer Utilities
|
||||
|
||||
```typescript
|
||||
// Run all pending timers
|
||||
vi.runAllTimers()
|
||||
|
||||
// Run only pending timers (not new ones created during execution)
|
||||
vi.runOnlyPendingTimers()
|
||||
|
||||
// Advance by specific time
|
||||
vi.advanceTimersByTime(1000)
|
||||
|
||||
// Get current fake time
|
||||
Date.now()
|
||||
|
||||
// Clear all timers
|
||||
vi.clearAllTimers()
|
||||
```
|
||||
|
||||
## API Testing Patterns
|
||||
|
||||
### Loading → Success → Error States
|
||||
|
||||
```typescript
|
||||
describe('DataFetcher', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should show loading state', () => {
|
||||
mockedApi.fetchData.mockImplementation(() => new Promise(() => {})) // Never resolves
|
||||
|
||||
render(<DataFetcher />)
|
||||
|
||||
expect(screen.getByTestId('loading-spinner')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show data on success', async () => {
|
||||
mockedApi.fetchData.mockResolvedValue({ items: ['Item 1', 'Item 2'] })
|
||||
|
||||
render(<DataFetcher />)
|
||||
|
||||
// Use findBy* for multiple async elements (better error messages than waitFor with multiple assertions)
|
||||
const item1 = await screen.findByText('Item 1')
|
||||
const item2 = await screen.findByText('Item 2')
|
||||
expect(item1).toBeInTheDocument()
|
||||
expect(item2).toBeInTheDocument()
|
||||
|
||||
expect(screen.queryByTestId('loading-spinner')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show error on failure', async () => {
|
||||
mockedApi.fetchData.mockRejectedValue(new Error('Failed to fetch'))
|
||||
|
||||
render(<DataFetcher />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/failed to fetch/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should retry on error', async () => {
|
||||
mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
|
||||
|
||||
render(<DataFetcher />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByRole('button', { name: /retry/i })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
mockedApi.fetchData.mockResolvedValue({ items: ['Item 1'] })
|
||||
fireEvent.click(screen.getByRole('button', { name: /retry/i }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('Item 1')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### Testing Mutations
|
||||
|
||||
```typescript
|
||||
it('should submit form and show success', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockedApi.createItem.mockResolvedValue({ id: '1', name: 'New Item' })
|
||||
|
||||
render(<CreateItemForm />)
|
||||
|
||||
await user.type(screen.getByLabelText('Name'), 'New Item')
|
||||
await user.click(screen.getByRole('button', { name: /create/i }))
|
||||
|
||||
// Button should be disabled during submission
|
||||
expect(screen.getByRole('button', { name: /creating/i })).toBeDisabled()
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/created successfully/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
expect(mockedApi.createItem).toHaveBeenCalledWith({ name: 'New Item' })
|
||||
})
|
||||
```
|
||||
|
||||
## useEffect Testing
|
||||
|
||||
### Testing Effect Execution
|
||||
|
||||
```typescript
|
||||
it('should fetch data on mount', async () => {
|
||||
const fetchData = vi.fn().mockResolvedValue({ data: 'test' })
|
||||
|
||||
render(<ComponentWithEffect fetchData={fetchData} />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(fetchData).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### Testing Effect Dependencies
|
||||
|
||||
```typescript
|
||||
it('should refetch when id changes', async () => {
|
||||
const fetchData = vi.fn().mockResolvedValue({ data: 'test' })
|
||||
|
||||
const { rerender } = render(<ComponentWithEffect id="1" fetchData={fetchData} />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(fetchData).toHaveBeenCalledWith('1')
|
||||
})
|
||||
|
||||
rerender(<ComponentWithEffect id="2" fetchData={fetchData} />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(fetchData).toHaveBeenCalledWith('2')
|
||||
expect(fetchData).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### Testing Effect Cleanup
|
||||
|
||||
```typescript
|
||||
it('should cleanup subscription on unmount', () => {
|
||||
const subscribe = vi.fn()
|
||||
const unsubscribe = vi.fn()
|
||||
subscribe.mockReturnValue(unsubscribe)
|
||||
|
||||
const { unmount } = render(<SubscriptionComponent subscribe={subscribe} />)
|
||||
|
||||
expect(subscribe).toHaveBeenCalledTimes(1)
|
||||
|
||||
unmount()
|
||||
|
||||
expect(unsubscribe).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
```
|
||||
|
||||
## Common Async Pitfalls
|
||||
|
||||
### ❌ Don't: Forget to await
|
||||
|
||||
```typescript
|
||||
// Bad - test may pass even if assertion fails
|
||||
it('should load data', () => {
|
||||
render(<Component />)
|
||||
waitFor(() => {
|
||||
expect(screen.getByText('Data')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// Good - properly awaited
|
||||
it('should load data', async () => {
|
||||
render(<Component />)
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('Data')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### ❌ Don't: Use multiple assertions in single waitFor
|
||||
|
||||
```typescript
|
||||
// Bad - if first assertion fails, won't know about second
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('Title')).toBeInTheDocument()
|
||||
expect(screen.getByText('Description')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Good - separate waitFor or use findBy
|
||||
const title = await screen.findByText('Title')
|
||||
const description = await screen.findByText('Description')
|
||||
expect(title).toBeInTheDocument()
|
||||
expect(description).toBeInTheDocument()
|
||||
```
|
||||
|
||||
### ❌ Don't: Mix fake timers with real async
|
||||
|
||||
```typescript
|
||||
// Bad - fake timers don't work well with real Promises
|
||||
vi.useFakeTimers()
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('Data')).toBeInTheDocument()
|
||||
}) // May timeout!
|
||||
|
||||
// Good - use runAllTimers or advanceTimersByTime
|
||||
vi.useFakeTimers()
|
||||
render(<Component />)
|
||||
vi.runAllTimers()
|
||||
expect(screen.getByText('Data')).toBeInTheDocument()
|
||||
```
|
||||
@@ -1,205 +0,0 @@
|
||||
# Test Generation Checklist
|
||||
|
||||
Use this checklist when generating or reviewing tests for Dify frontend components.
|
||||
|
||||
## Pre-Generation
|
||||
|
||||
- [ ] Read the component source code completely
|
||||
- [ ] Identify component type (component, hook, utility, page)
|
||||
- [ ] Run `pnpm analyze-component <path>` if available
|
||||
- [ ] Note complexity score and features detected
|
||||
- [ ] Check for existing tests in the same directory
|
||||
- [ ] **Identify ALL files in the directory** that need testing (not just index)
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
### ⚠️ Incremental Workflow (CRITICAL for Multi-File)
|
||||
|
||||
- [ ] **NEVER generate all tests at once** - process one file at a time
|
||||
- [ ] Order files by complexity: utilities → hooks → simple → complex → integration
|
||||
- [ ] Create a todo list to track progress before starting
|
||||
- [ ] For EACH file: write → run test → verify pass → then next
|
||||
- [ ] **DO NOT proceed** to next file until current one passes
|
||||
|
||||
### Path-Level Coverage
|
||||
|
||||
- [ ] **Test ALL files** in the assigned directory/path
|
||||
- [ ] List all components, hooks, utilities that need coverage
|
||||
- [ ] Decide: single spec file (integration) or multiple spec files (unit)
|
||||
|
||||
### Complexity Assessment
|
||||
|
||||
- [ ] Run `pnpm analyze-component <path>` for complexity score
|
||||
- [ ] **Complexity > 50**: Consider refactoring before testing
|
||||
- [ ] **500+ lines**: Consider splitting before testing
|
||||
- [ ] **30-50 complexity**: Use multiple describe blocks, organized structure
|
||||
|
||||
### Integration vs Mocking
|
||||
|
||||
- [ ] **DO NOT mock base components** (`Loading`, `Button`, `Tooltip`, etc.)
|
||||
- [ ] Import real project components instead of mocking
|
||||
- [ ] Only mock: API calls, complex context providers, third-party libs with side effects
|
||||
- [ ] Prefer integration testing when using single spec file
|
||||
|
||||
## Required Test Sections
|
||||
|
||||
### All Components MUST Have
|
||||
|
||||
- [ ] **Rendering tests** - Component renders without crashing
|
||||
- [ ] **Props tests** - Required props, optional props, default values
|
||||
- [ ] **Edge cases** - null, undefined, empty values, boundaries
|
||||
|
||||
### Conditional Sections (Add When Feature Present)
|
||||
|
||||
| Feature | Add Tests For |
|
||||
|---------|---------------|
|
||||
| `useState` | Initial state, transitions, cleanup |
|
||||
| `useEffect` | Execution, dependencies, cleanup |
|
||||
| Event handlers | onClick, onChange, onSubmit, keyboard |
|
||||
| API calls | Loading, success, error states |
|
||||
| Routing | Navigation, params, query strings |
|
||||
| `useCallback`/`useMemo` | Referential equality |
|
||||
| Context | Provider values, consumer behavior |
|
||||
| Forms | Validation, submission, error display |
|
||||
|
||||
## Code Quality Checklist
|
||||
|
||||
### Structure
|
||||
|
||||
- [ ] Uses `describe` blocks to group related tests
|
||||
- [ ] Test names follow `should <behavior> when <condition>` pattern
|
||||
- [ ] AAA pattern (Arrange-Act-Assert) is clear
|
||||
- [ ] Comments explain complex test scenarios
|
||||
|
||||
### Mocks
|
||||
|
||||
- [ ] **DO NOT mock base components** (`@/app/components/base/*`)
|
||||
- [ ] `vi.clearAllMocks()` in `beforeEach` (not `afterEach`)
|
||||
- [ ] Shared mock state reset in `beforeEach`
|
||||
- [ ] i18n uses global mock (auto-loaded in `web/vitest.setup.ts`); only override locally for custom translations
|
||||
- [ ] Router mocks match actual Next.js API
|
||||
- [ ] Mocks reflect actual component conditional behavior
|
||||
- [ ] Only mock: API services, complex context providers, third-party libs
|
||||
|
||||
### Queries
|
||||
|
||||
- [ ] Prefer semantic queries (`getByRole`, `getByLabelText`)
|
||||
- [ ] Use `queryBy*` for absence assertions
|
||||
- [ ] Use `findBy*` for async elements
|
||||
- [ ] `getByTestId` only as last resort
|
||||
|
||||
### Async
|
||||
|
||||
- [ ] All async tests use `async/await`
|
||||
- [ ] `waitFor` wraps async assertions
|
||||
- [ ] Fake timers properly setup/teardown
|
||||
- [ ] No floating promises
|
||||
|
||||
### TypeScript
|
||||
|
||||
- [ ] No `any` types without justification
|
||||
- [ ] Mock data uses actual types from source
|
||||
- [ ] Factory functions have proper return types
|
||||
|
||||
## Coverage Goals (Per File)
|
||||
|
||||
For the current file being tested:
|
||||
|
||||
- [ ] 100% function coverage
|
||||
- [ ] 100% statement coverage
|
||||
- [ ] >95% branch coverage
|
||||
- [ ] >95% line coverage
|
||||
|
||||
## Post-Generation (Per File)
|
||||
|
||||
**Run these checks after EACH test file, not just at the end:**
|
||||
|
||||
- [ ] Run `pnpm test path/to/file.spec.tsx` - **MUST PASS before next file**
|
||||
- [ ] Fix any failures immediately
|
||||
- [ ] Mark file as complete in todo list
|
||||
- [ ] Only then proceed to next file
|
||||
|
||||
### After All Files Complete
|
||||
|
||||
- [ ] Run full directory test: `pnpm test path/to/directory/`
|
||||
- [ ] Check coverage report: `pnpm test:coverage`
|
||||
- [ ] Run `pnpm lint:fix` on all test files
|
||||
- [ ] Run `pnpm type-check:tsgo`
|
||||
|
||||
## Common Issues to Watch
|
||||
|
||||
### False Positives
|
||||
|
||||
```typescript
|
||||
// ❌ Mock doesn't match actual behavior
|
||||
vi.mock('./Component', () => () => <div>Mocked</div>)
|
||||
|
||||
// ✅ Mock matches actual conditional logic
|
||||
vi.mock('./Component', () => ({ isOpen }: any) =>
|
||||
isOpen ? <div>Content</div> : null
|
||||
)
|
||||
```
|
||||
|
||||
### State Leakage
|
||||
|
||||
```typescript
|
||||
// ❌ Shared state not reset
|
||||
let mockState = false
|
||||
vi.mock('./useHook', () => () => mockState)
|
||||
|
||||
// ✅ Reset in beforeEach
|
||||
beforeEach(() => {
|
||||
mockState = false
|
||||
})
|
||||
```
|
||||
|
||||
### Async Race Conditions
|
||||
|
||||
```typescript
|
||||
// ❌ Not awaited
|
||||
it('loads data', () => {
|
||||
render(<Component />)
|
||||
expect(screen.getByText('Data')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// ✅ Properly awaited
|
||||
it('loads data', async () => {
|
||||
render(<Component />)
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('Data')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### Missing Edge Cases
|
||||
|
||||
Always test these scenarios:
|
||||
|
||||
- `null` / `undefined` inputs
|
||||
- Empty strings / arrays / objects
|
||||
- Boundary values (0, -1, MAX_INT)
|
||||
- Error states
|
||||
- Loading states
|
||||
- Disabled states
|
||||
|
||||
## Quick Commands
|
||||
|
||||
```bash
|
||||
# Run specific test
|
||||
pnpm test path/to/file.spec.tsx
|
||||
|
||||
# Run with coverage
|
||||
pnpm test:coverage path/to/file.spec.tsx
|
||||
|
||||
# Watch mode
|
||||
pnpm test:watch path/to/file.spec.tsx
|
||||
|
||||
# Update snapshots (use sparingly)
|
||||
pnpm test -u path/to/file.spec.tsx
|
||||
|
||||
# Analyze component
|
||||
pnpm analyze-component path/to/component.tsx
|
||||
|
||||
# Review existing test
|
||||
pnpm analyze-component path/to/component.tsx --review
|
||||
```
|
||||
@@ -1,449 +0,0 @@
|
||||
# Common Testing Patterns
|
||||
|
||||
## Query Priority
|
||||
|
||||
Use queries in this order (most to least preferred):
|
||||
|
||||
```typescript
|
||||
// 1. getByRole - Most recommended (accessibility)
|
||||
screen.getByRole('button', { name: /submit/i })
|
||||
screen.getByRole('textbox', { name: /email/i })
|
||||
screen.getByRole('heading', { level: 1 })
|
||||
|
||||
// 2. getByLabelText - Form fields
|
||||
screen.getByLabelText('Email address')
|
||||
screen.getByLabelText(/password/i)
|
||||
|
||||
// 3. getByPlaceholderText - When no label
|
||||
screen.getByPlaceholderText('Search...')
|
||||
|
||||
// 4. getByText - Non-interactive elements
|
||||
screen.getByText('Welcome to Dify')
|
||||
screen.getByText(/loading/i)
|
||||
|
||||
// 5. getByDisplayValue - Current input value
|
||||
screen.getByDisplayValue('current value')
|
||||
|
||||
// 6. getByAltText - Images
|
||||
screen.getByAltText('Company logo')
|
||||
|
||||
// 7. getByTitle - Tooltip elements
|
||||
screen.getByTitle('Close')
|
||||
|
||||
// 8. getByTestId - Last resort only!
|
||||
screen.getByTestId('custom-element')
|
||||
```
|
||||
|
||||
## Event Handling Patterns
|
||||
|
||||
### Click Events
|
||||
|
||||
```typescript
|
||||
// Basic click
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
|
||||
// With userEvent (preferred for realistic interaction)
|
||||
const user = userEvent.setup()
|
||||
await user.click(screen.getByRole('button'))
|
||||
|
||||
// Double click
|
||||
await user.dblClick(screen.getByRole('button'))
|
||||
|
||||
// Right click
|
||||
await user.pointer({ keys: '[MouseRight]', target: screen.getByRole('button') })
|
||||
```
|
||||
|
||||
### Form Input
|
||||
|
||||
```typescript
|
||||
const user = userEvent.setup()
|
||||
|
||||
// Type in input
|
||||
await user.type(screen.getByRole('textbox'), 'Hello World')
|
||||
|
||||
// Clear and type
|
||||
await user.clear(screen.getByRole('textbox'))
|
||||
await user.type(screen.getByRole('textbox'), 'New value')
|
||||
|
||||
// Select option
|
||||
await user.selectOptions(screen.getByRole('combobox'), 'option-value')
|
||||
|
||||
// Check checkbox
|
||||
await user.click(screen.getByRole('checkbox'))
|
||||
|
||||
// Upload file
|
||||
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
|
||||
await user.upload(screen.getByLabelText(/upload/i), file)
|
||||
```
|
||||
|
||||
### Keyboard Events
|
||||
|
||||
```typescript
|
||||
const user = userEvent.setup()
|
||||
|
||||
// Press Enter
|
||||
await user.keyboard('{Enter}')
|
||||
|
||||
// Press Escape
|
||||
await user.keyboard('{Escape}')
|
||||
|
||||
// Keyboard shortcut
|
||||
await user.keyboard('{Control>}a{/Control}') // Ctrl+A
|
||||
|
||||
// Tab navigation
|
||||
await user.tab()
|
||||
|
||||
// Arrow keys
|
||||
await user.keyboard('{ArrowDown}')
|
||||
await user.keyboard('{ArrowUp}')
|
||||
```
|
||||
|
||||
## Component State Testing
|
||||
|
||||
### Testing State Transitions
|
||||
|
||||
```typescript
|
||||
describe('Counter', () => {
|
||||
it('should increment count', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(<Counter initialCount={0} />)
|
||||
|
||||
// Initial state
|
||||
expect(screen.getByText('Count: 0')).toBeInTheDocument()
|
||||
|
||||
// Trigger transition
|
||||
await user.click(screen.getByRole('button', { name: /increment/i }))
|
||||
|
||||
// New state
|
||||
expect(screen.getByText('Count: 1')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### Testing Controlled Components
|
||||
|
||||
```typescript
|
||||
describe('ControlledInput', () => {
|
||||
it('should call onChange with new value', async () => {
|
||||
const user = userEvent.setup()
|
||||
const handleChange = vi.fn()
|
||||
|
||||
render(<ControlledInput value="" onChange={handleChange} />)
|
||||
|
||||
await user.type(screen.getByRole('textbox'), 'a')
|
||||
|
||||
expect(handleChange).toHaveBeenCalledWith('a')
|
||||
})
|
||||
|
||||
it('should display controlled value', () => {
|
||||
render(<ControlledInput value="controlled" onChange={vi.fn()} />)
|
||||
|
||||
expect(screen.getByRole('textbox')).toHaveValue('controlled')
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
## Conditional Rendering Testing
|
||||
|
||||
```typescript
|
||||
describe('ConditionalComponent', () => {
|
||||
it('should show loading state', () => {
|
||||
render(<DataDisplay isLoading={true} data={null} />)
|
||||
|
||||
expect(screen.getByText(/loading/i)).toBeInTheDocument()
|
||||
expect(screen.queryByTestId('data-content')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show error state', () => {
|
||||
render(<DataDisplay isLoading={false} data={null} error="Failed to load" />)
|
||||
|
||||
expect(screen.getByText(/failed to load/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show data when loaded', () => {
|
||||
render(<DataDisplay isLoading={false} data={{ name: 'Test' }} />)
|
||||
|
||||
expect(screen.getByText('Test')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show empty state when no data', () => {
|
||||
render(<DataDisplay isLoading={false} data={[]} />)
|
||||
|
||||
expect(screen.getByText(/no data/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
## List Rendering Testing
|
||||
|
||||
```typescript
|
||||
describe('ItemList', () => {
|
||||
const items = [
|
||||
{ id: '1', name: 'Item 1' },
|
||||
{ id: '2', name: 'Item 2' },
|
||||
{ id: '3', name: 'Item 3' },
|
||||
]
|
||||
|
||||
it('should render all items', () => {
|
||||
render(<ItemList items={items} />)
|
||||
|
||||
expect(screen.getAllByRole('listitem')).toHaveLength(3)
|
||||
items.forEach(item => {
|
||||
expect(screen.getByText(item.name)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle item selection', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onSelect = vi.fn()
|
||||
|
||||
render(<ItemList items={items} onSelect={onSelect} />)
|
||||
|
||||
await user.click(screen.getByText('Item 2'))
|
||||
|
||||
expect(onSelect).toHaveBeenCalledWith(items[1])
|
||||
})
|
||||
|
||||
it('should handle empty list', () => {
|
||||
render(<ItemList items={[]} />)
|
||||
|
||||
expect(screen.getByText(/no items/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
## Modal/Dialog Testing
|
||||
|
||||
```typescript
|
||||
describe('Modal', () => {
|
||||
it('should not render when closed', () => {
|
||||
render(<Modal isOpen={false} onClose={vi.fn()} />)
|
||||
|
||||
expect(screen.queryByRole('dialog')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render when open', () => {
|
||||
render(<Modal isOpen={true} onClose={vi.fn()} />)
|
||||
|
||||
expect(screen.getByRole('dialog')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should call onClose when clicking overlay', async () => {
|
||||
const user = userEvent.setup()
|
||||
const handleClose = vi.fn()
|
||||
|
||||
render(<Modal isOpen={true} onClose={handleClose} />)
|
||||
|
||||
await user.click(screen.getByTestId('modal-overlay'))
|
||||
|
||||
expect(handleClose).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should call onClose when pressing Escape', async () => {
|
||||
const user = userEvent.setup()
|
||||
const handleClose = vi.fn()
|
||||
|
||||
render(<Modal isOpen={true} onClose={handleClose} />)
|
||||
|
||||
await user.keyboard('{Escape}')
|
||||
|
||||
expect(handleClose).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should trap focus inside modal', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(
|
||||
<Modal isOpen={true} onClose={vi.fn()}>
|
||||
<button>First</button>
|
||||
<button>Second</button>
|
||||
</Modal>
|
||||
)
|
||||
|
||||
// Focus should cycle within modal
|
||||
await user.tab()
|
||||
expect(screen.getByText('First')).toHaveFocus()
|
||||
|
||||
await user.tab()
|
||||
expect(screen.getByText('Second')).toHaveFocus()
|
||||
|
||||
await user.tab()
|
||||
expect(screen.getByText('First')).toHaveFocus() // Cycles back
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
## Form Testing
|
||||
|
||||
```typescript
|
||||
describe('LoginForm', () => {
|
||||
it('should submit valid form', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onSubmit = vi.fn()
|
||||
|
||||
render(<LoginForm onSubmit={onSubmit} />)
|
||||
|
||||
await user.type(screen.getByLabelText(/email/i), 'test@example.com')
|
||||
await user.type(screen.getByLabelText(/password/i), 'password123')
|
||||
await user.click(screen.getByRole('button', { name: /sign in/i }))
|
||||
|
||||
expect(onSubmit).toHaveBeenCalledWith({
|
||||
email: 'test@example.com',
|
||||
password: 'password123',
|
||||
})
|
||||
})
|
||||
|
||||
it('should show validation errors', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(<LoginForm onSubmit={vi.fn()} />)
|
||||
|
||||
// Submit empty form
|
||||
await user.click(screen.getByRole('button', { name: /sign in/i }))
|
||||
|
||||
expect(screen.getByText(/email is required/i)).toBeInTheDocument()
|
||||
expect(screen.getByText(/password is required/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should validate email format', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(<LoginForm onSubmit={vi.fn()} />)
|
||||
|
||||
await user.type(screen.getByLabelText(/email/i), 'invalid-email')
|
||||
await user.click(screen.getByRole('button', { name: /sign in/i }))
|
||||
|
||||
expect(screen.getByText(/invalid email/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should disable submit button while submitting', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onSubmit = vi.fn(() => new Promise(resolve => setTimeout(resolve, 100)))
|
||||
|
||||
render(<LoginForm onSubmit={onSubmit} />)
|
||||
|
||||
await user.type(screen.getByLabelText(/email/i), 'test@example.com')
|
||||
await user.type(screen.getByLabelText(/password/i), 'password123')
|
||||
await user.click(screen.getByRole('button', { name: /sign in/i }))
|
||||
|
||||
expect(screen.getByRole('button', { name: /signing in/i })).toBeDisabled()
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByRole('button', { name: /sign in/i })).toBeEnabled()
|
||||
})
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
## Data-Driven Tests with test.each
|
||||
|
||||
```typescript
|
||||
describe('StatusBadge', () => {
|
||||
test.each([
|
||||
['success', 'bg-green-500'],
|
||||
['warning', 'bg-yellow-500'],
|
||||
['error', 'bg-red-500'],
|
||||
['info', 'bg-blue-500'],
|
||||
])('should apply correct class for %s status', (status, expectedClass) => {
|
||||
render(<StatusBadge status={status} />)
|
||||
|
||||
expect(screen.getByTestId('status-badge')).toHaveClass(expectedClass)
|
||||
})
|
||||
|
||||
test.each([
|
||||
{ input: null, expected: 'Unknown' },
|
||||
{ input: undefined, expected: 'Unknown' },
|
||||
{ input: '', expected: 'Unknown' },
|
||||
{ input: 'invalid', expected: 'Unknown' },
|
||||
])('should show "Unknown" for invalid input: $input', ({ input, expected }) => {
|
||||
render(<StatusBadge status={input} />)
|
||||
|
||||
expect(screen.getByText(expected)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
## Debugging Tips
|
||||
|
||||
```typescript
|
||||
// Print entire DOM
|
||||
screen.debug()
|
||||
|
||||
// Print specific element
|
||||
screen.debug(screen.getByRole('button'))
|
||||
|
||||
// Log testing playground URL
|
||||
screen.logTestingPlaygroundURL()
|
||||
|
||||
// Pretty print DOM
|
||||
import { prettyDOM } from '@testing-library/react'
|
||||
console.log(prettyDOM(screen.getByRole('dialog')))
|
||||
|
||||
// Check available roles
|
||||
import { getRoles } from '@testing-library/react'
|
||||
console.log(getRoles(container))
|
||||
```
|
||||
|
||||
## Common Mistakes to Avoid
|
||||
|
||||
### ❌ Don't Use Implementation Details
|
||||
|
||||
```typescript
|
||||
// Bad - testing implementation
|
||||
expect(component.state.isOpen).toBe(true)
|
||||
expect(wrapper.find('.internal-class').length).toBe(1)
|
||||
|
||||
// Good - testing behavior
|
||||
expect(screen.getByRole('dialog')).toBeInTheDocument()
|
||||
```
|
||||
|
||||
### ❌ Don't Forget Cleanup
|
||||
|
||||
```typescript
|
||||
// Bad - may leak state between tests
|
||||
it('test 1', () => {
|
||||
render(<Component />)
|
||||
})
|
||||
|
||||
// Good - cleanup is automatic with RTL, but reset mocks
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
```
|
||||
|
||||
### ❌ Don't Use Exact String Matching (Prefer Black-Box Assertions)
|
||||
|
||||
```typescript
|
||||
// ❌ Bad - hardcoded strings are brittle
|
||||
expect(screen.getByText('Submit Form')).toBeInTheDocument()
|
||||
expect(screen.getByText('Loading...')).toBeInTheDocument()
|
||||
|
||||
// ✅ Good - role-based queries (most semantic)
|
||||
expect(screen.getByRole('button', { name: /submit/i })).toBeInTheDocument()
|
||||
expect(screen.getByRole('status')).toBeInTheDocument()
|
||||
|
||||
// ✅ Good - pattern matching (flexible)
|
||||
expect(screen.getByText(/submit/i)).toBeInTheDocument()
|
||||
expect(screen.getByText(/loading/i)).toBeInTheDocument()
|
||||
|
||||
// ✅ Good - test behavior, not exact UI text
|
||||
expect(screen.getByRole('button')).toBeDisabled()
|
||||
expect(screen.getByRole('alert')).toBeInTheDocument()
|
||||
```
|
||||
|
||||
**Why prefer black-box assertions?**
|
||||
|
||||
- Text content may change (i18n, copy updates)
|
||||
- Role-based queries test accessibility
|
||||
- Pattern matching is resilient to minor changes
|
||||
- Tests focus on behavior, not implementation details
|
||||
|
||||
### ❌ Don't Assert on Absence Without Query
|
||||
|
||||
```typescript
|
||||
// Bad - throws if not found
|
||||
expect(screen.getByText('Error')).not.toBeInTheDocument() // Error!
|
||||
|
||||
// Good - use queryBy for absence assertions
|
||||
expect(screen.queryByText('Error')).not.toBeInTheDocument()
|
||||
```
|
||||
@@ -1,523 +0,0 @@
|
||||
# Domain-Specific Component Testing
|
||||
|
||||
This guide covers testing patterns for Dify's domain-specific components.
|
||||
|
||||
## Workflow Components (`workflow/`)
|
||||
|
||||
Workflow components handle node configuration, data flow, and graph operations.
|
||||
|
||||
### Key Test Areas
|
||||
|
||||
1. **Node Configuration**
|
||||
1. **Data Validation**
|
||||
1. **Variable Passing**
|
||||
1. **Edge Connections**
|
||||
1. **Error Handling**
|
||||
|
||||
### Example: Node Configuration Panel
|
||||
|
||||
```typescript
|
||||
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import NodeConfigPanel from './node-config-panel'
|
||||
import { createMockNode, createMockWorkflowContext } from '@/__mocks__/workflow'
|
||||
|
||||
// Mock workflow context
|
||||
vi.mock('@/app/components/workflow/hooks', () => ({
|
||||
useWorkflowStore: () => mockWorkflowStore,
|
||||
useNodesInteractions: () => mockNodesInteractions,
|
||||
}))
|
||||
|
||||
let mockWorkflowStore = {
|
||||
nodes: [],
|
||||
edges: [],
|
||||
updateNode: vi.fn(),
|
||||
}
|
||||
|
||||
let mockNodesInteractions = {
|
||||
handleNodeSelect: vi.fn(),
|
||||
handleNodeDelete: vi.fn(),
|
||||
}
|
||||
|
||||
describe('NodeConfigPanel', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockWorkflowStore = {
|
||||
nodes: [],
|
||||
edges: [],
|
||||
updateNode: vi.fn(),
|
||||
}
|
||||
})
|
||||
|
||||
describe('Node Configuration', () => {
|
||||
it('should render node type selector', () => {
|
||||
const node = createMockNode({ type: 'llm' })
|
||||
render(<NodeConfigPanel node={node} />)
|
||||
|
||||
expect(screen.getByLabelText(/model/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should update node config on change', async () => {
|
||||
const user = userEvent.setup()
|
||||
const node = createMockNode({ type: 'llm' })
|
||||
|
||||
render(<NodeConfigPanel node={node} />)
|
||||
|
||||
await user.selectOptions(screen.getByLabelText(/model/i), 'gpt-4')
|
||||
|
||||
expect(mockWorkflowStore.updateNode).toHaveBeenCalledWith(
|
||||
node.id,
|
||||
expect.objectContaining({ model: 'gpt-4' })
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Data Validation', () => {
|
||||
it('should show error for invalid input', async () => {
|
||||
const user = userEvent.setup()
|
||||
const node = createMockNode({ type: 'code' })
|
||||
|
||||
render(<NodeConfigPanel node={node} />)
|
||||
|
||||
// Enter invalid code
|
||||
const codeInput = screen.getByLabelText(/code/i)
|
||||
await user.clear(codeInput)
|
||||
await user.type(codeInput, 'invalid syntax {{{')
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/syntax error/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should validate required fields', async () => {
|
||||
const node = createMockNode({ type: 'http', data: { url: '' } })
|
||||
|
||||
render(<NodeConfigPanel node={node} />)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: /save/i }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/url is required/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Variable Passing', () => {
|
||||
it('should display available variables from upstream nodes', () => {
|
||||
const upstreamNode = createMockNode({
|
||||
id: 'node-1',
|
||||
type: 'start',
|
||||
data: { outputs: [{ name: 'user_input', type: 'string' }] },
|
||||
})
|
||||
const currentNode = createMockNode({
|
||||
id: 'node-2',
|
||||
type: 'llm',
|
||||
})
|
||||
|
||||
mockWorkflowStore.nodes = [upstreamNode, currentNode]
|
||||
mockWorkflowStore.edges = [{ source: 'node-1', target: 'node-2' }]
|
||||
|
||||
render(<NodeConfigPanel node={currentNode} />)
|
||||
|
||||
// Variable selector should show upstream variables
|
||||
fireEvent.click(screen.getByRole('button', { name: /add variable/i }))
|
||||
|
||||
expect(screen.getByText('user_input')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should insert variable into prompt template', async () => {
|
||||
const user = userEvent.setup()
|
||||
const node = createMockNode({ type: 'llm' })
|
||||
|
||||
render(<NodeConfigPanel node={node} />)
|
||||
|
||||
// Click variable button
|
||||
await user.click(screen.getByRole('button', { name: /insert variable/i }))
|
||||
await user.click(screen.getByText('user_input'))
|
||||
|
||||
const promptInput = screen.getByLabelText(/prompt/i)
|
||||
expect(promptInput).toHaveValue(expect.stringContaining('{{user_input}}'))
|
||||
})
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
## Dataset Components (`dataset/`)
|
||||
|
||||
Dataset components handle file uploads, data display, and search/filter operations.
|
||||
|
||||
### Key Test Areas
|
||||
|
||||
1. **File Upload**
|
||||
1. **File Type Validation**
|
||||
1. **Pagination**
|
||||
1. **Search & Filtering**
|
||||
1. **Data Format Handling**
|
||||
|
||||
### Example: Document Uploader
|
||||
|
||||
```typescript
|
||||
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import DocumentUploader from './document-uploader'
|
||||
|
||||
vi.mock('@/service/datasets', () => ({
|
||||
uploadDocument: vi.fn(),
|
||||
parseDocument: vi.fn(),
|
||||
}))
|
||||
|
||||
import * as datasetService from '@/service/datasets'
|
||||
const mockedService = vi.mocked(datasetService)
|
||||
|
||||
describe('DocumentUploader', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('File Upload', () => {
|
||||
it('should accept valid file types', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onUpload = vi.fn()
|
||||
mockedService.uploadDocument.mockResolvedValue({ id: 'doc-1' })
|
||||
|
||||
render(<DocumentUploader onUpload={onUpload} />)
|
||||
|
||||
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
|
||||
const input = screen.getByLabelText(/upload/i)
|
||||
|
||||
await user.upload(input, file)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockedService.uploadDocument).toHaveBeenCalledWith(
|
||||
expect.any(FormData)
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
it('should reject invalid file types', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(<DocumentUploader />)
|
||||
|
||||
const file = new File(['content'], 'test.exe', { type: 'application/x-msdownload' })
|
||||
const input = screen.getByLabelText(/upload/i)
|
||||
|
||||
await user.upload(input, file)
|
||||
|
||||
expect(screen.getByText(/unsupported file type/i)).toBeInTheDocument()
|
||||
expect(mockedService.uploadDocument).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should show upload progress', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
// Mock upload with progress
|
||||
mockedService.uploadDocument.mockImplementation(() => {
|
||||
return new Promise((resolve) => {
|
||||
setTimeout(() => resolve({ id: 'doc-1' }), 100)
|
||||
})
|
||||
})
|
||||
|
||||
render(<DocumentUploader />)
|
||||
|
||||
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
|
||||
await user.upload(screen.getByLabelText(/upload/i), file)
|
||||
|
||||
expect(screen.getByRole('progressbar')).toBeInTheDocument()
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByRole('progressbar')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Error Handling', () => {
|
||||
it('should handle upload failure', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockedService.uploadDocument.mockRejectedValue(new Error('Upload failed'))
|
||||
|
||||
render(<DocumentUploader />)
|
||||
|
||||
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
|
||||
await user.upload(screen.getByLabelText(/upload/i), file)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/upload failed/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should allow retry after failure', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockedService.uploadDocument
|
||||
.mockRejectedValueOnce(new Error('Network error'))
|
||||
.mockResolvedValueOnce({ id: 'doc-1' })
|
||||
|
||||
render(<DocumentUploader />)
|
||||
|
||||
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
|
||||
await user.upload(screen.getByLabelText(/upload/i), file)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByRole('button', { name: /retry/i })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
await user.click(screen.getByRole('button', { name: /retry/i }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/uploaded successfully/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### Example: Document List with Pagination
|
||||
|
||||
```typescript
|
||||
describe('DocumentList', () => {
|
||||
describe('Pagination', () => {
|
||||
it('should load first page on mount', async () => {
|
||||
mockedService.getDocuments.mockResolvedValue({
|
||||
data: [{ id: '1', name: 'Doc 1' }],
|
||||
total: 50,
|
||||
page: 1,
|
||||
pageSize: 10,
|
||||
})
|
||||
|
||||
render(<DocumentList datasetId="ds-1" />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('Doc 1')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
expect(mockedService.getDocuments).toHaveBeenCalledWith('ds-1', { page: 1 })
|
||||
})
|
||||
|
||||
it('should navigate to next page', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockedService.getDocuments.mockResolvedValue({
|
||||
data: [{ id: '1', name: 'Doc 1' }],
|
||||
total: 50,
|
||||
page: 1,
|
||||
pageSize: 10,
|
||||
})
|
||||
|
||||
render(<DocumentList datasetId="ds-1" />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('Doc 1')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
mockedService.getDocuments.mockResolvedValue({
|
||||
data: [{ id: '11', name: 'Doc 11' }],
|
||||
total: 50,
|
||||
page: 2,
|
||||
pageSize: 10,
|
||||
})
|
||||
|
||||
await user.click(screen.getByRole('button', { name: /next/i }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('Doc 11')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Search & Filtering', () => {
|
||||
it('should filter by search query', async () => {
|
||||
const user = userEvent.setup()
|
||||
vi.useFakeTimers()
|
||||
|
||||
render(<DocumentList datasetId="ds-1" />)
|
||||
|
||||
await user.type(screen.getByPlaceholderText(/search/i), 'test query')
|
||||
|
||||
// Debounce
|
||||
vi.advanceTimersByTime(300)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockedService.getDocuments).toHaveBeenCalledWith(
|
||||
'ds-1',
|
||||
expect.objectContaining({ search: 'test query' })
|
||||
)
|
||||
})
|
||||
|
||||
vi.useRealTimers()
|
||||
})
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
## Configuration Components (`app/configuration/`, `config/`)
|
||||
|
||||
Configuration components handle forms, validation, and data persistence.
|
||||
|
||||
### Key Test Areas
|
||||
|
||||
1. **Form Validation**
|
||||
1. **Save/Reset**
|
||||
1. **Required vs Optional Fields**
|
||||
1. **Configuration Persistence**
|
||||
1. **Error Feedback**
|
||||
|
||||
### Example: App Configuration Form
|
||||
|
||||
```typescript
|
||||
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import AppConfigForm from './app-config-form'
|
||||
|
||||
vi.mock('@/service/apps', () => ({
|
||||
updateAppConfig: vi.fn(),
|
||||
getAppConfig: vi.fn(),
|
||||
}))
|
||||
|
||||
import * as appService from '@/service/apps'
|
||||
const mockedService = vi.mocked(appService)
|
||||
|
||||
describe('AppConfigForm', () => {
|
||||
const defaultConfig = {
|
||||
name: 'My App',
|
||||
description: '',
|
||||
icon: 'default',
|
||||
openingStatement: '',
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockedService.getAppConfig.mockResolvedValue(defaultConfig)
|
||||
})
|
||||
|
||||
describe('Form Validation', () => {
|
||||
it('should require app name', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(<AppConfigForm appId="app-1" />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
|
||||
})
|
||||
|
||||
// Clear name field
|
||||
await user.clear(screen.getByLabelText(/name/i))
|
||||
await user.click(screen.getByRole('button', { name: /save/i }))
|
||||
|
||||
expect(screen.getByText(/name is required/i)).toBeInTheDocument()
|
||||
expect(mockedService.updateAppConfig).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should validate name length', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(<AppConfigForm appId="app-1" />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByLabelText(/name/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Enter very long name
|
||||
await user.clear(screen.getByLabelText(/name/i))
|
||||
await user.type(screen.getByLabelText(/name/i), 'a'.repeat(101))
|
||||
|
||||
expect(screen.getByText(/name must be less than 100 characters/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should allow empty optional fields', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockedService.updateAppConfig.mockResolvedValue({ success: true })
|
||||
|
||||
render(<AppConfigForm appId="app-1" />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
|
||||
})
|
||||
|
||||
// Leave description empty (optional)
|
||||
await user.click(screen.getByRole('button', { name: /save/i }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockedService.updateAppConfig).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Save/Reset Functionality', () => {
|
||||
it('should save configuration', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockedService.updateAppConfig.mockResolvedValue({ success: true })
|
||||
|
||||
render(<AppConfigForm appId="app-1" />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
|
||||
})
|
||||
|
||||
await user.clear(screen.getByLabelText(/name/i))
|
||||
await user.type(screen.getByLabelText(/name/i), 'Updated App')
|
||||
await user.click(screen.getByRole('button', { name: /save/i }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockedService.updateAppConfig).toHaveBeenCalledWith(
|
||||
'app-1',
|
||||
expect.objectContaining({ name: 'Updated App' })
|
||||
)
|
||||
})
|
||||
|
||||
expect(screen.getByText(/saved successfully/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should reset to default values', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(<AppConfigForm appId="app-1" />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
|
||||
})
|
||||
|
||||
// Make changes
|
||||
await user.clear(screen.getByLabelText(/name/i))
|
||||
await user.type(screen.getByLabelText(/name/i), 'Changed Name')
|
||||
|
||||
// Reset
|
||||
await user.click(screen.getByRole('button', { name: /reset/i }))
|
||||
|
||||
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
|
||||
})
|
||||
|
||||
it('should show unsaved changes warning', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(<AppConfigForm appId="app-1" />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
|
||||
})
|
||||
|
||||
// Make changes
|
||||
await user.type(screen.getByLabelText(/name/i), ' Updated')
|
||||
|
||||
expect(screen.getByText(/unsaved changes/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Error Handling', () => {
|
||||
it('should show error on save failure', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockedService.updateAppConfig.mockRejectedValue(new Error('Server error'))
|
||||
|
||||
render(<AppConfigForm appId="app-1" />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
|
||||
})
|
||||
|
||||
await user.click(screen.getByRole('button', { name: /save/i }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/failed to save/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
```
|
||||
@@ -1,349 +0,0 @@
|
||||
# Mocking Guide for Dify Frontend Tests
|
||||
|
||||
## ⚠️ Important: What NOT to Mock
|
||||
|
||||
### DO NOT Mock Base Components
|
||||
|
||||
**Never mock components from `@/app/components/base/`** such as:
|
||||
|
||||
- `Loading`, `Spinner`
|
||||
- `Button`, `Input`, `Select`
|
||||
- `Tooltip`, `Modal`, `Dropdown`
|
||||
- `Icon`, `Badge`, `Tag`
|
||||
|
||||
**Why?**
|
||||
|
||||
- Base components will have their own dedicated tests
|
||||
- Mocking them creates false positives (tests pass but real integration fails)
|
||||
- Using real components tests actual integration behavior
|
||||
|
||||
```typescript
|
||||
// ❌ WRONG: Don't mock base components
|
||||
vi.mock('@/app/components/base/loading', () => () => <div>Loading</div>)
|
||||
vi.mock('@/app/components/base/button', () => ({ children }: any) => <button>{children}</button>)
|
||||
|
||||
// ✅ CORRECT: Import and use real base components
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import Button from '@/app/components/base/button'
|
||||
// They will render normally in tests
|
||||
```
|
||||
|
||||
### What TO Mock
|
||||
|
||||
Only mock these categories:
|
||||
|
||||
1. **API services** (`@/service/*`) - Network calls
|
||||
1. **Complex context providers** - When setup is too difficult
|
||||
1. **Third-party libraries with side effects** - `next/navigation`, external SDKs
|
||||
1. **i18n** - Always mock to return keys
|
||||
|
||||
## Mock Placement
|
||||
|
||||
| Location | Purpose |
|
||||
|----------|---------|
|
||||
| `web/vitest.setup.ts` | Global mocks shared by all tests (for example `react-i18next`, `next/image`) |
|
||||
| `web/__mocks__/` | Reusable mock factories shared across multiple test files |
|
||||
| Test file | Test-specific mocks, inline with `vi.mock()` |
|
||||
|
||||
Modules are not mocked automatically. Use `vi.mock` in test files, or add global mocks in `web/vitest.setup.ts`.
|
||||
|
||||
## Essential Mocks
|
||||
|
||||
### 1. i18n (Auto-loaded via Global Mock)
|
||||
|
||||
A global mock is defined in `web/vitest.setup.ts` and is auto-loaded by Vitest setup.
|
||||
|
||||
The global mock provides:
|
||||
|
||||
- `useTranslation` - returns translation keys with namespace prefix
|
||||
- `Trans` component - renders i18nKey and components
|
||||
- `useMixedTranslation` (from `@/app/components/plugins/marketplace/hooks`)
|
||||
- `useGetLanguage` (from `@/context/i18n`) - returns `'en-US'`
|
||||
|
||||
**Default behavior**: Most tests should use the global mock (no local override needed).
|
||||
|
||||
**For custom translations**: Use the helper function from `@/test/i18n-mock`:
|
||||
|
||||
```typescript
|
||||
import { createReactI18nextMock } from '@/test/i18n-mock'
|
||||
|
||||
vi.mock('react-i18next', () => createReactI18nextMock({
|
||||
'my.custom.key': 'Custom translation',
|
||||
'button.save': 'Save',
|
||||
}))
|
||||
```
|
||||
|
||||
**Avoid**: Manually defining `useTranslation` mocks that just return the key - the global mock already does this.
|
||||
|
||||
### 2. Next.js Router
|
||||
|
||||
```typescript
|
||||
const mockPush = vi.fn()
|
||||
const mockReplace = vi.fn()
|
||||
|
||||
vi.mock('next/navigation', () => ({
|
||||
useRouter: () => ({
|
||||
push: mockPush,
|
||||
replace: mockReplace,
|
||||
back: vi.fn(),
|
||||
prefetch: vi.fn(),
|
||||
}),
|
||||
usePathname: () => '/current-path',
|
||||
useSearchParams: () => new URLSearchParams('?key=value'),
|
||||
}))
|
||||
|
||||
describe('Component', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should navigate on click', () => {
|
||||
render(<Component />)
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
expect(mockPush).toHaveBeenCalledWith('/expected-path')
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### 3. Portal Components (with Shared State)
|
||||
|
||||
```typescript
|
||||
// ⚠️ Important: Use shared state for components that depend on each other
|
||||
let mockPortalOpenState = false
|
||||
|
||||
vi.mock('@/app/components/base/portal-to-follow-elem', () => ({
|
||||
PortalToFollowElem: ({ children, open, ...props }: any) => {
|
||||
mockPortalOpenState = open || false // Update shared state
|
||||
return <div data-testid="portal" data-open={open}>{children}</div>
|
||||
},
|
||||
PortalToFollowElemContent: ({ children }: any) => {
|
||||
// ✅ Matches actual: returns null when portal is closed
|
||||
if (!mockPortalOpenState) return null
|
||||
return <div data-testid="portal-content">{children}</div>
|
||||
},
|
||||
PortalToFollowElemTrigger: ({ children }: any) => (
|
||||
<div data-testid="portal-trigger">{children}</div>
|
||||
),
|
||||
}))
|
||||
|
||||
describe('Component', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockPortalOpenState = false // ✅ Reset shared state
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### 4. API Service Mocks
|
||||
|
||||
```typescript
|
||||
import * as api from '@/service/api'
|
||||
|
||||
vi.mock('@/service/api')
|
||||
|
||||
const mockedApi = vi.mocked(api)
|
||||
|
||||
describe('Component', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
|
||||
// Setup default mock implementation
|
||||
mockedApi.fetchData.mockResolvedValue({ data: [] })
|
||||
})
|
||||
|
||||
it('should show data on success', async () => {
|
||||
mockedApi.fetchData.mockResolvedValue({ data: [{ id: 1 }] })
|
||||
|
||||
render(<Component />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('1')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should show error on failure', async () => {
|
||||
mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
|
||||
|
||||
render(<Component />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/error/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### 5. HTTP Mocking with Nock
|
||||
|
||||
```typescript
|
||||
import nock from 'nock'
|
||||
|
||||
const GITHUB_HOST = 'https://api.github.com'
|
||||
const GITHUB_PATH = '/repos/owner/repo'
|
||||
|
||||
const mockGithubApi = (status: number, body: Record<string, unknown>, delayMs = 0) => {
|
||||
return nock(GITHUB_HOST)
|
||||
.get(GITHUB_PATH)
|
||||
.delay(delayMs)
|
||||
.reply(status, body)
|
||||
}
|
||||
|
||||
describe('GithubComponent', () => {
|
||||
afterEach(() => {
|
||||
nock.cleanAll()
|
||||
})
|
||||
|
||||
it('should display repo info', async () => {
|
||||
mockGithubApi(200, { name: 'dify', stars: 1000 })
|
||||
|
||||
render(<GithubComponent />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('dify')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle API error', async () => {
|
||||
mockGithubApi(500, { message: 'Server error' })
|
||||
|
||||
render(<GithubComponent />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/error/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### 6. Context Providers
|
||||
|
||||
```typescript
|
||||
import { ProviderContext } from '@/context/provider-context'
|
||||
import { createMockProviderContextValue, createMockPlan } from '@/__mocks__/provider-context'
|
||||
|
||||
describe('Component with Context', () => {
|
||||
it('should render for free plan', () => {
|
||||
const mockContext = createMockPlan('sandbox')
|
||||
|
||||
render(
|
||||
<ProviderContext.Provider value={mockContext}>
|
||||
<Component />
|
||||
</ProviderContext.Provider>
|
||||
)
|
||||
|
||||
expect(screen.getByText('Upgrade')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render for pro plan', () => {
|
||||
const mockContext = createMockPlan('professional')
|
||||
|
||||
render(
|
||||
<ProviderContext.Provider value={mockContext}>
|
||||
<Component />
|
||||
</ProviderContext.Provider>
|
||||
)
|
||||
|
||||
expect(screen.queryByText('Upgrade')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
### 7. React Query
|
||||
|
||||
```typescript
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
|
||||
const createTestQueryClient = () => new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: { retry: false },
|
||||
mutations: { retry: false },
|
||||
},
|
||||
})
|
||||
|
||||
const renderWithQueryClient = (ui: React.ReactElement) => {
|
||||
const queryClient = createTestQueryClient()
|
||||
return render(
|
||||
<QueryClientProvider client={queryClient}>
|
||||
{ui}
|
||||
</QueryClientProvider>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
## Mock Best Practices
|
||||
|
||||
### ✅ DO
|
||||
|
||||
1. **Use real base components** - Import from `@/app/components/base/` directly
|
||||
1. **Use real project components** - Prefer importing over mocking
|
||||
1. **Reset mocks in `beforeEach`**, not `afterEach`
|
||||
1. **Match actual component behavior** in mocks (when mocking is necessary)
|
||||
1. **Use factory functions** for complex mock data
|
||||
1. **Import actual types** for type safety
|
||||
1. **Reset shared mock state** in `beforeEach`
|
||||
|
||||
### ❌ DON'T
|
||||
|
||||
1. **Don't mock base components** (`Loading`, `Button`, `Tooltip`, etc.)
|
||||
1. Don't mock components you can import directly
|
||||
1. Don't create overly simplified mocks that miss conditional logic
|
||||
1. Don't forget to clean up nock after each test
|
||||
1. Don't use `any` types in mocks without necessity
|
||||
|
||||
### Mock Decision Tree
|
||||
|
||||
```
|
||||
Need to use a component in test?
|
||||
│
|
||||
├─ Is it from @/app/components/base/*?
|
||||
│ └─ YES → Import real component, DO NOT mock
|
||||
│
|
||||
├─ Is it a project component?
|
||||
│ └─ YES → Prefer importing real component
|
||||
│ Only mock if setup is extremely complex
|
||||
│
|
||||
├─ Is it an API service (@/service/*)?
|
||||
│ └─ YES → Mock it
|
||||
│
|
||||
├─ Is it a third-party lib with side effects?
|
||||
│ └─ YES → Mock it (next/navigation, external SDKs)
|
||||
│
|
||||
└─ Is it i18n?
|
||||
└─ YES → Uses shared mock (auto-loaded). Override only for custom translations
|
||||
```
|
||||
|
||||
## Factory Function Pattern
|
||||
|
||||
```typescript
|
||||
// __mocks__/data-factories.ts
|
||||
import type { User, Project } from '@/types'
|
||||
|
||||
export const createMockUser = (overrides: Partial<User> = {}): User => ({
|
||||
id: 'user-1',
|
||||
name: 'Test User',
|
||||
email: 'test@example.com',
|
||||
role: 'member',
|
||||
createdAt: new Date().toISOString(),
|
||||
...overrides,
|
||||
})
|
||||
|
||||
export const createMockProject = (overrides: Partial<Project> = {}): Project => ({
|
||||
id: 'project-1',
|
||||
name: 'Test Project',
|
||||
description: 'A test project',
|
||||
owner: createMockUser(),
|
||||
members: [],
|
||||
createdAt: new Date().toISOString(),
|
||||
...overrides,
|
||||
})
|
||||
|
||||
// Usage in tests
|
||||
it('should display project owner', () => {
|
||||
const project = createMockProject({
|
||||
owner: createMockUser({ name: 'John Doe' }),
|
||||
})
|
||||
|
||||
render(<ProjectCard project={project} />)
|
||||
expect(screen.getByText('John Doe')).toBeInTheDocument()
|
||||
})
|
||||
```
|
||||
@@ -1,269 +0,0 @@
|
||||
# Testing Workflow Guide
|
||||
|
||||
This guide defines the workflow for generating tests, especially for complex components or directories with multiple files.
|
||||
|
||||
## Scope Clarification
|
||||
|
||||
This guide addresses **multi-file workflow** (how to process multiple test files). For coverage requirements within a single test file, see `web/testing/testing.md` § Coverage Goals.
|
||||
|
||||
| Scope | Rule |
|
||||
|-------|------|
|
||||
| **Single file** | Complete coverage in one generation (100% function, >95% branch) |
|
||||
| **Multi-file directory** | Process one file at a time, verify each before proceeding |
|
||||
|
||||
## ⚠️ Critical Rule: Incremental Approach for Multi-File Testing
|
||||
|
||||
When testing a **directory with multiple files**, **NEVER generate all test files at once.** Use an incremental, verify-as-you-go approach.
|
||||
|
||||
### Why Incremental?
|
||||
|
||||
| Batch Approach (❌) | Incremental Approach (✅) |
|
||||
|---------------------|---------------------------|
|
||||
| Generate 5+ tests at once | Generate 1 test at a time |
|
||||
| Run tests only at the end | Run test immediately after each file |
|
||||
| Multiple failures compound | Single point of failure, easy to debug |
|
||||
| Hard to identify root cause | Clear cause-effect relationship |
|
||||
| Mock issues affect many files | Mock issues caught early |
|
||||
| Messy git history | Clean, atomic commits possible |
|
||||
|
||||
## Single File Workflow
|
||||
|
||||
When testing a **single component, hook, or utility**:
|
||||
|
||||
```
|
||||
1. Read source code completely
|
||||
2. Run `pnpm analyze-component <path>` (if available)
|
||||
3. Check complexity score and features detected
|
||||
4. Write the test file
|
||||
5. Run test: `pnpm test <file>.spec.tsx`
|
||||
6. Fix any failures
|
||||
7. Verify coverage meets goals (100% function, >95% branch)
|
||||
```
|
||||
|
||||
## Directory/Multi-File Workflow (MUST FOLLOW)
|
||||
|
||||
When testing a **directory or multiple files**, follow this strict workflow:
|
||||
|
||||
### Step 1: Analyze and Plan
|
||||
|
||||
1. **List all files** that need tests in the directory
|
||||
1. **Categorize by complexity**:
|
||||
- 🟢 **Simple**: Utility functions, simple hooks, presentational components
|
||||
- 🟡 **Medium**: Components with state, effects, or event handlers
|
||||
- 🔴 **Complex**: Components with API calls, routing, or many dependencies
|
||||
1. **Order by dependency**: Test dependencies before dependents
|
||||
1. **Create a todo list** to track progress
|
||||
|
||||
### Step 2: Determine Processing Order
|
||||
|
||||
Process files in this recommended order:
|
||||
|
||||
```
|
||||
1. Utility functions (simplest, no React)
|
||||
2. Custom hooks (isolated logic)
|
||||
3. Simple presentational components (few/no props)
|
||||
4. Medium complexity components (state, effects)
|
||||
5. Complex components (API, routing, many deps)
|
||||
6. Container/index components (integration tests - last)
|
||||
```
|
||||
|
||||
**Rationale**:
|
||||
|
||||
- Simpler files help establish mock patterns
|
||||
- Hooks used by components should be tested first
|
||||
- Integration tests (index files) depend on child components working
|
||||
|
||||
### Step 3: Process Each File Incrementally
|
||||
|
||||
**For EACH file in the ordered list:**
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────┐
|
||||
│ 1. Write test file │
|
||||
│ 2. Run: pnpm test <file>.spec.tsx │
|
||||
│ 3. If FAIL → Fix immediately, re-run │
|
||||
│ 4. If PASS → Mark complete in todo list │
|
||||
│ 5. ONLY THEN proceed to next file │
|
||||
└─────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
**DO NOT proceed to the next file until the current one passes.**
|
||||
|
||||
### Step 4: Final Verification
|
||||
|
||||
After all individual tests pass:
|
||||
|
||||
```bash
|
||||
# Run all tests in the directory together
|
||||
pnpm test path/to/directory/
|
||||
|
||||
# Check coverage
|
||||
pnpm test:coverage path/to/directory/
|
||||
```
|
||||
|
||||
## Component Complexity Guidelines
|
||||
|
||||
Use `pnpm analyze-component <path>` to assess complexity before testing.
|
||||
|
||||
### 🔴 Very Complex Components (Complexity > 50)
|
||||
|
||||
**Consider refactoring BEFORE testing:**
|
||||
|
||||
- Break component into smaller, testable pieces
|
||||
- Extract complex logic into custom hooks
|
||||
- Separate container and presentational layers
|
||||
|
||||
**If testing as-is:**
|
||||
|
||||
- Use integration tests for complex workflows
|
||||
- Use `test.each()` for data-driven testing
|
||||
- Multiple `describe` blocks for organization
|
||||
- Consider testing major sections separately
|
||||
|
||||
### 🟡 Medium Complexity (Complexity 30-50)
|
||||
|
||||
- Group related tests in `describe` blocks
|
||||
- Test integration scenarios between internal parts
|
||||
- Focus on state transitions and side effects
|
||||
- Use helper functions to reduce test complexity
|
||||
|
||||
### 🟢 Simple Components (Complexity < 30)
|
||||
|
||||
- Standard test structure
|
||||
- Focus on props, rendering, and edge cases
|
||||
- Usually straightforward to test
|
||||
|
||||
### 📏 Large Files (500+ lines)
|
||||
|
||||
Regardless of complexity score:
|
||||
|
||||
- **Strongly consider refactoring** before testing
|
||||
- If testing as-is, test major sections separately
|
||||
- Create helper functions for test setup
|
||||
- May need multiple test files
|
||||
|
||||
## Todo List Format
|
||||
|
||||
When testing multiple files, use a todo list like this:
|
||||
|
||||
```
|
||||
Testing: path/to/directory/
|
||||
|
||||
Ordered by complexity (simple → complex):
|
||||
|
||||
☐ utils/helper.ts [utility, simple]
|
||||
☐ hooks/use-custom-hook.ts [hook, simple]
|
||||
☐ empty-state.tsx [component, simple]
|
||||
☐ item-card.tsx [component, medium]
|
||||
☐ list.tsx [component, complex]
|
||||
☐ index.tsx [integration]
|
||||
|
||||
Progress: 0/6 complete
|
||||
```
|
||||
|
||||
Update status as you complete each:
|
||||
|
||||
- ☐ → ⏳ (in progress)
|
||||
- ⏳ → ✅ (complete and verified)
|
||||
- ⏳ → ❌ (blocked, needs attention)
|
||||
|
||||
## When to Stop and Verify
|
||||
|
||||
**Always run tests after:**
|
||||
|
||||
- Completing a test file
|
||||
- Making changes to fix a failure
|
||||
- Modifying shared mocks
|
||||
- Updating test utilities or helpers
|
||||
|
||||
**Signs you should pause:**
|
||||
|
||||
- More than 2 consecutive test failures
|
||||
- Mock-related errors appearing
|
||||
- Unclear why a test is failing
|
||||
- Test passing but coverage unexpectedly low
|
||||
|
||||
## Common Pitfalls to Avoid
|
||||
|
||||
### ❌ Don't: Generate Everything First
|
||||
|
||||
```
|
||||
# BAD: Writing all files then testing
|
||||
Write component-a.spec.tsx
|
||||
Write component-b.spec.tsx
|
||||
Write component-c.spec.tsx
|
||||
Write component-d.spec.tsx
|
||||
Run pnpm test ← Multiple failures, hard to debug
|
||||
```
|
||||
|
||||
### ✅ Do: Verify Each Step
|
||||
|
||||
```
|
||||
# GOOD: Incremental with verification
|
||||
Write component-a.spec.tsx
|
||||
Run pnpm test component-a.spec.tsx ✅
|
||||
Write component-b.spec.tsx
|
||||
Run pnpm test component-b.spec.tsx ✅
|
||||
...continue...
|
||||
```
|
||||
|
||||
### ❌ Don't: Skip Verification for "Simple" Components
|
||||
|
||||
Even simple components can have:
|
||||
|
||||
- Import errors
|
||||
- Missing mock setup
|
||||
- Incorrect assumptions about props
|
||||
|
||||
**Always verify, regardless of perceived simplicity.**
|
||||
|
||||
### ❌ Don't: Continue When Tests Fail
|
||||
|
||||
Failing tests compound:
|
||||
|
||||
- A mock issue in file A affects files B, C, D
|
||||
- Fixing A later requires revisiting all dependent tests
|
||||
- Time wasted on debugging cascading failures
|
||||
|
||||
**Fix failures immediately before proceeding.**
|
||||
|
||||
## Integration with Claude's Todo Feature
|
||||
|
||||
When using Claude for multi-file testing:
|
||||
|
||||
1. **Ask Claude to create a todo list** before starting
|
||||
1. **Request one file at a time** or ensure Claude processes incrementally
|
||||
1. **Verify each test passes** before asking for the next
|
||||
1. **Mark todos complete** as you progress
|
||||
|
||||
Example prompt:
|
||||
|
||||
```
|
||||
Test all components in `path/to/directory/`.
|
||||
First, analyze the directory and create a todo list ordered by complexity.
|
||||
Then, process ONE file at a time, waiting for my confirmation that tests pass
|
||||
before proceeding to the next.
|
||||
```
|
||||
|
||||
## Summary Checklist
|
||||
|
||||
Before starting multi-file testing:
|
||||
|
||||
- [ ] Listed all files needing tests
|
||||
- [ ] Ordered by complexity (simple → complex)
|
||||
- [ ] Created todo list for tracking
|
||||
- [ ] Understand dependencies between files
|
||||
|
||||
During testing:
|
||||
|
||||
- [ ] Processing ONE file at a time
|
||||
- [ ] Running tests after EACH file
|
||||
- [ ] Fixing failures BEFORE proceeding
|
||||
- [ ] Updating todo list progress
|
||||
|
||||
After completion:
|
||||
|
||||
- [ ] All individual tests pass
|
||||
- [ ] Full directory test run passes
|
||||
- [ ] Coverage goals met
|
||||
- [ ] Todo list shows all complete
|
||||
@@ -1 +0,0 @@
|
||||
../.claude/skills
|
||||
@@ -1,5 +0,0 @@
|
||||
[run]
|
||||
omit =
|
||||
api/tests/*
|
||||
api/migrations/*
|
||||
api/core/rag/datasource/vdb/*
|
||||
@@ -1,5 +1,6 @@
|
||||
# Cursor Rules for Dify Project
|
||||
|
||||
## Automated Test Generation
|
||||
|
||||
- Use `web/testing/testing.md` as the canonical instruction set for generating frontend automated tests.
|
||||
- When proposing or saving tests, re-read that document and follow every requirement.
|
||||
- All frontend tests MUST also comply with the `frontend-testing` skill. Treat the skill as a mandatory constraint, not optional guidance.
|
||||
@@ -6,9 +6,6 @@
|
||||
"context": "..",
|
||||
"dockerfile": "Dockerfile"
|
||||
},
|
||||
"mounts": [
|
||||
"source=dify-dev-tmp,target=/tmp,type=volume"
|
||||
],
|
||||
"features": {
|
||||
"ghcr.io/devcontainers/features/node:1": {
|
||||
"nodeGypDependencies": true,
|
||||
@@ -37,13 +34,19 @@
|
||||
},
|
||||
"postStartCommand": "./.devcontainer/post_start_command.sh",
|
||||
"postCreateCommand": "./.devcontainer/post_create_command.sh"
|
||||
|
||||
// Features to add to the dev container. More info: https://containers.dev/features.
|
||||
// "features": {},
|
||||
|
||||
// Use 'forwardPorts' to make a list of ports inside the container available locally.
|
||||
// "forwardPorts": [],
|
||||
|
||||
// Use 'postCreateCommand' to run commands after the container is created.
|
||||
// "postCreateCommand": "python --version",
|
||||
|
||||
// Configure tool-specific properties.
|
||||
// "customizations": {},
|
||||
|
||||
// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
|
||||
}
|
||||
// "remoteUser": "root"
|
||||
}
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
#!/bin/bash
|
||||
WORKSPACE_ROOT=$(pwd)
|
||||
|
||||
export COREPACK_ENABLE_DOWNLOAD_PROMPT=0
|
||||
corepack enable
|
||||
cd web && pnpm install
|
||||
pipx install uv
|
||||
|
||||
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor\"" >> ~/.bashrc
|
||||
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
|
||||
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
|
||||
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
|
||||
|
||||
309
.github/CODEOWNERS
vendored
309
.github/CODEOWNERS
vendored
@@ -6,244 +6,221 @@
|
||||
|
||||
* @crazywoola @laipz8200 @Yeuoly
|
||||
|
||||
# CODEOWNERS file
|
||||
/.github/CODEOWNERS @laipz8200 @crazywoola
|
||||
|
||||
# Docs
|
||||
/docs/ @crazywoola
|
||||
|
||||
# Backend (default owner, more specific rules below will override)
|
||||
/api/ @QuantumGhost
|
||||
|
||||
# Backend - MCP
|
||||
/api/core/mcp/ @Nov1c444
|
||||
/api/core/entities/mcp_provider.py @Nov1c444
|
||||
/api/services/tools/mcp_tools_manage_service.py @Nov1c444
|
||||
/api/controllers/mcp/ @Nov1c444
|
||||
/api/controllers/console/app/mcp_server.py @Nov1c444
|
||||
/api/tests/**/*mcp* @Nov1c444
|
||||
api/ @QuantumGhost
|
||||
|
||||
# Backend - Workflow - Engine (Core graph execution engine)
|
||||
/api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost
|
||||
/api/core/workflow/runtime/ @laipz8200 @QuantumGhost
|
||||
/api/core/workflow/graph/ @laipz8200 @QuantumGhost
|
||||
/api/core/workflow/graph_events/ @laipz8200 @QuantumGhost
|
||||
/api/core/workflow/node_events/ @laipz8200 @QuantumGhost
|
||||
/api/core/model_runtime/ @laipz8200 @QuantumGhost
|
||||
api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost
|
||||
api/core/workflow/runtime/ @laipz8200 @QuantumGhost
|
||||
api/core/workflow/graph/ @laipz8200 @QuantumGhost
|
||||
api/core/workflow/graph_events/ @laipz8200 @QuantumGhost
|
||||
api/core/workflow/node_events/ @laipz8200 @QuantumGhost
|
||||
api/core/model_runtime/ @laipz8200 @QuantumGhost
|
||||
|
||||
# Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM)
|
||||
/api/core/workflow/nodes/agent/ @Nov1c444
|
||||
/api/core/workflow/nodes/iteration/ @Nov1c444
|
||||
/api/core/workflow/nodes/loop/ @Nov1c444
|
||||
/api/core/workflow/nodes/llm/ @Nov1c444
|
||||
api/core/workflow/nodes/agent/ @Nov1c444
|
||||
api/core/workflow/nodes/iteration/ @Nov1c444
|
||||
api/core/workflow/nodes/loop/ @Nov1c444
|
||||
api/core/workflow/nodes/llm/ @Nov1c444
|
||||
|
||||
# Backend - RAG (Retrieval Augmented Generation)
|
||||
/api/core/rag/ @JohnJyong
|
||||
/api/services/rag_pipeline/ @JohnJyong
|
||||
/api/services/dataset_service.py @JohnJyong
|
||||
/api/services/knowledge_service.py @JohnJyong
|
||||
/api/services/external_knowledge_service.py @JohnJyong
|
||||
/api/services/hit_testing_service.py @JohnJyong
|
||||
/api/services/metadata_service.py @JohnJyong
|
||||
/api/services/vector_service.py @JohnJyong
|
||||
/api/services/entities/knowledge_entities/ @JohnJyong
|
||||
/api/services/entities/external_knowledge_entities/ @JohnJyong
|
||||
/api/controllers/console/datasets/ @JohnJyong
|
||||
/api/controllers/service_api/dataset/ @JohnJyong
|
||||
/api/models/dataset.py @JohnJyong
|
||||
/api/tasks/rag_pipeline/ @JohnJyong
|
||||
/api/tasks/add_document_to_index_task.py @JohnJyong
|
||||
/api/tasks/batch_clean_document_task.py @JohnJyong
|
||||
/api/tasks/clean_document_task.py @JohnJyong
|
||||
/api/tasks/clean_notion_document_task.py @JohnJyong
|
||||
/api/tasks/document_indexing_task.py @JohnJyong
|
||||
/api/tasks/document_indexing_sync_task.py @JohnJyong
|
||||
/api/tasks/document_indexing_update_task.py @JohnJyong
|
||||
/api/tasks/duplicate_document_indexing_task.py @JohnJyong
|
||||
/api/tasks/recover_document_indexing_task.py @JohnJyong
|
||||
/api/tasks/remove_document_from_index_task.py @JohnJyong
|
||||
/api/tasks/retry_document_indexing_task.py @JohnJyong
|
||||
/api/tasks/sync_website_document_indexing_task.py @JohnJyong
|
||||
/api/tasks/batch_create_segment_to_index_task.py @JohnJyong
|
||||
/api/tasks/create_segment_to_index_task.py @JohnJyong
|
||||
/api/tasks/delete_segment_from_index_task.py @JohnJyong
|
||||
/api/tasks/disable_segment_from_index_task.py @JohnJyong
|
||||
/api/tasks/disable_segments_from_index_task.py @JohnJyong
|
||||
/api/tasks/enable_segment_to_index_task.py @JohnJyong
|
||||
/api/tasks/enable_segments_to_index_task.py @JohnJyong
|
||||
/api/tasks/clean_dataset_task.py @JohnJyong
|
||||
/api/tasks/deal_dataset_index_update_task.py @JohnJyong
|
||||
/api/tasks/deal_dataset_vector_index_task.py @JohnJyong
|
||||
api/core/rag/ @JohnJyong
|
||||
api/services/rag_pipeline/ @JohnJyong
|
||||
api/services/dataset_service.py @JohnJyong
|
||||
api/services/knowledge_service.py @JohnJyong
|
||||
api/services/external_knowledge_service.py @JohnJyong
|
||||
api/services/hit_testing_service.py @JohnJyong
|
||||
api/services/metadata_service.py @JohnJyong
|
||||
api/services/vector_service.py @JohnJyong
|
||||
api/services/entities/knowledge_entities/ @JohnJyong
|
||||
api/services/entities/external_knowledge_entities/ @JohnJyong
|
||||
api/controllers/console/datasets/ @JohnJyong
|
||||
api/controllers/service_api/dataset/ @JohnJyong
|
||||
api/models/dataset.py @JohnJyong
|
||||
api/tasks/rag_pipeline/ @JohnJyong
|
||||
api/tasks/add_document_to_index_task.py @JohnJyong
|
||||
api/tasks/batch_clean_document_task.py @JohnJyong
|
||||
api/tasks/clean_document_task.py @JohnJyong
|
||||
api/tasks/clean_notion_document_task.py @JohnJyong
|
||||
api/tasks/document_indexing_task.py @JohnJyong
|
||||
api/tasks/document_indexing_sync_task.py @JohnJyong
|
||||
api/tasks/document_indexing_update_task.py @JohnJyong
|
||||
api/tasks/duplicate_document_indexing_task.py @JohnJyong
|
||||
api/tasks/recover_document_indexing_task.py @JohnJyong
|
||||
api/tasks/remove_document_from_index_task.py @JohnJyong
|
||||
api/tasks/retry_document_indexing_task.py @JohnJyong
|
||||
api/tasks/sync_website_document_indexing_task.py @JohnJyong
|
||||
api/tasks/batch_create_segment_to_index_task.py @JohnJyong
|
||||
api/tasks/create_segment_to_index_task.py @JohnJyong
|
||||
api/tasks/delete_segment_from_index_task.py @JohnJyong
|
||||
api/tasks/disable_segment_from_index_task.py @JohnJyong
|
||||
api/tasks/disable_segments_from_index_task.py @JohnJyong
|
||||
api/tasks/enable_segment_to_index_task.py @JohnJyong
|
||||
api/tasks/enable_segments_to_index_task.py @JohnJyong
|
||||
api/tasks/clean_dataset_task.py @JohnJyong
|
||||
api/tasks/deal_dataset_index_update_task.py @JohnJyong
|
||||
api/tasks/deal_dataset_vector_index_task.py @JohnJyong
|
||||
|
||||
# Backend - Plugins
|
||||
/api/core/plugin/ @Mairuis @Yeuoly @Stream29
|
||||
/api/services/plugin/ @Mairuis @Yeuoly @Stream29
|
||||
/api/controllers/console/workspace/plugin.py @Mairuis @Yeuoly @Stream29
|
||||
/api/controllers/inner_api/plugin/ @Mairuis @Yeuoly @Stream29
|
||||
/api/tasks/process_tenant_plugin_autoupgrade_check_task.py @Mairuis @Yeuoly @Stream29
|
||||
api/core/plugin/ @Mairuis @Yeuoly @Stream29
|
||||
api/services/plugin/ @Mairuis @Yeuoly @Stream29
|
||||
api/controllers/console/workspace/plugin.py @Mairuis @Yeuoly @Stream29
|
||||
api/controllers/inner_api/plugin/ @Mairuis @Yeuoly @Stream29
|
||||
api/tasks/process_tenant_plugin_autoupgrade_check_task.py @Mairuis @Yeuoly @Stream29
|
||||
|
||||
# Backend - Trigger/Schedule/Webhook
|
||||
/api/controllers/trigger/ @Mairuis @Yeuoly
|
||||
/api/controllers/console/app/workflow_trigger.py @Mairuis @Yeuoly
|
||||
/api/controllers/console/workspace/trigger_providers.py @Mairuis @Yeuoly
|
||||
/api/core/trigger/ @Mairuis @Yeuoly
|
||||
/api/core/app/layers/trigger_post_layer.py @Mairuis @Yeuoly
|
||||
/api/services/trigger/ @Mairuis @Yeuoly
|
||||
/api/models/trigger.py @Mairuis @Yeuoly
|
||||
/api/fields/workflow_trigger_fields.py @Mairuis @Yeuoly
|
||||
/api/repositories/workflow_trigger_log_repository.py @Mairuis @Yeuoly
|
||||
/api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis @Yeuoly
|
||||
/api/libs/schedule_utils.py @Mairuis @Yeuoly
|
||||
/api/services/workflow/scheduler.py @Mairuis @Yeuoly
|
||||
/api/schedule/trigger_provider_refresh_task.py @Mairuis @Yeuoly
|
||||
/api/schedule/workflow_schedule_task.py @Mairuis @Yeuoly
|
||||
/api/tasks/trigger_processing_tasks.py @Mairuis @Yeuoly
|
||||
/api/tasks/trigger_subscription_refresh_tasks.py @Mairuis @Yeuoly
|
||||
/api/tasks/workflow_schedule_tasks.py @Mairuis @Yeuoly
|
||||
/api/tasks/workflow_cfs_scheduler/ @Mairuis @Yeuoly
|
||||
/api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis @Yeuoly
|
||||
/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis @Yeuoly
|
||||
/api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis @Yeuoly
|
||||
/api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis @Yeuoly
|
||||
api/controllers/trigger/ @Mairuis @Yeuoly
|
||||
api/controllers/console/app/workflow_trigger.py @Mairuis @Yeuoly
|
||||
api/controllers/console/workspace/trigger_providers.py @Mairuis @Yeuoly
|
||||
api/core/trigger/ @Mairuis @Yeuoly
|
||||
api/core/app/layers/trigger_post_layer.py @Mairuis @Yeuoly
|
||||
api/services/trigger/ @Mairuis @Yeuoly
|
||||
api/models/trigger.py @Mairuis @Yeuoly
|
||||
api/fields/workflow_trigger_fields.py @Mairuis @Yeuoly
|
||||
api/repositories/workflow_trigger_log_repository.py @Mairuis @Yeuoly
|
||||
api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis @Yeuoly
|
||||
api/libs/schedule_utils.py @Mairuis @Yeuoly
|
||||
api/services/workflow/scheduler.py @Mairuis @Yeuoly
|
||||
api/schedule/trigger_provider_refresh_task.py @Mairuis @Yeuoly
|
||||
api/schedule/workflow_schedule_task.py @Mairuis @Yeuoly
|
||||
api/tasks/trigger_processing_tasks.py @Mairuis @Yeuoly
|
||||
api/tasks/trigger_subscription_refresh_tasks.py @Mairuis @Yeuoly
|
||||
api/tasks/workflow_schedule_tasks.py @Mairuis @Yeuoly
|
||||
api/tasks/workflow_cfs_scheduler/ @Mairuis @Yeuoly
|
||||
api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis @Yeuoly
|
||||
api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis @Yeuoly
|
||||
api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis @Yeuoly
|
||||
api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis @Yeuoly
|
||||
|
||||
# Backend - Async Workflow
|
||||
/api/services/async_workflow_service.py @Mairuis @Yeuoly
|
||||
/api/tasks/async_workflow_tasks.py @Mairuis @Yeuoly
|
||||
api/services/async_workflow_service.py @Mairuis @Yeuoly
|
||||
api/tasks/async_workflow_tasks.py @Mairuis @Yeuoly
|
||||
|
||||
# Backend - Billing
|
||||
/api/services/billing_service.py @hj24 @zyssyz123
|
||||
/api/controllers/console/billing/ @hj24 @zyssyz123
|
||||
api/services/billing_service.py @hj24 @zyssyz123
|
||||
api/controllers/console/billing/ @hj24 @zyssyz123
|
||||
|
||||
# Backend - Enterprise
|
||||
/api/configs/enterprise/ @GarfieldDai @GareArc
|
||||
/api/services/enterprise/ @GarfieldDai @GareArc
|
||||
/api/services/feature_service.py @GarfieldDai @GareArc
|
||||
/api/controllers/console/feature.py @GarfieldDai @GareArc
|
||||
/api/controllers/web/feature.py @GarfieldDai @GareArc
|
||||
api/configs/enterprise/ @GarfieldDai @GareArc
|
||||
api/services/enterprise/ @GarfieldDai @GareArc
|
||||
api/services/feature_service.py @GarfieldDai @GareArc
|
||||
api/controllers/console/feature.py @GarfieldDai @GareArc
|
||||
api/controllers/web/feature.py @GarfieldDai @GareArc
|
||||
|
||||
# Backend - Database Migrations
|
||||
/api/migrations/ @snakevash @laipz8200 @MRZHUH
|
||||
|
||||
# Backend - Vector DB Middleware
|
||||
/api/configs/middleware/vdb/* @JohnJyong
|
||||
api/migrations/ @snakevash @laipz8200
|
||||
|
||||
# Frontend
|
||||
/web/ @iamjoel
|
||||
|
||||
# Frontend - Web Tests
|
||||
/.github/workflows/web-tests.yml @iamjoel
|
||||
web/ @iamjoel
|
||||
|
||||
# Frontend - App - Orchestration
|
||||
/web/app/components/workflow/ @iamjoel @zxhlyh
|
||||
/web/app/components/workflow-app/ @iamjoel @zxhlyh
|
||||
/web/app/components/app/configuration/ @iamjoel @zxhlyh
|
||||
/web/app/components/app/app-publisher/ @iamjoel @zxhlyh
|
||||
web/app/components/workflow/ @iamjoel @zxhlyh
|
||||
web/app/components/workflow-app/ @iamjoel @zxhlyh
|
||||
web/app/components/app/configuration/ @iamjoel @zxhlyh
|
||||
web/app/components/app/app-publisher/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - WebApp - Chat
|
||||
/web/app/components/base/chat/ @iamjoel @zxhlyh
|
||||
web/app/components/base/chat/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - WebApp - Completion
|
||||
/web/app/components/share/text-generation/ @iamjoel @zxhlyh
|
||||
web/app/components/share/text-generation/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - App - List and Creation
|
||||
/web/app/components/apps/ @JzoNgKVO @iamjoel
|
||||
/web/app/components/app/create-app-dialog/ @JzoNgKVO @iamjoel
|
||||
/web/app/components/app/create-app-modal/ @JzoNgKVO @iamjoel
|
||||
/web/app/components/app/create-from-dsl-modal/ @JzoNgKVO @iamjoel
|
||||
web/app/components/apps/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/create-app-dialog/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/create-app-modal/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/create-from-dsl-modal/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - App - API Documentation
|
||||
/web/app/components/develop/ @JzoNgKVO @iamjoel
|
||||
web/app/components/develop/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - App - Logs and Annotations
|
||||
/web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel
|
||||
/web/app/components/app/log/ @JzoNgKVO @iamjoel
|
||||
/web/app/components/app/log-annotation/ @JzoNgKVO @iamjoel
|
||||
/web/app/components/app/annotation/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/log/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/log-annotation/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/annotation/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - App - Monitoring
|
||||
/web/app/(commonLayout)/app/(appDetailLayout)/\[appId\]/overview/ @JzoNgKVO @iamjoel
|
||||
/web/app/components/app/overview/ @JzoNgKVO @iamjoel
|
||||
web/app/(commonLayout)/app/(appDetailLayout)/\[appId\]/overview/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/overview/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - App - Settings
|
||||
/web/app/components/app-sidebar/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app-sidebar/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - RAG - Hit Testing
|
||||
/web/app/components/datasets/hit-testing/ @JzoNgKVO @iamjoel
|
||||
web/app/components/datasets/hit-testing/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - RAG - List and Creation
|
||||
/web/app/components/datasets/list/ @iamjoel @WTW0313
|
||||
/web/app/components/datasets/create/ @iamjoel @WTW0313
|
||||
/web/app/components/datasets/create-from-pipeline/ @iamjoel @WTW0313
|
||||
/web/app/components/datasets/external-knowledge-base/ @iamjoel @WTW0313
|
||||
web/app/components/datasets/list/ @iamjoel @WTW0313
|
||||
web/app/components/datasets/create/ @iamjoel @WTW0313
|
||||
web/app/components/datasets/create-from-pipeline/ @iamjoel @WTW0313
|
||||
web/app/components/datasets/external-knowledge-base/ @iamjoel @WTW0313
|
||||
|
||||
# Frontend - RAG - Orchestration (general rule first, specific rules below override)
|
||||
/web/app/components/rag-pipeline/ @iamjoel @WTW0313
|
||||
/web/app/components/rag-pipeline/components/rag-pipeline-main.tsx @iamjoel @zxhlyh
|
||||
/web/app/components/rag-pipeline/store/ @iamjoel @zxhlyh
|
||||
web/app/components/rag-pipeline/ @iamjoel @WTW0313
|
||||
web/app/components/rag-pipeline/components/rag-pipeline-main.tsx @iamjoel @zxhlyh
|
||||
web/app/components/rag-pipeline/store/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - RAG - Documents List
|
||||
/web/app/components/datasets/documents/list.tsx @iamjoel @WTW0313
|
||||
/web/app/components/datasets/documents/create-from-pipeline/ @iamjoel @WTW0313
|
||||
web/app/components/datasets/documents/list.tsx @iamjoel @WTW0313
|
||||
web/app/components/datasets/documents/create-from-pipeline/ @iamjoel @WTW0313
|
||||
|
||||
# Frontend - RAG - Segments List
|
||||
/web/app/components/datasets/documents/detail/ @iamjoel @WTW0313
|
||||
web/app/components/datasets/documents/detail/ @iamjoel @WTW0313
|
||||
|
||||
# Frontend - RAG - Settings
|
||||
/web/app/components/datasets/settings/ @iamjoel @WTW0313
|
||||
web/app/components/datasets/settings/ @iamjoel @WTW0313
|
||||
|
||||
# Frontend - Ecosystem - Plugins
|
||||
/web/app/components/plugins/ @iamjoel @zhsama
|
||||
web/app/components/plugins/ @iamjoel @zhsama
|
||||
|
||||
# Frontend - Ecosystem - Tools
|
||||
/web/app/components/tools/ @iamjoel @Yessenia-d
|
||||
web/app/components/tools/ @iamjoel @Yessenia-d
|
||||
|
||||
# Frontend - Ecosystem - MarketPlace
|
||||
/web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d
|
||||
web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d
|
||||
|
||||
# Frontend - Login and Registration
|
||||
/web/app/signin/ @douxc @iamjoel
|
||||
/web/app/signup/ @douxc @iamjoel
|
||||
/web/app/reset-password/ @douxc @iamjoel
|
||||
/web/app/install/ @douxc @iamjoel
|
||||
/web/app/init/ @douxc @iamjoel
|
||||
/web/app/forgot-password/ @douxc @iamjoel
|
||||
/web/app/account/ @douxc @iamjoel
|
||||
web/app/signin/ @douxc @iamjoel
|
||||
web/app/signup/ @douxc @iamjoel
|
||||
web/app/reset-password/ @douxc @iamjoel
|
||||
web/app/install/ @douxc @iamjoel
|
||||
web/app/init/ @douxc @iamjoel
|
||||
web/app/forgot-password/ @douxc @iamjoel
|
||||
web/app/account/ @douxc @iamjoel
|
||||
|
||||
# Frontend - Service Authentication
|
||||
/web/service/base.ts @douxc @iamjoel
|
||||
web/service/base.ts @douxc @iamjoel
|
||||
|
||||
# Frontend - WebApp Authentication and Access Control
|
||||
/web/app/(shareLayout)/components/ @douxc @iamjoel
|
||||
/web/app/(shareLayout)/webapp-signin/ @douxc @iamjoel
|
||||
/web/app/(shareLayout)/webapp-reset-password/ @douxc @iamjoel
|
||||
/web/app/components/app/app-access-control/ @douxc @iamjoel
|
||||
web/app/(shareLayout)/components/ @douxc @iamjoel
|
||||
web/app/(shareLayout)/webapp-signin/ @douxc @iamjoel
|
||||
web/app/(shareLayout)/webapp-reset-password/ @douxc @iamjoel
|
||||
web/app/components/app/app-access-control/ @douxc @iamjoel
|
||||
|
||||
# Frontend - Explore Page
|
||||
/web/app/components/explore/ @CodingOnStar @iamjoel
|
||||
web/app/components/explore/ @CodingOnStar @iamjoel
|
||||
|
||||
# Frontend - Personal Settings
|
||||
/web/app/components/header/account-setting/ @CodingOnStar @iamjoel
|
||||
/web/app/components/header/account-dropdown/ @CodingOnStar @iamjoel
|
||||
web/app/components/header/account-setting/ @CodingOnStar @iamjoel
|
||||
web/app/components/header/account-dropdown/ @CodingOnStar @iamjoel
|
||||
|
||||
# Frontend - Analytics
|
||||
/web/app/components/base/ga/ @CodingOnStar @iamjoel
|
||||
web/app/components/base/ga/ @CodingOnStar @iamjoel
|
||||
|
||||
# Frontend - Base Components
|
||||
/web/app/components/base/ @iamjoel @zxhlyh
|
||||
web/app/components/base/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - Utils and Hooks
|
||||
/web/utils/classnames.ts @iamjoel @zxhlyh
|
||||
/web/utils/time.ts @iamjoel @zxhlyh
|
||||
/web/utils/format.ts @iamjoel @zxhlyh
|
||||
/web/utils/clipboard.ts @iamjoel @zxhlyh
|
||||
/web/hooks/use-document-title.ts @iamjoel @zxhlyh
|
||||
web/utils/classnames.ts @iamjoel @zxhlyh
|
||||
web/utils/time.ts @iamjoel @zxhlyh
|
||||
web/utils/format.ts @iamjoel @zxhlyh
|
||||
web/utils/clipboard.ts @iamjoel @zxhlyh
|
||||
web/hooks/use-document-title.ts @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - Billing and Education
|
||||
/web/app/components/billing/ @iamjoel @zxhlyh
|
||||
/web/app/education-apply/ @iamjoel @zxhlyh
|
||||
web/app/components/billing/ @iamjoel @zxhlyh
|
||||
web/app/education-apply/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - Workspace
|
||||
/web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh
|
||||
|
||||
# Docker
|
||||
/docker/* @laipz8200
|
||||
web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh
|
||||
|
||||
14
.github/ISSUE_TEMPLATE/refactor.yml
vendored
14
.github/ISSUE_TEMPLATE/refactor.yml
vendored
@@ -1,6 +1,8 @@
|
||||
name: "✨ Refactor or Chore"
|
||||
description: Refactor existing code or perform maintenance chores to improve readability and reliability.
|
||||
title: "[Refactor/Chore] "
|
||||
name: "✨ Refactor"
|
||||
description: Refactor existing code for improved readability and maintainability.
|
||||
title: "[Chore/Refactor] "
|
||||
labels:
|
||||
- refactor
|
||||
body:
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
@@ -9,7 +11,7 @@ body:
|
||||
options:
|
||||
- label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542).
|
||||
required: true
|
||||
- label: This is only for refactors or chores; if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
|
||||
- label: This is only for refactoring, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
|
||||
required: true
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
@@ -23,14 +25,14 @@ body:
|
||||
id: description
|
||||
attributes:
|
||||
label: Description
|
||||
placeholder: "Describe the refactor or chore you are proposing."
|
||||
placeholder: "Describe the refactor you are proposing."
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: motivation
|
||||
attributes:
|
||||
label: Motivation
|
||||
placeholder: "Explain why this refactor or chore is necessary."
|
||||
placeholder: "Explain why this refactor is necessary."
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
|
||||
13
.github/ISSUE_TEMPLATE/tracker.yml
vendored
Normal file
13
.github/ISSUE_TEMPLATE/tracker.yml
vendored
Normal file
@@ -0,0 +1,13 @@
|
||||
name: "👾 Tracker"
|
||||
description: For inner usages, please do not use this template.
|
||||
title: "[Tracker] "
|
||||
labels:
|
||||
- tracker
|
||||
body:
|
||||
- type: textarea
|
||||
id: content
|
||||
attributes:
|
||||
label: Blockers
|
||||
placeholder: "- [ ] ..."
|
||||
validations:
|
||||
required: true
|
||||
12
.github/copilot-instructions.md
vendored
Normal file
12
.github/copilot-instructions.md
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
# Copilot Instructions
|
||||
|
||||
GitHub Copilot must follow the unified frontend testing requirements documented in `web/testing/testing.md`.
|
||||
|
||||
Key reminders:
|
||||
|
||||
- Generate tests using the mandated tech stack, naming, and code style (AAA pattern, `fireEvent`, descriptive test names, cleans up mocks).
|
||||
- Cover rendering, prop combinations, and edge cases by default; extend coverage for hooks, routing, async flows, and domain-specific components when applicable.
|
||||
- Target >95% line and branch coverage and 100% function/statement coverage.
|
||||
- Apply the project's mocking conventions for i18n, toast notifications, and Next.js utilities.
|
||||
|
||||
Any suggestions from Copilot that conflict with `web/testing/testing.md` should be revised before acceptance.
|
||||
39
.github/workflows/api-tests.yml
vendored
39
.github/workflows/api-tests.yml
vendored
@@ -22,12 +22,12 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@@ -57,7 +57,7 @@ jobs:
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
- name: Set up Sandbox
|
||||
uses: hoverkraft-tech/compose-action@v2
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
@@ -71,18 +71,18 @@ jobs:
|
||||
run: |
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
- name: Run API Tests
|
||||
env:
|
||||
STORAGE_TYPE: opendal
|
||||
OPENDAL_SCHEME: fs
|
||||
OPENDAL_FS_ROOT: /tmp/dify-storage
|
||||
- name: Run Workflow
|
||||
run: uv run --project api bash dev/pytest/pytest_workflow.sh
|
||||
|
||||
- name: Run Tool
|
||||
run: uv run --project api bash dev/pytest/pytest_tools.sh
|
||||
|
||||
- name: Run TestContainers
|
||||
run: uv run --project api bash dev/pytest/pytest_testcontainers.sh
|
||||
|
||||
- name: Run Unit tests
|
||||
run: |
|
||||
uv run --project api pytest \
|
||||
--timeout "${PYTEST_TIMEOUT:-180}" \
|
||||
api/tests/integration_tests/workflow \
|
||||
api/tests/integration_tests/tools \
|
||||
api/tests/test_containers_integration_tests \
|
||||
api/tests/unit_tests
|
||||
uv run --project api bash dev/pytest/pytest_unit_tests.sh
|
||||
|
||||
- name: Coverage Summary
|
||||
run: |
|
||||
@@ -93,12 +93,5 @@ jobs:
|
||||
# Create a detailed coverage summary
|
||||
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
|
||||
{
|
||||
echo ""
|
||||
echo "<details><summary>File-level coverage (click to expand)</summary>"
|
||||
echo ""
|
||||
echo '```'
|
||||
uv run --project api coverage report -m
|
||||
echo '```'
|
||||
echo "</details>"
|
||||
} >> $GITHUB_STEP_SUMMARY
|
||||
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
|
||||
59
.github/workflows/autofix.yml
vendored
59
.github/workflows/autofix.yml
vendored
@@ -12,29 +12,12 @@ jobs:
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Check Docker Compose inputs
|
||||
id: docker-compose-changes
|
||||
uses: tj-actions/changed-files@v46
|
||||
with:
|
||||
files: |
|
||||
docker/generate_docker_compose
|
||||
docker/.env.example
|
||||
docker/docker-compose-template.yaml
|
||||
docker/docker-compose.yaml
|
||||
- uses: actions/setup-python@v5
|
||||
# Use uv to ensure we have the same ruff version in CI and locally.
|
||||
- uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- uses: astral-sh/setup-uv@v7
|
||||
|
||||
- name: Generate Docker Compose
|
||||
if: steps.docker-compose-changes.outputs.any_changed == 'true'
|
||||
run: |
|
||||
cd docker
|
||||
./generate_docker_compose
|
||||
|
||||
- run: |
|
||||
cd api
|
||||
uv sync --dev
|
||||
@@ -52,11 +35,10 @@ jobs:
|
||||
|
||||
- name: ast-grep
|
||||
run: |
|
||||
# ast-grep exits 1 if no matches are found; allow idempotent runs.
|
||||
uvx --from ast-grep-cli ast-grep --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all || true
|
||||
uvx --from ast-grep-cli ast-grep --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all || true
|
||||
uvx --from ast-grep-cli ast-grep -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all || true
|
||||
uvx --from ast-grep-cli ast-grep -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all || true
|
||||
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
|
||||
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
|
||||
uvx --from ast-grep-cli sg -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all
|
||||
uvx --from ast-grep-cli sg -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all
|
||||
# Convert Optional[T] to T | None (ignoring quoted types)
|
||||
cat > /tmp/optional-rule.yml << 'EOF'
|
||||
id: convert-optional-to-union
|
||||
@@ -74,14 +56,35 @@ jobs:
|
||||
pattern: $T
|
||||
fix: $T | None
|
||||
EOF
|
||||
uvx --from ast-grep-cli ast-grep scan . --inline-rules "$(cat /tmp/optional-rule.yml)" --update-all
|
||||
uvx --from ast-grep-cli sg scan --inline-rules "$(cat /tmp/optional-rule.yml)" --update-all
|
||||
# Fix forward references that were incorrectly converted (Python doesn't support "Type" | None syntax)
|
||||
find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \;
|
||||
find . -name "*.py.bak" -type f -delete
|
||||
|
||||
# mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter.
|
||||
- name: mdformat
|
||||
run: |
|
||||
uvx --python 3.13 mdformat . --exclude ".claude/skills/**/SKILL.md"
|
||||
uvx mdformat .
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Setup NodeJS
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/package.json
|
||||
|
||||
- name: Web dependencies
|
||||
working-directory: ./web
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: oxlint
|
||||
working-directory: ./web
|
||||
run: |
|
||||
pnpx oxlint --fix
|
||||
|
||||
- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27
|
||||
|
||||
2
.github/workflows/build-push.yml
vendored
2
.github/workflows/build-push.yml
vendored
@@ -90,7 +90,7 @@ jobs:
|
||||
touch "/tmp/digests/${sanitized_digest}"
|
||||
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: digests-${{ matrix.context }}-${{ env.PLATFORM_PAIR }}
|
||||
path: /tmp/digests/*
|
||||
|
||||
8
.github/workflows/db-migration-test.yml
vendored
8
.github/workflows/db-migration-test.yml
vendored
@@ -13,13 +13,13 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
@@ -63,13 +63,13 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
|
||||
3
.github/workflows/main-ci.yml
vendored
3
.github/workflows/main-ci.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
||||
vdb-changed: ${{ steps.changes.outputs.vdb }}
|
||||
migration-changed: ${{ steps.changes.outputs.migration }}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- uses: dorny/paths-filter@v3
|
||||
id: changes
|
||||
with:
|
||||
@@ -38,7 +38,6 @@ jobs:
|
||||
- '.github/workflows/api-tests.yml'
|
||||
web:
|
||||
- 'web/**'
|
||||
- '.github/workflows/web-tests.yml'
|
||||
vdb:
|
||||
- 'api/core/rag/datasource/**'
|
||||
- 'docker/**'
|
||||
|
||||
21
.github/workflows/semantic-pull-request.yml
vendored
21
.github/workflows/semantic-pull-request.yml
vendored
@@ -1,21 +0,0 @@
|
||||
name: Semantic Pull Request
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types:
|
||||
- opened
|
||||
- edited
|
||||
- reopened
|
||||
- synchronize
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
name: Validate PR title
|
||||
permissions:
|
||||
pull-requests: read
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check title
|
||||
uses: amannn/action-semantic-pull-request@v6.1.1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
58
.github/workflows/style.yml
vendored
58
.github/workflows/style.yml
vendored
@@ -19,13 +19,13 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v47
|
||||
uses: tj-actions/changed-files@v46
|
||||
with:
|
||||
files: |
|
||||
api/**
|
||||
@@ -33,7 +33,7 @@ jobs:
|
||||
|
||||
- name: Setup UV and Python
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: false
|
||||
python-version: "3.12"
|
||||
@@ -68,17 +68,15 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v47
|
||||
uses: tj-actions/changed-files@v46
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
.github/workflows/style.yml
|
||||
files: web/**
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
@@ -87,12 +85,12 @@ jobs:
|
||||
run_install: false
|
||||
|
||||
- name: Setup NodeJS
|
||||
uses: actions/setup-node@v6
|
||||
uses: actions/setup-node@v4
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
node-version: 22
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
cache-dependency-path: ./web/package.json
|
||||
|
||||
- name: Web dependencies
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
@@ -108,17 +106,37 @@ jobs:
|
||||
- name: Web type check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run type-check:tsgo
|
||||
run: pnpm run type-check
|
||||
|
||||
- name: Web dead code check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run knip
|
||||
docker-compose-template:
|
||||
name: Docker Compose Template
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
- name: Web build check
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v46
|
||||
with:
|
||||
files: |
|
||||
docker/generate_docker_compose
|
||||
docker/.env.example
|
||||
docker/docker-compose-template.yaml
|
||||
docker/docker-compose.yaml
|
||||
|
||||
- name: Generate Docker Compose
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run build
|
||||
run: |
|
||||
cd docker
|
||||
./generate_docker_compose
|
||||
|
||||
- name: Check for changes
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: git diff --exit-code
|
||||
|
||||
superlinter:
|
||||
name: SuperLinter
|
||||
@@ -126,14 +144,14 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v47
|
||||
uses: tj-actions/changed-files@v46
|
||||
with:
|
||||
files: |
|
||||
**.sh
|
||||
|
||||
4
.github/workflows/tool-test-sdks.yaml
vendored
4
.github/workflows/tool-test-sdks.yaml
vendored
@@ -25,12 +25,12 @@ jobs:
|
||||
working-directory: sdks/nodejs-client
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Use Node.js ${{ matrix.node-version }}
|
||||
uses: actions/setup-node@v6
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: ${{ matrix.node-version }}
|
||||
cache: ''
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
name: Translate i18n Files Based on English
|
||||
name: Check i18n Files and Create PR
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
paths:
|
||||
- 'web/i18n/en-US/*.json'
|
||||
workflow_dispatch:
|
||||
- 'web/i18n/en-US/*.ts'
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
@@ -19,7 +18,6 @@ jobs:
|
||||
run:
|
||||
working-directory: web
|
||||
steps:
|
||||
# Keep use old checkout action version for https://github.com/peter-evans/create-pull-request/issues/4272
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
@@ -28,28 +26,21 @@ jobs:
|
||||
- name: Check for file changes in i18n/en-US
|
||||
id: check_files
|
||||
run: |
|
||||
# Skip check for manual trigger, translate all files
|
||||
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
|
||||
git fetch origin "${{ github.event.before }}" || true
|
||||
git fetch origin "${{ github.sha }}" || true
|
||||
changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.ts')
|
||||
echo "Changed files: $changed_files"
|
||||
if [ -n "$changed_files" ]; then
|
||||
echo "FILES_CHANGED=true" >> $GITHUB_ENV
|
||||
echo "FILE_ARGS=" >> $GITHUB_ENV
|
||||
echo "Manual trigger: translating all files"
|
||||
file_args=""
|
||||
for file in $changed_files; do
|
||||
filename=$(basename "$file" .ts)
|
||||
file_args="$file_args --file $filename"
|
||||
done
|
||||
echo "FILE_ARGS=$file_args" >> $GITHUB_ENV
|
||||
echo "File arguments: $file_args"
|
||||
else
|
||||
git fetch origin "${{ github.event.before }}" || true
|
||||
git fetch origin "${{ github.sha }}" || true
|
||||
changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.json')
|
||||
echo "Changed files: $changed_files"
|
||||
if [ -n "$changed_files" ]; then
|
||||
echo "FILES_CHANGED=true" >> $GITHUB_ENV
|
||||
file_args=""
|
||||
for file in $changed_files; do
|
||||
filename=$(basename "$file" .json)
|
||||
file_args="$file_args --file $filename"
|
||||
done
|
||||
echo "FILE_ARGS=$file_args" >> $GITHUB_ENV
|
||||
echo "File arguments: $file_args"
|
||||
else
|
||||
echo "FILES_CHANGED=false" >> $GITHUB_ENV
|
||||
fi
|
||||
echo "FILES_CHANGED=false" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
- name: Install pnpm
|
||||
@@ -60,11 +51,11 @@ jobs:
|
||||
|
||||
- name: Set up Node.js
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
uses: actions/setup-node@v6
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 'lts/*'
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
cache-dependency-path: ./web/package.json
|
||||
|
||||
- name: Install dependencies
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
@@ -74,7 +65,12 @@ jobs:
|
||||
- name: Generate i18n translations
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run i18n:gen ${{ env.FILE_ARGS }}
|
||||
run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }}
|
||||
|
||||
- name: Generate i18n type definitions
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run gen:i18n-types
|
||||
|
||||
- name: Create Pull Request
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
@@ -82,13 +78,14 @@ jobs:
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
commit-message: 'chore(i18n): update translations based on en-US changes'
|
||||
title: 'chore(i18n): translate i18n files based on en-US changes'
|
||||
title: 'chore(i18n): translate i18n files and update type definitions'
|
||||
body: |
|
||||
This PR was automatically created to update i18n translation files based on changes in en-US locale.
|
||||
This PR was automatically created to update i18n files and TypeScript type definitions based on changes in en-US locale.
|
||||
|
||||
**Triggered by:** ${{ github.sha }}
|
||||
|
||||
**Changes included:**
|
||||
- Updated translation files for all locales
|
||||
- Regenerated TypeScript type definitions for type safety
|
||||
branch: chore/automated-i18n-updates-${{ github.sha }}
|
||||
delete-branch: true
|
||||
|
||||
6
.github/workflows/vdb-tests.yml
vendored
6
.github/workflows/vdb-tests.yml
vendored
@@ -19,19 +19,19 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Free Disk Space
|
||||
uses: endersonmenezes/free-disk-space@v3
|
||||
uses: endersonmenezes/free-disk-space@v2
|
||||
with:
|
||||
remove_dotnet: true
|
||||
remove_haskell: true
|
||||
remove_tool_cache: true
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
352
.github/workflows/web-tests.yml
vendored
352
.github/workflows/web-tests.yml
vendored
@@ -13,356 +13,46 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: ./web
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v46
|
||||
with:
|
||||
files: web/**
|
||||
|
||||
- name: Install pnpm
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
uses: actions/setup-node@v4
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
node-version: 22
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
cache-dependency-path: ./web/package.json
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Check i18n types synchronization
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run check:i18n-types
|
||||
|
||||
- name: Run tests
|
||||
run: pnpm test:coverage
|
||||
|
||||
- name: Coverage Summary
|
||||
if: always()
|
||||
id: coverage-summary
|
||||
run: |
|
||||
set -eo pipefail
|
||||
|
||||
COVERAGE_FILE="coverage/coverage-final.json"
|
||||
COVERAGE_SUMMARY_FILE="coverage/coverage-summary.json"
|
||||
|
||||
if [ ! -f "$COVERAGE_FILE" ] && [ ! -f "$COVERAGE_SUMMARY_FILE" ]; then
|
||||
echo "has_coverage=false" >> "$GITHUB_OUTPUT"
|
||||
echo "### 🚨 Test Coverage Report :test_tube:" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo "Coverage data not found. Ensure Vitest runs with coverage enabled." >> "$GITHUB_STEP_SUMMARY"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "has_coverage=true" >> "$GITHUB_OUTPUT"
|
||||
|
||||
node <<'NODE' >> "$GITHUB_STEP_SUMMARY"
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
let libCoverage = null;
|
||||
|
||||
try {
|
||||
libCoverage = require('istanbul-lib-coverage');
|
||||
} catch (error) {
|
||||
libCoverage = null;
|
||||
}
|
||||
|
||||
const summaryPath = path.join('coverage', 'coverage-summary.json');
|
||||
const finalPath = path.join('coverage', 'coverage-final.json');
|
||||
|
||||
const hasSummary = fs.existsSync(summaryPath);
|
||||
const hasFinal = fs.existsSync(finalPath);
|
||||
|
||||
if (!hasSummary && !hasFinal) {
|
||||
console.log('### Test Coverage Summary :test_tube:');
|
||||
console.log('');
|
||||
console.log('No coverage data found.');
|
||||
process.exit(0);
|
||||
}
|
||||
|
||||
const summary = hasSummary
|
||||
? JSON.parse(fs.readFileSync(summaryPath, 'utf8'))
|
||||
: null;
|
||||
const coverage = hasFinal
|
||||
? JSON.parse(fs.readFileSync(finalPath, 'utf8'))
|
||||
: null;
|
||||
|
||||
const getLineCoverageFromStatements = (statementMap, statementHits) => {
|
||||
const lineHits = {};
|
||||
|
||||
if (!statementMap || !statementHits) {
|
||||
return lineHits;
|
||||
}
|
||||
|
||||
Object.entries(statementMap).forEach(([key, statement]) => {
|
||||
const line = statement?.start?.line;
|
||||
if (!line) {
|
||||
return;
|
||||
}
|
||||
const hits = statementHits[key] ?? 0;
|
||||
const previous = lineHits[line];
|
||||
lineHits[line] = previous === undefined ? hits : Math.max(previous, hits);
|
||||
});
|
||||
|
||||
return lineHits;
|
||||
};
|
||||
|
||||
const getFileCoverage = (entry) => (
|
||||
libCoverage ? libCoverage.createFileCoverage(entry) : null
|
||||
);
|
||||
|
||||
const getLineHits = (entry, fileCoverage) => {
|
||||
const lineHits = entry.l ?? {};
|
||||
if (Object.keys(lineHits).length > 0) {
|
||||
return lineHits;
|
||||
}
|
||||
if (fileCoverage) {
|
||||
return fileCoverage.getLineCoverage();
|
||||
}
|
||||
return getLineCoverageFromStatements(entry.statementMap ?? {}, entry.s ?? {});
|
||||
};
|
||||
|
||||
const getUncoveredLines = (entry, fileCoverage, lineHits) => {
|
||||
if (lineHits && Object.keys(lineHits).length > 0) {
|
||||
return Object.entries(lineHits)
|
||||
.filter(([, count]) => count === 0)
|
||||
.map(([line]) => Number(line))
|
||||
.sort((a, b) => a - b);
|
||||
}
|
||||
if (fileCoverage) {
|
||||
return fileCoverage.getUncoveredLines();
|
||||
}
|
||||
return [];
|
||||
};
|
||||
|
||||
const totals = {
|
||||
lines: { covered: 0, total: 0 },
|
||||
statements: { covered: 0, total: 0 },
|
||||
branches: { covered: 0, total: 0 },
|
||||
functions: { covered: 0, total: 0 },
|
||||
};
|
||||
const fileSummaries = [];
|
||||
|
||||
if (summary) {
|
||||
const totalEntry = summary.total ?? {};
|
||||
['lines', 'statements', 'branches', 'functions'].forEach((key) => {
|
||||
if (totalEntry[key]) {
|
||||
totals[key].covered = totalEntry[key].covered ?? 0;
|
||||
totals[key].total = totalEntry[key].total ?? 0;
|
||||
}
|
||||
});
|
||||
|
||||
Object.entries(summary)
|
||||
.filter(([file]) => file !== 'total')
|
||||
.forEach(([file, data]) => {
|
||||
fileSummaries.push({
|
||||
file,
|
||||
pct: data.lines?.pct ?? data.statements?.pct ?? 0,
|
||||
lines: {
|
||||
covered: data.lines?.covered ?? 0,
|
||||
total: data.lines?.total ?? 0,
|
||||
},
|
||||
});
|
||||
});
|
||||
} else if (coverage) {
|
||||
Object.entries(coverage).forEach(([file, entry]) => {
|
||||
const fileCoverage = getFileCoverage(entry);
|
||||
const lineHits = getLineHits(entry, fileCoverage);
|
||||
const statementHits = entry.s ?? {};
|
||||
const branchHits = entry.b ?? {};
|
||||
const functionHits = entry.f ?? {};
|
||||
|
||||
const lineTotal = Object.keys(lineHits).length;
|
||||
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
|
||||
|
||||
const statementTotal = Object.keys(statementHits).length;
|
||||
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
|
||||
|
||||
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
|
||||
const branchCovered = Object.values(branchHits).reduce(
|
||||
(acc, branches) => acc + branches.filter((n) => n > 0).length,
|
||||
0,
|
||||
);
|
||||
|
||||
const functionTotal = Object.keys(functionHits).length;
|
||||
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
|
||||
|
||||
totals.lines.total += lineTotal;
|
||||
totals.lines.covered += lineCovered;
|
||||
totals.statements.total += statementTotal;
|
||||
totals.statements.covered += statementCovered;
|
||||
totals.branches.total += branchTotal;
|
||||
totals.branches.covered += branchCovered;
|
||||
totals.functions.total += functionTotal;
|
||||
totals.functions.covered += functionCovered;
|
||||
|
||||
const pct = (covered, tot) => (tot > 0 ? (covered / tot) * 100 : 0);
|
||||
|
||||
fileSummaries.push({
|
||||
file,
|
||||
pct: pct(lineCovered || statementCovered, lineTotal || statementTotal),
|
||||
lines: {
|
||||
covered: lineCovered || statementCovered,
|
||||
total: lineTotal || statementTotal,
|
||||
},
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
const pct = (covered, tot) => (tot > 0 ? ((covered / tot) * 100).toFixed(2) : '0.00');
|
||||
|
||||
console.log('### Test Coverage Summary :test_tube:');
|
||||
console.log('');
|
||||
console.log('| Metric | Coverage | Covered / Total |');
|
||||
console.log('|--------|----------|-----------------|');
|
||||
console.log(`| Lines | ${pct(totals.lines.covered, totals.lines.total)}% | ${totals.lines.covered} / ${totals.lines.total} |`);
|
||||
console.log(`| Statements | ${pct(totals.statements.covered, totals.statements.total)}% | ${totals.statements.covered} / ${totals.statements.total} |`);
|
||||
console.log(`| Branches | ${pct(totals.branches.covered, totals.branches.total)}% | ${totals.branches.covered} / ${totals.branches.total} |`);
|
||||
console.log(`| Functions | ${pct(totals.functions.covered, totals.functions.total)}% | ${totals.functions.covered} / ${totals.functions.total} |`);
|
||||
|
||||
console.log('');
|
||||
console.log('<details><summary>File coverage (lowest lines first)</summary>');
|
||||
console.log('');
|
||||
console.log('```');
|
||||
fileSummaries
|
||||
.sort((a, b) => (a.pct - b.pct) || (b.lines.total - a.lines.total))
|
||||
.slice(0, 25)
|
||||
.forEach(({ file, pct, lines }) => {
|
||||
console.log(`${pct.toFixed(2)}%\t${lines.covered}/${lines.total}\t${file}`);
|
||||
});
|
||||
console.log('```');
|
||||
console.log('</details>');
|
||||
|
||||
if (coverage) {
|
||||
const pctValue = (covered, tot) => {
|
||||
if (tot === 0) {
|
||||
return '0';
|
||||
}
|
||||
return ((covered / tot) * 100)
|
||||
.toFixed(2)
|
||||
.replace(/\.?0+$/, '');
|
||||
};
|
||||
|
||||
const formatLineRanges = (lines) => {
|
||||
if (lines.length === 0) {
|
||||
return '';
|
||||
}
|
||||
const ranges = [];
|
||||
let start = lines[0];
|
||||
let end = lines[0];
|
||||
|
||||
for (let i = 1; i < lines.length; i += 1) {
|
||||
const current = lines[i];
|
||||
if (current === end + 1) {
|
||||
end = current;
|
||||
continue;
|
||||
}
|
||||
ranges.push(start === end ? `${start}` : `${start}-${end}`);
|
||||
start = current;
|
||||
end = current;
|
||||
}
|
||||
ranges.push(start === end ? `${start}` : `${start}-${end}`);
|
||||
return ranges.join(',');
|
||||
};
|
||||
|
||||
const tableTotals = {
|
||||
statements: { covered: 0, total: 0 },
|
||||
branches: { covered: 0, total: 0 },
|
||||
functions: { covered: 0, total: 0 },
|
||||
lines: { covered: 0, total: 0 },
|
||||
};
|
||||
const tableRows = Object.entries(coverage)
|
||||
.map(([file, entry]) => {
|
||||
const fileCoverage = getFileCoverage(entry);
|
||||
const lineHits = getLineHits(entry, fileCoverage);
|
||||
const statementHits = entry.s ?? {};
|
||||
const branchHits = entry.b ?? {};
|
||||
const functionHits = entry.f ?? {};
|
||||
|
||||
const lineTotal = Object.keys(lineHits).length;
|
||||
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
|
||||
const statementTotal = Object.keys(statementHits).length;
|
||||
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
|
||||
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
|
||||
const branchCovered = Object.values(branchHits).reduce(
|
||||
(acc, branches) => acc + branches.filter((n) => n > 0).length,
|
||||
0,
|
||||
);
|
||||
const functionTotal = Object.keys(functionHits).length;
|
||||
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
|
||||
|
||||
tableTotals.lines.total += lineTotal;
|
||||
tableTotals.lines.covered += lineCovered;
|
||||
tableTotals.statements.total += statementTotal;
|
||||
tableTotals.statements.covered += statementCovered;
|
||||
tableTotals.branches.total += branchTotal;
|
||||
tableTotals.branches.covered += branchCovered;
|
||||
tableTotals.functions.total += functionTotal;
|
||||
tableTotals.functions.covered += functionCovered;
|
||||
|
||||
const uncoveredLines = getUncoveredLines(entry, fileCoverage, lineHits);
|
||||
|
||||
const filePath = entry.path ?? file;
|
||||
const relativePath = path.isAbsolute(filePath)
|
||||
? path.relative(process.cwd(), filePath)
|
||||
: filePath;
|
||||
|
||||
return {
|
||||
file: relativePath || file,
|
||||
statements: pctValue(statementCovered, statementTotal),
|
||||
branches: pctValue(branchCovered, branchTotal),
|
||||
functions: pctValue(functionCovered, functionTotal),
|
||||
lines: pctValue(lineCovered, lineTotal),
|
||||
uncovered: formatLineRanges(uncoveredLines),
|
||||
};
|
||||
})
|
||||
.sort((a, b) => a.file.localeCompare(b.file));
|
||||
|
||||
const columns = [
|
||||
{ key: 'file', header: 'File', align: 'left' },
|
||||
{ key: 'statements', header: '% Stmts', align: 'right' },
|
||||
{ key: 'branches', header: '% Branch', align: 'right' },
|
||||
{ key: 'functions', header: '% Funcs', align: 'right' },
|
||||
{ key: 'lines', header: '% Lines', align: 'right' },
|
||||
{ key: 'uncovered', header: 'Uncovered Line #s', align: 'left' },
|
||||
];
|
||||
|
||||
const allFilesRow = {
|
||||
file: 'All files',
|
||||
statements: pctValue(tableTotals.statements.covered, tableTotals.statements.total),
|
||||
branches: pctValue(tableTotals.branches.covered, tableTotals.branches.total),
|
||||
functions: pctValue(tableTotals.functions.covered, tableTotals.functions.total),
|
||||
lines: pctValue(tableTotals.lines.covered, tableTotals.lines.total),
|
||||
uncovered: '',
|
||||
};
|
||||
|
||||
const rowsForOutput = [allFilesRow, ...tableRows];
|
||||
const formatRow = (row) => `| ${columns
|
||||
.map(({ key }) => String(row[key] ?? ''))
|
||||
.join(' | ')} |`;
|
||||
const headerRow = `| ${columns.map(({ header }) => header).join(' | ')} |`;
|
||||
const dividerRow = `| ${columns
|
||||
.map(({ align }) => (align === 'right' ? '---:' : ':---'))
|
||||
.join(' | ')} |`;
|
||||
|
||||
console.log('');
|
||||
console.log('<details><summary>Vitest coverage table</summary>');
|
||||
console.log('');
|
||||
console.log(headerRow);
|
||||
console.log(dividerRow);
|
||||
rowsForOutput.forEach((row) => console.log(formatRow(row)));
|
||||
console.log('</details>');
|
||||
}
|
||||
NODE
|
||||
|
||||
- name: Upload Coverage Artifact
|
||||
if: steps.coverage-summary.outputs.has_coverage == 'true'
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: web-coverage-report
|
||||
path: web/coverage
|
||||
retention-days: 30
|
||||
if-no-files-found: error
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm test
|
||||
|
||||
14
.gitignore
vendored
14
.gitignore
vendored
@@ -139,6 +139,7 @@ pyrightconfig.json
|
||||
.idea/'
|
||||
|
||||
.DS_Store
|
||||
web/.vscode/settings.json
|
||||
|
||||
# Intellij IDEA Files
|
||||
.idea/*
|
||||
@@ -188,14 +189,12 @@ docker/volumes/matrixone/*
|
||||
docker/volumes/mysql/*
|
||||
docker/volumes/seekdb/*
|
||||
!docker/volumes/oceanbase/init.d
|
||||
docker/volumes/iris/*
|
||||
|
||||
docker/nginx/conf.d/default.conf
|
||||
docker/nginx/ssl/*
|
||||
!docker/nginx/ssl/.gitkeep
|
||||
docker/middleware.env
|
||||
docker/docker-compose.override.yaml
|
||||
docker/env-backup/*
|
||||
|
||||
sdks/python-client/build
|
||||
sdks/python-client/dist
|
||||
@@ -205,6 +204,7 @@ sdks/python-client/dify_client.egg-info
|
||||
!.vscode/launch.json.template
|
||||
!.vscode/README.md
|
||||
api/.vscode
|
||||
web/.vscode
|
||||
# vscode Code History Extension
|
||||
.history
|
||||
|
||||
@@ -219,6 +219,15 @@ plugins.jsonl
|
||||
# mise
|
||||
mise.toml
|
||||
|
||||
# Next.js build output
|
||||
.next/
|
||||
|
||||
# PWA generated files
|
||||
web/public/sw.js
|
||||
web/public/sw.js.map
|
||||
web/public/workbox-*.js
|
||||
web/public/workbox-*.js.map
|
||||
web/public/fallback-*.js
|
||||
|
||||
# AI Assistant
|
||||
.roo/
|
||||
@@ -235,4 +244,3 @@ scripts/stress-test/reports/
|
||||
|
||||
# settings
|
||||
*.local.json
|
||||
*.local.md
|
||||
|
||||
34
.mcp.json
Normal file
34
.mcp.json
Normal file
@@ -0,0 +1,34 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"context7": {
|
||||
"type": "http",
|
||||
"url": "https://mcp.context7.com/mcp"
|
||||
},
|
||||
"sequential-thinking": {
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-sequential-thinking"],
|
||||
"env": {}
|
||||
},
|
||||
"github": {
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-github"],
|
||||
"env": {
|
||||
"GITHUB_PERSONAL_ACCESS_TOKEN": "${GITHUB_PERSONAL_ACCESS_TOKEN}"
|
||||
}
|
||||
},
|
||||
"fetch": {
|
||||
"type": "stdio",
|
||||
"command": "uvx",
|
||||
"args": ["mcp-server-fetch"],
|
||||
"env": {}
|
||||
},
|
||||
"playwright": {
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"args": ["-y", "@playwright/mcp@latest"],
|
||||
"env": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
2
.vscode/launch.json.template
vendored
2
.vscode/launch.json.template
vendored
@@ -37,7 +37,7 @@
|
||||
"-c",
|
||||
"1",
|
||||
"-Q",
|
||||
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention",
|
||||
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor",
|
||||
"--loglevel",
|
||||
"INFO"
|
||||
],
|
||||
|
||||
5
.windsurf/rules/testing.md
Normal file
5
.windsurf/rules/testing.md
Normal file
@@ -0,0 +1,5 @@
|
||||
# Windsurf Testing Rules
|
||||
|
||||
- Use `web/testing/testing.md` as the single source of truth for frontend automated testing.
|
||||
- Honor every requirement in that document when generating or accepting tests.
|
||||
- When proposing or saving tests, re-read that document and follow every requirement.
|
||||
@@ -24,8 +24,8 @@ The codebase is split into:
|
||||
|
||||
```bash
|
||||
cd web
|
||||
pnpm lint
|
||||
pnpm lint:fix
|
||||
pnpm type-check:tsgo
|
||||
pnpm test
|
||||
```
|
||||
|
||||
@@ -39,7 +39,7 @@ pnpm test
|
||||
## Language Style
|
||||
|
||||
- **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`).
|
||||
- **TypeScript**: Use the strict config, rely on ESLint (`pnpm lint:fix` preferred) plus `pnpm type-check:tsgo`, and avoid `any` types.
|
||||
- **TypeScript**: Use the strict config, lean on ESLint + Prettier workflows, and avoid `any` types.
|
||||
|
||||
## General Practices
|
||||
|
||||
|
||||
13
README.md
13
README.md
@@ -139,19 +139,6 @@ Star Dify on GitHub and be instantly notified of new releases.
|
||||
|
||||
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
|
||||
|
||||
#### Customizing Suggested Questions
|
||||
|
||||
You can now customize the "Suggested Questions After Answer" feature to better fit your use case. For example, to generate longer, more technical questions:
|
||||
|
||||
```bash
|
||||
# In your .env file
|
||||
SUGGESTED_QUESTIONS_PROMPT='Please help me predict the five most likely technical follow-up questions a developer would ask. Focus on implementation details, best practices, and architecture considerations. Keep each question between 40-60 characters. Output must be JSON array: ["question1","question2","question3","question4","question5"]'
|
||||
SUGGESTED_QUESTIONS_MAX_TOKENS=512
|
||||
SUGGESTED_QUESTIONS_TEMPERATURE=0.3
|
||||
```
|
||||
|
||||
See the [Suggested Questions Configuration Guide](docs/suggested-questions-configuration.md) for detailed examples and usage instructions.
|
||||
|
||||
### Metrics Monitoring with Grafana
|
||||
|
||||
Import the dashboard to Grafana, using Dify's PostgreSQL database as data source, to monitor metrics in granularity of apps, tenants, messages, and more.
|
||||
|
||||
@@ -101,15 +101,6 @@ S3_ACCESS_KEY=your-access-key
|
||||
S3_SECRET_KEY=your-secret-key
|
||||
S3_REGION=your-region
|
||||
|
||||
# Workflow run and Conversation archive storage (S3-compatible)
|
||||
ARCHIVE_STORAGE_ENABLED=false
|
||||
ARCHIVE_STORAGE_ENDPOINT=
|
||||
ARCHIVE_STORAGE_ARCHIVE_BUCKET=
|
||||
ARCHIVE_STORAGE_EXPORT_BUCKET=
|
||||
ARCHIVE_STORAGE_ACCESS_KEY=
|
||||
ARCHIVE_STORAGE_SECRET_KEY=
|
||||
ARCHIVE_STORAGE_REGION=auto
|
||||
|
||||
# Azure Blob Storage configuration
|
||||
AZURE_BLOB_ACCOUNT_NAME=your-account-name
|
||||
AZURE_BLOB_ACCOUNT_KEY=your-account-key
|
||||
@@ -125,7 +116,6 @@ ALIYUN_OSS_AUTH_VERSION=v1
|
||||
ALIYUN_OSS_REGION=your-region
|
||||
# Don't start with '/'. OSS doesn't support leading slash in object names.
|
||||
ALIYUN_OSS_PATH=your-path
|
||||
ALIYUN_CLOUDBOX_ID=your-cloudbox-id
|
||||
|
||||
# Google Storage configuration
|
||||
GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name
|
||||
@@ -137,14 +127,12 @@ TENCENT_COS_SECRET_KEY=your-secret-key
|
||||
TENCENT_COS_SECRET_ID=your-secret-id
|
||||
TENCENT_COS_REGION=your-region
|
||||
TENCENT_COS_SCHEME=your-scheme
|
||||
TENCENT_COS_CUSTOM_DOMAIN=your-custom-domain
|
||||
|
||||
# Huawei OBS Storage Configuration
|
||||
HUAWEI_OBS_BUCKET_NAME=your-bucket-name
|
||||
HUAWEI_OBS_SECRET_KEY=your-secret-key
|
||||
HUAWEI_OBS_ACCESS_KEY=your-access-key
|
||||
HUAWEI_OBS_SERVER=your-server-url
|
||||
HUAWEI_OBS_PATH_STYLE=false
|
||||
|
||||
# Baidu OBS Storage Configuration
|
||||
BAIDU_OBS_BUCKET_NAME=your-bucket-name
|
||||
@@ -502,8 +490,6 @@ LOG_FILE_BACKUP_COUNT=5
|
||||
LOG_DATEFORMAT=%Y-%m-%d %H:%M:%S
|
||||
# Log Timezone
|
||||
LOG_TZ=UTC
|
||||
# Log output format: text or json
|
||||
LOG_OUTPUT_FORMAT=text
|
||||
# Log format
|
||||
LOG_FORMAT=%(asctime)s,%(msecs)d %(levelname)-2s [%(filename)s:%(lineno)d] %(req_id)s %(message)s
|
||||
|
||||
@@ -557,25 +543,6 @@ APP_MAX_EXECUTION_TIME=1200
|
||||
APP_DEFAULT_ACTIVE_REQUESTS=0
|
||||
APP_MAX_ACTIVE_REQUESTS=0
|
||||
|
||||
# Aliyun SLS Logstore Configuration
|
||||
# Aliyun Access Key ID
|
||||
ALIYUN_SLS_ACCESS_KEY_ID=
|
||||
# Aliyun Access Key Secret
|
||||
ALIYUN_SLS_ACCESS_KEY_SECRET=
|
||||
# Aliyun SLS Endpoint (e.g., cn-hangzhou.log.aliyuncs.com)
|
||||
ALIYUN_SLS_ENDPOINT=
|
||||
# Aliyun SLS Region (e.g., cn-hangzhou)
|
||||
ALIYUN_SLS_REGION=
|
||||
# Aliyun SLS Project Name
|
||||
ALIYUN_SLS_PROJECT_NAME=
|
||||
# Number of days to retain workflow run logs (default: 365 days, 3650 for permanent storage)
|
||||
ALIYUN_SLS_LOGSTORE_TTL=365
|
||||
# Enable dual-write to both SLS LogStore and SQL database (default: false)
|
||||
LOGSTORE_DUAL_WRITE_ENABLED=false
|
||||
# Enable dual-read fallback to SQL database when LogStore returns no results (default: true)
|
||||
# Useful for migration scenarios where historical data exists only in SQL database
|
||||
LOGSTORE_DUAL_READ_ENABLED=true
|
||||
|
||||
# Celery beat configuration
|
||||
CELERY_BEAT_SCHEDULER_TIME=1
|
||||
|
||||
@@ -666,45 +633,8 @@ SWAGGER_UI_PATH=/swagger-ui.html
|
||||
# Set to false to export dataset IDs as plain text for easier cross-environment import
|
||||
DSL_EXPORT_ENCRYPT_DATASET_ID=true
|
||||
|
||||
# Suggested Questions After Answer Configuration
|
||||
# These environment variables allow customization of the suggested questions feature
|
||||
#
|
||||
# Custom prompt for generating suggested questions (optional)
|
||||
# If not set, uses the default prompt that generates 3 questions under 20 characters each
|
||||
# Example: "Please help me predict the five most likely technical follow-up questions a developer would ask. Focus on implementation details, best practices, and architecture considerations. Keep each question between 40-60 characters. Output must be JSON array: [\"question1\",\"question2\",\"question3\",\"question4\",\"question5\"]"
|
||||
# SUGGESTED_QUESTIONS_PROMPT=
|
||||
|
||||
# Maximum number of tokens for suggested questions generation (default: 256)
|
||||
# Adjust this value for longer questions or more questions
|
||||
# SUGGESTED_QUESTIONS_MAX_TOKENS=256
|
||||
|
||||
# Temperature for suggested questions generation (default: 0.0)
|
||||
# Higher values (0.5-1.0) produce more creative questions, lower values (0.0-0.3) produce more focused questions
|
||||
# SUGGESTED_QUESTIONS_TEMPERATURE=0
|
||||
|
||||
# Tenant isolated task queue configuration
|
||||
TENANT_ISOLATED_TASK_CONCURRENCY=1
|
||||
|
||||
# Maximum number of segments for dataset segments API (0 for unlimited)
|
||||
DATASET_MAX_SEGMENTS_PER_REQUEST=0
|
||||
|
||||
# Multimodal knowledgebase limit
|
||||
SINGLE_CHUNK_ATTACHMENT_LIMIT=10
|
||||
ATTACHMENT_IMAGE_FILE_SIZE_LIMIT=2
|
||||
ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT=60
|
||||
IMAGE_FILE_BATCH_LIMIT=10
|
||||
|
||||
# Maximum allowed CSV file size for annotation import in megabytes
|
||||
ANNOTATION_IMPORT_FILE_SIZE_LIMIT=2
|
||||
#Maximum number of annotation records allowed in a single import
|
||||
ANNOTATION_IMPORT_MAX_RECORDS=10000
|
||||
# Minimum number of annotation records required in a single import
|
||||
ANNOTATION_IMPORT_MIN_RECORDS=1
|
||||
ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE=5
|
||||
ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20
|
||||
# Maximum number of concurrent annotation import tasks per tenant
|
||||
ANNOTATION_IMPORT_MAX_CONCURRENT=5
|
||||
# Sandbox expired records clean configuration
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
|
||||
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30
|
||||
|
||||
@@ -1,8 +1,4 @@
|
||||
exclude = [
|
||||
"migrations/*",
|
||||
".git",
|
||||
".git/**",
|
||||
]
|
||||
exclude = ["migrations/*"]
|
||||
line-length = 120
|
||||
|
||||
[format]
|
||||
@@ -40,20 +36,17 @@ select = [
|
||||
"UP", # pyupgrade rules
|
||||
"W191", # tab-indentation
|
||||
"W605", # invalid-escape-sequence
|
||||
"G001", # don't use str format to logging messages
|
||||
"G003", # don't use + in logging messages
|
||||
"G004", # don't use f-strings to format logging messages
|
||||
"UP042", # use StrEnum,
|
||||
"S110", # disallow the try-except-pass pattern.
|
||||
|
||||
# security related linting rules
|
||||
# RCE proctection (sort of)
|
||||
"S102", # exec-builtin, disallow use of `exec`
|
||||
"S307", # suspicious-eval-usage, disallow use of `eval` and `ast.literal_eval`
|
||||
"S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers.
|
||||
"S302", # suspicious-marshal-usage, disallow use of `marshal` module
|
||||
"S311", # suspicious-non-cryptographic-random-usage,
|
||||
|
||||
"S311", # suspicious-non-cryptographic-random-usage
|
||||
"G001", # don't use str format to logging messages
|
||||
"G003", # don't use + in logging messages
|
||||
"G004", # don't use f-strings to format logging messages
|
||||
"UP042", # use StrEnum
|
||||
]
|
||||
|
||||
ignore = [
|
||||
@@ -98,16 +91,18 @@ ignore = [
|
||||
"configs/*" = [
|
||||
"N802", # invalid-function-name
|
||||
]
|
||||
"core/model_runtime/callbacks/base_callback.py" = ["T201"]
|
||||
"core/workflow/callbacks/workflow_logging_callback.py" = ["T201"]
|
||||
"core/model_runtime/callbacks/base_callback.py" = [
|
||||
"T201",
|
||||
]
|
||||
"core/workflow/callbacks/workflow_logging_callback.py" = [
|
||||
"T201",
|
||||
]
|
||||
"libs/gmpy2_pkcs10aep_cipher.py" = [
|
||||
"N803", # invalid-argument-name
|
||||
]
|
||||
"tests/*" = [
|
||||
"F811", # redefined-while-unused
|
||||
"T201", # allow print in tests,
|
||||
"S110", # allow ignoring exceptions in tests code (currently)
|
||||
|
||||
"T201", # allow print in tests
|
||||
]
|
||||
|
||||
[lint.pyflakes]
|
||||
|
||||
@@ -84,7 +84,7 @@
|
||||
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
|
||||
|
||||
```bash
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor
|
||||
```
|
||||
|
||||
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
from opentelemetry.trace import get_current_span
|
||||
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
|
||||
|
||||
from configs import dify_config
|
||||
from contexts.wrapper import RecyclableContextVar
|
||||
from core.logging.context import init_request_context
|
||||
from dify_app import DifyApp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -27,35 +23,11 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||
# add before request hook
|
||||
@dify_app.before_request
|
||||
def before_request():
|
||||
# Initialize logging context for this request
|
||||
init_request_context()
|
||||
# add an unique identifier to each request
|
||||
RecyclableContextVar.increment_thread_recycles()
|
||||
|
||||
# add after request hook for injecting trace headers from OpenTelemetry span context
|
||||
# Only adds headers when OTEL is enabled and has valid context
|
||||
@dify_app.after_request
|
||||
def add_trace_headers(response):
|
||||
try:
|
||||
span = get_current_span()
|
||||
ctx = span.get_span_context() if span else None
|
||||
|
||||
if not ctx or not ctx.is_valid:
|
||||
return response
|
||||
|
||||
# Inject trace headers from OTEL context
|
||||
if ctx.trace_id != INVALID_TRACE_ID and "X-Trace-Id" not in response.headers:
|
||||
response.headers["X-Trace-Id"] = format(ctx.trace_id, "032x")
|
||||
if ctx.span_id != INVALID_SPAN_ID and "X-Span-Id" not in response.headers:
|
||||
response.headers["X-Span-Id"] = format(ctx.span_id, "016x")
|
||||
|
||||
except Exception:
|
||||
# Never break the response due to tracing header injection
|
||||
logger.warning("Failed to add trace headers to response", exc_info=True)
|
||||
return response
|
||||
|
||||
# Capture the decorator's return value to avoid pyright reportUnusedFunction
|
||||
_ = before_request
|
||||
_ = add_trace_headers
|
||||
|
||||
return dify_app
|
||||
|
||||
@@ -84,7 +56,6 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_import_modules,
|
||||
ext_logging,
|
||||
ext_login,
|
||||
ext_logstore,
|
||||
ext_mail,
|
||||
ext_migrate,
|
||||
ext_orjson,
|
||||
@@ -93,7 +64,6 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_redis,
|
||||
ext_request_logging,
|
||||
ext_sentry,
|
||||
ext_session_factory,
|
||||
ext_set_secretkey,
|
||||
ext_storage,
|
||||
ext_timezone,
|
||||
@@ -115,7 +85,6 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_migrate,
|
||||
ext_redis,
|
||||
ext_storage,
|
||||
ext_logstore, # Initialize logstore after storage, before celery
|
||||
ext_celery,
|
||||
ext_login,
|
||||
ext_mail,
|
||||
@@ -126,7 +95,6 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_commands,
|
||||
ext_otel,
|
||||
ext_request_logging,
|
||||
ext_session_factory,
|
||||
]
|
||||
for ext in extensions:
|
||||
short_name = ext.__name__.split(".")[-1]
|
||||
|
||||
@@ -1139,7 +1139,6 @@ def remove_orphaned_files_on_storage(force: bool):
|
||||
click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error fetching keys: {str(e)}", fg="red"))
|
||||
return
|
||||
|
||||
all_files_on_storage = []
|
||||
for storage_path in storage_paths:
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
from configs.extra.archive_config import ArchiveStorageConfig
|
||||
from configs.extra.notion_config import NotionConfig
|
||||
from configs.extra.sentry_config import SentryConfig
|
||||
|
||||
|
||||
class ExtraServiceConfig(
|
||||
# place the configs in alphabet order
|
||||
ArchiveStorageConfig,
|
||||
NotionConfig,
|
||||
SentryConfig,
|
||||
):
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class ArchiveStorageConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for workflow run logs archiving storage.
|
||||
"""
|
||||
|
||||
ARCHIVE_STORAGE_ENABLED: bool = Field(
|
||||
description="Enable workflow run logs archiving to S3-compatible storage",
|
||||
default=False,
|
||||
)
|
||||
|
||||
ARCHIVE_STORAGE_ENDPOINT: str | None = Field(
|
||||
description="URL of the S3-compatible storage endpoint (e.g., 'https://storage.example.com')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ARCHIVE_STORAGE_ARCHIVE_BUCKET: str | None = Field(
|
||||
description="Name of the bucket to store archived workflow logs",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ARCHIVE_STORAGE_EXPORT_BUCKET: str | None = Field(
|
||||
description="Name of the bucket to store exported workflow runs",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ARCHIVE_STORAGE_ACCESS_KEY: str | None = Field(
|
||||
description="Access key ID for authenticating with storage",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ARCHIVE_STORAGE_SECRET_KEY: str | None = Field(
|
||||
description="Secret access key for authenticating with storage",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ARCHIVE_STORAGE_REGION: str = Field(
|
||||
description="Region for storage (use 'auto' if the provider supports it)",
|
||||
default="auto",
|
||||
)
|
||||
@@ -218,7 +218,7 @@ class PluginConfig(BaseSettings):
|
||||
|
||||
PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field(
|
||||
description="Timeout in seconds for requests to the plugin daemon (set to None to disable)",
|
||||
default=600.0,
|
||||
default=300.0,
|
||||
)
|
||||
|
||||
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
|
||||
@@ -360,57 +360,6 @@ class FileUploadConfig(BaseSettings):
|
||||
default=10,
|
||||
)
|
||||
|
||||
IMAGE_FILE_BATCH_LIMIT: PositiveInt = Field(
|
||||
description="Maximum number of files allowed in a image batch upload operation",
|
||||
default=10,
|
||||
)
|
||||
|
||||
SINGLE_CHUNK_ATTACHMENT_LIMIT: PositiveInt = Field(
|
||||
description="Maximum number of files allowed in a single chunk attachment",
|
||||
default=10,
|
||||
)
|
||||
|
||||
ATTACHMENT_IMAGE_FILE_SIZE_LIMIT: NonNegativeInt = Field(
|
||||
description="Maximum allowed image file size for attachments in megabytes",
|
||||
default=2,
|
||||
)
|
||||
|
||||
ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT: NonNegativeInt = Field(
|
||||
description="Timeout for downloading image attachments in seconds",
|
||||
default=60,
|
||||
)
|
||||
|
||||
# Annotation Import Security Configurations
|
||||
ANNOTATION_IMPORT_FILE_SIZE_LIMIT: NonNegativeInt = Field(
|
||||
description="Maximum allowed CSV file size for annotation import in megabytes",
|
||||
default=2,
|
||||
)
|
||||
|
||||
ANNOTATION_IMPORT_MAX_RECORDS: PositiveInt = Field(
|
||||
description="Maximum number of annotation records allowed in a single import",
|
||||
default=10000,
|
||||
)
|
||||
|
||||
ANNOTATION_IMPORT_MIN_RECORDS: PositiveInt = Field(
|
||||
description="Minimum number of annotation records required in a single import",
|
||||
default=1,
|
||||
)
|
||||
|
||||
ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE: PositiveInt = Field(
|
||||
description="Maximum number of annotation import requests per minute per tenant",
|
||||
default=5,
|
||||
)
|
||||
|
||||
ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR: PositiveInt = Field(
|
||||
description="Maximum number of annotation import requests per hour per tenant",
|
||||
default=20,
|
||||
)
|
||||
|
||||
ANNOTATION_IMPORT_MAX_CONCURRENT: PositiveInt = Field(
|
||||
description="Maximum number of concurrent annotation import tasks per tenant",
|
||||
default=2,
|
||||
)
|
||||
|
||||
inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field(
|
||||
description=(
|
||||
"Comma-separated list of file extensions that are blocked from upload. "
|
||||
@@ -587,11 +536,6 @@ class LoggingConfig(BaseSettings):
|
||||
default="INFO",
|
||||
)
|
||||
|
||||
LOG_OUTPUT_FORMAT: Literal["text", "json"] = Field(
|
||||
description="Log output format: 'text' for human-readable, 'json' for structured JSON logs.",
|
||||
default="text",
|
||||
)
|
||||
|
||||
LOG_FILE: str | None = Field(
|
||||
description="File path for log output.",
|
||||
default=None,
|
||||
@@ -609,10 +553,7 @@ class LoggingConfig(BaseSettings):
|
||||
|
||||
LOG_FORMAT: str = Field(
|
||||
description="Format string for log messages",
|
||||
default=(
|
||||
"%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] "
|
||||
"[%(filename)s:%(lineno)d] %(trace_id)s - %(message)s"
|
||||
),
|
||||
default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s",
|
||||
)
|
||||
|
||||
LOG_DATEFORMAT: str | None = Field(
|
||||
@@ -1275,21 +1216,6 @@ class TenantIsolatedTaskQueueConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class SandboxExpiredRecordsCleanConfig(BaseSettings):
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: NonNegativeInt = Field(
|
||||
description="Graceful period in days for sandbox records clean after subscription expiration",
|
||||
default=21,
|
||||
)
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: PositiveInt = Field(
|
||||
description="Maximum number of records to process in each batch",
|
||||
default=1000,
|
||||
)
|
||||
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: PositiveInt = Field(
|
||||
description="Retention days for sandbox expired workflow_run records and message records",
|
||||
default=30,
|
||||
)
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
# place the configs in alphabet order
|
||||
AppExecutionConfig,
|
||||
@@ -1315,7 +1241,6 @@ class FeatureConfig(
|
||||
PositionConfig,
|
||||
RagEtlConfig,
|
||||
RepositoryConfig,
|
||||
SandboxExpiredRecordsCleanConfig,
|
||||
SecurityConfig,
|
||||
TenantIsolatedTaskQueueConfig,
|
||||
ToolConfig,
|
||||
|
||||
@@ -26,7 +26,6 @@ from .vdb.clickzetta_config import ClickzettaConfig
|
||||
from .vdb.couchbase_config import CouchbaseConfig
|
||||
from .vdb.elasticsearch_config import ElasticsearchConfig
|
||||
from .vdb.huawei_cloud_config import HuaweiCloudConfig
|
||||
from .vdb.iris_config import IrisVectorConfig
|
||||
from .vdb.lindorm_config import LindormConfig
|
||||
from .vdb.matrixone_config import MatrixoneConfig
|
||||
from .vdb.milvus_config import MilvusConfig
|
||||
@@ -107,7 +106,7 @@ class KeywordStoreConfig(BaseSettings):
|
||||
|
||||
class DatabaseConfig(BaseSettings):
|
||||
# Database type selector
|
||||
DB_TYPE: Literal["postgresql", "mysql", "oceanbase", "seekdb"] = Field(
|
||||
DB_TYPE: Literal["postgresql", "mysql", "oceanbase"] = Field(
|
||||
description="Database type to use. OceanBase is MySQL-compatible.",
|
||||
default="postgresql",
|
||||
)
|
||||
@@ -337,7 +336,6 @@ class MiddlewareConfig(
|
||||
ChromaConfig,
|
||||
ClickzettaConfig,
|
||||
HuaweiCloudConfig,
|
||||
IrisVectorConfig,
|
||||
MilvusConfig,
|
||||
AlibabaCloudMySQLConfig,
|
||||
MyScaleConfig,
|
||||
|
||||
@@ -41,8 +41,3 @@ class AliyunOSSStorageConfig(BaseSettings):
|
||||
description="Base path within the bucket to store objects (e.g., 'my-app-data/')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ALIYUN_CLOUDBOX_ID: str | None = Field(
|
||||
description="Cloudbox id for aliyun cloudbox service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
@@ -26,8 +26,3 @@ class HuaweiCloudOBSStorageConfig(BaseSettings):
|
||||
description="Endpoint URL for Huawei Cloud OBS (e.g., 'https://obs.cn-north-4.myhuaweicloud.com')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HUAWEI_OBS_PATH_STYLE: bool = Field(
|
||||
description="Flag to indicate whether to use path-style URLs for OBS requests",
|
||||
default=False,
|
||||
)
|
||||
|
||||
@@ -31,8 +31,3 @@ class TencentCloudCOSStorageConfig(BaseSettings):
|
||||
description="Protocol scheme for COS requests: 'https' (recommended) or 'http'",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TENCENT_COS_CUSTOM_DOMAIN: str | None = Field(
|
||||
description="Tencent Cloud COS custom domain setting",
|
||||
default=None,
|
||||
)
|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
"""Configuration for InterSystems IRIS vector database."""
|
||||
|
||||
from pydantic import Field, PositiveInt, model_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class IrisVectorConfig(BaseSettings):
|
||||
"""Configuration settings for IRIS vector database connection and pooling."""
|
||||
|
||||
IRIS_HOST: str | None = Field(
|
||||
description="Hostname or IP address of the IRIS server.",
|
||||
default="localhost",
|
||||
)
|
||||
|
||||
IRIS_SUPER_SERVER_PORT: PositiveInt | None = Field(
|
||||
description="Port number for IRIS connection.",
|
||||
default=1972,
|
||||
)
|
||||
|
||||
IRIS_USER: str | None = Field(
|
||||
description="Username for IRIS authentication.",
|
||||
default="_SYSTEM",
|
||||
)
|
||||
|
||||
IRIS_PASSWORD: str | None = Field(
|
||||
description="Password for IRIS authentication.",
|
||||
default="Dify@1234",
|
||||
)
|
||||
|
||||
IRIS_SCHEMA: str | None = Field(
|
||||
description="Schema name for IRIS tables.",
|
||||
default="dify",
|
||||
)
|
||||
|
||||
IRIS_DATABASE: str | None = Field(
|
||||
description="Database namespace for IRIS connection.",
|
||||
default="USER",
|
||||
)
|
||||
|
||||
IRIS_CONNECTION_URL: str | None = Field(
|
||||
description="Full connection URL for IRIS (overrides individual fields if provided).",
|
||||
default=None,
|
||||
)
|
||||
|
||||
IRIS_MIN_CONNECTION: PositiveInt = Field(
|
||||
description="Minimum number of connections in the pool.",
|
||||
default=1,
|
||||
)
|
||||
|
||||
IRIS_MAX_CONNECTION: PositiveInt = Field(
|
||||
description="Maximum number of connections in the pool.",
|
||||
default=3,
|
||||
)
|
||||
|
||||
IRIS_TEXT_INDEX: bool = Field(
|
||||
description="Enable full-text search index using %iFind.Index.Basic.",
|
||||
default=True,
|
||||
)
|
||||
|
||||
IRIS_TEXT_INDEX_LANGUAGE: str = Field(
|
||||
description="Language for full-text search index (e.g., 'en', 'ja', 'zh', 'de').",
|
||||
default="en",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
"""Validate IRIS configuration values.
|
||||
|
||||
Args:
|
||||
values: Configuration dictionary
|
||||
|
||||
Returns:
|
||||
Validated configuration dictionary
|
||||
|
||||
Raises:
|
||||
ValueError: If required fields are missing or pool settings are invalid
|
||||
"""
|
||||
# Only validate required fields if IRIS is being used as the vector store
|
||||
# This allows the config to be loaded even when IRIS is not in use
|
||||
|
||||
# vector_store = os.environ.get("VECTOR_STORE", "")
|
||||
# We rely on Pydantic defaults for required fields if they are missing from env.
|
||||
# Strict existence check is removed to allow defaults to work.
|
||||
|
||||
min_conn = values.get("IRIS_MIN_CONNECTION", 1)
|
||||
max_conn = values.get("IRIS_MAX_CONNECTION", 3)
|
||||
if min_conn > max_conn:
|
||||
raise ValueError("IRIS_MIN_CONNECTION must be less than or equal to IRIS_MAX_CONNECTION")
|
||||
|
||||
return values
|
||||
@@ -20,7 +20,6 @@ language_timezone_mapping = {
|
||||
"sl-SI": "Europe/Ljubljana",
|
||||
"th-TH": "Asia/Bangkok",
|
||||
"id-ID": "Asia/Jakarta",
|
||||
"ar-TN": "Africa/Tunis",
|
||||
}
|
||||
|
||||
languages = list(language_timezone_mapping.keys())
|
||||
|
||||
@@ -1,59 +1,62 @@
|
||||
from __future__ import annotations
|
||||
from flask_restx import Api, Namespace, fields
|
||||
|
||||
from typing import Any, TypeAlias
|
||||
from libs.helper import AppIconUrlField
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, computed_field
|
||||
|
||||
from core.file import helpers as file_helpers
|
||||
from models.model import IconType
|
||||
|
||||
JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]
|
||||
JSONObject: TypeAlias = dict[str, Any]
|
||||
parameters__system_parameters = {
|
||||
"image_file_size_limit": fields.Integer,
|
||||
"video_file_size_limit": fields.Integer,
|
||||
"audio_file_size_limit": fields.Integer,
|
||||
"file_size_limit": fields.Integer,
|
||||
"workflow_file_upload_limit": fields.Integer,
|
||||
}
|
||||
|
||||
|
||||
class SystemParameters(BaseModel):
|
||||
image_file_size_limit: int
|
||||
video_file_size_limit: int
|
||||
audio_file_size_limit: int
|
||||
file_size_limit: int
|
||||
workflow_file_upload_limit: int
|
||||
def build_system_parameters_model(api_or_ns: Api | Namespace):
|
||||
"""Build the system parameters model for the API or Namespace."""
|
||||
return api_or_ns.model("SystemParameters", parameters__system_parameters)
|
||||
|
||||
|
||||
class Parameters(BaseModel):
|
||||
opening_statement: str | None = None
|
||||
suggested_questions: list[str]
|
||||
suggested_questions_after_answer: JSONObject
|
||||
speech_to_text: JSONObject
|
||||
text_to_speech: JSONObject
|
||||
retriever_resource: JSONObject
|
||||
annotation_reply: JSONObject
|
||||
more_like_this: JSONObject
|
||||
user_input_form: list[JSONObject]
|
||||
sensitive_word_avoidance: JSONObject
|
||||
file_upload: JSONObject
|
||||
system_parameters: SystemParameters
|
||||
parameters_fields = {
|
||||
"opening_statement": fields.String,
|
||||
"suggested_questions": fields.Raw,
|
||||
"suggested_questions_after_answer": fields.Raw,
|
||||
"speech_to_text": fields.Raw,
|
||||
"text_to_speech": fields.Raw,
|
||||
"retriever_resource": fields.Raw,
|
||||
"annotation_reply": fields.Raw,
|
||||
"more_like_this": fields.Raw,
|
||||
"user_input_form": fields.Raw,
|
||||
"sensitive_word_avoidance": fields.Raw,
|
||||
"file_upload": fields.Raw,
|
||||
"system_parameters": fields.Nested(parameters__system_parameters),
|
||||
}
|
||||
|
||||
|
||||
class Site(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
def build_parameters_model(api_or_ns: Api | Namespace):
|
||||
"""Build the parameters model for the API or Namespace."""
|
||||
copied_fields = parameters_fields.copy()
|
||||
copied_fields["system_parameters"] = fields.Nested(build_system_parameters_model(api_or_ns))
|
||||
return api_or_ns.model("Parameters", copied_fields)
|
||||
|
||||
title: str
|
||||
chat_color_theme: str | None = None
|
||||
chat_color_theme_inverted: bool
|
||||
icon_type: str | None = None
|
||||
icon: str | None = None
|
||||
icon_background: str | None = None
|
||||
description: str | None = None
|
||||
copyright: str | None = None
|
||||
privacy_policy: str | None = None
|
||||
custom_disclaimer: str | None = None
|
||||
default_language: str
|
||||
show_workflow_steps: bool
|
||||
use_icon_as_answer_icon: bool
|
||||
|
||||
@computed_field(return_type=str | None) # type: ignore
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
if self.icon and self.icon_type == IconType.IMAGE:
|
||||
return file_helpers.get_signed_file_url(self.icon)
|
||||
return None
|
||||
site_fields = {
|
||||
"title": fields.String,
|
||||
"chat_color_theme": fields.String,
|
||||
"chat_color_theme_inverted": fields.Boolean,
|
||||
"icon_type": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"icon_url": AppIconUrlField,
|
||||
"description": fields.String,
|
||||
"copyright": fields.String,
|
||||
"privacy_policy": fields.String,
|
||||
"custom_disclaimer": fields.String,
|
||||
"default_language": fields.String,
|
||||
"show_workflow_steps": fields.Boolean,
|
||||
"use_icon_as_answer_icon": fields.Boolean,
|
||||
}
|
||||
|
||||
|
||||
def build_site_model(api_or_ns: Api | Namespace):
|
||||
"""Build the site model for the API or Namespace."""
|
||||
return api_or_ns.model("Site", site_fields)
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
import os
|
||||
from email.message import Message
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import Response
|
||||
|
||||
HTML_MIME_TYPES = frozenset({"text/html", "application/xhtml+xml"})
|
||||
HTML_EXTENSIONS = frozenset({"html", "htm"})
|
||||
|
||||
|
||||
def _normalize_mime_type(mime_type: str | None) -> str:
|
||||
if not mime_type:
|
||||
return ""
|
||||
message = Message()
|
||||
message["Content-Type"] = mime_type
|
||||
return message.get_content_type().strip().lower()
|
||||
|
||||
|
||||
def _is_html_extension(extension: str | None) -> bool:
|
||||
if not extension:
|
||||
return False
|
||||
return extension.lstrip(".").lower() in HTML_EXTENSIONS
|
||||
|
||||
|
||||
def is_html_content(mime_type: str | None, filename: str | None, extension: str | None = None) -> bool:
|
||||
normalized_mime_type = _normalize_mime_type(mime_type)
|
||||
if normalized_mime_type in HTML_MIME_TYPES:
|
||||
return True
|
||||
|
||||
if _is_html_extension(extension):
|
||||
return True
|
||||
|
||||
if filename:
|
||||
return _is_html_extension(os.path.splitext(filename)[1])
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def enforce_download_for_html(
|
||||
response: Response,
|
||||
*,
|
||||
mime_type: str | None,
|
||||
filename: str | None,
|
||||
extension: str | None = None,
|
||||
) -> bool:
|
||||
if not is_html_content(mime_type, filename, extension):
|
||||
return False
|
||||
|
||||
if filename:
|
||||
encoded_filename = quote(filename)
|
||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
else:
|
||||
response.headers["Content-Disposition"] = "attachment"
|
||||
|
||||
response.headers["Content-Type"] = "application/octet-stream"
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
return True
|
||||
@@ -1,26 +0,0 @@
|
||||
"""Helpers for registering Pydantic models with Flask-RESTX namespaces."""
|
||||
|
||||
from flask_restx import Namespace
|
||||
from pydantic import BaseModel
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
def register_schema_model(namespace: Namespace, model: type[BaseModel]) -> None:
|
||||
"""Register a single BaseModel with a namespace for Swagger documentation."""
|
||||
|
||||
namespace.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
def register_schema_models(namespace: Namespace, *models: type[BaseModel]) -> None:
|
||||
"""Register multiple BaseModels with a namespace."""
|
||||
|
||||
for model in models:
|
||||
register_schema_model(namespace, model)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_REF_TEMPLATE_SWAGGER_2_0",
|
||||
"register_schema_model",
|
||||
"register_schema_models",
|
||||
]
|
||||
@@ -3,47 +3,21 @@ from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import only_edition_cloud
|
||||
from core.db.session_factory import session_factory
|
||||
from extensions.ext_database import db
|
||||
from libs.token import extract_access_token
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class InsertExploreAppPayload(BaseModel):
|
||||
app_id: str = Field(...)
|
||||
desc: str | None = None
|
||||
copyright: str | None = None
|
||||
privacy_policy: str | None = None
|
||||
custom_disclaimer: str | None = None
|
||||
language: str = Field(...)
|
||||
category: str = Field(...)
|
||||
position: int = Field(...)
|
||||
|
||||
@field_validator("language")
|
||||
@classmethod
|
||||
def validate_language(cls, value: str) -> str:
|
||||
return supported_language(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
InsertExploreAppPayload.__name__,
|
||||
InsertExploreAppPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
def admin_required(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
@@ -66,34 +40,59 @@ def admin_required(view: Callable[P, R]):
|
||||
class InsertExploreAppListApi(Resource):
|
||||
@console_ns.doc("insert_explore_app")
|
||||
@console_ns.doc(description="Insert or update an app in the explore list")
|
||||
@console_ns.expect(console_ns.models[InsertExploreAppPayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"InsertExploreAppRequest",
|
||||
{
|
||||
"app_id": fields.String(required=True, description="Application ID"),
|
||||
"desc": fields.String(description="App description"),
|
||||
"copyright": fields.String(description="Copyright information"),
|
||||
"privacy_policy": fields.String(description="Privacy policy"),
|
||||
"custom_disclaimer": fields.String(description="Custom disclaimer"),
|
||||
"language": fields.String(required=True, description="Language code"),
|
||||
"category": fields.String(required=True, description="App category"),
|
||||
"position": fields.Integer(required=True, description="Display position"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "App updated successfully")
|
||||
@console_ns.response(201, "App inserted successfully")
|
||||
@console_ns.response(404, "App not found")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
payload = InsertExploreAppPayload.model_validate(console_ns.payload)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("app_id", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("desc", type=str, location="json")
|
||||
.add_argument("copyright", type=str, location="json")
|
||||
.add_argument("privacy_policy", type=str, location="json")
|
||||
.add_argument("custom_disclaimer", type=str, location="json")
|
||||
.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
|
||||
.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("position", type=int, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
app = db.session.execute(select(App).where(App.id == payload.app_id)).scalar_one_or_none()
|
||||
app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none()
|
||||
if not app:
|
||||
raise NotFound(f"App '{payload.app_id}' is not found")
|
||||
raise NotFound(f"App '{args['app_id']}' is not found")
|
||||
|
||||
site = app.site
|
||||
if not site:
|
||||
desc = payload.desc or ""
|
||||
copy_right = payload.copyright or ""
|
||||
privacy_policy = payload.privacy_policy or ""
|
||||
custom_disclaimer = payload.custom_disclaimer or ""
|
||||
desc = args["desc"] or ""
|
||||
copy_right = args["copyright"] or ""
|
||||
privacy_policy = args["privacy_policy"] or ""
|
||||
custom_disclaimer = args["custom_disclaimer"] or ""
|
||||
else:
|
||||
desc = site.description or payload.desc or ""
|
||||
copy_right = site.copyright or payload.copyright or ""
|
||||
privacy_policy = site.privacy_policy or payload.privacy_policy or ""
|
||||
custom_disclaimer = site.custom_disclaimer or payload.custom_disclaimer or ""
|
||||
desc = site.description or args["desc"] or ""
|
||||
copy_right = site.copyright or args["copyright"] or ""
|
||||
privacy_policy = site.privacy_policy or args["privacy_policy"] or ""
|
||||
custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or ""
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
with Session(db.engine) as session:
|
||||
recommended_app = session.execute(
|
||||
select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id)
|
||||
select(RecommendedApp).where(RecommendedApp.app_id == args["app_id"])
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not recommended_app:
|
||||
@@ -103,9 +102,9 @@ class InsertExploreAppListApi(Resource):
|
||||
copyright=copy_right,
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
language=payload.language,
|
||||
category=payload.category,
|
||||
position=payload.position,
|
||||
language=args["language"],
|
||||
category=args["category"],
|
||||
position=args["position"],
|
||||
)
|
||||
|
||||
db.session.add(recommended_app)
|
||||
@@ -119,9 +118,9 @@ class InsertExploreAppListApi(Resource):
|
||||
recommended_app.copyright = copy_right
|
||||
recommended_app.privacy_policy = privacy_policy
|
||||
recommended_app.custom_disclaimer = custom_disclaimer
|
||||
recommended_app.language = payload.language
|
||||
recommended_app.category = payload.category
|
||||
recommended_app.position = payload.position
|
||||
recommended_app.language = args["language"]
|
||||
recommended_app.category = args["category"]
|
||||
recommended_app.position = args["position"]
|
||||
|
||||
app.is_public = True
|
||||
|
||||
@@ -139,7 +138,7 @@ class InsertExploreAppApi(Resource):
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def delete(self, app_id):
|
||||
with session_factory.create_session() as session:
|
||||
with Session(db.engine) as session:
|
||||
recommended_app = session.execute(
|
||||
select(RecommendedApp).where(RecommendedApp.app_id == str(app_id))
|
||||
).scalar_one_or_none()
|
||||
@@ -147,13 +146,13 @@ class InsertExploreAppApi(Resource):
|
||||
if not recommended_app:
|
||||
return {"result": "success"}, 204
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
with Session(db.engine) as session:
|
||||
app = session.execute(select(App).where(App.id == recommended_app.app_id)).scalar_one_or_none()
|
||||
|
||||
if app:
|
||||
app.is_public = False
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
with Session(db.engine) as session:
|
||||
installed_apps = (
|
||||
session.execute(
|
||||
select(InstalledApp).where(
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
@@ -10,21 +8,10 @@ from libs.login import login_required
|
||||
from models.model import AppMode
|
||||
from services.agent_service import AgentService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class AgentLogQuery(BaseModel):
|
||||
message_id: str = Field(..., description="Message UUID")
|
||||
conversation_id: str = Field(..., description="Conversation UUID")
|
||||
|
||||
@field_validator("message_id", "conversation_id")
|
||||
@classmethod
|
||||
def validate_uuid(cls, value: str) -> str:
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
AgentLogQuery.__name__, AgentLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", type=uuid_value, required=True, location="args", help="Message UUID")
|
||||
.add_argument("conversation_id", type=uuid_value, required=True, location="args", help="Conversation UUID")
|
||||
)
|
||||
|
||||
|
||||
@@ -33,7 +20,7 @@ class AgentLogApi(Resource):
|
||||
@console_ns.doc("get_agent_logs")
|
||||
@console_ns.doc(description="Get agent execution logs for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[AgentLogQuery.__name__])
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.response(
|
||||
200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries"))
|
||||
)
|
||||
@@ -44,6 +31,6 @@ class AgentLogApi(Resource):
|
||||
@get_app_model(mode=[AppMode.AGENT_CHAT])
|
||||
def get(self, app_model):
|
||||
"""Get agent logs"""
|
||||
args = AgentLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = parser.parse_args()
|
||||
|
||||
return AgentService.get_agent_logs(app_model, args.conversation_id, args.message_id)
|
||||
return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"])
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
from typing import Any, Literal
|
||||
from typing import Literal
|
||||
|
||||
from flask import abort, make_response, request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||
|
||||
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
annotation_import_concurrency_limit,
|
||||
annotation_import_rate_limit,
|
||||
cloud_edition_billing_resource_check,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
@@ -24,79 +21,22 @@ from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class AnnotationReplyPayload(BaseModel):
|
||||
score_threshold: float = Field(..., description="Score threshold for annotation matching")
|
||||
embedding_provider_name: str = Field(..., description="Embedding provider name")
|
||||
embedding_model_name: str = Field(..., description="Embedding model name")
|
||||
|
||||
|
||||
class AnnotationSettingUpdatePayload(BaseModel):
|
||||
score_threshold: float = Field(..., description="Score threshold")
|
||||
|
||||
|
||||
class AnnotationListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, description="Page number")
|
||||
limit: int = Field(default=20, ge=1, description="Page size")
|
||||
keyword: str = Field(default="", description="Search keyword")
|
||||
|
||||
|
||||
class CreateAnnotationPayload(BaseModel):
|
||||
message_id: str | None = Field(default=None, description="Message ID")
|
||||
question: str | None = Field(default=None, description="Question text")
|
||||
answer: str | None = Field(default=None, description="Answer text")
|
||||
content: str | None = Field(default=None, description="Content text")
|
||||
annotation_reply: dict[str, Any] | None = Field(default=None, description="Annotation reply data")
|
||||
|
||||
@field_validator("message_id")
|
||||
@classmethod
|
||||
def validate_message_id(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class UpdateAnnotationPayload(BaseModel):
|
||||
question: str | None = None
|
||||
answer: str | None = None
|
||||
content: str | None = None
|
||||
annotation_reply: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class AnnotationReplyStatusQuery(BaseModel):
|
||||
action: Literal["enable", "disable"]
|
||||
|
||||
|
||||
class AnnotationFilePayload(BaseModel):
|
||||
message_id: str = Field(..., description="Message ID")
|
||||
|
||||
@field_validator("message_id")
|
||||
@classmethod
|
||||
def validate_message_id(cls, value: str) -> str:
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
def reg(model: type[BaseModel]) -> None:
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(AnnotationReplyPayload)
|
||||
reg(AnnotationSettingUpdatePayload)
|
||||
reg(AnnotationListQuery)
|
||||
reg(CreateAnnotationPayload)
|
||||
reg(UpdateAnnotationPayload)
|
||||
reg(AnnotationReplyStatusQuery)
|
||||
reg(AnnotationFilePayload)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>")
|
||||
class AnnotationReplyActionApi(Resource):
|
||||
@console_ns.doc("annotation_reply_action")
|
||||
@console_ns.doc(description="Enable or disable annotation reply for an app")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"})
|
||||
@console_ns.expect(console_ns.models[AnnotationReplyPayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AnnotationReplyActionRequest",
|
||||
{
|
||||
"score_threshold": fields.Float(required=True, description="Score threshold for annotation matching"),
|
||||
"embedding_provider_name": fields.String(required=True, description="Embedding provider name"),
|
||||
"embedding_model_name": fields.String(required=True, description="Embedding model name"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Action completed successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@@ -106,9 +46,15 @@ class AnnotationReplyActionApi(Resource):
|
||||
@edit_permission_required
|
||||
def post(self, app_id, action: Literal["enable", "disable"]):
|
||||
app_id = str(app_id)
|
||||
args = AnnotationReplyPayload.model_validate(console_ns.payload)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("score_threshold", required=True, type=float, location="json")
|
||||
.add_argument("embedding_provider_name", required=True, type=str, location="json")
|
||||
.add_argument("embedding_model_name", required=True, type=str, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if action == "enable":
|
||||
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
|
||||
result = AppAnnotationService.enable_app_annotation(args, app_id)
|
||||
elif action == "disable":
|
||||
result = AppAnnotationService.disable_app_annotation(app_id)
|
||||
return result, 200
|
||||
@@ -136,7 +82,16 @@ class AppAnnotationSettingUpdateApi(Resource):
|
||||
@console_ns.doc("update_annotation_setting")
|
||||
@console_ns.doc(description="Update annotation settings for an app")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"})
|
||||
@console_ns.expect(console_ns.models[AnnotationSettingUpdatePayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AnnotationSettingUpdateRequest",
|
||||
{
|
||||
"score_threshold": fields.Float(required=True, description="Score threshold"),
|
||||
"embedding_provider_name": fields.String(required=True, description="Embedding provider"),
|
||||
"embedding_model_name": fields.String(required=True, description="Embedding model"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Settings updated successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@@ -147,9 +102,10 @@ class AppAnnotationSettingUpdateApi(Resource):
|
||||
app_id = str(app_id)
|
||||
annotation_setting_id = str(annotation_setting_id)
|
||||
|
||||
args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload)
|
||||
parser = reqparse.RequestParser().add_argument("score_threshold", required=True, type=float, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args.model_dump())
|
||||
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
|
||||
return result, 200
|
||||
|
||||
|
||||
@@ -186,7 +142,12 @@ class AnnotationApi(Resource):
|
||||
@console_ns.doc("list_annotations")
|
||||
@console_ns.doc(description="Get annotations for an app with pagination")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[AnnotationListQuery.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("page", type=int, location="args", default=1, help="Page number")
|
||||
.add_argument("limit", type=int, location="args", default=20, help="Page size")
|
||||
.add_argument("keyword", type=str, location="args", default="", help="Search keyword")
|
||||
)
|
||||
@console_ns.response(200, "Annotations retrieved successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@@ -194,10 +155,9 @@ class AnnotationApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_id):
|
||||
args = AnnotationListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
page = args.page
|
||||
limit = args.limit
|
||||
keyword = args.keyword
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
keyword = request.args.get("keyword", default="", type=str)
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
|
||||
@@ -213,7 +173,18 @@ class AnnotationApi(Resource):
|
||||
@console_ns.doc("create_annotation")
|
||||
@console_ns.doc(description="Create a new annotation for an app")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[CreateAnnotationPayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"CreateAnnotationRequest",
|
||||
{
|
||||
"message_id": fields.String(description="Message ID (optional)"),
|
||||
"question": fields.String(description="Question text (required when message_id not provided)"),
|
||||
"answer": fields.String(description="Answer text (use 'answer' or 'content')"),
|
||||
"content": fields.String(description="Content text (use 'answer' or 'content')"),
|
||||
"annotation_reply": fields.Raw(description="Annotation reply data"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns))
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@@ -224,9 +195,16 @@ class AnnotationApi(Resource):
|
||||
@edit_permission_required
|
||||
def post(self, app_id):
|
||||
app_id = str(app_id)
|
||||
args = CreateAnnotationPayload.model_validate(console_ns.payload)
|
||||
data = args.model_dump(exclude_none=True)
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", required=False, type=uuid_value, location="json")
|
||||
.add_argument("question", required=False, type=str, location="json")
|
||||
.add_argument("answer", required=False, type=str, location="json")
|
||||
.add_argument("content", required=False, type=str, location="json")
|
||||
.add_argument("annotation_reply", required=False, type=dict, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
|
||||
return annotation
|
||||
|
||||
@setup_required
|
||||
@@ -259,7 +237,7 @@ class AnnotationApi(Resource):
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/export")
|
||||
class AnnotationExportApi(Resource):
|
||||
@console_ns.doc("export_annotations")
|
||||
@console_ns.doc(description="Export all annotations for an app with CSV injection protection")
|
||||
@console_ns.doc(description="Export all annotations for an app")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
@@ -274,14 +252,15 @@ class AnnotationExportApi(Resource):
|
||||
def get(self, app_id):
|
||||
app_id = str(app_id)
|
||||
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
|
||||
response_data = {"data": marshal(annotation_list, annotation_fields)}
|
||||
response = {"data": marshal(annotation_list, annotation_fields)}
|
||||
return response, 200
|
||||
|
||||
# Create response with secure headers for CSV export
|
||||
response = make_response(response_data, 200)
|
||||
response.headers["Content-Type"] = "application/json; charset=utf-8"
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
|
||||
return response
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("question", required=True, type=str, location="json")
|
||||
.add_argument("answer", required=True, type=str, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
|
||||
@@ -292,7 +271,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns))
|
||||
@console_ns.response(204, "Annotation deleted successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.expect(console_ns.models[UpdateAnnotationPayload.__name__])
|
||||
@console_ns.expect(parser)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -302,10 +281,8 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
def post(self, app_id, annotation_id):
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
args = UpdateAnnotationPayload.model_validate(console_ns.payload)
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(
|
||||
args.model_dump(exclude_none=True), app_id, annotation_id
|
||||
)
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
|
||||
return annotation
|
||||
|
||||
@setup_required
|
||||
@@ -322,25 +299,18 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/batch-import")
|
||||
class AnnotationBatchImportApi(Resource):
|
||||
@console_ns.doc("batch_import_annotations")
|
||||
@console_ns.doc(description="Batch import annotations from CSV file with rate limiting and security checks")
|
||||
@console_ns.doc(description="Batch import annotations from CSV file")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Batch import started successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(400, "No file uploaded or too many files")
|
||||
@console_ns.response(413, "File too large")
|
||||
@console_ns.response(429, "Too many requests or concurrent imports")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@annotation_import_rate_limit
|
||||
@annotation_import_concurrency_limit
|
||||
@edit_permission_required
|
||||
def post(self, app_id):
|
||||
from configs import dify_config
|
||||
|
||||
app_id = str(app_id)
|
||||
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
@@ -350,27 +320,9 @@ class AnnotationBatchImportApi(Resource):
|
||||
|
||||
# get file from request
|
||||
file = request.files["file"]
|
||||
|
||||
# check file type
|
||||
if not file.filename or not file.filename.lower().endswith(".csv"):
|
||||
raise ValueError("Invalid file type. Only CSV files are allowed")
|
||||
|
||||
# Check file size before processing
|
||||
file.seek(0, 2) # Seek to end of file
|
||||
file_size = file.tell()
|
||||
file.seek(0) # Reset to beginning
|
||||
|
||||
max_size_bytes = dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT * 1024 * 1024
|
||||
if file_size > max_size_bytes:
|
||||
abort(
|
||||
413,
|
||||
f"File size exceeds maximum limit of {dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT}MB. "
|
||||
f"Please reduce the file size and try again.",
|
||||
)
|
||||
|
||||
if file_size == 0:
|
||||
raise ValueError("The uploaded file is empty")
|
||||
|
||||
return AppAnnotationService.batch_import_app_annotations(app_id, file)
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import re
|
||||
import uuid
|
||||
from typing import Literal
|
||||
|
||||
@@ -32,6 +31,7 @@ from fields.app_fields import (
|
||||
from fields.workflow_fields import workflow_partial_fields as _workflow_partial_fields_dict
|
||||
from libs.helper import AppIconUrlField, TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.validators import validate_description_length
|
||||
from models import App, Workflow
|
||||
from services.app_dsl_service import AppDslService, ImportMode
|
||||
from services.app_service import AppService
|
||||
@@ -74,88 +74,52 @@ class AppListQuery(BaseModel):
|
||||
raise ValueError("Invalid UUID format in tag_ids.") from exc
|
||||
|
||||
|
||||
# XSS prevention: patterns that could lead to XSS attacks
|
||||
# Includes: script tags, iframe tags, javascript: protocol, SVG with onload, etc.
|
||||
_XSS_PATTERNS = [
|
||||
r"<script[^>]*>.*?</script>", # Script tags
|
||||
r"<iframe\b[^>]*?(?:/>|>.*?</iframe>)", # Iframe tags (including self-closing)
|
||||
r"javascript:", # JavaScript protocol
|
||||
r"<svg[^>]*?\s+onload\s*=[^>]*>", # SVG with onload handler (attribute-aware, flexible whitespace)
|
||||
r"<.*?on\s*\w+\s*=", # Event handlers like onclick, onerror, etc.
|
||||
r"<object\b[^>]*(?:\s*/>|>.*?</object\s*>)", # Object tags (opening tag)
|
||||
r"<embed[^>]*>", # Embed tags (self-closing)
|
||||
r"<link[^>]*>", # Link tags with javascript
|
||||
]
|
||||
|
||||
|
||||
def _validate_xss_safe(value: str | None, field_name: str = "Field") -> str | None:
|
||||
"""
|
||||
Validate that a string value doesn't contain potential XSS payloads.
|
||||
|
||||
Args:
|
||||
value: The string value to validate
|
||||
field_name: Name of the field for error messages
|
||||
|
||||
Returns:
|
||||
The original value if safe
|
||||
|
||||
Raises:
|
||||
ValueError: If the value contains XSS patterns
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
value_lower = value.lower()
|
||||
for pattern in _XSS_PATTERNS:
|
||||
if re.search(pattern, value_lower, re.DOTALL | re.IGNORECASE):
|
||||
raise ValueError(
|
||||
f"{field_name} contains invalid characters or patterns. "
|
||||
"HTML tags, JavaScript, and other potentially dangerous content are not allowed."
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class CreateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)")
|
||||
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@field_validator("name", "description", mode="before")
|
||||
@field_validator("description")
|
||||
@classmethod
|
||||
def validate_xss_safe(cls, value: str | None, info) -> str | None:
|
||||
return _validate_xss_safe(value, info.field_name)
|
||||
def validate_description(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return validate_description_length(value)
|
||||
|
||||
|
||||
class UpdateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)")
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
|
||||
max_active_requests: int | None = Field(default=None, description="Maximum active requests")
|
||||
|
||||
@field_validator("name", "description", mode="before")
|
||||
@field_validator("description")
|
||||
@classmethod
|
||||
def validate_xss_safe(cls, value: str | None, info) -> str | None:
|
||||
return _validate_xss_safe(value, info.field_name)
|
||||
def validate_description(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return validate_description_length(value)
|
||||
|
||||
|
||||
class CopyAppPayload(BaseModel):
|
||||
name: str | None = Field(default=None, description="Name for the copied app")
|
||||
description: str | None = Field(default=None, description="Description for the copied app", max_length=400)
|
||||
description: str | None = Field(default=None, description="Description for the copied app")
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@field_validator("name", "description", mode="before")
|
||||
@field_validator("description")
|
||||
@classmethod
|
||||
def validate_xss_safe(cls, value: str | None, info) -> str | None:
|
||||
return _validate_xss_safe(value, info.field_name)
|
||||
def validate_description(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return validate_description_length(value)
|
||||
|
||||
|
||||
class AppExportQuery(BaseModel):
|
||||
@@ -182,14 +146,7 @@ class AppApiStatusPayload(BaseModel):
|
||||
|
||||
class AppTracePayload(BaseModel):
|
||||
enabled: bool = Field(..., description="Enable or disable tracing")
|
||||
tracing_provider: str | None = Field(default=None, description="Tracing provider")
|
||||
|
||||
@field_validator("tracing_provider")
|
||||
@classmethod
|
||||
def validate_tracing_provider(cls, value: str | None, info) -> str | None:
|
||||
if info.data.get("enabled") and not value:
|
||||
raise ValueError("tracing_provider is required when enabled is True")
|
||||
return value
|
||||
tracing_provider: str = Field(..., description="Tracing provider")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
@@ -367,13 +324,10 @@ class AppListApi(Resource):
|
||||
NodeType.TRIGGER_PLUGIN,
|
||||
}
|
||||
for workflow in draft_workflows:
|
||||
try:
|
||||
for _, node_data in workflow.walk_nodes():
|
||||
if node_data.get("type") in trigger_node_types:
|
||||
draft_trigger_app_ids.add(str(workflow.app_id))
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
for _, node_data in workflow.walk_nodes():
|
||||
if node_data.get("type") in trigger_node_types:
|
||||
draft_trigger_app_ids.add(str(workflow.app_id))
|
||||
break
|
||||
|
||||
for app in app_pagination.items:
|
||||
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
@@ -36,29 +35,23 @@ app_import_check_dependencies_model = console_ns.model(
|
||||
"AppImportCheckDependencies", app_import_check_dependencies_fields_copy
|
||||
)
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class AppImportPayload(BaseModel):
|
||||
mode: str = Field(..., description="Import mode")
|
||||
yaml_content: str | None = None
|
||||
yaml_url: str | None = None
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
icon_type: str | None = None
|
||||
icon: str | None = None
|
||||
icon_background: str | None = None
|
||||
app_id: str | None = None
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
AppImportPayload.__name__, AppImportPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("mode", type=str, required=True, location="json")
|
||||
.add_argument("yaml_content", type=str, location="json")
|
||||
.add_argument("yaml_url", type=str, location="json")
|
||||
.add_argument("name", type=str, location="json")
|
||||
.add_argument("description", type=str, location="json")
|
||||
.add_argument("icon_type", type=str, location="json")
|
||||
.add_argument("icon", type=str, location="json")
|
||||
.add_argument("icon_background", type=str, location="json")
|
||||
.add_argument("app_id", type=str, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/imports")
|
||||
class AppImportApi(Resource):
|
||||
@console_ns.expect(console_ns.models[AppImportPayload.__name__])
|
||||
@console_ns.expect(parser)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -68,7 +61,7 @@ class AppImportApi(Resource):
|
||||
def post(self):
|
||||
# Check user role first
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = AppImportPayload.model_validate(console_ns.payload)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
@@ -77,15 +70,15 @@ class AppImportApi(Resource):
|
||||
account = current_user
|
||||
result = import_service.import_app(
|
||||
account=account,
|
||||
import_mode=args.mode,
|
||||
yaml_content=args.yaml_content,
|
||||
yaml_url=args.yaml_url,
|
||||
name=args.name,
|
||||
description=args.description,
|
||||
icon_type=args.icon_type,
|
||||
icon=args.icon,
|
||||
icon_background=args.icon_background,
|
||||
app_id=args.app_id,
|
||||
import_mode=args["mode"],
|
||||
yaml_content=args.get("yaml_content"),
|
||||
yaml_url=args.get("yaml_url"),
|
||||
name=args.get("name"),
|
||||
description=args.get("description"),
|
||||
icon_type=args.get("icon_type"),
|
||||
icon=args.get("icon"),
|
||||
icon_background=args.get("icon_background"),
|
||||
app_id=args.get("app_id"),
|
||||
)
|
||||
session.commit()
|
||||
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
@@ -33,27 +32,6 @@ from services.errors.audio import (
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class TextToSpeechPayload(BaseModel):
|
||||
message_id: str | None = Field(default=None, description="Message ID")
|
||||
text: str = Field(..., description="Text to convert")
|
||||
voice: str | None = Field(default=None, description="Voice name")
|
||||
streaming: bool | None = Field(default=None, description="Whether to stream audio")
|
||||
|
||||
|
||||
class TextToSpeechVoiceQuery(BaseModel):
|
||||
language: str = Field(..., description="Language code")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
TextToSpeechPayload.__name__, TextToSpeechPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
TextToSpeechVoiceQuery.__name__,
|
||||
TextToSpeechVoiceQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/audio-to-text")
|
||||
@@ -114,7 +92,17 @@ class ChatMessageTextApi(Resource):
|
||||
@console_ns.doc("chat_message_text_to_speech")
|
||||
@console_ns.doc(description="Convert text to speech for chat messages")
|
||||
@console_ns.doc(params={"app_id": "App ID"})
|
||||
@console_ns.expect(console_ns.models[TextToSpeechPayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"TextToSpeechRequest",
|
||||
{
|
||||
"message_id": fields.String(description="Message ID"),
|
||||
"text": fields.String(required=True, description="Text to convert to speech"),
|
||||
"voice": fields.String(description="Voice to use for TTS"),
|
||||
"streaming": fields.Boolean(description="Whether to stream the audio"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Text to speech conversion successful")
|
||||
@console_ns.response(400, "Bad request - Invalid parameters")
|
||||
@get_app_model
|
||||
@@ -123,14 +111,21 @@ class ChatMessageTextApi(Resource):
|
||||
@account_initialization_required
|
||||
def post(self, app_model: App):
|
||||
try:
|
||||
payload = TextToSpeechPayload.model_validate(console_ns.payload)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", type=str, location="json")
|
||||
.add_argument("text", type=str, location="json")
|
||||
.add_argument("voice", type=str, location="json")
|
||||
.add_argument("streaming", type=bool, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
voice = args.get("voice", None)
|
||||
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model,
|
||||
text=payload.text,
|
||||
voice=payload.voice,
|
||||
message_id=payload.message_id,
|
||||
is_draft=True,
|
||||
app_model=app_model, text=text, voice=voice, message_id=message_id, is_draft=True
|
||||
)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
@@ -164,7 +159,9 @@ class TextModesApi(Resource):
|
||||
@console_ns.doc("get_text_to_speech_voices")
|
||||
@console_ns.doc(description="Get available TTS voices for a specific language")
|
||||
@console_ns.doc(params={"app_id": "App ID"})
|
||||
@console_ns.expect(console_ns.models[TextToSpeechVoiceQuery.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.parser().add_argument("language", type=str, required=True, location="args", help="Language code")
|
||||
)
|
||||
@console_ns.response(
|
||||
200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices"))
|
||||
)
|
||||
@@ -175,11 +172,12 @@ class TextModesApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
try:
|
||||
args = TextToSpeechVoiceQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
parser = reqparse.RequestParser().add_argument("language", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
response = AudioService.transcript_tts_voices(
|
||||
tenant_id=app_model.tenant_id,
|
||||
language=args.language,
|
||||
language=args["language"],
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@@ -13,6 +13,7 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import MessageTextField
|
||||
from fields.raws import FilesContainedField
|
||||
from libs.datetime_utils import naive_utc_now, parse_time_range
|
||||
from libs.helper import TimestampField
|
||||
@@ -48,6 +49,7 @@ class CompletionConversationQuery(BaseConversationQuery):
|
||||
|
||||
|
||||
class ChatConversationQuery(BaseConversationQuery):
|
||||
message_count_gte: int | None = Field(default=None, ge=1, description="Minimum message count")
|
||||
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
|
||||
default="-updated_at", description="Sort field and direction"
|
||||
)
|
||||
@@ -176,12 +178,6 @@ annotation_hit_history_model = console_ns.model(
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class MessageTextField(fields.Raw):
|
||||
def format(self, value):
|
||||
return value[0]["text"] if value else ""
|
||||
|
||||
|
||||
# Simple message detail model
|
||||
simple_message_detail_model = console_ns.model(
|
||||
"SimpleMessageDetail",
|
||||
@@ -513,6 +509,14 @@ class ChatConversationApi(Resource):
|
||||
.having(func.count(MessageAnnotation.id) == 0)
|
||||
)
|
||||
|
||||
if args.message_count_gte and args.message_count_gte >= 1:
|
||||
query = (
|
||||
query.options(joinedload(Conversation.messages)) # type: ignore
|
||||
.join(Message, Message.conversation_id == Conversation.id)
|
||||
.group_by(Conversation.id)
|
||||
.having(func.count(Message.id) >= args.message_count_gte)
|
||||
)
|
||||
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)
|
||||
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import json
|
||||
from enum import StrEnum
|
||||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
@@ -13,8 +12,6 @@ from fields.app_fields import app_server_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import AppMCPServer
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||
app_server_model = console_ns.model("AppServer", app_server_fields)
|
||||
|
||||
@@ -24,22 +21,6 @@ class AppMCPServerStatus(StrEnum):
|
||||
INACTIVE = "inactive"
|
||||
|
||||
|
||||
class MCPServerCreatePayload(BaseModel):
|
||||
description: str | None = Field(default=None, description="Server description")
|
||||
parameters: dict = Field(..., description="Server parameters configuration")
|
||||
|
||||
|
||||
class MCPServerUpdatePayload(BaseModel):
|
||||
id: str = Field(..., description="Server ID")
|
||||
description: str | None = Field(default=None, description="Server description")
|
||||
parameters: dict = Field(..., description="Server parameters configuration")
|
||||
status: str | None = Field(default=None, description="Server status")
|
||||
|
||||
|
||||
for model in (MCPServerCreatePayload, MCPServerUpdatePayload):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/server")
|
||||
class AppMCPServerController(Resource):
|
||||
@console_ns.doc("get_app_mcp_server")
|
||||
@@ -58,7 +39,15 @@ class AppMCPServerController(Resource):
|
||||
@console_ns.doc("create_app_mcp_server")
|
||||
@console_ns.doc(description="Create MCP server configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[MCPServerCreatePayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"MCPServerCreateRequest",
|
||||
{
|
||||
"description": fields.String(description="Server description"),
|
||||
"parameters": fields.Raw(required=True, description="Server parameters configuration"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(201, "MCP server configuration created successfully", app_server_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@account_initialization_required
|
||||
@@ -69,16 +58,21 @@ class AppMCPServerController(Resource):
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
payload = MCPServerCreatePayload.model_validate(console_ns.payload or {})
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("description", type=str, required=False, location="json")
|
||||
.add_argument("parameters", type=dict, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
description = payload.description
|
||||
description = args.get("description")
|
||||
if not description:
|
||||
description = app_model.description or ""
|
||||
|
||||
server = AppMCPServer(
|
||||
name=app_model.name,
|
||||
description=description,
|
||||
parameters=json.dumps(payload.parameters, ensure_ascii=False),
|
||||
parameters=json.dumps(args["parameters"], ensure_ascii=False),
|
||||
status=AppMCPServerStatus.ACTIVE,
|
||||
app_id=app_model.id,
|
||||
tenant_id=current_tenant_id,
|
||||
@@ -91,7 +85,17 @@ class AppMCPServerController(Resource):
|
||||
@console_ns.doc("update_app_mcp_server")
|
||||
@console_ns.doc(description="Update MCP server configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[MCPServerUpdatePayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"MCPServerUpdateRequest",
|
||||
{
|
||||
"id": fields.String(required=True, description="Server ID"),
|
||||
"description": fields.String(description="Server description"),
|
||||
"parameters": fields.Raw(required=True, description="Server parameters configuration"),
|
||||
"status": fields.String(description="Server status"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "MCP server configuration updated successfully", app_server_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "Server not found")
|
||||
@@ -102,12 +106,19 @@ class AppMCPServerController(Resource):
|
||||
@marshal_with(app_server_model)
|
||||
@edit_permission_required
|
||||
def put(self, app_model):
|
||||
payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
|
||||
server = db.session.query(AppMCPServer).where(AppMCPServer.id == payload.id).first()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("id", type=str, required=True, location="json")
|
||||
.add_argument("description", type=str, required=False, location="json")
|
||||
.add_argument("parameters", type=dict, required=True, location="json")
|
||||
.add_argument("status", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first()
|
||||
if not server:
|
||||
raise NotFound()
|
||||
|
||||
description = payload.description
|
||||
description = args.get("description")
|
||||
if description is None:
|
||||
pass
|
||||
elif not description:
|
||||
@@ -115,11 +126,11 @@ class AppMCPServerController(Resource):
|
||||
else:
|
||||
server.description = description
|
||||
|
||||
server.parameters = json.dumps(payload.parameters, ensure_ascii=False)
|
||||
if payload.status:
|
||||
if payload.status not in [status.value for status in AppMCPServerStatus]:
|
||||
server.parameters = json.dumps(args["parameters"], ensure_ascii=False)
|
||||
if args["status"]:
|
||||
if args["status"] not in [status.value for status in AppMCPServerStatus]:
|
||||
raise ValueError("Invalid status")
|
||||
server.status = payload.status
|
||||
server.status = args["status"]
|
||||
db.session.commit()
|
||||
return server
|
||||
|
||||
|
||||
@@ -61,7 +61,6 @@ class ChatMessagesQuery(BaseModel):
|
||||
class MessageFeedbackPayload(BaseModel):
|
||||
message_id: str = Field(..., description="Message ID")
|
||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
|
||||
content: str | None = Field(default=None, description="Feedback content")
|
||||
|
||||
@field_validator("message_id")
|
||||
@classmethod
|
||||
@@ -325,7 +324,6 @@ class MessageFeedbackApi(Resource):
|
||||
db.session.delete(feedback)
|
||||
elif args.rating and feedback:
|
||||
feedback.rating = args.rating
|
||||
feedback.content = args.content
|
||||
elif not args.rating and not feedback:
|
||||
raise ValueError("rating cannot be None when feedback not exists")
|
||||
else:
|
||||
@@ -337,7 +335,6 @@ class MessageFeedbackApi(Resource):
|
||||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
rating=rating_value,
|
||||
content=args.content,
|
||||
from_source="admin",
|
||||
from_account_id=current_user.id,
|
||||
)
|
||||
|
||||
@@ -1,8 +1,4 @@
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.console import console_ns
|
||||
@@ -11,26 +7,6 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
||||
from libs.login import login_required
|
||||
from services.ops_service import OpsService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class TraceProviderQuery(BaseModel):
|
||||
tracing_provider: str = Field(..., description="Tracing provider name")
|
||||
|
||||
|
||||
class TraceConfigPayload(BaseModel):
|
||||
tracing_provider: str = Field(..., description="Tracing provider name")
|
||||
tracing_config: dict[str, Any] = Field(..., description="Tracing configuration data")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
TraceProviderQuery.__name__,
|
||||
TraceProviderQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
TraceConfigPayload.__name__, TraceConfigPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/trace-config")
|
||||
class TraceAppConfigApi(Resource):
|
||||
@@ -41,7 +17,11 @@ class TraceAppConfigApi(Resource):
|
||||
@console_ns.doc("get_trace_app_config")
|
||||
@console_ns.doc(description="Get tracing configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[TraceProviderQuery.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.parser().add_argument(
|
||||
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
|
||||
)
|
||||
)
|
||||
@console_ns.response(
|
||||
200, "Tracing configuration retrieved successfully", fields.Raw(description="Tracing configuration data")
|
||||
)
|
||||
@@ -50,10 +30,11 @@ class TraceAppConfigApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
|
||||
trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
|
||||
if not trace_config:
|
||||
return {"has_not_configured": True}
|
||||
return trace_config
|
||||
@@ -63,7 +44,15 @@ class TraceAppConfigApi(Resource):
|
||||
@console_ns.doc("create_trace_app_config")
|
||||
@console_ns.doc(description="Create a new tracing configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[TraceConfigPayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"TraceConfigCreateRequest",
|
||||
{
|
||||
"tracing_provider": fields.String(required=True, description="Tracing provider name"),
|
||||
"tracing_config": fields.Raw(required=True, description="Tracing configuration data"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(
|
||||
201, "Tracing configuration created successfully", fields.Raw(description="Created configuration data")
|
||||
)
|
||||
@@ -73,11 +62,16 @@ class TraceAppConfigApi(Resource):
|
||||
@account_initialization_required
|
||||
def post(self, app_id):
|
||||
"""Create a new trace app configuration"""
|
||||
args = TraceConfigPayload.model_validate(console_ns.payload)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||
.add_argument("tracing_config", type=dict, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
result = OpsService.create_tracing_app_config(
|
||||
app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
|
||||
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"]
|
||||
)
|
||||
if not result:
|
||||
raise TracingConfigIsExist()
|
||||
@@ -90,7 +84,15 @@ class TraceAppConfigApi(Resource):
|
||||
@console_ns.doc("update_trace_app_config")
|
||||
@console_ns.doc(description="Update an existing tracing configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[TraceConfigPayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"TraceConfigUpdateRequest",
|
||||
{
|
||||
"tracing_provider": fields.String(required=True, description="Tracing provider name"),
|
||||
"tracing_config": fields.Raw(required=True, description="Updated tracing configuration data"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response"))
|
||||
@console_ns.response(400, "Invalid request parameters or configuration not found")
|
||||
@setup_required
|
||||
@@ -98,11 +100,16 @@ class TraceAppConfigApi(Resource):
|
||||
@account_initialization_required
|
||||
def patch(self, app_id):
|
||||
"""Update an existing trace app configuration"""
|
||||
args = TraceConfigPayload.model_validate(console_ns.payload)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||
.add_argument("tracing_config", type=dict, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
result = OpsService.update_tracing_app_config(
|
||||
app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
|
||||
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"]
|
||||
)
|
||||
if not result:
|
||||
raise TracingConfigNotExist()
|
||||
@@ -113,7 +120,11 @@ class TraceAppConfigApi(Resource):
|
||||
@console_ns.doc("delete_trace_app_config")
|
||||
@console_ns.doc(description="Delete an existing tracing configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[TraceProviderQuery.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.parser().add_argument(
|
||||
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
|
||||
)
|
||||
)
|
||||
@console_ns.response(204, "Tracing configuration deleted successfully")
|
||||
@console_ns.response(400, "Invalid request parameters or configuration not found")
|
||||
@setup_required
|
||||
@@ -121,10 +132,11 @@ class TraceAppConfigApi(Resource):
|
||||
@account_initialization_required
|
||||
def delete(self, app_id):
|
||||
"""Delete an existing trace app configuration"""
|
||||
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
|
||||
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
|
||||
if not result:
|
||||
raise TracingConfigNotExist()
|
||||
return {"result": "success"}, 204
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from constants.languages import supported_language
|
||||
@@ -19,50 +16,69 @@ from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Site
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class AppSiteUpdatePayload(BaseModel):
|
||||
title: str | None = Field(default=None)
|
||||
icon_type: str | None = Field(default=None)
|
||||
icon: str | None = Field(default=None)
|
||||
icon_background: str | None = Field(default=None)
|
||||
description: str | None = Field(default=None)
|
||||
default_language: str | None = Field(default=None)
|
||||
chat_color_theme: str | None = Field(default=None)
|
||||
chat_color_theme_inverted: bool | None = Field(default=None)
|
||||
customize_domain: str | None = Field(default=None)
|
||||
copyright: str | None = Field(default=None)
|
||||
privacy_policy: str | None = Field(default=None)
|
||||
custom_disclaimer: str | None = Field(default=None)
|
||||
customize_token_strategy: Literal["must", "allow", "not_allow"] | None = Field(default=None)
|
||||
prompt_public: bool | None = Field(default=None)
|
||||
show_workflow_steps: bool | None = Field(default=None)
|
||||
use_icon_as_answer_icon: bool | None = Field(default=None)
|
||||
|
||||
@field_validator("default_language")
|
||||
@classmethod
|
||||
def validate_language(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return supported_language(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
AppSiteUpdatePayload.__name__,
|
||||
AppSiteUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||
app_site_model = console_ns.model("AppSite", app_site_fields)
|
||||
|
||||
|
||||
def parse_app_site_args():
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("title", type=str, required=False, location="json")
|
||||
.add_argument("icon_type", type=str, required=False, location="json")
|
||||
.add_argument("icon", type=str, required=False, location="json")
|
||||
.add_argument("icon_background", type=str, required=False, location="json")
|
||||
.add_argument("description", type=str, required=False, location="json")
|
||||
.add_argument("default_language", type=supported_language, required=False, location="json")
|
||||
.add_argument("chat_color_theme", type=str, required=False, location="json")
|
||||
.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json")
|
||||
.add_argument("customize_domain", type=str, required=False, location="json")
|
||||
.add_argument("copyright", type=str, required=False, location="json")
|
||||
.add_argument("privacy_policy", type=str, required=False, location="json")
|
||||
.add_argument("custom_disclaimer", type=str, required=False, location="json")
|
||||
.add_argument(
|
||||
"customize_token_strategy",
|
||||
type=str,
|
||||
choices=["must", "allow", "not_allow"],
|
||||
required=False,
|
||||
location="json",
|
||||
)
|
||||
.add_argument("prompt_public", type=bool, required=False, location="json")
|
||||
.add_argument("show_workflow_steps", type=bool, required=False, location="json")
|
||||
.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/site")
|
||||
class AppSite(Resource):
|
||||
@console_ns.doc("update_app_site")
|
||||
@console_ns.doc(description="Update application site configuration")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[AppSiteUpdatePayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AppSiteRequest",
|
||||
{
|
||||
"title": fields.String(description="Site title"),
|
||||
"icon_type": fields.String(description="Icon type"),
|
||||
"icon": fields.String(description="Icon"),
|
||||
"icon_background": fields.String(description="Icon background color"),
|
||||
"description": fields.String(description="Site description"),
|
||||
"default_language": fields.String(description="Default language"),
|
||||
"chat_color_theme": fields.String(description="Chat color theme"),
|
||||
"chat_color_theme_inverted": fields.Boolean(description="Inverted chat color theme"),
|
||||
"customize_domain": fields.String(description="Custom domain"),
|
||||
"copyright": fields.String(description="Copyright text"),
|
||||
"privacy_policy": fields.String(description="Privacy policy"),
|
||||
"custom_disclaimer": fields.String(description="Custom disclaimer"),
|
||||
"customize_token_strategy": fields.String(
|
||||
enum=["must", "allow", "not_allow"], description="Token strategy"
|
||||
),
|
||||
"prompt_public": fields.Boolean(description="Make prompt public"),
|
||||
"show_workflow_steps": fields.Boolean(description="Show workflow steps"),
|
||||
"use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Site configuration updated successfully", app_site_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "App not found")
|
||||
@@ -73,7 +89,7 @@ class AppSite(Resource):
|
||||
@get_app_model
|
||||
@marshal_with(app_site_model)
|
||||
def post(self, app_model):
|
||||
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
|
||||
args = parse_app_site_args()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
if not site:
|
||||
@@ -97,7 +113,7 @@ class AppSite(Resource):
|
||||
"show_workflow_steps",
|
||||
"use_icon_as_answer_icon",
|
||||
]:
|
||||
value = getattr(args, attr_name)
|
||||
value = args.get(attr_name)
|
||||
if value is not None:
|
||||
setattr(site, attr_name, value)
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any, NoReturn, ParamSpec, TypeVar
|
||||
from typing import NoReturn, ParamSpec, TypeVar
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask import Response
|
||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import console_ns
|
||||
@@ -30,27 +29,6 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList,
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowDraftVariableListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=100_000, description="Page number")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Items per page")
|
||||
|
||||
|
||||
class WorkflowDraftVariableUpdatePayload(BaseModel):
|
||||
name: str | None = Field(default=None, description="Variable name")
|
||||
value: Any | None = Field(default=None, description="Variable value")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkflowDraftVariableListQuery.__name__,
|
||||
WorkflowDraftVariableListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
WorkflowDraftVariableUpdatePayload.__name__,
|
||||
WorkflowDraftVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
def _convert_values_to_json_serializable_object(value: Segment):
|
||||
@@ -79,6 +57,22 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
|
||||
return _convert_values_to_json_serializable_object(value)
|
||||
|
||||
|
||||
def _create_pagination_parser():
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"page",
|
||||
type=inputs.int_range(1, 100_000),
|
||||
required=False,
|
||||
default=1,
|
||||
location="args",
|
||||
help="the page of data requested",
|
||||
)
|
||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
|
||||
value_type = workflow_draft_var.value_type
|
||||
return value_type.exposed_type().value
|
||||
@@ -207,7 +201,7 @@ def _api_prerequisite(f: Callable[P, R]):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables")
|
||||
class WorkflowVariableCollectionApi(Resource):
|
||||
@console_ns.expect(console_ns.models[WorkflowDraftVariableListQuery.__name__])
|
||||
@console_ns.expect(_create_pagination_parser())
|
||||
@console_ns.doc("get_workflow_variables")
|
||||
@console_ns.doc(description="Get draft workflow variables")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@@ -221,7 +215,8 @@ class WorkflowVariableCollectionApi(Resource):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
parser = _create_pagination_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
workflow_service = WorkflowService()
|
||||
@@ -328,7 +323,15 @@ class VariableApi(Resource):
|
||||
|
||||
@console_ns.doc("update_variable")
|
||||
@console_ns.doc(description="Update a workflow variable")
|
||||
@console_ns.expect(console_ns.models[WorkflowDraftVariableUpdatePayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"UpdateVariableRequest",
|
||||
{
|
||||
"name": fields.String(description="Variable name"),
|
||||
"value": fields.Raw(description="Variable value"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model)
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
@@ -355,10 +358,16 @@ class VariableApi(Resource):
|
||||
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
||||
# }
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
||||
.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
args = parser.parse_args(strict=True)
|
||||
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
@@ -366,8 +375,8 @@ class VariableApi(Resource):
|
||||
if variable.app_id != app_model.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
|
||||
new_name = args_model.name
|
||||
raw_value = args_model.value
|
||||
new_name = args.get(self._PATCH_NAME_FIELD, None)
|
||||
raw_value = args.get(self._PATCH_VALUE_FIELD, None)
|
||||
if new_name is None and raw_value is None:
|
||||
return variable
|
||||
|
||||
|
||||
@@ -114,7 +114,7 @@ class AppTriggersApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/trigger-enable")
|
||||
class AppTriggerEnableApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserEnable.__name__])
|
||||
@console_ns.expect(console_ns.models[ParserEnable.__name__], validate=True)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
||||
@@ -1,53 +1,28 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.error import AlreadyActivateError
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import EmailStr, timezone
|
||||
from libs.helper import StrLen, email, extract_remote_ip, timezone
|
||||
from models import AccountStatus
|
||||
from services.account_service import RegisterService
|
||||
from services.account_service import AccountService, RegisterService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ActivateCheckQuery(BaseModel):
|
||||
workspace_id: str | None = Field(default=None)
|
||||
email: EmailStr | None = Field(default=None)
|
||||
token: str
|
||||
|
||||
|
||||
class ActivatePayload(BaseModel):
|
||||
workspace_id: str | None = Field(default=None)
|
||||
email: EmailStr | None = Field(default=None)
|
||||
token: str
|
||||
name: str = Field(..., max_length=30)
|
||||
interface_language: str = Field(...)
|
||||
timezone: str = Field(...)
|
||||
|
||||
@field_validator("interface_language")
|
||||
@classmethod
|
||||
def validate_lang(cls, value: str) -> str:
|
||||
return supported_language(value)
|
||||
|
||||
@field_validator("timezone")
|
||||
@classmethod
|
||||
def validate_tz(cls, value: str) -> str:
|
||||
return timezone(value)
|
||||
|
||||
|
||||
for model in (ActivateCheckQuery, ActivatePayload):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
active_check_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID")
|
||||
.add_argument("email", type=email, required=False, nullable=True, location="args", help="Email address")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="args", help="Activation token")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/activate/check")
|
||||
class ActivateCheckApi(Resource):
|
||||
@console_ns.doc("check_activation_token")
|
||||
@console_ns.doc(description="Check if activation token is valid")
|
||||
@console_ns.expect(console_ns.models[ActivateCheckQuery.__name__])
|
||||
@console_ns.expect(active_check_parser)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
@@ -60,11 +35,11 @@ class ActivateCheckApi(Resource):
|
||||
),
|
||||
)
|
||||
def get(self):
|
||||
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = active_check_parser.parse_args()
|
||||
|
||||
workspaceId = args.workspace_id
|
||||
reg_email = args.email
|
||||
token = args.token
|
||||
workspaceId = args["workspace_id"]
|
||||
reg_email = args["email"]
|
||||
token = args["token"]
|
||||
|
||||
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
|
||||
if invitation:
|
||||
@@ -81,11 +56,22 @@ class ActivateCheckApi(Resource):
|
||||
return {"is_valid": False}
|
||||
|
||||
|
||||
active_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("email", type=email, required=False, nullable=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
.add_argument("interface_language", type=supported_language, required=True, nullable=False, location="json")
|
||||
.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/activate")
|
||||
class ActivateApi(Resource):
|
||||
@console_ns.doc("activate_account")
|
||||
@console_ns.doc(description="Activate account with invitation token")
|
||||
@console_ns.expect(console_ns.models[ActivatePayload.__name__])
|
||||
@console_ns.expect(active_parser)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Account activated successfully",
|
||||
@@ -93,27 +79,30 @@ class ActivateApi(Resource):
|
||||
"ActivationResponse",
|
||||
{
|
||||
"result": fields.String(description="Operation result"),
|
||||
"data": fields.Raw(description="Login token data"),
|
||||
},
|
||||
),
|
||||
)
|
||||
@console_ns.response(400, "Already activated or invalid token")
|
||||
def post(self):
|
||||
args = ActivatePayload.model_validate(console_ns.payload)
|
||||
args = active_parser.parse_args()
|
||||
|
||||
invitation = RegisterService.get_invitation_if_token_valid(args.workspace_id, args.email, args.token)
|
||||
invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"])
|
||||
if invitation is None:
|
||||
raise AlreadyActivateError()
|
||||
|
||||
RegisterService.revoke_token(args.workspace_id, args.email, args.token)
|
||||
RegisterService.revoke_token(args["workspace_id"], args["email"], args["token"])
|
||||
|
||||
account = invitation["account"]
|
||||
account.name = args.name
|
||||
account.name = args["name"]
|
||||
|
||||
account.interface_language = args.interface_language
|
||||
account.timezone = args.timezone
|
||||
account.interface_language = args["interface_language"]
|
||||
account.timezone = args["timezone"]
|
||||
account.interface_theme = "light"
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.initialized_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}
|
||||
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
||||
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
|
||||
@@ -1,26 +1,12 @@
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource, reqparse
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import ApiKeyAuthFailedError
|
||||
from controllers.console.wraps import is_admin_or_owner_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
|
||||
from .. import console_ns
|
||||
from ..auth.error import ApiKeyAuthFailedError
|
||||
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ApiKeyAuthBindingPayload(BaseModel):
|
||||
category: str = Field(...)
|
||||
provider: str = Field(...)
|
||||
credentials: dict = Field(...)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ApiKeyAuthBindingPayload.__name__,
|
||||
ApiKeyAuthBindingPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
from ..wraps import account_initialization_required, setup_required
|
||||
|
||||
|
||||
@console_ns.route("/api-key-auth/data-source")
|
||||
@@ -54,15 +40,19 @@ class ApiKeyAuthDataSourceBinding(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_admin_or_owner_required
|
||||
@console_ns.expect(console_ns.models[ApiKeyAuthBindingPayload.__name__])
|
||||
def post(self):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
payload = ApiKeyAuthBindingPayload.model_validate(console_ns.payload)
|
||||
data = payload.model_dump()
|
||||
ApiKeyAuthService.validate_api_key_auth_args(data)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
try:
|
||||
ApiKeyAuthService.create_provider_auth(current_tenant_id, data)
|
||||
ApiKeyAuthService.create_provider_auth(current_tenant_id, args)
|
||||
except Exception as e:
|
||||
raise ApiKeyAuthFailedError(str(e))
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@@ -5,11 +5,12 @@ from flask import current_app, redirect, request
|
||||
from flask_restx import Resource, fields
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import is_admin_or_owner_required
|
||||
from libs.login import login_required
|
||||
from libs.oauth_data_source import NotionOAuth
|
||||
|
||||
from .. import console_ns
|
||||
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from ..wraps import account_initialization_required, setup_required
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from flask_restx import Resource, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -15,45 +14,16 @@ from controllers.console.auth.error import (
|
||||
InvalidTokenError,
|
||||
PasswordMismatchError,
|
||||
)
|
||||
from controllers.console.error import AccountInFreezeError, EmailSendIpLimitError
|
||||
from controllers.console.wraps import email_password_login_enabled, email_register_enabled, setup_required
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.password import valid_password
|
||||
from models import Account
|
||||
from services.account_service import AccountService
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.account import AccountNotFoundError, AccountRegisterError
|
||||
|
||||
from ..error import AccountInFreezeError, EmailSendIpLimitError
|
||||
from ..wraps import email_password_login_enabled, email_register_enabled, setup_required
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class EmailRegisterSendPayload(BaseModel):
|
||||
email: EmailStr = Field(..., description="Email address")
|
||||
language: str | None = Field(default=None, description="Language code")
|
||||
|
||||
|
||||
class EmailRegisterValidityPayload(BaseModel):
|
||||
email: EmailStr = Field(...)
|
||||
code: str = Field(...)
|
||||
token: str = Field(...)
|
||||
|
||||
|
||||
class EmailRegisterResetPayload(BaseModel):
|
||||
token: str = Field(...)
|
||||
new_password: str = Field(...)
|
||||
password_confirm: str = Field(...)
|
||||
|
||||
@field_validator("new_password", "password_confirm")
|
||||
@classmethod
|
||||
def validate_password(cls, value: str) -> str:
|
||||
return valid_password(value)
|
||||
|
||||
|
||||
for model in (EmailRegisterSendPayload, EmailRegisterValidityPayload, EmailRegisterResetPayload):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
@console_ns.route("/email-register/send-email")
|
||||
class EmailRegisterSendEmailApi(Resource):
|
||||
@@ -61,22 +31,27 @@ class EmailRegisterSendEmailApi(Resource):
|
||||
@email_password_login_enabled
|
||||
@email_register_enabled
|
||||
def post(self):
|
||||
args = EmailRegisterSendPayload.model_validate(console_ns.payload)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
language = "en-US"
|
||||
if args.language in languages:
|
||||
language = args.language
|
||||
if args["language"] in languages:
|
||||
language = args["language"]
|
||||
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
|
||||
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
|
||||
token = None
|
||||
token = AccountService.send_email_register_email(email=args.email, account=account, language=language)
|
||||
token = AccountService.send_email_register_email(email=args["email"], account=account, language=language)
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
@@ -86,34 +61,40 @@ class EmailRegisterCheckApi(Resource):
|
||||
@email_password_login_enabled
|
||||
@email_register_enabled
|
||||
def post(self):
|
||||
args = EmailRegisterValidityPayload.model_validate(console_ns.payload)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
user_email = args.email
|
||||
user_email = args["email"]
|
||||
|
||||
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args.email)
|
||||
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args["email"])
|
||||
if is_email_register_error_rate_limit:
|
||||
raise EmailRegisterLimitError()
|
||||
|
||||
token_data = AccountService.get_email_register_data(args.token)
|
||||
token_data = AccountService.get_email_register_data(args["token"])
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if user_email != token_data.get("email"):
|
||||
raise InvalidEmailError()
|
||||
|
||||
if args.code != token_data.get("code"):
|
||||
AccountService.add_email_register_error_rate_limit(args.email)
|
||||
if args["code"] != token_data.get("code"):
|
||||
AccountService.add_email_register_error_rate_limit(args["email"])
|
||||
raise EmailCodeError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
AccountService.revoke_email_register_token(args.token)
|
||||
AccountService.revoke_email_register_token(args["token"])
|
||||
|
||||
# Refresh token data by generating a new token
|
||||
_, new_token = AccountService.generate_email_register_token(
|
||||
user_email, code=args.code, additional_data={"phase": "register"}
|
||||
user_email, code=args["code"], additional_data={"phase": "register"}
|
||||
)
|
||||
|
||||
AccountService.reset_email_register_error_rate_limit(args.email)
|
||||
AccountService.reset_email_register_error_rate_limit(args["email"])
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
|
||||
|
||||
@@ -123,14 +104,20 @@ class EmailRegisterResetApi(Resource):
|
||||
@email_password_login_enabled
|
||||
@email_register_enabled
|
||||
def post(self):
|
||||
args = EmailRegisterResetPayload.model_validate(console_ns.payload)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
||||
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate passwords match
|
||||
if args.new_password != args.password_confirm:
|
||||
if args["new_password"] != args["password_confirm"]:
|
||||
raise PasswordMismatchError()
|
||||
|
||||
# Validate token and get register data
|
||||
register_data = AccountService.get_email_register_data(args.token)
|
||||
register_data = AccountService.get_email_register_data(args["token"])
|
||||
if not register_data:
|
||||
raise InvalidTokenError()
|
||||
# Must use token in reset phase
|
||||
@@ -138,7 +125,7 @@ class EmailRegisterResetApi(Resource):
|
||||
raise InvalidTokenError()
|
||||
|
||||
# Revoke token to prevent reuse
|
||||
AccountService.revoke_email_register_token(args.token)
|
||||
AccountService.revoke_email_register_token(args["token"])
|
||||
|
||||
email = register_data.get("email", "")
|
||||
|
||||
@@ -148,7 +135,7 @@ class EmailRegisterResetApi(Resource):
|
||||
if account:
|
||||
raise EmailAlreadyInUseError()
|
||||
else:
|
||||
account = self._create_new_account(email, args.password_confirm)
|
||||
account = self._create_new_account(email, args["password_confirm"])
|
||||
if not account:
|
||||
raise AccountNotFoundError()
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
|
||||
@@ -2,8 +2,7 @@ import base64
|
||||
import secrets
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -19,46 +18,26 @@ from controllers.console.error import AccountNotFound, EmailSendIpLimitError
|
||||
from controllers.console.wraps import email_password_login_enabled, setup_required
|
||||
from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.password import hash_password, valid_password
|
||||
from models import Account
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ForgotPasswordSendPayload(BaseModel):
|
||||
email: EmailStr = Field(...)
|
||||
language: str | None = Field(default=None)
|
||||
|
||||
|
||||
class ForgotPasswordCheckPayload(BaseModel):
|
||||
email: EmailStr = Field(...)
|
||||
code: str = Field(...)
|
||||
token: str = Field(...)
|
||||
|
||||
|
||||
class ForgotPasswordResetPayload(BaseModel):
|
||||
token: str = Field(...)
|
||||
new_password: str = Field(...)
|
||||
password_confirm: str = Field(...)
|
||||
|
||||
@field_validator("new_password", "password_confirm")
|
||||
@classmethod
|
||||
def validate_password(cls, value: str) -> str:
|
||||
return valid_password(value)
|
||||
|
||||
|
||||
for model in (ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
@console_ns.route("/forgot-password")
|
||||
class ForgotPasswordSendEmailApi(Resource):
|
||||
@console_ns.doc("send_forgot_password_email")
|
||||
@console_ns.doc(description="Send password reset email")
|
||||
@console_ns.expect(console_ns.models[ForgotPasswordSendPayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"ForgotPasswordEmailRequest",
|
||||
{
|
||||
"email": fields.String(required=True, description="Email address"),
|
||||
"language": fields.String(description="Language for email (zh-Hans/en-US)"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Email sent successfully",
|
||||
@@ -75,23 +54,28 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
args = ForgotPasswordSendPayload.model_validate(console_ns.payload)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
|
||||
if args.language is not None and args.language == "zh-Hans":
|
||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
|
||||
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
|
||||
|
||||
token = AccountService.send_reset_password_email(
|
||||
account=account,
|
||||
email=args.email,
|
||||
email=args["email"],
|
||||
language=language,
|
||||
is_allow_register=FeatureService.get_system_features().is_allow_register,
|
||||
)
|
||||
@@ -103,7 +87,16 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
class ForgotPasswordCheckApi(Resource):
|
||||
@console_ns.doc("check_forgot_password_code")
|
||||
@console_ns.doc(description="Verify password reset code")
|
||||
@console_ns.expect(console_ns.models[ForgotPasswordCheckPayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"ForgotPasswordCheckRequest",
|
||||
{
|
||||
"email": fields.String(required=True, description="Email address"),
|
||||
"code": fields.String(required=True, description="Verification code"),
|
||||
"token": fields.String(required=True, description="Reset token"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Code verified successfully",
|
||||
@@ -120,34 +113,40 @@ class ForgotPasswordCheckApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
args = ForgotPasswordCheckPayload.model_validate(console_ns.payload)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
user_email = args.email
|
||||
user_email = args["email"]
|
||||
|
||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args.email)
|
||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"])
|
||||
if is_forgot_password_error_rate_limit:
|
||||
raise EmailPasswordResetLimitError()
|
||||
|
||||
token_data = AccountService.get_reset_password_data(args.token)
|
||||
token_data = AccountService.get_reset_password_data(args["token"])
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if user_email != token_data.get("email"):
|
||||
raise InvalidEmailError()
|
||||
|
||||
if args.code != token_data.get("code"):
|
||||
AccountService.add_forgot_password_error_rate_limit(args.email)
|
||||
if args["code"] != token_data.get("code"):
|
||||
AccountService.add_forgot_password_error_rate_limit(args["email"])
|
||||
raise EmailCodeError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
AccountService.revoke_reset_password_token(args.token)
|
||||
AccountService.revoke_reset_password_token(args["token"])
|
||||
|
||||
# Refresh token data by generating a new token
|
||||
_, new_token = AccountService.generate_reset_password_token(
|
||||
user_email, code=args.code, additional_data={"phase": "reset"}
|
||||
user_email, code=args["code"], additional_data={"phase": "reset"}
|
||||
)
|
||||
|
||||
AccountService.reset_forgot_password_error_rate_limit(args.email)
|
||||
AccountService.reset_forgot_password_error_rate_limit(args["email"])
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
|
||||
|
||||
@@ -155,7 +154,16 @@ class ForgotPasswordCheckApi(Resource):
|
||||
class ForgotPasswordResetApi(Resource):
|
||||
@console_ns.doc("reset_password")
|
||||
@console_ns.doc(description="Reset password with verification token")
|
||||
@console_ns.expect(console_ns.models[ForgotPasswordResetPayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"ForgotPasswordResetRequest",
|
||||
{
|
||||
"token": fields.String(required=True, description="Verification token"),
|
||||
"new_password": fields.String(required=True, description="New password"),
|
||||
"password_confirm": fields.String(required=True, description="Password confirmation"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Password reset successfully",
|
||||
@@ -165,14 +173,20 @@ class ForgotPasswordResetApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
args = ForgotPasswordResetPayload.model_validate(console_ns.payload)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
||||
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate passwords match
|
||||
if args.new_password != args.password_confirm:
|
||||
if args["new_password"] != args["password_confirm"]:
|
||||
raise PasswordMismatchError()
|
||||
|
||||
# Validate token and get reset data
|
||||
reset_data = AccountService.get_reset_password_data(args.token)
|
||||
reset_data = AccountService.get_reset_password_data(args["token"])
|
||||
if not reset_data:
|
||||
raise InvalidTokenError()
|
||||
# Must use token in reset phase
|
||||
@@ -180,11 +194,11 @@ class ForgotPasswordResetApi(Resource):
|
||||
raise InvalidTokenError()
|
||||
|
||||
# Revoke token to prevent reuse
|
||||
AccountService.revoke_reset_password_token(args.token)
|
||||
AccountService.revoke_reset_password_token(args["token"])
|
||||
|
||||
# Generate secure salt and hash password
|
||||
salt = secrets.token_bytes(16)
|
||||
password_hashed = hash_password(args.new_password, salt)
|
||||
password_hashed = hash_password(args["new_password"], salt)
|
||||
|
||||
email = reset_data.get("email", "")
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import flask_login
|
||||
from flask import make_response, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource, reqparse
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
@@ -22,14 +21,9 @@ from controllers.console.error import (
|
||||
NotAllowedCreateWorkspace,
|
||||
WorkspacesLimitExceeded,
|
||||
)
|
||||
from controllers.console.wraps import (
|
||||
decrypt_code_field,
|
||||
decrypt_password_field,
|
||||
email_password_login_enabled,
|
||||
setup_required,
|
||||
)
|
||||
from controllers.console.wraps import email_password_login_enabled, setup_required
|
||||
from events.tenant_event import tenant_was_created
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.login import current_account_with_tenant
|
||||
from libs.token import (
|
||||
clear_access_token_from_cookie,
|
||||
@@ -46,36 +40,6 @@ from services.errors.account import AccountRegisterError
|
||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class LoginPayload(BaseModel):
|
||||
email: EmailStr = Field(..., description="Email address")
|
||||
password: str = Field(..., description="Password")
|
||||
remember_me: bool = Field(default=False, description="Remember me flag")
|
||||
invite_token: str | None = Field(default=None, description="Invitation token")
|
||||
|
||||
|
||||
class EmailPayload(BaseModel):
|
||||
email: EmailStr = Field(...)
|
||||
language: str | None = Field(default=None)
|
||||
|
||||
|
||||
class EmailCodeLoginPayload(BaseModel):
|
||||
email: EmailStr = Field(...)
|
||||
code: str = Field(...)
|
||||
token: str = Field(...)
|
||||
language: str | None = Field(default=None)
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(LoginPayload)
|
||||
reg(EmailPayload)
|
||||
reg(EmailCodeLoginPayload)
|
||||
|
||||
|
||||
@console_ns.route("/login")
|
||||
class LoginApi(Resource):
|
||||
@@ -83,37 +47,41 @@ class LoginApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@console_ns.expect(console_ns.models[LoginPayload.__name__])
|
||||
@decrypt_password_field
|
||||
def post(self):
|
||||
"""Authenticate user and login."""
|
||||
args = LoginPayload.model_validate(console_ns.payload)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("password", type=str, required=True, location="json")
|
||||
.add_argument("remember_me", type=bool, required=False, default=False, location="json")
|
||||
.add_argument("invite_token", type=str, required=False, default=None, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args.email)
|
||||
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args["email"])
|
||||
if is_login_error_rate_limit:
|
||||
raise EmailPasswordLoginLimitError()
|
||||
|
||||
# TODO: why invitation is re-assigned with different type?
|
||||
invitation = args.invite_token # type: ignore
|
||||
invitation = args["invite_token"]
|
||||
if invitation:
|
||||
invitation = RegisterService.get_invitation_if_token_valid(None, args.email, invitation) # type: ignore
|
||||
invitation = RegisterService.get_invitation_if_token_valid(None, args["email"], invitation)
|
||||
|
||||
try:
|
||||
if invitation:
|
||||
data = invitation.get("data", {}) # type: ignore
|
||||
data = invitation.get("data", {})
|
||||
invitee_email = data.get("email") if data else None
|
||||
if invitee_email != args.email:
|
||||
if invitee_email != args["email"]:
|
||||
raise InvalidEmailError()
|
||||
account = AccountService.authenticate(args.email, args.password, args.invite_token)
|
||||
account = AccountService.authenticate(args["email"], args["password"], args["invite_token"])
|
||||
else:
|
||||
account = AccountService.authenticate(args.email, args.password)
|
||||
account = AccountService.authenticate(args["email"], args["password"])
|
||||
except services.errors.account.AccountLoginError:
|
||||
raise AccountBannedError()
|
||||
except services.errors.account.AccountPasswordError:
|
||||
AccountService.add_login_error_rate_limit(args.email)
|
||||
AccountService.add_login_error_rate_limit(args["email"])
|
||||
raise AuthenticationFailedError()
|
||||
# SELF_HOSTED only have one workspace
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
@@ -129,7 +97,7 @@ class LoginApi(Resource):
|
||||
}
|
||||
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(args.email)
|
||||
AccountService.reset_login_error_rate_limit(args["email"])
|
||||
|
||||
# Create response with cookies instead of returning tokens in body
|
||||
response = make_response({"result": "success"})
|
||||
@@ -166,21 +134,25 @@ class LogoutApi(Resource):
|
||||
class ResetPasswordSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@console_ns.expect(console_ns.models[EmailPayload.__name__])
|
||||
def post(self):
|
||||
args = EmailPayload.model_validate(console_ns.payload)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.language is not None and args.language == "zh-Hans":
|
||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
try:
|
||||
account = AccountService.get_user_through_email(args.email)
|
||||
account = AccountService.get_user_through_email(args["email"])
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
token = AccountService.send_reset_password_email(
|
||||
email=args.email,
|
||||
email=args["email"],
|
||||
account=account,
|
||||
language=language,
|
||||
is_allow_register=FeatureService.get_system_features().is_allow_register,
|
||||
@@ -192,26 +164,30 @@ class ResetPasswordSendEmailApi(Resource):
|
||||
@console_ns.route("/email-code-login")
|
||||
class EmailCodeLoginSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@console_ns.expect(console_ns.models[EmailPayload.__name__])
|
||||
def post(self):
|
||||
args = EmailPayload.model_validate(console_ns.payload)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
|
||||
if args.language is not None and args.language == "zh-Hans":
|
||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
try:
|
||||
account = AccountService.get_user_through_email(args.email)
|
||||
account = AccountService.get_user_through_email(args["email"])
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
if account is None:
|
||||
if FeatureService.get_system_features().is_allow_register:
|
||||
token = AccountService.send_email_code_login_email(email=args.email, language=language)
|
||||
token = AccountService.send_email_code_login_email(email=args["email"], language=language)
|
||||
else:
|
||||
raise AccountNotFound()
|
||||
else:
|
||||
@@ -223,25 +199,30 @@ class EmailCodeLoginSendEmailApi(Resource):
|
||||
@console_ns.route("/email-code-login/validity")
|
||||
class EmailCodeLoginApi(Resource):
|
||||
@setup_required
|
||||
@console_ns.expect(console_ns.models[EmailCodeLoginPayload.__name__])
|
||||
@decrypt_code_field
|
||||
def post(self):
|
||||
args = EmailCodeLoginPayload.model_validate(console_ns.payload)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
user_email = args.email
|
||||
language = args.language
|
||||
user_email = args["email"]
|
||||
language = args["language"]
|
||||
|
||||
token_data = AccountService.get_email_code_login_data(args.token)
|
||||
token_data = AccountService.get_email_code_login_data(args["token"])
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if token_data["email"] != args.email:
|
||||
if token_data["email"] != args["email"]:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if token_data["code"] != args.code:
|
||||
if token_data["code"] != args["code"]:
|
||||
raise EmailCodeError()
|
||||
|
||||
AccountService.revoke_email_code_login_token(args.token)
|
||||
AccountService.revoke_email_code_login_token(args["token"])
|
||||
try:
|
||||
account = AccountService.get_user_through_email(user_email)
|
||||
except AccountRegisterError:
|
||||
@@ -274,7 +255,7 @@ class EmailCodeLoginApi(Resource):
|
||||
except WorkspacesLimitExceededError:
|
||||
raise WorkspacesLimitExceeded()
|
||||
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(args.email)
|
||||
AccountService.reset_login_error_rate_limit(args["email"])
|
||||
|
||||
# Create response with cookies instead of returning tokens in body
|
||||
response = make_response({"result": "success"})
|
||||
|
||||
@@ -124,7 +124,7 @@ class OAuthCallback(Resource):
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
|
||||
|
||||
try:
|
||||
account, oauth_new_user = _generate_account(provider, user_info)
|
||||
account = _generate_account(provider, user_info)
|
||||
except AccountNotFoundError:
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.")
|
||||
except (WorkSpaceNotFoundError, WorkSpaceNotAllowedCreateError):
|
||||
@@ -159,10 +159,7 @@ class OAuthCallback(Resource):
|
||||
ip_address=extract_remote_ip(request),
|
||||
)
|
||||
|
||||
base_url = dify_config.CONSOLE_WEB_URL
|
||||
query_char = "&" if "?" in base_url else "?"
|
||||
target_url = f"{base_url}{query_char}oauth_new_user={str(oauth_new_user).lower()}"
|
||||
response = redirect(target_url)
|
||||
response = redirect(f"{dify_config.CONSOLE_WEB_URL}")
|
||||
|
||||
set_access_token_to_cookie(request, response, token_pair.access_token)
|
||||
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
|
||||
@@ -180,10 +177,9 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
|
||||
return account
|
||||
|
||||
|
||||
def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account, bool]:
|
||||
def _generate_account(provider: str, user_info: OAuthUserInfo):
|
||||
# Get account by openid or email.
|
||||
account = _get_account_by_openid_or_email(provider, user_info)
|
||||
oauth_new_user = False
|
||||
|
||||
if account:
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
@@ -197,7 +193,6 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account,
|
||||
tenant_was_created.send(new_tenant)
|
||||
|
||||
if not account:
|
||||
oauth_new_user = True
|
||||
if not FeatureService.get_system_features().is_allow_register:
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email):
|
||||
raise AccountRegisterError(
|
||||
@@ -225,4 +220,4 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account,
|
||||
# Link account
|
||||
AccountService.link_account_integrate(provider, user_info.id, account)
|
||||
|
||||
return account, oauth_new_user
|
||||
return account
|
||||
|
||||
@@ -3,8 +3,7 @@ from functools import wraps
|
||||
from typing import Concatenate, ParamSpec, TypeVar
|
||||
|
||||
from flask import jsonify, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
@@ -21,34 +20,15 @@ R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class OAuthClientPayload(BaseModel):
|
||||
client_id: str
|
||||
|
||||
|
||||
class OAuthProviderRequest(BaseModel):
|
||||
client_id: str
|
||||
redirect_uri: str
|
||||
|
||||
|
||||
class OAuthTokenRequest(BaseModel):
|
||||
client_id: str
|
||||
grant_type: str
|
||||
code: str | None = None
|
||||
client_secret: str | None = None
|
||||
redirect_uri: str | None = None
|
||||
refresh_token: str | None = None
|
||||
|
||||
|
||||
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
|
||||
json_data = request.get_json()
|
||||
if json_data is None:
|
||||
parser = reqparse.RequestParser().add_argument("client_id", type=str, required=True, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
client_id = parsed_args.get("client_id")
|
||||
if not client_id:
|
||||
raise BadRequest("client_id is required")
|
||||
|
||||
payload = OAuthClientPayload.model_validate(json_data)
|
||||
client_id = payload.client_id
|
||||
|
||||
oauth_provider_app = OAuthServerService.get_oauth_provider_app(client_id)
|
||||
if not oauth_provider_app:
|
||||
raise NotFound("client_id is invalid")
|
||||
@@ -109,8 +89,9 @@ class OAuthServerAppApi(Resource):
|
||||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
payload = OAuthProviderRequest.model_validate(request.get_json())
|
||||
redirect_uri = payload.redirect_uri
|
||||
parser = reqparse.RequestParser().add_argument("redirect_uri", type=str, required=True, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
redirect_uri = parsed_args.get("redirect_uri")
|
||||
|
||||
# check if redirect_uri is valid
|
||||
if redirect_uri not in oauth_provider_app.redirect_uris:
|
||||
@@ -149,25 +130,33 @@ class OAuthServerUserTokenApi(Resource):
|
||||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
payload = OAuthTokenRequest.model_validate(request.get_json())
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("grant_type", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=False, location="json")
|
||||
.add_argument("client_secret", type=str, required=False, location="json")
|
||||
.add_argument("redirect_uri", type=str, required=False, location="json")
|
||||
.add_argument("refresh_token", type=str, required=False, location="json")
|
||||
)
|
||||
parsed_args = parser.parse_args()
|
||||
|
||||
try:
|
||||
grant_type = OAuthGrantType(payload.grant_type)
|
||||
grant_type = OAuthGrantType(parsed_args["grant_type"])
|
||||
except ValueError:
|
||||
raise BadRequest("invalid grant_type")
|
||||
|
||||
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
|
||||
if not payload.code:
|
||||
if not parsed_args["code"]:
|
||||
raise BadRequest("code is required")
|
||||
|
||||
if payload.client_secret != oauth_provider_app.client_secret:
|
||||
if parsed_args["client_secret"] != oauth_provider_app.client_secret:
|
||||
raise BadRequest("client_secret is invalid")
|
||||
|
||||
if payload.redirect_uri not in oauth_provider_app.redirect_uris:
|
||||
if parsed_args["redirect_uri"] not in oauth_provider_app.redirect_uris:
|
||||
raise BadRequest("redirect_uri is invalid")
|
||||
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type, code=payload.code, client_id=oauth_provider_app.client_id
|
||||
grant_type, code=parsed_args["code"], client_id=oauth_provider_app.client_id
|
||||
)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
@@ -178,11 +167,11 @@ class OAuthServerUserTokenApi(Resource):
|
||||
}
|
||||
)
|
||||
elif grant_type == OAuthGrantType.REFRESH_TOKEN:
|
||||
if not payload.refresh_token:
|
||||
if not parsed_args["refresh_token"]:
|
||||
raise BadRequest("refresh_token is required")
|
||||
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id
|
||||
grant_type, refresh_token=parsed_args["refresh_token"], client_id=oauth_provider_app.client_id
|
||||
)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
import base64
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.console import console_ns
|
||||
@@ -12,21 +9,6 @@ from enums.cloud_plan import CloudPlan
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class SubscriptionQuery(BaseModel):
|
||||
plan: Literal[CloudPlan.PROFESSIONAL, CloudPlan.TEAM] = Field(..., description="Subscription plan")
|
||||
interval: Literal["month", "year"] = Field(..., description="Billing interval")
|
||||
|
||||
|
||||
class PartnerTenantsPayload(BaseModel):
|
||||
click_id: str = Field(..., description="Click Id from partner referral link")
|
||||
|
||||
|
||||
for model in (SubscriptionQuery, PartnerTenantsPayload):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
@console_ns.route("/billing/subscription")
|
||||
class Subscription(Resource):
|
||||
@@ -36,9 +18,20 @@ class Subscription(Resource):
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
args = SubscriptionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"plan",
|
||||
type=str,
|
||||
required=True,
|
||||
location="args",
|
||||
choices=[CloudPlan.PROFESSIONAL, CloudPlan.TEAM],
|
||||
)
|
||||
.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
|
||||
)
|
||||
args = parser.parse_args()
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
return BillingService.get_subscription(args.plan, args.interval, current_user.email, current_tenant_id)
|
||||
return BillingService.get_subscription(args["plan"], args["interval"], current_user.email, current_tenant_id)
|
||||
|
||||
|
||||
@console_ns.route("/billing/invoices")
|
||||
@@ -72,10 +65,11 @@ class PartnerTenants(Resource):
|
||||
@only_edition_cloud
|
||||
def put(self, partner_key: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("click_id", required=True, type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
args = PartnerTenantsPayload.model_validate(console_ns.payload or {})
|
||||
click_id = args.click_id
|
||||
click_id = args["click_id"]
|
||||
decoded_partner_key = base64.b64decode(partner_key).decode("utf-8")
|
||||
except Exception:
|
||||
raise BadRequest("Invalid partner_key")
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource, reqparse
|
||||
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
@@ -10,28 +9,16 @@ from .. import console_ns
|
||||
from ..wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
|
||||
|
||||
class ComplianceDownloadQuery(BaseModel):
|
||||
doc_name: str = Field(..., description="Compliance document name")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ComplianceDownloadQuery.__name__,
|
||||
ComplianceDownloadQuery.model_json_schema(ref_template="#/definitions/{model}"),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/compliance/download")
|
||||
class ComplianceApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ComplianceDownloadQuery.__name__])
|
||||
@console_ns.doc("download_compliance_document")
|
||||
@console_ns.doc(description="Get compliance document download link")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
args = ComplianceDownloadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
parser = reqparse.RequestParser().add_argument("doc_name", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
device_info = request.headers.get("User-Agent", "Unknown device")
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
from typing import cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.indexing_runner import IndexingRunner
|
||||
@@ -25,19 +25,6 @@ from services.dataset_service import DatasetService, DocumentService
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||
|
||||
from .. import console_ns
|
||||
from ..wraps import account_initialization_required, setup_required
|
||||
|
||||
|
||||
class NotionEstimatePayload(BaseModel):
|
||||
notion_info_list: list[dict[str, Any]]
|
||||
process_rule: dict[str, Any]
|
||||
doc_form: str = Field(default="text_model")
|
||||
doc_language: str = Field(default="English")
|
||||
|
||||
|
||||
register_schema_model(console_ns, NotionEstimatePayload)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/data-source/integrates",
|
||||
@@ -140,18 +127,6 @@ class DataSourceNotionListApi(Resource):
|
||||
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||
if not credential_id:
|
||||
raise ValueError("Credential id is required.")
|
||||
|
||||
# Get datasource_parameters from query string (optional, for GitHub and other datasources)
|
||||
datasource_parameters_str = request.args.get("datasource_parameters", default=None, type=str)
|
||||
datasource_parameters = {}
|
||||
if datasource_parameters_str:
|
||||
try:
|
||||
datasource_parameters = json.loads(datasource_parameters_str)
|
||||
if not isinstance(datasource_parameters, dict):
|
||||
raise ValueError("datasource_parameters must be a JSON object.")
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("Invalid datasource_parameters JSON format.")
|
||||
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=current_tenant_id,
|
||||
@@ -199,7 +174,7 @@ class DataSourceNotionListApi(Resource):
|
||||
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
|
||||
datasource_runtime.get_online_document_pages(
|
||||
user_id=current_user.id,
|
||||
datasource_parameters=datasource_parameters,
|
||||
datasource_parameters={},
|
||||
provider_type=datasource_runtime.datasource_provider_type(),
|
||||
)
|
||||
)
|
||||
@@ -230,14 +205,14 @@ class DataSourceNotionListApi(Resource):
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/notion/pages/<uuid:page_id>/<string:page_type>/preview",
|
||||
"/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview",
|
||||
"/datasets/notion-indexing-estimate",
|
||||
)
|
||||
class DataSourceNotionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, page_id, page_type):
|
||||
def get(self, workspace_id, page_id, page_type):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||
@@ -251,10 +226,11 @@ class DataSourceNotionApi(Resource):
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
|
||||
workspace_id = str(workspace_id)
|
||||
page_id = str(page_id)
|
||||
|
||||
extractor = NotionExtractor(
|
||||
notion_workspace_id="",
|
||||
notion_workspace_id=workspace_id,
|
||||
notion_obj_id=page_id,
|
||||
notion_page_type=page_type,
|
||||
notion_access_token=credential.get("integration_secret"),
|
||||
@@ -267,15 +243,20 @@ class DataSourceNotionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.expect(console_ns.models[NotionEstimatePayload.__name__])
|
||||
def post(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
payload = NotionEstimatePayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
|
||||
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
||||
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
# validate args
|
||||
DocumentService.estimate_args_validate(args)
|
||||
notion_info_list = payload.notion_info_list
|
||||
notion_info_list = args["notion_info_list"]
|
||||
extract_settings = []
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info["workspace_id"]
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.apikey import (
|
||||
api_key_item_model,
|
||||
@@ -50,6 +48,7 @@ from fields.dataset_fields import (
|
||||
)
|
||||
from fields.document_fields import document_status_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.validators import validate_description_length
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from models.provider_ids import ModelProviderID
|
||||
@@ -108,75 +107,10 @@ related_app_list_copy["data"] = fields.List(fields.Nested(app_detail_kernel_mode
|
||||
related_app_list_model = _get_or_create_model("RelatedAppList", related_app_list_copy)
|
||||
|
||||
|
||||
def _validate_indexing_technique(value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
if value not in Dataset.INDEXING_TECHNIQUE_LIST:
|
||||
raise ValueError("Invalid indexing technique.")
|
||||
return value
|
||||
|
||||
|
||||
class DatasetCreatePayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=40)
|
||||
description: str = Field("", max_length=400)
|
||||
indexing_technique: str | None = None
|
||||
permission: DatasetPermissionEnum | None = DatasetPermissionEnum.ONLY_ME
|
||||
provider: str = "vendor"
|
||||
external_knowledge_api_id: str | None = None
|
||||
external_knowledge_id: str | None = None
|
||||
|
||||
@field_validator("indexing_technique")
|
||||
@classmethod
|
||||
def validate_indexing(cls, value: str | None) -> str | None:
|
||||
return _validate_indexing_technique(value)
|
||||
|
||||
@field_validator("provider")
|
||||
@classmethod
|
||||
def validate_provider(cls, value: str) -> str:
|
||||
if value not in Dataset.PROVIDER_LIST:
|
||||
raise ValueError("Invalid provider.")
|
||||
return value
|
||||
|
||||
|
||||
class DatasetUpdatePayload(BaseModel):
|
||||
name: str | None = Field(None, min_length=1, max_length=40)
|
||||
description: str | None = Field(None, max_length=400)
|
||||
permission: DatasetPermissionEnum | None = None
|
||||
indexing_technique: str | None = None
|
||||
embedding_model: str | None = None
|
||||
embedding_model_provider: str | None = None
|
||||
retrieval_model: dict[str, Any] | None = None
|
||||
partial_member_list: list[dict[str, str]] | None = None
|
||||
external_retrieval_model: dict[str, Any] | None = None
|
||||
external_knowledge_id: str | None = None
|
||||
external_knowledge_api_id: str | None = None
|
||||
icon_info: dict[str, Any] | None = None
|
||||
is_multimodal: bool | None = False
|
||||
|
||||
@field_validator("indexing_technique")
|
||||
@classmethod
|
||||
def validate_indexing(cls, value: str | None) -> str | None:
|
||||
return _validate_indexing_technique(value)
|
||||
|
||||
|
||||
class IndexingEstimatePayload(BaseModel):
|
||||
info_list: dict[str, Any]
|
||||
process_rule: dict[str, Any]
|
||||
indexing_technique: str
|
||||
doc_form: str = "text_model"
|
||||
dataset_id: str | None = None
|
||||
doc_language: str = "English"
|
||||
|
||||
@field_validator("indexing_technique")
|
||||
@classmethod
|
||||
def validate_indexing(cls, value: str) -> str:
|
||||
result = _validate_indexing_technique(value)
|
||||
if result is None:
|
||||
raise ValueError("indexing_technique is required.")
|
||||
return result
|
||||
|
||||
|
||||
register_schema_models(console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload)
|
||||
def _validate_name(name: str) -> str:
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
|
||||
@@ -223,7 +157,6 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
|
||||
VectorType.COUCHBASE,
|
||||
VectorType.OPENGAUSS,
|
||||
VectorType.OCEANBASE,
|
||||
VectorType.SEEKDB,
|
||||
VectorType.TABLESTORE,
|
||||
VectorType.HUAWEI_CLOUD,
|
||||
VectorType.TENCENT,
|
||||
@@ -231,7 +164,6 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
|
||||
VectorType.CLICKZETTA,
|
||||
VectorType.BAIDU,
|
||||
VectorType.ALIBABACLOUD_MYSQL,
|
||||
VectorType.IRIS,
|
||||
}
|
||||
|
||||
semantic_methods = {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
@@ -323,7 +255,20 @@ class DatasetListApi(Resource):
|
||||
|
||||
@console_ns.doc("create_dataset")
|
||||
@console_ns.doc(description="Create a new dataset")
|
||||
@console_ns.expect(console_ns.models[DatasetCreatePayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"CreateDatasetRequest",
|
||||
{
|
||||
"name": fields.String(required=True, description="Dataset name (1-40 characters)"),
|
||||
"description": fields.String(description="Dataset description (max 400 characters)"),
|
||||
"indexing_technique": fields.String(description="Indexing technique"),
|
||||
"permission": fields.String(description="Dataset permission"),
|
||||
"provider": fields.String(description="Provider"),
|
||||
"external_knowledge_api_id": fields.String(description="External knowledge API ID"),
|
||||
"external_knowledge_id": fields.String(description="External knowledge ID"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(201, "Dataset created successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@@ -331,7 +276,52 @@ class DatasetListApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
payload = DatasetCreatePayload.model_validate(console_ns.payload or {})
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="type is required. Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument(
|
||||
"description",
|
||||
type=validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
location="json",
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
help="Invalid indexing technique.",
|
||||
)
|
||||
.add_argument(
|
||||
"external_knowledge_api_id",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
)
|
||||
.add_argument(
|
||||
"provider",
|
||||
type=str,
|
||||
nullable=True,
|
||||
choices=Dataset.PROVIDER_LIST,
|
||||
required=False,
|
||||
default="vendor",
|
||||
)
|
||||
.add_argument(
|
||||
"external_knowledge_id",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
@@ -341,14 +331,14 @@ class DatasetListApi(Resource):
|
||||
try:
|
||||
dataset = DatasetService.create_empty_dataset(
|
||||
tenant_id=current_tenant_id,
|
||||
name=payload.name,
|
||||
description=payload.description,
|
||||
indexing_technique=payload.indexing_technique,
|
||||
name=args["name"],
|
||||
description=args["description"],
|
||||
indexing_technique=args["indexing_technique"],
|
||||
account=current_user,
|
||||
permission=payload.permission or DatasetPermissionEnum.ONLY_ME,
|
||||
provider=payload.provider,
|
||||
external_knowledge_api_id=payload.external_knowledge_api_id,
|
||||
external_knowledge_id=payload.external_knowledge_id,
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
provider=args["provider"],
|
||||
external_knowledge_api_id=args["external_knowledge_api_id"],
|
||||
external_knowledge_id=args["external_knowledge_id"],
|
||||
)
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
@@ -409,7 +399,18 @@ class DatasetApi(Resource):
|
||||
|
||||
@console_ns.doc("update_dataset")
|
||||
@console_ns.doc(description="Update dataset details")
|
||||
@console_ns.expect(console_ns.models[DatasetUpdatePayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"UpdateDatasetRequest",
|
||||
{
|
||||
"name": fields.String(description="Dataset name"),
|
||||
"description": fields.String(description="Dataset description"),
|
||||
"permission": fields.String(description="Dataset permission"),
|
||||
"indexing_technique": fields.String(description="Indexing technique"),
|
||||
"external_retrieval_model": fields.Raw(description="External retrieval model settings"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Dataset updated successfully", dataset_detail_model)
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@@ -423,25 +424,93 @@ class DatasetApi(Resource):
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
payload = DatasetUpdatePayload.model_validate(console_ns.payload or {})
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
help="type is required. Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument("description", location="json", store_missing=False, type=validate_description_length)
|
||||
.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
location="json",
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
help="Invalid indexing technique.",
|
||||
)
|
||||
.add_argument(
|
||||
"permission",
|
||||
type=str,
|
||||
location="json",
|
||||
choices=(
|
||||
DatasetPermissionEnum.ONLY_ME,
|
||||
DatasetPermissionEnum.ALL_TEAM,
|
||||
DatasetPermissionEnum.PARTIAL_TEAM,
|
||||
),
|
||||
help="Invalid permission.",
|
||||
)
|
||||
.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
|
||||
.add_argument(
|
||||
"embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
|
||||
)
|
||||
.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
|
||||
.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
|
||||
.add_argument(
|
||||
"external_retrieval_model",
|
||||
type=dict,
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="json",
|
||||
help="Invalid external retrieval model.",
|
||||
)
|
||||
.add_argument(
|
||||
"external_knowledge_id",
|
||||
type=str,
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="json",
|
||||
help="Invalid external knowledge id.",
|
||||
)
|
||||
.add_argument(
|
||||
"external_knowledge_api_id",
|
||||
type=str,
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="json",
|
||||
help="Invalid external knowledge api id.",
|
||||
)
|
||||
.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="json",
|
||||
help="Invalid icon info.",
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
data = request.get_json()
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check embedding model setting
|
||||
if (
|
||||
payload.indexing_technique == "high_quality"
|
||||
and payload.embedding_model_provider is not None
|
||||
and payload.embedding_model is not None
|
||||
data.get("indexing_technique") == "high_quality"
|
||||
and data.get("embedding_model_provider") is not None
|
||||
and data.get("embedding_model") is not None
|
||||
):
|
||||
is_multimodal = DatasetService.check_is_multimodal_model(
|
||||
dataset.tenant_id, payload.embedding_model_provider, payload.embedding_model
|
||||
DatasetService.check_embedding_model_setting(
|
||||
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
|
||||
)
|
||||
payload.is_multimodal = is_multimodal
|
||||
payload_data = payload.model_dump(exclude_unset=True)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
DatasetPermissionService.check_permission(
|
||||
current_user, dataset, payload.permission, payload.partial_member_list
|
||||
current_user, dataset, data.get("permission"), data.get("partial_member_list")
|
||||
)
|
||||
|
||||
dataset = DatasetService.update_dataset(dataset_id_str, payload_data, current_user)
|
||||
dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
|
||||
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
@@ -449,10 +518,15 @@ class DatasetApi(Resource):
|
||||
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
if payload.partial_member_list is not None and payload.permission == DatasetPermissionEnum.PARTIAL_TEAM:
|
||||
DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id_str, payload.partial_member_list)
|
||||
if data.get("partial_member_list") and data.get("permission") == "partial_members":
|
||||
DatasetPermissionService.update_partial_member_list(
|
||||
tenant_id, dataset_id_str, data.get("partial_member_list")
|
||||
)
|
||||
# clear partial member list when permission is only_me or all_team_members
|
||||
elif payload.permission in {DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM}:
|
||||
elif (
|
||||
data.get("permission") == DatasetPermissionEnum.ONLY_ME
|
||||
or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
|
||||
):
|
||||
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
|
||||
|
||||
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
@@ -541,10 +615,24 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.expect(console_ns.models[IndexingEstimatePayload.__name__])
|
||||
def post(self):
|
||||
payload = IndexingEstimatePayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("info_list", type=dict, required=True, nullable=True, location="json")
|
||||
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
||||
.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
location="json",
|
||||
)
|
||||
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
.add_argument("dataset_id", type=str, required=False, nullable=False, location="json")
|
||||
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
# validate args
|
||||
DocumentService.estimate_args_validate(args)
|
||||
|
||||
@@ -6,14 +6,31 @@ from typing import Literal, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import asc, desc, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.datasets.error import (
|
||||
ArchivedDocumentImmutableError,
|
||||
DocumentAlreadyFinishedError,
|
||||
DocumentIndexingError,
|
||||
IndexingEstimateError,
|
||||
InvalidActionError,
|
||||
InvalidMetadataError,
|
||||
)
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
)
|
||||
from core.errors.error import (
|
||||
LLMBadRequestError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
@@ -38,30 +55,10 @@ from fields.document_fields import (
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DocumentPipelineExecutionLog
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
|
||||
|
||||
from ..app.error import (
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from ..datasets.error import (
|
||||
ArchivedDocumentImmutableError,
|
||||
DocumentAlreadyFinishedError,
|
||||
DocumentIndexingError,
|
||||
IndexingEstimateError,
|
||||
InvalidActionError,
|
||||
InvalidMetadataError,
|
||||
)
|
||||
from ..wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
)
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -96,24 +93,6 @@ dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(docume
|
||||
dataset_and_document_model = _get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy)
|
||||
|
||||
|
||||
class DocumentRetryPayload(BaseModel):
|
||||
document_ids: list[str]
|
||||
|
||||
|
||||
class DocumentRenamePayload(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
KnowledgeConfig,
|
||||
ProcessRule,
|
||||
RetrievalModel,
|
||||
DocumentRetryPayload,
|
||||
DocumentRenamePayload,
|
||||
)
|
||||
|
||||
|
||||
class DocumentResource(Resource):
|
||||
def get_document(self, dataset_id: str, document_id: str) -> Document:
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@@ -222,9 +201,8 @@ class DatasetDocumentListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
def get(self, dataset_id: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
@@ -332,7 +310,6 @@ class DatasetDocumentListApi(Resource):
|
||||
@marshal_with(dataset_and_document_model)
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
|
||||
def post(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
@@ -351,7 +328,23 @@ class DatasetDocumentListApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {})
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
|
||||
)
|
||||
.add_argument("data_source", type=dict, required=False, location="json")
|
||||
.add_argument("process_rule", type=dict, required=False, location="json")
|
||||
.add_argument("duplicate", type=bool, default=True, nullable=False, location="json")
|
||||
.add_argument("original_document_id", type=str, required=False, location="json")
|
||||
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
|
||||
if not dataset.indexing_technique and not knowledge_config.indexing_technique:
|
||||
raise ValueError("indexing_technique is required.")
|
||||
@@ -397,7 +390,17 @@ class DatasetDocumentListApi(Resource):
|
||||
class DatasetInitApi(Resource):
|
||||
@console_ns.doc("init_dataset")
|
||||
@console_ns.doc(description="Initialize dataset with documents")
|
||||
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"DatasetInitRequest",
|
||||
{
|
||||
"upload_file_id": fields.String(required=True, description="Upload file ID"),
|
||||
"indexing_technique": fields.String(description="Indexing technique"),
|
||||
"process_rule": fields.Raw(description="Processing rules"),
|
||||
"data_source": fields.Raw(description="Data source configuration"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(201, "Dataset initialized successfully", dataset_and_document_model)
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@@ -412,7 +415,27 @@ class DatasetInitApi(Resource):
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {})
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
required=True,
|
||||
nullable=False,
|
||||
location="json",
|
||||
)
|
||||
.add_argument("data_source", type=dict, required=True, nullable=True, location="json")
|
||||
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
||||
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
|
||||
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
if knowledge_config.indexing_technique == "high_quality":
|
||||
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
|
||||
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
|
||||
@@ -420,14 +443,10 @@ class DatasetInitApi(Resource):
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=knowledge_config.embedding_model_provider,
|
||||
provider=args["embedding_model_provider"],
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=knowledge_config.embedding_model,
|
||||
model=args["embedding_model"],
|
||||
)
|
||||
is_multimodal = DatasetService.check_is_multimodal_model(
|
||||
current_tenant_id, knowledge_config.embedding_model_provider, knowledge_config.embedding_model
|
||||
)
|
||||
knowledge_config.is_multimodal = is_multimodal
|
||||
except InvokeAuthorizationError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
@@ -572,7 +591,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": data_source_info.get("credential_id"),
|
||||
"credential_id": data_source_info["credential_id"],
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
@@ -1057,16 +1076,19 @@ class DocumentRetryApi(DocumentResource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[DocumentRetryPayload.__name__])
|
||||
def post(self, dataset_id):
|
||||
"""retry document."""
|
||||
payload = DocumentRetryPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"document_ids", type=list, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
retry_documents = []
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
for document_id in payload.document_ids:
|
||||
for document_id in args["document_ids"]:
|
||||
try:
|
||||
document_id = str(document_id)
|
||||
|
||||
@@ -1099,7 +1121,6 @@ class DocumentRenameApi(DocumentResource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(document_fields)
|
||||
@console_ns.expect(console_ns.models[DocumentRenamePayload.__name__])
|
||||
def post(self, dataset_id, document_id):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@@ -1109,10 +1130,11 @@ class DocumentRenameApi(DocumentResource):
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_operator_permission(current_user, dataset)
|
||||
payload = DocumentRenamePayload.model_validate(console_ns.payload or {})
|
||||
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
document = DocumentService.rename_document(dataset_id, document_id, payload.name)
|
||||
document = DocumentService.rename_document(dataset_id, document_id, args["name"])
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot delete document during indexing.")
|
||||
|
||||
|
||||
@@ -1,15 +1,11 @@
|
||||
import uuid
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import String, cast, func, or_, select
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from flask_restx import Resource, marshal, reqparse
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.datasets.error import (
|
||||
@@ -40,58 +36,6 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
|
||||
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
|
||||
|
||||
|
||||
class SegmentListQuery(BaseModel):
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
status: list[str] = Field(default_factory=list)
|
||||
hit_count_gte: int | None = None
|
||||
enabled: str = Field(default="all")
|
||||
keyword: str | None = None
|
||||
page: int = Field(default=1, ge=1)
|
||||
|
||||
|
||||
class SegmentCreatePayload(BaseModel):
|
||||
content: str
|
||||
answer: str | None = None
|
||||
keywords: list[str] | None = None
|
||||
attachment_ids: list[str] | None = None
|
||||
|
||||
|
||||
class SegmentUpdatePayload(BaseModel):
|
||||
content: str
|
||||
answer: str | None = None
|
||||
keywords: list[str] | None = None
|
||||
regenerate_child_chunks: bool = False
|
||||
attachment_ids: list[str] | None = None
|
||||
|
||||
|
||||
class BatchImportPayload(BaseModel):
|
||||
upload_file_id: str
|
||||
|
||||
|
||||
class ChildChunkCreatePayload(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class ChildChunkUpdatePayload(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class ChildChunkBatchUpdatePayload(BaseModel):
|
||||
chunks: list[ChildChunkUpdateArgs]
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
SegmentListQuery,
|
||||
SegmentCreatePayload,
|
||||
SegmentUpdatePayload,
|
||||
BatchImportPayload,
|
||||
ChildChunkCreatePayload,
|
||||
ChildChunkUpdatePayload,
|
||||
ChildChunkBatchUpdatePayload,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
|
||||
class DatasetDocumentSegmentListApi(Resource):
|
||||
@setup_required
|
||||
@@ -116,18 +60,23 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
args = SegmentListQuery.model_validate(
|
||||
{
|
||||
**request.args.to_dict(),
|
||||
"status": request.args.getlist("status"),
|
||||
}
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("limit", type=int, default=20, location="args")
|
||||
.add_argument("status", type=str, action="append", default=[], location="args")
|
||||
.add_argument("hit_count_gte", type=int, default=None, location="args")
|
||||
.add_argument("enabled", type=str, default="all", location="args")
|
||||
.add_argument("keyword", type=str, default=None, location="args")
|
||||
.add_argument("page", type=int, default=1, location="args")
|
||||
)
|
||||
|
||||
page = args.page
|
||||
limit = min(args.limit, 100)
|
||||
status_list = args.status
|
||||
hit_count_gte = args.hit_count_gte
|
||||
keyword = args.keyword
|
||||
args = parser.parse_args()
|
||||
|
||||
page = args["page"]
|
||||
limit = min(args["limit"], 100)
|
||||
status_list = args["status"]
|
||||
hit_count_gte = args["hit_count_gte"]
|
||||
keyword = args["keyword"]
|
||||
|
||||
query = (
|
||||
select(DocumentSegment)
|
||||
@@ -145,34 +94,12 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
query = query.where(DocumentSegment.hit_count >= hit_count_gte)
|
||||
|
||||
if keyword:
|
||||
# Search in both content and keywords fields
|
||||
# Use database-specific methods for JSON array search
|
||||
if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql":
|
||||
# PostgreSQL: Use jsonb_array_elements_text to properly handle Unicode/Chinese text
|
||||
keywords_condition = func.array_to_string(
|
||||
func.array(
|
||||
select(func.jsonb_array_elements_text(cast(DocumentSegment.keywords, JSONB)))
|
||||
.correlate(DocumentSegment)
|
||||
.scalar_subquery()
|
||||
),
|
||||
",",
|
||||
).ilike(f"%{keyword}%")
|
||||
else:
|
||||
# MySQL: Cast JSON to string for pattern matching
|
||||
# MySQL stores Chinese text directly in JSON without Unicode escaping
|
||||
keywords_condition = cast(DocumentSegment.keywords, String).ilike(f"%{keyword}%")
|
||||
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
|
||||
|
||||
query = query.where(
|
||||
or_(
|
||||
DocumentSegment.content.ilike(f"%{keyword}%"),
|
||||
keywords_condition,
|
||||
)
|
||||
)
|
||||
|
||||
if args.enabled.lower() != "all":
|
||||
if args.enabled.lower() == "true":
|
||||
if args["enabled"].lower() != "all":
|
||||
if args["enabled"].lower() == "true":
|
||||
query = query.where(DocumentSegment.enabled == True)
|
||||
elif args.enabled.lower() == "false":
|
||||
elif args["enabled"].lower() == "false":
|
||||
query = query.where(DocumentSegment.enabled == False)
|
||||
|
||||
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
@@ -283,7 +210,6 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[SegmentCreatePayload.__name__])
|
||||
def post(self, dataset_id, document_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@@ -320,10 +246,15 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
payload = SegmentCreatePayload.model_validate(console_ns.payload or {})
|
||||
payload_dict = payload.model_dump(exclude_none=True)
|
||||
SegmentService.segment_create_args_validate(payload_dict, document)
|
||||
segment = SegmentService.create_segment(payload_dict, document, dataset)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("answer", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("keywords", type=list, required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
SegmentService.segment_create_args_validate(args, document)
|
||||
segment = SegmentService.create_segment(args, document, dataset)
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
|
||||
@@ -334,7 +265,6 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[SegmentUpdatePayload.__name__])
|
||||
def patch(self, dataset_id, document_id, segment_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@@ -383,12 +313,18 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
payload = SegmentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
payload_dict = payload.model_dump(exclude_none=True)
|
||||
SegmentService.segment_create_args_validate(payload_dict, document)
|
||||
segment = SegmentService.update_segment(
|
||||
SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("answer", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("keywords", type=list, required=False, nullable=True, location="json")
|
||||
.add_argument(
|
||||
"regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
SegmentService.segment_create_args_validate(args, document)
|
||||
segment = SegmentService.update_segment(SegmentUpdateArgs.model_validate(args), segment, document, dataset)
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
@setup_required
|
||||
@@ -441,7 +377,6 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[BatchImportPayload.__name__])
|
||||
def post(self, dataset_id, document_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@@ -456,8 +391,11 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
payload = BatchImportPayload.model_validate(console_ns.payload or {})
|
||||
upload_file_id = payload.upload_file_id
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"upload_file_id", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
upload_file_id = args["upload_file_id"]
|
||||
|
||||
upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
|
||||
if not upload_file:
|
||||
@@ -508,7 +446,6 @@ class ChildChunkAddApi(Resource):
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[ChildChunkCreatePayload.__name__])
|
||||
def post(self, dataset_id, document_id, segment_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@@ -554,9 +491,13 @@ class ChildChunkAddApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"content", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
payload = ChildChunkCreatePayload.model_validate(console_ns.payload or {})
|
||||
child_chunk = SegmentService.create_child_chunk(payload.content, segment, document, dataset)
|
||||
content = args["content"]
|
||||
child_chunk = SegmentService.create_child_chunk(content, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
|
||||
@@ -588,17 +529,18 @@ class ChildChunkAddApi(Resource):
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
args = SegmentListQuery.model_validate(
|
||||
{
|
||||
"limit": request.args.get("limit", default=20, type=int),
|
||||
"keyword": request.args.get("keyword"),
|
||||
"page": request.args.get("page", default=1, type=int),
|
||||
}
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("limit", type=int, default=20, location="args")
|
||||
.add_argument("keyword", type=str, default=None, location="args")
|
||||
.add_argument("page", type=int, default=1, location="args")
|
||||
)
|
||||
|
||||
page = args.page
|
||||
limit = min(args.limit, 100)
|
||||
keyword = args.keyword
|
||||
args = parser.parse_args()
|
||||
|
||||
page = args["page"]
|
||||
limit = min(args["limit"], 100)
|
||||
keyword = args["keyword"]
|
||||
|
||||
child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
|
||||
return {
|
||||
@@ -646,9 +588,14 @@ class ChildChunkAddApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
payload = ChildChunkBatchUpdatePayload.model_validate(console_ns.payload or {})
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"chunks", type=list, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
child_chunks = SegmentService.update_child_chunks(payload.chunks, segment, document, dataset)
|
||||
chunks_data = args["chunks"]
|
||||
chunks = [ChildChunkUpdateArgs.model_validate(chunk) for chunk in chunks_data]
|
||||
child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
return {"data": marshal(child_chunks, child_chunk_fields)}, 200
|
||||
@@ -718,7 +665,6 @@ class ChildChunkUpdateApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[ChildChunkUpdatePayload.__name__])
|
||||
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@@ -765,9 +711,13 @@ class ChildChunkUpdateApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"content", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
payload = ChildChunkUpdatePayload.model_validate(console_ns.payload or {})
|
||||
child_chunk = SegmentService.update_child_chunk(payload.content, child_chunk, segment, document, dataset)
|
||||
content = args["content"]
|
||||
child_chunk = SegmentService.update_child_chunk(content, child_chunk, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource, fields, marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
@@ -73,38 +71,10 @@ except KeyError:
|
||||
dataset_detail_model = _build_dataset_detail_model()
|
||||
|
||||
|
||||
class ExternalKnowledgeApiPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=40)
|
||||
settings: dict[str, object]
|
||||
|
||||
|
||||
class ExternalDatasetCreatePayload(BaseModel):
|
||||
external_knowledge_api_id: str
|
||||
external_knowledge_id: str
|
||||
name: str = Field(..., min_length=1, max_length=40)
|
||||
description: str | None = Field(None, max_length=400)
|
||||
external_retrieval_model: dict[str, object] | None = None
|
||||
|
||||
|
||||
class ExternalHitTestingPayload(BaseModel):
|
||||
query: str
|
||||
external_retrieval_model: dict[str, object] | None = None
|
||||
metadata_filtering_conditions: dict[str, object] | None = None
|
||||
|
||||
|
||||
class BedrockRetrievalPayload(BaseModel):
|
||||
retrieval_setting: dict[str, object]
|
||||
query: str
|
||||
knowledge_id: str
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
ExternalKnowledgeApiPayload,
|
||||
ExternalDatasetCreatePayload,
|
||||
ExternalHitTestingPayload,
|
||||
BedrockRetrievalPayload,
|
||||
)
|
||||
def _validate_name(name: str) -> str:
|
||||
if not name or len(name) < 1 or len(name) > 100:
|
||||
raise ValueError("Name must be between 1 to 100 characters.")
|
||||
return name
|
||||
|
||||
|
||||
@console_ns.route("/datasets/external-knowledge-api")
|
||||
@@ -143,12 +113,28 @@ class ExternalApiTemplateListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
|
||||
def post(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name is required. Name must be between 1 to 100 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument(
|
||||
"settings",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=False,
|
||||
required=True,
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
ExternalDatasetService.validate_api_list(payload.settings)
|
||||
ExternalDatasetService.validate_api_list(args["settings"])
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
@@ -156,7 +142,7 @@ class ExternalApiTemplateListApi(Resource):
|
||||
|
||||
try:
|
||||
external_knowledge_api = ExternalDatasetService.create_external_knowledge_api(
|
||||
tenant_id=current_tenant_id, user_id=current_user.id, args=payload.model_dump()
|
||||
tenant_id=current_tenant_id, user_id=current_user.id, args=args
|
||||
)
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
@@ -185,19 +171,35 @@ class ExternalApiTemplateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
|
||||
def patch(self, external_knowledge_api_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||
|
||||
payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
|
||||
ExternalDatasetService.validate_api_list(payload.settings)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="type is required. Name must be between 1 to 100 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument(
|
||||
"settings",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=False,
|
||||
required=True,
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
ExternalDatasetService.validate_api_list(args["settings"])
|
||||
|
||||
external_knowledge_api = ExternalDatasetService.update_external_knowledge_api(
|
||||
tenant_id=current_tenant_id,
|
||||
user_id=current_user.id,
|
||||
external_knowledge_api_id=external_knowledge_api_id,
|
||||
args=payload.model_dump(),
|
||||
args=args,
|
||||
)
|
||||
|
||||
return external_knowledge_api.to_dict(), 200
|
||||
@@ -238,7 +240,17 @@ class ExternalApiUseCheckApi(Resource):
|
||||
class ExternalDatasetCreateApi(Resource):
|
||||
@console_ns.doc("create_external_dataset")
|
||||
@console_ns.doc(description="Create external knowledge dataset")
|
||||
@console_ns.expect(console_ns.models[ExternalDatasetCreatePayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"CreateExternalDatasetRequest",
|
||||
{
|
||||
"external_knowledge_api_id": fields.String(required=True, description="External knowledge API ID"),
|
||||
"external_knowledge_id": fields.String(required=True, description="External knowledge ID"),
|
||||
"name": fields.String(required=True, description="Dataset name"),
|
||||
"description": fields.String(description="Dataset description"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(201, "External dataset created successfully", dataset_detail_model)
|
||||
@console_ns.response(400, "Invalid parameters")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@@ -249,8 +261,22 @@ class ExternalDatasetCreateApi(Resource):
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
payload = ExternalDatasetCreatePayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="name is required. Name must be between 1 to 100 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument("description", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
@@ -273,7 +299,16 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
||||
@console_ns.doc("test_external_knowledge_retrieval")
|
||||
@console_ns.doc(description="Test external knowledge retrieval for dataset")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.expect(console_ns.models[ExternalHitTestingPayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"ExternalHitTestingRequest",
|
||||
{
|
||||
"query": fields.String(required=True, description="Query text for testing"),
|
||||
"retrieval_model": fields.Raw(description="Retrieval model configuration"),
|
||||
"external_retrieval_model": fields.Raw(description="External retrieval model configuration"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "External hit testing completed successfully")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(400, "Invalid parameters")
|
||||
@@ -292,16 +327,23 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
payload = ExternalHitTestingPayload.model_validate(console_ns.payload or {})
|
||||
HitTestingService.hit_testing_args_check(payload.model_dump())
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("query", type=str, location="json")
|
||||
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
||||
try:
|
||||
response = HitTestingService.external_retrieve(
|
||||
dataset=dataset,
|
||||
query=payload.query,
|
||||
query=args["query"],
|
||||
account=current_user,
|
||||
external_retrieval_model=payload.external_retrieval_model,
|
||||
metadata_filtering_conditions=payload.metadata_filtering_conditions,
|
||||
external_retrieval_model=args["external_retrieval_model"],
|
||||
metadata_filtering_conditions=args["metadata_filtering_conditions"],
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -314,13 +356,33 @@ class BedrockRetrievalApi(Resource):
|
||||
# this api is only for internal testing
|
||||
@console_ns.doc("bedrock_retrieval_test")
|
||||
@console_ns.doc(description="Bedrock retrieval test (internal use only)")
|
||||
@console_ns.expect(console_ns.models[BedrockRetrievalPayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"BedrockRetrievalTestRequest",
|
||||
{
|
||||
"retrieval_setting": fields.Raw(required=True, description="Retrieval settings"),
|
||||
"query": fields.String(required=True, description="Query text"),
|
||||
"knowledge_id": fields.String(required=True, description="Knowledge ID"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Bedrock retrieval test completed")
|
||||
def post(self):
|
||||
payload = BedrockRetrievalPayload.model_validate(console_ns.payload or {})
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json")
|
||||
.add_argument(
|
||||
"query",
|
||||
nullable=False,
|
||||
required=True,
|
||||
type=str,
|
||||
)
|
||||
.add_argument("knowledge_id", nullable=False, required=True, type=str)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Call the knowledge retrieval service
|
||||
result = ExternalDatasetTestService.knowledge_retrieval(
|
||||
payload.retrieval_setting, payload.query, payload.knowledge_id
|
||||
args["retrieval_setting"], args["query"], args["knowledge_id"]
|
||||
)
|
||||
return result, 200
|
||||
|
||||
@@ -1,17 +1,13 @@
|
||||
from flask_restx import Resource
|
||||
from flask_restx import Resource, fields
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from libs.login import login_required
|
||||
|
||||
from .. import console_ns
|
||||
from ..datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload
|
||||
from ..wraps import (
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
setup_required,
|
||||
)
|
||||
|
||||
register_schema_model(console_ns, HitTestingPayload)
|
||||
from libs.login import login_required
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
|
||||
@@ -19,7 +15,17 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||
@console_ns.doc("test_dataset_retrieval")
|
||||
@console_ns.doc(description="Test dataset knowledge retrieval")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.expect(console_ns.models[HitTestingPayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"HitTestingRequest",
|
||||
{
|
||||
"query": fields.String(required=True, description="Query text for testing"),
|
||||
"retrieval_model": fields.Raw(description="Retrieval model configuration"),
|
||||
"top_k": fields.Integer(description="Number of top results to return"),
|
||||
"score_threshold": fields.Float(description="Score threshold for filtering results"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Hit testing completed successfully")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(400, "Invalid parameters")
|
||||
@@ -31,8 +37,7 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
dataset = self.get_and_validate_dataset(dataset_id_str)
|
||||
payload = HitTestingPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
args = self.parse_args()
|
||||
self.hit_testing_args_check(args)
|
||||
|
||||
return self.perform_hit_testing(dataset, args)
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from flask_restx import marshal, reqparse
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
@@ -29,13 +27,6 @@ from services.hit_testing_service import HitTestingService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HitTestingPayload(BaseModel):
|
||||
query: str = Field(max_length=250)
|
||||
retrieval_model: dict[str, Any] | None = None
|
||||
external_retrieval_model: dict[str, Any] | None = None
|
||||
attachment_ids: list[str] | None = None
|
||||
|
||||
|
||||
class DatasetsHitTestingBase:
|
||||
@staticmethod
|
||||
def get_and_validate_dataset(dataset_id: str):
|
||||
@@ -52,15 +43,14 @@ class DatasetsHitTestingBase:
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def hit_testing_args_check(args: dict[str, Any]):
|
||||
def hit_testing_args_check(args):
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
||||
@staticmethod
|
||||
def parse_args():
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("query", type=str, required=False, location="json")
|
||||
.add_argument("attachment_ids", type=list, required=False, location="json")
|
||||
.add_argument("query", type=str, location="json")
|
||||
.add_argument("retrieval_model", type=dict, required=False, location="json")
|
||||
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
)
|
||||
@@ -72,11 +62,10 @@ class DatasetsHitTestingBase:
|
||||
try:
|
||||
response = HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=args.get("query"),
|
||||
query=args["query"],
|
||||
account=current_user,
|
||||
retrieval_model=args.get("retrieval_model"),
|
||||
external_retrieval_model=args.get("external_retrieval_model"),
|
||||
attachment_ids=args.get("attachment_ids"),
|
||||
retrieval_model=args["retrieval_model"],
|
||||
external_retrieval_model=args["external_retrieval_model"],
|
||||
limit=10,
|
||||
)
|
||||
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_model, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
from fields.dataset_fields import dataset_metadata_fields
|
||||
@@ -17,14 +15,6 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
|
||||
class MetadataUpdatePayload(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
register_schema_models(console_ns, MetadataArgs, MetadataOperationData)
|
||||
register_schema_model(console_ns, MetadataUpdatePayload)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/metadata")
|
||||
class DatasetMetadataCreateApi(Resource):
|
||||
@setup_required
|
||||
@@ -32,10 +22,15 @@ class DatasetMetadataCreateApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@marshal_with(dataset_metadata_fields)
|
||||
@console_ns.expect(console_ns.models[MetadataArgs.__name__])
|
||||
def post(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
metadata_args = MetadataArgs.model_validate(console_ns.payload or {})
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("type", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
metadata_args = MetadataArgs.model_validate(args)
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@@ -65,11 +60,11 @@ class DatasetMetadataApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@marshal_with(dataset_metadata_fields)
|
||||
@console_ns.expect(console_ns.models[MetadataUpdatePayload.__name__])
|
||||
def patch(self, dataset_id, metadata_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = MetadataUpdatePayload.model_validate(console_ns.payload or {})
|
||||
name = payload.name
|
||||
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
name = args["name"]
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
metadata_id_str = str(metadata_id)
|
||||
@@ -136,7 +131,6 @@ class DocumentMetadataEditApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.expect(console_ns.models[MetadataOperationData.__name__])
|
||||
def post(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
@@ -145,7 +139,11 @@ class DocumentMetadataEditApi(Resource):
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata_args = MetadataOperationData.model_validate(console_ns.payload or {})
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"operation_data", type=list, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
metadata_args = MetadataOperationData.model_validate(args)
|
||||
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
|
||||
@@ -1,63 +1,20 @@
|
||||
from typing import Any
|
||||
|
||||
from flask import make_response, redirect, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from libs.helper import StrLen
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.provider_ids import DatasourceProviderID
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
|
||||
|
||||
class DatasourceCredentialPayload(BaseModel):
|
||||
name: str | None = Field(default=None, max_length=100)
|
||||
credentials: dict[str, Any]
|
||||
|
||||
|
||||
class DatasourceCredentialDeletePayload(BaseModel):
|
||||
credential_id: str
|
||||
|
||||
|
||||
class DatasourceCredentialUpdatePayload(BaseModel):
|
||||
credential_id: str
|
||||
name: str | None = Field(default=None, max_length=100)
|
||||
credentials: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class DatasourceCustomClientPayload(BaseModel):
|
||||
client_params: dict[str, Any] | None = None
|
||||
enable_oauth_custom_client: bool | None = None
|
||||
|
||||
|
||||
class DatasourceDefaultPayload(BaseModel):
|
||||
id: str
|
||||
|
||||
|
||||
class DatasourceUpdateNamePayload(BaseModel):
|
||||
credential_id: str
|
||||
name: str = Field(max_length=100)
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
DatasourceCredentialPayload,
|
||||
DatasourceCredentialDeletePayload,
|
||||
DatasourceCredentialUpdatePayload,
|
||||
DatasourceCustomClientPayload,
|
||||
DatasourceDefaultPayload,
|
||||
DatasourceUpdateNamePayload,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/oauth/plugin/<path:provider_id>/datasource/get-authorization-url")
|
||||
class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
||||
@setup_required
|
||||
@@ -164,9 +121,16 @@ class DatasourceOAuthCallback(Resource):
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
parser_datasource = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None)
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>")
|
||||
class DatasourceAuth(Resource):
|
||||
@console_ns.expect(console_ns.models[DatasourceCredentialPayload.__name__])
|
||||
@console_ns.expect(parser_datasource)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -174,7 +138,7 @@ class DatasourceAuth(Resource):
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
payload = DatasourceCredentialPayload.model_validate(console_ns.payload or {})
|
||||
args = parser_datasource.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
|
||||
@@ -182,8 +146,8 @@ class DatasourceAuth(Resource):
|
||||
datasource_provider_service.add_datasource_api_key_provider(
|
||||
tenant_id=current_tenant_id,
|
||||
provider_id=datasource_provider_id,
|
||||
credentials=payload.credentials,
|
||||
name=payload.name,
|
||||
credentials=args["credentials"],
|
||||
name=args["name"],
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
@@ -205,9 +169,14 @@ class DatasourceAuth(Resource):
|
||||
return {"result": datasources}, 200
|
||||
|
||||
|
||||
parser_datasource_delete = reqparse.RequestParser().add_argument(
|
||||
"credential_id", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/delete")
|
||||
class DatasourceAuthDeleteApi(Resource):
|
||||
@console_ns.expect(console_ns.models[DatasourceCredentialDeletePayload.__name__])
|
||||
@console_ns.expect(parser_datasource_delete)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -219,20 +188,28 @@ class DatasourceAuthDeleteApi(Resource):
|
||||
plugin_id = datasource_provider_id.plugin_id
|
||||
provider_name = datasource_provider_id.provider_name
|
||||
|
||||
payload = DatasourceCredentialDeletePayload.model_validate(console_ns.payload or {})
|
||||
args = parser_datasource_delete.parse_args()
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.remove_datasource_credentials(
|
||||
tenant_id=current_tenant_id,
|
||||
auth_id=payload.credential_id,
|
||||
auth_id=args["credential_id"],
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
parser_datasource_update = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
|
||||
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update")
|
||||
class DatasourceAuthUpdateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[DatasourceCredentialUpdatePayload.__name__])
|
||||
@console_ns.expect(parser_datasource_update)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -241,16 +218,16 @@ class DatasourceAuthUpdateApi(Resource):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
payload = DatasourceCredentialUpdatePayload.model_validate(console_ns.payload or {})
|
||||
args = parser_datasource_update.parse_args()
|
||||
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.update_datasource_credentials(
|
||||
tenant_id=current_tenant_id,
|
||||
auth_id=payload.credential_id,
|
||||
auth_id=args["credential_id"],
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
credentials=payload.credentials or {},
|
||||
name=payload.name,
|
||||
credentials=args.get("credentials", {}),
|
||||
name=args.get("name", None),
|
||||
)
|
||||
return {"result": "success"}, 201
|
||||
|
||||
@@ -281,9 +258,16 @@ class DatasourceHardCodeAuthListApi(Resource):
|
||||
return {"result": jsonable_encoder(datasources)}, 200
|
||||
|
||||
|
||||
parser_datasource_custom = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/custom-client")
|
||||
class DatasourceAuthOauthCustomClient(Resource):
|
||||
@console_ns.expect(console_ns.models[DatasourceCustomClientPayload.__name__])
|
||||
@console_ns.expect(parser_datasource_custom)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -291,14 +275,14 @@ class DatasourceAuthOauthCustomClient(Resource):
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
payload = DatasourceCustomClientPayload.model_validate(console_ns.payload or {})
|
||||
args = parser_datasource_custom.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.setup_oauth_custom_client_params(
|
||||
tenant_id=current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
client_params=payload.client_params or {},
|
||||
enabled=payload.enable_oauth_custom_client or False,
|
||||
client_params=args.get("client_params", {}),
|
||||
enabled=args.get("enable_oauth_custom_client", False),
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@@ -317,9 +301,12 @@ class DatasourceAuthOauthCustomClient(Resource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
parser_default = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/default")
|
||||
class DatasourceAuthDefaultApi(Resource):
|
||||
@console_ns.expect(console_ns.models[DatasourceDefaultPayload.__name__])
|
||||
@console_ns.expect(parser_default)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -327,20 +314,27 @@ class DatasourceAuthDefaultApi(Resource):
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
payload = DatasourceDefaultPayload.model_validate(console_ns.payload or {})
|
||||
args = parser_default.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.set_default_datasource_provider(
|
||||
tenant_id=current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
credential_id=payload.id,
|
||||
credential_id=args["id"],
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
parser_update_name = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
|
||||
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update-name")
|
||||
class DatasourceUpdateProviderNameApi(Resource):
|
||||
@console_ns.expect(console_ns.models[DatasourceUpdateNamePayload.__name__])
|
||||
@console_ns.expect(parser_update_name)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -348,13 +342,13 @@ class DatasourceUpdateProviderNameApi(Resource):
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
payload = DatasourceUpdateNamePayload.model_validate(console_ns.payload or {})
|
||||
args = parser_update_name.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.update_datasource_provider_name(
|
||||
tenant_id=current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
name=payload.name,
|
||||
credential_id=payload.credential_id,
|
||||
name=args["name"],
|
||||
credential_id=args["credential_id"],
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@@ -26,7 +26,7 @@ console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=D
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
|
||||
class DataSourceContentPreviewApi(Resource):
|
||||
@console_ns.expect(console_ns.models[Parser.__name__])
|
||||
@console_ns.expect(console_ns.models[Parser.__name__], validate=True)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
@@ -22,6 +20,18 @@ from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _validate_name(name: str) -> str:
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description: str) -> str:
|
||||
if len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/templates")
|
||||
class PipelineTemplateListApi(Resource):
|
||||
@setup_required
|
||||
@@ -49,15 +59,6 @@ class PipelineTemplateDetailApi(Resource):
|
||||
return pipeline_template, 200
|
||||
|
||||
|
||||
class Payload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=40)
|
||||
description: str = Field(default="", max_length=400)
|
||||
icon_info: dict[str, object] | None = None
|
||||
|
||||
|
||||
register_schema_models(console_ns, Payload)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/customized/templates/<string:template_id>")
|
||||
class CustomizedPipelineTemplateApi(Resource):
|
||||
@setup_required
|
||||
@@ -65,8 +66,31 @@ class CustomizedPipelineTemplateApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def patch(self, template_id: str):
|
||||
payload = Payload.model_validate(console_ns.payload or {})
|
||||
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(payload.model_dump())
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument(
|
||||
"description",
|
||||
type=_validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=True,
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(args)
|
||||
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
|
||||
return 200
|
||||
|
||||
@@ -95,14 +119,36 @@ class CustomizedPipelineTemplateApi(Resource):
|
||||
|
||||
@console_ns.route("/rag/pipelines/<string:pipeline_id>/customized/publish")
|
||||
class PublishCustomizedPipelineTemplateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[Payload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@knowledge_pipeline_publish_enabled
|
||||
def post(self, pipeline_id: str):
|
||||
payload = Payload.model_validate(console_ns.payload or {})
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument(
|
||||
"description",
|
||||
type=_validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=True,
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, payload.model_dump())
|
||||
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args)
|
||||
return {"result": "success"}
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
from flask_restx import Resource, marshal
|
||||
from pydantic import BaseModel
|
||||
from flask_restx import Resource, marshal, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.wraps import (
|
||||
@@ -21,22 +19,22 @@ from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo,
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
|
||||
class RagPipelineDatasetImportPayload(BaseModel):
|
||||
yaml_content: str
|
||||
|
||||
|
||||
register_schema_model(console_ns, RagPipelineDatasetImportPayload)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/dataset")
|
||||
class CreateRagPipelineDatasetApi(Resource):
|
||||
@console_ns.expect(console_ns.models[RagPipelineDatasetImportPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
payload = RagPipelineDatasetImportPayload.model_validate(console_ns.payload or {})
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"yaml_content",
|
||||
type=str,
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="yaml_content is required.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
@@ -51,7 +49,7 @@ class CreateRagPipelineDatasetApi(Resource):
|
||||
),
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
partial_member_list=None,
|
||||
yaml_content=payload.yaml_content,
|
||||
yaml_content=args["yaml_content"],
|
||||
)
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import logging
|
||||
from typing import Any, NoReturn
|
||||
from typing import NoReturn
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask import Response
|
||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
DraftWorkflowNotExist,
|
||||
@@ -35,21 +33,19 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _create_pagination_parser():
|
||||
class PaginationQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=100_000)
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
register_schema_models(console_ns, PaginationQuery)
|
||||
|
||||
return PaginationQuery
|
||||
|
||||
|
||||
class WorkflowDraftVariablePatchPayload(BaseModel):
|
||||
name: str | None = None
|
||||
value: Any | None = None
|
||||
|
||||
|
||||
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"page",
|
||||
type=inputs.int_range(1, 100_000),
|
||||
required=False,
|
||||
default=1,
|
||||
location="args",
|
||||
help="the page of data requested",
|
||||
)
|
||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
|
||||
@@ -97,8 +93,8 @@ class RagPipelineVariableCollectionApi(Resource):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
pagination = _create_pagination_parser()
|
||||
query = pagination.model_validate(request.args.to_dict())
|
||||
parser = _create_pagination_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
@@ -113,8 +109,8 @@ class RagPipelineVariableCollectionApi(Resource):
|
||||
)
|
||||
workflow_vars = draft_var_srv.list_variables_without_values(
|
||||
app_id=pipeline.id,
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
)
|
||||
|
||||
return workflow_vars
|
||||
@@ -190,7 +186,6 @@ class RagPipelineVariableApi(Resource):
|
||||
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
@console_ns.expect(console_ns.models[WorkflowDraftVariablePatchPayload.__name__])
|
||||
def patch(self, pipeline: Pipeline, variable_id: str):
|
||||
# Request payload for file types:
|
||||
#
|
||||
@@ -213,11 +208,16 @@ class RagPipelineVariableApi(Resource):
|
||||
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
||||
# }
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
||||
.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
payload = WorkflowDraftVariablePatchPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
args = parser.parse_args(strict=True)
|
||||
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user