mirror of
https://github.com/langgenius/dify.git
synced 2026-04-07 07:09:19 +08:00
Compare commits
6 Commits
deploy/age
...
move-trigg
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a30a64d51b | ||
|
|
34ef10c818 | ||
|
|
26fedca865 | ||
|
|
2fd4e9e259 | ||
|
|
9e8a4c8a71 | ||
|
|
238497b7ab |
@@ -187,12 +187,53 @@ const Template = useMemo(() => {
|
||||
|
||||
**When**: Component directly handles API calls, data transformation, or complex async operations.
|
||||
|
||||
**Dify Convention**:
|
||||
- This skill is for component decomposition, not query/mutation design.
|
||||
- When refactoring data fetching, follow `web/AGENTS.md`.
|
||||
- Use `frontend-query-mutation` for contracts, query shape, data-fetching wrappers, query/mutation call-site patterns, conditional queries, invalidation, and mutation error handling.
|
||||
- Do not introduce deprecated `useInvalid` / `useReset`.
|
||||
- Do not add thin passthrough `useQuery` wrappers during refactoring; only extract a custom hook when it truly orchestrates multiple queries/mutations or shared derived state.
|
||||
**Dify Convention**: Use `@tanstack/react-query` hooks from `web/service/use-*.ts` or create custom data hooks.
|
||||
|
||||
```typescript
|
||||
// ❌ Before: API logic in component
|
||||
const MCPServiceCard = () => {
|
||||
const [basicAppConfig, setBasicAppConfig] = useState({})
|
||||
|
||||
useEffect(() => {
|
||||
if (isBasicApp && appId) {
|
||||
(async () => {
|
||||
const res = await fetchAppDetail({ url: '/apps', id: appId })
|
||||
setBasicAppConfig(res?.model_config || {})
|
||||
})()
|
||||
}
|
||||
}, [appId, isBasicApp])
|
||||
|
||||
// More API-related logic...
|
||||
}
|
||||
|
||||
// ✅ After: Extract to data hook using React Query
|
||||
// use-app-config.ts
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { get } from '@/service/base'
|
||||
|
||||
const NAME_SPACE = 'appConfig'
|
||||
|
||||
export const useAppConfig = (appId: string, isBasicApp: boolean) => {
|
||||
return useQuery({
|
||||
enabled: isBasicApp && !!appId,
|
||||
queryKey: [NAME_SPACE, 'detail', appId],
|
||||
queryFn: () => get<AppDetailResponse>(`/apps/${appId}`),
|
||||
select: data => data?.model_config || {},
|
||||
})
|
||||
}
|
||||
|
||||
// Component becomes cleaner
|
||||
const MCPServiceCard = () => {
|
||||
const { data: config, isLoading } = useAppConfig(appId, isBasicApp)
|
||||
// UI only
|
||||
}
|
||||
```
|
||||
|
||||
**React Query Best Practices in Dify**:
|
||||
- Define `NAME_SPACE` for query key organization
|
||||
- Use `enabled` option for conditional fetching
|
||||
- Use `select` for data transformation
|
||||
- Export invalidation hooks: `useInvalidXxx`
|
||||
|
||||
**Dify Examples**:
|
||||
- `web/service/use-workflow.ts`
|
||||
|
||||
@@ -155,14 +155,48 @@ const Configuration: FC = () => {
|
||||
|
||||
## Common Hook Patterns in Dify
|
||||
|
||||
### 1. Data Fetching / Mutation Hooks
|
||||
### 1. Data Fetching Hook (React Query)
|
||||
|
||||
When hook extraction touches query or mutation code, do not use this reference as the source of truth for data-layer patterns.
|
||||
```typescript
|
||||
// Pattern: Use @tanstack/react-query for data fetching
|
||||
import { useQuery, useQueryClient } from '@tanstack/react-query'
|
||||
import { get } from '@/service/base'
|
||||
import { useInvalid } from '@/service/use-base'
|
||||
|
||||
- Follow `web/AGENTS.md` first.
|
||||
- Use `frontend-query-mutation` for contracts, query shape, data-fetching wrappers, query/mutation call-site patterns, conditional queries, invalidation, and mutation error handling.
|
||||
- Do not introduce deprecated `useInvalid` / `useReset`.
|
||||
- Do not extract thin passthrough `useQuery` hooks; only extract orchestration hooks.
|
||||
const NAME_SPACE = 'appConfig'
|
||||
|
||||
// Query keys for cache management
|
||||
export const appConfigQueryKeys = {
|
||||
detail: (appId: string) => [NAME_SPACE, 'detail', appId] as const,
|
||||
}
|
||||
|
||||
// Main data hook
|
||||
export const useAppConfig = (appId: string) => {
|
||||
return useQuery({
|
||||
enabled: !!appId,
|
||||
queryKey: appConfigQueryKeys.detail(appId),
|
||||
queryFn: () => get<AppDetailResponse>(`/apps/${appId}`),
|
||||
select: data => data?.model_config || null,
|
||||
})
|
||||
}
|
||||
|
||||
// Invalidation hook for refreshing data
|
||||
export const useInvalidAppConfig = () => {
|
||||
return useInvalid([NAME_SPACE])
|
||||
}
|
||||
|
||||
// Usage in component
|
||||
const Component = () => {
|
||||
const { data: config, isLoading, error, refetch } = useAppConfig(appId)
|
||||
const invalidAppConfig = useInvalidAppConfig()
|
||||
|
||||
const handleRefresh = () => {
|
||||
invalidAppConfig() // Invalidates cache and triggers refetch
|
||||
}
|
||||
|
||||
return <div>...</div>
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Form State Hook
|
||||
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
---
|
||||
name: frontend-query-mutation
|
||||
description: Guide for implementing Dify frontend query and mutation patterns with TanStack Query and oRPC. Trigger when creating or updating contracts in web/contract, wiring router composition, consuming consoleQuery or marketplaceQuery in components or services, deciding whether to call queryOptions() directly or extract a helper or use-* hook, handling conditional queries, cache invalidation, mutation error handling, or migrating legacy service calls to contract-first query and mutation helpers.
|
||||
---
|
||||
|
||||
# Frontend Query & Mutation
|
||||
|
||||
## Intent
|
||||
|
||||
- Keep contract as the single source of truth in `web/contract/*`.
|
||||
- Prefer contract-shaped `queryOptions()` and `mutationOptions()`.
|
||||
- Keep invalidation and mutation flow knowledge in the service layer.
|
||||
- Keep abstractions minimal to preserve TypeScript inference.
|
||||
|
||||
## Workflow
|
||||
|
||||
1. Identify the change surface.
|
||||
- Read `references/contract-patterns.md` for contract files, router composition, client helpers, and query or mutation call-site shape.
|
||||
- Read `references/runtime-rules.md` for conditional queries, invalidation, error handling, and legacy migrations.
|
||||
- Read both references when a task spans contract shape and runtime behavior.
|
||||
2. Implement the smallest abstraction that fits the task.
|
||||
- Default to direct `useQuery(...)` or `useMutation(...)` calls with oRPC helpers at the call site.
|
||||
- Extract a small shared query helper only when multiple call sites share the same extra options.
|
||||
- Create `web/service/use-{domain}.ts` only for orchestration or shared domain behavior.
|
||||
3. Preserve Dify conventions.
|
||||
- Keep contract inputs in `{ params, query?, body? }` shape.
|
||||
- Bind invalidation in the service-layer mutation definition.
|
||||
- Prefer `mutate(...)`; use `mutateAsync(...)` only when Promise semantics are required.
|
||||
|
||||
## Files Commonly Touched
|
||||
|
||||
- `web/contract/console/*.ts`
|
||||
- `web/contract/marketplace.ts`
|
||||
- `web/contract/router.ts`
|
||||
- `web/service/client.ts`
|
||||
- `web/service/use-*.ts`
|
||||
- component and hook call sites using `consoleQuery` or `marketplaceQuery`
|
||||
|
||||
## References
|
||||
|
||||
- Use `references/contract-patterns.md` for contract shape, router registration, query and mutation helpers, and anti-patterns that degrade inference.
|
||||
- Use `references/runtime-rules.md` for conditional queries, invalidation, `mutate` versus `mutateAsync`, and legacy migration rules.
|
||||
|
||||
Treat this skill as the single query and mutation entry point for Dify frontend work. Keep detailed rules in the reference files instead of duplicating them in project docs.
|
||||
@@ -1,4 +0,0 @@
|
||||
interface:
|
||||
display_name: "Frontend Query & Mutation"
|
||||
short_description: "Dify TanStack Query and oRPC patterns"
|
||||
default_prompt: "Use this skill when implementing or reviewing Dify frontend contracts, query and mutation call sites, conditional queries, invalidation, or legacy query/mutation migrations."
|
||||
@@ -1,98 +0,0 @@
|
||||
# Contract Patterns
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- Intent
|
||||
- Minimal structure
|
||||
- Core workflow
|
||||
- Query usage decision rule
|
||||
- Mutation usage decision rule
|
||||
- Anti-patterns
|
||||
- Contract rules
|
||||
- Type export
|
||||
|
||||
## Intent
|
||||
|
||||
- Keep contract as the single source of truth in `web/contract/*`.
|
||||
- Default query usage to call-site `useQuery(consoleQuery|marketplaceQuery.xxx.queryOptions(...))` when endpoint behavior maps 1:1 to the contract.
|
||||
- Keep abstractions minimal and preserve TypeScript inference.
|
||||
|
||||
## Minimal Structure
|
||||
|
||||
```text
|
||||
web/contract/
|
||||
├── base.ts
|
||||
├── router.ts
|
||||
├── marketplace.ts
|
||||
└── console/
|
||||
├── billing.ts
|
||||
└── ...other domains
|
||||
web/service/client.ts
|
||||
```
|
||||
|
||||
## Core Workflow
|
||||
|
||||
1. Define contract in `web/contract/console/{domain}.ts` or `web/contract/marketplace.ts`.
|
||||
- Use `base.route({...}).output(type<...>())` as the baseline.
|
||||
- Add `.input(type<...>())` only when the request has `params`, `query`, or `body`.
|
||||
- For `GET` without input, omit `.input(...)`; do not use `.input(type<unknown>())`.
|
||||
2. Register contract in `web/contract/router.ts`.
|
||||
- Import directly from domain files and nest by API prefix.
|
||||
3. Consume from UI call sites via oRPC query utilities.
|
||||
|
||||
```typescript
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { consoleQuery } from '@/service/client'
|
||||
|
||||
const invoiceQuery = useQuery(consoleQuery.billing.invoices.queryOptions({
|
||||
staleTime: 5 * 60 * 1000,
|
||||
throwOnError: true,
|
||||
select: invoice => invoice.url,
|
||||
}))
|
||||
```
|
||||
|
||||
## Query Usage Decision Rule
|
||||
|
||||
1. Default to direct `*.queryOptions(...)` usage at the call site.
|
||||
2. If 3 or more call sites share the same extra options, extract a small query helper, not a `use-*` passthrough hook.
|
||||
3. Create `web/service/use-{domain}.ts` only for orchestration.
|
||||
- Combine multiple queries or mutations.
|
||||
- Share domain-level derived state or invalidation helpers.
|
||||
|
||||
```typescript
|
||||
const invoicesBaseQueryOptions = () =>
|
||||
consoleQuery.billing.invoices.queryOptions({ retry: false })
|
||||
|
||||
const invoiceQuery = useQuery({
|
||||
...invoicesBaseQueryOptions(),
|
||||
throwOnError: true,
|
||||
})
|
||||
```
|
||||
|
||||
## Mutation Usage Decision Rule
|
||||
|
||||
1. Default to mutation helpers from `consoleQuery` or `marketplaceQuery`, for example `useMutation(consoleQuery.billing.bindPartnerStack.mutationOptions(...))`.
|
||||
2. If the mutation flow is heavily custom, use oRPC clients as `mutationFn`, for example `consoleClient.xxx` or `marketplaceClient.xxx`, instead of handwritten non-oRPC mutation logic.
|
||||
|
||||
## Anti-Patterns
|
||||
|
||||
- Do not wrap `useQuery` with `options?: Partial<UseQueryOptions>`.
|
||||
- Do not split local `queryKey` and `queryFn` when oRPC `queryOptions` already exists and fits the use case.
|
||||
- Do not create thin `use-*` passthrough hooks for a single endpoint.
|
||||
- These patterns can degrade inference, especially around `throwOnError` and `select`, and add unnecessary indirection.
|
||||
|
||||
## Contract Rules
|
||||
|
||||
- Input structure: always use `{ params, query?, body? }`.
|
||||
- No-input `GET`: omit `.input(...)`; do not use `.input(type<unknown>())`.
|
||||
- Path params: use `{paramName}` in the path and match it in the `params` object.
|
||||
- Router nesting: group by API prefix, for example `/billing/*` becomes `billing: {}`.
|
||||
- No barrel files: import directly from specific files.
|
||||
- Types: import from `@/types/` and use the `type<T>()` helper.
|
||||
- Mutations: prefer `mutationOptions`; use explicit `mutationKey` mainly for defaults, filtering, and devtools.
|
||||
|
||||
## Type Export
|
||||
|
||||
```typescript
|
||||
export type ConsoleInputs = InferContractRouterInputs<typeof consoleRouterContract>
|
||||
```
|
||||
@@ -1,130 +0,0 @@
|
||||
# Runtime Rules
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- Conditional queries
|
||||
- Cache invalidation
|
||||
- Key API guide
|
||||
- `mutate` vs `mutateAsync`
|
||||
- Legacy migration
|
||||
|
||||
## Conditional Queries
|
||||
|
||||
Prefer contract-shaped `queryOptions(...)`.
|
||||
When required input is missing, prefer `input: skipToken` instead of placeholder params or non-null assertions.
|
||||
Use `enabled` only for extra business gating after the input itself is already valid.
|
||||
|
||||
```typescript
|
||||
import { skipToken, useQuery } from '@tanstack/react-query'
|
||||
|
||||
// Disable the query by skipping input construction.
|
||||
function useAccessMode(appId: string | undefined) {
|
||||
return useQuery(consoleQuery.accessControl.appAccessMode.queryOptions({
|
||||
input: appId
|
||||
? { params: { appId } }
|
||||
: skipToken,
|
||||
}))
|
||||
}
|
||||
|
||||
// Avoid runtime-only guards that bypass type checking.
|
||||
function useBadAccessMode(appId: string | undefined) {
|
||||
return useQuery(consoleQuery.accessControl.appAccessMode.queryOptions({
|
||||
input: { params: { appId: appId! } },
|
||||
enabled: !!appId,
|
||||
}))
|
||||
}
|
||||
```
|
||||
|
||||
## Cache Invalidation
|
||||
|
||||
Bind invalidation in the service-layer mutation definition.
|
||||
Components may add UI feedback in call-site callbacks, but they should not decide which queries to invalidate.
|
||||
|
||||
Use:
|
||||
|
||||
- `.key()` for namespace or prefix invalidation
|
||||
- `.queryKey(...)` only for exact cache reads or writes such as `getQueryData` and `setQueryData`
|
||||
- `queryClient.invalidateQueries(...)` in mutation `onSuccess`
|
||||
|
||||
Do not use deprecated `useInvalid` from `use-base.ts`.
|
||||
|
||||
```typescript
|
||||
// Service layer owns cache invalidation.
|
||||
export const useUpdateAccessMode = () => {
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
return useMutation(consoleQuery.accessControl.updateAccessMode.mutationOptions({
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: consoleQuery.accessControl.appWhitelistSubjects.key(),
|
||||
})
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
// Component only adds UI behavior.
|
||||
updateAccessMode({ appId, mode }, {
|
||||
onSuccess: () => toast.success('...'),
|
||||
})
|
||||
|
||||
// Avoid putting invalidation knowledge in the component.
|
||||
mutate({ appId, mode }, {
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: consoleQuery.accessControl.appWhitelistSubjects.key(),
|
||||
})
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
## Key API Guide
|
||||
|
||||
- `.key(...)`
|
||||
- Use for partial matching operations.
|
||||
- Prefer it for invalidation, refetch, and cancel patterns.
|
||||
- Example: `queryClient.invalidateQueries({ queryKey: consoleQuery.billing.key() })`
|
||||
- `.queryKey(...)`
|
||||
- Use for a specific query's full key.
|
||||
- Prefer it for exact cache addressing and direct reads or writes.
|
||||
- `.mutationKey(...)`
|
||||
- Use for a specific mutation's full key.
|
||||
- Prefer it for mutation defaults registration, mutation-status filtering, and devtools grouping.
|
||||
|
||||
## `mutate` vs `mutateAsync`
|
||||
|
||||
Prefer `mutate` by default.
|
||||
Use `mutateAsync` only when Promise semantics are truly required, such as parallel mutations or sequential steps with result dependencies.
|
||||
|
||||
Rules:
|
||||
|
||||
- Event handlers should usually call `mutate(...)` with `onSuccess` or `onError`.
|
||||
- Every `await mutateAsync(...)` must be wrapped in `try/catch`.
|
||||
- Do not use `mutateAsync` when callbacks already express the flow clearly.
|
||||
|
||||
```typescript
|
||||
// Default case.
|
||||
mutation.mutate(data, {
|
||||
onSuccess: result => router.push(result.url),
|
||||
})
|
||||
|
||||
// Promise semantics are required.
|
||||
try {
|
||||
const order = await createOrder.mutateAsync(orderData)
|
||||
await confirmPayment.mutateAsync({ orderId: order.id, token })
|
||||
router.push(`/orders/${order.id}`)
|
||||
}
|
||||
catch (error) {
|
||||
toast.error(error instanceof Error ? error.message : 'Unknown error')
|
||||
}
|
||||
```
|
||||
|
||||
## Legacy Migration
|
||||
|
||||
When touching old code, migrate it toward these rules:
|
||||
|
||||
| Old pattern | New pattern |
|
||||
|---|---|
|
||||
| `useInvalid(key)` in service layer | `queryClient.invalidateQueries(...)` inside mutation `onSuccess` |
|
||||
| component-triggered invalidation after mutation | move invalidation into the service-layer mutation definition |
|
||||
| imperative fetch plus manual invalidation | wrap it in `useMutation(...mutationOptions(...))` |
|
||||
| `await mutateAsync()` without `try/catch` | switch to `mutate(...)` or add `try/catch` |
|
||||
@@ -63,8 +63,7 @@ pnpm analyze-component <path> --review
|
||||
|
||||
### File Naming
|
||||
|
||||
- Test files: `ComponentName.spec.tsx` inside a same-level `__tests__/` directory
|
||||
- Placement rule: Component, hook, and utility tests must live in a sibling `__tests__/` folder at the same level as the source under test. For example, `foo/index.tsx` maps to `foo/__tests__/index.spec.tsx`, and `foo/bar.ts` maps to `foo/__tests__/bar.spec.ts`.
|
||||
- Test files: `ComponentName.spec.tsx` (same directory as component)
|
||||
- Integration tests: `web/__tests__/` directory
|
||||
|
||||
## Test Structure Template
|
||||
|
||||
@@ -41,7 +41,7 @@ import userEvent from '@testing-library/user-event'
|
||||
// Router (if component uses useRouter, usePathname, useSearchParams)
|
||||
// WHY: Isolates tests from Next.js routing, enables testing navigation behavior
|
||||
// const mockPush = vi.fn()
|
||||
// vi.mock('@/next/navigation', () => ({
|
||||
// vi.mock('next/navigation', () => ({
|
||||
// useRouter: () => ({ push: mockPush }),
|
||||
// usePathname: () => '/test-path',
|
||||
// }))
|
||||
|
||||
103
.agents/skills/orpc-contract-first/SKILL.md
Normal file
103
.agents/skills/orpc-contract-first/SKILL.md
Normal file
@@ -0,0 +1,103 @@
|
||||
---
|
||||
name: orpc-contract-first
|
||||
description: Guide for implementing oRPC contract-first API patterns in Dify frontend. Trigger when creating or updating contracts in web/contract, wiring router composition, integrating TanStack Query with typed contracts, migrating legacy service calls to oRPC, or deciding whether to call queryOptions directly vs extracting a helper or use-* hook in web/service.
|
||||
---
|
||||
|
||||
# oRPC Contract-First Development
|
||||
|
||||
## Intent
|
||||
|
||||
- Keep contract as single source of truth in `web/contract/*`.
|
||||
- Default query usage: call-site `useQuery(consoleQuery|marketplaceQuery.xxx.queryOptions(...))` when endpoint behavior maps 1:1 to the contract.
|
||||
- Keep abstractions minimal and preserve TypeScript inference.
|
||||
|
||||
## Minimal Structure
|
||||
|
||||
```text
|
||||
web/contract/
|
||||
├── base.ts
|
||||
├── router.ts
|
||||
├── marketplace.ts
|
||||
└── console/
|
||||
├── billing.ts
|
||||
└── ...other domains
|
||||
web/service/client.ts
|
||||
```
|
||||
|
||||
## Core Workflow
|
||||
|
||||
1. Define contract in `web/contract/console/{domain}.ts` or `web/contract/marketplace.ts`
|
||||
- Use `base.route({...}).output(type<...>())` as baseline.
|
||||
- Add `.input(type<...>())` only when request has `params/query/body`.
|
||||
- For `GET` without input, omit `.input(...)` (do not use `.input(type<unknown>())`).
|
||||
2. Register contract in `web/contract/router.ts`
|
||||
- Import directly from domain files and nest by API prefix.
|
||||
3. Consume from UI call sites via oRPC query utils.
|
||||
|
||||
```typescript
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { consoleQuery } from '@/service/client'
|
||||
|
||||
const invoiceQuery = useQuery(consoleQuery.billing.invoices.queryOptions({
|
||||
staleTime: 5 * 60 * 1000,
|
||||
throwOnError: true,
|
||||
select: invoice => invoice.url,
|
||||
}))
|
||||
```
|
||||
|
||||
## Query Usage Decision Rule
|
||||
|
||||
1. Default: call site directly uses `*.queryOptions(...)`.
|
||||
2. If 3+ call sites share the same extra options (for example `retry: false`), extract a small queryOptions helper, not a `use-*` passthrough hook.
|
||||
3. Create `web/service/use-{domain}.ts` only for orchestration:
|
||||
- Combine multiple queries/mutations.
|
||||
- Share domain-level derived state or invalidation helpers.
|
||||
|
||||
```typescript
|
||||
const invoicesBaseQueryOptions = () =>
|
||||
consoleQuery.billing.invoices.queryOptions({ retry: false })
|
||||
|
||||
const invoiceQuery = useQuery({
|
||||
...invoicesBaseQueryOptions(),
|
||||
throwOnError: true,
|
||||
})
|
||||
```
|
||||
|
||||
## Mutation Usage Decision Rule
|
||||
|
||||
1. Default: call mutation helpers from `consoleQuery` / `marketplaceQuery`, for example `useMutation(consoleQuery.billing.bindPartnerStack.mutationOptions(...))`.
|
||||
2. If mutation flow is heavily custom, use oRPC clients as `mutationFn` (for example `consoleClient.xxx` / `marketplaceClient.xxx`), instead of generic handwritten non-oRPC mutation logic.
|
||||
|
||||
## Key API Guide (`.key` vs `.queryKey` vs `.mutationKey`)
|
||||
|
||||
- `.key(...)`:
|
||||
- Use for partial matching operations (recommended for invalidation/refetch/cancel patterns).
|
||||
- Example: `queryClient.invalidateQueries({ queryKey: consoleQuery.billing.key() })`
|
||||
- `.queryKey(...)`:
|
||||
- Use for a specific query's full key (exact query identity / direct cache addressing).
|
||||
- `.mutationKey(...)`:
|
||||
- Use for a specific mutation's full key.
|
||||
- Typical use cases: mutation defaults registration, mutation-status filtering (`useIsMutating`, `queryClient.isMutating`), or explicit devtools grouping.
|
||||
|
||||
## Anti-Patterns
|
||||
|
||||
- Do not wrap `useQuery` with `options?: Partial<UseQueryOptions>`.
|
||||
- Do not split local `queryKey/queryFn` when oRPC `queryOptions` already exists and fits the use case.
|
||||
- Do not create thin `use-*` passthrough hooks for a single endpoint.
|
||||
- Reason: these patterns can degrade inference (`data` may become `unknown`, especially around `throwOnError`/`select`) and add unnecessary indirection.
|
||||
|
||||
## Contract Rules
|
||||
|
||||
- **Input structure**: Always use `{ params, query?, body? }` format
|
||||
- **No-input GET**: Omit `.input(...)`; do not use `.input(type<unknown>())`
|
||||
- **Path params**: Use `{paramName}` in path, match in `params` object
|
||||
- **Router nesting**: Group by API prefix (e.g., `/billing/*` -> `billing: {}`)
|
||||
- **No barrel files**: Import directly from specific files
|
||||
- **Types**: Import from `@/types/`, use `type<T>()` helper
|
||||
- **Mutations**: Prefer `mutationOptions`; use explicit `mutationKey` mainly for defaults/filtering/devtools
|
||||
|
||||
## Type Export
|
||||
|
||||
```typescript
|
||||
export type ConsoleInputs = InferContractRouterInputs<typeof consoleRouterContract>
|
||||
```
|
||||
@@ -1 +0,0 @@
|
||||
../../.agents/skills/frontend-query-mutation
|
||||
1
.claude/skills/orpc-contract-first
Symbolic link
1
.claude/skills/orpc-contract-first
Symbolic link
@@ -0,0 +1 @@
|
||||
../../.agents/skills/orpc-contract-first
|
||||
@@ -7,7 +7,7 @@ cd web && pnpm install
|
||||
pipx install uv
|
||||
|
||||
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc
|
||||
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev:inspect\"" >> ~/.bashrc
|
||||
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
|
||||
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
have_fun: false
|
||||
memory_config:
|
||||
disabled: false
|
||||
code_review:
|
||||
disable: true
|
||||
comment_severity_threshold: MEDIUM
|
||||
max_review_comments: -1
|
||||
pull_request_opened:
|
||||
help: false
|
||||
summary: false
|
||||
code_review: false
|
||||
include_drafts: false
|
||||
ignore_patterns: []
|
||||
9
.github/actions/setup-web/action.yml
vendored
9
.github/actions/setup-web/action.yml
vendored
@@ -4,9 +4,10 @@ runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Setup Vite+
|
||||
uses: voidzero-dev/setup-vp@20553a7a7429c429a74894104a2835d7fed28a72 # v1.3.0
|
||||
uses: voidzero-dev/setup-vp@b5d848f5a62488f3d3d920f8aa6ac318a60c5f07 # v1
|
||||
with:
|
||||
working-directory: web
|
||||
node-version-file: .nvmrc
|
||||
node-version-file: "./web/.nvmrc"
|
||||
cache: true
|
||||
run-install: true
|
||||
run-install: |
|
||||
- cwd: ./web
|
||||
args: ['--frozen-lockfile']
|
||||
|
||||
2
.github/workflows/anti-slop.yml
vendored
2
.github/workflows/anti-slop.yml
vendored
@@ -12,7 +12,7 @@ jobs:
|
||||
anti-slop:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: peakoss/anti-slop@85daca1880e9e1af197fc06ea03349daf08f4202 # v0.2.1
|
||||
- uses: peakoss/anti-slop@v0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
close-pr: false
|
||||
|
||||
38
.github/workflows/api-tests.yml
vendored
38
.github/workflows/api-tests.yml
vendored
@@ -2,12 +2,6 @@ name: Run Pytest
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
secrets:
|
||||
CODECOV_TOKEN:
|
||||
required: false
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: api-tests-${{ github.head_ref || github.run_id }}
|
||||
@@ -17,8 +11,6 @@ jobs:
|
||||
test:
|
||||
name: API Tests
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@@ -32,11 +24,10 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@@ -88,12 +79,21 @@ jobs:
|
||||
api/tests/test_containers_integration_tests \
|
||||
api/tests/unit_tests
|
||||
|
||||
- name: Report coverage
|
||||
if: ${{ env.CODECOV_TOKEN != '' && matrix.python-version == '3.12' }}
|
||||
uses: codecov/codecov-action@1af58845a975a7985b0beb0cbe6fbbb71a41dbad # v5.5.3
|
||||
with:
|
||||
files: ./coverage.xml
|
||||
disable_search: true
|
||||
flags: api
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
|
||||
- name: Coverage Summary
|
||||
run: |
|
||||
set -x
|
||||
# Extract coverage percentage and create a summary
|
||||
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
|
||||
|
||||
# Create a detailed coverage summary
|
||||
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
|
||||
{
|
||||
echo ""
|
||||
echo "<details><summary>File-level coverage (click to expand)</summary>"
|
||||
echo ""
|
||||
echo '```'
|
||||
uv run --project api coverage report -m
|
||||
echo '```'
|
||||
echo "</details>"
|
||||
} >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
7
.github/workflows/autofix.yml
vendored
7
.github/workflows/autofix.yml
vendored
@@ -39,7 +39,7 @@ jobs:
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
- uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
|
||||
- name: Generate Docker Compose
|
||||
if: steps.docker-compose-changes.outputs.any_changed == 'true'
|
||||
@@ -94,6 +94,11 @@ jobs:
|
||||
find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \;
|
||||
find . -name "*.py.bak" -type f -delete
|
||||
|
||||
# mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter.
|
||||
- name: mdformat
|
||||
run: |
|
||||
uvx --python 3.13 mdformat . --exclude ".agents/skills/**"
|
||||
|
||||
- name: Setup web environment
|
||||
if: steps.web-changes.outputs.any_changed == 'true'
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
6
.github/workflows/build-push.yml
vendored
6
.github/workflows/build-push.yml
vendored
@@ -76,9 +76,7 @@ jobs:
|
||||
with:
|
||||
context: "{{defaultContext}}:${{ matrix.context }}"
|
||||
platforms: ${{ matrix.platform }}
|
||||
build-args: |
|
||||
COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }}
|
||||
ENABLE_PROD_SOURCEMAP=${{ matrix.context == 'web' && github.ref_name == 'deploy/dev' }}
|
||||
build-args: COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
outputs: type=image,name=${{ env[matrix.image_name_env] }},push-by-digest=true,name-canonical=true,push=true
|
||||
cache-from: type=gha,scope=${{ matrix.service_name }}
|
||||
@@ -115,7 +113,7 @@ jobs:
|
||||
context: "web"
|
||||
steps:
|
||||
- name: Download digests
|
||||
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.0
|
||||
with:
|
||||
path: /tmp/digests
|
||||
pattern: digests-${{ matrix.context }}-*
|
||||
|
||||
4
.github/workflows/db-migration-test.yml
vendored
4
.github/workflows/db-migration-test.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
@@ -69,7 +69,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
|
||||
7
.github/workflows/main-ci.yml
vendored
7
.github/workflows/main-ci.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
migration-changed: ${{ steps.changes.outputs.migration }}
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- uses: dorny/paths-filter@fbd0ab8f3e69293af611ebaee6363fc25e6d187d # v4.0.1
|
||||
- uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2
|
||||
id: changes
|
||||
with:
|
||||
filters: |
|
||||
@@ -56,14 +56,15 @@ jobs:
|
||||
needs: check-changes
|
||||
if: needs.check-changes.outputs.api-changed == 'true'
|
||||
uses: ./.github/workflows/api-tests.yml
|
||||
secrets: inherit
|
||||
|
||||
web-tests:
|
||||
name: Web Tests
|
||||
needs: check-changes
|
||||
if: needs.check-changes.outputs.web-changed == 'true'
|
||||
uses: ./.github/workflows/web-tests.yml
|
||||
secrets: inherit
|
||||
with:
|
||||
base_sha: ${{ github.event_name == 'pull_request' && github.event.pull_request.base.sha || github.event.before }}
|
||||
head_sha: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
||||
|
||||
style-check:
|
||||
name: Style Check
|
||||
|
||||
2
.github/workflows/pyrefly-diff.yml
vendored
2
.github/workflows/pyrefly-diff.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python & UV
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
|
||||
31
.github/workflows/style.yml
vendored
31
.github/workflows/style.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
||||
|
||||
- name: Setup UV and Python
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
with:
|
||||
enable-cache: false
|
||||
python-version: "3.12"
|
||||
@@ -84,20 +84,20 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Restore ESLint cache
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
id: eslint-cache-restore
|
||||
uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
|
||||
with:
|
||||
path: web/.eslintcache
|
||||
key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
|
||||
|
||||
- name: Web style check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: vp run lint:ci
|
||||
run: |
|
||||
vp run lint:ci
|
||||
# pnpm run lint:report
|
||||
# continue-on-error: true
|
||||
|
||||
# - name: Annotate Code
|
||||
# if: steps.changed-files.outputs.any_changed == 'true' && github.event_name == 'pull_request'
|
||||
# uses: DerLev/eslint-annotations@51347b3a0abfb503fc8734d5ae31c4b151297fae
|
||||
# with:
|
||||
# eslint-report: web/eslint_report.json
|
||||
# github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Web tsslint
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
@@ -114,13 +114,6 @@ jobs:
|
||||
working-directory: ./web
|
||||
run: vp run knip
|
||||
|
||||
- name: Save ESLint cache
|
||||
if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true'
|
||||
uses: actions/cache/save@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
|
||||
with:
|
||||
path: web/.eslintcache
|
||||
key: ${{ steps.eslint-cache-restore.outputs.cache-primary-key }}
|
||||
|
||||
superlinter:
|
||||
name: SuperLinter
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
2
.github/workflows/translate-i18n-claude.yml
vendored
2
.github/workflows/translate-i18n-claude.yml
vendored
@@ -120,7 +120,7 @@ jobs:
|
||||
|
||||
- name: Run Claude Code for Translation Sync
|
||||
if: steps.detect_changes.outputs.CHANGED_FILES != ''
|
||||
uses: anthropics/claude-code-action@ff9acae5886d41a99ed4ec14b7dc147d55834722 # v1.0.77
|
||||
uses: anthropics/claude-code-action@26ec041249acb0a944c0a47b6c0c13f05dbc5b44 # v1.0.70
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
2
.github/workflows/vdb-tests.yml
vendored
2
.github/workflows/vdb-tests.yml
vendored
@@ -31,7 +31,7 @@ jobs:
|
||||
remove_tool_cache: true
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
355
.github/workflows/web-tests.yml
vendored
355
.github/workflows/web-tests.yml
vendored
@@ -2,9 +2,13 @@ name: Web Tests
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
secrets:
|
||||
CODECOV_TOKEN:
|
||||
inputs:
|
||||
base_sha:
|
||||
required: false
|
||||
type: string
|
||||
head_sha:
|
||||
required: false
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -22,8 +26,8 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
shardIndex: [1, 2, 3, 4, 5, 6]
|
||||
shardTotal: [6]
|
||||
shardIndex: [1, 2, 3, 4]
|
||||
shardTotal: [4]
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@@ -56,7 +60,7 @@ jobs:
|
||||
needs: [test]
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
VITEST_COVERAGE_SCOPE: app-components
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@@ -73,23 +77,346 @@ jobs:
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Download blob reports
|
||||
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.0
|
||||
with:
|
||||
path: web/.vitest-reports
|
||||
pattern: blob-report-*
|
||||
merge-multiple: true
|
||||
|
||||
- name: Merge reports
|
||||
run: vp test --merge-reports --coverage --silent=passed-only
|
||||
run: vp test --merge-reports --reporter=json --reporter=agent --coverage
|
||||
|
||||
- name: Report coverage
|
||||
if: ${{ env.CODECOV_TOKEN != '' }}
|
||||
uses: codecov/codecov-action@1af58845a975a7985b0beb0cbe6fbbb71a41dbad # v5.5.3
|
||||
with:
|
||||
directory: web/coverage
|
||||
flags: web
|
||||
- name: Check app/components diff coverage
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
|
||||
BASE_SHA: ${{ inputs.base_sha }}
|
||||
HEAD_SHA: ${{ inputs.head_sha }}
|
||||
run: node ./scripts/check-components-diff-coverage.mjs
|
||||
|
||||
- name: Coverage Summary
|
||||
if: always()
|
||||
id: coverage-summary
|
||||
run: |
|
||||
set -eo pipefail
|
||||
|
||||
COVERAGE_FILE="coverage/coverage-final.json"
|
||||
COVERAGE_SUMMARY_FILE="coverage/coverage-summary.json"
|
||||
|
||||
if [ ! -f "$COVERAGE_FILE" ] && [ ! -f "$COVERAGE_SUMMARY_FILE" ]; then
|
||||
echo "has_coverage=false" >> "$GITHUB_OUTPUT"
|
||||
echo "### 🚨 Test Coverage Report :test_tube:" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo "Coverage data not found. Ensure Vitest runs with coverage enabled." >> "$GITHUB_STEP_SUMMARY"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "has_coverage=true" >> "$GITHUB_OUTPUT"
|
||||
|
||||
node <<'NODE' >> "$GITHUB_STEP_SUMMARY"
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
let libCoverage = null;
|
||||
|
||||
try {
|
||||
libCoverage = require('istanbul-lib-coverage');
|
||||
} catch (error) {
|
||||
libCoverage = null;
|
||||
}
|
||||
|
||||
const summaryPath = path.join('coverage', 'coverage-summary.json');
|
||||
const finalPath = path.join('coverage', 'coverage-final.json');
|
||||
|
||||
const hasSummary = fs.existsSync(summaryPath);
|
||||
const hasFinal = fs.existsSync(finalPath);
|
||||
|
||||
if (!hasSummary && !hasFinal) {
|
||||
console.log('### Test Coverage Summary :test_tube:');
|
||||
console.log('');
|
||||
console.log('No coverage data found.');
|
||||
process.exit(0);
|
||||
}
|
||||
|
||||
const summary = hasSummary
|
||||
? JSON.parse(fs.readFileSync(summaryPath, 'utf8'))
|
||||
: null;
|
||||
const coverage = hasFinal
|
||||
? JSON.parse(fs.readFileSync(finalPath, 'utf8'))
|
||||
: null;
|
||||
|
||||
const getLineCoverageFromStatements = (statementMap, statementHits) => {
|
||||
const lineHits = {};
|
||||
|
||||
if (!statementMap || !statementHits) {
|
||||
return lineHits;
|
||||
}
|
||||
|
||||
Object.entries(statementMap).forEach(([key, statement]) => {
|
||||
const line = statement?.start?.line;
|
||||
if (!line) {
|
||||
return;
|
||||
}
|
||||
const hits = statementHits[key] ?? 0;
|
||||
const previous = lineHits[line];
|
||||
lineHits[line] = previous === undefined ? hits : Math.max(previous, hits);
|
||||
});
|
||||
|
||||
return lineHits;
|
||||
};
|
||||
|
||||
const getFileCoverage = (entry) => (
|
||||
libCoverage ? libCoverage.createFileCoverage(entry) : null
|
||||
);
|
||||
|
||||
const getLineHits = (entry, fileCoverage) => {
|
||||
const lineHits = entry.l ?? {};
|
||||
if (Object.keys(lineHits).length > 0) {
|
||||
return lineHits;
|
||||
}
|
||||
if (fileCoverage) {
|
||||
return fileCoverage.getLineCoverage();
|
||||
}
|
||||
return getLineCoverageFromStatements(entry.statementMap ?? {}, entry.s ?? {});
|
||||
};
|
||||
|
||||
const getUncoveredLines = (entry, fileCoverage, lineHits) => {
|
||||
if (lineHits && Object.keys(lineHits).length > 0) {
|
||||
return Object.entries(lineHits)
|
||||
.filter(([, count]) => count === 0)
|
||||
.map(([line]) => Number(line))
|
||||
.sort((a, b) => a - b);
|
||||
}
|
||||
if (fileCoverage) {
|
||||
return fileCoverage.getUncoveredLines();
|
||||
}
|
||||
return [];
|
||||
};
|
||||
|
||||
const totals = {
|
||||
lines: { covered: 0, total: 0 },
|
||||
statements: { covered: 0, total: 0 },
|
||||
branches: { covered: 0, total: 0 },
|
||||
functions: { covered: 0, total: 0 },
|
||||
};
|
||||
const fileSummaries = [];
|
||||
|
||||
if (summary) {
|
||||
const totalEntry = summary.total ?? {};
|
||||
['lines', 'statements', 'branches', 'functions'].forEach((key) => {
|
||||
if (totalEntry[key]) {
|
||||
totals[key].covered = totalEntry[key].covered ?? 0;
|
||||
totals[key].total = totalEntry[key].total ?? 0;
|
||||
}
|
||||
});
|
||||
|
||||
Object.entries(summary)
|
||||
.filter(([file]) => file !== 'total')
|
||||
.forEach(([file, data]) => {
|
||||
fileSummaries.push({
|
||||
file,
|
||||
pct: data.lines?.pct ?? data.statements?.pct ?? 0,
|
||||
lines: {
|
||||
covered: data.lines?.covered ?? 0,
|
||||
total: data.lines?.total ?? 0,
|
||||
},
|
||||
});
|
||||
});
|
||||
} else if (coverage) {
|
||||
Object.entries(coverage).forEach(([file, entry]) => {
|
||||
const fileCoverage = getFileCoverage(entry);
|
||||
const lineHits = getLineHits(entry, fileCoverage);
|
||||
const statementHits = entry.s ?? {};
|
||||
const branchHits = entry.b ?? {};
|
||||
const functionHits = entry.f ?? {};
|
||||
|
||||
const lineTotal = Object.keys(lineHits).length;
|
||||
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
|
||||
|
||||
const statementTotal = Object.keys(statementHits).length;
|
||||
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
|
||||
|
||||
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
|
||||
const branchCovered = Object.values(branchHits).reduce(
|
||||
(acc, branches) => acc + branches.filter((n) => n > 0).length,
|
||||
0,
|
||||
);
|
||||
|
||||
const functionTotal = Object.keys(functionHits).length;
|
||||
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
|
||||
|
||||
totals.lines.total += lineTotal;
|
||||
totals.lines.covered += lineCovered;
|
||||
totals.statements.total += statementTotal;
|
||||
totals.statements.covered += statementCovered;
|
||||
totals.branches.total += branchTotal;
|
||||
totals.branches.covered += branchCovered;
|
||||
totals.functions.total += functionTotal;
|
||||
totals.functions.covered += functionCovered;
|
||||
|
||||
const pct = (covered, tot) => (tot > 0 ? (covered / tot) * 100 : 0);
|
||||
|
||||
fileSummaries.push({
|
||||
file,
|
||||
pct: pct(lineCovered || statementCovered, lineTotal || statementTotal),
|
||||
lines: {
|
||||
covered: lineCovered || statementCovered,
|
||||
total: lineTotal || statementTotal,
|
||||
},
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
const pct = (covered, tot) => (tot > 0 ? ((covered / tot) * 100).toFixed(2) : '0.00');
|
||||
|
||||
console.log('### Test Coverage Summary :test_tube:');
|
||||
console.log('');
|
||||
console.log('| Metric | Coverage | Covered / Total |');
|
||||
console.log('|--------|----------|-----------------|');
|
||||
console.log(`| Lines | ${pct(totals.lines.covered, totals.lines.total)}% | ${totals.lines.covered} / ${totals.lines.total} |`);
|
||||
console.log(`| Statements | ${pct(totals.statements.covered, totals.statements.total)}% | ${totals.statements.covered} / ${totals.statements.total} |`);
|
||||
console.log(`| Branches | ${pct(totals.branches.covered, totals.branches.total)}% | ${totals.branches.covered} / ${totals.branches.total} |`);
|
||||
console.log(`| Functions | ${pct(totals.functions.covered, totals.functions.total)}% | ${totals.functions.covered} / ${totals.functions.total} |`);
|
||||
|
||||
console.log('');
|
||||
console.log('<details><summary>File coverage (lowest lines first)</summary>');
|
||||
console.log('');
|
||||
console.log('```');
|
||||
fileSummaries
|
||||
.sort((a, b) => (a.pct - b.pct) || (b.lines.total - a.lines.total))
|
||||
.slice(0, 25)
|
||||
.forEach(({ file, pct, lines }) => {
|
||||
console.log(`${pct.toFixed(2)}%\t${lines.covered}/${lines.total}\t${file}`);
|
||||
});
|
||||
console.log('```');
|
||||
console.log('</details>');
|
||||
|
||||
if (coverage) {
|
||||
const pctValue = (covered, tot) => {
|
||||
if (tot === 0) {
|
||||
return '0';
|
||||
}
|
||||
return ((covered / tot) * 100)
|
||||
.toFixed(2)
|
||||
.replace(/\.?0+$/, '');
|
||||
};
|
||||
|
||||
const formatLineRanges = (lines) => {
|
||||
if (lines.length === 0) {
|
||||
return '';
|
||||
}
|
||||
const ranges = [];
|
||||
let start = lines[0];
|
||||
let end = lines[0];
|
||||
|
||||
for (let i = 1; i < lines.length; i += 1) {
|
||||
const current = lines[i];
|
||||
if (current === end + 1) {
|
||||
end = current;
|
||||
continue;
|
||||
}
|
||||
ranges.push(start === end ? `${start}` : `${start}-${end}`);
|
||||
start = current;
|
||||
end = current;
|
||||
}
|
||||
ranges.push(start === end ? `${start}` : `${start}-${end}`);
|
||||
return ranges.join(',');
|
||||
};
|
||||
|
||||
const tableTotals = {
|
||||
statements: { covered: 0, total: 0 },
|
||||
branches: { covered: 0, total: 0 },
|
||||
functions: { covered: 0, total: 0 },
|
||||
lines: { covered: 0, total: 0 },
|
||||
};
|
||||
const tableRows = Object.entries(coverage)
|
||||
.map(([file, entry]) => {
|
||||
const fileCoverage = getFileCoverage(entry);
|
||||
const lineHits = getLineHits(entry, fileCoverage);
|
||||
const statementHits = entry.s ?? {};
|
||||
const branchHits = entry.b ?? {};
|
||||
const functionHits = entry.f ?? {};
|
||||
|
||||
const lineTotal = Object.keys(lineHits).length;
|
||||
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
|
||||
const statementTotal = Object.keys(statementHits).length;
|
||||
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
|
||||
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
|
||||
const branchCovered = Object.values(branchHits).reduce(
|
||||
(acc, branches) => acc + branches.filter((n) => n > 0).length,
|
||||
0,
|
||||
);
|
||||
const functionTotal = Object.keys(functionHits).length;
|
||||
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
|
||||
|
||||
tableTotals.lines.total += lineTotal;
|
||||
tableTotals.lines.covered += lineCovered;
|
||||
tableTotals.statements.total += statementTotal;
|
||||
tableTotals.statements.covered += statementCovered;
|
||||
tableTotals.branches.total += branchTotal;
|
||||
tableTotals.branches.covered += branchCovered;
|
||||
tableTotals.functions.total += functionTotal;
|
||||
tableTotals.functions.covered += functionCovered;
|
||||
|
||||
const uncoveredLines = getUncoveredLines(entry, fileCoverage, lineHits);
|
||||
|
||||
const filePath = entry.path ?? file;
|
||||
const relativePath = path.isAbsolute(filePath)
|
||||
? path.relative(process.cwd(), filePath)
|
||||
: filePath;
|
||||
|
||||
return {
|
||||
file: relativePath || file,
|
||||
statements: pctValue(statementCovered, statementTotal),
|
||||
branches: pctValue(branchCovered, branchTotal),
|
||||
functions: pctValue(functionCovered, functionTotal),
|
||||
lines: pctValue(lineCovered, lineTotal),
|
||||
uncovered: formatLineRanges(uncoveredLines),
|
||||
};
|
||||
})
|
||||
.sort((a, b) => a.file.localeCompare(b.file));
|
||||
|
||||
const columns = [
|
||||
{ key: 'file', header: 'File', align: 'left' },
|
||||
{ key: 'statements', header: '% Stmts', align: 'right' },
|
||||
{ key: 'branches', header: '% Branch', align: 'right' },
|
||||
{ key: 'functions', header: '% Funcs', align: 'right' },
|
||||
{ key: 'lines', header: '% Lines', align: 'right' },
|
||||
{ key: 'uncovered', header: 'Uncovered Line #s', align: 'left' },
|
||||
];
|
||||
|
||||
const allFilesRow = {
|
||||
file: 'All files',
|
||||
statements: pctValue(tableTotals.statements.covered, tableTotals.statements.total),
|
||||
branches: pctValue(tableTotals.branches.covered, tableTotals.branches.total),
|
||||
functions: pctValue(tableTotals.functions.covered, tableTotals.functions.total),
|
||||
lines: pctValue(tableTotals.lines.covered, tableTotals.lines.total),
|
||||
uncovered: '',
|
||||
};
|
||||
|
||||
const rowsForOutput = [allFilesRow, ...tableRows];
|
||||
const formatRow = (row) => `| ${columns
|
||||
.map(({ key }) => String(row[key] ?? ''))
|
||||
.join(' | ')} |`;
|
||||
const headerRow = `| ${columns.map(({ header }) => header).join(' | ')} |`;
|
||||
const dividerRow = `| ${columns
|
||||
.map(({ align }) => (align === 'right' ? '---:' : ':---'))
|
||||
.join(' | ')} |`;
|
||||
|
||||
console.log('');
|
||||
console.log('<details><summary>Vitest coverage table</summary>');
|
||||
console.log('');
|
||||
console.log(headerRow);
|
||||
console.log(dividerRow);
|
||||
rowsForOutput.forEach((row) => console.log(formatRow(row)));
|
||||
console.log('</details>');
|
||||
}
|
||||
NODE
|
||||
|
||||
- name: Upload Coverage Artifact
|
||||
if: steps.coverage-summary.outputs.has_coverage == 'true'
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
with:
|
||||
name: web-coverage-report
|
||||
path: web/coverage
|
||||
retention-days: 30
|
||||
if-no-files-found: error
|
||||
|
||||
web-build:
|
||||
name: Web Build
|
||||
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -209,7 +209,6 @@ api/.vscode
|
||||
.history
|
||||
|
||||
.idea/
|
||||
web/migration/
|
||||
|
||||
# pnpm
|
||||
/.pnpm-store
|
||||
@@ -222,7 +221,6 @@ mise.toml
|
||||
|
||||
|
||||
# AI Assistant
|
||||
.sisyphus/
|
||||
.roo/
|
||||
/.claude/worktrees/
|
||||
api/.env.backup
|
||||
@@ -239,6 +237,3 @@ scripts/stress-test/reports/
|
||||
# settings
|
||||
*.local.json
|
||||
*.local.md
|
||||
|
||||
# Code Agent Folder
|
||||
.qoder/*
|
||||
@@ -97,8 +97,3 @@ Feel free to reach out if you encounter any issues during the setup process.
|
||||
## Getting Help
|
||||
|
||||
If you ever get stuck or get a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat.
|
||||
|
||||
## Automated Agent Contributions
|
||||
|
||||
> [!NOTE]
|
||||
> If you are an automated agent, just add 🤖🤖🤖 to the end of the PR title to opt-in.
|
||||
|
||||
@@ -22,10 +22,10 @@ APP_WEB_URL=http://localhost:3000
|
||||
# Files URL
|
||||
FILES_URL=http://localhost:5001
|
||||
|
||||
# INTERNAL_FILES_URL is used by services running in Docker to reach the API file endpoints.
|
||||
# For Docker Desktop (Mac/Windows), use http://host.docker.internal:5001 when the API runs on the host.
|
||||
# For Docker Compose on Linux, use http://api:5001 when the API runs inside the Docker network.
|
||||
INTERNAL_FILES_URL=http://host.docker.internal:5001
|
||||
# INTERNAL_FILES_URL is used for plugin daemon communication within Docker network.
|
||||
# Set this to the internal Docker service URL for proper plugin file access.
|
||||
# Example: INTERNAL_FILES_URL=http://api:5001
|
||||
INTERNAL_FILES_URL=http://127.0.0.1:5001
|
||||
|
||||
# TRIGGER URL
|
||||
TRIGGER_URL=http://localhost:5001
|
||||
@@ -33,9 +33,6 @@ TRIGGER_URL=http://localhost:5001
|
||||
# The time in seconds after the signature is rejected
|
||||
FILES_ACCESS_TIMEOUT=300
|
||||
|
||||
# Collaboration mode toggle
|
||||
ENABLE_COLLABORATION_MODE=false
|
||||
|
||||
# Access token expiration time in minutes
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=60
|
||||
|
||||
@@ -183,7 +180,7 @@ CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
||||
COOKIE_DOMAIN=
|
||||
|
||||
# Vector database configuration
|
||||
# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `hologres`.
|
||||
# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
|
||||
VECTOR_STORE=weaviate
|
||||
# Prefix used to create collection name in vector database
|
||||
VECTOR_INDEX_NAME_PREFIX=Vector_index
|
||||
@@ -220,20 +217,6 @@ COUCHBASE_PASSWORD=password
|
||||
COUCHBASE_BUCKET_NAME=Embeddings
|
||||
COUCHBASE_SCOPE_NAME=_default
|
||||
|
||||
# Hologres configuration
|
||||
# access_key_id is used as the PG username, access_key_secret is used as the PG password
|
||||
HOLOGRES_HOST=
|
||||
HOLOGRES_PORT=80
|
||||
HOLOGRES_DATABASE=
|
||||
HOLOGRES_ACCESS_KEY_ID=
|
||||
HOLOGRES_ACCESS_KEY_SECRET=
|
||||
HOLOGRES_SCHEMA=public
|
||||
HOLOGRES_TOKENIZER=jieba
|
||||
HOLOGRES_DISTANCE_METHOD=Cosine
|
||||
HOLOGRES_BASE_QUANTIZATION_TYPE=rabitq
|
||||
HOLOGRES_MAX_DEGREE=64
|
||||
HOLOGRES_EF_CONSTRUCTION=400
|
||||
|
||||
# Milvus configuration
|
||||
MILVUS_URI=http://127.0.0.1:19530
|
||||
MILVUS_TOKEN=
|
||||
@@ -356,9 +339,6 @@ BAIDU_VECTOR_DB_SHARD=1
|
||||
BAIDU_VECTOR_DB_REPLICAS=3
|
||||
BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER=DEFAULT_ANALYZER
|
||||
BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE=COARSE_MODE
|
||||
BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT=500
|
||||
BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO=0.05
|
||||
BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS=300
|
||||
|
||||
# Upstash configuration
|
||||
UPSTASH_VECTOR_URL=your-server-url
|
||||
@@ -647,11 +627,6 @@ INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y
|
||||
MARKETPLACE_ENABLED=true
|
||||
MARKETPLACE_API_URL=https://marketplace.dify.ai
|
||||
|
||||
# Creators Platform configuration
|
||||
CREATORS_PLATFORM_FEATURES_ENABLED=true
|
||||
CREATORS_PLATFORM_API_URL=https://creators.dify.ai
|
||||
CREATORS_PLATFORM_OAUTH_CLIENT_ID=
|
||||
|
||||
# Endpoint configuration
|
||||
ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}
|
||||
|
||||
@@ -747,54 +722,25 @@ SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL=200
|
||||
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL=90000
|
||||
|
||||
# Sandbox Dify CLI configuration
|
||||
# Directory containing dify CLI binaries (dify-cli-<os>-<arch>). Defaults to api/bin when unset.
|
||||
SANDBOX_DIFY_CLI_ROOT=
|
||||
|
||||
# CLI API URL for sandbox (dify-sandbox or e2b) to call back to Dify API.
|
||||
# This URL must be accessible from the sandbox environment.
|
||||
# For local development: use http://localhost:5001 or http://127.0.0.1:5001
|
||||
# For middleware docker stack (api on host): keep localhost/127.0.0.1 and use agentbox via 127.0.0.1:2222
|
||||
# For Docker deployment: use http://api:5001 (internal Docker network)
|
||||
# For external sandbox (e.g., e2b): use a publicly accessible URL
|
||||
CLI_API_URL=http://localhost:5001
|
||||
|
||||
# Base URL for storage file ticket API endpoints (upload/download).
|
||||
# Used by sandbox containers (internal or external like e2b) that need an absolute,
|
||||
# routable address to reach the Dify API file endpoints.
|
||||
# Required for sandbox runtime file access.
|
||||
# For local development: http://localhost:5001
|
||||
# For all-in-one Docker deployment with nginx: http://localhost
|
||||
# For public/remote sandbox environments (e.g., e2b): use a public domain or IP
|
||||
FILES_API_URL=http://localhost:5001
|
||||
|
||||
# Optional defaults for SSH sandbox provider setup (for manual config/CLI usage).
|
||||
# Middleware/local dev usually uses 127.0.0.1:2222; full docker deployment usually uses agentbox:22.
|
||||
SSH_SANDBOX_HOST=127.0.0.1
|
||||
SSH_SANDBOX_PORT=2222
|
||||
SSH_SANDBOX_USERNAME=agentbox
|
||||
SSH_SANDBOX_PASSWORD=agentbox
|
||||
SSH_SANDBOX_BASE_WORKING_PATH=/workspace/sandboxes
|
||||
|
||||
# Redis URL used for event bus between API and
|
||||
# Redis URL used for PubSub between API and
|
||||
# celery worker
|
||||
# defaults to url constructed from `REDIS_*`
|
||||
# configurations
|
||||
EVENT_BUS_REDIS_URL=
|
||||
# Event transport type. Options are:
|
||||
PUBSUB_REDIS_URL=
|
||||
# Pub/sub channel type for streaming events.
|
||||
# valid options are:
|
||||
#
|
||||
# - pubsub: normal Pub/Sub (at-most-once)
|
||||
# - sharded: sharded Pub/Sub (at-most-once)
|
||||
# - streams: Redis Streams (at-least-once, recommended to avoid subscriber races)
|
||||
# - pubsub: for normal Pub/Sub
|
||||
# - sharded: for sharded Pub/Sub
|
||||
#
|
||||
# Note: Before enabling 'streams' in production, estimate your expected event volume and retention needs.
|
||||
# Configure Redis memory limits and stream trimming appropriately (e.g., MAXLEN and key expiry) to reduce
|
||||
# the risk of data loss from Redis auto-eviction under memory pressure.
|
||||
# Also accepts ENV: EVENT_BUS_REDIS_CHANNEL_TYPE.
|
||||
EVENT_BUS_REDIS_CHANNEL_TYPE=pubsub
|
||||
# Whether to use Redis cluster mode while use redis as event bus.
|
||||
# It's highly recommended to use sharded Pub/Sub AND redis cluster
|
||||
# for large deployments.
|
||||
PUBSUB_REDIS_CHANNEL_TYPE=pubsub
|
||||
# Whether to use Redis cluster mode while running
|
||||
# PubSub.
|
||||
# It's highly recommended to enable this for large deployments.
|
||||
EVENT_BUS_REDIS_USE_CLUSTERS=false
|
||||
PUBSUB_REDIS_USE_CLUSTERS=false
|
||||
|
||||
# Whether to Enable human input timeout check task
|
||||
ENABLE_HUMAN_INPUT_TIMEOUT_TASK=true
|
||||
|
||||
@@ -33,12 +33,6 @@ ignore_imports =
|
||||
# TODO(QuantumGhost): fix the import violation later
|
||||
dify_graph.entities.pause_reason -> dify_graph.nodes.human_input.entities
|
||||
|
||||
dify_graph.nodes.base.node -> core.workflow.node_factory
|
||||
dify_graph.nodes.tool.tool_node -> core.workflow.node_factory
|
||||
dify_graph.file.file_manager -> models.model
|
||||
dify_graph.file.file_manager -> models.tools
|
||||
dify_graph.file.file_manager -> extensions.ext_database
|
||||
|
||||
[importlinter:contract:workflow-infrastructure-dependencies]
|
||||
name = Workflow Infrastructure Dependencies
|
||||
type = forbidden
|
||||
@@ -49,11 +43,10 @@ forbidden_modules =
|
||||
extensions.ext_redis
|
||||
allow_indirect_imports = True
|
||||
ignore_imports =
|
||||
dify_graph.nodes.agent.agent_node -> extensions.ext_database
|
||||
dify_graph.nodes.llm.node -> extensions.ext_database
|
||||
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
|
||||
dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
|
||||
dify_graph.file.file_manager -> extensions.ext_database
|
||||
dify_graph.nodes.llm.llm_utils -> extensions.ext_database
|
||||
|
||||
[importlinter:contract:workflow-external-imports]
|
||||
name = Workflow External Imports
|
||||
@@ -97,6 +90,9 @@ forbidden_modules =
|
||||
core.trigger
|
||||
core.variables
|
||||
ignore_imports =
|
||||
dify_graph.nodes.agent.agent_node -> core.model_manager
|
||||
dify_graph.nodes.agent.agent_node -> core.provider_manager
|
||||
dify_graph.nodes.agent.agent_node -> core.tools.tool_manager
|
||||
dify_graph.nodes.llm.llm_utils -> core.model_manager
|
||||
dify_graph.nodes.llm.protocols -> core.model_manager
|
||||
dify_graph.nodes.llm.llm_utils -> dify_graph.model_runtime.model_providers.__base.large_language_model
|
||||
@@ -104,6 +100,9 @@ ignore_imports =
|
||||
dify_graph.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.tool_engine
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.tool_manager
|
||||
dify_graph.nodes.agent.agent_node -> core.agent.entities
|
||||
dify_graph.nodes.agent.agent_node -> core.agent.plugin_entities
|
||||
dify_graph.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> dify_graph.model_runtime.model_providers.__base.large_language_model
|
||||
@@ -111,10 +110,12 @@ ignore_imports =
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
|
||||
dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer
|
||||
dify_graph.nodes.agent.agent_node -> models.model
|
||||
dify_graph.nodes.llm.node -> core.helper.code_executor
|
||||
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors
|
||||
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.file_ref
|
||||
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output
|
||||
dify_graph.nodes.llm.node -> core.model_manager
|
||||
dify_graph.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
dify_graph.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
dify_graph.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
|
||||
dify_graph.nodes.llm.node -> core.prompt.utils.prompt_message_util
|
||||
@@ -123,29 +124,17 @@ ignore_imports =
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util
|
||||
dify_graph.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util
|
||||
dify_graph.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods
|
||||
dify_graph.nodes.llm.node -> models.dataset
|
||||
dify_graph.nodes.agent.agent_node -> core.tools.utils.message_transformer
|
||||
dify_graph.nodes.llm.file_saver -> core.tools.signature
|
||||
dify_graph.nodes.llm.file_saver -> core.tools.tool_file_manager
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.errors
|
||||
dify_graph.nodes.agent.agent_node -> extensions.ext_database
|
||||
dify_graph.nodes.llm.node -> extensions.ext_database
|
||||
dify_graph.nodes.agent.agent_node -> models
|
||||
dify_graph.nodes.llm.node -> models.model
|
||||
dify_graph.nodes.llm.node -> configs
|
||||
dify_graph.nodes.llm.node -> core.agent.entities
|
||||
dify_graph.nodes.llm.node -> core.agent.patterns
|
||||
dify_graph.nodes.llm.node -> core.app.entities.app_invoke_entities
|
||||
dify_graph.nodes.llm.node -> core.helper.code_executor
|
||||
dify_graph.nodes.llm.node -> core.memory.base
|
||||
dify_graph.nodes.llm.node -> core.sandbox
|
||||
dify_graph.nodes.llm.node -> core.sandbox.bash.session
|
||||
dify_graph.nodes.llm.node -> core.sandbox.entities.config
|
||||
dify_graph.nodes.llm.node -> core.skill.assembler
|
||||
dify_graph.nodes.llm.node -> core.skill.constants
|
||||
dify_graph.nodes.llm.node -> core.skill.entities.skill_bundle
|
||||
dify_graph.nodes.llm.node -> core.skill.entities.skill_document
|
||||
dify_graph.nodes.llm.node -> core.skill.entities.skill_metadata
|
||||
dify_graph.nodes.llm.node -> core.skill.entities.tool_dependencies
|
||||
dify_graph.nodes.llm.node -> core.tools.tool_file_manager
|
||||
dify_graph.nodes.llm.node -> core.tools.tool_manager
|
||||
dify_graph.nodes.agent.agent_node -> services
|
||||
dify_graph.nodes.tool.tool_node -> services
|
||||
dify_graph.model_runtime.model_providers.__base.ai_model -> configs
|
||||
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
|
||||
@@ -154,20 +143,6 @@ ignore_imports =
|
||||
dify_graph.model_runtime.model_providers.model_provider_factory -> configs
|
||||
dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
|
||||
dify_graph.model_runtime.model_providers.model_provider_factory -> models.provider_ids
|
||||
dify_graph.file.file_manager -> configs
|
||||
dify_graph.file.file_manager -> extensions.ext_database
|
||||
dify_graph.file.file_manager -> models.model
|
||||
dify_graph.file.file_manager -> models.tools
|
||||
dify_graph.nodes.llm.llm_utils -> core.app.llm.model_access
|
||||
dify_graph.nodes.llm.llm_utils -> core.app.llm.quota
|
||||
dify_graph.nodes.llm.llm_utils -> core.memory
|
||||
dify_graph.nodes.llm.llm_utils -> core.memory.base
|
||||
dify_graph.nodes.llm.llm_utils -> extensions.ext_database
|
||||
dify_graph.nodes.llm.llm_utils -> models.model
|
||||
dify_graph.nodes.llm.llm_utils -> core.prompt.entities.advanced_prompt_entities
|
||||
dify_graph.nodes.llm.entities -> core.agent.entities
|
||||
dify_graph.nodes.base.node -> core.workflow.node_factory
|
||||
dify_graph.nodes.tool.tool_node -> core.workflow.node_factory
|
||||
|
||||
[importlinter:contract:rsc]
|
||||
name = RSC
|
||||
|
||||
@@ -78,7 +78,7 @@ class UserProfile(TypedDict):
|
||||
nickname: NotRequired[str]
|
||||
```
|
||||
|
||||
- For classes, declare all member variables explicitly with types at the top of the class body (before `__init__`), even when the class is not a dataclass or Pydantic model, so the class shape is obvious at a glance:
|
||||
- For classes, declare member variables at the top of the class body (before `__init__`) so the class shape is obvious at a glance:
|
||||
|
||||
```python
|
||||
from datetime import datetime
|
||||
|
||||
@@ -97,7 +97,7 @@ ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
|
||||
|
||||
# Download nltk data
|
||||
RUN mkdir -p /usr/local/share/nltk_data \
|
||||
&& NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords')" \
|
||||
&& NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; from unstructured.nlp.tokenize import download_nltk_packages; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords'); download_nltk_packages()" \
|
||||
&& chmod -R 755 /usr/local/share/nltk_data
|
||||
|
||||
ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache
|
||||
|
||||
@@ -54,86 +54,6 @@ The scripts resolve paths relative to their location, so you can run them from a
|
||||
./dev/start-beat
|
||||
```
|
||||
|
||||
### Manual commands
|
||||
|
||||
<details>
|
||||
<summary>Show manual setup and run steps</summary>
|
||||
|
||||
These commands assume you start from the repository root.
|
||||
|
||||
1. Start the docker-compose stack.
|
||||
|
||||
The backend requires middleware, including PostgreSQL, Redis, and Weaviate, which can be started together using `docker-compose`.
|
||||
|
||||
```bash
|
||||
cp docker/middleware.env.example docker/middleware.env
|
||||
# Use mysql or another vector database profile if you are not using postgres/weaviate.
|
||||
docker compose -f docker/docker-compose.middleware.yaml --profile postgresql --profile weaviate -p dify up -d
|
||||
```
|
||||
|
||||
1. Copy env files.
|
||||
|
||||
```bash
|
||||
cp api/.env.example api/.env
|
||||
cp web/.env.example web/.env.local
|
||||
```
|
||||
|
||||
1. Install UV if needed.
|
||||
|
||||
```bash
|
||||
pip install uv
|
||||
# Or on macOS
|
||||
brew install uv
|
||||
```
|
||||
|
||||
1. Install API dependencies.
|
||||
|
||||
```bash
|
||||
cd api
|
||||
uv sync --group dev
|
||||
```
|
||||
|
||||
1. Install web dependencies.
|
||||
|
||||
```bash
|
||||
cd web
|
||||
pnpm install
|
||||
cd ..
|
||||
```
|
||||
|
||||
1. Start backend (runs migrations first, in a new terminal).
|
||||
|
||||
```bash
|
||||
cd api
|
||||
uv run flask db upgrade
|
||||
uv run flask run --host 0.0.0.0 --port=5001 --debug
|
||||
```
|
||||
|
||||
1. Start Dify [web](../web) service (in a new terminal).
|
||||
|
||||
```bash
|
||||
cd web
|
||||
pnpm dev:inspect
|
||||
```
|
||||
|
||||
1. Set up your application by visiting `http://localhost:3000`.
|
||||
|
||||
1. Optional: start the worker service (async tasks, in a new terminal).
|
||||
|
||||
```bash
|
||||
cd api
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution
|
||||
```
|
||||
|
||||
1. Optional: start Celery Beat (scheduled tasks, in a new terminal).
|
||||
|
||||
```bash
|
||||
cd api
|
||||
uv run celery -A app.celery beat
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### Environment notes
|
||||
|
||||
> [!IMPORTANT]
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
Summary:
|
||||
Summary:
|
||||
- Application configuration definitions, including file access settings.
|
||||
|
||||
Invariants:
|
||||
- File access settings drive signed URL expiration and base URLs.
|
||||
|
||||
Tests:
|
||||
- Config parsing tests under tests/unit_tests/configs.
|
||||
@@ -1,9 +0,0 @@
|
||||
Summary:
|
||||
- Registers file-related API namespaces and routes for files service.
|
||||
- Includes app-assets and sandbox archive proxy controllers.
|
||||
|
||||
Invariants:
|
||||
- files_ns must include all file controller modules to register routes.
|
||||
|
||||
Tests:
|
||||
- Coverage via controller unit tests and route registration smoke checks.
|
||||
@@ -1,14 +0,0 @@
|
||||
Summary:
|
||||
- App assets download proxy endpoint (signed URL verification, stream from storage).
|
||||
|
||||
Invariants:
|
||||
- Validates AssetPath fields (UUIDs, asset_type allowlist).
|
||||
- Verifies tenant-scoped signature and expiration before reading storage.
|
||||
- URL uses expires_at/nonce/sign query params.
|
||||
|
||||
Edge Cases:
|
||||
- Missing files return NotFound.
|
||||
- Invalid signature or expired link returns Forbidden.
|
||||
|
||||
Tests:
|
||||
- Verify signature validation and invalid/expired cases.
|
||||
@@ -1,13 +0,0 @@
|
||||
Summary:
|
||||
- App assets upload proxy endpoint (signed URL verification, upload to storage).
|
||||
|
||||
Invariants:
|
||||
- Validates AssetPath fields (UUIDs, asset_type allowlist).
|
||||
- Verifies tenant-scoped signature and expiration before writing storage.
|
||||
- URL uses expires_at/nonce/sign query params.
|
||||
|
||||
Edge Cases:
|
||||
- Invalid signature or expired link returns Forbidden.
|
||||
|
||||
Tests:
|
||||
- Verify signature validation and invalid/expired cases.
|
||||
@@ -1,14 +0,0 @@
|
||||
Summary:
|
||||
- Sandbox archive upload/download proxy endpoints (signed URL verification, stream to storage).
|
||||
|
||||
Invariants:
|
||||
- Validates tenant_id and sandbox_id UUIDs.
|
||||
- Verifies tenant-scoped signature and expiration before storage access.
|
||||
- URL uses expires_at/nonce/sign query params.
|
||||
|
||||
Edge Cases:
|
||||
- Missing archive returns NotFound.
|
||||
- Invalid signature or expired link returns Forbidden.
|
||||
|
||||
Tests:
|
||||
- Add unit tests for signature validation if needed.
|
||||
@@ -1,9 +0,0 @@
|
||||
Summary:
|
||||
Summary:
|
||||
- Collects file assets and emits FileAsset entries with storage keys.
|
||||
|
||||
Invariants:
|
||||
- Storage keys are derived via AppAssetStorage for draft files.
|
||||
|
||||
Tests:
|
||||
- Covered by asset build pipeline tests.
|
||||
@@ -1,14 +0,0 @@
|
||||
Summary:
|
||||
Summary:
|
||||
- Builds skill artifacts from markdown assets and uploads resolved outputs.
|
||||
|
||||
Invariants:
|
||||
- Reads draft asset content via AppAssetStorage refs.
|
||||
- Writes resolved artifacts via AppAssetStorage refs.
|
||||
- FileAsset storage keys are derived via AppAssetStorage.
|
||||
|
||||
Edge Cases:
|
||||
- Missing or invalid JSON content yields empty skill content/metadata.
|
||||
|
||||
Tests:
|
||||
- Build pipeline unit tests covering compile/upload paths.
|
||||
@@ -1,9 +0,0 @@
|
||||
Summary:
|
||||
Summary:
|
||||
- Converts AppAssetFileTree to FileAsset items for packaging.
|
||||
|
||||
Invariants:
|
||||
- Storage keys for assets are derived via AppAssetStorage.
|
||||
|
||||
Tests:
|
||||
- Used in packaging/service tests for asset bundles.
|
||||
@@ -1,14 +0,0 @@
|
||||
# Zip Packager Notes
|
||||
|
||||
## Purpose
|
||||
- Builds a ZIP archive of asset contents stored via the configured storage backend.
|
||||
|
||||
## Key Decisions
|
||||
- Packaging writes assets into an in-memory zip buffer returned as bytes.
|
||||
- Asset fetch + zip writing are executed via a thread pool with a lock guarding `ZipFile` writes.
|
||||
|
||||
## Edge Cases
|
||||
- ZIP writes are serialized by the lock; storage reads still run in parallel.
|
||||
|
||||
## Tests/Verification
|
||||
- None yet.
|
||||
@@ -1,9 +0,0 @@
|
||||
Summary:
|
||||
Summary:
|
||||
- Builds AssetItem entries for asset trees using AssetPath-derived storage keys.
|
||||
|
||||
Invariants:
|
||||
- Uses AssetPath to compute draft storage keys.
|
||||
|
||||
Tests:
|
||||
- Covered by asset parsing and packaging tests.
|
||||
@@ -1,20 +0,0 @@
|
||||
Summary:
|
||||
- Defines AssetPath facade + typed asset path classes for app-asset storage access.
|
||||
- Maps asset paths to storage keys and generates presigned or signed-proxy URLs.
|
||||
- Signs proxy URLs using tenant private keys and enforces expiration.
|
||||
- Exposes app_asset_storage singleton for reuse.
|
||||
|
||||
Invariants:
|
||||
- AssetPathBase fields (tenant_id/app_id/resource_id/node_id) must be UUIDs.
|
||||
- AssetPath.from_components enforces valid types and resolved node_id presence.
|
||||
- Storage keys are derived internally via AssetPathBase.get_storage_key; callers never supply raw paths.
|
||||
- AppAssetStorage.storage returns the cached presign wrapper (not the raw storage).
|
||||
|
||||
Edge Cases:
|
||||
- Storage backends without presign support must fall back to signed proxy URLs.
|
||||
- Signed proxy verification enforces expiration and tenant-scoped signing keys.
|
||||
- Upload URLs also fall back to signed proxy endpoints when presign is unsupported.
|
||||
- load_or_none treats SilentStorage "File Not Found" bytes as missing.
|
||||
|
||||
Tests:
|
||||
- Unit tests for ref validation, storage key mapping, and signed URL verification.
|
||||
@@ -1,10 +0,0 @@
|
||||
Summary:
|
||||
Summary:
|
||||
- Extracts asset files from a zip and persists them into app asset storage.
|
||||
|
||||
Invariants:
|
||||
- Rejects path traversal/absolute/backslash paths.
|
||||
- Saves extracted files via AppAssetStorage draft refs.
|
||||
|
||||
Tests:
|
||||
- Zip security edge cases and tree construction tests.
|
||||
@@ -1,9 +0,0 @@
|
||||
Summary:
|
||||
Summary:
|
||||
- Downloads published app asset zip into sandbox and extracts it.
|
||||
|
||||
Invariants:
|
||||
- Uses AppAssetStorage to generate download URLs for build zips (internal URL).
|
||||
|
||||
Tests:
|
||||
- Sandbox initialization integration tests.
|
||||
@@ -1,12 +0,0 @@
|
||||
Summary:
|
||||
Summary:
|
||||
- Downloads draft/resolved assets into sandbox for draft execution.
|
||||
|
||||
Invariants:
|
||||
- Uses AppAssetStorage to generate download URLs for draft/resolved refs (internal URL).
|
||||
|
||||
Edge Cases:
|
||||
- No nodes -> returns early.
|
||||
|
||||
Tests:
|
||||
- Sandbox draft initialization tests.
|
||||
@@ -1,9 +0,0 @@
|
||||
Summary:
|
||||
- Sandbox lifecycle wrapper (ready/cancel/fail signals, mount/unmount, release).
|
||||
|
||||
Invariants:
|
||||
- wait_ready raises with the original initialization error as the cause.
|
||||
- release always attempts unmount and environment release, logging failures.
|
||||
|
||||
Tests:
|
||||
- Covered by sandbox lifecycle/unit tests and workflow execution error handling.
|
||||
@@ -1,2 +0,0 @@
|
||||
Summary:
|
||||
- Sandbox security helper modules.
|
||||
@@ -1,13 +0,0 @@
|
||||
Summary:
|
||||
- Generates and verifies signed URLs for sandbox archive upload/download.
|
||||
|
||||
Invariants:
|
||||
- tenant_id and sandbox_id must be UUIDs.
|
||||
- Signatures are tenant-scoped and include operation, expiry, and nonce.
|
||||
|
||||
Edge Cases:
|
||||
- Missing tenant private key raises ValueError.
|
||||
- Expired or tampered signatures are rejected.
|
||||
|
||||
Tests:
|
||||
- Add unit tests if sandbox archive signature behavior expands.
|
||||
@@ -1,12 +0,0 @@
|
||||
Summary:
|
||||
- Manages sandbox archive uploads/downloads for workspace persistence.
|
||||
|
||||
Invariants:
|
||||
- Archive storage key is sandbox/<tenant_id>/<sandbox_id>.tar.gz.
|
||||
- Signed URLs are tenant-scoped and use external files URL.
|
||||
|
||||
Edge Cases:
|
||||
- Missing archive skips mount.
|
||||
|
||||
Tests:
|
||||
- Covered indirectly via sandbox integration tests.
|
||||
@@ -1,9 +0,0 @@
|
||||
Summary:
|
||||
Summary:
|
||||
- Loads/saves skill bundles to app asset storage.
|
||||
|
||||
Invariants:
|
||||
- Skill bundles use AppAssetStorage refs and JSON serialization.
|
||||
|
||||
Tests:
|
||||
- Covered by skill bundle build/load unit tests.
|
||||
@@ -1,16 +0,0 @@
|
||||
# E2B Sandbox Provider Notes
|
||||
|
||||
## Purpose
|
||||
- Implements the E2B-backed `VirtualEnvironment` provider and bootstraps sandbox metadata, file I/O, and command execution.
|
||||
|
||||
## Key Decisions
|
||||
- Sandbox metadata is gathered during `_construct_environment` using the E2B SDK before returning `Metadata`.
|
||||
- Architecture/OS detection uses a single `uname -m -s` call split by whitespace to reduce round-trips.
|
||||
- Command execution streams stdout/stderr through `QueueTransportReadCloser`; stdin is unsupported.
|
||||
|
||||
## Edge Cases
|
||||
- `release_environment` raises when sandbox termination fails.
|
||||
- `execute_command` runs in a background thread; consumers must read stdout/stderr until EOF.
|
||||
|
||||
## Tests/Verification
|
||||
- None yet. Add targeted service tests when behavior changes.
|
||||
@@ -1,14 +0,0 @@
|
||||
Summary:
|
||||
- App asset CRUD, publish/build pipeline, and presigned URL generation.
|
||||
|
||||
Invariants:
|
||||
- Asset storage access goes through AppAssetStorage + AssetPath, using app_asset_storage singleton.
|
||||
- Tree operations require tenant/app scoping and lock for mutation.
|
||||
- Asset zips are packaged via raw storage with storage keys from AppAssetStorage.
|
||||
|
||||
Edge Cases:
|
||||
- File nodes larger than preview limit are rejected.
|
||||
- Deletion runs asynchronously; storage failures are logged.
|
||||
|
||||
Tests:
|
||||
- Unit tests for storage URL generation and publish/build flows.
|
||||
@@ -1,10 +0,0 @@
|
||||
Summary:
|
||||
Summary:
|
||||
- Imports app bundles, including asset extraction into app asset storage.
|
||||
|
||||
Invariants:
|
||||
- Asset imports respect zip security checks and tenant/app scoping.
|
||||
- Draft asset packaging uses AppAssetStorage for key mapping.
|
||||
|
||||
Tests:
|
||||
- Bundle import unit tests and zip validation coverage.
|
||||
@@ -1,6 +0,0 @@
|
||||
Summary:
|
||||
Summary:
|
||||
- Unit tests for AppAssetStorage ref validation, key mapping, and signing.
|
||||
|
||||
Tests:
|
||||
- Covers valid/invalid refs, signature verify, expiration handling, and proxy URL generation.
|
||||
19
api/app.py
19
api/app.py
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
@@ -17,15 +16,10 @@ def is_db_command() -> bool:
|
||||
|
||||
|
||||
# create app
|
||||
flask_app = None
|
||||
socketio_app = None
|
||||
|
||||
if is_db_command():
|
||||
from app_factory import create_migrations_app
|
||||
|
||||
app = create_migrations_app()
|
||||
socketio_app = app
|
||||
flask_app = app
|
||||
else:
|
||||
# Gunicorn and Celery handle monkey patching automatically in production by
|
||||
# specifying the `gevent` worker class. Manual monkey patching is not required here.
|
||||
@@ -36,15 +30,8 @@ else:
|
||||
|
||||
from app_factory import create_app
|
||||
|
||||
socketio_app, flask_app = create_app()
|
||||
app = flask_app
|
||||
celery = cast("Celery", flask_app.extensions["celery"])
|
||||
app = create_app()
|
||||
celery = cast("Celery", app.extensions["celery"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
from gevent import pywsgi
|
||||
from geventwebsocket.handler import WebSocketHandler # type: ignore[reportMissingTypeStubs]
|
||||
|
||||
host = os.environ.get("HOST", "0.0.0.0")
|
||||
port = int(os.environ.get("PORT", 5001))
|
||||
server = pywsgi.WSGIServer((host, port), socketio_app, handler_class=WebSocketHandler)
|
||||
server.serve_forever()
|
||||
app.run(host="0.0.0.0", port=5001)
|
||||
|
||||
@@ -1,47 +1,16 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import socketio # type: ignore[reportMissingTypeStubs]
|
||||
from flask import request
|
||||
from opentelemetry.trace import get_current_span
|
||||
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
|
||||
|
||||
from configs import dify_config
|
||||
from contexts.wrapper import RecyclableContextVar
|
||||
from controllers.console.error import UnauthorizedAndForceLogout
|
||||
from core.logging.context import init_request_context
|
||||
from dify_app import DifyApp
|
||||
from extensions.ext_socketio import sio
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Console bootstrap APIs exempt from license check.
|
||||
# Defined at module level to avoid per-request tuple construction.
|
||||
# - system-features: license status for expiry UI (GlobalPublicStoreProvider)
|
||||
# - setup: install/setup status check (AppInitializer)
|
||||
# - init: init password validation for fresh install (InitPasswordPopup)
|
||||
# - login: auto-login after setup completion (InstallForm)
|
||||
# - features: billing/plan features (ProviderContextProvider)
|
||||
# - account/profile: login check + user profile (AppContextProvider, useIsLogin)
|
||||
# - workspaces/current: workspace + model providers (AppContextProvider)
|
||||
# - version: version check (AppContextProvider)
|
||||
# - activate/check: invitation link validation (signin page)
|
||||
# Without these exemptions, the signin page triggers location.reload()
|
||||
# on unauthorized_and_force_logout, causing an infinite loop.
|
||||
_CONSOLE_EXEMPT_PREFIXES = (
|
||||
"/console/api/system-features",
|
||||
"/console/api/setup",
|
||||
"/console/api/init",
|
||||
"/console/api/login",
|
||||
"/console/api/features",
|
||||
"/console/api/account/profile",
|
||||
"/console/api/workspaces/current",
|
||||
"/console/api/version",
|
||||
"/console/api/activate/check",
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# Application Factory Function
|
||||
@@ -62,39 +31,6 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||
init_request_context()
|
||||
RecyclableContextVar.increment_thread_recycles()
|
||||
|
||||
# Enterprise license validation for API endpoints (both console and webapp)
|
||||
# When license expires, block all API access except bootstrap endpoints needed
|
||||
# for the frontend to load the license expiration page without infinite reloads.
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
is_console_api = request.path.startswith("/console/api/")
|
||||
is_webapp_api = request.path.startswith("/api/")
|
||||
|
||||
if is_console_api or is_webapp_api:
|
||||
if is_console_api:
|
||||
is_exempt = any(request.path.startswith(p) for p in _CONSOLE_EXEMPT_PREFIXES)
|
||||
else: # webapp API
|
||||
is_exempt = request.path.startswith("/api/system-features")
|
||||
|
||||
if not is_exempt:
|
||||
try:
|
||||
# Check license status (cached — see EnterpriseService for TTL details)
|
||||
license_status = EnterpriseService.get_cached_license_status()
|
||||
if license_status in (LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST):
|
||||
raise UnauthorizedAndForceLogout(
|
||||
f"Enterprise license is {license_status}. Please contact your administrator."
|
||||
)
|
||||
if license_status is None:
|
||||
raise UnauthorizedAndForceLogout(
|
||||
"Unable to verify enterprise license. Please contact your administrator."
|
||||
)
|
||||
except UnauthorizedAndForceLogout:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to check enterprise license status")
|
||||
raise UnauthorizedAndForceLogout(
|
||||
"Unable to verify enterprise license. Please contact your administrator."
|
||||
)
|
||||
|
||||
# add after request hook for injecting trace headers from OpenTelemetry span context
|
||||
# Only adds headers when OTEL is enabled and has valid context
|
||||
@dify_app.after_request
|
||||
@@ -124,18 +60,14 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||
return dify_app
|
||||
|
||||
|
||||
def create_app() -> tuple[socketio.WSGIApp, DifyApp]:
|
||||
def create_app() -> DifyApp:
|
||||
start_time = time.perf_counter()
|
||||
app = create_flask_app_with_configs()
|
||||
initialize_extensions(app)
|
||||
|
||||
sio.app = app
|
||||
socketio_app = socketio.WSGIApp(sio, app)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
if dify_config.DEBUG:
|
||||
logger.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2))
|
||||
return socketio_app, app
|
||||
return app
|
||||
|
||||
|
||||
def initialize_extensions(app: DifyApp):
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,17 +1,15 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import encrypter
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.tools.utils.system_encryption import encrypt_system_params as encrypt_system_oauth_params
|
||||
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
|
||||
from extensions.ext_database import db
|
||||
from models import Tenant
|
||||
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||
@@ -50,15 +48,14 @@ def setup_system_tool_oauth_client(provider, client_params):
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
return
|
||||
|
||||
deleted_count = cast(
|
||||
CursorResult,
|
||||
db.session.execute(
|
||||
delete(ToolOAuthSystemClient).where(
|
||||
ToolOAuthSystemClient.provider == provider_name,
|
||||
ToolOAuthSystemClient.plugin_id == plugin_id,
|
||||
)
|
||||
),
|
||||
).rowcount
|
||||
deleted_count = (
|
||||
db.session.query(ToolOAuthSystemClient)
|
||||
.filter_by(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
if deleted_count > 0:
|
||||
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
|
||||
|
||||
@@ -100,15 +97,14 @@ def setup_system_trigger_oauth_client(provider, client_params):
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
return
|
||||
|
||||
deleted_count = cast(
|
||||
CursorResult,
|
||||
db.session.execute(
|
||||
delete(TriggerOAuthSystemClient).where(
|
||||
TriggerOAuthSystemClient.provider == provider_name,
|
||||
TriggerOAuthSystemClient.plugin_id == plugin_id,
|
||||
)
|
||||
),
|
||||
).rowcount
|
||||
deleted_count = (
|
||||
db.session.query(TriggerOAuthSystemClient)
|
||||
.filter_by(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
if deleted_count > 0:
|
||||
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
|
||||
|
||||
@@ -143,15 +139,14 @@ def setup_datasource_oauth_client(provider, client_params):
|
||||
return
|
||||
|
||||
click.echo(click.style(f"Ready to delete existing oauth client params: {provider_name}", fg="yellow"))
|
||||
deleted_count = cast(
|
||||
CursorResult,
|
||||
db.session.execute(
|
||||
delete(DatasourceOauthParamConfig).where(
|
||||
DatasourceOauthParamConfig.provider == provider_name,
|
||||
DatasourceOauthParamConfig.plugin_id == plugin_id,
|
||||
)
|
||||
),
|
||||
).rowcount
|
||||
deleted_count = (
|
||||
db.session.query(DatasourceOauthParamConfig)
|
||||
.filter_by(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
if deleted_count > 0:
|
||||
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
|
||||
|
||||
@@ -197,9 +192,7 @@ def transform_datasource_credentials(environment: str):
|
||||
|
||||
# deal notion credentials
|
||||
deal_notion_count = 0
|
||||
notion_credentials = db.session.scalars(
|
||||
select(DataSourceOauthBinding).where(DataSourceOauthBinding.provider == "notion")
|
||||
).all()
|
||||
notion_credentials = db.session.query(DataSourceOauthBinding).filter_by(provider="notion").all()
|
||||
if notion_credentials:
|
||||
notion_credentials_tenant_mapping: dict[str, list[DataSourceOauthBinding]] = {}
|
||||
for notion_credential in notion_credentials:
|
||||
@@ -208,7 +201,7 @@ def transform_datasource_credentials(environment: str):
|
||||
notion_credentials_tenant_mapping[tenant_id] = []
|
||||
notion_credentials_tenant_mapping[tenant_id].append(notion_credential)
|
||||
for tenant_id, notion_tenant_credentials in notion_credentials_tenant_mapping.items():
|
||||
tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id))
|
||||
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
|
||||
if not tenant:
|
||||
continue
|
||||
try:
|
||||
@@ -257,9 +250,7 @@ def transform_datasource_credentials(environment: str):
|
||||
db.session.commit()
|
||||
# deal firecrawl credentials
|
||||
deal_firecrawl_count = 0
|
||||
firecrawl_credentials = db.session.scalars(
|
||||
select(DataSourceApiKeyAuthBinding).where(DataSourceApiKeyAuthBinding.provider == "firecrawl")
|
||||
).all()
|
||||
firecrawl_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="firecrawl").all()
|
||||
if firecrawl_credentials:
|
||||
firecrawl_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
|
||||
for firecrawl_credential in firecrawl_credentials:
|
||||
@@ -268,7 +259,7 @@ def transform_datasource_credentials(environment: str):
|
||||
firecrawl_credentials_tenant_mapping[tenant_id] = []
|
||||
firecrawl_credentials_tenant_mapping[tenant_id].append(firecrawl_credential)
|
||||
for tenant_id, firecrawl_tenant_credentials in firecrawl_credentials_tenant_mapping.items():
|
||||
tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id))
|
||||
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
|
||||
if not tenant:
|
||||
continue
|
||||
try:
|
||||
@@ -321,9 +312,7 @@ def transform_datasource_credentials(environment: str):
|
||||
db.session.commit()
|
||||
# deal jina credentials
|
||||
deal_jina_count = 0
|
||||
jina_credentials = db.session.scalars(
|
||||
select(DataSourceApiKeyAuthBinding).where(DataSourceApiKeyAuthBinding.provider == "jinareader")
|
||||
).all()
|
||||
jina_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="jinareader").all()
|
||||
if jina_credentials:
|
||||
jina_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
|
||||
for jina_credential in jina_credentials:
|
||||
@@ -332,7 +321,7 @@ def transform_datasource_credentials(environment: str):
|
||||
jina_credentials_tenant_mapping[tenant_id] = []
|
||||
jina_credentials_tenant_mapping[tenant_id].append(jina_credential)
|
||||
for tenant_id, jina_tenant_credentials in jina_credentials_tenant_mapping.items():
|
||||
tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id))
|
||||
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
|
||||
if not tenant:
|
||||
continue
|
||||
try:
|
||||
|
||||
@@ -88,8 +88,6 @@ def clean_workflow_runs(
|
||||
"""
|
||||
Clean workflow runs and related workflow data for free tenants.
|
||||
"""
|
||||
from extensions.otel.runtime import flush_telemetry
|
||||
|
||||
if (start_from is None) ^ (end_before is None):
|
||||
raise click.UsageError("--start-from and --end-before must be provided together.")
|
||||
|
||||
@@ -106,27 +104,16 @@ def clean_workflow_runs(
|
||||
end_before = now - datetime.timedelta(days=to_days_ago)
|
||||
before_days = 0
|
||||
|
||||
if from_days_ago is not None and to_days_ago is not None:
|
||||
task_label = f"{from_days_ago}to{to_days_ago}"
|
||||
elif start_from is None:
|
||||
task_label = f"before-{before_days}"
|
||||
else:
|
||||
task_label = "custom"
|
||||
|
||||
start_time = datetime.datetime.now(datetime.UTC)
|
||||
click.echo(click.style(f"Starting workflow run cleanup at {start_time.isoformat()}.", fg="white"))
|
||||
|
||||
try:
|
||||
WorkflowRunCleanup(
|
||||
days=before_days,
|
||||
batch_size=batch_size,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
).run()
|
||||
finally:
|
||||
flush_telemetry()
|
||||
WorkflowRunCleanup(
|
||||
days=before_days,
|
||||
batch_size=batch_size,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
dry_run=dry_run,
|
||||
).run()
|
||||
|
||||
end_time = datetime.datetime.now(datetime.UTC)
|
||||
elapsed = end_time - start_time
|
||||
@@ -672,8 +659,6 @@ def clean_expired_messages(
|
||||
"""
|
||||
Clean expired messages and related data for tenants based on clean policy.
|
||||
"""
|
||||
from extensions.otel.runtime import flush_telemetry
|
||||
|
||||
click.echo(click.style("clean_messages: start clean messages.", fg="green"))
|
||||
|
||||
start_at = time.perf_counter()
|
||||
@@ -713,13 +698,6 @@ def clean_expired_messages(
|
||||
# NOTE: graceful_period will be ignored when billing is disabled.
|
||||
policy = create_message_clean_policy(graceful_period_days=graceful_period)
|
||||
|
||||
if from_days_ago is not None and before_days is not None:
|
||||
task_label = f"{from_days_ago}to{before_days}"
|
||||
elif start_from is None and before_days is not None:
|
||||
task_label = f"before-{before_days}"
|
||||
else:
|
||||
task_label = "custom"
|
||||
|
||||
# Create and run the cleanup service
|
||||
if abs_mode:
|
||||
assert start_from is not None
|
||||
@@ -730,7 +708,6 @@ def clean_expired_messages(
|
||||
end_before=end_before,
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
)
|
||||
elif from_days_ago is None:
|
||||
assert before_days is not None
|
||||
@@ -739,7 +716,6 @@ def clean_expired_messages(
|
||||
days=before_days,
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
)
|
||||
else:
|
||||
assert before_days is not None
|
||||
@@ -751,7 +727,6 @@ def clean_expired_messages(
|
||||
end_before=now - datetime.timedelta(days=before_days),
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
)
|
||||
stats = service.run()
|
||||
|
||||
@@ -777,8 +752,6 @@ def clean_expired_messages(
|
||||
)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
flush_telemetry()
|
||||
|
||||
click.echo(click.style("messages cleanup completed.", fg="green"))
|
||||
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
import json
|
||||
from typing import cast
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.engine import CursorResult
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
@@ -743,17 +740,14 @@ def migrate_oss(
|
||||
else:
|
||||
try:
|
||||
source_storage_type = StorageType.LOCAL if is_source_local else StorageType.OPENDAL
|
||||
updated = cast(
|
||||
CursorResult,
|
||||
db.session.execute(
|
||||
update(UploadFile)
|
||||
.where(
|
||||
UploadFile.storage_type == source_storage_type,
|
||||
UploadFile.key.in_(copied_upload_file_keys),
|
||||
)
|
||||
.values(storage_type=dify_config.STORAGE_TYPE)
|
||||
),
|
||||
).rowcount
|
||||
updated = (
|
||||
db.session.query(UploadFile)
|
||||
.where(
|
||||
UploadFile.storage_type == source_storage_type,
|
||||
UploadFile.key.in_(copied_upload_file_keys),
|
||||
)
|
||||
.update({UploadFile.storage_type: dify_config.STORAGE_TYPE}, synchronize_session=False)
|
||||
)
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"Updated storage_type for {updated} upload_files records.", fg="green"))
|
||||
except Exception as e:
|
||||
|
||||
@@ -2,7 +2,6 @@ import logging
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import delete, select, update
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
@@ -42,7 +41,7 @@ def reset_encrypt_key_pair():
|
||||
click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red"))
|
||||
return
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
tenants = session.scalars(select(Tenant)).all()
|
||||
tenants = session.query(Tenant).all()
|
||||
for tenant in tenants:
|
||||
if not tenant:
|
||||
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
|
||||
@@ -50,8 +49,8 @@ def reset_encrypt_key_pair():
|
||||
|
||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||
|
||||
session.execute(delete(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id))
|
||||
session.execute(delete(ProviderModel).where(ProviderModel.tenant_id == tenant.id))
|
||||
session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
|
||||
session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
@@ -94,7 +93,7 @@ def convert_to_agent_apps():
|
||||
app_id = str(i.id)
|
||||
if app_id not in proceeded_app_ids:
|
||||
proceeded_app_ids.append(app_id)
|
||||
app = db.session.scalar(select(App).where(App.id == app_id))
|
||||
app = db.session.query(App).where(App.id == app_id).first()
|
||||
if app is not None:
|
||||
apps.append(app)
|
||||
|
||||
@@ -109,8 +108,8 @@ def convert_to_agent_apps():
|
||||
db.session.commit()
|
||||
|
||||
# update conversation mode to agent
|
||||
db.session.execute(
|
||||
update(Conversation).where(Conversation.app_id == app.id).values(mode=AppMode.AGENT_CHAT)
|
||||
db.session.query(Conversation).where(Conversation.app_id == app.id).update(
|
||||
{Conversation.mode: AppMode.AGENT_CHAT}
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
@@ -178,7 +177,7 @@ where sites.id is null limit 1000"""
|
||||
continue
|
||||
|
||||
try:
|
||||
app = db.session.scalar(select(App).where(App.id == app_id))
|
||||
app = db.session.query(App).where(App.id == app_id).first()
|
||||
if not app:
|
||||
logger.info("App %s not found", app_id)
|
||||
continue
|
||||
|
||||
@@ -10,12 +10,10 @@ from configs import dify_config
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.models.document import ChildDocument, Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.enums import DatasetMetadataType, IndexingStatus, SegmentStatus
|
||||
from models.model import App, AppAnnotationSetting, MessageAnnotation
|
||||
|
||||
|
||||
@@ -42,13 +40,14 @@ def migrate_annotation_vector_database():
|
||||
# get apps info
|
||||
per_page = 50
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
apps = session.scalars(
|
||||
select(App)
|
||||
apps = (
|
||||
session.query(App)
|
||||
.where(App.status == "normal")
|
||||
.order_by(App.created_at.desc())
|
||||
.limit(per_page)
|
||||
.offset((page - 1) * per_page)
|
||||
).all()
|
||||
.all()
|
||||
)
|
||||
if not apps:
|
||||
break
|
||||
except SQLAlchemyError:
|
||||
@@ -63,8 +62,8 @@ def migrate_annotation_vector_database():
|
||||
try:
|
||||
click.echo(f"Creating app annotation index: {app.id}")
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
app_annotation_setting = session.scalar(
|
||||
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).limit(1)
|
||||
app_annotation_setting = (
|
||||
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
|
||||
)
|
||||
|
||||
if not app_annotation_setting:
|
||||
@@ -72,10 +71,10 @@ def migrate_annotation_vector_database():
|
||||
click.echo(f"App annotation setting disabled: {app.id}")
|
||||
continue
|
||||
# get dataset_collection_binding info
|
||||
dataset_collection_binding = session.scalar(
|
||||
select(DatasetCollectionBinding).where(
|
||||
DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id
|
||||
)
|
||||
dataset_collection_binding = (
|
||||
session.query(DatasetCollectionBinding)
|
||||
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
|
||||
.first()
|
||||
)
|
||||
if not dataset_collection_binding:
|
||||
click.echo(f"App annotation collection binding not found: {app.id}")
|
||||
@@ -86,7 +85,7 @@ def migrate_annotation_vector_database():
|
||||
dataset = Dataset(
|
||||
id=app.id,
|
||||
tenant_id=app.tenant_id,
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider=dataset_collection_binding.provider_name,
|
||||
embedding_model=dataset_collection_binding.model_name,
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
@@ -161,7 +160,6 @@ def migrate_knowledge_vector_database():
|
||||
}
|
||||
lower_collection_vector_types = {
|
||||
VectorType.ANALYTICDB,
|
||||
VectorType.HOLOGRES,
|
||||
VectorType.CHROMA,
|
||||
VectorType.MYSCALE,
|
||||
VectorType.PGVECTO_RS,
|
||||
@@ -178,9 +176,7 @@ def migrate_knowledge_vector_database():
|
||||
while True:
|
||||
try:
|
||||
stmt = (
|
||||
select(Dataset)
|
||||
.where(Dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY)
|
||||
.order_by(Dataset.created_at.desc())
|
||||
select(Dataset).where(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc())
|
||||
)
|
||||
|
||||
datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
|
||||
@@ -207,11 +203,11 @@ def migrate_knowledge_vector_database():
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
elif vector_type == VectorType.QDRANT:
|
||||
if dataset.collection_binding_id:
|
||||
dataset_collection_binding = db.session.execute(
|
||||
select(DatasetCollectionBinding).where(
|
||||
DatasetCollectionBinding.id == dataset.collection_binding_id
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
.where(DatasetCollectionBinding.id == dataset.collection_binding_id)
|
||||
.one_or_none()
|
||||
)
|
||||
if dataset_collection_binding:
|
||||
collection_name = dataset_collection_binding.collection_name
|
||||
else:
|
||||
@@ -245,7 +241,7 @@ def migrate_knowledge_vector_database():
|
||||
dataset_documents = db.session.scalars(
|
||||
select(DatasetDocument).where(
|
||||
DatasetDocument.dataset_id == dataset.id,
|
||||
DatasetDocument.indexing_status == IndexingStatus.COMPLETED,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
@@ -257,7 +253,7 @@ def migrate_knowledge_vector_database():
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == SegmentStatus.COMPLETED,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
).all()
|
||||
@@ -272,7 +268,7 @@ def migrate_knowledge_vector_database():
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
if dataset_document.doc_form == "hierarchical_model":
|
||||
child_chunks = segment.get_child_chunks()
|
||||
if child_chunks:
|
||||
child_documents = []
|
||||
@@ -336,7 +332,7 @@ def add_qdrant_index(field: str):
|
||||
create_count = 0
|
||||
|
||||
try:
|
||||
bindings = db.session.scalars(select(DatasetCollectionBinding)).all()
|
||||
bindings = db.session.query(DatasetCollectionBinding).all()
|
||||
if not bindings:
|
||||
click.echo(click.style("No dataset collection bindings found.", fg="red"))
|
||||
return
|
||||
@@ -423,22 +419,22 @@ def old_metadata_migration():
|
||||
if field.value == key:
|
||||
break
|
||||
else:
|
||||
dataset_metadata = db.session.scalar(
|
||||
select(DatasetMetadata)
|
||||
dataset_metadata = (
|
||||
db.session.query(DatasetMetadata)
|
||||
.where(DatasetMetadata.dataset_id == document.dataset_id, DatasetMetadata.name == key)
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
if not dataset_metadata:
|
||||
dataset_metadata = DatasetMetadata(
|
||||
tenant_id=document.tenant_id,
|
||||
dataset_id=document.dataset_id,
|
||||
name=key,
|
||||
type=DatasetMetadataType.STRING,
|
||||
type="string",
|
||||
created_by=document.created_by,
|
||||
)
|
||||
db.session.add(dataset_metadata)
|
||||
db.session.flush()
|
||||
dataset_metadata_binding: DatasetMetadataBinding | None = DatasetMetadataBinding(
|
||||
dataset_metadata_binding = DatasetMetadataBinding(
|
||||
tenant_id=document.tenant_id,
|
||||
dataset_id=document.dataset_id,
|
||||
metadata_id=dataset_metadata.id,
|
||||
@@ -447,14 +443,14 @@ def old_metadata_migration():
|
||||
)
|
||||
db.session.add(dataset_metadata_binding)
|
||||
else:
|
||||
dataset_metadata_binding = db.session.scalar(
|
||||
select(DatasetMetadataBinding)
|
||||
dataset_metadata_binding = (
|
||||
db.session.query(DatasetMetadataBinding) # type: ignore
|
||||
.where(
|
||||
DatasetMetadataBinding.dataset_id == document.dataset_id,
|
||||
DatasetMetadataBinding.document_id == document.id,
|
||||
DatasetMetadataBinding.metadata_id == dataset_metadata.id,
|
||||
)
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
if not dataset_metadata_binding:
|
||||
dataset_metadata_binding = DatasetMetadataBinding(
|
||||
|
||||
@@ -2,7 +2,6 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict, TomlConfigSettingsSource
|
||||
|
||||
@@ -83,17 +82,6 @@ class DifyConfig(
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
SANDBOX_DIFY_CLI_ROOT: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Filesystem directory containing dify CLI binaries named dify-cli-<os>-<arch>. "
|
||||
"Defaults to api/bin when unset."
|
||||
),
|
||||
)
|
||||
DIFY_PORT: int = Field(
|
||||
default=5001,
|
||||
description="Port used by Dify to communicate with the host machine.",
|
||||
)
|
||||
# Before adding any config,
|
||||
# please consider to arrange it in the proper config group of existed or added
|
||||
# for better readability and maintainability.
|
||||
|
||||
@@ -271,17 +271,6 @@ class PluginConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class CliApiConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for CLI API (for dify-cli to call back from external sandbox environments)
|
||||
"""
|
||||
|
||||
CLI_API_URL: str = Field(
|
||||
description="CLI API URL for external sandbox (e.g., e2b) to call back.",
|
||||
default="http://localhost:5001",
|
||||
)
|
||||
|
||||
|
||||
class MarketplaceConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for marketplace
|
||||
@@ -298,27 +287,6 @@ class MarketplaceConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class CreatorsPlatformConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for creators platform
|
||||
"""
|
||||
|
||||
CREATORS_PLATFORM_FEATURES_ENABLED: bool = Field(
|
||||
description="Enable or disable creators platform features",
|
||||
default=True,
|
||||
)
|
||||
|
||||
CREATORS_PLATFORM_API_URL: HttpUrl = Field(
|
||||
description="Creators Platform API URL",
|
||||
default=HttpUrl("https://creators.dify.ai"),
|
||||
)
|
||||
|
||||
CREATORS_PLATFORM_OAUTH_CLIENT_ID: str = Field(
|
||||
description="OAuth client_id for the Creators Platform app registered in Dify",
|
||||
default="",
|
||||
)
|
||||
|
||||
|
||||
class EndpointConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for various application endpoints and URLs
|
||||
@@ -373,15 +341,6 @@ class FileAccessConfig(BaseSettings):
|
||||
default="",
|
||||
)
|
||||
|
||||
FILES_API_URL: str = Field(
|
||||
description="Base URL for storage file ticket API endpoints."
|
||||
" Used by sandbox containers (internal or external like e2b) that need"
|
||||
" an absolute, routable address to upload/download files via the API."
|
||||
" For all-in-one Docker deployments, set to http://localhost."
|
||||
" For public sandbox environments, set to a public domain or IP.",
|
||||
default="",
|
||||
)
|
||||
|
||||
FILES_ACCESS_TIMEOUT: int = Field(
|
||||
description="Expiration time in seconds for file access URLs",
|
||||
default=300,
|
||||
@@ -1315,13 +1274,6 @@ class PositionConfig(BaseSettings):
|
||||
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""}
|
||||
|
||||
|
||||
class CollaborationConfig(BaseSettings):
|
||||
ENABLE_COLLABORATION_MODE: bool = Field(
|
||||
description="Whether to enable collaboration mode features across the workspace",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class LoginConfig(BaseSettings):
|
||||
ENABLE_EMAIL_CODE_LOGIN: bool = Field(
|
||||
description="whether to enable email code login",
|
||||
@@ -1423,9 +1375,7 @@ class FeatureConfig(
|
||||
TriggerConfig,
|
||||
AsyncWorkflowConfig,
|
||||
PluginConfig,
|
||||
CliApiConfig,
|
||||
MarketplaceConfig,
|
||||
CreatorsPlatformConfig,
|
||||
DataSetConfig,
|
||||
EndpointConfig,
|
||||
FileAccessConfig,
|
||||
@@ -1449,7 +1399,6 @@ class FeatureConfig(
|
||||
WorkflowConfig,
|
||||
WorkflowNodeExecutionConfig,
|
||||
WorkspaceConfig,
|
||||
CollaborationConfig,
|
||||
LoginConfig,
|
||||
AccountConfig,
|
||||
SwaggerUIConfig,
|
||||
|
||||
@@ -26,7 +26,6 @@ from .vdb.chroma_config import ChromaConfig
|
||||
from .vdb.clickzetta_config import ClickzettaConfig
|
||||
from .vdb.couchbase_config import CouchbaseConfig
|
||||
from .vdb.elasticsearch_config import ElasticsearchConfig
|
||||
from .vdb.hologres_config import HologresConfig
|
||||
from .vdb.huawei_cloud_config import HuaweiCloudConfig
|
||||
from .vdb.iris_config import IrisVectorConfig
|
||||
from .vdb.lindorm_config import LindormConfig
|
||||
@@ -348,7 +347,6 @@ class MiddlewareConfig(
|
||||
AnalyticdbConfig,
|
||||
ChromaConfig,
|
||||
ClickzettaConfig,
|
||||
HologresConfig,
|
||||
HuaweiCloudConfig,
|
||||
IrisVectorConfig,
|
||||
MilvusConfig,
|
||||
|
||||
12
api/configs/middleware/cache/redis_config.py
vendored
12
api/configs/middleware/cache/redis_config.py
vendored
@@ -1,4 +1,4 @@
|
||||
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, field_validator
|
||||
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
@@ -116,13 +116,3 @@ class RedisConfig(BaseSettings):
|
||||
description="Maximum connections in the Redis connection pool (unset for library default)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
@field_validator("REDIS_MAX_CONNECTIONS", mode="before")
|
||||
@classmethod
|
||||
def _empty_string_to_none_for_max_conns(cls, v):
|
||||
"""Allow empty string in env/.env to mean 'unset' (None)."""
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, str) and v.strip() == "":
|
||||
return None
|
||||
return v
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Literal, Protocol, cast
|
||||
from typing import Literal, Protocol
|
||||
from urllib.parse import quote_plus, urlunparse
|
||||
|
||||
from pydantic import AliasChoices, Field
|
||||
@@ -12,13 +12,16 @@ class RedisConfigDefaults(Protocol):
|
||||
REDIS_PASSWORD: str | None
|
||||
REDIS_DB: int
|
||||
REDIS_USE_SSL: bool
|
||||
REDIS_USE_SENTINEL: bool | None
|
||||
REDIS_USE_CLUSTERS: bool
|
||||
|
||||
|
||||
def _redis_defaults(config: object) -> RedisConfigDefaults:
|
||||
return cast(RedisConfigDefaults, config)
|
||||
class RedisConfigDefaultsMixin:
|
||||
def _redis_defaults(self: RedisConfigDefaults) -> RedisConfigDefaults:
|
||||
return self
|
||||
|
||||
|
||||
class RedisPubSubConfig(BaseSettings):
|
||||
class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
|
||||
"""
|
||||
Configuration settings for event transport between API and workers.
|
||||
|
||||
@@ -38,10 +41,10 @@ class RedisPubSubConfig(BaseSettings):
|
||||
)
|
||||
|
||||
PUBSUB_REDIS_USE_CLUSTERS: bool = Field(
|
||||
validation_alias=AliasChoices("EVENT_BUS_REDIS_USE_CLUSTERS", "PUBSUB_REDIS_USE_CLUSTERS"),
|
||||
validation_alias=AliasChoices("EVENT_BUS_REDIS_CLUSTERS", "PUBSUB_REDIS_USE_CLUSTERS"),
|
||||
description=(
|
||||
"Enable Redis Cluster mode for pub/sub or streams transport. Recommended for large deployments. "
|
||||
"Also accepts ENV: EVENT_BUS_REDIS_USE_CLUSTERS."
|
||||
"Also accepts ENV: EVENT_BUS_REDIS_CLUSTERS."
|
||||
),
|
||||
default=False,
|
||||
)
|
||||
@@ -71,7 +74,7 @@ class RedisPubSubConfig(BaseSettings):
|
||||
)
|
||||
|
||||
def _build_default_pubsub_url(self) -> str:
|
||||
defaults = _redis_defaults(self)
|
||||
defaults = self._redis_defaults()
|
||||
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:
|
||||
raise ValueError("PUBSUB_REDIS_URL must be set when default Redis URL cannot be constructed")
|
||||
|
||||
@@ -88,9 +91,11 @@ class RedisPubSubConfig(BaseSettings):
|
||||
if userinfo:
|
||||
userinfo = f"{userinfo}@"
|
||||
|
||||
host = defaults.REDIS_HOST
|
||||
port = defaults.REDIS_PORT
|
||||
db = defaults.REDIS_DB
|
||||
|
||||
netloc = f"{userinfo}{defaults.REDIS_HOST}:{defaults.REDIS_PORT}"
|
||||
netloc = f"{userinfo}{host}:{port}"
|
||||
return urlunparse((scheme, netloc, f"/{db}", "", "", ""))
|
||||
|
||||
@property
|
||||
|
||||
@@ -51,18 +51,3 @@ class BaiduVectorDBConfig(BaseSettings):
|
||||
description="Parser mode for inverted index in Baidu Vector Database (default is COARSE_MODE)",
|
||||
default="COARSE_MODE",
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT: int = Field(
|
||||
description="Auto build row count increment threshold (default is 500)",
|
||||
default=500,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO: float = Field(
|
||||
description="Auto build row count increment ratio threshold (default is 0.05)",
|
||||
default=0.05,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS: int = Field(
|
||||
description="Timeout in seconds for rebuilding the index in Baidu Vector Database (default is 3600 seconds)",
|
||||
default=300,
|
||||
)
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class HologresConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for Hologres vector database.
|
||||
|
||||
Hologres is compatible with PostgreSQL protocol.
|
||||
access_key_id is used as the PostgreSQL username,
|
||||
and access_key_secret is used as the PostgreSQL password.
|
||||
"""
|
||||
|
||||
HOLOGRES_HOST: str | None = Field(
|
||||
description="Hostname or IP address of the Hologres instance.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOLOGRES_PORT: int = Field(
|
||||
description="Port number for connecting to the Hologres instance.",
|
||||
default=80,
|
||||
)
|
||||
|
||||
HOLOGRES_DATABASE: str | None = Field(
|
||||
description="Name of the Hologres database to connect to.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOLOGRES_ACCESS_KEY_ID: str | None = Field(
|
||||
description="Alibaba Cloud AccessKey ID, also used as the PostgreSQL username.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOLOGRES_ACCESS_KEY_SECRET: str | None = Field(
|
||||
description="Alibaba Cloud AccessKey Secret, also used as the PostgreSQL password.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOLOGRES_SCHEMA: str = Field(
|
||||
description="Schema name in the Hologres database.",
|
||||
default="public",
|
||||
)
|
||||
|
||||
HOLOGRES_TOKENIZER: TokenizerType = Field(
|
||||
description="Tokenizer for full-text search index (e.g., 'jieba', 'ik', 'standard', 'simple').",
|
||||
default="jieba",
|
||||
)
|
||||
|
||||
HOLOGRES_DISTANCE_METHOD: DistanceType = Field(
|
||||
description="Distance method for vector index (e.g., 'Cosine', 'Euclidean', 'InnerProduct').",
|
||||
default="Cosine",
|
||||
)
|
||||
|
||||
HOLOGRES_BASE_QUANTIZATION_TYPE: BaseQuantizationType = Field(
|
||||
description="Base quantization type for vector index (e.g., 'rabitq', 'sq8', 'fp16', 'fp32').",
|
||||
default="rabitq",
|
||||
)
|
||||
|
||||
HOLOGRES_MAX_DEGREE: int = Field(
|
||||
description="Max degree (M) parameter for HNSW vector index.",
|
||||
default=64,
|
||||
)
|
||||
|
||||
HOLOGRES_EF_CONSTRUCTION: int = Field(
|
||||
description="ef_construction parameter for HNSW vector index.",
|
||||
default=400,
|
||||
)
|
||||
@@ -1,27 +0,0 @@
|
||||
from flask import Blueprint
|
||||
from flask_restx import Namespace
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint("cli_api", __name__, url_prefix="/cli/api")
|
||||
|
||||
api = ExternalApi(
|
||||
bp,
|
||||
version="1.0",
|
||||
title="CLI API",
|
||||
description="APIs for Dify CLI to call back from external sandbox environments (e.g., e2b)",
|
||||
)
|
||||
|
||||
# Create namespace
|
||||
cli_api_ns = Namespace("cli_api", description="CLI API operations", path="/")
|
||||
|
||||
from .dify_cli import cli_api as _plugin
|
||||
|
||||
api.add_namespace(cli_api_ns)
|
||||
|
||||
__all__ = [
|
||||
"_plugin",
|
||||
"api",
|
||||
"bp",
|
||||
"cli_api_ns",
|
||||
]
|
||||
@@ -1,192 +0,0 @@
|
||||
from flask import abort
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
|
||||
from controllers.cli_api import cli_api_ns
|
||||
from controllers.cli_api.dify_cli.wraps import get_cli_user_tenant, plugin_data
|
||||
from controllers.cli_api.wraps import cli_api_only
|
||||
from controllers.console.wraps import setup_required
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
|
||||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocationResponse
|
||||
from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation
|
||||
from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation
|
||||
from core.plugin.entities.request import (
|
||||
RequestInvokeApp,
|
||||
RequestInvokeLLM,
|
||||
RequestInvokeTool,
|
||||
RequestRequestUploadFile,
|
||||
)
|
||||
from core.sandbox.bash.dify_cli import DifyCliToolConfig
|
||||
from core.session.cli_api import CliContext
|
||||
from core.skill.entities import ToolInvocationRequest
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from dify_graph.file.helpers import get_signed_file_url_for_plugin
|
||||
from libs.helper import length_prefixed_response
|
||||
from models.account import Account
|
||||
from models.model import EndUser, Tenant
|
||||
|
||||
|
||||
class FetchToolItem(BaseModel):
|
||||
tool_type: str
|
||||
tool_provider: str
|
||||
tool_name: str
|
||||
credential_id: str | None = None
|
||||
|
||||
|
||||
class FetchToolBatchRequest(BaseModel):
|
||||
tools: list[FetchToolItem]
|
||||
|
||||
|
||||
@cli_api_ns.route("/invoke/llm")
|
||||
class CliInvokeLLMApi(Resource):
|
||||
@cli_api_only
|
||||
@get_cli_user_tenant
|
||||
@setup_required
|
||||
@plugin_data(payload_type=RequestInvokeLLM)
|
||||
def post(
|
||||
self,
|
||||
user_model: Account | EndUser,
|
||||
tenant_model: Tenant,
|
||||
payload: RequestInvokeLLM,
|
||||
cli_context: CliContext,
|
||||
):
|
||||
def generator():
|
||||
response = PluginModelBackwardsInvocation.invoke_llm(user_model.id, tenant_model, payload)
|
||||
return PluginModelBackwardsInvocation.convert_to_event_stream(response)
|
||||
|
||||
return length_prefixed_response(0xF, generator())
|
||||
|
||||
|
||||
@cli_api_ns.route("/invoke/tool")
|
||||
class CliInvokeToolApi(Resource):
|
||||
@cli_api_only
|
||||
@get_cli_user_tenant
|
||||
@setup_required
|
||||
@plugin_data(payload_type=RequestInvokeTool)
|
||||
def post(
|
||||
self,
|
||||
user_model: Account | EndUser,
|
||||
tenant_model: Tenant,
|
||||
payload: RequestInvokeTool,
|
||||
cli_context: CliContext,
|
||||
):
|
||||
tool_type = ToolProviderType.value_of(payload.tool_type)
|
||||
|
||||
request = ToolInvocationRequest(
|
||||
tool_type=tool_type,
|
||||
provider=payload.provider,
|
||||
tool_name=payload.tool,
|
||||
credential_id=payload.credential_id,
|
||||
)
|
||||
if cli_context.tool_access and not cli_context.tool_access.is_allowed(request):
|
||||
abort(403, description=f"Access denied for tool: {payload.provider}/{payload.tool}")
|
||||
|
||||
def generator():
|
||||
return PluginToolBackwardsInvocation.convert_to_event_stream(
|
||||
PluginToolBackwardsInvocation.invoke_tool(
|
||||
tenant_id=tenant_model.id,
|
||||
user_id=user_model.id,
|
||||
tool_type=tool_type,
|
||||
provider=payload.provider,
|
||||
tool_name=payload.tool,
|
||||
tool_parameters=payload.tool_parameters,
|
||||
credential_id=payload.credential_id,
|
||||
),
|
||||
)
|
||||
|
||||
return length_prefixed_response(0xF, generator())
|
||||
|
||||
|
||||
@cli_api_ns.route("/invoke/app")
|
||||
class CliInvokeAppApi(Resource):
|
||||
@cli_api_only
|
||||
@get_cli_user_tenant
|
||||
@setup_required
|
||||
@plugin_data(payload_type=RequestInvokeApp)
|
||||
def post(
|
||||
self,
|
||||
user_model: Account | EndUser,
|
||||
tenant_model: Tenant,
|
||||
payload: RequestInvokeApp,
|
||||
cli_context: CliContext,
|
||||
):
|
||||
response = PluginAppBackwardsInvocation.invoke_app(
|
||||
app_id=payload.app_id,
|
||||
user_id=user_model.id,
|
||||
tenant_id=tenant_model.id,
|
||||
conversation_id=payload.conversation_id,
|
||||
query=payload.query,
|
||||
stream=payload.response_mode == "streaming",
|
||||
inputs=payload.inputs,
|
||||
files=payload.files,
|
||||
)
|
||||
|
||||
return length_prefixed_response(0xF, PluginAppBackwardsInvocation.convert_to_event_stream(response))
|
||||
|
||||
|
||||
@cli_api_ns.route("/upload/file/request")
|
||||
class CliUploadFileRequestApi(Resource):
|
||||
@cli_api_only
|
||||
@get_cli_user_tenant
|
||||
@setup_required
|
||||
@plugin_data(payload_type=RequestRequestUploadFile)
|
||||
def post(
|
||||
self,
|
||||
user_model: Account | EndUser,
|
||||
tenant_model: Tenant,
|
||||
payload: RequestRequestUploadFile,
|
||||
cli_context: CliContext,
|
||||
):
|
||||
url = get_signed_file_url_for_plugin(
|
||||
filename=payload.filename,
|
||||
mimetype=payload.mimetype,
|
||||
tenant_id=tenant_model.id,
|
||||
user_id=user_model.id,
|
||||
)
|
||||
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()
|
||||
|
||||
|
||||
@cli_api_ns.route("/fetch/tools/batch")
|
||||
class CliFetchToolsBatchApi(Resource):
|
||||
@cli_api_only
|
||||
@get_cli_user_tenant
|
||||
@setup_required
|
||||
@plugin_data(payload_type=FetchToolBatchRequest)
|
||||
def post(
|
||||
self,
|
||||
user_model: Account | EndUser,
|
||||
tenant_model: Tenant,
|
||||
payload: FetchToolBatchRequest,
|
||||
cli_context: CliContext,
|
||||
):
|
||||
tools: list[dict] = []
|
||||
|
||||
for item in payload.tools:
|
||||
provider_type = ToolProviderType.value_of(item.tool_type)
|
||||
|
||||
request = ToolInvocationRequest(
|
||||
tool_type=provider_type,
|
||||
provider=item.tool_provider,
|
||||
tool_name=item.tool_name,
|
||||
credential_id=item.credential_id,
|
||||
)
|
||||
if cli_context.tool_access and not cli_context.tool_access.is_allowed(request):
|
||||
abort(403, description=f"Access denied for tool: {item.tool_provider}/{item.tool_name}")
|
||||
|
||||
try:
|
||||
tool_runtime = ToolManager.get_tool_runtime(
|
||||
tenant_id=tenant_model.id,
|
||||
provider_type=provider_type,
|
||||
provider_id=item.tool_provider,
|
||||
tool_name=item.tool_name,
|
||||
invoke_from=InvokeFrom.AGENT,
|
||||
credential_id=item.credential_id,
|
||||
)
|
||||
tool_config = DifyCliToolConfig.create_from_tool(tool_runtime)
|
||||
tools.append(tool_config.model_dump())
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return BaseBackwardsInvocationResponse(data={"tools": tools}).model_dump()
|
||||
@@ -1,137 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import current_app, g, request
|
||||
from flask_login import user_logged_in
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.session.cli_api import CliApiSession, CliContext
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from models.account import Tenant
|
||||
from models.model import DefaultEndUserSessionID, EndUser
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class TenantUserPayload(BaseModel):
|
||||
tenant_id: str
|
||||
user_id: str
|
||||
|
||||
|
||||
def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
||||
"""
|
||||
Get current user
|
||||
|
||||
NOTE: user_id is not trusted, it could be maliciously set to any value.
|
||||
As a result, it could only be considered as an end user id.
|
||||
"""
|
||||
if not user_id:
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
user_model = None
|
||||
|
||||
if is_anonymous:
|
||||
user_model = (
|
||||
session.query(EndUser)
|
||||
.where(
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
else:
|
||||
user_model = (
|
||||
session.query(EndUser)
|
||||
.where(
|
||||
EndUser.id == user_id,
|
||||
EndUser.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not user_model:
|
||||
user_model = EndUser(
|
||||
tenant_id=tenant_id,
|
||||
type="service_api",
|
||||
is_anonymous=is_anonymous,
|
||||
session_id=user_id,
|
||||
)
|
||||
session.add(user_model)
|
||||
session.commit()
|
||||
session.refresh(user_model)
|
||||
|
||||
except Exception:
|
||||
raise ValueError("user not found")
|
||||
|
||||
return user_model
|
||||
|
||||
|
||||
def get_cli_user_tenant(view_func: Callable[P, R]):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
session: CliApiSession | None = getattr(g, "cli_api_session", None)
|
||||
if session is None:
|
||||
raise ValueError("session not found")
|
||||
|
||||
user_id = session.user_id
|
||||
tenant_id = session.tenant_id
|
||||
cli_context = CliContext.model_validate(session.context)
|
||||
|
||||
if not user_id:
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
|
||||
try:
|
||||
tenant_model = (
|
||||
db.session.query(Tenant)
|
||||
.where(
|
||||
Tenant.id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
except Exception:
|
||||
raise ValueError("tenant not found")
|
||||
|
||||
if not tenant_model:
|
||||
raise ValueError("tenant not found")
|
||||
|
||||
kwargs["tenant_model"] = tenant_model
|
||||
kwargs["user_model"] = get_user(tenant_id, user_id)
|
||||
kwargs["cli_context"] = cli_context
|
||||
|
||||
current_app.login_manager._update_request_context_with_user(kwargs["user_model"]) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
|
||||
def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]):
|
||||
def decorator(view_func: Callable[P, R]):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
try:
|
||||
data = request.get_json()
|
||||
except Exception:
|
||||
raise ValueError("invalid json")
|
||||
|
||||
try:
|
||||
payload = payload_type.model_validate(data)
|
||||
except Exception as e:
|
||||
raise ValueError(f"invalid payload: {str(e)}")
|
||||
|
||||
kwargs["payload"] = payload
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
@@ -1,56 +0,0 @@
|
||||
import hashlib
|
||||
import hmac
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import abort, g, request
|
||||
|
||||
from core.session.cli_api import CliApiSessionManager
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
SIGNATURE_TTL_SECONDS = 300
|
||||
|
||||
|
||||
def _verify_signature(session_secret: str, timestamp: str, body: bytes, signature: str) -> bool:
|
||||
expected = hmac.new(
|
||||
session_secret.encode(),
|
||||
f"{timestamp}.".encode() + body,
|
||||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
return hmac.compare_digest(f"sha256={expected}", signature)
|
||||
|
||||
|
||||
def cli_api_only(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
session_id = request.headers.get("X-Cli-Api-Session-Id")
|
||||
timestamp = request.headers.get("X-Cli-Api-Timestamp")
|
||||
signature = request.headers.get("X-Cli-Api-Signature")
|
||||
|
||||
if not session_id or not timestamp or not signature:
|
||||
abort(401)
|
||||
|
||||
try:
|
||||
ts = int(timestamp)
|
||||
if abs(time.time() - ts) > SIGNATURE_TTL_SECONDS:
|
||||
abort(401)
|
||||
except ValueError:
|
||||
abort(401)
|
||||
|
||||
session = CliApiSessionManager().get(session_id)
|
||||
if not session:
|
||||
abort(401)
|
||||
|
||||
body = request.get_data()
|
||||
if not _verify_signature(session.secret, timestamp, body, signature):
|
||||
abort(401)
|
||||
|
||||
g.cli_api_session = session
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
@@ -32,7 +32,6 @@ for module_name in RESOURCE_MODULES:
|
||||
|
||||
# Ensure resource modules are imported so route decorators are evaluated.
|
||||
# Import other controllers
|
||||
# Sandbox file browser
|
||||
from . import (
|
||||
admin,
|
||||
apikey,
|
||||
@@ -42,7 +41,6 @@ from . import (
|
||||
init_validate,
|
||||
notification,
|
||||
ping,
|
||||
sandbox_files,
|
||||
setup,
|
||||
spec,
|
||||
version,
|
||||
@@ -54,7 +52,6 @@ from .app import (
|
||||
agent,
|
||||
annotation,
|
||||
app,
|
||||
app_asset,
|
||||
audio,
|
||||
completion,
|
||||
conversation,
|
||||
@@ -65,11 +62,9 @@ from .app import (
|
||||
model_config,
|
||||
ops_trace,
|
||||
site,
|
||||
skills,
|
||||
statistic,
|
||||
workflow,
|
||||
workflow_app_log,
|
||||
workflow_comment,
|
||||
workflow_draft_variable,
|
||||
workflow_run,
|
||||
workflow_statistic,
|
||||
@@ -121,7 +116,6 @@ from .explore import (
|
||||
saved_message,
|
||||
trial,
|
||||
)
|
||||
from .socketio import workflow as socketio_workflow # pyright: ignore[reportUnusedImport]
|
||||
|
||||
# Import tag controllers
|
||||
from .tag import tags
|
||||
@@ -136,7 +130,6 @@ from .workspace import (
|
||||
model_providers,
|
||||
models,
|
||||
plugin,
|
||||
sandbox_providers,
|
||||
tool_providers,
|
||||
trigger_providers,
|
||||
workspace,
|
||||
@@ -155,7 +148,6 @@ __all__ = [
|
||||
"api",
|
||||
"apikey",
|
||||
"app",
|
||||
"app_asset",
|
||||
"audio",
|
||||
"banner",
|
||||
"billing",
|
||||
@@ -206,12 +198,9 @@ __all__ = [
|
||||
"rag_pipeline_import",
|
||||
"rag_pipeline_workflow",
|
||||
"recommended_app",
|
||||
"sandbox_files",
|
||||
"sandbox_providers",
|
||||
"saved_message",
|
||||
"setup",
|
||||
"site",
|
||||
"skills",
|
||||
"spec",
|
||||
"statistic",
|
||||
"tags",
|
||||
@@ -222,7 +211,6 @@ __all__ = [
|
||||
"website",
|
||||
"workflow",
|
||||
"workflow_app_log",
|
||||
"workflow_comment",
|
||||
"workflow_draft_variable",
|
||||
"workflow_run",
|
||||
"workflow_statistic",
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import flask_restx
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from flask_restx._http import HTTPStatus
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
@@ -9,7 +9,6 @@ from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import Dataset
|
||||
from models.enums import ApiTokenType
|
||||
from models.model import ApiToken, App
|
||||
from services.api_token_service import ApiTokenCache
|
||||
|
||||
@@ -34,10 +33,16 @@ api_key_list_model = console_ns.model(
|
||||
|
||||
|
||||
def _get_resource(resource_id, tenant_id, resource_model):
|
||||
with Session(db.engine) as session:
|
||||
resource = session.execute(
|
||||
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
|
||||
).scalar_one_or_none()
|
||||
if resource_model == App:
|
||||
with Session(db.engine) as session:
|
||||
resource = session.execute(
|
||||
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
|
||||
).scalar_one_or_none()
|
||||
else:
|
||||
with Session(db.engine) as session:
|
||||
resource = session.execute(
|
||||
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if resource is None:
|
||||
flask_restx.abort(HTTPStatus.NOT_FOUND, message=f"{resource_model.__name__} not found.")
|
||||
@@ -48,7 +53,7 @@ def _get_resource(resource_id, tenant_id, resource_model):
|
||||
class BaseApiKeyListResource(Resource):
|
||||
method_decorators = [account_initialization_required, login_required, setup_required]
|
||||
|
||||
resource_type: ApiTokenType | None = None
|
||||
resource_type: str | None = None
|
||||
resource_model: type | None = None
|
||||
resource_id_field: str | None = None
|
||||
token_prefix: str | None = None
|
||||
@@ -75,13 +80,10 @@ class BaseApiKeyListResource(Resource):
|
||||
resource_id = str(resource_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
current_key_count: int = (
|
||||
db.session.scalar(
|
||||
select(func.count(ApiToken.id)).where(
|
||||
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
|
||||
)
|
||||
)
|
||||
or 0
|
||||
current_key_count = (
|
||||
db.session.query(ApiToken)
|
||||
.where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
|
||||
.count()
|
||||
)
|
||||
|
||||
if current_key_count >= self.max_keys:
|
||||
@@ -92,7 +94,6 @@ class BaseApiKeyListResource(Resource):
|
||||
)
|
||||
|
||||
key = ApiToken.generate_api_key(self.token_prefix or "", 24)
|
||||
assert self.resource_type is not None, "resource_type must be set"
|
||||
api_token = ApiToken()
|
||||
setattr(api_token, self.resource_id_field, resource_id)
|
||||
api_token.tenant_id = current_tenant_id
|
||||
@@ -106,7 +107,7 @@ class BaseApiKeyListResource(Resource):
|
||||
class BaseApiKeyResource(Resource):
|
||||
method_decorators = [account_initialization_required, login_required, setup_required]
|
||||
|
||||
resource_type: ApiTokenType | None = None
|
||||
resource_type: str | None = None
|
||||
resource_model: type | None = None
|
||||
resource_id_field: str | None = None
|
||||
|
||||
@@ -118,14 +119,14 @@ class BaseApiKeyResource(Resource):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
key = db.session.scalar(
|
||||
select(ApiToken)
|
||||
key = (
|
||||
db.session.query(ApiToken)
|
||||
.where(
|
||||
getattr(ApiToken, self.resource_id_field) == resource_id,
|
||||
ApiToken.type == self.resource_type,
|
||||
ApiToken.id == api_key_id,
|
||||
)
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
|
||||
if key is None:
|
||||
@@ -136,7 +137,7 @@ class BaseApiKeyResource(Resource):
|
||||
assert key is not None # nosec - for type checker only
|
||||
ApiTokenCache.delete(key.token, key.type)
|
||||
|
||||
db.session.execute(delete(ApiToken).where(ApiToken.id == api_key_id))
|
||||
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
@@ -161,7 +162,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
|
||||
"""Create a new API key for an app"""
|
||||
return super().post(resource_id)
|
||||
|
||||
resource_type = ApiTokenType.APP
|
||||
resource_type = "app"
|
||||
resource_model = App
|
||||
resource_id_field = "app_id"
|
||||
token_prefix = "app-"
|
||||
@@ -177,7 +178,7 @@ class AppApiKeyResource(BaseApiKeyResource):
|
||||
"""Delete an API key for an app"""
|
||||
return super().delete(resource_id, api_key_id)
|
||||
|
||||
resource_type = ApiTokenType.APP
|
||||
resource_type = "app"
|
||||
resource_model = App
|
||||
resource_id_field = "app_id"
|
||||
|
||||
@@ -201,7 +202,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
|
||||
"""Create a new API key for a dataset"""
|
||||
return super().post(resource_id)
|
||||
|
||||
resource_type = ApiTokenType.DATASET
|
||||
resource_type = "dataset"
|
||||
resource_model = Dataset
|
||||
resource_id_field = "dataset_id"
|
||||
token_prefix = "ds-"
|
||||
@@ -217,6 +218,6 @@ class DatasetApiKeyResource(BaseApiKeyResource):
|
||||
"""Delete an API key for a dataset"""
|
||||
return super().delete(resource_id, api_key_id)
|
||||
|
||||
resource_type = ApiTokenType.DATASET
|
||||
resource_type = "dataset"
|
||||
resource_model = Dataset
|
||||
resource_id_field = "dataset_id"
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any, Literal, TypeAlias
|
||||
|
||||
from flask import request
|
||||
@@ -26,14 +25,12 @@ from controllers.console.wraps import (
|
||||
)
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.trigger.constants import TRIGGER_NODE_TYPES
|
||||
from dify_graph.enums import WorkflowExecutionStatus
|
||||
from dify_graph.enums import NodeType, WorkflowExecutionStatus
|
||||
from dify_graph.file import helpers as file_helpers
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, DatasetPermissionEnum, Workflow
|
||||
from models.model import AppMode, IconType
|
||||
from models.workflow_features import WorkflowFeatures
|
||||
from models.model import IconType
|
||||
from services.app_dsl_service import AppDslService, ImportMode
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
@@ -61,11 +58,6 @@ register_enum_models(console_ns, IconType)
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RuntimeType(StrEnum):
|
||||
CLASSIC = "classic"
|
||||
SANDBOXED = "sandboxed"
|
||||
|
||||
|
||||
class AppListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
|
||||
@@ -102,7 +94,7 @@ class CreateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
||||
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
|
||||
icon_type: IconType | None = Field(default=None, description="Icon type")
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@@ -110,7 +102,7 @@ class CreateAppPayload(BaseModel):
|
||||
class UpdateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
||||
icon_type: IconType | None = Field(default=None, description="Icon type")
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
|
||||
@@ -120,7 +112,7 @@ class UpdateAppPayload(BaseModel):
|
||||
class CopyAppPayload(BaseModel):
|
||||
name: str | None = Field(default=None, description="Name for the copied app")
|
||||
description: str | None = Field(default=None, description="Description for the copied app", max_length=400)
|
||||
icon_type: IconType | None = Field(default=None, description="Icon type")
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@@ -130,11 +122,6 @@ class AppExportQuery(BaseModel):
|
||||
workflow_id: str | None = Field(default=None, description="Specific workflow ID to export")
|
||||
|
||||
|
||||
class AppExportBundleQuery(BaseModel):
|
||||
include_secret: bool = Field(default=False, description="Include secrets in export")
|
||||
workflow_id: str | None = Field(default=None, description="Specific workflow ID to export")
|
||||
|
||||
|
||||
class AppNamePayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="Name to check")
|
||||
|
||||
@@ -360,7 +347,6 @@ class AppPartial(ResponseModel):
|
||||
create_user_name: str | None = None
|
||||
author_name: str | None = None
|
||||
has_draft_trigger: bool | None = None
|
||||
runtime_type: RuntimeType = RuntimeType.CLASSIC
|
||||
|
||||
@computed_field(return_type=str | None) # type: ignore
|
||||
@property
|
||||
@@ -510,7 +496,6 @@ class AppListApi(Resource):
|
||||
str(app.id) for app in app_pagination.items if app.mode in {"workflow", "advanced-chat"}
|
||||
]
|
||||
draft_trigger_app_ids: set[str] = set()
|
||||
sandbox_app_ids: set[str] = set()
|
||||
if workflow_capable_app_ids:
|
||||
draft_workflows = (
|
||||
db.session.execute(
|
||||
@@ -523,12 +508,12 @@ class AppListApi(Resource):
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
trigger_node_types = TRIGGER_NODE_TYPES
|
||||
trigger_node_types = {
|
||||
NodeType.TRIGGER_WEBHOOK,
|
||||
NodeType.TRIGGER_SCHEDULE,
|
||||
NodeType.TRIGGER_PLUGIN,
|
||||
}
|
||||
for workflow in draft_workflows:
|
||||
# Check sandbox feature
|
||||
if workflow.get_feature(WorkflowFeatures.SANDBOX).enabled:
|
||||
sandbox_app_ids.add(str(workflow.app_id))
|
||||
|
||||
node_id = None
|
||||
try:
|
||||
for node_id, node_data in workflow.walk_nodes():
|
||||
@@ -541,7 +526,6 @@ class AppListApi(Resource):
|
||||
|
||||
for app in app_pagination.items:
|
||||
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
|
||||
app.runtime_type = RuntimeType.SANDBOXED if str(app.id) in sandbox_app_ids else RuntimeType.CLASSIC
|
||||
|
||||
pagination_model = AppPagination.model_validate(app_pagination, from_attributes=True)
|
||||
return pagination_model.model_dump(mode="json"), 200
|
||||
@@ -613,7 +597,7 @@ class AppApi(Resource):
|
||||
args_dict: AppService.ArgsDict = {
|
||||
"name": args.name,
|
||||
"description": args.description or "",
|
||||
"icon_type": args.icon_type,
|
||||
"icon_type": args.icon_type or "",
|
||||
"icon": args.icon or "",
|
||||
"icon_background": args.icon_background or "",
|
||||
"use_icon_as_answer_icon": args.use_icon_as_answer_icon or False,
|
||||
@@ -723,89 +707,6 @@ class AppExportApi(Resource):
|
||||
return payload.model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/export-bundle")
|
||||
class AppExportBundleApi(Resource):
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
from services.app_bundle_service import AppBundleService
|
||||
|
||||
args = AppExportBundleQuery.model_validate(request.args.to_dict(flat=True))
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
result = AppBundleService.export_bundle(
|
||||
app_model=app_model,
|
||||
account_id=str(current_user.id),
|
||||
include_secret=args.include_secret,
|
||||
workflow_id=args.workflow_id,
|
||||
)
|
||||
|
||||
return result.model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/publish-to-creators-platform")
|
||||
class AppPublishToCreatorsPlatformApi(Resource):
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
"""Export the app DSL and upload to Creators Platform, returning a redirect URL.
|
||||
|
||||
Classic apps export as YAML; sandboxed apps export as ZIP bundle.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.creators import get_redirect_url, upload_dsl
|
||||
from services.app_bundle_service import AppBundleService
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
if not dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
|
||||
return {"message": "Creators Platform is not enabled"}, 403
|
||||
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
# Determine if the app is sandboxed by checking the draft workflow's sandbox feature
|
||||
is_sandboxed = False
|
||||
if app_model.mode in {"workflow", "advanced-chat"}:
|
||||
draft_workflow = WorkflowService().get_draft_workflow(app_model)
|
||||
if draft_workflow and draft_workflow.get_feature(WorkflowFeatures.SANDBOX).enabled:
|
||||
is_sandboxed = True
|
||||
|
||||
if is_sandboxed:
|
||||
# Sandboxed app: export as ZIP bundle
|
||||
bundle_result = AppBundleService.export_bundle(
|
||||
app_model=app_model,
|
||||
account_id=str(current_user.id),
|
||||
include_secret=False,
|
||||
)
|
||||
download_response = httpx.get(bundle_result.download_url, timeout=60, follow_redirects=True)
|
||||
download_response.raise_for_status()
|
||||
file_bytes = download_response.content
|
||||
filename = bundle_result.filename
|
||||
else:
|
||||
# Classic app: export as YAML
|
||||
dsl_content = AppDslService.export_dsl(
|
||||
app_model=app_model,
|
||||
include_secret=False,
|
||||
)
|
||||
file_bytes = dsl_content.encode("utf-8")
|
||||
filename = f"{app_model.name}.yml"
|
||||
|
||||
# Upload to Creators Platform
|
||||
claim_code = upload_dsl(file_bytes, filename=filename)
|
||||
|
||||
# Generate redirect URL (with optional OAuth code)
|
||||
redirect_url = get_redirect_url(str(current_user.id), claim_code)
|
||||
|
||||
return {"redirect_url": redirect_url}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/name")
|
||||
class AppNameApi(Resource):
|
||||
@console_ns.doc("check_app_name")
|
||||
@@ -929,26 +830,3 @@ class AppTraceApi(Resource):
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/upgrade-runtime")
|
||||
class AppRuntimeUpgradeApi(Resource):
|
||||
@console_ns.doc("upgrade_app_runtime")
|
||||
@console_ns.doc(description="Clone the app and upgrade the clone to sandboxed runtime")
|
||||
@console_ns.doc(params={"app_id": "Application ID to upgrade"})
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
"""Upgrade app runtime by cloning to sandboxed mode"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
from services.app_runtime_upgrade_service import AppRuntimeUpgradeService
|
||||
|
||||
result = AppRuntimeUpgradeService(session).upgrade(app_model, current_user)
|
||||
session.commit()
|
||||
|
||||
return result, 200
|
||||
|
||||
@@ -1,333 +0,0 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
AppAssetNodeNotFoundError,
|
||||
AppAssetPathConflictError,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.app.entities.app_asset_entities import BatchUploadNode
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from services.app_asset_service import AppAssetService
|
||||
from services.errors.app_asset import (
|
||||
AppAssetNodeNotFoundError as ServiceNodeNotFoundError,
|
||||
)
|
||||
from services.errors.app_asset import (
|
||||
AppAssetParentNotFoundError,
|
||||
)
|
||||
from services.errors.app_asset import (
|
||||
AppAssetPathConflictError as ServicePathConflictError,
|
||||
)
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class CreateFolderPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
parent_id: str | None = None
|
||||
|
||||
|
||||
class CreateFilePayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
parent_id: str | None = None
|
||||
|
||||
@field_validator("name", mode="before")
|
||||
@classmethod
|
||||
def strip_name(cls, v: str) -> str:
|
||||
return v.strip() if isinstance(v, str) else v
|
||||
|
||||
@field_validator("parent_id", mode="before")
|
||||
@classmethod
|
||||
def empty_to_none(cls, v: str | None) -> str | None:
|
||||
return v or None
|
||||
|
||||
|
||||
class GetUploadUrlPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
size: int = Field(..., ge=0)
|
||||
parent_id: str | None = None
|
||||
|
||||
@field_validator("name", mode="before")
|
||||
@classmethod
|
||||
def strip_name(cls, v: str) -> str:
|
||||
return v.strip() if isinstance(v, str) else v
|
||||
|
||||
@field_validator("parent_id", mode="before")
|
||||
@classmethod
|
||||
def empty_to_none(cls, v: str | None) -> str | None:
|
||||
return v or None
|
||||
|
||||
|
||||
class BatchUploadPayload(BaseModel):
|
||||
children: list[BatchUploadNode] = Field(..., min_length=1)
|
||||
parent_id: str | None = None
|
||||
|
||||
@field_validator("parent_id", mode="before")
|
||||
@classmethod
|
||||
def empty_to_none(cls, v: str | None) -> str | None:
|
||||
return v or None
|
||||
|
||||
|
||||
class UpdateFileContentPayload(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class RenameNodePayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
|
||||
|
||||
class MoveNodePayload(BaseModel):
|
||||
parent_id: str | None = None
|
||||
|
||||
|
||||
class ReorderNodePayload(BaseModel):
|
||||
after_node_id: str | None = Field(default=None, description="Place after this node, None for first position")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]) -> None:
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(CreateFolderPayload)
|
||||
reg(CreateFilePayload)
|
||||
reg(GetUploadUrlPayload)
|
||||
reg(BatchUploadNode)
|
||||
reg(BatchUploadPayload)
|
||||
reg(UpdateFileContentPayload)
|
||||
reg(RenameNodePayload)
|
||||
reg(MoveNodePayload)
|
||||
reg(ReorderNodePayload)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/tree")
|
||||
class AppAssetTreeResource(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
tree = AppAssetService.get_asset_tree(app_model, current_user.id)
|
||||
return {"children": [view.model_dump() for view in tree.transform()]}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/folders")
|
||||
class AppAssetFolderResource(Resource):
|
||||
@console_ns.expect(console_ns.models[CreateFolderPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = CreateFolderPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
node = AppAssetService.create_folder(app_model, current_user.id, payload.name, payload.parent_id)
|
||||
return node.model_dump(), 201
|
||||
except AppAssetParentNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
except ServicePathConflictError:
|
||||
raise AppAssetPathConflictError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/files/<string:node_id>")
|
||||
class AppAssetFileDetailResource(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, node_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
try:
|
||||
content = AppAssetService.get_file_content(app_model, current_user.id, node_id)
|
||||
return {"content": content.decode("utf-8", errors="replace")}
|
||||
except ServiceNodeNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
|
||||
@console_ns.expect(console_ns.models[UpdateFileContentPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def put(self, app_model: App, node_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
file = request.files.get("file")
|
||||
if file:
|
||||
content = file.read()
|
||||
else:
|
||||
payload = UpdateFileContentPayload.model_validate(console_ns.payload or {})
|
||||
content = payload.content.encode("utf-8")
|
||||
|
||||
try:
|
||||
node = AppAssetService.update_file_content(app_model, current_user.id, node_id, content)
|
||||
return node.model_dump()
|
||||
except ServiceNodeNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/nodes/<string:node_id>")
|
||||
class AppAssetNodeResource(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def delete(self, app_model: App, node_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
try:
|
||||
AppAssetService.delete_node(app_model, current_user.id, node_id)
|
||||
return {"result": "success"}, 200
|
||||
except ServiceNodeNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/nodes/<string:node_id>/rename")
|
||||
class AppAssetNodeRenameResource(Resource):
|
||||
@console_ns.expect(console_ns.models[RenameNodePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = RenameNodePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
node = AppAssetService.rename_node(app_model, current_user.id, node_id, payload.name)
|
||||
return node.model_dump()
|
||||
except ServiceNodeNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
except ServicePathConflictError:
|
||||
raise AppAssetPathConflictError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/nodes/<string:node_id>/move")
|
||||
class AppAssetNodeMoveResource(Resource):
|
||||
@console_ns.expect(console_ns.models[MoveNodePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = MoveNodePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
node = AppAssetService.move_node(app_model, current_user.id, node_id, payload.parent_id)
|
||||
return node.model_dump()
|
||||
except ServiceNodeNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
except AppAssetParentNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
except ServicePathConflictError:
|
||||
raise AppAssetPathConflictError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/nodes/<string:node_id>/reorder")
|
||||
class AppAssetNodeReorderResource(Resource):
|
||||
@console_ns.expect(console_ns.models[ReorderNodePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = ReorderNodePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
node = AppAssetService.reorder_node(app_model, current_user.id, node_id, payload.after_node_id)
|
||||
return node.model_dump()
|
||||
except ServiceNodeNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/files/<string:node_id>/download-url")
|
||||
class AppAssetFileDownloadUrlResource(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, node_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
try:
|
||||
download_url = AppAssetService.get_file_download_url(app_model, current_user.id, node_id)
|
||||
return {"download_url": download_url}
|
||||
except ServiceNodeNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/files/upload")
|
||||
class AppAssetFileUploadUrlResource(Resource):
|
||||
@console_ns.expect(console_ns.models[GetUploadUrlPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = GetUploadUrlPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
node, upload_url = AppAssetService.get_file_upload_url(
|
||||
app_model, current_user.id, payload.name, payload.size, payload.parent_id
|
||||
)
|
||||
return {"node": node.model_dump(), "upload_url": upload_url}, 201
|
||||
except AppAssetParentNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
except ServicePathConflictError:
|
||||
raise AppAssetPathConflictError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/batch-upload")
|
||||
class AppAssetBatchUploadResource(Resource):
|
||||
@console_ns.expect(console_ns.models[BatchUploadPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
Create nodes from tree structure and return upload URLs.
|
||||
|
||||
Input:
|
||||
{
|
||||
"parent_id": "optional-target-folder-id",
|
||||
"children": [
|
||||
{"name": "folder1", "node_type": "folder", "children": [
|
||||
{"name": "file1.txt", "node_type": "file", "size": 1024}
|
||||
]},
|
||||
{"name": "root.txt", "node_type": "file", "size": 512}
|
||||
]
|
||||
}
|
||||
|
||||
Output:
|
||||
{
|
||||
"children": [
|
||||
{"id": "xxx", "name": "folder1", "node_type": "folder", "children": [
|
||||
{"id": "yyy", "name": "file1.txt", "node_type": "file", "size": 1024, "upload_url": "..."}
|
||||
]},
|
||||
{"id": "zzz", "name": "root.txt", "node_type": "file", "size": 512, "upload_url": "..."}
|
||||
]
|
||||
}
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = BatchUploadPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
result_children = AppAssetService.batch_create_from_tree(
|
||||
app_model,
|
||||
current_user.id,
|
||||
payload.children,
|
||||
parent_id=payload.parent_id,
|
||||
)
|
||||
return {"children": [child.model_dump() for child in result_children]}, 201
|
||||
except AppAssetParentNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
except ServicePathConflictError:
|
||||
raise AppAssetPathConflictError()
|
||||
@@ -51,14 +51,6 @@ class AppImportPayload(BaseModel):
|
||||
app_id: str | None = Field(None)
|
||||
|
||||
|
||||
class AppImportBundleConfirmPayload(BaseModel):
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
icon_type: str | None = None
|
||||
icon: str | None = None
|
||||
icon_background: str | None = None
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
AppImportPayload.__name__, AppImportPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
@@ -147,68 +139,3 @@ class AppImportCheckDependenciesApi(Resource):
|
||||
result = import_service.check_dependencies(app_model=app_model)
|
||||
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/apps/imports-bundle/prepare")
|
||||
class AppImportBundlePrepareApi(Resource):
|
||||
"""Step 1: Get upload URL for bundle import."""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
from services.app_bundle_service import AppBundleService
|
||||
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
result = AppBundleService.prepare_import(
|
||||
tenant_id=current_tenant_id,
|
||||
account_id=current_user.id,
|
||||
)
|
||||
|
||||
return {"import_id": result.import_id, "upload_url": result.upload_url}, 200
|
||||
|
||||
|
||||
@console_ns.route("/apps/imports-bundle/<string:import_id>/confirm")
|
||||
class AppImportBundleConfirmApi(Resource):
|
||||
"""Step 2: Confirm bundle import after upload."""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_import_model)
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
@edit_permission_required
|
||||
def post(self, import_id: str):
|
||||
from flask import request
|
||||
|
||||
from core.app.entities.app_bundle_entities import BundleFormatError
|
||||
from services.app_bundle_service import AppBundleService
|
||||
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = AppImportBundleConfirmPayload.model_validate(request.get_json() or {})
|
||||
|
||||
try:
|
||||
result = AppBundleService.confirm_import(
|
||||
import_id=import_id,
|
||||
account=current_user,
|
||||
name=args.name,
|
||||
description=args.description,
|
||||
icon_type=args.icon_type,
|
||||
icon=args.icon,
|
||||
icon_background=args.icon_background,
|
||||
)
|
||||
except BundleFormatError as e:
|
||||
return {"error": str(e)}, 400
|
||||
|
||||
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
|
||||
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
|
||||
|
||||
status = result.status
|
||||
if status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
elif status == ImportStatus.PENDING:
|
||||
return result.model_dump(mode="json"), 202
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
@@ -5,7 +5,7 @@ from flask import abort, request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import joinedload
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
@@ -376,12 +376,8 @@ class CompletionConversationApi(Resource):
|
||||
|
||||
# FIXME, the type ignore in this file
|
||||
if args.annotation_status == "annotated":
|
||||
query = (
|
||||
query.options(selectinload(Conversation.message_annotations)) # type: ignore[arg-type]
|
||||
.join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
.distinct()
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
elif args.annotation_status == "not_annotated":
|
||||
query = (
|
||||
@@ -458,7 +454,9 @@ class ChatConversationApi(Resource):
|
||||
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
subquery = (
|
||||
sa.select(Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id"))
|
||||
db.session.query(
|
||||
Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id")
|
||||
)
|
||||
.outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id)
|
||||
.subquery()
|
||||
)
|
||||
@@ -513,12 +511,8 @@ class ChatConversationApi(Resource):
|
||||
|
||||
match args.annotation_status:
|
||||
case "annotated":
|
||||
query = (
|
||||
query.options(selectinload(Conversation.message_annotations)) # type: ignore[arg-type]
|
||||
.join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
.distinct()
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
case "not_annotated":
|
||||
query = (
|
||||
@@ -593,8 +587,10 @@ class ChatConversationDetailApi(Resource):
|
||||
|
||||
def _get_conversation(app_model, conversation_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
conversation = db.session.scalar(
|
||||
sa.select(Conversation).where(Conversation.id == conversation_id, Conversation.app_id == app_model.id).limit(1)
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
|
||||
@@ -110,6 +110,8 @@ class TracingConfigCheckError(BaseHTTPException):
|
||||
|
||||
|
||||
class InvokeRateLimitError(BaseHTTPException):
|
||||
"""Raised when the Invoke returns rate limit error."""
|
||||
|
||||
error_code = "rate_limit_error"
|
||||
description = "Rate Limit Error"
|
||||
code = 429
|
||||
@@ -119,21 +121,3 @@ class NeedAddIdsError(BaseHTTPException):
|
||||
error_code = "need_add_ids"
|
||||
description = "Need to add ids."
|
||||
code = 400
|
||||
|
||||
|
||||
class AppAssetNodeNotFoundError(BaseHTTPException):
|
||||
error_code = "app_asset_node_not_found"
|
||||
description = "App asset node not found."
|
||||
code = 404
|
||||
|
||||
|
||||
class AppAssetFileRequiredError(BaseHTTPException):
|
||||
error_code = "app_asset_file_required"
|
||||
description = "File is required."
|
||||
code = 400
|
||||
|
||||
|
||||
class AppAssetPathConflictError(BaseHTTPException):
|
||||
error_code = "app_asset_path_conflict"
|
||||
description = "Path already exists."
|
||||
code = 409
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -17,11 +16,6 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.llm_generator.context_models import (
|
||||
AvailableVarPayload,
|
||||
CodeContextPayload,
|
||||
ParameterInfoPayload,
|
||||
)
|
||||
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeError
|
||||
@@ -47,34 +41,6 @@ class InstructionTemplatePayload(BaseModel):
|
||||
type: str = Field(..., description="Instruction template type")
|
||||
|
||||
|
||||
class ContextGeneratePayload(BaseModel):
|
||||
"""Payload for generating extractor code node."""
|
||||
|
||||
language: str = Field(default="python3", description="Code language (python3/javascript)")
|
||||
prompt_messages: list[dict[str, Any]] = Field(
|
||||
..., description="Multi-turn conversation history, last message is the current instruction"
|
||||
)
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
available_vars: list[AvailableVarPayload] = Field(..., description="Available variables from upstream nodes")
|
||||
parameter_info: ParameterInfoPayload = Field(..., description="Target parameter metadata from the frontend")
|
||||
code_context: CodeContextPayload = Field(description="Existing code node context for incremental generation")
|
||||
|
||||
|
||||
class SuggestedQuestionsPayload(BaseModel):
|
||||
"""Payload for generating suggested questions."""
|
||||
|
||||
language: str = Field(
|
||||
default="English", description="Language for generated questions (e.g. English, Chinese, Japanese)"
|
||||
)
|
||||
model_config_data: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
alias="model_config",
|
||||
description="Model configuration (optional, uses system default if not provided)",
|
||||
)
|
||||
available_vars: list[AvailableVarPayload] = Field(..., description="Available variables from upstream nodes")
|
||||
parameter_info: ParameterInfoPayload = Field(..., description="Target parameter metadata from the frontend")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
@@ -84,8 +50,6 @@ reg(RuleCodeGeneratePayload)
|
||||
reg(RuleStructuredOutputPayload)
|
||||
reg(InstructionGeneratePayload)
|
||||
reg(InstructionTemplatePayload)
|
||||
reg(ContextGeneratePayload)
|
||||
reg(SuggestedQuestionsPayload)
|
||||
reg(ModelConfig)
|
||||
|
||||
|
||||
@@ -204,7 +168,7 @@ class InstructionGenerateApi(Resource):
|
||||
try:
|
||||
# Generate from nothing for a workflow node
|
||||
if (args.current in (code_template, "")) and args.node_id != "":
|
||||
app = db.session.get(App, args.flow_id)
|
||||
app = db.session.query(App).where(App.id == args.flow_id).first()
|
||||
if not app:
|
||||
return {"error": f"app {args.flow_id} not found"}, 400
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
||||
@@ -299,70 +263,3 @@ class InstructionGenerationTemplateApi(Resource):
|
||||
return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE}
|
||||
case _:
|
||||
raise ValueError(f"Invalid type: {args.type}")
|
||||
|
||||
|
||||
@console_ns.route("/context-generate")
|
||||
class ContextGenerateApi(Resource):
|
||||
@console_ns.doc("generate_with_context")
|
||||
@console_ns.doc(description="Generate with multi-turn conversation context")
|
||||
@console_ns.expect(console_ns.models[ContextGeneratePayload.__name__])
|
||||
@console_ns.response(200, "Content generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters or workflow not found")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
from core.llm_generator.utils import deserialize_prompt_messages
|
||||
|
||||
args = ContextGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
return LLMGenerator.generate_with_context(
|
||||
tenant_id=current_tenant_id,
|
||||
language=args.language,
|
||||
prompt_messages=deserialize_prompt_messages(args.prompt_messages),
|
||||
model_config=args.model_config_data,
|
||||
available_vars=args.available_vars,
|
||||
parameter_info=args.parameter_info,
|
||||
code_context=args.code_context,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
|
||||
|
||||
@console_ns.route("/context-generate/suggested-questions")
|
||||
class SuggestedQuestionsApi(Resource):
|
||||
@console_ns.doc("generate_suggested_questions")
|
||||
@console_ns.doc(description="Generate suggested questions for context generation")
|
||||
@console_ns.expect(console_ns.models[SuggestedQuestionsPayload.__name__])
|
||||
@console_ns.response(200, "Questions generated successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
args = SuggestedQuestionsPayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
try:
|
||||
return LLMGenerator.generate_suggested_questions(
|
||||
tenant_id=current_tenant_id,
|
||||
language=args.language,
|
||||
available_vars=args.available_vars,
|
||||
parameter_info=args.parameter_info,
|
||||
model_config=args.model_config_data,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import json
|
||||
from enum import StrEnum
|
||||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
@@ -11,7 +11,6 @@ from controllers.console.wraps import account_initialization_required, edit_perm
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_server_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import AppMCPServerStatus
|
||||
from models.model import AppMCPServer
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
@@ -20,6 +19,11 @@ DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
app_server_model = console_ns.model("AppServer", app_server_fields)
|
||||
|
||||
|
||||
class AppMCPServerStatus(StrEnum):
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
|
||||
|
||||
class MCPServerCreatePayload(BaseModel):
|
||||
description: str | None = Field(default=None, description="Server description")
|
||||
parameters: dict = Field(..., description="Server parameters configuration")
|
||||
@@ -48,7 +52,7 @@ class AppMCPServerController(Resource):
|
||||
@get_app_model
|
||||
@marshal_with(app_server_model)
|
||||
def get(self, app_model):
|
||||
server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.app_id == app_model.id).limit(1))
|
||||
server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first()
|
||||
return server
|
||||
|
||||
@console_ns.doc("create_app_mcp_server")
|
||||
@@ -99,24 +103,23 @@ class AppMCPServerController(Resource):
|
||||
@edit_permission_required
|
||||
def put(self, app_model):
|
||||
payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
|
||||
server = db.session.get(AppMCPServer, payload.id)
|
||||
server = db.session.query(AppMCPServer).where(AppMCPServer.id == payload.id).first()
|
||||
if not server:
|
||||
raise NotFound()
|
||||
|
||||
description = payload.description
|
||||
if description is None or not description:
|
||||
if description is None:
|
||||
pass
|
||||
elif not description:
|
||||
server.description = app_model.description or ""
|
||||
else:
|
||||
server.description = description
|
||||
|
||||
server.name = app_model.name
|
||||
|
||||
server.parameters = json.dumps(payload.parameters, ensure_ascii=False)
|
||||
if payload.status:
|
||||
try:
|
||||
server.status = AppMCPServerStatus(payload.status)
|
||||
except ValueError:
|
||||
if payload.status not in [status.value for status in AppMCPServerStatus]:
|
||||
raise ValueError("Invalid status")
|
||||
server.status = payload.status
|
||||
db.session.commit()
|
||||
return server
|
||||
|
||||
@@ -136,10 +139,11 @@ class AppMCPServerRefreshController(Resource):
|
||||
@edit_permission_required
|
||||
def get(self, server_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
server = db.session.scalar(
|
||||
select(AppMCPServer)
|
||||
.where(AppMCPServer.id == server_id, AppMCPServer.tenant_id == current_tenant_id)
|
||||
.limit(1)
|
||||
server = (
|
||||
db.session.query(AppMCPServer)
|
||||
.where(AppMCPServer.id == server_id)
|
||||
.where(AppMCPServer.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not server:
|
||||
raise NotFound()
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Literal
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import exists, func, select
|
||||
from sqlalchemy import exists, select
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
@@ -30,7 +30,6 @@ from fields.raws import FilesContainedField
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
@@ -213,7 +212,6 @@ message_detail_model = console_ns.model(
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
"generation_detail": fields.Raw,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -245,25 +243,27 @@ class ChatMessageListApi(Resource):
|
||||
def get(self, app_model):
|
||||
args = ChatMessagesQuery.model_validate(request.args.to_dict())
|
||||
|
||||
conversation = db.session.scalar(
|
||||
select(Conversation)
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id)
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
if args.first_id:
|
||||
first_message = db.session.scalar(
|
||||
select(Message).where(Message.conversation_id == conversation.id, Message.id == args.first_id).limit(1)
|
||||
first_message = (
|
||||
db.session.query(Message)
|
||||
.where(Message.conversation_id == conversation.id, Message.id == args.first_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not first_message:
|
||||
raise NotFound("First message not found")
|
||||
|
||||
history_messages = db.session.scalars(
|
||||
select(Message)
|
||||
history_messages = (
|
||||
db.session.query(Message)
|
||||
.where(
|
||||
Message.conversation_id == conversation.id,
|
||||
Message.created_at < first_message.created_at,
|
||||
@@ -271,14 +271,16 @@ class ChatMessageListApi(Resource):
|
||||
)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(args.limit)
|
||||
).all()
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
history_messages = db.session.scalars(
|
||||
select(Message)
|
||||
history_messages = (
|
||||
db.session.query(Message)
|
||||
.where(Message.conversation_id == conversation.id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(args.limit)
|
||||
).all()
|
||||
.all()
|
||||
)
|
||||
|
||||
# Initialize has_more based on whether we have a full page
|
||||
if len(history_messages) == args.limit:
|
||||
@@ -323,9 +325,7 @@ class MessageFeedbackApi(Resource):
|
||||
|
||||
message_id = str(args.message_id)
|
||||
|
||||
message = db.session.scalar(
|
||||
select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1)
|
||||
)
|
||||
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||
|
||||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
@@ -335,7 +335,7 @@ class MessageFeedbackApi(Resource):
|
||||
if not args.rating and feedback:
|
||||
db.session.delete(feedback)
|
||||
elif args.rating and feedback:
|
||||
feedback.rating = FeedbackRating(args.rating)
|
||||
feedback.rating = args.rating
|
||||
feedback.content = args.content
|
||||
elif not args.rating and not feedback:
|
||||
raise ValueError("rating cannot be None when feedback not exists")
|
||||
@@ -347,9 +347,9 @@ class MessageFeedbackApi(Resource):
|
||||
app_id=app_model.id,
|
||||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
rating=FeedbackRating(rating_value),
|
||||
rating=rating_value,
|
||||
content=args.content,
|
||||
from_source=FeedbackFromSource.ADMIN,
|
||||
from_source="admin",
|
||||
from_account_id=current_user.id,
|
||||
)
|
||||
db.session.add(feedback)
|
||||
@@ -374,9 +374,7 @@ class MessageAnnotationCountApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
count = db.session.scalar(
|
||||
select(func.count(MessageAnnotation.id)).where(MessageAnnotation.app_id == app_model.id)
|
||||
)
|
||||
count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count()
|
||||
|
||||
return {"count": count}
|
||||
|
||||
@@ -480,9 +478,7 @@ class MessageApi(Resource):
|
||||
def get(self, app_model, message_id: str):
|
||||
message_id = str(message_id)
|
||||
|
||||
message = db.session.scalar(
|
||||
select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1)
|
||||
)
|
||||
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||
|
||||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
@@ -69,7 +69,9 @@ class ModelConfigResource(Resource):
|
||||
|
||||
if app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent:
|
||||
# get original app model config
|
||||
original_app_model_config = db.session.get(AppModelConfig, app_model.app_model_config_id)
|
||||
original_app_model_config = (
|
||||
db.session.query(AppModelConfig).where(AppModelConfig.id == app_model.app_model_config_id).first()
|
||||
)
|
||||
if original_app_model_config is None:
|
||||
raise ValueError("Original app model config not found")
|
||||
agent_mode = original_app_model_config.agent_mode_dict
|
||||
|
||||
@@ -2,7 +2,6 @@ from typing import Literal
|
||||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from constants.languages import supported_language
|
||||
@@ -76,7 +75,7 @@ class AppSite(Resource):
|
||||
def post(self, app_model):
|
||||
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
|
||||
current_user, _ = current_account_with_tenant()
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
if not site:
|
||||
raise NotFound
|
||||
|
||||
@@ -125,7 +124,7 @@ class AppSiteAccessTokenReset(Resource):
|
||||
@marshal_with(app_site_model)
|
||||
def post(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
|
||||
if not site:
|
||||
raise NotFound
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, current_account_with_tenant, setup_required
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from services.skill_service import SkillService
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/llm/skills")
|
||||
class NodeSkillsApi(Resource):
|
||||
"""Extract tool dependencies from an LLM node's skill prompts.
|
||||
|
||||
The client sends the full node ``data`` object in the request body.
|
||||
The server real-time builds a ``SkillBundle`` from the current draft
|
||||
``.md`` assets and resolves transitive tool dependencies — no cached
|
||||
bundle is used.
|
||||
"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
node_data = request.get_json(force=True)
|
||||
if not isinstance(node_data, dict):
|
||||
return {"tool_dependencies": []}
|
||||
|
||||
tool_deps = SkillService.extract_tool_dependencies(
|
||||
app=app_model,
|
||||
node_data=node_data,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
return {"tool_dependencies": [d.model_dump() for d in tool_deps]}
|
||||
@@ -7,7 +7,7 @@ from flask import abort, request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import console_ns
|
||||
@@ -22,7 +22,6 @@ from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.plugin.impl.exc import PluginInvokeError
|
||||
from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE
|
||||
from core.trigger.debug.event_selectors import (
|
||||
TriggerDebugEvent,
|
||||
TriggerDebugEventPoller,
|
||||
@@ -37,7 +36,6 @@ from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from factories import file_factory, variable_factory
|
||||
from fields.member_fields import simple_account_fields
|
||||
from fields.online_user_fields import online_user_list_fields
|
||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
from libs import helper
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@@ -46,18 +44,14 @@ from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from models.workflow import Workflow
|
||||
from repositories.workflow_collaboration_repository import WORKFLOW_ONLINE_USERS_PREFIX
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow.entities import NestedNodeGraphRequest, NestedNodeParameterSchema
|
||||
from services.workflow.nested_node_graph_service import NestedNodeGraphService
|
||||
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
LISTENING_RETRY_IN = 2000
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE = "source workflow must be published"
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
@@ -167,14 +161,6 @@ class WorkflowUpdatePayload(BaseModel):
|
||||
marked_comment: str | None = Field(default=None, max_length=100)
|
||||
|
||||
|
||||
class WorkflowFeaturesPayload(BaseModel):
|
||||
features: dict[str, Any] = Field(..., description="Workflow feature configuration")
|
||||
|
||||
|
||||
class WorkflowOnlineUsersQuery(BaseModel):
|
||||
workflow_ids: str = Field(..., description="Comma-separated workflow IDs")
|
||||
|
||||
|
||||
class DraftWorkflowTriggerRunPayload(BaseModel):
|
||||
node_id: str
|
||||
|
||||
@@ -183,15 +169,6 @@ class DraftWorkflowTriggerRunAllPayload(BaseModel):
|
||||
node_ids: list[str]
|
||||
|
||||
|
||||
class NestedNodeGraphPayload(BaseModel):
|
||||
"""Request payload for generating nested node graph."""
|
||||
|
||||
parent_node_id: str = Field(description="ID of the parent node that uses the extracted value")
|
||||
parameter_key: str = Field(description="Key of the parameter being extracted")
|
||||
context_source: list[str] = Field(description="Variable selector for the context source")
|
||||
parameter_schema: dict[str, Any] = Field(description="Schema of the parameter to extract")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
@@ -207,11 +184,8 @@ reg(DefaultBlockConfigQuery)
|
||||
reg(ConvertToWorkflowPayload)
|
||||
reg(WorkflowListQuery)
|
||||
reg(WorkflowUpdatePayload)
|
||||
reg(WorkflowFeaturesPayload)
|
||||
reg(WorkflowOnlineUsersQuery)
|
||||
reg(DraftWorkflowTriggerRunPayload)
|
||||
reg(DraftWorkflowTriggerRunAllPayload)
|
||||
reg(NestedNodeGraphPayload)
|
||||
|
||||
|
||||
# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing
|
||||
@@ -309,9 +283,7 @@ class DraftWorkflowApi(Resource):
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
try:
|
||||
environment_variables_list = Workflow.normalize_environment_variable_mappings(
|
||||
args.get("environment_variables") or [],
|
||||
)
|
||||
environment_variables_list = args.get("environment_variables") or []
|
||||
environment_variables = [
|
||||
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
|
||||
]
|
||||
@@ -856,14 +828,13 @@ class PublishedWorkflowApi(Resource):
|
||||
"""
|
||||
Publish workflow
|
||||
"""
|
||||
from services.app_bundle_service import AppBundleService
|
||||
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = PublishWorkflowPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
with Session(db.engine) as session:
|
||||
workflow = AppBundleService.publish(
|
||||
workflow = workflow_service.publish_workflow(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
@@ -974,31 +945,6 @@ class ConvertToWorkflowApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/features")
|
||||
class WorkflowFeaturesApi(Resource):
|
||||
"""Update draft workflow features."""
|
||||
|
||||
@console_ns.expect(console_ns.models[WorkflowFeaturesPayload.__name__])
|
||||
@console_ns.doc("update_workflow_features")
|
||||
@console_ns.doc(description="Update draft workflow features")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Workflow features updated successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = WorkflowFeaturesPayload.model_validate(console_ns.payload or {})
|
||||
features = args.features
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
workflow_service.update_draft_workflow_features(app_model=app_model, features=features, account=current_user)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows")
|
||||
class PublishedAllWorkflowApi(Resource):
|
||||
@console_ns.expect(console_ns.models[WorkflowListQuery.__name__])
|
||||
@@ -1047,43 +993,6 @@ class PublishedAllWorkflowApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/<string:workflow_id>/restore")
|
||||
class DraftWorkflowRestoreApi(Resource):
|
||||
@console_ns.doc("restore_workflow_to_draft")
|
||||
@console_ns.doc(description="Restore a published workflow version into the draft workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "workflow_id": "Published workflow ID"})
|
||||
@console_ns.response(200, "Workflow restored successfully")
|
||||
@console_ns.response(400, "Source workflow must be published")
|
||||
@console_ns.response(404, "Workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, workflow_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
try:
|
||||
workflow = workflow_service.restore_published_workflow_to_draft(
|
||||
app_model=app_model,
|
||||
workflow_id=workflow_id,
|
||||
account=current_user,
|
||||
)
|
||||
except IsDraftWorkflowError as exc:
|
||||
raise BadRequest(RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE) from exc
|
||||
except WorkflowNotFoundError as exc:
|
||||
raise NotFound(str(exc)) from exc
|
||||
except ValueError as exc:
|
||||
raise BadRequest(str(exc)) from exc
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"hash": workflow.unique_hash,
|
||||
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/<string:workflow_id>")
|
||||
class WorkflowByIdApi(Resource):
|
||||
@console_ns.doc("update_workflow_by_id")
|
||||
@@ -1300,7 +1209,7 @@ class DraftWorkflowTriggerNodeApi(Resource):
|
||||
node_type: NodeType = draft_workflow.get_node_type_from_node_config(node_config)
|
||||
event: TriggerDebugEvent | None = None
|
||||
# for schedule trigger, when run single node, just execute directly
|
||||
if node_type == TRIGGER_SCHEDULE_NODE_TYPE:
|
||||
if node_type == NodeType.TRIGGER_SCHEDULE:
|
||||
event = TriggerDebugEvent(
|
||||
workflow_args={},
|
||||
node_id=node_id,
|
||||
@@ -1413,83 +1322,3 @@ class DraftWorkflowTriggerRunAllApi(Resource):
|
||||
"status": "error",
|
||||
}
|
||||
), 400
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nested-node-graph")
|
||||
class NestedNodeGraphApi(Resource):
|
||||
"""
|
||||
API for generating Nested Node LLM graph structures.
|
||||
|
||||
This endpoint creates a complete graph structure containing an LLM node
|
||||
configured to extract values from list[PromptMessage] variables.
|
||||
"""
|
||||
|
||||
@console_ns.doc("generate_nested_node_graph")
|
||||
@console_ns.doc(description="Generate a Nested Node LLM graph structure")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[NestedNodeGraphPayload.__name__])
|
||||
@console_ns.response(200, "Nested node graph generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
Generate a Nested Node LLM graph structure.
|
||||
|
||||
Returns a complete graph structure containing a single LLM node
|
||||
configured for extracting values from list[PromptMessage] context.
|
||||
"""
|
||||
|
||||
payload = NestedNodeGraphPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
parameter_schema = NestedNodeParameterSchema(
|
||||
name=payload.parameter_schema.get("name", payload.parameter_key),
|
||||
type=payload.parameter_schema.get("type", "string"),
|
||||
description=payload.parameter_schema.get("description", ""),
|
||||
)
|
||||
|
||||
request = NestedNodeGraphRequest(
|
||||
parent_node_id=payload.parent_node_id,
|
||||
parameter_key=payload.parameter_key,
|
||||
context_source=payload.context_source,
|
||||
parameter_schema=parameter_schema,
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
service = NestedNodeGraphService(session)
|
||||
response = service.generate_nested_node_graph(tenant_id=app_model.tenant_id, request=request)
|
||||
|
||||
return response.model_dump()
|
||||
|
||||
|
||||
@console_ns.route("/apps/workflows/online-users")
|
||||
class WorkflowOnlineUsersApi(Resource):
|
||||
@console_ns.expect(console_ns.models[WorkflowOnlineUsersQuery.__name__])
|
||||
@console_ns.doc("get_workflow_online_users")
|
||||
@console_ns.doc(description="Get workflow online users")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(online_user_list_fields)
|
||||
def get(self):
|
||||
args = WorkflowOnlineUsersQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
workflow_ids = [workflow_id.strip() for workflow_id in args.workflow_ids.split(",") if workflow_id.strip()]
|
||||
|
||||
results = []
|
||||
for workflow_id in workflow_ids:
|
||||
users_json = redis_client.hgetall(f"{WORKFLOW_ONLINE_USERS_PREFIX}{workflow_id}")
|
||||
|
||||
users = []
|
||||
for _, user_info_json in users_json.items():
|
||||
try:
|
||||
users.append(json.loads(user_info_json))
|
||||
except Exception:
|
||||
continue
|
||||
results.append({"workflow_id": workflow_id, "users": users})
|
||||
|
||||
return {"data": results}
|
||||
|
||||
@@ -1,322 +0,0 @@
|
||||
import logging
|
||||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.member_fields import AccountWithRole
|
||||
from fields.workflow_comment_fields import (
|
||||
workflow_comment_basic_fields,
|
||||
workflow_comment_create_fields,
|
||||
workflow_comment_detail_fields,
|
||||
workflow_comment_reply_create_fields,
|
||||
workflow_comment_reply_update_fields,
|
||||
workflow_comment_resolve_fields,
|
||||
workflow_comment_update_fields,
|
||||
)
|
||||
from libs.login import current_user, login_required
|
||||
from models import App
|
||||
from services.account_service import TenantService
|
||||
from services.workflow_comment_service import WorkflowCommentService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowCommentCreatePayload(BaseModel):
|
||||
position_x: float = Field(..., description="Comment X position")
|
||||
position_y: float = Field(..., description="Comment Y position")
|
||||
content: str = Field(..., description="Comment content")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentUpdatePayload(BaseModel):
|
||||
content: str = Field(..., description="Comment content")
|
||||
position_x: float | None = Field(default=None, description="Comment X position")
|
||||
position_y: float | None = Field(default=None, description="Comment Y position")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentReplyCreatePayload(BaseModel):
|
||||
content: str = Field(..., description="Reply content")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentReplyUpdatePayload(BaseModel):
|
||||
content: str = Field(..., description="Reply content")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentMentionUsersResponse(BaseModel):
|
||||
users: list[AccountWithRole] = Field(description="Mentionable users")
|
||||
|
||||
|
||||
for model in (
|
||||
WorkflowCommentCreatePayload,
|
||||
WorkflowCommentUpdatePayload,
|
||||
WorkflowCommentReplyCreatePayload,
|
||||
WorkflowCommentReplyUpdatePayload,
|
||||
):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
for model in (AccountWithRole, WorkflowCommentMentionUsersResponse):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
workflow_comment_basic_model = console_ns.model("WorkflowCommentBasic", workflow_comment_basic_fields)
|
||||
workflow_comment_detail_model = console_ns.model("WorkflowCommentDetail", workflow_comment_detail_fields)
|
||||
workflow_comment_create_model = console_ns.model("WorkflowCommentCreate", workflow_comment_create_fields)
|
||||
workflow_comment_update_model = console_ns.model("WorkflowCommentUpdate", workflow_comment_update_fields)
|
||||
workflow_comment_resolve_model = console_ns.model("WorkflowCommentResolve", workflow_comment_resolve_fields)
|
||||
workflow_comment_reply_create_model = console_ns.model(
|
||||
"WorkflowCommentReplyCreate", workflow_comment_reply_create_fields
|
||||
)
|
||||
workflow_comment_reply_update_model = console_ns.model(
|
||||
"WorkflowCommentReplyUpdate", workflow_comment_reply_update_fields
|
||||
)
|
||||
workflow_comment_mention_users_model = console_ns.models[WorkflowCommentMentionUsersResponse.__name__]
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments")
|
||||
class WorkflowCommentListApi(Resource):
|
||||
"""API for listing and creating workflow comments."""
|
||||
|
||||
@console_ns.doc("list_workflow_comments")
|
||||
@console_ns.doc(description="Get all comments for a workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Comments retrieved successfully", workflow_comment_basic_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_basic_model, envelope="data")
|
||||
def get(self, app_model: App):
|
||||
"""Get all comments for a workflow."""
|
||||
comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id)
|
||||
|
||||
return comments
|
||||
|
||||
@console_ns.doc("create_workflow_comment")
|
||||
@console_ns.doc(description="Create a new workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentCreatePayload.__name__])
|
||||
@console_ns.response(201, "Comment created successfully", workflow_comment_create_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_create_model)
|
||||
def post(self, app_model: App):
|
||||
"""Create a new workflow comment."""
|
||||
payload = WorkflowCommentCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.create_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
created_by=current_user.id,
|
||||
content=payload.content,
|
||||
position_x=payload.position_x,
|
||||
position_y=payload.position_y,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result, 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>")
|
||||
class WorkflowCommentDetailApi(Resource):
|
||||
"""API for managing individual workflow comments."""
|
||||
|
||||
@console_ns.doc("get_workflow_comment")
|
||||
@console_ns.doc(description="Get a specific workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(200, "Comment retrieved successfully", workflow_comment_detail_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_detail_model)
|
||||
def get(self, app_model: App, comment_id: str):
|
||||
"""Get a specific workflow comment."""
|
||||
comment = WorkflowCommentService.get_comment(
|
||||
tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id
|
||||
)
|
||||
|
||||
return comment
|
||||
|
||||
@console_ns.doc("update_workflow_comment")
|
||||
@console_ns.doc(description="Update a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Comment updated successfully", workflow_comment_update_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_update_model)
|
||||
def put(self, app_model: App, comment_id: str):
|
||||
"""Update a workflow comment."""
|
||||
payload = WorkflowCommentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.update_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
content=payload.content,
|
||||
position_x=payload.position_x,
|
||||
position_y=payload.position_y,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@console_ns.doc("delete_workflow_comment")
|
||||
@console_ns.doc(description="Delete a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(204, "Comment deleted successfully")
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def delete(self, app_model: App, comment_id: str):
|
||||
"""Delete a workflow comment."""
|
||||
WorkflowCommentService.delete_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/resolve")
|
||||
class WorkflowCommentResolveApi(Resource):
|
||||
"""API for resolving and reopening workflow comments."""
|
||||
|
||||
@console_ns.doc("resolve_workflow_comment")
|
||||
@console_ns.doc(description="Resolve a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(200, "Comment resolved successfully", workflow_comment_resolve_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_resolve_model)
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
"""Resolve a workflow comment."""
|
||||
comment = WorkflowCommentService.resolve_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return comment
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies")
|
||||
class WorkflowCommentReplyApi(Resource):
|
||||
"""API for managing comment replies."""
|
||||
|
||||
@console_ns.doc("create_workflow_comment_reply")
|
||||
@console_ns.doc(description="Add a reply to a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentReplyCreatePayload.__name__])
|
||||
@console_ns.response(201, "Reply created successfully", workflow_comment_reply_create_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_reply_create_model)
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
"""Add a reply to a workflow comment."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
payload = WorkflowCommentReplyCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.create_reply(
|
||||
comment_id=comment_id,
|
||||
content=payload.content,
|
||||
created_by=current_user.id,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result, 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies/<string:reply_id>")
|
||||
class WorkflowCommentReplyDetailApi(Resource):
|
||||
"""API for managing individual comment replies."""
|
||||
|
||||
@console_ns.doc("update_workflow_comment_reply")
|
||||
@console_ns.doc(description="Update a comment reply")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentReplyUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Reply updated successfully", workflow_comment_reply_update_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_reply_update_model)
|
||||
def put(self, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Update a comment reply."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
payload = WorkflowCommentReplyUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
reply = WorkflowCommentService.update_reply(
|
||||
reply_id=reply_id,
|
||||
user_id=current_user.id,
|
||||
content=payload.content,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return reply
|
||||
|
||||
@console_ns.doc("delete_workflow_comment_reply")
|
||||
@console_ns.doc(description="Delete a comment reply")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"})
|
||||
@console_ns.response(204, "Reply deleted successfully")
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def delete(self, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Delete a comment reply."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
WorkflowCommentService.delete_reply(reply_id=reply_id, user_id=current_user.id)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/mention-users")
|
||||
class WorkflowCommentMentionUsersApi(Resource):
|
||||
"""API for getting mentionable users for workflow comments."""
|
||||
|
||||
@console_ns.doc("workflow_comment_mention_users")
|
||||
@console_ns.doc(description="Get all users in current tenant for mentions")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Mentionable users retrieved successfully", workflow_comment_mention_users_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model: App):
|
||||
"""Get all users in current tenant for mentions."""
|
||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||
member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
|
||||
response = WorkflowCommentMentionUsersResponse(users=member_models)
|
||||
return response.model_dump(mode="json"), 200
|
||||
@@ -18,15 +18,14 @@ from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from dify_graph.file import helpers as file_helpers
|
||||
from dify_graph.variables.segment_group import SegmentGroup
|
||||
from dify_graph.variables.segments import ArrayFileSegment, ArrayPromptMessageSegment, FileSegment, Segment
|
||||
from dify_graph.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from dify_graph.variables.types import SegmentType
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from libs.login import current_account_with_tenant, current_user, login_required
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from libs.login import login_required
|
||||
from models import App, AppMode
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.sandbox.sandbox_service import SandboxService
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
@@ -44,16 +43,6 @@ class WorkflowDraftVariableUpdatePayload(BaseModel):
|
||||
value: Any | None = Field(default=None, description="Variable value")
|
||||
|
||||
|
||||
class ConversationVariableUpdatePayload(BaseModel):
|
||||
conversation_variables: list[dict[str, Any]] = Field(
|
||||
..., description="Conversation variables for the draft workflow"
|
||||
)
|
||||
|
||||
|
||||
class EnvironmentVariableUpdatePayload(BaseModel):
|
||||
environment_variables: list[dict[str, Any]] = Field(..., description="Environment variables for the draft workflow")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkflowDraftVariableListQuery.__name__,
|
||||
WorkflowDraftVariableListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
@@ -62,14 +51,6 @@ console_ns.schema_model(
|
||||
WorkflowDraftVariableUpdatePayload.__name__,
|
||||
WorkflowDraftVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ConversationVariableUpdatePayload.__name__,
|
||||
ConversationVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
EnvironmentVariableUpdatePayload.__name__,
|
||||
EnvironmentVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
def _convert_values_to_json_serializable_object(value: Segment):
|
||||
@@ -77,8 +58,6 @@ def _convert_values_to_json_serializable_object(value: Segment):
|
||||
return value.value.model_dump()
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
return [i.model_dump() for i in value.value]
|
||||
elif isinstance(value, ArrayPromptMessageSegment):
|
||||
return value.to_object()
|
||||
elif isinstance(value, SegmentGroup):
|
||||
return [_convert_values_to_json_serializable_object(i) for i in value.value]
|
||||
else:
|
||||
@@ -121,18 +100,6 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> dict | None:
|
||||
}
|
||||
|
||||
|
||||
def _ensure_variable_access(
|
||||
variable: WorkflowDraftVariable | None,
|
||||
app_id: str,
|
||||
variable_id: str,
|
||||
) -> WorkflowDraftVariable:
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_id or variable.user_id != current_user.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
return variable
|
||||
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
|
||||
"id": fields.String,
|
||||
"type": fields.String(attribute=lambda model: model.get_variable_type()),
|
||||
@@ -271,7 +238,6 @@ class WorkflowVariableCollectionApi(Resource):
|
||||
app_id=app_model.id,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return workflow_vars
|
||||
@@ -281,12 +247,10 @@ class WorkflowVariableCollectionApi(Resource):
|
||||
@console_ns.response(204, "Workflow variables deleted successfully")
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
SandboxService.delete_draft_storage(app_model.tenant_id, app_model.id, current_user.id)
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
draft_var_srv.delete_user_workflow_variables(app_model.id, user_id=current_user.id)
|
||||
draft_var_srv.delete_workflow_variables(app_model.id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
@@ -323,7 +287,7 @@ class NodeVariableCollectionApi(Resource):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
node_vars = draft_var_srv.list_node_variables(app_model.id, node_id, user_id=current_user.id)
|
||||
node_vars = draft_var_srv.list_node_variables(app_model.id, node_id)
|
||||
|
||||
return node_vars
|
||||
|
||||
@@ -334,7 +298,7 @@ class NodeVariableCollectionApi(Resource):
|
||||
def delete(self, app_model: App, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
srv = WorkflowDraftVariableService(db.session())
|
||||
srv.delete_node_variables(app_model.id, node_id, user_id=current_user.id)
|
||||
srv.delete_node_variables(app_model.id, node_id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
@@ -355,11 +319,11 @@ class VariableApi(Resource):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id,
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_model.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
return variable
|
||||
|
||||
@console_ns.doc("update_variable")
|
||||
@@ -396,11 +360,11 @@ class VariableApi(Resource):
|
||||
)
|
||||
args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id,
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_model.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
|
||||
new_name = args_model.name
|
||||
raw_value = args_model.value
|
||||
@@ -419,7 +383,7 @@ class VariableApi(Resource):
|
||||
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id)
|
||||
new_value = variable_factory.build_segment_with_type(variable.value_type, raw_value)
|
||||
new_value = build_segment_with_type(variable.value_type, raw_value)
|
||||
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||
db.session.commit()
|
||||
return variable
|
||||
@@ -433,11 +397,11 @@ class VariableApi(Resource):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id,
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_model.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
draft_var_srv.delete_variable(variable)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
@@ -463,11 +427,11 @@ class VariableResetApi(Resource):
|
||||
raise NotFoundError(
|
||||
f"Draft workflow not found, app_id={app_model.id}",
|
||||
)
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id,
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_model.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
|
||||
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
|
||||
db.session.commit()
|
||||
@@ -483,15 +447,11 @@ def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
|
||||
session=session,
|
||||
)
|
||||
if node_id == CONVERSATION_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_conversation_variables(app_model.id, user_id=current_user.id)
|
||||
draft_vars = draft_var_srv.list_conversation_variables(app_model.id)
|
||||
elif node_id == SYSTEM_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_system_variables(app_model.id, user_id=current_user.id)
|
||||
draft_vars = draft_var_srv.list_system_variables(app_model.id)
|
||||
else:
|
||||
draft_vars = draft_var_srv.list_node_variables(
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
draft_vars = draft_var_srv.list_node_variables(app_id=app_model.id, node_id=node_id)
|
||||
return draft_vars
|
||||
|
||||
|
||||
@@ -512,39 +472,10 @@ class ConversationVariableCollectionApi(Resource):
|
||||
if draft_workflow is None:
|
||||
raise NotFoundError(description=f"draft workflow not found, id={app_model.id}")
|
||||
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||
draft_var_srv.prefill_conversation_variable_default_values(draft_workflow, user_id=current_user.id)
|
||||
draft_var_srv.prefill_conversation_variable_default_values(draft_workflow)
|
||||
db.session.commit()
|
||||
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
|
||||
|
||||
@console_ns.expect(console_ns.models[ConversationVariableUpdatePayload.__name__])
|
||||
@console_ns.doc("update_conversation_variables")
|
||||
@console_ns.doc(description="Update conversation variables for workflow draft")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Conversation variables updated successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
||||
def post(self, app_model: App):
|
||||
payload = ConversationVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
conversation_variables_list = payload.conversation_variables
|
||||
conversation_variables = [
|
||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
]
|
||||
|
||||
current_user, _ = current_account_with_tenant()
|
||||
workflow_service.update_draft_workflow_conversation_variables(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/system-variables")
|
||||
class SystemVariableCollectionApi(Resource):
|
||||
@@ -596,32 +527,3 @@ class EnvironmentVariableCollectionApi(Resource):
|
||||
)
|
||||
|
||||
return {"items": env_vars_list}
|
||||
|
||||
@console_ns.expect(console_ns.models[EnvironmentVariableUpdatePayload.__name__])
|
||||
@console_ns.doc("update_environment_variables")
|
||||
@console_ns.doc(description="Update environment variables for workflow draft")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Environment variables updated successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
payload = EnvironmentVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
environment_variables_list = payload.environment_variables
|
||||
environment_variables = [
|
||||
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
|
||||
]
|
||||
|
||||
workflow_service.update_draft_workflow_environment_variables(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
environment_variables=environment_variables,
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@@ -2,8 +2,6 @@ from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar, Union
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from controllers.console.app.error import AppNotFoundError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant
|
||||
@@ -17,14 +15,16 @@ R1 = TypeVar("R1")
|
||||
|
||||
def _load_app_model(app_id: str) -> App | None:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app_model = db.session.scalar(
|
||||
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
|
||||
app_model = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
return app_model
|
||||
|
||||
|
||||
def _load_app_model_with_trial(app_id: str) -> App | None:
|
||||
app_model = db.session.scalar(select(App).where(App.id == app_id, App.status == "normal").limit(1))
|
||||
app_model = db.session.query(App).where(App.id == app_id, App.status == "normal").first()
|
||||
return app_model
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
@@ -73,7 +73,7 @@ class EmailRegisterSendEmailApi(Resource):
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
|
||||
token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language)
|
||||
return {"result": "success", "data": token}
|
||||
@@ -145,7 +145,7 @@ class EmailRegisterResetApi(Resource):
|
||||
email = register_data.get("email", "")
|
||||
normalized_email = email.lower()
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
|
||||
if account:
|
||||
|
||||
@@ -4,7 +4,7 @@ import secrets
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
@@ -102,7 +102,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
|
||||
|
||||
token = AccountService.send_reset_password_email(
|
||||
@@ -201,7 +201,7 @@ class ForgotPasswordResetApi(Resource):
|
||||
password_hashed = hash_password(args.new_password, salt)
|
||||
|
||||
email = reset_data.get("email", "")
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
|
||||
if account:
|
||||
@@ -215,6 +215,7 @@ class ForgotPasswordResetApi(Resource):
|
||||
# Update existing account credentials
|
||||
account.password = base64.b64encode(password_hashed).decode()
|
||||
account.password_salt = base64.b64encode(salt).decode()
|
||||
session.commit()
|
||||
|
||||
# Create workspace if needed
|
||||
if (
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import logging
|
||||
import urllib.parse
|
||||
|
||||
import httpx
|
||||
from flask import current_app, redirect, request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
@@ -113,9 +112,6 @@ class OAuthCallback(Resource):
|
||||
error_text = e.response.text
|
||||
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
|
||||
return {"error": "OAuth process failed"}, 400
|
||||
except ValueError as e:
|
||||
logger.warning("OAuth error with %s", provider, exc_info=True)
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={urllib.parse.quote(str(e))}")
|
||||
|
||||
if invite_token and RegisterService.is_valid_invite_token(invite_token):
|
||||
invitation = RegisterService.get_invitation_by_token(token=invite_token)
|
||||
@@ -180,7 +176,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
|
||||
account: Account | None = Account.get_by_openid(provider, user_info.id)
|
||||
|
||||
if not account:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine) as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session)
|
||||
|
||||
return account
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, cast
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
@@ -29,7 +29,6 @@ from core.provider_manager import ProviderManager
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
@@ -55,7 +54,6 @@ from fields.document_fields import document_status_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DatasetPermission, DatasetPermissionEnum
|
||||
from models.enums import ApiTokenType, SegmentStatus
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.api_token_service import ApiTokenCache
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
@@ -265,7 +263,6 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
|
||||
VectorType.BAIDU,
|
||||
VectorType.ALIBABACLOUD_MYSQL,
|
||||
VectorType.IRIS,
|
||||
VectorType.HOLOGRES,
|
||||
}
|
||||
|
||||
semantic_methods = {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
@@ -356,7 +353,7 @@ class DatasetListApi(Resource):
|
||||
|
||||
for item in data:
|
||||
# convert embedding_model_provider to plugin standard format
|
||||
if item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY and item["embedding_model_provider"]:
|
||||
if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
|
||||
item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
|
||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
|
||||
if item_model in model_names:
|
||||
@@ -437,7 +434,7 @@ class DatasetApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.embedding_model_provider:
|
||||
provider_id = ModelProviderID(dataset.embedding_model_provider)
|
||||
data["embedding_model_provider"] = str(provider_id)
|
||||
@@ -455,7 +452,7 @@ class DatasetApi(Resource):
|
||||
for embedding_model in embedding_models:
|
||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||
|
||||
if data["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY:
|
||||
if data["indexing_technique"] == "high_quality":
|
||||
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
|
||||
if item_model in model_names:
|
||||
data["embedding_available"] = True
|
||||
@@ -486,7 +483,7 @@ class DatasetApi(Resource):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# check embedding model setting
|
||||
if (
|
||||
payload.indexing_technique == IndexTechniqueType.HIGH_QUALITY
|
||||
payload.indexing_technique == "high_quality"
|
||||
and payload.embedding_model_provider is not None
|
||||
and payload.embedding_model is not None
|
||||
):
|
||||
@@ -739,23 +736,18 @@ class DatasetIndexingStatusApi(Resource):
|
||||
documents_status = []
|
||||
for document in documents:
|
||||
completed_segments = (
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != "re_segment",
|
||||
)
|
||||
or 0
|
||||
.count()
|
||||
)
|
||||
total_segments = (
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
|
||||
.count()
|
||||
)
|
||||
# Create a dictionary with document attributes and additional fields
|
||||
document_dict = {
|
||||
@@ -781,7 +773,7 @@ class DatasetIndexingStatusApi(Resource):
|
||||
class DatasetApiKeyApi(Resource):
|
||||
max_keys = 10
|
||||
token_prefix = "dataset-"
|
||||
resource_type = ApiTokenType.DATASET
|
||||
resource_type = "dataset"
|
||||
|
||||
@console_ns.doc("get_dataset_api_keys")
|
||||
@console_ns.doc(description="Get dataset API keys")
|
||||
@@ -806,12 +798,9 @@ class DatasetApiKeyApi(Resource):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
current_key_count = (
|
||||
db.session.scalar(
|
||||
select(func.count(ApiToken.id)).where(
|
||||
ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id
|
||||
)
|
||||
)
|
||||
or 0
|
||||
db.session.query(ApiToken)
|
||||
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
|
||||
.count()
|
||||
)
|
||||
|
||||
if current_key_count >= self.max_keys:
|
||||
@@ -833,7 +822,7 @@ class DatasetApiKeyApi(Resource):
|
||||
|
||||
@console_ns.route("/datasets/api-keys/<uuid:api_key_id>")
|
||||
class DatasetApiDeleteApi(Resource):
|
||||
resource_type = ApiTokenType.DATASET
|
||||
resource_type = "dataset"
|
||||
|
||||
@console_ns.doc("delete_dataset_api_key")
|
||||
@console_ns.doc(description="Delete dataset API key")
|
||||
@@ -846,14 +835,14 @@ class DatasetApiDeleteApi(Resource):
|
||||
def delete(self, api_key_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
api_key_id = str(api_key_id)
|
||||
key = db.session.scalar(
|
||||
select(ApiToken)
|
||||
key = (
|
||||
db.session.query(ApiToken)
|
||||
.where(
|
||||
ApiToken.tenant_id == current_tenant_id,
|
||||
ApiToken.type == self.resource_type,
|
||||
ApiToken.id == api_key_id,
|
||||
)
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
|
||||
if key is None:
|
||||
@@ -864,7 +853,7 @@ class DatasetApiDeleteApi(Resource):
|
||||
assert key is not None # nosec - for type checker only
|
||||
ApiTokenCache.delete(key.token, key.type)
|
||||
|
||||
db.session.delete(key)
|
||||
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
@@ -10,7 +10,7 @@ import sqlalchemy as sa
|
||||
from flask import request, send_file
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import asc, desc, func, select
|
||||
from sqlalchemy import asc, desc, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
@@ -27,7 +27,6 @@ from core.model_manager import ModelManager
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from extensions.ext_database import db
|
||||
@@ -43,7 +42,6 @@ from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DocumentPipelineExecutionLog
|
||||
from models.enums import IndexingStatus, SegmentStatus
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
|
||||
from services.file_service import FileService
|
||||
@@ -212,11 +210,12 @@ class GetProcessRuleApi(Resource):
|
||||
raise Forbidden(str(e))
|
||||
|
||||
# get the latest process rule
|
||||
dataset_process_rule = db.session.scalar(
|
||||
select(DatasetProcessRule)
|
||||
dataset_process_rule = (
|
||||
db.session.query(DatasetProcessRule)
|
||||
.where(DatasetProcessRule.dataset_id == document.dataset_id)
|
||||
.order_by(DatasetProcessRule.created_at.desc())
|
||||
.limit(1)
|
||||
.one_or_none()
|
||||
)
|
||||
if dataset_process_rule:
|
||||
mode = dataset_process_rule.mode
|
||||
@@ -298,7 +297,6 @@ class DatasetDocumentListApi(Resource):
|
||||
if sort == "hit_count":
|
||||
sub_query = (
|
||||
sa.select(DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
|
||||
.where(DocumentSegment.dataset_id == str(dataset_id))
|
||||
.group_by(DocumentSegment.document_id)
|
||||
.subquery()
|
||||
)
|
||||
@@ -330,23 +328,18 @@ class DatasetDocumentListApi(Resource):
|
||||
if fetch:
|
||||
for document in documents:
|
||||
completed_segments = (
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != "re_segment",
|
||||
)
|
||||
or 0
|
||||
.count()
|
||||
)
|
||||
total_segments = (
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
|
||||
.count()
|
||||
)
|
||||
document.completed_segments = completed_segments
|
||||
document.total_segments = total_segments
|
||||
@@ -450,7 +443,7 @@ class DatasetInitApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {})
|
||||
if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
if knowledge_config.indexing_technique == "high_quality":
|
||||
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
|
||||
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
|
||||
try:
|
||||
@@ -464,7 +457,7 @@ class DatasetInitApi(Resource):
|
||||
is_multimodal = DatasetService.check_is_multimodal_model(
|
||||
current_tenant_id, knowledge_config.embedding_model_provider, knowledge_config.embedding_model
|
||||
)
|
||||
knowledge_config.is_multimodal = is_multimodal # pyrefly: ignore[bad-assignment]
|
||||
knowledge_config.is_multimodal = is_multimodal
|
||||
except InvokeAuthorizationError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
@@ -510,7 +503,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}:
|
||||
if document.indexing_status in {"completed", "error"}:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
|
||||
data_process_rule = document.dataset_process_rule
|
||||
@@ -523,10 +516,10 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
if data_source_info and "upload_file_id" in data_source_info:
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
|
||||
file = db.session.scalar(
|
||||
select(UploadFile)
|
||||
file = (
|
||||
db.session.query(UploadFile)
|
||||
.where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
|
||||
# raise error if file not found
|
||||
@@ -580,7 +573,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {}
|
||||
extract_settings = []
|
||||
for document in documents:
|
||||
if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}:
|
||||
if document.indexing_status in {"completed", "error"}:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
data_source_info = document.data_source_info_dict
|
||||
match document.data_source_type:
|
||||
@@ -588,10 +581,10 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
if not data_source_info:
|
||||
continue
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
file_detail = db.session.scalar(
|
||||
select(UploadFile)
|
||||
file_detail = (
|
||||
db.session.query(UploadFile)
|
||||
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
|
||||
if file_detail is None:
|
||||
@@ -674,28 +667,23 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
|
||||
documents_status = []
|
||||
for document in documents:
|
||||
completed_segments = (
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != "re_segment",
|
||||
)
|
||||
or 0
|
||||
.count()
|
||||
)
|
||||
total_segments = (
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
|
||||
.count()
|
||||
)
|
||||
# Create a dictionary with document attributes and additional fields
|
||||
document_dict = {
|
||||
"id": document.id,
|
||||
"indexing_status": IndexingStatus.PAUSED if document.is_paused else document.indexing_status,
|
||||
"indexing_status": "paused" if document.is_paused else document.indexing_status,
|
||||
"processing_started_at": document.processing_started_at,
|
||||
"parsing_completed_at": document.parsing_completed_at,
|
||||
"cleaning_completed_at": document.cleaning_completed_at,
|
||||
@@ -728,29 +716,24 @@ class DocumentIndexingStatusApi(DocumentResource):
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
completed_segments = (
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.status != "re_segment",
|
||||
)
|
||||
or 0
|
||||
.count()
|
||||
)
|
||||
total_segments = (
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment")
|
||||
.count()
|
||||
)
|
||||
|
||||
# Create a dictionary with document attributes and additional fields
|
||||
document_dict = {
|
||||
"id": document.id,
|
||||
"indexing_status": IndexingStatus.PAUSED if document.is_paused else document.indexing_status,
|
||||
"indexing_status": "paused" if document.is_paused else document.indexing_status,
|
||||
"processing_started_at": document.processing_started_at,
|
||||
"parsing_completed_at": document.parsing_completed_at,
|
||||
"cleaning_completed_at": document.cleaning_completed_at,
|
||||
@@ -972,7 +955,7 @@ class DocumentProcessingApi(DocumentResource):
|
||||
|
||||
match action:
|
||||
case "pause":
|
||||
if document.indexing_status != IndexingStatus.INDEXING:
|
||||
if document.indexing_status != "indexing":
|
||||
raise InvalidActionError("Document not in indexing state.")
|
||||
|
||||
document.paused_by = current_user.id
|
||||
@@ -981,7 +964,7 @@ class DocumentProcessingApi(DocumentResource):
|
||||
db.session.commit()
|
||||
|
||||
case "resume":
|
||||
if document.indexing_status not in {IndexingStatus.PAUSED, IndexingStatus.ERROR}:
|
||||
if document.indexing_status not in {"paused", "error"}:
|
||||
raise InvalidActionError("Document not in paused or error state.")
|
||||
|
||||
document.paused_by = None
|
||||
@@ -1186,7 +1169,7 @@ class DocumentRetryApi(DocumentResource):
|
||||
raise ArchivedDocumentImmutableError()
|
||||
|
||||
# 400 if document is completed
|
||||
if document.indexing_status == IndexingStatus.COMPLETED:
|
||||
if document.indexing_status == "completed":
|
||||
raise DocumentAlreadyFinishedError()
|
||||
retry_documents.append(document)
|
||||
except Exception:
|
||||
@@ -1268,11 +1251,11 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
log = db.session.scalar(
|
||||
select(DocumentPipelineExecutionLog)
|
||||
.where(DocumentPipelineExecutionLog.document_id == document_id)
|
||||
log = (
|
||||
db.session.query(DocumentPipelineExecutionLog)
|
||||
.filter_by(document_id=document_id)
|
||||
.order_by(DocumentPipelineExecutionLog.created_at.desc())
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
if not log:
|
||||
return {
|
||||
@@ -1338,7 +1321,7 @@ class DocumentGenerateSummaryApi(Resource):
|
||||
raise BadRequest("document_list cannot be empty.")
|
||||
|
||||
# Check if dataset configuration supports summary generation
|
||||
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
raise ValueError(
|
||||
f"Summary generation is only available for 'high_quality' indexing technique. "
|
||||
f"Current indexing technique: {dataset.indexing_technique}"
|
||||
|
||||
@@ -26,7 +26,6 @@ from controllers.console.wraps import (
|
||||
)
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
@@ -46,7 +45,7 @@ def _get_segment_with_summary(segment, dataset_id):
|
||||
"""Helper function to marshal segment and add summary information."""
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore
|
||||
segment_dict = dict(marshal(segment, segment_fields))
|
||||
# Query summary for this segment (only enabled summaries)
|
||||
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
|
||||
segment_dict["summary"] = summary.summary_content if summary else None
|
||||
@@ -207,7 +206,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
# Add summary to each segment
|
||||
segments_with_summary = []
|
||||
for segment in segments.items:
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore
|
||||
segment_dict = dict(marshal(segment, segment_fields))
|
||||
segment_dict["summary"] = summaries.get(segment.id)
|
||||
segments_with_summary.append(segment_dict)
|
||||
|
||||
@@ -280,7 +279,7 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
# check embedding model setting
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
@@ -334,7 +333,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
@@ -384,7 +383,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
# check embedding model setting
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
@@ -402,10 +401,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
@@ -448,10 +447,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
@@ -495,7 +494,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
payload = BatchImportPayload.model_validate(console_ns.payload or {})
|
||||
upload_file_id = payload.upload_file_id
|
||||
|
||||
upload_file = db.session.scalar(select(UploadFile).where(UploadFile.id == upload_file_id).limit(1))
|
||||
upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
|
||||
if not upload_file:
|
||||
raise NotFound("UploadFile not found.")
|
||||
|
||||
@@ -560,17 +559,17 @@ class ChildChunkAddApi(Resource):
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
@@ -617,10 +616,10 @@ class ChildChunkAddApi(Resource):
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
@@ -667,10 +666,10 @@ class ChildChunkAddApi(Resource):
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
@@ -715,24 +714,24 @@ class ChildChunkUpdateApi(Resource):
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
# check child chunk
|
||||
child_chunk_id = str(child_chunk_id)
|
||||
child_chunk = db.session.scalar(
|
||||
select(ChildChunk)
|
||||
child_chunk = (
|
||||
db.session.query(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.id == str(child_chunk_id),
|
||||
ChildChunk.tenant_id == current_tenant_id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.document_id == document_id,
|
||||
)
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
if not child_chunk:
|
||||
raise NotFound("Child chunk not found.")
|
||||
@@ -772,24 +771,24 @@ class ChildChunkUpdateApi(Resource):
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
# check child chunk
|
||||
child_chunk_id = str(child_chunk_id)
|
||||
child_chunk = db.session.scalar(
|
||||
select(ChildChunk)
|
||||
child_chunk = (
|
||||
db.session.query(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.id == str(child_chunk_id),
|
||||
ChildChunk.tenant_id == current_tenant_id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.document_id == document_id,
|
||||
)
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
if not child_chunk:
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user